Skip to content

Commit

Permalink
Rename Transformer to Formatter.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jun 15, 2024
1 parent 2b1150c commit 1c165bc
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion crates/ai00-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
model_text: vec![],
buffer: vec![],
model_tokens: vec![],
transformers: vec![],
formatters: vec![],
instant: None,
request,
sender: token_sender,
Expand Down
30 changes: 15 additions & 15 deletions crates/ai00-core/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use web_rwkv::{
};

use crate::{
sampler::{bnf::BnfSampler, Transformer},
sampler::{bnf::BnfSampler, Formatter},
Environment, FinishReason, GenerateKind, GenerateRequest, ReloadRequest, Token, TokenCounter,
};

Expand Down Expand Up @@ -229,7 +229,7 @@ pub struct GenerateContext {
pub model_tokens: Vec<u16>,
/// Compiled BNF schema, if any.
#[derivative(Debug = "ignore")]
pub transformers: Vec<Arc<RwLock<dyn Transformer + Send + Sync>>>,
pub formatters: Vec<Arc<RwLock<dyn Formatter + Send + Sync>>>,
/// For measuring time used.
pub instant: Option<Instant>,
/// Generate request provided by the caller.
Expand Down Expand Up @@ -578,10 +578,10 @@ impl Runtime {
}

// compile the BNF schema.
let mut transformers = Vec::<Arc<RwLock<dyn Transformer + Send + Sync>>>::new();
let mut formatters = Vec::<Arc<RwLock<dyn Formatter + Send + Sync>>>::new();
if let Some(schema) = context.request.bnf_schema.clone() {
match self.compile_bnf_schema(schema).await {
Ok(bnf) => transformers.push(Arc::new(RwLock::new(bnf))),
Ok(bnf) => formatters.push(Arc::new(RwLock::new(bnf))),
Err(err) => return Ok(SlotResult::Error(err.to_string())),
}
}
Expand Down Expand Up @@ -614,7 +614,7 @@ impl Runtime {
GenerateContext {
prefix: Default::default(),
suffix: Tokens(tokens),
transformers,
formatters,
..context
}
.into(),
Expand All @@ -634,7 +634,7 @@ impl Runtime {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
output: checkout.output,
transformers,
formatters,
..context
}
.into(),
Expand All @@ -658,7 +658,7 @@ impl Runtime {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
output: checkout.output,
transformers,
formatters,
..context
}
.into(),
Expand All @@ -681,7 +681,7 @@ impl Runtime {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
output: checkout.output,
transformers,
formatters,
..context
}
.into(),
Expand Down Expand Up @@ -803,7 +803,7 @@ impl Runtime {

let num_vocab = self.info.num_vocab;
let output = output;
let transformers = context.transformers.clone();
let formatters = context.formatters.clone();
let sampler = context.request.sampler.clone();
let bias = context.request.bias.clone();
set.spawn(async move {
Expand All @@ -814,8 +814,8 @@ impl Runtime {
for (token, bias) in bias.iter() {
data[*token as usize] += *bias;
}
for transformer in transformers {
transformer.read().await.transform(&mut data);
for formatter in formatters {
formatter.read().await.transform(&mut data);
}

(batch, data)
Expand Down Expand Up @@ -934,11 +934,11 @@ impl Runtime {
done = true;
};

// update the transformer (BNF) state
// update the formatter (BNF) state
let mut halt = false;
for transformer in context.transformers.iter() {
let mut transformer = transformer.write().await;
halt |= transformer.update(token);
for formatter in context.formatters.iter() {
let mut formatter = formatter.write().await;
halt |= formatter.update(token);
}

// here we detect if there is a stop word in our buffer
Expand Down
4 changes: 2 additions & 2 deletions crates/ai00-core/src/sampler/bnf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use kbnf::{
};
use web_rwkv::tokenizer::Tokenizer;

use super::Transformer;
use super::Formatter;

#[derive(Debug)]
pub struct BnfSampler(Engine);
Expand All @@ -31,7 +31,7 @@ impl BnfSampler {
}
}

impl Transformer for BnfSampler {
impl Formatter for BnfSampler {
fn transform(&self, output: &mut [f32]) {
let output = &mut output[..self.0.vocab().vocab_size()];
self.0.mask_logits(output).expect("bnf transform error")
Expand Down
2 changes: 1 addition & 1 deletion crates/ai00-core/src/sampler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub trait Sampler {
fn sample(&mut self, probs: &[f32]) -> u16;
}

pub trait Transformer {
pub trait Formatter {
/// Update the raw model output.
fn transform(&self, output: &mut [f32]);
/// Update the internal state after a token is chosen. Return if the state machine is halt.
Expand Down

0 comments on commit 1c165bc

Please sign in to comment.