Skip to content

Commit

Permalink
Add layer norm onnx op support (#1680)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Apr 23, 2024
1 parent 1718da5 commit e6b1b7a
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 4 deletions.
2 changes: 1 addition & 1 deletion crates/burn-import/SUPPORTED-ONNX-OPS.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ represent the corresponding Burn Op.
| [InstanceNormalization][79] |||
| [IsInf][80] |||
| [IsNaN][81] |||
| [LayerNormalization][82] | ||
| [LayerNormalization][82] | ||
| [LeakyRelu][83] |||
| [Less][84] |||
| [LessOrEqual][85] |||
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn main() {
.input("tests/gather/gather.onnx")
.input("tests/gelu/gelu.onnx")
.input("tests/global_avr_pool/global_avr_pool.onnx")
.input("tests/layer_norm/layer_norm.onnx")
.input("tests/linear/linear.onnx")
.input("tests/log_softmax/log_softmax.onnx")
.input("tests/log/log.onnx")
Expand Down
Binary file not shown.
41 changes: 41 additions & 0 deletions crates/burn-import/onnx-tests/tests/layer_norm/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python3

# used to generate model: onnx-tests/tests/layer_norm/layer_norm.onnx

import torch
import torch.nn as nn


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.norm = nn.LayerNorm(4)

def forward(self, x):
return self.norm(x)


def main():
# Set random seed for reproducibility
torch.manual_seed(0)

# Export to onnx
model = Model()
model.eval()
device = torch.device("cpu")

onnx_name = "layer_norm.onnx"
test_input = torch.arange(24, dtype=torch.float, device=device).reshape(2, 3, 4)
# LayerNormalization only appeared in opset 17
torch.onnx.export(model, test_input, onnx_name, verbose=False, opset_version=17)

print(f"Finished exporting model to {onnx_name}")

# Output some test data for use in the test
print(f"Test input data: {test_input}")
output = model.forward(test_input)
print(f"Test output data: {output}")


if __name__ == "__main__":
main()
35 changes: 35 additions & 0 deletions crates/burn-import/onnx-tests/tests/onnx_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ include_models!(
gather,
gelu,
global_avr_pool,
layer_norm,
leaky_relu,
linear,
log_softmax,
Expand Down Expand Up @@ -600,6 +601,40 @@ mod tests {
assert!(expected_sum.approx_eq(output_sum, (1.0e-8, 2)));
}

#[test]
fn layer_norm() {
let device = Default::default();
let model: layer_norm::Model<Backend> = layer_norm::Model::default();

// Run the model with ones as input for easier testing
let input = Tensor::<Backend, 3>::from_floats(
[
[[0., 1., 2., 3.], [4., 5., 6., 7.], [8., 9., 10., 11.]],
[
[12., 13., 14., 15.],
[16., 17., 18., 19.],
[20., 21., 22., 23.],
],
],
&device,
);
let output = model.forward(input);
let expected = Data::from([
[
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
],
[
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
[-1.3416, -0.4472, 0.4472, 1.3416],
],
]);

output.to_data().assert_approx_eq(&expected, 4);
}

#[test]
fn leaky_relu() {
// Initialize the model without weights (because the exported file does not contain them)
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-import/src/burn/node/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::layer_norm::LayerNormNode;
use super::mask_where::WhereNode;
use super::unsqueeze::UnsqueezeNode;
use super::{
Expand Down Expand Up @@ -87,6 +88,7 @@ pub enum Node<PS: PrecisionSettings> {
Dropout(DropoutNode),
Gather(GatherNode),
GlobalAvgPool(GlobalAvgPoolNode),
LayerNorm(LayerNormNode<PS>),
Linear(LinearNode<PS>),
Matmul(MatmulNode),
MaxPool2d(MaxPool2dNode),
Expand All @@ -112,6 +114,7 @@ macro_rules! match_all {
Node::Dropout(node) => $func(node),
Node::Gather(node) => $func(node),
Node::GlobalAvgPool(node) => $func(node),
Node::LayerNorm(node) => $func(node),
Node::Linear(node) => $func(node),
Node::Matmul(node) => $func(node),
Node::MaxPool2d(node) => $func(node),
Expand Down Expand Up @@ -147,6 +150,7 @@ impl<PS: PrecisionSettings> Node<PS> {
Node::Dropout(_) => "dropout",
Node::Gather(_) => "gather",
Node::GlobalAvgPool(_) => "global_avg_pool",
Node::LayerNorm(_) => "layer_norm",
Node::Linear(_) => "linear",
Node::Matmul(_) => "matmul",
Node::MaxPool2d(_) => "max_pool2d",
Expand Down
177 changes: 177 additions & 0 deletions crates/burn-import/src/burn/node/layer_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
use super::{Node, NodeCodegen, SerializationBackend};
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
use burn::{
module::{ConstantRecord, Param, ParamId},
nn::{LayerNormConfig, LayerNormRecord},
record::{PrecisionSettings, Record},
tensor::{DataSerialize, Tensor},
};
use proc_macro2::TokenStream;
use quote::quote;
use serde::Serialize;

#[derive(Debug, Clone)]
pub struct LayerNormNode<PS: PrecisionSettings> {
pub field: OtherType,
pub input: TensorType,
pub output: TensorType,
pub gamma: DataSerialize<PS::FloatElem>, // Scale
pub beta: Option<DataSerialize<PS::FloatElem>>, // Bias (B)
pub config: LayerNormConfig,
pub full_precision: bool,
}

impl<PS: PrecisionSettings> LayerNormNode<PS> {
pub fn new<S: AsRef<str>>(
name: S,
input: TensorType,
output: TensorType,
gamma: DataSerialize<PS::FloatElem>,
beta: Option<DataSerialize<PS::FloatElem>>,
config: LayerNormConfig,
full_precision: bool,
) -> Self {
Self {
field: OtherType::new(
name,
quote! {
LayerNorm<B>
},
),
input,
output,
gamma,
beta,
config,
full_precision,
}
}
}

impl<PS: PrecisionSettings> NodeCodegen<PS> for LayerNormNode<PS> {
fn input_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.input.clone())]
}
fn output_types(&self) -> Vec<Type> {
vec![Type::Tensor(self.output.clone())]
}
fn field_type(&self) -> Option<Type> {
Some(Type::Other(self.field.clone()))
}

fn field_init(&self) -> Option<TokenStream> {
let name = &self.field.name;
let num_features = self.config.d_model.to_tokens();
let epsilon = self.config.epsilon;

let tokens = quote! {
let #name = LayerNormConfig::new(#num_features)
.with_epsilon(#epsilon)
.init(device);
};

Some(tokens)
}

fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let device = Default::default();
let record = LayerNormRecord::<SerializationBackend> {
gamma: Param::initialized(
ParamId::new(),
Tensor::from_data(self.gamma.clone().convert(), &device),
),
beta: Param::initialized(
ParamId::new(),
if let Some(beta) = self.beta.clone() {
Tensor::from_data(beta.convert(), &device)
} else {
Tensor::zeros([self.config.d_model], &device)
},
),
epsilon: ConstantRecord::new(),
};

let item = Record::into_item::<PS>(record);
item.serialize(serializer)
}

fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let field = &self.field.name;

// TODO: handle self.full_precision
quote! {
let #output = self.#field.forward(#input);
}
}
fn register_imports(&self, imports: &mut BurnImports) {
imports.register("burn::nn::LayerNorm");
imports.register("burn::nn::LayerNormConfig");
}

fn into_node(self) -> Node<PS> {
Node::LayerNorm(self)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
use burn::{record::FullPrecisionSettings, tensor::Data};

#[test]
fn test_codegen() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(LayerNormNode::new(
"norm",
TensorType::new_float("input", 4),
TensorType::new_float("output", 4),
Data::from([2.]).serialize(),
Some(Data::from([2.]).serialize()),
LayerNormConfig::new(128),
true, // full_precision isn't taken into account
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);

let expected = quote! {
use burn::{
module::Module,
tensor::{backend::Backend, Tensor},
};
use burn::nn::LayerNorm;
use burn::nn::LayerNormConfig;

#[derive(Module, Debug)]
pub struct Model <B: Backend> {
norm: LayerNorm<B>,
phantom: core::marker::PhantomData<B>,
}

impl<B: Backend> Model <B> {
#[allow(unused_variables)]
pub fn new(device: &B::Device) -> Self {
let norm = LayerNormConfig::new(128)
.with_epsilon(0.00001f64)
.init(device);

Self {
norm,
phantom: core::marker::PhantomData,
}
}
#[allow(clippy::let_and_return, clippy::approx_constant)]
pub fn forward(&self, input: Tensor<B, 4>) -> Tensor<B, 4> {
let output = self.norm.forward(input);

output
}
}
};

assert_tokens(graph.codegen(), expected);
}
}
1 change: 1 addition & 0 deletions crates/burn-import/src/burn/node/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub(crate) mod conv_transpose_2d;
pub(crate) mod dropout;
pub(crate) mod gather;
pub(crate) mod global_avg_pool;
pub(crate) mod layer_norm;
pub(crate) mod linear;
pub(crate) mod mask_where;
pub(crate) mod matmul;
Expand Down
1 change: 1 addition & 0 deletions crates/burn-import/src/onnx/dim_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
NodeType::GatherElements => same_as_input(node),
NodeType::GlobalAveragePool => same_as_input(node),
NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node),
NodeType::LayerNormalization => same_as_input(node),
NodeType::Linear => linear_update_outputs(node),
NodeType::Log => same_as_input(node),
NodeType::LogSoftmax => same_as_input(node),
Expand Down
42 changes: 39 additions & 3 deletions crates/burn-import/src/onnx/op_configuration.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn::nn::{
conv::Conv1dConfig,
conv::{Conv2dConfig, ConvTranspose2dConfig},
conv::{Conv1dConfig, Conv2dConfig, ConvTranspose2dConfig},
pool::{AvgPool2dConfig, MaxPool2dConfig},
BatchNormConfig, DropoutConfig, LinearConfig, PaddingConfig1d, PaddingConfig2d,
BatchNormConfig, DropoutConfig, LayerNormConfig, LinearConfig, PaddingConfig1d,
PaddingConfig2d,
};

use super::ir::{ArgType, AttributeValue, Data, Node};
Expand Down Expand Up @@ -465,6 +465,42 @@ pub fn batch_norm_config(node: &Node) -> BatchNormConfig {
.with_momentum(momentum as f64)
}

/// Create a LayerNormConfig from the attributes of the node
pub fn layer_norm_config(node: &Node) -> (LayerNormConfig, bool) {
// Extract the shape of the weight tensor
let tensor_type = if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty {
tensor_type
} else {
panic!("LayerNorm: weight tensor must be present");
};

let num_features: usize = tensor_type.shape.clone().unwrap()[0];

// When `stash_type` is `1` (default), perform operations in 32-bit float and
// cast the results back to original dtype
let mut stash_type = 1;
let mut axis = -1;
let mut epsilon = 1e-5;

for (key, value) in node.attrs.iter() {
match key.as_str() {
"axis" => axis = value.clone().into_i64(),
"epsilon" => epsilon = value.clone().into_f32(),
"stash_type" => stash_type = value.clone().into_i64(),
_ => {}
}
}

if axis != -1 && axis != tensor_type.dim as i64 - 1 {
panic!("LayerNorm: normalization is only supported on the last axis right now")
}

(
LayerNormConfig::new(num_features).with_epsilon(epsilon as f64),
stash_type == 1,
)
}

/// Calculate the padding configuration for a 2D operations such as Convolution and Pooling.
///
/// # Arguments
Expand Down
Loading

0 comments on commit e6b1b7a

Please sign in to comment.