Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ regex = "1.11"
memchr = "2.7"
patricia_tree = "0.9"
log = "0.4"
fancy-regex = "0.14.0"
ggus = "0.4"
memmap2 = "0.9"
191 changes: 173 additions & 18 deletions src/bpe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@

mod algorithm;

use ggus::{GGmlTokenType, GGufMetaMapExt};
use log::warn;

use crate::{
Method, utok,
Method, Tokeneer,
common::TOKENIZER_PRE_QWEN,
utils::{llama_decode_text, unicode_byte_to_utf8},
utok,
vocab::{CollectedVocab, CompressedVocab, TokenType},
};
use std::{collections::HashSet, iter::zip, ops::Deref, pin::Pin, ptr::NonNull};

use std::{borrow::Cow, collections::HashSet, iter::zip, ops::Deref, pin::Pin, ptr::NonNull};
// 只用于分词,是否判断停止让,tokeneer做,弱国要支持gpt2要添加
pub enum Model {
GPT2(fancy_regex::Regex),
LLaMa,
}
pub struct Bpe {
/// 保存所有词的字符串内容,以 u8 为单位所以不需要对齐,占用空间少
_vocabs: Pin<Box<[u8]>>,
Expand All @@ -22,6 +32,7 @@
special: Box<[utok]>,
/// token: <unk>
unk: utok,
modeltype: Model,
}

struct TokenMeta {
Expand All @@ -46,7 +57,7 @@
}

impl Bpe {
/// 解析 tokenizer.model 文件并构造一个 bpe 分词器。
/// 解析 tokenizer.model 文件并构造一个 bpe 分词器。暂时不支持gpt2格式的分词模型
pub fn from_tokenizer_model(model: &[u8]) -> Self {
// 遍历文件,标记所有词汇的位置
let offsets = (0..)
Expand Down Expand Up @@ -80,6 +91,7 @@
0,
),
scores,
Model::LLaMa,
)
}

Expand All @@ -88,20 +100,27 @@
scores: impl IntoIterator<Item = f32>,
token_type: impl IntoIterator<Item = TokenType>,
unk: utok,
modeltype: Model,
) -> Self {
Self::from_collected_vocab(
CollectedVocab::collect(vocabs.into_iter().map(|s| s.as_bytes()), token_type, unk),
scores,
modeltype,
)
}

fn from_collected_vocab(vocab: CollectedVocab, scores: impl IntoIterator<Item = f32>) -> Self {
fn from_collected_vocab(
vocab: CollectedVocab,
scores: impl IntoIterator<Item = f32>,
modeltype: Model,
) -> Self {
let CollectedVocab {
vocabs,
total_len,
bytes,
special,
unk,
..
} = vocab;
let CompressedVocab { vocabs, slices } = CompressedVocab::new(&vocabs, total_len);
// 收集合词评分
Expand Down Expand Up @@ -140,9 +159,16 @@
bytes,
special,
unk,
modeltype,
};
let inaccessible = ans.inaccessible();
ans.special = ans.special.into_iter().chain(inaccessible).collect();
match ans.modeltype {
// 在gpt2中会把正常字符当作特殊字符 例如 1052 Ġthere 影响解码
Model::GPT2(_) => {}
Model::LLaMa => {
let inaccessible = ans.inaccessible();
ans.special = ans.special.into_iter().chain(inaccessible).collect();
}
}
ans
}

