Skip to content

Commit ebaa562

Browse files
committed
Add testing
Signed-off-by: Justin Chu <[email protected]>
1 parent 259aef7 commit ebaa562

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed

src/onnx_ir/testing.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,196 @@
11
# Copyright (c) ONNX Project Contributors
22
# SPDX-License-Identifier: Apache-2.0
33
"""Utilities for testing."""
4+
5+
from __future__ import annotations
6+
7+
__all__ = [
8+
"assert_onnx_proto_equal",
9+
]
10+
11+
import difflib
12+
import math
13+
from typing import Any, Collection, Sequence
14+
15+
import google.protobuf.message
16+
import onnx
17+
18+
19+
def _opset_import_key(opset_import: onnx.OperatorSetIdProto) -> tuple[str, int]:
20+
return (opset_import.domain, opset_import.version)
21+
22+
23+
def _value_info_key(value_info: onnx.ValueInfoProto) -> str:
24+
return value_info.name
25+
26+
27+
def _function_key(function: onnx.FunctionProto) -> tuple[str, str, str]:
28+
return (function.domain, function.name, getattr(function, "overload", ""))
29+
30+
31+
def _find_duplicates(with_duplicates: Collection[Any]) -> list[Any]:
32+
"""Return a list of duplicated elements in a collection."""
33+
seen = set()
34+
duplicates = []
35+
for x in with_duplicates:
36+
if x in seen:
37+
duplicates.append(x)
38+
seen.add(x)
39+
return duplicates
40+
41+
42+
def assert_onnx_proto_equal(
43+
actual: google.protobuf.message.Message | Any,
44+
expected: google.protobuf.message.Message | Any,
45+
ignore_initializer_value_proto: bool = False,
46+
) -> None:
47+
"""Assert that two ONNX protos are equal.
48+
49+
Equality is defined as having the same fields with the same values. When
50+
a field takes the default value, it is considered equal to the field
51+
not being set.
52+
53+
Sequential fields with name `opset_import`, `value_info`, and `functions` are
54+
compared disregarding the order of their elements.
55+
56+
Args:
57+
actual: The first ONNX proto.
58+
expected: The second ONNX proto.
59+
ignore_initializer_value_proto: Ignore value protos for initializers if there
60+
are extra ones in the actual proto.
61+
"""
62+
assert type(actual) is type(expected), (
63+
f"Type not equal: {type(actual)} != {type(expected)}"
64+
)
65+
66+
a_fields = {field.name: value for field, value in actual.ListFields()}
67+
b_fields = {field.name: value for field, value in expected.ListFields()}
68+
all_fields = sorted(set(a_fields.keys()) | set(b_fields.keys()))
69+
if isinstance(actual, onnx.GraphProto) and isinstance(expected, onnx.GraphProto):
70+
actual_initializer_names = {i.name for i in actual.initializer}
71+
expected_initializer_names = {i.name for i in expected.initializer}
72+
else:
73+
actual_initializer_names = set()
74+
expected_initializer_names = set()
75+
76+
# Record and report all errors
77+
errors = []
78+
for field in all_fields: # pylint: disable=too-many-nested-blocks
79+
# Obtain the default value if the field is not set. This way we can compare the two fields.
80+
a_value = getattr(actual, field)
81+
b_value = getattr(expected, field)
82+
if (
83+
isinstance(a_value, Sequence)
84+
and isinstance(b_value, Sequence)
85+
and not isinstance(a_value, (str, bytes))
86+
and not isinstance(b_value, (str, bytes))
87+
):
88+
# Check length first
89+
a_keys: list[Any] = []
90+
b_keys: list[Any] = []
91+
if field == "opset_import":
92+
a_value = sorted(a_value, key=_opset_import_key)
93+
b_value = sorted(b_value, key=_opset_import_key)
94+
a_keys = [_opset_import_key(opset_import) for opset_import in a_value]
95+
b_keys = [_opset_import_key(opset_import) for opset_import in b_value]
96+
elif field == "value_info":
97+
if (
98+
ignore_initializer_value_proto
99+
and isinstance(actual, onnx.GraphProto)
100+
and isinstance(expected, onnx.GraphProto)
101+
):
102+
# Filter out initializers from the value_info list
103+
a_value = [
104+
value_info
105+
for value_info in a_value
106+
if value_info.name not in actual_initializer_names
107+
]
108+
b_value = [
109+
value_info
110+
for value_info in b_value
111+
if value_info.name not in expected_initializer_names
112+
]
113+
a_value = sorted(a_value, key=_value_info_key)
114+
b_value = sorted(b_value, key=_value_info_key)
115+
a_keys = [_value_info_key(value_info) for value_info in a_value]
116+
b_keys = [_value_info_key(value_info) for value_info in b_value]
117+
elif field == "functions":
118+
a_value = sorted(a_value, key=_function_key)
119+
b_value = sorted(b_value, key=_function_key)
120+
a_keys = [_function_key(functions) for functions in a_value]
121+
b_keys = [_function_key(functions) for functions in b_value]
122+
123+
if a_keys != b_keys:
124+
keys_only_in_actual = set(a_keys) - set(b_keys)
125+
keys_only_in_expected = set(b_keys) - set(a_keys)
126+
error_message = (
127+
f"Field {field} not equal: keys_only_in_actual={keys_only_in_actual}, keys_only_in_expected={keys_only_in_expected}. "
128+
f"Field type: {type(a_value)}. "
129+
f"Duplicated a_keys: {_find_duplicates(a_keys)}, duplicated b_keys: {_find_duplicates(b_keys)}"
130+
)
131+
errors.append(error_message)
132+
elif len(a_value) != len(b_value):
133+
error_message = (
134+
f"Field {field} not equal: len(a)={len(a_value)}, len(b)={len(b_value)} "
135+
f"Field type: {type(a_value)}"
136+
)
137+
errors.append(error_message)
138+
else:
139+
# Check every element
140+
for i in range(len(a_value)): # pylint: disable=consider-using-enumerate
141+
actual_value_i = a_value[i]
142+
expected_value_i = b_value[i]
143+
if isinstance(
144+
actual_value_i, google.protobuf.message.Message
145+
) and isinstance(expected_value_i, google.protobuf.message.Message):
146+
try:
147+
assert_onnx_proto_equal(
148+
actual_value_i,
149+
expected_value_i,
150+
ignore_initializer_value_proto=ignore_initializer_value_proto,
151+
)
152+
except AssertionError as e:
153+
error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}, actual_value_i: {actual_value_i}, expected_value_i: {expected_value_i}"
154+
error_message = (
155+
str(e) + "\n\nCaused by the above error\n\n" + error_message
156+
)
157+
errors.append(error_message)
158+
elif actual_value_i != expected_value_i:
159+
if (
160+
isinstance(actual_value_i, float)
161+
and isinstance(expected_value_i, float)
162+
and math.isnan(actual_value_i)
163+
and math.isnan(expected_value_i)
164+
):
165+
# Consider NaNs equal
166+
continue
167+
error_message = f"Field {field} index {i} in sequence not equal. type(actual_value_i): {type(actual_value_i)}, type(expected_value_i): {type(expected_value_i)}"
168+
for line in difflib.ndiff(
169+
str(actual_value_i).splitlines(),
170+
str(expected_value_i).splitlines(),
171+
):
172+
error_message += "\n" + line
173+
errors.append(error_message)
174+
elif isinstance(a_value, google.protobuf.message.Message) and isinstance(
175+
b_value, google.protobuf.message.Message
176+
):
177+
assert_onnx_proto_equal(
178+
a_value, b_value, ignore_initializer_value_proto=ignore_initializer_value_proto
179+
)
180+
elif a_value != b_value:
181+
if (
182+
isinstance(a_value, float)
183+
and isinstance(b_value, float)
184+
and math.isnan(a_value)
185+
and math.isnan(b_value)
186+
):
187+
# Consider NaNs equal
188+
continue
189+
error_message = (
190+
f"Field {field} not equal. field_actual: {a_value}, field_expected: {b_value}"
191+
)
192+
errors.append(error_message)
193+
if errors:
194+
raise AssertionError(
195+
f"Protos not equal: {type(actual)} != {type(expected)}\n" + "\n".join(errors)
196+
)

0 commit comments

Comments
 (0)