Skip to content

Commit

Permalink
Merge branch 'main' into justinchu/torch-tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms authored Nov 25, 2024
2 parents 4e1f9dd + e282467 commit 70f50f6
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
CREATE_REPRODUCTION_REPORT: "${{ matrix.os == 'ubuntu-latest' && '1' || '0' }}"
- name: Upload coverage to Codecov
if: always()
uses: codecov/codecov-action@v4
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
- name: Upload test results to Codecov
Expand Down
19 changes: 19 additions & 0 deletions onnxscript/_framework_apis/torch_2_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
"torchlib_opset",
]
from typing import TYPE_CHECKING

from onnxscript import ir, optimizer
from onnxscript._framework_apis.torch_2_5 import (
check_model,
Expand All @@ -19,8 +22,24 @@
save_model_with_external_data,
)

if TYPE_CHECKING:
from onnxscript.onnx_opset._impl.opset18 import Opset18


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
optimizer.optimize_ir(model)
return model


def torchlib_opset() -> Opset18:
"""Return the default opset for torchlib."""
import onnxscript # pylint: disable=import-outside-toplevel

return onnxscript.opset18 # type: ignore


def torchlib_opset_version() -> int:
"""Return the default opset version for torchlib."""

return torchlib_opset().version
143 changes: 68 additions & 75 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,14 +1585,14 @@ def aten_cdist(
raise NotImplementedError()


@torch_op("aten::ceil")
@torch_op("aten::ceil", traceable=True)
def aten_ceil(self: TFloat) -> TFloat:
"""ceil(Tensor self) -> Tensor"""

return op.Ceil(self)


@torch_op("math::ceil")
@torch_op("math::ceil", traceable=True)
def python_math_ceil(self: TFloat) -> TInt:
"""ceil(Tensor self) -> Tensor"""
ceil = op.Ceil(self)
Expand Down Expand Up @@ -1764,13 +1764,6 @@ def aten_combinations(
raise NotImplementedError()


@torch_op("aten::complex", private=True)
def _aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""Non-broadcasting complex constructor."""

return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::complex", trace_only=True)
def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
"""complex(Tensor real, Tensor imag) -> Tensor"""
Expand All @@ -1780,7 +1773,7 @@ def aten_complex(real: TFloat, imag: TFloat) -> TFloat:
real = op.Expand(real, broadcasted_shape)
imag = op.Expand(imag, broadcasted_shape)

return _aten_complex(real, imag)
return op.Concat(op.Unsqueeze(real, axes=[-1]), op.Unsqueeze(imag, axes=[-1]), axis=-1)


@torch_op("aten::conj", trace_only=True)
Expand All @@ -1790,7 +1783,6 @@ def aten_conj(self: TTensor) -> TTensor:
return op.Identity(self)


@torch_op("aten::conj", complex=True, private=True)
def _complex_conjugate(self: TFloat) -> TFloat:
zero = op.Constant(value_ints=[0])
one = op.Constant(value_ints=[1])
Expand All @@ -1809,8 +1801,6 @@ def _complex_conjugate(self: TFloat) -> TFloat:
def aten_conj_complex(self: TFloat) -> TFloat:
"""conj(Tensor(a) self) -> Tensor(a)"""

# TODO(#834): Allow calling scripted functions from other
# scripted functions and remove trace only.
return _complex_conjugate(self)


Expand Down Expand Up @@ -3273,7 +3263,7 @@ def aten_empty_quantized(
raise NotImplementedError()


@torch_op("aten::empty_strided")
@torch_op("aten::empty_strided", traceable=True)
def aten_empty_strided(
size: INT64,
stride: INT64,
Expand All @@ -3290,14 +3280,14 @@ def aten_empty_strided(
return op.Expand(zero, size)


@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq"))
@torch_op(("aten::eq", "aten::eq.Tensor", "aten::eq.Scalar", "_operator::eq"), traceable=True)
def aten_eq(self: TTensor, other: TTensor) -> BOOL:
"""eq.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Equal(self, other)


@torch_op("aten::equal")
@torch_op("aten::equal", traceable=True)
def aten_equal(self: TTensor, other: TTensor) -> BOOL:
"""equal(Tensor self, Tensor other) -> bool"""

Expand Down Expand Up @@ -3759,7 +3749,8 @@ def aten_gcd(self: TensorType, other: TensorType) -> TensorType:


@torch_op(
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge")
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
traceable=True,
)
def aten_ge(self: TReal, other: TReal) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand All @@ -3768,7 +3759,8 @@ def aten_ge(self: TReal, other: TReal) -> BOOL:


@torch_op(
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge")
("aten::ge.Tensor", "aten::ge.Scalar", "aten::greater_equal.Tensor", "_operator::ge"),
traceable=True,
)
def aten_ge_bool(self: BOOL, other: BOOL) -> BOOL:
"""ge.Tensor(Tensor self, Tensor other) -> Tensor"""
Expand Down Expand Up @@ -3904,14 +3896,20 @@ def aten_gru_cell(
raise NotImplementedError()


@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"))
@torch_op(
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"),
traceable=True,
)
def aten_gt(self: TReal, other: TReal) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Greater(self, other)


@torch_op(("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"))
@torch_op(
("aten::gt.Tensor", "aten::gt.Scalar", "aten::greater.Tensor", "_operator::gt"),
traceable=True,
)
def aten_gt_bool(self: BOOL, other: BOOL) -> BOOL:
"""gt.Tensor(Tensor self, Tensor other) -> Tensor"""
# self, other, self > other
Expand Down Expand Up @@ -3949,7 +3947,7 @@ def aten_hardshrink_backward(
raise NotImplementedError()


@torch_op("aten::heaviside")
@torch_op("aten::heaviside", traceable=True)
def aten_heaviside(self: TReal, values: TReal) -> TReal:
"""heaviside(Tensor self, Tensor values) -> Tensor"""

Expand Down Expand Up @@ -4393,7 +4391,10 @@ def aten_instance_norm(
), "running_mean and running_var must be provided when use_input_stats is False"

batch_size = op.Shape(input, start=0, end=1)
bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0))
bn_input = op.Reshape(
input,
op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0),
)
weight = op.Tile(weight, batch_size)
bias = op.Tile(bias, batch_size)
running_mean = op.Tile(running_mean, batch_size)
Expand Down Expand Up @@ -4695,14 +4696,20 @@ def aten_ldexp(self: TensorType, other: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"),
traceable=True,
)
def aten_le(self: TReal, other: TReal) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.LessOrEqual(self, other)


@torch_op(("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"))
@torch_op(
("aten::le.Tensor", "aten::le.Scalar", "aten::less_equal.Tensor", "_operator::le"),
traceable=True,
)
def aten_le_bool(self: BOOL, other: BOOL) -> BOOL:
"""le.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5002,14 +5009,20 @@ def aten_lstm_mps_backward(
raise NotImplementedError()


@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"))
@torch_op(
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
traceable=True,
)
def aten_lt(self: TReal, other: TReal) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Less(self, other)


@torch_op(("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"))
@torch_op(
("aten::lt.Tensor", "aten::lt.Scalar", "aten::less.Tensor", "_operator::lt"),
traceable=True,
)
def aten_lt_bool(self: BOOL, other: BOOL) -> BOOL:
"""lt.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -5051,9 +5064,6 @@ def aten_mH(self: TRealOrUInt8) -> TRealOrUInt8:
def aten_mH_complex(self: TFloat) -> TFloat:
"""mH(Tensor(a) self) -> Tensor(a)"""

# TODO(#834): Allow calling scripted functions from other
# scripted functions and remove trace only.

# c is the last dimension being the real and imaginary parts
trasposed = op.Einsum(self, equation="...ijc->...jic")
return _complex_conjugate(trasposed)
Expand Down Expand Up @@ -5218,9 +5228,8 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
if IsScalar(self):
result = self
else:
if IsScalar(dim):
dim = op.Unsqueeze(dim, axes=0)
result = op.ReduceMean(self, dim, keepdims=keepdim)
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
result = op.ReduceMean(self, dims, keepdims=keepdim)
return result


Expand Down Expand Up @@ -6218,14 +6227,14 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"))
@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor", "_operator::ne"), traceable=True)
def aten_ne(self: TReal, other: TReal) -> BOOL:
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""

return op.Not(op.Equal(self, other))


@torch_op(("aten::neg", "_operator::neg"))
@torch_op(("aten::neg", "_operator::neg"), traceable=True)
def aten_neg(self: TReal) -> TReal:
"""neg(Tensor self) -> Tensor"""

Expand Down Expand Up @@ -7067,7 +7076,7 @@ def aten_real(self: TensorType) -> TensorType:
raise NotImplementedError()


@torch_op("aten::reciprocal")
@torch_op("aten::reciprocal", traceable=True)
def aten_reciprocal(self: TFloat) -> TFloat:
"""reciprocal(Tensor self) -> Tensor"""

Expand All @@ -7086,7 +7095,7 @@ def aten_refine_names(self: TensorType, names: Sequence[str]) -> TensorType:
raise NotImplementedError()


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"))
@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar"), traceable=True)
def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand All @@ -7099,7 +7108,9 @@ def aten_remainder(self: TFloat, other: TFloat) -> TFloat:
return op.Sub(self, op.Mul(rounded_quotient, other))


@torch_op(("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"))
@torch_op(
("aten::remainder.Tensor", "aten::remainder.Scalar", "_operator::mod"), traceable=True
)
def aten_remainder_int(self: TInt, other: TInt) -> TInt:
"""remainder.Tensor(Tensor self, Tensor other) -> Tensor"""

Expand Down Expand Up @@ -8013,53 +8024,35 @@ def aten_sub_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
return aten_sub(self, other, alpha=alpha)


@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
def aten_sum_dim_IntList(
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
) -> TReal:
"""sum(Tensor self, SymInt dim, bool keepdim, *, ScalarType? dtype=None) -> Tensor"""

# NOTE: trace_only because both if branches need to be the same type, but we have
# a cast in the if branch.

# TODO: Combine the overloads when OptionalHasElement() works
if dim is None:
result = _aten_sum_dim_none(self, keepdim=keepdim)
@torch_op("aten::sum", trace_only=True)
def aten_sum(self: TReal, dtype: int = -1) -> TReal:
"""sum(Tensor self, *, ScalarType? dtype=None) -> Tensor"""
if len(self.shape) == 0:
result = op.Identity(self)
else:
result = _aten_sum_dim_onnx(self, dim, keepdim=keepdim)

if dtype != -1:
result = op.ReduceSum(self, keepdims=False)
if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)

