Skip to content

Commit

Permalink
google stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
nbonamy committed May 15, 2024
1 parent 7d4d06f commit 2e7dedb
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/services/assistant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ export default class {
}

} catch (error) {
console.error('Error while generating text', error)
if (error.name !== 'AbortError') {
if (error.status === 401 || error.message.includes('401') || error.message.toLowerCase().includes('apikey')) {
message.setText('You need to enter your API key in the Models tab of <a href="#settings">Settings</a> in order to chat.')
Expand Down
61 changes: 49 additions & 12 deletions src/services/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import { Message } from '../types/index.d'
import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmToolCall, LlmEventCallback } from '../types/llm.d'
import { EngineConfig, Configuration } from '../types/config.d'
import LlmEngine from './engine'
import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI } from '@google/generative-ai'
import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI, Part } from '@google/generative-ai'
import { getFileContents } from './download'

export const isGoogleReady = (engineConfig: EngineConfig): boolean => {
return engineConfig.apiKey?.length > 0
Expand All @@ -27,11 +28,11 @@ export default class extends LlmEngine {
}

getVisionModels(): string[] {
return []//['gemini-pro-vision', '*vision*']
return ['models/gemini-1.5-pro-latest', 'models/gemini-pro-vision', '*vision*']
}

isVisionModel(model: string): boolean {
return this.getVisionModels().includes(model) || model.includes('vision')
return this.getVisionModels().includes(model)
}

getRountingModel(): string | null {
Expand All @@ -50,7 +51,7 @@ export default class extends LlmEngine {
{ id: 'models/gemini-1.5-pro-latest', name: 'Gemini 1.5 Pro' },
//{ id: 'gemini-1.5-flash', name: 'Gemini 1.5 Flash' },
{ id: 'models/gemini-pro', name: 'Gemini 1.0 Pro' },
{ id: 'models/gemini-pro-vision', name: 'Gemini Pro Vision' },
//{ id: 'models/gemini-pro-vision', name: 'Gemini Pro Vision' },
]
}

Expand All @@ -66,7 +67,7 @@ export default class extends LlmEngine {
})

// done
const result = await chat.sendMessage(thread[thread.length-1].content)
const result = await chat.sendMessage(this.getPrompt(thread))
return {
type: 'text',
content: result.response.text()
Expand All @@ -92,26 +93,62 @@ export default class extends LlmEngine {
})

// done
const result = await this.currentChat.sendMessageStream(payload[payload.length-1].content)
const result = await this.currentChat.sendMessageStream(this.getPrompt(thread))
return result.stream

}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
getModel(model: string, instructions: string): GenerativeModel {

// not all models have all features
const hasInstructions = !(['models/gemini-pro'].includes(model))
const hasTools = false

return this.client.getGenerativeModel({
model: model,
//systemInstruction: instructions
// tools: [{
// functionDeclarations: this.getAvailableTools().map((tool) => {
// return tool.function
// })
// }],
systemInstruction: hasInstructions ? instructions : null,
tools: hasTools ? [{
functionDeclarations: this.getAvailableTools().map((tool) => {
return tool.function
})
}] : null,
}, {
apiVersion: 'v1beta'
})
}

getPrompt(thread: Message[]): Array<string | Part> {

// init
const prompt = []
const lastMessage = thread[thread.length-1]

// content
prompt.push(lastMessage.content)

// image
if (lastMessage.attachment) {

// load if no contents
if (!lastMessage.attachment.contents) {
lastMessage.attachment.contents = getFileContents(lastMessage.attachment.url).contents
}

// add inline
prompt.push({
inlineData: {
mimeType: 'image/png',
data: lastMessage.attachment.contents,
}
})

}

// done
return prompt
}

messageToContent(message: any): Content {
return {
role: message.role == 'assistant' ? 'model' : message.role,
Expand Down
2 changes: 1 addition & 1 deletion src/services/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ export default class extends LlmEngine {
}

isVisionModel(model: string): boolean {
return this.getVisionModels().includes(model) || model.includes('vision')
return this.getVisionModels().includes(model)
}

getRountingModel(): string | null {
Expand Down

0 comments on commit 2e7dedb

Please sign in to comment.