Skip to content

Commit

Permalink
Fix cache and slot scheduler.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Oct 18, 2023
1 parent 357be66 commit b505f75
Showing 1 changed file with 90 additions and 27 deletions.
117 changes: 90 additions & 27 deletions src/run.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Borrow,
cmp::Ordering,
collections::{HashMap, HashSet},
convert::Infallible,
sync::{Arc, Mutex, RwLock},
Expand Down Expand Up @@ -46,6 +47,34 @@ impl Default for SlotState {
}
}

#[derive(Debug, PartialEq, Eq)]
enum SlotChoice {
Continue(usize, usize),
Back(usize),
Empty(usize),
}

impl std::cmp::Ord for SlotChoice {
fn cmp(&self, other: &Self) -> Ordering {
use SlotChoice::{Back, Continue, Empty};
match (self, other) {
(Continue(_, x), Continue(_, y)) => x.cmp(y),
(Continue(_, _), _) => Ordering::Greater,
(_, Continue(_, _)) => Ordering::Less,
(Empty(_), Empty(_)) => Ordering::Equal,
(Empty(_), Back(_)) => Ordering::Greater,
(Back(_), Empty(_)) => Ordering::Less,
(Back(_), Back(_)) => Ordering::Equal,
}
}
}

impl std::cmp::PartialOrd for SlotChoice {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

#[derive(Debug, Default)]
enum Payload {
#[default]
Expand Down Expand Up @@ -247,14 +276,42 @@ where
let choice = slots
.iter()
.enumerate()
.filter_map(|(index, slot)| match slot {
SlotState::Idle(content, time) => match tokens.starts_with(content) {
true => Some((index, true, content.len(), time.elapsed().as_micros())),
false => Some((index, false, 0, time.elapsed().as_micros())),
},
.filter_map(|(batch, slot)| match slot {
SlotState::Idle(content, time) => {
let delta = time.elapsed().as_millis();
match (content.is_empty(), tokens.starts_with(content)) {
(true, _) => Some((SlotChoice::Empty(batch), delta)),
(false, true) => Some((SlotChoice::Continue(batch, content.len()), delta)),
(false, false) => Some((SlotChoice::Back(batch), delta)),
}
}
_ => None,
})
.max_by(|&lhs, &rhs| lhs.2.cmp(&rhs.2).then(lhs.3.cmp(&rhs.3)));
.max_by(|lhs, rhs| lhs.0.cmp(&rhs.0).then(lhs.1.cmp(&rhs.1)));

let mut checkout = |batch: usize| -> (Vec<u16>, B) {
let prefix = cache.longest_common_prefix(&tokens);
let len = (0..prefix.len())
.rev()
.find(|len| cache.contains_key(prefix[0..*len].as_token_slice()))
.unwrap_or_default();
log::info!("slot {} checks out backed cache of length {}", batch, len);

let prefix = prefix.to_vec();
let reload = cache
.remove(prefix[..].as_token_slice())
.unwrap_or_else(|| {
let context = self.model.context();
let info = self.model.info();
StateBuilder::new(context, info)
.with_max_batch(1)
.with_chunk_size(STATE_CHUNK_SIZE)
.build_backed()
});
cache.insert(Tokens(prefix.clone()), reload.clone());
(prefix, reload)
};

match choice {
None => SlotResult::Failure(
GenerateContext {
Expand All @@ -264,12 +321,11 @@ where
}
.into(),
),
Some((batch, false, _, _)) => {
let prefix = cache.longest_common_prefix(&tokens);
let len = match cache.contains_key(prefix) {
true => prefix.len(),
false => 0,
};
Some((SlotChoice::Back(batch), _)) => {
log::info!("start at non-empty slot {}", batch);
let (prefix, reload) = checkout(batch);

let len = prefix.len();
let mut state = SlotState::Wait(
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
Expand All @@ -279,18 +335,6 @@ where
.into(),
);

let prefix = prefix.to_vec();
let reload = cache
.remove(prefix[..].as_token_slice())
.unwrap_or_else(|| {
let context = self.model.context();
let info = self.model.info();
StateBuilder::new(context, info)
.with_max_batch(1)
.with_chunk_size(STATE_CHUNK_SIZE)
.build_backed()
});

std::mem::swap(&mut state, &mut slots[batch]);
match state {
SlotState::Idle(content, _) => {
Expand All @@ -302,7 +346,26 @@ where
_ => unreachable!(),
}
}
Some((id, true, len, _)) => {
Some((SlotChoice::Empty(batch), _)) => {
log::info!("start at empty slot {}", batch);
let (prefix, reload) = checkout(batch);

let len = prefix.len();
let state = SlotState::Wait(
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
..context
}
.into(),
);
slots[batch] = state;

self.state.load_batch(&reload, batch).expect("load state");
SlotResult::Fault(batch)
}
Some((SlotChoice::Continue(batch, len), _)) => {
log::info!("continue at slot {}", batch);
let state = SlotState::Wait(
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
Expand All @@ -311,8 +374,8 @@ where
}
.into(),
);
let _ = std::mem::replace(&mut slots[id], state);
SlotResult::Success(id)
slots[batch] = state;
SlotResult::Success(batch)
}
}
}
Expand Down

0 comments on commit b505f75

Please sign in to comment.