Skip to content

Commit

Permalink
Add pretrained model/tokenizer download
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed May 22, 2024
1 parent ed1277c commit fb57388
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 12 deletions.
5 changes: 2 additions & 3 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ default = ["std"] # TODO: remove std default
std = []
pretrained = ["burn/network", "std", "dep:dirs"]

llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["dep:rust_tokenizers"]
llama3 = ["pretrained", "dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["pretrained", "dep:rust_tokenizers"]

[dependencies]
# Note: default-features = false is needed to disable std
Expand All @@ -37,7 +37,6 @@ rustc-hash = {version = "1.1.0", optional = true }
# SentencePiece tokenizer (tiny llama / llama 2)
rust_tokenizers = { version = "8.1.1", optional = true }

# Temporary; will be removed once we have the multinomial distribution
rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std
Expand Down
8 changes: 4 additions & 4 deletions llama-burn/examples/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ pub fn main() {

#[cfg(feature = "tiny")]
{
let mut llama =
LlamaConfig::load_tiny_llama::<B>(&args.model, &args.tokenizer, &device).unwrap();
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(&device).unwrap();

generate(
&mut llama,
&args.prompt,
Expand All @@ -106,8 +106,8 @@ pub fn main() {

#[cfg(feature = "llama3")]
{
let mut llama =
LlamaConfig::load_llama3_8b::<B>(&args.model, &args.tokenizer, &device).unwrap();
let mut llama = LlamaConfig::llama3_8b_pretrained::<B>(&device).unwrap();

generate(
&mut llama,
&args.prompt,
Expand Down
1 change: 1 addition & 0 deletions llama-burn/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub(crate) mod cache;
pub mod llama;
pub mod pretrained;
pub mod sampling;
pub mod tokenizer;
mod transformer;
52 changes: 47 additions & 5 deletions llama-burn/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ use crate::{
transformer::{KeyValueCache, Transformer, TransformerConfig},
};

#[cfg(feature = "pretrained")]
use crate::pretrained::{self, ModelMeta};
#[cfg(feature = "tiny")]
use crate::tokenizer::SentiencePieceTokenizer;
#[cfg(feature = "llama3")]
use crate::tokenizer::Tiktoken;

const LLAMA3_VOCAB_SIZE: usize = 128256;

#[derive(Config, Debug)]
pub struct LlamaConfig {
/// The size of the model.
Expand Down Expand Up @@ -58,7 +58,7 @@ impl LlamaConfig {
/// Llama-3-8B configuration.
pub fn llama3_8b(tokenizer_path: &str) -> Self {
// hidden_size = 14336; vocab_size = 128256
Self::new(14336, LLAMA3_VOCAB_SIZE, tokenizer_path.to_string())
Self::new(14336, 128256, tokenizer_path.to_string())
.with_num_key_value_heads(Some(8))
.with_rope_theta(500000.0)
}
Expand All @@ -82,6 +82,27 @@ impl LlamaConfig {
Ok(llama)
}

/// Load pre-trained Llama-3-8B model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer.
#[cfg(feature = "llama3")]
pub fn llama3_8b_pretrained<B: Backend>(
device: &Device<B>,
) -> Result<Llama<B, Tiktoken>, String> {
// Download checkpoint and tokenizer
let model = pretrained::Llama::Llama3.pretrained();
let checkpoint = model
.download_weights()
.map_err(|err| format!("Could not download weights.\nError: {err}"))?;
let tokenizer = model
.download_tokenizer()
.map_err(|err| format!("Could not download tokenizer.\nError: {err}"))?;

Self::load_llama3_8b(
checkpoint.to_str().unwrap(),
tokenizer.to_str().unwrap(),
device,
)
}

/// TinyLlama-1.1B Chat v1.0 configuration.
pub fn tiny_llama(tokenizer_path: &str) -> Self {
// hidden_size = 5632; vocab_size = 32000
Expand All @@ -91,6 +112,7 @@ impl LlamaConfig {
.with_num_key_value_heads(Some(4))
.with_rope_theta(10000.0)
}

/// Load pre-trained TinyLlama-1.1B Chat v1.0 model with [SentenciePiece](https://github.com/google/sentencepiece) tokenizer.
#[cfg(feature = "tiny")]
pub fn load_tiny_llama<B: Backend>(
Expand All @@ -110,6 +132,27 @@ impl LlamaConfig {
Ok(llama)
}

/// Load pre-trained TinyLlama-1.1B Chat v1.0 model with [SentenciePiece](https://github.com/google/sentencepiece) tokenizer.
#[cfg(feature = "tiny")]
pub fn tiny_llama_pretrained<B: Backend>(
device: &Device<B>,
) -> Result<Llama<B, SentiencePieceTokenizer>, String> {
// Download checkpoint and tokenizer
let model = pretrained::Llama::TinyLlama.pretrained();
let checkpoint = model
.download_weights()
.map_err(|err| format!("Could not download weights.\nError: {err}"))?;
let tokenizer = model
.download_tokenizer()
.map_err(|err| format!("Could not download tokenizer.\nError: {err}"))?;

Self::load_tiny_llama(
checkpoint.to_str().unwrap(),
tokenizer.to_str().unwrap(),
device,
)
}

/// Initialize a new [Llama](Llama) module.
pub fn init<B: Backend, T: Tokenizer>(
&self,
Expand Down Expand Up @@ -160,7 +203,7 @@ impl LlamaConfig {
// Load weights from torch state_dict
let mut load_args = LoadArgs::new(checkpoint.into());

if self.vocab_size == LLAMA3_VOCAB_SIZE {
if !cfg!(feature = "tiny") {
load_args = load_args
// Map layers.[i].feed_forward.w1.* -> layers.[i].feed_forward.swiglu.linear_inner.*
.with_key_remap(
Expand All @@ -175,7 +218,6 @@ impl LlamaConfig {
// Map norm.weight -> norm.gamma for all layers
.with_key_remap("(.*)norm\\.weight", "${1}norm.gamma");
} else {
// We assume Tiny Llama when != LLAMA3_VOCAB_SIZE
load_args = load_args
// Map lm_head.* -> output.*
.with_key_remap("lm_head\\.(.+)", "output.$1")
Expand Down
82 changes: 82 additions & 0 deletions llama-burn/src/pretrained.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/// Pre-trained model metadata.
pub struct Pretrained {
pub(super) name: &'static str,
pub(super) model: &'static str,
pub(super) tokenizer: &'static str,
}

#[cfg(feature = "pretrained")]
mod downloader {
use super::*;
use burn::data::network::downloader;
use std::fs::{create_dir_all, File};
use std::io::Write;
use std::path::PathBuf;

impl Pretrained {
/// Download the file to the local cache directory.
fn download(&self, url: &str, file: &str) -> Result<PathBuf, std::io::Error> {
// Model cache directory
let model_dir = dirs::home_dir()
.expect("Should be able to get home directory")
.join(".cache")
.join("llama-burn")
.join(self.name);

if !model_dir.exists() {
create_dir_all(&model_dir)?;
}

let file_name = model_dir.join(file);
if !file_name.exists() {
// Download file content
let bytes = downloader::download_file_as_bytes(url, file);

// Write content to file
let mut output_file = File::create(&file_name)?;
output_file.write_all(&bytes)?; // write_all is not OS limited (files over 2GB)
}

Ok(file_name)
}

/// Download the pre-trained model weights to the local cache directory.
pub fn download_weights(&self) -> Result<PathBuf, std::io::Error> {
self.download(self.model, "model.mpk")
}

/// Download the tokenizer to the local cache directory.
pub fn download_tokenizer(&self) -> Result<PathBuf, std::io::Error> {
self.download(self.tokenizer, "tokenizer.model")
}
}
}

pub trait ModelMeta {
fn pretrained(&self) -> Pretrained;
}

/// Llama pre-trained weights.
pub enum Llama {
/// Llama-3-8B.
Llama3,
/// TinyLlama-1.1B.
TinyLlama,
}

impl ModelMeta for Llama {
fn pretrained(&self) -> Pretrained {
match self {
Self::Llama3 => Pretrained {
name: "Llama-3-8B",
model: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/model.mpk?download=true",
tokenizer: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/tokenizer.model?download=true",
},
Self::TinyLlama => Pretrained {
name: "TinyLlama-1.1B",
model: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/model.mpk?download=true",
tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.model?download=true",
},
}
}
}

0 comments on commit fb57388

Please sign in to comment.