Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(@ai-sdk/azure): Don't submit logprobs to azure if not specified #2025

Closed
wants to merge 8 commits into from
28 changes: 28 additions & 0 deletions examples/ai-core/src/stream-text/azure-fullstream-logprobs.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import { azure } from '@ai-sdk/azure';
import { streamText } from 'ai';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const result = await streamText({
model: azure('gpt-4o', { logprobs: true }),
prompt: 'Invent a new holiday and describe its traditions.',
});

for await (const part of result.fullStream) {
switch (part.type) {
case 'text-delta': {
console.log('Text delta:', part.textDelta);
break;
}

case 'finish': {
console.log(`finishReason: ${part.finishReason}`);
console.log('Logprobs:', part.logprobs); // object: { string, number, array}
}
}
}
}

main().catch(console.error);
2 changes: 1 addition & 1 deletion packages/openai/src/openai-chat-language-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ describe('doGenerate', () => {
expect(await server.getRequestBodyJson()).toStrictEqual({
model: 'gpt-3.5-turbo',
messages: [{ role: 'user', content: 'Hello' }],
logprobs: true,
logprobs: 2,
top_logprobs: 2,
logit_bias: { 50256: -100 },
parallel_tool_calls: false,
Expand Down
31 changes: 29 additions & 2 deletions packages/openai/src/openai-chat-language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
// model specific settings:
logit_bias: this.settings.logitBias,
logprobs:
this.settings.logprobs === true ||
typeof this.settings.logprobs === 'number',
this.settings.logprobs === undefined
? false
: (this.settings.logprobs === this.settings.logprobs) === null ||
this.settings.logprobs === true ||
typeof this.settings.logprobs === 'number'
? this.settings.logprobs
: false,
top_logprobs:
typeof this.settings.logprobs === 'number'
? this.settings.logprobs
Expand Down Expand Up @@ -146,6 +151,17 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
): Promise<Awaited<ReturnType<LanguageModelV1['doGenerate']>>> {
const args = this.getArgs(options);

// For azure, only include logprobs if it's defined. (#2024):
if (
this.config.compatibility === 'compatible' &&
this.provider === 'azure-openai.chat' &&
(args.logprobs === undefined || args.logprobs === false)
) {
if ('logprobs' in args) {
delete (args as { logprobs?: boolean | number }).logprobs;
}
}

const { responseHeaders, value: response } = await postJsonToApi({
url: this.config.url({
path: '/chat/completions',
Expand Down Expand Up @@ -189,6 +205,17 @@ export class OpenAIChatLanguageModel implements LanguageModelV1 {
): Promise<Awaited<ReturnType<LanguageModelV1['doStream']>>> {
const args = this.getArgs(options);

// For azure, only include logprobs if it's defined. (#2024):
if (
this.config.compatibility === 'compatible' &&
this.provider === 'azure-openai.chat' &&
(args.logprobs === undefined || args.logprobs === false)
) {
if ('logprobs' in args) {
delete (args as { logprobs?: boolean | number }).logprobs;
}
}

const { responseHeaders, value: response } = await postJsonToApi({
url: this.config.url({
path: '/chat/completions',
Expand Down
Loading