Skip to content

Commit 453783f

Browse files
committed
Remove try-catch block and apply fixes to enable torch.onnx.dynamo_export to succeed
1 parent ebc1b96 commit 453783f

File tree

1 file changed

+16
-32
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+16
-32
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

+16-32
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from __future__ import annotations
1414

1515
import math
16-
import re
1716
from typing import Any, Optional, Sequence, Tuple, Union
1817

1918
from onnxscript import (
@@ -8360,13 +8359,6 @@ def aten_unique_consecutive(
83608359
raise NotImplementedError()
83618360

83628361

8363-
_NOT_IMPLEMENTED_UNIQUE = re.compile(
8364-
r"NOT_IMPLEMENTED\s*:\s*Could\s+not\s+find\s+an\s+implementation\s+for\s+Unique"
8365-
)
8366-
"""
8367-
A pattern to detect an unsupported (not implemented) Unique operator
8368-
"""
8369-
83708362
@torch_op("aten::unique", trace_only=True)
83718363
def aten_unique(
83728364
self: TensorType,
@@ -8377,18 +8369,10 @@ def aten_unique(
83778369
) -> tuple[TensorType, TensorType, TensorType]:
83788370
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor?, Tensor?)"""
83798371

8380-
try:
8381-
if dim is None:
8382-
unique_values, inverse_indices, counts = aten_unique2(self, sorted, return_inverse, return_counts)
8383-
else:
8384-
unique_values, inverse_indices, counts = aten_unique_dim(self, dim, sorted, return_inverse, return_counts)
8385-
except Exception as e:
8386-
# try to provide a more informative error message
8387-
if _NOT_IMPLEMENTED_UNIQUE.search(str(e)) is not None:
8388-
raise NotImplementedError(
8389-
f"'onnxruntime' does not yet support Unique(11) operator with dtype={self.dtype}'"
8390-
) from e
8391-
raise
8372+
if dim is None:
8373+
unique_values, inverse_indices, counts = aten_unique2(self, sorted, return_inverse, return_counts)
8374+
else:
8375+
unique_values, inverse_indices, counts = aten_unique_dim(self, dim, sorted, return_inverse, return_counts)
83928376
if return_inverse:
83938377
if return_counts:
83948378
result = unique_values, inverse_indices, counts
@@ -8410,7 +8394,11 @@ def aten_unique2(
84108394
) -> tuple[TensorType, TensorType, TensorType]:
84118395
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
84128396

8413-
unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
8397+
unique_values, indices, inverse_indices, counts = op.Unique(self, axis=None, sorted=sorted)
8398+
# HACK: force indices to be in the graph so that it gets a name during optimization
8399+
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8400+
indices_size = op.Shape(indices)
8401+
counts = op.Reshape(counts, indices_size)
84148402
input_size = op.Shape(self)
84158403
inverse_indices = op.Reshape(inverse_indices, input_size)
84168404
return unique_values, inverse_indices, counts
@@ -8426,19 +8414,15 @@ def aten_unique_dim(
84268414
) -> tuple[TensorType, TensorType, TensorType]:
84278415
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
84288416

8429-
unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=sorted)
8417+
unique_values, indices, inverse_indices, counts = op.Unique(self, axis=dim, sorted=sorted)
8418+
# HACK: force indices to be in the graph so that it gets a name during optimization
8419+
# Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8420+
indices_size = op.Shape(indices)
8421+
counts = op.Reshape(counts, indices_size)
84308422
input_size = op.Shape(self)
8431-
# PyTorch accepts negative dim as reversed counting
8432-
input_rank = op.Size(input_size)
8433-
dim = input_rank + dim
8434-
dim = dim % input_rank
8435-
starts = op.Reshape(dim, [-1])
8436-
ends = op.Reshape(dim + 1, [-1])
8437-
input_dim_size = op.Slice(input_size, starts=starts, ends=ends)
8438-
inverse_indices = op.Reshape(inverse_indices, input_dim_size)
8423+
inverse_indices = op.Reshape(inverse_indices, op.Reshape(input_size[dim], [-1]))
84398424
output_size = op.Shape(unique_values)
8440-
output_dim_size = op.Slice(output_size, starts=starts, ends=ends)
8441-
counts = op.Reshape(counts, output_dim_size)
8425+
counts = op.Reshape(counts, op.Reshape(output_size[dim], [-1]))
84428426
return unique_values, inverse_indices, counts
84438427

84448428

0 commit comments

Comments
 (0)