diff --git a/packages/ai/src/model/ModelRegistry.ts b/packages/ai/src/model/ModelRegistry.ts index 83377b7..aa0e05c 100644 --- a/packages/ai/src/model/ModelRegistry.ts +++ b/packages/ai/src/model/ModelRegistry.ts @@ -10,48 +10,60 @@ import { ModelRepository, Task2ModelPrimaryKey } from "./ModelRepository"; // temporary model registry that is synchronous until we have a proper model repository -class FallbackModelRegistry { - models: Model[] = []; - task2models: Task2ModelPrimaryKey[] = []; +export class FallbackModelRegistry { + private models: Model[] = []; + private task2models: Task2ModelPrimaryKey[] = []; constructor() { console.warn("Using FallbackModelRegistry"); } - public async addModel(model: Model) { - if (this.models.some((m) => m.name === model.name)) { - this.models = this.models.filter((m) => m.name !== model.name); + public async addModel(model: Model): Promise { + const existingIndex = this.models.findIndex((m) => m.name === model.name); + if (existingIndex !== -1) { + this.models[existingIndex] = model; + } else { + this.models.push(model); } - - this.models.push(model); } - public async findModelsByTask(task: string) { - return this.task2models + + public async findModelsByTask(task: string): Promise { + const models = this.task2models .filter((t2m) => t2m.task === task) - .map((t2m) => this.models.find((m) => m.name === t2m.model)) - .filter((m) => m !== undefined); + .map((t2m) => this.models.find((m) => m.name === t2m.model)); + + if (models.some((m) => m === undefined)) { + console.warn(`Some models for task ${task} were not found`); + } + + const found = models.filter((m): m is Model => m !== undefined) + return found.length > 0 ? found : undefined; } - public async findTasksByModel(name: string) { + public async findTasksByModel(name: string): Promise { return this.task2models.filter((t2m) => t2m.model === name).map((t2m) => t2m.task); } - public async findByName(name: string) { + public async findByName(name: string): Promise { return this.models.find((m) => m.name === name); } - public async enumerateAllModels() { - return this.models; + + public async enumerateAllModels(): Promise { + return [...this.models]; } - public async enumerateAllTasks() { - const tasks = new Set(); - for (const t2m of this.task2models) { - tasks.add(t2m.task); - } - return Array.from(tasks); + + public async enumerateAllTasks(): Promise { + return Array.from(new Set(this.task2models.map(t2m => t2m.task))); } - public async size() { + + public async size(): Promise { return this.models.length; } - public async connectTaskToModel(task: string, model: string) { - this.task2models.push({ task, model }); + + public async connectTaskToModel(task: string, modelName: string): Promise { + if (!this.findByName(modelName)) { + throw new Error(`Model ${modelName} not found when connecting to task ${task}`); + } + + this.task2models.push({ task, model: modelName }); } }