Skip to content

Commit 7acc4bf

Browse files
authored
Simplify attribute signatures and use tuples for repeating attrs (#186)
Always return a tuple for repeated attributes so that users do not need to convert the type again when comparing values. Note that previously the methods are annotated to return Sequence, and they return list as a concreate type. This PR updates the methods to return tuples, which may be a BC breaking behavior, but should be good in the long run to reflect the immutable nature of the attributes. --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 8ba415c commit 7acc4bf

File tree

3 files changed

+32
-32
lines changed

3 files changed

+32
-32
lines changed

src/onnx_ir/_convenience/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def convert_attributes(
226226
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
227227
... }
228228
>>> convert_attributes(attrs)
229-
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, (1, 2, 3)), Attr('floats', FLOATS, (1.0, 2.0, 3.0)), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
229+
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, (1, 2, 3)), Attr('floats', FLOATS, (1.0, 2.0, 3.0)), Attr('strings', STRINGS, ('hello', 'world')), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
230230
name='graph0',
231231
inputs=(
232232
<BLANKLINE>
@@ -235,7 +235,7 @@ def convert_attributes(
235235
<BLANKLINE>
236236
),
237237
len()=0
238-
)), Attr('graphs', GRAPHS, [Graph(
238+
)), Attr('graphs', GRAPHS, (Graph(
239239
name='graph1',
240240
inputs=(
241241
<BLANKLINE>
@@ -253,7 +253,7 @@ def convert_attributes(
253253
<BLANKLINE>
254254
),
255255
len()=0
256-
)]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
256+
))), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, (Tensor(FLOAT), Tensor(FLOAT)))]
257257
258258
.. important::
259259
An empty sequence should be created with an explicit type by initializing

