Skip to content

Commit

Permalink
feat (ai/core): add experimental transform option to streamText (#4074)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Dec 12, 2024
1 parent 3ce210f commit cc16a83
Show file tree
Hide file tree
Showing 16 changed files with 391 additions and 14 deletions.
5 changes: 5 additions & 0 deletions .changeset/giant-ducks-live.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): add smoothStream helper
5 changes: 5 additions & 0 deletions .changeset/kind-cougars-check.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

feat (ai/core): add experimental transform option to streamText
18 changes: 18 additions & 0 deletions content/docs/03-ai-sdk-core/05-generating-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,24 @@ for await (const part of result.fullStream) {
}
```

### Stream transformation

You can use the `experimental_transform` option to transform the stream.
This is useful for e.g. filtering, changing, or smoothing the text stream.

The AI SDK Core provides a [`smoothStream` function](/docs/reference/ai-sdk-core/smooth-stream) that
can be used to smooth out text streaming.

```tsx highlight="6"
import { smoothStream, streamText } from 'ai';

const result = streamText({
model,
prompt,
experimental_transform: smoothStream(),
});
```

## Generating Long Text

Most language models have an output limit that is much shorter than their context window.
Expand Down
6 changes: 6 additions & 0 deletions content/docs/07-reference/01-ai-sdk-core/02-stream-text.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,12 @@ To see `streamText` in action, check out [these examples](#examples).
description:
'Enable streaming of tool call deltas as they are generated. Disabled by default.',
},
{
name: 'experimental_transform',
type: 'TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>',
isOptional: true,
description: 'Optional transformation that is applied to the stream.',
},
{
name: 'experimental_providerMetadata',
type: 'Record<string,Record<string,JSONValue>> | undefined',
Expand Down
52 changes: 52 additions & 0 deletions content/docs/07-reference/01-ai-sdk-core/80-smooth-stream.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
---
title: smoothStream
description: Helper function for smoothing text streaming output
---

# `smoothStream()`

`smoothStream` is a utility function that creates a TransformStream
for the `streamText` `transform` option
to smooth out text streaming by buffering and releasing complete words with configurable delays.
This creates a more natural reading experience when streaming text responses.

```ts highlight={"6-8"}
import { smoothStream, streamText } from 'ai';