Expand Down Expand Up @@ -181,6 +207,68 @@
fn token(&self, token: utok) -> &TokenMeta {
&self.tokens[token as usize]
}
pub fn from_gguf<T: GGufMetaMapExt>(gguf: &T) -> Tokeneer<Bpe> {
let tokens = gguf.tokenizer_ggml_tokens().unwrap();
let vocab_type = gguf.tokenizer_ggml_model().unwrap();
let scores = gguf.tokenizer_ggml_scores();
let token_type = gguf.tokenizer_ggml_token_type().unwrap();
let token_type = token_type.map(|ty| {
match unsafe { std::mem::transmute::<i32, ggus::GGmlTokenType>(ty.unwrap()) } {
GGmlTokenType::Normal => TokenType::Normal,
GGmlTokenType::Unknown => TokenType::Unknown,
GGmlTokenType::Control => TokenType::Control,
GGmlTokenType::User => TokenType::UserDefined,
GGmlTokenType::Unused => TokenType::Normal,
GGmlTokenType::Byte => TokenType::Byte,
}
});
let token_len = tokens.len();
let vocabs = tokens.map(|piece| piece.unwrap());
// 部分模型可能无该字段
let unk = match gguf.tokenizer_ggml_unknown_token_id() {
Ok(id) => id,
Err(_) => {
warn!("tokenizer_ggml_unknown_token_id not found");
u32::MAX
}
};
match vocab_type {
"llama" => {
let scores = scores.unwrap();
assert_eq!(token_len, scores.len());
let scores = scores.map(|s| s.unwrap());
Tokeneer::new(Bpe::new(vocabs, scores, token_type, unk, Model::LLaMa))
}
"gpt2" => {
let pre_type = gguf.get_str("tokenizer.ggml.pre").unwrap();
let regex_str = match pre_type {
"qwen2" | "deepseek-r1-qwen" => TOKENIZER_PRE_QWEN,
_ => unimplemented!("not supported pre_type {}", pre_type),
};
match scores {
Ok(scores) => {
assert_eq!(token_len, scores.len());
let scores = scores.map(|s| s.unwrap());
Tokeneer::new(Bpe::new(
vocabs,
scores,
token_type,
unk,
Model::GPT2(fancy_regex::Regex::new(regex_str).unwrap()),
))
}
Err(_) => Tokeneer::new(Bpe::new(
vocabs,
std::iter::repeat(0.0).take(token_len),

Check failure

Code scanning / clippy

this repeat().take() can be written more concisely Error

this repeat().take() can be written more concisely

Check failure

Code scanning / clippy

this repeat().take() can be written more concisely Error

this repeat().take() can be written more concisely
token_type,
unk,
Model::GPT2(fancy_regex::Regex::new(regex_str).unwrap()),
)),
}
}
_ => unreachable!("not supported model"),
}
}
}

impl Method for Bpe {
Expand All @@ -201,13 +289,79 @@
}
#[inline]
fn encode(&self, text: &str) -> impl IntoIterator<Item = utok> + '_ {
let mut tokenizer = self.begin_merge(text);
while tokenizer.merge() {}
tokenizer.into_iter()
let text = self.pre_encode(text);
let mut vocab = Vec::new();
match &self.modeltype {
Model::GPT2(regex) => {
// 使用正则表达式分割文本
let mut result = Vec::new();
let mut last_end = 0;
// 将正则表达式中的匹配结果编码之后,添加到列表中
let mut push_result = |s: &str| {
let r = s.bytes().map(unicode_byte_to_utf8).collect::<String>();
result.push(r);
};
for cap in regex.captures_iter(&text).flatten() {
if let Some(m) = cap.get(0) {
// 如果匹配前有未匹配的文本,添加到结果中
if m.start() > last_end {
push_result(&text[last_end..m.start()]);
}
// 添加匹配的文本
push_result(&text[m.start()..m.end()]);
last_end = m.end();
}
}

// 添加最后一部分未匹配的文本
if last_end < text.len() {
push_result(&text[last_end..]);
}
for r in result {
let mut tokenizer = self.begin_merge(&r);
while tokenizer.merge() {}
vocab.extend(tokenizer.into_iter());
}
}
Model::LLaMa => {
let mut tokenizer = self.begin_merge(&text);
while tokenizer.merge() {}
vocab.extend(tokenizer);
}
}
vocab.into_iter()
}
#[inline]
fn decode(&self, token: utok) -> &[u8] {
self.token(token)
fn decode(&self, token: utok) -> Cow<'_, [u8]> {
match &self.modeltype {
Model::GPT2(_) => match self.special.contains(&token) {
true => {
// 特殊token 直接返回
std::borrow::Cow::Borrowed(self.token(token))
}
false => llama_decode_text(&String::from_utf8_lossy(self.token(token)))
.into_bytes()
.into(),
},

Model::LLaMa => {
let token_str = String::from_utf8_lossy(self.token(token));
let decoded_str = self.pre_decode(&token_str);
Cow::Owned(decoded_str.into_owned().into_bytes())
}
}
}
fn pre_encode<'s>(&self, text: &'s str) -> Cow<'s, str> {
match &self.modeltype {
Model::GPT2(_) => text.into(),
Model::LLaMa => text.replace(" ", "\u{2581}").into(),
}
}
fn pre_decode<'s>(&self, text: &'s str) -> Cow<'s, str> {
match &self.modeltype {
Model::GPT2(_) => text.into(),
Model::LLaMa => text.replace("\u{2581}", " ").into(),
}
}
}

