Skip to content

Commit

Permalink
- Add LRU to state cache.
Browse files Browse the repository at this point in the history
- Fix derail.
  • Loading branch information
cryscan committed Jul 24, 2023
1 parent a5f6e4e commit fc913a0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "ai00_server"
version = "0.1.4"
version = "0.1.5"
edition = "2021"
authors = ["Gu ZhenNiu <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down
36 changes: 24 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::{
collections::{HashMap, HashSet},
ffi::OsStr,
fs::File,
io::{BufReader, Read, Write},
io::{BufReader, Read},
net::SocketAddr,
path::PathBuf,
str::FromStr,
Expand All @@ -29,6 +29,7 @@ use crate::{

pub const MAX_TOKENS: usize = 4096;
pub const MAX_PENALTY_COUNT: usize = 1024;
pub const STATE_CACHE_LRU: usize = 16;

#[derive(Debug)]
pub enum Token {
Expand Down Expand Up @@ -152,7 +153,7 @@ fn model_task(
set
};

let mut state_cache = Trie::<&[u8], BackedModelState>::new();
let mut state_cache = Trie::<&[u8], (BackedModelState, usize)>::new();

loop {
let ThreadRequest {
Expand Down Expand Up @@ -196,14 +197,16 @@ fn model_task(

let state = model.create_state();
let remain = {
let prefix = state_cache.longest_common_prefix(prompt.as_bytes());
let prefix = state_cache
.longest_common_prefix(prompt.as_bytes())
.to_vec();
let mut remain = prompt.as_bytes().to_vec();
if state_cache
.get(prefix)
.and_then(|backed| state.load(backed).ok())
.is_some()
if let Some(count) = state_cache
.get_mut(&prefix[..])
.and_then(|(backed, count)| state.load(backed).ok().and(Some(count)))
{
log::info!("state cache hit");
log::info!("state cache hit: {count}");
*count = 0;
remain.split_off(prefix.len())
} else {
log::info!("state cache miss");
Expand Down Expand Up @@ -243,7 +246,6 @@ fn model_task(
.unwrap_or_default();

print!("{word}");
std::io::stdout().flush()?;

model_text += &word;
token_counter.completion_tokens += 1;
Expand All @@ -255,7 +257,7 @@ fn model_task(
let _ = token_sender.send(Token::Token(word));
tokens = vec![token];

if stop.iter().any(|x| model_text.contains(x)) {
if token == 0 || stop.iter().any(|x| model_text.contains(x)) {
let _ = token_sender.send(Token::Stop(FinishReason::Stop, token_counter));
break 'run;
}
Expand All @@ -265,7 +267,6 @@ fn model_task(
}

print!("\n\n");
std::io::stdout().flush()?;

if let Ok(back) = state.back() {
if embedding {
Expand All @@ -278,7 +279,18 @@ fn model_task(
let mut prompt = prompt.as_bytes().to_vec();
let mut model_text = model_text.as_bytes().to_vec();
prompt.append(&mut model_text);
state_cache.insert(prompt.leak(), back);
state_cache.insert(prompt.leak(), (back, 0));

let mut keys_to_remove = vec![];
for (&key, (_, count)) in state_cache.iter_mut() {
*count += 1;
if *count > STATE_CACHE_LRU {
keys_to_remove.push(key);
}
}
for key in keys_to_remove {
state_cache.remove(key);
}
}

let _ = token_sender.send(Token::Done);
Expand Down

0 comments on commit fc913a0

Please sign in to comment.