Skip to content

Commit

Permalink
Update README and features deps
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed May 24, 2024
1 parent 67740e3 commit 9492e20
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 17 deletions.
9 changes: 4 additions & 5 deletions llama-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ edition = "2021"
description = "Llama 3 large language model with Burn"

[features]
default = ["std"] # TODO: remove std default
std = []
pretrained = ["burn/network", "std", "dep:dirs"]
default = ["pretrained"]
pretrained = ["burn/network", "dep:dirs"]

llama3 = ["pretrained", "dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["pretrained", "dep:tokenizers"]
llama3 = ["dep:tiktoken-rs", "dep:rustc-hash", "dep:base64"]
tiny = ["dep:tokenizers"]

[dependencies]
# Note: default-features = false is needed to disable std
Expand Down
36 changes: 26 additions & 10 deletions llama-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

Llama-3 implementation.

You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the YOLOX variants in
[src/model/yolox.rs](src/model/yolox.rs).

The model is [no_std compatible](https://docs.rust-embedded.org/book/intro/no-std.html).
You can find the [Burn](https://github.com/tracel-ai/burn) implementation for the Llama variants in
[src/llama.rs](src/llama.rs).

## Usage

Expand All @@ -18,22 +16,40 @@ Add this to your `Cargo.toml`:
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", default-features = false }
```

If you want to get the COCO pre-trained weights, enable the `pretrained` feature flag.
If you want to use Llama 3 or TinyLlama (including pre-trained weights if default features are
active), enable the corresponding feature flag.

**Important:** these features require `std`.

#### Llama 3

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["pretrained"] }
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["llama3"] }
```

**Important:** this feature requires `std`.
#### TinyLlama

```toml
[dependencies]
llama-burn = { git = "https://github.com/tracel-ai/models", package = "llama-burn", features = ["tiny"] }
```

### Example Usage

The [text generation example](examples/generate.rs) initializes a Llama-3-8B from the provided
weights file with the `Wgpu` backend and generates a sequence of text based on the input prompt.
The [text generation example](examples/generate.rs) initializes a Llama model from the provided
weights file and generates a sequence of text based on the input prompt.

You can run the example with the following command:

### LLama 3

```sh
cargo run --features llama3 --example generate --release
```

### TinyLlama

```sh
cargo run --example generate --release -- --model Meta-Llama-3-8B/consolidated.00.pth --tokenizer Meta-Llama-3-8B/tokenizer.model
cargo run --features tiny --example generate --release
```
4 changes: 2 additions & 2 deletions llama-burn/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl LlamaConfig {
}

/// Load pre-trained Llama-3-8B model with [Tiktoken](https://github.com/openai/tiktoken) tokenizer.
#[cfg(feature = "llama3")]
#[cfg(all(feature = "llama3", feature = "pretrained"))]
pub fn llama3_8b_pretrained<B: Backend>(
device: &Device<B>,
) -> Result<Llama<B, Tiktoken>, String> {
Expand Down Expand Up @@ -133,7 +133,7 @@ impl LlamaConfig {
}

/// Load pre-trained TinyLlama-1.1B Chat v1.0 model with [SentenciePiece](https://github.com/google/sentencepiece) tokenizer.
#[cfg(feature = "tiny")]
#[cfg(all(feature = "tiny", feature = "pretrained"))]
pub fn tiny_llama_pretrained<B: Backend>(
device: &Device<B>,
) -> Result<Llama<B, SentiencePieceTokenizer>, String> {
Expand Down

0 comments on commit 9492e20

Please sign in to comment.