Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Nov 19, 2024
1 parent 0429f87 commit 87b09ea
Showing 1 changed file with 3 additions and 16 deletions.
19 changes: 3 additions & 16 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,14 +1585,14 @@ def aten_cdist(
raise NotImplementedError()


@torch_op("aten::ceil")
@torch_op("aten::ceil", traceable=True)
def aten_ceil(self: TFloat) -> TFloat:
"""ceil(Tensor self) -> Tensor"""

return op.Ceil(self)


@torch_op("math::ceil")
@torch_op("math::ceil", traceable=True)
def python_math_ceil(self: TFloat) -> TInt:
"""ceil(Tensor self) -> Tensor"""
ceil = op.Ceil(self)
Expand Down Expand Up @@ -1764,13 +1764,6 @@ def aten_combinations(
raise NotImplementedError()


@torch_op("aten::complex", private=True)
def _aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""Non-broadcasting complex constructor."""

return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::complex", trace_only=True)
def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""complex(Tensor real, Tensor imag) -> Tensor"""
Expand All @@ -1780,7 +1773,7 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
real = op.Expand(real, broadcasted_shape)
imag = op.Expand(imag, broadcasted_shape)

return _aten_complex(real, imag)
return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::conj", trace_only=True)
Expand All @@ -1790,7 +1783,6 @@ def aten_conj(self: TTensor) -> TTensor:
return op.Identity(self)


@torch_op("aten::conj", complex=True, private=True)
def _complex_conjugate(self: TFloat) -> TFloat:
zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
Expand All @@ -1809,8 +1801,6 @@ def _complex_conjugate(self: TFloat) -> TFloat:
def aten_conj_complex(self: TFloat) -> TFloat:
"""conj(Tensor(a) self) -> Tensor(a)"""

# TODO(#834): Allow calling scripted functions from other
# scripted functions and remove trace only.
return _complex_conjugate(self)


Expand Down Expand Up @@ -5071,9 +5061,6 @@ def aten_mH(self: TRealOrUInt8) -> TRealOrUInt8:
def aten_mH_complex(self: TFloat) -> TFloat:
"""mH(Tensor(a) self) -> Tensor(a)"""

# TODO(#834): Allow calling scripted functions from other
# scripted functions and remove trace only.

# c is the last dimension being the real and imaginary parts
trasposed = op.Einsum(self, equation="...ijc->...jic")
return _complex_conjugate(trasposed)
Expand Down

0 comments on commit 87b09ea

Please sign in to comment.