Skip to content

Commit

Permalink
fix: wrong behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
senysenyseny16 committed Jul 5, 2023
1 parent 663f9b9 commit 3302062
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 45 deletions.
44 changes: 29 additions & 15 deletions onnx2torch/node_converters/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
'OnnxLayerNorm',
]

from typing import List
from typing import Optional

import torch
import torch.nn.functional as F
Expand All @@ -17,48 +17,62 @@
from onnx2torch.utils.common import get_shape_from_value_info
from onnx2torch.utils.common import onnx_mapping_from_node

AXIS_DEFAULT_VALUE = -1
EPSILON_DEFAULT_VALUE = 1e-5


class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, axis: int, epsilon: float, stash_type: int):
def __init__(self, axis: int, epsilon: float):
super().__init__()
self.axis = axis
self.epsilon = epsilon
self.stash_type = stash_type

def forward( # pylint: disable=missing-function-docstring
self,
input_data: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor,
bias: torch.Tensor,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.layer_norm(input_data, normalized_shape, weight=weight, bias=bias, eps=self.epsilon)
normalized_shape = input_data.shape[self.axis :]
return F.layer_norm(
input=input_data,
normalized_shape=normalized_shape,
weight=weight,
bias=bias,
eps=self.epsilon,
)


@add_converter(operation_type='LayerNormalization', version=17)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
node_attributes = node.attributes
axis = node_attributes.get('axis', -1)
epsilon = node_attributes.get('epsilon', 1e-5)
stash_type = node_attributes.get('stash_type', 1)

axis = node_attributes.get('axis', AXIS_DEFAULT_VALUE)
epsilon = node_attributes.get('epsilon', EPSILON_DEFAULT_VALUE)

if all(value_name in graph.initializers for value_name in node.input_values[1:]):
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)

scale_value_name = node.input_values[1]
bias_value_name = node.input_values[2]
torch_module = nn.LayerNorm(
normalized_shape=input_shape[axis:],
eps=epsilon,
elementwise_affine=True,
)

torch_module = nn.LayerNorm(input_shape[axis], eps=epsilon, elementwise_affine=True)
scale_value_name = node.input_values[1]
bias_value_name = node.input_values[2] if len(node.input_values) > 2 else None

with torch.no_grad():
torch_module.weight.data = graph.initializers[scale_value_name].to_torch()
torch_module.bias.data = graph.initializers[bias_value_name].to_torch()
if bias_value_name is not None:
torch_module.bias.data = graph.initializers[bias_value_name].to_torch()

onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values)
else:
input_value_info = graph.value_info[node.input_values[0]]
input_shape = get_shape_from_value_info(input_value_info)
torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon, stash_type=stash_type)
torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon)
onnx_mapping = onnx_mapping_from_node(node)

return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping)
2 changes: 1 addition & 1 deletion operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
| InstanceNormalization | Y | |
| IsInf | N | |
| IsNaN | N | |
| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" is not implemented |
| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" are not implemented |
| LRN | Y | |
| LSTM | N | |
| LeakyRelu | Y | |
Expand Down
66 changes: 37 additions & 29 deletions tests/node_converters/layer_norm_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# pylint: disable=missing-function-docstring
from typing import List
from typing import Optional

import numpy as np
import onnx
Expand All @@ -10,14 +12,16 @@

def _test_layer_norm(
x: np.ndarray,
scale: np.ndarray,
bias: Optional[np.ndarray],
axis: int,
parameters_as_inputs: bool,
) -> None:
normalized_shape = calculate_normalized_shape(x.shape, -1)
scale = np.random.randn(*normalized_shape).astype(np.float32)
bias = np.random.randn(*normalized_shape).astype(np.float32)

inputs = {'input': x}
parameters = {'scale': scale, 'bias': bias}
parameters = {'scale': scale}
if bias is not None:
parameters['bias'] = bias

initializers = {}

if parameters_as_inputs:
Expand All @@ -27,42 +31,46 @@ def _test_layer_norm(

node = onnx.helper.make_node(
op_type='LayerNormalization',
inputs=['input', 'scale', 'bias'],
inputs=['input', 'scale', 'bias'] if bias is not None else ['input', 'scale'],
outputs=['y'],
axis=axis,
)
model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17)
check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6)
check_onnx_model(
onnx_model=model,
onnx_inputs=inputs,
atol_onnx_torch=1e-5,
atol_torch_cpu_cuda=1e-5,
atol_onnx_torch2onnx=1e-5,
)


# @pytest.mark.parametrize(
# 'parameters_as_inputs',
# (True, False),
# )
@pytest.mark.parametrize('parameters_as_inputs', (True, False))
@pytest.mark.parametrize(
'input_shape',
(
# 1d
[2, 3, 16],
[2, 1, 7],
# # 2d
[2, 3, 16, 16],
[2, 1, 7, 16],
# # 3d
[2, 3, 16, 16, 16],
[2, 1, 16, 7, 16],
[3, 1, 224],
[4, 3, 16, 16],
[5, 1, 32, 32],
[6, 3, 16, 16, 8],
[7, 1, 7, 7, 16],
),
)
def test_layer_norm( # pylint: disable=missing-function-docstring
input_shape: List[int],
parameters_as_inputs: bool = False,
) -> None:
def test_layer_norm(input_shape: List[int], parameters_as_inputs: bool) -> None:
x = np.random.randn(*input_shape).astype(np.float32)

_test_layer_norm(x=x, parameters_as_inputs=parameters_as_inputs)
for axis in [*range(len(input_shape))] + [-1]:
normalized_shape = input_shape[axis:]

scale = np.random.randn(*normalized_shape).astype(np.float32)
bias = np.random.randn(*normalized_shape).astype(np.float32)

def calculate_normalized_shape(x_shape, axis): # pylint: disable=missing-function-docstring
x_rank = len(x_shape)
if axis < 0:
axis = axis + x_rank
return x_shape[axis:]
for bias_ in [bias, None]:
_test_layer_norm(
x=x,
scale=scale,
bias=bias_,
axis=axis,
parameters_as_inputs=parameters_as_inputs,
)

0 comments on commit 3302062

Please sign in to comment.