Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the tokenizer to clients #622

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/engine.ts
Original file line number Diff line number Diff line change
@@ -70,6 +70,7 @@ import {
} from "./error";
import { asyncLoadTokenizer } from "./cache_util";
import { EmbeddingPipeline } from "./embedding";
import { Tokenizer } from "@mlc-ai/web-tokenizers";

/**
* Creates `MLCEngine`, and loads `modelId` onto WebGPU.
@@ -131,6 +132,7 @@ export class MLCEngine implements MLCEngineInterface {
private logitProcessorRegistry?: Map<string, LogitProcessor>;
private initProgressCallback?: InitProgressCallback;
private appConfig: AppConfig;
private tokenizer: Tokenizer | null = null;

// Signals and flags
private interruptSignal = false;
@@ -359,7 +361,7 @@ export class MLCEngine implements MLCEngineInterface {
});
tvm.initWebGPU(gpuDetectOutput.device);

const tokenizer = await asyncLoadTokenizer(
this.tokenizer = await asyncLoadTokenizer(
modelUrl,
curModelConfig,
this.appConfig,
@@ -379,11 +381,11 @@ export class MLCEngine implements MLCEngineInterface {
// embedding model, and prompt user to use ModelRecord.model_type
let newPipeline: LLMChatPipeline | EmbeddingPipeline;
if (modelRecord.model_type === ModelType.embedding) {
newPipeline = new EmbeddingPipeline(tvm, tokenizer, curModelConfig);
newPipeline = new EmbeddingPipeline(tvm, this.tokenizer, curModelConfig);
} else {
newPipeline = new LLMChatPipeline(
tvm,
tokenizer,
this.tokenizer,
curModelConfig,
logitProcessor,
);
@@ -1333,4 +1335,16 @@ export class MLCEngine implements MLCEngineInterface {
async decode(pipeline: LLMChatPipeline, genConfig?: GenerationConfig) {
return pipeline.decodeStep(genConfig);
}

//-----------------------------------------------
// 8. Expose tokenizer
//-----------------------------------------------

async tokenize(text: string) {
return this.tokenizer!.encode(text);
}

async decodeTokens(ids: Int32Array) {
return this.tokenizer!.decode(ids);
}
}
11 changes: 10 additions & 1 deletion src/message.ts
Original file line number Diff line number Diff line change
@@ -34,7 +34,9 @@ type RequestKind =
| "customRequest"
| "keepAlive"
| "setLogLevel"
| "setAppConfig";
| "setAppConfig"
| "tokenize"
| "decodeTokens";

// eslint-disable-next-line @typescript-eslint/no-unused-vars
type ResponseKind = "return" | "throw" | "initProgressCallback";
@@ -58,6 +60,12 @@ export interface ForwardTokensAndSampleParams {
isPrefill: boolean;
modelId?: string;
}
export interface TokenizeParams {
text: string;
}
export interface DecodeTokensParams {
inputIds: Int32Array;
}

// Notes on the following Params with modelId and chatOpts:
// These fields are the model and chatOpts that the frontend engine expects the backend
@@ -128,6 +136,7 @@ export type MessageContent =
| CreateEmbeddingResponse
| Completion
| AppConfig
| Int32Array
| void;
/**
* The message used in exchange between worker
6 changes: 6 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -172,6 +172,12 @@ export interface MLCEngineInterface {
*/
embedding(request: EmbeddingCreateParams): Promise<CreateEmbeddingResponse>;

/**
* Exposes the tokenizer for clients to avoid needing to load it twice
*/
tokenize(input: string): Promise<Int32Array>;
decodeTokens(input: Int32Array): Promise<string>;

/**
* @returns A text summarizing the runtime stats.
* @param modelId Only required when multiple models are loaded.
44 changes: 44 additions & 0 deletions src/web_worker.ts
Original file line number Diff line number Diff line change
@@ -26,6 +26,8 @@ import {
MessageContent,
ReloadParams,
ForwardTokensAndSampleParams,
TokenizeParams,
DecodeTokensParams,
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
ResetChatParams,
@@ -345,6 +347,24 @@ export class WebWorkerMLCEngineHandler {
onComplete?.(null);
return;
}
case "decodeTokens": {
this.handleTask(msg.uuid, async () => {
const params = msg.content as DecodeTokensParams;
const res = await this.engine.decodeTokens(params.inputIds);
onComplete?.(res);
return res;
});
return;
}
case "tokenize": {
this.handleTask(msg.uuid, async () => {
const params = msg.content as TokenizeParams;
const res = await this.engine.tokenize(params.text);
onComplete?.(res);
return res;
});
return;
}
default: {
if (msg.kind && msg.content) {
onError?.();
@@ -633,6 +653,30 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
return await this.getPromise<number>(msg);
}

async tokenize(text: string, modelId?: string) {
const msg: WorkerRequest = {
kind: "tokenize",
uuid: crypto.randomUUID(),
content: {
text,
modelId: modelId,
},
};
return await this.getPromise<Int32Array>(msg);
}

async decodeTokens(ids: Int32Array, modelId?: string) {
const msg: WorkerRequest = {
kind: "decodeTokens",
uuid: crypto.randomUUID(),
content: {
inputIds: Array.from(ids),
modelId: modelId,
},
};
return await this.getPromise<string>(msg);
}

/**
* Every time the generator is called, we post a message to the worker asking it to
* decode one step, and we expect to receive a message of `ChatCompletionChunk` from