From ca132c324a5a46764ed1a017df6074032caa1033 Mon Sep 17 00:00:00 2001 From: Arseny Date: Tue, 30 Jul 2024 21:42:28 +0200 Subject: [PATCH] chore: linters --- onnx2torch/node_converters/arg_extrema.py | 52 +++++++++++------------ onnx2torch/utils/custom_export_to_onnx.py | 2 +- tests/node_converters/arg_extrema_test.py | 40 +++++------------ tests/node_converters/conv_test.py | 5 ++- 4 files changed, 41 insertions(+), 58 deletions(-) diff --git a/onnx2torch/node_converters/arg_extrema.py b/onnx2torch/node_converters/arg_extrema.py index 6f1c7fb4..b342b91e 100644 --- a/onnx2torch/node_converters/arg_extrema.py +++ b/onnx2torch/node_converters/arg_extrema.py @@ -1,12 +1,10 @@ +# pylint: disable=missing-docstring __all__ = [ - "OnnxArgExtremumOld", - "OnnxArgExtremum", + 'OnnxArgExtremumOld', + 'OnnxArgExtremum', ] -from typing import Optional - import torch -import torch.nn.functional as F from torch import nn from onnx2torch.node_converters.registry import add_converter @@ -21,23 +19,23 @@ DEFAULT_SELECT_LAST_INDEX = 0 _TORCH_FUNCTION_FROM_ONNX_TYPE = { - "ArgMax": torch.argmax, - "ArgMin": torch.argmin, + 'ArgMax': torch.argmax, + 'ArgMin': torch.argmin, } -class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring +class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): def __init__(self, operation_type: str, axis: int, keepdims: int): super().__init__() self.axis = axis self.keepdims = bool(keepdims) self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] - def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + def forward(self, data: torch.Tensor) -> torch.Tensor: return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims) -class OnnxArgExtremum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring +class OnnxArgExtremum(nn.Module, OnnxToTorchModule): def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_index: int): super().__init__() self.axis = axis @@ -45,7 +43,7 @@ def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_in self.select_last_index = bool(select_last_index) self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type] - def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring + def forward(self, data: torch.Tensor) -> torch.Tensor: if self.select_last_index: # torch's argmax does not handle the select_last_index attribute from Onnx. # We flip the data, call the normal argmax, then map it back to the original @@ -54,34 +52,36 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missin extremum_index_flipped = self.extremum_function(flipped, dim=self.axis, keepdim=self.keepdims) extremum_index_original = data.size(dim=self.axis) - 1 - extremum_index_flipped return extremum_index_original - else: - return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims) + + return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims) -@add_converter(operation_type="ArgMax", version=12) -@add_converter(operation_type="ArgMax", version=13) -@add_converter(operation_type="ArgMin", version=12) -@add_converter(operation_type="ArgMin", version=13) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +@add_converter(operation_type='ArgMax', version=12) +@add_converter(operation_type='ArgMax', version=13) +@add_converter(operation_type='ArgMin', version=12) +@add_converter(operation_type='ArgMin', version=13) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph return OperationConverterResult( torch_module=OnnxArgExtremum( operation_type=node.operation_type, - axis=node.attributes.get("axis", DEFAULT_AXIS), - keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS), - select_last_index=node.attributes.get("select_last_index", DEFAULT_SELECT_LAST_INDEX), + axis=node.attributes.get('axis', DEFAULT_AXIS), + keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS), + select_last_index=node.attributes.get('select_last_index', DEFAULT_SELECT_LAST_INDEX), ), onnx_mapping=onnx_mapping_from_node(node=node), ) -@add_converter(operation_type="ArgMax", version=11) -@add_converter(operation_type="ArgMin", version=11) -def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument +@add_converter(operation_type='ArgMax', version=11) +@add_converter(operation_type='ArgMin', version=11) +def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: + del graph return OperationConverterResult( torch_module=OnnxArgExtremumOld( operation_type=node.operation_type, - axis=node.attributes.get("axis", DEFAULT_AXIS), - keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS), + axis=node.attributes.get('axis', DEFAULT_AXIS), + keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS), ), onnx_mapping=onnx_mapping_from_node(node=node), ) diff --git a/onnx2torch/utils/custom_export_to_onnx.py b/onnx2torch/utils/custom_export_to_onnx.py index 78e80297..bd0cae4b 100644 --- a/onnx2torch/utils/custom_export_to_onnx.py +++ b/onnx2torch/utils/custom_export_to_onnx.py @@ -57,7 +57,7 @@ def export(cls, forward_function: Callable, *args) -> Any: return cls.apply(*args) @staticmethod - def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument + def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument, arguments-differ """Applies custom forward function.""" if CustomExportToOnnx._NEXT_FORWARD_FUNCTION is None: raise RuntimeError('Forward function is not set') diff --git a/tests/node_converters/arg_extrema_test.py b/tests/node_converters/arg_extrema_test.py index 4b2fd9dc..5c66ca8f 100644 --- a/tests/node_converters/arg_extrema_test.py +++ b/tests/node_converters/arg_extrema_test.py @@ -1,10 +1,11 @@ +# pylint: disable=missing-docstring from pathlib import Path import numpy as np import onnx -from onnx.helper import make_tensor_value_info import pytest import torch +from onnx.helper import make_tensor_value_info from tests.utils.common import check_onnx_model from tests.utils.common import make_model_from_nodes @@ -51,7 +52,7 @@ "select_last_index", (0, 1), ) -def test_arg_max_arg_min( # pylint: disable=missing-function-docstring +def test_arg_max_arg_min( op_type: str, opset_version: int, dims: int, @@ -95,7 +96,7 @@ class ArgMaxModel(torch.nn.Module): def __init__(self, axis: int, keepdims: bool): super().__init__() self.axis = axis - self.keepdims = bool(keepdims) + self.keepdims = keepdims def forward(self, data: torch.Tensor) -> torch.Tensor: return torch.argmax(data, dim=self.axis, keepdim=self.keepdims) @@ -105,29 +106,16 @@ class ArgMinModel(torch.nn.Module): def __init__(self, axis: int, keepdims: bool): super().__init__() self.axis = axis - self.keepdims = bool(keepdims) + self.keepdims = keepdims def forward(self, data: torch.Tensor) -> torch.Tensor: return torch.argmin(data, dim=self.axis, keepdim=self.keepdims) +@pytest.mark.parametrize("op_type", ["ArgMax", "ArgMin"]) +@pytest.mark.parametrize("opset_version", [11, 12, 13]) @pytest.mark.parametrize( - "op_type", - ( - "ArgMax", - "ArgMin", - ), -) -@pytest.mark.parametrize( - "opset_version", - ( - 11, - 12, - 13, - ), -) -@pytest.mark.parametrize( - "dims,axis", + "dims, axis", ( (1, 0), (2, 0), @@ -141,19 +129,13 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: (4, 3), ), ) -@pytest.mark.parametrize( - "keepdims", - ( - 0, - 1, - ), -) +@pytest.mark.parametrize("keepdims", [True, False]) def test_start_from_torch_module( op_type: str, opset_version: int, dims: int, axis: int, - keepdims: int, + keepdims: bool, tmp_path: Path, ) -> None: """ @@ -179,7 +161,7 @@ def test_start_from_torch_module( input_names=input_names, output_names=output_names, do_constant_folding=False, - training=torch._C._onnx.TrainingMode.TRAINING, + opset_version=opset_version, ) # load the exported onnx file diff --git a/tests/node_converters/conv_test.py b/tests/node_converters/conv_test.py index 72dcf3ff..9efa892e 100644 --- a/tests/node_converters/conv_test.py +++ b/tests/node_converters/conv_test.py @@ -1,5 +1,6 @@ from itertools import chain from itertools import product +from typing import Literal from typing import Tuple import numpy as np @@ -10,7 +11,7 @@ def _test_conv( - op_type: str, + op_type: Literal['Conv', 'ConvTranspose'], in_channels: int, out_channels: int, kernel_shape: Tuple[int, int], @@ -23,7 +24,7 @@ def _test_conv( x = np.random.uniform(low=-1.0, high=1.0, size=x_shape).astype(np.float32) if op_type == 'Conv': weights_shape = (out_channels, in_channels // group) + kernel_shape - elif op_type == 'ConvTranspose': + else: # ConvTranspose weights_shape = (in_channels, out_channels // group) + kernel_shape weights = np.random.uniform(low=-1.0, high=1.0, size=weights_shape).astype(np.float32)