Skip to content

Commit

Permalink
more google attachment fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nbonamy committed May 15, 2024
1 parent 413d491 commit 6280cac
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 44 deletions.
22 changes: 19 additions & 3 deletions src/models/attachment.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@

import { Attachment } from '../types/index.d'

export const textFormats = ['pdf', 'txt', 'docx', 'pptx', 'xlsx']
export const imageFormats = ['jpeg', 'jpg', 'png', 'webp']

export default class Attachment {
export default class implements Attachment {

url: string
format: string
contents: string
downloaded: boolean

constructor(url: string, format = "", contents = "", downloaded = false) {
this.url = url
constructor(url: string|object, format = "", contents = "", downloaded = false) {

if (url != null && typeof url === 'object') {
this.fromJson(url)
return
}

// default
this.url = url as string
this.format = format.toLowerCase()
this.contents = contents
this.downloaded = downloaded
Expand All @@ -23,6 +32,13 @@ export default class Attachment {
}
}

fromJson(obj: any) {
this.url = obj.url
this.format = obj.format
this.contents = obj.contents
this.downloaded = obj.downloaded
}

isText(): boolean {
return textFormats.includes(this.format)
}
Expand Down
3 changes: 3 additions & 0 deletions src/models/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ export default class implements Chat {
engine: string
model: string
messages: Message[]
deleted: boolean

constructor(obj?: any) {

Expand All @@ -29,6 +30,7 @@ export default class implements Chat {
this.engine = null
this.model = null
this.messages = []
this.deleted = false

}

Expand All @@ -40,6 +42,7 @@ export default class implements Chat {
this.engine = obj.engine || 'openai'
this.model = obj.model
this.messages = []
this.deleted = false
for (const msg of obj.messages) {
const message = new Message(msg.role, msg)
this.messages.push(message)
Expand Down
7 changes: 4 additions & 3 deletions src/models/message.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

import { Message, Attachment } from '../types/index.d'
import { LlmRole, LlmChunk } from '../types/llm.d';
import { Message } from '../types/index.d'
import { LlmRole, LlmChunk } from '../types/llm.d'
import { v4 as uuidv4 } from 'uuid'
import Attachment from './attachment'

export default class implements Message {

Expand Down Expand Up @@ -41,7 +42,7 @@ export default class implements Message {
this.role = obj.role
this.type = obj.type
this.content = obj.content
this.attachment = obj.attachment
this.attachment = obj.attachment ? new Attachment(obj.attachment) : null
this.transient = false
this.toolCall = null
}
Expand Down
2 changes: 1 addition & 1 deletion src/services/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ export default class LlmEngine {

// we only want to upload the last image attachment
// so build messages in reverse order
// and then rerse the array
// and then reverse the array

let imageAttached = false
return thread.toReversed().filter((msg) => msg.type === 'text' && msg.content !== null).map((msg): LLmCompletionPayload => {
Expand Down
77 changes: 47 additions & 30 deletions src/services/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { EngineConfig, Configuration } from '../types/config.d'
import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmToolCall, LlmEventCallback } from '../types/llm.d'
import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI, ModelParams, Part } from '@google/generative-ai'
import { getFileContents } from './download'
import Attachment from '../models/attachment'
import LlmEngine from './engine'

export const isGoogleReady = (engineConfig: EngineConfig): boolean => {
Expand Down Expand Up @@ -59,7 +60,7 @@ export default class extends LlmEngine {
console.log(`[openai] prompting model ${modelName}`)
const model = this.getModel(modelName, thread[0].content)
const chat = model.startChat({
history: thread.slice(1, -1).map((message) => this.messageToContent(message))
history: this.threadToHistory(thread, modelName)
})

// done
Expand All @@ -78,14 +79,11 @@ export default class extends LlmEngine {
// reset
this.toolCalls = []

// save the message thread
const payload = this.buildPayload(thread, modelName)

// call
console.log(`[openai] prompting model ${modelName}`)
const model = this.getModel(modelName, payload[0].content)
const model = this.getModel(modelName, thread[0].content)
this.currentChat = model.startChat({
history: payload.slice(1, -1).map((message) => this.messageToContent(message))
history: this.threadToHistory(thread, modelName)
})

// done
Expand Down Expand Up @@ -135,44 +133,63 @@ export default class extends LlmEngine {
// content
prompt.push(lastMessage.content)

// image
// attachment
if (lastMessage.attachment) {

// load if no contents
if (!lastMessage.attachment.contents) {
lastMessage.attachment.contents = getFileContents(lastMessage.attachment.url).contents
}

// add inline
if (lastMessage.attachment.isImage()) {
prompt.push({
inlineData: {
mimeType: 'image/png',
data: lastMessage.attachment.contents,
}
})
} else if (lastMessage.attachment.isText()) {
prompt.push(lastMessage.attachment.contents)
}

this.addAttachment(prompt, lastMessage.attachment)
}

// done
return prompt
}

messageToContent(message: any): Content {
return {
role: message.role == 'assistant' ? 'model' : message.role,
parts: [ { text: message.content } ]
threadToHistory(thread: Message[], modelName: string): Content[] {
const payload = this.buildPayload(thread, modelName)
return payload.slice(1, -1).map((message) => this.messageToContent(message))
}

messageToContent(payload: LLmCompletionPayload): Content {
const content: Content = {
role: payload.role == 'assistant' ? 'model' : payload.role,
parts: [ { text: payload.content } ]
}
for (const index in payload.images) {
content.parts.push({
inlineData: {
mimeType: 'image/png',
data: payload.images[index],
}
})
}
return content
}

addAttachment(parts: Array<string | Part>, attachment: Attachment) {

// load if no contents
if (!attachment.contents) {
attachment.contents = getFileContents(attachment.url).contents
}

// add inline
if (attachment.isImage()) {
parts.push({
inlineData: {
mimeType: 'image/png',
data: attachment.contents,
}
})
} else if (attachment.isText()) {
parts.push(attachment.contents)
}

}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
async stop(stream: AsyncGenerator<any>) {
//await stream?.controller?.abort()
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
async streamChunkToLlmChunk(chunk: EnhancedGenerateContentResponse, eventCallback: LlmEventCallback): Promise<LlmChunk|null> {

//console.log('[google] chunk:', chunk)
Expand Down Expand Up @@ -241,7 +258,7 @@ export default class extends LlmEngine {

// eslint-disable-next-line @typescript-eslint/no-unused-vars
addImageToPayload(message: Message, payload: LLmCompletionPayload) {
//TODO
payload.images = [message.attachment.contents]
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
Expand Down
7 changes: 5 additions & 2 deletions src/types/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ interface Chat {
engine: string
model: string
messages: Message[]
fromJson(jsonChat: any): void
deleted: boolean
fromJson(json: any): void
patchFromJson(jsonChat: any): boolean
setEngineModel(engine: string, model: string): void
addMessage(message: Message): void
Expand All @@ -28,7 +29,7 @@ interface Message {
content: string
attachment: Attachment
transient: boolean
fromJson(jsonMessage: any): void
fromJson(json: any): void
setText(text: string|null): void
setImage(url: string): void
appendText(chunk: LlmChunk): void
Expand All @@ -41,8 +42,10 @@ interface Attachment {
format: string
contents: string
downloaded: boolean
fromJson(json: any): void
isText(): boolean
isImage(): boolean
extractText(): void
}

interface Command {
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/engine_google.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import defaults from '../../defaults/settings.json'
import Message from '../../src/models/message'
import Google from '../../src/services/google'
import { loadGoogleModels } from '../../src/services/llm'
import { EnhancedGenerateContentResponse } from '@google/generative-ai'
import { EnhancedGenerateContentResponse, FinishReason } from '@google/generative-ai'
import { Model } from '../../src/types/config.d'

vi.mock('@google/generative-ai', async() => {
Expand Down Expand Up @@ -102,7 +102,7 @@ test('Google streamChunkToLlmChunk Text', async () => {
expect(streamChunk.text).toHaveBeenCalled()
//expect(streamChunk.functionCalls).toHaveBeenCalled()
expect(llmChunk1).toStrictEqual({ text: 'response', done: false })
streamChunk.candidates[0].finishReason = 'STOP'
streamChunk.candidates[0].finishReason = 'STOP' as FinishReason
streamChunk.text = vi.fn(() => '')
const llmChunk2 = await google.streamChunkToLlmChunk(streamChunk, null)
expect(streamChunk.text).toHaveBeenCalled()
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/message.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import { expect, test } from 'vitest'
import Message from '../../src/models/message'

window.api = {
base64: {
decode: (data: string) => data
},
file: {
extractText: (contents) => contents
}
}

test('Build from role and text', () => {
const message = new Message('user', 'content')
expect(message.uuid).not.toBe(null)
Expand Down
8 changes: 5 additions & 3 deletions tests/unit/store.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

import { vi, expect, test, beforeEach } from 'vitest'
import { Command } from '../../src/types/index.d'
import { store } from '../../src/services/store'
import Chat from '../../src/models/chat'
import Message from '../../src/models/message'
Expand Down Expand Up @@ -28,7 +29,7 @@ window.api = {
save: vi.fn(),
},
commands: {
load: vi.fn(() => defaultCommands),
load: vi.fn(() => defaultCommands as Command[]),
},
prompts: {
load: vi.fn(() => defaultPrompts),
Expand Down Expand Up @@ -95,9 +96,10 @@ test('Save history', async () => {
uuid: '123',
engine: 'engine',
model: 'model',
deleted: false,
messages: [
{ uuid: 1, role: 'system', content: 'Hi', toolCall: null, transient: false },
{ uuid: 2, role: 'user', content: 'Hello', toolCall: null, transient: false }
{ uuid: 1, role: 'system', content: 'Hi', toolCall: null, attachment: null, transient: false },
{ uuid: 2, role: 'user', content: 'Hello', toolCall: null, attachment: null, transient: false }
]
}])
})
Expand Down

0 comments on commit 6280cac

Please sign in to comment.