Skip to content

Commit

Permalink
Pin burn git rev and add tch/wgpu feature flags
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed May 29, 2024
1 parent 93f9ecc commit 553f3dd
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 24 deletions.
16 changes: 9 additions & 7 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ pretrained = ["burn/network", "dep:dirs"]
llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["dep:tokenizers"]

# Example feature flags (backend selection)
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { path = "../../burn/crates/burn", default-features = false }
burn-import = { path = "../../burn/crates/burn-import" }
# burn = { version = "0.13.0", default-features = false }
# burn-import = { version = "0.13.0" }
burn = { git = "https://github.com/tracel-ai/burn", rev = "e4836241e1e5d7391aa278f32a8fffeb1cdbe12a", default-features = false }
burn-import = { git = "https://github.com/tracel-ai/burn", rev = "e4836241e1e5d7391aa278f32a8fffeb1cdbe12a" }
itertools = { version = "0.12.1", default-features = false, features = [
"use_alloc",
] }
Expand All @@ -41,6 +44,5 @@ rand = { version = "0.8.5", default-features = false, features = [
] } # std_rng is for no_std

[dev-dependencies]
burn = { path = "../../burn/crates/burn", features = ["tch"] }
clap = { version = "4.5.4", features = ["derive"] }
# burn = { version = "0.13.0", features = ["wgpu"] }
burn = { git = "https://github.com/tracel-ai/burn", rev = "e4836241e1e5d7391aa278f32a8fffeb1cdbe12a" }
clap = { version = "4.5.4", features = ["derive"] }
51 changes: 46 additions & 5 deletions llama-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,64 @@ The [chat completion example](examples/chat.rs) initializes a Llama model from t
file and generates a sequence of text based on the input prompt. The instruction-tuned model is
loaded for dialogue applications, so the prompt is automatically formatted for chat completion.

You can run the example with the following command:
The example can be executed on the `tch` backend (CUDA or CPU) or `wgpu`.

### LLama 3
| Argument | Description |
|:----------------|:---------------------------------------------------------------------------------------------------------------|
| `-p` | The prompt or question to pass to the LLM (default: `"How many helicopters can a human eat in one sitting?"`). |
| `-n` | The number of new tokens to generate (default: `50`). |
| `--top-p` | Top-p probability threshold (default: `0.9`). |
| `--temperature` | Temperature value for controlling randomness in sampling. (default: `0.6`). |
| `--max-seq-len` | Maximum sequence length for input text. (default: `128`). |
| `--seed` | The seed to use when generating random samples.. (default: `42`). |

Any of the commands below can be used by appending any of the listed arguments by appending `[-- <arguments>]`. For example, you can provided your own prompt/question `-- -p "How many llamas does it take to change a lightbulb?"`.

#### Llama 3

Using the `tch` backend with CUDA:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --features llama3,tch-gpu --example chat
```

Using the `tch` backend with CPU:

```sh
cargo run --release --features llama3,tch-cpu --example chat
```

Using the `wgpu` backend:

```sh
cargo run --release --features llama3 --example chat [-- --prompt "<your question/prompt here>"]
cargo run --release --features llama3,wgpu --example chat
```

**Built with Meta Llama 3.** This example uses the
[Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
instruction-tuned model. Note that the [base pre-trained Llama-3 model](./src/pretrained.rs#L77) is
also available if you wish to use it in your application.

### TinyLlama
#### TinyLlama

Using the `tch` backend with CUDA:

```sh
export TORCH_CUDA_VERSION=cu121
cargo run --release --features tiny,tch-gpu --example chat
```

Using the `tch` backend with CPU:

```sh
cargo run --release --features tiny,tch-cpu --example chat
```

Using the `wgpu` backend:

```sh
cargo run --release --features tiny --example chat [-- --prompt "<your question/prompt here>"]
cargo run --release --features tiny,wgpu --example chat
```

This example uses the
Expand Down
69 changes: 57 additions & 12 deletions llama-burn/examples/chat.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::time::Instant;

use burn::{
backend::{libtorch::LibTorchDevice, LibTorch},
tensor::{backend::Backend, f16},
};
use burn::tensor::{backend::Backend, Device};
use clap::Parser;
use llama_burn::{
llama::{Llama, LlamaConfig},
Expand All @@ -25,7 +22,7 @@ pub struct Config {
temperature: f64,

/// Maximum sequence length for input text.
#[arg(long, default_value_t = 512)]
#[arg(long, default_value_t = 128)]
max_seq_len: usize,

/// The number of new tokens to generate (i.e., the number of generation steps to take).
Expand Down Expand Up @@ -66,13 +63,7 @@ pub fn generate<B: Backend, T: Tokenizer>(
);
}

pub fn main() {
type B = LibTorch<f16>;

// Parse arguments
let args = Config::parse();

let device = LibTorchDevice::Cuda(0);
pub fn chat<B: Backend>(args: Config, device: Device<B>) {
let mut prompt = args.prompt;

// Sampling strategy
Expand Down Expand Up @@ -122,3 +113,57 @@ pub fn main() {
);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use super::*;
use burn::{
backend::{libtorch::LibTorchDevice, LibTorch},
tensor::f16,
};

pub fn run(args: Config) {
#[cfg(not(target_os = "macos"))]
let device = LibTorchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

chat::<LibTorch<f16>>(args, device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use super::*;
use burn::backend::{libtorch::LibTorchDevice, LibTorch};

pub fn run(args: Config) {
let device = LibTorchDevice::Cpu;

chat::<LibTorch>(args, device);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use super::*;
use burn::backend::wgpu::{Wgpu, WgpuDevice};

pub fn run(args: Config) {
let device = WgpuDevice::default();

chat::<Wgpu>(args, device);
}
}

pub fn main() {
// Parse arguments
let args = Config::parse();

#[cfg(feature = "tch-gpu")]
tch_gpu::run(args);
#[cfg(feature = "tch-cpu")]
tch_cpu::run(args);
#[cfg(feature = "wgpu")]
wgpu::run(args);
}

0 comments on commit 553f3dd

Please sign in to comment.