diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index a955583e9..4c22df181 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -5225,9 +5225,8 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal: if IsScalar(self): result = self else: - if IsScalar(dim): - dim = op.Unsqueeze(dim, axes=0) - result = op.ReduceMean(self, dim, keepdims=keepdim) + dims = op.Reshape(dim, op.Constant(value_ints=[-1])) + result = op.ReduceMean(self, dims, keepdims=keepdim) return result