Skip to content

Commit b9ecb98

Browse files
authored
Ensure int/float attributes are stored as Python int/float (#178)
Ensure int/float attributes are stored as Python int/float by doing a quick type conversion at Attr initialization. Even though this makes Attr initialization slightly more costly, it allows us to maintain an invariance where an INT or FLOAT value can only be a plain python number, making downstream usage and comparison much easier. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent d5c9cb5 commit b9ecb98

File tree

5 files changed

+38
-22
lines changed

5 files changed

+38
-22
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def convert_attributes(
226226
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
227227
... }
228228
>>> convert_attributes(attrs)
229-
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
229+
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, (1, 2, 3)), Attr('floats', FLOATS, (1.0, 2.0, 3.0)), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
230230
name='graph0',
231231
inputs=(
232232
<BLANKLINE>

src/onnx_ir/_core.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3397,6 +3397,24 @@ def __init__(
33973397
*,
33983398
doc_string: str | None = None,
33993399
) -> None:
3400+
# Quick checks to ensure that INT and FLOAT attributes are stored as int and float,
3401+
# not np.int32, np.float32, bool, etc.
3402+
# This also allows errors to be raised at the time of construction instead of later
3403+
# during serialization.
3404+
# TODO(justinchuby): Use case matching when we drop support for Python 3.9
3405+
if value is None:
3406+
# Value can be None for reference attributes or when it is used as a
3407+
# placeholder for schemas
3408+
pass
3409+
elif type == _enums.AttributeType.INT:
3410+
value = int(value)
3411+
elif type == _enums.AttributeType.FLOAT:
3412+
value = float(value)
3413+
elif type == _enums.AttributeType.INTS:
3414+
value = tuple(int(v) for v in value)
3415+
elif type == _enums.AttributeType.FLOATS:
3416+
value = tuple(float(v) for v in value)
3417+
34003418
self._name = name
34013419
self._type = type
34023420
self._value = value
@@ -3472,17 +3490,17 @@ def as_float(self) -> float:
34723490
raise TypeError(
34733491
f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
34743492
)
3475-
# Do not use isinstance check because it may prevent np.float32 etc. from being used
3476-
return float(self.value)
3493+
# value is guaranteed to be a float in the constructor
3494+
return self.value
34773495

34783496
def as_int(self) -> int:
34793497
"""Get the attribute value as an int."""
34803498
if self.type != _enums.AttributeType.INT:
34813499
raise TypeError(
34823500
f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
34833501
)
3484-
# Do not use isinstance check because it may prevent np.int32 etc. from being used
3485-
return int(self.value)
3502+
# value is guaranteed to be an int in the constructor
3503+
return self.value
34863504

34873505
def as_string(self) -> str:
34883506
"""Get the attribute value as a string."""
@@ -3522,9 +3540,8 @@ def as_floats(self) -> Sequence[float]:
35223540
)
35233541
if not isinstance(self.value, Sequence):
35243542
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3525-
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3526-
# Create a copy of the list to prevent mutation
3527-
return [float(v) for v in self.value]
3543+
# value is guaranteed to be a sequence of float in the constructor
3544+
return self.value
35283545

35293546
def as_ints(self) -> Sequence[int]:
35303547
"""Get the attribute value as a sequence of ints."""
@@ -3534,9 +3551,8 @@ def as_ints(self) -> Sequence[int]:
35343551
)
35353552
if not isinstance(self.value, Sequence):
35363553
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3537-
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3538-
# Create a copy of the list to prevent mutation
3539-
return list(self.value)
3554+
# value is guaranteed to be a sequence of int in the constructor
3555+
return self.value
35403556

