Skip to content

Commit

Permalink
[ResNet] Upgrade to Burn 0.13.0 (#25)
Browse files Browse the repository at this point in the history
* Change to Relu and remove init_with methods

* Refactor residual blocks with enum

* Change burn-rs -> tracel-ai links

* Add with_classes output layer init method

* Upgrade to burn 0.13.0
  • Loading branch information
laggui committed Apr 16, 2024
1 parent 323a00d commit 5677dc0
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 298 deletions.
2 changes: 1 addition & 1 deletion bert-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Add this to your `Cargo.toml`:

```toml
[dependencies]
bert-burn = { git = "https://github.com/burn-rs/models", package = "bert-burn", default-features = false }
bert-burn = { git = "https://github.com/tracel-ai/models", package = "bert-burn", default-features = false }
```

## Example Usage
Expand Down
10 changes: 4 additions & 6 deletions resnet-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@ pretrained = ["burn/network", "std", "dep:dirs"]

[dependencies]
# Note: default-features = false is needed to disable std
burn = { git = "https://github.com/tracel-ai/burn.git", rev = "9a2cbadd41161c8aac142bbcb9c2ceaf5ffd6edd", default-features = false }
burn-import = { git = "https://github.com/tracel-ai/burn.git", rev = "9a2cbadd41161c8aac142bbcb9c2ceaf5ffd6edd" }
burn = { version = "0.13.0", default-features = false }
burn-import = "0.13.0"
dirs = { version = "5.0.1", optional = true }
serde = { version = "1.0.192", default-features = false, features = [
"derive",
"alloc",
] } # alloc is for no_std, derive is needed

[dev-dependencies]
burn = { git = "https://github.com/tracel-ai/burn.git", rev = "9a2cbadd41161c8aac142bbcb9c2ceaf5ffd6edd", features = [
"ndarray",
] }
image = { version = "0.24.7", features = ["png", "jpeg"] }
burn = { version = "0.13.0", features = ["ndarray"] }
image = { version = "0.24.9", features = ["png", "jpeg"] }
4 changes: 2 additions & 2 deletions resnet-burn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ Add this to your `Cargo.toml`:

```toml
[dependencies]
resnet-burn = { git = "https://github.com/burn-rs/models", package = "resnet-burn", default-features = false }
resnet-burn = { git = "https://github.com/tracel-ai/models", package = "resnet-burn", default-features = false }
```

If you want to get the pre-trained ImageNet weights, enable the `pretrained` feature flag.

```toml
[dependencies]
resnet-burn = { git = "https://github.com/burn-rs/models", package = "resnet-burn", features = ["pretrained"] }
resnet-burn = { git = "https://github.com/tracel-ai/models", package = "resnet-burn", features = ["pretrained"] }
```

**Important:** this feature requires `std`.
Expand Down
2 changes: 1 addition & 1 deletion resnet-burn/examples/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub fn main() {

// Create ResNet-18
let device = Default::default();
let model: ResNet<NdArray, _> =
let model: ResNet<NdArray> =
ResNet::resnet18_pretrained(weights::ResNet18::ImageNet1kV1, &device)
.map_err(|err| format!("Failed to load pre-trained weights.\nError: {err}"))
.unwrap();
Expand Down
231 changes: 71 additions & 160 deletions resnet-burn/src/model/block.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,56 @@
use core::f64::consts::SQRT_2;
use core::marker::PhantomData;

use alloc::vec::Vec;

use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
BatchNorm, BatchNormConfig, Initializer, PaddingConfig2d, ReLU,
BatchNorm, BatchNormConfig, Initializer, PaddingConfig2d, Relu,
},
tensor::{backend::Backend, Device, Tensor},
};

pub trait ResidualBlock<B: Backend> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4>;
#[derive(Module, Debug)]
pub enum ResidualBlock<B: Backend> {
/// A bottleneck residual block.
Bottleneck(Bottleneck<B>),
/// A basic residual block.
Basic(BasicBlock<B>),
}

impl<B: Backend> ResidualBlock<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
match self {
Self::Basic(block) => block.forward(input),
Self::Bottleneck(block) => block.forward(input),
}
}
}

#[derive(Config)]
struct ResidualBlockConfig {
in_channels: usize,
out_channels: usize,
stride: usize,
bottleneck: bool,
}

impl ResidualBlockConfig {
fn init<B: Backend>(&self, device: &Device<B>) -> ResidualBlock<B> {
if self.bottleneck {
ResidualBlock::Bottleneck(
BottleneckConfig::new(self.in_channels, self.out_channels, self.stride)
.init(device),
)
} else {
ResidualBlock::Basic(
BasicBlockConfig::new(self.in_channels, self.out_channels, self.stride)
.init(device),
)
}
}
}

