Skip to content

Commit a706e17

Browse files
authored
[serde] Cast attribute values to int or float (#175)
Cast attribute values to int or float based on the attribute type when serializing. Otherwise protobuf will complain > Field onnx.AttributeProto.i: Expected an int, got a boolean. This will be rejected in 7.34.0, please fix it before that This should fix pytorch/pytorch#161941. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent bfab676 commit a706e17

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/onnx_ir/serde.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,11 +1784,12 @@ def _fill_in_value_for_attribute(
17841784
) -> None:
17851785
if type_ == _enums.AttributeType.INT:
17861786
# value: int
1787-
attribute_proto.i = value
1787+
# Cast bool to int, for example
1788+
attribute_proto.i = int(value)
17881789
attribute_proto.type = onnx.AttributeProto.INT
17891790
elif type_ == _enums.AttributeType.FLOAT:
17901791
# value: float
1791-
attribute_proto.f = value
1792+
attribute_proto.f = float(value)
17921793
attribute_proto.type = onnx.AttributeProto.FLOAT
17931794
elif type_ == _enums.AttributeType.STRING:
17941795
# value: str

src/onnx_ir/serde_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33
import itertools
44
import unittest
5+
import warnings
56

67
import google.protobuf.text_format
78
import ml_dtypes
@@ -531,6 +532,33 @@ def test_deserialize_builds_correct_value_connections_for_subgraphs_that_referen
531532
)
532533

533534

535+
class SerializationTest(unittest.TestCase):
536+
@parameterized.parameterized.expand(
537+
[
538+
("float", ir.AttributeType.FLOAT, 1.5, 1.5),
539+
("int_as_float", ir.AttributeType.FLOAT, 1, 1.0),
540+
("int", ir.AttributeType.INT, 42, 42),
541+
("bool", ir.AttributeType.INT, True, 1),
542+
("ints", ir.AttributeType.INTS, [1, 2, 3], [1, 2, 3]),
543+
("floats", ir.AttributeType.FLOATS, [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]),
544+
("string", ir.AttributeType.STRING, "test_string", "test_string"),
545+
]
546+
)
547+
def test_serialize_attribute(self, _: str, typ: ir.AttributeType, value, expected):
548+
attr = ir.Attr("test_attr", typ, value)
549+
with warnings.catch_warnings(record=True) as w:
550+
# Ensure all warnings are caught, not just the default ones
551+
warnings.simplefilter("always")
552+
attr_proto = serde.serialize_attribute(attr)
553+
self.assertEqual(
554+
len(w), 0, f"Unexpected warnings: {[str(warn.message) for warn in w]}"
555+
)
556+
deserialized_attr = serde.deserialize_attribute(attr_proto)
557+
self.assertEqual(deserialized_attr.name, attr.name)
558+
self.assertEqual(deserialized_attr.type, attr.type)
559+
self.assertEqual(deserialized_attr.value, expected)
560+
561+
534562
class QuantizationAnnotationTest(unittest.TestCase):
535563
"""Test that quantization annotations are correctly serialized and deserialized."""
536564

0 commit comments

Comments
 (0)