diff --git a/src/models/replicate.js b/src/models/replicate.js index aa296c57d..d38e8fea8 100644 --- a/src/models/replicate.js +++ b/src/models/replicate.js @@ -2,7 +2,7 @@ import Replicate from 'replicate'; import { toSinglePrompt } from '../utils/text.js'; import { getKey } from '../utils/keys.js'; -// llama, mistral +// llama, mistral, gemini export class ReplicateAPI { static prefix = 'replicate'; constructor(model_name, url, params) { @@ -24,24 +24,60 @@ export class ReplicateAPI { const prompt = toSinglePrompt(turns, null, stop_seq); let model_name = this.model_name || 'meta/meta-llama-3-70b-instruct'; - const input = { - prompt, - system_prompt: systemMessage, - ...(this.params || {}) - }; + // Detect model type to use correct input format + const isGemini = model_name.includes('gemini'); + + let input; + if (isGemini) { + // Gemini models on Replicate ignore system_prompt field + // Combine system message into the main prompt instead + const fullPrompt = systemMessage + '\n\n' + prompt; + input = { + prompt: fullPrompt, + ...(this.params || {}) + }; + } else { + // Llama and other models use system_prompt + input = { + prompt, + system_prompt: systemMessage, + ...(this.params || {}) + }; + } + let res = null; try { console.log('Awaiting Replicate API response...'); - let result = ''; - for await (const event of this.replicate.stream(model_name, { input })) { - result += event; - if (result === '') break; - if (result.includes(stop_seq)) { - result = result.slice(0, result.indexOf(stop_seq)); - break; + + if (isGemini) { + // Gemini doesn't stream well on Replicate, use run() instead + const output = await this.replicate.run(model_name, { input }); + // Output might be a string or an array + if (Array.isArray(output)) { + res = output.join(''); + } else if (typeof output === 'string') { + res = output; + } else { + res = String(output); } + } else { + // Use streaming for other models + let result = ''; + for await (const event of this.replicate.stream(model_name, { input })) { + result += event; + if (result === '') break; + if (result.includes(stop_seq)) { + result = result.slice(0, result.indexOf(stop_seq)); + break; + } + } + res = result; + } + + // Trim stop sequence if present + if (res && res.includes(stop_seq)) { + res = res.slice(0, res.indexOf(stop_seq)); } - res = result; } catch (err) { console.log(err); res = 'My brain disconnected, try again.'; @@ -51,10 +87,50 @@ export class ReplicateAPI { } async embed(text) { - const output = await this.replicate.run( - this.model_name || "mark3labs/embeddings-gte-base:d619cff29338b9a37c3d06605042e1ff0594a8c3eff0175fd6967f5643fc4d47", - { input: {text} } + // Always use a dedicated embedding model, not the chat model + const DEFAULT_EMBEDDING_MODEL = "mark3labs/embeddings-gte-base:d619cff29338b9a37c3d06605042e1ff0594a8c3eff0175fd6967f5643fc4d47"; + + // Validate text input + if (!text || typeof text !== 'string') { + throw new Error('Text is required for embedding'); + } + + // Check if model_name is an embedding model or a chat model + // Chat models (like meta/meta-llama-3-70b-instruct) won't work for embeddings + const isEmbeddingModel = this.model_name && ( + this.model_name.includes('embed') || + this.model_name.includes('gte') || + this.model_name.includes('e5-') ); - return output.vectors; + const embeddingModel = isEmbeddingModel ? this.model_name : DEFAULT_EMBEDDING_MODEL; + + // Helper to extract embedding from various output formats + const extractEmbedding = (output) => { + if (output.vectors) { + return output.vectors; + } else if (Array.isArray(output)) { + return output; + } else if (output.embedding) { + return output.embedding; + } else if (output.embeddings) { + return Array.isArray(output.embeddings[0]) ? output.embeddings[0] : output.embeddings; + } + return null; + }; + + try { + const output = await this.replicate.run( + embeddingModel, + { input: { text } } + ); + const embedding = extractEmbedding(output); + if (embedding) { + return embedding; + } + throw new Error('Unknown embedding output format'); + } catch (err) { + console.error('Replicate embed error:', err.message || err); + throw err; + } } } \ No newline at end of file