Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

name: Lint

on:
Expand Down Expand Up @@ -67,7 +63,7 @@ jobs:
if ! lintrunner --force-color --all-files --tee-json=lint.json -v; then
echo ""
echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m"
echo -e "\e[1m\e[36mSee https://github.com/microsoft/onnxscript#coding-style for setup instructions.\e[0m"
echo -e "\e[1m\e[36mSee https://github.com/onnx/onnx_ir/blob/main/CONTRIBUTING.md for setup instructions.\e[0m"
exit 1
fi
- name: Produce SARIF
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

# This workflow uses actions that are not certified by GitHub. They are provided
# by a third-party and are governed by separate terms of service, privacy
# policy, and support documentation.
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
global-exclude *_test.py
1 change: 0 additions & 1 deletion REUSE.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

version = 1
Expand Down
26 changes: 15 additions & 11 deletions src/onnx_ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""In-memory intermediate representation for ONNX graphs."""

__all__ = [
Expand Down Expand Up @@ -81,11 +81,13 @@
# IO
"load",
"save",
# Flags
"DEBUG",
]

from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal
from onnxscript.ir._convenience._constructors import node, tensor
from onnxscript.ir._core import (
from onnx_ir import convenience, external_data, passes, serde, tape, traversal
from onnx_ir._convenience._constructors import node, tensor
from onnx_ir._core import (
Attr,
AttrFloat32,
AttrFloat32s,
Expand Down Expand Up @@ -121,12 +123,12 @@
TypeAndShape,
Value,
)
from onnxscript.ir._enums import (
from onnx_ir._enums import (
AttributeType,
DataType,
)
from onnxscript.ir._io import load, save
from onnxscript.ir._protocols import (
from onnx_ir._io import load, save
from onnx_ir._protocols import (
ArrayCompatible,
AttributeProtocol,
DLPackCompatible,
Expand All @@ -145,14 +147,16 @@
TypeProtocol,
ValueProtocol,
)
from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto

DEBUG = False

def __set_module() -> None:
"""Set the module of all functions in this module to this public module."""
global_dict = globals()
for name in __all__:
global_dict[name].__module__ = __name__

if hasattr(global_dict[name], "__module__"):
# Set the module of the function to this module
global_dict[name].__module__ = __name__

__set_module()
11 changes: 5 additions & 6 deletions src/onnx_ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Convenience methods for constructing and manipulating the IR.

This is an internal only module. We should choose to expose some of the methods
Expand All @@ -20,7 +20,7 @@

import onnx

from onnxscript.ir import _core, _enums, _protocols, serde
from onnx_ir import _core, _enums, _protocols, serde

SupportedAttrTypes = Union[
str,
Expand Down Expand Up @@ -188,7 +188,7 @@ def convert_attributes(
types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
:class:`_core.Tensor`, and :class:`_core.Attr`::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> import onnx
>>> import numpy as np
>>> attrs = {
Expand Down Expand Up @@ -269,7 +269,7 @@ def replace_all_uses_with(

We want to replace the node A with a new node D::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> input = ir.Input("input")
>>> node_a = ir.Node("", "A", [input])
>>> node_b = ir.Node("", "B", node_a.outputs)
Expand Down Expand Up @@ -355,7 +355,6 @@ def replace_nodes_and_values(
old_values: The values to replace.
new_values: The values to replace with.
"""

for old_value, new_value in zip(old_values, new_values):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider
Expand Down
12 changes: 6 additions & 6 deletions src/onnx_ir/_convenience/_constructors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Convenience constructors for IR objects."""

from __future__ import annotations
Expand All @@ -15,12 +15,12 @@
import numpy as np
import onnx

from onnxscript.ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters

if typing.TYPE_CHECKING:
import numpy.typing as npt

from onnxscript import ir
import onnx_ir as ir


def tensor(
Expand All @@ -39,7 +39,7 @@ def tensor(

Example::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> import numpy as np
>>> import ml_dtypes
>>> import onnx
Expand Down Expand Up @@ -162,7 +162,7 @@ def node(

Example::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
>>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
>>> node = ir.node(
Expand Down
8 changes: 4 additions & 4 deletions src/onnx_ir/_convenience/_constructors_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for the _constructors module."""

import unittest

import numpy as np

from onnxscript import ir
from onnxscript.ir._convenience import _constructors
import onnx_ir as ir
from onnx_ir._convenience import _constructors


