Skip to content

Commit

Permalink
Make BNF a general transformer.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 22, 2024
1 parent 567d2db commit abd291f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 39 deletions.
8 changes: 4 additions & 4 deletions crates/ai00-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,10 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
prompt_cached: false,
prefix: Default::default(),
suffix: tokens,
model_text: Default::default(),
buffer: Default::default(),
model_tokens: Default::default(),
bnf_sampler: None,
model_text: vec![],
buffer: vec![],
model_tokens: vec![],
transformers: vec![],
instant: None,
request,
sender: token_sender,
Expand Down
46 changes: 20 additions & 26 deletions crates/ai00-core/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
};

use anyhow::Result;
use bnf_sampler::{grammar::Grammar, sampler::AcceptTokenResult, vocabulary::Vocabulary};
use bnf_sampler::{grammar::Grammar, vocabulary::Vocabulary};
use derivative::Derivative;
use flume::{Receiver, Sender};
use itertools::Itertools;
Expand Down Expand Up @@ -208,7 +208,8 @@ impl AsTokenSlice for [u16] {
}
}

#[derive(Debug, Clone)]
#[derive(Derivative, Clone)]
#[derivative(Debug)]
pub struct GenerateContext {
/// Tokens that are provided at first.
pub prompt_tokens: Vec<u16>,
Expand All @@ -225,7 +226,8 @@ pub struct GenerateContext {
/// Tokens that are output by the model.
pub model_tokens: Vec<u16>,
/// Compiled BNF schema, if any.
pub bnf_sampler: Option<Arc<RwLock<BnfSampler>>>,
#[derivative(Debug = "ignore")]
pub transformers: Vec<Arc<RwLock<dyn Transformer + Send + Sync>>>,
/// For measuring time used.
pub instant: Option<Instant>,
/// Generate request provided by the caller.
Expand Down Expand Up @@ -562,14 +564,13 @@ impl Runtime {
};

// compile the BNF schema.
let bnf_sampler = if let Some(schema) = context.request.bnf_schema.clone() {
let mut transformers = Vec::<Arc<RwLock<dyn Transformer + Send + Sync>>>::new();
if let Some(schema) = context.request.bnf_schema.clone() {
match self.compile_bnf_schema(schema).await {
Ok(bnf_sampler) => Some(Arc::new(RwLock::new(bnf_sampler))),
Ok(bnf) => transformers.push(Arc::new(RwLock::new(bnf))),
Err(err) => return Ok(SlotResult::Error(err.to_string())),
}
} else {
None
};
}

// find the best idle slot by:
// 1. find the slot that matches the context (continue)
Expand Down Expand Up @@ -599,7 +600,7 @@ impl Runtime {
GenerateContext {
prefix: Default::default(),
suffix: Tokens([tokens, vec![last]].concat()),
bnf_sampler,
transformers,
..context
}
.into(),
Expand All @@ -616,7 +617,7 @@ impl Runtime {
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
bnf_sampler,
transformers,
..context
}
.into(),
Expand All @@ -637,7 +638,7 @@ impl Runtime {
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
bnf_sampler,
transformers,
..context
}
.into(),
Expand All @@ -653,7 +654,7 @@ impl Runtime {
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
bnf_sampler,
transformers,
..context
}
.into(),
Expand Down Expand Up @@ -773,7 +774,7 @@ impl Runtime {
(Payload::Busy(context), output) if output.size() > 0 => {
let num_vocab = self.info.num_vocab;
let output = output.0.clone();
let bnf = context.bnf_sampler.clone();
let transformers = context.transformers.clone();
let sampler = context.request.sampler.clone();
let bias = context.request.bias.clone();
set.spawn(async move {
Expand All @@ -784,8 +785,8 @@ impl Runtime {
for (token, bias) in bias.iter() {
data[*token as usize] += *bias;
}
if let Some(bnf) = bnf {
bnf.read().await.transform(&mut data);
for transformer in transformers {
transformer.read().await.transform(&mut data);
}

(batch, data)
Expand Down Expand Up @@ -906,18 +907,11 @@ impl Runtime {
done = true;
};

// update the BNF state
// update the transformer (BNF) state
let mut exhausted = false;
if let Some(bnf) = context.bnf_sampler.clone() {
let mut bnf = bnf.write().await;
match bnf.update(token) {
AcceptTokenResult::Continue => {}
AcceptTokenResult::End => exhausted = true,
AcceptTokenResult::Failed => {
log::warn!("slot {batch} bnf failure");
exhausted = true;
}
}
for transformer in context.transformers.iter() {
let mut transformer = transformer.write().await;
exhausted |= transformer.update(token);
}

// here we detect if there is a stop word in our buffer
Expand Down
11 changes: 6 additions & 5 deletions crates/ai00-core/src/sampler/bnf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ impl std::ops::DerefMut for BnfSampler {
}

impl Transformer for BnfSampler {
type Output = AcceptTokenResult;

fn transform(&self, output: &mut [f32]) {
output
.iter_mut()
Expand All @@ -52,13 +50,16 @@ impl Transformer for BnfSampler {
.for_each(|(_, logits)| *logits = f32::MIN)
}

fn update(&mut self, token: u16) -> AcceptTokenResult {
fn update(&mut self, token: u16) -> bool {
let token = Some(token as u32);
let res = self.accept_a_token(token).expect("invalid input token");
let accept = self.accept_a_token(token).expect("invalid input token");
self.current_token_ids = match self.sampler.all_possible_next_tokens(None) {
Ok(PossibleTokensResult::Continue(tokens)) => tokens.clone(),
_ => BitSet::new(),
};
res
match accept {
AcceptTokenResult::Continue => false,
AcceptTokenResult::End | AcceptTokenResult::Failed => true,
}
}
}
6 changes: 2 additions & 4 deletions crates/ai00-core/src/sampler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ pub trait Sampler {
}

pub trait Transformer {
type Output;

/// Update the raw model output.
fn transform(&self, output: &mut [f32]);
/// Update the internal state after a token is chosen.
fn update(&mut self, token: u16) -> Self::Output;
/// Update the internal state after a token is chosen. Return if the state machine is halt.
fn update(&mut self, token: u16) -> bool;
}

0 comments on commit abd291f

Please sign in to comment.