From 402f5792f0c15655abda1c4f60002a1fc6f5609f Mon Sep 17 00:00:00 2001 From: Nicolas Bonamy Date: Wed, 15 May 2024 09:21:33 -0500 Subject: [PATCH] gemini flash support --- src/services/google.ts | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/services/google.ts b/src/services/google.ts index 3823cad..2daa448 100644 --- a/src/services/google.ts +++ b/src/services/google.ts @@ -1,10 +1,10 @@ import { Message } from '../types/index.d' -import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmToolCall, LlmEventCallback } from '../types/llm.d' import { EngineConfig, Configuration } from '../types/config.d' -import LlmEngine from './engine' -import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI, Part } from '@google/generative-ai' +import { LLmCompletionPayload, LlmChunk, LlmCompletionOpts, LlmResponse, LlmStream, LlmToolCall, LlmEventCallback } from '../types/llm.d' +import { ChatSession, Content, EnhancedGenerateContentResponse, GenerativeModel, GoogleGenerativeAI, ModelParams, Part } from '@google/generative-ai' import { getFileContents } from './download' +import LlmEngine from './engine' export const isGoogleReady = (engineConfig: EngineConfig): boolean => { return engineConfig.apiKey?.length > 0 @@ -28,7 +28,7 @@ export default class extends LlmEngine { } getVisionModels(): string[] { - return ['models/gemini-1.5-pro-latest', 'models/gemini-pro-vision', '*vision*'] + return ['gemini-1.5-flash-latest', 'models/gemini-1.5-pro-latest', 'models/gemini-pro-vision', '*vision*'] } isVisionModel(model: string): boolean { @@ -49,7 +49,7 @@ export default class extends LlmEngine { // do it return [ { id: 'models/gemini-1.5-pro-latest', name: 'Gemini 1.5 Pro' }, - //{ id: 'gemini-1.5-flash', name: 'Gemini 1.5 Flash' }, + { id: 'gemini-1.5-flash-latest', name: 'Gemini 1.5 Flash' }, { id: 'models/gemini-pro', name: 'Gemini 1.0 Pro' }, //{ id: 'models/gemini-pro-vision', name: 'Gemini Pro Vision' }, ] @@ -102,18 +102,30 @@ export default class extends LlmEngine { getModel(model: string, instructions: string): GenerativeModel { // not all models have all features - const hasInstructions = !(['models/gemini-pro'].includes(model)) + const hasInstructions = !(['models/gemini-pro', 'gemini-1.5-flash'].includes(model)) const hasTools = false - return this.client.getGenerativeModel({ + // model params + const modelParams: ModelParams = { model: model, - systemInstruction: hasInstructions ? instructions : null, - tools: hasTools ? [{ + } + + // add instructions + if (hasInstructions) { + modelParams.systemInstruction = instructions + } + + // add tools + if (hasTools) { + modelParams.tools = [{ functionDeclarations: this.getAvailableTools().map((tool) => { return tool.function }) - }] : null, - }, { + }] + } + + // call + return this.client.getGenerativeModel( modelParams, { apiVersion: 'v1beta' }) }