Skip to content

Commit 859a2c4

Browse files
committed
More fixes
Signed-off-by: Justin Chu <[email protected]>
1 parent a792f1c commit 859a2c4

32 files changed

+63
-822
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
if ! lintrunner --force-color --all-files --tee-json=lint.json -v; then
6464
echo ""
6565
echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner\`.\e[0m"
66-
echo -e "\e[1m\e[36mSee https://github.com/microsoft/onnxscript#coding-style for setup instructions.\e[0m"
66+
echo -e "\e[1m\e[36mSee https://github.com/onnx/onnx_ir/blob/main/CONTRIBUTING.md for setup instructions.\e[0m"
6767
exit 1
6868
fi
6969
- name: Produce SARIF

src/onnx_ir/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@
8484
]
8585

8686
from onnx_ir import convenience, external_data, passes, serde, tape, traversal
87-
from onnxscript.ir._convenience._constructors import node, tensor
88-
from onnxscript.ir._core import (
87+
from onnx_ir._convenience._constructors import node, tensor
88+
from onnx_ir._core import (
8989
Attr,
9090
AttrFloat32,
9191
AttrFloat32s,
@@ -121,12 +121,12 @@
121121
TypeAndShape,
122122
Value,
123123
)
124-
from onnxscript.ir._enums import (
124+
from onnx_ir._enums import (
125125
AttributeType,
126126
DataType,
127127
)
128-
from onnxscript.ir._io import load, save
129-
from onnxscript.ir._protocols import (
128+
from onnx_ir._io import load, save
129+
from onnx_ir._protocols import (
130130
ArrayCompatible,
131131
AttributeProtocol,
132132
DLPackCompatible,
@@ -145,7 +145,7 @@
145145
TypeProtocol,
146146
ValueProtocol,
147147
)
148-
from onnxscript.ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
148+
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
149149

150150

151151
def __set_module() -> None:

src/onnx_ir/_convenience/_constructors_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
import onnx_ir as ir
10-
from onnxscript.ir._convenience import _constructors
10+
from onnx_ir._convenience import _constructors
1111

1212

1313
class ConstructorsTest(unittest.TestCase):

src/onnx_ir/_core.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import numpy as np
4545
from typing_extensions import TypeIs
4646

47-
import onnxscript
47+
import onnx_ir
4848
from onnx_ir import (
4949
_display,
5050
_enums,
@@ -186,7 +186,7 @@ def display(self, *, page: bool = False) -> None:
186186

187187
status_manager = rich.status.Status(f"Computing tensor stats for {self!r}")
188188

189-
from onnxscript._thirdparty import ( # pylint: disable=import-outside-toplevel
189+
from onnx_ir._thirdparty import ( # pylint: disable=import-outside-toplevel
190190
asciichartpy,
191191
)
192192

@@ -582,7 +582,7 @@ def __init__(
582582
# NOTE: Do not verify the location by default. This is because the location field
583583
# in the tensor proto can be anything and we would like deserialization from
584584
# proto to IR to not fail.
585-
if onnxscript.DEBUG:
585+
if onnx_ir.DEBUG:
586586
if os.path.isabs(location):
587587
raise ValueError(
588588
"The location must be a relative path. Please specify base_dir as well."
@@ -2052,7 +2052,7 @@ def const_value(
20522052
self,
20532053
value: _protocols.TensorProtocol | None,
20542054
) -> None:
2055-
if onnxscript.DEBUG:
2055+
if onnx_ir.DEBUG:
20562056
if value is not None and not isinstance(value, _protocols.TensorProtocol):
20572057
raise TypeError(
20582058
f"Expected value to be a TensorProtocol or None, got '{type(value)}'"
@@ -2469,7 +2469,7 @@ def sort(self) -> None:
24692469
ValueError: If the graph contains a cycle, making topological sorting impossible.
24702470
"""
24712471
# Obtain all nodes from the graph and its subgraphs for sorting
2472-
nodes = list(onnxscript.ir.traversal.RecursiveGraphIterator(self))
2472+
nodes = list(onnx_ir.traversal.RecursiveGraphIterator(self))
24732473
# Store the sorted nodes of each subgraph
24742474
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
24752475
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
@@ -2858,7 +2858,7 @@ def graphs(self) -> Iterable[Graph]:
28582858
"""Get all graphs and subgraphs in the model.
28592859
28602860
This is a convenience method to traverse the model. Consider using
2861-
`onnxscript.ir.traversal.RecursiveGraphIterator` for more advanced
2861+
`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
28622862
traversals on nodes.
28632863
"""
28642864
# NOTE(justinchuby): Given
@@ -2868,7 +2868,7 @@ def graphs(self) -> Iterable[Graph]:
28682868
# I created this method as a core method instead of an iterator in
28692869
# `traversal.py`.
28702870
seen_graphs: set[Graph] = set()
2871-
for node in onnxscript.ir.traversal.RecursiveGraphIterator(self.graph):
2871+
for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
28722872
if node.graph is not None and node.graph not in seen_graphs:
28732873
seen_graphs.add(node.graph)
28742874
yield node.graph
@@ -3226,7 +3226,7 @@ def as_strings(self) -> Sequence[str]:
32263226
"""Get the attribute value as a sequence of strings."""
32273227
if not isinstance(self.value, Sequence):
32283228
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3229-
if onnxscript.DEBUG:
3229+
if onnx_ir.DEBUG:
32303230
if not all(isinstance(x, str) for x in self.value):
32313231
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
32323232
# Create a copy of the list to prevent mutation
@@ -3236,7 +3236,7 @@ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
32363236
"""Get the attribute value as a sequence of tensors."""
32373237
if not isinstance(self.value, Sequence):
32383238
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3239-
if onnxscript.DEBUG:
3239+
if onnx_ir.DEBUG:
32403240
if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
32413241
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
32423242
# Create a copy of the list to prevent mutation
@@ -3246,7 +3246,7 @@ def as_graphs(self) -> Sequence[Graph]:
32463246
"""Get the attribute value as a sequence of graphs."""
32473247
if not isinstance(self.value, Sequence):
32483248
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3249-
if onnxscript.DEBUG:
3249+
if onnx_ir.DEBUG:
32503250
if not all(isinstance(x, Graph) for x in self.value):
32513251
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
32523252
# Create a copy of the list to prevent mutation

src/onnx_ir/_display_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import numpy as np
99

10-
import onnxscript.ir as ir
10+
import onnx_ir as ir
1111

1212

1313
class DisplayTest(unittest.TestCase):

src/onnx_ir/_graph_containers.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
import collections
1515
from typing import TYPE_CHECKING, Iterable, SupportsIndex
1616

17-
import onnxscript
18-
1917
if TYPE_CHECKING:
2018
from onnx_ir import _core
2119

@@ -132,7 +130,7 @@ class GraphInputs(_GraphIO):
132130

133131
def _check_invariance(self) -> None:
134132
"""Check the invariance of the graph."""
135-
if not onnxscript.DEBUG:
133+
if not onnx_ir.DEBUG:
136134
return
137135
for value in self.data:
138136
if value._graph is self._graph:
@@ -170,7 +168,7 @@ class GraphOutputs(_GraphIO):
170168

171169
def _check_invariance(self) -> None:
172170
"""Check the invariance of the graph."""
173-
if not onnxscript.DEBUG:
171+
if not onnx_ir.DEBUG:
174172
return
175173
for value in self.data:
176174
if value._graph is self._graph:

src/onnx_ir/_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from onnx_ir import _core, serde
1414
from onnx_ir import external_data as _external_data
15-
from onnxscript.ir._polyfill import zip
15+
from onnx_ir._polyfill import zip
1616

1717

1818
def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:

0 commit comments

Comments
 (0)