diff --git a/src/models/attachment.ts b/src/models/attachment.ts index ae49e15..53a2883 100644 --- a/src/models/attachment.ts +++ b/src/models/attachment.ts @@ -1,4 +1,7 @@ +export const textFormats = ['pdf', 'txt', 'docx', 'pptx', 'xlsx'] +export const imageFormats = ['jpeg', 'jpg', 'png', 'webp'] + export default class Attachment { url: string @@ -20,6 +23,14 @@ export default class Attachment { } } + isText(): boolean { + return textFormats.includes(this.format) + } + + isImage(): boolean { + return imageFormats.includes(this.format) + } + extractText(): void { const rawText = window.api.file.extractText(this.contents, this.format) //console.log('Raw text:', rawText) diff --git a/src/services/engine.ts b/src/services/engine.ts index 12665fb..4a279e1 100644 --- a/src/services/engine.ts +++ b/src/services/engine.ts @@ -7,7 +7,6 @@ import Plugin from '../plugins/plugin' import BrowsePlugin from '../plugins/browse' import TavilyPlugin from '../plugins/tavily' import PythonPlugin from '../plugins/python' -import { textFormats, imageFormats } from './llm' import { PluginParameter } from '../types/plugin.d' import { minimatch } from 'minimatch' @@ -98,7 +97,7 @@ export default class LlmEngine { } // check if amy of the messages in the thread have an attachment - return thread.some((msg) => msg.attachment && imageFormats.includes(msg.attachment.format)) + return thread.some((msg) => msg.attachment && msg.attachment.isImage()) } @@ -152,12 +151,12 @@ export default class LlmEngine { } // text formats - if (textFormats.includes(msg.attachment.format)) { + if (msg.attachment.isText()) { payload.content += `\n\n${msg.attachment.contents}` } // image formats - if (imageFormats.includes(msg.attachment.format)) { + if (msg.attachment.isImage()) { if (!imageAttached && this.isVisionModel(model)) { this.addImageToPayload(msg, payload) imageAttached = true diff --git a/src/services/google.ts b/src/services/google.ts index 5d898d8..15a2db7 100644 --- a/src/services/google.ts +++ b/src/services/google.ts @@ -144,12 +144,16 @@ export default class extends LlmEngine { } // add inline - prompt.push({ - inlineData: { - mimeType: 'image/png', - data: lastMessage.attachment.contents, - } - }) + if (lastMessage.attachment.isImage()) { + prompt.push({ + inlineData: { + mimeType: 'image/png', + data: lastMessage.attachment.contents, + } + }) + } else if (lastMessage.attachment.isText()) { + prompt.push(lastMessage.attachment.contents) + } } diff --git a/src/services/llm.ts b/src/services/llm.ts index c5eb508..9f48ab8 100644 --- a/src/services/llm.ts +++ b/src/services/llm.ts @@ -1,5 +1,6 @@ import { Model, EngineConfig, Configuration } from '../types/config.d' +import { imageFormats, textFormats } from '../models/attachment' import { store } from './store' import OpenAI, { isOpenAIReady } from './openai' import Ollama, { isOllamaReady } from './ollama' @@ -10,8 +11,6 @@ import Groq, { isGroqReady } from './groq' import LlmEngine from './engine' export const availableEngines = ['openai', 'ollama', 'anthropic', 'mistralai', 'google', 'groq'] -export const textFormats = ['pdf', 'txt', 'docx', 'pptx', 'xlsx'] -export const imageFormats = ['jpeg', 'jpg', 'png', 'webp'] export const isEngineReady = (engine: string) => { if (engine === 'openai') return isOpenAIReady(store.config.engines.openai) diff --git a/src/types/index.d.ts b/src/types/index.d.ts index 8e4ff00..3c01d64 100644 --- a/src/types/index.d.ts +++ b/src/types/index.d.ts @@ -41,6 +41,8 @@ interface Attachment { format: string contents: string downloaded: boolean + isText(): boolean + isImage(): boolean } interface Command { diff --git a/tests/unit/engine.test.ts b/tests/unit/engine.test.ts index 6bef444..bbf1d96 100644 --- a/tests/unit/engine.test.ts +++ b/tests/unit/engine.test.ts @@ -4,6 +4,7 @@ import { isEngineReady, igniteEngine, hasVisionModels, isVisionModel, loadOpenAI import { store } from '../../src/services/store' import defaults from '../../defaults/settings.json' import Message from '../../src/models/message' +import Attachment from '../../src/models/attachment' import OpenAI from '../../src/services/openai' import Ollama from '../../src/services/ollama' import MistralAI from '../../src/services/mistralai' @@ -11,7 +12,6 @@ import Anthropic from '../../src/services/anthropic' import Google from '../../src/services/google' import Groq from '../../src/services/groq' import { Model } from '../../src/types/config.d' -import { text } from 'stream/consumers' const model = [{ id: 'llava:latest', name: 'llava:latest', meta: {} }] @@ -31,6 +31,15 @@ vi.mock('openai', async() => { return { default: OpenAI } }) +window.api = { + base64: { + decode: (data: string) => data + }, + file: { + extractText: (contents) => contents + } +} + beforeEach(() => { store.config = defaults }) @@ -166,7 +175,7 @@ test('Build payload with text attachment', async () => { new Message('system', { role: 'system', type: 'text', content: 'instructions' }), new Message('user', { role: 'user', type: 'text', content: 'prompt1' }), ] - messages[1].attachFile({ format: 'txt', contents: 'attachment', downloaded: true, url: '' }) + messages[1].attachFile(new Attachment('', 'txt', 'attachment', true)) expect(openai.buildPayload(messages, 'gpt-model1')).toStrictEqual([ { role: 'system', content: 'instructions' }, { role: 'user', content: 'prompt1\n\nattachment' }, @@ -179,7 +188,7 @@ test('Build payload with image attachment', async () => { new Message('system', { role: 'system', type: 'text', content: 'instructions' }), new Message('user', { role: 'user', type: 'text', content: 'prompt1' }), ] - messages[1].attachFile({ format: 'png', contents: 'attachment', downloaded: true, url: '' }) + messages[1].attachFile(new Attachment('', 'png', 'attachment', true)) expect(openai.buildPayload(messages, 'gpt-model1')).toStrictEqual([ { role: 'system', content: 'instructions' }, { role: 'user', content: 'prompt1' },