From cc16a830c65073d60d5403835c9a8d762589f066 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Thu, 12 Dec 2024 15:43:49 +0100 Subject: [PATCH] feat (ai/core): add experimental transform option to streamText (#4074) --- .changeset/giant-ducks-live.md | 5 + .changeset/kind-cougars-check.md | 5 + .../03-ai-sdk-core/05-generating-text.mdx | 18 +++ .../01-ai-sdk-core/02-stream-text.mdx | 6 + .../01-ai-sdk-core/80-smooth-stream.mdx | 52 +++++++ .../07-reference/01-ai-sdk-core/index.mdx | 5 + .../src/stream-text/anthropic-smooth.ts | 21 +++ .../ai-core/src/stream-text/azure-smooth.ts | 21 +++ .../data-stream/create-data-stream.test.ts | 5 +- packages/ai/core/generate-text/index.ts | 1 + .../generate-text/parse-tool-call.test.ts | 2 - .../core/generate-text/smooth-stream.test.ts | 127 ++++++++++++++++++ .../ai/core/generate-text/smooth-stream.ts | 52 +++++++ .../ai/core/generate-text/stream-text.test.ts | 36 +++++ packages/ai/core/generate-text/stream-text.ts | 43 ++++-- .../ai/core/util/create-stitchable-stream.ts | 6 +- 16 files changed, 391 insertions(+), 14 deletions(-) create mode 100644 .changeset/giant-ducks-live.md create mode 100644 .changeset/kind-cougars-check.md create mode 100644 content/docs/07-reference/01-ai-sdk-core/80-smooth-stream.mdx create mode 100644 examples/ai-core/src/stream-text/anthropic-smooth.ts create mode 100644 examples/ai-core/src/stream-text/azure-smooth.ts create mode 100644 packages/ai/core/generate-text/smooth-stream.test.ts create mode 100644 packages/ai/core/generate-text/smooth-stream.ts diff --git a/.changeset/giant-ducks-live.md b/.changeset/giant-ducks-live.md new file mode 100644 index 000000000000..bb2c596ef8cb --- /dev/null +++ b/.changeset/giant-ducks-live.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat (ai/core): add smoothStream helper diff --git a/.changeset/kind-cougars-check.md b/.changeset/kind-cougars-check.md new file mode 100644 index 000000000000..01174adfe307 --- /dev/null +++ b/.changeset/kind-cougars-check.md @@ -0,0 +1,5 @@ +--- +'ai': patch +--- + +feat (ai/core): add experimental transform option to streamText diff --git a/content/docs/03-ai-sdk-core/05-generating-text.mdx b/content/docs/03-ai-sdk-core/05-generating-text.mdx index d128aa045921..294c100bee1c 100644 --- a/content/docs/03-ai-sdk-core/05-generating-text.mdx +++ b/content/docs/03-ai-sdk-core/05-generating-text.mdx @@ -200,6 +200,24 @@ for await (const part of result.fullStream) { } ``` +### Stream transformation + +You can use the `experimental_transform` option to transform the stream. +This is useful for e.g. filtering, changing, or smoothing the text stream. + +The AI SDK Core provides a [`smoothStream` function](/docs/reference/ai-sdk-core/smooth-stream) that +can be used to smooth out text streaming. + +```tsx highlight="6" +import { smoothStream, streamText } from 'ai'; + +const result = streamText({ + model, + prompt, + experimental_transform: smoothStream(), +}); +``` + ## Generating Long Text Most language models have an output limit that is much shorter than their context window. diff --git a/content/docs/07-reference/01-ai-sdk-core/02-stream-text.mdx b/content/docs/07-reference/01-ai-sdk-core/02-stream-text.mdx index e4274899578d..3a35592f3b36 100644 --- a/content/docs/07-reference/01-ai-sdk-core/02-stream-text.mdx +++ b/content/docs/07-reference/01-ai-sdk-core/02-stream-text.mdx @@ -469,6 +469,12 @@ To see `streamText` in action, check out [these examples](#examples). description: 'Enable streaming of tool call deltas as they are generated. Disabled by default.', }, + { + name: 'experimental_transform', + type: 'TransformStream, TextStreamPart>', + isOptional: true, + description: 'Optional transformation that is applied to the stream.', + }, { name: 'experimental_providerMetadata', type: 'Record> | undefined', diff --git a/content/docs/07-reference/01-ai-sdk-core/80-smooth-stream.mdx b/content/docs/07-reference/01-ai-sdk-core/80-smooth-stream.mdx new file mode 100644 index 000000000000..e0faac1a3085 --- /dev/null +++ b/content/docs/07-reference/01-ai-sdk-core/80-smooth-stream.mdx @@ -0,0 +1,52 @@ +--- +title: smoothStream +description: Helper function for smoothing text streaming output +--- + +# `smoothStream()` + +`smoothStream` is a utility function that creates a TransformStream +for the `streamText` `transform` option +to smooth out text streaming by buffering and releasing complete words with configurable delays. +This creates a more natural reading experience when streaming text responses. + +```ts highlight={"6-8"} +import { smoothStream, streamText } from 'ai'; + +const result = streamText({ + model, + prompt, + experimental_transform: smoothStream({ + delayInMs: 40, // optional: defaults to 40ms + }), +}); +``` + +## Import + + + +## API Signature + +### Parameters + + + +### Returns + +Returns a `TransformStream` that: + +- Buffers incoming text chunks +- Releases complete words when whitespace is encountered +- Adds configurable delays between words for smooth output +- Passes through non-text chunks (like step-finish events) immediately diff --git a/content/docs/07-reference/01-ai-sdk-core/index.mdx b/content/docs/07-reference/01-ai-sdk-core/index.mdx index cfa4098a71e6..51b276399ed7 100644 --- a/content/docs/07-reference/01-ai-sdk-core/index.mdx +++ b/content/docs/07-reference/01-ai-sdk-core/index.mdx @@ -81,5 +81,10 @@ It also contains the following helper functions: 'Calculates the cosine similarity between two vectors, e.g. embeddings.', href: '/docs/reference/ai-sdk-core/cosine-similarity', }, + { + title: 'smoothStream()', + description: 'Smooths text streaming output.', + href: '/docs/reference/ai-sdk-core/smooth-stream', + }, ]} /> diff --git a/examples/ai-core/src/stream-text/anthropic-smooth.ts b/examples/ai-core/src/stream-text/anthropic-smooth.ts new file mode 100644 index 000000000000..45d4c297cd88 --- /dev/null +++ b/examples/ai-core/src/stream-text/anthropic-smooth.ts @@ -0,0 +1,21 @@ +import { anthropic } from '@ai-sdk/anthropic'; +import { smoothStream, streamText } from 'ai'; +import 'dotenv/config'; + +async function main() { + const result = streamText({ + model: anthropic('claude-3-5-sonnet-20240620'), + prompt: 'Invent a new holiday and describe its traditions.', + experimental_transform: smoothStream(), + }); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } + + console.log(); + console.log('Token usage:', await result.usage); + console.log('Finish reason:', await result.finishReason); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/stream-text/azure-smooth.ts b/examples/ai-core/src/stream-text/azure-smooth.ts new file mode 100644 index 000000000000..12bb1570bd02 --- /dev/null +++ b/examples/ai-core/src/stream-text/azure-smooth.ts @@ -0,0 +1,21 @@ +import { azure } from '@ai-sdk/azure'; +import { smoothStream, streamText } from 'ai'; +import 'dotenv/config'; + +async function main() { + const result = streamText({ + model: azure('gpt-4o'), // use your own deployment + prompt: 'Invent a new holiday and describe its traditions.', + experimental_transform: smoothStream(), + }); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } + + console.log(); + console.log('Token usage:', await result.usage); + console.log('Finish reason:', await result.finishReason); +} + +main().catch(console.error); diff --git a/packages/ai/core/data-stream/create-data-stream.test.ts b/packages/ai/core/data-stream/create-data-stream.test.ts index c5ac18167f25..ea1421df1eb2 100644 --- a/packages/ai/core/data-stream/create-data-stream.test.ts +++ b/packages/ai/core/data-stream/create-data-stream.test.ts @@ -1,11 +1,10 @@ import { convertReadableStreamToArray } from '@ai-sdk/provider-utils/test'; import { formatDataStreamPart } from '@ai-sdk/ui-utils'; import { expect, it } from 'vitest'; -import { createDataStream } from './create-data-stream'; -import { DataStreamWriter } from './data-stream-writer'; import { delay } from '../../util/delay'; -import { createResolvablePromise } from '../../util/create-resolvable-promise'; import { DelayedPromise } from '../../util/delayed-promise'; +import { createDataStream } from './create-data-stream'; +import { DataStreamWriter } from './data-stream-writer'; describe('createDataStream', () => { it('should send single data json and close the stream', async () => { diff --git a/packages/ai/core/generate-text/index.ts b/packages/ai/core/generate-text/index.ts index 22646f4768f6..25bf7c854bbe 100644 --- a/packages/ai/core/generate-text/index.ts +++ b/packages/ai/core/generate-text/index.ts @@ -5,6 +5,7 @@ export type { StepResult } from './step-result'; export { streamText } from './stream-text'; export type { StreamTextResult, TextStreamPart } from './stream-text-result'; export type { ToolCallRepairFunction } from './tool-call-repair'; +export { smoothStream } from './smooth-stream'; // TODO 4.1: rename to ToolCall and ToolResult, deprecate old names export type { diff --git a/packages/ai/core/generate-text/parse-tool-call.test.ts b/packages/ai/core/generate-text/parse-tool-call.test.ts index 5461bb1bac02..a57b8c198cce 100644 --- a/packages/ai/core/generate-text/parse-tool-call.test.ts +++ b/packages/ai/core/generate-text/parse-tool-call.test.ts @@ -234,7 +234,5 @@ describe('tool call repair', () => { originalError: expect.any(InvalidToolArgumentsError), }); expect(repairToolCall).toHaveBeenCalledTimes(1); - - console.log('xxx'); }); }); diff --git a/packages/ai/core/generate-text/smooth-stream.test.ts b/packages/ai/core/generate-text/smooth-stream.test.ts new file mode 100644 index 000000000000..55b9bb93adc9 --- /dev/null +++ b/packages/ai/core/generate-text/smooth-stream.test.ts @@ -0,0 +1,127 @@ +import { describe, expect, it } from 'vitest'; +import { convertArrayToReadableStream } from '../../test'; +import { smoothStream } from './smooth-stream'; + +describe('smoothStream', () => { + it('should combine partial words', async () => { + const events: any[] = []; + + const stream = convertArrayToReadableStream([ + { textDelta: 'Hello', type: 'text-delta' }, + { textDelta: ', ', type: 'text-delta' }, + { textDelta: 'world!', type: 'text-delta' }, + { type: 'step-finish' }, + { type: 'finish' }, + ]).pipeThrough( + smoothStream({ + delayInMs: 10, + _internal: { + delay: () => { + events.push('delay'); + return Promise.resolve(); + }, + }, + }), + ); + + // Get a reader and read chunks + const reader = stream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) break; + events.push(value); + } + + expect(events).toEqual([ + 'delay', + { + textDelta: 'Hello, ', + type: 'text-delta', + }, + { + textDelta: 'world!', + type: 'text-delta', + }, + { + type: 'step-finish', + }, + { + type: 'finish', + }, + ]); + }); + + it('should split larger text chunks', async () => { + const events: any[] = []; + + const stream = convertArrayToReadableStream([ + { + textDelta: 'Hello, World! This is an example text.', + type: 'text-delta', + }, + { type: 'step-finish' }, + { type: 'finish' }, + ]).pipeThrough( + smoothStream({ + delayInMs: 10, + _internal: { + delay: () => { + events.push('delay'); + return Promise.resolve(); + }, + }, + }), + ); + + // Get a reader and read chunks + const reader = stream.getReader(); + while (true) { + const { done, value } = await reader.read(); + if (done) break; + events.push(value); + } + + expect(events).toEqual([ + 'delay', + { + textDelta: 'Hello, ', + type: 'text-delta', + }, + 'delay', + { + textDelta: 'World! ', + type: 'text-delta', + }, + 'delay', + { + textDelta: 'This ', + type: 'text-delta', + }, + 'delay', + { + textDelta: 'is ', + type: 'text-delta', + }, + 'delay', + { + textDelta: 'an ', + type: 'text-delta', + }, + 'delay', + { + textDelta: 'example ', + type: 'text-delta', + }, + { + textDelta: 'text.', + type: 'text-delta', + }, + { + type: 'step-finish', + }, + { + type: 'finish', + }, + ]); + }); +}); diff --git a/packages/ai/core/generate-text/smooth-stream.ts b/packages/ai/core/generate-text/smooth-stream.ts new file mode 100644 index 000000000000..5617221a4e62 --- /dev/null +++ b/packages/ai/core/generate-text/smooth-stream.ts @@ -0,0 +1,52 @@ +import { delay as originalDelay } from '../../util/delay'; +import { CoreTool } from '../tool/tool'; +import { TextStreamPart } from './stream-text-result'; + +export function smoothStream>({ + delayInMs = 40, + _internal: { delay = originalDelay } = {}, +}: { + delayInMs?: number; + + /** + * Internal. For test use only. May change without notice. + */ + _internal?: { + delay?: (delayInMs: number) => Promise; + }; +} = {}): TransformStream, TextStreamPart> { + let buffer = ''; + + return new TransformStream, TextStreamPart>({ + async transform(chunk, controller) { + if (chunk.type === 'step-finish') { + if (buffer.length > 0) { + controller.enqueue({ type: 'text-delta', textDelta: buffer }); + buffer = ''; + } + + controller.enqueue(chunk); + return; + } + + if (chunk.type !== 'text-delta') { + controller.enqueue(chunk); + return; + } + + buffer += chunk.textDelta; + + // Stream out complete words when whitespace is found + while (buffer.match(/\s/)) { + const whitespaceIndex = buffer.search(/\s/); + const word = buffer.slice(0, whitespaceIndex + 1); + controller.enqueue({ type: 'text-delta', textDelta: word }); + buffer = buffer.slice(whitespaceIndex + 1); + + if (delayInMs > 0) { + await delay(delayInMs); + } + } + }, + }); +} diff --git a/packages/ai/core/generate-text/stream-text.test.ts b/packages/ai/core/generate-text/stream-text.test.ts index 942b47ac1fb8..95d7fc5c3f8d 100644 --- a/packages/ai/core/generate-text/stream-text.test.ts +++ b/packages/ai/core/generate-text/stream-text.test.ts @@ -3257,4 +3257,40 @@ describe('streamText', () => { ]); }); }); + + describe('options.transform', () => { + it('should transform the stream', async () => { + const result = streamText({ + model: new MockLanguageModelV1({ + doStream: async () => ({ + stream: convertArrayToReadableStream([ + { type: 'text-delta', textDelta: 'Hello' }, + { type: 'text-delta', textDelta: ', ' }, + { type: 'text-delta', textDelta: `world!` }, + { + type: 'finish', + finishReason: 'stop', + logprobs: undefined, + usage: { completionTokens: 10, promptTokens: 3 }, + }, + ]), + rawCall: { rawPrompt: 'prompt', rawSettings: {} }, + }), + }), + experimental_transform: new TransformStream({ + transform(chunk, controller) { + if (chunk.type === 'text-delta') { + chunk.textDelta = chunk.textDelta.toUpperCase(); + } + controller.enqueue(chunk); + }, + }), + prompt: 'test-input', + }); + + expect( + await convertAsyncIterableToArray(result.textStream), + ).toStrictEqual(['HELLO', ', ', 'WORLD!']); + }); + }); }); diff --git a/packages/ai/core/generate-text/stream-text.ts b/packages/ai/core/generate-text/stream-text.ts index 6c462e98ead1..54e9e8a38456 100644 --- a/packages/ai/core/generate-text/stream-text.ts +++ b/packages/ai/core/generate-text/stream-text.ts @@ -116,6 +116,7 @@ export function streamText>({ experimental_toolCallStreaming: toolCallStreaming = false, experimental_activeTools: activeTools, experimental_repairToolCall: repairToolCall, + experimental_transform: transform, onChunk, onFinish, onStepFinish, @@ -186,6 +187,14 @@ Enable streaming of tool call deltas as they are generated. Disabled by default. */ experimental_toolCallStreaming?: boolean; + /** +Optional transformation that is applied to the stream. + */ + experimental_transform?: TransformStream< + TextStreamPart, + TextStreamPart + >; + /** Callback that is called for each chunk of the stream. The stream processing will pause until the callback promise is resolved. */ @@ -245,6 +254,7 @@ Details for all steps. tools, toolChoice, toolCallStreaming, + transform, activeTools, repairToolCall, maxSteps, @@ -293,8 +303,13 @@ class DefaultStreamTextResult> Awaited['steps']> >(); - private readonly stitchableStream = - createStitchableStream>(); + private readonly addStream: ( + stream: ReadableStream>, + ) => void; + + private readonly closeStream: () => void; + + private baseStream: ReadableStream>; constructor({ model, @@ -309,6 +324,7 @@ class DefaultStreamTextResult> tools, toolChoice, toolCallStreaming, + transform, activeTools, repairToolCall, maxSteps, @@ -333,6 +349,9 @@ class DefaultStreamTextResult> tools: TOOLS | undefined; toolChoice: CoreToolChoice | undefined; toolCallStreaming: boolean; + transform: + | TransformStream, TextStreamPart> + | undefined; activeTools: Array | undefined; repairToolCall: ToolCallRepairFunction | undefined; maxSteps: number; @@ -375,6 +394,14 @@ class DefaultStreamTextResult> }); } + // initialize the stitchable stream and the transformed stream: + const stitchableStream = createStitchableStream>(); + this.addStream = stitchableStream.addStream; + this.closeStream = stitchableStream.close; + this.baseStream = transform + ? stitchableStream.stream.pipeThrough(transform) + : stitchableStream.stream; + const { maxRetries, retry } = prepareRetries({ maxRetries: maxRetriesArg, }); @@ -571,7 +598,7 @@ class DefaultStreamTextResult> await onChunk?.({ chunk }); } - self.stitchableStream.addStream( + self.addStream( transformedStream.pipeThrough( new TransformStream< SingleRequestTextStreamPart, @@ -892,7 +919,7 @@ class DefaultStreamTextResult> }); // close the stitchable stream - self.stitchableStream.close(); + self.closeStream(); // Add response information to the root span: rootSpan.setAttributes( @@ -977,7 +1004,7 @@ class DefaultStreamTextResult> }, }).catch(error => { // add an error stream part and close the streams: - self.stitchableStream.addStream( + self.addStream( new ReadableStream({ start(controller) { controller.enqueue({ type: 'error', error }); @@ -985,7 +1012,7 @@ class DefaultStreamTextResult> }, }), ); - self.stitchableStream.close(); + self.closeStream(); }); } @@ -1038,8 +1065,8 @@ Note: this leads to buffering the stream content on the server. However, the LLM results are expected to be small enough to not cause issues. */ private teeStream() { - const [stream1, stream2] = this.stitchableStream.stream.tee(); - this.stitchableStream.stream = stream2; + const [stream1, stream2] = this.baseStream.tee(); + this.baseStream = stream2; return stream1; } diff --git a/packages/ai/core/util/create-stitchable-stream.ts b/packages/ai/core/util/create-stitchable-stream.ts index 0f0182395676..9f470a81d4f6 100644 --- a/packages/ai/core/util/create-stitchable-stream.ts +++ b/packages/ai/core/util/create-stitchable-stream.ts @@ -6,7 +6,11 @@ import { createResolvablePromise } from '../../util/create-resolvable-promise'; * @template T - The type of values emitted by the streams. * @returns {Object} An object containing the stitchable stream and control methods. */ -export function createStitchableStream() { +export function createStitchableStream(): { + stream: ReadableStream; + addStream: (innerStream: ReadableStream) => void; + close: () => void; +} { let innerStreamReaders: ReadableStreamDefaultReader[] = []; let controller: ReadableStreamDefaultController | null = null; let isClosed = false;