diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index b01bbee9..3ed8cc25 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -20,6 +20,7 @@ from onnx2torch.node_converters.global_average_pool import * from onnx2torch.node_converters.identity import * from onnx2torch.node_converters.instance_norm import * +from onnx2torch.node_converters.layer_norm import * from onnx2torch.node_converters.logical import * from onnx2torch.node_converters.lrn import * from onnx2torch.node_converters.matmul import * diff --git a/onnx2torch/node_converters/layer_norm.py b/onnx2torch/node_converters/layer_norm.py new file mode 100644 index 00000000..e1928a4d --- /dev/null +++ b/onnx2torch/node_converters/layer_norm.py @@ -0,0 +1,64 @@ +__all__ = [ + 'OnnxLayerNorm', +] + +from typing import List + +import torch +import torch.nn.functional as F +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_shape_from_value_info +from onnx2torch.utils.common import onnx_mapping_from_node + + +class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring + def __init__(self, axis: int, epsilon: float, stash_type: int): + super().__init__() + self.axis = axis + self.epsilon = epsilon + self.stash_type = stash_type + + def forward( # pylint: disable=missing-function-docstring + self, + input_data: torch.Tensor, + normalized_shape: List[int], + weight: torch.Tensor, + bias: torch.Tensor, + ) -> torch.Tensor: + return F.layer_norm(input_data, normalized_shape, weight=weight, bias=bias, eps=self.epsilon) + + +@add_converter(operation_type='LayerNormalization', version=17) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + node_attributes = node.attributes + axis = node_attributes.get('axis', -1) + epsilon = node_attributes.get('epsilon', 1e-5) + stash_type = node_attributes.get('stash_type', 1) + if all(value_name in graph.initializers for value_name in node.input_values[1:]): + input_value_info = graph.value_info[node.input_values[0]] + input_shape = get_shape_from_value_info(input_value_info) + + scale_value_name = node.input_values[1] + bias_value_name = node.input_values[2] + + torch_module = nn.LayerNorm(input_shape[axis], eps=epsilon, elementwise_affine=True) + + with torch.no_grad(): + torch_module.weight.data = graph.initializers[scale_value_name].to_torch() + torch_module.bias.data = graph.initializers[bias_value_name].to_torch() + + onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values) + else: + input_value_info = graph.value_info[node.input_values[0]] + input_shape = get_shape_from_value_info(input_value_info) + torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon, stash_type=stash_type) + onnx_mapping = onnx_mapping_from_node(node) + + return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping) diff --git a/operators.md b/operators.md index f2db4396..5610c245 100644 --- a/operators.md +++ b/operators.md @@ -63,6 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops | InstanceNormalization | Y | | | IsInf | N | | | IsNaN | N | | +| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" is not implemented | | LRN | Y | | | LSTM | N | | | LeakyRelu | Y | | diff --git a/tests/node_converters/layer_norm_test.py b/tests/node_converters/layer_norm_test.py new file mode 100644 index 00000000..54807b94 --- /dev/null +++ b/tests/node_converters/layer_norm_test.py @@ -0,0 +1,68 @@ +from typing import List + +import numpy as np +import onnx +import pytest + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + + +def _test_layer_norm( + x: np.ndarray, + parameters_as_inputs: bool, +) -> None: + normalized_shape = calculate_normalized_shape(x.shape, -1) + scale = np.random.randn(*normalized_shape).astype(np.float32) + bias = np.random.randn(*normalized_shape).astype(np.float32) + + inputs = {'input': x} + parameters = {'scale': scale, 'bias': bias} + initializers = {} + + if parameters_as_inputs: + inputs.update(parameters) + else: + initializers.update(parameters) + + node = onnx.helper.make_node( + op_type='LayerNormalization', + inputs=['input', 'scale', 'bias'], + outputs=['y'], + ) + model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17) + check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6) + + +# @pytest.mark.parametrize( +# 'parameters_as_inputs', +# (True, False), +# ) +@pytest.mark.parametrize( + 'input_shape', + ( + # 1d + [2, 3, 16], + [2, 1, 7], + # # 2d + [2, 3, 16, 16], + [2, 1, 7, 16], + # # 3d + [2, 3, 16, 16, 16], + [2, 1, 16, 7, 16], + ), +) +def test_layer_norm( # pylint: disable=missing-function-docstring + input_shape: List[int], + parameters_as_inputs: bool = False, +) -> None: + x = np.random.randn(*input_shape).astype(np.float32) + + _test_layer_norm(x=x, parameters_as_inputs=parameters_as_inputs) + + +def calculate_normalized_shape(x_shape, axis): # pylint: disable=missing-function-docstring + x_rank = len(x_shape) + if axis < 0: + axis = axis + x_rank + return x_shape[axis:]