From fb573887a29e6135bba726e34b0ae443a513e855 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 22 May 2024 15:46:19 -0400 Subject: [PATCH] Add pretrained model/tokenizer download --- llama-burn/Cargo.toml | 5 +- llama-burn/examples/generate.rs | 8 ++-- llama-burn/src/lib.rs | 1 + llama-burn/src/llama.rs | 52 +++++++++++++++++++-- llama-burn/src/pretrained.rs | 82 +++++++++++++++++++++++++++++++++ 5 files changed, 136 insertions(+), 12 deletions(-) create mode 100644 llama-burn/src/pretrained.rs diff --git a/llama-burn/Cargo.toml b/llama-burn/Cargo.toml index 0d7a4ee..6d65632 100644 --- a/llama-burn/Cargo.toml +++ b/llama-burn/Cargo.toml @@ -11,8 +11,8 @@ default = ["std"] # TODO: remove std default std = [] pretrained = ["burn/network", "std", "dep:dirs"] -llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"] -tiny = ["dep:rust_tokenizers"] +llama3 = ["pretrained", "dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"] +tiny = ["pretrained", "dep:rust_tokenizers"] [dependencies] # Note: default-features = false is needed to disable std @@ -37,7 +37,6 @@ rustc-hash = {version = "1.1.0", optional = true } # SentencePiece tokenizer (tiny llama / llama 2) rust_tokenizers = { version = "8.1.1", optional = true } -# Temporary; will be removed once we have the multinomial distribution rand = { version = "0.8.5", default-features = false, features = [ "std_rng", ] } # std_rng is for no_std diff --git a/llama-burn/examples/generate.rs b/llama-burn/examples/generate.rs index f63c63f..1c166c5 100644 --- a/llama-burn/examples/generate.rs +++ b/llama-burn/examples/generate.rs @@ -93,8 +93,8 @@ pub fn main() { #[cfg(feature = "tiny")] { - let mut llama = - LlamaConfig::load_tiny_llama::(&args.model, &args.tokenizer, &device).unwrap(); + let mut llama = LlamaConfig::tiny_llama_pretrained::(&device).unwrap(); + generate( &mut llama, &args.prompt, @@ -106,8 +106,8 @@ pub fn main() { #[cfg(feature = "llama3")] { - let mut llama = - LlamaConfig::load_llama3_8b::(&args.model, &args.tokenizer, &device).unwrap(); + let mut llama = LlamaConfig::llama3_8b_pretrained::(&device).unwrap(); + generate( &mut llama, &args.prompt, diff --git a/llama-burn/src/lib.rs b/llama-burn/src/lib.rs index 3d2fd76..1f385e9 100644 --- a/llama-burn/src/lib.rs +++ b/llama-burn/src/lib.rs @@ -1,5 +1,6 @@ pub(crate) mod cache; pub mod llama; +pub mod pretrained; pub mod sampling; pub mod tokenizer; mod transformer; diff --git a/llama-burn/src/llama.rs b/llama-burn/src/llama.rs index 63846af..e1d15c9 100644 --- a/llama-burn/src/llama.rs +++ b/llama-burn/src/llama.rs @@ -17,13 +17,13 @@ use crate::{ transformer::{KeyValueCache, Transformer, TransformerConfig}, }; +#[cfg(feature = "pretrained")] +use crate::pretrained::{self, ModelMeta}; #[cfg(feature = "tiny")] use crate::tokenizer::SentiencePieceTokenizer; #[cfg(feature = "llama3")] use crate::tokenizer::Tiktoken; -const LLAMA3_VOCAB_SIZE: usize = 128256; - #[derive(Config, Debug)] pub struct LlamaConfig { /// The size of the model. @@ -58,7 +58,7 @@ impl LlamaConfig { /// Llama-3-8B configuration. pub fn llama3_8b(tokenizer_path: &str) -> Self { // hidden_size = 14336; vocab_size = 128256 - Self::new(14336, LLAMA3_VOCAB_SIZE, tokenizer_path.to_string()) + Self::new(14336, 128256, tokenizer_path.to_string()) .with_num_key_value_heads(Some(8)) .with_rope_theta(500000.0) } @@ -82,6 +82,27 @@ impl LlamaConfig { Ok(llama) } + /// Load pre-trained Llama-3-8B model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer. + #[cfg(feature = "llama3")] + pub fn llama3_8b_pretrained( + device: &Device, + ) -> Result, String> { + // Download checkpoint and tokenizer + let model = pretrained::Llama::Llama3.pretrained(); + let checkpoint = model + .download_weights() + .map_err(|err| format!("Could not download weights.\nError: {err}"))?; + let tokenizer = model + .download_tokenizer() + .map_err(|err| format!("Could not download tokenizer.\nError: {err}"))?; + + Self::load_llama3_8b( + checkpoint.to_str().unwrap(), + tokenizer.to_str().unwrap(), + device, + ) + } + /// TinyLlama-1.1B Chat v1.0 configuration. pub fn tiny_llama(tokenizer_path: &str) -> Self { // hidden_size = 5632; vocab_size = 32000 @@ -91,6 +112,7 @@ impl LlamaConfig { .with_num_key_value_heads(Some(4)) .with_rope_theta(10000.0) } + /// Load pre-trained TinyLlama-1.1B Chat v1.0 model with [SentenciePiece](https://github.com/google/sentencepiece) tokenizer. #[cfg(feature = "tiny")] pub fn load_tiny_llama( @@ -110,6 +132,27 @@ impl LlamaConfig { Ok(llama) } + /// Load pre-trained TinyLlama-1.1B Chat v1.0 model with [SentenciePiece](https://github.com/google/sentencepiece) tokenizer. + #[cfg(feature = "tiny")] + pub fn tiny_llama_pretrained( + device: &Device, + ) -> Result, String> { + // Download checkpoint and tokenizer + let model = pretrained::Llama::TinyLlama.pretrained(); + let checkpoint = model + .download_weights() + .map_err(|err| format!("Could not download weights.\nError: {err}"))?; + let tokenizer = model + .download_tokenizer() + .map_err(|err| format!("Could not download tokenizer.\nError: {err}"))?; + + Self::load_tiny_llama( + checkpoint.to_str().unwrap(), + tokenizer.to_str().unwrap(), + device, + ) + } + /// Initialize a new [Llama](Llama) module. pub fn init( &self, @@ -160,7 +203,7 @@ impl LlamaConfig { // Load weights from torch state_dict let mut load_args = LoadArgs::new(checkpoint.into()); - if self.vocab_size == LLAMA3_VOCAB_SIZE { + if !cfg!(feature = "tiny") { load_args = load_args // Map layers.[i].feed_forward.w1.* -> layers.[i].feed_forward.swiglu.linear_inner.* .with_key_remap( @@ -175,7 +218,6 @@ impl LlamaConfig { // Map norm.weight -> norm.gamma for all layers .with_key_remap("(.*)norm\\.weight", "${1}norm.gamma"); } else { - // We assume Tiny Llama when != LLAMA3_VOCAB_SIZE load_args = load_args // Map lm_head.* -> output.* .with_key_remap("lm_head\\.(.+)", "output.$1") diff --git a/llama-burn/src/pretrained.rs b/llama-burn/src/pretrained.rs new file mode 100644 index 0000000..75cc21e --- /dev/null +++ b/llama-burn/src/pretrained.rs @@ -0,0 +1,82 @@ +/// Pre-trained model metadata. +pub struct Pretrained { + pub(super) name: &'static str, + pub(super) model: &'static str, + pub(super) tokenizer: &'static str, +} + +#[cfg(feature = "pretrained")] +mod downloader { + use super::*; + use burn::data::network::downloader; + use std::fs::{create_dir_all, File}; + use std::io::Write; + use std::path::PathBuf; + + impl Pretrained { + /// Download the file to the local cache directory. + fn download(&self, url: &str, file: &str) -> Result { + // Model cache directory + let model_dir = dirs::home_dir() + .expect("Should be able to get home directory") + .join(".cache") + .join("llama-burn") + .join(self.name); + + if !model_dir.exists() { + create_dir_all(&model_dir)?; + } + + let file_name = model_dir.join(file); + if !file_name.exists() { + // Download file content + let bytes = downloader::download_file_as_bytes(url, file); + + // Write content to file + let mut output_file = File::create(&file_name)?; + output_file.write_all(&bytes)?; // write_all is not OS limited (files over 2GB) + } + + Ok(file_name) + } + + /// Download the pre-trained model weights to the local cache directory. + pub fn download_weights(&self) -> Result { + self.download(self.model, "model.mpk") + } + + /// Download the tokenizer to the local cache directory. + pub fn download_tokenizer(&self) -> Result { + self.download(self.tokenizer, "tokenizer.model") + } + } +} + +pub trait ModelMeta { + fn pretrained(&self) -> Pretrained; +} + +/// Llama pre-trained weights. +pub enum Llama { + /// Llama-3-8B. + Llama3, + /// TinyLlama-1.1B. + TinyLlama, +} + +impl ModelMeta for Llama { + fn pretrained(&self) -> Pretrained { + match self { + Self::Llama3 => Pretrained { + name: "Llama-3-8B", + model: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/model.mpk?download=true", + tokenizer: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/tokenizer.model?download=true", + }, + Self::TinyLlama => Pretrained { + name: "TinyLlama-1.1B", + model: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/model.mpk?download=true", + tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.model?download=true", + }, + } + } +}