|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 | import itertools
|
4 | 4 | import unittest
|
| 5 | +import warnings |
5 | 6 |
|
6 | 7 | import google.protobuf.text_format
|
7 | 8 | import ml_dtypes
|
@@ -531,6 +532,33 @@ def test_deserialize_builds_correct_value_connections_for_subgraphs_that_referen
|
531 | 532 | )
|
532 | 533 |
|
533 | 534 |
|
| 535 | +class SerializationTest(unittest.TestCase): |
| 536 | + @parameterized.parameterized.expand( |
| 537 | + [ |
| 538 | + ("float", ir.AttributeType.FLOAT, 1.5, 1.5), |
| 539 | + ("int_as_float", ir.AttributeType.FLOAT, 1, 1.0), |
| 540 | + ("int", ir.AttributeType.INT, 42, 42), |
| 541 | + ("bool", ir.AttributeType.INT, True, 1), |
| 542 | + ("ints", ir.AttributeType.INTS, [1, 2, 3], [1, 2, 3]), |
| 543 | + ("floats", ir.AttributeType.FLOATS, [1.0, 2.0, 3.0], [1.0, 2.0, 3.0]), |
| 544 | + ("string", ir.AttributeType.STRING, "test_string", "test_string"), |
| 545 | + ] |
| 546 | + ) |
| 547 | + def test_serialize_attribute(self, _: str, typ: ir.AttributeType, value, expected): |
| 548 | + attr = ir.Attr("test_attr", typ, value) |
| 549 | + with warnings.catch_warnings(record=True) as w: |
| 550 | + # Ensure all warnings are caught, not just the default ones |
| 551 | + warnings.simplefilter("always") |
| 552 | + attr_proto = serde.serialize_attribute(attr) |
| 553 | + self.assertEqual( |
| 554 | + len(w), 0, f"Unexpected warnings: {[str(warn.message) for warn in w]}" |
| 555 | + ) |
| 556 | + deserialized_attr = serde.deserialize_attribute(attr_proto) |
| 557 | + self.assertEqual(deserialized_attr.name, attr.name) |
| 558 | + self.assertEqual(deserialized_attr.type, attr.type) |
| 559 | + self.assertEqual(deserialized_attr.value, expected) |
| 560 | + |
| 561 | + |
534 | 562 | class QuantizationAnnotationTest(unittest.TestCase):
|
535 | 563 | """Test that quantization annotations are correctly serialized and deserialized."""
|
536 | 564 |
|
|
0 commit comments