Skip to content

Commit

Permalink
fix: default model registry
Browse files Browse the repository at this point in the history
  • Loading branch information
sroussey committed Jan 31, 2025
1 parent 1b29fa8 commit a54a998
Showing 1 changed file with 37 additions and 25 deletions.
62 changes: 37 additions & 25 deletions packages/ai/src/model/ModelRegistry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,48 +10,60 @@ import { ModelRepository, Task2ModelPrimaryKey } from "./ModelRepository";

// temporary model registry that is synchronous until we have a proper model repository

class FallbackModelRegistry {
models: Model[] = [];
task2models: Task2ModelPrimaryKey[] = [];
export class FallbackModelRegistry {
private models: Model[] = [];
private task2models: Task2ModelPrimaryKey[] = [];

constructor() {
console.warn("Using FallbackModelRegistry");
}

public async addModel(model: Model) {
if (this.models.some((m) => m.name === model.name)) {
this.models = this.models.filter((m) => m.name !== model.name);
public async addModel(model: Model): Promise<void> {
const existingIndex = this.models.findIndex((m) => m.name === model.name);
if (existingIndex !== -1) {
this.models[existingIndex] = model;
} else {
this.models.push(model);
}

this.models.push(model);
}
public async findModelsByTask(task: string) {
return this.task2models

public async findModelsByTask(task: string): Promise<Model[]|undefined> {
const models = this.task2models
.filter((t2m) => t2m.task === task)
.map((t2m) => this.models.find((m) => m.name === t2m.model))
.filter((m) => m !== undefined);
.map((t2m) => this.models.find((m) => m.name === t2m.model));

if (models.some((m) => m === undefined)) {
console.warn(`Some models for task ${task} were not found`);
}

const found = models.filter((m): m is Model => m !== undefined)
return found.length > 0 ? found : undefined;
}
public async findTasksByModel(name: string) {
public async findTasksByModel(name: string): Promise<string[]> {
return this.task2models.filter((t2m) => t2m.model === name).map((t2m) => t2m.task);
}
public async findByName(name: string) {
public async findByName(name: string): Promise<Model | undefined> {
return this.models.find((m) => m.name === name);
}
public async enumerateAllModels() {
return this.models;

public async enumerateAllModels(): Promise<Model[]> {
return [...this.models];
}
public async enumerateAllTasks() {
const tasks = new Set<string>();
for (const t2m of this.task2models) {
tasks.add(t2m.task);
}
return Array.from(tasks);

public async enumerateAllTasks(): Promise<string[]> {
return Array.from(new Set(this.task2models.map(t2m => t2m.task)));
}
public async size() {

public async size(): Promise<number> {
return this.models.length;
}
public async connectTaskToModel(task: string, model: string) {
this.task2models.push({ task, model });

public async connectTaskToModel(task: string, modelName: string): Promise<void> {
if (!this.findByName(modelName)) {
throw new Error(`Model ${modelName} not found when connecting to task ${task}`);
}

this.task2models.push({ task, model: modelName });
}
}

Expand Down

0 comments on commit a54a998

Please sign in to comment.