Skip to content

Commit

Permalink
[torchlib] Simplify aten_sum_dim_IntList (#1958)
Browse files Browse the repository at this point in the history
Simplify aten_sum_dim_IntList by removing the script functions.
  • Loading branch information
justinchuby authored Nov 19, 2024
1 parent 5c62178 commit db83343
Showing 1 changed file with 20 additions and 38 deletions.
58 changes: 20 additions & 38 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8022,53 +8022,35 @@ def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return aten_sub(self, other, alpha=alpha)


@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
def aten_sum_dim_IntList(
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
) -> TReal:
"""sum(Tensor self, SymInt dim, bool keepdim, *, ScalarType? dtype=None) -> Tensor"""

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.

# TODO: Combine the overloads when OptionalHasElement() works
if dim is None:
result = _aten_sum_dim_none(self, keepdim=keepdim)
@torch_op("aten::sum", trace_only=True)
def aten_sum(self: TReal, dtype: int = -1) -> TReal:
"""sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"""
if len(self.shape) == 0:
result = op.Identity(self)
else:
result = _aten_sum_dim_onnx(self, dim, keepdim=keepdim)

if dtype != -1:
result = op.ReduceSum(self, keepdims=False)
if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)

return result


@torch_op("aten::sum", private=True, traceable=True)
def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))

if IsScalar(dim):
@torch_op("aten::sum.dim_IntList", trace_only=True)
def aten_sum_dim_IntList(
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
) -> TReal:
"""sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
if len(self.shape) == 0:
result = op.Identity(self)
elif dim is None:
result = op.ReduceSum(self, keepdims=keepdim)
else:
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
dim = op.Cast(dim, to=INT64.dtype)
result = op.ReduceSum(self, dim, keepdims=keepdim)

if self_is_scalar:
result = op.Squeeze(result)
return result
result = op.ReduceSum(self, dim, keepdims=keepdim)

if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)

@torch_op("aten::sum", private=True)
def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal:
self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))

result = op.ReduceSum(self, keepdims=keepdim)

if self_is_scalar:
result = op.Squeeze(result)
return result


Expand Down

0 comments on commit db83343

Please sign in to comment.