diff --git a/Cargo.lock b/Cargo.lock index 2b6932a..ef2fb36 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1659,9 +1659,9 @@ dependencies = [ [[package]] name = "kbnf" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d1722638274535080c14a131457152942c92534ed8989980d53aab4cf9da8a" +checksum = "480e80d8fadff72f7a7267646177ac7ac3e29ada0f0a02241c35182913f406c4" dependencies = [ "ahash 0.8.11", "displaydoc", diff --git a/crates/ai00-core/Cargo.toml b/crates/ai00-core/Cargo.toml index f3a3c11..d20c55c 100644 --- a/crates/ai00-core/Cargo.toml +++ b/crates/ai00-core/Cargo.toml @@ -16,7 +16,7 @@ bytemuck = "1" cbor4ii = { version = "0.3.2", features = ["serde1"] } fastrand = "2" half = "2.4" -kbnf = "0.1.2" +kbnf = "0.1.3" qp-trie = "0.8" rustc-hash = "1.1.0" uuid = { version = "1.8.0", features = ["serde", "v4"] } diff --git a/crates/ai00-core/src/sampler/bnf.rs b/crates/ai00-core/src/sampler/bnf.rs index bb048db..0b0858f 100644 --- a/crates/ai00-core/src/sampler/bnf.rs +++ b/crates/ai00-core/src/sampler/bnf.rs @@ -15,12 +15,14 @@ impl BnfSampler { .token_index_to_bytes() .iter() .enumerate() + .filter(|(_, v)| !v.is_empty()) .map(|(k, v)| (k as u32, Token(v.clone().into_boxed_slice()))) .collect(); let strings = tokenizer .token_index_to_bytes() .iter() .enumerate() + .filter(|(_, v)| !v.is_empty()) .map(|(k, v)| (k as u32, String::from_utf8_lossy(v).to_string())) .collect(); let vocab = Vocabulary::new(tokens, strings)?; @@ -31,6 +33,7 @@ impl BnfSampler { impl Transformer for BnfSampler { fn transform(&self, output: &mut [f32]) { + let output = &mut output[..self.0.vocab().vocab_size()]; self.0.mask_logits(output).expect("bnf transform error") }