Skip to content

Commit

Permalink
google tests and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nbonamy committed May 16, 2024
1 parent 6280cac commit a535f7a
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 43 deletions.
2 changes: 1 addition & 1 deletion build/build_number.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
190
191
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "witsy",
"productName": "Witsy",
"version": "1.7.0",
"version": "1.7.2",
"description": "Witsy: desktop AI assistant",
"repository": {
"type": "git",
Expand Down
12 changes: 6 additions & 6 deletions src/services/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ export default class extends LlmEngine {
async complete(thread: Message[], opts: LlmCompletionOpts): Promise<LlmResponse> {

// call
const modelName = opts?.model || this.config.engines.openai.model.chat
console.log(`[openai] prompting model ${modelName}`)
const modelName = opts?.model || this.config.engines.google.model.chat
console.log(`[google] prompting model ${modelName}`)
const model = this.getModel(modelName, thread[0].content)
const chat = model.startChat({
history: this.threadToHistory(thread, modelName)
Expand All @@ -80,7 +80,7 @@ export default class extends LlmEngine {
this.toolCalls = []

// call
console.log(`[openai] prompting model ${modelName}`)
console.log(`[google] prompting model ${modelName}`)
const model = this.getModel(modelName, thread[0].content)
this.currentChat = model.startChat({
history: this.threadToHistory(thread, modelName)
Expand Down Expand Up @@ -143,8 +143,8 @@ export default class extends LlmEngine {
}

threadToHistory(thread: Message[], modelName: string): Content[] {
const payload = this.buildPayload(thread, modelName)
return payload.slice(1, -1).map((message) => this.messageToContent(message))
const payload = this.buildPayload(thread.slice(1, -1), modelName)
return payload.map((message) => this.messageToContent(message))
}

messageToContent(payload: LLmCompletionPayload): Content {
Expand Down Expand Up @@ -226,7 +226,7 @@ export default class extends LlmEngine {
// // now execute
// const args = JSON.parse(toolCall.args)
// const content = await this.callTool(toolCall.function, args)
// console.log(`[openai] tool call ${toolCall.function} with ${JSON.stringify(args)} => ${JSON.stringify(content).substring(0, 128)}`)
// console.log(`[google] tool call ${toolCall.function} with ${JSON.stringify(args)} => ${JSON.stringify(content).substring(0, 128)}`)

// // send
// this.currentChat.sendMessageStream([
Expand Down
128 changes: 95 additions & 33 deletions tests/unit/engine_google.test.ts
Original file line number Diff line number Diff line change
@@ -1,46 +1,33 @@

import { vi, beforeEach, expect, test } from 'vitest'
import { vi, beforeAll, beforeEach, expect, test } from 'vitest'
import { store } from '../../src/services/store'
import defaults from '../../defaults/settings.json'
import Message from '../../src/models/message'
import Google from '../../src/services/google'
import { loadGoogleModels } from '../../src/services/llm'
import { EnhancedGenerateContentResponse, FinishReason } from '@google/generative-ai'
import { Model } from '../../src/types/config.d'
import { EnhancedGenerateContentResponse, FunctionCall, FinishReason } from '@google/generative-ai'
import * as _Google from '@google/generative-ai'

vi.mock('@google/generative-ai', async() => {
return {
GoogleGenerativeAI: vi.fn((apiKey) => {
return {
apiKey: apiKey,
getGenerativeModel: vi.fn(() => {
return {
startChat: vi.fn(() => {
return {
sendMessage: vi.fn(() => {
return {
response: {
text: () => 'response'
}
}
}),
sendMessageStream: vi.fn(() => {
return {
stream: vi.fn()
}
})
}
})
}
}),
}
})
}
const GoogleChat = vi.fn()
GoogleChat.prototype.sendMessage = vi.fn(() => { return { response: { text: () => 'response' } } })
GoogleChat.prototype.sendMessageStream = vi.fn(() => { return { stream: vi.fn() } })
const GoogleModel = vi.fn()
GoogleModel.prototype.startChat = vi.fn(() => new GoogleChat())
const GoogleGenerativeAI = vi.fn()
GoogleGenerativeAI.prototype.apiKey = '123'
GoogleGenerativeAI.prototype.getGenerativeModel = vi.fn(() => new GoogleModel())
return { GoogleGenerativeAI, GoogleModel, GoogleChat, default: GoogleGenerativeAI }
})

beforeEach(() => {
beforeAll(() => {
store.config = defaults
store.config.engines.google.apiKey = '123'
store.config.engines.google.model.chat = 'models/gemini-1.5-pro-latest'
})

beforeEach(() => {
vi.clearAllMocks()
})

test('Google Load Models', async () => {
Expand Down Expand Up @@ -69,6 +56,10 @@ test('Google completion', async () => {
new Message('system', 'instruction'),
new Message('user', 'prompt'),
], null)
expect(_Google.GoogleGenerativeAI).toHaveBeenCalled()
expect(_Google.GoogleGenerativeAI.prototype.getGenerativeModel).toHaveBeenCalled()
expect(_Google.GoogleModel.prototype.startChat).toHaveBeenCalledWith({ history: []})
expect(_Google.GoogleChat.prototype.sendMessage).toHaveBeenCalledWith(['prompt'])
expect(response).toStrictEqual({
type: 'text',
content: 'response'
Expand All @@ -81,7 +72,10 @@ test('Google stream', async () => {
new Message('system', 'instruction'),
new Message('user', 'prompt'),
], null)
//expect(response.controller).toBeDefined()
expect(_Google.GoogleGenerativeAI).toHaveBeenCalled()
expect(_Google.GoogleGenerativeAI.prototype.getGenerativeModel).toHaveBeenCalled()
expect(_Google.GoogleModel.prototype.startChat).toHaveBeenCalledWith({ history: []})
expect(_Google.GoogleChat.prototype.sendMessageStream).toHaveBeenCalledWith(['prompt'])
await google.stop(response)
//expect(response.controller.abort).toHaveBeenCalled()
})
Expand All @@ -95,7 +89,7 @@ test('Google streamChunkToLlmChunk Text', async () => {
//finishReason: FinishReason.STOP,
} ],
text: vi.fn(() => 'response'),
functionCalls: vi.fn(() => []),
functionCalls: vi.fn((): FunctionCall[] => []),
functionCall: null,
}
const llmChunk1 = await google.streamChunkToLlmChunk(streamChunk, null)
Expand All @@ -109,3 +103,71 @@ test('Google streamChunkToLlmChunk Text', async () => {
//expect(streamChunk.functionCalls).toHaveBeenCalled()
expect(llmChunk2).toStrictEqual({ text: '', done: true })
})

test('Google History Complete', async () => {
const google = new Google(store.config)
await google.complete([
new Message('system', 'instruction'),
new Message('user', 'prompt1'),
new Message('assistant', 'response1'),
new Message('user', 'prompt2'),
], null)
expect(_Google.GoogleGenerativeAI.prototype.getGenerativeModel).toHaveBeenCalledWith({
model: 'models/gemini-1.5-pro-latest',
systemInstruction: 'instruction',
}, { apiVersion: 'v1beta' })
expect(_Google.GoogleModel.prototype.startChat).toHaveBeenCalledWith({ history: [
{ role: 'user', parts: [ { text: 'prompt1' } ] },
{ role: 'model', parts: [ { text: 'response1' } ] },
]})
expect(_Google.GoogleChat.prototype.sendMessage).toHaveBeenCalledWith(['prompt2'])
})

test('Google History Stream', async () => {
const google = new Google(store.config)
await google.stream([
new Message('system', 'instruction'),
new Message('user', 'prompt1'),
new Message('assistant', 'response1'),
new Message('user', 'prompt2'),
], null)
expect(_Google.GoogleGenerativeAI.prototype.getGenerativeModel).toHaveBeenCalledWith({
model: 'models/gemini-1.5-pro-latest',
systemInstruction: 'instruction',
}, { apiVersion: 'v1beta' })
expect(_Google.GoogleModel.prototype.startChat).toHaveBeenCalledWith({ history: [
{ role: 'user', parts: [ { text: 'prompt1' } ] },
{ role: 'model', parts: [ { text: 'response1' } ] },
]})
expect(_Google.GoogleChat.prototype.sendMessageStream).toHaveBeenCalledWith(['prompt2'])
})

test('Google Text Attachments', async () => {
const google = new Google(store.config)
await google.stream([
new Message('system', 'instruction'),
new Message('user', { role: 'user', type: 'text', content: 'prompt1', attachment: { url: '', format: 'txt', contents: 'text1', downloaded: true } } ),
new Message('assistant', 'response1'),
new Message('user', { role: 'user', type: 'text', content: 'prompt2', attachment: { url: '', format: 'txt', contents: 'text2', downloaded: true } } ),
], null)
expect(_Google.GoogleModel.prototype.startChat).toHaveBeenCalledWith({ history: [
{ role: 'user', parts: [ { text: 'prompt1\n\ntext1' } ] },
{ role: 'model', parts: [ { text: 'response1' } ] },
]})
expect(_Google.GoogleChat.prototype.sendMessageStream).toHaveBeenCalledWith(['prompt2', 'text2'])
})

test('Google Image Attachments', async () => {
const google = new Google(store.config)
await google.stream([
new Message('system', 'instruction'),
new Message('user', { role: 'user', type: 'text', content: 'prompt1', attachment: { url: '', format: 'png', contents: 'image', downloaded: true } } ),
new Message('assistant', 'response1'),
new Message('user', { role: 'user', type: 'text', content: 'prompt2', attachment: { url: '', format: 'png', contents: 'image', downloaded: true } } ),
], null)
expect(_Google.GoogleModel.prototype.startChat).toHaveBeenCalledWith({ history: [
{ role: 'user', parts: [ { text: 'prompt1' }, { inlineData: { data: 'image', mimeType: 'image/png' }} ] },
{ role: 'model', parts: [ { text: 'response1' } ] },
]})
expect(_Google.GoogleChat.prototype.sendMessageStream).toHaveBeenCalledWith(['prompt2', { inlineData: { data: 'image', mimeType: 'image/png' }}])
})

0 comments on commit a535f7a

Please sign in to comment.