Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: LayerNormalization #143

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from onnx2torch.node_converters.global_average_pool import *
from onnx2torch.node_converters.identity import *
from onnx2torch.node_converters.instance_norm import *
from onnx2torch.node_converters.layer_norm import *
from onnx2torch.node_converters.logical import *
from onnx2torch.node_converters.lrn import *
from onnx2torch.node_converters.matmul import *
Expand Down
64 changes: 64 additions & 0 deletions onnx2torch/node_converters/layer_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
__all__ = [
'OnnxLayerNorm',
]

from typing import List

import torch
import torch.nn.functional as F
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxMapping
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import get_shape_from_value_info
from onnx2torch.utils.common import onnx_mapping_from_node


class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, axis: int, epsilon: float, stash_type: int):
super().__init__()
self.axis = axis
self.epsilon = epsilon
self.stash_type = stash_type

def forward( # pylint: disable=missing-function-docstring
self,
input_data: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return F.layer_norm(input_data, normalized_shape, weight=weight, bias=bias, eps=self.epsilon)


@add_converter(operation_type='LayerNormalization', version=17)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
node_attributes = node.attributes
axis = node_attributes.get('axis', -1)
epsilon = node_attributes.get('epsilon', 1e-5)
stash_type = node_attributes.get('stash_type', 1)
if all(value_name in graph.initializers for value_name in node.input_values[1:]):
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)

scale_value_name = node.input_values[1]
bias_value_name = node.input_values[2]

torch_module = nn.LayerNorm(input_shape[axis], eps=epsilon, elementwise_affine=True)

with torch.no_grad():
torch_module.weight.data = graph.initializers[scale_value_name].to_torch()
torch_module.bias.data = graph.initializers[bias_value_name].to_torch()

onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values)
else:
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)
torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon, stash_type=stash_type)
onnx_mapping = onnx_mapping_from_node(node)

return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping)
1 change: 1 addition & 0 deletions operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
| InstanceNormalization | Y | |
| IsInf | N | |
| IsNaN | N | |
| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" is not implemented |
| LRN | Y | |
| LSTM | N | |
| LeakyRelu | Y | |
Expand Down
68 changes: 68 additions & 0 deletions tests/node_converters/layer_norm_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import List

import numpy as np
import onnx
import pytest

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


def _test_layer_norm(
x: np.ndarray,
parameters_as_inputs: bool,
) -> None:
normalized_shape = calculate_normalized_shape(x.shape, -1)
scale = np.random.randn(*normalized_shape).astype(np.float32)
bias = np.random.randn(*normalized_shape).astype(np.float32)

inputs = {'input': x}
parameters = {'scale': scale, 'bias': bias}
initializers = {}

if parameters_as_inputs:
inputs.update(parameters)
else:
initializers.update(parameters)

node = onnx.helper.make_node(
op_type='LayerNormalization',
inputs=['input', 'scale', 'bias'],
outputs=['y'],
)
model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17)
check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6)


# @pytest.mark.parametrize(
# 'parameters_as_inputs',
# (True, False),
# )
@pytest.mark.parametrize(
'input_shape',
(
# 1d
[2, 3, 16],
[2, 1, 7],
# # 2d
[2, 3, 16, 16],
[2, 1, 7, 16],
# # 3d
[2, 3, 16, 16, 16],
[2, 1, 16, 7, 16],
),
)
def test_layer_norm( # pylint: disable=missing-function-docstring
input_shape: List[int],
parameters_as_inputs: bool = False,
) -> None:
x = np.random.randn(*input_shape).astype(np.float32)

_test_layer_norm(x=x, parameters_as_inputs=parameters_as_inputs)


def calculate_normalized_shape(x_shape, axis): # pylint: disable=missing-function-docstring
x_rank = len(x_shape)
if axis < 0:
axis = axis + x_rank
return x_shape[axis:]