diff --git a/agents/src/llm/llm.ts b/agents/src/llm/llm.ts index 1d4fec5bf..162fca6e4 100644 --- a/agents/src/llm/llm.ts +++ b/agents/src/llm/llm.ts @@ -3,8 +3,9 @@ // SPDX-License-Identifier: Apache-2.0 import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter'; import { EventEmitter } from 'node:events'; +import type { ReadableStream } from 'node:stream/web'; import type { LLMMetrics } from '../metrics/base.js'; -import { AsyncIterableQueue } from '../utils.js'; +import { IdentityTransform } from '../stream/identity_transform.js'; import type { ChatContext, ChatRole } from './chat_context.js'; import type { FunctionCallInfo, FunctionContext } from './function_context.js'; @@ -59,8 +60,7 @@ export abstract class LLM extends (EventEmitter as new () => TypedEmitter { - protected output = new AsyncIterableQueue(); - protected queue = new AsyncIterableQueue(); + protected outputWriter: WritableStreamDefaultWriter; protected closed = false; protected _functionCalls: FunctionCallInfo[] = []; abstract label: string; @@ -68,11 +68,21 @@ export abstract class LLMStream implements AsyncIterableIterator { #llm: LLM; #chatCtx: ChatContext; #fncCtx?: FunctionContext; + private output: IdentityTransform; + private outputReader: ReadableStreamDefaultReader; + private metricsStream: ReadableStream; constructor(llm: LLM, chatCtx: ChatContext, fncCtx?: FunctionContext) { this.#llm = llm; this.#chatCtx = chatCtx; this.#fncCtx = fncCtx; + + this.output = new IdentityTransform(); + this.outputWriter = this.output.writable.getWriter(); + const [outputStream, metricsStream] = this.output.readable.tee(); + this.outputReader = outputStream.getReader(); + this.metricsStream = metricsStream; + this.monitorMetrics(); } @@ -82,8 +92,11 @@ export abstract class LLMStream implements AsyncIterableIterator { let requestId = ''; let usage: CompletionUsage | undefined; - for await (const ev of this.queue) { - this.output.put(ev); + const metricsReader = this.metricsStream.getReader(); + while (true) { + const { done, value: ev } = await metricsReader.read(); + if (done) break; + requestId = ev.requestId; if (!ttft) { ttft = process.hrtime.bigint() - startTime; @@ -92,7 +105,7 @@ export abstract class LLMStream implements AsyncIterableIterator { usage = ev.usage; } } - this.output.close(); + metricsReader.releaseLock(); const duration = process.hrtime.bigint() - startTime; const metrics: LLMMetrics = { @@ -139,12 +152,18 @@ export abstract class LLMStream implements AsyncIterableIterator { } next(): Promise> { - return this.output.next(); + return this.outputReader.read().then(({ done, value }) => { + if (done) { + return { done: true, value: undefined }; + } + return { done: false, value }; + }); } close() { - this.output.close(); - this.queue.close(); + if (!this.closed) { + this.outputWriter.close(); + } this.closed = true; } diff --git a/plugins/openai/src/llm.ts b/plugins/openai/src/llm.ts index 87c16fde8..80a6a6310 100644 --- a/plugins/openai/src/llm.ts +++ b/plugins/openai/src/llm.ts @@ -469,12 +469,12 @@ export class LLMStream extends llm.LLMStream { for (const choice of chunk.choices) { const chatChunk = this.#parseChoice(chunk.id, choice); if (chatChunk) { - this.queue.put(chatChunk); + this.outputWriter.write(chatChunk); } if (chunk.usage) { const usage = chunk.usage; - this.queue.put({ + this.outputWriter.write({ requestId: chunk.id, choices: [], usage: { @@ -487,7 +487,7 @@ export class LLMStream extends llm.LLMStream { } } } finally { - this.queue.close(); + this.close(); } }