Skip to content

Commit

Permalink
Can specify ip.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jul 25, 2023
1 parent 565d266 commit 5712c6c
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::{
ffi::OsStr,
fs::File,
io::{BufReader, Read},
net::SocketAddr,
net::{Ipv4Addr, SocketAddr},
path::PathBuf,
str::FromStr,
};
Expand Down Expand Up @@ -130,11 +130,6 @@ fn model_task(
tokenizer: PathBuf,
receiver: Receiver<ThreadRequest>,
) -> Result<()> {
simple_logger::SimpleLogger::new()
.with_level(log::LevelFilter::Warn)
.with_module_level("ai00_server", log::LevelFilter::Trace)
.init()?;

let tokenizer = load_tokenizer(tokenizer)?;
let model = load_model(&env, model)?;

Expand Down Expand Up @@ -193,7 +188,7 @@ fn model_task(
RequestKind::Embedding(request) => request.into(),
};

log::info!("{:#?}", sampler);
log::trace!("{:#?}", sampler);

let state = model.create_state();
let remain = {
Expand All @@ -205,11 +200,11 @@ fn model_task(
.get_mut(&prefix[..])
.and_then(|(backed, count)| state.load(backed).ok().and(Some(count)))
{
log::info!("state cache hit: {count}");
log::trace!("state cache hit: {count}");
*count = 0;
remain.split_off(prefix.len())
} else {
log::info!("state cache miss");
log::trace!("state cache miss");
remain
}
};
Expand Down Expand Up @@ -290,6 +285,8 @@ fn model_task(
keys_to_remove.push(key);
}
}

log::trace!("state cache evicted: {}", keys_to_remove.len());
for key in keys_to_remove {
state_cache.remove(key);
}
Expand All @@ -306,12 +303,19 @@ struct Args {
model: Option<String>,
#[arg(long, short, value_name = "FILE")]
tokenizer: Option<String>,
#[arg(long, short)]
ip: Option<Ipv4Addr>,
#[arg(long, short, default_value_t = 3000)]
port: u16,
}

#[tokio::main]
async fn main() -> Result<()> {
simple_logger::SimpleLogger::new()
.with_level(log::LevelFilter::Warn)
.with_module_level("ai00_server", log::LevelFilter::Trace)
.init()?;

let args = Args::parse();
let model_path = PathBuf::from(
args.model
Expand Down Expand Up @@ -345,7 +349,9 @@ async fn main() -> Result<()> {
.layer(CorsLayer::permissive())
.with_state(ThreadState { sender, model_name });

let addr = SocketAddr::from(([127, 0, 0, 1], args.port));
let addr = SocketAddr::from((args.ip.unwrap_or(Ipv4Addr::new(0, 0, 0, 0)), args.port));
log::info!("server started at http://{addr}");

let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;

Expand Down

0 comments on commit 5712c6c

Please sign in to comment.