Skip to content

Commit

Permalink
[torchlib] Fix aten::diagonal (#1755)
Browse files Browse the repository at this point in the history
Turn aten::diagonal as trace only and fix its logic by explicitly
converting python constants to onnx constants. This was needed because
the exporter logic was not handling the type conversion correctly (yet)
  • Loading branch information
justinchuby authored Jul 31, 2024
1 parent a72f048 commit efe674d
Showing 1 changed file with 14 additions and 28 deletions.
42 changes: 14 additions & 28 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,19 +2542,11 @@ def aten_diagonal(self: TReal, offset: int = 0, dim1: int = 0, dim2: int = 1) ->
# This is because computing diagonal sum is on dim2 after transpose by perm
axes = [self_rank - 2]

return _aten_diagonal_onnx(self, offset, dim1, dim2, perm, axes)


@torch_op("aten::diagonal", private=True, traceable=True)
def _aten_diagonal_onnx(
self: TTensor, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
) -> TTensor:
neg_1 = op.Constant(value_ints=[-1])
dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
tmp_tensor = op.ConstantOfShape(mask_shape)
mask = op.EyeLike(tmp_tensor, k=offset)
mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset)
mask = op.CastLike(mask, self)
self_t = op.Transpose(self, perm=perm)
result = op.Mul(self_t, mask)
Expand All @@ -2580,18 +2572,19 @@ def _aten_diagonal_onnx(
# 6 0 4 0

# From above table, we can get the logic below
offset_val = op.Constant(value_ints=[offset])
if offset < 0:
# row + offset
length = dim1_size + offset
length = op.Add(dim1_size, offset_val)
start = op.Constant(value_ints=[0])
else: # offset >= 0
# col - offset
length = dim2_size - offset
start = op.Reshape(op.Constant(value_int=offset), neg_1)
length = op.Sub(dim2_size, offset_val)
start = offset_val

# max(min(length, min(row, col)), 0)
length = op.Max(op.Min(length, min_dim_size), 0)
end = start + length
length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))
end = op.Add(start, length)
result = op.Slice(result, start, end, axes=axes)

return result
Expand Down Expand Up @@ -2621,19 +2614,11 @@ def aten_diagonal_bool(self: BOOL, offset: int = 0, dim1: int = 0, dim2: int = 1
# This is because computing diagonal sum is on dim2 after transpose by perm
axes = [self_rank - 2]

return _aten_diagonal_bool_onnx(self, offset, dim1, dim2, perm, axes)


@torch_op("aten::diagonal", private=True)
def _aten_diagonal_bool_onnx(
self: BOOL, offset: int, dim1: int, dim2: int, perm: Sequence[int], axes: Sequence[int]
) -> BOOL:
neg_1 = op.Constant(value_ints=[-1])
dim1_size = op.Reshape(op.Gather(op.Shape(self), dim1), neg_1) # row
dim2_size = op.Reshape(op.Gather(op.Shape(self), dim2), neg_1) # col
mask_shape = op.Concat(dim1_size, dim2_size, axis=0)
tmp_tensor = op.ConstantOfShape(mask_shape)
mask = op.EyeLike(tmp_tensor, k=offset)
mask = op.EyeLike(op.ConstantOfShape(mask_shape), k=offset)
self_int = op.Cast(self, to=INT64.dtype)
mask_int = op.Cast(mask, to=INT64.dtype)
self_int_t = op.Transpose(self_int, perm=perm)
Expand All @@ -2660,18 +2645,19 @@ def _aten_diagonal_bool_onnx(
# 6 0 4 0

# From above table, we can get the logic below
offset_val = op.Constant(value_ints=[offset])
if offset < 0:
# row + offset
length = dim1_size + offset
length = op.Add(dim1_size, offset_val)
start = op.Constant(value_ints=[0])
else: # offset >= 0
# col - offset
length = dim2_size - offset
start = op.Reshape(op.Constant(value_int=offset), neg_1)
length = op.Sub(dim2_size, offset_val)
start = offset_val

# max(min(length, min(row, col)), 0)
length = op.Max(op.Min(length, min_dim_size), 0)
end = start + length
length = op.Max(op.Min(length, min_dim_size), op.Constant(value_ints=[0]))
end = op.Add(start, length)
result = op.Slice(result, start, end, axes=axes)
result = op.Cast(result, to=BOOL.dtype)

Expand Down

0 comments on commit efe674d

Please sign in to comment.