/// ResNet [basic residual block](https://paperswithcode.com/method/residual-block) implementation.
Expand All @@ -22,13 +59,13 @@ pub trait ResidualBlock<B: Backend> {
pub struct BasicBlock<B: Backend> {
conv1: Conv2d<B>,
bn1: BatchNorm<B, 2>,
relu: ReLU,
relu: Relu,
conv2: Conv2d<B>,
bn2: BatchNorm<B, 2>,
downsample: Option<Downsample<B>>,
}

impl<B: Backend> ResidualBlock<B> for BasicBlock<B> {
impl<B: Backend> BasicBlock<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let identity = input.clone();

Expand Down Expand Up @@ -63,15 +100,15 @@ impl<B: Backend> ResidualBlock<B> for BasicBlock<B> {
pub struct Bottleneck<B: Backend> {
conv1: Conv2d<B>,
bn1: BatchNorm<B, 2>,
relu: ReLU,
relu: Relu,
conv2: Conv2d<B>,
bn2: BatchNorm<B, 2>,
conv3: Conv2d<B>,
bn3: BatchNorm<B, 2>,
downsample: Option<Downsample<B>>,
}

impl<B: Backend> ResidualBlock<B> for Bottleneck<B> {
impl<B: Backend> Bottleneck<B> {
fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let identity = input.clone();

Expand Down Expand Up @@ -114,12 +151,11 @@ impl<B: Backend> Downsample<B> {

/// Collection of sequential residual blocks.
#[derive(Module, Debug)]
pub struct LayerBlock<B: Backend, M> {
blocks: Vec<M>,
_backend: PhantomData<B>,
pub struct LayerBlock<B: Backend> {
blocks: Vec<ResidualBlock<B>>,
}

impl<B: Backend, M: ResidualBlock<B>> LayerBlock<B, M> {
impl<B: Backend> LayerBlock<B> {
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let mut out = input;
for block in &self.blocks {
Expand Down Expand Up @@ -187,7 +223,7 @@ impl BasicBlockConfig {
.with_initializer(initializer.clone())
.init(device),
bn1: self.bn1.init(device),
relu: ReLU::new(),
relu: Relu::new(),
conv2: self
.conv2
.clone()
Expand All @@ -197,24 +233,6 @@ impl BasicBlockConfig {
downsample: self.downsample.as_ref().map(|d| d.init(device)),
}
}

/// Initialize a new [basic residual block](BasicBlock) module with a [record](BasicBlockRecord).
fn init_with<B: Backend>(&self, record: BasicBlockRecord<B>) -> BasicBlock<B> {
BasicBlock {
conv1: self.conv1.init_with(record.conv1),
bn1: self.bn1.init_with(record.bn1),
relu: ReLU::new(),
conv2: self.conv2.init_with(record.conv2),
bn2: self.bn2.init_with(record.bn2),
downsample: self.downsample.as_ref().map(|d| {
d.init_with(
record
.downsample
.expect("Should initialize downsample block with record."),
)
}),
}
}
}

/// [Bottleneck residual block](Bottleneck) configuration.
Expand Down Expand Up @@ -286,7 +304,7 @@ impl BottleneckConfig {
.with_initializer(initializer.clone())
.init(device),
bn1: self.bn1.init(device),
relu: ReLU::new(),
relu: Relu::new(),
conv2: self
.conv2
.clone()
Expand All @@ -302,26 +320,6 @@ impl BottleneckConfig {
downsample: self.downsample.as_ref().map(|d| d.init(device)),
}
}

/// Initialize a new [bottleneck residual block](Bottleneck) module with a [record](BottleneckRecord).
fn init_with<B: Backend>(&self, record: BottleneckRecord<B>) -> Bottleneck<B> {
Bottleneck {
conv1: self.conv1.init_with(record.conv1),
bn1: self.bn1.init_with(record.bn1),
relu: ReLU::new(),
conv2: self.conv2.init_with(record.conv2),
bn2: self.bn2.init_with(record.bn2),
conv3: self.conv3.init_with(record.conv3),
bn3: self.bn3.init_with(record.bn3),
downsample: self.downsample.as_ref().map(|d| {
d.init_with(
record
.downsample
.expect("Should initialize downsample block with record."),
)
}),
}
}
}

/// [Downsample](Downsample) configuration.
Expand Down Expand Up @@ -356,132 +354,45 @@ impl DownsampleConfig {
bn: self.bn.init(device),
}
}

/// Initialize a new [downsample](Downsample) module with a [record](DownsampleRecord).
fn init_with<B: Backend>(&self, record: DownsampleRecord<B>) -> Downsample<B> {
Downsample {
conv: self.conv.init_with(record.conv),
bn: self.bn.init_with(record.bn),
}
}
}

/// [Residual layer block](LayerBlock) configuration.
pub struct LayerBlockConfig<M> {
#[derive(Config)]
pub struct LayerBlockConfig {
num_blocks: usize,
in_channels: usize,
out_channels: usize,
stride: usize,
_block: PhantomData<M>,
bottleneck: bool,
}

impl<M> LayerBlockConfig<M> {
/// Create a new instance of the layer block [config](LayerBlockConfig).
pub fn new(num_blocks: usize, in_channels: usize, out_channels: usize, stride: usize) -> Self {
Self {
num_blocks,
in_channels,
out_channels,
stride,
_block: PhantomData,
}
}
}

impl<B: Backend> LayerBlockConfig<BasicBlock<B>> {
/// Initialize a new [LayerBlock](LayerBlock) module with [basic residual blocks](BasicBlock).
pub fn init(&self, device: &Device<B>) -> LayerBlock<B, BasicBlock<B>> {
impl LayerBlockConfig {
/// Initialize a new [LayerBlock](LayerBlock) module.
pub fn init<B: Backend>(&self, device: &Device<B>) -> LayerBlock<B> {
let blocks = (0..self.num_blocks)
.map(|b| {
if b == 0 {
// First block uses the specified stride
BasicBlockConfig::new(self.in_channels, self.out_channels, self.stride)
.init(device)
ResidualBlockConfig::new(
self.in_channels,
self.out_channels,
self.stride,
self.bottleneck,
)
.init(device)
} else {
// Other blocks use a stride of 1
BasicBlockConfig::new(self.out_channels, self.out_channels, 1).init(device)
ResidualBlockConfig::new(
self.out_channels,
self.out_channels,
1,
self.bottleneck,
)
.init(device)
}
})
.collect();

LayerBlock {
blocks,
_backend: PhantomData,
}
}

/// Initialize a new [LayerBlock](LayerBlock) module with a [record](LayerBlockRecord) for
/// [basic residual blocks](BasicBlock).
pub fn init_with(
&self,
record: LayerBlockRecord<B, BasicBlock<B>>,
) -> LayerBlock<B, BasicBlock<B>> {
let blocks = (0..self.num_blocks)
.zip(record.blocks)
.map(|(b, rec)| {
if b == 0 {
// First block uses the specified stride
BasicBlockConfig::new(self.in_channels, self.out_channels, self.stride)
.init_with(rec)
} else {
// Other blocks use a stride of 1
BasicBlockConfig::new(self.out_channels, self.out_channels, 1).init_with(rec)
}
})
.collect();

LayerBlock {
blocks,
_backend: PhantomData,
}
}
}

impl<B: Backend> LayerBlockConfig<Bottleneck<B>> {
/// Initialize a new [LayerBlock](LayerBlock) module with [bottleneck residual blocks](Bottleneck).
pub fn init(&self, device: &Device<B>) -> LayerBlock<B, Bottleneck<B>> {
let blocks = (0..self.num_blocks)
.map(|b| {
if b == 0 {
// First block uses the specified stride
BottleneckConfig::new(self.in_channels, self.out_channels, self.stride)
.init(device)
} else {
// Other blocks use a stride of 1
BottleneckConfig::new(self.out_channels, self.out_channels, 1).init(device)
}
})
.collect();

LayerBlock {
blocks,
_backend: PhantomData,
}
}

/// Initialize a new [LayerBlock](LayerBlock) module with a [record](LayerBlockRecord) for
/// [bottleneck residual blocks](Bottleneck).
pub fn init_with(
&self,
record: LayerBlockRecord<B, Bottleneck<B>>,
) -> LayerBlock<B, Bottleneck<B>> {
let blocks = (0..self.num_blocks)
.zip(record.blocks)
.map(|(b, rec)| {
if b == 0 {
// First block uses the specified stride
BottleneckConfig::new(self.in_channels, self.out_channels, self.stride)
.init_with(rec)
} else {
// Other blocks use a stride of 1
BottleneckConfig::new(self.out_channels, self.out_channels, 1).init_with(rec)
}
})
.collect();

LayerBlock {
blocks,
_backend: PhantomData,
}
LayerBlock { blocks }
}
}
Loading

0 comments on commit 5677dc0

Please sign in to comment.