35413557
def as_strings(self) -> Sequence[str]:
35423558
"""Get the attribute value as a sequence of strings."""
@@ -3605,7 +3621,7 @@ def RefAttr(
36053621
return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
36063622

36073623

3608-
def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
3624+
def AttrFloat32(name: str, value: float | np.floating, doc_string: str | None = None) -> Attr:
36093625
"""Create a float attribute."""
36103626
# NOTE: The function name is capitalized to maintain API backward compatibility.
36113627
return Attr(
@@ -3616,7 +3632,7 @@ def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
36163632
)
36173633

36183634

3619-
def AttrInt64(name: str, value: int, doc_string: str | None = None) -> Attr:
3635+
def AttrInt64(name: str, value: int | np.integer, doc_string: str | None = None) -> Attr:
36203636
"""Create an int attribute."""
36213637
# NOTE: The function name is capitalized to maintain API backward compatibility.
36223638
return Attr(

src/onnx_ir/_core_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -950,7 +950,7 @@ def test_attributes_get_ints(self):
950950
inputs=(),
951951
attributes=[_core.AttrInt64s("test_attr", [1, 2, 3])],
952952
)
953-
self.assertEqual(node.attributes.get_ints("test_attr"), [1, 2, 3])
953+
self.assertEqual(node.attributes.get_ints("test_attr"), (1, 2, 3))
954954
self.assertIsNone(node.attributes.get_ints("non_existent_attr"))
955955
self.assertEqual(node.attributes.get_ints("non_existent_attr", [42]), [42])
956956

@@ -961,7 +961,7 @@ def test_attributes_get_floats(self):
961961
inputs=(),
962962
attributes=[_core.AttrFloat32s("test_attr", [1.0, 2.0, 3.0])],
963963
)
964-
self.assertEqual(node.attributes.get_floats("test_attr"), [1.0, 2.0, 3.0])
964+
self.assertEqual(node.attributes.get_floats("test_attr"), (1.0, 2.0, 3.0))
965965
self.assertIsNone(node.attributes.get_floats("non_existent_attr"))
966966
self.assertEqual(node.attributes.get_floats("non_existent_attr", [42.0]), [42.0])
967967

@@ -1971,11 +1971,11 @@ def test_as_graph(self):
19711971

19721972
def test_as_floats(self):
19731973
attr = _core.Attr("test", ir.AttributeType.FLOATS, [42.0])
1974-
self.assertEqual(attr.as_floats(), [42.0])
1974+
self.assertEqual(tuple(attr.as_floats()), (42.0,))
19751975

19761976
def test_as_ints(self):
19771977
attr = _core.Attr("test", ir.AttributeType.INTS, [42])
1978-
self.assertEqual(attr.as_ints(), [42])
1978+
self.assertEqual(tuple(attr.as_ints()), (42,))
19791979

19801980
def test_as_strings(self):
19811981
attr = _core.Attr("test", ir.AttributeType.STRINGS, ["test string", ""])

src/onnx_ir/serde.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,12 +1784,11 @@ def _fill_in_value_for_attribute(
17841784
) -> None:
17851785
if type_ == _enums.AttributeType.INT:
17861786
# value: int
1787-
# Cast bool to int, for example
1788-
attribute_proto.i = int(value)
1787+
attribute_proto.i = value
17891788
attribute_proto.type = onnx.AttributeProto.INT
17901789
elif type_ == _enums.AttributeType.FLOAT:
17911790
# value: float
1792-
attribute_proto.f = float(value)
1791+
attribute_proto.f = value
17931792
attribute_proto.type = onnx.AttributeProto.FLOAT
17941793
elif type_ == _enums.AttributeType.STRING:
17951794
# value: str

src/onnx_ir/serde_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,9 @@ class SerializationTest(unittest.TestCase):
539539
("int_as_float", ir.AttributeType.FLOAT, 1, 1.0),
540540
("int", ir.AttributeType.INT, 42, 42),
541541
("bool", ir.AttributeType.INT, True, 1),
542-
("ints", ir.AttributeType.INTS, [1, 2, 3], [1, 2, 3]),
543-
("floats", ir.AttributeType.FLOATS, [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]),
542+
("ints", ir.AttributeType.INTS, [1, 2, 3], (1, 2, 3)),
543+
("floats", ir.AttributeType.FLOATS, [1.0, 2.0, 3.0], (1.0, 2.0, 3.0)),
544+
("bools", ir.AttributeType.INTS, [True, False], (1, 0)),
544545
("string", ir.AttributeType.STRING, "test_string", "test_string"),
545546
]
546547
)

0 commit comments

Comments
 (0)