Skip to content

Commit

Permalink
feat (provider/vertex): add schema support (#3147)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Sep 27, 2024
1 parent 988707c commit 465189a
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 18 deletions.
5 changes: 5 additions & 0 deletions .changeset/chilly-cows-taste.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/google-vertex': patch
---

feat (provider/vertex): add schema support
55 changes: 49 additions & 6 deletions content/providers/01-ai-sdk-providers/11-google-vertex.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ const model = vertex('gemini-1.5-pro', {

The following optional settings are available for Google Vertex models:

- **structuredOutputs** _boolean_

Optional. Enable structured output. Default is true.

This is useful when the JSON Schema contains elements that are
not supported by the OpenAPI schema version that
Google Vertex uses. You can use this to disable
structured outputs if you need to.

See [Troubleshooting: Schema Limitations](#troubleshooting-schema-limitations) for more details.

- **safetySettings** _Array\<\{ category: string; threshold: string \}\>_

Optional. Safety settings for the model.
Expand Down Expand Up @@ -207,14 +218,46 @@ const { text } = await generateText({
});
```

### Troubleshooting: Schema Limitations

The Google Vertex API uses a subset of the OpenAPI 3.0 schema,
which does not support features such as unions.
The errors that you get in this case look like this:

`GenerateContentRequest.generation_config.response_schema.properties[occupation].type: must be specified`

By default, structured outputs are enabled (and for tool calling they are required).
You can disable structured outputs for object generation as a workaround:

```ts highlight="3,8"
const result = await generateObject({
model: vertex('gemini-1.5-pro', {
structuredOutputs: false,
}),
schema: z.object({
name: z.string(),
age: z.number(),
contact: z.union([
z.object({
type: z.literal('email'),
value: z.string(),
}),
z.object({
type: z.literal('phone'),
value: z.string(),
}),
]),
}),
prompt: 'Generate an example person for testing.',
});
```

### Model Capabilities

| Model | Image Input | Object Generation | Tool Usage | Tool Streaming |
| ----------------------- | ------------------- | ------------------- | ------------------- | ------------------- |
| `gemini-1.5-flash` | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
| `gemini-1.5-pro` | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
| `gemini-1.0-pro-vision` | <Check size={18} /> | <Check size={18} /> | <Cross size={18} /> | <Cross size={18} /> |
| `gemini-1.0-pro` | <Cross size={18} /> | <Check size={18} /> | <Cross size={18} /> | <Cross size={18} /> |
| Model | Image Input | Object Generation | Tool Usage | Tool Streaming |
| ------------------ | ------------------- | ------------------- | ------------------- | ------------------- |
| `gemini-1.5-flash` | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |
| `gemini-1.5-pro` | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> | <Check size={18} /> |

<Note>
The table above lists popular models. You can also pass any available provider
Expand Down
31 changes: 31 additions & 0 deletions examples/ai-core/src/generate-object/google-vertex-tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import { vertex } from '@ai-sdk/google-vertex';
import { generateObject } from 'ai';
import 'dotenv/config';
import { z } from 'zod';

async function main() {
const result = await generateObject({
model: vertex('gemini-1.5-pro'),
mode: 'tool',
schema: z.object({
recipe: z.object({
name: z.string(),
ingredients: z.array(
z.object({
name: z.string(),
amount: z.string(),
}),
),
steps: z.array(z.string()),
}),
}),
prompt: 'Generate a lasagna recipe.',
});

console.log(JSON.stringify(result.object.recipe, null, 2));
console.log();
console.log('Token usage:', result.usage);
console.log('Finish reason:', result.finishReason);
}

main().catch(console.error);
82 changes: 71 additions & 11 deletions packages/google-vertex/src/google-vertex-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ describe('doGenerate', () => {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: undefined,
responseSchema: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
Expand Down Expand Up @@ -234,6 +235,7 @@ describe('doGenerate', () => {
generationConfig: {
maxOutputTokens: 100,
responseMimeType: undefined,
responseSchema: undefined,
temperature: 0.5,
topK: 0.1,
topP: 0.9,
Expand Down Expand Up @@ -267,6 +269,7 @@ describe('doGenerate', () => {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: undefined,
responseSchema: undefined,
stopSequences: undefined,
temperature: undefined,
topK: undefined,
Expand Down Expand Up @@ -321,46 +324,102 @@ describe('doGenerate', () => {
});
});

it('should set name & description in object-json mode', async () => {
it('should pass specification in object-json mode with structuredOutputs = true (default)', async () => {
const { model, mockVertexAI } = createModel({
modelId: 'test-model',
generateContent: prepareResponse({
parts: [{ text: '{"value":"Spark"}' }],
text: '{"property1":"value1","property2":"value2"}',
}),
});

const response = await model.doGenerate({
const result = await model.doGenerate({
inputFormat: 'prompt',
mode: {
type: 'object-json',
name: 'test-name',
description: 'test description',
schema: {
type: 'object',
properties: { value: { type: 'string' } },
required: ['value'],
properties: {
property1: { type: 'string' },
property2: { type: 'number' },
},
required: ['property1', 'property2'],
additionalProperties: false,
},
},
prompt: TEST_PROMPT,
});

expect(mockVertexAI.lastModelParams).toStrictEqual({
generationConfig: {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: 'application/json',
responseSchema: {
properties: {
property1: { type: 'string' },
property2: { type: 'number' },
},
required: ['property1', 'property2'],
type: 'object',
},
stopSequences: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
},
model: 'gemini-1.0-pro-002',
safetySettings: undefined,
});

expect(result.text).toStrictEqual(
'{"property1":"value1","property2":"value2"}',
);
});

it('should not pass specification in object-json mode with structuredOutputs = false', async () => {
const { model, mockVertexAI } = createModel({
generateContent: prepareResponse({
text: '{"property1":"value1","property2":"value2"}',
}),
settings: {
structuredOutputs: false,
},
});

const result = await model.doGenerate({
inputFormat: 'prompt',
mode: {
type: 'object-json',
schema: {
type: 'object',
properties: {
property1: { type: 'string' },
property2: { type: 'number' },
},
required: ['property1', 'property2'],
additionalProperties: false,
$schema: 'http://json-schema.org/draft-07/schema#',
},
},
prompt: TEST_PROMPT,
});

expect(mockVertexAI.lastModelParams).toStrictEqual({
model: 'test-model',
generationConfig: {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: 'application/json',
responseSchema: undefined,
stopSequences: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
},
model: 'gemini-1.0-pro-002',
safetySettings: undefined,
});

expect(response.text).toStrictEqual('{"value":"Spark"}');
expect(result.text).toStrictEqual(
'{"property1":"value1","property2":"value2"}',
);
});

it('should support object-tool mode', async () => {
Expand Down Expand Up @@ -403,6 +462,7 @@ describe('doGenerate', () => {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: undefined,
responseSchema: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
Expand Down
28 changes: 27 additions & 1 deletion packages/google-vertex/src/google-vertex-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
GenerateContentResponse,
GenerationConfig,
Part,
ResponseSchema,
SafetySetting,
Tool,
ToolConfig,
Expand All @@ -33,9 +34,13 @@ type GoogleVertexAIConfig = {
export class GoogleVertexLanguageModel implements LanguageModelV1 {
readonly specificationVersion = 'v1';
readonly provider = 'google-vertex';
readonly defaultObjectGenerationMode = 'tool';
readonly defaultObjectGenerationMode = 'json';
readonly supportsImageUrls = false;

get supportsObjectGeneration() {
return this.settings.structuredOutputs !== false;
}

readonly modelId: GoogleVertexModelId;
readonly settings: GoogleVertexSettings;

Expand Down Expand Up @@ -98,8 +103,20 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 {
temperature,
topP,
stopSequences,

// response format:
responseMimeType:
responseFormat?.type === 'json' ? 'application/json' : undefined,
responseSchema:
responseFormat?.type === 'json' &&
responseFormat.schema != null &&
// Google Vertex does not support all OpenAPI Schema features,
// so this is needed as an escape hatch:
this.supportsObjectGeneration
? (convertJSONSchemaToOpenAPISchema(
responseFormat.schema,
) as ResponseSchema)
: undefined,
};

const type = mode.type;
Expand Down Expand Up @@ -132,6 +149,15 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 {
generationConfig: {
...generationConfig,
responseMimeType: 'application/json',
responseSchema:
mode.schema != null &&
// Google Vertex does not support all OpenAPI Schema features,
// so this is needed as an escape hatch:
this.supportsObjectGeneration
? (convertJSONSchemaToOpenAPISchema(
mode.schema,
) as ResponseSchema)
: undefined,
},
safetySettings: this.settings.safetySettings as
| undefined
Expand Down
10 changes: 10 additions & 0 deletions packages/google-vertex/src/google-vertex-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ Models running with nucleus sampling don't allow topK setting.
*/
topK?: number;

/**
* Optional. Enable structured output. Default is true.
*
* This is useful when the JSON Schema contains elements that are
* not supported by the OpenAPI schema version that
* Google Generative AI uses. You can use this to disable
* structured outputs if you need to.
*/
structuredOutputs?: boolean;

/**
Optional. A list of unique safety settings for blocking unsafe content.
*/
Expand Down

0 comments on commit 465189a

Please sign in to comment.