diff --git a/src/services/assistant.ts b/src/services/assistant.ts index fd5e0b0..2d4b5f2 100644 --- a/src/services/assistant.ts +++ b/src/services/assistant.ts @@ -177,6 +177,7 @@ export default class { } } catch (error) { + console.error('Error while generating text', error) if (error.name !== 'AbortError') { if (error.status === 401 || error.message.includes('401') || error.message.toLowerCase().includes('apikey')) { message.setText('You need to enter your API key in the Models tab of Settings in order to chat.') diff --git a/src/services/google.ts b/src/services/google.ts index 89b01eb..3823cad 100644 --- a/src/services/google.ts +++ b/src/services/google.ts @@ -3,7 +3,8 @@ import { Message } from '../types/index.d' import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmToolCall, LlmEventCallback } from '../types/llm.d' import { EngineConfig, Configuration } from '../types/config.d' import LlmEngine from './engine' -import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI } from '@google/generative-ai' +import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI, Part } from '@google/generative-ai' +import { getFileContents } from './download' export const isGoogleReady = (engineConfig: EngineConfig): boolean => { return engineConfig.apiKey?.length > 0 @@ -27,11 +28,11 @@ export default class extends LlmEngine { } getVisionModels(): string[] { - return []//['gemini-pro-vision', '*vision*'] + return ['models/gemini-1.5-pro-latest', 'models/gemini-pro-vision', '*vision*'] } isVisionModel(model: string): boolean { - return this.getVisionModels().includes(model) || model.includes('vision') + return this.getVisionModels().includes(model) } getRountingModel(): string | null { @@ -50,7 +51,7 @@ export default class extends LlmEngine { { id: 'models/gemini-1.5-pro-latest', name: 'Gemini 1.5 Pro' }, //{ id: 'gemini-1.5-flash', name: 'Gemini 1.5 Flash' }, { id: 'models/gemini-pro', name: 'Gemini 1.0 Pro' }, - { id: 'models/gemini-pro-vision', name: 'Gemini Pro Vision' }, + //{ id: 'models/gemini-pro-vision', name: 'Gemini Pro Vision' }, ] } @@ -66,7 +67,7 @@ export default class extends LlmEngine { }) // done - const result = await chat.sendMessage(thread[thread.length-1].content) + const result = await chat.sendMessage(this.getPrompt(thread)) return { type: 'text', content: result.response.text() @@ -92,26 +93,62 @@ export default class extends LlmEngine { }) // done - const result = await this.currentChat.sendMessageStream(payload[payload.length-1].content) + const result = await this.currentChat.sendMessageStream(this.getPrompt(thread)) return result.stream } // eslint-disable-next-line @typescript-eslint/no-unused-vars getModel(model: string, instructions: string): GenerativeModel { + + // not all models have all features + const hasInstructions = !(['models/gemini-pro'].includes(model)) + const hasTools = false + return this.client.getGenerativeModel({ model: model, - //systemInstruction: instructions - // tools: [{ - // functionDeclarations: this.getAvailableTools().map((tool) => { - // return tool.function - // }) - // }], + systemInstruction: hasInstructions ? instructions : null, + tools: hasTools ? [{ + functionDeclarations: this.getAvailableTools().map((tool) => { + return tool.function + }) + }] : null, }, { apiVersion: 'v1beta' }) } + getPrompt(thread: Message[]): Array { + + // init + const prompt = [] + const lastMessage = thread[thread.length-1] + + // content + prompt.push(lastMessage.content) + + // image + if (lastMessage.attachment) { + + // load if no contents + if (!lastMessage.attachment.contents) { + lastMessage.attachment.contents = getFileContents(lastMessage.attachment.url).contents + } + + // add inline + prompt.push({ + inlineData: { + mimeType: 'image/png', + data: lastMessage.attachment.contents, + } + }) + + } + + // done + return prompt + } + messageToContent(message: any): Content { return { role: message.role == 'assistant' ? 'model' : message.role, diff --git a/src/services/openai.ts b/src/services/openai.ts index fab147d..18328c1 100644 --- a/src/services/openai.ts +++ b/src/services/openai.ts @@ -37,7 +37,7 @@ export default class extends LlmEngine { } isVisionModel(model: string): boolean { - return this.getVisionModels().includes(model) || model.includes('vision') + return this.getVisionModels().includes(model) } getRountingModel(): string | null {