Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama #35

Merged
merged 34 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
ac9a032
Add Llama-3-8b
laggui May 3, 2024
d15e363
Fix next_token_logits softmax dim
laggui May 7, 2024
3fa0fbc
Fix RoPE dim and mask cache_seq_len
laggui May 7, 2024
2500ef0
Move RotaryEncoding outside of Transformer block so pytorch checkpoin…
laggui May 9, 2024
78fbc2e
Fix sampling sort dim
laggui May 10, 2024
f8a934c
Fix attention mask
laggui May 10, 2024
f958603
Add load/save model
laggui May 10, 2024
ffab4ae
Add tiny llama w/ sentencepiece tokenizer
laggui May 14, 2024
4ddf2bc
Add tiny llama checkpoint loading w/ keys remap
laggui May 14, 2024
7623cf4
Fix top-p sampling index
laggui May 14, 2024
dff88b4
Add tiny llama and llama3 feature flags
laggui May 15, 2024
ed1277c
Add default prompt and replace sentencepiece newline tokens
laggui May 21, 2024
fb57388
Add pretrained model/tokenizer download
laggui May 22, 2024
edbf3a6
Use sentencepiece from hf tokenizers
laggui May 23, 2024
67740e3
Fix TinyLlama query/key weights and add chat mode
laggui May 24, 2024
9492e20
Update README and features deps
laggui May 24, 2024
59061d1
Switch to f16
laggui May 27, 2024
fe35d98
Use BinFileRecorder
laggui May 27, 2024
98be063
Fix tiktoken special tokens index offset (could lead to decoding error)
laggui May 28, 2024
456c262
Fix prompt format for different tokenizers
laggui May 28, 2024
86812fd
Add llama-3-8b-instruct
laggui May 28, 2024
9b61726
Add readme with chat example
laggui May 28, 2024
8c9d9e2
Add llama burn generated image
laggui May 28, 2024
7fa84c0
Add llama to repo README
laggui May 28, 2024
7946acd
Link to models root not readme file
laggui May 28, 2024
93f9ecc
Fix typo
laggui May 28, 2024
553f3dd
Pin burn git rev and add tch/wgpu feature flags
laggui May 29, 2024
46f91ee
Update readme and change rev to fix zip version
laggui May 29, 2024
7b3daae
Add bin weights note
laggui Jul 4, 2024
b5ddcc1
Update TensorData usage
laggui Jul 8, 2024
ee7bb49
Switch to NamedMpkFileRecorder
laggui Aug 6, 2024
86eb69f
Cleanup todo
laggui Aug 6, 2024
f72647b
Fix squeezenet half precision flag
laggui Aug 6, 2024
c3ec71e
Update burn pinned revision
laggui Aug 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ examples constructed using the [Burn](https://github.com/burn-rs/burn) deep lear

## Collection of Official Models

| Model | Description | Repository Link |
|-------------------------------------------------|----------------------------------------------------------|------------------------------------------------|
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/README.md) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/README.md) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/README.md) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/README.md) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/README.md) |
| Model | Description | Repository Link |
|-------------------------------------------------|----------------------------------------------------------|---------------------------------------|
| [Llama](https://github.com/meta-llama/llama3) | Llama 3 and TinyLlama large language models. | [llama-burn](llama-burn/) |
| [MobileNetV2](https://arxiv.org/abs/1801.04381) | A CNN model targeted at mobile devices. | [mobilenetv2-burn](mobilenetv2-burn/) |
| [SqueezeNet](https://arxiv.org/abs/1602.07360) | A small CNN-based model for image classification. | [squeezenet-burn](squeezenet-burn/) |
| [ResNet](https://arxiv.org/abs/1512.03385) | A CNN based on residual blocks with skip connections. | [resnet-burn](resnet-burn/) |
| [RoBERTa](https://arxiv.org/abs/1907.11692) | A robustly optimized BERT pretraining approach. | [bert-burn](bert-burn/) |
| [YOLOX](https://arxiv.org/abs/2107.08430) | A single-stage object detector based on the YOLO series. | [yolox-burn](yolox-burn/) |

## Community Contributions

Expand Down
50 changes: 50 additions & 0 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
[package]
authors = ["guillaumelagrange <[email protected]>"]
license = "MIT OR Apache-2.0"
name = "llama-burn"
version = "0.1.0"
edition = "2021"
description = "Llama 3 large language model with Burn"

[features]
default = ["pretrained"]
pretrained = ["burn/network", "dep:dirs"]

llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["dep:tokenizers"]

# Example feature flags (backend selection)
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c", default-features = false }
burn-import = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

# Tiktoken tokenizer (llama 3)
tiktoken-rs = { version = "0.5.8", optional = true }
base64 = { version = "0.22.1", optional = true }
rustc-hash = { version = "1.1.0", optional = true }

# SentencePiece tokenizer (tiny llama / llama 2)
tokenizers = { version = "0.19.1", default-features = false, features = [
"onig",
], optional = true }

rand = { version = "0.8.5", default-features = false, features = [
"std_rng",
] } # std_rng is for no_std

[dev-dependencies]
burn = { git = "https://github.com/tracel-ai/burn", rev = "a53f459f205889a22ecea3713bbae12d3de7eb0c" }
clap = { version = "4.5.4", features = ["derive"] }
14 changes: 14 additions & 0 deletions llama-burn/NOTICES.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# NOTICES AND INFORMATION

This file contains notices and information required by libraries that this repository copied or
derived from. The use of the following resources complies with the licenses provided.

## Implementation

The model implementation was adapted from the original
[Llama 3 implementation](https://github.com/meta-llama/llama3), which is distributed under the
[Meta Llama 3 Community License Agreement](https://github.com/meta-llama/llama3/blob/main/LICENSE).

The TinyLlama implementation is derived from the same code, but its weights and tokenizers were
adapted from the [original implementation](https://github.com/jzhang38/TinyLlama) distributed under
the [Apache 2.0](https://github.com/jzhang38/TinyLlama/blob/main/LICENSE) open source license.
114 changes: 114 additions & 0 deletions llama-burn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Llama Burn

<img src="./assets/llama-burn.jpeg" alt="An image of a llama surrounded by fiery colors and a gust of fire" width="500px"/>

The popular Llama LLM is here!

This repository contains the [Llama 3](https://github.com/meta-llama/llama3) and
[TinyLlama](https://github.com/jzhang38/TinyLlama) implementations with their corresponding
tokenizers. You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama
variants in [src/llama.rs](src/llama.rs).

## Usage

### `Cargo.toml`

Add this to your `Cargo.toml`:

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", default-features = false }
```

If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are
active), enable the corresponding feature flag.

> **Important:** these features require `std`. Note that the weights have been saved in the binary
> format, which is more compact and faster to save & load, but might not be compatible in future
> versions if the Burn data schema were to evolve.

#### Llama 3

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["llama3"] }
```

#### TinyLlama

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["tiny"] }
```

### Example Usage

The [chat completion example](examples/chat.rs) initializes a Llama model from the provided weights
file and generates a sequence of text based on the input prompt. The instruction-tuned model is
loaded for dialogue applications, so the prompt is automatically formatted for chat completion.

The example can be executed on the `tch` backend (CUDA or CPU) or `wgpu`.

| Argument | Description |
| :-------------- | :------------------------------------------------------------------------------------------------------------- |
| `-p` | The prompt or question to pass to the LLM (default: `"How many helicopters can a human eat in one sitting?"`). |
| `-n` | The number of new tokens to generate (default: `50`). |
| `--top-p` | Top-p probability threshold (default: `0.9`). |
| `--temperature` | Temperature value for controlling randomness in sampling. (default: `0.6`). |
| `--max-seq-len` | Maximum sequence length for input text. (default: `128`). |
| `--seed` | The seed to use when generating random samples.. (default: `42`). |

Any of the commands below can be used by appending any of the listed arguments by appending
`[-- <arguments>]`. For example, you can provided your own prompt/question
`-- -p "How many llamas does it take to change a lightbulb?"`.

#### Llama 3

Using the `tch` backend with CUDA:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --features llama3,tch-gpu --example chat
```

Using the `tch` backend with CPU:

```sh
cargo run --release --features llama3,tch-cpu --example chat
```

Using the `wgpu` backend:

```sh
cargo run --release --features llama3,wgpu --example chat
```

**Built with Meta Llama 3.** This example uses the
[Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
instruction-tuned model. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is
also available if you wish to use it in your application.

#### TinyLlama

Using the `tch` backend with CUDA:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --features tiny,tch-gpu --example chat
```

Using the `tch` backend with CPU:

```sh
cargo run --release --features tiny,tch-cpu --example chat
```

Using the `wgpu` backend:

```sh
cargo run --release --features tiny,wgpu --example chat
```

This example uses the
[TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)
instruction-tuned model based on the Llama2 architecture and tokenizer.
Binary file added llama-burn/assets/llama-burn.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
169 changes: 169 additions & 0 deletions llama-burn/examples/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
use std::time::Instant;

use burn::tensor::{backend::Backend, Device};
use clap::Parser;
use llama_burn::{
llama::{Llama, LlamaConfig},
sampling::{Sampler, TopP},
tokenizer::Tokenizer,
};

const DEFAULT_PROMPT: &str = "How many helicopters can a human eat in one sitting?";

#[derive(Parser, Debug)]
#[command(version, about, long_about = None)]
pub struct Config {
/// Top-p probability threshold.
#[arg(long, default_value_t = 0.9)]
top_p: f64,

/// Temperature value for controlling randomness in sampling.
#[arg(long, default_value_t = 0.6)]
temperature: f64,

/// Maximum sequence length for input text.
#[arg(long, default_value_t = 128)]
max_seq_len: usize,

/// The number of new tokens to generate (i.e., the number of generation steps to take).
#[arg(long, short = 'n', default_value_t = 50)]
sample_len: usize,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 42)]
seed: u64,

/// The input prompt.
#[arg(short, long, default_value_t = String::from(DEFAULT_PROMPT))]
prompt: String,
}

pub fn generate<B: Backend, T: Tokenizer>(
llama: &mut Llama<B, T>,
prompt: &str,
sample_len: usize,
temperature: f64,
sampler: &mut Sampler,
) {
let now = Instant::now();
let generated = llama.generate(prompt, sample_len, temperature, sampler);
let elapsed = now.elapsed().as_secs();

println!("> {}\n", generated.text);
println!(
"{} tokens generated ({:.4} tokens/s)\n",
generated.tokens,
generated.tokens as f64 / generated.time
);

println!(
"Generation completed in {}m{}s",
(elapsed / 60),
elapsed % 60
);
}

pub fn chat<B: Backend>(args: Config, device: Device<B>) {
let mut prompt = args.prompt;

// Sampling strategy
let mut sampler = if args.temperature > 0.0 {
Sampler::TopP(TopP::new(args.top_p, args.seed))
} else {
Sampler::Argmax
};

#[cfg(feature = "tiny")]
{
// TinyLlama-1.1B Chat v1.0
let mut llama = LlamaConfig::tiny_llama_pretrained::<B>(&device).unwrap();
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
prompt = format!(
"<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
);

generate(
&mut llama,
&prompt,
args.sample_len,
args.temperature,
&mut sampler,
);
}

#[cfg(feature = "llama3")]
{
// Llama-3-8B-Instruct
let mut llama = LlamaConfig::llama3_8b_pretrained::<B>(true, &device).unwrap();
println!("Processing prompt: {}", prompt);

// Prompt formatting for chat model
prompt = format!(
"<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
);

generate(
&mut llama,
&prompt,
args.sample_len,
args.temperature,
&mut sampler,
);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use super::*;
use burn::{
backend::{libtorch::LibTorchDevice, LibTorch},
tensor::f16,
};

pub fn run(args: Config) {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

chat::<LibTorch<f16>>(args, device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use super::*;
use burn::backend::{libtorch::LibTorchDevice, LibTorch};

pub fn run(args: Config) {
let device = LibTorchDevice::Cpu;

chat::<LibTorch>(args, device);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use super::*;
use burn::backend::wgpu::{Wgpu, WgpuDevice};

pub fn run(args: Config) {
let device = WgpuDevice::default();

chat::<Wgpu>(args, device);
}
}

pub fn main() {
// Parse arguments
let args = Config::parse();

#[cfg(feature = "tch-gpu")]
tch_gpu::run(args);
#[cfg(feature = "tch-cpu")]
tch_cpu::run(args);
#[cfg(feature = "wgpu")]
wgpu::run(args);
}
Loading