Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: linters #222

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions onnx2torch/node_converters/arg_extrema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# pylint: disable=missing-docstring
__all__ = [
"OnnxArgExtremumOld",
"OnnxArgExtremum",
'OnnxArgExtremumOld',
'OnnxArgExtremum',
]

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from onnx2torch.node_converters.registry import add_converter
Expand All @@ -21,31 +19,31 @@
DEFAULT_SELECT_LAST_INDEX = 0

_TORCH_FUNCTION_FROM_ONNX_TYPE = {
"ArgMax": torch.argmax,
"ArgMin": torch.argmin,
'ArgMax': torch.argmax,
'ArgMin': torch.argmin,
}


class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule):
def __init__(self, operation_type: str, axis: int, keepdims: int):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]

def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)


class OnnxArgExtremum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
class OnnxArgExtremum(nn.Module, OnnxToTorchModule):
def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_index: int):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.select_last_index = bool(select_last_index)
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]

def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
def forward(self, data: torch.Tensor) -> torch.Tensor:
if self.select_last_index:
# torch's argmax does not handle the select_last_index attribute from Onnx.
# We flip the data, call the normal argmax, then map it back to the original
Expand All @@ -54,34 +52,36 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missin
extremum_index_flipped = self.extremum_function(flipped, dim=self.axis, keepdim=self.keepdims)
extremum_index_original = data.size(dim=self.axis) - 1 - extremum_index_flipped
return extremum_index_original
else:
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)

return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)


@add_converter(operation_type="ArgMax", version=12)
@add_converter(operation_type="ArgMax", version=13)
@add_converter(operation_type="ArgMin", version=12)
@add_converter(operation_type="ArgMin", version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
@add_converter(operation_type='ArgMax', version=12)
@add_converter(operation_type='ArgMax', version=13)
@add_converter(operation_type='ArgMin', version=12)
@add_converter(operation_type='ArgMin', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
return OperationConverterResult(
torch_module=OnnxArgExtremum(
operation_type=node.operation_type,
axis=node.attributes.get("axis", DEFAULT_AXIS),
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
select_last_index=node.attributes.get("select_last_index", DEFAULT_SELECT_LAST_INDEX),
axis=node.attributes.get('axis', DEFAULT_AXIS),
keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS),
select_last_index=node.attributes.get('select_last_index', DEFAULT_SELECT_LAST_INDEX),
),
onnx_mapping=onnx_mapping_from_node(node=node),
)


@add_converter(operation_type="ArgMax", version=11)
@add_converter(operation_type="ArgMin", version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
@add_converter(operation_type='ArgMax', version=11)
@add_converter(operation_type='ArgMin', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
return OperationConverterResult(
torch_module=OnnxArgExtremumOld(
operation_type=node.operation_type,
axis=node.attributes.get("axis", DEFAULT_AXIS),
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
axis=node.attributes.get('axis', DEFAULT_AXIS),
keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS),
),
onnx_mapping=onnx_mapping_from_node(node=node),
)
2 changes: 1 addition & 1 deletion onnx2torch/utils/custom_export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def export(cls, forward_function: Callable, *args) -> Any:
return cls.apply(*args)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument, arguments-differ
"""Applies custom forward function."""
if CustomExportToOnnx._NEXT_FORWARD_FUNCTION is None:
raise RuntimeError('Forward function is not set')
Expand Down
40 changes: 11 additions & 29 deletions tests/node_converters/arg_extrema_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# pylint: disable=missing-docstring
from pathlib import Path

import numpy as np
import onnx
from onnx.helper import make_tensor_value_info
import pytest
import torch
from onnx.helper import make_tensor_value_info

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes
Expand Down Expand Up @@ -51,7 +52,7 @@
"select_last_index",
(0, 1),
)
def test_arg_max_arg_min( # pylint: disable=missing-function-docstring
def test_arg_max_arg_min(
op_type: str,
opset_version: int,
dims: int,
Expand Down Expand Up @@ -95,7 +96,7 @@ class ArgMaxModel(torch.nn.Module):
def __init__(self, axis: int, keepdims: bool):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.keepdims = keepdims

def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.argmax(data, dim=self.axis, keepdim=self.keepdims)
Expand All @@ -105,29 +106,16 @@ class ArgMinModel(torch.nn.Module):
def __init__(self, axis: int, keepdims: bool):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.keepdims = keepdims

def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.argmin(data, dim=self.axis, keepdim=self.keepdims)


@pytest.mark.parametrize("op_type", ["ArgMax", "ArgMin"])
@pytest.mark.parametrize("opset_version", [11, 12, 13])
@pytest.mark.parametrize(
"op_type",
(
"ArgMax",
"ArgMin",
),
)
@pytest.mark.parametrize(
"opset_version",
(
11,
12,
13,
),
)
@pytest.mark.parametrize(
"dims,axis",
"dims, axis",
(
(1, 0),
(2, 0),
Expand All @@ -141,19 +129,13 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
(4, 3),
),
)
@pytest.mark.parametrize(
"keepdims",
(
0,
1,
),
)
@pytest.mark.parametrize("keepdims", [True, False])
def test_start_from_torch_module(
op_type: str,
opset_version: int,
dims: int,
axis: int,
keepdims: int,
keepdims: bool,
tmp_path: Path,
) -> None:
"""
Expand All @@ -179,7 +161,7 @@ def test_start_from_torch_module(
input_names=input_names,
output_names=output_names,
do_constant_folding=False,
training=torch._C._onnx.TrainingMode.TRAINING,
opset_version=opset_version,
)

# load the exported onnx file
Expand Down
5 changes: 3 additions & 2 deletions tests/node_converters/conv_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import chain
from itertools import product
from typing import Literal
from typing import Tuple

import numpy as np
Expand All @@ -10,7 +11,7 @@


def _test_conv(
op_type: str,
op_type: Literal['Conv', 'ConvTranspose'],
in_channels: int,
out_channels: int,
kernel_shape: Tuple[int, int],
Expand All @@ -23,7 +24,7 @@ def _test_conv(
x = np.random.uniform(low=-1.0, high=1.0, size=x_shape).astype(np.float32)
if op_type == 'Conv':
weights_shape = (out_channels, in_channels // group) + kernel_shape
elif op_type == 'ConvTranspose':
else: # ConvTranspose
weights_shape = (in_channels, out_channels // group) + kernel_shape
weights = np.random.uniform(low=-1.0, high=1.0, size=weights_shape).astype(np.float32)

Expand Down
Loading