Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

name: Lint

on:
Expand Down Expand Up @@ -57,7 +53,7 @@ jobs:
# Install dependencies
python -m pip install --upgrade pip
python -m pip install --upgrade setuptools
python -m pip install --upgrade lintrunner lintrunner-adapters
python -m pip install -r requirements-dev.txt
# Install packages
python -m pip install -e .
lintrunner init
Expand All @@ -67,7 +63,7 @@ jobs:
if ! lintrunner --force-color --all-files --tee-json=lint.json -v; then
echo ""
echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m"
echo -e "\e[1m\e[36mSee https://github.com/microsoft/onnxscript#coding-style for setup instructions.\e[0m"
echo -e "\e[1m\e[36mSee https://github.com/onnx/onnx_ir/blob/main/CONTRIBUTING.md for setup instructions.\e[0m"
exit 1
fi
- name: Produce SARIF
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/scorecard.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

# This workflow uses actions that are not certified by GitHub. They are provided
# by a third-party and are governed by separate terms of service, privacy
# policy, and support documentation.
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
global-exclude *_test.py
2 changes: 1 addition & 1 deletion REUSE.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

version = 1
Expand All @@ -18,6 +17,7 @@ path = [
"**/*.toml",
"**/*.yml",
"CODEOWNERS",
"MANIFEST.in",
"requirements/**/*.txt",
]
precedence = "aggregate"
Expand Down
34 changes: 34 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
setuptools>=70.0.0
onnxruntime>=1.17.0
rich>=13.7.1

# Docs site
furo
jax[cpu]
matplotlib
myst-parser[linkify]
sphinx-copybutton
sphinx-exec-code
sphinx-gallery
sphinx>=6
myst_nb
chardet

# Testing
expecttest==0.1.6
hypothesis
packaging
parameterized
pytest-cov
pytest-randomly
pytest-subtests
pytest-xdist
pytest!=7.1.0
pyyaml
torch>=2.3
torchvision>=0.18.0
transformers>=4.37.2

# Lint
lintrunner>=0.10.7
lintrunner_adapters>=0.12.0
26 changes: 16 additions & 10 deletions src/onnx_ir/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""In-memory intermediate representation for ONNX graphs."""

__all__ = [
Expand Down Expand Up @@ -81,11 +81,13 @@
# IO
"load",
"save",
# Flags
"DEBUG",
]

from onnxscript.ir import convenience, external_data, passes, serde, tape, traversal
from onnxscript.ir._convenience._constructors import node, tensor
from onnxscript.ir._core import (
from onnx_ir import convenience, external_data, passes, serde, tape, traversal
from onnx_ir._convenience._constructors import node, tensor
from onnx_ir._core import (
Attr,
AttrFloat32,
AttrFloat32s,
Expand Down Expand Up @@ -121,12 +123,12 @@
TypeAndShape,
Value,
)
from onnxscript.ir._enums import (
from onnx_ir._enums import (
AttributeType,
DataType,
)
from onnxscript.ir._io import load, save
from onnxscript.ir._protocols import (
from onnx_ir._io import load, save
from onnx_ir._protocols import (
ArrayCompatible,
AttributeProtocol,
DLPackCompatible,
Expand All @@ -145,14 +147,18 @@
TypeProtocol,
ValueProtocol,
)
from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto

DEBUG = False


def __set_module() -> None:
"""Set the module of all functions in this module to this public module."""
global_dict = globals()
for name in __all__:
global_dict[name].__module__ = __name__
if hasattr(global_dict[name], "__module__"):
# Set the module of the function to this module
global_dict[name].__module__ = __name__


__set_module()
11 changes: 5 additions & 6 deletions src/onnx_ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Convenience methods for constructing and manipulating the IR.

This is an internal only module. We should choose to expose some of the methods
Expand All @@ -20,7 +20,7 @@

import onnx

from onnxscript.ir import _core, _enums, _protocols, serde
from onnx_ir import _core, _enums, _protocols, serde

SupportedAttrTypes = Union[
str,
Expand Down Expand Up @@ -188,7 +188,7 @@ def convert_attributes(
types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
:class:`_core.Tensor`, and :class:`_core.Attr`::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> import onnx
>>> import numpy as np
>>> attrs = {
Expand Down Expand Up @@ -269,7 +269,7 @@ def replace_all_uses_with(

We want to replace the node A with a new node D::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> input = ir.Input("input")
>>> node_a = ir.Node("", "A", [input])
>>> node_b = ir.Node("", "B", node_a.outputs)
Expand Down Expand Up @@ -355,7 +355,6 @@ def replace_nodes_and_values(
old_values: The values to replace.
new_values: The values to replace with.
"""

for old_value, new_value in zip(old_values, new_values):
# Propagate relevant info from old value to new value
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider
Expand Down
12 changes: 6 additions & 6 deletions src/onnx_ir/_convenience/_constructors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Convenience constructors for IR objects."""

from __future__ import annotations
Expand All @@ -15,12 +15,12 @@
import numpy as np
import onnx

from onnxscript.ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters

if typing.TYPE_CHECKING:
import numpy.typing as npt

from onnxscript import ir
import onnx_ir as ir


def tensor(
Expand All @@ -39,7 +39,7 @@ def tensor(

Example::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> import numpy as np
>>> import ml_dtypes
>>> import onnx
Expand Down Expand Up @@ -162,7 +162,7 @@ def node(

Example::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> input_a = ir.Input("A", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
>>> input_b = ir.Input("B", shape=ir.Shape([1, 2]), type=ir.TensorType(ir.DataType.INT32))
>>> node = ir.node(
Expand Down
8 changes: 4 additions & 4 deletions src/onnx_ir/_convenience/_constructors_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for the _constructors module."""