Expand Down Expand Up @@ -291,6 +445,7 @@
],
[TokenType::Normal; 10],
0,
Model::LLaMa,
)
}

Expand All @@ -316,10 +471,10 @@
#[test]
fn test_bpe_decode() {
let bpe = test_bpe();
assert_eq!(bpe.decode(3), b"c");
assert_eq!(bpe.decode(6), b"ac");
assert_eq!(bpe.decode(9), b"bcd");
assert_eq!(bpe.decode(0), b"<unk>");
assert_eq!(&*bpe.decode(3), b"c");
assert_eq!(&*bpe.decode(6), b"ac");
assert_eq!(&*bpe.decode(9), b"bcd");
assert_eq!(&*bpe.decode(0), b"<unk>");
}

#[test]
Expand All @@ -332,7 +487,7 @@

let decoded: Vec<_> = encoded
.iter()
.flat_map(|&t| bpe.decode(t).iter().copied())
.flat_map(|&t| bpe.decode(t).iter().copied().collect::<Vec<_>>())
.collect();
assert_eq!(std::str::from_utf8(&decoded), Ok("abcd<unk>"))
}
Expand Down Expand Up @@ -376,7 +531,7 @@
TokenType::Byte,
TokenType::Byte,
];
let bpe = Bpe::new(vocabs, scores, token_type, 0);
let bpe = Bpe::new(vocabs, scores, token_type, 0, Model::LLaMa);

let encoded: Vec<_> = bpe.encode("aAB").into_iter().collect();
assert_eq!(encoded, [0, 2, 3], "Expected 3 tokens for input 'aAB'")
Expand Down
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub static TOKENIZER_PRE_QWEN: &str = "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\nÒA-Za-zÑ0-9]?[ÒA-Za-z]+|[Ñ0-9]| ?[^\\sÒA-Za-zÑ0-9]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+";
15 changes: 12 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
#![deny(warnings)]

mod bpe;
mod common;
mod lpe;
mod tokeneer;
mod utils;
mod vocab;
use std::borrow::Cow;

pub use bpe::Bpe;
pub use bpe::Model;
pub use lpe::Lpe;
pub use tokeneer::Tokeneer;
pub use vocab::TokenType;

/// `utok` for token id.
#[allow(non_camel_case_types)]
pub type utok = u32;

// 添加判断终止符的函数
pub trait Method {
fn unk_token(&self) -> utok;
fn vocab_size(&self) -> usize;
fn internal_special(&self) -> impl IntoIterator<Item = (&str, utok)>;
fn encode(&self, text: &str) -> impl IntoIterator<Item = utok> + '_;
fn decode(&self, token: utok) -> &[u8];
fn decode(&self, token: utok) -> Cow<[u8]>;
fn pre_encode<'s>(&self, text: &'s str) -> Cow<'s, str> {
text.into()
}
fn pre_decode<'s>(&self, text: &'s str) -> Cow<'s, str> {
text.into()
}
}
5 changes: 3 additions & 2 deletions src/lpe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ impl Lpe {
bytes,
special,
unk,
..
} = vocab;

let CompressedVocab { vocabs, slices } = if map_utf8 {
Expand Down Expand Up @@ -164,8 +165,8 @@ impl Method for Lpe {
tokens
}
#[inline]
fn decode(&self, token: utok) -> &[u8] {
self.token(token)
fn decode(&self, token: utok) -> Cow<'_, [u8]> {
std::borrow::Cow::Borrowed(self.token(token))
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/tokeneer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<M: Method> Tokeneer<M> {
pub fn decode(&self, tokens: &[utok]) -> String {
let mut ans = Vec::new();
for &t in tokens {
ans.extend_from_slice(self.method.decode(t))
ans.extend_from_slice(&self.method.decode(t))
}
String::from_utf8(ans).unwrap()
}
Expand Down
Loading