13
13
from __future__ import annotations
14
14
15
15
import math
16
- import re
17
16
from typing import Any , Optional , Sequence , Tuple , Union
18
17
19
18
from onnxscript import (
@@ -8360,13 +8359,6 @@ def aten_unique_consecutive(
8360
8359
raise NotImplementedError ()
8361
8360
8362
8361
8363
- _NOT_IMPLEMENTED_UNIQUE = re .compile (
8364
- r"NOT_IMPLEMENTED\s*:\s*Could\s+not\s+find\s+an\s+implementation\s+for\s+Unique"
8365
- )
8366
- """
8367
- A pattern to detect an unsupported (not implemented) Unique operator
8368
- """
8369
-
8370
8362
@torch_op ("aten::unique" , trace_only = True )
8371
8363
def aten_unique (
8372
8364
self : TensorType ,
@@ -8377,18 +8369,10 @@ def aten_unique(
8377
8369
) -> tuple [TensorType , TensorType , TensorType ]:
8378
8370
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor?, Tensor?)"""
8379
8371
8380
- try :
8381
- if dim is None :
8382
- unique_values , inverse_indices , counts = aten_unique2 (self , sorted , return_inverse , return_counts )
8383
- else :
8384
- unique_values , inverse_indices , counts = aten_unique_dim (self , dim , sorted , return_inverse , return_counts )
8385
- except Exception as e :
8386
- # try to provide a more informative error message
8387
- if _NOT_IMPLEMENTED_UNIQUE .search (str (e )) is not None :
8388
- raise NotImplementedError (
8389
- f"'onnxruntime' does not yet support Unique(11) operator with dtype={ self .dtype } '"
8390
- ) from e
8391
- raise
8372
+ if dim is None :
8373
+ unique_values , inverse_indices , counts = aten_unique2 (self , sorted , return_inverse , return_counts )
8374
+ else :
8375
+ unique_values , inverse_indices , counts = aten_unique_dim (self , dim , sorted , return_inverse , return_counts )
8392
8376
if return_inverse :
8393
8377
if return_counts :
8394
8378
result = unique_values , inverse_indices , counts
@@ -8410,7 +8394,11 @@ def aten_unique2(
8410
8394
) -> tuple [TensorType , TensorType , TensorType ]:
8411
8395
"""unique(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8412
8396
8413
- unique_values , _ , inverse_indices , counts = op .Unique (self , axis = None , sorted = sorted )
8397
+ unique_values , indices , inverse_indices , counts = op .Unique (self , axis = None , sorted = sorted )
8398
+ # HACK: force indices to be in the graph so that it gets a name during optimization
8399
+ # Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8400
+ indices_size = op .Shape (indices )
8401
+ counts = op .Reshape (counts , indices_size )
8414
8402
input_size = op .Shape (self )
8415
8403
inverse_indices = op .Reshape (inverse_indices , input_size )
8416
8404
return unique_values , inverse_indices , counts
@@ -8426,19 +8414,15 @@ def aten_unique_dim(
8426
8414
) -> tuple [TensorType , TensorType , TensorType ]:
8427
8415
"""unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)"""
8428
8416
8429
- unique_values , _ , inverse_indices , counts = op .Unique (self , axis = dim , sorted = sorted )
8417
+ unique_values , indices , inverse_indices , counts = op .Unique (self , axis = dim , sorted = sorted )
8418
+ # HACK: force indices to be in the graph so that it gets a name during optimization
8419
+ # Otherwise an error will be raised in `onnxscript.Scope.lookup_or_create`
8420
+ indices_size = op .Shape (indices )
8421
+ counts = op .Reshape (counts , indices_size )
8430
8422
input_size = op .Shape (self )
8431
- # PyTorch accepts negative dim as reversed counting
8432
- input_rank = op .Size (input_size )
8433
- dim = input_rank + dim
8434
- dim = dim % input_rank
8435
- starts = op .Reshape (dim , [- 1 ])
8436
- ends = op .Reshape (dim + 1 , [- 1 ])
8437
- input_dim_size = op .Slice (input_size , starts = starts , ends = ends )
8438
- inverse_indices = op .Reshape (inverse_indices , input_dim_size )
8423
+ inverse_indices = op .Reshape (inverse_indices , op .Reshape (input_size [dim ], [- 1 ]))
8439
8424
output_size = op .Shape (unique_values )
8440
- output_dim_size = op .Slice (output_size , starts = starts , ends = ends )
8441
- counts = op .Reshape (counts , output_dim_size )
8425
+ counts = op .Reshape (counts , op .Reshape (output_size [dim ], [- 1 ]))
8442
8426
return unique_values , inverse_indices , counts
8443
8427
8444
8428
0 commit comments