diff --git a/.gitignore b/.gitignore index 20cf4dd4e..7f4152128 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,11 @@ data-in/ data-out/ node_modules/ .DS_Store/ +tsconfig.tsbuildinfo +dist +.nx/cache +*.d.ts +storybook-static/ +packages/cli-web/ +packages/cli-sec/ +packages/ngraph/ \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 79d355444..b37cb81ac 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,6 +10,7 @@ "search.exclude": { "data-files/**": true, "data-out/**": true, + "dist/**": true, "**/package-lock.json": true }, "editor.codeActionsOnSave": { @@ -27,5 +28,6 @@ "html.format.enable": true, "json.format.enable": true, "javascript.format.enable": true, - "editor.wordWrapColumn": 100 + "editor.wordWrapColumn": 100, + "prettier.printWidth": 100 } diff --git a/bun.lockb b/bun.lockb index bf24743cc..2ade2c80c 100755 Binary files a/bun.lockb and b/bun.lockb differ diff --git a/docs/01_motivations.md b/docs/01_motivations.md index f7d1689b6..957f50faf 100644 --- a/docs/01_motivations.md +++ b/docs/01_motivations.md @@ -16,4 +16,4 @@ I have two very different side projects that I am working on, where intelligent This project is an attempt to build a framework that can be used to build intelligent retrieval systems. The main requirements are: - I need the ability to test and iterate quickly on ideas, models, and data using both local and cloud resources. And compare the results of different approaches. -- Be able to put into production and decide to change approaches later without having to rewrite everything. +- Be able to put into production and decide to change approaches later without having to rewrite everything. So switching embedding models part way through a project should be easy. diff --git a/docs/03_matrix_operations.md b/docs/03_matrix_operations.md index bb56ad762..3a9ef3f74 100644 --- a/docs/03_matrix_operations.md +++ b/docs/03_matrix_operations.md @@ -27,6 +27,10 @@ We need to lookup data based on the query. - We will use the same embedding model as the data embeddings - Different pre-processing methods of the query to embed. We will call this a query rewriter. This does not need to use the same generative model as the data rewriter, though it likely would. +## Reranking + +We need to re-rank the results based on the query. + ## Storage and Retrieval We can start with these options: diff --git a/docs/05_task_chains.md b/docs/05_task_chains.md index 004f0c755..eeb6fa94b 100644 --- a/docs/05_task_chains.md +++ b/docs/05_task_chains.md @@ -1,4 +1,4 @@ -# Tasks, Task Lists, and Strategies +# Tasks and Task Graphs ## Requirements @@ -76,12 +76,12 @@ Uses: - ApplyPromptTask - TextGenerationTask -### TextRewriterTaskList +### TextRewriter with Multiple Models Inputs - content -- model +- model[] - parameters - prompt @@ -93,7 +93,7 @@ Uses: - TextRewriterTask -### TextEmbeddingStrategy +### TextEmbedding Strategy Inputs @@ -118,14 +118,7 @@ Example: ```ts new TextEmbeddingStrategy({ content: "This is a test", - embedding_model: [ - { - name: "Xenova/distilbert-base-uncased", - model_parameters: { - temperature: 0.7, - }, - }, - ], + embedding_model: name: "Xenova/distilbert-base-uncased" rewriter: [ { prompt_model: "Xenova/gpt2", @@ -145,12 +138,6 @@ A task is a single step in the chain where most tasks output will be input for t Tasks get posted to a job queue and are run by a job queue runner. -## TaskList - -A strategy is a list of tasks that are chained together to look like a single task. - -## Strategy - -A strategy is a list of tasks that are chained together to look like a single task. Parts can be run in series or in parallel. It orchestrates variations of the same task. +## CompoundTask -Strategies get a name and are saved in the database, both as a parent all the variations. The variation names are based on the spefic parameters used rather than the parent name. +A compound task is a groups of tasks (in DAG format) that are chained together to look like a single task. diff --git a/docs/06_run_graph_orchestration.md b/docs/06_run_graph_orchestration.md index 2d888499c..18bf74504 100644 --- a/docs/06_run_graph_orchestration.md +++ b/docs/06_run_graph_orchestration.md @@ -14,19 +14,66 @@ The pipline DAG is defined by the end user and saved in the database (nodes and The graph is a DAG. It is a list of nodes and a list of edges. The nodes are the tasks and the edges are the inputs and outputs of the tasks plus some other instrumetation data. +We might want to have events based on what happens in the graph (and a suspend/resume for bulk creation/etc). This will be needed to keep UI in sync with the as it runs. + ### Node - Task -- TaskList -- Strategy +- SimpleTaks +- CompoundTask (has a sub-graph) + +Notes about requirements for the nodes: + +- Must have input list and output list + - the input or output will have a type object that JS can read, and not a TS type (though that should get derived from the type object) +- We need to convert the inputs/outputs to a TypeScript type ### Edge -- Input -- Output +- DataFlow - Instrumentation -- Events + +Notes about requirements for the edges: + +- There can be multiple outputs that go to multiple inputs + - I.g., there can and will be multiple edges between two nodes ### Graph Runner -The graph runner is a simple recursive function that takes a graph and a node and runs the node. If the node is a task, it runs the task. If the node is a TaskList or Strategy, it runs the subgraph. +The graph runner is a simple recursive function that takes a graph and a node and runs the node. If the node is a task, it runs the task. If the node is a CompoundTask, it runs the subgraph. + +# User Task Graph + +```mermaid +erDiagram + TaskGraph ||--o{ Task : nodes + TaskGraph ||--o{ DataFlow : edges + Task ||--o{ TaskInput : inputs + Task ||--o{ TaskOutput: outputs + TaskInput ||--|| ValueType : valueType + TaskOutput ||--|| ValueType : valueType + DataFlow ||--|| TaskInput : handle + DataFlow ||--|| TaskOutput : handle + DataFlow ||--|| Task : source + DataFlow ||--|| Task : target + + TaskGraph{ + Task[] nodes + DataFlow[] edges + } + + Task{ + string name + string id + TaskInput[] inputs + TaskOutput[] outputs + } + + DataFlow{ + string id + Task sourceTaskId + Task targetTaskId + TaskInput sourceTaskInput + TaskOutput targetTaskOutput + } +``` diff --git a/lerna.json b/lerna.json new file mode 100644 index 000000000..4560d2f43 --- /dev/null +++ b/lerna.json @@ -0,0 +1,5 @@ +{ + "$schema": "node_modules/lerna/schemas/lerna-schema.json", + "version": "0.0.0", + "packages": ["packages/*"] +} diff --git a/nx.json b/nx.json new file mode 100644 index 000000000..f3114c05a --- /dev/null +++ b/nx.json @@ -0,0 +1,13 @@ +{ + "targetDefaults": { + "build": { + "cache": true, + "dependsOn": [], + "outputs": ["{projectRoot}/dist"] + }, + "test": { + "cache": true, + "dependsOn": [] + } + } +} diff --git a/package.json b/package.json index ee29b67d1..11a7247f7 100644 --- a/package.json +++ b/package.json @@ -1,23 +1,50 @@ { "name": "ellmers", - "module": "ellmers", "type": "module", - "devDependencies": { - "@types/bun": "^1.0.4", - "@types/uuid": "^9.0.7" - }, - "peerDependencies": { - "typescript": "^5.3.3" + "version": "0.0.1", + "description": "Ellmers is a tool for building and running DAG pipelines of AI tasks.", + "workspaces": [ + "./packages/*" + ], + "scripts": { + "build": "lerna run build", + "clean": "rm -rf node_modules packages/*/node_modules packages/*/dist", + "watch": "lerna run watch --parallel --stream", + "docs": "typedoc", + "format": "eslint \"{packages}/*/src/**/*.{js,ts,tsx,json}\" --fix && prettier \"{packages}/*/src/**/*.{js,ts,tsx,json}\" --check --write", + "release": "npm run build && npm publish", + "test": "jest" }, "dependencies": { - "@mediapipe/tasks-text": "^0.10.9", - "@sroussey/transformers": "^2.14.3", - "@sroussey/typescript-graph": "^0.3.6", + "@sroussey/typescript-graph": "^0.3.12", + "@xyflow/react": "12.0.0-next.9", "chalk": "^5.3.0", "commander": "^11.1.0", "eventemitter3": "^5.0.1", - "listr2": "^8.0.1", + "listr2": "^8.0.2", + "nanoid": "^5.0.6", + "postcss": "^8.4.35", + "react-hotkeys-hook": "^4.5.0", + "react-icons": "^5.0.1", "rxjs": "^7.8.1", "uuid": "^9.0.1" + }, + "devDependencies": { + "@types/bun": "^1.0.6", + "@types/uuid": "^9.0.8", + "autoprefixer": "^10.4.17", + "lerna": "^8.1.2", + "react": "^18.2.0", + "react-dom": "^18.2.0", + "tailwindcss": "^3.4.1", + "typescript": "^5.3.3", + "vite": "^5.1.4" + }, + "peerDependencies": { + "@mediapipe/tasks-text": "^0.10.9", + "@sroussey/transformers": "^2.15.1" + }, + "engines": { + "bun": "^1.0.5" } } diff --git a/packages/cli/package.json b/packages/cli/package.json new file mode 100644 index 000000000..b51455e03 --- /dev/null +++ b/packages/cli/package.json @@ -0,0 +1,18 @@ +{ + "name": "ellmers-cli", + "type": "module", + "version": "0.0.1", + "description": "Ellmers is a tool for building and running DAG pipelines of AI tasks.", + "scripts": { + "watch": "", + "build": "", + "test": "echo \"Error: no test specified\" && exit 1" + }, + "bin": "src/elmers.js", + "files": [ + "src" + ], + "dependencies": { + "ellmers-core": "workspace:*" + } +} diff --git a/src-examples/TaskCLI.ts b/packages/cli/src/TaskCLI.ts similarity index 50% rename from src-examples/TaskCLI.ts rename to packages/cli/src/TaskCLI.ts index 580868e3e..673754aa9 100644 --- a/src-examples/TaskCLI.ts +++ b/packages/cli/src/TaskCLI.ts @@ -6,63 +6,54 @@ // ******************************************************************************* import { Command } from "commander"; -import { runTaskToListr } from "./TaskStreamToListr2"; -import { ParallelTaskList } from "#/Task"; -import { - EmbeddingTask, - RewriterTask, - SummarizeTask, - DownloadTask, -} from "#/tasks/FactoryTasks"; -import { - EmbeddingStrategy, - RewriterEmbeddingStrategy, - RewriterStrategy, - SummarizeStrategy, -} from "#/tasks/Strategies"; -import { sleep } from "#/util/Misc"; -import { JsonStrategy, TaskJsonInput } from "#/tasks/JsonTask"; -import { ModelUseCaseEnum } from "#/Model"; +import { runTask } from "./TaskStreamToListr2"; +import "@sroussey/transformers"; import { findAllModels, findModelByName, findModelByUseCase, -} from "#/storage/InMemoryStorage"; - -async function runTask(task: any) { - if (process.stdout.isTTY) { - await runTaskToListr(task); - await sleep(100); - console.log(task.output); - } else { - await task.run({}); - process.stdout.write(JSON.stringify(task.output)); - } -} - -export function AddSampleCommand(program: Command) { + EmbeddingTask, + TextRewriterTask, + SummarizeTask, + DownloadTask, + ModelUseCaseEnum, + EmbeddingMultiModelTask, + registerHuggingfaceLocalTasks, + registerMediaPipeTfJsLocalTasks, + DownloadMultiModelTask, + TextRewriterMultiModelTask, + SummarizeMultiModelTask, + TaskGraph, + JsonTaskArray, + JsonTask, +} from "ellmers-core"; + +registerHuggingfaceLocalTasks(); +registerMediaPipeTfJsLocalTasks(); + +export function AddBaseCommands(program: Command) { program .command("download") .description("download models") .option("--model ", "model to download") .action(async (options) => { - let models = findAllModels(); + const models = findAllModels(); + const graph = new TaskGraph(); if (options.model) { const model = findModelByName(options.model); if (model) { - models = [model]; + graph.addTask(new DownloadTask({ input: { model: model.name } })); } else { program.error(`Unknown model ${options.model}`); } + } else { + graph.addTask( + new DownloadMultiModelTask({ + input: { model: models.map((m) => m.name) }, + }) + ); } - - const task = new ParallelTaskList( - { name: "Download Models" }, - models.map((model) => new DownloadTask({}, { model })) - ); - await runTaskToListr(task); - - await sleep(100); + await runTask(graph); }); program @@ -70,24 +61,25 @@ export function AddSampleCommand(program: Command) { .description("get a embedding vector for a piece of text") .argument("", "text to embed") .option("--model ", "model to use") - .action(async (text, options) => { - let task; + .action(async (text: string, options) => { + const graph = new TaskGraph(); if (options.model) { const model = findModelByName(options.model); if (model) { - task = new EmbeddingTask({}, { model, text }); + graph.addTask(new EmbeddingTask({ input: { model: model.name, text } })); } else { program.error(`Unknown model ${options.model}`); } } else { let models = findModelByUseCase(ModelUseCaseEnum.TEXT_EMBEDDING); - task = new EmbeddingStrategy( - { name: "Embed several" }, - { text, models } + graph.addTask( + new EmbeddingMultiModelTask({ + name: "Embed several", + input: { text, model: models.map((m) => m.name) }, + }) ); } - - await runTask(task); + await runTask(graph); }); program @@ -96,20 +88,23 @@ export function AddSampleCommand(program: Command) { .argument("", "text to embed") .option("--model ", "model to use") .action(async (text, options) => { - let task; + const graph = new TaskGraph(); if (options.model) { const model = findModelByName(options.model); if (model) { - task = new SummarizeTask({}, { model, text }); + graph.addTask(new SummarizeTask({ input: { model: model.name, text } })); } else { program.error(`Unknown model ${options.model}`); } } else { let models = findModelByUseCase(ModelUseCaseEnum.TEXT_SUMMARIZATION); - task = new SummarizeStrategy({}, { text, models }); + graph.addTask( + new SummarizeMultiModelTask({ + input: { text, model: models.map((m) => m.name) }, + }) + ); } - - await runTask(task); + await runTask(graph); }); program @@ -119,92 +114,68 @@ export function AddSampleCommand(program: Command) { .option("--instruction ", "instruction for how to rewrite", "") .option("--model ", "model to use") .action(async (text, options) => { - let task; + const graph = new TaskGraph(); if (options.model) { const model = findModelByName(options.model); if (model) { - task = new RewriterTask( - { name: "Rewrite" }, - { model, text, prompt: options.instruction } + graph.addTask( + new TextRewriterTask({ + input: { model: model.name, text, prompt: options.instruction }, + }) ); } else { program.error(`Unknown model ${options.model}`); } } else { let models = findModelByUseCase(ModelUseCaseEnum.TEXT_GENERATION); - task = new RewriterStrategy( - { name: "Rewrite" }, - { - text, - prompt: options.instruction, - model: models, - } + graph.addTask( + new TextRewriterMultiModelTask({ + input: { + text, + prompt: options.instruction, + model: models.map((m) => m.name), + }, + }) ); } - await runTask(task); - }); - - program - .command("rewrite-embedding") - .description("rewrite based on internal prompt list, then embed") - .argument("", "text to rewrite and vectorize") - .action(async (text) => { - const prompt = [ - "Rewrite the following text:", - "Rewrite the following and make it more descriptive:", - ]; - const prompt_model = findModelByUseCase(ModelUseCaseEnum.TEXT_GENERATION); - - const embed_model = findModelByUseCase(ModelUseCaseEnum.TEXT_EMBEDDING); - - const task = new RewriterEmbeddingStrategy( - {}, - { - text, - prompt, - prompt_model, - embed_model, - } - ); - - await runTask(task); + await runTask(graph); }); program .command("json") .description("run based on json input") .argument("[json]", "json text to rewrite and vectorize") - .action(async (jsonText) => { - if (!jsonText) { - const exampleJson: TaskJsonInput[] = [ + .action(async (json) => { + if (!json) { + const exampleJson: JsonTaskArray = [ { - run: "RewriterTask", - config: { - output_name: "results", - }, + id: "1", + type: "DownloadTask", input: { - text: "The quick brown fox jumps over the lazy dog.", - prompt: "Rewrite the following text in reverse:", model: "Xenova/LaMini-Flan-T5-783M", }, }, { - run: "RenameTask", + id: "2", + type: "TextRewriterTask", input: { - output_remap_array: [{ from: "results", to: "reverse" }], + text: "The quick brown fox jumps over the lazy dog.", + prompt: "Rewrite the following text in reverse:", + }, + dependencies: { + model: { + id: "1", + output: "model", + }, }, }, ]; - jsonText = JSON.stringify(exampleJson); + json = JSON.stringify(exampleJson); } - const json = JSON.parse(jsonText); - const task = new JsonStrategy({ name: "Test JSON" }, json); - - await runTask(task); + const task = new JsonTask({ name: "Test JSON", input: { json } }); + const graph = new TaskGraph(); + graph.addTask(task); + await runTask(graph); }); - - program.command("test").action(async () => { - // - }); } diff --git a/src-examples/TaskHelper.ts b/packages/cli/src/TaskHelper.ts similarity index 100% rename from src-examples/TaskHelper.ts rename to packages/cli/src/TaskHelper.ts diff --git a/packages/cli/src/TaskStreamToListr2.ts b/packages/cli/src/TaskStreamToListr2.ts new file mode 100644 index 000000000..84dcc6c77 --- /dev/null +++ b/packages/cli/src/TaskStreamToListr2.ts @@ -0,0 +1,181 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { Listr, ListrTask, PRESET_TIMER } from "listr2"; +import { Observable } from "rxjs"; +import { + type TaskStream, + TaskStatus, + sleep, + TaskGraph, + TaskGraphRunner, + type Task, + DataFlow, + TaskInputDefinition, +} from "ellmers-core"; +import { createBar } from "./TaskHelper"; + +type TaskTree = { + task: Task; + children?: TaskTree; +}[]; + +/** + * Convert the DAG to a tree for use in a UI like listr2 + * Obviously a DAG can't be turned into a tree, but we can + * skip some edges and make it look like a tree if we want + */ +function convertToTree(runner: TaskGraphRunner) { + const sortedNodes = runner.dag.topologicallySortedNodes(); + // const allEdges = this.dag.getEdges().map(([s, t, e]) => e); + runner.assignLayers(sortedNodes); + const taskToDependency = new Map(); + runner.layers.forEach((nodes, layerNumber) => { + // console.log(`Layer ${layerNumber}`); + nodes.forEach((node) => { + if (layerNumber >= 0) { + const incomingEdges = runner.dag + .inEdges(node.config.id) + .map(([sourceNodeId]: [sourceNodeId: string]) => sourceNodeId); + // console.log(` ${node.config.name} <- ${incomingEdges.join(", ")}`); + for (const sourceNodeId of incomingEdges) { + if (!taskToDependency.has(node)) { + const sourceNode = runner.dag.getNode(sourceNodeId); + if (runner.layers.get(layerNumber - 1)?.find((n) => n == sourceNode)) { + taskToDependency.set(node, sourceNode!); + } + } + } + } + }); + }); + // reverse the map + const dependencyToTask = new Map(); + taskToDependency.forEach((dependency, task) => { + if (!dependencyToTask.has(dependency)) { + dependencyToTask.set(dependency, []); + } + dependencyToTask.get(dependency)?.push(task); + }); + const startNodes = runner.layers.get(0); + + // convert to tree + const convertToTree = (nodes: TaskStream): TaskTree => { + const tree: TaskTree = []; + nodes.forEach((node) => { + const children = dependencyToTask.get(node); + if (children) { + tree.push({ task: node, children: convertToTree(children) }); + } else { + tree.push({ task: node }); + } + }); + return tree; + }; + return convertToTree(startNodes!); +} + +const taskTreeToListr = ( + tree: TaskTree = [], + options: Record = { concurrent: false, exitOnError: true } +) => { + const list: ListrTask[] = []; + + for (const { task, children } of tree) { + list.push({ + title: task.config.name, + task: async (_, t) => { + if (children) { + return t.newListr(taskTreeToListr(children, options), options); + } else if (task.status == TaskStatus.COMPLETED || task.status == TaskStatus.FAILED) { + return; + } + return new Observable((observer) => { + const start = Date.now(); + let lastUpdate = start; + task.on("progress", (progress: any, file: string) => { + const timeSinceLast = Date.now() - lastUpdate; + const timeSinceStart = Date.now() - start; + if (timeSinceLast > 250 || timeSinceStart > 100) { + observer.next(createBar(progress / 100 || 0, 30) + " " + (file || "")); + } + }); + task.on("complete", () => { + observer.complete(); + }); + task.on("error", (error) => { + observer.complete(); + }); + }); + }, + }); + } + return list; +}; + +const flattenCompoundGraph = (dag: TaskGraph) => { + const nodes: Task[] = []; + const edges: DataFlow[] = []; + edges.push(...dag.getDataFlows()); + dag.getNodes().forEach((node) => { + if (node.isCompound) { + const { nodes: subNodes, edges: subEdges } = flattenCompoundGraph(node.subGraph); + // const inputNode = new SingleTask({ name: node.config.name, id: node.config.id }); + nodes.push(node); + nodes.push(...subNodes); + edges.push(...subEdges); + const inputs = (node.constructor as any).inputs as TaskInputDefinition[]; + subNodes.forEach((subNode) => { + inputs.forEach((input) => { + edges.push(new DataFlow(node.config.id, input.id, subNode.config.id, input.id)); + }); + }); + // const outputs = (node.constructor as any).outputs as TaskOutputDefinition[]; + // const outputNode = new SingleTask(); + // subNodes.forEach((subNode) => { + // outputs.forEach((output) => { + // edges.push(new DataFlow(subNode.config.id, output.id, node.config.id, output.id)); + // }); + // }); + } else { + nodes.push(node); + } + }); + return { nodes, edges }; +}; + +const runTaskToListr = async (runner: TaskGraphRunner) => { + const { nodes, edges } = flattenCompoundGraph(runner.dag); + const displayGraph = new TaskGraph(); + displayGraph.addTasks(nodes); + displayGraph.addDataFlows(edges); + const flatRunner = new TaskGraphRunner(displayGraph); + const tree = convertToTree(flatRunner); + const options = { + exitOnError: true, + concurrent: true, + rendererOptions: { timer: PRESET_TIMER }, + }; + const listrTasks = taskTreeToListr(tree, options); + const listr = new Listr(listrTasks, options); + + listr.run({}); + await sleep(250); + const result = await runner.runGraph(); + await sleep(250); + console.log("Result", result); +}; + +export async function runTask(dag: TaskGraph) { + const runner = new TaskGraphRunner(dag); + if (process.stdout.isTTY) { + await runTaskToListr(runner); + } else { + const result = await runner.runGraph(); + console.log(JSON.stringify(result, null, 2)); + } +} diff --git a/elmers.ts b/packages/cli/src/ellmers.ts similarity index 52% rename from elmers.ts rename to packages/cli/src/ellmers.ts index 87b9802a7..97140ae77 100755 --- a/elmers.ts +++ b/packages/cli/src/ellmers.ts @@ -2,12 +2,10 @@ import { program } from "commander"; import { argv } from "process"; -import { AddSecCommands } from "./src-examples/ExampleSEC"; -import { AddSampleCommand } from "./src-examples/TaskCLI"; +import { AddBaseCommands } from "./TaskCLI"; program.version("1.0.0").description("A CLI to run Ellmers."); -AddSecCommands(program); -AddSampleCommand(program); +AddBaseCommands(program); await program.parseAsync(argv); diff --git a/packages/cli/src/lib.ts b/packages/cli/src/lib.ts new file mode 100644 index 000000000..5461620c6 --- /dev/null +++ b/packages/cli/src/lib.ts @@ -0,0 +1,9 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +export * from "./TaskHelper"; +export * from "./TaskStreamToListr2"; diff --git a/packages/cli/tsconfig.json b/packages/cli/tsconfig.json new file mode 100644 index 000000000..21c004a3f --- /dev/null +++ b/packages/cli/tsconfig.json @@ -0,0 +1,14 @@ +{ + "extends": "../../tsconfig.json", + "include": ["src/**/*"], + "files": ["src/lib.ts", "src/ellmers.ts"], + "exclude": ["**/*.test.ts"], + "compilerOptions": { + "outDir": "dist", + "baseUrl": "./src", + "rootDir": "./src", + "paths": { + "#/*": ["./src/*"] + } + } +} diff --git a/packages/core/package.json b/packages/core/package.json new file mode 100644 index 000000000..0212665fa --- /dev/null +++ b/packages/core/package.json @@ -0,0 +1,26 @@ +{ + "name": "ellmers-core", + "type": "module", + "version": "0.0.1", + "description": "Ellmers is a tool for building and running DAG pipelines of AI tasks.", + "scripts": { + "watch": "bunx concurrently 'bun run watch-types' 'bun run watch-js'", + "watch-js": "bun build --watch --target=browser --sourcemap=external --external @sroussey/transformers --outdir ./dist ./src/lib.ts", + "watch-types": "tsc --watch", + "build": "bun run build-clean && bun run build-types && bun run build-js && bun run build-types-map", + "build-clean": "rm -fr dist/* tsconfig.tsbuildinfo", + "build-js": "bun build --target=browser --minify-whitespace --minify-syntax --sourcemap=external --external @sroussey/transformers --outdir ./dist ./src/lib.ts", + "build-types": "tsc", + "build-types-map": "tsc --declarationMap", + "test": "echo \"Error: no test specified\" && exit 1" + }, + "module": "dist/lib.js", + "exports": { + ".": { + "import": "./dist/lib.js" + } + }, + "files": [ + "dist" + ] +} diff --git a/packages/core/src/lib.ts b/packages/core/src/lib.ts new file mode 100644 index 000000000..beac8f2e6 --- /dev/null +++ b/packages/core/src/lib.ts @@ -0,0 +1,24 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +export * from "./source/Document"; +export * from "./task/Task"; +export * from "./task/TaskRegistry"; +export * from "./task/BasicTasks"; +export * from "./task/ArrayTask"; +export * from "./task/TaskIOTypes"; +export * from "./task/ModelFactory"; +export * from "./task/ModelFactoryTasks"; +export * from "./task/TaskGraph"; +export * from "./task/TaskGraphRunner"; +export * from "./task/JsonTask"; +export * from "./task/exec/ml/HuggingFaceLocalTaskRun"; +export * from "./task/exec/ml/MediaPipeLocalTaskRun"; +export * from "./model/Model"; +export * from "./model/HuggingFaceModel"; +export * from "./storage/InMemoryStorage"; +export * from "./util/Misc"; diff --git a/src/Instruct.ts b/packages/core/src/model/HuggingFaceModel.ts similarity index 51% rename from src/Instruct.ts rename to packages/core/src/model/HuggingFaceModel.ts index c035c0681..fd1c325c7 100644 --- a/src/Instruct.ts +++ b/packages/core/src/model/HuggingFaceModel.ts @@ -5,26 +5,16 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import type { Model } from "./Model"; - -export class Instruct { - public queryInstruction: string = ""; - public storageInstruction: string = ""; - public model: Model | null = null; - public parameters: Record = {}; +import { Model, ModelProcessorEnum, ModelUseCaseEnum } from "./Model"; +export class ONNXTransformerJsModel extends Model { constructor( - public name: string, - public description: string, - options?: Partial< - Pick< - Instruct, - "queryInstruction" | "storageInstruction" | "model" | "parameters" - > - > + name: string, + useCase: ModelUseCaseEnum[], + public pipeline: string, + options?: Partial> ) { - Object.assign(this, options); + super(name, useCase, options); } + readonly type = ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS; } - -export type InstructList = Instruct[]; diff --git a/packages/core/src/model/MediaPipeModel.ts b/packages/core/src/model/MediaPipeModel.ts new file mode 100644 index 000000000..b8ca7b368 --- /dev/null +++ b/packages/core/src/model/MediaPipeModel.ts @@ -0,0 +1,15 @@ +import { Model, ModelProcessorEnum, ModelUseCaseEnum } from "./Model"; + +export class MediaPipeTfJsModel extends Model { + constructor( + name: string, + useCase: ModelUseCaseEnum[], + public url: string, + options?: Partial< + Pick + > + ) { + super(name, useCase, options); + } + readonly type = ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL; +} diff --git a/src/Model.ts b/packages/core/src/model/Model.ts similarity index 100% rename from src/Model.ts rename to packages/core/src/model/Model.ts diff --git a/src/query/InMemoryQuery.ts b/packages/core/src/query/InMemoryQuery.ts similarity index 97% rename from src/query/InMemoryQuery.ts rename to packages/core/src/query/InMemoryQuery.ts index a13b15cbd..07983d984 100644 --- a/src/query/InMemoryQuery.ts +++ b/packages/core/src/query/InMemoryQuery.ts @@ -5,7 +5,7 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { NodeEmbedding, QueryText, TextNode } from "#/Document"; +import { NodeEmbedding, QueryText, TextNode } from "../source/Document"; // import { inner, cosine } from "simsimd"; export function inner(arr1: number[], arr2: number[]) { diff --git a/src/Document.ts b/packages/core/src/source/Document.ts similarity index 100% rename from src/Document.ts rename to packages/core/src/source/Document.ts diff --git a/src/storage/InMemoryStorage.ts b/packages/core/src/storage/InMemoryStorage.ts similarity index 58% rename from src/storage/InMemoryStorage.ts rename to packages/core/src/storage/InMemoryStorage.ts index 148e64b95..884028af3 100644 --- a/src/storage/InMemoryStorage.ts +++ b/packages/core/src/storage/InMemoryStorage.ts @@ -5,11 +5,9 @@ // * Licensed under the Apache License, Version 2.0 (the "License"); * // ******************************************************************************* -import { Instruct, InstructList } from "#/Instruct"; -import { Model, ModelUseCaseEnum } from "#/Model"; -import { StrategyList } from "#/Strategy"; -import { ONNXTransformerJsModel } from "#/tasks/HuggingFaceLocalTasks"; -import { MediaPipeTfJsModel } from "#/tasks/MediaPipeLocalTasks"; +import { ONNXTransformerJsModel } from "../model/HuggingFaceModel"; +import { MediaPipeTfJsModel } from "../model/MediaPipeModel"; +import { Model, ModelUseCaseEnum } from "../model/Model"; export const universal_sentence_encoder = new MediaPipeTfJsModel( "Universal Sentence Encoder", @@ -72,13 +70,12 @@ export const xenovaDistilbertMnli = new ONNXTransformerJsModel( "zero-shot-classification" ); -export const stentancetransformerMultiQaMpnetBaseDotV1 = - new ONNXTransformerJsModel( - "Xenova/multi-qa-mpnet-base-dot-v1", - [ModelUseCaseEnum.TEXT_EMBEDDING], - "feature-extraction", - { dimensions: 768 } - ); +export const stentancetransformerMultiQaMpnetBaseDotV1 = new ONNXTransformerJsModel( + "Xenova/multi-qa-mpnet-base-dot-v1", + [ModelUseCaseEnum.TEXT_EMBEDDING], + "feature-extraction", + { dimensions: 768 } +); export const gpt2 = new ONNXTransformerJsModel( "Xenova/gpt2", @@ -110,70 +107,8 @@ export const distilbartCnn = new ONNXTransformerJsModel( "summarization" ); -export const instructPlain = new Instruct( - "Plain", - "The plain version does nothing extra to queries or storage" -); - -export const instructHighTemp = new Instruct( - "HighTemp", - "This is similar to plain but with a higher temperature and four versions averaged together", - { parameters: { temperature: 2.5, versions: 4 } } // no model, so inert for now -); - -export const instructQuestion = new Instruct( - "EverythingIsAQuestion", - "This converts storage into questions", - { storageInstruction: "Rephrase the following as a question: ", model: gpt2 } -); - -export const instructSummarize = new Instruct( - "Summarize", - "This converts storage into summaries", - { model: distilbartCnn } -); - -export const instructRepresent = new Instruct( - "Represent", - "This tries to coax the model into representing the query or passage", - { - queryInstruction: "Represent this query for searching relevant passages: ", - storageInstruction: "Represent this passage for later retrieval: ", - } // no model, so inert -); - -export const instructKeywords = new Instruct( - "Keywords", - "Try and pull keywords and concepts from both query and storage", - { - queryInstruction: - "What are the most important keywords and concepts that represent the following: ", - storageInstruction: - "What are the most important keywords and concepts that represent the following: ", - model: xenovaDistilbertMnli, // doesn't work - } -); - -export const instructList: InstructList = [ - instructPlain, - // instructHighTemp, - // instructQuestion, - // instructRepresent, - instructSummarize, -]; - -export const strategyAllPairs: StrategyList = []; - -for (const feModel of findModelByUseCase(ModelUseCaseEnum.TEXT_EMBEDDING)) { - for (const instruct of instructList) { - strategyAllPairs.push({ - embeddingModel: feModel, - instruct, - }); - } -} - export function findModelByName(name: string) { + if (typeof name != "string") return undefined; return Model.all.find((m) => m.name.toLowerCase() == name.toLowerCase()); } diff --git a/packages/core/src/task/ArrayTask.ts b/packages/core/src/task/ArrayTask.ts new file mode 100644 index 000000000..fb542dbcd --- /dev/null +++ b/packages/core/src/task/ArrayTask.ts @@ -0,0 +1,76 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { ModelFactory } from "./ModelFactory"; +import { CompoundTask, TaskInput, TaskConfig, TaskOutput, TaskTypeName } from "./Task"; +import { TaskInputDefinition, TaskOutputDefinition } from "./TaskIOTypes"; +import { TaskRegistry } from "./TaskRegistry"; + +export type ConvertToArrays = { + [P in keyof T]: P extends K ? Array : T[P]; +}; + +type Writeable = { -readonly [P in keyof T]: T[P] }; + +function convertToArray(io: D[], id: string) { + const results: D[] = []; + for (const item of io) { + const newItem: Writeable = { ...item }; + if (newItem.id === id) { + newItem.isArray = true; + } + results.push(newItem); + } + return results as D[]; +} + +export function arrayTaskFactory< + PluralInputType extends TaskInput, + PluralOutputType extends TaskOutput, +>(taskClass: typeof ModelFactory, inputMakeArray: string, outputMakeArray: string, name?: string) { + const inputs = convertToArray(Array.from(taskClass.inputs), inputMakeArray); + const outputs = convertToArray( + Array.from(taskClass.outputs), + outputMakeArray + ); + + const nameWithoutTask = taskClass.type.slice(0, -4); + const capitalized = inputMakeArray.charAt(0).toUpperCase() + inputMakeArray.slice(1); + name ??= nameWithoutTask + "Multi" + capitalized + "Task"; + + class ArrayTask extends CompoundTask { + static readonly displayName = name!; // this is for debuggers as they can't infer the name from code + static readonly type: TaskTypeName = name!; + static readonly category = (taskClass.constructor as any).category; + declare runInputData: PluralInputType; + declare runOutputData: PluralOutputType; + declare defaults: Partial; + + itemClass = taskClass; + + static inputs = inputs; + static outputs = outputs; + constructor(config: TaskConfig & { input?: PluralInputType } = {}) { + super(config); + this.generateGraph(); + } + generateGraph() { + if (Array.isArray(this.runInputData[inputMakeArray])) { + this.runInputData[inputMakeArray].forEach((prop: any) => { + const input = { ...this.runInputData, [inputMakeArray]: prop }; + const current = new taskClass({ input }); + this.subGraph.addTask(current); + }); + } + } + } + TaskRegistry.registerTask(ArrayTask); + + return ArrayTask; +} + +export type ArrayTaskType = ReturnType; diff --git a/packages/core/src/task/BasicTasks.ts b/packages/core/src/task/BasicTasks.ts new file mode 100644 index 000000000..ffc5fa276 --- /dev/null +++ b/packages/core/src/task/BasicTasks.ts @@ -0,0 +1,98 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { TaskConfig } from "./Task"; +import { SingleTask, TaskInput, TaskOutput } from "./Task"; +import { CreateMappedType } from "./TaskIOTypes"; +import { TaskRegistry } from "./TaskRegistry"; + +export interface RenameTaskInput { + output_remap_array: { + from: string; + to: string; + }[]; +} + +// =============================================================================== + +export type LambdaTaskInput = + | CreateMappedType + | { run: (input: TaskInput) => Promise }; +export type LambdaTaskOutput = CreateMappedType; + +export class LambdaTask extends SingleTask { + static readonly type = "LambdaTask"; + static readonly category = "Utility"; + declare runOutputData: TaskOutput; + public static inputs = [ + { + id: "run", + name: "Run Function", + valueType: "text", + }, + ] as const; + public static outputs = [ + { + id: "output", + name: "Output", + valueType: "any", + }, + ] as const; + constructor(config: TaskConfig & LambdaTaskInput) { + super(config); + } + runSyncOnly() { + if (!this.runInputData.run) { + throw new Error("No runner provided"); + } + if (typeof this.runInputData.run === "string") { + const fnText = this.runInputData.run; + const fn = new Function(fnText); + try { + fn(); + this.runInputData.run = fn; + } catch (e) {} + } + if (typeof this.runInputData.run === "function") { + this.runOutputData.output = this.runInputData.run(this.runInputData); + console.log("lambda output", this.runOutputData); + } else { + console.error("error", "Runner is not a function"); + } + return this.runOutputData; + } +} +TaskRegistry.registerTask(LambdaTask); + +export type DebugLogTaskInput = CreateMappedType; +export type DebugLogTaskOutput = CreateMappedType; + +export class DebugLogTask extends SingleTask { + static readonly type: string = "DebugLogTask"; + static readonly category = "Utility"; + declare runInputData: DebugLogTaskInput; + declare runOutputData: DebugLogTaskOutput; + public static inputs = [ + { + id: "message", + name: "Input", + valueType: "any", + }, + { + id: "level", + name: "Level", + valueType: "log_level", + defaultValue: "info", + }, + ] as const; + public static outputs = [] as const; + runSyncOnly() { + console[this.runInputData.level || "log"](this.runInputData.message); + return this.runOutputData; + } +} +TaskRegistry.registerTask(DebugLogTask); diff --git a/packages/core/src/task/JsonTask.ts b/packages/core/src/task/JsonTask.ts new file mode 100644 index 000000000..f3a9f873e --- /dev/null +++ b/packages/core/src/task/JsonTask.ts @@ -0,0 +1,95 @@ +// // ******************************************************************************* +// // * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// // * * +// // * Copyright Steven Roussey * +// // * Licensed under the Apache License, Version 2.0 (the "License"); * +// // ******************************************************************************* + +import { CompoundTask, Task, TaskConfig, TaskInput } from "./Task"; +import { DataFlow, TaskGraph } from "./TaskGraph"; +import { CreateMappedType } from "./TaskIOTypes"; +import { TaskRegistry } from "./TaskRegistry"; + +export type JsonTaskArray = Array; +export type JsonTaskItem = { + id: string; + type: string; + name?: string; + input?: TaskInput; + dependencies?: { + [x: string]: JsonTaskDependecy; + }; +}; +type JsonTaskDependecy = { + id: string; + output: string; +}; + +type JsonTaskInput = CreateMappedType; +type JsonTaskOutput = CreateMappedType; + +export class JsonTask extends CompoundTask { + public static inputs = [ + { + id: "json", + name: "JSON", + valueType: "text", + }, + ] as const; + public static outputs = [] as const; + + declare runInputData: JsonTaskInput; + declare runOutputData: JsonTaskOutput; + declare defaults: Partial; + constructor(config: TaskConfig & { input?: JsonTaskInput }) { + super(config); + if (config?.input?.json) { + this.generateGraph(); + } + } + public setInputData(...overrides: (Partial | undefined)[]): void { + let changed = false; + for (const override of overrides) { + if (override) { + if (override.json !== undefined && override.json != this.runInputData.json) changed = true; + } + } + super.setInputData(...overrides); + if (changed) this.generateGraph(); + } + public generateGraph() { + if (!this.runInputData.json) return; + let data = JSON.parse(this.runInputData.json); + if (!Array.isArray(data)) data = [data]; + const jsonItems: JsonTaskArray = data as JsonTaskArray; + // create the task nodes + this.subGraph = new TaskGraph(); + for (const item of jsonItems) { + if (!item.id) throw new Error("Task id required"); + if (!item.type) throw new Error("Task type required"); + if (item.input && Array.isArray(item.input)) throw new Error("Task input must be an object"); + + const taskClass = TaskRegistry.all.get(item.type); + if (!taskClass) throw new Error(`Task type ${item.type} not found`); + + const taskConfig = { id: item.id, name: item.name, input: item.input ?? {} }; + const task = new taskClass(taskConfig); + this.subGraph.addTask(task); + } + // create the data flow edges + for (const item of jsonItems) { + if (!item.dependencies) continue; + for (const [input, dependency] of Object.entries(item.dependencies)) { + const sourceTask = this.subGraph.getTask(dependency.id); + if (!sourceTask) { + throw new Error(`Dependency id ${dependency.id} not found`); + } + const df = new DataFlow(sourceTask.config.id, dependency.output, item.id, input); + this.subGraph.addDataFlow(df); + } + } + } + + static readonly type = "JsonTask"; + static readonly category = "Utility"; +} diff --git a/packages/core/src/task/ModelFactory.ts b/packages/core/src/task/ModelFactory.ts new file mode 100644 index 000000000..345a04354 --- /dev/null +++ b/packages/core/src/task/ModelFactory.ts @@ -0,0 +1,78 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +/** + * @file ModelFactory.ts + * @description This file contains the implementation of the ModelFactory class and its derived classes. + * The ModelFactory class is responsible for creating and running tasks based on different models. + * It provides a common interface for running tasks and handles the execution logic. + * Each derived class defines its own input and output types and implements the run() method to perform the task-specific logic. + */ + +import { findModelByName } from "../storage/InMemoryStorage"; +import { ModelProcessorEnum } from "../model/Model"; +import { SingleTask, TaskInput, TaskConfig, TaskOutput } from "./Task"; +import { TaskInputDefinition, TaskOutputDefinition } from "./TaskIOTypes"; +import { TaskRegistry } from "./TaskRegistry"; + +export class ModelFactory extends SingleTask { + public static inputs: readonly TaskInputDefinition[]; + public static outputs: readonly TaskOutputDefinition[]; + static readonly type: string = "ModelFactory"; + declare runOutputData: TaskOutput; + + static runFnRegistry: Record< + string, + Record Promise> + > = {}; + static registerRunFn( + baseClass: typeof ModelFactory, + modelType: ModelProcessorEnum, + runFn: (task: any, runInputData: any) => Promise + ) { + if (!ModelFactory.runFnRegistry[baseClass.type]) + ModelFactory.runFnRegistry[baseClass.type] = {}; + ModelFactory.runFnRegistry[baseClass.type][modelType] = runFn; + } + static getRunFn(taskClassName: string, modelType: ModelProcessorEnum) { + return ModelFactory.runFnRegistry[taskClassName]?.[modelType]; + } + + constructor(config: TaskConfig = {}) { + config.name ||= `${new.target.name}${config.input?.model ? " with model " + config.input?.model : ""}`; + super(config); + } + + async run(): Promise { + this.emit("start"); + this.runOutputData = {}; + let results; + debugger; + try { + const taskClass = TaskRegistry.all.get(this.constructor.name); + if (!taskClass) throw new Error("ModelFactoryTask: No task class found"); + const modelname = this.runInputData["model"]; + if (!modelname) throw new Error("ModelFactoryTask: No model name found"); + const model = findModelByName(modelname); + if (!model) throw new Error("ModelFactoryTask: No model found"); + const runFn = ModelFactory.getRunFn(this.constructor.name, model.type); + if (!runFn) throw new Error("ModelFactoryTask: No run function found"); + results = await runFn(this, this.runInputData); + } catch (err) { + this.emit("error", err); + console.error(err); + return {}; + } + this.emit("complete"); + this.runOutputData = results ?? {}; + this.runOutputData = this.runSyncOnly(); + return this.runOutputData; + } + runSyncOnly(): TaskOutput { + return this.runOutputData; + } +} diff --git a/packages/core/src/task/ModelFactoryTasks.ts b/packages/core/src/task/ModelFactoryTasks.ts new file mode 100644 index 000000000..095cb8d20 --- /dev/null +++ b/packages/core/src/task/ModelFactoryTasks.ts @@ -0,0 +1,240 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { ConvertToArrays, arrayTaskFactory } from "./ArrayTask"; +import { CreateMappedType } from "./TaskIOTypes"; +import { TaskRegistry } from "./TaskRegistry"; +import { ModelFactory } from "./ModelFactory"; +import { TaskConfig } from "./Task"; + +// =============================================================================== + +export type DownloadTaskInput = CreateMappedType; +export type DownloadTaskOutput = CreateMappedType; + +export class DownloadTask extends ModelFactory { + public static inputs = [ + { + id: "model", + name: "Model", + valueType: "model", + }, + ] as const; + public static outputs = [ + { + id: "model", + name: "Model", + valueType: "model", + }, + ] as const; + + declare runInputData: DownloadTaskInput; + declare runOutputData: DownloadTaskOutput; + declare defaults: Partial; + constructor(config: TaskConfig & { input?: DownloadTaskInput }) { + super(config); + } + static readonly type = "DownloadTask"; + static readonly category = "Text Model"; +} +TaskRegistry.registerTask(DownloadTask); + +export const DownloadMultiModelTask = arrayTaskFactory< + ConvertToArrays, + ConvertToArrays +>(DownloadTask, "model", "text"); + +// =============================================================================== + +export type EmbeddingTaskInput = CreateMappedType; +export type EmbeddingTaskOutput = CreateMappedType; + +/** + * This is a task that generates an embedding for a single piece of text + */ +export class EmbeddingTask extends ModelFactory { + public static inputs = [ + { + id: "text", + name: "Text", + valueType: "text", + }, + { + id: "model", + name: "Model", + valueType: "text_embedding_model", + }, + ] as const; + public static outputs = [{ id: "vector", name: "Embedding", valueType: "vector" }] as const; + + declare runInputData: EmbeddingTaskInput; + declare runOutputData: EmbeddingTaskOutput; + declare defaults: Partial; + static readonly type = "EmbeddingTask"; + static readonly category = "Text Model"; +} +TaskRegistry.registerTask(EmbeddingTask); + +export const EmbeddingMultiModelTask = arrayTaskFactory< + ConvertToArrays, + ConvertToArrays +>(EmbeddingTask, "model", "text"); + +// =============================================================================== + +export type TextGenerationTaskInput = CreateMappedType; +export type TextGenerationTaskOutput = CreateMappedType; + +/** + * This generates text from a prompt + */ +export class TextGenerationTask extends ModelFactory { + public static inputs = [ + { + id: "prompt", + name: "Prompt", + valueType: "text", + }, + { + id: "model", + name: "Model", + valueType: "text_generation_model", + }, + ] as const; + public static outputs = [{ id: "text", name: "Text", valueType: "text" }] as const; + + declare runInputData: TextGenerationTaskInput; + declare runOutputData: TextGenerationTaskOutput; + declare defaults: Partial; + static readonly type = "TextGenerationTask"; + static readonly category = "Text Model"; +} +TaskRegistry.registerTask(TextGenerationTask); + +export const TextGenerationMultiModelTask = arrayTaskFactory< + ConvertToArrays, + ConvertToArrays +>(TextGenerationTask, "model", "text"); + +// =============================================================================== + +export type SummarizeTaskInput = CreateMappedType; +export type SummarizeTaskOutput = CreateMappedType; + +/** + * This summarizes a piece of text + */ + +export class SummarizeTask extends ModelFactory { + public static inputs = [ + { + id: "text", + name: "Text", + valueType: "text", + }, + { + id: "model", + name: "Model", + valueType: "text_summarization_model", + }, + ] as const; + public static outputs = [{ id: "text", name: "Text", valueType: "text" }] as const; + + declare runInputData: SummarizeTaskInput; + declare runOutputData: SummarizeTaskOutput; + declare defaults: Partial; + static readonly type = "SummarizeTask"; + static readonly category = "Text Model"; +} +TaskRegistry.registerTask(SummarizeTask); + +export const SummarizeMultiModelTask = arrayTaskFactory< + ConvertToArrays, + ConvertToArrays +>(SummarizeTask, "model", "text"); + +// =============================================================================== + +export type TextRewriterTaskInput = CreateMappedType; +export type TextRewriterTaskOutput = CreateMappedType; + +/** + * This is a special case of text generation that takes a prompt and text to rewrite + */ + +export class TextRewriterTask extends ModelFactory { + public static inputs = [ + { + id: "text", + name: "Text", + valueType: "text", + }, + { + id: "prompt", + name: "Prompt", + valueType: "text", + }, + { + id: "model", + name: "Model", + valueType: "text_generation_model", + }, + ] as const; + public static outputs = [{ id: "text", name: "Text", valueType: "text" }] as const; + + declare runInputData: TextRewriterTaskInput; + declare runOutputData: TextRewriterTaskOutput; + declare defaults: Partial; + static readonly type = "TextRewriterTask"; + static readonly category = "Text Model"; +} +TaskRegistry.registerTask(TextRewriterTask); + +export const TextRewriterMultiModelTask = arrayTaskFactory< + ConvertToArrays, + ConvertToArrays +>(TextRewriterTask, "model", "text"); + +// =============================================================================== +export type QuestionAnswerTaskInput = CreateMappedType; +export type QuestionAnswerTaskOutput = CreateMappedType; + +/** + * This is a special case of text generation that takes a context and a question + */ +export class QuestionAnswerTask extends ModelFactory { + public static inputs = [ + { + id: "context", + name: "Context", + valueType: "text", + }, + { + id: "question", + name: "Question", + valueType: "text", + }, + { + id: "model", + name: "Model", + valueType: "text_question_answering_model", + }, + ] as const; + public static outputs = [{ id: "answer", name: "Answer", valueType: "text" }] as const; + + declare runInputData: QuestionAnswerTaskInput; + declare runOutputData: QuestionAnswerTaskOutput; + declare defaults: Partial; + static readonly type = "QuestionAnswerTask"; + static readonly category = "Text Model"; +} +TaskRegistry.registerTask(QuestionAnswerTask); + +export const QuestionAnswerMultiModelTask = arrayTaskFactory< + ConvertToArrays, + ConvertToArrays +>(TextRewriterTask, "model", "answer"); diff --git a/packages/core/src/task/Task.ts b/packages/core/src/task/Task.ts new file mode 100644 index 000000000..6b2ec4cbc --- /dev/null +++ b/packages/core/src/task/Task.ts @@ -0,0 +1,196 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { EventEmitter } from "eventemitter3"; +import { TaskGraph } from "./TaskGraph"; +import { TaskGraphRunner } from "./TaskGraphRunner"; +import type { TaskInputDefinition, TaskOutputDefinition } from "./TaskIOTypes"; + +export enum TaskStatus { + PENDING = "NEW", + PROCESSING = "PROCESSING", + COMPLETED = "COMPLETED", + FAILED = "FAILED", +} + +/** + * TaskEvents + * + * There is no job queue at the moement. + */ +export type TaskEvents = "start" | "complete" | "error" | "progress"; + +export interface TaskInput { + [key: string]: any; +} +export interface TaskOutput { + [key: string]: any; +} + +export interface ITaskSimple { + readonly isCompound: false; +} +export interface ITaskCompound { + readonly isCompound: true; + subGraph: TaskGraph; +} + +export type ITask = ITaskSimple | ITaskCompound; + +export type TaskTypeName = string; + +export type TaskConfig = Partial & { input?: TaskInput }; + +// =============================================================================== + +export interface IConfig { + id: string; + name?: string; +} + +abstract class TaskBase { + // information about the task that should be overriden by the subclasses + static readonly type: TaskTypeName = "TaskBase"; + static readonly category: string = "Hidden"; + + events = new EventEmitter(); + on(name: TaskEvents, fn: (...args: any[]) => void) { + this.events.on.call(this.events, name, fn); + } + off(name: TaskEvents, fn: (...args: any[]) => void) { + this.events.off.call(this.events, name, fn); + } + emit(name: TaskEvents, ...args: any[]) { + this.events.emit.call(this.events, name, ...args); + } + /** + * Does this task have subtasks? + */ + abstract readonly isCompound: boolean; + /** + * Configuration for the task, might include things like name and id for the database + */ + config: IConfig; + status: TaskStatus = TaskStatus.PENDING; + progress: number = 0; + createdAt: Date = new Date(); + startedAt?: Date; + completedAt?: Date; + error: string | undefined = undefined; + + constructor(config: TaskConfig = {}) { + // pull out input data from the config + const { input = {}, ...rest } = config; + this.defaults = input; + this.setInputData(); + + // setup the configuration + const name = (this.constructor as any).type ?? this.constructor.name; + this.config = Object.assign( + { + id: name + ":" + Math.random().toString(36).substring(2, 9), + name: name, + }, + rest + ); + // setup the events + this.setupEvents(); + } + + public setupEvents() { + Object.defineProperty(this, "events", { enumerable: false }); // in case it is serialized + this.on("start", () => { + this.startedAt = new Date(); + this.progress = 0; + this.status = TaskStatus.PROCESSING; + }); + this.on("complete", () => { + this.completedAt = new Date(); + this.progress = 100; + this.status = TaskStatus.COMPLETED; + }); + this.on("error", (error) => { + this.completedAt = new Date(); + this.progress = 100; + this.status = TaskStatus.FAILED; + this.error = error; + }); + } + /** + * The defaults for the task. If no overrides at run time, then this would be equal to the + * input + */ + defaults: TaskInput; + /** + * The input to the task at the time of the task run. This takes defaults from construction + * time and overrides from run time. It is the input that created the output. + */ + runInputData: TaskInput = {}; + /** + * The output of the task at the time of the task run. This is the result of the task. + * The the defaults and overrides are combined to match the required input of the task. + */ + runOutputData: TaskOutput = {}; + public static inputs: readonly TaskInputDefinition[]; + public static outputs: readonly TaskOutputDefinition[]; + + /** + * + * This calculates the input to the task at the time of the task run. This takes defaults from + * construction and applies run time overrides (which may be output from a previous run if this + * is a serial task or strategy). Caller needs to decide if should set to this classes input + * or not. + */ + setInputData(...overrides: (Partial | undefined)[]) { + this.runInputData = Object.assign({}, this.defaults, ...overrides) as T; + } + runWithInput(input: T) { + this.setInputData(input); + return this.run(); + } + async run(): Promise { + return this.runSyncOnly(); + } + runSyncOnly(): TaskOutput { + return this.runOutputData; + } +} + +export type TaskIdType = TaskBase["config"]["id"]; + +export class SingleTask extends TaskBase implements ITaskSimple { + static readonly type: TaskTypeName = "SingleTask"; + readonly isCompound = false; +} + +export class CompoundTask extends TaskBase implements ITaskCompound { + static readonly type: TaskTypeName = "CompoundTask"; + readonly isCompound = true; + _subGraph: TaskGraph | null = null; + set subGraph(subGraph: TaskGraph) { + this._subGraph = subGraph; + } + get subGraph() { + if (!this._subGraph) { + this._subGraph = new TaskGraph(); + } + return this._subGraph; + } + async run(): Promise { + this.emit("start"); + const runner = new TaskGraphRunner(this.subGraph); + this.runOutputData = await runner.runGraph(); + this.runOutputData = this.runSyncOnly(); + this.emit("complete"); + return this.runOutputData; + } +} + +// =============================================================================== + +export type Task = SingleTask | CompoundTask; +export type TaskStream = Task[]; diff --git a/packages/core/src/task/TaskGraph.ts b/packages/core/src/task/TaskGraph.ts new file mode 100644 index 000000000..e0eb72700 --- /dev/null +++ b/packages/core/src/task/TaskGraph.ts @@ -0,0 +1,117 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { DirectedAcyclicGraph } from "@sroussey/typescript-graph"; +import { TaskIdType } from "./Task"; +import { Task, TaskStream } from "./Task"; + +export type IDataFlow = { + sourceTaskId: TaskIdType; + sourceTaskOutputId: string; + targetTaskId: TaskIdType; + targetTaskInputId: string; + id: string; +}; + +export type DataFlowIdType = IDataFlow["id"]; + +export class DataFlow implements IDataFlow { + constructor( + public sourceTaskId: TaskIdType, + public sourceTaskOutputId: string, + public targetTaskId: TaskIdType, + public targetTaskInputId: string, + public id: string = `${sourceTaskId}.${sourceTaskOutputId} -> ${targetTaskId}.${targetTaskInputId}` + ) {} +} + +type NodeNodeEdgeTuples = Array< + [sourceTask: TaskIdType, targetTask: TaskIdType, edge?: IDataFlow | undefined] +>; + +export class TaskGraph extends DirectedAcyclicGraph { + constructor() { + super( + (task: Task) => task.config.id, + (dataFlow: IDataFlow) => dataFlow.id + ); + } + public getTask(id: TaskIdType): Task | undefined { + return super.getNode(id); + } + public addTask(task: Task) { + return super.addNode(task); + } + public addTasks(tasks: Task[]) { + return super.addNodes(tasks); + } + public addDataFlow(dataflow: DataFlow) { + return super.addEdge(dataflow.sourceTaskId, dataflow.targetTaskId, dataflow); + } + public addDataFlows(dataflows: DataFlow[]) { + const addedEdges = dataflows.map<[s: string, t: string, e: IDataFlow]>((edge) => { + return [edge.sourceTaskId, edge.targetTaskId, edge]; + }); + return super.addEdges(addedEdges); + } + public getDataFlow(id: DataFlowIdType): IDataFlow | undefined { + for (const i in this.adjacency) { + for (const j in this.adjacency[i]) { + const maybeEdges = this.adjacency[i][j]; + if (maybeEdges !== null) { + for (const edge of maybeEdges) { + if (this.edgeIdentity(edge, "", "") == id) { + return edge; + } + } + } + } + } + } + public getDataFlows(): IDataFlow[] { + return this.getEdges().map((edge) => edge[2]); + } +} + +/** + * Super simple helper if you know the input and output handles, and there is only one each + * + * @param tasks TaskStream + * @param inputHandle TaskIdType + * @param outputHandle TaskIdType + * @returns + */ +function serialGraphEdges( + tasks: TaskStream, + inputHandle: string, + outputHandle: string +): IDataFlow[] { + const edges: IDataFlow[] = []; + for (let i = 0; i < tasks.length - 1; i++) { + edges.push(new DataFlow(tasks[i].config.id, inputHandle, tasks[i + 1].config.id, outputHandle)); + } + return edges; +} + +/** + * Super simple helper if you know the input and output handles, and there is only one each + * + * @param tasks TaskStream + * @param inputHandle TaskIdType + * @param outputHandle TaskIdType + * @returns + */ +export function serialGraph( + tasks: TaskStream, + inputHandle: string, + outputHandle: string +): TaskGraph { + const graph = new TaskGraph(); + graph.addTasks(tasks); + graph.addDataFlows(serialGraphEdges(tasks, inputHandle, outputHandle)); + return graph; +} diff --git a/packages/core/src/task/TaskGraphRunner.ts b/packages/core/src/task/TaskGraphRunner.ts new file mode 100644 index 000000000..4457d0e63 --- /dev/null +++ b/packages/core/src/task/TaskGraphRunner.ts @@ -0,0 +1,97 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { TaskInput, Task, TaskOutput } from "task/Task"; +import { TaskGraph } from "task/TaskGraph"; + +export class TaskGraphRunner { + public layers: Map; + + constructor(public dag: TaskGraph) { + this.dag = dag; + this.layers = new Map(); + } + + public assignLayers(sortedNodes: Task[]) { + this.layers = new Map(); + const nodeToLayer = new Map(); + + sortedNodes.forEach((node, _index) => { + let maxLayer = -1; + + // Get all incoming edges (dependencies) of the node + const incomingEdges = this.dag.inEdges(node.config.id).map(([from]) => from); + + incomingEdges.forEach((from) => { + // Find the layer of the dependency + const layer = nodeToLayer.get(from); + if (layer !== undefined) { + maxLayer = Math.max(maxLayer, layer); + } + }); + + // Assign the node to the next layer after the maximum layer of its dependencies + const assignedLayer = maxLayer + 1; + nodeToLayer.set(node.config.id, assignedLayer); + + if (!this.layers.has(assignedLayer)) { + this.layers.set(assignedLayer, []); + } + + this.layers.get(assignedLayer)?.push(node); + }); + } + + public async runTasksAsync() { + let results: TaskOutput[] = []; + for (const [_layerNumber, nodes] of this.layers.entries()) { + const layerPromises = nodes.map(async (node) => { + const results = await node.run(); + this.dag.outEdges(node.config.id).forEach(([, , dataFlow]) => { + const toInput: TaskInput = {}; + const targetNode = this.dag.getNode(dataFlow.targetTaskId); + if (results[dataFlow.sourceTaskOutputId] !== undefined) + toInput[dataFlow.targetTaskInputId] = results[dataFlow.sourceTaskOutputId]; + targetNode!.setInputData(targetNode!.runInputData, toInput); + }); + return results; + }); + results = await Promise.all(layerPromises); + } + return results; + } + + public runTasksSync() { + let results: TaskOutput[] = []; + for (const [_layerNumber, nodes] of this.layers.entries()) { + results = nodes.map((node) => { + const results = node.runSyncOnly(); + this.dag.outEdges(node.config.id).forEach(([, , dataFlow]) => { + const toInput: TaskInput = {}; + const targetNode = this.dag.getNode(dataFlow.targetTaskId); + if (results[dataFlow.sourceTaskOutputId] !== undefined) + toInput[dataFlow.targetTaskInputId] = results[dataFlow.sourceTaskOutputId]; + targetNode!.setInputData(targetNode!.runInputData, toInput); + }); + return results; + }); + } + return results; + } + + public async runGraph() { + const sortedNodes = this.dag.topologicallySortedNodes(); + this.assignLayers(sortedNodes); + return await this.runTasksAsync(); + } + + public runGraphSyncOnly() { + const sortedNodes = this.dag.topologicallySortedNodes(); + this.assignLayers(sortedNodes); + return this.runTasksSync(); + } +} diff --git a/packages/core/src/task/TaskIOTypes.ts b/packages/core/src/task/TaskIOTypes.ts new file mode 100644 index 000000000..e04654a4a --- /dev/null +++ b/packages/core/src/task/TaskIOTypes.ts @@ -0,0 +1,101 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +export type Vector = number[] | Float32Array; + +export const valueTypes = { + any: { + name: "Any", + tsType: "any", + }, + text: { + name: "Text", + tsType: "string", + defaultValue: "", + }, + number: { + name: "Number", + tsType: "number", + defaultValue: 0, + }, + vector: { + name: "Vector", + tsType: "Vector", + defaultValue: [], + }, + model: { + name: "Model", + tsType: "string", + }, + text_embedding_model: { + name: "Embedding Model", + tsType: "string", + }, + text_generation_model: { + name: "Generation Model", + tsType: "string", + }, + text_summarization_model: { + name: "Summarization Model", + tsType: "string", + }, + text_question_answering_model: { + name: "Q&A Model", + tsType: "string", + }, + log_level: { + name: "Log Level", + tsType: "log_level", + }, +} as const; + +// Provided lookup type +type TsTypes = { + any: any; + string: string; + number: number; + Vector: Vector; + log_level: "debug" | "info" | "warn" | "error"; +}; + +// Extract TypeScript type for a given value type +export type ExtractTsType = + TsTypes[(typeof valueTypes)[VT]["tsType"]]; + +type InputType = { + id: string | number; + valueType: keyof typeof valueTypes; + isArray?: boolean; +}; + +type MappedType = T["isArray"] extends true + ? { [K in T["id"]]: Array> } + : { [K in T["id"]]: ExtractTsType }; + +export type CreateMappedType> = { + [P in T[number] as P["id"]]: MappedType

