Skip to content

Commit

Permalink
Merge pull request #2 from josStorer/main
Browse files Browse the repository at this point in the history
Enable CORS and stop generating when interrupted
  • Loading branch information
cryscan committed Jul 22, 2023
2 parents 860058b + e9a7145 commit cd3a43a
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use std::{
path::PathBuf,
str::FromStr,
};
use tower_http::cors::CorsLayer;
use web_rwkv::{BackedModelState, Environment, Model, Tokenizer};

mod chat;
Expand Down Expand Up @@ -225,6 +226,10 @@ fn model_task(
token_counter.total_tokens = tokens.len();

for _ in 0..max_tokens {
if token_sender.is_disconnected() {
break 'run;
}

let mut logits = model.run(&tokens, &state).unwrap_or_default();
for (&token, &count) in &occurrences {
let penalty =
Expand Down Expand Up @@ -324,6 +329,7 @@ async fn main() -> Result<()> {
.route("/v1/chat/completions", post(chat::chat_completions))
.route("/embeddings", post(embedding::embeddings))
.route("/v1/embeddings", post(embedding::embeddings))
.layer(CorsLayer::permissive())
.with_state(ThreadState { sender, model_name });

let addr = SocketAddr::from(([127, 0, 0, 1], args.port));
Expand Down

0 comments on commit cd3a43a

Please sign in to comment.