Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions agents/src/tts/stream_adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ export class StreamAdapterWrapper extends SynthesizeStream {

async #run() {
const forwardInput = async () => {
for await (const input of this.input) {
while (true) {
const { done, value: input } = await this.inputReader.read();
if (done) break;
if (input === SynthesizeStream.FLUSH_SENTINEL) {
this.#sentenceStream.flush();
} else {
Expand All @@ -65,10 +67,10 @@ export class StreamAdapterWrapper extends SynthesizeStream {
const synthesize = async () => {
for await (const ev of this.#sentenceStream) {
for await (const audio of this.#tts.synthesize(ev.token)) {
this.output.put(audio);
this.outputWriter.write(audio);
}
}
this.output.put(SynthesizeStream.END_OF_STREAM);
this.outputWriter.write(SynthesizeStream.END_OF_STREAM);
};

Promise.all([forwardInput(), synthesize()]);
Expand Down
146 changes: 113 additions & 33 deletions agents/src/tts/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import type { AudioFrame } from '@livekit/rtc-node';
import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter';
import { EventEmitter } from 'node:events';
import type { ReadableStream } from 'node:stream/web';
import { log } from '../log.js';
import type { TTSMetrics } from '../metrics/base.js';
import { AsyncIterableQueue, mergeFrames } from '../utils.js';
import { DeferredReadableStream } from '../stream/deferred_stream.js';
import { IdentityTransform } from '../stream/identity_transform.js';
import { mergeFrames } from '../utils.js';

/** SynthesizedAudio is a packet of speech synthesis as returned by the TTS. */
export interface SynthesizedAudio {
Expand Down Expand Up @@ -105,22 +109,73 @@ export abstract class SynthesizeStream
{
protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL');
static readonly END_OF_STREAM = Symbol('END_OF_STREAM');
protected input = new AsyncIterableQueue<string | typeof SynthesizeStream.FLUSH_SENTINEL>();
protected queue = new AsyncIterableQueue<
protected inputReader: ReadableStreamDefaultReader<
string | typeof SynthesizeStream.FLUSH_SENTINEL
>;
protected outputWriter: WritableStreamDefaultWriter<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>();
protected output = new AsyncIterableQueue<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>();
>;
protected closed = false;
abstract label: string;
#tts: TTS;
#metricsPendingTexts: string[] = [];
#metricsText = '';
#monitorMetricsTask?: Promise<void>;

private deferredInputStream: DeferredReadableStream<
string | typeof SynthesizeStream.FLUSH_SENTINEL
>;
protected metricsStream: ReadableStream<SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM>;
private input = new IdentityTransform<string | typeof SynthesizeStream.FLUSH_SENTINEL>();
private output = new IdentityTransform<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>();
private inputWriter: WritableStreamDefaultWriter<string | typeof SynthesizeStream.FLUSH_SENTINEL>;
private outputReader: ReadableStreamDefaultReader<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>;
private logger = log();
private inputClosed = false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicative. The Readable/WritableStreamDefaultWriter internally tracks whether it's closed but doesn't expose it. The only way to know is when you try to write to or close an already-closed writer, it throws an error.

The other option would be to

    try {
      this.inputWriter.write(SynthesizeStream.FLUSH_SENTINEL);
    } catch (error) {
      throw new Error('Input is closed');
    }

everywhere, which doesn't seem much better. Let me know what you thoughts are @lukasIO @toubatbrian


constructor(tts: TTS) {
this.#tts = tts;
this.deferredInputStream = new DeferredReadableStream();

this.inputWriter = this.input.writable.getWriter();
this.inputReader = this.input.readable.getReader();
this.outputWriter = this.output.writable.getWriter();

const [outputStream, metricsStream] = this.output.readable.tee();
this.outputReader = outputStream.getReader();
this.metricsStream = metricsStream;

this.pumpDeferredStream();
this.monitorMetrics();
}

/**
* Reads from the deferred input stream and forwards chunks to the input writer.
*
* Note: we can't just do this.deferredInputStream.stream.pipeTo(this.input.writable)
* because the inputWriter locks the this.input.writable stream. All writes must go through
* the inputWriter.
*/
private async pumpDeferredStream() {
const reader = this.deferredInputStream.stream.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done || value === SynthesizeStream.FLUSH_SENTINEL) {
break;
}
this.inputWriter.write(value);
}
} catch (error) {
this.logger.error(error, 'Error reading deferred input stream');
} finally {
reader.releaseLock();
this.flush();
this.endInput();
}
}

protected async monitorMetrics() {
Expand Down Expand Up @@ -148,9 +203,11 @@ export abstract class SynthesizeStream
}
};

for await (const audio of this.queue) {
this.output.put(audio);
if (audio === SynthesizeStream.END_OF_STREAM) continue;
const metricsReader = this.metricsStream.getReader();

while (true) {
const { done, value: audio } = await metricsReader.read();
if (done || audio === SynthesizeStream.END_OF_STREAM) break;
requestId = audio.requestId;
if (!ttfb) {
ttfb = process.hrtime.bigint() - startTime;
Expand All @@ -164,23 +221,24 @@ export abstract class SynthesizeStream
if (requestId) {
emit();
}
this.output.close();
}

updateInputStream(text: ReadableStream<string>) {
this.deferredInputStream.setSource(text);
}

/** Push a string of text to the TTS */
/** @deprecated Use `updateInputStream` instead */
pushText(text: string) {
if (!this.#monitorMetricsTask) {
this.#monitorMetricsTask = this.monitorMetrics();
}
this.#metricsText += text;

if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.put(text);
this.inputWriter.write(text);
}

/** Flush the TTS, causing it to process all pending text */
Expand All @@ -189,34 +247,41 @@ export abstract class SynthesizeStream
this.#metricsPendingTexts.push(this.#metricsText);
this.#metricsText = '';
}
if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.put(SynthesizeStream.FLUSH_SENTINEL);
this.inputWriter.write(SynthesizeStream.FLUSH_SENTINEL);
}

/** Mark the input as ended and forbid additional pushes */
endInput() {
if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.close();
this.inputClosed = true;
this.inputWriter.close();
}

next(): Promise<IteratorResult<SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM>> {
return this.output.next();
return this.outputReader.read().then(({ done, value }) => {
if (done) {
return { done: true, value: undefined };
}
return { done: false, value };
});
}

/** Close both the input and output of the TTS stream */
close() {
this.input.close();
this.output.close();
if (!this.inputClosed) {
this.inputWriter.close();
}
this.closed = true;
}

Expand All @@ -240,17 +305,26 @@ export abstract class SynthesizeStream
* exports its own child ChunkedStream class, which inherits this class's methods.
*/
export abstract class ChunkedStream implements AsyncIterableIterator<SynthesizedAudio> {
protected queue = new AsyncIterableQueue<SynthesizedAudio>();
protected output = new AsyncIterableQueue<SynthesizedAudio>();
protected outputWriter: WritableStreamDefaultWriter<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>;
protected closed = false;
abstract label: string;
#text: string;
#tts: TTS;
private output = new IdentityTransform<SynthesizedAudio>();
private outputReader: ReadableStreamDefaultReader<SynthesizedAudio>;
private metricsStream: ReadableStream<SynthesizedAudio>;

constructor(text: string, tts: TTS) {
this.#text = text;
this.#tts = tts;

this.outputWriter = this.output.writable.getWriter();
const [outputStream, metricsStream] = this.output.readable.tee();
this.outputReader = outputStream.getReader();
this.metricsStream = metricsStream;

this.monitorMetrics();
}

Expand All @@ -260,15 +334,18 @@ export abstract class ChunkedStream implements AsyncIterableIterator<Synthesized
let ttfb: bigint | undefined;
let requestId = '';

for await (const audio of this.queue) {
this.output.put(audio);
const metricsReader = this.metricsStream.getReader();

while (true) {
const { done, value: audio } = await metricsReader.read();
if (done) break;

requestId = audio.requestId;
if (!ttfb) {
ttfb = process.hrtime.bigint() - startTime;
}
audioDuration += audio.frame.samplesPerChannel / audio.frame.sampleRate;
}
this.output.close();

const duration = process.hrtime.bigint() - startTime;
const metrics: TTSMetrics = {
Expand All @@ -294,14 +371,17 @@ export abstract class ChunkedStream implements AsyncIterableIterator<Synthesized
return mergeFrames(frames);
}

next(): Promise<IteratorResult<SynthesizedAudio>> {
return this.output.next();
async next(): Promise<IteratorResult<SynthesizedAudio>> {
const { done, value } = await this.outputReader.read();
if (done) {
return { done: true, value: undefined };
}
return { done: false, value };
}

/** Close both the input and output of the TTS stream */
close() {
this.queue.close();
this.output.close();
this.outputWriter.close();
this.closed = true;
}

Expand Down
17 changes: 10 additions & 7 deletions agents/src/vad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,22 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter<VADCal

export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL');
protected input = new IdentityTransform<AudioFrame | typeof VADStream.FLUSH_SENTINEL>();
protected output = new IdentityTransform<VADEvent>();
protected inputWriter: WritableStreamDefaultWriter<AudioFrame | typeof VADStream.FLUSH_SENTINEL>;

protected inputReader: ReadableStreamDefaultReader<AudioFrame | typeof VADStream.FLUSH_SENTINEL>;
protected outputWriter: WritableStreamDefaultWriter<VADEvent>;
protected outputReader: ReadableStreamDefaultReader<VADEvent>;
protected closed = false;
protected inputClosed = false;

#vad: VAD;
#lastActivityTime = BigInt(0);
private logger = log();
private deferredInputStream: DeferredReadableStream<AudioFrame>;

private input = new IdentityTransform<AudioFrame | typeof VADStream.FLUSH_SENTINEL>();
private output = new IdentityTransform<VADEvent>();
private metricsStream: ReadableStream<VADEvent>;
private outputReader: ReadableStreamDefaultReader<VADEvent>;
private inputWriter: WritableStreamDefaultWriter<AudioFrame | typeof VADStream.FLUSH_SENTINEL>;
Comment on lines +97 to +101
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making access modifiers more restrictive . Forgot to do it in #390


constructor(vad: VAD) {
this.#vad = vad;
this.deferredInputStream = new DeferredReadableStream<AudioFrame>();
Expand Down Expand Up @@ -207,7 +208,7 @@ export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
throw new Error('Stream is closed');
}
this.inputClosed = true;
this.input.writable.close();
this.inputWriter.close();
}

async next(): Promise<IteratorResult<VADEvent>> {
Expand All @@ -220,7 +221,9 @@ export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
}

close() {
this.input.writable.close();
if (!this.inputClosed) {
this.inputWriter.close();
}
Comment on lines +224 to +226
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to do in #390

this.closed = true;
}

Expand Down
14 changes: 8 additions & 6 deletions plugins/cartesia/src/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export class ChunkedStream extends tts.ChunkedStream {
(res) => {
res.on('data', (chunk) => {
for (const frame of bstream.write(chunk)) {
this.queue.put({
this.outputWriter.write({
requestId,
frame,
final: false,
Expand All @@ -117,14 +117,14 @@ export class ChunkedStream extends tts.ChunkedStream {
});
res.on('close', () => {
for (const frame of bstream.flush()) {
this.queue.put({
this.outputWriter.write({
requestId,
frame,
final: false,
segmentId: requestId,
});
}
this.queue.close();
this.outputWriter.close();
});
},
);
Expand Down Expand Up @@ -178,7 +178,9 @@ export class SynthesizeStream extends tts.SynthesizeStream {
};

const inputTask = async () => {
for await (const data of this.input) {
while (true) {
const { done, value: data } = await this.inputReader.read();
if (done) break;
if (data === SynthesizeStream.FLUSH_SENTINEL) {
this.#tokenizer.flush();
continue;
Expand All @@ -195,7 +197,7 @@ export class SynthesizeStream extends tts.SynthesizeStream {
let lastFrame: AudioFrame | undefined;
const sendLastFrame = (segmentId: string, final: boolean) => {
if (lastFrame) {
this.queue.put({ requestId, segmentId, frame: lastFrame, final });
this.outputWriter.write({ requestId, segmentId, frame: lastFrame, final });
lastFrame = undefined;
}
};
Expand All @@ -215,7 +217,7 @@ export class SynthesizeStream extends tts.SynthesizeStream {
lastFrame = frame;
}
sendLastFrame(segmentId, true);
this.queue.put(SynthesizeStream.END_OF_STREAM);
this.outputWriter.write(SynthesizeStream.END_OF_STREAM);

if (segmentId === requestId) {
closing = true;
Expand Down
Loading
Loading