From 6d8e4b17fbd3783fd0e391ece4ca7ab45a3ed7dc Mon Sep 17 00:00:00 2001 From: Mason Ma Date: Thu, 23 Feb 2023 15:36:58 +0800 Subject: [PATCH 1/5] feat: EyeLike --- onnx2torch/node_converters/__init__.py | 1 + onnx2torch/node_converters/eye.py | 44 ++++++++++++++++++++++++++ tests/node_converters/eye_test.py | 34 ++++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 onnx2torch/node_converters/eye.py create mode 100644 tests/node_converters/eye_test.py diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index b01bbee9..571bf82b 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -13,6 +13,7 @@ from onnx2torch.node_converters.dropout import * from onnx2torch.node_converters.einsum import * from onnx2torch.node_converters.expand import * +from onnx2torch.node_converters.eye import * from onnx2torch.node_converters.flatten import * from onnx2torch.node_converters.functions import * from onnx2torch.node_converters.gather import * diff --git a/onnx2torch/node_converters/eye.py b/onnx2torch/node_converters/eye.py new file mode 100644 index 00000000..c2d0293a --- /dev/null +++ b/onnx2torch/node_converters/eye.py @@ -0,0 +1,44 @@ +__all__ = [ + 'OnnxEyeLike', +] + +import torch +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxToTorchModule, OperationConverterResult, onnx_mapping_from_node +from torch import nn + + +class OnnxEyeLike(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring + def __init__(self, k: int, dtype: torch.dtype = None): + super().__init__() + self.dtype = dtype + self.k = k + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + if len(x.shape) != 2: + raise ValueError('OnnxEyeLike only support 2-D tensor') + + n = x.size(dim=0) + m = x.size(dim=1) + if self.k > n: + raise ValueError(f'Error EyeLike Attribute k value, the k value is {self.k}, but x shape is {(n,m)}') + + if self.k == 0: + return torch.eye(n, m, dtype=self.dtype) + + k_tensor = torch.zeros(n, self.k, dtype=self.dtype) + eye_tensor = torch.eye(n, m - self.k, dtype=self.dtype) + return torch.concat([k_tensor, eye_tensor], axis=1) + + +@add_converter(operation_type='EyeLike', version=9) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument + node_attributes = node.attributes + k = node_attributes.get('k', 0) + dtype = node_attributes.get('dtype', torch.float32) + return OperationConverterResult( + torch_module=OnnxEyeLike(dtype=dtype, k=k), + onnx_mapping=onnx_mapping_from_node(node=node), + ) diff --git a/tests/node_converters/eye_test.py b/tests/node_converters/eye_test.py new file mode 100644 index 00000000..8cbb3140 --- /dev/null +++ b/tests/node_converters/eye_test.py @@ -0,0 +1,34 @@ +from typing import Tuple + +import numpy as np +import onnx +import pytest +from tests.utils.common import check_onnx_model, make_model_from_nodes + + +@pytest.mark.parametrize('dtype', ((None),)) +@pytest.mark.parametrize('k', [0, 1, 2]) +@pytest.mark.parametrize( + 'input_shapes', + ( + ((2, 3)), + ((3, 4)), + ((3, 3)), + ), +) +def test_eye( # pylint: disable=missing-function-docstring + input_shapes: Tuple[int], + dtype: str, + k: int, +) -> None: + input_values = np.random.randint(0, 100, size=input_shapes) + + test_inputs = {'x': input_values} + + node = onnx.helper.make_node(op_type='EyeLike', inputs=['x'], outputs=['z'], k=k, dtype=dtype) + model = make_model_from_nodes( + nodes=node, + initializers={}, + inputs_example=test_inputs, + ) + check_onnx_model(model, test_inputs) From b41cb5f138605e782fefbb309b51f68e32d93dc9 Mon Sep 17 00:00:00 2001 From: Mason Ma Date: Thu, 23 Feb 2023 16:07:23 +0800 Subject: [PATCH 2/5] reformatted code --- onnx2torch/node_converters/eye.py | 30 +++++++++++++++++------------- tests/node_converters/eye_test.py | 6 +++--- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/onnx2torch/node_converters/eye.py b/onnx2torch/node_converters/eye.py index c2d0293a..ec96e286 100644 --- a/onnx2torch/node_converters/eye.py +++ b/onnx2torch/node_converters/eye.py @@ -3,42 +3,46 @@ ] import torch +from torch import nn + from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode from onnx2torch.utils.common import OnnxToTorchModule, OperationConverterResult, onnx_mapping_from_node -from torch import nn class OnnxEyeLike(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring - def __init__(self, k: int, dtype: torch.dtype = None): + def __init__(self, eyelike_k: int, dtype: torch.dtype = None): super().__init__() self.dtype = dtype - self.k = k + self.eyelike_k = eyelike_k def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring if len(x.shape) != 2: raise ValueError('OnnxEyeLike only support 2-D tensor') - n = x.size(dim=0) - m = x.size(dim=1) - if self.k > n: - raise ValueError(f'Error EyeLike Attribute k value, the k value is {self.k}, but x shape is {(n,m)}') + input_value_n = x.size(dim=0) + input_value_m = x.size(dim=1) + if self.eyelike_k > input_value_n: + raise ValueError( + f'Error EyeLike Attribute k value, the k value is {self.eyelike_k},' + 'but x shape is {(input_value_n, input_value_m)}' + ) - if self.k == 0: - return torch.eye(n, m, dtype=self.dtype) + if self.eyelike_k == 0: + return torch.eye(input_value_n, input_value_m, dtype=self.dtype) - k_tensor = torch.zeros(n, self.k, dtype=self.dtype) - eye_tensor = torch.eye(n, m - self.k, dtype=self.dtype) + k_tensor = torch.zeros(input_value_n, self.eyelike_k, dtype=self.dtype) + eye_tensor = torch.eye(input_value_n, input_value_m - self.eyelike_k, dtype=self.dtype) return torch.concat([k_tensor, eye_tensor], axis=1) @add_converter(operation_type='EyeLike', version=9) def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument node_attributes = node.attributes - k = node_attributes.get('k', 0) + eyelike_k = node_attributes.get('k', 0) dtype = node_attributes.get('dtype', torch.float32) return OperationConverterResult( - torch_module=OnnxEyeLike(dtype=dtype, k=k), + torch_module=OnnxEyeLike(dtype=dtype, eyelike_k=eyelike_k), onnx_mapping=onnx_mapping_from_node(node=node), ) diff --git a/tests/node_converters/eye_test.py b/tests/node_converters/eye_test.py index 8cbb3140..55e96164 100644 --- a/tests/node_converters/eye_test.py +++ b/tests/node_converters/eye_test.py @@ -7,7 +7,7 @@ @pytest.mark.parametrize('dtype', ((None),)) -@pytest.mark.parametrize('k', [0, 1, 2]) +@pytest.mark.parametrize('eyelike_k', [0, 1, 2]) @pytest.mark.parametrize( 'input_shapes', ( @@ -19,13 +19,13 @@ def test_eye( # pylint: disable=missing-function-docstring input_shapes: Tuple[int], dtype: str, - k: int, + eyelike_k: int, ) -> None: input_values = np.random.randint(0, 100, size=input_shapes) test_inputs = {'x': input_values} - node = onnx.helper.make_node(op_type='EyeLike', inputs=['x'], outputs=['z'], k=k, dtype=dtype) + node = onnx.helper.make_node(op_type='EyeLike', inputs=['x'], outputs=['z'], k=eyelike_k, dtype=dtype) model = make_model_from_nodes( nodes=node, initializers={}, From 702fa311bf50007b4110b38e3377807c896b4520 Mon Sep 17 00:00:00 2001 From: Mason Ma Date: Thu, 23 Feb 2023 16:28:20 +0800 Subject: [PATCH 3/5] reformatted code --- onnx2torch/node_converters/eye.py | 4 +++- tests/node_converters/eye_test.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/onnx2torch/node_converters/eye.py b/onnx2torch/node_converters/eye.py index ec96e286..457bbaa0 100644 --- a/onnx2torch/node_converters/eye.py +++ b/onnx2torch/node_converters/eye.py @@ -8,7 +8,9 @@ from onnx2torch.node_converters.registry import add_converter from onnx2torch.onnx_graph import OnnxGraph from onnx2torch.onnx_node import OnnxNode -from onnx2torch.utils.common import OnnxToTorchModule, OperationConverterResult, onnx_mapping_from_node +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import onnx_mapping_from_node class OnnxEyeLike(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring diff --git a/tests/node_converters/eye_test.py b/tests/node_converters/eye_test.py index 55e96164..afa48a80 100644 --- a/tests/node_converters/eye_test.py +++ b/tests/node_converters/eye_test.py @@ -3,7 +3,9 @@ import numpy as np import onnx import pytest -from tests.utils.common import check_onnx_model, make_model_from_nodes + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes @pytest.mark.parametrize('dtype', ((None),)) From f6668ed7ab38c4c4a0ddbab1e9e9a29b002c7aa9 Mon Sep 17 00:00:00 2001 From: Mason Ma Date: Fri, 24 Feb 2023 13:24:57 +0800 Subject: [PATCH 4/5] feat: LayerNorm --- onnx2torch/node_converters/__init__.py | 1 + onnx2torch/node_converters/layer_norm.py | 70 ++++++++++++++++++++++++ operators.md | 1 + tests/node_converters/layer_norm_test.py | 67 +++++++++++++++++++++++ 4 files changed, 139 insertions(+) create mode 100644 onnx2torch/node_converters/layer_norm.py create mode 100644 tests/node_converters/layer_norm_test.py diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index 571bf82b..f6d88d64 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -21,6 +21,7 @@ from onnx2torch.node_converters.global_average_pool import * from onnx2torch.node_converters.identity import * from onnx2torch.node_converters.instance_norm import * +from onnx2torch.node_converters.layer_norm import * from onnx2torch.node_converters.logical import * from onnx2torch.node_converters.lrn import * from onnx2torch.node_converters.matmul import * diff --git a/onnx2torch/node_converters/layer_norm.py b/onnx2torch/node_converters/layer_norm.py new file mode 100644 index 00000000..7942946c --- /dev/null +++ b/onnx2torch/node_converters/layer_norm.py @@ -0,0 +1,70 @@ +__all__ = [ + 'OnnxLayerNorm', +] + +import torch +import torch.nn.functional as F +from torch import nn + +from onnx2torch.node_converters.registry import add_converter +from onnx2torch.onnx_graph import OnnxGraph +from onnx2torch.onnx_node import OnnxNode +from onnx2torch.utils.common import OnnxMapping +from onnx2torch.utils.common import OnnxToTorchModule +from onnx2torch.utils.common import OperationConverterResult +from onnx2torch.utils.common import get_shape_from_value_info +from onnx2torch.utils.common import onnx_mapping_from_node + + +class OnnxLayerNorm(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring + def __init__(self, axis: int, epsilon: float, stash_type: int): + super().__init__() + self.axis = axis + self.epsilon = epsilon + self.stash_type = stash_type + + def forward( # pylint: disable=missing-function-docstring + self, + input_data: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + ) -> torch.Tensor: + return F.layer_norm( + input_data, + normalized_shape, + weight = weight, + bias = bias, + eps = self.epsilon + ) + + +@add_converter(operation_type='LayerNormalization', version=17) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + node_attributes = node.attributes + axis = node_attributes.get('axis', -1) + epsilon = node_attributes.get('epsilon', 1e-5) + stash_type = node_attributes.get('stash_type', 1) + if all(value_name in graph.initializers for value_name in node.input_values[1:]): + input_value_info = graph.value_info[node.input_values[0]] + input_shape = get_shape_from_value_info(input_value_info) + + scale_value_name = node.input_values[1] + bias_value_name = node.input_values[2] + + scale = graph.initializers[scale_value_name].to_torch() + torch_module = nn.LayerNorm( + input_shape[axis], + eps=epsilon, + elementwise_affine=True + ) + + with torch.no_grad(): + torch_module.weight.data = graph.initializers[scale_value_name].to_torch() + torch_module.bias.data = graph.initializers[bias_value_name].to_torch() + + onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values) + else: + torch_module = OnnxLayerNorm(momentum=momentum, axis=axis, epsilon=epsilon, stash_type=stash_type) + onnx_mapping = onnx_mapping_from_node(node) + + return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping) diff --git a/operators.md b/operators.md index f2db4396..5610c245 100644 --- a/operators.md +++ b/operators.md @@ -63,6 +63,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops | InstanceNormalization | Y | | | IsInf | N | | | IsNaN | N | | +| LayerNormalization | Y | LayerNormalization outputs "Mean" and "InvStdDev" is not implemented | | LRN | Y | | | LSTM | N | | | LeakyRelu | Y | | diff --git a/tests/node_converters/layer_norm_test.py b/tests/node_converters/layer_norm_test.py new file mode 100644 index 00000000..f1594e78 --- /dev/null +++ b/tests/node_converters/layer_norm_test.py @@ -0,0 +1,67 @@ +from typing import List + +import numpy as np +import onnx +import pytest + +from tests.utils.common import check_onnx_model +from tests.utils.common import make_model_from_nodes + +def _test_layer_norm( + x: np.ndarray, + parameters_as_inputs: bool, +) -> None: + normalized_shape = calculate_normalized_shape(x.shape, -1) + scale = np.random.randn(*normalized_shape).astype(np.float32) + bias = np.random.randn(*normalized_shape).astype(np.float32) + + inputs = {'input': x} + parameters = {'scale': scale,'bias': bias} + initializers = {} + + if parameters_as_inputs: + inputs.update(parameters) + else: + initializers.update(parameters) + + node = onnx.helper.make_node( + op_type='LayerNormalization', + inputs=['input', 'scale', 'bias'], + outputs=['y'], + ) + model = make_model_from_nodes(nodes=node, initializers=initializers, inputs_example=inputs, opset_version=17) + check_onnx_model(onnx_model=model, onnx_inputs=inputs, atol_onnx_torch=1e-6, atol_torch_cpu_cuda=1e-6) + + +# @pytest.mark.parametrize( + # 'parameters_as_inputs', + # (True, False), +# ) +@pytest.mark.parametrize( + 'input_shape', + ( + # 1d + [2, 3, 16], + [2, 1, 7], + # # 2d + [2, 3, 16, 16], + [2, 1, 7, 16], + # # 3d + [2, 3, 16, 16, 16], + [2, 1, 16, 7, 16], + ), +) +def test_layer_norm( # pylint: disable=missing-function-docstring + input_shape: List[int], + parameters_as_inputs: bool = False, +) -> None: + x = np.random.randn(*input_shape).astype(np.float32) + + _test_layer_norm(x=x, parameters_as_inputs=parameters_as_inputs) + + +def calculate_normalized_shape(X_shape, axis): # type: ignore + X_rank = len(X_shape) + if axis < 0: + axis = axis + X_rank + return X_shape[axis:] From 2a9443aa21f8a5d7ef9a1c15c12cc2e341c6f0f1 Mon Sep 17 00:00:00 2001 From: Mason Ma Date: Fri, 24 Feb 2023 13:39:40 +0800 Subject: [PATCH 5/5] resolving conflict --- onnx2torch/node_converters/__init__.py | 1 - onnx2torch/node_converters/eye.py | 50 ------------------------ onnx2torch/node_converters/layer_norm.py | 22 ++++------- tests/node_converters/eye_test.py | 36 ----------------- tests/node_converters/layer_norm_test.py | 15 +++---- 5 files changed, 16 insertions(+), 108 deletions(-) delete mode 100644 onnx2torch/node_converters/eye.py delete mode 100644 tests/node_converters/eye_test.py diff --git a/onnx2torch/node_converters/__init__.py b/onnx2torch/node_converters/__init__.py index f6d88d64..3ed8cc25 100644 --- a/onnx2torch/node_converters/__init__.py +++ b/onnx2torch/node_converters/__init__.py @@ -13,7 +13,6 @@ from onnx2torch.node_converters.dropout import * from onnx2torch.node_converters.einsum import * from onnx2torch.node_converters.expand import * -from onnx2torch.node_converters.eye import * from onnx2torch.node_converters.flatten import * from onnx2torch.node_converters.functions import * from onnx2torch.node_converters.gather import * diff --git a/onnx2torch/node_converters/eye.py b/onnx2torch/node_converters/eye.py deleted file mode 100644 index 457bbaa0..00000000 --- a/onnx2torch/node_converters/eye.py +++ /dev/null @@ -1,50 +0,0 @@ -__all__ = [ - 'OnnxEyeLike', -] - -import torch -from torch import nn - -from onnx2torch.node_converters.registry import add_converter -from onnx2torch.onnx_graph import OnnxGraph -from onnx2torch.onnx_node import OnnxNode -from onnx2torch.utils.common import OnnxToTorchModule -from onnx2torch.utils.common import OperationConverterResult -from onnx2torch.utils.common import onnx_mapping_from_node - - -class OnnxEyeLike(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring - def __init__(self, eyelike_k: int, dtype: torch.dtype = None): - super().__init__() - self.dtype = dtype - self.eyelike_k = eyelike_k - - def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring - if len(x.shape) != 2: - raise ValueError('OnnxEyeLike only support 2-D tensor') - - input_value_n = x.size(dim=0) - input_value_m = x.size(dim=1) - if self.eyelike_k > input_value_n: - raise ValueError( - f'Error EyeLike Attribute k value, the k value is {self.eyelike_k},' - 'but x shape is {(input_value_n, input_value_m)}' - ) - - if self.eyelike_k == 0: - return torch.eye(input_value_n, input_value_m, dtype=self.dtype) - - k_tensor = torch.zeros(input_value_n, self.eyelike_k, dtype=self.dtype) - eye_tensor = torch.eye(input_value_n, input_value_m - self.eyelike_k, dtype=self.dtype) - return torch.concat([k_tensor, eye_tensor], axis=1) - - -@add_converter(operation_type='EyeLike', version=9) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument - node_attributes = node.attributes - eyelike_k = node_attributes.get('k', 0) - dtype = node_attributes.get('dtype', torch.float32) - return OperationConverterResult( - torch_module=OnnxEyeLike(dtype=dtype, eyelike_k=eyelike_k), - onnx_mapping=onnx_mapping_from_node(node=node), - ) diff --git a/onnx2torch/node_converters/layer_norm.py b/onnx2torch/node_converters/layer_norm.py index 7942946c..e1928a4d 100644 --- a/onnx2torch/node_converters/layer_norm.py +++ b/onnx2torch/node_converters/layer_norm.py @@ -2,6 +2,8 @@ 'OnnxLayerNorm', ] +from typing import List + import torch import torch.nn.functional as F from torch import nn @@ -26,16 +28,11 @@ def __init__(self, axis: int, epsilon: float, stash_type: int): def forward( # pylint: disable=missing-function-docstring self, input_data: torch.Tensor, + normalized_shape: List[int], weight: torch.Tensor, bias: torch.Tensor, ) -> torch.Tensor: - return F.layer_norm( - input_data, - normalized_shape, - weight = weight, - bias = bias, - eps = self.epsilon - ) + return F.layer_norm(input_data, normalized_shape, weight=weight, bias=bias, eps=self.epsilon) @add_converter(operation_type='LayerNormalization', version=17) @@ -51,12 +48,7 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: scale_value_name = node.input_values[1] bias_value_name = node.input_values[2] - scale = graph.initializers[scale_value_name].to_torch() - torch_module = nn.LayerNorm( - input_shape[axis], - eps=epsilon, - elementwise_affine=True - ) + torch_module = nn.LayerNorm(input_shape[axis], eps=epsilon, elementwise_affine=True) with torch.no_grad(): torch_module.weight.data = graph.initializers[scale_value_name].to_torch() @@ -64,7 +56,9 @@ def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: onnx_mapping = OnnxMapping(inputs=(node.input_values[0],), outputs=node.output_values) else: - torch_module = OnnxLayerNorm(momentum=momentum, axis=axis, epsilon=epsilon, stash_type=stash_type) + input_value_info = graph.value_info[node.input_values[0]] + input_shape = get_shape_from_value_info(input_value_info) + torch_module = OnnxLayerNorm(axis=axis, epsilon=epsilon, stash_type=stash_type) onnx_mapping = onnx_mapping_from_node(node) return OperationConverterResult(torch_module=torch_module, onnx_mapping=onnx_mapping) diff --git a/tests/node_converters/eye_test.py b/tests/node_converters/eye_test.py deleted file mode 100644 index afa48a80..00000000 --- a/tests/node_converters/eye_test.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Tuple - -import numpy as np -import onnx -import pytest - -from tests.utils.common import check_onnx_model -from tests.utils.common import make_model_from_nodes - - -@pytest.mark.parametrize('dtype', ((None),)) -@pytest.mark.parametrize('eyelike_k', [0, 1, 2]) -@pytest.mark.parametrize( - 'input_shapes', - ( - ((2, 3)), - ((3, 4)), - ((3, 3)), - ), -) -def test_eye( # pylint: disable=missing-function-docstring - input_shapes: Tuple[int], - dtype: str, - eyelike_k: int, -) -> None: - input_values = np.random.randint(0, 100, size=input_shapes) - - test_inputs = {'x': input_values} - - node = onnx.helper.make_node(op_type='EyeLike', inputs=['x'], outputs=['z'], k=eyelike_k, dtype=dtype) - model = make_model_from_nodes( - nodes=node, - initializers={}, - inputs_example=test_inputs, - ) - check_onnx_model(model, test_inputs) diff --git a/tests/node_converters/layer_norm_test.py b/tests/node_converters/layer_norm_test.py index f1594e78..54807b94 100644 --- a/tests/node_converters/layer_norm_test.py +++ b/tests/node_converters/layer_norm_test.py @@ -7,6 +7,7 @@ from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes + def _test_layer_norm( x: np.ndarray, parameters_as_inputs: bool, @@ -16,7 +17,7 @@ def _test_layer_norm( bias = np.random.randn(*normalized_shape).astype(np.float32) inputs = {'input': x} - parameters = {'scale': scale,'bias': bias} + parameters = {'scale': scale, 'bias': bias} initializers = {} if parameters_as_inputs: @@ -34,8 +35,8 @@ def _test_layer_norm( # @pytest.mark.parametrize( - # 'parameters_as_inputs', - # (True, False), +# 'parameters_as_inputs', +# (True, False), # ) @pytest.mark.parametrize( 'input_shape', @@ -60,8 +61,8 @@ def test_layer_norm( # pylint: disable=missing-function-docstring _test_layer_norm(x=x, parameters_as_inputs=parameters_as_inputs) -def calculate_normalized_shape(X_shape, axis): # type: ignore - X_rank = len(X_shape) +def calculate_normalized_shape(x_shape, axis): # pylint: disable=missing-function-docstring + x_rank = len(x_shape) if axis < 0: - axis = axis + X_rank - return X_shape[axis:] + axis = axis + x_rank + return x_shape[axis:]