Skip to content

Commit

Permalink
Print module part3 - Update book (#1940)
Browse files Browse the repository at this point in the history
* Update book example guide

* Update Module book section on module display
  • Loading branch information
antimora authored Jul 1, 2024
1 parent 3a9367d commit 6f2ba34
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 23 deletions.
36 changes: 35 additions & 1 deletion burn-book/src/basic-workflow/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ Next, we need to instantiate the model for training.
# linear2: Linear<B>,
# activation: Relu,
# }
#
#
#[derive(Config, Debug)]
pub struct ModelConfig {
num_classes: usize,
Expand All @@ -217,6 +217,40 @@ impl ModelConfig {
}
```


At a glance, you can view the model configuration by printing the model instance:

```rust , ignore
use burn::backend::Wgpu;
use guide::model::ModelConfig;

fn main() {
type MyBackend = Wgpu<f32, i32>;

let device = Default::default();
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);

println!("{}", model);
}
```

Output:

```rust , ignore
Model {
conv1: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 80}
conv2: Conv2d {stride: [1, 1], kernel_size: [3, 3], dilation: [1, 1], groups: 1, padding: Valid, params: 1168}
pool: AdaptiveAvgPool2d {output_size: [8, 8]}
dropout: Dropout {prob: 0.5}
linear1: Linear {d_input: 1024, d_output: 512, bias: true, params: 524800}
linear2: Linear {d_input: 512, d_output: 10, bias: true, params: 5130}
activation: Relu
params: 531178
}
```



<details>
<summary><strong>🦀 References</strong></summary>

Expand Down
73 changes: 64 additions & 9 deletions burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ the `Module` derive, you need to be careful to achieve the behavior you want.
These methods are available for all modules.

| Burn API | PyTorch Equivalent |
|-----------------------------------------|------------------------------------------|
| --------------------------------------- | ---------------------------------------- |
| `module.devices()` | N/A |
| `module.fork(device)` | Similar to `module.to(device).detach()` |
| `module.to_device(device)` | `module.to(device)` |
Expand All @@ -69,7 +69,7 @@ Similar to the backend trait, there is also the `AutodiffModule` trait to signif
autodiff support.

| Burn API | PyTorch Equivalent |
|------------------|--------------------|
| ---------------- | ------------------ |
| `module.valid()` | `module.eval()` |

## Visitor & Mapper
Expand All @@ -96,7 +96,62 @@ pub trait ModuleVisitor<B: Backend> {
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) ->
Tensor<B, D>;
}
```

## Module Display

Burn provides a simple way to display the structure of a module and its configuration at a glance.
You can print the module to see its structure, which is useful for debugging and tracking changes
across different versions of a module. (See the print output of the
[Basic Workflow Model](../basic-workflow/model.md) example.)

To customize the display of a module, you can implement the `ModuleDisplay` trait for your module.
This will change the default display settings for the module and its children. Note that
`ModuleDisplay` is automatically implemented for all modules, but you can override it to customize
the display by annotating the module with `#[module(custom_display)]`.

```rust
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: Gelu,
}

impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
/// Custom settings for the display of the module.
/// If `None` is returned, the default settings will be used.
fn custom_settings(&self) -> Option<burn::module::DisplaySettings> {
DisplaySettings::new()
// Will show all attributes (default is false)
.with_show_all_attributes(false)
// Will show each attribute on a new line (default is true)
.with_new_line_after_attribute(true)
// Will show the number of parameters (default is true)
.with_show_num_parameters(true)
// Will indent by 2 spaces (default is 2)
.with_indentation_size(2)
// Will show the parameter ID (default is false)
.with_show_param_id(false)
// Convenience method to wrap settings in Some()
.optional()
}

/// Custom content to be displayed.
/// If `None` is returned, the default content will be used
/// (all attributes of the module)
fn custom_content(&self, content: Content) -> Option<Content> {
content
.add("linear_inner", &self.linear_inner)
.add("linear_outer", &self.linear_outer)
.add("anything", "anything_else")
.optional()
}
}
```

Expand All @@ -107,7 +162,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### General

| Burn API | PyTorch Equivalent |
|----------------|-----------------------------------------------|
| -------------- | --------------------------------------------- |
| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. |
| `Dropout` | `nn.Dropout` |
| `Embedding` | `nn.Embedding` |
Expand All @@ -125,7 +180,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Convolutions

