Skip to content

Commit

Permalink
add phi3 support
Browse files Browse the repository at this point in the history
  • Loading branch information
guschmue committed Apr 28, 2024
1 parent 305be89 commit c967e5f
Show file tree
Hide file tree
Showing 66 changed files with 33,057 additions and 78,620 deletions.
45 changes: 15 additions & 30 deletions chat/chat.js
Original file line number Diff line number Diff line change
Expand Up @@ -186,18 +186,18 @@ const MODELS = {
"tinyllama": { name: "tinyllama", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-int4" },
"tinyllama_fp16": { name: "tinyllama-fp16", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-fp16", externaldata: true },
"phi2": { name: "phi2", path: "schmuell/phi2-int4" },
"phi3": { name: "phi3", path: "schmuell/phi3-int4", externaldata: true },
"stablelm": { name: "stablelm", path: "schmuell/stablelm-2-zephyr-1_6b-int4" },
}

function getConfig() {
const query = window.location.search.substring(1);
var config = {
model: "tinyllama",
model: "phi3",
provider: "webgpu",
profiler: 0,
verbose: 0,
threads: 1,
trace: 0,
csv: 0,
max_tokens: 512,
local: 0,
Expand Down Expand Up @@ -246,14 +246,14 @@ async function fetchAndCache(url) {
class LLM {
sess = undefined;
profiler = false;
trace = false;
feed = {};
output_tokens = [];
eos = 2;
need_position_ids = true;
stop = false;
kv_dims = [];
dtype = "float16";
max_tokens = 256;

constructor() {
}
Expand All @@ -263,18 +263,22 @@ class LLM {
const verbose = options.verbose;
const local = options.local;
this.profiler = options.profiler;
this.trace = options.trace;


const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main";

log(`loading... ${model.name}, ${provider}`);
const json_bytes = await fetchAndCache(model_path + "/config.json");
let textDecoder = new TextDecoder();
const model_config = JSON.parse(textDecoder.decode(json_bytes));

const model_bytes = await fetchAndCache(model_path + "/onnx/decoder_model_merged.onnx");
log(`model size ${Math.round(model_bytes.byteLength / 1024 / 1024)} MB`);
const externaldata = (model.externalData) ? await fetchAndCache(model_path + '/onnx/decoder_model_merged.onnx.data') : false;
const externaldata = (model.externaldata) ? await fetchAndCache(model_path + '/onnx/decoder_model_merged.onnx.data') : false;
let modelSize = model_bytes.byteLength;
if (externaldata) {
modelSize += externaldata.byteLength;
}
log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`);


const opt = {
executionProviders: [provider],
Expand All @@ -291,11 +295,6 @@ class LLM {
opt.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer';
}
break;
case "webnn":
if (!("ml" in navigator)) {
throw new Error("webnn is NOT supported");
}
break;
}

if (externaldata !== undefined) {
Expand All @@ -320,13 +319,6 @@ class LLM {
}

this.sess = await ort.InferenceSession.create(model_bytes, opt);

if (this.trace) {
ort.env.trace = true;
ort.env.webgpu.profiling.ondata = (version, inputsMetadata, outputsMetadata, kernelId, kernelType,
kernelName, programName, startTime, endTime) => { };
}

this.eos = model_config.eos_token_id;
this.kv_dims = [1, model_config.num_key_value_heads, 0, model_config.hidden_size / model_config.num_attention_heads];
this.dtype = config.model.dtype || "float16";
Expand Down Expand Up @@ -410,14 +402,7 @@ class LLM {
while (last_token != this.eos && seqlen < max_tokens && !this.stop) {
seqlen = this.output_tokens.length;
feed['attention_mask'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, () => 1n), [1, seqlen]);
let outputs;
if (this.trace) {
console.timeStamp("RUN-BEGIN");
outputs = await this.sess.run(feed);
console.timeStamp("RUN-END");
} else {
outputs = await this.sess.run(feed);
}
const outputs = await this.sess.run(feed);
last_token = BigInt(this.argmax(outputs.logits));
this.output_tokens.push(last_token);
if (callback && !this.profiler) {
Expand Down Expand Up @@ -470,7 +455,7 @@ async function Query(query, cb) {
const start_timer = performance.now();
const output_tokens = await llm.generate(input_ids, (output_tokens) => {
cb(token_to_text(tokenizer, output_tokens, input_ids.length));
}, {});
}, {max_tokens: config.max_tokens});

const took = (performance.now() - start_timer) / 1000;
const txt = token_to_text(tokenizer, output_tokens, input_ids.length);
Expand All @@ -490,8 +475,8 @@ async function LoadModel() {
provider: config.provider,
profiler: config.profiler,
verbose: config.verbose,
trace: config.trace,
local: config.local,
max_tokens: config.max_tokens,
});
log("Ready.");
} catch (error) {
Expand Down
Loading

0 comments on commit c967e5f

Please sign in to comment.