From 5712c6c447b4df326792f2ba50303f582e525b52 Mon Sep 17 00:00:00 2001 From: cryscan Date: Tue, 25 Jul 2023 13:48:38 +0800 Subject: [PATCH] Can specify ip. --- src/main.rs | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/main.rs b/src/main.rs index 6fe5c03a..873e2fa9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,7 +11,7 @@ use std::{ ffi::OsStr, fs::File, io::{BufReader, Read}, - net::SocketAddr, + net::{Ipv4Addr, SocketAddr}, path::PathBuf, str::FromStr, }; @@ -130,11 +130,6 @@ fn model_task( tokenizer: PathBuf, receiver: Receiver, ) -> Result<()> { - simple_logger::SimpleLogger::new() - .with_level(log::LevelFilter::Warn) - .with_module_level("ai00_server", log::LevelFilter::Trace) - .init()?; - let tokenizer = load_tokenizer(tokenizer)?; let model = load_model(&env, model)?; @@ -193,7 +188,7 @@ fn model_task( RequestKind::Embedding(request) => request.into(), }; - log::info!("{:#?}", sampler); + log::trace!("{:#?}", sampler); let state = model.create_state(); let remain = { @@ -205,11 +200,11 @@ fn model_task( .get_mut(&prefix[..]) .and_then(|(backed, count)| state.load(backed).ok().and(Some(count))) { - log::info!("state cache hit: {count}"); + log::trace!("state cache hit: {count}"); *count = 0; remain.split_off(prefix.len()) } else { - log::info!("state cache miss"); + log::trace!("state cache miss"); remain } }; @@ -290,6 +285,8 @@ fn model_task( keys_to_remove.push(key); } } + + log::trace!("state cache evicted: {}", keys_to_remove.len()); for key in keys_to_remove { state_cache.remove(key); } @@ -306,12 +303,19 @@ struct Args { model: Option, #[arg(long, short, value_name = "FILE")] tokenizer: Option, + #[arg(long, short)] + ip: Option, #[arg(long, short, default_value_t = 3000)] port: u16, } #[tokio::main] async fn main() -> Result<()> { + simple_logger::SimpleLogger::new() + .with_level(log::LevelFilter::Warn) + .with_module_level("ai00_server", log::LevelFilter::Trace) + .init()?; + let args = Args::parse(); let model_path = PathBuf::from( args.model @@ -345,7 +349,9 @@ async fn main() -> Result<()> { .layer(CorsLayer::permissive()) .with_state(ThreadState { sender, model_name }); - let addr = SocketAddr::from(([127, 0, 0, 1], args.port)); + let addr = SocketAddr::from((args.ip.unwrap_or(Ipv4Addr::new(0, 0, 0, 0)), args.port)); + log::info!("server started at http://{addr}"); + let listener = tokio::net::TcpListener::bind(addr).await?; axum::serve(listener, app).await?;