return result


@torch_op("aten::sum", private=True, traceable=True)
def _aten_sum_dim_onnx(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))

if IsScalar(dim):
@torch_op("aten::sum.dim_IntList", trace_only=True)
def aten_sum_dim_IntList(
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
) -> TReal:
"""sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"""
if len(self.shape) == 0:
result = op.Identity(self)
elif dim is None:
result = op.ReduceSum(self, keepdims=keepdim)
else:
dim = op.Reshape(dim, op.Constant(value_ints=[-1]))
dim = op.Cast(dim, to=INT64.dtype)
result = op.ReduceSum(self, dim, keepdims=keepdim)

if self_is_scalar:
result = op.Squeeze(result)
return result
result = op.ReduceSum(self, dim, keepdims=keepdim)

if dtype != -1 and dtype is not None:
result = op.Cast(result, to=dtype)

@torch_op("aten::sum", private=True)
def _aten_sum_dim_none(self: TReal, keepdim: bool = False) -> TReal:
self_is_scalar = IsScalar(self)
if self_is_scalar:
self = op.Reshape(self, op.Constant(value_ints=[-1]))

result = op.ReduceSum(self, keepdims=keepdim)

if self_is_scalar:
result = op.Squeeze(result)
return result


Expand Down
3 changes: 2 additions & 1 deletion onnxscript/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
__all__ = [
# Modules
"serde",
"convenience",
# IR classes
"Tensor",
"ExternalTensor",
Expand Down Expand Up @@ -77,7 +78,7 @@
"save",
]

from onnxscript.ir import passes, serde, traversal
from onnxscript.ir import convenience, passes, serde, traversal
from onnxscript.ir._convenience import tensor
from onnxscript.ir._core import (
Attr,
Expand Down
Loading

0 comments on commit 70f50f6

Please sign in to comment.