Skip to content

Commit

Permalink
Support V5.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Oct 11, 2023
1 parent 14e853f commit 7e858dc
Show file tree
Hide file tree
Showing 8 changed files with 458 additions and 407 deletions.
250 changes: 116 additions & 134 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ai00_server"
version = "0.2.1"
version = "0.2.2"
edition = "2021"
authors = ["Gu ZhenNiu <[email protected]>", "Zhang Zhenyuan <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand All @@ -18,7 +18,8 @@ axum = { git = "https://github.com/cryscan/axum", branch = "sse-leading-space" }
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.4", features = ["full"] }
tokio = { version = "1", features = ["full"] }
web-rwkv = "0.2.0"
# web-rwkv = "0.3.2"
web-rwkv = { git = "https://github.com/cryscan/web-rwkv", branch = "main" }
memmap = "0.7"
bytemuck = "1"
regex = "1.8"
Expand Down
8 changes: 4 additions & 4 deletions src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ pub struct ChatResponse {
}

async fn chat_completions_one(
State(ThreadState { sender, model_name }): State<ThreadState>,
State(ThreadState(sender)): State<ThreadState>,
Json(request): Json<ChatRequest>,
) -> Json<ChatResponse> {
let (token_sender, token_receiver) = flume::unbounded();
let model_name = model_name.read().unwrap().clone();
let model_name = String::new();

let request = request.into();
let _ = sender.send(ThreadRequest::Generate {
Expand Down Expand Up @@ -219,11 +219,11 @@ pub struct PartialChatResponse {
}

async fn chat_completions_stream(
State(ThreadState { sender, model_name }): State<ThreadState>,
State(ThreadState(sender)): State<ThreadState>,
Json(request): Json<ChatRequest>,
) -> Sse<impl Stream<Item = Result<Event>>> {
let (token_sender, token_receiver) = flume::unbounded();
let model_name = model_name.read().unwrap().clone();
let model_name = String::new();

let request = request.into();
let _ = sender.send(ThreadRequest::Generate {
Expand Down
8 changes: 4 additions & 4 deletions src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ pub struct CompletionResponse {
}

async fn completions_one(
State(ThreadState { sender, model_name }): State<ThreadState>,
State(ThreadState(sender)): State<ThreadState>,
Json(request): Json<CompletionRequest>,
) -> Json<CompletionResponse> {
let (token_sender, token_receiver) = flume::unbounded();
let model_name = model_name.read().unwrap().clone();
let model_name = String::new();

let request = GenerateRequest::from(request);
let _ = sender.send(ThreadRequest::Generate {
Expand Down Expand Up @@ -167,11 +167,11 @@ pub struct PartialCompletionResponse {
}

async fn completions_stream(
State(ThreadState { sender, model_name }): State<ThreadState>,
State(ThreadState(sender)): State<ThreadState>,
Json(request): Json<CompletionRequest>,
) -> Sse<impl Stream<Item = Result<Event>>> {
let (token_sender, token_receiver) = flume::unbounded();
let model_name = model_name.read().unwrap().clone();
let model_name = String::new();

let request = GenerateRequest::from(request);
let _ = sender.send(ThreadRequest::Generate {
Expand Down
6 changes: 2 additions & 4 deletions src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ pub struct EmbeddingResponse {
}

pub async fn embeddings(
State(ThreadState {
sender, model_name, ..
}): State<ThreadState>,
State(ThreadState(sender)): State<ThreadState>,
Json(request): Json<EmbeddingRequest>,
) -> Json<EmbeddingResponse> {
let (token_sender, token_receiver) = flume::unbounded();
let model_name = model_name.read().unwrap().clone();
let model_name = String::new();

let _ = sender.send(ThreadRequest::Generate {
request: request.into(),
Expand Down
120 changes: 80 additions & 40 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::{
collections::HashMap,
convert::Infallible,
fs::File,
io::{BufReader, Cursor, Read},
net::{Ipv4Addr, SocketAddr},
path::{Path, PathBuf},
sync::{Arc, Mutex, RwLock},
sync::Arc,
time::Duration,
};

Expand All @@ -15,13 +16,17 @@ use axum::{
};
use clap::{Parser, ValueEnum};
use dialoguer::{theme::ColorfulTheme, Select};
use flume::Receiver;
use flume::{Receiver, Sender};
use memmap::Mmap;
use run::RuntimeUntyped;
use serde::{Deserialize, Serialize};
use tower_http::{cors::CorsLayer, services::ServeDir};
use web_rwkv::{
context::{Context, ContextBuilder, Instance},
model::{LayerFlags, Model, ModelBuilder, ModelState, Quantization},
model::{
loader::Loader, FromBuilder, LayerFlags, Model, ModelBuilder, ModelState, ModelVersion,
Quantization, StateBuilder,
},
tokenizer::Tokenizer,
wgpu::PowerPreference,
};
Expand All @@ -39,8 +44,7 @@ mod run;
mod sampler;

pub const MAX_TOKENS: usize = 4096;
pub const MAX_PENALTY_COUNT: usize = 1024;
pub const STATE_CACHE_LRU: usize = 16;
pub const STATE_CHUNK_SIZE: usize = 4;

#[derive(Debug)]
pub enum Token {
Expand Down Expand Up @@ -91,7 +95,7 @@ pub enum ThreadRequest {
Reload(ReloadRequest),
Generate {
request: GenerateRequest,
token_sender: flume::Sender<Token>,
token_sender: Sender<Token>,
},
}

Expand Down Expand Up @@ -146,10 +150,7 @@ pub struct TokenCounter {
}

#[derive(Clone)]
pub struct ThreadState {
pub sender: flume::Sender<ThreadRequest>,
pub model_name: Arc<RwLock<String>>,
}
pub struct ThreadState(pub Sender<ThreadRequest>);

async fn create_context(args: &Args) -> Result<Context> {
let instance = Instance::new();
Expand Down Expand Up @@ -185,9 +186,12 @@ fn load_tokenizer(path: impl AsRef<Path>) -> Result<Tokenizer> {
Ok(Tokenizer::new(&contents)?)
}

fn load_model<'a>(context: &Context, request: ReloadRequest) -> Result<Model<'a>> {
fn load_model<'a, M, S>(context: &Context, request: ReloadRequest, data: &'a [u8]) -> Result<(M, S)>
where
S: ModelState + FromBuilder<Builder<'a> = StateBuilder, Error = Infallible>,
M: Model<ModelState = S> + FromBuilder<Builder<'a> = ModelBuilder<'a>, Error = anyhow::Error>,
{
let ReloadRequest {
path,
quant,
token_chunk_size,
head_chunk_size,
Expand All @@ -203,14 +207,16 @@ fn load_model<'a>(context: &Context, request: ReloadRequest) -> Result<Model<'a>
Quantization::Int8(layers)
};

let file = File::open(path)?;
let map = unsafe { Mmap::map(&file)? };

ModelBuilder::new(context, &map)
let model: M = ModelBuilder::new(context, data)
.with_quant(quant)
.with_token_chunk_size(token_chunk_size)
.with_head_chunk_size(head_chunk_size)
.build()
.build()?;
let state: S = StateBuilder::new(context, model.info())
.with_max_batch(request.max_batch)
.with_chunk_size(STATE_CHUNK_SIZE)
.build();
Ok((model, state))
}

fn load_web(path: impl AsRef<Path>, target: &Path) -> Result<()> {
Expand Down Expand Up @@ -449,44 +455,74 @@ fn model_route(
tokenizer: Tokenizer,
receiver: Receiver<ThreadRequest>,
) -> Result<()> {
let runtime: Arc<Mutex<Option<Runtime>>> = Default::default();
let mut runtime: Option<Arc<RuntimeUntyped>> = None;
let mut pending = Vec::new();

let sender = {
let (sender, receiver) = flume::unbounded();
let runtime = runtime.clone();
let tokenizer = tokenizer.clone();
std::thread::spawn(move || run::run(runtime, tokenizer, receiver));
std::thread::spawn(move || run::run(tokenizer, receiver));
sender
};

let queue = |context| -> Vec<GenerateContext> {
fn queue<'a>(
runtime: &Option<Arc<RuntimeUntyped<'a>>>,
context: GenerateContext,
sender: &Sender<Option<Arc<RuntimeUntyped<'a>>>>,
) -> Vec<GenerateContext> {
let mut pending = Vec::new();
let mut runtime = runtime.lock().unwrap();
match &mut *runtime {
match &runtime {
Some(runtime) => match runtime.queue(context) {
SlotResult::Success(batch) => log::info!("queued task at {batch}"),
SlotResult::Fault(batch) => log::info!("swapped task at {batch}"),
SlotResult::Failure(context) => pending.push(*context),
},
None => pending.push(context),
}
let _ = sender.send(());
let _ = sender.send(runtime.clone());
pending
};
}

let listen = |pending: &mut Vec<GenerateContext>| -> Result<()> {
fn listen<'a>(
runtime: &mut Option<Arc<RuntimeUntyped<'a>>>,
pending: &mut Vec<GenerateContext>,
context: &Context,
tokenizer: &Tokenizer,
receiver: &Receiver<ThreadRequest>,
sender: &Sender<Option<Arc<RuntimeUntyped<'a>>>>,
) -> Result<()> {
match receiver.recv()? {
ThreadRequest::Reload(request) => {
let max_runtime_batch = request.max_runtime_batch;
let max_batch = request.max_batch;
let embed_layer = request.embed_layer;

let model = Arc::new(load_model(&context, request)?);
let info = model.info();
let state = ModelState::new(&context, info, max_batch);
let mut runtime = runtime.lock().unwrap();
let _ = runtime.replace(Runtime::new(model, state, max_runtime_batch, embed_layer));
let file = File::open(&request.path)?;
let data = unsafe { Mmap::map(&file)? };
let info = Loader::info(&data)?;
log::info!("{:#?}", info);

let rt = match info.version {
ModelVersion::V4 => {
let (model, state) = load_model(context, request, &data)?;
RuntimeUntyped::V4(Runtime::new(
model,
state,
max_runtime_batch,
embed_layer,
))
}
ModelVersion::V5 => {
let (model, state) = load_model(context, request, &data)?;
RuntimeUntyped::V5(Runtime::new(
model,
state,
max_runtime_batch,
embed_layer,
))
}
};
runtime.replace(Arc::new(rt));
let _ = sender.send(runtime.clone());
}
ThreadRequest::Generate {
request,
Expand Down Expand Up @@ -529,21 +565,28 @@ fn model_route(
embed,
sender: token_sender,
};
pending.append(&mut queue(context));
pending.append(&mut queue(runtime, context, sender));
}
};
Ok(())
};
}

loop {
if let Err(err) = listen(&mut pending) {
if let Err(err) = listen(
&mut runtime,
&mut pending,
&context,
&tokenizer,
&receiver,
&sender,
) {
log::error!("{err}");
}

while !pending.is_empty() {
let mut temp = Vec::new();
for context in pending.drain(..) {
temp.append(&mut queue(context));
temp.append(&mut queue(&runtime, context, &sender));
}
std::mem::swap(&mut pending, &mut temp);
std::thread::sleep(Duration::from_secs(1));
Expand Down Expand Up @@ -628,10 +671,7 @@ async fn main() {
.route("/v1/embeddings", post(embedding::embeddings))
.fallback_service(ServeDir::new(temp_path.join("www")))
.layer(CorsLayer::permissive())
.with_state(ThreadState {
sender,
model_name: Default::default(),
});
.with_state(ThreadState(sender));

let addr = SocketAddr::from((args.ip.unwrap_or(Ipv4Addr::new(0, 0, 0, 0)), args.port));
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
Expand Down
8 changes: 3 additions & 5 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@ pub struct ModelResponse {
pub data: Vec<ModelChoice>,
}

pub async fn models(
State(ThreadState { model_name, .. }): State<ThreadState>,
) -> Json<ModelResponse> {
pub async fn models(State(ThreadState(_sender)): State<ThreadState>) -> Json<ModelResponse> {
Json(ModelResponse {
data: vec![ModelChoice {
object: "models".into(),
id: model_name.read().unwrap().clone(),
id: "".into(),
}],
})
}

pub async fn load(
State(ThreadState { sender, .. }): State<ThreadState>,
State(ThreadState(sender)): State<ThreadState>,
Json(request): Json<ReloadRequest>,
) {
let _ = sender.send(ThreadRequest::Reload(request));
Expand Down
Loading

0 comments on commit 7e858dc

Please sign in to comment.