Skip to content

Commit

Permalink
Add support for depth_to_space operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
sushanthr committed Jul 9, 2023
1 parent eb86d59 commit 0855f89
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onnx2torch.node_converters.constant_of_shape import *
from onnx2torch.node_converters.conv import *
from onnx2torch.node_converters.cumsum import *
from onnx2torch.node_converters.depth_to_space import *
from onnx2torch.node_converters.dropout import *
from onnx2torch.node_converters.einsum import *
from onnx2torch.node_converters.expand import *
Expand Down
33 changes: 33 additions & 0 deletions onnx2torch/node_converters/depth_to_space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
__all__ = [
'OnnxDepthToSpace',
]

import torch
from torch import nn

from onnx2torch.node_converters.base_element_wise import OnnxBaseElementWise
from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node
from onnx2torch.utils.common import OnnxToTorchModule


class OnnxDepthToSpace(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, blocksize):
super().__init__()
self._upscale_factor = blocksize

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
return torch.pixel_shuffle(input_tensor, self._upscale_factor);


@add_converter(operation_type='DepthToSpace', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
if node.attributes.get("mode") != "CRD":
raise NotImplementedError('DepthToSpace for mode other than CRD is not implemented')
return OperationConverterResult(
torch_module=OnnxDepthToSpace(node.attributes.get("blocksize")),
onnx_mapping=onnx_mapping_from_node(node=node),
)
2 changes: 1 addition & 1 deletion operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Minimal tested opset version 9, maximum tested opset version 16, recommended ops
| Cos | Y | |
| Cosh | N | |
| CumSum | Y | |
| DepthToSpace | N | |
| DepthToSpace | Y | Partial support for CRD mode |
| DequantizeLinear | N | |
| Det | N | |
| Div | Y | |
Expand Down

0 comments on commit 0855f89

Please sign in to comment.