From fc913a0fbf966de484ac201f4e230b6cdd2bf67c Mon Sep 17 00:00:00 2001 From: cryscan Date: Tue, 25 Jul 2023 00:39:20 +0800 Subject: [PATCH] - Add LRU to state cache. - Fix derail. --- Cargo.toml | 2 +- src/main.rs | 36 ++++++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 63e1d69a..a34a5b14 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ai00_server" -version = "0.1.4" +version = "0.1.5" edition = "2021" authors = ["Gu ZhenNiu <448885@qq.com>"] license = "MIT OR Apache-2.0" diff --git a/src/main.rs b/src/main.rs index f6992bdf..239a3b91 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,7 +10,7 @@ use std::{ collections::{HashMap, HashSet}, ffi::OsStr, fs::File, - io::{BufReader, Read, Write}, + io::{BufReader, Read}, net::SocketAddr, path::PathBuf, str::FromStr, @@ -29,6 +29,7 @@ use crate::{ pub const MAX_TOKENS: usize = 4096; pub const MAX_PENALTY_COUNT: usize = 1024; +pub const STATE_CACHE_LRU: usize = 16; #[derive(Debug)] pub enum Token { @@ -152,7 +153,7 @@ fn model_task( set }; - let mut state_cache = Trie::<&[u8], BackedModelState>::new(); + let mut state_cache = Trie::<&[u8], (BackedModelState, usize)>::new(); loop { let ThreadRequest { @@ -196,14 +197,16 @@ fn model_task( let state = model.create_state(); let remain = { - let prefix = state_cache.longest_common_prefix(prompt.as_bytes()); + let prefix = state_cache + .longest_common_prefix(prompt.as_bytes()) + .to_vec(); let mut remain = prompt.as_bytes().to_vec(); - if state_cache - .get(prefix) - .and_then(|backed| state.load(backed).ok()) - .is_some() + if let Some(count) = state_cache + .get_mut(&prefix[..]) + .and_then(|(backed, count)| state.load(backed).ok().and(Some(count))) { - log::info!("state cache hit"); + log::info!("state cache hit: {count}"); + *count = 0; remain.split_off(prefix.len()) } else { log::info!("state cache miss"); @@ -243,7 +246,6 @@ fn model_task( .unwrap_or_default(); print!("{word}"); - std::io::stdout().flush()?; model_text += &word; token_counter.completion_tokens += 1; @@ -255,7 +257,7 @@ fn model_task( let _ = token_sender.send(Token::Token(word)); tokens = vec![token]; - if stop.iter().any(|x| model_text.contains(x)) { + if token == 0 || stop.iter().any(|x| model_text.contains(x)) { let _ = token_sender.send(Token::Stop(FinishReason::Stop, token_counter)); break 'run; } @@ -265,7 +267,6 @@ fn model_task( } print!("\n\n"); - std::io::stdout().flush()?; if let Ok(back) = state.back() { if embedding { @@ -278,7 +279,18 @@ fn model_task( let mut prompt = prompt.as_bytes().to_vec(); let mut model_text = model_text.as_bytes().to_vec(); prompt.append(&mut model_text); - state_cache.insert(prompt.leak(), back); + state_cache.insert(prompt.leak(), (back, 0)); + + let mut keys_to_remove = vec![]; + for (&key, (_, count)) in state_cache.iter_mut() { + *count += 1; + if *count > STATE_CACHE_LRU { + keys_to_remove.push(key); + } + } + for key in keys_to_remove { + state_cache.remove(key); + } } let _ = token_sender.send(Token::Done);