Skip to content

Commit 3ec5479

Browse files
committed
Fix
Signed-off-by: Justin Chu <[email protected]>
1 parent 859a2c4 commit 3ec5479

File tree

12 files changed

+20
-19
lines changed

12 files changed

+20
-19
lines changed

src/onnx_ir/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
# IO
8282
"load",
8383
"save",
84+
# Flags
85+
"DEBUG",
8486
]
8587

8688
from onnx_ir import convenience, external_data, passes, serde, tape, traversal
@@ -147,12 +149,14 @@
147149
)
148150
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_proto
149151

152+
DEBUG = False
150153

151154
def __set_module() -> None:
152155
"""Set the module of all functions in this module to this public module."""
153156
global_dict = globals()
154157
for name in __all__:
155-
global_dict[name].__module__ = __name__
156-
158+
if hasattr(global_dict[name], "__module__"):
159+
# Set the module of the function to this module
160+
global_dict[name].__module__ = __name__
157161

158162
__set_module()

src/onnx_ir/_convenience/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ def replace_nodes_and_values(
355355
old_values: The values to replace.
356356
new_values: The values to replace with.
357357
"""
358-
359358
for old_value, new_value in zip(old_values, new_values):
360359
# Propagate relevant info from old value to new value
361360
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider

src/onnx_ir/_core.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,7 @@ def _short_tensor_str_for_node(x: Value) -> str:
12791279

12801280

12811281
def _normalize_domain(domain: str) -> str:
1282-
"""Normalize 'ai.onnx' to ''"""
1282+
"""Normalize 'ai.onnx' to ''."""
12831283
return "" if domain == "ai.onnx" else domain
12841284

12851285

@@ -1709,7 +1709,7 @@ def dtype(self, value: _enums.DataType) -> None:
17091709

17101710
@property
17111711
def elem_type(self) -> _enums.DataType:
1712-
"""Return the element type of the tensor type"""
1712+
"""Return the element type of the tensor type."""
17131713
return self.dtype
17141714

17151715
def __hash__(self) -> int:
@@ -2099,7 +2099,6 @@ def Input(
20992099
21002100
This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
21012101
"""
2102-
21032102
# NOTE: The function name is capitalized to maintain API backward compatibility.
21042103

21052104
return Value(name=name, shape=shape, type=type, doc_string=doc_string)

src/onnx_ir/_graph_containers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import collections
1515
from typing import TYPE_CHECKING, Iterable, SupportsIndex
1616

17+
import onnx_ir
18+
1719
if TYPE_CHECKING:
1820
from onnx_ir import _core
1921

src/onnx_ir/external_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@
3535

3636
@dataclasses.dataclass
3737
class _ExternalDataInfo:
38-
"""
39-
A class that stores information about a tensor that is to be stored as external data.
38+
"""A class that stores information about a tensor that is to be stored as external data.
4039
4140
Attributes:
4241
name: The name of the tensor that is to be stored as external data.

src/onnx_ir/passes/_pass_infra.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ class PassResult:
7171
class PassBase(abc.ABC):
7272
"""Base class for all passes.
7373
74-
7574
``in_place`` and ``changes_input`` properties and what they mean:
7675
7776
+------------+------------------+----------------------------+

src/onnx_ir/passes/common/_c_api_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
3434
Returns:
3535
The resulting ModelProto that contains the result of the API call.
3636
"""
37-
3837
# Store the original initializer values so they can be restored
3938
initializer_values = tuple(model.graph.initializers.values())
4039
tensors = {v.name: v.const_value for v in initializer_values}

src/onnx_ir/passes/common/inliner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) ONNX Project Contributors
22
# SPDX-License-Identifier: Apache-2.0
3-
"""Implementation of an inliner for onnx_ir"""
3+
"""Implementation of an inliner for onnx_ir."""
44

55
from __future__ import annotations
66

@@ -11,8 +11,8 @@
1111
from collections import defaultdict
1212
from typing import Iterable, List, Sequence, Tuple
1313

14-
import onnx_ir.convenience as _ir_convenience
1514
import onnx_ir as ir
15+
import onnx_ir.convenience as _ir_convenience
1616

1717
# A replacement for a node specifies a list of nodes that replaces the original node,
1818
# and a list of values that replaces the original node's outputs.

src/onnx_ir/serde_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import parameterized
1010

1111
import onnx_ir as ir
12-
from onnx_ir import _version_utils
13-
from onnx_ir import serde
12+
from onnx_ir import _version_utils, serde
1413

1514

1615
class ConvenienceFunctionsTest(unittest.TestCase):

tests/ir/public_api_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# Copyright (c) ONNX Project Contributors
2-
# SPDX-License-Identifier: Apache-2.0
31
# Adapted from
42
# https://github.com/pytorch/pytorch/blob/b505e8647547f029d0f7df408ee5f2968f757f89/test/test_public_bindings.py#L523
53
# Original code PyTorch license https://github.com/pytorch/pytorch/blob/main/LICENSE
@@ -21,6 +19,7 @@
2119

2220
def _find_all_importables(pkg):
2321
"""Find all importables in the project.
22+
2423
Return them in order.
2524
"""
2625
return sorted(
@@ -34,6 +33,7 @@ def _find_all_importables(pkg):
3433

3534
def _discover_path_importables(pkg_path: os.PathLike, pkg_name: str) -> Iterable[str]:
3635
"""Yield all importables under a given path and package.
36+
3737
This is like pkgutil.walk_packages, but does *not* skip over namespace
3838
packages. Taken from https://stackoverflow.com/questions/41203765/init-py-required-for-pkgutil-walk-packages-in-python3
3939
"""
@@ -155,7 +155,8 @@ class TestPublicApiNamespace(unittest.TestCase):
155155
tested_modules = (IR_NAMESPACE, *(_find_all_importables(onnx_ir)))
156156

157157
def test_correct_module_names(self):
158-
"""
158+
"""Test module names are correct.
159+
159160
An API is considered public, if its `__module__` starts with `onnx_ir`
160161
and there is no name in `__module__` or the object itself that starts with "_".
161162
Each public package should either:

0 commit comments

Comments
 (0)