diff --git a/Cargo.lock b/Cargo.lock index 0a0c8cf9..e1570b19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4093,9 +4093,9 @@ dependencies = [ [[package]] name = "web-rwkv" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aaa2dae444cdfbf00718b97a7dffa79d7ff73df5b0175f5f8f4caea371b8b7e" +checksum = "0dc832976e594cc006a34b3da4d1d151e1c12fee1e529b588ea43a9c8e7b8361" dependencies = [ "ahash 0.8.11", "anyhow", diff --git a/Cargo.toml b/Cargo.toml index 64b1910e..e5374c4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ zip-extract = "0.1" # path = "../web-rwkv" default-features = false features = ["native"] -version = "0.8.8" +version = "0.8.9" [dependencies.salvo] default-features = true diff --git a/src/middleware.rs b/src/middleware.rs index a6076664..08b07d18 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -241,6 +241,7 @@ pub struct TokenCounter { pub prompt_tokens: usize, pub completion_tokens: usize, pub total_tokens: usize, + pub duration: Duration, } #[derive(Clone)] @@ -419,38 +420,31 @@ async fn load_runtime( let context = context.clone(); let reload = reload.clone(); - match (info.version, reload.precision) { - (ModelVersion::V4, Precision::Fp16) => { - let model = Build::::build(builder).await?; - let builder = v4::ModelRuntime::::new(model, max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V5, Precision::Fp16) => { - let model = Build::::build(builder).await?; - let builder = v5::ModelRuntime::::new(model, max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V6, Precision::Fp16) => { - let model = Build::::build(builder).await?; - let builder = v6::ModelRuntime::::new(model, max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V4, Precision::Fp32) => { - let model = Build::::build(builder).await?; - let builder = v4::ModelRuntime::::new(model, max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V5, Precision::Fp32) => { - let model = Build::::build(builder).await?; - let builder = v5::ModelRuntime::::new(model, max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V6, Precision::Fp32) => { - let model = Build::::build(builder).await?; - let builder = v6::ModelRuntime::::new(model, max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await + + macro_rules! match_safe_tensors { + (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $runtime:ty)),+ }) => { + match ($v, $p) { + $( + ($version, $precision) => { + let model = Build::<$model>::build(builder).await?; + let builder = <$runtime>::new(model, max_batch); + Runtime::new(context, builder, reload, states, tokenizer, vocab).await + } + )+ + } } } + match_safe_tensors!( + (info.version, reload.precision), + { + (ModelVersion::V4, Precision::Fp16, v4::Model, v4::ModelRuntime::), + (ModelVersion::V5, Precision::Fp16, v5::Model, v5::ModelRuntime::), + (ModelVersion::V6, Precision::Fp16, v6::Model, v6::ModelRuntime::), + (ModelVersion::V4, Precision::Fp32, v4::Model, v4::ModelRuntime::), + (ModelVersion::V5, Precision::Fp32, v5::Model, v5::ModelRuntime::), + (ModelVersion::V6, Precision::Fp32, v6::Model, v6::ModelRuntime::) + } + ) } LoadType::Prefab => { use cbor4ii::{core::utils::SliceReader, serde::Deserializer}; @@ -460,44 +454,32 @@ async fn load_runtime( let context = context.clone(); let reload = reload.clone(); - match (info.version, reload.precision) { - (ModelVersion::V4, Precision::Fp16) => { - let seed: Seed<_, v4::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v4::ModelRuntime::::new(model, reload.max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V5, Precision::Fp16) => { - let seed: Seed<_, v5::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v5::ModelRuntime::::new(model, reload.max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V6, Precision::Fp16) => { - let seed: Seed<_, v6::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v6::ModelRuntime::::new(model, reload.max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V4, Precision::Fp32) => { - let seed: Seed<_, v4::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v4::ModelRuntime::::new(model, reload.max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V5, Precision::Fp32) => { - let seed: Seed<_, v5::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v5::ModelRuntime::::new(model, reload.max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await - } - (ModelVersion::V6, Precision::Fp32) => { - let seed: Seed<_, v6::Model> = Seed::new(&context); - let model = seed.deserialize(&mut deserializer)?; - let builder = v6::ModelRuntime::::new(model, reload.max_batch); - Runtime::new(context, builder, reload, states, tokenizer, vocab).await + + macro_rules! match_prefab { + (($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $runtime:ty)),+ }) => { + match ($v, $p) { + $( + ($version, $precision) => { + let seed: Seed<_, $model> = Seed::new(&context); + let model = seed.deserialize(&mut deserializer)?; + let builder = <$runtime>::new(model, reload.max_batch); + Runtime::new(context, builder, reload, states, tokenizer, vocab).await + } + )+ + } } } + match_prefab!( + (info.version, reload.precision), + { + (ModelVersion::V4, Precision::Fp16, v4::Model, v4::ModelRuntime::), + (ModelVersion::V5, Precision::Fp16, v5::Model, v5::ModelRuntime::), + (ModelVersion::V6, Precision::Fp16, v6::Model, v6::ModelRuntime::), + (ModelVersion::V4, Precision::Fp32, v4::Model, v4::ModelRuntime::), + (ModelVersion::V5, Precision::Fp32, v5::Model, v5::ModelRuntime::), + (ModelVersion::V6, Precision::Fp32, v6::Model, v6::ModelRuntime::) + } + ) } }; @@ -664,6 +646,7 @@ pub async fn model_route(receiver: Receiver) -> Result<()> { buffer: Default::default(), model_tokens: Default::default(), bnf_sampler: None, + instant: None, request, sender: token_sender, }; diff --git a/src/run.rs b/src/run.rs index 24ef99dd..9a55bfe1 100644 --- a/src/run.rs +++ b/src/run.rs @@ -228,6 +228,8 @@ pub struct GenerateContext { pub model_tokens: Vec, /// Compiled BNF schema, if any. pub bnf_sampler: Option>>, + /// For measuring time used. + pub instant: Option, /// Generate request provided by the caller. pub request: GenerateRequest, /// To send back generated tokens. @@ -809,6 +811,7 @@ impl Runtime { continue; }; + let instant = context.instant.get_or_insert(Instant::now()); let prefix = std::mem::take(&mut context.prefix); let suffix = std::mem::take(&mut context.suffix); let model_tokens = [prefix.0, suffix.0].concat(); @@ -853,20 +856,22 @@ impl Runtime { context.buffer.append(&mut word); context.model_tokens.push(token); - let count_tokens = || { - let prompt_tokens = context.prompt_tokens.len(); - let completion_tokens = context.model_tokens.len(); - let total_tokens = prompt_tokens + completion_tokens; - TokenCounter { - prompt_tokens, - completion_tokens, - total_tokens, - } - }; - let mut done = false; let mut finish = |reason| { - let _ = context.sender.send(Token::Stop(reason, count_tokens())); + let counter = { + let prompt_tokens = context.prompt_tokens.len(); + let completion_tokens = context.model_tokens.len(); + let total_tokens = prompt_tokens + completion_tokens; + let duration = instant.elapsed(); + TokenCounter { + prompt_tokens, + completion_tokens, + total_tokens, + duration, + } + }; + + let _ = context.sender.send(Token::Stop(reason, counter)); let _ = context.sender.send(Token::Done); done = true; };