import unittest

import numpy as np

from onnxscript import ir
from onnxscript.ir._convenience import _constructors
import onnx_ir as ir
from onnx_ir._convenience import _constructors


class ConstructorsTest(unittest.TestCase):
Expand Down
35 changes: 17 additions & 18 deletions src/onnx_ir/_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) ONNX Project Contributors
# SPDX-License-Identifier: Apache-2.0
"""data structures for the intermediate representation."""

# NOTES for developers:
Expand Down Expand Up @@ -44,8 +44,8 @@
import numpy as np
from typing_extensions import TypeIs

import onnxscript
from onnxscript.ir import (
import onnx_ir
from onnx_ir import (
_display,
_enums,
_graph_containers,
Expand Down Expand Up @@ -186,7 +186,7 @@ def display(self, *, page: bool = False) -> None:

status_manager = rich.status.Status(f"Computing tensor stats for {self!r}")

from onnxscript._thirdparty import ( # pylint: disable=import-outside-toplevel
from onnx_ir._thirdparty import ( # pylint: disable=import-outside-toplevel
asciichartpy,
)

Expand Down Expand Up @@ -582,7 +582,7 @@ def __init__(
# NOTE: Do not verify the location by default. This is because the location field
# in the tensor proto can be anything and we would like deserialization from
# proto to IR to not fail.
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if os.path.isabs(location):
raise ValueError(
"The location must be a relative path. Please specify base_dir as well."
Expand Down Expand Up @@ -852,7 +852,7 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
Example::

>>> import numpy as np
>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> weights = np.array([[1, 2, 3]])
>>> def create_tensor(): # Delay applying transformations to the weights
... weights_t = weights.transpose()
Expand Down Expand Up @@ -1039,7 +1039,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):

Example::

>>> from onnxscript import ir
>>> import onnx_ir as ir
>>> shape = ir.Shape(["B", None, 3])
>>> shape.rank()
3
Expand Down Expand Up @@ -1279,7 +1279,7 @@ def _short_tensor_str_for_node(x: Value) -> str:


def _normalize_domain(domain: str) -> str:
"""Normalize 'ai.onnx' to ''"""
"""Normalize 'ai.onnx' to ''."""
return "" if domain == "ai.onnx" else domain


Expand Down Expand Up @@ -1709,7 +1709,7 @@ def dtype(self, value: _enums.DataType) -> None:

@property
def elem_type(self) -> _enums.DataType:
"""Return the element type of the tensor type"""
"""Return the element type of the tensor type."""
return self.dtype

def __hash__(self) -> int:
Expand Down Expand Up @@ -2052,7 +2052,7 @@ def const_value(
self,
value: _protocols.TensorProtocol | None,
) -> None:
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if value is not None and not isinstance(value, _protocols.TensorProtocol):
raise TypeError(
f"Expected value to be a TensorProtocol or None, got '{type(value)}'"
Expand Down Expand Up @@ -2099,7 +2099,6 @@ def Input(

This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
"""

# NOTE: The function name is capitalized to maintain API backward compatibility.

return Value(name=name, shape=shape, type=type, doc_string=doc_string)
Expand Down Expand Up @@ -2469,7 +2468,7 @@ def sort(self) -> None:
ValueError: If the graph contains a cycle, making topological sorting impossible.
"""
# Obtain all nodes from the graph and its subgraphs for sorting
nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self))
nodes = list(onnx_ir.traversal.RecursiveGraphIterator(self))
# Store the sorted nodes of each subgraph
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
Expand Down Expand Up @@ -2858,7 +2857,7 @@ def graphs(self) -> Iterable[Graph]:
"""Get all graphs and subgraphs in the model.

This is a convenience method to traverse the model. Consider using
`onnxscript.ir.traversal.RecursiveGraphIterator` for more advanced
`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
traversals on nodes.
"""
# NOTE(justinchuby): Given
Expand All @@ -2868,7 +2867,7 @@ def graphs(self) -> Iterable[Graph]:
# I created this method as a core method instead of an iterator in
# `traversal.py`.
seen_graphs: set[Graph] = set()
for node in onnxscript.ir.traversal.RecursiveGraphIterator(self.graph):
for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
if node.graph is not None and node.graph not in seen_graphs:
seen_graphs.add(node.graph)
yield node.graph
Expand Down Expand Up @@ -3226,7 +3225,7 @@ def as_strings(self) -> Sequence[str]:
"""Get the attribute value as a sequence of strings."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if not all(isinstance(x, str) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
# Create a copy of the list to prevent mutation
Expand All @@ -3236,7 +3235,7 @@ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
"""Get the attribute value as a sequence of tensors."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
# Create a copy of the list to prevent mutation
Expand All @@ -3246,7 +3245,7 @@ def as_graphs(self) -> Sequence[Graph]:
"""Get the attribute value as a sequence of graphs."""
if not isinstance(self.value, Sequence):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
if onnxscript.DEBUG:
if onnx_ir.DEBUG:
if not all(isinstance(x, Graph) for x in self.value):
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
# Create a copy of the list to prevent mutation
Expand Down
Loading