| Burn API | PyTorch Equivalent |
|-------------------|----------------------|
| ----------------- | -------------------- |
| `Conv1d` | `nn.Conv1d` |
| `Conv2d` | `nn.Conv2d` |
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
Expand All @@ -134,7 +189,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Pooling

| Burn API | PyTorch Equivalent |
|---------------------|------------------------|
| ------------------- | ---------------------- |
| `AdaptiveAvgPool1d` | `nn.AdaptiveAvgPool1d` |
| `AdaptiveAvgPool2d` | `nn.AdaptiveAvgPool2d` |
| `AvgPool1d` | `nn.AvgPool1d` |
Expand All @@ -145,15 +200,15 @@ Burn comes with built-in modules that you can use to build your own modules.
### RNNs

| Burn API | PyTorch Equivalent |
|------------------|------------------------|
| ---------------- | ---------------------- |
| `Gru` | `nn.GRU` |
| `Lstm`/`BiLstm` | `nn.LSTM` |
| `GateController` | _No direct equivalent_ |

### Transformer

| Burn API | PyTorch Equivalent |
|----------------------|-------------------------|
| -------------------- | ----------------------- |
| `MultiHeadAttention` | `nn.MultiheadAttention` |
| `TransformerDecoder` | `nn.TransformerDecoder` |
| `TransformerEncoder` | `nn.TransformerEncoder` |
Expand All @@ -163,7 +218,7 @@ Burn comes with built-in modules that you can use to build your own modules.
### Loss

| Burn API | PyTorch Equivalent |
|--------------------|-----------------------|
| ------------------ | --------------------- |
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |
19 changes: 17 additions & 2 deletions examples/guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@ This example corresponds to the [book's guide](https://burn.dev/book/basic-workf

## Example Usage


### Training

```sh
cargo run --bin train --release
```

### Inference

```sh
cargo run --bin infer --release
```

### Print the model

```sh
cargo run --example guide
```
cargo run --bin print --release
```
2 changes: 1 addition & 1 deletion examples/guide/examples/guide.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use std::process::Command;

fn main() {
Command::new("cargo")
.args(["run", "--bin", "guide"])
.args(["run", "--bin", "train", "--release"])
.status()
.expect("guide example should run");
}
20 changes: 20 additions & 0 deletions examples/guide/src/bin/infer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use burn::{backend::Wgpu, data::dataset::Dataset};
use guide::inference;

fn main() {
type MyBackend = Wgpu<f32, i32>;

let device = burn::backend::wgpu::WgpuDevice::default();

// All the training artifacts are saved in this directory
let artifact_dir = "/tmp/guide";

// Infer the model
inference::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::vision::MnistDataset::test()
.get(42)
.unwrap(),
);
}
11 changes: 11 additions & 0 deletions examples/guide/src/bin/print.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use burn::backend::Wgpu;
use guide::model::ModelConfig;

fn main() {
type MyBackend = Wgpu<f32, i32>;

let device = Default::default();
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);

println!("{}", model);
}
22 changes: 14 additions & 8 deletions examples/guide/src/main.rs → examples/guide/src/bin/train.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
mod data;
mod inference;
mod model;
mod training;

use crate::{model::ModelConfig, training::TrainingConfig};
use burn::{
backend::{Autodiff, Wgpu},
data::dataset::Dataset,
optim::AdamConfig,
};
use guide::{
inference,
model::ModelConfig,
training::{self, TrainingConfig},
};

fn main() {
type MyBackend = Wgpu<f32, i32>;
type MyAutodiffBackend = Autodiff<MyBackend>;

// Create a default Wgpu device
let device = burn::backend::wgpu::WgpuDevice::default();

// All the training artifacts will be saved in this directory
let artifact_dir = "/tmp/guide";
crate::training::train::<MyAutodiffBackend>(

// Train the model
training::train::<MyAutodiffBackend>(
artifact_dir,
TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
device.clone(),
);
crate::inference::infer::<MyBackend>(

// Infer the model
inference::infer::<MyBackend>(
artifact_dir,
device,
burn::data::dataset::vision::MnistDataset::test()
Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use burn::{

pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {
let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))
.expect("Config should exist for the model");
.expect("Config should exist for the model; run train first");
let record = CompactRecorder::new()
.load(format!("{artifact_dir}/model").into(), &device)
.expect("Trained model should exist");
.expect("Trained model should exist; run train first");

let model: Model<B> = config.model.init(&device).load_record(record);

Expand Down

0 comments on commit 6f2ba34

Please sign in to comment.