Skip to content

Commit

Permalink
Fix TinyLlama query/key weights and add chat mode
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed May 24, 2024
1 parent edbf3a6 commit 67740e3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 30 deletions.
4 changes: 2 additions & 2 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ std = []
pretrained = ["burn/network", "std", "dep:dirs"]

llama3 = ["pretrained", "dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["pretrained", "dep:rust_tokenizers"]
tiny = ["pretrained", "dep:tokenizers"]

[dependencies]
# Note: default-features = false is needed to disable std
Expand All @@ -35,7 +35,7 @@ base64 = { version = "0.22.1", optional = true }
rustc-hash = {version = "1.1.0", optional = true }

# SentencePiece tokenizer (tiny llama / llama 2)
rust_tokenizers = { version = "8.1.1", optional = true }
tokenizers = { version = "0.19.1", default-features = false, features = ["onig"], optional = true }

rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
Expand Down
40 changes: 27 additions & 13 deletions llama-burn/examples/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,11 @@ use llama_burn::{
tokenizer::Tokenizer,
};

const DEFAULT_PROMPT: &str = "I believe the meaning of life is";
const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?";

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct Config {
// TODO: download checkpoint from HF hub.
/// Model checkpoint path (automatically downloaded from the web if not present).
#[arg(short, long)]
model: String,

/// Tokenizer path.
#[arg(short, long)]
tokenizer: String,

/// Top-p probability threshold.
#[arg(long, default_value_t = 0.9)]
top_p: f64,
Expand All @@ -48,6 +39,10 @@ pub struct Config {
/// The input prompt.
#[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))]
prompt: String,

/// Chat assistant mode.
#[arg(short, long, default_value_t = cfg!(feature = "tiny"))]
chat: bool,
}

pub fn generate<B: Backend, T: Tokenizer>(
Expand All @@ -57,7 +52,6 @@ pub fn generate<B: Backend, T: Tokenizer>(
temperature: f64,
sampler: &mut Sampler,
) {
println!("Processing prompt: {}", prompt);
let now = Instant::now();
let generated = llama.generate(prompt, sample_len, temperature, sampler);
let elapsed = now.elapsed().as_secs();
Expand All @@ -83,6 +77,7 @@ pub fn main() {
let args = Config::parse();

let device = LibTorchDevice::Cuda(0);
let prompt = args.prompt;

// Sampling strategy
let mut sampler = if args.temperature > 0.0 {
Expand All @@ -94,10 +89,21 @@ pub fn main() {
#[cfg(feature = "tiny")]
{
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(&device).unwrap();
println!("Processing prompt: {}", prompt);

let prompt = if args.chat {
// Prompt formatting for chat model
format!(
"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
)
} else {
// Prompt with BOS token
format!("{}{prompt}", llama.tokenizer.bos())
};

generate(
&mut llama,
&args.prompt,
&prompt,
args.sample_len,
args.temperature,
&mut sampler,
Expand All @@ -107,10 +113,18 @@ pub fn main() {
#[cfg(feature = "llama3")]
{
let mut llama = LlamaConfig::llama3_8b_pretrained::<B>(&device).unwrap();
println!("Processing prompt: {}", prompt);

let prompt = if args.chat {
panic!("Llama-8B-Instruct is not available yet.");
} else {
// Prompt with BOS token
format!("{}{prompt}", llama.tokenizer.bos())
};

generate(
&mut llama,
&args.prompt,
&prompt,
args.sample_len,
args.temperature,
&mut sampler,
Expand Down
52 changes: 43 additions & 9 deletions llama-burn/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use burn_import::pytorch::{LoadArgs, PyTorchFileRecorder};
use crate::{
sampling::Sampler,
tokenizer::Tokenizer,
transformer::{KeyValueCache, Transformer, TransformerConfig},
transformer::{KeyValueCache, Transformer, TransformerConfig, TransformerRecord},
};

#[cfg(feature = "pretrained")]
Expand Down Expand Up @@ -275,12 +275,47 @@ impl LlamaConfig {
}
println!("Loading record...");
let now = Instant::now();
let record = PyTorchFileRecorder::<HalfPrecisionSettings>::new()
let mut record: TransformerRecord<B> = PyTorchFileRecorder::<HalfPrecisionSettings>::new()
.load(load_args, device)
.map_err(|e| e.to_string())?;
let elapsed = now.elapsed().as_secs();
println!("Loaded in {}s", elapsed);

if cfg!(feature = "tiny") {
// TinyLlama weights from HuggingFace use a different rotary positional encoding
// which requires weight permutation:
// https://github.com/huggingface/transformers/issues/25199#issuecomment-1687720247
// https://github.com/jzhang38/TinyLlama/issues/24
let n_heads = self.num_attention_heads;
let n_kv_heads = self.num_key_value_heads.unwrap_or(n_heads);
let wk_dim = self.d_model * n_kv_heads / n_heads;
let permute = |w: Tensor<B, 2>, n_heads: usize, dim1: usize, dim2: usize| {
let w = w // [2048, 256]
.reshape([dim1, n_heads, 2, dim2 / n_heads / 2]) // [2048, 4, 2, 32]
.swap_dims(2, 3) // [2048, 4, 32, 2]
.reshape([dim1, dim2]);
w
};

record.layers = record
.layers
.into_iter()
.map(|mut layer| {
layer.attention.wq.weight = layer
.attention
.wq
.weight
.map(|w| permute(w, n_heads, self.d_model, self.d_model));
layer.attention.wk.weight = layer
.attention
.wk
.weight
.map(|w| permute(w, n_kv_heads, self.d_model, wk_dim));
layer
})
.collect::<Vec<_>>();
}

llama.model = llama.model.load_record(record);
println!("Llama record loaded");

Expand All @@ -301,14 +336,14 @@ pub struct GenerationOutput {
/// Meta Llama large language model and tokenizer.
pub struct Llama<B: Backend, T: Tokenizer> {
/// The tokenizer.
tokenizer: T,
pub tokenizer: T,
/// Llama decoder-only transformer.
model: Transformer<B>,
pub model: Transformer<B>,
/// Key-value cache for each transformer block.
cache: Vec<KeyValueCache<B>>,
pub cache: Vec<KeyValueCache<B>>,
/// Rotary positional encoding (RoPE).
rope: RotaryEncoding<B>,
device: Device<B>,
pub rope: RotaryEncoding<B>,
pub device: Device<B>,
}

impl<B: Backend, T: Tokenizer> Llama<B, T> {
Expand Down Expand Up @@ -346,7 +381,6 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {
.slice([0..batch_size, seq_len - 1..seq_len])
.squeeze(1); // [batch_size=1, vocab_size]

// TODO: naive sampling w/o cumsum tensor op to first test llama implementation correctness
if temperature > 0.0 {
next_token_logits = softmax(next_token_logits / temperature, 1);
};
Expand Down Expand Up @@ -383,7 +417,7 @@ impl<B: Backend, T: Tokenizer> Llama<B, T> {

/// Encode a string into a tensor of tokens.
fn tokenize(&self, text: &str) -> Tensor<B, 1, Int> {
let tokens = self.tokenizer.encode(text, true, false);
let tokens = self.tokenizer.encode(text, false, false);

let shape = Shape::new([tokens.len()]);
Tensor::<B, 1, Int>::from_data(Data::new(tokens, shape).convert(), &self.device)
Expand Down
17 changes: 11 additions & 6 deletions llama-burn/src/pretrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ mod downloader {

impl Pretrained {
/// Download the file to the local cache directory.
fn download(&self, url: &str, file: &str) -> Result<PathBuf, std::io::Error> {
fn download(&self, url: &str) -> Result<PathBuf, std::io::Error> {
// Model cache directory
let model_dir = dirs::home_dir()
.expect("Should be able to get home directory")
Expand All @@ -27,10 +27,15 @@ mod downloader {
create_dir_all(&model_dir)?;
}

let file_name = model_dir.join(file);
let file_base_name = url
.rsplit_once('/')
.unwrap()
.1
.replace("?download=true", "");
let file_name = model_dir.join(&file_base_name);
if !file_name.exists() {
// Download file content
let bytes = downloader::download_file_as_bytes(url, file);
let bytes = downloader::download_file_as_bytes(url, &file_base_name);

// Write content to file
let mut output_file = File::create(&file_name)?;
Expand All @@ -42,12 +47,12 @@ mod downloader {

/// Download the pre-trained model weights to the local cache directory.
pub fn download_weights(&self) -> Result<PathBuf, std::io::Error> {
self.download(self.model, "model.mpk")
self.download(self.model)
}

/// Download the tokenizer to the local cache directory.
pub fn download_tokenizer(&self) -> Result<PathBuf, std::io::Error> {
self.download(self.tokenizer, "tokenizer.model")
self.download(self.tokenizer)
}
}
}
Expand Down Expand Up @@ -75,7 +80,7 @@ impl ModelMeta for Llama {
Self::TinyLlama => Pretrained {
name: "TinyLlama-1.1B",
model: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/model.mpk?download=true",
tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.model?download=true",
tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.json?download=true",
},
}
}
Expand Down
10 changes: 10 additions & 0 deletions llama-burn/src/tokenizer/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@ pub trait Tokenizer {
/// Decode a list of token identifiers into a string.
fn decode(&self, tokens: Vec<u32>) -> String;

/// Beginning of sentence token.
fn bos(&self) -> String {
self.decode(vec![self.bos_id()])
}

/// Beginning of sentence token identifier.
fn bos_id(&self) -> u32;

/// End of sentence token.
fn eos(&self) -> String {
self.decode(vec![self.eos_id()])
}

/// End of sentence token identifier.
fn eos_id(&self) -> u32;
}

0 comments on commit 67740e3

Please sign in to comment.