From 87b09ead26912922b8208ab1158b59cb264b8f2a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 19 Nov 2024 18:21:03 +0000 Subject: [PATCH] more --- .../function_libs/torch_lib/ops/core.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 535da66d0..b2138d4e6 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1585,14 +1585,14 @@ def aten_cdist( raise NotImplementedError() -@torch_op("aten::ceil") +@torch_op("aten::ceil", traceable=True) def aten_ceil(self: TFloat) -> TFloat: """ceil(Tensor self) -> Tensor""" return op.Ceil(self) -@torch_op("math::ceil") +@torch_op("math::ceil", traceable=True) def python_math_ceil(self: TFloat) -> TInt: """ceil(Tensor self) -> Tensor""" ceil = op.Ceil(self) @@ -1764,13 +1764,6 @@ def aten_combinations( raise NotImplementedError() -@torch_op("aten::complex", private=True) -def _aten_complex(real: TFloat, imag: TFloat) -> TFloat: - """Non-broadcasting complex constructor.""" - - return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1) - - @torch_op("aten::complex", trace_only=True) def aten_complex(real: TFloat, imag: TFloat) -> TFloat: """complex(Tensor real, Tensor imag) -> Tensor""" @@ -1780,7 +1773,7 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat: real = op.Expand(real, broadcasted_shape) imag = op.Expand(imag, broadcasted_shape) - return _aten_complex(real, imag) + return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1) @torch_op("aten::conj", trace_only=True) @@ -1790,7 +1783,6 @@ def aten_conj(self: TTensor) -> TTensor: return op.Identity(self) -@torch_op("aten::conj", complex=True, private=True) def _complex_conjugate(self: TFloat) -> TFloat: zero = op.Constant(value_ints=[0]) one = op.Constant(value_ints=[1]) @@ -1809,8 +1801,6 @@ def _complex_conjugate(self: TFloat) -> TFloat: def aten_conj_complex(self: TFloat) -> TFloat: """conj(Tensor(a) self) -> Tensor(a)""" - # TODO(#834): Allow calling scripted functions from other - # scripted functions and remove trace only. return _complex_conjugate(self) @@ -5071,9 +5061,6 @@ def aten_mH(self: TRealOrUInt8) -> TRealOrUInt8: def aten_mH_complex(self: TFloat) -> TFloat: """mH(Tensor(a) self) -> Tensor(a)""" - # TODO(#834): Allow calling scripted functions from other - # scripted functions and remove trace only. - # c is the last dimension being the real and imaginary parts trasposed = op.Einsum(self, equation="...ijc->...jic") return _complex_conjugate(trasposed)