From abfcd4cb21403a4106b7f67936b7335b1d4227a8 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 08:54:26 -0700 Subject: [PATCH 01/16] Update package namespace to onnx_ir Signed-off-by: Justin Chu --- MANIFEST.in | 1 + src/onnx_ir/__init__.py | 2 +- src/onnx_ir/_convenience/__init__.py | 6 +- src/onnx_ir/_convenience/_constructors.py | 8 +- .../_convenience/_constructors_test.py | 2 +- src/onnx_ir/_core.py | 6 +- src/onnx_ir/_core_test.py | 4 +- src/onnx_ir/_enums_test.py | 2 +- src/onnx_ir/_graph_comparison.py | 2 +- src/onnx_ir/_graph_containers.py | 2 +- src/onnx_ir/_internal/version_utils.py | 118 +++++++ src/onnx_ir/_io.py | 4 +- src/onnx_ir/_io_test.py | 4 +- src/onnx_ir/_linked_list_test.py | 2 +- src/onnx_ir/_name_authority.py | 2 +- src/onnx_ir/_name_authority_test.py | 4 +- src/onnx_ir/_protocols.py | 2 +- src/onnx_ir/_schemas.py | 2 +- src/onnx_ir/_schemas_test.py | 2 +- src/onnx_ir/_tape.py | 6 +- src/onnx_ir/_tape_test.py | 2 +- src/onnx_ir/_thirdparty/asciichartpy.py | 313 ++++++++++++++++++ src/onnx_ir/_type_casting_test.py | 2 +- src/onnx_ir/external_data.py | 4 +- src/onnx_ir/external_data_test.py | 4 +- src/onnx_ir/passes/_pass_infra.py | 2 +- src/onnx_ir/passes/_pass_infra_test.py | 2 +- src/onnx_ir/passes/common/_c_api_utils.py | 2 +- .../common/clear_metadata_and_docstring.py | 2 +- .../clear_metadata_and_docstring_test.py | 2 +- .../passes/common/constant_manipulation.py | 2 +- .../common/constant_manipulation_test.py | 2 +- src/onnx_ir/passes/common/inliner.py | 2 +- src/onnx_ir/passes/common/inliner_test.py | 2 +- src/onnx_ir/passes/common/onnx_checker.py | 2 +- .../passes/common/onnx_checker_test.py | 2 +- src/onnx_ir/passes/common/shape_inference.py | 2 +- .../passes/common/shape_inference_test.py | 2 +- src/onnx_ir/passes/common/topological_sort.py | 2 +- .../passes/common/topological_sort_test.py | 2 +- src/onnx_ir/passes/common/unused_removal.py | 2 +- .../passes/common/unused_removal_test.py | 2 +- src/onnx_ir/serde.py | 2 +- src/onnx_ir/serde_test.py | 4 +- src/onnx_ir/tensor_adapters.py | 6 +- src/onnx_ir/tensor_adapters_test.py | 2 +- src/onnx_ir/traversal.py | 2 +- src/onnx_ir/traversal_test.py | 4 +- tests/ir/graph_view_test.py | 2 +- tests/ir/serde_roundtrip_test.py | 2 +- tools/model_zoo_test/model_zoo_test.py | 2 +- 51 files changed, 499 insertions(+), 67 deletions(-) create mode 100644 MANIFEST.in create mode 100644 src/onnx_ir/_internal/version_utils.py create mode 100644 src/onnx_ir/_thirdparty/asciichartpy.py diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 00000000..aa2a5b26 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +global-exclude *_test.py diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index b5daebe2..a73baae5 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -83,7 +83,7 @@ "save", ] -from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal +from onnx_ir import convenience, external_data, passes, serde, tape, traversal from onnxscript.ir._convenience._constructors import node, tensor from onnxscript.ir._core import ( Attr, diff --git a/src/onnx_ir/_convenience/__init__.py b/src/onnx_ir/_convenience/__init__.py index 839c5d33..42ff3e55 100644 --- a/src/onnx_ir/_convenience/__init__.py +++ b/src/onnx_ir/_convenience/__init__.py @@ -20,7 +20,7 @@ import onnx -from onnxscript.ir import _core, _enums, _protocols, serde +from onnx_ir import _core, _enums, _protocols, serde SupportedAttrTypes = Union[ str, @@ -188,7 +188,7 @@ def convert_attributes( types are: int, float, str, Sequence[int], Sequence[float], Sequence[str], :class:`_core.Tensor`, and :class:`_core.Attr`:: - >>> from onnxscript import ir + >>> import onnx_ir as ir >>> import onnx >>> import numpy as np >>> attrs = { @@ -269,7 +269,7 @@ def replace_all_uses_with( We want to replace the node A with a new node D:: - >>> from onnxscript import ir + >>> import onnx_ir as ir >>> input = ir.Input("input") >>> node_a = ir.Node("", "A", [input]) >>> node_b = ir.Node("", "B", node_a.outputs) diff --git a/src/onnx_ir/_convenience/_constructors.py b/src/onnx_ir/_convenience/_constructors.py index 33b738e5..7e66fa24 100644 --- a/src/onnx_ir/_convenience/_constructors.py +++ b/src/onnx_ir/_convenience/_constructors.py @@ -15,12 +15,12 @@ import numpy as np import onnx -from onnxscript.ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters +from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters if typing.TYPE_CHECKING: import numpy.typing as npt - from onnxscript import ir + import onnx_ir as ir def tensor( @@ -39,7 +39,7 @@ def tensor( Example:: - >>> from onnxscript import ir + >>> import onnx_ir as ir >>> import numpy as np >>> import ml_dtypes >>> import onnx @@ -162,7 +162,7 @@ def node( Example:: - >>> from onnxscript import ir + >>> import onnx_ir as ir >>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32)) >>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32)) >>> node = ir.node( diff --git a/src/onnx_ir/_convenience/_constructors_test.py b/src/onnx_ir/_convenience/_constructors_test.py index 6f291d81..0723a619 100644 --- a/src/onnx_ir/_convenience/_constructors_test.py +++ b/src/onnx_ir/_convenience/_constructors_test.py @@ -6,7 +6,7 @@ import numpy as np -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir._convenience import _constructors diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index f699916f..f36ceed6 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -45,7 +45,7 @@ from typing_extensions import TypeIs import onnxscript -from onnxscript.ir import ( +from onnx_ir import ( _display, _enums, _graph_containers, @@ -852,7 +852,7 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too- Example:: >>> import numpy as np - >>> from onnxscript import ir + >>> import onnx_ir as ir >>> weights = np.array([[1, 2, 3]]) >>> def create_tensor(): # Delay applying transformations to the weights ... weights_t = weights.transpose() @@ -1039,7 +1039,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable): Example:: - >>> from onnxscript import ir + >>> import onnx_ir as ir >>> shape = ir.Shape(["B", None, 3]) >>> shape.rank() 3 diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index 2af10646..8fd3830d 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -15,8 +15,8 @@ import parameterized import torch -from onnxscript import ir -from onnxscript.ir import _core +import onnx_ir as ir +from onnx_ir import _core class TensorTest(unittest.TestCase): diff --git a/src/onnx_ir/_enums_test.py b/src/onnx_ir/_enums_test.py index 906bf7b5..96f32b06 100644 --- a/src/onnx_ir/_enums_test.py +++ b/src/onnx_ir/_enums_test.py @@ -9,7 +9,7 @@ import onnx._custom_element_types import parameterized -from onnxscript.ir import _enums +from onnx_ir import _enums class DataTypeTest(unittest.TestCase): diff --git a/src/onnx_ir/_graph_comparison.py b/src/onnx_ir/_graph_comparison.py index e13b8ba4..b0877433 100644 --- a/src/onnx_ir/_graph_comparison.py +++ b/src/onnx_ir/_graph_comparison.py @@ -4,7 +4,7 @@ from __future__ import annotations -from onnxscript.ir import _core +from onnx_ir import _core # NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs # NOTE(justinchuby): A graph may be specified with a set of inputs and outputs diff --git a/src/onnx_ir/_graph_containers.py b/src/onnx_ir/_graph_containers.py index 620e73e8..a5b6796b 100644 --- a/src/onnx_ir/_graph_containers.py +++ b/src/onnx_ir/_graph_containers.py @@ -17,7 +17,7 @@ import onnxscript if TYPE_CHECKING: - from onnxscript.ir import _core + from onnx_ir import _core class _GraphIO(collections.UserList["_core.Value"]): diff --git a/src/onnx_ir/_internal/version_utils.py b/src/onnx_ir/_internal/version_utils.py new file mode 100644 index 00000000..390f7ee3 --- /dev/null +++ b/src/onnx_ir/_internal/version_utils.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Version utils for testing.""" + +from __future__ import annotations + +import warnings +from typing import Callable, Sequence + +import packaging.version + + +def onnx_older_than(version: str) -> bool: + """Returns True if the ONNX version is older than the given version.""" + import onnx # pylint: disable=import-outside-toplevel + + return ( + packaging.version.parse(onnx.__version__).release + < packaging.version.parse(version).release + ) + + +def torch_older_than(version: str) -> bool: + """Returns True if the torch version is older than the given version.""" + import torch # pylint: disable=import-outside-toplevel + + return ( + packaging.version.parse(torch.__version__).release + < packaging.version.parse(version).release + ) + + +def transformers_older_than(version: str) -> bool | None: + """Returns True if the transformers version is older than the given version.""" + try: + import transformers # pylint: disable=import-outside-toplevel + except ImportError: + return None + + return ( + packaging.version.parse(transformers.__version__).release + < packaging.version.parse(version).release + ) + + +def is_onnxruntime_training() -> bool: + """Returns True if the onnxruntime is onnxruntime-training.""" + try: + from onnxruntime import training # pylint: disable=import-outside-toplevel + + assert training + except ImportError: + # onnxruntime not training + return False + + try: + from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel + OrtValueVector, + ) + except ImportError: + return False + + return hasattr(OrtValueVector, "push_back_batch") + + +def onnxruntime_older_than(version: str) -> bool: + """Returns True if the onnxruntime version is older than the given version.""" + import onnxruntime # pylint: disable=import-outside-toplevel + + return ( + packaging.version.parse(onnxruntime.__version__).release + < packaging.version.parse(version).release + ) + + +def numpy_older_than(version: str) -> bool: + """Returns True if the numpy version is older than the given version.""" + import numpy # pylint: disable=import-outside-toplevel + + return ( + packaging.version.parse(numpy.__version__).release + < packaging.version.parse(version).release + ) + + +def has_transformers(): + """Tells if transformers is installed.""" + try: + import transformers # pylint: disable=import-outside-toplevel + + assert transformers + return True # noqa + except ImportError: + return False + + +def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type] + """Catches warnings. + + Args: + warns: warnings to ignore + + Returns: + decorated function + """ + + def wrapper(fct): + if warns is None: + raise AssertionError(f"warns cannot be None for '{fct}'.") + + def call_f(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", warns) # type: ignore[arg-type] + return fct(self) + + return call_f + + return wrapper diff --git a/src/onnx_ir/_io.py b/src/onnx_ir/_io.py index a83cfdbd..f1b9cdc2 100644 --- a/src/onnx_ir/_io.py +++ b/src/onnx_ir/_io.py @@ -10,8 +10,8 @@ import onnx -from onnxscript.ir import _core, serde -from onnxscript.ir import external_data as _external_data +from onnx_ir import _core, serde +from onnx_ir import external_data as _external_data from onnxscript.ir._polyfill import zip diff --git a/src/onnx_ir/_io_test.py b/src/onnx_ir/_io_test.py index 6473827b..52feab80 100644 --- a/src/onnx_ir/_io_test.py +++ b/src/onnx_ir/_io_test.py @@ -8,8 +8,8 @@ import numpy as np -from onnxscript import ir -from onnxscript.ir import _io +import onnx_ir as ir +from onnx_ir import _io def _create_initializer(tensor: ir.TensorProtocol) -> ir.Value: diff --git a/src/onnx_ir/_linked_list_test.py b/src/onnx_ir/_linked_list_test.py index 00f03e71..bdc51d64 100644 --- a/src/onnx_ir/_linked_list_test.py +++ b/src/onnx_ir/_linked_list_test.py @@ -8,7 +8,7 @@ import parameterized -from onnxscript.ir import _linked_list +from onnx_ir import _linked_list class _TestElement: diff --git a/src/onnx_ir/_name_authority.py b/src/onnx_ir/_name_authority.py index ab12be53..0f058f67 100644 --- a/src/onnx_ir/_name_authority.py +++ b/src/onnx_ir/_name_authority.py @@ -4,7 +4,7 @@ from __future__ import annotations -from onnxscript.ir import _core +from onnx_ir import _core class NameAuthority: diff --git a/src/onnx_ir/_name_authority_test.py b/src/onnx_ir/_name_authority_test.py index 1a0fed80..c6661639 100644 --- a/src/onnx_ir/_name_authority_test.py +++ b/src/onnx_ir/_name_authority_test.py @@ -2,8 +2,8 @@ # Licensed under the MIT License. import unittest -from onnxscript import ir -from onnxscript.ir import _name_authority +import onnx_ir as ir +from onnx_ir import _name_authority class NameAuthorityTest(unittest.TestCase): diff --git a/src/onnx_ir/_protocols.py b/src/onnx_ir/_protocols.py index fbc2c7c0..688c4226 100644 --- a/src/onnx_ir/_protocols.py +++ b/src/onnx_ir/_protocols.py @@ -45,7 +45,7 @@ Tuple, ) -from onnxscript.ir import _enums +from onnx_ir import _enums if typing.TYPE_CHECKING: import numpy as np diff --git a/src/onnx_ir/_schemas.py b/src/onnx_ir/_schemas.py index d4d88ab5..6cfa15c8 100644 --- a/src/onnx_ir/_schemas.py +++ b/src/onnx_ir/_schemas.py @@ -13,7 +13,7 @@ import onnx import onnxscript -from onnxscript import ir +import onnx_ir as ir logger = logging.getLogger(__name__) diff --git a/src/onnx_ir/_schemas_test.py b/src/onnx_ir/_schemas_test.py index c134bd7a..79b9ae1f 100644 --- a/src/onnx_ir/_schemas_test.py +++ b/src/onnx_ir/_schemas_test.py @@ -10,7 +10,7 @@ import onnxscript import onnxscript.testing from onnxscript import FLOAT, INT64, ir -from onnxscript.ir import _schemas +from onnx_ir import _schemas _TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) _TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) diff --git a/src/onnx_ir/_tape.py b/src/onnx_ir/_tape.py index fbcfcb42..4997e2f1 100644 --- a/src/onnx_ir/_tape.py +++ b/src/onnx_ir/_tape.py @@ -12,8 +12,8 @@ Tuple, ) -from onnxscript import ir -from onnxscript.ir import _convenience +import onnx_ir as ir +from onnx_ir import _convenience # A type representing the domains/versions used in creating nodes in IR. UsedOpsets = set[Tuple[str, Optional[int]]] @@ -27,7 +27,7 @@ class Tape: Example:: - from onnxscript import ir + import onnx_ir as ir tape = ir.tape.Tape() a = tape.initializer(ir.tensor([1, 2, 3], name="a")) diff --git a/src/onnx_ir/_tape_test.py b/src/onnx_ir/_tape_test.py index 46cbcc23..31e3d2c1 100644 --- a/src/onnx_ir/_tape_test.py +++ b/src/onnx_ir/_tape_test.py @@ -4,7 +4,7 @@ import unittest -from onnxscript import ir +import onnx_ir as ir class TestTape(unittest.TestCase): diff --git a/src/onnx_ir/_thirdparty/asciichartpy.py b/src/onnx_ir/_thirdparty/asciichartpy.py new file mode 100644 index 00000000..88c46202 --- /dev/null +++ b/src/onnx_ir/_thirdparty/asciichartpy.py @@ -0,0 +1,313 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# +# Copyright © 2016 Igor Kroitor +# +# MIT License +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Module to generate ascii charts. + +This module provides a single function `plot` that can be used to generate an +ascii chart from a series of numbers. The chart can be configured via several +options to tune the output. +""" + +from __future__ import annotations + +from math import ceil, floor, isnan +from typing import Mapping + +black = "\033[30m" +red = "\033[31m" +green = "\033[32m" +yellow = "\033[33m" +blue = "\033[34m" +magenta = "\033[35m" +cyan = "\033[36m" +lightgray = "\033[37m" +default = "\033[39m" +darkgray = "\033[90m" +lightred = "\033[91m" +lightgreen = "\033[92m" +lightyellow = "\033[93m" +lightblue = "\033[94m" +lightmagenta = "\033[95m" +lightcyan = "\033[96m" +white = "\033[97m" +reset = "\033[0m" + + +__all__ = [ + "plot", + "black", + "red", + "green", + "yellow", + "blue", + "magenta", + "cyan", + "lightgray", + "default", + "darkgray", + "lightred", + "lightgreen", + "lightyellow", + "lightblue", + "lightmagenta", + "lightcyan", + "white", + "reset", +] + + +# Python 3.2 has math.isfinite, which could have been used, but to support older +# versions, this little helper is shorter than having to keep doing not isnan(), +# plus the double-negative of "not is not a number" is confusing, so this should +# help with readability. +def _isnum(n): + return not isnan(n) + + +def colored(char, color): + if not color: + return char + else: + return color + char + reset + + +_DEFAULT_SYMBOLS = ("┼", "┤", "╶", "╴", "─", "╰", "╭", "╮", "╯", "│") + + +def plot(series, *, bin_edges=None, cfg=None): + """Generate an ascii chart for a series of numbers. + + `series` should be a list of ints or floats. Missing data values in the + series can be specified as a NaN. In Python versions less than 3.5, use + float("nan") to specify an NaN. With 3.5 onwards, use math.nan to specify a + NaN. + + >>> series = [1,2,3,4,float("nan"),4,3,2,1] + >>> print(plot(series)) + 4.00 ┤ ╭╴╶╮ + 3.00 ┤ ╭╯ ╰╮ + 2.00 ┤╭╯ ╰╮ + 1.00 ┼╯ ╰ + + `series` can also be a list of lists to support multiple data series. + + >>> series = [[10,20,30,40,30,20,10], [40,30,20,10,20,30,40]] + >>> print(plot(series, cfg={'height': 3})) + 40.00 ┤╮ ╭╮ ╭ + 30.00 ┤╰╮╯╰╭╯ + 20.00 ┤╭╰╮╭╯╮ + 10.00 ┼╯ ╰╯ ╰ + + `bin_edges` is an optional list of bin edges to display on the x-axis. If + provided, the x-axis will be labeled with the bin edges. If there are too + many bin edges to fit on the x-axis, some labels will be dropped and they + will be spaced out evenly to fit the width of the chart. + The labels will be formatted using the `x_format` option in `cfg`. + + `cfg` is an optional dictionary of various parameters to tune the appearance + of the chart. `min` and `max` will clamp the y-axis and all values: + + >>> series = [1,2,3,4,float("nan"),4,3,2,1] + >>> print(plot(series, cfg={'min': 0})) + 4.00 ┼ ╭╴╶╮ + 3.00 ┤ ╭╯ ╰╮ + 2.00 ┤╭╯ ╰╮ + 1.00 ┼╯ ╰ + 0.00 ┤ + + >>> print(plot(series, cfg={'min': 2})) + 4.00 ┤ ╭╴╶╮ + 3.00 ┤ ╭╯ ╰╮ + 2.00 ┼─╯ ╰─ + + >>> print(plot(series, cfg={'min': 2, 'max': 3})) + 3.00 ┤ ╭─╴╶─╮ + 2.00 ┼─╯ ╰─ + + `height` specifies the number of rows the graph should occupy. It can be + used to scale down a graph with large data values: + + >>> series = [10,20,30,40,50,40,30,20,10] + >>> print(plot(series, cfg={'height': 4})) + 50.00 ┤ ╭╮ + 40.00 ┤ ╭╯╰╮ + 30.00 ┤ ╭╯ ╰╮ + 20.00 ┤╭╯ ╰╮ + 10.00 ┼╯ ╰ + + `format` specifies a Python format string used to format the labels on the + y-axis. The default value is "{:8.2f} ". This can be used to remove the + decimal point: + + >>> series = [10,20,30,40,50,40,30,20,10] + >>> print(plot(series, cfg={'height': 4, 'format':'{:8.0f}'})) + 50 ┤ ╭╮ + 40 ┤ ╭╯╰╮ + 30 ┤ ╭╯ ╰╮ + 20 ┤╭╯ ╰╮ + 10 ┼╯ ╰ + """ + if len(series) == 0: + return "" + + if not isinstance(series[0], list): + if all(isnan(n) for n in series): + return "" + else: + series = [series] + + if cfg is not None and not isinstance(cfg, Mapping): + raise TypeError("cfg must be a dictionary or None") + + cfg = cfg or {} + + colors = cfg.get("colors", [None]) + + minimum = cfg.get("min", min(filter(_isnum, [j for i in series for j in i]))) + maximum = cfg.get("max", max(filter(_isnum, [j for i in series for j in i]))) + + symbols = cfg.get("symbols", _DEFAULT_SYMBOLS) + + if minimum > maximum: + raise ValueError("The min value cannot exceed the max value.") + + interval = maximum - minimum + offset = cfg.get("offset", 3) + height = cfg.get("height", interval) + ratio = height / interval if interval > 0 else 1 + + min2 = floor(minimum * ratio) + max2 = ceil(maximum * ratio) + + def clamp(n): + return min(max(n, minimum), maximum) + + def scaled(y): + return int(round(clamp(y) * ratio) - min2) + + rows = max2 - min2 + + width = 0 + for series_i in series: + width = max(width, len(series_i)) + width += offset + + placeholder = cfg.get("format", "{:8.2f} ") + x_placeholder = cfg.get("x_format", "{:4.4f}") + + result = [[" "] * width for i in range(rows + 1)] + + # axis and labels + for y in range(min2, max2 + 1): + label = placeholder.format(maximum - ((y - min2) * interval / (rows if rows else 1))) + result[y - min2][max(offset - len(label), 0)] = label + result[y - min2][offset - 1] = symbols[0] if y == 0 else symbols[1] # zero tick mark + + # first value is a tick mark across the y-axis + d0 = series[0][0] + if _isnum(d0): + result[rows - scaled(d0)][offset - 1] = symbols[0] + + for i, series_i in enumerate(series): + color = colors[i % len(colors)] + + # plot the line + for x in range(len(series_i) - 1): + d0 = series_i[x + 0] + d1 = series_i[x + 1] + + if isnan(d0) and isnan(d1): + continue + + if isnan(d0) and _isnum(d1): + result[rows - scaled(d1)][x + offset] = colored(symbols[2], color) + continue + + if _isnum(d0) and isnan(d1): + result[rows - scaled(d0)][x + offset] = colored(symbols[3], color) + continue + + y0 = scaled(d0) + y1 = scaled(d1) + if y0 == y1: + result[rows - y0][x + offset] = colored(symbols[4], color) + continue + + result[rows - y1][x + offset] = ( + colored(symbols[5], color) if y0 > y1 else colored(symbols[6], color) + ) + result[rows - y0][x + offset] = ( + colored(symbols[7], color) if y0 > y1 else colored(symbols[8], color) + ) + + start = min(y0, y1) + 1 + end = max(y0, y1) + for y in range(start, end): + result[rows - y][x + offset] = colored(symbols[9], color) + + the_plot = "\n".join(["".join(row).rstrip() for row in result]) + + if bin_edges is None or len(bin_edges) == 0: + return the_plot + + # Plot x axis labels + current_location = 0 + # Compute the amount of leading space for the first x-label using the old label size + leading_space = offset + len(label) + # Obtain the first x-label to compute its size + x_label = x_placeholder.format(bin_edges[0]) + # Initialize the x-label text with the leading space. We allow the first label to + # recess so that the center of it is aligned with the first tick mark. + x_label_size = len(x_label) + x_leading_space = max(0, leading_space - x_label_size) + + x_labels = [] + # This is the amount of space we have to fit the x-labels. It can overflow the width + # by half of the x-label size + workable_width = width + x_label_size // 2 + # Compute the spacing between x-labels + # If we fit labels and space them by 2 characters, we can fit this many labels: + min_spacing = 2 + num_labels_can_fit = width // (x_label_size + min_spacing) + labels_count = len(bin_edges) + # Find out the actual number of labels we need to display + num_labels_to_display = min(labels_count, num_labels_can_fit) + num_spaces = num_labels_to_display - 1 + spacing = max( + min_spacing, + (workable_width - num_labels_to_display * x_label_size) // num_spaces, + ) + # Now start placing labels + while current_location < workable_width: + # Find the current label that would be suitable for the current location + bin_index = int((current_location / workable_width) * labels_count) + x_label = x_placeholder.format(bin_edges[bin_index]) + x_labels.append(x_label) + # Move to the next location + current_location += len(x_label) + spacing + # Create the x-label row + x_labels_text = " " * x_leading_space + (" " * spacing).join(x_labels) + + return the_plot + "\n" + x_labels_text diff --git a/src/onnx_ir/_type_casting_test.py b/src/onnx_ir/_type_casting_test.py index abe4923e..f9ce9ca9 100644 --- a/src/onnx_ir/_type_casting_test.py +++ b/src/onnx_ir/_type_casting_test.py @@ -5,7 +5,7 @@ import numpy as np import parameterized -from onnxscript.ir import _type_casting +from onnx_ir import _type_casting class TypeCastingTest(unittest.TestCase): diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index 4ca9ca50..b6570f2c 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -17,8 +17,8 @@ import os from typing import Iterator, Sequence -from onnxscript.ir import _core, _enums, _protocols -from onnxscript.ir import traversal as _traversal +from onnx_ir import _core, _enums, _protocols +from onnx_ir import traversal as _traversal from onnxscript.ir._polyfill import zip # Note: If needed in future, add these as parameters to the function calls diff --git a/src/onnx_ir/external_data_test.py b/src/onnx_ir/external_data_test.py index 11de6285..ee803e6d 100644 --- a/src/onnx_ir/external_data_test.py +++ b/src/onnx_ir/external_data_test.py @@ -10,8 +10,8 @@ import onnx import onnx.external_data_helper -from onnxscript import ir -from onnxscript.ir import external_data +import onnx_ir as ir +from onnx_ir import external_data class ExternalDataTest(unittest.TestCase): diff --git a/src/onnx_ir/passes/_pass_infra.py b/src/onnx_ir/passes/_pass_infra.py index 18e5c871..1a0195ad 100644 --- a/src/onnx_ir/passes/_pass_infra.py +++ b/src/onnx_ir/passes/_pass_infra.py @@ -34,7 +34,7 @@ import abc -from onnxscript import ir +import onnx_ir as ir logger = logging.getLogger(__name__) diff --git a/src/onnx_ir/passes/_pass_infra_test.py b/src/onnx_ir/passes/_pass_infra_test.py index 7f916bae..a60beb97 100644 --- a/src/onnx_ir/passes/_pass_infra_test.py +++ b/src/onnx_ir/passes/_pass_infra_test.py @@ -5,7 +5,7 @@ import unittest -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes import _pass_infra diff --git a/src/onnx_ir/passes/common/_c_api_utils.py b/src/onnx_ir/passes/common/_c_api_utils.py index bb2715c7..4f728262 100644 --- a/src/onnx_ir/passes/common/_c_api_utils.py +++ b/src/onnx_ir/passes/common/_c_api_utils.py @@ -7,7 +7,7 @@ import logging from typing import TYPE_CHECKING, Callable, TypeVar -from onnxscript import ir +import onnx_ir as ir if TYPE_CHECKING: import onnx diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring.py index 0c1fa48c..542423e9 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring.py @@ -10,7 +10,7 @@ import logging -from onnxscript import ir +import onnx_ir as ir logger = logging.getLogger(__name__) diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py index 7707a87f..3a1a40a4 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py @@ -6,7 +6,7 @@ import numpy as np -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import clear_metadata_and_docstring diff --git a/src/onnx_ir/passes/common/constant_manipulation.py b/src/onnx_ir/passes/common/constant_manipulation.py index b76c3c08..2308e2e9 100644 --- a/src/onnx_ir/passes/common/constant_manipulation.py +++ b/src/onnx_ir/passes/common/constant_manipulation.py @@ -15,7 +15,7 @@ import numpy as np -from onnxscript import ir +import onnx_ir as ir logger = logging.getLogger(__name__) diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index d0293313..60899715 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -7,7 +7,7 @@ import numpy as np import parameterized -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import constant_manipulation diff --git a/src/onnx_ir/passes/common/inliner.py b/src/onnx_ir/passes/common/inliner.py index 3a4f97a8..69bebe31 100644 --- a/src/onnx_ir/passes/common/inliner.py +++ b/src/onnx_ir/passes/common/inliner.py @@ -12,7 +12,7 @@ from typing import Iterable, List, Sequence, Tuple import onnxscript.ir.convenience as _ir_convenience -from onnxscript import ir +import onnx_ir as ir # A replacement for a node specifies a list of nodes that replaces the original node, # and a list of values that replaces the original node's outputs. diff --git a/src/onnx_ir/passes/common/inliner_test.py b/src/onnx_ir/passes/common/inliner_test.py index 1a4be6ce..ed98ba51 100644 --- a/src/onnx_ir/passes/common/inliner_test.py +++ b/src/onnx_ir/passes/common/inliner_test.py @@ -9,7 +9,7 @@ import onnx -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import inliner diff --git a/src/onnx_ir/passes/common/onnx_checker.py b/src/onnx_ir/passes/common/onnx_checker.py index b8156296..668333a4 100644 --- a/src/onnx_ir/passes/common/onnx_checker.py +++ b/src/onnx_ir/passes/common/onnx_checker.py @@ -12,7 +12,7 @@ import onnx -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import _c_api_utils diff --git a/src/onnx_ir/passes/common/onnx_checker_test.py b/src/onnx_ir/passes/common/onnx_checker_test.py index 14422541..5033c1c0 100644 --- a/src/onnx_ir/passes/common/onnx_checker_test.py +++ b/src/onnx_ir/passes/common/onnx_checker_test.py @@ -4,7 +4,7 @@ import unittest -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import onnx_checker diff --git a/src/onnx_ir/passes/common/shape_inference.py b/src/onnx_ir/passes/common/shape_inference.py index 586fa5b4..158fe341 100644 --- a/src/onnx_ir/passes/common/shape_inference.py +++ b/src/onnx_ir/passes/common/shape_inference.py @@ -13,7 +13,7 @@ import onnx -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import _c_api_utils logger = logging.getLogger(__name__) diff --git a/src/onnx_ir/passes/common/shape_inference_test.py b/src/onnx_ir/passes/common/shape_inference_test.py index 5a2f02c6..f887e1dd 100644 --- a/src/onnx_ir/passes/common/shape_inference_test.py +++ b/src/onnx_ir/passes/common/shape_inference_test.py @@ -6,7 +6,7 @@ import numpy as np -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import _c_api_utils, shape_inference diff --git a/src/onnx_ir/passes/common/topological_sort.py b/src/onnx_ir/passes/common/topological_sort.py index 9be183cf..56bee7ff 100644 --- a/src/onnx_ir/passes/common/topological_sort.py +++ b/src/onnx_ir/passes/common/topological_sort.py @@ -9,7 +9,7 @@ ] -from onnxscript import ir +import onnx_ir as ir class TopologicalSortPass(ir.passes.InPlacePass): diff --git a/src/onnx_ir/passes/common/topological_sort_test.py b/src/onnx_ir/passes/common/topological_sort_test.py index 8680761f..e050fad4 100644 --- a/src/onnx_ir/passes/common/topological_sort_test.py +++ b/src/onnx_ir/passes/common/topological_sort_test.py @@ -4,7 +4,7 @@ import unittest -from onnxscript import ir +import onnx_ir as ir from onnxscript.ir.passes.common import topological_sort diff --git a/src/onnx_ir/passes/common/unused_removal.py b/src/onnx_ir/passes/common/unused_removal.py index fe9cc28b..8fb10a6c 100644 --- a/src/onnx_ir/passes/common/unused_removal.py +++ b/src/onnx_ir/passes/common/unused_removal.py @@ -12,7 +12,7 @@ import onnx -from onnxscript import ir +import onnx_ir as ir logger = logging.getLogger(__name__) diff --git a/src/onnx_ir/passes/common/unused_removal_test.py b/src/onnx_ir/passes/common/unused_removal_test.py index 04d55455..3c4fd2d6 100644 --- a/src/onnx_ir/passes/common/unused_removal_test.py +++ b/src/onnx_ir/passes/common/unused_removal_test.py @@ -6,7 +6,7 @@ import parameterized import onnxscript.optimizer -from onnxscript import ir +import onnx_ir as ir @parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) diff --git a/src/onnx_ir/serde.py b/src/onnx_ir/serde.py index b5be445a..62e1a9e8 100644 --- a/src/onnx_ir/serde.py +++ b/src/onnx_ir/serde.py @@ -68,7 +68,7 @@ import onnx import onnx.external_data_helper -from onnxscript.ir import _core, _enums, _protocols, _type_casting +from onnx_ir import _core, _enums, _protocols, _type_casting if typing.TYPE_CHECKING: import google.protobuf.internal.containers as proto_containers diff --git a/src/onnx_ir/serde_test.py b/src/onnx_ir/serde_test.py index 303f0276..9a7f13a3 100644 --- a/src/onnx_ir/serde_test.py +++ b/src/onnx_ir/serde_test.py @@ -8,9 +8,9 @@ import onnx import parameterized -from onnxscript import ir +import onnx_ir as ir from onnxscript._internal import version_utils -from onnxscript.ir import serde +from onnx_ir import serde class ConvenienceFunctionsTest(unittest.TestCase): diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 0a74e0a7..aedc7767 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -9,7 +9,7 @@ Example:: import torch - from onnxscript import ir + import onnx_ir as ir # Create a PyTorch tensor torch_tensor = torch.tensor([1, 2, 3]) @@ -37,8 +37,8 @@ import numpy.typing as npt -from onnxscript import ir -from onnxscript.ir import _core +import onnx_ir as ir +from onnx_ir import _core if TYPE_CHECKING: import torch diff --git a/src/onnx_ir/tensor_adapters_test.py b/src/onnx_ir/tensor_adapters_test.py index 4898cb42..262e2b3e 100644 --- a/src/onnx_ir/tensor_adapters_test.py +++ b/src/onnx_ir/tensor_adapters_test.py @@ -12,7 +12,7 @@ import parameterized import torch -from onnxscript.ir import tensor_adapters +from onnx_ir import tensor_adapters def skip_if_no(module_name: str): diff --git a/src/onnx_ir/traversal.py b/src/onnx_ir/traversal.py index 5fa9a9ac..1fed530e 100644 --- a/src/onnx_ir/traversal.py +++ b/src/onnx_ir/traversal.py @@ -12,7 +12,7 @@ from typing_extensions import Self -from onnxscript.ir import _core, _enums +from onnx_ir import _core, _enums GraphLike = Union[_core.Graph, _core.Function, _core.GraphView] diff --git a/src/onnx_ir/traversal_test.py b/src/onnx_ir/traversal_test.py index 5ed4d314..72f3687b 100644 --- a/src/onnx_ir/traversal_test.py +++ b/src/onnx_ir/traversal_test.py @@ -6,8 +6,8 @@ import parameterized -from onnxscript import ir -from onnxscript.ir import traversal +import onnx_ir as ir +from onnx_ir import traversal class RecursiveGraphIteratorTest(unittest.TestCase): diff --git a/tests/ir/graph_view_test.py b/tests/ir/graph_view_test.py index 83a51cda..542d487a 100644 --- a/tests/ir/graph_view_test.py +++ b/tests/ir/graph_view_test.py @@ -5,7 +5,7 @@ import onnx -from onnxscript import ir +import onnx_ir as ir class GraphViewTest(unittest.TestCase): diff --git a/tests/ir/serde_roundtrip_test.py b/tests/ir/serde_roundtrip_test.py index 69d23d69..8e747205 100644 --- a/tests/ir/serde_roundtrip_test.py +++ b/tests/ir/serde_roundtrip_test.py @@ -11,7 +11,7 @@ import parameterized import onnxscript.testing -from onnxscript import ir +import onnx_ir as ir model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" onnx_backend_test_path = pathlib.Path(onnx.backend.test.__file__).parent / "data" diff --git a/tools/model_zoo_test/model_zoo_test.py b/tools/model_zoo_test/model_zoo_test.py index 82d7a540..f32e891c 100644 --- a/tools/model_zoo_test/model_zoo_test.py +++ b/tools/model_zoo_test/model_zoo_test.py @@ -23,7 +23,7 @@ from onnx import hub import onnxscript.testing -from onnxscript import ir +import onnx_ir as ir def test_model(model_info: hub.ModelInfo) -> float: From 51b608580a1f7bdfb4d51ca409b675500918842f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 08:56:53 -0700 Subject: [PATCH 02/16] License Signed-off-by: Justin Chu --- .github/workflows/lint.yml | 4 ---- .github/workflows/scorecard.yml | 4 ---- REUSE.toml | 1 - src/onnx_ir/__init__.py | 4 ++-- src/onnx_ir/_convenience/__init__.py | 4 ++-- src/onnx_ir/_convenience/_constructors.py | 4 ++-- src/onnx_ir/_convenience/_constructors_test.py | 4 ++-- src/onnx_ir/_core.py | 4 ++-- src/onnx_ir/_core_test.py | 4 ++-- src/onnx_ir/_display.py | 4 ++-- src/onnx_ir/_display_test.py | 4 ++-- src/onnx_ir/_enums.py | 4 ++-- src/onnx_ir/_enums_test.py | 4 ++-- src/onnx_ir/_graph_comparison.py | 4 ++-- src/onnx_ir/_graph_containers.py | 4 ++-- src/onnx_ir/_internal/version_utils.py | 4 ++-- src/onnx_ir/_io.py | 4 ++-- src/onnx_ir/_io_test.py | 4 ++-- src/onnx_ir/_linked_list.py | 4 ++-- src/onnx_ir/_linked_list_test.py | 4 ++-- src/onnx_ir/_metadata.py | 4 ++-- src/onnx_ir/_name_authority.py | 4 ++-- src/onnx_ir/_name_authority_test.py | 4 ++-- src/onnx_ir/_polyfill.py | 4 ++-- src/onnx_ir/_protocols.py | 4 ++-- src/onnx_ir/_schemas.py | 4 ++-- src/onnx_ir/_schemas_test.py | 4 ++-- src/onnx_ir/_tape.py | 4 ++-- src/onnx_ir/_tape_test.py | 4 ++-- src/onnx_ir/_thirdparty/asciichartpy.py | 4 ++-- src/onnx_ir/_type_casting.py | 4 ++-- src/onnx_ir/_type_casting_test.py | 4 ++-- src/onnx_ir/convenience.py | 4 ++-- src/onnx_ir/external_data.py | 4 ++-- src/onnx_ir/external_data_test.py | 4 ++-- src/onnx_ir/passes/__init__.py | 4 ++-- src/onnx_ir/passes/_pass_infra.py | 4 ++-- src/onnx_ir/passes/_pass_infra_test.py | 4 ++-- src/onnx_ir/passes/common/__init__.py | 4 ++-- src/onnx_ir/passes/common/_c_api_utils.py | 4 ++-- src/onnx_ir/passes/common/clear_metadata_and_docstring.py | 4 ++-- .../passes/common/clear_metadata_and_docstring_test.py | 4 ++-- src/onnx_ir/passes/common/constant_manipulation.py | 4 ++-- src/onnx_ir/passes/common/constant_manipulation_test.py | 4 ++-- src/onnx_ir/passes/common/inliner.py | 4 ++-- src/onnx_ir/passes/common/inliner_test.py | 4 ++-- src/onnx_ir/passes/common/onnx_checker.py | 4 ++-- src/onnx_ir/passes/common/onnx_checker_test.py | 4 ++-- src/onnx_ir/passes/common/shape_inference.py | 4 ++-- src/onnx_ir/passes/common/shape_inference_test.py | 4 ++-- src/onnx_ir/passes/common/topological_sort.py | 4 ++-- src/onnx_ir/passes/common/topological_sort_test.py | 4 ++-- src/onnx_ir/passes/common/unused_removal.py | 4 ++-- src/onnx_ir/passes/common/unused_removal_test.py | 4 ++-- src/onnx_ir/serde.py | 4 ++-- src/onnx_ir/serde_test.py | 4 ++-- src/onnx_ir/tape.py | 4 ++-- src/onnx_ir/tensor_adapters.py | 4 ++-- src/onnx_ir/tensor_adapters_test.py | 4 ++-- src/onnx_ir/testing.py | 3 +++ src/onnx_ir/traversal.py | 4 ++-- src/onnx_ir/traversal_test.py | 4 ++-- tests/ir/graph_view_test.py | 4 ++-- tests/ir/public_api_test.py | 4 ++-- tests/ir/serde_roundtrip_test.py | 4 ++-- tools/model_zoo_test/model_zoo_test.py | 4 ++-- 66 files changed, 127 insertions(+), 133 deletions(-) create mode 100644 src/onnx_ir/testing.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index f37555eb..712042bf 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,7 +1,3 @@ -# Copyright (c) ONNX Project Contributors -# -# SPDX-License-Identifier: Apache-2.0 - name: Lint on: diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml index 875aae03..163e31a3 100644 --- a/.github/workflows/scorecard.yml +++ b/.github/workflows/scorecard.yml @@ -1,7 +1,3 @@ -# Copyright (c) ONNX Project Contributors -# -# SPDX-License-Identifier: Apache-2.0 - # This workflow uses actions that are not certified by GitHub. They are provided # by a third-party and are governed by separate terms of service, privacy # policy, and support documentation. diff --git a/REUSE.toml b/REUSE.toml index 21763642..49735f3b 100644 --- a/REUSE.toml +++ b/REUSE.toml @@ -1,5 +1,4 @@ # Copyright (c) ONNX Project Contributors -# # SPDX-License-Identifier: Apache-2.0 version = 1 diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index a73baae5..46f92eec 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """In-memory intermediate representation for ONNX graphs.""" __all__ = [ diff --git a/src/onnx_ir/_convenience/__init__.py b/src/onnx_ir/_convenience/__init__.py index 42ff3e55..053ae39f 100644 --- a/src/onnx_ir/_convenience/__init__.py +++ b/src/onnx_ir/_convenience/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Convenience methods for constructing and manipulating the IR. This is an internal only module. We should choose to expose some of the methods diff --git a/src/onnx_ir/_convenience/_constructors.py b/src/onnx_ir/_convenience/_constructors.py index 7e66fa24..312727d2 100644 --- a/src/onnx_ir/_convenience/_constructors.py +++ b/src/onnx_ir/_convenience/_constructors.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Convenience constructors for IR objects.""" from __future__ import annotations diff --git a/src/onnx_ir/_convenience/_constructors_test.py b/src/onnx_ir/_convenience/_constructors_test.py index 0723a619..f133aa03 100644 --- a/src/onnx_ir/_convenience/_constructors_test.py +++ b/src/onnx_ir/_convenience/_constructors_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Unit tests for the _constructors module.""" import unittest diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index f36ceed6..5d650bcc 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """data structures for the intermediate representation.""" # NOTES for developers: diff --git a/src/onnx_ir/_core_test.py b/src/onnx_ir/_core_test.py index 8fd3830d..ceaa3706 100644 --- a/src/onnx_ir/_core_test.py +++ b/src/onnx_ir/_core_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import copy diff --git a/src/onnx_ir/_display.py b/src/onnx_ir/_display.py index 2fc62114..8bb8c800 100644 --- a/src/onnx_ir/_display.py +++ b/src/onnx_ir/_display.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Internal utilities for displaying the intermediate representation of a model. NOTE: All third-party imports should be scoped and imported only when used to avoid diff --git a/src/onnx_ir/_display_test.py b/src/onnx_ir/_display_test.py index ee745b48..8f831bf0 100644 --- a/src/onnx_ir/_display_test.py +++ b/src/onnx_ir/_display_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Test display() methods in various classes.""" import contextlib diff --git a/src/onnx_ir/_enums.py b/src/onnx_ir/_enums.py index 9ecce9fe..26ecaa47 100644 --- a/src/onnx_ir/_enums.py +++ b/src/onnx_ir/_enums.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """ONNX IR enums that matches the ONNX spec.""" from __future__ import annotations diff --git a/src/onnx_ir/_enums_test.py b/src/onnx_ir/_enums_test.py index 96f32b06..1e0ea85f 100644 --- a/src/onnx_ir/_enums_test.py +++ b/src/onnx_ir/_enums_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 # pylint: disable=protected-access import unittest diff --git a/src/onnx_ir/_graph_comparison.py b/src/onnx_ir/_graph_comparison.py index b0877433..9dec423f 100644 --- a/src/onnx_ir/_graph_comparison.py +++ b/src/onnx_ir/_graph_comparison.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Utilities for comparing IR graphs.""" from __future__ import annotations diff --git a/src/onnx_ir/_graph_containers.py b/src/onnx_ir/_graph_containers.py index a5b6796b..cb95ea01 100644 --- a/src/onnx_ir/_graph_containers.py +++ b/src/onnx_ir/_graph_containers.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Tracked containers for graph.""" # pylint: disable=protected-access diff --git a/src/onnx_ir/_internal/version_utils.py b/src/onnx_ir/_internal/version_utils.py index 390f7ee3..f68e242c 100644 --- a/src/onnx_ir/_internal/version_utils.py +++ b/src/onnx_ir/_internal/version_utils.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Version utils for testing.""" from __future__ import annotations diff --git a/src/onnx_ir/_io.py b/src/onnx_ir/_io.py index f1b9cdc2..74f5323e 100644 --- a/src/onnx_ir/_io.py +++ b/src/onnx_ir/_io.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Load and save ONNX models.""" from __future__ import annotations diff --git a/src/onnx_ir/_io_test.py b/src/onnx_ir/_io_test.py index 52feab80..f9772035 100644 --- a/src/onnx_ir/_io_test.py +++ b/src/onnx_ir/_io_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Unit tests for the _io module.""" import os diff --git a/src/onnx_ir/_linked_list.py b/src/onnx_ir/_linked_list.py index 0db770e2..398d5876 100644 --- a/src/onnx_ir/_linked_list.py +++ b/src/onnx_ir/_linked_list.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Mutable list for nodes in a graph with safe mutation properties.""" from __future__ import annotations diff --git a/src/onnx_ir/_linked_list_test.py b/src/onnx_ir/_linked_list_test.py index bdc51d64..0ff80449 100644 --- a/src/onnx_ir/_linked_list_test.py +++ b/src/onnx_ir/_linked_list_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Unit tests for the _linked_list module.""" from __future__ import annotations diff --git a/src/onnx_ir/_metadata.py b/src/onnx_ir/_metadata.py index 77db7cc4..6e4597bf 100644 --- a/src/onnx_ir/_metadata.py +++ b/src/onnx_ir/_metadata.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Class for storing metadata about the IR objects.""" from __future__ import annotations diff --git a/src/onnx_ir/_name_authority.py b/src/onnx_ir/_name_authority.py index 0f058f67..6713422f 100644 --- a/src/onnx_ir/_name_authority.py +++ b/src/onnx_ir/_name_authority.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Auxiliary class for managing names in the IR.""" from __future__ import annotations diff --git a/src/onnx_ir/_name_authority_test.py b/src/onnx_ir/_name_authority_test.py index c6661639..9cb14fa3 100644 --- a/src/onnx_ir/_name_authority_test.py +++ b/src/onnx_ir/_name_authority_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 import unittest import onnx_ir as ir diff --git a/src/onnx_ir/_polyfill.py b/src/onnx_ir/_polyfill.py index fb6008db..236c6dc8 100644 --- a/src/onnx_ir/_polyfill.py +++ b/src/onnx_ir/_polyfill.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Polyfill for Python builtin functions.""" import sys diff --git a/src/onnx_ir/_protocols.py b/src/onnx_ir/_protocols.py index 688c4226..493e657e 100644 --- a/src/onnx_ir/_protocols.py +++ b/src/onnx_ir/_protocols.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Protocols for the ONNX IR. This file defines the interfaces for tools to interact with the IR. The interfaces diff --git a/src/onnx_ir/_schemas.py b/src/onnx_ir/_schemas.py index 6cfa15c8..6849a640 100644 --- a/src/onnx_ir/_schemas.py +++ b/src/onnx_ir/_schemas.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import collections.abc diff --git a/src/onnx_ir/_schemas_test.py b/src/onnx_ir/_schemas_test.py index 79b9ae1f..184061d2 100644 --- a/src/onnx_ir/_schemas_test.py +++ b/src/onnx_ir/_schemas_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import unittest diff --git a/src/onnx_ir/_tape.py b/src/onnx_ir/_tape.py index 4997e2f1..f4ba2297 100644 --- a/src/onnx_ir/_tape.py +++ b/src/onnx_ir/_tape.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Convenience methods for constructing the IR.""" from __future__ import annotations diff --git a/src/onnx_ir/_tape_test.py b/src/onnx_ir/_tape_test.py index 31e3d2c1..d58bdbc4 100644 --- a/src/onnx_ir/_tape_test.py +++ b/src/onnx_ir/_tape_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import unittest diff --git a/src/onnx_ir/_thirdparty/asciichartpy.py b/src/onnx_ir/_thirdparty/asciichartpy.py index 88c46202..62e9764c 100644 --- a/src/onnx_ir/_thirdparty/asciichartpy.py +++ b/src/onnx_ir/_thirdparty/asciichartpy.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 # # Copyright © 2016 Igor Kroitor # diff --git a/src/onnx_ir/_type_casting.py b/src/onnx_ir/_type_casting.py index 20bab690..a6d460aa 100644 --- a/src/onnx_ir/_type_casting.py +++ b/src/onnx_ir/_type_casting.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Numpy utilities for non-native type operation.""" # TODO(justinchuby): Upstream the logic to onnx diff --git a/src/onnx_ir/_type_casting_test.py b/src/onnx_ir/_type_casting_test.py index f9ce9ca9..31fde139 100644 --- a/src/onnx_ir/_type_casting_test.py +++ b/src/onnx_ir/_type_casting_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 import unittest import numpy as np diff --git a/src/onnx_ir/convenience.py b/src/onnx_ir/convenience.py index 480ff603..d1ffb1c5 100644 --- a/src/onnx_ir/convenience.py +++ b/src/onnx_ir/convenience.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Convenience methods for constructing and manipulating the IR.""" from __future__ import annotations diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index b6570f2c..8525e84f 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """External data related utilities.""" from __future__ import annotations diff --git a/src/onnx_ir/external_data_test.py b/src/onnx_ir/external_data_test.py index ee803e6d..f778b513 100644 --- a/src/onnx_ir/external_data_test.py +++ b/src/onnx_ir/external_data_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 import os import sys import tempfile diff --git a/src/onnx_ir/passes/__init__.py b/src/onnx_ir/passes/__init__.py index 8a18c1b7..9b8516dd 100644 --- a/src/onnx_ir/passes/__init__.py +++ b/src/onnx_ir/passes/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 __all__ = [ "PassBase", diff --git a/src/onnx_ir/passes/_pass_infra.py b/src/onnx_ir/passes/_pass_infra.py index 1a0195ad..cf0fad1b 100644 --- a/src/onnx_ir/passes/_pass_infra.py +++ b/src/onnx_ir/passes/_pass_infra.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 # # This module implements some APIs described in # https://pytorch.org/executorch/stable/compiler-custom-compiler-passes.html diff --git a/src/onnx_ir/passes/_pass_infra_test.py b/src/onnx_ir/passes/_pass_infra_test.py index a60beb97..087b1196 100644 --- a/src/onnx_ir/passes/_pass_infra_test.py +++ b/src/onnx_ir/passes/_pass_infra_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index d1b4f176..07da4c55 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 __all__ = [ "AddInitializersToInputsPass", diff --git a/src/onnx_ir/passes/common/_c_api_utils.py b/src/onnx_ir/passes/common/_c_api_utils.py index 4f728262..ef345d13 100644 --- a/src/onnx_ir/passes/common/_c_api_utils.py +++ b/src/onnx_ir/passes/common/_c_api_utils.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Utilities for interfacing with onnx C APIs.""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring.py index 542423e9..501e46b8 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Clear all metadata and docstring from the model, graphs, nodes, and functions.""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py index 3a1a40a4..3d9283c2 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import unittest diff --git a/src/onnx_ir/passes/common/constant_manipulation.py b/src/onnx_ir/passes/common/constant_manipulation.py index 2308e2e9..008d145f 100644 --- a/src/onnx_ir/passes/common/constant_manipulation.py +++ b/src/onnx_ir/passes/common/constant_manipulation.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Lift constants to initializers.""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index 60899715..a65a74d2 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import unittest diff --git a/src/onnx_ir/passes/common/inliner.py b/src/onnx_ir/passes/common/inliner.py index 69bebe31..0e50b569 100644 --- a/src/onnx_ir/passes/common/inliner.py +++ b/src/onnx_ir/passes/common/inliner.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Implementation of an inliner for onnxscript.ir""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/inliner_test.py b/src/onnx_ir/passes/common/inliner_test.py index ed98ba51..edccf928 100644 --- a/src/onnx_ir/passes/common/inliner_test.py +++ b/src/onnx_ir/passes/common/inliner_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Tests for the inliner pass.""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/onnx_checker.py b/src/onnx_ir/passes/common/onnx_checker.py index 668333a4..981dfdb6 100644 --- a/src/onnx_ir/passes/common/onnx_checker.py +++ b/src/onnx_ir/passes/common/onnx_checker.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Passes for debugging purposes.""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/onnx_checker_test.py b/src/onnx_ir/passes/common/onnx_checker_test.py index 5033c1c0..bdf3f5e7 100644 --- a/src/onnx_ir/passes/common/onnx_checker_test.py +++ b/src/onnx_ir/passes/common/onnx_checker_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import unittest diff --git a/src/onnx_ir/passes/common/shape_inference.py b/src/onnx_ir/passes/common/shape_inference.py index 158fe341..b4f0c2bc 100644 --- a/src/onnx_ir/passes/common/shape_inference.py +++ b/src/onnx_ir/passes/common/shape_inference.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Shape inference pass using onnx.shape_inference.""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/shape_inference_test.py b/src/onnx_ir/passes/common/shape_inference_test.py index f887e1dd..55a7b034 100644 --- a/src/onnx_ir/passes/common/shape_inference_test.py +++ b/src/onnx_ir/passes/common/shape_inference_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import unittest diff --git a/src/onnx_ir/passes/common/topological_sort.py b/src/onnx_ir/passes/common/topological_sort.py index 56bee7ff..0cedca0c 100644 --- a/src/onnx_ir/passes/common/topological_sort.py +++ b/src/onnx_ir/passes/common/topological_sort.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Pass for topologically sorting the graphs.""" from __future__ import annotations diff --git a/src/onnx_ir/passes/common/topological_sort_test.py b/src/onnx_ir/passes/common/topological_sort_test.py index e050fad4..6a287d26 100644 --- a/src/onnx_ir/passes/common/topological_sort_test.py +++ b/src/onnx_ir/passes/common/topological_sort_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Unit tests for the TopologicalSortPass.""" import unittest diff --git a/src/onnx_ir/passes/common/unused_removal.py b/src/onnx_ir/passes/common/unused_removal.py index 8fb10a6c..456e1bea 100644 --- a/src/onnx_ir/passes/common/unused_removal.py +++ b/src/onnx_ir/passes/common/unused_removal.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations __all__ = [ diff --git a/src/onnx_ir/passes/common/unused_removal_test.py b/src/onnx_ir/passes/common/unused_removal_test.py index 3c4fd2d6..d1dd06a3 100644 --- a/src/onnx_ir/passes/common/unused_removal_test.py +++ b/src/onnx_ir/passes/common/unused_removal_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 import unittest import onnx diff --git a/src/onnx_ir/serde.py b/src/onnx_ir/serde.py index 62e1a9e8..bb0f1af9 100644 --- a/src/onnx_ir/serde.py +++ b/src/onnx_ir/serde.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Serialize and deserialize the intermediate representation to/from ONNX protos.""" # NOTES for developers: diff --git a/src/onnx_ir/serde_test.py b/src/onnx_ir/serde_test.py index 9a7f13a3..7484bae8 100644 --- a/src/onnx_ir/serde_test.py +++ b/src/onnx_ir/serde_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 import unittest import google.protobuf.text_format diff --git a/src/onnx_ir/tape.py b/src/onnx_ir/tape.py index 9270dcdc..96cc98f9 100644 --- a/src/onnx_ir/tape.py +++ b/src/onnx_ir/tape.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Taping module to facilitate building IR graphs.""" # NOTE: Be *selective* about what this module exports because it is part of the public API. diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index aedc7767..fcee58f7 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Compatible adapters implementing the TensorProtocol interface for various framework tensor types. This module provides public classes that implement the :class:`onnxscript.ir.TensorProtocol` diff --git a/src/onnx_ir/tensor_adapters_test.py b/src/onnx_ir/tensor_adapters_test.py index 262e2b3e..970e106c 100644 --- a/src/onnx_ir/tensor_adapters_test.py +++ b/src/onnx_ir/tensor_adapters_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Unit tests for the tensor_adapters module.""" from __future__ import annotations diff --git a/src/onnx_ir/testing.py b/src/onnx_ir/testing.py new file mode 100644 index 00000000..10e60984 --- /dev/null +++ b/src/onnx_ir/testing.py @@ -0,0 +1,3 @@ +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 +"""Utilities for testing.""" diff --git a/src/onnx_ir/traversal.py b/src/onnx_ir/traversal.py index 1fed530e..80d72cad 100644 --- a/src/onnx_ir/traversal.py +++ b/src/onnx_ir/traversal.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Utilities for traversing the IR graph.""" from __future__ import annotations diff --git a/src/onnx_ir/traversal_test.py b/src/onnx_ir/traversal_test.py index 72f3687b..03cc7496 100644 --- a/src/onnx_ir/traversal_test.py +++ b/src/onnx_ir/traversal_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import unittest diff --git a/tests/ir/graph_view_test.py b/tests/ir/graph_view_test.py index 542d487a..9812666d 100644 --- a/tests/ir/graph_view_test.py +++ b/tests/ir/graph_view_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 import pathlib import unittest diff --git a/tests/ir/public_api_test.py b/tests/ir/public_api_test.py index ac2655cf..65e3381d 100644 --- a/tests/ir/public_api_test.py +++ b/tests/ir/public_api_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/pytorch/pytorch/blob/b505e8647547f029d0f7df408ee5f2968f757f89/test/test_public_bindings.py#L523 # Original code PyTorch license https://github.com/pytorch/pytorch/blob/main/LICENSE diff --git a/tests/ir/serde_roundtrip_test.py b/tests/ir/serde_roundtrip_test.py index 8e747205..c29c8ed3 100644 --- a/tests/ir/serde_roundtrip_test.py +++ b/tests/ir/serde_roundtrip_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 # pylint: disable=import-outside-toplevel from __future__ import annotations diff --git a/tools/model_zoo_test/model_zoo_test.py b/tools/model_zoo_test/model_zoo_test.py index f32e891c..8db93352 100644 --- a/tools/model_zoo_test/model_zoo_test.py +++ b/tools/model_zoo_test/model_zoo_test.py @@ -1,5 +1,5 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. +# Copyright (c) ONNX Project Contributors +# SPDX-License-Identifier: Apache-2.0 """Test IR roundtrip with ONNX model zoo. Usage: From 79400ec0e674a8b136a95ce297b1683b966ee994 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 09:31:11 -0700 Subject: [PATCH 03/16] More fixes Signed-off-by: Justin Chu --- .github/workflows/lint.yml | 2 +- src/onnx_ir/__init__.py | 12 +- .../_convenience/_constructors_test.py | 2 +- src/onnx_ir/_core.py | 20 +- src/onnx_ir/_display_test.py | 2 +- src/onnx_ir/_graph_containers.py | 6 +- src/onnx_ir/_io.py | 2 +- src/onnx_ir/_schemas.py | 548 ------------------ src/onnx_ir/_schemas_test.py | 176 ------ .../version_utils.py => _version_utils.py} | 27 - src/onnx_ir/convenience.py | 2 +- src/onnx_ir/external_data.py | 2 +- src/onnx_ir/passes/__init__.py | 2 +- src/onnx_ir/passes/_pass_infra_test.py | 2 +- src/onnx_ir/passes/common/__init__.py | 14 +- .../clear_metadata_and_docstring_test.py | 2 +- .../common/constant_manipulation_test.py | 2 +- src/onnx_ir/passes/common/inliner.py | 4 +- src/onnx_ir/passes/common/inliner_test.py | 2 +- src/onnx_ir/passes/common/onnx_checker.py | 2 +- .../passes/common/onnx_checker_test.py | 2 +- src/onnx_ir/passes/common/shape_inference.py | 2 +- .../passes/common/shape_inference_test.py | 2 +- .../passes/common/topological_sort_test.py | 2 +- .../passes/common/unused_removal_test.py | 14 +- src/onnx_ir/serde.py | 4 +- src/onnx_ir/serde_test.py | 6 +- src/onnx_ir/tape.py | 2 +- src/onnx_ir/tensor_adapters.py | 2 +- tests/ir/public_api_test.py | 10 +- tests/ir/serde_roundtrip_test.py | 4 +- tools/model_zoo_test/model_zoo_test.py | 4 +- 32 files changed, 63 insertions(+), 822 deletions(-) delete mode 100644 src/onnx_ir/_schemas.py delete mode 100644 src/onnx_ir/_schemas_test.py rename src/onnx_ir/{_internal/version_utils.py => _version_utils.py} (81%) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 712042bf..d8c342cf 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -63,7 +63,7 @@ jobs: if ! lintrunner --force-color --all-files --tee-json=lint.json -v; then echo "" echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m" - echo -e "\e[1m\e[36mSee https://github.com/microsoft/onnxscript#coding-style for setup instructions.\e[0m" + echo -e "\e[1m\e[36mSee https://github.com/onnx/onnx_ir/blob/main/CONTRIBUTING.md for setup instructions.\e[0m" exit 1 fi - name: Produce SARIF diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index 46f92eec..3adb36b9 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -84,8 +84,8 @@ ] from onnx_ir import convenience, external_data, passes, serde, tape, traversal -from onnxscript.ir._convenience._constructors import node, tensor -from onnxscript.ir._core import ( +from onnx_ir._convenience._constructors import node, tensor +from onnx_ir._core import ( Attr, AttrFloat32, AttrFloat32s, @@ -121,12 +121,12 @@ TypeAndShape, Value, ) -from onnxscript.ir._enums import ( +from onnx_ir._enums import ( AttributeType, DataType, ) -from onnxscript.ir._io import load, save -from onnxscript.ir._protocols import ( +from onnx_ir._io import load, save +from onnx_ir._protocols import ( ArrayCompatible, AttributeProtocol, DLPackCompatible, @@ -145,7 +145,7 @@ TypeProtocol, ValueProtocol, ) -from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto +from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto def __set_module() -> None: diff --git a/src/onnx_ir/_convenience/_constructors_test.py b/src/onnx_ir/_convenience/_constructors_test.py index f133aa03..ab5e00e1 100644 --- a/src/onnx_ir/_convenience/_constructors_test.py +++ b/src/onnx_ir/_convenience/_constructors_test.py @@ -7,7 +7,7 @@ import numpy as np import onnx_ir as ir -from onnxscript.ir._convenience import _constructors +from onnx_ir._convenience import _constructors class ConstructorsTest(unittest.TestCase): diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index 5d650bcc..f794b904 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -44,7 +44,7 @@ import numpy as np from typing_extensions import TypeIs -import onnxscript +import onnx_ir from onnx_ir import ( _display, _enums, @@ -186,7 +186,7 @@ def display(self, *, page: bool = False) -> None: status_manager = rich.status.Status(f"Computing tensor stats for {self!r}") - from onnxscript._thirdparty import ( # pylint: disable=import-outside-toplevel + from onnx_ir._thirdparty import ( # pylint: disable=import-outside-toplevel asciichartpy, ) @@ -582,7 +582,7 @@ def __init__( # NOTE: Do not verify the location by default. This is because the location field # in the tensor proto can be anything and we would like deserialization from # proto to IR to not fail. - if onnxscript.DEBUG: + if onnx_ir.DEBUG: if os.path.isabs(location): raise ValueError( "The location must be a relative path. Please specify base_dir as well." @@ -2052,7 +2052,7 @@ def const_value( self, value: _protocols.TensorProtocol | None, ) -> None: - if onnxscript.DEBUG: + if onnx_ir.DEBUG: if value is not None and not isinstance(value, _protocols.TensorProtocol): raise TypeError( f"Expected value to be a TensorProtocol or None, got '{type(value)}'" @@ -2469,7 +2469,7 @@ def sort(self) -> None: ValueError: If the graph contains a cycle, making topological sorting impossible. """ # Obtain all nodes from the graph and its subgraphs for sorting - nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self)) + nodes = list(onnx_ir.traversal.RecursiveGraphIterator(self)) # Store the sorted nodes of each subgraph sorted_nodes_by_graph: dict[Graph, list[Node]] = { graph: [] for graph in {node.graph for node in nodes if node.graph is not None} @@ -2858,7 +2858,7 @@ def graphs(self) -> Iterable[Graph]: """Get all graphs and subgraphs in the model. This is a convenience method to traverse the model. Consider using - `onnxscript.ir.traversal.RecursiveGraphIterator` for more advanced + `onnx_ir.traversal.RecursiveGraphIterator` for more advanced traversals on nodes. """ # NOTE(justinchuby): Given @@ -2868,7 +2868,7 @@ def graphs(self) -> Iterable[Graph]: # I created this method as a core method instead of an iterator in # `traversal.py`. seen_graphs: set[Graph] = set() - for node in onnxscript.ir.traversal.RecursiveGraphIterator(self.graph): + for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph): if node.graph is not None and node.graph not in seen_graphs: seen_graphs.add(node.graph) yield node.graph @@ -3226,7 +3226,7 @@ def as_strings(self) -> Sequence[str]: """Get the attribute value as a sequence of strings.""" if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - if onnxscript.DEBUG: + if onnx_ir.DEBUG: if not all(isinstance(x, str) for x in self.value): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.") # Create a copy of the list to prevent mutation @@ -3236,7 +3236,7 @@ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]: """Get the attribute value as a sequence of tensors.""" if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - if onnxscript.DEBUG: + if onnx_ir.DEBUG: if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.") # Create a copy of the list to prevent mutation @@ -3246,7 +3246,7 @@ def as_graphs(self) -> Sequence[Graph]: """Get the attribute value as a sequence of graphs.""" if not isinstance(self.value, Sequence): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.") - if onnxscript.DEBUG: + if onnx_ir.DEBUG: if not all(isinstance(x, Graph) for x in self.value): raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.") # Create a copy of the list to prevent mutation diff --git a/src/onnx_ir/_display_test.py b/src/onnx_ir/_display_test.py index 8f831bf0..c77929f0 100644 --- a/src/onnx_ir/_display_test.py +++ b/src/onnx_ir/_display_test.py @@ -7,7 +7,7 @@ import numpy as np -import onnxscript.ir as ir +import onnx_ir as ir class DisplayTest(unittest.TestCase): diff --git a/src/onnx_ir/_graph_containers.py b/src/onnx_ir/_graph_containers.py index cb95ea01..59168e51 100644 --- a/src/onnx_ir/_graph_containers.py +++ b/src/onnx_ir/_graph_containers.py @@ -14,8 +14,6 @@ import collections from typing import TYPE_CHECKING, Iterable, SupportsIndex -import onnxscript - if TYPE_CHECKING: from onnx_ir import _core @@ -132,7 +130,7 @@ class GraphInputs(_GraphIO): def _check_invariance(self) -> None: """Check the invariance of the graph.""" - if not onnxscript.DEBUG: + if not onnx_ir.DEBUG: return for value in self.data: if value._graph is self._graph: @@ -170,7 +168,7 @@ class GraphOutputs(_GraphIO): def _check_invariance(self) -> None: """Check the invariance of the graph.""" - if not onnxscript.DEBUG: + if not onnx_ir.DEBUG: return for value in self.data: if value._graph is self._graph: diff --git a/src/onnx_ir/_io.py b/src/onnx_ir/_io.py index 74f5323e..044ba368 100644 --- a/src/onnx_ir/_io.py +++ b/src/onnx_ir/_io.py @@ -12,7 +12,7 @@ from onnx_ir import _core, serde from onnx_ir import external_data as _external_data -from onnxscript.ir._polyfill import zip +from onnx_ir._polyfill import zip def load(path: str | os.PathLike, format: str | None = None) -> _core.Model: diff --git a/src/onnx_ir/_schemas.py b/src/onnx_ir/_schemas.py deleted file mode 100644 index 6849a640..00000000 --- a/src/onnx_ir/_schemas.py +++ /dev/null @@ -1,548 +0,0 @@ -# Copyright (c) ONNX Project Contributors -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -import collections.abc -import dataclasses -import inspect -import logging -import types -import typing -from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union - -import onnx - -import onnxscript -import onnx_ir as ir - -logger = logging.getLogger(__name__) - - -# A special value to indicate that the default value is not specified -class _Empty: - def __repr__(self): - return "_EMPTY_DEFAULT" - - -_EMPTY_DEFAULT = _Empty() - -# Map from python type to corresponding ONNX AttributeProto type -_PY_TYPE_TO_ATTR_TYPE = { - float: ir.AttributeType.FLOAT, - int: ir.AttributeType.INT, - str: ir.AttributeType.STRING, - bool: ir.AttributeType.INT, - ir.Tensor: ir.AttributeType.TENSOR, - ir.TensorProtocol: ir.AttributeType.TENSOR, - ir.Graph: ir.AttributeType.GRAPH, - ir.GraphProtocol: ir.AttributeType.GRAPH, -} - -# Map from python type to corresponding ONNX AttributeProto type, -# for repeated (i.e., list of) values -_LIST_TYPE_TO_ATTR_TYPE = { - float: ir.AttributeType.FLOATS, - int: ir.AttributeType.INTS, - str: ir.AttributeType.STRINGS, - bool: ir.AttributeType.INTS, - ir.Tensor: ir.AttributeType.TENSORS, - ir.TensorProtocol: ir.AttributeType.TENSORS, - ir.Graph: ir.AttributeType.GRAPHS, - ir.GraphProtocol: ir.AttributeType.GRAPHS, -} - -_ALL_VALUE_TYPES = ( - {ir.TensorType(dtype) for dtype in ir.DataType} - | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} - | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} -) - -# TypeAnnotationValue represents the (value of) valid type-annotations recognized -# by ONNX Script. Currently, it supports -# - float, int, str (primitive attribute types) -# - Sequence[float], Sequence[int], Sequence[str] (attribute types) -# - Tensor types -# - Sequence[Tensor] types -# - Union of above 2 -# - TypeVars with above bounds -# - Above types with annotation attached -TypeAnnotationValue = Any - - -@dataclasses.dataclass(frozen=True) -class TypeConstraintParam: - """Type constraint for a parameter. - - Attributes: - name: Name of the parameter. E.g. "TFloat" - allowed_types: Allowed types for the parameter. - """ - - name: str - allowed_types: set[ir.TypeProtocol] - description: str = "" - - def __hash__(self) -> int: - return hash((self.name, tuple(self.allowed_types))) - - def __str__(self) -> str: - allowed_types_str = " | ".join(str(t) for t in self.allowed_types) - return f"{self.name}={allowed_types_str}" - - @classmethod - def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: - return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) - - @classmethod - def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: - return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] - - -@dataclasses.dataclass(frozen=True) -class Parameter: - """A formal parameter of an operator.""" - - name: str - type_constraint: TypeConstraintParam - required: bool - variadic: bool - default: Any = _EMPTY_DEFAULT - # TODO: Add other properties too - - def __str__(self) -> str: - type_str = self.type_constraint.name - if self.has_default(): - return f"{self.name}: {type_str} = {self.default}" - return f"{self.name}: {type_str}" - - def has_default(self) -> bool: - return self.default is not _EMPTY_DEFAULT - - -@dataclasses.dataclass(frozen=True) -class AttributeParameter: - """A parameter in the function signature that represents an ONNX attribute.""" - - name: str - type: ir.AttributeType - required: bool - default: ir.Attr | None = None - - def __str__(self) -> str: - type_str = self.type.name - if self.has_default(): - return f"{self.name}: {type_str} = {self.default}" - return f"{self.name}: {type_str}" - - def has_default(self) -> bool: - return self.default is not None - - -def _get_type_from_str( - type_str: str, -) -> ir.TensorType | ir.SequenceType | ir.OptionalType: - """Converter a type_str from ONNX OpSchema to ir.TypeProtocol. - - A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". - """ - # Split the type_str a sequence types and dtypes - # 1. Remove the ending ")" - striped = type_str.rstrip(")") - # 2. Split the type_str by "(" - type_parts = striped.split("(") - - # Convert the dtype to ir.DataType - dtype = ir.DataType[type_parts[-1].upper()] - - # Create a place holder type first - type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) - - # Construct the type - for type_part in reversed(type_parts[:-1]): - if type_part == "tensor": - type_ = ir.TensorType(dtype) - elif type_part == "seq": - type_ = ir.SequenceType(type_) - elif type_part == "optional": - type_ = ir.OptionalType(type_) - else: - raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") - return type_ # type: ignore[return-value] - - -def _convert_formal_parameter( - param: onnx.defs.OpSchema.FormalParameter, - type_constraints: Mapping[str, TypeConstraintParam], -) -> Parameter: - """Convert a formal parameter from ONNX OpSchema to Parameter.""" - if param.type_str in type_constraints: - type_constraint = type_constraints[param.type_str] - else: - # param.type_str can be a plain type like 'int64'. - type_constraint = TypeConstraintParam( - name=param.name, - allowed_types={_get_type_from_str(param.type_str)}, - ) - return Parameter( - name=param.name, - type_constraint=type_constraint, - required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, - variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, - ) - - -def _is_optional(type_: type) -> bool: - """Returns whether a type_ is an Optional.""" - origin_type = typing.get_origin(type_) - if origin_type is Union and type(None) in typing.get_args(type_): - # Python < 3.10 - return True - if origin_type is Optional: - # Python >= 3.10 - return True - if ( - hasattr(types, "UnionType") - and origin_type is types.UnionType - and type(None) in typing.get_args(type_) - ): - # Python >= 3.10 - return True - return False - - -def _get_attr_type(type_: type) -> ir.AttributeType: - """Obtain the type of the attribute from a Python class.""" - try: - if type_ in _PY_TYPE_TO_ATTR_TYPE: - return _PY_TYPE_TO_ATTR_TYPE[type_] - origin_type = typing.get_origin(type_) - if origin_type is None: - return ir.AttributeType.UNDEFINED - if origin_type in ( - collections.abc.Sequence, - Sequence, - typing.List, - list, - typing.Tuple, - tuple, - ): - inner_type = typing.get_args(type_)[0] - if inner_type in _LIST_TYPE_TO_ATTR_TYPE: - return _LIST_TYPE_TO_ATTR_TYPE[inner_type] - except TypeError: - logger.warning("TypeError when checking %s.", type_, exc_info=True) - return ir.AttributeType.UNDEFINED - - -def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: - """Returns the name of the type constraint for a given type annotation. - - Args: - type_: A Python type. - - Returns: - The name of the type constraint if it is a TypeVar. - - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. - """ - if isinstance(type_, TypeVar): - return type_.__name__ - if _is_optional(type_): - subtypes = typing.get_args(type_) - for subtype in subtypes: - if subtype is type(None): - continue - type_param_name = _get_type_constraint_name(subtype) - return type_param_name if type_param_name else None - origin_type = typing.get_origin(type_) - if isinstance(origin_type, type) and issubclass(origin_type, Sequence): - subtypes = typing.get_args(type_) - type_param_name = _get_type_constraint_name(subtypes[0]) - return f"Sequence_{type_param_name}" if type_param_name else None - return None - - -def _get_allowed_types_from_type_annotation( - type_: TypeAnnotationValue, -) -> set[ir.TypeProtocol]: - """Obtain the allowed types from a type annotation.""" - if type_ is onnxscript.onnx_types.TensorType: - # Any tensor type - return {ir.TensorType(dtype) for dtype in ir.DataType} - - allowed_types: set[ir.TypeProtocol] - - if isinstance(type_, TypeVar): - allowed_types = set() - if constraints := type_.__constraints__: - for constraint in constraints: - allowed_types.update(_get_allowed_types_from_type_annotation(constraint)) - else: - bound = type_.__bound__ - if bound is None: - allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] - else: - allowed_types.update(_get_allowed_types_from_type_annotation(bound)) - return allowed_types - if hasattr(type_, "dtype"): - # A single tensor type like INT64, FLOAT, etc. - return {ir.TensorType(ir.DataType(type_.dtype))} - if _is_optional(type_): - allowed_types = set() - subtypes = typing.get_args(type_) - for subtype in subtypes: - if subtype is type(None): - continue - allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) - # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. - return allowed_types - - origin_type = typing.get_origin(type_) - if origin_type is Union: - allowed_types = set() - subtypes = typing.get_args(type_) - for subtype in subtypes: - assert subtype is not type(None), ( - "Union should not contain None type because it is handled by _is_optional." - ) - allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) - return allowed_types - - if isinstance(origin_type, type) and issubclass(origin_type, Sequence): - subtypes = typing.get_args(type_) - return { - ir.SequenceType(t) for t in _get_allowed_types_from_type_annotation(subtypes[0]) - } - - # Allow everything by default - return _ALL_VALUE_TYPES # type: ignore[return-value] - - -@dataclasses.dataclass -class OpSignature: - """Schema for an operator. - - Attributes: - domain: Domain of the operator. E.g. "". - name: Name of the operator. E.g. "Add". - overload: Overload name of the operator. - params: Input parameters. When the op is an ONNX function definition, - the order is according to the function signature. This mean we can - interleave ONNX inputs and ONNX attributes in the list. - outputs: Output parameters. - """ - - domain: str - name: str - overload: str - params: Sequence[Parameter | AttributeParameter] - outputs: Sequence[Parameter] - params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( - init=False, repr=False - ) - - def __post_init__(self): - self.params_map = {param.name: param for param in self.params} - - def get(self, name: str) -> Parameter | AttributeParameter: - return self.params_map[name] - - def __contains__(self, name: str) -> bool: - return name in self.params_map - - def __iter__(self) -> Iterator[Parameter | AttributeParameter]: - return iter(self.params) - - def __str__(self) -> str: - domain = self.domain or "''" - # TODO: Double check the separator for overload - overload = f"::{self.overload}" if self.overload else "" - params = ", ".join(str(param) for param in self.params) - outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) - type_constraints = {} - for param in self.params: - if isinstance(param, Parameter): - type_constraints[param.type_constraint.name] = param.type_constraint - for param in self.outputs: - type_constraints[param.type_constraint.name] = param.type_constraint - type_constraints_str = ", ".join( - str(type_constraint) for type_constraint in type_constraints.values() - ) - return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" - - @classmethod - def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: - """Produce an OpSignature from an ONNX OpSchema.""" - type_constraints = { - constraint.type_param_str: TypeConstraintParam( - name=constraint.type_param_str, - allowed_types={ - _get_type_from_str(type_str) for type_str in constraint.allowed_type_strs - }, - description=constraint.description, - ) - for constraint in op_schema.type_constraints - } - - params = [ - _convert_formal_parameter(param, type_constraints) for param in op_schema.inputs - ] - - for param in op_schema.attributes.values(): - default_attr = ( - ir.serde.deserialize_attribute(param.default_value) - if param.default_value is not None - else None - ) - if default_attr is not None: - # Set the name of the default attribute because it may have a different name from the parameter - default_attr.name = param.name - params.append( - AttributeParameter( - name=param.name, - type=ir.AttributeType(param.type), # type: ignore[arg-type] - required=param.required, - default=default_attr, # type: ignore[arg-type] - ) - ) - - outputs = [ - _convert_formal_parameter(param, type_constraints) for param in op_schema.outputs - ] - - return cls( - domain=op_schema.domain, - name=op_schema.name, - overload="", - params=params, - outputs=outputs, - ) - - @classmethod - def from_function( - cls, func, domain: str, name: str | None = None, overload: str = "" - ) -> OpSignature: - """Produce an OpSignature from a function using type annotation.""" - - py_signature = inspect.signature(func) - # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases - # https://github.com/python/cpython/issues/102405 - type_hints = typing.get_type_hints(func) - - params: list[Parameter | AttributeParameter] = [] - # Create a mapping from type to a unique name - type_constraints: dict[str, TypeConstraintParam] = {} - - for param in py_signature.parameters.values(): - if param.name not in type_hints: - logger.warning( - "Missing annotation for parameter '%s' from %s. Treating as an Input.", - param.name, - py_signature, - ) - type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") - type_constraints[param.name] = type_constraint - params.append( - Parameter( - name=param.name, - type_constraint=type_constraint, - required=param.default is inspect.Parameter.empty, - # TODO: Handle variadic - variadic=False, - default=param.default - if param.default is not inspect.Parameter.empty - else _EMPTY_DEFAULT, - ) - ) - else: - type_ = type_hints[param.name] - if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: - # Construct the default attribute - if param.default is not inspect.Parameter.empty: - # TODO: Use ir_convenience instead to handle int as float - default = ir.Attr(param.name, attr_type, param.default) - else: - default = None - params.append( - AttributeParameter( - name=param.name, - type=attr_type, - required=param.default is inspect.Parameter.empty, - default=default, - ) - ) - else: - # Obtain the type constraint from the type annotation - - # 1. Get a type constraint name from the type annotation - # If the type annotation is a TypeVar or Optional[TypeVar], get its name - # Otherwise, name it T_{param.name} - type_constraint_name = _get_type_constraint_name(type_) - if type_constraint_name is None: - type_constraint_name = f"T_{param.name}" - - # 2. If the type constraint param is already initialized, use it - if type_constraint_name in type_constraints: - type_constraint = type_constraints[type_constraint_name] - else: - # 3. Otherwise, create a new TypeConstraintParam - type_constraint = TypeConstraintParam( - name=type_constraint_name, - allowed_types=_get_allowed_types_from_type_annotation(type_), - ) - type_constraints[type_constraint_name] = type_constraint - # 4. Create Parameter - params.append( - Parameter( - name=param.name, - type_constraint=type_constraint, - required=param.default is inspect.Parameter.empty, - # TODO: Handle variadic - variadic=False, - default=param.default - if param.default is not inspect.Parameter.empty - else _EMPTY_DEFAULT, - ) - ) - - return_type = type_hints.get("return") - - outputs = [] - if return_type is None: - # No returns - pass - else: - if typing.get_origin(return_type) is tuple: - # Multiple returns - return_types = typing.get_args(return_type) - else: - return_types = [return_type] # type: ignore[assignment] - - for i, return_type_i in enumerate(return_types): - if ( - return_param_name := _get_type_constraint_name(return_type_i) - ) in type_constraints: - type_constraint = type_constraints[return_param_name] - else: - return_param_name = f"TReturn{i}" - type_constraint = TypeConstraintParam( - name=return_param_name, - allowed_types=_get_allowed_types_from_type_annotation(return_type_i), - ) - type_constraints[return_param_name] = type_constraint - outputs.append( - Parameter( - name=return_param_name, - type_constraint=type_constraint, - required=True, - variadic=False, - default=_EMPTY_DEFAULT, - ) - ) - - return cls( - domain=domain, - name=name or func.__name__, - overload=overload, - params=params, - outputs=outputs, - ) diff --git a/src/onnx_ir/_schemas_test.py b/src/onnx_ir/_schemas_test.py deleted file mode 100644 index 184061d2..00000000 --- a/src/onnx_ir/_schemas_test.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) ONNX Project Contributors -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -import unittest -from typing import Any, Optional, Sequence, TypeVar, Union - -import parameterized - -import onnxscript -import onnxscript.testing -from onnxscript import FLOAT, INT64, ir -from onnx_ir import _schemas - -_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) -_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) -_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) - - -class TypeConversionFunctionsTest(unittest.TestCase): - @parameterized.parameterized.expand( - [ - ( - "tensor_type_all", - onnxscript.onnx_types.TensorType, - {ir.TensorType(dtype) for dtype in ir.DataType}, - ), - ("tensor_type", INT64, {ir.TensorType(ir.DataType.INT64)}), - ( - "tensor_type_union", - Union[INT64, FLOAT], - {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, - ), - ( - "tensor_type_variadic_shape", - INT64[...], - {ir.TensorType(ir.DataType.INT64)}, - ), - ("tensor_type_shape", INT64[10], {ir.TensorType(ir.DataType.INT64)}), - ( - "type_var_constraints", - _TestTypeVarConstraints, - {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, - ), - ( - "type_bound_one", - _TestTypeVarOneBound, - {ir.TensorType(ir.DataType.INT64)}, - ), - ( - "type_bound_two", - _TestTypeVarTwoBound, - {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, - ), - ( - "optional_tensor_type_all", - Optional[onnxscript.onnx_types.TensorType], - {ir.TensorType(dtype) for dtype in ir.DataType}, - ), - ( - "optional_tensor_type", - Optional[INT64], - {ir.TensorType(ir.DataType.INT64)}, - ), - ( - "optional_tensor_type_union", - Optional[Union[INT64, FLOAT]], - {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, - ), - ( - "optional_tensor_type_variadic_shape", - Optional[INT64[...]], - {ir.TensorType(ir.DataType.INT64)}, - ), - ( - "optional_tensor_type_shape", - Optional[INT64[10]], - {ir.TensorType(ir.DataType.INT64)}, - ), - ( - "optional_type_var_constraints", - Optional[_TestTypeVarConstraints], - {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, - ), - ( - "optional_type_bound_one", - Optional[_TestTypeVarOneBound], - {ir.TensorType(ir.DataType.INT64)}, - ), - ( - "optional_type_bound_two", - Optional[_TestTypeVarTwoBound], - {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, - ), - ( - "sequence_type_all", - Sequence[onnxscript.onnx_types.TensorType], - {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}, - ), - ( - "sequence_type", - Sequence[INT64], - {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, - ), - ( - "union_sequence_type", - Union[Sequence[INT64], Sequence[FLOAT]], - { - ir.SequenceType(ir.TensorType(ir.DataType.INT64)), - ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), - }, - ), - ( - "sequence_type_variadic_shape", - Sequence[INT64[...]], - {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, - ), - ( - "sequence_type_shape", - Sequence[INT64[10]], - {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, - ), - ( - "sequence_type_var_constraints", - Sequence[_TestTypeVarConstraints], - { - ir.SequenceType(ir.TensorType(ir.DataType.INT64)), - ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), - }, - ), - ( - "sequence_type_bound_one", - Sequence[_TestTypeVarOneBound], - {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, - ), - ( - "sequence_type_bound_two", - Sequence[_TestTypeVarTwoBound], - { - ir.SequenceType(ir.TensorType(ir.DataType.INT64)), - ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), - }, - ), - ] - ) - def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]): - self.assertEqual(_schemas._get_allowed_types_from_type_annotation(pytype), expected) # pylint: disable=protected-access - - @parameterized.parameterized.expand( - [ - ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), - ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), - ( - "optional_type_var", - Optional[_TestTypeVarOneBound], - "_TestTypeVarOneBound", - ), - ( - "sequence_type_var", - Sequence[_TestTypeVarOneBound], - "Sequence__TestTypeVarOneBound", - ), - ("normal_type", INT64, None), - ("union_type", Union[INT64, FLOAT], None), - ("optional_type", Optional[INT64], None), - ("sequence_type", Sequence[INT64], None), - ("optional_sequence_type", Optional[Sequence[INT64]], None), - ("optional_union_type", Optional[Union[INT64, FLOAT]], None), - ] - ) - def test_get_type_constraint_name(self, _: str, pytype: Any, expected: str | None): - self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) # pylint: disable=protected-access - - -if __name__ == "__main__": - unittest.main() diff --git a/src/onnx_ir/_internal/version_utils.py b/src/onnx_ir/_version_utils.py similarity index 81% rename from src/onnx_ir/_internal/version_utils.py rename to src/onnx_ir/_version_utils.py index f68e242c..759d91d6 100644 --- a/src/onnx_ir/_internal/version_utils.py +++ b/src/onnx_ir/_version_utils.py @@ -4,9 +4,6 @@ from __future__ import annotations -import warnings -from typing import Callable, Sequence - import packaging.version @@ -92,27 +89,3 @@ def has_transformers(): return True # noqa except ImportError: return False - - -def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type] - """Catches warnings. - - Args: - warns: warnings to ignore - - Returns: - decorated function - """ - - def wrapper(fct): - if warns is None: - raise AssertionError(f"warns cannot be None for '{fct}'.") - - def call_f(self): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", warns) # type: ignore[arg-type] - return fct(self) - - return call_f - - return wrapper diff --git a/src/onnx_ir/convenience.py b/src/onnx_ir/convenience.py index d1ffb1c5..2d6bffcc 100644 --- a/src/onnx_ir/convenience.py +++ b/src/onnx_ir/convenience.py @@ -12,7 +12,7 @@ "create_value_mapping", ] -from onnxscript.ir._convenience import ( +from onnx_ir._convenience import ( convert_attribute, convert_attributes, create_value_mapping, diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index 8525e84f..a94c2661 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -19,7 +19,7 @@ from onnx_ir import _core, _enums, _protocols from onnx_ir import traversal as _traversal -from onnxscript.ir._polyfill import zip +from onnx_ir._polyfill import zip # Note: If needed in future, add these as parameters to the function calls # align_offset: Offset will always be page aligned and alloction granularity aligned for mmap support. This is done by padding previous tensor data with zeros keeping same length. Tensor data will be aligned if > align_threshold diff --git a/src/onnx_ir/passes/__init__.py b/src/onnx_ir/passes/__init__.py index 9b8516dd..7f971ea3 100644 --- a/src/onnx_ir/passes/__init__.py +++ b/src/onnx_ir/passes/__init__.py @@ -15,7 +15,7 @@ "PassError", ] -from onnxscript.ir.passes._pass_infra import ( +from onnx_ir.passes._pass_infra import ( FunctionalPass, InPlacePass, InvariantError, diff --git a/src/onnx_ir/passes/_pass_infra_test.py b/src/onnx_ir/passes/_pass_infra_test.py index 087b1196..68dd5dbc 100644 --- a/src/onnx_ir/passes/_pass_infra_test.py +++ b/src/onnx_ir/passes/_pass_infra_test.py @@ -6,7 +6,7 @@ import unittest import onnx_ir as ir -from onnxscript.ir.passes import _pass_infra +from onnx_ir.passes import _pass_infra class PassBaseTest(unittest.TestCase): diff --git a/src/onnx_ir/passes/common/__init__.py b/src/onnx_ir/passes/common/__init__.py index 07da4c55..8b53813b 100644 --- a/src/onnx_ir/passes/common/__init__.py +++ b/src/onnx_ir/passes/common/__init__.py @@ -16,20 +16,20 @@ "TopologicalSortPass", ] -from onnxscript.ir.passes.common.clear_metadata_and_docstring import ( +from onnx_ir.passes.common.clear_metadata_and_docstring import ( ClearMetadataAndDocStringPass, ) -from onnxscript.ir.passes.common.constant_manipulation import ( +from onnx_ir.passes.common.constant_manipulation import ( AddInitializersToInputsPass, LiftConstantsToInitializersPass, LiftSubgraphInitializersToMainGraphPass, RemoveInitializersFromInputsPass, ) -from onnxscript.ir.passes.common.inliner import InlinePass -from onnxscript.ir.passes.common.onnx_checker import CheckerPass -from onnxscript.ir.passes.common.shape_inference import ShapeInferencePass -from onnxscript.ir.passes.common.topological_sort import TopologicalSortPass -from onnxscript.ir.passes.common.unused_removal import ( +from onnx_ir.passes.common.inliner import InlinePass +from onnx_ir.passes.common.onnx_checker import CheckerPass +from onnx_ir.passes.common.shape_inference import ShapeInferencePass +from onnx_ir.passes.common.topological_sort import TopologicalSortPass +from onnx_ir.passes.common.unused_removal import ( RemoveUnusedFunctionsPass, RemoveUnusedNodesPass, RemoveUnusedOpsetsPass, diff --git a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py index 3d9283c2..5463cbad 100644 --- a/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py +++ b/src/onnx_ir/passes/common/clear_metadata_and_docstring_test.py @@ -7,7 +7,7 @@ import numpy as np import onnx_ir as ir -from onnxscript.ir.passes.common import clear_metadata_and_docstring +from onnx_ir.passes.common import clear_metadata_and_docstring class TestClearMetadataAndDocStringPass(unittest.TestCase): diff --git a/src/onnx_ir/passes/common/constant_manipulation_test.py b/src/onnx_ir/passes/common/constant_manipulation_test.py index a65a74d2..f237e017 100644 --- a/src/onnx_ir/passes/common/constant_manipulation_test.py +++ b/src/onnx_ir/passes/common/constant_manipulation_test.py @@ -8,7 +8,7 @@ import parameterized import onnx_ir as ir -from onnxscript.ir.passes.common import constant_manipulation +from onnx_ir.passes.common import constant_manipulation class TestLiftConstantsToInitializersPass(unittest.TestCase): diff --git a/src/onnx_ir/passes/common/inliner.py b/src/onnx_ir/passes/common/inliner.py index 0e50b569..3bf65466 100644 --- a/src/onnx_ir/passes/common/inliner.py +++ b/src/onnx_ir/passes/common/inliner.py @@ -1,6 +1,6 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 -"""Implementation of an inliner for onnxscript.ir""" +"""Implementation of an inliner for onnx_ir""" from __future__ import annotations @@ -11,7 +11,7 @@ from collections import defaultdict from typing import Iterable, List, Sequence, Tuple -import onnxscript.ir.convenience as _ir_convenience +import onnx_ir.convenience as _ir_convenience import onnx_ir as ir # A replacement for a node specifies a list of nodes that replaces the original node, diff --git a/src/onnx_ir/passes/common/inliner_test.py b/src/onnx_ir/passes/common/inliner_test.py index edccf928..4ab228cc 100644 --- a/src/onnx_ir/passes/common/inliner_test.py +++ b/src/onnx_ir/passes/common/inliner_test.py @@ -10,7 +10,7 @@ import onnx import onnx_ir as ir -from onnxscript.ir.passes.common import inliner +from onnx_ir.passes.common import inliner def _name_checker(renameable: Sequence[str] | None) -> Callable[[str, str], bool]: diff --git a/src/onnx_ir/passes/common/onnx_checker.py b/src/onnx_ir/passes/common/onnx_checker.py index 981dfdb6..a8b26390 100644 --- a/src/onnx_ir/passes/common/onnx_checker.py +++ b/src/onnx_ir/passes/common/onnx_checker.py @@ -13,7 +13,7 @@ import onnx import onnx_ir as ir -from onnxscript.ir.passes.common import _c_api_utils +from onnx_ir.passes.common import _c_api_utils class CheckerPass(ir.passes.PassBase): diff --git a/src/onnx_ir/passes/common/onnx_checker_test.py b/src/onnx_ir/passes/common/onnx_checker_test.py index bdf3f5e7..2c016666 100644 --- a/src/onnx_ir/passes/common/onnx_checker_test.py +++ b/src/onnx_ir/passes/common/onnx_checker_test.py @@ -5,7 +5,7 @@ import unittest import onnx_ir as ir -from onnxscript.ir.passes.common import onnx_checker +from onnx_ir.passes.common import onnx_checker class TestCheckerPass(unittest.TestCase): diff --git a/src/onnx_ir/passes/common/shape_inference.py b/src/onnx_ir/passes/common/shape_inference.py index b4f0c2bc..365cb0d5 100644 --- a/src/onnx_ir/passes/common/shape_inference.py +++ b/src/onnx_ir/passes/common/shape_inference.py @@ -14,7 +14,7 @@ import onnx import onnx_ir as ir -from onnxscript.ir.passes.common import _c_api_utils +from onnx_ir.passes.common import _c_api_utils logger = logging.getLogger(__name__) diff --git a/src/onnx_ir/passes/common/shape_inference_test.py b/src/onnx_ir/passes/common/shape_inference_test.py index 55a7b034..8981217c 100644 --- a/src/onnx_ir/passes/common/shape_inference_test.py +++ b/src/onnx_ir/passes/common/shape_inference_test.py @@ -7,7 +7,7 @@ import numpy as np import onnx_ir as ir -from onnxscript.ir.passes.common import _c_api_utils, shape_inference +from onnx_ir.passes.common import _c_api_utils, shape_inference class TestShapeInferencePass(unittest.TestCase): diff --git a/src/onnx_ir/passes/common/topological_sort_test.py b/src/onnx_ir/passes/common/topological_sort_test.py index 6a287d26..38faca8b 100644 --- a/src/onnx_ir/passes/common/topological_sort_test.py +++ b/src/onnx_ir/passes/common/topological_sort_test.py @@ -5,7 +5,7 @@ import unittest import onnx_ir as ir -from onnxscript.ir.passes.common import topological_sort +from onnx_ir.passes.common import topological_sort class TopologicalSortPassTest(unittest.TestCase): diff --git a/src/onnx_ir/passes/common/unused_removal_test.py b/src/onnx_ir/passes/common/unused_removal_test.py index d1dd06a3..0554da0a 100644 --- a/src/onnx_ir/passes/common/unused_removal_test.py +++ b/src/onnx_ir/passes/common/unused_removal_test.py @@ -3,23 +3,17 @@ import unittest import onnx -import parameterized -import onnxscript.optimizer import onnx_ir as ir +import onnx_ir.passes.common -@parameterized.parameterized_class(("using_ir",), [(False,), (True,)]) class RemoveUnusedTest(unittest.TestCase): - using_ir: bool def remove_unused_nodes(self, model: onnx.ModelProto): - if self.using_ir: - model_ir = ir.serde.deserialize_model(model) - onnxscript.optimizer.remove_unused_nodes(model_ir) - model = ir.serde.serialize_model(model_ir) - return model - onnxscript.optimizer.remove_unused_nodes(model) + model_ir = ir.serde.deserialize_model(model) + onnx_ir.passes.common.RemoveUnusedNodesPass()(model_ir) + model = ir.serde.serialize_model(model_ir) return model def test_remove_unused_nodes(self): diff --git a/src/onnx_ir/serde.py b/src/onnx_ir/serde.py index bb0f1af9..f73fcc84 100644 --- a/src/onnx_ir/serde.py +++ b/src/onnx_ir/serde.py @@ -77,7 +77,7 @@ logger = logging.getLogger(__name__) _PLEASE_CONTRIBUTE = ( - "Please contribute by creating a PR at https://github.com/microsoft/onnxscript." + "Please contribute by creating a PR at https://github.com/onnx/onnx-ir." ) _FUNCTION_VALUE_INFO_SUPPORTED_VERSION = ( 10 # ONNX IR version where value info in functions was introduced @@ -321,7 +321,7 @@ def numpy(self) -> np.ndarray: specification. External tensors are not supported by this class. Use - :class:`onnxscript.ir.ExternalTensor` instead. + :class:`onnx_ir.ExternalTensor` instead. Raises: ValueError: If the data type is UNDEFINED. diff --git a/src/onnx_ir/serde_test.py b/src/onnx_ir/serde_test.py index 7484bae8..fd054d5c 100644 --- a/src/onnx_ir/serde_test.py +++ b/src/onnx_ir/serde_test.py @@ -9,7 +9,7 @@ import parameterized import onnx_ir as ir -from onnxscript._internal import version_utils +from onnx_ir import _version_utils from onnx_ir import serde @@ -81,12 +81,12 @@ def test_tensor_proto_tensor(self, _: str, dtype: int): array_from_raw_data = onnx.numpy_helper.to_array(tensor_proto_from_raw_data) np.testing.assert_array_equal(array_from_raw_data, expected_array) # Test dlpack - if dtype == onnx.TensorProto.BOOL and version_utils.numpy_older_than("1.25"): + if dtype == onnx.TensorProto.BOOL and _version_utils.numpy_older_than("1.25"): self.skipTest("numpy<1.25 does not support bool dtype in from_dlpack") np.testing.assert_array_equal(np.from_dlpack(tensor), tensor.numpy()) @unittest.skipIf( - version_utils.onnx_older_than("1.17"), + _version_utils.onnx_older_than("1.17"), "numpy_helper.to_array was not correctly implemented in onnx<1.17", ) def test_tensor_proto_tensor_bfloat16(self): diff --git a/src/onnx_ir/tape.py b/src/onnx_ir/tape.py index 96cc98f9..a34ee617 100644 --- a/src/onnx_ir/tape.py +++ b/src/onnx_ir/tape.py @@ -10,6 +10,6 @@ "Tape", ] -from onnxscript.ir._tape import Tape +from onnx_ir._tape import Tape Tape.__module__ = __name__ diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index fcee58f7..fc61742d 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Compatible adapters implementing the TensorProtocol interface for various framework tensor types. -This module provides public classes that implement the :class:`onnxscript.ir.TensorProtocol` +This module provides public classes that implement the :class:`onnx_ir.TensorProtocol` interface for various tensor types from popular deep learning frameworks. You can use these classes to create tensors and use them in the IR graph like any other tensor. diff --git a/tests/ir/public_api_test.py b/tests/ir/public_api_test.py index 65e3381d..e4d0e687 100644 --- a/tests/ir/public_api_test.py +++ b/tests/ir/public_api_test.py @@ -14,9 +14,9 @@ import unittest from typing import Iterable -import onnxscript.ir +import onnx_ir -IR_NAMESPACE = "onnxscript.ir" +IR_NAMESPACE = "onnx_ir" def _find_all_importables(pkg): @@ -83,7 +83,7 @@ def check_one_element(elem, modname, mod, *, is_public, is_all): if not why_not_looks_public and not elem_modname_starts_with_mod: why_not_looks_public = ( f"because its `__module__` attribute (`{elem_module}`) is not within the " - f"onnxscript.ir library or does not start with the submodule where it is defined (`{modname}`)" + f"onnx_ir library or does not start with the submodule where it is defined (`{modname}`)" ) # elem's name must NOT begin with an `_` and it's module name # SHOULD start with it's current module since it's a public API @@ -152,11 +152,11 @@ def check_one_element(elem, modname, mod, *, is_public, is_all): class TestPublicApiNamespace(unittest.TestCase): - tested_modules = (IR_NAMESPACE, *(_find_all_importables(onnxscript.ir))) + tested_modules = (IR_NAMESPACE, *(_find_all_importables(onnx_ir))) def test_correct_module_names(self): """ - An API is considered public, if its `__module__` starts with `onnxscript.ir` + An API is considered public, if its `__module__` starts with `onnx_ir` and there is no name in `__module__` or the object itself that starts with "_". Each public package should either: - (preferred) Define `__all__` and all callables and classes in there must have their diff --git a/tests/ir/serde_roundtrip_test.py b/tests/ir/serde_roundtrip_test.py index c29c8ed3..431f99aa 100644 --- a/tests/ir/serde_roundtrip_test.py +++ b/tests/ir/serde_roundtrip_test.py @@ -10,7 +10,7 @@ import onnx.backend.test import parameterized -import onnxscript.testing +import onnx_ir.testing import onnx_ir as ir model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" @@ -38,7 +38,7 @@ def test_serialization_deserialization_produces_same_model( ir_model = ir.serde.deserialize_model(model) serialized = ir.serde.serialize_model(ir_model) - onnxscript.testing.assert_onnx_proto_equal(serialized, model) + onnx_ir.testing.assert_onnx_proto_equal(serialized, model) onnx.checker.check_model(serialized) diff --git a/tools/model_zoo_test/model_zoo_test.py b/tools/model_zoo_test/model_zoo_test.py index 8db93352..e0a125db 100644 --- a/tools/model_zoo_test/model_zoo_test.py +++ b/tools/model_zoo_test/model_zoo_test.py @@ -22,7 +22,7 @@ import tqdm from onnx import hub -import onnxscript.testing +import onnx_ir.testing import onnx_ir as ir @@ -43,7 +43,7 @@ def test_model(model_info: hub.ModelInfo) -> float: ir_model = ir.serde.deserialize_model(model) serialized = ir.serde.serialize_model(ir_model) end = time.time() - onnxscript.testing.assert_onnx_proto_equal( + onnx_ir.testing.assert_onnx_proto_equal( serialized, model, ignore_initializer_value_proto=True ) onnx.checker.check_model(serialized) From cbb52d3bf7882c432e0d873eeace8eedeed45e6a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 09:34:46 -0700 Subject: [PATCH 04/16] Fix Signed-off-by: Justin Chu --- src/onnx_ir/__init__.py | 8 ++++++-- src/onnx_ir/_convenience/__init__.py | 1 - src/onnx_ir/_core.py | 5 ++--- src/onnx_ir/_graph_containers.py | 2 ++ src/onnx_ir/external_data.py | 3 +-- src/onnx_ir/passes/_pass_infra.py | 1 - src/onnx_ir/passes/common/_c_api_utils.py | 1 - src/onnx_ir/passes/common/inliner.py | 4 ++-- src/onnx_ir/serde_test.py | 3 +-- tests/ir/public_api_test.py | 7 ++++--- tests/ir/serde_roundtrip_test.py | 2 +- tools/model_zoo_test/model_zoo_test.py | 2 +- 12 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index 3adb36b9..12c89f2f 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -81,6 +81,8 @@ # IO "load", "save", + # Flags + "DEBUG", ] from onnx_ir import convenience, external_data, passes, serde, tape, traversal @@ -147,12 +149,14 @@ ) from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto +DEBUG = False def __set_module() -> None: """Set the module of all functions in this module to this public module.""" global_dict = globals() for name in __all__: - global_dict[name].__module__ = __name__ - + if hasattr(global_dict[name], "__module__"): + # Set the module of the function to this module + global_dict[name].__module__ = __name__ __set_module() diff --git a/src/onnx_ir/_convenience/__init__.py b/src/onnx_ir/_convenience/__init__.py index 053ae39f..240b9f12 100644 --- a/src/onnx_ir/_convenience/__init__.py +++ b/src/onnx_ir/_convenience/__init__.py @@ -355,7 +355,6 @@ def replace_nodes_and_values( old_values: The values to replace. new_values: The values to replace with. """ - for old_value, new_value in zip(old_values, new_values): # Propagate relevant info from old value to new value # TODO(Rama): Perhaps this should be a separate utility function. Also, consider diff --git a/src/onnx_ir/_core.py b/src/onnx_ir/_core.py index f794b904..e0a9daf7 100644 --- a/src/onnx_ir/_core.py +++ b/src/onnx_ir/_core.py @@ -1279,7 +1279,7 @@ def _short_tensor_str_for_node(x: Value) -> str: def _normalize_domain(domain: str) -> str: - """Normalize 'ai.onnx' to ''""" + """Normalize 'ai.onnx' to ''.""" return "" if domain == "ai.onnx" else domain @@ -1709,7 +1709,7 @@ def dtype(self, value: _enums.DataType) -> None: @property def elem_type(self) -> _enums.DataType: - """Return the element type of the tensor type""" + """Return the element type of the tensor type.""" return self.dtype def __hash__(self) -> int: @@ -2099,7 +2099,6 @@ def Input( This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``. """ - # NOTE: The function name is capitalized to maintain API backward compatibility. return Value(name=name, shape=shape, type=type, doc_string=doc_string) diff --git a/src/onnx_ir/_graph_containers.py b/src/onnx_ir/_graph_containers.py index 59168e51..2f268457 100644 --- a/src/onnx_ir/_graph_containers.py +++ b/src/onnx_ir/_graph_containers.py @@ -14,6 +14,8 @@ import collections from typing import TYPE_CHECKING, Iterable, SupportsIndex +import onnx_ir + if TYPE_CHECKING: from onnx_ir import _core diff --git a/src/onnx_ir/external_data.py b/src/onnx_ir/external_data.py index a94c2661..200cdef9 100644 --- a/src/onnx_ir/external_data.py +++ b/src/onnx_ir/external_data.py @@ -35,8 +35,7 @@ @dataclasses.dataclass class _ExternalDataInfo: - """ - A class that stores information about a tensor that is to be stored as external data. + """A class that stores information about a tensor that is to be stored as external data. Attributes: name: The name of the tensor that is to be stored as external data. diff --git a/src/onnx_ir/passes/_pass_infra.py b/src/onnx_ir/passes/_pass_infra.py index cf0fad1b..2510e69e 100644 --- a/src/onnx_ir/passes/_pass_infra.py +++ b/src/onnx_ir/passes/_pass_infra.py @@ -71,7 +71,6 @@ class PassResult: class PassBase(abc.ABC): """Base class for all passes. - ``in_place`` and ``changes_input`` properties and what they mean: +------------+------------------+----------------------------+ diff --git a/src/onnx_ir/passes/common/_c_api_utils.py b/src/onnx_ir/passes/common/_c_api_utils.py index ef345d13..a4a75a1a 100644 --- a/src/onnx_ir/passes/common/_c_api_utils.py +++ b/src/onnx_ir/passes/common/_c_api_utils.py @@ -34,7 +34,6 @@ def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R: Returns: The resulting ModelProto that contains the result of the API call. """ - # Store the original initializer values so they can be restored initializer_values = tuple(model.graph.initializers.values()) tensors = {v.name: v.const_value for v in initializer_values} diff --git a/src/onnx_ir/passes/common/inliner.py b/src/onnx_ir/passes/common/inliner.py index 3bf65466..a49e9a4c 100644 --- a/src/onnx_ir/passes/common/inliner.py +++ b/src/onnx_ir/passes/common/inliner.py @@ -1,6 +1,6 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 -"""Implementation of an inliner for onnx_ir""" +"""Implementation of an inliner for onnx_ir.""" from __future__ import annotations @@ -11,8 +11,8 @@ from collections import defaultdict from typing import Iterable, List, Sequence, Tuple -import onnx_ir.convenience as _ir_convenience import onnx_ir as ir +import onnx_ir.convenience as _ir_convenience # A replacement for a node specifies a list of nodes that replaces the original node, # and a list of values that replaces the original node's outputs. diff --git a/src/onnx_ir/serde_test.py b/src/onnx_ir/serde_test.py index fd054d5c..e25b42a2 100644 --- a/src/onnx_ir/serde_test.py +++ b/src/onnx_ir/serde_test.py @@ -9,8 +9,7 @@ import parameterized import onnx_ir as ir -from onnx_ir import _version_utils -from onnx_ir import serde +from onnx_ir import _version_utils, serde class ConvenienceFunctionsTest(unittest.TestCase): diff --git a/tests/ir/public_api_test.py b/tests/ir/public_api_test.py index e4d0e687..f177254c 100644 --- a/tests/ir/public_api_test.py +++ b/tests/ir/public_api_test.py @@ -1,5 +1,3 @@ -# Copyright (c) ONNX Project Contributors -# SPDX-License-Identifier: Apache-2.0 # Adapted from # https://github.com/pytorch/pytorch/blob/b505e8647547f029d0f7df408ee5f2968f757f89/test/test_public_bindings.py#L523 # Original code PyTorch license https://github.com/pytorch/pytorch/blob/main/LICENSE @@ -21,6 +19,7 @@ def _find_all_importables(pkg): """Find all importables in the project. + Return them in order. """ return sorted( @@ -34,6 +33,7 @@ def _find_all_importables(pkg): def _discover_path_importables(pkg_path: os.PathLike, pkg_name: str) -> Iterable[str]: """Yield all importables under a given path and package. + This is like pkgutil.walk_packages, but does *not* skip over namespace packages. Taken from https://stackoverflow.com/questions/41203765/init-py-required-for-pkgutil-walk-packages-in-python3 """ @@ -155,7 +155,8 @@ class TestPublicApiNamespace(unittest.TestCase): tested_modules = (IR_NAMESPACE, *(_find_all_importables(onnx_ir))) def test_correct_module_names(self): - """ + """Test module names are correct. + An API is considered public, if its `__module__` starts with `onnx_ir` and there is no name in `__module__` or the object itself that starts with "_". Each public package should either: diff --git a/tests/ir/serde_roundtrip_test.py b/tests/ir/serde_roundtrip_test.py index 431f99aa..9bae4f47 100644 --- a/tests/ir/serde_roundtrip_test.py +++ b/tests/ir/serde_roundtrip_test.py @@ -10,8 +10,8 @@ import onnx.backend.test import parameterized -import onnx_ir.testing import onnx_ir as ir +import onnx_ir.testing model_folder_path = pathlib.Path(__file__).resolve().parent.parent.parent / "testdata" onnx_backend_test_path = pathlib.Path(onnx.backend.test.__file__).parent / "data" diff --git a/tools/model_zoo_test/model_zoo_test.py b/tools/model_zoo_test/model_zoo_test.py index e0a125db..2aa02d52 100644 --- a/tools/model_zoo_test/model_zoo_test.py +++ b/tools/model_zoo_test/model_zoo_test.py @@ -22,8 +22,8 @@ import tqdm from onnx import hub -import onnx_ir.testing import onnx_ir as ir +import onnx_ir.testing def test_model(model_info: hub.ModelInfo) -> float: From b9f585de4c50923d445066476bef0a164e8677a4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 09:39:14 -0700 Subject: [PATCH 05/16] Format Signed-off-by: Justin Chu --- src/onnx_ir/__init__.py | 2 ++ src/onnx_ir/passes/common/unused_removal_test.py | 1 - src/onnx_ir/serde.py | 4 +--- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index 12c89f2f..2571822e 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -151,6 +151,7 @@ DEBUG = False + def __set_module() -> None: """Set the module of all functions in this module to this public module.""" global_dict = globals() @@ -159,4 +160,5 @@ def __set_module() -> None: # Set the module of the function to this module global_dict[name].__module__ = __name__ + __set_module() diff --git a/src/onnx_ir/passes/common/unused_removal_test.py b/src/onnx_ir/passes/common/unused_removal_test.py index 0554da0a..7e34b4b7 100644 --- a/src/onnx_ir/passes/common/unused_removal_test.py +++ b/src/onnx_ir/passes/common/unused_removal_test.py @@ -9,7 +9,6 @@ class RemoveUnusedTest(unittest.TestCase): - def remove_unused_nodes(self, model: onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) onnx_ir.passes.common.RemoveUnusedNodesPass()(model_ir) diff --git a/src/onnx_ir/serde.py b/src/onnx_ir/serde.py index f73fcc84..bb9c3957 100644 --- a/src/onnx_ir/serde.py +++ b/src/onnx_ir/serde.py @@ -76,9 +76,7 @@ logger = logging.getLogger(__name__) -_PLEASE_CONTRIBUTE = ( - "Please contribute by creating a PR at https://github.com/onnx/onnx-ir." -) +_PLEASE_CONTRIBUTE = "Please contribute by creating a PR at https://github.com/onnx/onnx-ir." _FUNCTION_VALUE_INFO_SUPPORTED_VERSION = ( 10 # ONNX IR version where value info in functions was introduced ) From 259aef7276b51c0a74513c923a29a0c085556a62 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 09:40:05 -0700 Subject: [PATCH 06/16] MANIFEST Signed-off-by: Justin Chu --- REUSE.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/REUSE.toml b/REUSE.toml index 49735f3b..37aad2ac 100644 --- a/REUSE.toml +++ b/REUSE.toml @@ -17,6 +17,7 @@ path = [ "**/*.toml", "**/*.yml", "CODEOWNERS", + "MANIFEST.in", "requirements/**/*.txt", ] precedence = "aggregate" From ebaa562d90cfb4c25e119823d3671f3e9867b7aa Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 10:11:18 -0700 Subject: [PATCH 07/16] Add testing Signed-off-by: Justin Chu --- src/onnx_ir/testing.py | 193 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 193 insertions(+) diff --git a/src/onnx_ir/testing.py b/src/onnx_ir/testing.py index 10e60984..6fccd450 100644 --- a/src/onnx_ir/testing.py +++ b/src/onnx_ir/testing.py @@ -1,3 +1,196 @@ # Copyright (c) ONNX Project Contributors # SPDX-License-Identifier: Apache-2.0 """Utilities for testing.""" + +from __future__ import annotations + +__all__ = [ + "assert_onnx_proto_equal", +] + +import difflib +import math +from typing import Any, Collection, Sequence + +import google.protobuf.message +import onnx + + +def _opset_import_key(opset_import: onnx.OperatorSetIdProto) -> tuple[str, int]: + return (opset_import.domain, opset_import.version) + + +def _value_info_key(value_info: onnx.ValueInfoProto) -> str: + return value_info.name + + +def _function_key(function: onnx.FunctionProto) -> tuple[str, str, str]: + return (function.domain, function.name, getattr(function, "overload", "")) + + +def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]: + """Return a list of duplicated elements in a collection.""" + seen = set() + duplicates = [] + for x in with_duplicates: + if x in seen: + duplicates.append(x) + seen.add(x) + return duplicates + + +def assert_onnx_proto_equal( + actual: google.protobuf.message.Message | Any, + expected: google.protobuf.message.Message | Any, + ignore_initializer_value_proto: bool = False, +) -> None: + """Assert that two ONNX protos are equal. + + Equality is defined as having the same fields with the same values. When + a field takes the default value, it is considered equal to the field + not being set. + + Sequential fields with name `opset_import`, `value_info`, and `functions` are + compared disregarding the order of their elements. + + Args: + actual: The first ONNX proto. + expected: The second ONNX proto. + ignore_initializer_value_proto: Ignore value protos for initializers if there + are extra ones in the actual proto. + """ + assert type(actual) is type(expected), ( + f"Type not equal: {type(actual)} != {type(expected)}" + ) + + a_fields = {field.name: value for field, value in actual.ListFields()} + b_fields = {field.name: value for field, value in expected.ListFields()} + all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys())) + if isinstance(actual, onnx.GraphProto) and isinstance(expected, onnx.GraphProto): + actual_initializer_names = {i.name for i in actual.initializer} + expected_initializer_names = {i.name for i in expected.initializer} + else: + actual_initializer_names = set() + expected_initializer_names = set() + + # Record and report all errors + errors = [] + for field in all_fields: # pylint: disable=too-many-nested-blocks + # Obtain the default value if the field is not set. This way we can compare the two fields. + a_value = getattr(actual, field) + b_value = getattr(expected, field) + if ( + isinstance(a_value, Sequence) + and isinstance(b_value, Sequence) + and not isinstance(a_value, (str, bytes)) + and not isinstance(b_value, (str, bytes)) + ): + # Check length first + a_keys: list[Any] = [] + b_keys: list[Any] = [] + if field == "opset_import": + a_value = sorted(a_value, key=_opset_import_key) + b_value = sorted(b_value, key=_opset_import_key) + a_keys = [_opset_import_key(opset_import) for opset_import in a_value] + b_keys = [_opset_import_key(opset_import) for opset_import in b_value] + elif field == "value_info": + if ( + ignore_initializer_value_proto + and isinstance(actual, onnx.GraphProto) + and isinstance(expected, onnx.GraphProto) + ): + # Filter out initializers from the value_info list + a_value = [ + value_info + for value_info in a_value + if value_info.name not in actual_initializer_names + ] + b_value = [ + value_info + for value_info in b_value + if value_info.name not in expected_initializer_names + ] + a_value = sorted(a_value, key=_value_info_key) + b_value = sorted(b_value, key=_value_info_key) + a_keys = [_value_info_key(value_info) for value_info in a_value] + b_keys = [_value_info_key(value_info) for value_info in b_value] + elif field == "functions": + a_value = sorted(a_value, key=_function_key) + b_value = sorted(b_value, key=_function_key) + a_keys = [_function_key(functions) for functions in a_value] + b_keys = [_function_key(functions) for functions in b_value] + + if a_keys != b_keys: + keys_only_in_actual = set(a_keys) - set(b_keys) + keys_only_in_expected = set(b_keys) - set(a_keys) + error_message = ( + f"Field {field} not equal: keys_only_in_actual={keys_only_in_actual}, keys_only_in_expected={keys_only_in_expected}. " + f"Field type: {type(a_value)}. " + f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}" + ) + errors.append(error_message) + elif len(a_value) != len(b_value): + error_message = ( + f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} " + f"Field type: {type(a_value)}" + ) + errors.append(error_message) + else: + # Check every element + for i in range(len(a_value)): # pylint: disable=consider-using-enumerate + actual_value_i = a_value[i] + expected_value_i = b_value[i] + if isinstance( + actual_value_i, google.protobuf.message.Message + ) and isinstance(expected_value_i, google.protobuf.message.Message): + try: + assert_onnx_proto_equal( + actual_value_i, + expected_value_i, + ignore_initializer_value_proto=ignore_initializer_value_proto, + ) + except AssertionError as e: + error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}, actual_value_i: {actual_value_i}, expected_value_i: {expected_value_i}" + error_message = ( + str(e) + "\n\nCaused by the above error\n\n" + error_message + ) + errors.append(error_message) + elif actual_value_i != expected_value_i: + if ( + isinstance(actual_value_i, float) + and isinstance(expected_value_i, float) + and math.isnan(actual_value_i) + and math.isnan(expected_value_i) + ): + # Consider NaNs equal + continue + error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}" + for line in difflib.ndiff( + str(actual_value_i).splitlines(), + str(expected_value_i).splitlines(), + ): + error_message += "\n" + line + errors.append(error_message) + elif isinstance(a_value, google.protobuf.message.Message) and isinstance( + b_value, google.protobuf.message.Message + ): + assert_onnx_proto_equal( + a_value, b_value, ignore_initializer_value_proto=ignore_initializer_value_proto + ) + elif a_value != b_value: + if ( + isinstance(a_value, float) + and isinstance(b_value, float) + and math.isnan(a_value) + and math.isnan(b_value) + ): + # Consider NaNs equal + continue + error_message = ( + f"Field {field} not equal. field_actual: {a_value}, field_expected: {b_value}" + ) + errors.append(error_message) + if errors: + raise AssertionError( + f"Protos not equal: {type(actual)} != {type(expected)}\n" + "\n".join(errors) + ) From 5d9640430625cf8e3d66cf0202744a33b0266782 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 10:18:49 -0700 Subject: [PATCH 08/16] req Signed-off-by: Justin Chu --- .github/workflows/lint.yml | 2 +- requirements-dev.txt | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 requirements-dev.txt diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d8c342cf..01e6b574 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -53,7 +53,7 @@ jobs: # Install dependencies python -m pip install --upgrade pip python -m pip install --upgrade setuptools - python -m pip install --upgrade lintrunner lintrunner-adapters + python -m pip install -r requirements-dev.txt # Install packages python -m pip install -e . lintrunner init diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..93409eda --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,33 @@ +setuptools>=70.0.0 +onnxruntime>=1.17.0 +rich>=13.7.1 + +# Docs site +furo +jax[cpu] +matplotlib +myst-parser[linkify] +sphinx-copybutton +sphinx-exec-code +sphinx-gallery +sphinx>=6 +myst_nb +chardet + +# Testing +expecttest==0.1.6 +hypothesis +parameterized +pytest-cov +pytest-randomly +pytest-subtests +pytest-xdist +pytest!=7.1.0 +pyyaml +torch>=2.3 +torchvision>=0.18.0 +transformers>=4.37.2 + +# Lint +lintrunner>=0.10.7 +lintrunner_adapters>=0.12.0 From 671f92f33b2b3fe1581898d97baf5c11008b8acd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 10:21:54 -0700 Subject: [PATCH 09/16] req Signed-off-by: Justin Chu --- requirements-dev.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 93409eda..a6c1c625 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,6 +17,7 @@ chardet # Testing expecttest==0.1.6 hypothesis +packaging parameterized pytest-cov pytest-randomly From 94d1ab2514aa444027e9265fe4f794323bf116ff Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 10:25:00 -0700 Subject: [PATCH 10/16] Move tests Signed-off-by: Justin Chu --- tests/{ir => }/graph_view_test.py | 0 tests/{ir => }/public_api_test.py | 0 tests/{ir => }/serde_roundtrip_test.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/{ir => }/graph_view_test.py (100%) rename tests/{ir => }/public_api_test.py (100%) rename tests/{ir => }/serde_roundtrip_test.py (100%) diff --git a/tests/ir/graph_view_test.py b/tests/graph_view_test.py similarity index 100% rename from tests/ir/graph_view_test.py rename to tests/graph_view_test.py diff --git a/tests/ir/public_api_test.py b/tests/public_api_test.py similarity index 100% rename from tests/ir/public_api_test.py rename to tests/public_api_test.py diff --git a/tests/ir/serde_roundtrip_test.py b/tests/serde_roundtrip_test.py similarity index 100% rename from tests/ir/serde_roundtrip_test.py rename to tests/serde_roundtrip_test.py From b7b34d3b74f1acc0f31d8274f515d5e162f9913f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 11:54:37 -0700 Subject: [PATCH 11/16] license bsd Signed-off-by: Justin Chu --- LICENSES/BSD-3-Clause.txt | 27 +++++++++ REUSE.toml | 74 +++++++++++++++++++++++++ src/onnx_ir/_thirdparty/asciichartpy.py | 3 - 3 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 LICENSES/BSD-3-Clause.txt diff --git a/LICENSES/BSD-3-Clause.txt b/LICENSES/BSD-3-Clause.txt new file mode 100644 index 00000000..1e47ac7c --- /dev/null +++ b/LICENSES/BSD-3-Clause.txt @@ -0,0 +1,27 @@ +Copyright (c) [year], [fullname] +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of [project] nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/REUSE.toml b/REUSE.toml index 37aad2ac..fe5a1c4a 100644 --- a/REUSE.toml +++ b/REUSE.toml @@ -23,3 +23,77 @@ path = [ precedence = "aggregate" SPDX-FileCopyrightText = "Copyright (c) ONNX Project Contributors" SPDX-License-Identifier = "Apache-2.0" + +[[annotations]] +path = "src/onnx_ir/_thirdparty/asciichartpy.py" +precedence = "aggregate" +SPDX-FileCopyrightText = [ + "2016 Igor Kroitor", +] +SPDX-License-Identifier = "MIT" + +[[annotations]] +path = "tests/public_api_test.py" +precedence = "aggregate" +SPDX-FileCopyrightText = [ + """ + From PyTorch: + + Copyright (c) 2016- Facebook, Inc (Adam Paszke) + Copyright (c) 2014- Facebook, Inc (Soumith Chintala) + Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) + Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) + Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) + Copyright (c) 2011-2013 NYU (Clement Farabet) + Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) + Copyright (c) 2006 Idiap Research Institute (Samy Bengio) + Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + + From Caffe2: + + Copyright (c) 2016-present, Facebook Inc. All rights reserved. + + All contributions by Facebook: + Copyright (c) 2016 Facebook Inc. + + All contributions by Google: + Copyright (c) 2015 Google Inc. + All rights reserved. + + All contributions by Yangqing Jia: + Copyright (c) 2015 Yangqing Jia + All rights reserved. + + All contributions by Kakao Brain: + Copyright 2019-2020 Kakao Brain + + All contributions by Cruise LLC: + Copyright (c) 2022 Cruise LLC. + All rights reserved. + + All contributions by Tri Dao: + Copyright (c) 2024 Tri Dao. + All rights reserved. + + All contributions by Arm: + Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates + + All contributions from Caffe: + Copyright(c) 2013, 2014, 2015, the respective contributors + All rights reserved. + + All other contributions: + Copyright(c) 2015, 2016 the respective contributors + All rights reserved. + + Caffe2 uses a copyright model similar to Caffe: each contributor holds + copyright over their contributions to Caffe2. The project versioning records + all such contribution and copyright details. If a contributor wants to further + mark their specific copyright on a particular contribution, they should + indicate their copyright solely in the commit message of the change when it is + committed. + + All rights reserved. + """ +] +SPDX-License-Identifier = "BSD-3-Clause" diff --git a/src/onnx_ir/_thirdparty/asciichartpy.py b/src/onnx_ir/_thirdparty/asciichartpy.py index 62e9764c..a4af661d 100644 --- a/src/onnx_ir/_thirdparty/asciichartpy.py +++ b/src/onnx_ir/_thirdparty/asciichartpy.py @@ -1,6 +1,3 @@ -# Copyright (c) ONNX Project Contributors -# SPDX-License-Identifier: Apache-2.0 -# # Copyright © 2016 Igor Kroitor # # MIT License From 5e23a04c684351ab41255fcff9b49ab74d662cce Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 11:56:06 -0700 Subject: [PATCH 12/16] reuse Signed-off-by: Justin Chu --- REUSE.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/REUSE.toml b/REUSE.toml index fe5a1c4a..521e4c37 100644 --- a/REUSE.toml +++ b/REUSE.toml @@ -18,6 +18,7 @@ path = [ "**/*.yml", "CODEOWNERS", "MANIFEST.in", + "requirements*.txt", "requirements/**/*.txt", ] precedence = "aggregate" From c99a9a39f961b3f36bd8787e13d55acc73827f0b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 12:07:29 -0700 Subject: [PATCH 13/16] mit Signed-off-by: Justin Chu --- LICENSES/mit.txt | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSES/mit.txt diff --git a/LICENSES/mit.txt b/LICENSES/mit.txt new file mode 100644 index 00000000..8aa26455 --- /dev/null +++ b/LICENSES/mit.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) [year] [fullname] + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From 1b4cdc3083ce2d5c253e3755fd4400b6030466e7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 12:34:54 -0700 Subject: [PATCH 14/16] rename Signed-off-by: Justin Chu --- LICENSES/{mit.txt => MIT.txt} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename LICENSES/{mit.txt => MIT.txt} (100%) diff --git a/LICENSES/mit.txt b/LICENSES/MIT.txt similarity index 100% rename from LICENSES/mit.txt rename to LICENSES/MIT.txt From 991e3935589248186c953e23ad4cae912b634e1e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 12:40:42 -0700 Subject: [PATCH 15/16] Build docs Signed-off-by: Justin Chu --- docs/conf.py | 1 - docs/index.md | 2 +- src/onnx_ir/__init__.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index ca3f2c5c..56cc294e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,7 +24,6 @@ "myst_nb", "sphinx_copybutton", "sphinx_exec_code", - "sphinx_gallery.gen_gallery", "sphinx.ext.autodoc", "sphinx.ext.autosummary", "sphinx.ext.doctest", diff --git a/docs/index.md b/docs/index.md index 035070f4..ce165558 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,5 +20,5 @@ An in-memory IR that supports the full ONNX spec, designed for graph constructio Overview getting_started tensors -ir_api/index +api/index ``` diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index 2571822e..178b68cd 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -149,6 +149,7 @@ ) from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto +__version__ = "0.0.1" DEBUG = False From 259103bb28b7e90692b755e04035695c05af4358 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 15 May 2025 13:15:51 -0700 Subject: [PATCH 16/16] Update version Signed-off-by: Justin Chu --- README.md | 15 +++++++++------ pyproject.toml | 6 +++++- src/onnx_ir/__init__.py | 2 +- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index dae5c09a..0ceb592b 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,8 @@ # ONNX IR +[![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) + An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation. ## Features ✨ @@ -14,9 +17,9 @@ An in-memory IR that supports the full ONNX spec, designed for graph constructio ## Code Organization 🗺️ -- [`_protocols.py`](_protocols.py): Interfaces defined for all entities in the IR. -- [`_core.py`](_core.py): Implementation of the core entities in the IR, including `Model`, `Graph`, `Node`, `Value`, and others. -- [`_enums.py`](_enums.py): Definition of the type enums that correspond to the `DataType` and `AttributeType` in `onnx.proto`. -- [`_name_authority.py`](_name_authority.py): The authority for giving names to entities in the graph, used internally. -- [`_linked_list.py`](_linked_list.py): The data structure as the node container in the graph that supports robust iteration and mutation. Internal. -- [`_metadata.py`](_metadata.py): Metadata store for all entities in the IR. +- [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR. +- [`_core.py`](src/onnx_ir/_core.py): Implementation of the core entities in the IR, including `Model`, `Graph`, `Node`, `Value`, and others. +- [`_enums.py`](src/onnx_ir/_enums.py): Definition of the type enums that correspond to the `DataType` and `AttributeType` in `onnx.proto`. +- [`_name_authority.py`](src/onnx_ir/_name_authority.py): The authority for giving names to entities in the graph, used internally. +- [`_linked_list.py`](src/onnx_ir/_linked_list.py): The data structure as the node container in the graph that supports robust iteration and mutation. Internal. +- [`_metadata.py`](src/onnx_ir/_metadata.py): Metadata store for all entities in the IR. diff --git a/pyproject.toml b/pyproject.toml index 537a72e0..10bc05a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "onnx-ir" -version = "0.0.1" +dynamic = ["version"] authors = [ {name = "ONNX Contributors", email = "onnx-technical-discuss@lists.lfaidata.foundation"}, ] @@ -20,8 +20,12 @@ dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"] [project.urls] Homepage = "https://onnx.ai/onnx-ir" +Issues = "https://github.com/onnx/onnx-ir/issues" Repository = "https://github.com/onnx/onnx-ir" +[tool.setuptools.dynamic] +version = {attr = "onnx_ir.__version__"} + [tool.pytest.ini_options] addopts = "--tb=short --color=yes" diff --git a/src/onnx_ir/__init__.py b/src/onnx_ir/__init__.py index 178b68cd..e0a1ff77 100644 --- a/src/onnx_ir/__init__.py +++ b/src/onnx_ir/__init__.py @@ -149,7 +149,7 @@ ) from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto -__version__ = "0.0.1" +__version__ = "0.1.0" DEBUG = False