const result = streamText({
model,
prompt,
experimental_transform: smoothStream({
delayInMs: 40, // optional: defaults to 40ms
}),
});
```

## Import

<Snippet text={`import { smoothStream } from "ai"`} prompt={false} />

## API Signature

### Parameters

<PropertiesTable
content={[
{
name: 'delayInMs',
type: 'number',
isOptional: true,
description:
'The delay in milliseconds between outputting each word. Defaults to 40ms. Set to 0 to disable delays.',
},
]}
/>

### Returns

Returns a `TransformStream` that:

- Buffers incoming text chunks
- Releases complete words when whitespace is encountered
- Adds configurable delays between words for smooth output
- Passes through non-text chunks (like step-finish events) immediately
5 changes: 5 additions & 0 deletions content/docs/07-reference/01-ai-sdk-core/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -81,5 +81,10 @@ It also contains the following helper functions:
'Calculates the cosine similarity between two vectors, e.g. embeddings.',
href: '/docs/reference/ai-sdk-core/cosine-similarity',
},
{
title: 'smoothStream()',
description: 'Smooths text streaming output.',
href: '/docs/reference/ai-sdk-core/smooth-stream',
},
]}
/>
21 changes: 21 additions & 0 deletions examples/ai-core/src/stream-text/anthropic-smooth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { anthropic } from '@ai-sdk/anthropic';
import { smoothStream, streamText } from 'ai';
import 'dotenv/config';

async function main() {
const result = streamText({
model: anthropic('claude-3-5-sonnet-20240620'),
prompt: 'Invent a new holiday and describe its traditions.',
experimental_transform: smoothStream(),
});

for await (const textPart of result.textStream) {
process.stdout.write(textPart);
}

console.log();
console.log('Token usage:', await result.usage);
console.log('Finish reason:', await result.finishReason);
}

main().catch(console.error);
21 changes: 21 additions & 0 deletions examples/ai-core/src/stream-text/azure-smooth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import { azure } from '@ai-sdk/azure';
import { smoothStream, streamText } from 'ai';
import 'dotenv/config';

async function main() {
const result = streamText({
model: azure('gpt-4o'), // use your own deployment
prompt: 'Invent a new holiday and describe its traditions.',
experimental_transform: smoothStream(),
});

for await (const textPart of result.textStream) {
process.stdout.write(textPart);
}

console.log();
console.log('Token usage:', await result.usage);
console.log('Finish reason:', await result.finishReason);
}

main().catch(console.error);
5 changes: 2 additions & 3 deletions packages/ai/core/data-stream/create-data-stream.test.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import { convertReadableStreamToArray } from '@ai-sdk/provider-utils/test';
import { formatDataStreamPart } from '@ai-sdk/ui-utils';
import { expect, it } from 'vitest';
import { createDataStream } from './create-data-stream';
import { DataStreamWriter } from './data-stream-writer';
import { delay } from '../../util/delay';
import { createResolvablePromise } from '../../util/create-resolvable-promise';
import { DelayedPromise } from '../../util/delayed-promise';
import { createDataStream } from './create-data-stream';
import { DataStreamWriter } from './data-stream-writer';

describe('createDataStream', () => {
it('should send single data json and close the stream', async () => {
Expand Down
1 change: 1 addition & 0 deletions packages/ai/core/generate-text/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export type { StepResult } from './step-result';
export { streamText } from './stream-text';
export type { StreamTextResult, TextStreamPart } from './stream-text-result';
export type { ToolCallRepairFunction } from './tool-call-repair';
export { smoothStream } from './smooth-stream';

// TODO 4.1: rename to ToolCall and ToolResult, deprecate old names
export type {
Expand Down
2 changes: 0 additions & 2 deletions packages/ai/core/generate-text/parse-tool-call.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,5 @@ describe('tool call repair', () => {
originalError: expect.any(InvalidToolArgumentsError),
});
expect(repairToolCall).toHaveBeenCalledTimes(1);

console.log('xxx');
});
});
127 changes: 127 additions & 0 deletions packages/ai/core/generate-text/smooth-stream.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import { describe, expect, it } from 'vitest';
import { convertArrayToReadableStream } from '../../test';
import { smoothStream } from './smooth-stream';

describe('smoothStream', () => {
it('should combine partial words', async () => {
const events: any[] = [];

const stream = convertArrayToReadableStream([
{ textDelta: 'Hello', type: 'text-delta' },
{ textDelta: ', ', type: 'text-delta' },
{ textDelta: 'world!', type: 'text-delta' },
{ type: 'step-finish' },
{ type: 'finish' },
]).pipeThrough(
smoothStream({
delayInMs: 10,
_internal: {
delay: () => {
events.push('delay');
return Promise.resolve();
},
},
}),
);

// Get a reader and read chunks
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) break;
events.push(value);
}

expect(events).toEqual([
'delay',
{
textDelta: 'Hello, ',
type: 'text-delta',
},
{
textDelta: 'world!',
type: 'text-delta',
},
{
type: 'step-finish',
},
{
type: 'finish',
},
]);
});

it('should split larger text chunks', async () => {
const events: any[] = [];

const stream = convertArrayToReadableStream([
{
textDelta: 'Hello, World! This is an example text.',
type: 'text-delta',
},
{ type: 'step-finish' },
{ type: 'finish' },
]).pipeThrough(
smoothStream({
delayInMs: 10,
_internal: {
delay: () => {
events.push('delay');
return Promise.resolve();
},
},
}),
);

// Get a reader and read chunks
const reader = stream.getReader();
while (true) {
const { done, value } = await reader.read();
if (done) break;
events.push(value);
}

expect(events).toEqual([
'delay',
{
textDelta: 'Hello, ',
type: 'text-delta',
},
'delay',
{
textDelta: 'World! ',
type: 'text-delta',
},
'delay',
{
textDelta: 'This ',
type: 'text-delta',
},
'delay',
{
textDelta: 'is ',
type: 'text-delta',
},
'delay',
{
textDelta: 'an ',
type: 'text-delta',
},
'delay',
{
textDelta: 'example ',
type: 'text-delta',
},
{
textDelta: 'text.',
type: 'text-delta',
},
{
type: 'step-finish',
},
{
type: 'finish',
},
]);
});
});
52 changes: 52 additions & 0 deletions packages/ai/core/generate-text/smooth-stream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import { delay as originalDelay } from '../../util/delay';
import { CoreTool } from '../tool/tool';
import { TextStreamPart } from './stream-text-result';

export function smoothStream<TOOLS extends Record<string, CoreTool>>({
delayInMs = 40,
_internal: { delay = originalDelay } = {},
}: {
delayInMs?: number;

/**
* Internal. For test use only. May change without notice.
*/
_internal?: {
delay?: (delayInMs: number) => Promise<void>;
};
} = {}): TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>> {
let buffer = '';

return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
async transform(chunk, controller) {
if (chunk.type === 'step-finish') {
if (buffer.length > 0) {
controller.enqueue({ type: 'text-delta', textDelta: buffer });
buffer = '';
}

controller.enqueue(chunk);
return;
}

if (chunk.type !== 'text-delta') {
controller.enqueue(chunk);
return;
}

buffer += chunk.textDelta;

// Stream out complete words when whitespace is found
while (buffer.match(/\s/)) {
const whitespaceIndex = buffer.search(/\s/);
const word = buffer.slice(0, whitespaceIndex + 1);
controller.enqueue({ type: 'text-delta', textDelta: word });
buffer = buffer.slice(whitespaceIndex + 1);

if (delayInMs > 0) {
await delay(delayInMs);
}
}
},
});
}
Loading

0 comments on commit cc16a83

Please sign in to comment.