Skip to content

Commit

Permalink
Merge pull request #23 from sroussey/remove-type-lookups
Browse files Browse the repository at this point in the history
refactor: Unroll the auto type system and make explicit
  • Loading branch information
sroussey authored Jan 31, 2025
2 parents 5cc1c5f + 56c34cc commit 37c8922
Show file tree
Hide file tree
Showing 44 changed files with 407 additions and 426 deletions.
36 changes: 20 additions & 16 deletions docs/developers/01_getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,22 +131,24 @@ And unrolling the config helpers, we get the following equivalent code:

```ts
import {
DownloadModelTask,
TextRewriterCompoundTask,
DebugLog,
DataFlow,
TaskGraph,
TaskGraphRunner,
getProviderRegistry,
HuggingFaceLocal_DownloadRun,
HuggingFaceLocal_TextRewriterRun,
InMemoryJobQueue,
ModelProcessorEnum,
ConcurrencyLimiter,
TaskInput,
TaskOutput,
} from "ellmers-core";

import { DownloadModelTask, TextRewriterCompoundTask, getAiProviderRegistry } from "ellmers-ai";

import {
HuggingFaceLocal_DownloadRun,
HuggingFaceLocal_TextRewriterRun,
} from "ellmers-ai-provider/hf-transformers";

import { InMemoryJobQueue } from "ellmers-storage/inmemory";

// config and start up
getGlobalModelRepository(new InMemoryModelRepository());
await getGlobalModelRepository().addModel({
Expand All @@ -166,7 +168,7 @@ await getGlobalModelRepository().connectTaskToModel(
"ONNX Xenova/LaMini-Flan-T5-783M q8"
);

const ProviderRegistry = getProviderRegistry();
const ProviderRegistry = getAiProviderRegistry();
ProviderRegistry.registerRunFn(
DownloadModelTask.type,
ModelProcessorEnum.LOCAL_ONNX_TRANSFORMERJS,
Expand All @@ -187,7 +189,9 @@ jobQueue.start();

// build and run graph
const graph = new TaskGraph();
graph.addTask(new DownloadModel({ id: "1", input: { model: "Xenova/LaMini-Flan-T5-783M" } }));
graph.addTask(
new DownloadModel({ id: "1", input: { model: "ONNX Xenova/LaMini-Flan-T5-783M q8" } })
);
graph.addTask(
new TextRewriterCompoundTask({
id: "2",
Expand Down Expand Up @@ -256,7 +260,7 @@ An example is TextEmbeddingTask and TextEmbeddingCompoundTask. The first takes a
import { TaskGraphBuilder } from "ellmers-core";
const builder = new TaskGraphBuilder();
builder.TextEmbedding({
model: "Xenova/LaMini-Flan-T5-783M",
model: "ONNX Xenova/LaMini-Flan-T5-783M q8",
text: "The quick brown fox jumps over the lazy dog.",
});
await builder.run();
Expand All @@ -268,7 +272,7 @@ OR
import { TaskGraphBuilder } from "ellmers-core";
const builder = new TaskGraphBuilder();
builder.TextEmbedding({
model: ["Xenova/LaMini-Flan-T5-783M", "Universal Sentence Encoder"],
model: ["ONNX Xenova/LaMini-Flan-T5-783M q8", "Universal Sentence Encoder"],
text: "The quick brown fox jumps over the lazy dog.",
});
await builder.run();
Expand All @@ -281,7 +285,7 @@ import { TaskGraphBuilder } from "ellmers-core";
const builder = new TaskGraphBuilder();
builder
.DownloadModel({
model: ["Xenova/LaMini-Flan-T5-783M", "Universal Sentence Encoder"],
model: ["ONNX Xenova/LaMini-Flan-T5-783M q8", "Universal Sentence Encoder"],
})
.TextEmbedding({
text: "The quick brown fox jumps over the lazy dog.",
Expand Down Expand Up @@ -316,7 +320,7 @@ There is a JSONTask that can be used to build a graph. This is useful for saving
"dependencies": {
"model": {
"id": "1",
"output": "text_generation_model"
"output": "generation_model"
}
}
},
Expand Down Expand Up @@ -376,7 +380,7 @@ To use a task, instantiate it with some input and call `run()`:
const task = new TextEmbeddingTask({
id: "1",
input: {
model: "Xenova/LaMini-Flan-T5-783M",
model: "ONNX Xenova/LaMini-Flan-T5-783M q8",
text: "The quick brown fox jumps over the lazy dog.",
},
});
Expand All @@ -397,7 +401,7 @@ const graph = new TaskGraph();
graph.addTask(
new TextRewriterCompoundTask({
input: {
model: "Xenova/LaMini-Flan-T5-783M",
model: "ONNX Xenova/LaMini-Flan-T5-783M q8",
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
},
Expand All @@ -417,7 +421,7 @@ graph.addTask(
new TextRewriterCompoundTask({
id: "1",
input: {
model: "Xenova/LaMini-Flan-T5-783M",
model: "ONNX Xenova/LaMini-Flan-T5-783M q8",
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
},
Expand Down
4 changes: 2 additions & 2 deletions examples/cli/src/ellmers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import { program } from "commander";
import { argv } from "process";
import { AddBaseCommands } from "./TaskCLI";
import { getProviderRegistry } from "ellmers-ai";
import { getAiProviderRegistry } from "ellmers-ai";
import {
registerHuggingfaceLocalModels,
registerHuggingfaceLocalTasksInMemory,
Expand All @@ -24,4 +24,4 @@ registerMediaPipeTfJsLocalInMemory();

await program.parseAsync(argv);

getProviderRegistry().stopQueues();
getAiProviderRegistry().stopQueues();
8 changes: 4 additions & 4 deletions examples/web/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { QueuesStatus } from "./QueueSatus";
import { OutputRepositoryStatus } from "./OutputRepositoryStatus";
import { GraphStoreStatus } from "./GraphStoreStatus";
import { InMemoryJobQueue } from "ellmers-storage/inmemory";
import { getProviderRegistry } from "ellmers-ai";
import { getAiProviderRegistry } from "ellmers-ai";
import {
LOCAL_ONNX_TRANSFORMERJS,
registerHuggingfaceLocalTasks,
Expand All @@ -38,7 +38,7 @@ import { env } from "@huggingface/transformers";
env.backends.onnx.wasm.proxy = true;
env.allowLocalModels = true;

const ProviderRegistry = getProviderRegistry();
const ProviderRegistry = getAiProviderRegistry();

registerHuggingfaceLocalTasks();
ProviderRegistry.registerQueue(
Expand Down Expand Up @@ -77,8 +77,8 @@ const resetGraph = () => {
})
.TextTranslation({
model: "ONNX Xenova/m2m100_418M q8",
source: "en",
target: "es",
source_lang: "en",
target_lang: "es",
})
.rename("text", "message")
.rename("text", "message", -2)
Expand Down
6 changes: 3 additions & 3 deletions examples/web/src/QueueSatus.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { JobStatus } from "ellmers-core";
import { getProviderRegistry } from "ellmers-ai";
import { getAiProviderRegistry } from "ellmers-ai";
import { useCallback, useEffect, useState } from "react";

export function QueueStatus({ queueType }: { queueType: string }) {
const queue = getProviderRegistry().getQueue(queueType);
const queue = getAiProviderRegistry().getQueue(queueType);
const [pending, setPending] = useState<number>(0);
const [processing, setProcessing] = useState<number>(0);
const [completed, setCompleted] = useState<number>(0);
Expand Down Expand Up @@ -51,7 +51,7 @@ export function QueueStatus({ queueType }: { queueType: string }) {
}

export function QueuesStatus() {
const queues = getProviderRegistry().queues;
const queues = getAiProviderRegistry().queues;
const queueKeys = Array.from(queues.keys());

return (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import {
getProviderRegistry,
getAiProviderRegistry,
DownloadModelTask,
TextEmbeddingTask,
TextGenerationTask,
Expand All @@ -20,7 +20,7 @@ import {
import { LOCAL_ONNX_TRANSFORMERJS } from "../model/ONNXTransformerJsModel";

export async function registerHuggingfaceLocalTasks() {
const ProviderRegistry = getProviderRegistry();
const ProviderRegistry = getAiProviderRegistry();

ProviderRegistry.registerRunFn(
DownloadModelTask.type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ import {
type TranslationSingle,
TextStreamer,
} from "@huggingface/transformers";
import { ElVector } from "ellmers-core";
import { getGlobalModelRepository } from "ellmers-ai";
import { ElVector, getGlobalModelRepository } from "ellmers-ai";
import type {
JobQueueLlmTask,
DownloadModelTask,
Expand Down Expand Up @@ -229,8 +228,8 @@ export async function HuggingFaceLocal_TextTranslationRun(
});

let results = await translate(runInputData.text, {
src_lang: runInputData.source,
tgt_lang: runInputData.target,
src_lang: runInputData.source_lang,
tgt_lang: runInputData.target_lang,
streamer,
} as any);
if (!Array.isArray(results)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import { describe, expect, it } from "bun:test";
import { ConcurrencyLimiter, TaskGraphBuilder, TaskInput, TaskOutput } from "ellmers-core";
import {
getProviderRegistry,
getAiProviderRegistry,
getGlobalModelRepository,
setGlobalModelRepository,
} from "ellmers-ai";
Expand All @@ -23,7 +23,7 @@ const HFQUEUE = "local_hf";
describe("HFTransformersBinding", () => {
describe("InMemoryJobQueue", () => {
it("Should have an item queued", async () => {
const providerRegistry = getProviderRegistry();
const providerRegistry = getAiProviderRegistry();
const jobQueue = new InMemoryJobQueue<TaskInput, TaskOutput>(
HFQUEUE,
new ConcurrencyLimiter(1, 10),
Expand Down Expand Up @@ -85,7 +85,7 @@ describe("HFTransformersBinding", () => {
"TextRewritingTask",
"ONNX Xenova/LaMini-Flan-T5-783M q8"
);
const providerRegistry = getProviderRegistry();
const providerRegistry = getAiProviderRegistry();
const jobQueue = new SqliteJobQueue<TaskInput, TaskOutput>(
getDatabase(":memory:"),
HFQUEUE,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getProviderRegistry } from "ellmers-ai";
import { getAiProviderRegistry } from "ellmers-ai";
import { DownloadModelTask, TextEmbeddingTask } from "ellmers-ai";
import {
MediaPipeTfJsLocal_Download,
Expand All @@ -7,7 +7,7 @@ import {
import { MEDIA_PIPE_TFJS_MODEL } from "..";

export const registerMediaPipeTfJsLocalTasks = () => {
const ProviderRegistry = getProviderRegistry();
const ProviderRegistry = getAiProviderRegistry();

ProviderRegistry.registerRunFn(
DownloadModelTask.type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
// *******************************************************************************

import { FilesetResolver, TextEmbedder } from "@mediapipe/tasks-text";
import { ElVector } from "ellmers-core";
import {
DownloadModelTask,
DownloadModelTaskInput,
TextEmbeddingTask,
TextEmbeddingTaskInput,
getGlobalModelRepository,
ElVector,
} from "ellmers-ai";

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

// import { describe, expect, it } from "bun:test";
// import { ConcurrencyLimiter, TaskGraphBuilder, TaskInput, TaskOutput } from "ellmers-core";
// import { getGlobalModelRepository, getProviderRegistry, Model } from "ellmers-ai";
// import { getGlobalModelRepository, getAiProviderRegistry, Model } from "ellmers-ai";
// import { InMemoryJobQueue } from "ellmers-storage/inmemory";
// import { SqliteJobQueue } from "../../../../storage/dist/bun/sqlite";
// import { registerMediaPipeTfJsLocalTasks } from "../bindings/registerTasks";
Expand Down Expand Up @@ -35,7 +35,7 @@
// universal_sentence_encoder.name
// );
// registerMediaPipeTfJsLocalTasks();
// const ProviderRegistry = getProviderRegistry();
// const ProviderRegistry = getAiProviderRegistry();
// const jobQueue = new InMemoryJobQueue<TaskInput, TaskOutput>(
// TFQUEUE,
// new ConcurrencyLimiter(1, 10),
Expand Down Expand Up @@ -74,7 +74,7 @@
// universal_sentence_encoder.name
// );
// registerMediaPipeTfJsLocalTasks();
// const ProviderRegistry = getProviderRegistry();
// const ProviderRegistry = getAiProviderRegistry();
// const jobQueue = new SqliteJobQueue<TaskInput, TaskOutput>(
// getDatabase(":memory:"),
// TFQUEUE,
Expand Down
5 changes: 4 additions & 1 deletion packages/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ export * from "./task";
export * from "./model/Model";
export * from "./model/ModelRegistry";
export * from "./model/ModelRepository";
export * from "./provider/ProviderRegistry";
export * from "./source/Document";
export * from "./source/DocumentConverterText";
export * from "./source/DocumentConverterMarkdown";
export * from "./provider/AiProviderRegistry";
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ProviderJob<Input, Output> extends Job<Input, Output> {
* Handles the registration, retrieval, and execution of task processing functions
* for different model providers and task types.
*/
export class ProviderRegistry<Input, Output> {
export class AiProviderRegistry<Input, Output> {
// Registry of task execution functions organized by task type and model provider
runFnRegistry: Record<string, Record<string, (task: any, runInputData: any) => Promise<Output>>> =
{};
Expand Down Expand Up @@ -135,11 +135,11 @@ export class ProviderRegistry<Input, Output> {
}

// Singleton instance management for the ProviderRegistry
let providerRegistry: ProviderRegistry<TaskInput, TaskOutput>;
export function getProviderRegistry() {
if (!providerRegistry) providerRegistry = new ProviderRegistry();
let providerRegistry: AiProviderRegistry<TaskInput, TaskOutput>;
export function getAiProviderRegistry() {
if (!providerRegistry) providerRegistry = new AiProviderRegistry();
return providerRegistry;
}
export function setProviderRegistry(pr: ProviderRegistry<TaskInput, TaskOutput>) {
export function setAiProviderRegistry(pr: AiProviderRegistry<TaskInput, TaskOutput>) {
providerRegistry = pr;
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ enum DocumentType {
TABLE = "table",
}

const doc_variants = [
"tree",
"flat",
"tree-paragraphs",
"flat-paragraphs",
"tree-sentences",
"flat-sentences",
] as const;
type DocVariant = (typeof doc_variants)[number];
const doc_parsers = ["txt", "md"] as const; // | "html" | "pdf" | "csv";
type DocParser = (typeof doc_parsers)[number];

export interface DocumentMetadata {
title: string;
}
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
21 changes: 9 additions & 12 deletions packages/ai/src/task/DocumentSplitterTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

import {
Document,
DocumentFragment,
SingleTask,
TaskGraphBuilder,
TaskGraphBuilderHelper,
CreateMappedType,
TaskRegistry,
} from "ellmers-core";

export type DocumentSplitterTaskInput = CreateMappedType<typeof DocumentSplitterTask.inputs>;
export type DocumentSplitterTaskOutput = CreateMappedType<typeof DocumentSplitterTask.outputs>;
import { SingleTask, TaskGraphBuilder, TaskGraphBuilderHelper, TaskRegistry } from "ellmers-core";
import { Document, DocumentFragment } from "../source/Document";
export type DocumentSplitterTaskInput = {
parser: "txt" | "md";
file: Document;
};
export type DocumentSplitterTaskOutput = {
texts: string[];
};

export class DocumentSplitterTask extends SingleTask {
static readonly type: string = "DocumentSplitterTask";
Expand Down
Loading

0 comments on commit 37c8922

Please sign in to comment.