From db833431aeb2c92f93cae905ea1e478217ffb0ed Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 19 Nov 2024 10:51:16 -0800 Subject: [PATCH] [torchlib] Simplify aten_sum_dim_IntList (#1958) Simplify aten_sum_dim_IntList by removing the script functions. --- .../function_libs/torch_lib/ops/core.py | 58 +++++++------------ 1 file changed, 20 insertions(+), 38 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index b2138d4e6..a955583e9 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -8022,53 +8022,35 @@ def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: return aten_sub(self, other, alpha=alpha) -@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True) -def aten_sum_dim_IntList( - self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1 -) -> TReal: - """sum(Tensor self, SymInt dim, bool keepdim, *, ScalarType? dtype=None) -> Tensor""" - - # NOTE: trace_only because both if branches need to be the same type, but we have - # a cast in the if branch. - - # TODO: Combine the overloads when OptionalHasElement() works - if dim is None: - result = _aten_sum_dim_none(self, keepdim=keepdim) +@torch_op("aten::sum", trace_only=True) +def aten_sum(self: TReal, dtype: int = -1) -> TReal: + """sum(Tensor self, *, ScalarType? dtype=None) -> Tensor""" + if len(self.shape) == 0: + result = op.Identity(self) else: - result = _aten_sum_dim_onnx(self, dim, keepdim=keepdim) - - if dtype != -1: + result = op.ReduceSum(self, keepdims=False) + if dtype != -1 and dtype is not None: result = op.Cast(result, to=dtype) - return result -@torch_op("aten::sum", private=True, traceable=True) -def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Reshape(self, op.Constant(value_ints=[-1])) - - if IsScalar(dim): +@torch_op("aten::sum.dim_IntList", trace_only=True) +def aten_sum_dim_IntList( + self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1 +) -> TReal: + """sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor""" + if len(self.shape) == 0: + result = op.Identity(self) + elif dim is None: + result = op.ReduceSum(self, keepdims=keepdim) + else: dim = op.Reshape(dim, op.Constant(value_ints=[-1])) dim = op.Cast(dim, to=INT64.dtype) - result = op.ReduceSum(self, dim, keepdims=keepdim) - - if self_is_scalar: - result = op.Squeeze(result) - return result + result = op.ReduceSum(self, dim, keepdims=keepdim) + if dtype != -1 and dtype is not None: + result = op.Cast(result, to=dtype) -@torch_op("aten::sum", private=True) -def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal: - self_is_scalar = IsScalar(self) - if self_is_scalar: - self = op.Reshape(self, op.Constant(value_ints=[-1])) - - result = op.ReduceSum(self, keepdims=keepdim) - - if self_is_scalar: - result = op.Squeeze(result) return result