class ConstructorsTest(unittest.TestCase):
Expand Down
35 changes: 17 additions & 18 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""data structures for the intermediate representation."""

# NOTES for developers:
Expand Down Expand Up @@ -44,8 +44,8 @@
import numpy as np
from typing_extensions import TypeIs

import onnxscript
from onnxscript.ir import (
import onnx_ir
from onnx_ir import (
_display,
_enums,
_graph_containers,
Expand Down Expand Up @@ -186,7 +186,7 @@ def display(self, *, page: bool = False) -> None:

status_manager = rich.status.Status(f"Computing tensor stats for {self!r}")

from onnxscript._thirdparty import ( # pylint: disable=import-outside-toplevel
from onnx_ir._thirdparty import ( # pylint: disable=import-outside-toplevel
asciichartpy,
)

Expand Down Expand Up @@ -582,7 +582,7 @@ def __init__(
# NOTE: Do not verify the location by default. This is because the location field
# in the tensor proto can be anything and we would like deserialization from
# proto to IR to not fail.
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if os.path.isabs(location):
raise ValueError(
"The location must be a relative path. Please specify base_dir as well."
Expand Down Expand Up @@ -852,7 +852,7 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
Example::

>>> import numpy as np
>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> weights = np.array([[1, 2, 3]])
>>> def create_tensor(): # Delay applying transformations to the weights
... weights_t = weights.transpose()
Expand Down Expand Up @@ -1039,7 +1039,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):

Example::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> shape = ir.Shape(["B", None, 3])
>>> shape.rank()
3
Expand Down Expand Up @@ -1279,7 +1279,7 @@ def _short_tensor_str_for_node(x: Value) -> str:


def _normalize_domain(domain: str) -> str:
"""Normalize 'ai.onnx' to ''"""
"""Normalize 'ai.onnx' to ''."""
return "" if domain == "ai.onnx" else domain


Expand Down Expand Up @@ -1709,7 +1709,7 @@ def dtype(self, value: _enums.DataType) -> None:

@property
def elem_type(self) -> _enums.DataType:
"""Return the element type of the tensor type"""
"""Return the element type of the tensor type."""
return self.dtype

def __hash__(self) -> int:
Expand Down Expand Up @@ -2052,7 +2052,7 @@ def const_value(
self,
value: _protocols.TensorProtocol | None,
) -> None:
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if value is not None and not isinstance(value, _protocols.TensorProtocol):
raise TypeError(
f"Expected value to be a TensorProtocol or None, got '{type(value)}'"
Expand Down Expand Up @@ -2099,7 +2099,6 @@ def Input(

This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
"""

# NOTE: The function name is capitalized to maintain API backward compatibility.

return Value(name=name, shape=shape, type=type, doc_string=doc_string)
Expand Down Expand Up @@ -2469,7 +2468,7 @@ def sort(self) -> None:
ValueError: If the graph contains a cycle, making topological sorting impossible.
"""
# Obtain all nodes from the graph and its subgraphs for sorting
nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self))
nodes = list(onnx_ir.traversal.RecursiveGraphIterator(self))
# Store the sorted nodes of each subgraph
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
Expand Down Expand Up @@ -2858,7 +2857,7 @@ def graphs(self) -> Iterable[Graph]:
"""Get all graphs and subgraphs in the model.

This is a convenience method to traverse the model. Consider using
`onnxscript.ir.traversal.RecursiveGraphIterator` for more advanced
`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
traversals on nodes.
"""
# NOTE(justinchuby): Given
Expand All @@ -2868,7 +2867,7 @@ def graphs(self) -> Iterable[Graph]:
# I created this method as a core method instead of an iterator in
# `traversal.py`.
seen_graphs: set[Graph] = set()
for node in onnxscript.ir.traversal.RecursiveGraphIterator(self.graph):
for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
if node.graph is not None and node.graph not in seen_graphs:
seen_graphs.add(node.graph)
yield node.graph
Expand Down Expand Up @@ -3226,7 +3225,7 @@ def as_strings(self) -> Sequence[str]:
"""Get the attribute value as a sequence of strings."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if not all(isinstance(x, str) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
# Create a copy of the list to prevent mutation
Expand All @@ -3236,7 +3235,7 @@ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
"""Get the attribute value as a sequence of tensors."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
# Create a copy of the list to prevent mutation
Expand All @@ -3246,7 +3245,7 @@ def as_graphs(self) -> Sequence[Graph]:
"""Get the attribute value as a sequence of graphs."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if not all(isinstance(x, Graph) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
# Create a copy of the list to prevent mutation
Expand Down
8 changes: 4 additions & 4 deletions src/onnx_ir/_core_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import copy
Expand All @@ -15,8 +15,8 @@
import parameterized
import torch

from onnxscript import ir
from onnxscript.ir import _core
import onnx_ir as ir
from onnx_ir import _core


class TensorTest(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions src/onnx_ir/_display.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Internal utilities for displaying the intermediate representation of a model.

NOTE: All third-party imports should be scoped and imported only when used to avoid
Expand Down
6 changes: 3 additions & 3 deletions src/onnx_ir/_display_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Test display() methods in various classes."""

import contextlib
import unittest

import numpy as np

import onnxscript.ir as ir
import onnx_ir as ir


class DisplayTest(unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions src/onnx_ir/_enums.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""ONNX IR enums that matches the ONNX spec."""

from __future__ import annotations
Expand Down
Loading
Loading