-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add pretrained model/tokenizer download
- Loading branch information
Showing
5 changed files
with
136 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
pub(crate) mod cache; | ||
pub mod llama; | ||
pub mod pretrained; | ||
pub mod sampling; | ||
pub mod tokenizer; | ||
mod transformer; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/// Pre-trained model metadata. | ||
pub struct Pretrained { | ||
pub(super) name: &'static str, | ||
pub(super) model: &'static str, | ||
pub(super) tokenizer: &'static str, | ||
} | ||
|
||
#[cfg(feature = "pretrained")] | ||
mod downloader { | ||
use super::*; | ||
use burn::data::network::downloader; | ||
use std::fs::{create_dir_all, File}; | ||
use std::io::Write; | ||
use std::path::PathBuf; | ||
|
||
impl Pretrained { | ||
/// Download the file to the local cache directory. | ||
fn download(&self, url: &str, file: &str) -> Result<PathBuf, std::io::Error> { | ||
// Model cache directory | ||
let model_dir = dirs::home_dir() | ||
.expect("Should be able to get home directory") | ||
.join(".cache") | ||
.join("llama-burn") | ||
.join(self.name); | ||
|
||
if !model_dir.exists() { | ||
create_dir_all(&model_dir)?; | ||
} | ||
|
||
let file_name = model_dir.join(file); | ||
if !file_name.exists() { | ||
// Download file content | ||
let bytes = downloader::download_file_as_bytes(url, file); | ||
|
||
// Write content to file | ||
let mut output_file = File::create(&file_name)?; | ||
output_file.write_all(&bytes)?; // write_all is not OS limited (files over 2GB) | ||
} | ||
|
||
Ok(file_name) | ||
} | ||
|
||
/// Download the pre-trained model weights to the local cache directory. | ||
pub fn download_weights(&self) -> Result<PathBuf, std::io::Error> { | ||
self.download(self.model, "model.mpk") | ||
} | ||
|
||
/// Download the tokenizer to the local cache directory. | ||
pub fn download_tokenizer(&self) -> Result<PathBuf, std::io::Error> { | ||
self.download(self.tokenizer, "tokenizer.model") | ||
} | ||
} | ||
} | ||
|
||
pub trait ModelMeta { | ||
fn pretrained(&self) -> Pretrained; | ||
} | ||
|
||
/// Llama pre-trained weights. | ||
pub enum Llama { | ||
/// Llama-3-8B. | ||
Llama3, | ||
/// TinyLlama-1.1B. | ||
TinyLlama, | ||
} | ||
|
||
impl ModelMeta for Llama { | ||
fn pretrained(&self) -> Pretrained { | ||
match self { | ||
Self::Llama3 => Pretrained { | ||
name: "Llama-3-8B", | ||
model: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/model.mpk?download=true", | ||
tokenizer: "https://huggingface.co/tracel-ai/llama-3-8b-burn/resolve/main/tokenizer.model?download=true", | ||
}, | ||
Self::TinyLlama => Pretrained { | ||
name: "TinyLlama-1.1B", | ||
model: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/model.mpk?download=true", | ||
tokenizer: "https://huggingface.co/tracel-ai/tiny-llama-1.1b-burn/resolve/main/tokenizer.model?download=true", | ||
}, | ||
} | ||
} | ||
} |