Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Nov 19, 2024
1 parent 97979f5 commit deb4411
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8028,17 +8028,15 @@ def aten_sum_dim_IntList(
) -> TReal:
"""sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""

self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))
if len(self.shape) == 0:
return self

if dim is None:
result = op.ReduceSum(self, keepdims=keepdim)
else:
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
dim = op.Cast(dim, to=INT64.dtype)
result = op.ReduceSum(self, dim, keepdims=keepdim)
if self_is_scalar:
result = op.Squeeze(result)

if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)
Expand Down

0 comments on commit deb4411

Please sign in to comment.