Skip to content

Commit

Permalink
refactor: rename LOCAL_ONNX_TRANSFORMERJS to ONNX_TRANSFORMERJS
Browse files Browse the repository at this point in the history
  • Loading branch information
sroussey committed Mar 5, 2025
1 parent 11192a0 commit f4ecc27
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 38 deletions.
6 changes: 3 additions & 3 deletions docs/developers/01_getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ await getGlobalModelRepository().addModel({
url: "Xenova/LaMini-Flan-T5-783M",
availableOnBrowser: true,
availableOnServer: true,
provider: LOCAL_ONNX_TRANSFORMERJS,
provider: ONNX_TRANSFORMERJS,
pipeline: "text2text-generation",
});
await getGlobalModelRepository().connectTaskToModel(
Expand All @@ -171,12 +171,12 @@ await getGlobalModelRepository().connectTaskToModel(
const aiProviderRegistry = getAiProviderRegistry();
aiProviderRegistry.registerRunFn(
DownloadModelTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_DownloadRun
);
aiProviderRegistry.registerRunFn(
TextRewriterTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_TextRewriterRun
);
const jobQueue = new JobQueue<TaskInput, TaskOutput>("test", Job, {
Expand Down
4 changes: 2 additions & 2 deletions examples/web/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { env } from "@huggingface/transformers";
import { ReactFlowProvider } from "@xyflow/react";
import { AiJob } from "@ellmers/ai";
import {
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
registerHuggingfaceLocalTasks,
} from "@ellmers/ai-provider/hf-transformers";
import {
Expand Down Expand Up @@ -43,7 +43,7 @@ const queueRegistry = getTaskQueueRegistry();

registerHuggingfaceLocalTasks();
queueRegistry.registerQueue(
new JobQueue<TaskInput, TaskOutput>(LOCAL_ONNX_TRANSFORMERJS, AiJob<TaskInput, TaskOutput>, {
new JobQueue<TaskInput, TaskOutput>(ONNX_TRANSFORMERJS, AiJob<TaskInput, TaskOutput>, {
limiter: new ConcurrencyLimiter(1, 10),
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,50 +17,50 @@ import {
HuggingFaceLocal_TextSummaryRun,
HuggingFaceLocal_TextTranslationRun,
} from "../provider/HuggingFaceLocal_TaskRun";
import { LOCAL_ONNX_TRANSFORMERJS } from "../model/ONNXTransformerJsModel";
import { ONNX_TRANSFORMERJS } from "../model/ONNXTransformerJsModel";

export async function registerHuggingfaceLocalTasks() {
const ProviderRegistry = getAiProviderRegistry();

ProviderRegistry.registerRunFn(
DownloadModelTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_DownloadRun
);

ProviderRegistry.registerRunFn(
TextEmbeddingTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_EmbeddingRun
);

ProviderRegistry.registerRunFn(
TextGenerationTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_TextGenerationRun
);

ProviderRegistry.registerRunFn(
TextTranslationTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_TextTranslationRun
);

ProviderRegistry.registerRunFn(
TextRewriterTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_TextRewriterRun
);

ProviderRegistry.registerRunFn(
TextSummaryTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_TextSummaryRun
);

ProviderRegistry.registerRunFn(
TextQuestionAnswerTask.type,
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
HuggingFaceLocal_TextQuestionAnswerRun
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

export const LOCAL_ONNX_TRANSFORMERJS = "LOCAL_ONNX_TRANSFORMERJS";
export const ONNX_TRANSFORMERJS = "ONNX_TRANSFORMERJS";

export enum QUANTIZATION_DATA_TYPES {
auto = "auto", // Auto-detect based on environment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
} from "@ellmers/task-graph";
import { Sqlite, sleep } from "@ellmers/util";
import { registerHuggingfaceLocalTasks } from "../bindings/registerTasks";
import { LOCAL_ONNX_TRANSFORMERJS } from "../model/ONNXTransformerJsModel";
import { ONNX_TRANSFORMERJS } from "../model/ONNXTransformerJsModel";
import { AiProviderInput } from "@ellmers/ai";
import { SqliteQueueStorage } from "@ellmers/storage";

Expand All @@ -37,7 +37,7 @@ describe("HFTransformersBinding", () => {
registerHuggingfaceLocalTasks();
const queueRegistry = getTaskQueueRegistry();
const jobQueue = new JobQueue<AiProviderInput<TaskInput>, TaskOutput>(
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
AiJob<TaskInput, TaskOutput>,
{
limiter: new ConcurrencyLimiter(1, 10),
Expand All @@ -50,7 +50,7 @@ describe("HFTransformersBinding", () => {
url: "Xenova/LaMini-Flan-T5-783M",
availableOnBrowser: true,
availableOnServer: true,
provider: LOCAL_ONNX_TRANSFORMERJS,
provider: ONNX_TRANSFORMERJS,
pipeline: "text2text-generation",
};

Expand All @@ -59,9 +59,9 @@ describe("HFTransformersBinding", () => {
await getGlobalModelRepository().connectTaskToModel("TextGenerationTask", model.name);
await getGlobalModelRepository().connectTaskToModel("TextRewriterTask", model.name);

const queue = queueRegistry.getQueue(LOCAL_ONNX_TRANSFORMERJS);
const queue = queueRegistry.getQueue(ONNX_TRANSFORMERJS);
expect(queue).toBeDefined();
expect(queue!.queueName).toEqual(LOCAL_ONNX_TRANSFORMERJS);
expect(queue!.queueName).toEqual(ONNX_TRANSFORMERJS);

const workflow = new Workflow();
workflow.DownloadModel({
Expand All @@ -80,7 +80,7 @@ describe("HFTransformersBinding", () => {
registerHuggingfaceLocalTasks();
const queueRegistry = getTaskQueueRegistry();
const jobQueue = new JobQueue<AiProviderInput<TaskInput>, TaskOutput>(
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
AiJob<TaskInput, TaskOutput>,
{
storage: new SqliteQueueStorage<AiProviderInput<TaskInput>, TaskOutput>(db, "test"),
Expand All @@ -97,17 +97,17 @@ describe("HFTransformersBinding", () => {
url: "Xenova/LaMini-Flan-T5-783M",
availableOnBrowser: true,
availableOnServer: true,
provider: LOCAL_ONNX_TRANSFORMERJS,
provider: ONNX_TRANSFORMERJS,
pipeline: "text2text-generation",
};

await getGlobalModelRepository().addModel(model);
await getGlobalModelRepository().connectTaskToModel("TextGenerationTask", model.name);
await getGlobalModelRepository().connectTaskToModel("TextRewriterTask", model.name);

const queue = queueRegistry.getQueue(LOCAL_ONNX_TRANSFORMERJS);
const queue = queueRegistry.getQueue(ONNX_TRANSFORMERJS);
expect(queue).toBeDefined();
expect(queue?.queueName).toEqual(LOCAL_ONNX_TRANSFORMERJS);
expect(queue?.queueName).toEqual(ONNX_TRANSFORMERJS);

const workflow = new Workflow();
workflow.DownloadModel({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { getGlobalModelRepository } from "../../ModelRegistry";
import { ModelRepository } from "../../ModelRepository";
import { setGlobalModelRepository } from "../../ModelRegistry";

const LOCAL_ONNX_TRANSFORMERJS = "LOCAL_ONNX_TRANSFORMERJS";
const ONNX_TRANSFORMERJS = "ONNX_TRANSFORMERJS";

export const runGenericModelRepositoryTests = (
createRepository: () => Promise<ModelRepository>
Expand All @@ -32,7 +32,7 @@ export const runGenericModelRepositoryTests = (
url: "Xenova/LaMini-Flan-T5-783M",
availableOnBrowser: true,
availableOnServer: true,
provider: LOCAL_ONNX_TRANSFORMERJS,
provider: ONNX_TRANSFORMERJS,
pipeline: "text2text-generation",
});

Expand All @@ -50,7 +50,7 @@ export const runGenericModelRepositoryTests = (
url: "Xenova/LaMini-Flan-T5-783M",
availableOnBrowser: true,
availableOnServer: true,
provider: LOCAL_ONNX_TRANSFORMERJS,
provider: ONNX_TRANSFORMERJS,
pipeline: "text2text-generation",
});
await getGlobalModelRepository().connectTaskToModel(
Expand All @@ -76,7 +76,7 @@ export const runGenericModelRepositoryTests = (
url: "Xenova/LaMini-Flan-T5-783M",
availableOnBrowser: true,
availableOnServer: true,
provider: LOCAL_ONNX_TRANSFORMERJS,
provider: ONNX_TRANSFORMERJS,
pipeline: "text2text-generation",
});

Expand All @@ -89,7 +89,7 @@ export const runGenericModelRepositoryTests = (
expect(models).toBeDefined();
expect(models?.length).toEqual(1);
expect(models?.[0].name).toEqual("onnx:Xenova/LaMini-Flan-T5-783M:q8");
expect(models?.[0].provider).toEqual(LOCAL_ONNX_TRANSFORMERJS);
expect(models?.[0].provider).toEqual(ONNX_TRANSFORMERJS);
expect(models?.[0].pipeline).toEqual("text2text-generation");
});
};
8 changes: 3 additions & 5 deletions packages/test/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import { AiJob, AiProviderInput } from "@ellmers/ai";
import {
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
registerHuggingfaceLocalTasks,
} from "@ellmers/ai-provider/hf-transformers";
import {
Expand All @@ -23,12 +23,10 @@ export * from "./sample/ONNXModelSamples";
export async function registerHuggingfaceLocalTasksInMemory() {
registerHuggingfaceLocalTasks();
const jobQueue = new JobQueue<AiProviderInput<TaskInput>, TaskOutput>(
LOCAL_ONNX_TRANSFORMERJS,
ONNX_TRANSFORMERJS,
AiJob<TaskInput, TaskOutput>,
{
storage: new InMemoryQueueStorage<AiProviderInput<TaskInput>, TaskOutput>(
LOCAL_ONNX_TRANSFORMERJS
),
storage: new InMemoryQueueStorage<AiProviderInput<TaskInput>, TaskOutput>(ONNX_TRANSFORMERJS),
limiter: new ConcurrencyLimiter(1, 10),
}
);
Expand Down
7 changes: 2 additions & 5 deletions packages/test/src/sample/ONNXModelSamples.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import {
LOCAL_ONNX_TRANSFORMERJS,
QUANTIZATION_DATA_TYPES,
} from "@ellmers/ai-provider/hf-transformers";
import { ONNX_TRANSFORMERJS, QUANTIZATION_DATA_TYPES } from "@ellmers/ai-provider/hf-transformers";
import { getGlobalModelRepository, Model } from "@ellmers/ai";

export async function addONNXModel(info: Partial<Model>, tasks: string[]) {
const model = Object.assign(
{
name: "onnx:" + info.url + ":" + (info.quantization ?? QUANTIZATION_DATA_TYPES.q8),
provider: LOCAL_ONNX_TRANSFORMERJS,
provider: ONNX_TRANSFORMERJS,
quantization: QUANTIZATION_DATA_TYPES.q8,
normalize: true,
contextWindow: 4096,
Expand Down

0 comments on commit f4ecc27

Please sign in to comment.