Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Ai00-X/ai00_server
Browse files Browse the repository at this point in the history
  • Loading branch information
cgisky1980 committed May 14, 2024
2 parents fed750a + c870f69 commit ae416bf
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 81 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ zip-extract = "0.1"
# path = "../web-rwkv"
default-features = false
features = ["native"]
version = "0.8.8"
version = "0.8.9"

[dependencies.salvo]
default-features = true
Expand Down
115 changes: 49 additions & 66 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ pub struct TokenCounter {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
pub duration: Duration,
}

#[derive(Clone)]
Expand Down Expand Up @@ -419,38 +420,31 @@ async fn load_runtime(

let context = context.clone();
let reload = reload.clone();
match (info.version, reload.precision) {
(ModelVersion::V4, Precision::Fp16) => {
let model = Build::<v4::Model>::build(builder).await?;
let builder = v4::ModelRuntime::<f16>::new(model, max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V5, Precision::Fp16) => {
let model = Build::<v5::Model>::build(builder).await?;
let builder = v5::ModelRuntime::<f16>::new(model, max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V6, Precision::Fp16) => {
let model = Build::<v6::Model>::build(builder).await?;
let builder = v6::ModelRuntime::<f16>::new(model, max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V4, Precision::Fp32) => {
let model = Build::<v4::Model>::build(builder).await?;
let builder = v4::ModelRuntime::<f32>::new(model, max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V5, Precision::Fp32) => {
let model = Build::<v5::Model>::build(builder).await?;
let builder = v5::ModelRuntime::<f32>::new(model, max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V6, Precision::Fp32) => {
let model = Build::<v6::Model>::build(builder).await?;
let builder = v6::ModelRuntime::<f32>::new(model, max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await

macro_rules! match_safe_tensors {
(($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $runtime:ty)),+ }) => {
match ($v, $p) {
$(
($version, $precision) => {
let model = Build::<$model>::build(builder).await?;
let builder = <$runtime>::new(model, max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
)+
}
}
}
match_safe_tensors!(
(info.version, reload.precision),
{
(ModelVersion::V4, Precision::Fp16, v4::Model, v4::ModelRuntime::<f16>),
(ModelVersion::V5, Precision::Fp16, v5::Model, v5::ModelRuntime::<f16>),
(ModelVersion::V6, Precision::Fp16, v6::Model, v6::ModelRuntime::<f16>),
(ModelVersion::V4, Precision::Fp32, v4::Model, v4::ModelRuntime::<f32>),
(ModelVersion::V5, Precision::Fp32, v5::Model, v5::ModelRuntime::<f32>),
(ModelVersion::V6, Precision::Fp32, v6::Model, v6::ModelRuntime::<f32>)
}
)
}
LoadType::Prefab => {
use cbor4ii::{core::utils::SliceReader, serde::Deserializer};
Expand All @@ -460,44 +454,32 @@ async fn load_runtime(

let context = context.clone();
let reload = reload.clone();
match (info.version, reload.precision) {
(ModelVersion::V4, Precision::Fp16) => {
let seed: Seed<_, v4::Model> = Seed::new(&context);
let model = seed.deserialize(&mut deserializer)?;
let builder = v4::ModelRuntime::<f16>::new(model, reload.max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V5, Precision::Fp16) => {
let seed: Seed<_, v5::Model> = Seed::new(&context);
let model = seed.deserialize(&mut deserializer)?;
let builder = v5::ModelRuntime::<f16>::new(model, reload.max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V6, Precision::Fp16) => {
let seed: Seed<_, v6::Model> = Seed::new(&context);
let model = seed.deserialize(&mut deserializer)?;
let builder = v6::ModelRuntime::<f16>::new(model, reload.max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V4, Precision::Fp32) => {
let seed: Seed<_, v4::Model> = Seed::new(&context);
let model = seed.deserialize(&mut deserializer)?;
let builder = v4::ModelRuntime::<f32>::new(model, reload.max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V5, Precision::Fp32) => {
let seed: Seed<_, v5::Model> = Seed::new(&context);
let model = seed.deserialize(&mut deserializer)?;
let builder = v5::ModelRuntime::<f32>::new(model, reload.max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
(ModelVersion::V6, Precision::Fp32) => {
let seed: Seed<_, v6::Model> = Seed::new(&context);
let model = seed.deserialize(&mut deserializer)?;
let builder = v6::ModelRuntime::<f32>::new(model, reload.max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await

macro_rules! match_prefab {
(($v:expr, $p:expr), { $(($version:path, $precision:path, $model:ty, $runtime:ty)),+ }) => {
match ($v, $p) {
$(
($version, $precision) => {
let seed: Seed<_, $model> = Seed::new(&context);
let model = seed.deserialize(&mut deserializer)?;
let builder = <$runtime>::new(model, reload.max_batch);
Runtime::new(context, builder, reload, states, tokenizer, vocab).await
}
)+
}
}
}
match_prefab!(
(info.version, reload.precision),
{
(ModelVersion::V4, Precision::Fp16, v4::Model, v4::ModelRuntime::<f16>),
(ModelVersion::V5, Precision::Fp16, v5::Model, v5::ModelRuntime::<f16>),
(ModelVersion::V6, Precision::Fp16, v6::Model, v6::ModelRuntime::<f16>),
(ModelVersion::V4, Precision::Fp32, v4::Model, v4::ModelRuntime::<f32>),
(ModelVersion::V5, Precision::Fp32, v5::Model, v5::ModelRuntime::<f32>),
(ModelVersion::V6, Precision::Fp32, v6::Model, v6::ModelRuntime::<f32>)
}
)
}
};

Expand Down Expand Up @@ -664,6 +646,7 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
buffer: Default::default(),
model_tokens: Default::default(),
bnf_sampler: None,
instant: None,
request,
sender: token_sender,
};
Expand Down
29 changes: 17 additions & 12 deletions src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ pub struct GenerateContext {
pub model_tokens: Vec<u16>,
/// Compiled BNF schema, if any.
pub bnf_sampler: Option<Arc<RwLock<BnfSampler>>>,
/// For measuring time used.
pub instant: Option<Instant>,
/// Generate request provided by the caller.
pub request: GenerateRequest,
/// To send back generated tokens.
Expand Down Expand Up @@ -809,6 +811,7 @@ impl Runtime {
continue;
};

let instant = context.instant.get_or_insert(Instant::now());
let prefix = std::mem::take(&mut context.prefix);
let suffix = std::mem::take(&mut context.suffix);
let model_tokens = [prefix.0, suffix.0].concat();
Expand Down Expand Up @@ -853,20 +856,22 @@ impl Runtime {
context.buffer.append(&mut word);
context.model_tokens.push(token);

let count_tokens = || {
let prompt_tokens = context.prompt_tokens.len();
let completion_tokens = context.model_tokens.len();
let total_tokens = prompt_tokens + completion_tokens;
TokenCounter {
prompt_tokens,
completion_tokens,
total_tokens,
}
};

let mut done = false;
let mut finish = |reason| {
let _ = context.sender.send(Token::Stop(reason, count_tokens()));
let counter = {
let prompt_tokens = context.prompt_tokens.len();
let completion_tokens = context.model_tokens.len();
let total_tokens = prompt_tokens + completion_tokens;
let duration = instant.elapsed();
TokenCounter {
prompt_tokens,
completion_tokens,
total_tokens,
duration,
}
};

let _ = context.sender.send(Token::Stop(reason, counter));
let _ = context.sender.send(Token::Done);
done = true;
};
Expand Down

0 comments on commit ae416bf

Please sign in to comment.