diff --git a/src/main.rs b/src/main.rs index 7c06a4b6..84c11458 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,6 +20,7 @@ use std::{ path::PathBuf, str::FromStr, }; +use tower_http::cors::CorsLayer; use web_rwkv::{BackedModelState, Environment, Model, Tokenizer}; mod chat; @@ -225,6 +226,10 @@ fn model_task( token_counter.total_tokens = tokens.len(); for _ in 0..max_tokens { + if token_sender.is_disconnected() { + break 'run; + } + let mut logits = model.run(&tokens, &state).unwrap_or_default(); for (&token, &count) in &occurrences { let penalty = @@ -324,6 +329,7 @@ async fn main() -> Result<()> { .route("/v1/chat/completions", post(chat::chat_completions)) .route("/embeddings", post(embedding::embeddings)) .route("/v1/embeddings", post(embedding::embeddings)) + .layer(CorsLayer::permissive()) .with_state(ThreadState { sender, model_name }); let addr = SocketAddr::from(([127, 0, 0, 1], args.port));