@@ -3414,6 +3414,13 @@ def __init__(
3414
3414
value = tuple (int (v ) for v in value )
3415
3415
elif type == _enums .AttributeType .FLOATS :
3416
3416
value = tuple (float (v ) for v in value )
3417
+ elif type in {
3418
+ _enums .AttributeType .STRINGS ,
3419
+ _enums .AttributeType .TENSORS ,
3420
+ _enums .AttributeType .GRAPHS ,
3421
+ _enums .AttributeType .TYPE_PROTOS ,
3422
+ }:
3423
+ value = tuple (value )
3417
3424
3418
3425
self ._name = name
3419
3426
self ._type = type
@@ -3508,93 +3515,86 @@ def as_string(self) -> str:
3508
3515
raise TypeError (
3509
3516
f"Attribute '{ self .name } ' is not of type STRING. Actual type: { self .type } "
3510
3517
)
3511
- if not isinstance (self .value , str ):
3518
+ value = self .value
3519
+ if not isinstance (value , str ):
3512
3520
raise TypeError (f"Value of attribute '{ self !r} ' is not a string." )
3513
- return self . value
3521
+ return value
3514
3522
3515
3523
def as_tensor (self ) -> _protocols .TensorProtocol :
3516
3524
"""Get the attribute value as a tensor."""
3517
3525
if self .type != _enums .AttributeType .TENSOR :
3518
3526
raise TypeError (
3519
3527
f"Attribute '{ self .name } ' is not of type TENSOR. Actual type: { self .type } "
3520
3528
)
3521
- if not isinstance (self .value , _protocols .TensorProtocol ):
3529
+ value = self .value
3530
+ if not isinstance (value , _protocols .TensorProtocol ):
3522
3531
raise TypeError (f"Value of attribute '{ self !r} ' is not a tensor." )
3523
- return self . value
3532
+ return value
3524
3533
3525
3534
def as_graph (self ) -> Graph :
3526
3535
"""Get the attribute value as a graph."""
3527
3536
if self .type != _enums .AttributeType .GRAPH :
3528
3537
raise TypeError (
3529
3538
f"Attribute '{ self .name } ' is not of type GRAPH. Actual type: { self .type } "
3530
3539
)
3531
- if not isinstance (self .value , Graph ):
3540
+ value = self .value
3541
+ if not isinstance (value , Graph ):
3532
3542
raise TypeError (f"Value of attribute '{ self !r} ' is not a graph." )
3533
- return self . value
3543
+ return value
3534
3544
3535
- def as_floats (self ) -> Sequence [float ]:
3545
+ def as_floats (self ) -> tuple [float , ... ]:
3536
3546
"""Get the attribute value as a sequence of floats."""
3537
3547
if self .type != _enums .AttributeType .FLOATS :
3538
3548
raise TypeError (
3539
3549
f"Attribute '{ self .name } ' is not of type FLOATS. Actual type: { self .type } "
3540
3550
)
3541
- if not isinstance (self .value , Sequence ):
3542
- raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3543
3551
# value is guaranteed to be a sequence of float in the constructor
3544
3552
return self .value
3545
3553
3546
- def as_ints (self ) -> Sequence [int ]:
3554
+ def as_ints (self ) -> tuple [int , ... ]:
3547
3555
"""Get the attribute value as a sequence of ints."""
3548
3556
if self .type != _enums .AttributeType .INTS :
3549
3557
raise TypeError (
3550
3558
f"Attribute '{ self .name } ' is not of type INTS. Actual type: { self .type } "
3551
3559
)
3552
- if not isinstance (self .value , Sequence ):
3553
- raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3554
3560
# value is guaranteed to be a sequence of int in the constructor
3555
3561
return self .value
3556
3562
3557
- def as_strings (self ) -> Sequence [str ]:
3563
+ def as_strings (self ) -> tuple [str , ... ]:
3558
3564
"""Get the attribute value as a sequence of strings."""
3559
3565
if self .type != _enums .AttributeType .STRINGS :
3560
3566
raise TypeError (
3561
3567
f"Attribute '{ self .name } ' is not of type STRINGS. Actual type: { self .type } "
3562
3568
)
3563
- if not isinstance (self .value , Sequence ):
3564
- raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3565
3569
if onnx_ir .DEBUG :
3566
3570
if not all (isinstance (x , str ) for x in self .value ):
3567
3571
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence of strings." )
3568
- # Create a copy of the list to prevent mutation
3569
- return list ( self .value )
3572
+ # value is guaranteed to be a sequence in the constructor
3573
+ return self .value
3570
3574
3571
- def as_tensors (self ) -> Sequence [_protocols .TensorProtocol ]:
3575
+ def as_tensors (self ) -> tuple [_protocols .TensorProtocol , ... ]:
3572
3576
"""Get the attribute value as a sequence of tensors."""
3573
3577
if self .type != _enums .AttributeType .TENSORS :
3574
3578
raise TypeError (
3575
3579
f"Attribute '{ self .name } ' is not of type TENSORS. Actual type: { self .type } "
3576
3580
)
3577
- if not isinstance (self .value , Sequence ):
3578
- raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3579
3581
if onnx_ir .DEBUG :
3580
3582
if not all (isinstance (x , _protocols .TensorProtocol ) for x in self .value ):
3581
3583
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence of tensors." )
3582
- # Create a copy of the list to prevent mutation
3583
- return list (self .value )
3584
+ # value is guaranteed to be a sequence in the constructor
3585
+ return tuple (self .value )
3584
3586
3585
- def as_graphs (self ) -> Sequence [Graph ]:
3587
+ def as_graphs (self ) -> tuple [Graph , ... ]:
3586
3588
"""Get the attribute value as a sequence of graphs."""
3587
3589
if self .type != _enums .AttributeType .GRAPHS :
3588
3590
raise TypeError (
3589
3591
f"Attribute '{ self .name } ' is not of type GRAPHS. Actual type: { self .type } "
3590
3592
)
3591
- if not isinstance (self .value , Sequence ):
3592
- raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence." )
3593
3593
if onnx_ir .DEBUG :
3594
3594
if not all (isinstance (x , Graph ) for x in self .value ):
3595
3595
raise TypeError (f"Value of attribute '{ self !r} ' is not a Sequence of graphs." )
3596
- # Create a copy of the list to prevent mutation
3597
- return list (self .value )
3596
+ # value is guaranteed to be a sequence in the constructor
3597
+ return tuple (self .value )
3598
3598
3599
3599
3600
3600
# NOTE: The following functions are just for convenience
0 commit comments