Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchlib] Make binary comparison ops and more traceable #1957

Merged
merged 8 commits into from
Nov 19, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3759,7 +3759,8 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:


@torch_op(
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge")
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
traceable=True,
)
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -3768,7 +3769,8 @@ def aten_ge(self: TReal, other: TReal) -> BOOL:


@torch_op(
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge")
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
traceable=True,
)
def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -3904,14 +3906,14 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"))
@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), traceable=True)
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Greater(self, other)


@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"))
@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"), traceable=True)
def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""
# self, other, self > other
Expand Down Expand Up @@ -4695,14 +4697,14 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), traceable=True)
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.LessOrEqual(self, other)


@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"), traceable=True)
def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5002,14 +5004,14 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"))
@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), traceable=True)
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Less(self, other)


@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"))
@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"), traceable=True)
def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down
Loading