From 1b29fa8f9e8baf4643e63b2f119128988b898fa1 Mon Sep 17 00:00:00 2001 From: Steven Roussey Date: Thu, 30 Jan 2025 17:23:23 -0800 Subject: [PATCH] feat: get all the task names or all models from model repo --- packages/ai/src/model/ModelRegistry.ts | 17 +++++++++++++++++ packages/ai/src/model/ModelRepository.ts | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/packages/ai/src/model/ModelRegistry.ts b/packages/ai/src/model/ModelRegistry.ts index 1200da2..83377b7 100644 --- a/packages/ai/src/model/ModelRegistry.ts +++ b/packages/ai/src/model/ModelRegistry.ts @@ -14,6 +14,10 @@ class FallbackModelRegistry { models: Model[] = []; task2models: Task2ModelPrimaryKey[] = []; + constructor() { + console.warn("Using FallbackModelRegistry"); + } + public async addModel(model: Model) { if (this.models.some((m) => m.name === model.name)) { this.models = this.models.filter((m) => m.name !== model.name); @@ -33,6 +37,19 @@ class FallbackModelRegistry { public async findByName(name: string) { return this.models.find((m) => m.name === name); } + public async enumerateAllModels() { + return this.models; + } + public async enumerateAllTasks() { + const tasks = new Set(); + for (const t2m of this.task2models) { + tasks.add(t2m.task); + } + return Array.from(tasks); + } + public async size() { + return this.models.length; + } public async connectTaskToModel(task: string, model: string) { this.task2models.push({ task, model }); } diff --git a/packages/ai/src/model/ModelRepository.ts b/packages/ai/src/model/ModelRepository.ts index 531238f..6006f36 100644 --- a/packages/ai/src/model/ModelRepository.ts +++ b/packages/ai/src/model/ModelRepository.ts @@ -144,6 +144,27 @@ export abstract class ModelRepository { return junctions.map((junction) => junction.task); } + /** + * Enumerates all tasks in the repository + * @returns Promise resolving to an array of task identifiers + */ + async enumerateAllTasks() { + const junctions = await this.task2ModelKvRepository.getAll(); + if (!junctions || junctions.length === 0) return undefined; + const uniqueTasks = [...new Set(junctions.map((junction) => junction.task))]; + return uniqueTasks; + } + + /** + * Enumerates all models in the repository + * @returns Promise resolving to an array of model instances + */ + async enumerateAllModels() { + const models = await this.modelKvRepository.getAll(); + if (!models || models.length === 0) return undefined; + return models; + } + /** * Creates an association between a task and a model * @param task - The task identifier