[P["id"]]; +}; + +export type TaskInputDefinition = { + readonly id: string; + readonly name: string; + readonly valueType: keyof typeof valueTypes; + readonly isArray?: boolean; + readonly defaultValue?: ExtractTsType; +}; + +export type TaskOutputDefinition = { + readonly id: string; + readonly name: string; + readonly valueType: keyof typeof valueTypes; + readonly isArray?: boolean; +}; + +export interface TaskNodeIO { + readonly inputs: TaskInputDefinition[]; + readonly outputs: TaskOutputDefinition[]; +} diff --git a/packages/core/src/task/TaskRegistry.ts b/packages/core/src/task/TaskRegistry.ts new file mode 100644 index 000000000..4c8a5a334 --- /dev/null +++ b/packages/core/src/task/TaskRegistry.ts @@ -0,0 +1,19 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { CompoundTask, SingleTask } from "./Task"; + +const all = new Map(); + +const registerTask = (baseClass: typeof SingleTask | typeof CompoundTask) => { + all.set(baseClass.type, baseClass); +}; + +export const TaskRegistry = { + registerTask, + all, +}; diff --git a/packages/core/src/task/exec/ml/HuggingFaceLocalTaskRun.ts b/packages/core/src/task/exec/ml/HuggingFaceLocalTaskRun.ts new file mode 100644 index 000000000..5a4b9467d --- /dev/null +++ b/packages/core/src/task/exec/ml/HuggingFaceLocalTaskRun.ts @@ -0,0 +1,256 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { + pipeline, + type PipelineType, + type FeatureExtractionPipeline, + type TextGenerationPipeline, + type TextGenerationSingle, + type SummarizationPipeline, + type SummarizationSingle, + type QuestionAnsweringPipeline, + type DocumentQuestionAnsweringSingle, + env, +} from "@sroussey/transformers"; +import { ModelFactory } from "../../ModelFactory"; +import { + DownloadTask, + DownloadTaskInput, + EmbeddingTask, + EmbeddingTaskInput, + QuestionAnswerTask, + QuestionAnswerTaskInput, + TextRewriterTaskInput, + SummarizeTask, + SummarizeTaskInput, + TextGenerationTask, + TextGenerationTaskInput, + TextRewriterTask, + DownloadTaskOutput, + EmbeddingTaskOutput, + TextGenerationTaskOutput, + TextRewriterTaskOutput, + SummarizeTaskOutput, + QuestionAnswerTaskOutput, +} from "../../ModelFactoryTasks"; +import { findModelByName } from "../../../storage/InMemoryStorage"; +import { ONNXTransformerJsModel } from "../../../model/HuggingFaceModel"; +import { ModelProcessorEnum } from "../../../model/Model"; +import { Vector } from "../../TaskIOTypes"; +import { SingleTask } from "../../Task"; + +env.backends.onnx.logLevel = "error"; +env.backends.onnx.debug = false; + +/** + * + * This is a helper function to get a pipeline for a model and assign a + * progress callback to the task. + * + * @param task + * @param model + * @param options + */ +const getPipeline = async ( + task: SingleTask, + model: ONNXTransformerJsModel, + { quantized, config }: { quantized: boolean; config: any } = { + quantized: true, + config: null, + } +) => { + return await pipeline(model.pipeline as PipelineType, model.name, { + quantized, + config, + progress_callback: (details: { + file: string; + status: string; + name: string; + progress: number; + loaded: number; + total: number; + }) => { + const { progress, file } = details; + task.progress = progress; + task.emit("progress", progress, file); + }, + }); +}; + +// =============================================================================== + +/** + * This is a task that downloads and caches an onnx model. + */ + +export async function HuggingFaceLocal_DownloadTask( + task: DownloadTask, + runInputData: DownloadTaskInput +): Promise { + const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + await getPipeline(task, model!); + return { model: model.name }; +} + +/** + * This is a task that generates an embedding for a single piece of text + * + * Model pipeline must be "feature-extraction" + */ +export async function HuggingFaceLocal_EmbeddingTask( + task: EmbeddingTask, + runInputData: EmbeddingTaskInput +): Promise { + const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + const generateEmbedding = (await getPipeline(task, model)) as FeatureExtractionPipeline; + + var vector = await generateEmbedding(runInputData.text, { + pooling: "mean", + normalize: model.normalize, + }); + + if (vector.size !== model.dimensions) { + throw `Embedding vector length does not match model dimensions v${vector.size} != m${model.dimensions}`; + } + return { vector: vector.data as Vector }; +} + +/** + * This generates text from a prompt + * + * Model pipeline must be "text-generation" or "text2text-generation" + */ +export async function HuggingFaceLocal_TextGenerationTask( + task: TextGenerationTask, + runInputData: TextGenerationTaskInput +): Promise { + const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + + const generateText = (await getPipeline(task, model)) as TextGenerationPipeline; + + let results = await generateText(runInputData.prompt); + if (!Array.isArray(results)) { + results = [results]; + } + return { + text: (results[0] as TextGenerationSingle)?.generated_text, + }; +} + +/** + * This is a special case of text generation that takes a prompt and text to rewrite + * + * Model pipeline must be "text-generation" or "text2text-generation" + */ +export async function HuggingFaceLocal_TextRewriterTask( + task: TextRewriterTask, + runInputData: TextRewriterTaskInput +): Promise { + const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + + const generateText = (await getPipeline(task, model)) as TextGenerationPipeline; + + // This lib doesn't support this kind of rewriting with a separate prompt vs text + const promptedtext = (runInputData.prompt ? runInputData.prompt + "\n" : "") + runInputData.text; + let results = await generateText(promptedtext); + if (!Array.isArray(results)) { + results = [results]; + } + + const text = (results[0] as TextGenerationSingle)?.generated_text; + if (text == promptedtext) { + throw "Rewriter failed to generate new text"; + } + + return { text }; +} + +/** + * This summarizes a piece of text + * + * Model pipeline must be "summarization" + */ + +export async function HuggingFaceLocal_SummarizeTask( + task: SummarizeTask, + runInputData: SummarizeTaskInput +): Promise { + const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + + const generateSummary = (await getPipeline(task, model)) as SummarizationPipeline; + + let results = await generateSummary(runInputData.text); + if (!Array.isArray(results)) { + results = [results]; + } + + return { + text: (results[0] as SummarizationSingle)?.summary_text, + }; +} + +/** + * This is a special case of text generation that takes a context and a question + * + * Model pipeline must be "question-answering" + */ +export async function HuggingFaceLocal_QuestionAnswerTask( + task: QuestionAnswerTask, + runInputData: QuestionAnswerTaskInput +): Promise { + const model = findModelByName(runInputData.model) as ONNXTransformerJsModel; + + const generateAnswer = (await getPipeline(task, model)) as QuestionAnsweringPipeline; + + let results = await generateAnswer(runInputData.question, runInputData.context); + if (!Array.isArray(results)) { + results = [results]; + } + + return { + answer: (results[0] as DocumentQuestionAnsweringSingle)?.answer, + }; +} + +export async function registerHuggingfaceLocalTasks() { + ModelFactory.registerRunFn( + DownloadTask, + ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + HuggingFaceLocal_DownloadTask + ); + + ModelFactory.registerRunFn( + EmbeddingTask, + ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + HuggingFaceLocal_EmbeddingTask + ); + + ModelFactory.registerRunFn( + TextGenerationTask, + ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + HuggingFaceLocal_TextGenerationTask + ); + + ModelFactory.registerRunFn( + TextRewriterTask, + ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + HuggingFaceLocal_TextRewriterTask + ); + + ModelFactory.registerRunFn( + SummarizeTask, + ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + HuggingFaceLocal_SummarizeTask + ); + + ModelFactory.registerRunFn( + QuestionAnswerTask, + ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS, + HuggingFaceLocal_QuestionAnswerTask + ); +} diff --git a/packages/core/src/task/exec/ml/MediaPipeLocalTaskRun.ts b/packages/core/src/task/exec/ml/MediaPipeLocalTaskRun.ts new file mode 100644 index 000000000..bdf42a4c1 --- /dev/null +++ b/packages/core/src/task/exec/ml/MediaPipeLocalTaskRun.ts @@ -0,0 +1,81 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { FilesetResolver, TextEmbedder } from "@mediapipe/tasks-text"; +import { ModelFactory } from "../../ModelFactory"; +import { + DownloadTask, + DownloadTaskInput, + EmbeddingTask, + EmbeddingTaskInput, +} from "../../ModelFactoryTasks"; +import { findModelByName } from "../../../storage/InMemoryStorage"; +import { MediaPipeTfJsModel } from "../../../model/MediaPipeModel"; +import { ModelProcessorEnum } from "../../../model/Model"; + +/** + * This is a task that downloads and caches a MediaPipe TFJS model. + */ +export async function MediaPipeTfJsLocal_DownloadTask( + task: DownloadTask, + runInputData: DownloadTaskInput +) { + const textFiles = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" + ); + const model = findModelByName(runInputData.model) as MediaPipeTfJsModel; + const results = await TextEmbedder.createFromOptions(textFiles, { + baseOptions: { + modelAssetPath: model.url, + }, + quantize: true, + }); + + return results; +} + +/** + * This is a task that generates an embedding for a single piece of text + * using a MediaPipe TFJS model. + */ +export async function MediaPipeTfJsLocal_EmbeddingTask( + task: EmbeddingTask, + runInputData: EmbeddingTaskInput +) { + const textFiles = await FilesetResolver.forTextTasks( + "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm" + ); + const model = findModelByName(runInputData.model) as MediaPipeTfJsModel; + const textEmbedder = await TextEmbedder.createFromOptions(textFiles, { + baseOptions: { + modelAssetPath: model.url, + }, + quantize: true, + }); + + const output = textEmbedder.embed(runInputData.text); + const vector = output.embeddings[0].floatEmbedding; + + if (vector?.length !== model.dimensions) { + throw `Embedding vector length does not match model dimensions v${vector?.length} != m${model.dimensions}`; + } + return { vector }; +} + +export const registerMediaPipeTfJsLocalTasks = () => { + ModelFactory.registerRunFn( + DownloadTask, + ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL, + MediaPipeTfJsLocal_DownloadTask + ); + + ModelFactory.registerRunFn( + DownloadTask, + ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL, + MediaPipeTfJsLocal_EmbeddingTask + ); +}; diff --git a/src/util/Misc.ts b/packages/core/src/util/Misc.ts similarity index 100% rename from src/util/Misc.ts rename to packages/core/src/util/Misc.ts diff --git a/packages/core/tests/Task.test.ts b/packages/core/tests/Task.test.ts new file mode 100644 index 000000000..b0475c0ef --- /dev/null +++ b/packages/core/tests/Task.test.ts @@ -0,0 +1,69 @@ +import { describe, expect, it } from "bun:test"; +import { SingleTask, CompoundTask } from "../src/task/Task"; +import { TaskGraph } from "../src/task/TaskGraph"; +import { TaskOutput } from "../dist/lib"; + +class TestTask extends SingleTask { + static readonly type = "TestTask"; + runSyncOnly(): TaskOutput { + return { syncOnly: true }; + } + async run(): Promise { + return { all: true }; + } +} + +class TestCompoundTask extends CompoundTask { + static readonly type = "TestTask"; + runSyncOnly(): TaskOutput { + return Object.assign(this.runOutputData, this.runInputData, { syncOnly: true }); + } + async run(): Promise { + return Object.assign(this.runOutputData, this.runInputData, { all: true }); + } +} + +describe("Task", () => { + describe("SingleTask", () => { + it("should set input data and run the task", async () => { + const node = new TestTask(); + const input = { key: "value" }; + const output = await node.runWithInput(input); + expect(output).toEqual({ all: true }); + expect(node.runInputData).toEqual(input); + }); + + it("should run the task synchronously", () => { + const node = new TestTask(); + const output = node.runSyncOnly(); + expect(output).toEqual({ syncOnly: true }); + }); + }); + + describe("CompoundTask", () => { + it("should create a CompoundTask", () => { + const node = new TestCompoundTask(); + expect(node).toBeInstanceOf(CompoundTask); + }); + + it("should create a subgraph for the CompoundTask", () => { + const node = new TestCompoundTask(); + const subGraph = node.subGraph; + expect(subGraph).toBeInstanceOf(TaskGraph); + }); + + it("should set input data and run the task", async () => { + const node = new TestCompoundTask(); + const input = { key: "value" }; + const output = await node.runWithInput(input); + expect(output).toEqual({ key: "value", all: true }); + expect(node.runInputData).toEqual(input); + }); + + it("should run the task synchronously", () => { + const node = new TestCompoundTask({ input: { key: "value2" } }); + const output = node.runSyncOnly(); + expect(output).toEqual({ key: "value2", syncOnly: true }); + }); + }); +}); diff --git a/packages/core/tests/TaskGraph.test.ts b/packages/core/tests/TaskGraph.test.ts new file mode 100644 index 000000000..8c2a73f33 --- /dev/null +++ b/packages/core/tests/TaskGraph.test.ts @@ -0,0 +1,70 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it, beforeEach } from "bun:test"; +import { SingleTask, Task, TaskOutput } from "../src/task/Task"; +import { TaskGraph, DataFlow, serialGraph } from "../src/task/TaskGraph"; + +class TestTask extends SingleTask { + static readonly type = "TestTask"; + runSyncOnly(): TaskOutput { + return {}; + } + async run(): Promise { + return {}; + } +} + +describe("TaskGraph", () => { + let graph = new TaskGraph(); + let tasks: Task[]; + + beforeEach(() => { + graph = new TaskGraph(); + tasks = [ + new TestTask({ id: "task1" }), + new TestTask({ id: "task2" }), + new TestTask({ id: "task3" }), + ]; + }); + + it("should add nodes to the graph", () => { + graph.addTasks(tasks); + + expect(graph.getTask("task1")).toBeDefined(); + expect(graph.getTask("task2")).toBeDefined(); + expect(graph.getTask("task3")).toBeDefined(); + }); + + it("should add edges to the graph", () => { + const edges: DataFlow[] = [ + new DataFlow("task1", "output1", "task2", "input1"), + new DataFlow("task2", "output2", "task3", "input2"), + ]; + + graph.addTasks(tasks); + graph.addDataFlows(edges); + + expect(graph.getDataFlow("task1.output1 -> task2.input1")).toBeDefined(); + expect(graph.getDataFlow("task2.output2 -> task3.input2")).toBeDefined(); + }); + + it("should create a serial graph", () => { + const inputHandle = "input"; + const outputHandle = "output"; + + const expectedDataFlows: DataFlow[] = [ + new DataFlow("task1", inputHandle, "task2", outputHandle), + new DataFlow("task2", inputHandle, "task3", outputHandle), + ]; + + const result = serialGraph(tasks, inputHandle, outputHandle); + + expect(result).toBeInstanceOf(TaskGraph); + expect(result.getDataFlows()).toEqual(expectedDataFlows); + }); +}); diff --git a/packages/core/tests/TaskGraphRunner.test.ts b/packages/core/tests/TaskGraphRunner.test.ts new file mode 100644 index 000000000..8bc10f8a8 --- /dev/null +++ b/packages/core/tests/TaskGraphRunner.test.ts @@ -0,0 +1,100 @@ +// ******************************************************************************* +// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * +// * * +// * Copyright Steven Roussey * +// * Licensed under the Apache License, Version 2.0 (the "License"); * +// ******************************************************************************* + +import { describe, expect, it, beforeEach, spyOn } from "bun:test"; +import { TaskGraphRunner } from "../src/task/TaskGraphRunner"; +import { Task, SingleTask, TaskOutput } from "../src/task/Task"; +import { DataFlow, TaskGraph } from "../src/task/TaskGraph"; + +class TestTask extends SingleTask { + static readonly type = "TestTask"; + runSyncOnly(): TaskOutput { + return {}; + } + async run(): Promise { + return {}; + } +} + +describe("TaskGraphRunner", () => { + let runner: TaskGraphRunner; + let graph: TaskGraph; + let nodes: Task[]; + + beforeEach(() => { + graph = new TaskGraph(); + nodes = [ + new TestTask({ id: "task1" }), + new TestTask({ id: "task2" }), + new TestTask({ id: "task3" }), + ]; + graph.addTasks(nodes); + runner = new TaskGraphRunner(graph); + }); + + describe("assignLayers same layer", () => { + it("should assign layers to nodes based on dependencies", () => { + runner.assignLayers(nodes); + + expect(runner.layers.size).toBe(1); + expect(runner.layers.get(0)?.[0]).toEqual(nodes[0]); + expect(runner.layers.get(0)?.[1]).toEqual(nodes[1]); + expect(runner.layers.get(0)?.[2]).toEqual(nodes[2]); + }); + }); + + describe("assignLayers different layers", () => { + it("should assign layers to nodes based on dependencies", () => { + graph.addDataFlows([ + new DataFlow("task1", "output", "task2", "input"), + new DataFlow("task2", "output", "task3", "input"), + ]); + runner.assignLayers(nodes); + + expect(runner.layers.size).toBe(3); + expect(runner.layers.get(0)).toEqual([nodes[0]]); + expect(runner.layers.get(1)).toEqual([nodes[1]]); + expect(runner.layers.get(2)).toEqual([nodes[2]]); + }); + }); + + describe("runNodesAsync", () => { + it("should run nodes in each layer asynchronously", async () => { + const runSpy = spyOn(nodes[0], "run"); + + runner.assignLayers(nodes); + await runner.runTasksAsync(); + + expect(runSpy).toHaveBeenCalledTimes(1); + }); + }); + + describe("runNodesSync", () => { + it("should run nodes in each layer synchronously", () => { + const runSyncOnlySpy = spyOn(nodes[0], "runSyncOnly"); + + runner.assignLayers(nodes); + runner.runTasksSync(); + + expect(runSyncOnlySpy).toHaveBeenCalledTimes(1); + }); + }); + + describe("runGraph", () => { + it("should run the graph in the correct order", async () => { + const assignLayersSpy = spyOn(runner, "assignLayers"); + const runNodesSyncSpy = spyOn(runner, "runTasksSync"); + const runNodesAsyncSpy = spyOn(runner, "runTasksAsync"); + + await runner.runGraph(); + + expect(assignLayersSpy).toHaveBeenCalled(); + expect(runNodesSyncSpy).toHaveBeenCalled(); + expect(runNodesAsyncSpy).toHaveBeenCalled(); + }); + }); +}); diff --git a/packages/core/tsconfig.json b/packages/core/tsconfig.json new file mode 100644 index 000000000..ef0e5f056 --- /dev/null +++ b/packages/core/tsconfig.json @@ -0,0 +1,14 @@ +{ + "extends": "../../tsconfig.json", + "include": ["src/**/*", "tests/*/*.ts"], + "files": ["src/lib.ts"], + "exclude": ["**/*.test.ts", "dist"], + "compilerOptions": { + "outDir": "dist", + "baseUrl": "./src", + "rootDir": "./src", + "paths": { + "#/*": ["./src/*"] + } + } +} diff --git a/src-examples/ExampleSEC.ts b/src-examples/ExampleSEC.ts deleted file mode 100644 index 0146ddf23..000000000 --- a/src-examples/ExampleSEC.ts +++ /dev/null @@ -1,309 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { Command, InvalidArgumentError } from "commander"; -import { readFile } from "fs/promises"; -import { Listr, PRESET_TIMER } from "listr2"; -import { TaskHelper } from "./TaskHelper"; -import { Document, TextDocument, TextNode } from "#/Document"; -import { getPipeline } from "#/embeddings/TransformerJsService"; -import { - strategyAllPairs, - baaiBgeSmallEnV15, - instructPlain, - supabaseGteSmall, - // instructRepresent, - instructQuestion, - xenovaDistilbert, - whereIsAIUAELargeV1, - gpt2, - xenovaDistilbertMnli, - distilbartCnn, -} from "#/storage/InMemoryStorage"; -import { readFileSync, writeFileSync, mkdirSync } from "fs"; -import { getTopKEmbeddings } from "#/query/InMemoryQuery"; -import { Observable } from "rxjs"; -import { generateDocumentEmbeddings } from "#/embeddings/GenerateEmbeddings"; - -interface Filing { - cik: number; - accession_number: string; - primary_doc: string; - report_date: string; - filing_date: string; - acceptance_date: string; - form: string; - file_number: string; - film_number: string; - documents?: TextDocument[]; -} - -const loadSecAccessionDocument = async ( - cik: number, - accession_number: string, - primary_doc: string -) => { - const filepath = `./data-in/sec/cik-${cik}/files/${accession_number}:${primary_doc}`; - const docraw = await readFile(filepath, "utf8"); - return docraw; -}; - -const processSingleFiling = async (cik: number, accession_number: string) => { - const filing = await getFiling(cik, accession_number); - - const doc = await loadSecAccessionDocument( - cik, - accession_number, - filing.primary_doc - ); - - switch (filing.form) { - case "10-K": - // await process10K(doc); - break; - case "8-K": - return JSON.parse(doc); - case "10-Q": - // await process10Q(doc); - break; - default: - throw new Error(`Form ${filing.form} not supported`); - } - return doc; -}; - -async function getFiling(cik: number, accession_number: string) { - const filings = await getFilingsForCik(cik); - const filing = filings.find( - (filing) => - filing.cik === cik && filing.accession_number === accession_number - ); - if (!filing) throw new Error("Filing not found"); - return filing; -} - -let filingsCache: Filing[]; -async function getFilingsForCik(this: any, cik: number): Promise { - if (filingsCache) return filingsCache; - const filingspath = `./data-in/sec/cik-${cik}/filings.json`; - const filingsraw = await readFile(filingspath, "utf8"); - filingsCache = JSON.parse(filingsraw) as Filing[]; - return filingsCache; -} - -function myParseInt(value: string, dummyPrevious: number) { - // parseInt takes a string and a radix - const parsedValue = parseInt(value, 10); - if (isNaN(parsedValue)) { - throw new InvalidArgumentError("Not a number."); - } - return parsedValue; -} - -export function AddSecCommands(program: Command) { - program - .command("sec-index") - .description("process sec filings") - .argument("", "Run for only one cik", myParseInt) - .argument("[accession]", "Run for only one accession document") - .option("--debug", "Show debug messages") - .option("--form [name]", "Only certain forms") - .action(async (cik, accession, options) => { - const listrTasks = new Listr( - [ - { - title: "Prepare pipelines", - task: () => { - return new Observable((observer) => { - function updateProgress(stat: any) { - const { status, name, file, progress } = stat; - observer.next(`${name} ${file} ${status} ${progress}`); - } - async function run() { - await getPipeline(whereIsAIUAELargeV1, updateProgress); - await getPipeline(baaiBgeSmallEnV15, updateProgress); - await getPipeline(supabaseGteSmall, updateProgress); - await getPipeline(gpt2, updateProgress); - await getPipeline(xenovaDistilbertMnli, updateProgress); - await getPipeline(distilbartCnn, updateProgress); - observer.complete(); - } - run(); - }); - }, - }, - { - title: "Process SEC filings", - task: async (ctx, task) => { - let filings = await getFilingsForCik(cik); - - if (options.form) { - filings = filings.filter( - (f) => f.form === options.form.toUpperCase() - ); - } - - if (accession) { - filings = filings.filter( - (f) => f.accession_number === accession - ); - } - - const helper = new TaskHelper(task, filings.length); - for (const filing of filings) { - const cikStr = cik.toString().padStart(10, "0"); - await helper.onIteration(async () => { - const sections = await processSingleFiling( - filing.cik, - filing.accession_number - ); - filing.documents = [ - new TextDocument( - `${cikStr}:${filing.accession_number}:${filing.primary_doc}`, - Object.values(sections as object) - ), - ]; - await filing.documents.reduce(async (acc, document) => { - await acc; - await generateDocumentEmbeddings( - strategyAllPairs, - document - ); - return acc; - }, Promise.resolve()); - }, `Processing ${cikStr} ${filing.accession_number}`); - } - - mkdirSync(`./data-out/sec/cik-${cik}`, { recursive: true }); - writeFileSync( - `./data-out/sec/cik-${cik}/embeddings.json`, - JSON.stringify(filings, null, 2) - ); - }, - }, - ], - { - exitOnError: true, - concurrent: false, - rendererOptions: { timer: PRESET_TIMER }, - } - ); - await listrTasks.run({ cik, accession, debug: options.debug }); - }); - - program - .command("sec-search") - .description("search sec filings") - .argument("", "Run for only one cik", myParseInt) - .argument("", "Question to ask") - .option("--debug", "Show debug messages") - .option("--form [name]", "Only certain forms") - .action(async (cik, query, options) => { - const listrTasks = new Listr( - [ - { - title: "Prepare pipelines", - task: () => { - return new Observable((observer) => { - function updateProgress(stat: any) { - const { status, name, file, progress } = stat; - observer.next(`${name} ${file} ${status} ${progress}`); - } - async function run() { - await getPipeline(xenovaDistilbert, updateProgress); - await getPipeline(baaiBgeSmallEnV15, updateProgress); - await getPipeline(supabaseGteSmall, updateProgress); - observer.complete(); - } - run(); - }); - }, - }, - { - title: "Search SEC filings", - task: async (ctx, task) => { - let filings = JSON.parse( - readFileSync( - `./data-out/sec/cik-${cik}/embeddings.json`, - "utf8" - ) - ) as Filing[]; - - // filings.forEach((f) => { - // f.documents?.forEach((d) => { - // d.nodes?.forEach((n) => { - // n.embeddings.forEach( - // (e) => (e.vector = new Float32Array(e.vector)) - // ); - // }); - // }); - // }); - - if (options.form) { - filings = filings.filter( - (f) => f.form === options.form.toUpperCase() - ); - } - - const docs = filings.reduce((acc, f) => { - if (!f.documents) return acc; - return acc.concat(f.documents); - }, []); - - const nodes = docs.reduce((acc, d) => { - if (!d.nodes) return acc; - return acc.concat(d.nodes as TextNode[]); - }, []); - - const queryDocument = new TextDocument("query", query); - await generateDocumentEmbeddings( - // strategyAllPairs, - [ - { - embeddingModel: baaiBgeSmallEnV15, - instruct: instructPlain, - }, - ], - queryDocument - ); - - const similarities = getTopKEmbeddings( - queryDocument.nodes[0], - nodes, - 3 - ); - - const answerer = await getPipeline(xenovaDistilbert); - - const context = similarities - .map((s) => s.node.content) - .join("\n\n"); - const output = await answerer(query, context); - - console.log( - output, - // similarities.map((s) => { - // return { - // model: s.embedding.modelName, - // instruct: s.embedding.instructName, - // similarity: s.similarity, - // }; - // }), - "\n\n\n\n\n" - ); - }, - }, - ], - { - exitOnError: true, - concurrent: false, - rendererOptions: { timer: PRESET_TIMER }, - } - ); - await listrTasks.run({ cik, query, debug: options.debug }); - }); -} diff --git a/src-examples/TaskStreamToListr2.ts b/src-examples/TaskStreamToListr2.ts deleted file mode 100644 index 97ec459df..000000000 --- a/src-examples/TaskStreamToListr2.ts +++ /dev/null @@ -1,93 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { - TaskStreamable, - type TaskStream, - TaskStatus, - TaskListOrdering, -} from "#/Task"; -import { Listr, ListrTask } from "listr2"; -import { createBar } from "./TaskHelper"; -import { PRESET_TIMER } from "listr2"; -import { Observable } from "rxjs"; - -const taskArrayToListr = ( - tasks: TaskStream, - options: Record = { concurrent: false, exitOnError: true } -): Listr => { - const list: ListrTask[] = []; - - for (const task of tasks) { - switch (task.kind) { - case "TASK": - list.push({ - title: task.config.name, - task: async (_, t) => { - if ( - task.status == TaskStatus.COMPLETED || - task.status == TaskStatus.FAILED - ) { - return; - } - return new Observable((observer) => { - const start = Date.now(); - let lastUpdate = start; - task.on("progress", (progress, file) => { - const timeSinceLast = Date.now() - lastUpdate; - const timeSinceStart = Date.now() - start; - if (timeSinceLast > 250 || timeSinceStart > 100) { - observer.next( - createBar(progress / 100 || 0, 30) + " " + (file || "") - ); - } - }); - task.on("complete", () => { - observer.complete(); - }); - task.on("error", () => { - observer.complete(); - }); - }); - }, - }); - break; - case "TASK_LIST": - list.push({ - title: task.config.name, - task: async (_, t) => { - return taskArrayToListr(task.tasks, { - concurrent: task.ordering == TaskListOrdering.PARALLEL, - exitOnError: task.ordering == TaskListOrdering.SERIAL, - }); - }, - }); - break; - case "STRATEGY": - list.push({ - title: task.config.name, - task: async (_, t) => { - return taskArrayToListr(task.tasks); - }, - }); - break; - } - } - const listr = new Listr(list, options); - return listr; -}; - -export const runTaskToListr = async (task: TaskStreamable) => { - const listrTasks = taskArrayToListr([task], { - exitOnError: true, - concurrent: false, - rendererOptions: { timer: PRESET_TIMER }, - }); - listrTasks.run({}); - await new Promise((resolve) => setTimeout(resolve, 100)); - await task.run({}); -}; diff --git a/src/Flow.ts b/src/Flow.ts deleted file mode 100644 index a897ea107..000000000 --- a/src/Flow.ts +++ /dev/null @@ -1,39 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { Task, TaskInput, TaskOutput } from "./Task"; -import { DirectedAcyclicGraph } from "@sroussey/typescript-graph"; - -class Transfer { - input?: TaskInput; - output?: TaskOutput; -} -type FlowGraph = DirectedAcyclicGraph; - -export class Flow { - #graph: FlowGraph; - constructor() { - this.#graph = new DirectedAcyclicGraph( - (task) => task.config.id - ); - } - addTask(task: Task) { - this.#graph.insert(task); - } - getTasks(): Task[] { - return this.#graph.topologicallySortedNodes(); - } - removeTask(task: Task) { - this.#graph.remove(task.config.id); - } - addTransfer(from: Task, to: Task, edge: Transfer) { - this.#graph.addEdge(from.config.id, to.config.id, edge); - } - removeTransfer(from: Task, to: Task) { - this.#graph.removeEdge(from.config.id, to.config.id); - } -} diff --git a/src/JobQueue.ts b/src/JobQueue.ts deleted file mode 100644 index 2425cf13f..000000000 --- a/src/JobQueue.ts +++ /dev/null @@ -1,204 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import uuid from "uuid"; - -export enum JobStatus { - PENDING = "NEW", - PROCESSING = "PROCESSING", - COMPLETED = "COMPLETED", - FAILED = "FAILED", -} - -// =============================================================================== - -export abstract class Job { - constructor( - public readonly id: unknown, - public readonly queue: string, - public readonly taskName: string, - public readonly input: any, - public readonly maxRetries: number, - public readonly createdAt: Date - ) { - this.runAfter = createdAt; - } - public status: JobStatus = JobStatus.PENDING; - public runAfter: Date; - public output: any = null; - public retries: number = 0; - public ranAt: Date | null = null; - public completedAt: Date | null = null; - public error: string | undefined = undefined; -} - -export abstract class JobQueue { - public abstract add(job: Job): void; - public abstract next(): Job | undefined; - public abstract size(): number; - public abstract complete(id: unknown, output: any, error?: string): void; -} - -// =============================================================================== -// Local Version -// =============================================================================== - -export class LocalJob extends Job { - constructor(queue: string, taskName: string, input: any) { - const id = uuid.v4(); - const createdAt = new Date(); - const maxRetries = 10; - super(id, queue, taskName, input, maxRetries, createdAt); - } -} - -export class LocalJobQueue extends JobQueue { - private readonly queue: LocalJob[] = []; - - #reorderQueue(): void { - this.queue - .filter((job) => job.status === JobStatus.PENDING) - .filter((job) => job.runAfter.getTime() <= Date.now()) - .sort((a, b) => a.createdAt.getTime() - b.createdAt.getTime()); - } - - public add(job: Job): void { - this.queue.push(job); - } - - public next(): Job | undefined { - this.#reorderQueue(); - - const job = this.queue[0]; - job.status = JobStatus.PROCESSING; - return job; - } - - public size(): number { - return this.queue.length; - } - - public complete(id: unknown, output: any, error?: string): void { - const job = this.queue.find((j) => j.id === id); - if (!job) { - throw new Error(`Job ${id} not found`); - } - job.completedAt = new Date(); - if (error) { - job.status = JobStatus.FAILED; - job.error = error; - } else { - job.status = JobStatus.COMPLETED; - job.output = output; - } - } -} - -// =============================================================================== -// PostgreSQL Version (idea for, never executed, todo) -// =============================================================================== - -/* - -CREATE TABLE IF NOT EXISTS job_queue ( - id bigint SERIAL NOT NULL, - fingerprint text NOT NULL, - queue text NOT NULL, - status job_status NOT NULL default 'new', - payload jsonb, - output jsonb, - retries integer default 0, - max_retries integer default 23, - run_after timestamp with time zone DEFAULT now(), - ran_at timestamp with time zone, - created_at timestamp with time zone DEFAULT now(), - error text -); - -CREATE INDEX IF NOT EXISTS job_fetcher_idx ON job_queue (id, status, run_after); -CREATE INDEX IF NOT EXISTS job_queue_fetcher_idx ON job_queue (queue, status, run_after); -CREATE INDEX IF NOT EXISTS job_queue_fingerprint_idx ON job_queue (fingerprint, status); - ---- we could return results from the existing job instead of queuing a new one if we wanted to, not sure, weird way to cache, todo -CREATE UNIQUE INDEX IF NOT EXISTS jobs_fingerprint_unique_idx ON job_queue (fingerprint, status) WHERE NOT (status = 'processed'); - -*/ - -export class PostgresqlJob extends Job { - constructor(queue: string, taskName: string, input: any) { - const id = uuid.v4(); - const createdAt = new Date(); - const maxRetries = 10; - super(id, queue, taskName, input, maxRetries, createdAt); - } -} - -// Do not use "AUTOCOMMIT" mode for any of the below, put inside a transaction - -export class PostgresqlJobQueue extends JobQueue { - public add(job: Job): void { - const AddQuery = ` - INSERT INTO job_queue(queue, fingerprint, payload, run_after, deadline, max_retries) - VALUES ($1, $2, $3, $4, $5, $6) - RETURNING id`; - } - - public get(): Job | undefined { - const JobQuery = ` - SELECT id, fingerprint, queue, status, deadline, payload, retries, max_retries, run_after, ran_at, created_at, error - FROM job_queue - WHERE id = $1 - FOR UPDATE SKIP LOCKED - LIMIT 1`; - - return; - } - - public peek(num: number = 100): Job | undefined { - num = Number(num) || 100; - const FutureJobQuery = ` - SELECT id, fingerprint, queue, status, deadline, payload, retries, max_retries, run_after, ran_at, created_at, error - FROM job_queue - WHERE queue = $1 - AND status = 'new' - AND run_after > NOW() - ORDER BY run_after ASC - LIMIT ${num} - FOR UPDATE SKIP LOCKED`; - - return; - } - - public next(): Job | undefined { - const PendingJobIDQuery = ` - SELECT id - FROM job_queue - WHERE queue = $1 - AND status = 'new' - AND run_after <= NOW() - FOR UPDATE SKIP LOCKED - LIMIT 1`; - - return; - } - - public size(): number { - // why do we even care about size? - return 0; - } - - #update(id: unknown, output: any, error?: string): void { - const UpdateQuery = ` - UPDATE job_queue - SET ... - WHERE id = ...`; - } - - public complete(id: unknown, output: any, error?: string): void { - // - } -} diff --git a/src/Strategy.ts b/src/Strategy.ts deleted file mode 100644 index 4dfac5baf..000000000 --- a/src/Strategy.ts +++ /dev/null @@ -1,27 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { Instruct } from "./Instruct"; -import { Model } from "./Model"; - -/** - * A strategy is a combination of a model and an instruction for that model. - * This combination is used to generate embeddings for a document, and later - * to generate embeddings for a query. - * - * A node (a block of text or clip of image) can have multiple embeddings, - * though perferably only one in use at a time. This allows for updating - * an embedding strategy while keeping the old one around for live use. - * - * It also allows for testing multiple strategies for a given dataset. - */ -export interface Strategy { - embeddingModel: Model; - instruct: Instruct; -} - -export type StrategyList = Strategy[]; diff --git a/src/Task.ts b/src/Task.ts deleted file mode 100644 index 15835e414..000000000 --- a/src/Task.ts +++ /dev/null @@ -1,330 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { EventEmitter } from "eventemitter3"; -import { deepEqual } from "./util/Misc"; - -class InputOutput { - provanance: string[] = []; - constructor(public input: I, public output: O, task: T) { - this.provanance.push(task.type); - } -} - -/** - * WARNING! - * TODO! - * - * Task input and output is not type safe. It is super brittle and hacky. I am thinking about a visual - * UI editor for tasks where you can map inputs and outputs, see what will run before you run it, etc. - * - * Also, task provenance is not tracked which is terrible for keeping state and caching intermediate results. - */ - -export enum TaskStatus { - PENDING = "NEW", - PROCESSING = "PROCESSING", - COMPLETED = "COMPLETED", - FAILED = "FAILED", -} - -/** - * TaskEvents - * - * There is no job queue at the moement. - */ -export type TaskEvents = "start" | "complete" | "error" | "progress"; - -// =============================================================================== - -export type StreamableTaskKind = "TASK" | "TASK_LIST" | "STRATEGY"; -export type StreamableTaskType = string; - -// =============================================================================== - -export type TaskStreamable = Task | TaskList | Strategy; -export type TaskStream = TaskStreamable[]; - -// =============================================================================== - -export interface ITaskSimple { - isCompound: false; -} -export interface ITaskCompound { - isCompound: true; - tasks: TaskStream; -} -export type ITask = ITaskSimple & ITaskCompound; - -export interface TaskConfig { - name?: string; - id?: unknown; - output_name?: string; -} - -type TaskConfigFull = TaskConfig & { output_name: string }; - -export interface TaskInput { - [key: string]: any; -} -export interface TaskOutput { - [key: string]: any; -} - -abstract class TaskBase { - events = new EventEmitter(); - on(name: TaskEvents, fn: (...args: any[]) => void) { - this.events.on.call(this.events, name, fn); - } - off(name: TaskEvents, fn: (...args: any[]) => void) { - this.events.off.call(this.events, name, fn); - } - emit(name: TaskEvents, ...args: any[]) { - this.events.emit.call(this.events, name, ...args); - } - /** - * Does this task have subtasks? - */ - abstract isCompound: boolean; - /** - * The defaults for the task. If no overrides at run time, then this would be equal to the - * input - */ - defaults: TaskInput = {}; - /** - * The input to the task at the time of the task run. This takes defaults from construction - * time and overrides from run time. It is the input that created the output. - */ - input: TaskInput = {}; - /** - * The output of the task at the time of the task run. This is the result of the task. - * The the defaults and overrides are combined to match the required input of the task. - */ - output: TaskInput = {}; - /** - * Configuration for the task, might include things like name and id for the database - */ - config: TaskConfigFull = { output_name: "out" }; - status: TaskStatus = TaskStatus.PENDING; - progress: number = 0; - createdAt: Date = new Date(); - completedAt: Date | null = null; - error: string | undefined = undefined; - - /** - * - * This calculates the input to the task at the time of the task run. This takes defaults from - * construction and applies run time overrides (which may be output from a previous run if this - * is a serial task or strategy). Caller needs to decide if should set to this classes input - * or not. - */ - withDefaults(...overrides: (Partial | undefined)[]): T { - return Object.assign({}, this.defaults, ...overrides) as T; - } - - constructor(config: TaskConfig = {}, defaults: TaskInput = {}) { - Object.defineProperty(this, "events", { enumerable: false }); - this.defaults = defaults; - this.input = this.withDefaults(); - this.config = Object.assign( - { - id: - this.constructor.name + - ":" + - Math.random().toString(36).substring(2, 9), - name: this.constructor.name, - }, - this.config, - config - ); - this.on("start", () => { - this.status = TaskStatus.PROCESSING; - }); - this.on("complete", () => { - this.completedAt = new Date(); - this.status = TaskStatus.COMPLETED; - }); - this.on("error", (error) => { - this.completedAt = new Date(); - this.status = TaskStatus.FAILED; - this.error = error; - }); - } - - abstract run(overrides?: TaskInput): Promise; -} - -export abstract class Task extends TaskBase implements ITaskSimple { - readonly kind = "TASK"; - readonly type: StreamableTaskType = "Task"; - readonly isCompound = false; -} - -// =============================================================================== - -export enum TaskListOrdering { - SERIAL = "SERIAL", - PARALLEL = "PARALLEL", -} - -export abstract class MultiTaskBase extends TaskBase implements ITaskCompound { - readonly isCompound = true; - abstract ordering: TaskListOrdering; - tasks: TaskStream = []; - started = 0; - completed = 0; - total = 0; - errors = 0; - - constructor( - config: TaskConfig = {}, - tasks: TaskStream = [], - defaults: TaskInput = {} - ) { - super(config, defaults); - this.setTasks(tasks); - } - - setTasks(tasks: TaskStream) { - if (this.tasks.length) { - this.tasks.forEach((task) => { - task.off("complete", this.#completeTask); - task.off("error", this.#errorTask); - }); - } - this.tasks = tasks; - tasks.forEach((task) => { - task.on("complete", this.#completeTask); - task.on("error", this.#errorTask); - }); - } - - generateTasks(_tasks?: TaskStream) {} - - async #run_serial(overrides?: TaskInput) { - try { - this.emit("start"); - this.input = this.withDefaults(overrides); - // TODO: dont regenerate if defaults are the same as input (only check what matters) - if (this.generateTasks && !deepEqual(this.input, this.defaults)) - this.generateTasks(); // only strategy should do this - const total = this.tasks.length; - let taskInput = {}; - for (const task of this.tasks) { - await task.run(taskInput); - if (this.tasks[this.tasks.length - 1] == task) { - // if last task, their result is our result - this.output = task.output; - break; - } - taskInput = Object.assign({}, task.output); - this.emit("progress", this.completed / total); - if (this.errors) { - this.emit("error", this.error); - break; - } - } - this.emit("complete"); - return this.output; - } catch (e) { - this.emit("error", String(e)); - return this.output; - } - } - - async #run_parallel(overrides?: TaskInput) { - this.emit("start"); - - this.input = this.withDefaults(overrides); - // TODO: dont regenerate if defaults are the same as input (only check what matters) - if (this.generateTasks && !deepEqual(this.input, this.defaults)) - this.generateTasks(); // only strategy should do this - - let taskInput = {}; - - const total = this.tasks.length; - await Promise.all( - this.tasks.map(async (task) => { - await task.run(taskInput); - this.emit("progress", this.completed / total); - }) - ); - - const outputs = this.tasks.map((task) => task.output || {}) || []; - const result: TaskInput = {}; - outputs.forEach((item) => { - Object.keys(item).forEach((key) => { - if (!result[key]) { - result[key] = []; - } - result[key].push(item[key]); - }); - }); - - this.output = result; - if (this.errors === total) this.emit("error", this.error); - this.emit("complete"); - return this.output; - } - - async run(overrides?: TaskInput) { - if (this.ordering === TaskListOrdering.SERIAL) { - return this.#run_serial(overrides); - } else { - return this.#run_parallel(overrides); - } - } - - #completeTask() { - this.completed++; - this.emit("progress", this.completed / this.total); - } - - #errorTask(error: string) { - this.errors++; - this.error = this.error ? this.error + " & " + error : error; - } -} - -abstract class TaskList extends MultiTaskBase { - readonly kind = "TASK_LIST"; - readonly type: StreamableTaskType = "TaskList"; - declare _tasks: Task[]; -} - -export class SerialTaskList extends TaskList { - readonly type: StreamableTaskType = "SerialTaskList"; - ordering = TaskListOrdering.SERIAL; -} - -export class ParallelTaskList extends TaskList { - readonly type: StreamableTaskType = "ParallelTaskList"; - ordering = TaskListOrdering.PARALLEL; -} - -// =============================================================================== - -abstract class Strategy extends MultiTaskBase { - readonly kind = "STRATEGY"; - readonly type: StreamableTaskType = "Strategy"; - ordering = TaskListOrdering.SERIAL; - constructor(config: TaskConfig = {}, defaults: TaskInput = {}) { - super(config, [], defaults); - this.generateTasks(); - } - abstract generateTasks(): void; -} - -export abstract class SerialStrategy extends Strategy { - readonly type: StreamableTaskType = "SerialStrategy"; - ordering = TaskListOrdering.SERIAL; -} - -export abstract class ParallelStrategy extends Strategy { - readonly type: StreamableTaskType = "ParallelStrategy"; - ordering = TaskListOrdering.PARALLEL; -} diff --git a/src/embeddings/GenerateEmbeddings.ts b/src/embeddings/GenerateEmbeddings.ts deleted file mode 100644 index b55867d98..000000000 --- a/src/embeddings/GenerateEmbeddings.ts +++ /dev/null @@ -1,61 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { ModelProcessorEnum } from "#/Model"; -import { TextDocument } from "#/Document"; -import type { StrategyList } from "#/Strategy"; -import { - generateTransformerJsEmbedding, - generateTransformerJsRewrite, -} from "./TransformerJsService"; - -export async function generateEmbeddings( - strategies: StrategyList, - document: TextDocument, - isQuery: boolean -) { - for (const node of document.nodes) { - for (const { embeddingModel, instruct } of strategies) { - let text = node.content; - if (instruct.model) { - switch (instruct.model.type) { - case ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS: - text = await generateTransformerJsRewrite(node, instruct, isQuery); - break; - default: - throw new Error("Instruct Model type not supported yet"); - } - } - switch (embeddingModel.type) { - case ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS: - await generateTransformerJsEmbedding( - node, - text, - embeddingModel, - instruct - ); - break; - default: - throw new Error("Embedding Model type not supported yet"); - } - } - } -} - -export async function generateDocumentEmbeddings( - strategies: StrategyList, - document: TextDocument -) { - return generateEmbeddings(strategies, document, false); -} - -export async function generateQueryEmbeddings( - strategies: StrategyList, - document: TextDocument -) { - return generateEmbeddings(strategies, document, true); -} diff --git a/src/embeddings/TransformerJsService.ts b/src/embeddings/TransformerJsService.ts deleted file mode 100644 index 9dd7b4c33..000000000 --- a/src/embeddings/TransformerJsService.ts +++ /dev/null @@ -1,99 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { pipeline, type PipelineType } from "@sroussey/transformers"; -import { Model } from "#/Model"; -import type { Instruct } from "#/Instruct"; -import { NodeEmbedding, TextNode } from "#/Document"; -import { ONNXTransformerJsModel } from "#/tasks/HuggingFaceLocalTasks"; - -const modelPipelinesCache: Record = {}; - -export const getPipeline = async ( - model: ONNXTransformerJsModel, - progress_callback?: (progress: any) => void -) => { - if (!modelPipelinesCache[model.name]) { - modelPipelinesCache[model.name] = await pipeline( - model.pipeline as PipelineType, - model.name, - { - progress_callback, - } - ); - } - return modelPipelinesCache[model.name]; -}; - -export async function generateTransformerJsEmbedding( - node: TextNode, - rewrittenText: string, - model: Model, - instruct: Instruct -) { - const generateEmbedding = await getPipeline(model as ONNXTransformerJsModel); - - const text = rewrittenText || node.content; - - const output = await generateEmbedding(text, { - pooling: "mean", - normalize: model.normalize, - temperature: instruct.parameters?.temperature, - }); - - const vector = Array.from(output.data); - - if (vector.length !== model.dimensions) { - throw new Error( - `Embedding vector length does not match model dimensions v${vector.length} != m${model.dimensions}` - ); - } - - node.embeddings.push( - new NodeEmbedding(model.name, instruct.name, text, vector, model.normalize) - ); -} - -export async function generateTransformerJsRewrite( - node: TextNode, - instruct: Instruct, - query: boolean -): Promise { - let instruction = query - ? instruct.queryInstruction - : instruct.storageInstruction; - if (!instruct.model) { - return node.content; - } else { - instruction = instruction ? instruction + ":\n" : ""; - } - - const rewriter = await getPipeline(instruct.model as ONNXTransformerJsModel); - - const output = await rewriter(node.content); - - let result = ""; - - switch ((instruct.model as ONNXTransformerJsModel).pipeline) { - case "text-generation": - result = output.generated_text; - break; - case "zero-shot-classification": - result = output.labels.join(", "); - break; - case "question-answering": - result = output.answer; - break; - case "summarization": - result = output?.[0]?.summary_text; - break; - default: - throw new Error("rewrite model pipeline not supported yet"); - } - - return result; -} diff --git a/src/stuff.ts b/src/stuff.ts deleted file mode 100644 index a7fc36e2e..000000000 --- a/src/stuff.ts +++ /dev/null @@ -1,55 +0,0 @@ -interface StorageService { - getDocument(documentId: number): Document; - getDocumentNode(documentId: number, nodeId: number): DocumentNode; - getDocumentNodeEmbedding( - documentId: number, - nodeId: number, - instructId: number, - modelId: number - ): DocumentNodeEmbedding; - getModels(): Model[]; - getInstructs(): Instruct[]; -} - -enum InvokationEventType { - START, - STOP, - EMBEDDING, - FIRST_TOKEN, - LAST_TOKEN, - TOKENS, - TOKENS_PER_SECOND, -} - -interface InvokationEvent { - type: InvokationEventType; - instructId: number; - modelId: number; - token: string; - tokens: string[]; - tokensPerSecond: number; -} - -interface DocumentInvokationEvent extends InvokationEvent { - documentId: number; - nodeId: number; -} - -interface PromptInvokationEvent extends InvokationEvent { - prompt: string; -} - -interface EmbeddingService { - instructId: number; - modelId: number; - transform: (document: Document, node: DocumentNode) => DocumentNodeEmbedding; -} - -// Configuration for Deno runtime -// env.useBrowserCache = false; -// env.allowLocalModels = false; - -// const generateEmbedding = await pipeline( -// "feature-extraction", -// "Supabase/gte-small" -// ); diff --git a/src/tasks/BasicTasks.ts b/src/tasks/BasicTasks.ts deleted file mode 100644 index d317776e5..000000000 --- a/src/tasks/BasicTasks.ts +++ /dev/null @@ -1,70 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { TaskConfig, Task, TaskInput } from "#/Task"; - -export interface RenameTaskInput { - output_remap_array: { - from: string; - to: string; - }[]; -} - -/** - * Uses config to map multiple values from inputs to outputs - */ -export class RenameTask extends Task { - readonly type: string = "RenameTask"; - constructor( - config: TaskConfig = {}, - defaults: RenameTaskInput = { output_remap_array: [] } - ) { - config.name ||= - "RenameTask" + - (defaults?.output_remap_array?.length - ? defaults?.output_remap_array - .map(({ from, to }) => `: from ${from} to ${to}`) - .join(", ") - : ""); - super(config, defaults); - } - async run(overrides?: RenameTaskInput) { - this.emit("start"); - this.input = this.withDefaults(overrides); - this.output = {}; - for (const { from, to } of this.input.output_remap_array) { - if (from != "output_remap_array") { - this.output[to] = this.input[from]; - } - } - this.emit("complete"); - return this.output; - } -} - -// =============================================================================== - -export class LambdaTask extends Task { - #runner: (input: TaskInput) => Promise; - readonly type: string = "LambdaTask"; - constructor( - config: TaskConfig & { - run: () => Promise; - }, - defaults: TaskInput = {} - ) { - super(config, defaults); - this.#runner = config.run; - } - async run(overrides?: TaskInput) { - this.emit("start"); - this.input = this.withDefaults(overrides); - this.output = await this.#runner(this.input); - this.emit("complete"); - return this.output; - } -} diff --git a/src/tasks/FactoryTasks.ts b/src/tasks/FactoryTasks.ts deleted file mode 100644 index 75440ab10..000000000 --- a/src/tasks/FactoryTasks.ts +++ /dev/null @@ -1,156 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { Model } from "#/Model"; -import { TaskConfig, Task, TaskInput, StreamableTaskType } from "#/Task"; -import { - ONNXTransformerJsModel, - HuggingFaceLocal_EmbeddingTask, - HuggingFaceLocal_QuestionAnswerTask, - HuggingFaceLocal_SummarizationTask, - HuggingFaceLocal_TextGenerationTask, - HuggingFaceLocal_TextRewriterTask, - HuggingFaceLocal_DownloadTask, -} from "./HuggingFaceLocalTasks"; -import { - MediaPipeTfJsLocal_DownloadTask, - MediaPipeTfJsLocal_EmbeddingTask, - MediaPipeTfJsModel, -} from "./MediaPipeLocalTasks"; - -export interface ModelFactoryTaskInput { - model: Model; -} - -abstract class ModelFactoryTask extends Task { - declare input: ModelFactoryTaskInput; - constructor(config: TaskConfig = {}, defaults: ModelFactoryTaskInput) { - super(config, defaults); - } - - run(overrides?: TaskInput): Promise { - throw new Error("ModelFactoryTask:run() method not implemented."); - } -} - -interface DownloadTaskInput { - model: Model; -} - -export class DownloadTask extends ModelFactoryTask { - declare input: DownloadTaskInput; - readonly type: StreamableTaskType = "DownloadTask"; - constructor(config: TaskConfig = {}, defaults: DownloadTaskInput) { - super(config, defaults); - const { model } = this.input; - if (model instanceof ONNXTransformerJsModel) { - return new HuggingFaceLocal_DownloadTask(this.config, { model }); - } - if (model instanceof MediaPipeTfJsModel) { - return new MediaPipeTfJsLocal_DownloadTask(this.config, { model }); - } - } -} - -export interface EmbeddingTaskInput { - text: string; - model: Model; -} -/** - * This is a task that generates an embedding for a single piece of text - */ -export class EmbeddingTask extends ModelFactoryTask { - declare input: EmbeddingTaskInput; - readonly type: StreamableTaskType = "EmbeddingTask"; - constructor(config: TaskConfig = {}, defaults: EmbeddingTaskInput) { - super(config, defaults); - const { text, model } = this.input; - if (model instanceof ONNXTransformerJsModel) { - return new HuggingFaceLocal_EmbeddingTask(this.config, { text, model }); - } - if (model instanceof MediaPipeTfJsModel) { - return new MediaPipeTfJsLocal_EmbeddingTask(this.config, { text, model }); - } - } -} - -export interface TextGenerationTaskInput { - text: string; - model: Model; -} -export class TextGenerationTask extends ModelFactoryTask { - readonly type: StreamableTaskType = "TextGenerationTask"; - declare input: TextGenerationTaskInput; - constructor(config: TaskConfig = {}, input: TextGenerationTaskInput) { - super(config, input); - const { text, model } = this.input; - if (model instanceof ONNXTransformerJsModel) { - return new HuggingFaceLocal_TextGenerationTask(this.config, { - text, - model, - }); - } - } -} - -export class SummarizeTask extends ModelFactoryTask { - declare input: TextGenerationTaskInput; - readonly type: StreamableTaskType = "SummarizeTask"; - constructor(config: TaskConfig = {}, input: TextGenerationTaskInput) { - super(config, input); - const { text, model } = this.input; - if (model instanceof ONNXTransformerJsModel) { - return new HuggingFaceLocal_SummarizationTask(this.config, { - text, - model, - }); - } - } -} - -export interface RewriterTaskInput { - text: string; - prompt: string; - model: Model; -} - -export class RewriterTask extends ModelFactoryTask { - readonly type: StreamableTaskType = "RewriterTask"; - declare input: RewriterTaskInput; - constructor(config: TaskConfig = {}, input: RewriterTaskInput) { - super(config, input); - const { text, model, prompt } = this.input; - if (model instanceof ONNXTransformerJsModel) { - return new HuggingFaceLocal_TextRewriterTask(this.config, { - text, - prompt, - model, - }); - } - } -} - -export interface QuestionAnswerTaskInput { - text: string; - context: string; - model: Model; -} -export class QuestionAnswerTask extends ModelFactoryTask { - declare input: QuestionAnswerTaskInput; - readonly type: StreamableTaskType = "QuestionAnswerTask"; - constructor(config: TaskConfig = {}, input: QuestionAnswerTaskInput) { - super(config, input); - const { text, model, context } = this.input; - if (model instanceof ONNXTransformerJsModel) { - return new HuggingFaceLocal_QuestionAnswerTask(this.config, { - text, - context, - model, - }); - } - } -} diff --git a/src/tasks/HuggingFaceLocalTasks.ts b/src/tasks/HuggingFaceLocalTasks.ts deleted file mode 100644 index 87ecf8208..000000000 --- a/src/tasks/HuggingFaceLocalTasks.ts +++ /dev/null @@ -1,331 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { Model, ModelProcessorEnum, ModelUseCaseEnum } from "#/Model"; -import { StreamableTaskType, Task, TaskConfig } from "#/Task"; -import { - pipeline, - type PipelineType, - type FeatureExtractionPipeline, - type TextGenerationPipeline, - type TextGenerationSingle, - type SummarizationPipeline, - type SummarizationSingle, - type QuestionAnsweringPipeline, - type DocumentQuestionAnsweringSingle, - env, -} from "@sroussey/transformers"; - -env.backends.onnx.logLevel = "error"; -env.backends.onnx.debug = false; - -export class ONNXTransformerJsModel extends Model { - constructor( - name: string, - useCase: ModelUseCaseEnum[], - public pipeline: string, - options?: Partial> - ) { - super(name, useCase, options); - } - readonly type = ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS; -} - -/** - * - * This is a helper function to get a pipeline for a model and assign a - * progress callback to the task. - * - * @param task - * @param model - * @param options - */ -const getPipeline = async ( - task: Task, - model: ONNXTransformerJsModel, - { quantized, config }: { quantized: boolean; config: any } = { - quantized: true, - config: null, - } -) => { - return await pipeline(model.pipeline as PipelineType, model.name, { - quantized, - config, - progress_callback: (details: { - file: string; - status: string; - name: string; - progress: number; - loaded: number; - total: number; - }) => { - const { progress, file } = details; - task.progress = progress; - task.emit("progress", progress, file); - }, - }); -}; - -// =============================================================================== - -interface DownloadTaskInput { - model: ONNXTransformerJsModel; -} -export class HuggingFaceLocal_DownloadTask extends Task { - declare input: DownloadTaskInput; - declare defaults: Partial; - readonly type: StreamableTaskType = "DownloadTask"; - constructor(config: TaskConfig = {}, defaults: DownloadTaskInput) { - config.name ||= `Downloading ${defaults.model.name}`; - super(config, defaults); - } - - public async run(overrides?: Partial) { - this.input = this.withDefaults(overrides); - try { - this.emit("start"); - await getPipeline(this, this.input.model); - this.emit("complete"); - } catch (e) { - this.emit("error", String(e)); - } - return this.output; - } -} - -// =============================================================================== - -interface EmbeddingTaskInput { - text: string; - model: ONNXTransformerJsModel; -} -/** - * This is a task that generates an embedding for a single piece of text - * - * Model pipeline must be "feature-extraction" - */ -export class HuggingFaceLocal_EmbeddingTask extends Task { - declare input: EmbeddingTaskInput; - declare defaults: Partial; - readonly type: StreamableTaskType = "EmbeddingTask"; - constructor(config: TaskConfig = {}, defaults: EmbeddingTaskInput) { - config.name ||= `Embedding content via ${defaults.model.name}`; - config.output_name ||= "vector"; - super(config, defaults); - } - - public async run(overrides?: Partial) { - this.input = this.withDefaults(overrides); - - this.emit("start"); - - const generateEmbedding = (await getPipeline( - this, - this.input.model - )) as FeatureExtractionPipeline; - - var vector = await generateEmbedding(this.input.text, { - pooling: "mean", - normalize: this.input.model.normalize, - }); - - if (vector.size !== this.input.model.dimensions) { - this.emit( - "error", - `Embedding vector length does not match model dimensions v${vector.size} != m${this.input.model.dimensions}` - ); - } else { - this.output = { [this.config.output_name]: vector.data }; - this.emit("complete"); - } - return this.output; - } -} - -// =============================================================================== - -interface TextGenerationTaskInput { - text: string; - model: ONNXTransformerJsModel; -} -abstract class TextGenerationTaskBase extends Task { - declare input: TextGenerationTaskInput; - constructor(config: TaskConfig = {}, input: TextGenerationTaskInput) { - config.name ||= `Text generation content via ${input.model.name} : ${input.model.pipeline}`; - config.output_name ||= "text"; - super(config, input); - } -} - -// =============================================================================== - -/** - * This generates text from a prompt - * - * Model pipeline must be "text-generation" or "text2text-generation" - */ -export class HuggingFaceLocal_TextGenerationTask extends TextGenerationTaskBase { - readonly type: StreamableTaskType = "TextGenerationTask"; - public async run(overrides?: Partial) { - this.input = this.withDefaults(overrides); - - this.emit("start"); - - const generateText = (await getPipeline( - this, - this.input.model - )) as TextGenerationPipeline; - - let results = await generateText(this.input.text); - if (!Array.isArray(results)) { - results = [results]; - } - - this.output = { - [this.config.output_name]: (results[0] as TextGenerationSingle) - ?.generated_text, - }; - this.emit("complete"); - return this.output; - } -} - -// =============================================================================== - -interface RewriterTaskInput { - text: string; - prompt: string; - model: ONNXTransformerJsModel; -} - -/** - * This is a special case of text generation that takes a prompt and text to rewrite - * - * Model pipeline must be "text-generation" or "text2text-generation" - */ -export class HuggingFaceLocal_TextRewriterTask extends TextGenerationTaskBase { - declare input: RewriterTaskInput; - readonly type: StreamableTaskType = "RewriterTask"; - constructor(config: TaskConfig = {}, input: RewriterTaskInput) { - const { model } = input; - config.name ||= `Text to text rewriting content via ${model.name} : ${model.pipeline}`; - config.output_name ||= "text"; - super(config, input); - } - - public async run(overrides?: Partial) { - this.input = this.withDefaults(overrides); - this.emit("start"); - - const generateText = (await getPipeline( - this, - this.input.model - )) as TextGenerationPipeline; - - // This lib doesn't support this kind of rewriting with a separate prompt vs text - const promptedtext = - (this.input.prompt ? this.input.prompt + "\n" : "") + this.input.text; - let results = await generateText(promptedtext); - if (!Array.isArray(results)) { - results = [results]; - } - - const text = (results[0] as TextGenerationSingle)?.generated_text; - if (text == promptedtext) { - this.output = {}; - this.emit("error", "Rewriter failed to generate new text"); - } else { - this.output = { [this.config.output_name]: text }; - this.emit("complete"); - } - - return this.output; - } -} - -// =============================================================================== - -/** - * This is a special case of text generation that takes a context and a question - * - * Model pipeline must be "summarization" - */ - -export class HuggingFaceLocal_SummarizationTask extends TextGenerationTaskBase { - readonly type: StreamableTaskType = "SummarizeTask"; - public async run(overrides?: Partial) { - this.emit("start"); - - this.input = this.withDefaults(overrides); - - const generateSummary = (await getPipeline( - this, - this.input.model - )) as SummarizationPipeline; - - let results = await generateSummary(this.input.text); - if (!Array.isArray(results)) { - results = [results]; - } - - this.output = { - [this.config.output_name]: (results[0] as SummarizationSingle) - ?.summary_text, - }; - this.emit("complete"); - return this.output; - } -} - -// =============================================================================== - -interface QuestionAnswerTaskInput { - text: string; - context: string; - model: ONNXTransformerJsModel; - topk?: number; -} -/** - * This is a special case of text generation that takes a context and a question - * - * Model pipeline must be "question-answering" - */ -export class HuggingFaceLocal_QuestionAnswerTask extends TextGenerationTaskBase { - declare input: QuestionAnswerTaskInput; - readonly type: StreamableTaskType = "QuestionAnswerTask"; - constructor(config: TaskConfig = {}, input: QuestionAnswerTaskInput) { - config.name = - config.name || `Question and Answer content via ${input.model.name}`; - config.output_name ||= "text"; - super(config, input); - } - - public async run(overrides?: Partial) { - this.emit("start"); - - this.input = this.withDefaults(overrides); - - const generateAnswer = (await getPipeline( - this, - this.input.model - )) as QuestionAnsweringPipeline; - - let results = await generateAnswer(this.input.text, this.input.context, { - topk: this.input.topk ?? 1, - }); - if (!Array.isArray(results)) { - results = [results]; - } - - this.output = { - [this.config.output_name]: (results[0] as DocumentQuestionAnsweringSingle) - ?.answer, - }; - this.emit("complete"); - return this.output; - } -} diff --git a/src/tasks/JsonTask.ts b/src/tasks/JsonTask.ts deleted file mode 100644 index 30c0c2572..000000000 --- a/src/tasks/JsonTask.ts +++ /dev/null @@ -1,184 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { - ParallelTaskList, - SerialStrategy, - SerialTaskList, - StreamableTaskType, - TaskConfig, - TaskStreamable, -} from "#/Task"; -import { findModelByName } from "#/storage/InMemoryStorage"; -import { RenameTask, RenameTaskInput } from "./BasicTasks"; -import { - EmbeddingTask, - EmbeddingTaskInput, - QuestionAnswerTask, - QuestionAnswerTaskInput, - RewriterTask, - RewriterTaskInput, - SummarizeTask, - TextGenerationTask, - TextGenerationTaskInput, -} from "./FactoryTasks"; -import { - EmbeddingStrategy, - EmbeddingStrategyInput, - RewriterEmbeddingStrategy, - RewriterEmbeddingStrategyInput, - RewriterStrategy, - RewriterStrategyInput, - SummarizeStrategy, - SummarizeStrategyInput, -} from "./Strategies"; - -const AllRegisteredTasks = new Map(); - -AllRegisteredTasks.set("SerialTaskList", SerialTaskList); -AllRegisteredTasks.set("ParallelTaskList", ParallelTaskList); - -AllRegisteredTasks.set("RenameTask", RenameTask); - -AllRegisteredTasks.set("EmbeddingTask", EmbeddingTask); -AllRegisteredTasks.set("RewriterTask", RewriterTask); -AllRegisteredTasks.set("TextGenerationTask", TextGenerationTask); -AllRegisteredTasks.set("SummarizeTask", SummarizeTask); -AllRegisteredTasks.set("QuestionAnswerTask", QuestionAnswerTask); - -AllRegisteredTasks.set("EmbeddingStrategy", EmbeddingStrategy); -AllRegisteredTasks.set("RewriterStrategy", RewriterStrategy); -AllRegisteredTasks.set("SummarizeStrategy", SummarizeStrategy); -AllRegisteredTasks.set("RewriterEmbeddingStrategy", RewriterEmbeddingStrategy); - -type TaskListJsonInput = { - run: "SerialTaskList" | "ParallelTaskList"; - config?: TaskConfig; - tasks: TaskJsonInput[]; -}; - -type SimpleTasks = { - run: "RenameTask"; - config?: TaskConfig; - input: RenameTaskInput; -}; - -type ChangeToString = { - [P in keyof T]: P extends K[number] - ? string - : T[P] extends object - ? ChangeToString - : T[P]; -}; -type ChangeToStringArray = { - [P in keyof T]: P extends K[number] - ? string | string[] - : T[P] extends object - ? ChangeToString - : T[P]; -}; -type FactoryHelper = { - run: R; - config?: TaskConfig; - input: ChangeToString; -}; - -type StrategyHelper = { - run: R; - config?: TaskConfig; - input: ChangeToStringArray< - T, - ["model", "models", "prompt_model", "embed_model"] - >; -}; - -type FactoryTasksJsonInput = - | FactoryHelper - | FactoryHelper - | FactoryHelper - | FactoryHelper - | FactoryHelper; - -type StrategyJSONInput = - | StrategyHelper - | StrategyHelper - | StrategyHelper - | StrategyHelper; - -export type TaskJsonInput = - | StrategyJSONInput - | TaskListJsonInput - | SimpleTasks - | FactoryTasksJsonInput; - -function makeArrayOfModel(model: string | string[] | undefined) { - if (!model) return undefined; - const modelstrs = Array.isArray(model) ? model : [model]; - const models = modelstrs.map((s) => { - const found = findModelByName(s); - if (!found) throw new Error(`Model not found: ${s}`); - return found; - }); - return models; -} -function convertJson(json: TaskJsonInput): TaskStreamable { - const { run, config } = json; - const runTask = AllRegisteredTasks.get(run); - if (!runTask) throw new Error("Task not found"); - let result: TaskStreamable; - if (run == "SerialTaskList" || run == "ParallelTaskList") { - const tasks = json.tasks.map(convertJson); - result = new runTask(config, tasks); - } else if ( - run == "EmbeddingTask" || - run == "RewriterTask" || - run == "SummarizeTask" || - run == "TextGenerationTask" || - run == "QuestionAnswerTask" - ) { - const input = json.input; - const model = findModelByName(input.model); - if (!model) throw new Error(`Model not found: ${input.model}`); - result = new runTask(config, { ...input, model }); - } else if (run == "RewriterStrategy") { - const input = json.input; - const model = makeArrayOfModel(input.model); - result = new runTask(config, { ...input, model }); - } else if (run == "RewriterEmbeddingStrategy") { - const input = json.input; - const embed_model = makeArrayOfModel(input.embed_model); - const prompt_model = makeArrayOfModel(input.prompt_model); - result = new runTask(config, { ...input, embed_model, prompt_model }); - } else if (run == "RenameTask") { - result = new runTask(config, json.input); - } else { - throw new Error(`Unknown task type: ${run}`); - } - return result; -} - -export class JsonStrategy extends SerialStrategy { - declare input: { tasks: TaskJsonInput[] }; - readonly type: StreamableTaskType = "JsonStrategy"; - constructor( - config: TaskConfig = {}, - defaults?: TaskJsonInput | TaskJsonInput[] - ) { - const tasks = Array.isArray(defaults) ? defaults : [defaults]; - super(config, { tasks }); - } - generateTasks() { - let tasks: TaskStreamable[]; - try { - tasks = this.input.tasks.map(convertJson); - } catch (e) { - throw new Error(`Error converting json: ${String(e)}`); - } - if (!tasks) throw new Error("Task not found"); - this.setTasks(tasks); - } -} diff --git a/src/tasks/MediaPipeLocalTasks.ts b/src/tasks/MediaPipeLocalTasks.ts deleted file mode 100644 index 17538c550..000000000 --- a/src/tasks/MediaPipeLocalTasks.ts +++ /dev/null @@ -1,111 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -import { Model, ModelProcessorEnum, ModelUseCaseEnum } from "#/Model"; -import { StreamableTaskType, Task, TaskConfig } from "#/Task"; -import { FilesetResolver, TextEmbedder } from "@mediapipe/tasks-text"; - -export class MediaPipeTfJsModel extends Model { - constructor( - name: string, - useCase: ModelUseCaseEnum[], - public url: string, - options?: Partial< - Pick - > - ) { - super(name, useCase, options); - } - readonly type = ModelProcessorEnum.MEDIA_PIPE_TFJS_MODEL; -} - -// =============================================================================== - -interface DownloadTaskInput { - model: MediaPipeTfJsModel; -} -export class MediaPipeTfJsLocal_DownloadTask extends Task { - declare input: DownloadTaskInput; - declare defaults: Partial; - readonly type: StreamableTaskType = "DownloadTask"; - constructor(config: TaskConfig = {}, defaults: DownloadTaskInput) { - config.name ||= `Downloading ${defaults.model.name}`; - super(config, defaults); - } - - public async run(overrides?: Partial) { - this.input = this.withDefaults(overrides); - try { - this.emit("start"); - const textFiles = await FilesetResolver.forTextTasks( - "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm/" - ); - await TextEmbedder.createFromOptions(textFiles, { - baseOptions: { - modelAssetPath: this.input.model.url, - }, - quantize: true, - }); - this.emit("complete"); - } catch (e) { - this.emit("error", String(e)); - } - return this.output; - } -} - -// =============================================================================== - -interface EmbeddingTaskInput { - text: string; - model: MediaPipeTfJsModel; -} -/** - * This is a task that generates an embedding for a single piece of text - * - * Model pipeline must be "feature-extraction" - */ -export class MediaPipeTfJsLocal_EmbeddingTask extends Task { - declare input: EmbeddingTaskInput; - declare defaults: Partial; - readonly type: StreamableTaskType = "EmbeddingTask"; - constructor(config: TaskConfig = {}, defaults: EmbeddingTaskInput) { - config.name ||= `Embedding content via ${defaults.model.name}`; - config.output_name ||= "vector"; - super(config, defaults); - } - - public async run(overrides?: Partial) { - this.input = this.withDefaults(overrides); - - this.emit("start"); - - const textFiles = await FilesetResolver.forTextTasks( - "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-text@latest/wasm/" - ); - const textEmbedder = await TextEmbedder.createFromOptions(textFiles, { - baseOptions: { - modelAssetPath: this.input.model.url, - }, - quantize: true, - }); - - const output = textEmbedder.embed(this.input.text); - const vector = output.embeddings[0].floatEmbedding; - - if (vector?.length !== this.input.model.dimensions) { - this.emit( - "error", - `Embedding vector length does not match model dimensions v${vector?.length} != m${this.input.model.dimensions}` - ); - } else { - this.output = { [this.config.output_name]: vector }; - this.emit("complete"); - } - return this.output; - } -} diff --git a/src/tasks/Strategies.ts b/src/tasks/Strategies.ts deleted file mode 100644 index 7968fc1b1..000000000 --- a/src/tasks/Strategies.ts +++ /dev/null @@ -1,166 +0,0 @@ -// ******************************************************************************* -// * ELLMERS: Embedding Large Language Model Experiential Retrieval Service * -// * * -// * Copyright Steven Roussey * -// * Licensed under the Apache License, Version 2.0 (the "License"); * -// ******************************************************************************* - -/* - -TODO: Still need to save the task tree to disc to restore later, though the JSON -strategy might be the thing to use. - -*/ - -import { Model } from "#/Model"; -import { - ParallelStrategy, - ParallelTaskList, - SerialTaskList, - StreamableTaskType, - TaskConfig, - TaskStream, -} from "#/Task"; -import { forceArray } from "#/util/Misc"; -import { EmbeddingTask, RewriterTask, SummarizeTask } from "./FactoryTasks"; - -export interface EmbeddingStrategyInput { - text: string; - models: Model[]; -} -export class EmbeddingStrategy extends ParallelStrategy { - declare input: EmbeddingStrategyInput; - readonly type: StreamableTaskType = "EmbeddingStrategy"; - constructor(config: TaskConfig = {}, defaults?: EmbeddingStrategyInput) { - super(config, defaults); - } - - generateTasks() { - const tasks = this.input.models.map( - (model) => new EmbeddingTask({}, { text: this.input.text, model }) - ); - - this.setTasks(tasks); - } -} - -export interface SummarizeStrategyInput { - text: string; - models: Model[]; -} -export class SummarizeStrategy extends ParallelStrategy { - declare input: SummarizeStrategyInput; - readonly type: StreamableTaskType = "SummarizeStrategy"; - - constructor(config: TaskConfig = {}, defaults?: SummarizeStrategyInput) { - super(config, defaults); - } - - generateTasks() { - const tasks = this.input.models.map( - (model) => new SummarizeTask({}, { text: this.input.text, model }) - ); - this.setTasks(tasks); - } -} - -export interface RewriterStrategyInput { - text: string; - prompt?: string | string[]; - model?: Model | Model[]; - prompt_model_pair?: { prompt: string; model: Model }[]; -} -export class RewriterStrategy extends ParallelStrategy { - declare input: RewriterStrategyInput; - readonly type: StreamableTaskType = "RewriterStrategy"; - - constructor(config: TaskConfig = {}, defaults?: RewriterStrategyInput) { - super(config, defaults); - } - - generateTasks() { - const name = this.config.name || `Vary Rewriter content`; - const { text, prompt_model_pair, model, prompt } = this.input; - let pairs: { prompt: string; model: Model }[] = []; - if (prompt_model_pair) { - pairs = forceArray(prompt_model_pair); - } else { - if (!prompt || !model) throw new Error("Invalid input"); - const models = forceArray(model); - const prompts = forceArray(prompt); - for (const model of models) { - for (const prompt of prompts) { - pairs.push({ prompt, model }); - } - } - } - const tasks = pairs.map( - ({ prompt, model }) => new RewriterTask({}, { text, prompt, model }) - ); - - this.setTasks(tasks); - } -} - -export interface RewriterEmbeddingStrategyInput { - text: string; - prompt?: string | string[]; - prompt_model?: Model | Model[]; - embed_model?: Model | Model[]; - prompt_model_tuple?: { - prompt: string; - prompt_model: Model; - embed_model: Model; - }[]; -} - -export class RewriterEmbeddingStrategy extends ParallelStrategy { - declare input: RewriterEmbeddingStrategyInput; - readonly type: StreamableTaskType = "RewriterEmbeddingStrategy"; - constructor(config: TaskConfig, defaults: RewriterEmbeddingStrategyInput) { - super(config, defaults); - } - - generateTasks() { - const name = this.config.name || `RewriterEmbeddingStrategy`; - const { text, prompt_model_tuple, prompt, embed_model, prompt_model } = - this.input; - let tasks: TaskStream = []; - if (prompt_model_tuple) { - const tuples = forceArray(prompt_model_tuple); - tasks = tuples.map(({ prompt, prompt_model, embed_model }) => { - return new SerialTaskList({ name }, [ - new RewriterTask( - { name: name + " Rewriter" }, - { text, prompt, model: prompt_model } - ), - new EmbeddingTask( - { name: name + " Embedding" }, - { text, model: embed_model } - ), - ]); - }); - } else { - if (!prompt || !prompt_model || !embed_model) - throw new Error("Invalid input"); - const prompt_models = forceArray(prompt_model); - const embed_models = forceArray(embed_model); - const prompts = forceArray(prompt); - for (const prompt of prompts) { - tasks.push( - new ParallelTaskList({ name }, [ - new RewriterStrategy( - { name: name + " Rewriter" }, - { text, prompt, model: prompt_models } - ), - new EmbeddingStrategy( - { name: name + " Embedding" }, - { text, models: embed_models } - ), - ]) - ); - } - } - this.setTasks(tasks); - } -} diff --git a/tsconfig.json b/tsconfig.json index 216523716..05206ff66 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -6,19 +6,15 @@ "moduleResolution": "bundler", "moduleDetection": "force", "allowImportingTsExtensions": true, - "noEmit": true, "composite": true, "strict": true, "downlevelIteration": true, "skipLibCheck": true, - "jsx": "react-jsx", "allowSyntheticDefaultImports": true, "forceConsistentCasingInFileNames": true, "allowJs": true, - "types": ["bun-types"], - "baseUrl": "./src", - "paths": { - "#/*": ["./*"] - } + "declaration": true, + "emitDeclarationOnly": true, + "types": ["bun-types"] } }