src/onnx_ir/_core.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3414,6 +3414,13 @@ def __init__(
34143414
value = tuple(int(v) for v in value)
34153415
elif type == _enums.AttributeType.FLOATS:
34163416
value = tuple(float(v) for v in value)
3417+
elif type in {
3418+
_enums.AttributeType.STRINGS,
3419+
_enums.AttributeType.TENSORS,
3420+
_enums.AttributeType.GRAPHS,
3421+
_enums.AttributeType.TYPE_PROTOS,
3422+
}:
3423+
value = tuple(value)
34173424

34183425
self._name = name
34193426
self._type = type
@@ -3508,93 +3515,86 @@ def as_string(self) -> str:
35083515
raise TypeError(
35093516
f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
35103517
)
3511-
if not isinstance(self.value, str):
3518+
value = self.value
3519+
if not isinstance(value, str):
35123520
raise TypeError(f"Value of attribute '{self!r}' is not a string.")
3513-
return self.value
3521+
return value
35143522

35153523
def as_tensor(self) -> _protocols.TensorProtocol:
35163524
"""Get the attribute value as a tensor."""
35173525
if self.type != _enums.AttributeType.TENSOR:
35183526
raise TypeError(
35193527
f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
35203528
)
3521-
if not isinstance(self.value, _protocols.TensorProtocol):
3529+
value = self.value
3530+
if not isinstance(value, _protocols.TensorProtocol):
35223531
raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
3523-
return self.value
3532+
return value
35243533

35253534
def as_graph(self) -> Graph:
35263535
"""Get the attribute value as a graph."""
35273536
if self.type != _enums.AttributeType.GRAPH:
35283537
raise TypeError(
35293538
f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
35303539
)
3531-
if not isinstance(self.value, Graph):
3540+
value = self.value
3541+
if not isinstance(value, Graph):
35323542
raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
3533-
return self.value
3543+
return value
35343544

3535-
def as_floats(self) -> Sequence[float]:
3545+
def as_floats(self) -> tuple[float, ...]:
35363546
"""Get the attribute value as a sequence of floats."""
35373547
if self.type != _enums.AttributeType.FLOATS:
35383548
raise TypeError(
35393549
f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
35403550
)
3541-
if not isinstance(self.value, Sequence):
3542-
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
35433551
# value is guaranteed to be a sequence of float in the constructor
35443552
return self.value
35453553

3546-
def as_ints(self) -> Sequence[int]:
3554+
def as_ints(self) -> tuple[int, ...]:
35473555
"""Get the attribute value as a sequence of ints."""
35483556
if self.type != _enums.AttributeType.INTS:
35493557
raise TypeError(
35503558
f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
35513559
)
3552-
if not isinstance(self.value, Sequence):
3553-
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
35543560
# value is guaranteed to be a sequence of int in the constructor
35553561
return self.value
35563562

3557-
def as_strings(self) -> Sequence[str]:
3563+
def as_strings(self) -> tuple[str, ...]:
35583564
"""Get the attribute value as a sequence of strings."""
35593565
if self.type != _enums.AttributeType.STRINGS:
35603566
raise TypeError(
35613567
f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
35623568
)
3563-
if not isinstance(self.value, Sequence):
3564-
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
35653569
if onnx_ir.DEBUG:
35663570
if not all(isinstance(x, str) for x in self.value):
35673571
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
3568-
# Create a copy of the list to prevent mutation
3569-
return list(self.value)
3572+
# value is guaranteed to be a sequence in the constructor
3573+
return self.value
35703574

3571-
def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
3575+
def as_tensors(self) -> tuple[_protocols.TensorProtocol, ...]:
35723576
"""Get the attribute value as a sequence of tensors."""
35733577
if self.type != _enums.AttributeType.TENSORS:
35743578
raise TypeError(
35753579
f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
35763580
)
3577-
if not isinstance(self.value, Sequence):
3578-
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
35793581
if onnx_ir.DEBUG:
35803582
if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
35813583
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
3582-
# Create a copy of the list to prevent mutation
3583-
return list(self.value)
3584+
# value is guaranteed to be a sequence in the constructor
3585+
return tuple(self.value)
35843586

3585-
def as_graphs(self) -> Sequence[Graph]:
3587+
def as_graphs(self) -> tuple[Graph, ...]:
35863588
"""Get the attribute value as a sequence of graphs."""
35873589
if self.type != _enums.AttributeType.GRAPHS:
35883590
raise TypeError(
35893591
f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
35903592
)
3591-
if not isinstance(self.value, Sequence):
3592-
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
35933593
if onnx_ir.DEBUG:
35943594
if not all(isinstance(x, Graph) for x in self.value):
35953595
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
3596-
# Create a copy of the list to prevent mutation
3597-
return list(self.value)
3596+
# value is guaranteed to be a sequence in the constructor
3597+
return tuple(self.value)
35983598

35993599

36003600
# NOTE: The following functions are just for convenience

src/onnx_ir/_core_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,7 @@ def test_attributes_get_strings(self):
972972
inputs=(),
973973
attributes=[_core.AttrStrings("test_attr", ["a", "b", "c"])],
974974
)
975-
self.assertEqual(node.attributes.get_strings("test_attr"), ["a", "b", "c"])
975+
self.assertEqual(node.attributes.get_strings("test_attr"), ("a", "b", "c"))
976976
self.assertIsNone(node.attributes.get_strings("non_existent_attr"))
977977
self.assertEqual(
978978
node.attributes.get_strings("non_existent_attr", ["default"]), ["default"]
@@ -1979,7 +1979,7 @@ def test_as_ints(self):
19791979

19801980
def test_as_strings(self):
19811981
attr = _core.Attr("test", ir.AttributeType.STRINGS, ["test string", ""])
1982-
self.assertEqual(attr.as_strings(), ["test string", ""])
1982+
self.assertEqual(attr.as_strings(), ("test string", ""))
19831983

19841984
def test_as_tensors(self):
19851985
attr = _core.Attr("test", ir.AttributeType.TENSORS, [ir.tensor([42.0])])

0 commit comments

Comments
 (0)