Skip to content

Commit

Permalink
fix: Fixup after accepting PR too quickly
Browse files Browse the repository at this point in the history
- Modify task JSON serialization and deserialization methods (closes issue DRY Tasks from JSON #69)
- Rename default task ports to 'console' for consistency
- Refactor task graph creation from JSON with improved type safety
- Simplify task input and configuration handling
  • Loading branch information
sroussey committed Mar 3, 2025
1 parent 028af70 commit 08e007c
Show file tree
Hide file tree
Showing 17 changed files with 203 additions and 211 deletions.
69 changes: 29 additions & 40 deletions docs/developers/01_getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ workflow
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
})
.rename("text", "message")
.rename("text", "console")
.DebugLog();
await workflow.run();

Expand Down Expand Up @@ -89,19 +89,19 @@ registerHuggingfaceLocalTasksInMemory();

// build and run graph
const graph = new TaskGraph();
graph.addTask(new DownloadModel({ model: "onnx:Xenova/LaMini-Flan-T5-783M:q8" }, { id: "1" }));
graph.addTask(
new DownloadModel({ id: "1", input: { model: "onnx:Xenova/LaMini-Flan-T5-783M:q8" } })
);
graph.addTask(
new TextRewriterCompoundTask({
id: "2",
input: {
new TextRewriterCompoundTask(
{
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
},
})
{
id: "2",
}
)
);
graph.addTask(new DebugLog({ id: "3" }));
graph.addTask(new DebugLog({}, { id: "3" }));
graph.addDataflow(
new Dataflow({
sourceTaskId: "1",
Expand All @@ -115,7 +115,7 @@ graph.addDataflow(
sourceTaskId: "2",
sourceTaskPortId: "text",
targetTaskId: "3",
targetTaskPortId: "message",
targetTaskPortId: "console",
})
);

Expand Down Expand Up @@ -191,19 +191,17 @@ jobQueue.start();

// build and run graph
const graph = new TaskGraph();
graph.addTask(new DownloadModel({ model: "onnx:Xenova/LaMini-Flan-T5-783M:q8" }, { id: "1" }));
graph.addTask(
new DownloadModel({ id: "1", input: { model: "onnx:Xenova/LaMini-Flan-T5-783M:q8" } })
);
graph.addTask(
new TextRewriterCompoundTask({
id: "2",
input: {
new TextRewriterCompoundTask(
{
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
},
})
{ id: "2" }
)
);
graph.addTask(new DebugLog({ id: "3" }));
graph.addTask(new DebugLog({}, { id: "3" }));
graph.addDataflow(
new Dataflow({
sourceTaskId: "1",
Expand All @@ -217,7 +215,7 @@ graph.addDataflow(
sourceTaskId: "2",
sourceTaskPortId: "text",
targetTaskId: "3",
targetTaskPortId: "message",
targetTaskPortId: "console",
})
);

Expand Down Expand Up @@ -292,7 +290,7 @@ workflow
.TextEmbedding({
text: "The quick brown fox jumps over the lazy dog.",
});
.rename("vector", "message")
.rename("*", "console")
.DebugLog();
await workflow.run();
```
Expand Down Expand Up @@ -368,7 +366,7 @@ The JSON above is a good example as it shows how to use a compound task with mul
```ts
import { JSONTask } from "@ellmers/task-graph";
const json = require("./example.json");
const task = new JSONTask({ input: { json } });
const task = new JSONTask({ json });
await task.run();
```

Expand All @@ -380,11 +378,8 @@ To use a task, instantiate it with some input and call `run()`:

```ts
const task = new TextEmbeddingTask({
id: "1",
input: {
model: "onnx:Xenova/LaMini-Flan-T5-783M:q8",
text: "The quick brown fox jumps over the lazy dog.",
},
model: "onnx:Xenova/LaMini-Flan-T5-783M:q8",
text: "The quick brown fox jumps over the lazy dog.",
});
const result = await task.run();
console.log(result);
Expand All @@ -402,11 +397,9 @@ Example:
const graph = new TaskGraph();
graph.addTask(
new TextRewriterCompoundTask({
input: {
model: "onnx:Xenova/LaMini-Flan-T5-783M:q8",
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
},
model: "onnx:Xenova/LaMini-Flan-T5-783M:q8",
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
})
);
```
Expand All @@ -420,20 +413,16 @@ Example, adding a data flow to the graph similar to above:
```ts
const graph = new TaskGraph();
graph.addTask(
new TextRewriterCompoundTask({
id: "1",
input: {
new TextRewriterCompoundTask(
{
model: "onnx:Xenova/LaMini-Flan-T5-783M:q8",
text: "The quick brown fox jumps over the lazy dog.",
prompt: ["Rewrite the following text in reverse:", "Rewrite this to sound like a pirate:"],
},
})
);
graph.addTask(
new DebugLogTask({
id: "2",
})
{ id: "1" }
)
);
graph.addTask(new DebugLogTask({}, { id: "2" }));
graph.addDataflow(
new Dataflow({
sourceTaskId: "1",
Expand Down
10 changes: 5 additions & 5 deletions docs/developers/03_extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Here we will write an example of a simple Task that prints a message to the cons

```ts
export class SimpleDebugLogTask extends SimpleTask {
run() {
runFull() {
console.dir(<something>, { depth: null });
}
}
Expand Down Expand Up @@ -50,12 +50,12 @@ export class SimpleDebugLogTask extends SimpleTask {
] as const;
declare defaults: Partial<SimpleDebugLogTaskInputs>;
declare runInputData: SimpleDebugLogTaskInputs;
run() {
runFull() {
console.dir(this.runInputData.message, { depth: null });
}
}

new SimpleDebugLogTask({ input: { message: "hello world" } }).run();
new SimpleDebugLogTask({ message: "hello world" }).run();
```

Since the code itself can't read the TypeScript types, we need to explain in the static value `inputs`. We still create a type `SimpleDebugLogTaskInputs` to help us since we are writing TypeScript code. We use it to re-type (`declare`) the `defaults` and `runInputData` properties.
Expand Down Expand Up @@ -92,14 +92,14 @@ export class SimpleDebugLogTask extends SimpleTask {
},
] as const;
declare runOutputData: SimpleDebugLogTaskOutputs;
run() {
runFull() {
console.dir(this.runInputData.message, { depth: null });
this.runOutputData.output = this.runInputData.message;
return this.runOutputData;
}
}

new SimpleDebugLogTask({ input: { message: "hello world" } }).run();
new SimpleDebugLogTask({ message: "hello world" }).run();
```

In the above code, we added an output to the Task. We also added `static sideeffects` flag to tell the system that this Task has side effects. This is important for the system to know if it can cache the output of the Task or not. If a Task has side effects, it should not be cached.
Expand Down
2 changes: 1 addition & 1 deletion examples/cli/src/TaskCLI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ export function AddBaseCommands(program: Command) {
.TextEmbedding({
text: "The quick brown fox jumps over the lazy dog.",
})
.rename("*", "messages")
.rename("*", "console")
.DebugLog();

try {
Expand Down
47 changes: 23 additions & 24 deletions examples/web/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,10 @@
// * Licensed under the Apache License, Version 2.0 (the "License"); *
// *******************************************************************************

import React, { useCallback, useEffect, useState } from "react";
import { useCallback, useEffect, useState } from "react";
import { env } from "@huggingface/transformers";
import { ReactFlowProvider } from "@xyflow/react";
import { JsonTask } from "@ellmers/tasks";
import {
JsonTaskItem,
TaskGraph,
Workflow,
TaskInput,
TaskOutput,
getTaskQueueRegistry,
IndexedDbTaskGraphRepository,
IndexedDbTaskOutputRepository,
} from "@ellmers/task-graph";
import { ConcurrencyLimiter, JobQueue } from "@ellmers/job-queue";
import { AiJob } from "@ellmers/ai";
import {
LOCAL_ONNX_TRANSFORMERJS,
registerHuggingfaceLocalTasks,
Expand All @@ -27,16 +17,25 @@ import {
MEDIA_PIPE_TFJS_MODEL,
registerMediaPipeTfJsLocalTasks,
} from "@ellmers/ai-provider/tf-mediapipe";
import { registerMediaPipeTfJsLocalModels, registerHuggingfaceLocalModels } from "@ellmers/test";
import { env } from "@huggingface/transformers";
import { AiJob } from "@ellmers/ai";

import { RunGraphFlow } from "./RunGraphFlow";
import { ConcurrencyLimiter, JobQueue } from "@ellmers/job-queue";
import {
getTaskQueueRegistry,
IndexedDbTaskGraphRepository,
IndexedDbTaskOutputRepository,
JsonTaskItem,
TaskGraph,
TaskInput,
TaskOutput,
Workflow,
} from "@ellmers/task-graph";
import { JsonTask } from "@ellmers/tasks";
import { registerHuggingfaceLocalModels, registerMediaPipeTfJsLocalModels } from "@ellmers/test";
import { GraphStoreStatus } from "./GraphStoreStatus";
import { JsonEditor } from "./JsonEditor";
import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from "./Resize";
import { QueuesStatus } from "./QueueStatus";
import { OutputRepositoryStatus } from "./OutputRepositoryStatus";
import { GraphStoreStatus } from "./GraphStoreStatus";
import { QueuesStatus } from "./QueueStatus";
import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from "./Resize";
import { RunGraphFlow } from "./RunGraphFlow";

env.backends.onnx.wasm.proxy = true;

Expand Down Expand Up @@ -89,8 +88,8 @@ const resetGraph = () => {
source_lang: "en",
target_lang: "es",
})
.rename("*", "messages")
.rename("*", "messages", -2)
.rename("*", "console")
.rename("*", "console", -2)
.DebugLog({ log_level: "info" });
taskGraphRepo.saveTaskGraph("default", workflow.graph);
};
Expand Down Expand Up @@ -168,7 +167,7 @@ export const App = () => {
}, []);

const setNewJson = useCallback((json: string) => {
const task = new JsonTask({ input: { json: json } });
const task = new JsonTask({ json });
workflow.graph = task.subGraph;
setJsonData(json);
}, []);
Expand Down
2 changes: 1 addition & 1 deletion examples/web/src/JsonEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ export const JsonEditor: React.FC<PopupProps> = ({
// this will throw an error if the JSON is invalid
JSON.parse(jsonString);
// this will throw an error if the JSON is not a valid task graph
new JsonTask({ name: "Test JSON", input: { json: jsonString } });
new JsonTask({ json: jsonString }, { name: "Test JSON" });

setIsValidJSON(true);
setCode(jsonString);
Expand Down
2 changes: 1 addition & 1 deletion examples/web/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ console.log(
workflow.%cDownloadModel%c({ %cmodel%c: [%c'ONNX Xenova/LaMini-Flan-T5-783M q8']%c });
workflow.%cTextRewriter%c({ %ctext%c: %c'The quick brown fox jumps over the lazy dog.'%c, %cprompt%c: [%c'Rewrite the following text in reverse:'%c, %c'Rewrite this to sound like a pirate:'%c] });
workflow.%crename%c(%c'text'%c, %c'message'%c);
workflow.%crename%c(%c'*'%c, %c'console'%c);
workflow.%cDebugLog%c({ %clevel%c: %c'info'%c });
console.log(JSON.stringify(workflow.toJSON(),null,2));
Expand Down
64 changes: 2 additions & 62 deletions packages/task-graph/src/storage/taskgraph/TaskGraphRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@

import { EventEmitter, EventParameters } from "@ellmers/util";
import type { TabularRepository } from "@ellmers/storage";
import { Dataflow } from "../../task-graph/Dataflow";
import { TaskGraph } from "../../task-graph/TaskGraph";
import { TaskGraphItemJson, TaskGraphJson } from "task/TaskJSON";
import { TaskRegistry } from "../../task/TaskRegistry";
import { TaskConfigurationError } from "../../task/TaskError";
import { createGraphFromGraphJSON } from "../../task/TaskJSON";

/**
* Events that can be emitted by the TaskGraphRepository
Expand Down Expand Up @@ -85,62 +82,6 @@ export abstract class TaskGraphRepository {
return this.events.emitted(name) as Promise<TaskGraphEventParameters<Event>>;
}

/**
* Creates a task instance from a task graph item JSON representation
* @param item The JSON representation of the task
* @returns A new task instance
* @throws Error if required fields are missing or invalid
*/
private createTaskFromJSON(item: TaskGraphItemJson) {
if (!item.id) throw new TaskConfigurationError("Task id required");
if (!item.type) throw new TaskConfigurationError("Task type required");
if (item.input && (Array.isArray(item.input) || Array.isArray(item.provenance)))
throw new TaskConfigurationError("Task input must be an object");
if (item.provenance && (Array.isArray(item.provenance) || typeof item.provenance !== "object"))
throw new TaskConfigurationError("Task provenance must be an object");

const taskClass = TaskRegistry.all.get(item.type);
if (!taskClass) throw new TaskConfigurationError(`Task type ${item.type} not found`);
if (!taskClass.isCompound && item.subgraph) {
throw new TaskConfigurationError("Subgraph is only supported for CompoundTasks");
}

const taskConfig = {
id: item.id,
name: item.name,
provenance: item.provenance ?? {},
};

const task = new taskClass(item.input ?? {}, taskConfig);
if (task.isCompound && item.subgraph) {
task.subGraph = this.createSubGraph(item.subgraph);
}
return task;
}

/**
* Creates a TaskGraph instance from its JSON representation
* @param graphJsonObj The JSON representation of the task graph
* @returns A new TaskGraph instance with all tasks and data flows
*/
public createSubGraph(graphJsonObj: TaskGraphJson) {
const subGraph = new TaskGraph();
for (const subitem of graphJsonObj.nodes) {
subGraph.addTask(this.createTaskFromJSON(subitem));
}
for (const subitem of graphJsonObj.edges) {
subGraph.addDataflow(
new Dataflow(
subitem.sourceTaskId,
subitem.sourceTaskPortId,
subitem.targetTaskId,
subitem.targetTaskPortId
)
);
}
return subGraph;
}

/**
* Saves a task graph to persistent storage
* @param key The unique identifier for the task graph
Expand All @@ -166,8 +107,7 @@ export abstract class TaskGraphRepository {
return undefined;
}
const jsonObj = JSON.parse(value);

const graph = this.createSubGraph(jsonObj);
const graph = createGraphFromGraphJSON(jsonObj);

this.events.emit("graph_retrieved", key);
return graph;
Expand Down
Loading

0 comments on commit 08e007c

Please sign in to comment.