From 041a0c68ab7c4701ce45028b81fd88cbb646ea3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Fran=C3=A7ois=20Nguyen?= Date: Tue, 12 Sep 2023 14:14:57 +0200 Subject: [PATCH] csr.reg: migrate to lib.wiring interfaces. --- amaranth_soc/csr/reg.py | 184 +++++++++++++++++++++++++++------------- tests/test_csr_field.py | 8 -- tests/test_csr_reg.py | 124 ++++++++++++++++++--------- 3 files changed, 210 insertions(+), 106 deletions(-) diff --git a/amaranth_soc/csr/reg.py b/amaranth_soc/csr/reg.py index 1f2cfd9..176674f 100644 --- a/amaranth_soc/csr/reg.py +++ b/amaranth_soc/csr/reg.py @@ -1,6 +1,7 @@ from collections.abc import Mapping, Sequence -import enum from amaranth import * +from amaranth.lib import enum, wiring +from amaranth.lib.wiring import In, Out from ..memory import MemoryMap from .bus import Element, Multiplexer @@ -9,7 +10,7 @@ __all__ = ["FieldPort", "Field", "FieldMap", "FieldArray", "Register", "RegisterMap", "Bridge"] -class FieldPort: +class FieldPort(wiring.Interface): class Access(enum.Enum): """Field access mode.""" R = "r" @@ -22,64 +23,135 @@ def readable(self): def writable(self): return self == self.W or self == self.RW + class Signature(wiring.Signature): + """CSR register field port signature. + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. + access : :class:`FieldPort.Access` + Field access mode. + + Interface attributes + -------------------- + r_data : Signal(shape) + Read data. Must always be valid, and is sampled when ``r_stb`` is asserted. + r_stb : Signal() + Read strobe. Fields with read side effects should perform them when this strobe is + asserted. + w_data : Signal(shape) + Write data. Valid only when ``w_stb`` is asserted. + w_stb : Signal() + Write strobe. Fields should update their value or perform the write side effect when + this strobe is asserted. + + Raises + ------ + See :meth:`FieldPort.Signature.check_parameters`. + """ + def __init__(self, shape, access): + self.check_parameters(shape, access) + + self._shape = Shape.cast(shape) + self._access = FieldPort.Access(access) + + members = { + "r_data": Out(self.shape), + "r_stb": In(1), + "w_data": In(self.shape), + "w_stb": In(1), + } + super().__init__(members) + + @property + def shape(self): + return self._shape + + @property + def access(self): + return self._access + + @classmethod + def check_parameters(cls, shape, access): + """Validate signature parameters. + + Raises + ------ + :exc:`TypeError` + If ``shape`` is not a shape-castable object. + :exc:`ValueError` + If ``access`` is not a member of :class:`FieldPort.Access`. + """ + try: + Shape.cast(shape) + except TypeError as e: + raise TypeError(f"Field shape must be a shape-castable object, not {shape!r}") from e + # TODO(py3.9): Remove this. Python 3.8 and below use cls.__name__ in the error message + # instead of cls.__qualname__. + # FieldPort.Access(access) + try: + FieldPort.Access(access) + except ValueError as e: + raise ValueError(f"{access!r} is not a valid FieldPort.Access") from e + + def create(self, *, path=()): + """Create a compatible interface. + + See :meth:`wiring.Signature.create` for details. + + Returns + ------- + A :class:`FieldPort` object using this signature. + """ + return FieldPort(self, path=path) + + def __eq__(self, other): + """Compare signatures. + + Two signatures are equal if they have the same shape and field access mode. + """ + return (isinstance(other, FieldPort.Signature) and + Shape.cast(self.shape) == Shape.cast(other.shape) and + self.access == other.access) + + def __repr__(self): + return f"csr.FieldPort.Signature({self.members!r})" + """CSR register field port. An interface between a CSR register and one of its fields. Parameters ---------- - shape : :ref:`shape-castable ` - Shape of the field. - access : :class:`FieldPort.Access` - Field access mode. - - Attributes - ---------- - r_data : Signal(shape) - Read data. Must always be valid, and is sampled when ``r_stb`` is asserted. - r_stb : Signal() - Read strobe. Fields with read side effects should perform them when this strobe is - asserted. - w_data : Signal(shape) - Write data. Valid only when ``w_stb`` is asserted. - w_stb : Signal() - Write strobe. Fields should update their value or perform the write side effect when - this strobe is asserted. + signature : :class:`FieldPort.Signature` + Field port signature. + path : iter(:class:`str`) + Path to the field port. Optional. See :class:`wiring.Interface`. Raises ------ :exc:`TypeError` If ``shape`` is not a shape-castable object. - :exc:`ValueError` + :exc:`TypeError` If ``access`` is not a member of :class:`FieldPort.Access`. """ - def __init__(self, shape, access): - try: - shape = Shape.cast(shape) - except TypeError as e: - raise TypeError("Field shape must be a shape-castable object, not {!r}" - .format(shape)) from e - if not isinstance(access, FieldPort.Access) and access not in ("r", "w", "rw"): - raise ValueError("Access mode must be one of \"r\", \"w\", or \"rw\", not {!r}" - .format(access)) - self._shape = shape - self._access = FieldPort.Access(access) - - self.r_data = Signal(shape) - self.r_stb = Signal() - self.w_data = Signal(shape) - self.w_stb = Signal() + def __init__(self, signature, *, path=()): + if not isinstance(signature, FieldPort.Signature): + raise TypeError(f"This interface requires a csr.FieldPort.Signature, not " + f"{signature!r}") + super().__init__(signature, path=path) @property def shape(self): - return self._shape + return self.signature.shape @property def access(self): - return self._access + return self.signature.access def __repr__(self): - return "FieldPort({}, {})".format(self.shape, self.access) + return f"csr.FieldPort({self.signature!r})" class Field(Elaboratable): @@ -109,7 +181,7 @@ class Field(Elaboratable): attributes="") def __init__(self, shape, access): - self.port = FieldPort(shape, access) + self.port = FieldPort.Signature(shape, access).create(path=("port",)) @property def shape(self): @@ -296,7 +368,7 @@ class Register(Elaboratable): Raises ------ - :exc:`ValueError` + :exc:`TypeError` If ``access`` is not a member of :class:`Element.Access`. :exc:`TypeError` If ``fields`` is not ``None`` or a :class:`FieldMap` or a :class:`FieldArray`. @@ -305,39 +377,37 @@ class Register(Elaboratable): :exc:`ValueError` If ``access`` is not writable and at least one field is writable. """ - def __init__(self, access, fields=None): + def __init__(self, access="rw", fields=None): if not isinstance(access, Element.Access) and access not in ("r", "w", "rw"): - raise ValueError("Access mode must be one of \"r\", \"w\", or \"rw\", not {!r}" - .format(access)) + raise TypeError(f"Access mode must be one of \"r\", \"w\", or \"rw\", not {access!r}") access = Element.Access(access) if hasattr(self, "__annotations__"): - annotation_fields = {} + annot_fields = {} for key, value in self.__annotations__.items(): if isinstance(value, (Field, FieldMap, FieldArray)): - annotation_fields[key] = value + annot_fields[key] = value if fields is None: - fields = FieldMap(annotation_fields) - elif annotation_fields: - raise ValueError("Field collection {} cannot be provided in addition to field annotations: {}" - .format(fields, ", ".join(annotation_fields.keys()))) + fields = FieldMap(annot_fields) + elif annot_fields: + raise ValueError(f"Field collection {fields} cannot be provided in addition to " + f"field annotations: {', '.join(annot_fields)}") if not isinstance(fields, (FieldMap, FieldArray)): - raise TypeError("Field collection must be a FieldMap or a FieldArray, not {!r}" - .format(fields)) + raise TypeError(f"Field collection must be a FieldMap or a FieldArray, not {fields!r}") width = 0 for field_name, field in fields.flatten(): width += Shape.cast(field.shape).width if field.access.readable() and not access.readable(): - raise ValueError("Field {} is readable, but register access mode is '{}'" - .format("__".join(field_name), access)) + raise ValueError(f"Field {'__'.join(field_name)} is readable, but register access " + f"mode is {access!r}") if field.access.writable() and not access.writable(): - raise ValueError("Field {} is writable, but register access mode is '{}'" - .format("__".join(field_name), access)) + raise ValueError(f"Field {'__'.join(field_name)} is writable, but register access " + f"mode is {access!r}") - self.element = Element(width, access) + self.element = Element.Signature(width, access).create(path=("element",)) self._fields = fields @property diff --git a/tests/test_csr_field.py b/tests/test_csr_field.py index 6e1d1d6..7415295 100644 --- a/tests/test_csr_field.py +++ b/tests/test_csr_field.py @@ -10,7 +10,6 @@ class RTestCase(unittest.TestCase): def test_simple(self): f = field.R(unsigned(4)) - self.assertEqual(repr(f.port), "FieldPort(unsigned(4), Access.R)") self.assertEqual(f.r_data.shape(), unsigned(4)) def test_sim(self): @@ -30,7 +29,6 @@ def process(): class WTestCase(unittest.TestCase): def test_simple(self): f = field.W(unsigned(4)) - self.assertEqual(repr(f.port), "FieldPort(unsigned(4), Access.W)") self.assertEqual(f.w_data.shape(), unsigned(4)) def test_sim(self): @@ -51,10 +49,8 @@ class RWTestCase(unittest.TestCase): def test_simple(self): f4 = field.RW(unsigned(4), reset=0x5) f8 = field.RW(signed(8)) - self.assertEqual(repr(f4.port), "FieldPort(unsigned(4), Access.RW)") self.assertEqual(f4.data.shape(), unsigned(4)) self.assertEqual(f4.reset, 0x5) - self.assertEqual(repr(f8.port), "FieldPort(signed(8), Access.RW)") self.assertEqual(f8.data.shape(), signed(8)) self.assertEqual(f8.reset, 0) @@ -82,11 +78,9 @@ class RW1CTestCase(unittest.TestCase): def test_simple(self): f4 = field.RW1C(unsigned(4), reset=0x5) f8 = field.RW1C(signed(8)) - self.assertEqual(repr(f4.port), "FieldPort(unsigned(4), Access.RW)") self.assertEqual(f4.data.shape(), unsigned(4)) self.assertEqual(f4.set .shape(), unsigned(4)) self.assertEqual(f4.reset, 0x5) - self.assertEqual(repr(f8.port), "FieldPort(signed(8), Access.RW)") self.assertEqual(f8.data.shape(), signed(8)) self.assertEqual(f8.set .shape(), signed(8)) self.assertEqual(f8.reset, 0) @@ -122,11 +116,9 @@ class RW1STestCase(unittest.TestCase): def test_simple(self): f4 = field.RW1S(unsigned(4), reset=0x5) f8 = field.RW1S(signed(8)) - self.assertEqual(repr(f4.port), "FieldPort(unsigned(4), Access.RW)") self.assertEqual(f4.data .shape(), unsigned(4)) self.assertEqual(f4.clear.shape(), unsigned(4)) self.assertEqual(f4.reset, 0x5) - self.assertEqual(repr(f8.port), "FieldPort(signed(8), Access.RW)") self.assertEqual(f8.data .shape(), signed(8)) self.assertEqual(f8.clear.shape(), signed(8)) self.assertEqual(f8.reset, 0) diff --git a/tests/test_csr_reg.py b/tests/test_csr_reg.py index a823cc7..3f2281a 100644 --- a/tests/test_csr_reg.py +++ b/tests/test_csr_reg.py @@ -2,62 +2,100 @@ import unittest from amaranth import * +from amaranth.lib.wiring import * from amaranth.sim import * from amaranth_soc.csr.reg import * from amaranth_soc.csr import field -class FieldPortTestCase(unittest.TestCase): +class FieldPortSignatureTestCase(unittest.TestCase): def test_shape_1_ro(self): - port = FieldPort(1, "r") - self.assertEqual(port.shape, unsigned(1)) - self.assertEqual(port.access, FieldPort.Access.R) - self.assertEqual(port.r_data.shape(), unsigned(1)) - self.assertEqual(port.r_stb .shape(), unsigned(1)) - self.assertEqual(port.w_data.shape(), unsigned(1)) - self.assertEqual(port.w_stb .shape(), unsigned(1)) - self.assertEqual(repr(port), "FieldPort(unsigned(1), Access.R)") + sig = FieldPort.Signature(1, "r") + self.assertEqual(sig.shape, unsigned(1)) + self.assertEqual(sig.access, FieldPort.Access.R) + self.assertEqual(sig.members, Signature({ + "r_data": Out(unsigned(1)), + "r_stb": In(1), + "w_data": In(unsigned(1)), + "w_stb": In(1), + }).members) def test_shape_8_rw(self): - port = FieldPort(8, "rw") - self.assertEqual(port.shape, unsigned(8)) - self.assertEqual(port.access, FieldPort.Access.RW) - self.assertEqual(port.r_data.shape(), unsigned(8)) - self.assertEqual(port.r_stb .shape(), unsigned(1)) - self.assertEqual(port.w_data.shape(), unsigned(8)) - self.assertEqual(port.w_stb .shape(), unsigned(1)) - self.assertEqual(repr(port), "FieldPort(unsigned(8), Access.RW)") + sig = FieldPort.Signature(8, "rw") + self.assertEqual(sig.shape, unsigned(8)) + self.assertEqual(sig.access, FieldPort.Access.RW) + self.assertEqual(sig.members, Signature({ + "r_data": Out(unsigned(8)), + "r_stb": In(1), + "w_data": In(unsigned(8)), + "w_stb": In(1), + }).members) def test_shape_10_wo(self): - port = FieldPort(10, "w") - self.assertEqual(port.shape, unsigned(10)) - self.assertEqual(port.access, FieldPort.Access.W) - self.assertEqual(port.r_data.shape(), unsigned(10)) - self.assertEqual(port.r_stb .shape(), unsigned(1)) - self.assertEqual(port.w_data.shape(), unsigned(10)) - self.assertEqual(port.w_stb .shape(), unsigned(1)) - self.assertEqual(repr(port), "FieldPort(unsigned(10), Access.W)") + sig = FieldPort.Signature(10, "w") + self.assertEqual(sig.shape, unsigned(10)) + self.assertEqual(sig.access, FieldPort.Access.W) + self.assertEqual(sig.members, Signature({ + "r_data": Out(unsigned(10)), + "r_stb": In(1), + "w_data": In(unsigned(10)), + "w_stb": In(1), + }).members) def test_shape_0_rw(self): - port = FieldPort(0, "rw") - self.assertEqual(port.shape, unsigned(0)) + sig = FieldPort.Signature(0, "w") + self.assertEqual(sig.shape, unsigned(0)) + self.assertEqual(sig.access, FieldPort.Access.W) + self.assertEqual(sig.members, Signature({ + "r_data": Out(unsigned(0)), + "r_stb": In(1), + "w_data": In(unsigned(0)), + "w_stb": In(1), + }).members) + + def test_create(self): + sig = FieldPort.Signature(unsigned(8), "rw") + port = sig.create(path=("foo", "bar")) + self.assertIsInstance(port, FieldPort) + self.assertEqual(port.shape, unsigned(8)) self.assertEqual(port.access, FieldPort.Access.RW) - self.assertEqual(port.r_data.shape(), unsigned(0)) - self.assertEqual(port.r_stb .shape(), unsigned(1)) - self.assertEqual(port.w_data.shape(), unsigned(0)) - self.assertEqual(port.w_stb .shape(), unsigned(1)) - self.assertEqual(repr(port), "FieldPort(unsigned(0), Access.RW)") + self.assertEqual(port.r_stb.name, "foo__bar__r_stb") + self.assertIs(port.signature, sig) + + def test_eq(self): + self.assertEqual(FieldPort.Signature(8, "r"), FieldPort.Signature(8, "r")) + self.assertEqual(FieldPort.Signature(8, "r"), FieldPort.Signature(8, FieldPort.Access.R)) + # different shape + self.assertNotEqual(FieldPort.Signature(8, "r"), FieldPort.Signature(1, "r")) + # different access mode + self.assertNotEqual(FieldPort.Signature(8, "r"), FieldPort.Signature(8, "w")) + self.assertNotEqual(FieldPort.Signature(8, "r"), FieldPort.Signature(8, "rw")) + self.assertNotEqual(FieldPort.Signature(8, "w"), FieldPort.Signature(8, "rw")) - def test_shape_wrong(self): + def test_wrong_shape(self): with self.assertRaisesRegex(TypeError, r"Field shape must be a shape-castable object, not 'foo'"): - port = FieldPort("foo", "rw") + port = FieldPort.Signature("foo", "rw") - def test_access_wrong(self): - with self.assertRaisesRegex(ValueError, - r"Access mode must be one of \"r\", \"w\", or \"rw\", not 'wo'"): - port = FieldPort(8, "wo") + def test_wrong_access(self): + with self.assertRaisesRegex(ValueError, r"'wo' is not a valid FieldPort.Access"): + port = FieldPort.Signature(8, "wo") + + +class FieldPortTestCase(unittest.TestCase): + def test_simple(self): + sig = FieldPort.Signature(unsigned(8), "rw") + port = FieldPort(sig, path=("foo", "bar")) + self.assertEqual(port.shape, unsigned(8)) + self.assertEqual(port.access, FieldPort.Access.RW) + self.assertEqual(port.r_stb.name, "foo__bar__r_stb") + self.assertIs(port.signature, sig) + + def test_wrong_signature(self): + with self.assertRaisesRegex(TypeError, + r"This interface requires a csr\.FieldPort\.Signature, not 'foo'"): + FieldPort("foo") def _compatible_fields(a, b): @@ -70,7 +108,6 @@ def test_simple(self): field = Field(unsigned(4), "rw") self.assertEqual(field.shape, unsigned(4)) self.assertEqual(field.access, FieldPort.Access.RW) - self.assertEqual(repr(field.port), "FieldPort(unsigned(4), Access.RW)") def test_compatible(self): self.assertTrue(_compatible_fields(Field(unsigned(4), "rw"), @@ -86,8 +123,7 @@ def test_wrong_shape(self): Field("foo", "rw") def test_wrong_access(self): - with self.assertRaisesRegex(ValueError, - r"Access mode must be one of \"r\", \"w\", or \"rw\", not 'wo'"): + with self.assertRaisesRegex(ValueError, r"'wo' is not a valid FieldPort.Access"): Field(8, "wo") @@ -614,6 +650,8 @@ def test_memory_map(self): self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") self.assertEqual(registers[3][2], (4, 6)) + Fragment.get(dut, platform=None) # silence UnusedElaboratable + def test_wrong_register_map(self): with self.assertRaisesRegex(TypeError, r"Register map must be an instance of RegisterMap, not 'foo'"): @@ -659,6 +697,8 @@ def test_register_addr(self): self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") self.assertEqual(registers[3][2], (0x22, 0x24)) + Fragment.get(dut, platform=None) # silence UnusedElaboratable + def test_register_alignment(self): reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4)})) reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8)})) @@ -699,6 +739,8 @@ def test_register_alignment(self): self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") self.assertEqual(registers[3][2], (16, 18)) + Fragment.get(dut, platform=None) # silence UnusedElaboratable + def test_register_out_of_bounds(self): reg_rw_24 = Register("rw", FieldMap({"a": field.RW(24)})) register_map = RegisterMap()