diff --git a/src/services/assistant.ts b/src/services/assistant.ts
index fd5e0b0..2d4b5f2 100644
--- a/src/services/assistant.ts
+++ b/src/services/assistant.ts
@@ -177,6 +177,7 @@ export default class {
}
} catch (error) {
+ console.error('Error while generating text', error)
if (error.name !== 'AbortError') {
if (error.status === 401 || error.message.includes('401') || error.message.toLowerCase().includes('apikey')) {
message.setText('You need to enter your API key in the Models tab of Settings in order to chat.')
diff --git a/src/services/google.ts b/src/services/google.ts
index 89b01eb..3823cad 100644
--- a/src/services/google.ts
+++ b/src/services/google.ts
@@ -3,7 +3,8 @@ import { Message } from '../types/index.d'
import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmToolCall, LlmEventCallback } from '../types/llm.d'
import { EngineConfig, Configuration } from '../types/config.d'
import LlmEngine from './engine'
-import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI } from '@google/generative-ai'
+import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI, Part } from '@google/generative-ai'
+import { getFileContents } from './download'
export const isGoogleReady = (engineConfig: EngineConfig): boolean => {
return engineConfig.apiKey?.length > 0
@@ -27,11 +28,11 @@ export default class extends LlmEngine {
}
getVisionModels(): string[] {
- return []//['gemini-pro-vision', '*vision*']
+ return ['models/gemini-1.5-pro-latest', 'models/gemini-pro-vision', '*vision*']
}
isVisionModel(model: string): boolean {
- return this.getVisionModels().includes(model) || model.includes('vision')
+ return this.getVisionModels().includes(model)
}
getRountingModel(): string | null {
@@ -50,7 +51,7 @@ export default class extends LlmEngine {
{ id: 'models/gemini-1.5-pro-latest', name: 'Gemini 1.5 Pro' },
//{ id: 'gemini-1.5-flash', name: 'Gemini 1.5 Flash' },
{ id: 'models/gemini-pro', name: 'Gemini 1.0 Pro' },
- { id: 'models/gemini-pro-vision', name: 'Gemini Pro Vision' },
+ //{ id: 'models/gemini-pro-vision', name: 'Gemini Pro Vision' },
]
}
@@ -66,7 +67,7 @@ export default class extends LlmEngine {
})
// done
- const result = await chat.sendMessage(thread[thread.length-1].content)
+ const result = await chat.sendMessage(this.getPrompt(thread))
return {
type: 'text',
content: result.response.text()
@@ -92,26 +93,62 @@ export default class extends LlmEngine {
})
// done
- const result = await this.currentChat.sendMessageStream(payload[payload.length-1].content)
+ const result = await this.currentChat.sendMessageStream(this.getPrompt(thread))
return result.stream
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
getModel(model: string, instructions: string): GenerativeModel {
+
+ // not all models have all features
+ const hasInstructions = !(['models/gemini-pro'].includes(model))
+ const hasTools = false
+
return this.client.getGenerativeModel({
model: model,
- //systemInstruction: instructions
- // tools: [{
- // functionDeclarations: this.getAvailableTools().map((tool) => {
- // return tool.function
- // })
- // }],
+ systemInstruction: hasInstructions ? instructions : null,
+ tools: hasTools ? [{
+ functionDeclarations: this.getAvailableTools().map((tool) => {
+ return tool.function
+ })
+ }] : null,
}, {
apiVersion: 'v1beta'
})
}
+ getPrompt(thread: Message[]): Array {
+
+ // init
+ const prompt = []
+ const lastMessage = thread[thread.length-1]
+
+ // content
+ prompt.push(lastMessage.content)
+
+ // image
+ if (lastMessage.attachment) {
+
+ // load if no contents
+ if (!lastMessage.attachment.contents) {
+ lastMessage.attachment.contents = getFileContents(lastMessage.attachment.url).contents
+ }
+
+ // add inline
+ prompt.push({
+ inlineData: {
+ mimeType: 'image/png',
+ data: lastMessage.attachment.contents,
+ }
+ })
+
+ }
+
+ // done
+ return prompt
+ }
+
messageToContent(message: any): Content {
return {
role: message.role == 'assistant' ? 'model' : message.role,
diff --git a/src/services/openai.ts b/src/services/openai.ts
index fab147d..18328c1 100644
--- a/src/services/openai.ts
+++ b/src/services/openai.ts
@@ -37,7 +37,7 @@ export default class extends LlmEngine {
}
isVisionModel(model: string): boolean {
- return this.getVisionModels().includes(model) || model.includes('vision')
+ return this.getVisionModels().includes(model)
}
getRountingModel(): string | null {