diff --git a/src/models/attachment.ts b/src/models/attachment.ts index 53a2883..e9de4d1 100644 --- a/src/models/attachment.ts +++ b/src/models/attachment.ts @@ -1,16 +1,25 @@ +import { Attachment } from '../types/index.d' + export const textFormats = ['pdf', 'txt', 'docx', 'pptx', 'xlsx'] export const imageFormats = ['jpeg', 'jpg', 'png', 'webp'] -export default class Attachment { +export default class implements Attachment { url: string format: string contents: string downloaded: boolean - constructor(url: string, format = "", contents = "", downloaded = false) { - this.url = url + constructor(url: string|object, format = "", contents = "", downloaded = false) { + + if (url != null && typeof url === 'object') { + this.fromJson(url) + return + } + + // default + this.url = url as string this.format = format.toLowerCase() this.contents = contents this.downloaded = downloaded @@ -23,6 +32,13 @@ export default class Attachment { } } + fromJson(obj: any) { + this.url = obj.url + this.format = obj.format + this.contents = obj.contents + this.downloaded = obj.downloaded + } + isText(): boolean { return textFormats.includes(this.format) } diff --git a/src/models/chat.ts b/src/models/chat.ts index 333b2db..03d9d27 100644 --- a/src/models/chat.ts +++ b/src/models/chat.ts @@ -12,6 +12,7 @@ export default class implements Chat { engine: string model: string messages: Message[] + deleted: boolean constructor(obj?: any) { @@ -29,6 +30,7 @@ export default class implements Chat { this.engine = null this.model = null this.messages = [] + this.deleted = false } @@ -40,6 +42,7 @@ export default class implements Chat { this.engine = obj.engine || 'openai' this.model = obj.model this.messages = [] + this.deleted = false for (const msg of obj.messages) { const message = new Message(msg.role, msg) this.messages.push(message) diff --git a/src/models/message.ts b/src/models/message.ts index 498b72b..f7b04d2 100644 --- a/src/models/message.ts +++ b/src/models/message.ts @@ -1,7 +1,8 @@ -import { Message, Attachment } from '../types/index.d' -import { LlmRole, LlmChunk } from '../types/llm.d'; +import { Message } from '../types/index.d' +import { LlmRole, LlmChunk } from '../types/llm.d' import { v4 as uuidv4 } from 'uuid' +import Attachment from './attachment' export default class implements Message { @@ -41,7 +42,7 @@ export default class implements Message { this.role = obj.role this.type = obj.type this.content = obj.content - this.attachment = obj.attachment + this.attachment = obj.attachment ? new Attachment(obj.attachment) : null this.transient = false this.toolCall = null } diff --git a/src/services/engine.ts b/src/services/engine.ts index 4a279e1..4c501cc 100644 --- a/src/services/engine.ts +++ b/src/services/engine.ts @@ -136,7 +136,7 @@ export default class LlmEngine { // we only want to upload the last image attachment // so build messages in reverse order - // and then rerse the array + // and then reverse the array let imageAttached = false return thread.toReversed().filter((msg) => msg.type === 'text' && msg.content !== null).map((msg): LLmCompletionPayload => { diff --git a/src/services/google.ts b/src/services/google.ts index 15a2db7..1c2aa47 100644 --- a/src/services/google.ts +++ b/src/services/google.ts @@ -4,6 +4,7 @@ import { EngineConfig, Configuration } from '../types/config.d' import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmToolCall, LlmEventCallback } from '../types/llm.d' import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI, ModelParams, Part } from '@google/generative-ai' import { getFileContents } from './download' +import Attachment from '../models/attachment' import LlmEngine from './engine' export const isGoogleReady = (engineConfig: EngineConfig): boolean => { @@ -59,7 +60,7 @@ export default class extends LlmEngine { console.log(`[openai] prompting model ${modelName}`) const model = this.getModel(modelName, thread[0].content) const chat = model.startChat({ - history: thread.slice(1, -1).map((message) => this.messageToContent(message)) + history: this.threadToHistory(thread, modelName) }) // done @@ -78,14 +79,11 @@ export default class extends LlmEngine { // reset this.toolCalls = [] - // save the message thread - const payload = this.buildPayload(thread, modelName) - // call console.log(`[openai] prompting model ${modelName}`) - const model = this.getModel(modelName, payload[0].content) + const model = this.getModel(modelName, thread[0].content) this.currentChat = model.startChat({ - history: payload.slice(1, -1).map((message) => this.messageToContent(message)) + history: this.threadToHistory(thread, modelName) }) // done @@ -135,37 +133,55 @@ export default class extends LlmEngine { // content prompt.push(lastMessage.content) - // image + // attachment if (lastMessage.attachment) { - - // load if no contents - if (!lastMessage.attachment.contents) { - lastMessage.attachment.contents = getFileContents(lastMessage.attachment.url).contents - } - - // add inline - if (lastMessage.attachment.isImage()) { - prompt.push({ - inlineData: { - mimeType: 'image/png', - data: lastMessage.attachment.contents, - } - }) - } else if (lastMessage.attachment.isText()) { - prompt.push(lastMessage.attachment.contents) - } - + this.addAttachment(prompt, lastMessage.attachment) } // done return prompt } - messageToContent(message: any): Content { - return { - role: message.role == 'assistant' ? 'model' : message.role, - parts: [ { text: message.content } ] + threadToHistory(thread: Message[], modelName: string): Content[] { + const payload = this.buildPayload(thread, modelName) + return payload.slice(1, -1).map((message) => this.messageToContent(message)) + } + + messageToContent(payload: LLmCompletionPayload): Content { + const content: Content = { + role: payload.role == 'assistant' ? 'model' : payload.role, + parts: [ { text: payload.content } ] + } + for (const index in payload.images) { + content.parts.push({ + inlineData: { + mimeType: 'image/png', + data: payload.images[index], + } + }) } + return content + } + + addAttachment(parts: Array, attachment: Attachment) { + + // load if no contents + if (!attachment.contents) { + attachment.contents = getFileContents(attachment.url).contents + } + + // add inline + if (attachment.isImage()) { + parts.push({ + inlineData: { + mimeType: 'image/png', + data: attachment.contents, + } + }) + } else if (attachment.isText()) { + parts.push(attachment.contents) + } + } // eslint-disable-next-line @typescript-eslint/no-unused-vars @@ -173,6 +189,7 @@ export default class extends LlmEngine { //await stream?.controller?.abort() } + // eslint-disable-next-line @typescript-eslint/no-unused-vars async streamChunkToLlmChunk(chunk: EnhancedGenerateContentResponse, eventCallback: LlmEventCallback): Promise { //console.log('[google] chunk:', chunk) @@ -241,7 +258,7 @@ export default class extends LlmEngine { // eslint-disable-next-line @typescript-eslint/no-unused-vars addImageToPayload(message: Message, payload: LLmCompletionPayload) { - //TODO + payload.images = [message.attachment.contents] } // eslint-disable-next-line @typescript-eslint/no-unused-vars diff --git a/src/types/index.d.ts b/src/types/index.d.ts index 3c01d64..d081109 100644 --- a/src/types/index.d.ts +++ b/src/types/index.d.ts @@ -11,7 +11,8 @@ interface Chat { engine: string model: string messages: Message[] - fromJson(jsonChat: any): void + deleted: boolean + fromJson(json: any): void patchFromJson(jsonChat: any): boolean setEngineModel(engine: string, model: string): void addMessage(message: Message): void @@ -28,7 +29,7 @@ interface Message { content: string attachment: Attachment transient: boolean - fromJson(jsonMessage: any): void + fromJson(json: any): void setText(text: string|null): void setImage(url: string): void appendText(chunk: LlmChunk): void @@ -41,8 +42,10 @@ interface Attachment { format: string contents: string downloaded: boolean + fromJson(json: any): void isText(): boolean isImage(): boolean + extractText(): void } interface Command { diff --git a/tests/unit/engine_google.test.ts b/tests/unit/engine_google.test.ts index 95a6f91..13008f1 100644 --- a/tests/unit/engine_google.test.ts +++ b/tests/unit/engine_google.test.ts @@ -5,7 +5,7 @@ import defaults from '../../defaults/settings.json' import Message from '../../src/models/message' import Google from '../../src/services/google' import { loadGoogleModels } from '../../src/services/llm' -import { EnhancedGenerateContentResponse } from '@google/generative-ai' +import { EnhancedGenerateContentResponse, FinishReason } from '@google/generative-ai' import { Model } from '../../src/types/config.d' vi.mock('@google/generative-ai', async() => { @@ -102,7 +102,7 @@ test('Google streamChunkToLlmChunk Text', async () => { expect(streamChunk.text).toHaveBeenCalled() //expect(streamChunk.functionCalls).toHaveBeenCalled() expect(llmChunk1).toStrictEqual({ text: 'response', done: false }) - streamChunk.candidates[0].finishReason = 'STOP' + streamChunk.candidates[0].finishReason = 'STOP' as FinishReason streamChunk.text = vi.fn(() => '') const llmChunk2 = await google.streamChunkToLlmChunk(streamChunk, null) expect(streamChunk.text).toHaveBeenCalled() diff --git a/tests/unit/message.test.ts b/tests/unit/message.test.ts index eab9e46..9e4f083 100644 --- a/tests/unit/message.test.ts +++ b/tests/unit/message.test.ts @@ -1,6 +1,15 @@ import { expect, test } from 'vitest' import Message from '../../src/models/message' +window.api = { + base64: { + decode: (data: string) => data + }, + file: { + extractText: (contents) => contents + } +} + test('Build from role and text', () => { const message = new Message('user', 'content') expect(message.uuid).not.toBe(null) diff --git a/tests/unit/store.test.ts b/tests/unit/store.test.ts index b723794..8f6269c 100644 --- a/tests/unit/store.test.ts +++ b/tests/unit/store.test.ts @@ -1,5 +1,6 @@ import { vi, expect, test, beforeEach } from 'vitest' +import { Command } from '../../src/types/index.d' import { store } from '../../src/services/store' import Chat from '../../src/models/chat' import Message from '../../src/models/message' @@ -28,7 +29,7 @@ window.api = { save: vi.fn(), }, commands: { - load: vi.fn(() => defaultCommands), + load: vi.fn(() => defaultCommands as Command[]), }, prompts: { load: vi.fn(() => defaultPrompts), @@ -95,9 +96,10 @@ test('Save history', async () => { uuid: '123', engine: 'engine', model: 'model', + deleted: false, messages: [ - { uuid: 1, role: 'system', content: 'Hi', toolCall: null, transient: false }, - { uuid: 2, role: 'user', content: 'Hello', toolCall: null, transient: false } + { uuid: 1, role: 'system', content: 'Hi', toolCall: null, attachment: null, transient: false }, + { uuid: 2, role: 'user', content: 'Hello', toolCall: null, attachment: null, transient: false } ] }]) })