Skip to content

Commit

Permalink
Use sentencepiece from hf tokenizers
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed May 23, 2024
1 parent fb57388 commit edbf3a6
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 37 deletions.
4 changes: 2 additions & 2 deletions llama-burn/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {
) -> GenerationOutput {
let mut tokens = self.tokenize(prompt);
let prompt_len = tokens.dims()[0];
let eos_token = self.tokenizer.eos_id();
let eos_token = self.tokenizer.eos_id() as i64;

let mut num_tokens: usize = 0;
let mut input_pos = Tensor::<B, 1, Int>::arange(0..tokens.dims()[0] as i64, &self.device);
Expand Down Expand Up @@ -368,7 +368,7 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {

let tokens = tokens.into_data().value[prompt_len..]
.iter()
.map(|t| t.elem::<i64>())
.map(|t| t.elem::<u32>())
.collect::<Vec<_>>();

let generated = self.tokenizer.decode(tokens);
Expand Down
8 changes: 4 additions & 4 deletions llama-burn/src/tokenizer/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ pub trait Tokenizer {
Self: Sized;

/// Encode a string into a list of token identifiers.
fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<i64>;
fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<u32>;

/// Decode a list of token identifiers into a string.
fn decode(&self, tokens: Vec<i64>) -> String;
fn decode(&self, tokens: Vec<u32>) -> String;

/// Beginning of sentence token identifier.
fn bos_id(&self) -> i64;
fn bos_id(&self) -> u32;

/// End of sentence token identifier.
fn eos_id(&self) -> i64;
fn eos_id(&self) -> u32;
}
43 changes: 19 additions & 24 deletions llama-burn/src/tokenizer/sentence_piece.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,20 @@
use rust_tokenizers::tokenizer::{
SentencePieceBpeTokenizer, Tokenizer as BaseTokenizer, TruncationStrategy,
};
use tokenizers::Tokenizer as BaseTokenizer;

use super::Tokenizer;

const BOS_TOKEN_ID: i64 = 1;
const EOS_TOKEN_ID: i64 = 2;
const BOS_TOKEN_ID: u32 = 1;
const EOS_TOKEN_ID: u32 = 2;

pub struct SentiencePieceTokenizer {
bpe: SentencePieceBpeTokenizer,
bos_token_id: i64,
eos_token_id: i64,
bpe: BaseTokenizer,
bos_token_id: u32,
eos_token_id: u32,
}

impl Tokenizer for SentiencePieceTokenizer {
/// Load the [SentenciePiece](https://github.com/google/sentencepiece) tokenizer.
fn new(tokenizer_path: &str) -> Result<Self, String> {
let bpe = SentencePieceBpeTokenizer::from_file(tokenizer_path, false)
.map_err(|e| e.to_string())?;
let bpe = BaseTokenizer::from_file(tokenizer_path).map_err(|e| e.to_string())?;

Ok(Self {
bpe,
Expand All @@ -26,34 +23,32 @@ impl Tokenizer for SentiencePieceTokenizer {
})
}

fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<i64> {
fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<u32> {
let bos_token = if bos { vec![self.bos_token_id] } else { vec![] };
let eos_token = if eos { vec![self.eos_token_id] } else { vec![] };

// No text combination
let tokens = self
.bpe
.encode(text, None, usize::MAX, &TruncationStrategy::LongestFirst, 0)
.token_ids;
let tokens = self.bpe.encode(text, false).unwrap().get_ids().to_vec();

[bos_token, tokens, eos_token]
.into_iter()
.flat_map(|t| t.into_iter())
.map(|t| t as i64)
.collect()
}

fn decode(&self, tokens: Vec<i64>) -> String {
fn decode(&self, tokens: Vec<u32>) -> String {
self.bpe
.decode(&tokens, true, false)
.replace("<0x0A>", "\n")
.decode(
&tokens.into_iter().map(|t| t as u32).collect::<Vec<u32>>(),
true,
)
.unwrap()
}

fn bos_id(&self) -> i64 {
self.bos_token_id as i64
fn bos_id(&self) -> u32 {
self.bos_token_id
}

fn eos_id(&self) -> i64 {
self.eos_token_id as i64
fn eos_id(&self) -> u32 {
self.eos_token_id
}
}
14 changes: 7 additions & 7 deletions llama-burn/src/tokenizer/tiktoken.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl Tokenizer for Tiktoken {
})
}

fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<i64> {
fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec<u32> {
let bos_token = if bos { vec![self.bos_token_id] } else { vec![] };
let eos_token = if eos { vec![self.eos_token_id] } else { vec![] };

Expand All @@ -95,21 +95,21 @@ impl Tokenizer for Tiktoken {
[bos_token, tokens, eos_token]
.into_iter()
.flat_map(|t| t.into_iter())
.map(|t| t as i64)
.map(|t| t as u32)
.collect()
}

fn decode(&self, tokens: Vec<i64>) -> String {
fn decode(&self, tokens: Vec<u32>) -> String {
self.bpe
.decode(tokens.into_iter().map(|t| t as usize).collect())
.expect("Should decode tokens")
}

fn bos_id(&self) -> i64 {
self.bos_token_id as i64
fn bos_id(&self) -> u32 {
self.bos_token_id as u32
}

fn eos_id(&self) -> i64 {
self.eos_token_id as i64
fn eos_id(&self) -> u32 {
self.eos_token_id as u32
}
}

0 comments on commit edbf3a6

Please sign in to comment.