Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Sep 27, 2024
1 parent 988707c commit b8b93ff
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 12 deletions.
82 changes: 71 additions & 11 deletions packages/google-vertex/src/google-vertex-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ describe('doGenerate', () => {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: undefined,
responseSchema: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
Expand Down Expand Up @@ -234,6 +235,7 @@ describe('doGenerate', () => {
generationConfig: {
maxOutputTokens: 100,
responseMimeType: undefined,
responseSchema: undefined,
temperature: 0.5,
topK: 0.1,
topP: 0.9,
Expand Down Expand Up @@ -267,6 +269,7 @@ describe('doGenerate', () => {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: undefined,
responseSchema: undefined,
stopSequences: undefined,
temperature: undefined,
topK: undefined,
Expand Down Expand Up @@ -321,46 +324,102 @@ describe('doGenerate', () => {
});
});

it('should set name & description in object-json mode', async () => {
it('should pass specification in object-json mode with structuredOutputs = true (default)', async () => {
const { model, mockVertexAI } = createModel({
modelId: 'test-model',
generateContent: prepareResponse({
parts: [{ text: '{"value":"Spark"}' }],
text: '{"property1":"value1","property2":"value2"}',
}),
});

const response = await model.doGenerate({
const result = await model.doGenerate({
inputFormat: 'prompt',
mode: {
type: 'object-json',
name: 'test-name',
description: 'test description',
schema: {
type: 'object',
properties: { value: { type: 'string' } },
required: ['value'],
properties: {
property1: { type: 'string' },
property2: { type: 'number' },
},
required: ['property1', 'property2'],
additionalProperties: false,
},
},
prompt: TEST_PROMPT,
});

expect(mockVertexAI.lastModelParams).toStrictEqual({
generationConfig: {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: 'application/json',
responseSchema: {
properties: {
property1: { type: 'string' },
property2: { type: 'number' },
},
required: ['property1', 'property2'],
type: 'object',
},
stopSequences: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
},
model: 'gemini-1.0-pro-002',
safetySettings: undefined,
});

expect(result.text).toStrictEqual(
'{"property1":"value1","property2":"value2"}',
);
});

it('should not pass specification in object-json mode with structuredOutputs = false', async () => {
const { model, mockVertexAI } = createModel({
generateContent: prepareResponse({
text: '{"property1":"value1","property2":"value2"}',
}),
settings: {
structuredOutputs: false,
},
});

const result = await model.doGenerate({
inputFormat: 'prompt',
mode: {
type: 'object-json',
schema: {
type: 'object',
properties: {
property1: { type: 'string' },
property2: { type: 'number' },
},
required: ['property1', 'property2'],
additionalProperties: false,
$schema: 'http://json-schema.org/draft-07/schema#',
},
},
prompt: TEST_PROMPT,
});

expect(mockVertexAI.lastModelParams).toStrictEqual({
model: 'test-model',
generationConfig: {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: 'application/json',
responseSchema: undefined,
stopSequences: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
},
model: 'gemini-1.0-pro-002',
safetySettings: undefined,
});

expect(response.text).toStrictEqual('{"value":"Spark"}');
expect(result.text).toStrictEqual(
'{"property1":"value1","property2":"value2"}',
);
});

it('should support object-tool mode', async () => {
Expand Down Expand Up @@ -403,6 +462,7 @@ describe('doGenerate', () => {
frequencyPenalty: undefined,
maxOutputTokens: undefined,
responseMimeType: undefined,
responseSchema: undefined,
temperature: undefined,
topK: undefined,
topP: undefined,
Expand Down
28 changes: 27 additions & 1 deletion packages/google-vertex/src/google-vertex-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
GenerateContentResponse,
GenerationConfig,
Part,
ResponseSchema,
SafetySetting,
Tool,
ToolConfig,
Expand All @@ -33,9 +34,13 @@ type GoogleVertexAIConfig = {
export class GoogleVertexLanguageModel implements LanguageModelV1 {
readonly specificationVersion = 'v1';
readonly provider = 'google-vertex';
readonly defaultObjectGenerationMode = 'tool';
readonly defaultObjectGenerationMode = 'json';
readonly supportsImageUrls = false;

get supportsObjectGeneration() {
return this.settings.structuredOutputs !== false;
}

readonly modelId: GoogleVertexModelId;
readonly settings: GoogleVertexSettings;

Expand Down Expand Up @@ -98,8 +103,20 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 {
temperature,
topP,
stopSequences,

// response format:
responseMimeType:
responseFormat?.type === 'json' ? 'application/json' : undefined,
responseSchema:
responseFormat?.type === 'json' &&
responseFormat.schema != null &&
// Google Vertex does not support all OpenAPI Schema features,
// so this is needed as an escape hatch:
this.supportsObjectGeneration
? (convertJSONSchemaToOpenAPISchema(
responseFormat.schema,
) as ResponseSchema)
: undefined,
};

const type = mode.type;
Expand Down Expand Up @@ -132,6 +149,15 @@ export class GoogleVertexLanguageModel implements LanguageModelV1 {
generationConfig: {
...generationConfig,
responseMimeType: 'application/json',
responseSchema:
mode.schema != null &&
// Google Vertex does not support all OpenAPI Schema features,
// so this is needed as an escape hatch:
this.supportsObjectGeneration
? (convertJSONSchemaToOpenAPISchema(
mode.schema,
) as ResponseSchema)
: undefined,
},
safetySettings: this.settings.safetySettings as
| undefined
Expand Down
10 changes: 10 additions & 0 deletions packages/google-vertex/src/google-vertex-settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ Models running with nucleus sampling don't allow topK setting.
*/
topK?: number;

/**
* Optional. Enable structured output. Default is true.
*
* This is useful when the JSON Schema contains elements that are
* not supported by the OpenAPI schema version that
* Google Generative AI uses. You can use this to disable
* structured outputs if you need to.
*/
structuredOutputs?: boolean;

/**
Optional. A list of unique safety settings for blocking unsafe content.
*/
Expand Down

0 comments on commit b8b93ff

Please sign in to comment.