Skip to content

Commit

Permalink
Add unique op (#1547)
Browse files Browse the repository at this point in the history
Add support for exporting `torch.unique` following the conclusion of
pytorch/pytorch#113118.

---------

Co-authored-by: Justin Chu <[email protected]>
  • Loading branch information
a-gardner1 and justinchuby authored Mar 7, 2025
1 parent ddce766 commit 4c1cda2
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 2 deletions.
72 changes: 70 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8591,16 +8591,84 @@ def aten_unique_consecutive(
raise NotImplementedError()


@torch_op("aten::_unique", trace_only=True)
def aten__unique(
self: TensorType,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
) -> tuple[TensorType, TensorType]:
"""_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)"""

unique_values, _, inverse_indices, _ = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
return unique_values, inverse_indices


@torch_op("aten::_unique2", trace_only=True)
def aten__unique2(
self: TensorType,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
return_counts: bool = False,
) -> tuple[TensorType, TensorType, TensorType]:
"""_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

unique_values, _, inverse_indices, counts = op.Unique(self, axis=None, sorted=True)
input_size = op.Shape(self)
if return_inverse:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
input_numel = op.ReduceProd(input_size, keepdims=False)
if input_numel == 0:
inverse_indices = op.Reshape(inverse_indices, input_size)
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
if not return_counts:
counts = op.ConstantOfShape([0])
counts = op.Cast(counts, to=INT64.dtype)
return unique_values, inverse_indices, counts


@torch_op("aten::unique_dim", trace_only=True)
def aten_unique_dim(
self: TensorType,
dim: int,
sorted: bool = True,
sorted: bool = True, # pylint: disable=unused-argument
return_inverse: bool = False,
return_counts: bool = False,
) -> tuple[TensorType, TensorType, TensorType]:
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""

raise NotImplementedError()
unique_values, _, inverse_indices, counts = op.Unique(self, axis=dim, sorted=True)
input_size = op.Shape(self)
# Normalize dim to be non-negative
input_ndim = op.Max(op.Size(input_size), op.Constant(value_ints=[1]))
dim = op.Mod(dim, input_ndim)
if return_inverse:
inverse_indices = op.Reshape(
inverse_indices,
op.Reshape(op.Slice(input_size, dim, dim + 1), op.Constant(value_ints=[-1])),
)
else:
inverse_indices = op.ConstantOfShape([0])
inverse_indices = op.Cast(inverse_indices, to=INT64.dtype)
if return_counts:
output_size = op.Shape(unique_values)
counts = op.Reshape(counts, op.Reshape(op.Slice(output_size, dim, dim + 1), [-1]))
else:
counts = op.ConstantOfShape([0])
counts = op.Cast(counts, to=INT64.dtype)
return unique_values, inverse_indices, counts


def aten_unique_dim_consecutive(
Expand Down
53 changes: 53 additions & 0 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,6 +1950,35 @@ def shape(size, rank, with_batch_channel=True):
)


def sample_inputs__unique(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs
):
return_counts = sample.kwargs.pop("return_counts", None)
dim = sample.kwargs.pop("dim", None)
# take only those samples that do not ask for counts or a dim
if not return_counts and dim is None:
yield sample


def sample_inputs__unique2(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs
):
# take only those samples that do not ask for a dim
if sample.kwargs.pop("dim", None) is None:
yield sample


def sample_inputs_unique_dim(op_info, device, dtype, requires_grad, **kwargs):
for sample in common_methods_invocations.sample_inputs_unique(
op_info, device, dtype, requires_grad, **kwargs
):
# take only those samples that ask for a dim
if sample.kwargs.get("dim") is not None:
yield sample


def sample_inputs_upsample_trilinear3d_vec(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
Expand Down Expand Up @@ -2504,6 +2533,30 @@ def __init__(self):
sample_inputs_func=sample_inputs_unfold,
supports_out=False,
),
opinfo_core.OpInfo(
"ops.aten._unique.default",
aten_name="_unique.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs__unique,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten._unique2.default",
aten_name="_unique2.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs__unique2,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten.unique_dim.default",
aten_name="unique_dim.default",
dtypes=common_dtype.floating_types_and(torch.float16, torch.int64, torch.int8),
sample_inputs_func=sample_inputs_unique_dim,
supports_out=False,
supports_autograd=False,
),
opinfo_core.OpInfo(
"ops.aten.upsample_bicubic2d.default",
aten_name="upsample_bicubic2d",
Expand Down
9 changes: 9 additions & 0 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,15 @@ def _where_input_wrangler(
), # Custom from extra_opinfo
TorchLibOpInfo("transpose", core_ops.aten_transpose),
TorchLibOpInfo("transpose", core_ops.aten_transpose_complex, complex=True),
TorchLibOpInfo("ops.aten._unique.default", core_ops.aten__unique),
TorchLibOpInfo("ops.aten._unique2.default", core_ops.aten__unique2),
TorchLibOpInfo("ops.aten.unique_dim.default", core_ops.aten_unique_dim).skip(
device_type="cpu",
reason=(
"ops.aten.unique_dim.default returns different shapes for optional outputs on CPU/CUDA. "
"Our implementation is based on that for CUDA"
),
),
TorchLibOpInfo(
"ops.prims.var.default", prims_ops.prims_var, tolerance={torch.float16: (1e-3, 5e-2)}
),
Expand Down

0 comments on commit 4c1cda2

Please sign in to comment.