Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Nov 12, 2024
1 parent c1b2358 commit a204407
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions onnxscript/ir/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,5 +1061,59 @@ def test_composite_type_is_comparable(self, _: str, type_: ir.TypeProtocol):
self.assertEqual(type_, copy.deepcopy(type_))


class AttrTest(unittest.TestCase):
"""Test the Attr class."""

def test_init(self):
attr = _core.Attr("test", ir.AttributeType.INT, 42, doc_string="test string")
self.assertEqual(attr.name, "test")
self.assertEqual(attr.value, 42)
self.assertEqual(attr.type, ir.AttributeType.INT)
self.assertEqual(attr.doc_string, "test string")

def test_as_float(self):
attr = _core.Attr("test", ir.AttributeType.FLOAT, 42.0)
self.assertEqual(attr.as_float(), 42.0)

attr_int_value = _core.Attr("test", ir.AttributeType.FLOAT, 42)
self.assertEqual(attr_int_value.as_float(), 42.0)

def test_as_int(self):
attr = _core.Attr("test", ir.AttributeType.INT, 0)
self.assertEqual(attr.as_int(), 0)

def test_as_string(self):
attr = _core.Attr("test", ir.AttributeType.STRING, "test string")
self.assertEqual(attr.as_string(), "test string")

def test_as_tensor(self):
attr = _core.Attr("test", ir.AttributeType.TENSOR, ir.tensor([42.0]))
np.testing.assert_equal(attr.as_tensor().numpy(), np.array([42.0]))

def test_as_graph(self):
attr = _core.Attr("test", ir.AttributeType.GRAPH, _core.Graph((), (), nodes=()))
self.assertIsInstance(attr.as_graph(), _core.Graph)

def test_as_floats(self):
attr = _core.Attr("test", ir.AttributeType.FLOATS, [42.0])
self.assertEqual(attr.as_floats(), [42.0])

def test_as_ints(self):
attr = _core.Attr("test", ir.AttributeType.INTS, [42])
self.assertEqual(attr.as_ints(), [42])

def test_as_strings(self):
attr = _core.Attr("test", ir.AttributeType.STRINGS, ["test string", ""])
self.assertEqual(attr.as_strings(), ["test string", ""])

def test_as_tensors(self):
attr = _core.Attr("test", ir.AttributeType.TENSORS, [ir.tensor([42.0])])
np.testing.assert_equal(attr.as_tensors()[0].numpy(), np.array([42.0]))

def test_as_graphs(self):
attr = _core.Attr("test", ir.AttributeType.GRAPHS, [_core.Graph((), (), nodes=())])
self.assertIsInstance(attr.as_graphs()[0], _core.Graph)


if __name__ == "__main__":
unittest.main()

0 comments on commit a204407

Please sign in to comment.