Skip to content

Commit

Permalink
Merge branch 'main' into rama/irutils
Browse files Browse the repository at this point in the history
  • Loading branch information
gramalingam authored Dec 3, 2024
2 parents c94db79 + 99cf79f commit 1461ece
Show file tree
Hide file tree
Showing 6 changed files with 697 additions and 17 deletions.
18 changes: 5 additions & 13 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pathlib
from typing import Callable

from onnxscript import ir, optimizer
from onnxscript import ir, optimizer, version_converter
from onnxscript.function_libs.torch_lib import registration
from onnxscript.ir import _external_data

Expand Down Expand Up @@ -51,18 +51,10 @@ def optimize(model: ir.Model) -> ir.Model:

def convert_version(model: ir.Model, target_version: int) -> ir.Model:
"""Convert the model to the specified ONNX opset version."""
# model_version = model.opset_import.get("")
# if model_version == target_version:
# # No conversion needed
# return model

# # FIXME(justinchuby): version_converter does not support functions
# proto = ir.serde.serialize_model(model)
# proto = onnx.version_converter.convert_version(proto, target_version)
# return ir.serde.deserialize_model(proto)
# TODO(justinchuby): This function needs to be carefully implemented
# to handle large models. For now, we just return the model.
del target_version # Unused
# Internal flag. Will go away.
enabled = os.getenv("TORCH_ONNX_ENABLE_VERSION_CONVERSION") == "1"
if enabled:
version_converter.convert_version(model, target_version)
return model


Expand Down
19 changes: 19 additions & 0 deletions onnxscript/_framework_apis/torch_2_6.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
"torchlib_opset",
]
from typing import TYPE_CHECKING

from onnxscript import ir, optimizer
from onnxscript._framework_apis.torch_2_5 import (
check_model,
Expand All @@ -19,8 +22,24 @@
save_model_with_external_data,
)

if TYPE_CHECKING:
from onnxscript.onnx_opset._impl.opset18 import Opset18


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
optimizer.optimize_ir(model)
return model


def torchlib_opset() -> Opset18:
"""Return the default opset for torchlib."""
import onnxscript # pylint: disable=import-outside-toplevel

return onnxscript.opset18 # type: ignore


def torchlib_opset_version() -> int:
"""Return the default opset version for torchlib."""

return torchlib_opset().version
10 changes: 6 additions & 4 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4391,7 +4391,10 @@ def aten_instance_norm(
), "running_mean and running_var must be provided when use_input_stats is False"

batch_size = op.Shape(input, start=0, end=1)
bn_input = op.Reshape(input, op.Concat([1, -1], op.Shape(input, start=2), axis=0))
bn_input = op.Reshape(
input,
op.Concat(op.Constant(value_ints=[1, -1]), op.Shape(input, start=2), axis=0),
)
weight = op.Tile(weight, batch_size)
bias = op.Tile(bias, batch_size)
running_mean = op.Tile(running_mean, batch_size)
Expand Down Expand Up @@ -5225,9 +5228,8 @@ def aten_mean_dim(self: TReal, dim: INT64, keepdim: bool = False) -> TReal:
if IsScalar(self):
result = self
else:
if IsScalar(dim):
dim = op.Unsqueeze(dim, axes=0)
result = op.ReduceMean(self, dim, keepdims=keepdim)
dims = op.Reshape(dim, op.Constant(value_ints=[-1]))
result = op.ReduceMean(self, dims, keepdims=keepdim)
return result


Expand Down
21 changes: 21 additions & 0 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

__all__ = [
# Functions
"convert_version",
]

from onnxscript import ir
from onnxscript.optimizer import _inliner
from onnxscript.version_converter import _version_converter


def convert_version(model: ir.Model, target_version: int) -> None:
"""Convert the model to the specified ONNX opset version."""

# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
# Hence, we inline all the functions.
_inliner.inline(model)
_version_converter.convert_version(model, target_version)
Loading

0 comments on commit 1461ece

Please sign in to comment.