diff --git a/Cargo.toml b/Cargo.toml index 904432a..fc19b6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,3 +16,6 @@ regex = "1.11" memchr = "2.7" patricia_tree = "0.9" log = "0.4" +fancy-regex = "0.14.0" +ggus = "0.4" +memmap2 = "0.9" diff --git a/src/bpe/mod.rs b/src/bpe/mod.rs index e486a36..906ea8c 100644 --- a/src/bpe/mod.rs +++ b/src/bpe/mod.rs @@ -2,12 +2,22 @@ mod algorithm; +use ggus::{GGmlTokenType, GGufMetaMapExt}; +use log::warn; + use crate::{ - Method, utok, + Method, Tokeneer, + common::TOKENIZER_PRE_QWEN, + utils::{llama_decode_text, unicode_byte_to_utf8}, + utok, vocab::{CollectedVocab, CompressedVocab, TokenType}, }; -use std::{collections::HashSet, iter::zip, ops::Deref, pin::Pin, ptr::NonNull}; - +use std::{borrow::Cow, collections::HashSet, iter::zip, ops::Deref, pin::Pin, ptr::NonNull}; +// 只用于分词,是否判断停止让,tokeneer做,弱国要支持gpt2要添加 +pub enum Model { + GPT2(fancy_regex::Regex), + LLaMa, +} pub struct Bpe { /// 保存所有词的字符串内容,以 u8 为单位所以不需要对齐,占用空间少 _vocabs: Pin>, @@ -22,6 +32,7 @@ pub struct Bpe { special: Box<[utok]>, /// token: unk: utok, + modeltype: Model, } struct TokenMeta { @@ -46,7 +57,7 @@ impl Deref for TokenMeta { } impl Bpe { - /// 解析 tokenizer.model 文件并构造一个 bpe 分词器。 + /// 解析 tokenizer.model 文件并构造一个 bpe 分词器。暂时不支持gpt2格式的分词模型 pub fn from_tokenizer_model(model: &[u8]) -> Self { // 遍历文件,标记所有词汇的位置 let offsets = (0..) @@ -80,6 +91,7 @@ impl Bpe { 0, ), scores, + Model::LLaMa, ) } @@ -88,20 +100,27 @@ impl Bpe { scores: impl IntoIterator, token_type: impl IntoIterator, unk: utok, + modeltype: Model, ) -> Self { Self::from_collected_vocab( CollectedVocab::collect(vocabs.into_iter().map(|s| s.as_bytes()), token_type, unk), scores, + modeltype, ) } - fn from_collected_vocab(vocab: CollectedVocab, scores: impl IntoIterator) -> Self { + fn from_collected_vocab( + vocab: CollectedVocab, + scores: impl IntoIterator, + modeltype: Model, + ) -> Self { let CollectedVocab { vocabs, total_len, bytes, special, unk, + .. } = vocab; let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len); // 收集合词评分 @@ -140,9 +159,16 @@ impl Bpe { bytes, special, unk, + modeltype, }; - let inaccessible = ans.inaccessible(); - ans.special = ans.special.into_iter().chain(inaccessible).collect(); + match ans.modeltype { + // 在gpt2中会把正常字符当作特殊字符 例如 1052 Ġthere 影响解码 + Model::GPT2(_) => {} + Model::LLaMa => { + let inaccessible = ans.inaccessible(); + ans.special = ans.special.into_iter().chain(inaccessible).collect(); + } + } ans } @@ -181,6 +207,68 @@ impl Bpe { fn token(&self, token: utok) -> &TokenMeta { &self.tokens[token as usize] } + pub fn from_gguf(gguf: &T) -> Tokeneer { + let tokens = gguf.tokenizer_ggml_tokens().unwrap(); + let vocab_type = gguf.tokenizer_ggml_model().unwrap(); + let scores = gguf.tokenizer_ggml_scores(); + let token_type = gguf.tokenizer_ggml_token_type().unwrap(); + let token_type = token_type.map(|ty| { + match unsafe { std::mem::transmute::(ty.unwrap()) } { + GGmlTokenType::Normal => TokenType::Normal, + GGmlTokenType::Unknown => TokenType::Unknown, + GGmlTokenType::Control => TokenType::Control, + GGmlTokenType::User => TokenType::UserDefined, + GGmlTokenType::Unused => TokenType::Normal, + GGmlTokenType::Byte => TokenType::Byte, + } + }); + let token_len = tokens.len(); + let vocabs = tokens.map(|piece| piece.unwrap()); + // 部分模型可能无该字段 + let unk = match gguf.tokenizer_ggml_unknown_token_id() { + Ok(id) => id, + Err(_) => { + warn!("tokenizer_ggml_unknown_token_id not found"); + u32::MAX + } + }; + match vocab_type { + "llama" => { + let scores = scores.unwrap(); + assert_eq!(token_len, scores.len()); + let scores = scores.map(|s| s.unwrap()); + Tokeneer::new(Bpe::new(vocabs, scores, token_type, unk, Model::LLaMa)) + } + "gpt2" => { + let pre_type = gguf.get_str("tokenizer.ggml.pre").unwrap(); + let regex_str = match pre_type { + "qwen2" | "deepseek-r1-qwen" => TOKENIZER_PRE_QWEN, + _ => unimplemented!("not supported pre_type {}", pre_type), + }; + match scores { + Ok(scores) => { + assert_eq!(token_len, scores.len()); + let scores = scores.map(|s| s.unwrap()); + Tokeneer::new(Bpe::new( + vocabs, + scores, + token_type, + unk, + Model::GPT2(fancy_regex::Regex::new(regex_str).unwrap()), + )) + } + Err(_) => Tokeneer::new(Bpe::new( + vocabs, + std::iter::repeat(0.0).take(token_len), + token_type, + unk, + Model::GPT2(fancy_regex::Regex::new(regex_str).unwrap()), + )), + } + } + _ => unreachable!("not supported model"), + } + } } impl Method for Bpe { @@ -201,13 +289,79 @@ impl Method for Bpe { } #[inline] fn encode(&self, text: &str) -> impl IntoIterator + '_ { - let mut tokenizer = self.begin_merge(text); - while tokenizer.merge() {} - tokenizer.into_iter() + let text = self.pre_encode(text); + let mut vocab = Vec::new(); + match &self.modeltype { + Model::GPT2(regex) => { + // 使用正则表达式分割文本 + let mut result = Vec::new(); + let mut last_end = 0; + // 将正则表达式中的匹配结果编码之后,添加到列表中 + let mut push_result = |s: &str| { + let r = s.bytes().map(unicode_byte_to_utf8).collect::(); + result.push(r); + }; + for cap in regex.captures_iter(&text).flatten() { + if let Some(m) = cap.get(0) { + // 如果匹配前有未匹配的文本,添加到结果中 + if m.start() > last_end { + push_result(&text[last_end..m.start()]); + } + // 添加匹配的文本 + push_result(&text[m.start()..m.end()]); + last_end = m.end(); + } + } + + // 添加最后一部分未匹配的文本 + if last_end < text.len() { + push_result(&text[last_end..]); + } + for r in result { + let mut tokenizer = self.begin_merge(&r); + while tokenizer.merge() {} + vocab.extend(tokenizer.into_iter()); + } + } + Model::LLaMa => { + let mut tokenizer = self.begin_merge(&text); + while tokenizer.merge() {} + vocab.extend(tokenizer); + } + } + vocab.into_iter() } #[inline] - fn decode(&self, token: utok) -> &[u8] { - self.token(token) + fn decode(&self, token: utok) -> Cow<'_, [u8]> { + match &self.modeltype { + Model::GPT2(_) => match self.special.contains(&token) { + true => { + // 特殊token 直接返回 + std::borrow::Cow::Borrowed(self.token(token)) + } + false => llama_decode_text(&String::from_utf8_lossy(self.token(token))) + .into_bytes() + .into(), + }, + + Model::LLaMa => { + let token_str = String::from_utf8_lossy(self.token(token)); + let decoded_str = self.pre_decode(&token_str); + Cow::Owned(decoded_str.into_owned().into_bytes()) + } + } + } + fn pre_encode<'s>(&self, text: &'s str) -> Cow<'s, str> { + match &self.modeltype { + Model::GPT2(_) => text.into(), + Model::LLaMa => text.replace(" ", "\u{2581}").into(), + } + } + fn pre_decode<'s>(&self, text: &'s str) -> Cow<'s, str> { + match &self.modeltype { + Model::GPT2(_) => text.into(), + Model::LLaMa => text.replace("\u{2581}", " ").into(), + } } } @@ -291,6 +445,7 @@ mod bpe_tests { ], [TokenType::Normal; 10], 0, + Model::LLaMa, ) } @@ -316,10 +471,10 @@ mod bpe_tests { #[test] fn test_bpe_decode() { let bpe = test_bpe(); - assert_eq!(bpe.decode(3), b"c"); - assert_eq!(bpe.decode(6), b"ac"); - assert_eq!(bpe.decode(9), b"bcd"); - assert_eq!(bpe.decode(0), b""); + assert_eq!(&*bpe.decode(3), b"c"); + assert_eq!(&*bpe.decode(6), b"ac"); + assert_eq!(&*bpe.decode(9), b"bcd"); + assert_eq!(&*bpe.decode(0), b""); } #[test] @@ -332,7 +487,7 @@ mod bpe_tests { let decoded: Vec<_> = encoded .iter() - .flat_map(|&t| bpe.decode(t).iter().copied()) + .flat_map(|&t| bpe.decode(t).iter().copied().collect::>()) .collect(); assert_eq!(std::str::from_utf8(&decoded), Ok("abcd")) } @@ -376,7 +531,7 @@ mod bpe_tests { TokenType::Byte, TokenType::Byte, ]; - let bpe = Bpe::new(vocabs, scores, token_type, 0); + let bpe = Bpe::new(vocabs, scores, token_type, 0, Model::LLaMa); let encoded: Vec<_> = bpe.encode("aAB").into_iter().collect(); assert_eq!(encoded, [0, 2, 3], "Expected 3 tokens for input 'aAB'") diff --git a/src/common/mod.rs b/src/common/mod.rs new file mode 100644 index 0000000..50c355e --- /dev/null +++ b/src/common/mod.rs @@ -0,0 +1 @@ +pub static TOKENIZER_PRE_QWEN: &str = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\nÒA-Za-zÑ0-9]?[ÒA-Za-z]+|[Ñ0-9]| ?[^\\sÒA-Za-zÑ0-9]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"; diff --git a/src/lib.rs b/src/lib.rs index 295fd6d..2246a60 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,11 +1,14 @@ #![deny(warnings)] - mod bpe; +mod common; mod lpe; mod tokeneer; +mod utils; mod vocab; +use std::borrow::Cow; pub use bpe::Bpe; +pub use bpe::Model; pub use lpe::Lpe; pub use tokeneer::Tokeneer; pub use vocab::TokenType; @@ -13,11 +16,17 @@ pub use vocab::TokenType; /// `utok` for token id. #[allow(non_camel_case_types)] pub type utok = u32; - +// 添加判断终止符的函数 pub trait Method { fn unk_token(&self) -> utok; fn vocab_size(&self) -> usize; fn internal_special(&self) -> impl IntoIterator; fn encode(&self, text: &str) -> impl IntoIterator + '_; - fn decode(&self, token: utok) -> &[u8]; + fn decode(&self, token: utok) -> Cow<[u8]>; + fn pre_encode<'s>(&self, text: &'s str) -> Cow<'s, str> { + text.into() + } + fn pre_decode<'s>(&self, text: &'s str) -> Cow<'s, str> { + text.into() + } } diff --git a/src/lpe/mod.rs b/src/lpe/mod.rs index da71dd6..3eb304a 100644 --- a/src/lpe/mod.rs +++ b/src/lpe/mod.rs @@ -63,6 +63,7 @@ impl Lpe { bytes, special, unk, + .. } = vocab; let CompressedVocab { vocabs, slices } = if map_utf8 { @@ -164,8 +165,8 @@ impl Method for Lpe { tokens } #[inline] - fn decode(&self, token: utok) -> &[u8] { - self.token(token) + fn decode(&self, token: utok) -> Cow<'_, [u8]> { + std::borrow::Cow::Borrowed(self.token(token)) } } diff --git a/src/tokeneer.rs b/src/tokeneer.rs index 4337e2f..e7d8fa7 100644 --- a/src/tokeneer.rs +++ b/src/tokeneer.rs @@ -61,7 +61,7 @@ impl Tokeneer { pub fn decode(&self, tokens: &[utok]) -> String { let mut ans = Vec::new(); for &t in tokens { - ans.extend_from_slice(self.method.decode(t)) + ans.extend_from_slice(&self.method.decode(t)) } String::from_utf8(ans).unwrap() } diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000..2e8ba6f --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,99 @@ +use std::{collections::HashMap, sync::OnceLock}; +static BYTE_TO_UTF8: OnceLock> = OnceLock::new(); +static UTF8_TO_BYTE: OnceLock> = OnceLock::new(); +/// 创建一个从字节到 UTF-8 字符串的映射,主要用于gpt2 +fn unicode_byte_to_utf8_map() -> HashMap { + let mut map = HashMap::new(); + + // 映射 ASCII 可打印字符 '!' 到 '~' + for ch in 0x21..=0x7E { + map.insert(ch as u8, char::from_u32(ch).unwrap()); + } + + // 映射拉丁字符 '¡' 到 '¬' + for ch in 0xA1..=0xAC { + map.insert(ch as u8, char::from_u32(ch).unwrap()); + } + + // 映射拉丁字符 '®' 到 'ÿ' + for ch in 0xAE..=0xFF { + map.insert(ch as u8, char::from_u32(ch).unwrap()); + } + + // 为剩余的字节值分配映射 + let mut n = 0; + for ch in 0..256 { + if let std::collections::hash_map::Entry::Vacant(e) = map.entry(ch as u8) { + e.insert(char::from_u32(256 + n).unwrap()); + n += 1; + } + } + map +} +/// 创建一个从字节到 UTF-8 字符串的映射,主要用于gpt2 +fn unicode_utf8_to_byte_map() -> HashMap { + let mut map = HashMap::new(); + + // 映射 ASCII 可打印字符 '!' 到 '~' + for ch in 0x21..=0x7E { + map.insert(char::from_u32(ch).unwrap(), ch as u8); + } + + // 映射拉丁字符 '¡' 到 '¬' + for ch in 0xA1..=0xAC { + map.insert(char::from_u32(ch).unwrap(), ch as u8); + } + + // 映射拉丁字符 '®' 到 'ÿ' + for ch in 0xAE..=0xFF { + map.insert(char::from_u32(ch).unwrap(), ch as u8); + } + + // 为剩余的字节值分配映射 + let mut n = 0; + for ch in 0..256 { + if !map.contains_key(&char::from_u32(ch).unwrap()) { + map.insert(char::from_u32(256 + n).unwrap(), ch as u8); + n += 1; + } + } + map +} + +pub fn unicode_byte_to_utf8(byte: u8) -> char { + *BYTE_TO_UTF8 + .get_or_init(unicode_byte_to_utf8_map) + .get(&byte) + .unwrap() +} +pub fn unicode_utf8_to_byte(utf8: char) -> u8 { + *UTF8_TO_BYTE + .get_or_init(unicode_utf8_to_byte_map) + .get(&utf8) + .unwrap() +} + +pub fn llama_decode_text(text: &str) -> String { + let bytes: Vec = text.chars().map(unicode_utf8_to_byte).collect(); + + String::from_utf8_lossy(&bytes).to_string() +} +#[cfg(test)] +mod test_tokoneer { + use crate::utils::{llama_decode_text, unicode_byte_to_utf8}; + + #[test] + fn bpe_from_gguf() { + let s = String::from("你好"); + let p: String = s + .into_bytes() + .iter() + .map(|s| unicode_byte_to_utf8(*s)) + .collect(); + print!("dsf {:?}", p); + assert!(p == "ä½łå¥½"); + let a = llama_decode_text("Ġthere"); + println!("dsf {:?}", a); + // println!("ds {:?}",b.get(&'▁').unwrap()); + } +} diff --git a/src/vocab.rs b/src/vocab.rs index 7415387..6745b87 100644 --- a/src/vocab.rs +++ b/src/vocab.rs @@ -1,8 +1,14 @@ -//! 这个模块提供对词表的预处理功能,这些功能适用于多种不同算法的分词器。 +//! 这个模块提供对词表的预处理功能,这些功能适用于多种不同算法的分词器。 use crate::utok; use log::trace; -use std::{iter::zip, pin::Pin, slice::from_ref, str::from_utf8_unchecked}; +use std::{ + collections::{HashSet, hash_set}, + iter::zip, + pin::Pin, + slice::from_ref, + str::from_utf8_unchecked, +}; /// 收集和预处理词表。 /// @@ -22,6 +28,9 @@ pub(crate) struct CollectedVocab<'s> { pub bytes: Box<[utok; 256]>, /// 特殊词汇 pub special: Box<[utok]>, + // 判断对话终止字符,预留的收集终止推理的token + #[allow(dead_code)] + pub eog: HashSet, /// 填充词 pub unk: utok, } @@ -47,6 +56,7 @@ impl<'s> CollectedVocab<'s> { let mut vocabs = Vec::new(); let mut special = Vec::new(); + let mut eog = hash_set::HashSet::new(); for (i, (piece, tt)) in zip(vocabs_, token_type).enumerate() { let piece = match tt { TokenType::Byte => { @@ -63,20 +73,20 @@ impl<'s> CollectedVocab<'s> { trace!("find {tt:?}: {} @ {i}", unsafe { from_utf8_unchecked(piece) }); + if Self::is_eog(tt, piece) { + eog.insert(i as _); + } special.push(i as _); piece } - _ => { - let piece = match as_byte_token(piece) { - Some(b) => { - let b = b as usize; - bytes[b] = i as _; - from_ref(&BYTES[b]) - } - None => piece, - }; - piece - } + _ => match as_byte_token(piece) { + Some(b) => { + let b = b as usize; + bytes[b] = i as _; + from_ref(&BYTES[b]) + } + None => piece, + }, }; vocabs.push(piece); total_len += piece.len() @@ -86,9 +96,26 @@ impl<'s> CollectedVocab<'s> { total_len, bytes, special: special.into_boxed_slice(), + eog, unk, } } + pub fn is_eog(tt: TokenType, piece: &[u8]) -> bool { + match tt { + TokenType::Control => { + let key = unsafe { from_utf8_unchecked(piece) }; + key == "<|eot_id|>" + || key == "<|im_end|>" + || key == "<|end|>" + || key == "" + || key == "<|endoftext|>" + || key == "<|eom_id|>" + || key == "< EOT >" + || key == "_< EOT >" + } + _ => false, + } + } } /// 利用词表中的重复部分压缩词表。