Skip to content

Commit

Permalink
feat: get all the task names or all models from model repo
Browse files Browse the repository at this point in the history
  • Loading branch information
sroussey committed Jan 31, 2025
1 parent a07e709 commit 1b29fa8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
17 changes: 17 additions & 0 deletions packages/ai/src/model/ModelRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ class FallbackModelRegistry {
models: Model[] = [];
task2models: Task2ModelPrimaryKey[] = [];

constructor() {
console.warn("Using FallbackModelRegistry");
}

public async addModel(model: Model) {
if (this.models.some((m) => m.name === model.name)) {
this.models = this.models.filter((m) => m.name !== model.name);
Expand All @@ -33,6 +37,19 @@ class FallbackModelRegistry {
public async findByName(name: string) {
return this.models.find((m) => m.name === name);
}
public async enumerateAllModels() {
return this.models;
}
public async enumerateAllTasks() {
const tasks = new Set<string>();
for (const t2m of this.task2models) {
tasks.add(t2m.task);
}
return Array.from(tasks);
}
public async size() {
return this.models.length;
}
public async connectTaskToModel(task: string, model: string) {
this.task2models.push({ task, model });
}
Expand Down
21 changes: 21 additions & 0 deletions packages/ai/src/model/ModelRepository.ts
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,27 @@ export abstract class ModelRepository {
return junctions.map((junction) => junction.task);
}

/**
* Enumerates all tasks in the repository
* @returns Promise resolving to an array of task identifiers
*/
async enumerateAllTasks() {
const junctions = await this.task2ModelKvRepository.getAll();
if (!junctions || junctions.length === 0) return undefined;
const uniqueTasks = [...new Set(junctions.map((junction) => junction.task))];
return uniqueTasks;
}

/**
* Enumerates all models in the repository
* @returns Promise resolving to an array of model instances
*/
async enumerateAllModels() {
const models = await this.modelKvRepository.getAll();
if (!models || models.length === 0) return undefined;
return models;
}

/**
* Creates an association between a task and a model
* @param task - The task identifier
Expand Down

0 comments on commit 1b29fa8

Please sign in to comment.