44
44
import numpy as np
45
45
from typing_extensions import TypeIs
46
46
47
- import onnxscript
47
+ import onnx_ir
48
48
from onnx_ir import (
49
49
_display ,
50
50
_enums ,
@@ -186,7 +186,7 @@ def display(self, *, page: bool = False) -> None:
186
186
187
187
status_manager = rich .status .Status (f"Computing tensor stats for { self !r} " )
188
188
189
- from onnxscript ._thirdparty import ( # pylint: disable=import-outside-toplevel
189
+ from onnx_ir ._thirdparty import ( # pylint: disable=import-outside-toplevel
190
190
asciichartpy ,
191
191
)
192
192
@@ -582,7 +582,7 @@ def __init__(
582
582
# NOTE: Do not verify the location by default. This is because the location field
583
583
# in the tensor proto can be anything and we would like deserialization from
584
584
# proto to IR to not fail.
585
- if onnxscript .DEBUG :
585
+ if onnx_ir .DEBUG :
586
586
if os .path .isabs (location ):
587
587
raise ValueError (
588
588
"The location must be a relative path. Please specify base_dir as well."
@@ -2052,7 +2052,7 @@ def const_value(
2052
2052
self ,
2053
2053
value : _protocols .TensorProtocol | None ,
2054
2054
) -> None :
2055
- if onnxscript .DEBUG :
2055
+ if onnx_ir .DEBUG :
2056
2056
if value is not None and not isinstance (value , _protocols .TensorProtocol ):
2057
2057
raise TypeError (
2058
2058
f"Expected value to be a TensorProtocol or None, got '{ type (value )} '"
@@ -2469,7 +2469,7 @@ def sort(self) -> None:
2469
2469
ValueError: If the graph contains a cycle, making topological sorting impossible.
2470
2470
"""
2471
2471
# Obtain all nodes from the graph and its subgraphs for sorting
2472
- nodes = list (onnxscript . ir .traversal .RecursiveGraphIterator (self ))
2472
+ nodes = list (onnx_ir .traversal .RecursiveGraphIterator (self ))
2473
2473
# Store the sorted nodes of each subgraph
2474
2474
sorted_nodes_by_graph : dict [Graph , list [Node ]] = {
2475
2475
graph : [] for graph in {node .graph for node in nodes if node .graph is not None }
@@ -2858,7 +2858,7 @@ def graphs(self) -> Iterable[Graph]:
2858
2858
"""Get all graphs and subgraphs in the model.
2859
2859
2860
2860
This is a convenience method to traverse the model. Consider using
2861
- `onnxscript.ir .traversal.RecursiveGraphIterator` for more advanced
2861
+ `onnx_ir .traversal.RecursiveGraphIterator` for more advanced
2862
2862
traversals on nodes.
2863
2863
"""
2864
2864
# NOTE(justinchuby): Given
@@ -2868,7 +2868,7 @@ def graphs(self) -> Iterable[Graph]:
2868
2868
# I created this method as a core method instead of an iterator in
2869
2869
# `traversal.py`.
2870
2870
seen_graphs : set [Graph ] = set ()
2871
- for node in onnxscript . ir .traversal .RecursiveGraphIterator (self .graph ):
2871
+ for node in onnx_ir .traversal .RecursiveGraphIterator (self .graph ):
2872
2872
if node .graph is not None and node .graph not in seen_graphs :
2873
2873
seen_graphs .add (node .graph )
2874
2874
yield node .graph
@@ -3226,7 +3226,7 @@ def as_strings(self) -> Sequence[str]:
3226
3226
"""Get the attribute value as a sequence of strings."""
3227
3227
if not isinstance (self .value , Sequence ):
3228
3228
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3229
- if onnxscript .DEBUG :
3229
+ if onnx_ir .DEBUG :
3230
3230
if not all (isinstance (x , str ) for x in self .value ):
3231
3231
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence of strings." )
3232
3232
# Create a copy of the list to prevent mutation
@@ -3236,7 +3236,7 @@ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
3236
3236
"""Get the attribute value as a sequence of tensors."""
3237
3237
if not isinstance (self .value , Sequence ):
3238
3238
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3239
- if onnxscript .DEBUG :
3239
+ if onnx_ir .DEBUG :
3240
3240
if not all (isinstance (x , _protocols .TensorProtocol ) for x in self .value ):
3241
3241
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence of tensors." )
3242
3242
# Create a copy of the list to prevent mutation
@@ -3246,7 +3246,7 @@ def as_graphs(self) -> Sequence[Graph]:
3246
3246
"""Get the attribute value as a sequence of graphs."""
3247
3247
if not isinstance (self .value , Sequence ):
3248
3248
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3249
- if onnxscript .DEBUG :
3249
+ if onnx_ir .DEBUG :
3250
3250
if not all (isinstance (x , Graph ) for x in self .value ):
3251
3251
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence of graphs." )
3252
3252
# Create a copy of the list to prevent mutation
0 commit comments