Skip to content

Commit

Permalink
[ONNX] Use the torchlib opset number and fix opset import logic (pyto…
Browse files Browse the repository at this point in the history
…rch#141413)

- Update the ONNX IR `add_opset_imports` pass to remove the heuristics of taking the `max` of the seen opsets. Instead, it uses the torchlib default opset version for the model's opset_import. The version converter is able to take the true opset versions in the nodes and convert the model to the correct version.
- Update all hard coding of opset 18 to instead query the default torchlib opset from onnxscript, introduced in microsoft/onnxscript#1963

Fixes pytorch#141260
Pull Request resolved: pytorch#141413
Approved by: https://github.com/titaiwangms
  • Loading branch information
justinchuby authored and pobin6 committed Dec 5, 2024
1 parent 34ce0d4 commit efbbced
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 65 deletions.
33 changes: 7 additions & 26 deletions torch/onnx/_internal/_exporter_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
import logging
import warnings
from collections import defaultdict
from typing import Any, Callable, Final, Mapping, Sequence, TYPE_CHECKING, TypeVar
from typing import Any, Callable, Mapping, Sequence, TYPE_CHECKING, TypeVar

import torch
import torch._ops
import torch.utils._pytree as pytree
from torch.onnx import errors
from torch.onnx._internal import io_adapter
from torch.onnx._internal._lazy_import import onnxscript_ir as ir
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.exporter import _onnx_program
from torch.onnx._internal.fx import (
Expand All @@ -49,11 +49,6 @@
from torch._subclasses import fake_tensor
from torch.onnx._internal.fx import diagnostics

_DEFAULT_OPSET_VERSION: Final[int] = 18
"""The default ONNX opset version the exporter will use if one is not specified explicitly
through :class:`ExportOptions`. This should NEVER be accessed outside of this module! Users
should reference :attr:`ExportOptions.opset_version`."""

_PYTORCH_GITHUB_ISSUES_URL = "https://github.com/pytorch/pytorch/issues"
"""The URL to the PyTorch GitHub issues page."""

Expand Down Expand Up @@ -102,9 +97,7 @@ def __init__(self) -> None:
defaultdict(list)
)

# opset_version is unused for now, since torchlib only supports opset18.
# TODO: get opset version from torchlib
self._opset_version = _DEFAULT_OPSET_VERSION
self._opset_version = onnxscript_apis.torchlib_opset_version()
warnings.warn(
f"torch.onnx.dynamo_export only implements opset version {self._opset_version} for now. If you need to use a "
"different opset version, please register them with register_custom_op."
Expand All @@ -114,9 +107,7 @@ def __init__(self) -> None:

@property
def opset_version(self) -> int:
"""The ONNX opset version the exporter should target. Defaults to the latest
supported ONNX opset version: 18. The default version will increment over time as
ONNX continues to evolve."""
"""The ONNX opset version the exporter should target."""

return self._opset_version

Expand All @@ -126,8 +117,6 @@ def _initiate_registry_from_torchlib(self) -> None:
Args:
torchlib_registry: The torchlib registry to use for populating the registry.
"""
import onnxscript._framework_apis.torch_2_6 as onnxscript_apis

for meta in onnxscript_apis.get_torchlib_ops():
internal_name_instance = registration.OpName.from_qualified_name(
meta.qualified_name
Expand Down Expand Up @@ -587,26 +576,18 @@ def export(self) -> _onnx_program.ONNXProgram:
onnx_model = onnxscript_graph.to_model_proto(
self.options.onnx_registry.opset_version,
)
ir_model = ir.serde.deserialize_model(onnx_model)

try:
from onnxscript import optimizer

onnx_model = optimizer.optimize(onnx_model)
except ImportError:
warnings.warn(
"ONNXScript optimizer is not available. Skipping optimization. "
"Please `pip install onnxscript -U` to enable post-export optimization."
)
ir_model = onnxscript_apis.optimize(ir_model)
except Exception as e:
warnings.warn(
"ONNXScript optimizer failed. Skipping optimization. "
"\n\nPLEASE REPORT A BUG AT https://github.com/microsoft/onnxscript/issues "
f"\n\nDetail:\n{e}"
)

return _onnx_program.ONNXProgram(
ir.serde.deserialize_model(onnx_model), None
)
return _onnx_program.ONNXProgram(ir_model, None)

def _assert_fake_tensor_mode(self):
"""Asserts that the model and its input do not contain fake tensors."""
Expand Down
3 changes: 1 addition & 2 deletions torch/onnx/_internal/exporter/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,7 @@ def export_compat(
**_,
) -> _onnx_program.ONNXProgram:
if opset_version is None:
# TODO(justinchuby): Change the hardcoded opset version for it to be flexible
opset_version = 18
opset_version = onnxscript_apis.torchlib_opset_version()

if isinstance(model, torch.export.ExportedProgram):
# We know the model is already exported program, so the args, kwargs, and dynamic_shapes
Expand Down
46 changes: 26 additions & 20 deletions torch/onnx/_internal/exporter/_ir_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from __future__ import annotations

import logging
from typing import Mapping, Sequence
from typing import Sequence

from onnxscript import ir
from torch.onnx._internal._lazy_import import onnxscript_apis, onnxscript_ir as ir


_MIN_ONNX_OPSET_VERSION = 18

# The opset domain for ONNX operators
_ONNX_DOMAIN = ""

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,32 +44,38 @@ def add_torchlib_common_imports(model: ir.Model) -> None:
logger.exception("Failed to add torchlib common imports to the model.")


def _get_opset_version(node: ir.Node, opset_imports: Mapping[str, int]) -> int:
"""Determine the appropriate opset version for a node."""
domain = node.domain
version = node.version if node.version is not None else 1
if domain == "":
return max(version, _MIN_ONNX_OPSET_VERSION)
elif domain in opset_imports:
# Heuristic to use the latest version seen
return max(version, opset_imports[domain])
return version
def _maybe_set_opset_version(
opset_imports: dict[str, int], domain: str, version: int | None
) -> None:
"""Set the opset version for the domain."""
if domain in opset_imports and opset_imports[domain] != 1:
# Already set
return
if domain == _ONNX_DOMAIN:
# Set the default opset version for ONNX operators
opset_imports[domain] = onnxscript_apis.torchlib_opset_version()
return
if version is None:
# We don't know the opset version, so set it to 1
# This is valid for the custom function domains like "pkg.torch.__subgraph__"
opset_imports[domain] = 1
return
# Set the known opset version for the domain
opset_imports[domain] = version


def add_opset_imports(model: ir.Model) -> None:
"""Collect all opsets used and add opset imports to the model and functions."""
for node in ir.traversal.RecursiveGraphIterator(model.graph):
domain = node.domain
model.opset_imports[domain] = _get_opset_version(node, model.opset_imports)
_maybe_set_opset_version(model.opset_imports, domain, node.version)

for function in model.functions.values():
for node in ir.traversal.RecursiveGraphIterator(function):
domain = node.domain
function.opset_imports[domain] = _get_opset_version(
node, function.opset_imports
)
for opset, version in function.opset_imports.items():
_maybe_set_opset_version(function.opset_imports, domain, node.version)
for domain, version in function.opset_imports.items():
# Add all opsets used in the function to the model, because ONNX Runtime
# does not handle adding the opset imports to the model after inlining during inference.
# This should happen after all opsets are collected for the function from its nodes.
model.opset_imports[opset] = max(version, model.opset_imports.get(opset, 1))
_maybe_set_opset_version(model.opset_imports, domain, version)
15 changes: 2 additions & 13 deletions torch/onnx/_internal/exporter/_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@
from torch.onnx._internal.exporter._torchlib import _torchlib_registry


_DEFAULT_OPSET_VERSION = 18


TorchOp: TypeAlias = Union[torch._ops.OpOverload, types.BuiltinFunctionType, Callable]

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -115,20 +112,12 @@ class ONNXRegistry:

def __init__(self) -> None:
"""Initializes the registry"""

# TODO: Design multi-opset version support
self._opset_version = _DEFAULT_OPSET_VERSION

self._opset_version = onnxscript_apis.torchlib_opset_version()
self.functions: dict[TorchOp | str, list[OnnxDecompMeta]] = {}

@property
def opset_version(self) -> int:
"""The ONNX opset version the exporter should target.
Defaults to the latest supported ONNX opset version: 18.
The default version will increment over time as ONNX continues to evolve.
"""

"""The ONNX opset version the exporter should target."""
return self._opset_version

@classmethod
Expand Down
13 changes: 9 additions & 4 deletions torch/onnx/_internal/fx/fx_onnx_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import re
from typing import Callable, Sequence

import onnxscript # type: ignore[import]
from onnxscript.function_libs.torch_lib import ( # type: ignore[import]
import onnxscript
from onnxscript.function_libs.torch_lib import (
graph_building as onnxscript_graph_building,
)

import torch
import torch.fx
from torch.onnx import _type_utils as jit_type_utils
from torch.onnx._internal._lazy_import import onnxscript_apis
from torch.onnx._internal.fx import (
_pass,
diagnostics,
Expand Down Expand Up @@ -144,7 +145,9 @@ def _retrieve_or_adapt_input_to_graph_set(
# Since tensors with rank=0 (i.e., scalar) cannot be concated, all
# scalars are promoted to tensors with shape (1,).
with onnxscript.evaluator.default_as(tracer):
element_value = onnxscript.opset18.Reshape(element_value, [1]) # type: ignore[arg-type, type-var]
element_value = onnxscript_apis.torchlib_opset().Reshape(
element_value, [1]
) # type: ignore[arg-type, type-var]
sequence_mixed_elements.append(element_value)
elif isinstance(tensor, int):
# NOTE: op.Concat doesn't support scalar, so we need to wrap it with
Expand All @@ -165,7 +168,9 @@ def _retrieve_or_adapt_input_to_graph_set(
# onnx-script auto wraps python number with op.Constants,
# so we don't need to specifically process them.
with onnxscript.evaluator.default_as(tracer):
output = onnxscript.opset18.Concat(*sequence_mixed_elements, axis=0) # type: ignore[type-var]
output = onnxscript_apis.torchlib_opset().Concat(
*sequence_mixed_elements, axis=0
) # type: ignore[type-var]
output.dtype = torch.int64 # type: ignore[union-attr]
output.shape = [len(sequence_mixed_elements)] # type: ignore[union-attr]
return output
Expand Down

0 comments on commit efbbced

Please sign in to comment.