diff --git a/amaranth_soc/csr/reg.py b/amaranth_soc/csr/reg.py index 176674f..1fe6b4c 100644 --- a/amaranth_soc/csr/reg.py +++ b/amaranth_soc/csr/reg.py @@ -1,7 +1,7 @@ from collections.abc import Mapping, Sequence from amaranth import * from amaranth.lib import enum, wiring -from amaranth.lib.wiring import In, Out +from amaranth.lib.wiring import In, Out, connect, flipped from ..memory import MemoryMap from .bus import Element, Multiplexer @@ -57,10 +57,10 @@ def __init__(self, shape, access): self._access = FieldPort.Access(access) members = { - "r_data": Out(self.shape), - "r_stb": In(1), - "w_data": In(self.shape), - "w_stb": In(1), + "r_data": In(self.shape), + "r_stb": Out(1), + "w_data": Out(self.shape), + "w_stb": Out(1), } super().__init__(members) @@ -154,7 +154,7 @@ def __repr__(self): return f"csr.FieldPort({self.signature!r})" -class Field(Elaboratable): +class Field(wiring.Component): _doc_template = """ {description} @@ -181,15 +181,24 @@ class Field(Elaboratable): attributes="") def __init__(self, shape, access): - self.port = FieldPort.Signature(shape, access).create(path=("port",)) + FieldPort.Signature.check_parameters(shape, access) + self._shape = Shape.cast(shape) + self._access = FieldPort.Access(access) + super().__init__() @property def shape(self): - return self.port.shape + return self._shape @property def access(self): - return self.port.access + return self._access + + @property + def signature(self): + return wiring.Signature({ + "port": Out(FieldPort.Signature(self._shape, self._access)), + }) class FieldMap(Mapping): @@ -346,7 +355,7 @@ def flatten(self): assert False # :nocov: -class Register(Elaboratable): +class Register(wiring.Component): """CSR register. Parameters @@ -407,8 +416,10 @@ def __init__(self, access="rw", fields=None): raise ValueError(f"Field {'__'.join(field_name)} is writable, but register access " f"mode is {access!r}") - self.element = Element.Signature(width, access).create(path=("element",)) + self._width = width + self._access = access self._fields = fields + super().__init__() @property def fields(self): @@ -418,6 +429,12 @@ def fields(self): def f(self): return self._fields + @property + def signature(self): + return wiring.Signature({ + "element": Out(Element.Signature(self._width, self._access)), + }) + def __iter__(self): """Recursively iterate over the field collection. @@ -684,7 +701,7 @@ def get_register(self, path): raise KeyError(path) -class Bridge(Elaboratable): +class Bridge(wiring.Component): """CSR bridge. Parameters @@ -754,18 +771,20 @@ def get_register_param(path, root, kind): self._map = register_map self._mux = Multiplexer(memory_map) + super().__init__() @property def register_map(self): return self._map @property - def bus(self): - return self._mux.bus + def signature(self): + return self._mux.signature def elaborate(self, platform): m = Module() for register, path in self.register_map.flatten(): m.submodules["__".join(path)] = register m.submodules.mux = self._mux + connect(m, flipped(self), self._mux) return m diff --git a/tests/test_csr_reg.py b/tests/test_csr_reg.py index 3f2281a..a77f93e 100644 --- a/tests/test_csr_reg.py +++ b/tests/test_csr_reg.py @@ -15,10 +15,10 @@ def test_shape_1_ro(self): self.assertEqual(sig.shape, unsigned(1)) self.assertEqual(sig.access, FieldPort.Access.R) self.assertEqual(sig.members, Signature({ - "r_data": Out(unsigned(1)), - "r_stb": In(1), - "w_data": In(unsigned(1)), - "w_stb": In(1), + "r_data": In(unsigned(1)), + "r_stb": Out(1), + "w_data": Out(unsigned(1)), + "w_stb": Out(1), }).members) def test_shape_8_rw(self): @@ -26,10 +26,10 @@ def test_shape_8_rw(self): self.assertEqual(sig.shape, unsigned(8)) self.assertEqual(sig.access, FieldPort.Access.RW) self.assertEqual(sig.members, Signature({ - "r_data": Out(unsigned(8)), - "r_stb": In(1), - "w_data": In(unsigned(8)), - "w_stb": In(1), + "r_data": In(unsigned(8)), + "r_stb": Out(1), + "w_data": Out(unsigned(8)), + "w_stb": Out(1), }).members) def test_shape_10_wo(self): @@ -37,10 +37,10 @@ def test_shape_10_wo(self): self.assertEqual(sig.shape, unsigned(10)) self.assertEqual(sig.access, FieldPort.Access.W) self.assertEqual(sig.members, Signature({ - "r_data": Out(unsigned(10)), - "r_stb": In(1), - "w_data": In(unsigned(10)), - "w_stb": In(1), + "r_data": In(unsigned(10)), + "r_stb": Out(1), + "w_data": Out(unsigned(10)), + "w_stb": Out(1), }).members) def test_shape_0_rw(self): @@ -48,10 +48,10 @@ def test_shape_0_rw(self): self.assertEqual(sig.shape, unsigned(0)) self.assertEqual(sig.access, FieldPort.Access.W) self.assertEqual(sig.members, Signature({ - "r_data": Out(unsigned(0)), - "r_stb": In(1), - "w_data": In(unsigned(0)), - "w_stb": In(1), + "r_data": In(unsigned(0)), + "r_stb": Out(1), + "w_data": Out(unsigned(0)), + "w_stb": Out(1), }).members) def test_create(self):