diff --git a/.changeset/thirty-planets-wait.md b/.changeset/thirty-planets-wait.md new file mode 100644 index 000000000000..ee2fc3bee45a --- /dev/null +++ b/.changeset/thirty-planets-wait.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/google-vertex': patch +--- + +feat (provider/vertex): tool choice support & object generation with tool mode diff --git a/packages/google-vertex/src/convert-json-schema-to-openapi-schema.test.ts b/packages/google-vertex/src/convert-json-schema-to-openapi-schema.test.ts new file mode 100644 index 000000000000..dc0c6bdd7fb7 --- /dev/null +++ b/packages/google-vertex/src/convert-json-schema-to-openapi-schema.test.ts @@ -0,0 +1,468 @@ +import { JSONSchema7 } from 'json-schema'; +import { convertJSONSchemaToOpenAPISchema } from './convert-json-schema-to-openapi-schema'; + +it('should remove additionalProperties and $schema', () => { + const input: JSONSchema7 = { + $schema: 'http://json-schema.org/draft-07/schema#', + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' }, + }, + additionalProperties: false, + }; + + const expected = { + type: 'object', + properties: { + name: { type: 'string' }, + age: { type: 'number' }, + }, + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should handle nested objects and arrays', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + users: { + type: 'array', + items: { + type: 'object', + properties: { + id: { type: 'number' }, + name: { type: 'string' }, + }, + additionalProperties: false, + }, + }, + }, + additionalProperties: false, + }; + + const expected = { + type: 'object', + properties: { + users: { + type: 'array', + items: { + type: 'object', + properties: { + id: { type: 'number' }, + name: { type: 'string' }, + }, + }, + }, + }, + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should convert "const" to "enum" with a single value', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + status: { const: 'active' }, + }, + }; + + const expected = { + type: 'object', + properties: { + status: { enum: ['active'] }, + }, + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should handle allOf, anyOf, and oneOf', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + allOfProp: { allOf: [{ type: 'string' }, { minLength: 5 }] }, + anyOfProp: { anyOf: [{ type: 'string' }, { type: 'number' }] }, + oneOfProp: { oneOf: [{ type: 'boolean' }, { type: 'null' }] }, + }, + }; + + const expected = { + type: 'object', + properties: { + allOfProp: { + allOf: [{ type: 'string' }, { minLength: 5 }], + }, + anyOfProp: { + anyOf: [{ type: 'string' }, { type: 'number' }], + }, + oneOfProp: { + oneOf: [{ type: 'boolean' }, { type: 'null' }], + }, + }, + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should convert "format: date-time" to "format: date-time"', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + timestamp: { type: 'string', format: 'date-time' }, + }, + }; + + const expected = { + type: 'object', + properties: { + timestamp: { type: 'string', format: 'date-time' }, + }, + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should handle required properties', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + id: { type: 'number' }, + name: { type: 'string' }, + }, + required: ['id'], + }; + + const expected = { + type: 'object', + properties: { + id: { type: 'number' }, + name: { type: 'string' }, + }, + required: ['id'], + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should convert deeply nested "const" to "enum"', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + nested: { + type: 'object', + properties: { + deeplyNested: { + anyOf: [ + { + type: 'object', + properties: { + value: { + const: 'specific value', + }, + }, + }, + { + type: 'string', + }, + ], + }, + }, + }, + }, + }; + + const expected = { + type: 'object', + properties: { + nested: { + type: 'object', + properties: { + deeplyNested: { + anyOf: [ + { + type: 'object', + properties: { + value: { + enum: ['specific value'], + }, + }, + }, + { + type: 'string', + }, + ], + }, + }, + }, + }, + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should correctly convert a complex schema with nested const and anyOf', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + name: { + type: 'string', + }, + age: { + type: 'number', + }, + contact: { + anyOf: [ + { + type: 'object', + properties: { + type: { + type: 'string', + const: 'email', + }, + value: { + type: 'string', + }, + }, + required: ['type', 'value'], + additionalProperties: false, + }, + { + type: 'object', + properties: { + type: { + type: 'string', + const: 'phone', + }, + value: { + type: 'string', + }, + }, + required: ['type', 'value'], + additionalProperties: false, + }, + ], + }, + occupation: { + anyOf: [ + { + type: 'object', + properties: { + type: { + type: 'string', + const: 'employed', + }, + company: { + type: 'string', + }, + position: { + type: 'string', + }, + }, + required: ['type', 'company', 'position'], + additionalProperties: false, + }, + { + type: 'object', + properties: { + type: { + type: 'string', + const: 'student', + }, + school: { + type: 'string', + }, + grade: { + type: 'number', + }, + }, + required: ['type', 'school', 'grade'], + additionalProperties: false, + }, + { + type: 'object', + properties: { + type: { + type: 'string', + const: 'unemployed', + }, + }, + required: ['type'], + additionalProperties: false, + }, + ], + }, + }, + required: ['name', 'age', 'contact', 'occupation'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }; + + const expected = { + type: 'object', + properties: { + name: { + type: 'string', + }, + age: { + type: 'number', + }, + contact: { + anyOf: [ + { + type: 'object', + properties: { + type: { + type: 'string', + enum: ['email'], + }, + value: { + type: 'string', + }, + }, + required: ['type', 'value'], + }, + { + type: 'object', + properties: { + type: { + type: 'string', + enum: ['phone'], + }, + value: { + type: 'string', + }, + }, + required: ['type', 'value'], + }, + ], + }, + occupation: { + anyOf: [ + { + type: 'object', + properties: { + type: { + type: 'string', + enum: ['employed'], + }, + company: { + type: 'string', + }, + position: { + type: 'string', + }, + }, + required: ['type', 'company', 'position'], + }, + { + type: 'object', + properties: { + type: { + type: 'string', + enum: ['student'], + }, + school: { + type: 'string', + }, + grade: { + type: 'number', + }, + }, + required: ['type', 'school', 'grade'], + }, + { + type: 'object', + properties: { + type: { + type: 'string', + enum: ['unemployed'], + }, + }, + required: ['type'], + }, + ], + }, + }, + required: ['name', 'age', 'contact', 'occupation'], + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should handle null type correctly', () => { + const input: JSONSchema7 = { + type: 'object', + properties: { + nullableField: { + type: ['string', 'null'], + }, + explicitNullField: { + type: 'null', + }, + }, + }; + + const expected = { + type: 'object', + properties: { + nullableField: { + type: 'string', + nullable: true, + }, + explicitNullField: { + type: 'null', + }, + }, + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); + +it('should handle descriptions', () => { + const input: JSONSchema7 = { + type: 'object', + description: 'A user object', + properties: { + id: { + type: 'number', + description: 'The user ID', + }, + name: { + type: 'string', + description: "The user's full name", + }, + email: { + type: 'string', + format: 'email', + description: "The user's email address", + }, + }, + required: ['id', 'name'], + }; + + const expected = { + type: 'object', + description: 'A user object', + properties: { + id: { + type: 'number', + description: 'The user ID', + }, + name: { + type: 'string', + description: "The user's full name", + }, + email: { + type: 'string', + format: 'email', + description: "The user's email address", + }, + }, + required: ['id', 'name'], + }; + + expect(convertJSONSchemaToOpenAPISchema(input)).toEqual(expected); +}); diff --git a/packages/google-vertex/src/convert-json-schema-to-openapi-schema.ts b/packages/google-vertex/src/convert-json-schema-to-openapi-schema.ts new file mode 100644 index 000000000000..fa55dd0d2fae --- /dev/null +++ b/packages/google-vertex/src/convert-json-schema-to-openapi-schema.ts @@ -0,0 +1,82 @@ +import { JSONSchema7Definition } from 'json-schema'; + +/** + * Converts JSON Schema 7 to OpenAPI Schema 3.0 + */ +export function convertJSONSchemaToOpenAPISchema( + jsonSchema: JSONSchema7Definition, +): unknown { + if (typeof jsonSchema === 'boolean') { + return { type: 'boolean', properties: {} }; + } + + const { + type, + description, + required, + properties, + items, + allOf, + anyOf, + oneOf, + format, + const: constValue, + minLength, + } = jsonSchema; + + const result: Record = {}; + + if (description) result.description = description; + if (required) result.required = required; + if (format) result.format = format; + + if (constValue !== undefined) { + result.enum = [constValue]; + } + + // Handle type + if (type) { + if (Array.isArray(type)) { + if (type.includes('null')) { + result.type = type.filter(t => t !== 'null')[0]; + result.nullable = true; + } else { + result.type = type; + } + } else if (type === 'null') { + result.type = 'null'; + } else { + result.type = type; + } + } + + if (properties) { + result.properties = Object.entries(properties).reduce( + (acc, [key, value]) => { + acc[key] = convertJSONSchemaToOpenAPISchema(value); + return acc; + }, + {} as Record, + ); + } + + if (items) { + result.items = Array.isArray(items) + ? items.map(convertJSONSchemaToOpenAPISchema) + : convertJSONSchemaToOpenAPISchema(items); + } + + if (allOf) { + result.allOf = allOf.map(convertJSONSchemaToOpenAPISchema); + } + if (anyOf) { + result.anyOf = anyOf.map(convertJSONSchemaToOpenAPISchema); + } + if (oneOf) { + result.oneOf = oneOf.map(convertJSONSchemaToOpenAPISchema); + } + + if (minLength !== undefined) result.minLength = minLength; + + return result; +} diff --git a/packages/google-vertex/src/google-vertex-language-model.test.ts b/packages/google-vertex/src/google-vertex-language-model.test.ts index 74330269ae72..bb8780bae783 100644 --- a/packages/google-vertex/src/google-vertex-language-model.test.ts +++ b/packages/google-vertex/src/google-vertex-language-model.test.ts @@ -165,21 +165,19 @@ describe('doGenerate', () => { description: '', name: 'test-tool', parameters: { - description: undefined, properties: { value: { - description: undefined, - required: undefined, - type: 'STRING', + type: 'string', }, }, required: ['value'], - type: 'OBJECT', + type: 'object', }, }, ], }, ], + toolConfig: undefined, safetySettings: undefined, }); @@ -243,6 +241,7 @@ describe('doGenerate', () => { stopSequences: ['abc', 'def'], }, tools: undefined, + toolConfig: undefined, safetySettings: undefined, }); }); @@ -274,6 +273,7 @@ describe('doGenerate', () => { topP: undefined, }, tools: [{ googleSearchRetrieval: {} }], + toolConfig: undefined, safetySettings: undefined, }); }); @@ -362,6 +362,84 @@ describe('doGenerate', () => { expect(response.text).toStrictEqual('{"value":"Spark"}'); }); + + it('should support object-tool mode', async () => { + const { model, mockVertexAI } = createModel({ + generateContent: prepareResponse({ + parts: [ + { + functionCall: { + name: 'test-tool', + args: { value: 'Spark' }, + }, + }, + ], + }), + }); + + const result = await model.doGenerate({ + inputFormat: 'prompt', + mode: { + type: 'object-tool', + tool: { + type: 'function', + name: 'test-tool', + description: 'test description', + parameters: { + type: 'object', + properties: { value: { type: 'string' } }, + required: ['value'], + additionalProperties: false, + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, + }, + prompt: TEST_PROMPT, + }); + + expect(mockVertexAI.lastModelParams).toStrictEqual({ + model: 'gemini-1.0-pro-002', + generationConfig: { + frequencyPenalty: undefined, + maxOutputTokens: undefined, + responseMimeType: undefined, + temperature: undefined, + topK: undefined, + topP: undefined, + stopSequences: undefined, + }, + tools: [ + { + functionDeclarations: [ + { + description: 'test description', + name: 'test-tool', + parameters: { + properties: { + value: { + type: 'string', + }, + }, + required: ['value'], + type: 'object', + }, + }, + ], + }, + ], + toolConfig: { functionCallingConfig: { mode: 'ANY' } }, + safetySettings: undefined, + }); + + expect(result.toolCalls).toStrictEqual([ + { + args: '{"value":"Spark"}', + toolCallId: 'test-id', + toolCallType: 'function', + toolName: 'test-tool', + }, + ]); + }); }); describe('doStream', () => { diff --git a/packages/google-vertex/src/google-vertex-language-model.ts b/packages/google-vertex/src/google-vertex-language-model.ts index b284b32e26eb..2879d34e8b52 100644 --- a/packages/google-vertex/src/google-vertex-language-model.ts +++ b/packages/google-vertex/src/google-vertex-language-model.ts @@ -5,25 +5,26 @@ import { LanguageModelV1FinishReason, LanguageModelV1StreamPart, NoContentGeneratedError, - UnsupportedFunctionalityError, } from '@ai-sdk/provider'; import { convertAsyncGeneratorToReadableStream } from '@ai-sdk/provider-utils'; import { + FunctionCallingMode, + FunctionDeclarationSchema, GenerateContentResponse, GenerationConfig, Part, SafetySetting, + Tool, + ToolConfig, VertexAI, - Tool as GoogleTool, } from '@google-cloud/vertexai'; +import { convertJSONSchemaToOpenAPISchema } from './convert-json-schema-to-openapi-schema'; import { convertToGoogleVertexContentRequest } from './convert-to-google-vertex-content-request'; import { GoogleVertexModelId, GoogleVertexSettings, } from './google-vertex-settings'; import { mapGoogleVertexFinishReason } from './map-google-vertex-finish-reason'; -import { prepareFunctionDeclarationSchema } from './prepare-function-declaration-schema'; - type GoogleVertexAIConfig = { vertexAI: VertexAI; generateId: () => string; @@ -32,7 +33,7 @@ type GoogleVertexAIConfig = { export class GoogleVertexLanguageModel implements LanguageModelV1 { readonly specificationVersion = 'v1'; readonly provider = 'google-vertex'; - readonly defaultObjectGenerationMode = 'json'; + readonly defaultObjectGenerationMode = 'tool'; readonly supportsImageUrls = false; readonly modelId: GoogleVertexModelId; @@ -105,10 +106,10 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 { switch (type) { case 'regular': { - const conf = { + const configuration = { model: this.modelId, generationConfig, - tools: prepareTools({ + ...prepareToolsAndToolConfig({ mode, useSearchGrounding: this.settings.useSearchGrounding ?? false, }), @@ -118,7 +119,7 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 { }; return { - model: this.config.vertexAI.getGenerativeModel(conf), + model: this.config.vertexAI.getGenerativeModel(configuration), contentRequest: convertToGoogleVertexContentRequest(prompt), warnings, }; @@ -142,9 +143,35 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 { } case 'object-tool': { - throw new UnsupportedFunctionalityError({ - functionality: 'object-tool mode', - }); + const configuration = { + model: this.modelId, + generationConfig, + tools: [ + { + functionDeclarations: [ + { + name: mode.tool.name, + description: mode.tool.description ?? '', + parameters: convertJSONSchemaToOpenAPISchema( + mode.tool.parameters, + ) as FunctionDeclarationSchema, + }, + ], + }, + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.ANY }, + }, + safetySettings: this.settings.safetySettings as + | undefined + | Array, + }; + + return { + model: this.config.vertexAI.getGenerativeModel(configuration), + contentRequest: convertToGoogleVertexContentRequest(prompt), + warnings, + }; } default: { @@ -291,7 +318,7 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 { } } -function prepareTools({ +function prepareToolsAndToolConfig({ useSearchGrounding, mode, }: { @@ -299,40 +326,82 @@ function prepareTools({ mode: Parameters[0]['mode'] & { type: 'regular'; }; -}): GoogleTool[] | undefined { +}): { + tools: Tool[] | undefined; + toolConfig: ToolConfig | undefined; +} { // when the tools array is empty, change it to undefined to prevent errors: const tools = mode.tools?.length ? mode.tools : undefined; - const toolChoice = mode.toolChoice; - if (toolChoice?.type === 'none') { - return undefined; + const mappedTools: Tool[] = + tools == null + ? [] + : [ + { + functionDeclarations: tools.map(tool => ({ + name: tool.name, + description: tool.description ?? '', + parameters: convertJSONSchemaToOpenAPISchema( + tool.parameters, + ) as FunctionDeclarationSchema, + })), + }, + ]; + + if (useSearchGrounding) { + mappedTools.push({ googleSearchRetrieval: {} }); } - if (toolChoice == null || toolChoice.type === 'auto') { - const mappedTools: GoogleTool[] = - tools != null - ? [ - { - functionDeclarations: tools.map(tool => ({ - name: tool.name, - description: tool.description ?? '', - parameters: prepareFunctionDeclarationSchema(tool.parameters), - })), - }, - ] - : []; + const finalTools = mappedTools.length > 0 ? mappedTools : undefined; - if (useSearchGrounding) { - mappedTools.push({ googleSearchRetrieval: {} }); - } + const toolChoice = mode.toolChoice; - return mappedTools.length > 0 ? mappedTools : undefined; + if (toolChoice == null) { + return { + tools: finalTools, + toolConfig: undefined, + }; } - // forcing tool calls or a specific tool call is not supported by Vertex: - throw new UnsupportedFunctionalityError({ - functionality: `toolChoice: ${toolChoice.type}`, - }); + const type = toolChoice.type; + + switch (type) { + case 'auto': + return { + tools: finalTools, + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.AUTO }, + }, + }; + case 'none': + return { + tools: finalTools, + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE }, + }, + }; + case 'required': + return { + tools: finalTools, + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.ANY }, + }, + }; + case 'tool': + return { + tools: finalTools, + toolConfig: { + functionCallingConfig: { + mode: FunctionCallingMode.ANY, + allowedFunctionNames: [toolChoice.toolName], + }, + }, + }; + default: { + const _exhaustiveCheck: never = type; + throw new Error(`Unsupported tool choice type: ${_exhaustiveCheck}`); + } + } } function getToolCallsFromParts({ diff --git a/packages/google-vertex/src/prepare-function-declaration-schema.test.ts b/packages/google-vertex/src/prepare-function-declaration-schema.test.ts deleted file mode 100644 index a4594c346e75..000000000000 --- a/packages/google-vertex/src/prepare-function-declaration-schema.test.ts +++ /dev/null @@ -1,282 +0,0 @@ -import { - FunctionDeclarationSchema, - FunctionDeclarationSchemaType, -} from '@google-cloud/vertexai'; -import { JSONSchema7 } from 'json-schema'; -import { prepareFunctionDeclarationSchema } from './prepare-function-declaration-schema'; - -it('should convert a string property', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - testProperty: { type: 'string' }, - }, - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - testProperty: { type: FunctionDeclarationSchemaType.STRING }, - }, - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert number property', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - testProperty: { type: 'number' }, - }, - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - testProperty: { type: FunctionDeclarationSchemaType.NUMBER }, - }, - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert integer property', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - testProperty: { type: 'integer' }, - }, - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - testProperty: { type: FunctionDeclarationSchemaType.INTEGER }, - }, - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert boolean property', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - testProperty: { type: 'boolean' }, - }, - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - testProperty: { type: FunctionDeclarationSchemaType.BOOLEAN }, - }, - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert property description', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - testProperty: { type: 'string', description: 'test-description' }, - }, - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - testProperty: { - type: FunctionDeclarationSchemaType.STRING, - description: 'test-description', - }, - }, - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert an object type with several properties', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - name: { type: 'string' }, - age: { type: 'integer' }, - }, - required: ['name'], - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - description: undefined, - properties: { - name: { type: FunctionDeclarationSchemaType.STRING }, - age: { type: FunctionDeclarationSchemaType.INTEGER }, - }, - required: ['name'], - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert a nested object type', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - name: { type: 'string' }, - address: { - type: 'object', - properties: { - street: { type: 'string' }, - city: { type: 'string' }, - }, - }, - }, - required: ['name'], - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - description: undefined, - properties: { - name: { - type: FunctionDeclarationSchemaType.STRING, - description: undefined, - }, - address: { - type: FunctionDeclarationSchemaType.OBJECT, - description: undefined, - properties: { - street: { - type: FunctionDeclarationSchemaType.STRING, - description: undefined, - properties: {}, - }, - city: { - type: FunctionDeclarationSchemaType.STRING, - description: undefined, - properties: {}, - }, - }, - }, - }, - required: ['name'], - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert a nested object type with description', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - name: { type: 'string' }, - address: { - type: 'object', - description: 'Address description', - properties: { - street: { type: 'string', description: 'Street description' }, - city: { type: 'string', description: 'City description' }, - }, - }, - }, - required: ['name'], - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - description: undefined, - properties: { - name: { - type: FunctionDeclarationSchemaType.STRING, - description: undefined, - }, - address: { - type: FunctionDeclarationSchemaType.OBJECT, - description: 'Address description', - properties: { - street: { - type: FunctionDeclarationSchemaType.STRING, - description: 'Street description', - properties: {}, - }, - city: { - type: FunctionDeclarationSchemaType.STRING, - description: 'City description', - properties: {}, - }, - }, - }, - }, - required: ['name'], - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert an array of strings', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - names: { - type: 'array', - items: { - type: 'string', - }, - }, - }, - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - names: { - type: FunctionDeclarationSchemaType.ARRAY, - items: { - type: FunctionDeclarationSchemaType.STRING, - properties: {}, - }, - }, - }, - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); - -it('should convert an array of objects', () => { - const jsonSchema: JSONSchema7 = { - type: 'object', - properties: { - people: { - type: 'array', - items: { - type: 'object', - properties: { - name: { type: 'string' }, - age: { type: 'integer' }, - }, - }, - }, - }, - }; - - const expected: FunctionDeclarationSchema = { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - people: { - type: FunctionDeclarationSchemaType.ARRAY, - items: { - type: FunctionDeclarationSchemaType.OBJECT, - properties: { - name: { type: FunctionDeclarationSchemaType.STRING }, - age: { type: FunctionDeclarationSchemaType.INTEGER }, - }, - }, - }, - }, - }; - - expect(prepareFunctionDeclarationSchema(jsonSchema)).toEqual(expected); -}); diff --git a/packages/google-vertex/src/prepare-function-declaration-schema.ts b/packages/google-vertex/src/prepare-function-declaration-schema.ts deleted file mode 100644 index 7c891cd8d704..000000000000 --- a/packages/google-vertex/src/prepare-function-declaration-schema.ts +++ /dev/null @@ -1,134 +0,0 @@ -import { UnsupportedFunctionalityError } from '@ai-sdk/provider'; -import { - FunctionDeclarationSchema, - FunctionDeclarationSchemaProperty, - FunctionDeclarationSchemaType, -} from '@google-cloud/vertexai'; -import { JSONSchema7Definition } from 'json-schema'; - -const primitiveTypes = { - string: FunctionDeclarationSchemaType.STRING, - number: FunctionDeclarationSchemaType.NUMBER, - integer: FunctionDeclarationSchemaType.INTEGER, - boolean: FunctionDeclarationSchemaType.BOOLEAN, -}; - -/** -Converts the tool parameters JSON schema to the format required by Vertex AI. - */ -export function prepareFunctionDeclarationSchema( - jsonSchema: JSONSchema7Definition, -): FunctionDeclarationSchema { - if (typeof jsonSchema === 'boolean') { - return { - type: FunctionDeclarationSchemaType.BOOLEAN, - properties: {}, - }; - } - - const type = jsonSchema.type; - switch (type) { - case 'number': - case 'integer': - case 'boolean': - case 'string': - return { - type: primitiveTypes[type], - description: jsonSchema.description, - required: jsonSchema.required, - properties: {}, - }; - - case 'object': - return { - type: FunctionDeclarationSchemaType.OBJECT, - properties: Object.entries(jsonSchema.properties ?? {}).reduce( - (acc, [key, value]) => { - acc[key] = prepareFunctionDeclarationSchemaProperty(value); - return acc; - }, - {} as Record, - ), - description: jsonSchema.description, - required: jsonSchema.required, - }; - - case 'array': - throw new UnsupportedFunctionalityError({ - functionality: - 'arrays are not supported as root or as array parameters', - }); - - default: { - throw new UnsupportedFunctionalityError({ - functionality: `json schema type: ${type}`, - }); - } - } -} - -function prepareFunctionDeclarationSchemaProperty( - jsonSchema: JSONSchema7Definition, -): FunctionDeclarationSchemaProperty { - if (typeof jsonSchema === 'boolean') { - return { - type: FunctionDeclarationSchemaType.BOOLEAN, - }; - } - - const type = jsonSchema.type; - - switch (type) { - // primitive types: - case 'number': - case 'integer': - case 'boolean': - case 'string': { - return { - type: primitiveTypes[type], - description: jsonSchema.description, - required: jsonSchema.required, - }; - } - // array: - case 'array': { - const items = jsonSchema.items; - - if (items == null) { - throw new UnsupportedFunctionalityError({ - functionality: - 'Array without items is not supported in tool parameters', - }); - } - - if (Array.isArray(items)) { - throw new UnsupportedFunctionalityError({ - functionality: 'Tuple arrays are not supported in tool parameters', - }); - } - - return { - type: FunctionDeclarationSchemaType.ARRAY, - description: jsonSchema.description, - required: jsonSchema.required, - items: prepareFunctionDeclarationSchema(items), - }; - } - // nested object: - case 'object': - return { - type: FunctionDeclarationSchemaType.OBJECT, - properties: Object.entries(jsonSchema.properties ?? {}).reduce( - (acc, [key, value]) => { - acc[key] = prepareFunctionDeclarationSchema(value); - return acc; - }, - {} as Record, - ), - description: jsonSchema.description, - required: jsonSchema.required, - }; - default: - throw new Error(`Unsupported type: ${type}`); - } -}