From 0181c95ebaa3911309c6597ab6ba35dfca71bb00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Fran=C3=A7ois=20Nguyen?= Date: Fri, 1 Sep 2023 18:51:40 +0200 Subject: [PATCH] csr.reg: implement amaranth-lang/rfcs#16. --- amaranth_soc/csr/__init__.py | 1 + amaranth_soc/csr/field.py | 188 ++++++++ amaranth_soc/csr/reg.py | 701 +++++++++++++++++++++++++++++ tests/test_csr_field.py | 158 +++++++ tests/test_csr_reg.py | 843 +++++++++++++++++++++++++++++++++++ 5 files changed, 1891 insertions(+) create mode 100644 amaranth_soc/csr/field.py create mode 100644 amaranth_soc/csr/reg.py create mode 100644 tests/test_csr_field.py create mode 100644 tests/test_csr_reg.py diff --git a/amaranth_soc/csr/__init__.py b/amaranth_soc/csr/__init__.py index 35bfd29..fc66607 100644 --- a/amaranth_soc/csr/__init__.py +++ b/amaranth_soc/csr/__init__.py @@ -1,2 +1,3 @@ from .bus import * from .event import * +from .reg import * diff --git a/amaranth_soc/csr/field.py b/amaranth_soc/csr/field.py new file mode 100644 index 0000000..33df5a0 --- /dev/null +++ b/amaranth_soc/csr/field.py @@ -0,0 +1,188 @@ +from amaranth import * + +from .reg import Field + + +__all__ = ["R", "W", "RW", "RW1C", "RW1S"] + + +class R(Field): + __doc__ = Field._doc_template.format( + description=""" + A read-only field. + """.strip(), + parameters="", + attributes=""" + r_data : Signal(shape) + Read data. Drives the :attr:`~FieldPort.r_data` signal of ``port``. + """.strip()) + + def __init__(self, shape): + super().__init__(shape, access="r") + self.r_data = Signal(shape) + + def elaborate(self, platform): + m = Module() + m.d.comb += self.port.r_data.eq(self.r_data) + return m + + +class W(Field): + __doc__ = Field._doc_template.format( + description=""" + A write-only field. + """.strip(), + parameters="", + attributes=""" + w_data : Signal(shape) + Write data. Driven by the :attr:`~FieldPort.w_data` signal of ``port``. + """.strip()) + + def __init__(self, shape): + super().__init__(shape, access="w") + self.w_data = Signal(shape) + + def elaborate(self, platform): + m = Module() + m.d.comb += self.w_data.eq(self.port.w_data) + return m + + +class RW(Field): + __doc__ = Field._doc_template.format( + description=""" + A read/write field with built-in storage. + + Storage is updated with the value of ``port.w_data`` one clock cycle after ``port.w_stb`` is + asserted. + """.strip(), + parameters=""" + reset : :class:`int` + Storage reset value. + """, + attributes=""" + data : Signal(shape) + Storage output. + """.strip()) + + def __init__(self, shape, *, reset=0): + super().__init__(shape, access="rw") + self.data = Signal(shape) + self._storage = Signal(shape, reset=reset) + self._reset = reset + + @property + def reset(self): + return self._reset + + def elaborate(self, platform): + m = Module() + + with m.If(self.port.w_stb): + m.d.sync += self._storage.eq(self.port.w_data) + + m.d.comb += [ + self.port.r_data.eq(self._storage), + self.data.eq(self._storage), + ] + + return m + + +class RW1C(Field): + __doc__ = Field._doc_template.format( + description=""" + A read/write-one-to-clear field with built-in storage. + + Storage bits are: + * cleared by high bits in ``port.w_data``, one clock cycle after ``port.w_stb`` is asserted; + * set by high bits in ``set``, one clock cycle after they are asserted. + + If a storage bit is set and cleared on the same clock cycle, setting it has precedence. + """.strip(), + parameters=""" + reset : :class:`int` + Storage reset value. + """, + attributes=""" + data : Signal(shape) + Storage output. + set : Signal(shape) + Mask to set storage bits. + """.strip()) + + def __init__(self, shape, *, reset=0): + super().__init__(shape, access="rw") + self.data = Signal(shape) + self.set = Signal(shape) + self._storage = Signal(shape, reset=reset) + self._reset = reset + + @property + def reset(self): + return self._reset + + def elaborate(self, platform): + m = Module() + + for i, storage_bit in enumerate(self._storage): + with m.If(self.port.w_stb & self.port.w_data[i]): + m.d.sync += storage_bit.eq(0) + with m.If(self.set[i]): + m.d.sync += storage_bit.eq(1) + + m.d.comb += [ + self.port.r_data.eq(self._storage), + self.data.eq(self._storage), + ] + + return m + + +class RW1S(Field): + __doc__ = Field._doc_template.format( + description=""" + A read/write-one-to-set field with built-in storage. + + Storage bits are: + * set by high bits in ``port.w_data``, one clock cycle after ``port.w_stb`` is asserted; + * cleared by high bits in ``clear``, one clock cycle after they are asserted. + + If a storage bit is set and cleared on the same clock cycle, setting it has precedence. + """.strip(), + parameters=""" + reset : :class:`int` + Storage reset value. + """, + attributes=""" + data : Signal(shape) + Storage output. + clear : Signal(shape) + Mask to clear storage bits. + """.strip()) + def __init__(self, shape, *, reset=0): + super().__init__(shape, access="rw") + self.data = Signal(shape) + self.clear = Signal(shape) + self._storage = Signal(shape, reset=reset) + self._reset = reset + + @property + def reset(self): + return self._reset + + def elaborate(self, platform): + m = Module() + + for i, storage_bit in enumerate(self._storage): + with m.If(self.clear[i]): + m.d.sync += storage_bit.eq(0) + with m.If(self.port.w_stb & self.port.w_data[i]): + m.d.sync += storage_bit.eq(1) + + m.d.comb += [ + self.port.r_data.eq(self._storage), + self.data.eq(self._storage), + ] + + return m diff --git a/amaranth_soc/csr/reg.py b/amaranth_soc/csr/reg.py new file mode 100644 index 0000000..1f2cfd9 --- /dev/null +++ b/amaranth_soc/csr/reg.py @@ -0,0 +1,701 @@ +from collections.abc import Mapping, Sequence +import enum +from amaranth import * + +from ..memory import MemoryMap +from .bus import Element, Multiplexer + + +__all__ = ["FieldPort", "Field", "FieldMap", "FieldArray", "Register", "RegisterMap", "Bridge"] + + +class FieldPort: + class Access(enum.Enum): + """Field access mode.""" + R = "r" + W = "w" + RW = "rw" + + def readable(self): + return self == self.R or self == self.RW + + def writable(self): + return self == self.W or self == self.RW + + """CSR register field port. + + An interface between a CSR register and one of its fields. + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. + access : :class:`FieldPort.Access` + Field access mode. + + Attributes + ---------- + r_data : Signal(shape) + Read data. Must always be valid, and is sampled when ``r_stb`` is asserted. + r_stb : Signal() + Read strobe. Fields with read side effects should perform them when this strobe is + asserted. + w_data : Signal(shape) + Write data. Valid only when ``w_stb`` is asserted. + w_stb : Signal() + Write strobe. Fields should update their value or perform the write side effect when + this strobe is asserted. + + Raises + ------ + :exc:`TypeError` + If ``shape`` is not a shape-castable object. + :exc:`ValueError` + If ``access`` is not a member of :class:`FieldPort.Access`. + """ + def __init__(self, shape, access): + try: + shape = Shape.cast(shape) + except TypeError as e: + raise TypeError("Field shape must be a shape-castable object, not {!r}" + .format(shape)) from e + if not isinstance(access, FieldPort.Access) and access not in ("r", "w", "rw"): + raise ValueError("Access mode must be one of \"r\", \"w\", or \"rw\", not {!r}" + .format(access)) + self._shape = shape + self._access = FieldPort.Access(access) + + self.r_data = Signal(shape) + self.r_stb = Signal() + self.w_data = Signal(shape) + self.w_stb = Signal() + + @property + def shape(self): + return self._shape + + @property + def access(self): + return self._access + + def __repr__(self): + return "FieldPort({}, {})".format(self.shape, self.access) + + +class Field(Elaboratable): + _doc_template = """ + {description} + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. + access : :class:`FieldPort.Access` + Field access mode. + {parameters} + + Attributes + ---------- + port : :class:`FieldPort` + Field port. + {attributes} + """ + + __doc__ = _doc_template.format( + description=""" + A generic register field. + """.strip(), + parameters="", + attributes="") + + def __init__(self, shape, access): + self.port = FieldPort(shape, access) + + @property + def shape(self): + return self.port.shape + + @property + def access(self): + return self.port.access + + +class FieldMap(Mapping): + """A mapping of CSR register fields. + + Parameters + ---------- + fields : dict of :class:`str` to one of :class:`Field` or :class:`FieldMap`. + """ + def __init__(self, fields): + self._fields = {} + + if not isinstance(fields, Mapping) or len(fields) == 0: + raise TypeError("Fields must be provided as a non-empty mapping, not {!r}" + .format(fields)) + + for key, field in fields.items(): + if not isinstance(key, str) or not key: + raise TypeError("Field name must be a non-empty string, not {!r}" + .format(key)) + if not isinstance(field, (Field, FieldMap, FieldArray)): + raise TypeError("Field must be a Field or a FieldMap or a FieldArray, not {!r}" + .format(field)) + self._fields[key] = field + + def __getitem__(self, key): + """Access a field by name or index. + + Returns + -------- + :class:`Field` or :class:`FieldMap` or :class:`FieldArray` + The field associated with ``key``. + + Raises + ------ + :exc:`KeyError` + If there is no field associated with ``key``. + """ + return self._fields[key] + + def __getattr__(self, name): + """Access a field by name. + + Returns + ------- + :class:`Field` or :class:`FieldMap` or :class:`FieldArray` + The field associated with ``name``. + + Raises + ------ + :exc:`AttributeError` + If the field map does not have a field associated with ``name``. + :exc:`AttributeError` + If ``name`` is reserved (i.e. starts with an underscore). + """ + try: + item = self[name] + except KeyError: + raise AttributeError("Field map does not have a field {!r}; " + "did you mean one of: {}?" + .format(name, ", ".join(repr(name) for name in self.keys()))) + if name.startswith("_"): + raise AttributeError("Field map field {!r} has a reserved name and may only be " + "accessed by indexing" + .format(name)) + return item + + def __iter__(self): + """Iterate over the field map. + + Yields + ------ + :class:`str` + Key (name) for accessing the field. + """ + yield from self._fields + + def __len__(self): + return len(self._fields) + + def flatten(self): + """Recursively iterate over the field map. + + Yields + ------ + iter(:class:`str`) + Name of the field. It is prefixed by the name of every nested field collection. + :class:`Field` + Register field. + """ + for key, field in self.items(): + if isinstance(field, Field): + yield (key,), field + elif isinstance(field, (FieldMap, FieldArray)): + for sub_name, sub_field in field.flatten(): + yield (key, *sub_name), sub_field + else: + assert False # :nocov: + + +class FieldArray(Sequence): + """An array of CSR register fields. + + Parameters + ---------- + fields : iter(:class:`Field` or :class:`FieldMap` or :class:`FieldArray`) + Field array members. + """ + def __init__(self, fields): + fields = tuple(fields) + for field in fields: + if not isinstance(field, (Field, FieldMap, FieldArray)): + raise TypeError("Field must be a Field or a FieldMap or a FieldArray, not {!r}" + .format(field)) + self._fields = fields + + def __getitem__(self, key): + """Access a field by index. + + Returns + ------- + :class:`Field` or :class:`FieldMap` or :class:`FieldArray` + The field associated with ``key``. + """ + return self._fields[key] + + def __len__(self): + """Field array length. + + Returns + ------- + :class:`int` + The number of fields in the array. + """ + return len(self._fields) + + def flatten(self): + """Iterate recursively over the field array. + + Yields + ------ + iter(:class:`str`) + Name of the field. It is prefixed by the name of every nested field collection. + :class:`Field` + Register field. + """ + for key, field in enumerate(self._fields): + if isinstance(field, Field): + yield (key,), field + elif isinstance(field, (FieldMap, FieldArray)): + for sub_name, sub_field in field.flatten(): + yield (key, *sub_name), sub_field + else: + assert False # :nocov: + + +class Register(Elaboratable): + """CSR register. + + Parameters + ---------- + access : :class:`Element.Access` + Register access mode. + fields : :class:`FieldMap` or :class:`FieldArray` + Collection of register fields. If ``None`` (default), a :class:`FieldMap` is created + from Python :term:`variable annotations `. + + Attributes + ---------- + element : :class:`Element` + Interface between this register and a CSR bus primitive. + fields : :class:`FieldMap` or :class:`FieldArray` + Collection of register fields. + f : :class:`FieldMap` or :class:`FieldArray` + Shorthand for :attr:`Register.fields`. + + Raises + ------ + :exc:`ValueError` + If ``access`` is not a member of :class:`Element.Access`. + :exc:`TypeError` + If ``fields`` is not ``None`` or a :class:`FieldMap` or a :class:`FieldArray`. + :exc:`ValueError` + If ``access`` is not readable and at least one field is readable. + :exc:`ValueError` + If ``access`` is not writable and at least one field is writable. + """ + def __init__(self, access, fields=None): + if not isinstance(access, Element.Access) and access not in ("r", "w", "rw"): + raise ValueError("Access mode must be one of \"r\", \"w\", or \"rw\", not {!r}" + .format(access)) + access = Element.Access(access) + + if hasattr(self, "__annotations__"): + annotation_fields = {} + for key, value in self.__annotations__.items(): + if isinstance(value, (Field, FieldMap, FieldArray)): + annotation_fields[key] = value + + if fields is None: + fields = FieldMap(annotation_fields) + elif annotation_fields: + raise ValueError("Field collection {} cannot be provided in addition to field annotations: {}" + .format(fields, ", ".join(annotation_fields.keys()))) + + if not isinstance(fields, (FieldMap, FieldArray)): + raise TypeError("Field collection must be a FieldMap or a FieldArray, not {!r}" + .format(fields)) + + width = 0 + for field_name, field in fields.flatten(): + width += Shape.cast(field.shape).width + if field.access.readable() and not access.readable(): + raise ValueError("Field {} is readable, but register access mode is '{}'" + .format("__".join(field_name), access)) + if field.access.writable() and not access.writable(): + raise ValueError("Field {} is writable, but register access mode is '{}'" + .format("__".join(field_name), access)) + + self.element = Element(width, access) + self._fields = fields + + @property + def fields(self): + return self._fields + + @property + def f(self): + return self._fields + + def __iter__(self): + """Recursively iterate over the field collection. + + Yields + ------ + iter(:class:`str`) + Name of the field. It is prefixed by the name of every nested field collection. + :class:`Field` + Register field. + """ + yield from self.fields.flatten() + + def elaborate(self, platform): + m = Module() + + field_start = 0 + + for field_name, field in self.fields.flatten(): + m.submodules["__".join(str(key) for key in field_name)] = field + + field_slice = slice(field_start, field_start + Shape.cast(field.shape).width) + + if field.access.readable(): + m.d.comb += [ + self.element.r_data[field_slice].eq(field.port.r_data), + field.port.r_stb.eq(self.element.r_stb), + ] + if field.access.writable(): + m.d.comb += [ + field.port.w_data.eq(self.element.w_data[field_slice]), + field.port.w_stb .eq(self.element.w_stb), + ] + + field_start = field_slice.stop + + return m + + +class RegisterMap: + """A collection of CSR registers.""" + def __init__(self): + self._registers = dict() + self._clusters = dict() + self._namespace = dict() + self._frozen = False + + def freeze(self): + """Freeze the cluster. + + Once the cluster is frozen, its visible state becomes immutable. Registers and clusters + cannot be added anymore. + """ + self._frozen = True + + def add_register(self, register, *, name): + """Add a register. + + Arguments + --------- + register : :class:`Register` + Register. + name : :class:`str` + Name of the register. + + Returns + ------- + :class:`Register` + ``register``, which is added to the register map. + + Raises + ------ + :exc:`ValueError` + If the register map is frozen. + :exc:`TypeError` + If ``register` is not an instance of :class:`Register`. + :exc:`TypeError` + If ``name`` is not a string. + :exc:`ValueError` + If ``name`` is already used. + """ + if self._frozen: + raise ValueError("Register map is frozen") + if not isinstance(register, Register): + raise TypeError("Register must be an instance of csr.Register, not {!r}" + .format(register)) + + if not isinstance(name, str) or not name: + raise TypeError("Name must be a non-empty string, not {!r}".format(name)) + if name in self._namespace: + raise ValueError("Name '{}' is already used by {!r}".format(name, self._namespace[name])) + + self._registers[id(register)] = register, name + self._namespace[name] = register + return register + + def registers(self): + """Iterate local registers. + + Yields + ------ + :class:`Register` + Register. + :class:`str` + Name of the register. + """ + for register, name in self._registers.values(): + yield register, name + + def add_cluster(self, cluster, *, name): + """Add a cluster of registers. + + Arguments + --------- + cluster : :class:`RegisterMap` + Cluster of registers. + name : :class:`str` + Name of the cluster. + + Returns + ------- + :class:`RegisterMap` + ``cluster``, which is added to the register map. + + Raises + ------ + :exc:`ValueError` + If the register map is frozen. + :exc:`TypeError` + If ``cluster` is not an instance of :class:`RegisterMap`. + :exc:`TypeError` + If ``name`` is not a string. + :exc:`ValueError` + If ``name`` is already used. + """ + if self._frozen: + raise ValueError("Register map is frozen") + if not isinstance(cluster, RegisterMap): + raise TypeError("Cluster must be an instance of csr.RegisterMap, not {!r}" + .format(cluster)) + + if not isinstance(name, str) or not name: + raise TypeError("Name must be a non-empty string, not {!r}".format(name)) + if name in self._namespace: + raise ValueError("Name '{}' is already used by {!r}".format(name, self._namespace[name])) + + self._clusters[id(cluster)] = cluster, name + self._namespace[name] = cluster + return cluster + + def clusters(self): + """Iterate local clusters of registers. + + Yields + ------ + :class:`RegisterMap` + Cluster of registers. + :class:`str` + Name of the cluster. + """ + for cluster, name in self._clusters.values(): + yield cluster, name + + def flatten(self, *, _path=()): + """Recursively iterate over all registers. + + Yields + ------ + :class:`Register` + Register. + iter(:class:`str`) + Path of the register. It contains its name, prefixed by the name of parent clusters up + to this register map. + """ + for name, assignment in self._namespace.items(): + path = (*_path, name) + if id(assignment) in self._registers: + yield assignment, path + elif id(assignment) in self._clusters: + yield from assignment.flatten(_path=path) + else: + assert False # :nocov: + + def get_path(self, register, *, _path=()): + """Get the path of a register. + + Arguments + --------- + register : :class:`Register` + A register of the register map. + + Returns + ------- + iter(:class:`str`) + Path of the register. It contains its name, prefixed by the name of parent clusters up + to this register map. + + Raises + ------ + :exc:`TypeError` + If ``register` is not an instance of :class:`Register`. + :exc:`KeyError` + If ``register` is not in the register map. + """ + if not isinstance(register, Register): + raise TypeError("Register must be an instance of csr.Register, not {!r}" + .format(register)) + + if id(register) in self._registers: + _, name = self._registers[id(register)] + return (*_path, name) + + for cluster, name in self._clusters.values(): + try: + return cluster.get_path(register, _path=(*_path, name)) + except KeyError: + pass + + raise KeyError(register) + + def get_register(self, path): + """Get a register. + + Arguments + --------- + path : iter(:class:`str`) + Path of the register. It contains its name, prefixed by the name of parent clusters up + to this register map. + + Returns + ------- + :class:`Register` + The register assigned to ``path``. + + Raises + ------ + :exc:`ValueError` + If ``path`` is empty. + :exc:`TypeError` + If ``path`` is not composed of non-empty strings. + :exc:`KeyError` + If ``path`` is not assigned to a register. + """ + path = tuple(path) + if not path: + raise ValueError("Path must be a non-empty iterable") + for name in path: + if not isinstance(name, str) or not name: + raise TypeError("Path must contain non-empty strings, not {!r}".format(name)) + + name, *rest = path + + if name in self._namespace: + assignment = self._namespace[name] + if not rest: + assert id(assignment) in self._registers + return assignment + else: + assert id(assignment) in self._clusters + try: + return assignment.get_register(rest) + except KeyError: + pass + + raise KeyError(path) + + +class Bridge(Elaboratable): + """CSR bridge. + + Parameters + ---------- + register_map : :class:`RegisterMap` + Register map. + addr_width : :class:`int` + Address width. See :class:`Interface`. + data_width : :class:`int` + Data width. See :class:`Interface`. + alignment : log2 of :class:`int` + Register alignment. Optional, defaults to ``0``. See :class:`..memory.MemoryMap`. + name : :class:`str` + Window name. Optional. + register_addr : :class:`dict` + Register address mapping. Optional, defaults to ``None``. + register_alignment : :class:`dict` + Register alignment mapping. Optional, defaults to ``None``. + + Attributes + ---------- + register_map : :class:`RegisterMap` + Register map. + bus : :class:`Interface` + CSR bus providing access to registers. + + Raises + ------ + :exc:`TypeError` + If ``register_map`` is not an instance of :class:`RegisterMap`. + :exc:`TypeError` + If ``register_addr`` is a not a mapping. + :exc:`TypeError` + If ``register_alignment`` is a not a mapping. + """ + def __init__(self, register_map, *, addr_width, data_width, alignment=0, name=None, + register_addr=None, register_alignment=None): + if not isinstance(register_map, RegisterMap): + raise TypeError("Register map must be an instance of RegisterMap, not {!r}" + .format(register_map)) + + memory_map = MemoryMap(addr_width=addr_width, data_width=data_width, alignment=alignment, + name=name) + + def get_register_param(path, root, kind): + node = root + prev = [] + for name in path: + if node is None: + break + if not isinstance(node, Mapping): + raise TypeError("Register {}{} must be a mapping, not {!r}" + .format(kind, "" if not prev else f" {tuple(prev)}", node)) + prev.append(name) + node = node.get(name, None) + return node + + register_map.freeze() + + for register, path in register_map.flatten(): + elem_size = (register.element.width + data_width - 1) // data_width + elem_name = "__".join(path) + elem_addr = get_register_param(path, register_addr, "address") + elem_alignment = get_register_param(path, register_alignment, "alignment") + memory_map.add_resource(register.element, name=elem_name, size=elem_size, + addr=elem_addr, alignment=elem_alignment) + + self._map = register_map + self._mux = Multiplexer(memory_map) + + @property + def register_map(self): + return self._map + + @property + def bus(self): + return self._mux.bus + + def elaborate(self, platform): + m = Module() + for register, path in self.register_map.flatten(): + m.submodules["__".join(path)] = register + m.submodules.mux = self._mux + return m diff --git a/tests/test_csr_field.py b/tests/test_csr_field.py new file mode 100644 index 0000000..6e1d1d6 --- /dev/null +++ b/tests/test_csr_field.py @@ -0,0 +1,158 @@ +# amaranth: UnusedElaboratable=no + +import unittest +from amaranth import * +from amaranth.sim import * + +from amaranth_soc.csr import field + + +class RTestCase(unittest.TestCase): + def test_simple(self): + f = field.R(unsigned(4)) + self.assertEqual(repr(f.port), "FieldPort(unsigned(4), Access.R)") + self.assertEqual(f.r_data.shape(), unsigned(4)) + + def test_sim(self): + dut = field.R(unsigned(4)) + + def process(): + yield dut.r_data.eq(0xa) + yield Settle() + self.assertEqual((yield dut.port.r_data), 0xa) + + sim = Simulator(dut) + sim.add_process(process) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() + + +class WTestCase(unittest.TestCase): + def test_simple(self): + f = field.W(unsigned(4)) + self.assertEqual(repr(f.port), "FieldPort(unsigned(4), Access.W)") + self.assertEqual(f.w_data.shape(), unsigned(4)) + + def test_sim(self): + dut = field.W(unsigned(4)) + + def process(): + yield dut.port.w_data.eq(0xa) + yield Settle() + self.assertEqual((yield dut.w_data), 0xa) + + sim = Simulator(dut) + sim.add_process(process) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() + + +class RWTestCase(unittest.TestCase): + def test_simple(self): + f4 = field.RW(unsigned(4), reset=0x5) + f8 = field.RW(signed(8)) + self.assertEqual(repr(f4.port), "FieldPort(unsigned(4), Access.RW)") + self.assertEqual(f4.data.shape(), unsigned(4)) + self.assertEqual(f4.reset, 0x5) + self.assertEqual(repr(f8.port), "FieldPort(signed(8), Access.RW)") + self.assertEqual(f8.data.shape(), signed(8)) + self.assertEqual(f8.reset, 0) + + def test_sim(self): + dut = field.RW(unsigned(4), reset=0x5) + + def process(): + self.assertEqual((yield dut.port.r_data), 0x5) + self.assertEqual((yield dut.data), 0x5) + yield dut.port.w_stb .eq(1) + yield dut.port.w_data.eq(0xa) + yield + yield Settle() + self.assertEqual((yield dut.port.r_data), 0xa) + self.assertEqual((yield dut.data), 0xa) + + sim = Simulator(dut) + sim.add_clock(1e-6) + sim.add_sync_process(process) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() + + +class RW1CTestCase(unittest.TestCase): + def test_simple(self): + f4 = field.RW1C(unsigned(4), reset=0x5) + f8 = field.RW1C(signed(8)) + self.assertEqual(repr(f4.port), "FieldPort(unsigned(4), Access.RW)") + self.assertEqual(f4.data.shape(), unsigned(4)) + self.assertEqual(f4.set .shape(), unsigned(4)) + self.assertEqual(f4.reset, 0x5) + self.assertEqual(repr(f8.port), "FieldPort(signed(8), Access.RW)") + self.assertEqual(f8.data.shape(), signed(8)) + self.assertEqual(f8.set .shape(), signed(8)) + self.assertEqual(f8.reset, 0) + + def test_sim(self): + dut = field.RW1C(unsigned(4), reset=0xf) + + def process(): + self.assertEqual((yield dut.port.r_data), 0xf) + self.assertEqual((yield dut.data), 0xf) + yield dut.port.w_stb .eq(1) + yield dut.port.w_data.eq(0x5) + yield + yield Settle() + self.assertEqual((yield dut.port.r_data), 0xa) + self.assertEqual((yield dut.data), 0xa) + + yield dut.port.w_data.eq(0x3) + yield dut.set.eq(0x4) + yield + yield Settle() + self.assertEqual((yield dut.port.r_data), 0xc) + self.assertEqual((yield dut.data), 0xc) + + sim = Simulator(dut) + sim.add_clock(1e-6) + sim.add_sync_process(process) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() + + +class RW1STestCase(unittest.TestCase): + def test_simple(self): + f4 = field.RW1S(unsigned(4), reset=0x5) + f8 = field.RW1S(signed(8)) + self.assertEqual(repr(f4.port), "FieldPort(unsigned(4), Access.RW)") + self.assertEqual(f4.data .shape(), unsigned(4)) + self.assertEqual(f4.clear.shape(), unsigned(4)) + self.assertEqual(f4.reset, 0x5) + self.assertEqual(repr(f8.port), "FieldPort(signed(8), Access.RW)") + self.assertEqual(f8.data .shape(), signed(8)) + self.assertEqual(f8.clear.shape(), signed(8)) + self.assertEqual(f8.reset, 0) + + def test_sim(self): + dut = field.RW1S(unsigned(4), reset=0x5) + + def process(): + self.assertEqual((yield dut.port.r_data), 0x5) + self.assertEqual((yield dut.data), 0x5) + yield dut.port.w_stb .eq(1) + yield dut.port.w_data.eq(0xa) + yield + yield Settle() + self.assertEqual((yield dut.port.r_data), 0xf) + self.assertEqual((yield dut.data), 0xf) + + yield dut.port.w_data.eq(0x3) + yield dut.clear.eq(0x7) + yield + yield Settle() + self.assertEqual((yield dut.port.r_data), 0xb) + self.assertEqual((yield dut.data), 0xb) + + sim = Simulator(dut) + sim.add_clock(1e-6) + sim.add_sync_process(process) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() diff --git a/tests/test_csr_reg.py b/tests/test_csr_reg.py new file mode 100644 index 0000000..a823cc7 --- /dev/null +++ b/tests/test_csr_reg.py @@ -0,0 +1,843 @@ +# amaranth: UnusedElaboratable=no + +import unittest +from amaranth import * +from amaranth.sim import * + +from amaranth_soc.csr.reg import * +from amaranth_soc.csr import field + + +class FieldPortTestCase(unittest.TestCase): + def test_shape_1_ro(self): + port = FieldPort(1, "r") + self.assertEqual(port.shape, unsigned(1)) + self.assertEqual(port.access, FieldPort.Access.R) + self.assertEqual(port.r_data.shape(), unsigned(1)) + self.assertEqual(port.r_stb .shape(), unsigned(1)) + self.assertEqual(port.w_data.shape(), unsigned(1)) + self.assertEqual(port.w_stb .shape(), unsigned(1)) + self.assertEqual(repr(port), "FieldPort(unsigned(1), Access.R)") + + def test_shape_8_rw(self): + port = FieldPort(8, "rw") + self.assertEqual(port.shape, unsigned(8)) + self.assertEqual(port.access, FieldPort.Access.RW) + self.assertEqual(port.r_data.shape(), unsigned(8)) + self.assertEqual(port.r_stb .shape(), unsigned(1)) + self.assertEqual(port.w_data.shape(), unsigned(8)) + self.assertEqual(port.w_stb .shape(), unsigned(1)) + self.assertEqual(repr(port), "FieldPort(unsigned(8), Access.RW)") + + def test_shape_10_wo(self): + port = FieldPort(10, "w") + self.assertEqual(port.shape, unsigned(10)) + self.assertEqual(port.access, FieldPort.Access.W) + self.assertEqual(port.r_data.shape(), unsigned(10)) + self.assertEqual(port.r_stb .shape(), unsigned(1)) + self.assertEqual(port.w_data.shape(), unsigned(10)) + self.assertEqual(port.w_stb .shape(), unsigned(1)) + self.assertEqual(repr(port), "FieldPort(unsigned(10), Access.W)") + + def test_shape_0_rw(self): + port = FieldPort(0, "rw") + self.assertEqual(port.shape, unsigned(0)) + self.assertEqual(port.access, FieldPort.Access.RW) + self.assertEqual(port.r_data.shape(), unsigned(0)) + self.assertEqual(port.r_stb .shape(), unsigned(1)) + self.assertEqual(port.w_data.shape(), unsigned(0)) + self.assertEqual(port.w_stb .shape(), unsigned(1)) + self.assertEqual(repr(port), "FieldPort(unsigned(0), Access.RW)") + + def test_shape_wrong(self): + with self.assertRaisesRegex(TypeError, + r"Field shape must be a shape-castable object, not 'foo'"): + port = FieldPort("foo", "rw") + + def test_access_wrong(self): + with self.assertRaisesRegex(ValueError, + r"Access mode must be one of \"r\", \"w\", or \"rw\", not 'wo'"): + port = FieldPort(8, "wo") + + +def _compatible_fields(a, b): + return isinstance(a, Field) and type(a) == type(b) and \ + a.shape == b.shape and a.access == b.access + + +class FieldTestCase(unittest.TestCase): + def test_simple(self): + field = Field(unsigned(4), "rw") + self.assertEqual(field.shape, unsigned(4)) + self.assertEqual(field.access, FieldPort.Access.RW) + self.assertEqual(repr(field.port), "FieldPort(unsigned(4), Access.RW)") + + def test_compatible(self): + self.assertTrue(_compatible_fields(Field(unsigned(4), "rw"), + Field(unsigned(4), FieldPort.Access.RW))) + self.assertFalse(_compatible_fields(Field(unsigned(3), "r" ), Field(unsigned(4), "r"))) + self.assertFalse(_compatible_fields(Field(unsigned(4), "rw"), Field(unsigned(4), "w"))) + self.assertFalse(_compatible_fields(Field(unsigned(4), "rw"), Field(unsigned(4), "r"))) + self.assertFalse(_compatible_fields(Field(unsigned(4), "r" ), Field(unsigned(4), "w"))) + + def test_wrong_shape(self): + with self.assertRaisesRegex(TypeError, + r"Field shape must be a shape-castable object, not 'foo'"): + Field("foo", "rw") + + def test_wrong_access(self): + with self.assertRaisesRegex(ValueError, + r"Access mode must be one of \"r\", \"w\", or \"rw\", not 'wo'"): + Field(8, "wo") + + +class FieldMapTestCase(unittest.TestCase): + def test_simple(self): + field_map = FieldMap({ + "a": Field(unsigned(1), "r"), + "b": Field(signed(3), "rw"), + "c": FieldMap({ + "d": Field(unsigned(4), "rw"), + }), + }) + self.assertTrue(_compatible_fields(field_map["a"], Field(unsigned(1), "r"))) + self.assertTrue(_compatible_fields(field_map["b"], Field(signed(3), "rw"))) + self.assertTrue(_compatible_fields(field_map["c"]["d"], Field(unsigned(4), "rw"))) + + self.assertTrue(_compatible_fields(field_map.a, Field(unsigned(1), "r"))) + self.assertTrue(_compatible_fields(field_map.b, Field(signed(3), "rw"))) + self.assertTrue(_compatible_fields(field_map.c.d, Field(unsigned(4), "rw"))) + + self.assertEqual(len(field_map), 3) + + def test_iter(self): + field_map = FieldMap({ + "a": Field(unsigned(1), "r"), + "b": Field(signed(3), "rw") + }) + self.assertEqual(list(field_map.items()), [ + ("a", field_map["a"]), + ("b", field_map["b"]), + ]) + + def test_flatten(self): + field_map = FieldMap({ + "a": Field(unsigned(1), "r"), + "b": Field(signed(3), "rw"), + "c": FieldMap({ + "d": Field(unsigned(4), "rw"), + }), + }) + self.assertEqual(list(field_map.flatten()), [ + (("a",), field_map["a"]), + (("b",), field_map["b"]), + (("c", "d"), field_map["c"]["d"]), + ]) + + def test_wrong_mapping(self): + with self.assertRaisesRegex(TypeError, + r"Fields must be provided as a non-empty mapping, not 'foo'"): + FieldMap("foo") + + def test_wrong_field_key(self): + with self.assertRaisesRegex(TypeError, + r"Field name must be a non-empty string, not 1"): + FieldMap({1: Field(unsigned(1), "rw")}) + with self.assertRaisesRegex(TypeError, + r"Field name must be a non-empty string, not ''"): + FieldMap({"": Field(unsigned(1), "rw")}) + + def test_wrong_field_value(self): + with self.assertRaisesRegex(TypeError, + r"Field must be a Field or a FieldMap or a FieldArray, not unsigned\(1\)"): + FieldMap({"a": unsigned(1)}) + + def test_getitem_wrong_key(self): + with self.assertRaises(KeyError): + FieldMap({"a": Field(unsigned(1), "rw")})["b"] + + +class FieldArrayTestCase(unittest.TestCase): + def test_simple(self): + field_array = FieldArray([Field(unsigned(2), "rw") for _ in range(8)]) + self.assertEqual(len(field_array), 8) + for i in range(8): + self.assertTrue(_compatible_fields(field_array[i], Field(unsigned(2), "rw"))) + + def test_dim_2(self): + field_array = FieldArray([FieldArray([Field(unsigned(1), "rw") for _ in range(4)]) + for _ in range(4)]) + self.assertEqual(len(field_array), 4) + for i in range(4): + for j in range(4): + self.assertTrue(_compatible_fields(field_array[i][j], Field(1, "rw"))) + + def test_nested(self): + field_array = FieldArray([ + FieldMap({ + "a": Field(unsigned(4), "rw"), + "b": FieldArray([Field(unsigned(1), "rw") for _ in range(4)]), + }) for _ in range(4)]) + self.assertEqual(len(field_array), 4) + for i in range(4): + self.assertTrue(_compatible_fields(field_array[i]["a"], Field(unsigned(4), "rw"))) + for j in range(4): + self.assertTrue(_compatible_fields(field_array[i]["b"][j], + Field(unsigned(1), "rw"))) + + def test_iter(self): + field_array = FieldArray([Field(1, "rw") for _ in range(3)]) + self.assertEqual(list(field_array), [ + field_array[i] for i in range(3) + ]) + + def test_flatten(self): + field_array = FieldArray([ + FieldMap({ + "a": Field(4, "rw"), + "b": FieldArray([Field(1, "rw") for _ in range(2)]), + }) for _ in range(2)]) + self.assertEqual(list(field_array.flatten()), [ + ((0, "a"), field_array[0]["a"]), + ((0, "b", 0), field_array[0]["b"][0]), + ((0, "b", 1), field_array[0]["b"][1]), + ((1, "a"), field_array[1]["a"]), + ((1, "b", 0), field_array[1]["b"][0]), + ((1, "b", 1), field_array[1]["b"][1]), + ]) + + def test_wrong_field(self): + with self.assertRaisesRegex(TypeError, + r"Field must be a Field or a FieldMap or a FieldArray, not 'foo'"): + FieldArray([Field(1, "rw"), "foo"]) + + +class RegisterTestCase(unittest.TestCase): + def test_simple(self): + reg = Register("rw", FieldMap({ + "a": field.R(unsigned(1)), + "b": field.RW1C(unsigned(3)), + "c": FieldMap({"d": field.RW(signed(2))}), + "e": FieldArray([field.W(unsigned(1)) for _ in range(2)]) + })) + + self.assertTrue(_compatible_fields(reg.f.a, field.R(unsigned(1)))) + self.assertTrue(_compatible_fields(reg.f.b, field.RW1C(unsigned(3)))) + self.assertTrue(_compatible_fields(reg.f.c.d, field.RW(signed(2)))) + self.assertTrue(_compatible_fields(reg.f.e[0], field.W(unsigned(1)))) + self.assertTrue(_compatible_fields(reg.f.e[1], field.W(unsigned(1)))) + + self.assertEqual(reg.element.width, 8) + self.assertEqual(reg.element.access.readable(), True) + self.assertEqual(reg.element.access.writable(), True) + + def test_annotations(self): + class MockRegister(Register): + a: field.R(unsigned(1)) + b: field.RW1C(unsigned(3)) + c: FieldMap({"d": field.RW(signed(2))}) + e: FieldArray([field.W(unsigned(1)) for _ in range(2)]) + + foo: unsigned(42) + + reg = MockRegister("rw") + + self.assertTrue(_compatible_fields(reg.f.a, field.R(unsigned(1)))) + self.assertTrue(_compatible_fields(reg.f.b, field.RW1C(unsigned(3)))) + self.assertTrue(_compatible_fields(reg.f.c.d, field.RW(signed(2)))) + self.assertTrue(_compatible_fields(reg.f.e[0], field.W(unsigned(1)))) + self.assertTrue(_compatible_fields(reg.f.e[1], field.W(unsigned(1)))) + + self.assertEqual(reg.element.width, 8) + self.assertEqual(reg.element.access.readable(), True) + self.assertEqual(reg.element.access.writable(), True) + + def test_iter(self): + reg = Register("rw", FieldMap({ + "a": field.R(unsigned(1)), + "b": field.RW1C(unsigned(3)), + "c": FieldMap({"d": field.RW(signed(2))}), + "e": FieldArray([field.W(unsigned(1)) for _ in range(2)]) + })) + self.assertEqual(list(reg), [ + (("a",), reg.f.a), + (("b",), reg.f.b), + (("c", "d"), reg.f.c.d), + (("e", 0), reg.f.e[0]), + (("e", 1), reg.f.e[1]), + ]) + + def test_sim(self): + dut = Register("rw", FieldMap({ + "a": field.R(unsigned(1)), + "b": field.RW1C(unsigned(3), reset=0b111), + "c": FieldMap({"d": field.RW(signed(2), reset=-1)}), + "e": FieldArray([field.W(unsigned(1)) for _ in range(2)]), + "f": field.RW1S(unsigned(3)), + })) + + def process(): + # Check reset values: + + self.assertEqual((yield dut.f.b .data), 0b111) + self.assertEqual((yield dut.f.c.d.data), -1) + self.assertEqual((yield dut.f.f .data), 0b000) + + self.assertEqual((yield dut.f.b .port.r_data), 0b111) + self.assertEqual((yield dut.f.c.d .port.r_data), -1) + self.assertEqual((yield dut.f.f .port.r_data), 0b000) + + # Initiator read: + + yield dut.element.r_stb.eq(1) + yield Delay() + + self.assertEqual((yield dut.f.a.port.r_stb), 1) + self.assertEqual((yield dut.f.b.port.r_stb), 1) + self.assertEqual((yield dut.f.f.port.r_stb), 1) + + yield dut.element.r_stb.eq(0) + + # Initiator write: + + yield dut.element.w_stb .eq(1) + yield dut.element.w_data.eq(Cat( + Const(0b1, 1), # a + Const(0b010, 3), # b + Const(0b00, 2), # c.d + Const(0b00, 2), # e + Const(0b110, 3), # f + )) + yield Settle() + + self.assertEqual((yield dut.f.a .port.w_stb), 0) + self.assertEqual((yield dut.f.b .port.w_stb), 1) + self.assertEqual((yield dut.f.c.d .port.w_stb), 1) + self.assertEqual((yield dut.f.e[0].port.w_stb), 1) + self.assertEqual((yield dut.f.e[1].port.w_stb), 1) + self.assertEqual((yield dut.f.f .port.w_stb), 1) + + self.assertEqual((yield dut.f.b .port.w_data), 0b010) + self.assertEqual((yield dut.f.c.d .port.w_data), 0b00) + self.assertEqual((yield dut.f.e[0].port.w_data), 0b0) + self.assertEqual((yield dut.f.e[1].port.w_data), 0b0) + self.assertEqual((yield dut.f.f .port.w_data), 0b110) + + self.assertEqual((yield dut.f.e[0].w_data), 0b0) + self.assertEqual((yield dut.f.e[1].w_data), 0b0) + + yield + yield dut.element.w_stb.eq(0) + yield Settle() + + self.assertEqual((yield dut.f.b .data), 0b101) + self.assertEqual((yield dut.f.c.d.data), 0b00) + self.assertEqual((yield dut.f.f .data), 0b110) + + # User write: + + yield dut.f.a.r_data.eq(0b1) + yield dut.f.b.set .eq(0b010) + yield dut.f.f.clear .eq(0b010) + yield Settle() + + self.assertEqual((yield dut.element.r_data), + Const.cast(Cat( + Const(0b1, 1), # a + Const(0b101, 3), # b + Const(0b00, 2), # c.d + Const(0b00, 2), # e + Const(0b110, 3), # f + )).value) + + yield + yield dut.f.a.r_data.eq(0b0) + yield dut.f.b.set .eq(0b000) + yield dut.f.f.clear .eq(0b000) + yield Settle() + + self.assertEqual((yield dut.element.r_data), + Const.cast(Cat( + Const(0b0, 1), # a + Const(0b111, 3), # b + Const(0b00, 2), # c.d + Const(0b00, 2), # e + Const(0b100, 3), # f + )).value) + + # Concurrent writes: + + yield dut.element.w_stb .eq(1) + yield dut.element.w_data.eq(Cat( + Const(0b0, 1), # a + Const(0b111, 3), # b + Const(0b00, 2), # c.d + Const(0b00, 2), # e + Const(0b111, 3), # f + )) + + yield dut.f.b.set .eq(0b001) + yield dut.f.f.clear.eq(0b111) + yield + yield Settle() + + self.assertEqual((yield dut.element.r_data), + Const.cast(Cat( + Const(0b0, 1), # a + Const(0b001, 3), # b + Const(0b00, 2), # c.d + Const(0b00, 2), # e + Const(0b111, 3), # f + )).value) + + self.assertEqual((yield dut.f.b.data), 0b001) + self.assertEqual((yield dut.f.f.data), 0b111) + + sim = Simulator(dut) + sim.add_clock(1e-6) + sim.add_sync_process(process) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() + + +class RegisterMapTestCase(unittest.TestCase): + def setUp(self): + self.dut = RegisterMap() + + def test_add_register(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + self.assertIs(self.dut.add_register(reg_rw_a, name="reg_rw_a"), reg_rw_a) + + def test_add_register_frozen(self): + self.dut.freeze() + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + with self.assertRaisesRegex(ValueError, r"Register map is frozen"): + self.dut.add_register(reg_rw_a, name="reg_rw_a") + + def test_add_register_wrong_type(self): + with self.assertRaisesRegex(TypeError, + r"Register must be an instance of csr\.Register, not 'foo'"): + self.dut.add_register("foo", name="foo") + + def test_add_register_wrong_name(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + with self.assertRaisesRegex(TypeError, + r"Name must be a non-empty string, not None"): + self.dut.add_register(reg_rw_a, name=None) + + def test_add_register_empty_name(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + with self.assertRaisesRegex(TypeError, + r"Name must be a non-empty string, not ''"): + self.dut.add_register(reg_rw_a, name="") + + def test_add_cluster(self): + cluster = RegisterMap() + self.assertIs(self.dut.add_cluster(cluster, name="cluster"), cluster) + + def test_add_cluster_frozen(self): + self.dut.freeze() + cluster = RegisterMap() + with self.assertRaisesRegex(ValueError, r"Register map is frozen"): + self.dut.add_cluster(cluster, name="cluster") + + def test_add_cluster_wrong_type(self): + with self.assertRaisesRegex(TypeError, + r"Cluster must be an instance of csr\.RegisterMap, not 'foo'"): + self.dut.add_cluster("foo", name="foo") + + def test_add_cluster_wrong_name(self): + cluster = RegisterMap() + with self.assertRaisesRegex(TypeError, + r"Name must be a non-empty string, not None"): + self.dut.add_cluster(cluster, name=None) + + def test_add_cluster_empty_name(self): + cluster = RegisterMap() + with self.assertRaisesRegex(TypeError, + r"Name must be a non-empty string, not ''"): + self.dut.add_cluster(cluster, name="") + + def test_namespace_collision(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + cluster_0 = RegisterMap() + cluster_1 = RegisterMap() + + self.dut.add_register(reg_rw_a, name="reg_rw_a") + self.dut.add_cluster(cluster_0, name="cluster_0") + + with self.assertRaisesRegex(ValueError, # register/register + r"Name 'reg_rw_a' is already used by *"): + self.dut.add_register(reg_rw_b, name="reg_rw_a") + with self.assertRaisesRegex(ValueError, # register/cluster + r"Name 'reg_rw_a' is already used by *"): + self.dut.add_cluster(cluster_1, name="reg_rw_a") + with self.assertRaisesRegex(ValueError, # cluster/cluster + r"Name 'cluster_0' is already used by *"): + self.dut.add_cluster(cluster_1, name="cluster_0") + with self.assertRaisesRegex(ValueError, # cluster/register + r"Name 'cluster_0' is already used by *"): + self.dut.add_register(reg_rw_b, name="cluster_0") + + def test_iter_registers(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + self.dut.add_register(reg_rw_a, name="reg_rw_a") + self.dut.add_register(reg_rw_b, name="reg_rw_b") + + registers = list(self.dut.registers()) + + self.assertEqual(len(registers), 2) + self.assertIs(registers[0][0], reg_rw_a) + self.assertEqual(registers[0][1], "reg_rw_a") + self.assertIs(registers[1][0], reg_rw_b) + self.assertEqual(registers[1][1], "reg_rw_b") + + def test_iter_clusters(self): + cluster_0 = RegisterMap() + cluster_1 = RegisterMap() + self.dut.add_cluster(cluster_0, name="cluster_0") + self.dut.add_cluster(cluster_1, name="cluster_1") + + clusters = list(self.dut.clusters()) + + self.assertEqual(len(clusters), 2) + self.assertIs(clusters[0][0], cluster_0) + self.assertEqual(clusters[0][1], "cluster_0") + self.assertIs(clusters[1][0], cluster_1) + self.assertEqual(clusters[1][1], "cluster_1") + + def test_iter_flatten(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + cluster_0 = RegisterMap() + cluster_1 = RegisterMap() + + cluster_0.add_register(reg_rw_a, name="reg_rw_a") + cluster_1.add_register(reg_rw_b, name="reg_rw_b") + + self.dut.add_cluster(cluster_0, name="cluster_0") + self.dut.add_cluster(cluster_1, name="cluster_1") + + registers = list(self.dut.flatten()) + + self.assertEqual(len(registers), 2) + self.assertIs(registers[0][0], reg_rw_a) + self.assertEqual(registers[0][1], ("cluster_0", "reg_rw_a")) + self.assertIs(registers[1][0], reg_rw_b) + self.assertEqual(registers[1][1], ("cluster_1", "reg_rw_b")) + + def test_get_path(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + cluster_0 = RegisterMap() + + cluster_0.add_register(reg_rw_a, name="reg_rw_a") + self.dut.add_cluster(cluster_0, name="cluster_0") + self.dut.add_register(reg_rw_b, name="reg_rw_b") + + self.assertEqual(self.dut.get_path(reg_rw_a), ("cluster_0", "reg_rw_a")) + self.assertEqual(self.dut.get_path(reg_rw_b), ("reg_rw_b",)) + + def test_get_path_wrong_register(self): + with self.assertRaisesRegex(TypeError, + r"Register must be an instance of csr\.Register, not 'foo'"): + self.dut.get_path("foo") + + def test_get_path_unknown_register(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + with self.assertRaises(KeyError): + self.dut.get_path(reg_rw_a) + + def test_get_register(self): + reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + cluster_0 = RegisterMap() + + cluster_0.add_register(reg_rw_a, name="reg_rw_a") + self.dut.add_cluster(cluster_0, name="cluster_0") + self.dut.add_register(reg_rw_b, name="reg_rw_b") + + self.assertIs(self.dut.get_register(("cluster_0", "reg_rw_a")), reg_rw_a) + self.assertIs(self.dut.get_register(("reg_rw_b",)), reg_rw_b) + + def test_get_register_empty_path(self): + with self.assertRaisesRegex(ValueError, r"Path must be a non-empty iterable"): + self.dut.get_register(()) + + def test_get_register_wrong_path(self): + with self.assertRaisesRegex(TypeError, + r"Path must contain non-empty strings, not 0"): + self.dut.get_register(("cluster_0", 0)) + with self.assertRaisesRegex(TypeError, + r"Path must contain non-empty strings, not ''"): + self.dut.get_register(("", "reg_rw_a")) + + def test_get_register_unknown_path(self): + with self.assertRaises(KeyError): + self.dut.get_register(("reg_rw_a",)) + + +class BridgeTestCase(unittest.TestCase): + def test_memory_map(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4)})) + reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8)})) + reg_rw_12 = Register("rw", FieldMap({"a": field.RW(12)})) + reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16)})) + + cluster_0 = RegisterMap() + cluster_0.add_register(reg_rw_12, name="reg_rw_12") + cluster_0.add_register(reg_rw_16, name="reg_rw_16") + + register_map = RegisterMap() + register_map.add_register(reg_rw_4, name="reg_rw_4") + register_map.add_register(reg_rw_8, name="reg_rw_8") + register_map.add_cluster(cluster_0, name="cluster_0") + + dut = Bridge(register_map, addr_width=16, data_width=8) + registers = list(dut.bus.memory_map.resources()) + + self.assertIs(registers[0][0], reg_rw_4.element) + self.assertEqual(registers[0][1], "reg_rw_4") + self.assertEqual(registers[0][2], (0, 1)) + + self.assertIs(registers[1][0], reg_rw_8.element) + self.assertEqual(registers[1][1], "reg_rw_8") + self.assertEqual(registers[1][2], (1, 2)) + + self.assertIs(registers[2][0], reg_rw_12.element) + self.assertEqual(registers[2][1], "cluster_0__reg_rw_12") + self.assertEqual(registers[2][2], (2, 4)) + + self.assertIs(registers[3][0], reg_rw_16.element) + self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") + self.assertEqual(registers[3][2], (4, 6)) + + def test_wrong_register_map(self): + with self.assertRaisesRegex(TypeError, + r"Register map must be an instance of RegisterMap, not 'foo'"): + dut = Bridge("foo", addr_width=16, data_width=8) + + def test_register_addr(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4)})) + reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8)})) + reg_rw_12 = Register("rw", FieldMap({"a": field.RW(12)})) + reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16)})) + + cluster_0 = RegisterMap() + cluster_0.add_register(reg_rw_12, name="reg_rw_12") + cluster_0.add_register(reg_rw_16, name="reg_rw_16") + + register_map = RegisterMap() + register_map.add_register(reg_rw_4, name="reg_rw_4") + register_map.add_register(reg_rw_8, name="reg_rw_8") + register_map.add_cluster(cluster_0, name="cluster_0") + + register_addr = { + "reg_rw_4": 0x10, + "reg_rw_8": None, + "cluster_0": { + "reg_rw_12": 0x20, + "reg_rw_16": None, + }, + } + + dut = Bridge(register_map, addr_width=16, data_width=8, + register_addr=register_addr) + registers = list(dut.bus.memory_map.resources()) + + self.assertEqual(registers[0][1], "reg_rw_4") + self.assertEqual(registers[0][2], (0x10, 0x11)) + + self.assertEqual(registers[1][1], "reg_rw_8") + self.assertEqual(registers[1][2], (0x11, 0x12)) + + self.assertEqual(registers[2][1], "cluster_0__reg_rw_12") + self.assertEqual(registers[2][2], (0x20, 0x22)) + + self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") + self.assertEqual(registers[3][2], (0x22, 0x24)) + + def test_register_alignment(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4)})) + reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8)})) + reg_rw_12 = Register("rw", FieldMap({"a": field.RW(12)})) + reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16)})) + + cluster_0 = RegisterMap() + cluster_0.add_register(reg_rw_12, name="reg_rw_12") + cluster_0.add_register(reg_rw_16, name="reg_rw_16") + + register_map = RegisterMap() + register_map.add_register(reg_rw_4, name="reg_rw_4") + register_map.add_register(reg_rw_8, name="reg_rw_8") + register_map.add_cluster(cluster_0, name="cluster_0") + + register_alignment = { + "reg_rw_4": None, + "reg_rw_8": None, + "cluster_0": { + "reg_rw_12": 3, + "reg_rw_16": None, + }, + } + + dut = Bridge(register_map, addr_width=16, data_width=8, alignment=1, + register_alignment=register_alignment) + registers = list(dut.bus.memory_map.resources()) + + self.assertEqual(registers[0][1], "reg_rw_4") + self.assertEqual(registers[0][2], (0, 2)) + + self.assertEqual(registers[1][1], "reg_rw_8") + self.assertEqual(registers[1][2], (2, 4)), + + self.assertEqual(registers[2][1], "cluster_0__reg_rw_12") + self.assertEqual(registers[2][2], (8, 16)) + + self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") + self.assertEqual(registers[3][2], (16, 18)) + + def test_register_out_of_bounds(self): + reg_rw_24 = Register("rw", FieldMap({"a": field.RW(24)})) + register_map = RegisterMap() + register_map.add_register(reg_rw_24, name="reg_rw_24") + with self.assertRaisesRegex(ValueError, + r"Address range 0x0\.\.0x3 out of bounds for memory map spanning " + r"range 0x0\.\.0x2 \(1 address bits\)"): + dut = Bridge(register_map, addr_width=1, data_width=8) + + def test_wrong_register_address(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + register_map = RegisterMap() + register_map.add_register(reg_rw_4, name="reg_rw_4") + with self.assertRaisesRegex(TypeError, r"Register address must be a mapping, not 'foo'"): + dut = Bridge(register_map, addr_width=1, data_width=8, register_addr="foo") + + def test_wrong_cluster_address(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + cluster_0 = RegisterMap() + cluster_0.add_register(reg_rw_4, name="reg_rw_4") + register_map = RegisterMap() + register_map.add_cluster(cluster_0, name="cluster_0") + with self.assertRaisesRegex(TypeError, + r"Register address \('cluster_0',\) must be a mapping, not 'foo'"): + dut = Bridge(register_map, addr_width=1, data_width=8, + register_addr={"cluster_0": "foo"}) + + def test_wrong_register_alignment(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + register_map = RegisterMap() + register_map.add_register(reg_rw_4, name="reg_rw_4") + with self.assertRaisesRegex(TypeError, r"Register alignment must be a mapping, not 'foo'"): + dut = Bridge(register_map, addr_width=1, data_width=8, register_alignment="foo") + + def test_wrong_cluster_alignment(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + cluster_0 = RegisterMap() + cluster_0.add_register(reg_rw_4, name="reg_rw_4") + register_map = RegisterMap() + register_map.add_cluster(cluster_0, name="cluster_0") + with self.assertRaisesRegex(TypeError, + r"Register alignment \('cluster_0',\) must be a mapping, not 'foo'"): + dut = Bridge(register_map, addr_width=1, data_width=8, + register_alignment={"cluster_0": "foo"}) + + def test_sim(self): + reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4, reset=0x0)})) + reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8, reset=0x11)})) + reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16, reset=0x3322)})) + + cluster_0 = RegisterMap() + cluster_0.add_register(reg_rw_16, name="reg_rw_16") + + register_map = RegisterMap() + register_map.add_register(reg_rw_4, name="reg_rw_4") + register_map.add_register(reg_rw_8, name="reg_rw_8") + register_map.add_cluster(cluster_0, name="cluster_0") + + dut = Bridge(register_map, addr_width=16, data_width=8) + + def process(): + yield dut.bus.addr.eq(0) + yield dut.bus.r_stb.eq(1) + yield dut.bus.w_stb.eq(1) + yield dut.bus.w_data.eq(0xa) + yield + yield Settle() + self.assertEqual((yield dut.bus.r_data), 0x0) + self.assertEqual((yield reg_rw_4 .f.a.port.r_stb), 1) + self.assertEqual((yield reg_rw_8 .f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_16.f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_4 .f.a.port.w_stb), 1) + self.assertEqual((yield reg_rw_8 .f.a.port.w_stb), 0) + self.assertEqual((yield reg_rw_16.f.a.port.w_stb), 0) + yield dut.bus.r_stb.eq(0) + yield dut.bus.w_stb.eq(0) + yield + yield Settle() + self.assertEqual((yield reg_rw_4.f.a.data), 0xa) + + yield dut.bus.addr.eq(1) + yield dut.bus.r_stb.eq(1) + yield dut.bus.w_stb.eq(1) + yield dut.bus.w_data.eq(0xbb) + yield + yield Settle() + self.assertEqual((yield dut.bus.r_data), 0x11) + self.assertEqual((yield reg_rw_4 .f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_8 .f.a.port.r_stb), 1) + self.assertEqual((yield reg_rw_16.f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_4 .f.a.port.w_stb), 0) + self.assertEqual((yield reg_rw_8 .f.a.port.w_stb), 1) + self.assertEqual((yield reg_rw_16.f.a.port.w_stb), 0) + yield dut.bus.r_stb.eq(0) + yield dut.bus.w_stb.eq(0) + yield + yield Settle() + self.assertEqual((yield reg_rw_8.f.a.data), 0xbb) + + yield dut.bus.addr.eq(2) + yield dut.bus.r_stb.eq(1) + yield dut.bus.w_stb.eq(1) + yield dut.bus.w_data.eq(0xcc) + yield + yield Settle() + self.assertEqual((yield dut.bus.r_data), 0x22) + self.assertEqual((yield reg_rw_4 .f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_8 .f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_16.f.a.port.r_stb), 1) + self.assertEqual((yield reg_rw_4 .f.a.port.w_stb), 0) + self.assertEqual((yield reg_rw_8 .f.a.port.w_stb), 0) + self.assertEqual((yield reg_rw_16.f.a.port.w_stb), 0) + yield dut.bus.r_stb.eq(0) + yield dut.bus.w_stb.eq(0) + yield + yield Settle() + self.assertEqual((yield reg_rw_16.f.a.data), 0x3322) + + yield dut.bus.addr.eq(3) + yield dut.bus.r_stb.eq(1) + yield dut.bus.w_stb.eq(1) + yield dut.bus.w_data.eq(0xdd) + yield + yield Settle() + self.assertEqual((yield dut.bus.r_data), 0x33) + self.assertEqual((yield reg_rw_4 .f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_8 .f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_16.f.a.port.r_stb), 0) + self.assertEqual((yield reg_rw_4 .f.a.port.w_stb), 0) + self.assertEqual((yield reg_rw_8 .f.a.port.w_stb), 0) + self.assertEqual((yield reg_rw_16.f.a.port.w_stb), 1) + yield dut.bus.r_stb.eq(0) + yield dut.bus.w_stb.eq(0) + yield + yield Settle() + self.assertEqual((yield reg_rw_16.f.a.data), 0xddcc) + + sim = Simulator(dut) + sim.add_clock(1e-6) + sim.add_sync_process(process) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run()