From f5bfa2262665cb1c2fe7e69ea32114f045375158 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Fran=C3=A7ois=20Nguyen?= Date: Fri, 15 Dec 2023 13:41:24 +0100 Subject: [PATCH] csr.reg: defer Field creation until Register.__init__(). Before this commit, a Register whose fields were variable annotations couldn't be instantiated more than once, as the latter belonged to the global scope. The Field class is now a factory for user-defined field components, instead of a base class. Every `FieldMap` or `FieldArray` sharing a Field will use a different instance of its component, obtained from Field.create(). Also: * update to follow RFCs 37 and 38 * docstring and diagnostics improvements. --- amaranth_soc/csr/field.py | 150 ++++++----- amaranth_soc/csr/reg.py | 342 ++++++++++++++---------- tests/test_csr_reg.py | 539 ++++++++++++++++++++++++-------------- 3 files changed, 633 insertions(+), 398 deletions(-) diff --git a/amaranth_soc/csr/field.py b/amaranth_soc/csr/field.py index 33df5a0..144d2a8 100644 --- a/amaranth_soc/csr/field.py +++ b/amaranth_soc/csr/field.py @@ -1,25 +1,33 @@ from amaranth import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out -from .reg import Field +from .reg import FieldPort __all__ = ["R", "W", "RW", "RW1C", "RW1S"] -class R(Field): - __doc__ = Field._doc_template.format( - description=""" - A read-only field. - """.strip(), - parameters="", - attributes=""" +class R(wiring.Component): + """A read-only field. + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. + + Interface attributes + -------------------- + port : :class:`FieldPort` + Field port. r_data : Signal(shape) Read data. Drives the :attr:`~FieldPort.r_data` signal of ``port``. - """.strip()) - + """ def __init__(self, shape): - super().__init__(shape, access="r") - self.r_data = Signal(shape) + super().__init__({ + "port": In(FieldPort.Signature(shape, access="r")), + "r_data": In(shape), + }) def elaborate(self, platform): m = Module() @@ -27,20 +35,26 @@ def elaborate(self, platform): return m -class W(Field): - __doc__ = Field._doc_template.format( - description=""" - A write-only field. - """.strip(), - parameters="", - attributes=""" +class W(wiring.Component): + """A write-only field. + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. + + Interface attributes + -------------------- + port : :class:`FieldPort` + Field port. w_data : Signal(shape) Write data. Driven by the :attr:`~FieldPort.w_data` signal of ``port``. - """.strip()) - + """ def __init__(self, shape): - super().__init__(shape, access="w") - self.w_data = Signal(shape) + super().__init__({ + "port": In(FieldPort.Signature(shape, access="w")), + "w_data": Out(shape), + }) def elaborate(self, platform): m = Module() @@ -48,26 +62,31 @@ def elaborate(self, platform): return m -class RW(Field): - __doc__ = Field._doc_template.format( - description=""" - A read/write field with built-in storage. +class RW(wiring.Component): + """A read/write field with built-in storage. Storage is updated with the value of ``port.w_data`` one clock cycle after ``port.w_stb`` is asserted. - """.strip(), - parameters=""" + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. reset : :class:`int` Storage reset value. - """, - attributes=""" + + Interface attributes + -------------------- + port : :class:`FieldPort` + Field port. data : Signal(shape) Storage output. - """.strip()) - + """ def __init__(self, shape, *, reset=0): - super().__init__(shape, access="rw") - self.data = Signal(shape) + super().__init__({ + "port": In(FieldPort.Signature(shape, access="rw")), + "data": Out(shape), + }) self._storage = Signal(shape, reset=reset) self._reset = reset @@ -89,32 +108,37 @@ def elaborate(self, platform): return m -class RW1C(Field): - __doc__ = Field._doc_template.format( - description=""" - A read/write-one-to-clear field with built-in storage. +class RW1C(wiring.Component): + """A read/write-one-to-clear field with built-in storage. Storage bits are: * cleared by high bits in ``port.w_data``, one clock cycle after ``port.w_stb`` is asserted; * set by high bits in ``set``, one clock cycle after they are asserted. If a storage bit is set and cleared on the same clock cycle, setting it has precedence. - """.strip(), - parameters=""" + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. reset : :class:`int` Storage reset value. - """, - attributes=""" + + Interface attributes + -------------------- + port : :class:`FieldPort` + Field port. data : Signal(shape) Storage output. set : Signal(shape) Mask to set storage bits. - """.strip()) - + """ def __init__(self, shape, *, reset=0): - super().__init__(shape, access="rw") - self.data = Signal(shape) - self.set = Signal(shape) + super().__init__({ + "port": In(FieldPort.Signature(shape, access="rw")), + "data": Out(shape), + "set": In(shape), + }) self._storage = Signal(shape, reset=reset) self._reset = reset @@ -139,31 +163,37 @@ def elaborate(self, platform): return m -class RW1S(Field): - __doc__ = Field._doc_template.format( - description=""" - A read/write-one-to-set field with built-in storage. +class RW1S(wiring.Component): + """A read/write-one-to-set field with built-in storage. Storage bits are: * set by high bits in ``port.w_data``, one clock cycle after ``port.w_stb`` is asserted; * cleared by high bits in ``clear``, one clock cycle after they are asserted. If a storage bit is set and cleared on the same clock cycle, setting it has precedence. - """.strip(), - parameters=""" + + Parameters + ---------- + shape : :ref:`shape-castable ` + Shape of the field. reset : :class:`int` Storage reset value. - """, - attributes=""" + + Interface attributes + -------------------- + port : :class:`FieldPort` + Field port. data : Signal(shape) Storage output. clear : Signal(shape) Mask to clear storage bits. - """.strip()) + """ def __init__(self, shape, *, reset=0): - super().__init__(shape, access="rw") - self.data = Signal(shape) - self.clear = Signal(shape) + super().__init__({ + "port": In(FieldPort.Signature(shape, access="rw")), + "clear": In(shape), + "data": Out(shape), + }) self._storage = Signal(shape, reset=reset) self._reset = reset diff --git a/amaranth_soc/csr/reg.py b/amaranth_soc/csr/reg.py index af68d44..4cf7144 100644 --- a/amaranth_soc/csr/reg.py +++ b/amaranth_soc/csr/reg.py @@ -4,7 +4,7 @@ from amaranth.lib.wiring import In, Out, connect, flipped from ..memory import MemoryMap -from .bus import Element, Multiplexer +from .bus import Element, Signature, Multiplexer __all__ = ["FieldPort", "Field", "FieldMap", "FieldArray", "Register", "RegisterMap", "Bridge"] @@ -52,17 +52,15 @@ class Signature(wiring.Signature): """ def __init__(self, shape, access): self.check_parameters(shape, access) - self._shape = Shape.cast(shape) self._access = FieldPort.Access(access) - members = { + super().__init__({ "r_data": In(self.shape), "r_stb": Out(1), "w_data": Out(self.shape), "w_stb": Out(1), - } - super().__init__(members) + }) @property def shape(self): @@ -95,7 +93,7 @@ def check_parameters(cls, shape, access): except ValueError as e: raise ValueError(f"{access!r} is not a valid FieldPort.Access") from e - def create(self, *, path=()): + def create(self, *, path=None, src_loc_at=0): """Create a compatible interface. See :meth:`wiring.Signature.create` for details. @@ -104,7 +102,7 @@ def create(self, *, path=()): ------- A :class:`FieldPort` object using this signature. """ - return FieldPort(self, path=path) + return FieldPort(self, path=path, src_loc_at=1 + src_loc_at) def __eq__(self, other): """Compare signatures. @@ -129,18 +127,23 @@ def __repr__(self): path : iter(:class:`str`) Path to the field port. Optional. See :class:`wiring.PureInterface`. + Attributes + ---------- + shape : :ref:`shape-castable ` + Shape of the field. See :class:`FieldPort.Signature`. + access : :class:`FieldPort.Access` + Field access mode. See :class:`FieldPort.Signature`. + Raises ------ :exc:`TypeError` - If ``shape`` is not a shape-castable object. - :exc:`TypeError` - If ``access`` is not a member of :class:`FieldPort.Access`. + If ``signature`` is not a :class:`FieldPort.Signature`. """ - def __init__(self, signature, *, path=()): + def __init__(self, signature, *, path=None, src_loc_at=0): if not isinstance(signature, FieldPort.Signature): raise TypeError(f"This interface requires a csr.FieldPort.Signature, not " f"{signature!r}") - super().__init__(signature, path=path) + super().__init__(signature, path=path, src_loc_at=1 + src_loc_at) @property def shape(self): @@ -154,51 +157,34 @@ def __repr__(self): return f"csr.FieldPort({self.signature!r})" -class Field(wiring.Component): - _doc_template = """ - {description} +class Field: + """Register field factory. Parameters ---------- - shape : :ref:`shape-castable ` - Shape of the field. - access : :class:`FieldPort.Access` - Field access mode. - {parameters} - - Attributes - ---------- - port : :class:`FieldPort` - Field port. - {attributes} + field_cls : :class:`type` + The field type instantiated by :meth:`Field.create`. It must be a :class:`wiring.Component` + subclass. ``field_cls`` instances must have a signature containing a member named "port", + which must be an input :class:`FieldPort.Signature`. + *args : :class:`tuple` + Positional arguments passed to ``field_cls.__init__``. + **kwargs : :class:`dict` + Keyword arguments passed to ``field_cls.__init__``. """ + def __init__(self, field_cls, *args, **kwargs): + self._field_cls = field_cls + self._args = args + self._kwargs = kwargs - __doc__ = _doc_template.format( - description=""" - A generic register field. - """.strip(), - parameters="", - attributes="") - - def __init__(self, shape, access): - FieldPort.Signature.check_parameters(shape, access) - self._shape = Shape.cast(shape) - self._access = FieldPort.Access(access) - super().__init__() - - @property - def shape(self): - return self._shape - - @property - def access(self): - return self._access + def create(self): + """Create a field instance. - @property - def signature(self): - return wiring.Signature({ - "port": Out(FieldPort.Signature(self._shape, self._access)), - }) + Returns + ------- + :class:`object` + The instance returned by ``field_cls(*args, **kwargs)``. + """ + return self._field_cls(*self._args, **self._kwargs) class FieldMap(Mapping): @@ -206,23 +192,42 @@ class FieldMap(Mapping): Parameters ---------- - fields : dict of :class:`str` to one of :class:`Field` or :class:`FieldMap`. + fields : :class:`dict` of :class:`str` to (:class:`Field` or :class:`dict` or :class:`list`) + Field map members. A :class:`FieldMap` stores an instance of :class:`Field` members (see + :meth:`Field.create`). :class:`dict` members are cast to :class:`FieldMap`. :class:`list` + members are cast to :class:`FieldArray`. + + Raises + ------ + :exc:`TypeError` + If ``fields`` is not a non-empty dict. + :exc:`TypeError` + If ``fields`` has a key that is not a non-empty string. + :exc:`TypeError` + If ``fields`` has a value that is neither a :class:`Field` object or a dict or list of + :class:`Field` objects. """ def __init__(self, fields): self._fields = {} - if not isinstance(fields, Mapping) or len(fields) == 0: - raise TypeError("Fields must be provided as a non-empty mapping, not {!r}" - .format(fields)) + if not isinstance(fields, dict) or len(fields) == 0: + raise TypeError(f"Fields must be provided as a non-empty dict, not {fields!r}") for key, field in fields.items(): if not isinstance(key, str) or not key: - raise TypeError("Field name must be a non-empty string, not {!r}" - .format(key)) - if not isinstance(field, (Field, FieldMap, FieldArray)): - raise TypeError("Field must be a Field or a FieldMap or a FieldArray, not {!r}" - .format(field)) - self._fields[key] = field + raise TypeError(f"Field name must be a non-empty string, not {key!r}") + + if isinstance(field, Field): + field_inst = field.create() + elif isinstance(field, dict): + field_inst = FieldMap(field) + elif isinstance(field, list): + field_inst = FieldArray(field) + else: + raise TypeError(f"{field!r} must be a Field object or a collection of Field " + f"objects") + + self._fields[key] = field_inst def __getitem__(self, key): """Access a field by name or index. @@ -257,13 +262,11 @@ def __getattr__(self, name): try: item = self[name] except KeyError: - raise AttributeError("Field map does not have a field {!r}; " - "did you mean one of: {}?" - .format(name, ", ".join(repr(name) for name in self.keys()))) + raise AttributeError(f"Field map does not have a field {name!r}; did you mean one of: " + f"{', '.join(f'{name!r}' for name in self.keys())}?") if name.startswith("_"): - raise AttributeError("Field map field {!r} has a reserved name and may only be " - "accessed by indexing" - .format(name)) + raise AttributeError(f"Field map field {name!r} has a reserved name and may only be " + f"accessed by indexing") return item def __iter__(self): @@ -277,6 +280,13 @@ def __iter__(self): yield from self._fields def __len__(self): + """Field map length. + + Returns + ------- + :class:`int` + The number of items in the map. + """ return len(self._fields) def flatten(self): @@ -286,17 +296,15 @@ def flatten(self): ------ iter(:class:`str`) Path of the field. It is prefixed by the name of every nested field collection. - :class:`Field` + :class:`wiring.Component` Register field. """ for key, field in self.items(): - if isinstance(field, Field): - yield (key,), field - elif isinstance(field, (FieldMap, FieldArray)): + if isinstance(field, (FieldMap, FieldArray)): for sub_path, sub_field in field.flatten(): yield (key, *sub_path), sub_field else: - assert False # :nocov: + yield (key,), field class FieldArray(Sequence): @@ -304,16 +312,37 @@ class FieldArray(Sequence): Parameters ---------- - fields : iter(:class:`Field` or :class:`FieldMap` or :class:`FieldArray`) - Field array members. + fields : :class:`list` of (:class:`Field` or :class:`dict` or :class:`list`) + Field array members. A :class:`FieldArray` stores an instance of :class:`Field` members + (see :meth:`Field.create`). :class:`dict` members are cast to :class:`FieldMap`. + :class:`list` members are cast to :class:`FieldArray`. + + Raises + ------ + :exc:`TypeError` + If ``fields`` is not a non-empty list. + :exc:`TypeError` + If ``fields`` has an item that is neither a :class:`Field` object or a dict or list of + :class:`Field` objects. """ def __init__(self, fields): - fields = tuple(fields) + self._fields = [] + + if not isinstance(fields, list) or len(fields) == 0: + raise TypeError(f"Fields must be provided as a non-empty list, not {fields!r}") + for field in fields: - if not isinstance(field, (Field, FieldMap, FieldArray)): - raise TypeError("Field must be a Field or a FieldMap or a FieldArray, not {!r}" - .format(field)) - self._fields = fields + if isinstance(field, Field): + field_inst = field.create() + elif isinstance(field, dict): + field_inst = FieldMap(field) + elif isinstance(field, list): + field_inst = FieldArray(field) + else: + raise TypeError(f"{field!r} must be a Field object or a collection of Field " + f"objects") + + self._fields.append(field_inst) def __getitem__(self, key): """Access a field by index. @@ -342,17 +371,15 @@ def flatten(self): ------ iter(:class:`str`) Path of the field. It is prefixed by the name of every nested field collection. - :class:`Field` + :class:`wiring.Component` Register field. """ for key, field in enumerate(self._fields): - if isinstance(field, Field): - yield (key,), field - elif isinstance(field, (FieldMap, FieldArray)): + if isinstance(field, (FieldMap, FieldArray)): for sub_path, sub_field in field.flatten(): yield (key, *sub_path), sub_field else: - assert False # :nocov: + yield (key,), field class Register(wiring.Component): @@ -362,14 +389,21 @@ class Register(wiring.Component): ---------- access : :class:`Element.Access` Register access mode. - fields : :class:`FieldMap` or :class:`FieldArray` - Collection of register fields. If ``None`` (default), a :class:`FieldMap` is created - from Python :term:`variable annotations `. + fields : :class:`dict` or :class:`list` + Collection of register fields. If ``None`` (default), a :class:`dict` is populated from + Python :term:`variable annotations `. If ``fields`` is a + :class:`dict`, it is cast to a :class:`FieldMap`; if ``fields`` is a :class:`list`, it is + cast to a :class`FieldArray`. + + Interface attributes + -------------------- + element : :class:`Element` + Interface between this register and a CSR bus primitive. Attributes ---------- - element : :class:`Element` - Interface between this register and a CSR bus primitive. + access : :class:`Element.Access` + Register access mode. fields : :class:`FieldMap` or :class:`FieldArray` Collection of register fields. f : :class:`FieldMap` or :class:`FieldArray` @@ -380,7 +414,9 @@ class Register(wiring.Component): :exc:`TypeError` If ``access`` is not a member of :class:`Element.Access`. :exc:`TypeError` - If ``fields`` is not ``None`` or a :class:`FieldMap` or a :class:`FieldArray`. + If ``fields`` is not ``None`` or a :class:`dict` or a :class:`list`. + :exc:`ValueError` + If ``fields`` is not ``None`` and at least one variable annotation is a :class:`Field`. :exc:`ValueError` If ``access`` is not readable and at least one field is readable. :exc:`ValueError` @@ -389,37 +425,60 @@ class Register(wiring.Component): def __init__(self, access="rw", fields=None): if not isinstance(access, Element.Access) and access not in ("r", "w", "rw"): raise TypeError(f"Access mode must be one of \"r\", \"w\", or \"rw\", not {access!r}") - access = Element.Access(access) + self._access = Element.Access(access) if hasattr(self, "__annotations__"): - annot_fields = {} - for key, value in self.__annotations__.items(): - if isinstance(value, (Field, FieldMap, FieldArray)): - annot_fields[key] = value + def filter_dict(d): + fields = {} + for key, value in d.items(): + if isinstance(value, Field): + fields[key] = value + elif isinstance(value, dict): + if sub_fields := filter_dict(value): + fields[key] = sub_fields + elif isinstance(value, list): + if sub_fields := filter_list(value): + fields[key] = sub_fields + return fields + + def filter_list(l): + fields = [] + for item in l: + if isinstance(item, Field): + fields.append(item) + elif isinstance(item, dict): + if sub_fields := filter_dict(item): + fields.append(sub_fields) + elif isinstance(item, list): + if sub_fields := filter_list(item): + fields.append(sub_fields) + return fields + + annot_fields = filter_dict(self.__annotations__) if fields is None: - fields = FieldMap(annot_fields) + fields = annot_fields elif annot_fields: raise ValueError(f"Field collection {fields} cannot be provided in addition to " f"field annotations: {', '.join(annot_fields)}") - if not isinstance(fields, (FieldMap, FieldArray)): - raise TypeError(f"Field collection must be a FieldMap or a FieldArray, not {fields!r}") + self._fields = FieldMap(fields) width = 0 - for field_path, field in fields.flatten(): - width += Shape.cast(field.shape).width - if field.access.readable() and not access.readable(): + for field_path, field in self._fields.flatten(): + width += Shape.cast(field.port.shape).width + if field.port.access.readable() and not self._access.readable(): raise ValueError(f"Field {'__'.join(field_path)} is readable, but register access " - f"mode is {access!r}") - if field.access.writable() and not access.writable(): + f"mode is {self._access!r}") + if field.port.access.writable() and not self._access.writable(): raise ValueError(f"Field {'__'.join(field_path)} is writable, but register access " - f"mode is {access!r}") + f"mode is {self._access!r}") + + super().__init__({"element": Out(Element.Signature(width, self._access))}) - self._width = width - self._access = access - self._fields = fields - super().__init__() + @property + def access(self): + return self._access @property def fields(self): @@ -429,12 +488,6 @@ def fields(self): def f(self): return self._fields - @property - def signature(self): - return wiring.Signature({ - "element": Out(Element.Signature(self._width, self._access)), - }) - def __iter__(self): """Recursively iterate over the field collection. @@ -453,16 +506,17 @@ def elaborate(self, platform): field_start = 0 for field_path, field in self.fields.flatten(): - m.submodules["__".join(str(key) for key in field_path)] = field + field_width = Shape.cast(field.port.shape).width + field_slice = slice(field_start, field_start + field_width) - field_slice = slice(field_start, field_start + Shape.cast(field.shape).width) + m.submodules["__".join(str(key) for key in field_path)] = field - if field.access.readable(): + if field.port.access.readable(): m.d.comb += [ self.element.r_data[field_slice].eq(field.port.r_data), field.port.r_stb.eq(self.element.r_stb), ] - if field.access.writable(): + if field.port.access.writable(): m.d.comb += [ field.port.w_data.eq(self.element.w_data[field_slice]), field.port.w_stb .eq(self.element.w_stb), @@ -518,13 +572,12 @@ def add_register(self, register, *, name): if self._frozen: raise ValueError("Register map is frozen") if not isinstance(register, Register): - raise TypeError("Register must be an instance of csr.Register, not {!r}" - .format(register)) + raise TypeError(f"Register must be an instance of csr.Register, not {register!r}") if not isinstance(name, str) or not name: - raise TypeError("Name must be a non-empty string, not {!r}".format(name)) + raise TypeError(f"Name must be a non-empty string, not {name!r}") if name in self._namespace: - raise ValueError("Name '{}' is already used by {!r}".format(name, self._namespace[name])) + raise ValueError(f"Name '{name}' is already used by {self._namespace[name]!r}") self._registers[id(register)] = register, name self._namespace[name] = register @@ -572,13 +625,12 @@ def add_cluster(self, cluster, *, name): if self._frozen: raise ValueError("Register map is frozen") if not isinstance(cluster, RegisterMap): - raise TypeError("Cluster must be an instance of csr.RegisterMap, not {!r}" - .format(cluster)) + raise TypeError(f"Cluster must be an instance of csr.RegisterMap, not {cluster!r}") if not isinstance(name, str) or not name: - raise TypeError("Name must be a non-empty string, not {!r}".format(name)) + raise TypeError(f"Name must be a non-empty string, not {name!r}") if name in self._namespace: - raise ValueError("Name '{}' is already used by {!r}".format(name, self._namespace[name])) + raise ValueError(f"Name '{name}' is already used by {self._namespace[name]!r}") self._clusters[id(cluster)] = cluster, name self._namespace[name] = cluster @@ -639,8 +691,7 @@ def get_path(self, register, *, _path=()): If ``register` is not in the register map. """ if not isinstance(register, Register): - raise TypeError("Register must be an instance of csr.Register, not {!r}" - .format(register)) + raise TypeError(f"Register must be an instance of csr.Register, not {register!r}") if id(register) in self._registers: _, name = self._registers[id(register)] @@ -682,7 +733,7 @@ def get_register(self, path): raise ValueError("Path must be a non-empty iterable") for name in path: if not isinstance(name, str) or not name: - raise TypeError("Path must contain non-empty strings, not {!r}".format(name)) + raise TypeError(f"Path must contain non-empty strings, not {name!r}") name, *rest = path @@ -717,9 +768,9 @@ class Bridge(wiring.Component): name : :class:`str` Window name. Optional. register_addr : :class:`dict` - Register address mapping. Optional, defaults to ``None``. + Register address assignments. Optional, defaults to ``None``. register_alignment : :class:`dict` - Register alignment mapping. Optional, defaults to ``None``. + Register alignment assignments. Optional, defaults to ``None``. Attributes ---------- @@ -733,15 +784,15 @@ class Bridge(wiring.Component): :exc:`TypeError` If ``register_map`` is not an instance of :class:`RegisterMap`. :exc:`TypeError` - If ``register_addr`` is a not a mapping. + If ``register_addr`` is a not a dict. :exc:`TypeError` - If ``register_alignment`` is a not a mapping. + If ``register_alignment`` is a not a dict. """ def __init__(self, register_map, *, addr_width, data_width, alignment=0, name=None, register_addr=None, register_alignment=None): if not isinstance(register_map, RegisterMap): - raise TypeError("Register map must be an instance of RegisterMap, not {!r}" - .format(register_map)) + raise TypeError(f"Register map must be an instance of RegisterMap, not " + f"{register_map!r}") memory_map = MemoryMap(addr_width=addr_width, data_width=data_width, alignment=alignment, name=name) @@ -752,9 +803,9 @@ def get_register_param(path, root, kind): for name in path: if node is None: break - if not isinstance(node, Mapping): - raise TypeError("Register {}{} must be a mapping, not {!r}" - .format(kind, "" if not prev else f" {tuple(prev)}", node)) + if not isinstance(node, dict): + raise TypeError(f"Register {kind}{'' if not prev else f' {tuple(prev)}'} must " + f"be a dict, not {node!r}") prev.append(name) node = node.get(name, None) return node @@ -771,20 +822,21 @@ def get_register_param(path, root, kind): self._map = register_map self._mux = Multiplexer(memory_map) - super().__init__() + + super().__init__({"bus": In(Signature(addr_width=addr_width, data_width=data_width))}) + self.bus.memory_map = self._mux.bus.memory_map @property def register_map(self): return self._map - @property - def signature(self): - return self._mux.signature - def elaborate(self, platform): m = Module() + + m.submodules.mux = self._mux for register, path in self.register_map.flatten(): m.submodules["__".join(path)] = register - m.submodules.mux = self._mux - connect(m, flipped(self), self._mux) + + connect(m, flipped(self.bus), self._mux.bus) + return m diff --git a/tests/test_csr_reg.py b/tests/test_csr_reg.py index a77f93e..81e32a6 100644 --- a/tests/test_csr_reg.py +++ b/tests/test_csr_reg.py @@ -2,11 +2,21 @@ import unittest from amaranth import * -from amaranth.lib.wiring import * +from amaranth.lib import wiring +from amaranth.lib.wiring import In, Out from amaranth.sim import * from amaranth_soc.csr.reg import * -from amaranth_soc.csr import field +from amaranth_soc.csr import field, Element + + +def _silence_unused(*objs): + for obj in objs: + Fragment.get(obj, platform=None) + + +def _compatible_fields(a, b): + return a.port.shape == b.port.shape and a.port.access == b.port.access class FieldPortSignatureTestCase(unittest.TestCase): @@ -14,18 +24,24 @@ def test_shape_1_ro(self): sig = FieldPort.Signature(1, "r") self.assertEqual(sig.shape, unsigned(1)) self.assertEqual(sig.access, FieldPort.Access.R) - self.assertEqual(sig.members, Signature({ + self.assertEqual(sig.members, wiring.Signature({ "r_data": In(unsigned(1)), "r_stb": Out(1), "w_data": Out(unsigned(1)), "w_stb": Out(1), }).members) + self.assertEqual(repr(sig), + "csr.FieldPort.Signature(SignatureMembers({" + "'r_data': In(unsigned(1)), " + "'r_stb': Out(1), " + "'w_data': Out(unsigned(1)), " + "'w_stb': Out(1)}))") def test_shape_8_rw(self): sig = FieldPort.Signature(8, "rw") self.assertEqual(sig.shape, unsigned(8)) self.assertEqual(sig.access, FieldPort.Access.RW) - self.assertEqual(sig.members, Signature({ + self.assertEqual(sig.members, wiring.Signature({ "r_data": In(unsigned(8)), "r_stb": Out(1), "w_data": Out(unsigned(8)), @@ -36,7 +52,7 @@ def test_shape_10_wo(self): sig = FieldPort.Signature(10, "w") self.assertEqual(sig.shape, unsigned(10)) self.assertEqual(sig.access, FieldPort.Access.W) - self.assertEqual(sig.members, Signature({ + self.assertEqual(sig.members, wiring.Signature({ "r_data": In(unsigned(10)), "r_stb": Out(1), "w_data": Out(unsigned(10)), @@ -47,7 +63,7 @@ def test_shape_0_rw(self): sig = FieldPort.Signature(0, "w") self.assertEqual(sig.shape, unsigned(0)) self.assertEqual(sig.access, FieldPort.Access.W) - self.assertEqual(sig.members, Signature({ + self.assertEqual(sig.members, wiring.Signature({ "r_data": In(unsigned(0)), "r_stb": Out(1), "w_data": Out(unsigned(0)), @@ -91,6 +107,12 @@ def test_simple(self): self.assertEqual(port.access, FieldPort.Access.RW) self.assertEqual(port.r_stb.name, "foo__bar__r_stb") self.assertIs(port.signature, sig) + self.assertEqual(repr(port), + "csr.FieldPort(csr.FieldPort.Signature(SignatureMembers({" + "'r_data': In(unsigned(8)), " + "'r_stb': Out(1), " + "'w_data': Out(unsigned(8)), " + "'w_stb': Out(1)})))") def test_wrong_signature(self): with self.assertRaisesRegex(TypeError, @@ -98,141 +120,177 @@ def test_wrong_signature(self): FieldPort("foo") -def _compatible_fields(a, b): - return isinstance(a, Field) and type(a) == type(b) and \ - a.shape == b.shape and a.access == b.access +class FieldTestCase(unittest.TestCase): + def test_create(self): + class MockField(wiring.Component): + def __init__(self, shape, *, reset): + super().__init__({"port": Out(FieldPort.Signature(shape, "rw"))}) + self.reset = reset + def elaborate(self, platform): + return Module() -class FieldTestCase(unittest.TestCase): - def test_simple(self): - field = Field(unsigned(4), "rw") - self.assertEqual(field.shape, unsigned(4)) - self.assertEqual(field.access, FieldPort.Access.RW) - - def test_compatible(self): - self.assertTrue(_compatible_fields(Field(unsigned(4), "rw"), - Field(unsigned(4), FieldPort.Access.RW))) - self.assertFalse(_compatible_fields(Field(unsigned(3), "r" ), Field(unsigned(4), "r"))) - self.assertFalse(_compatible_fields(Field(unsigned(4), "rw"), Field(unsigned(4), "w"))) - self.assertFalse(_compatible_fields(Field(unsigned(4), "rw"), Field(unsigned(4), "r"))) - self.assertFalse(_compatible_fields(Field(unsigned(4), "r" ), Field(unsigned(4), "w"))) + field_u8 = Field(MockField, unsigned(8), reset=1).create() + self.assertEqual(field_u8.port.shape, unsigned(8)) + self.assertEqual(field_u8.reset, 1) + _silence_unused(field_u8) - def test_wrong_shape(self): - with self.assertRaisesRegex(TypeError, - r"Field shape must be a shape-castable object, not 'foo'"): - Field("foo", "rw") + def test_create_multiple(self): + class MockField(wiring.Component): + port: Out(FieldPort.Signature(unsigned(8), "rw")) - def test_wrong_access(self): - with self.assertRaisesRegex(ValueError, r"'wo' is not a valid FieldPort.Access"): - Field(8, "wo") + def elaborate(self, platform): + return Module() + + field_1 = Field(MockField).create() + field_2 = Field(MockField).create() + self.assertIsNot(field_1, field_2) + _silence_unused(field_1, field_2) class FieldMapTestCase(unittest.TestCase): def test_simple(self): field_map = FieldMap({ - "a": Field(unsigned(1), "r"), - "b": Field(signed(3), "rw"), - "c": FieldMap({ - "d": Field(unsigned(4), "rw"), - }), + "a": Field(field.R, unsigned(1)), + "b": Field(field.RW, signed(3)), + "c": {"d": Field(field.RW, unsigned(4))}, }) - self.assertTrue(_compatible_fields(field_map["a"], Field(unsigned(1), "r"))) - self.assertTrue(_compatible_fields(field_map["b"], Field(signed(3), "rw"))) - self.assertTrue(_compatible_fields(field_map["c"]["d"], Field(unsigned(4), "rw"))) - self.assertTrue(_compatible_fields(field_map.a, Field(unsigned(1), "r"))) - self.assertTrue(_compatible_fields(field_map.b, Field(signed(3), "rw"))) - self.assertTrue(_compatible_fields(field_map.c.d, Field(unsigned(4), "rw"))) + field_r_u1 = Field(field.R, unsigned(1)).create() + field_rw_s3 = Field(field.RW, signed(3)).create() + field_rw_u4 = Field(field.RW, unsigned(4)).create() + + self.assertTrue(_compatible_fields(field_map["a"], field_r_u1)) + self.assertTrue(_compatible_fields(field_map["b"], field_rw_s3)) + self.assertTrue(_compatible_fields(field_map["c"]["d"], field_rw_u4)) + + self.assertTrue(_compatible_fields(field_map.a, field_r_u1)) + self.assertTrue(_compatible_fields(field_map.b, field_rw_s3)) + self.assertTrue(_compatible_fields(field_map.c.d, field_rw_u4)) self.assertEqual(len(field_map), 3) + _silence_unused(*(v for k, v in field_map.flatten())) + _silence_unused(field_r_u1, field_rw_s3, field_rw_u4) + def test_iter(self): field_map = FieldMap({ - "a": Field(unsigned(1), "r"), - "b": Field(signed(3), "rw") + "a": Field(field.R, unsigned(1)), + "b": Field(field.RW, signed(3)) }) self.assertEqual(list(field_map.items()), [ ("a", field_map["a"]), ("b", field_map["b"]), ]) + _silence_unused(*(v for k, v in field_map.flatten())) def test_flatten(self): field_map = FieldMap({ - "a": Field(unsigned(1), "r"), - "b": Field(signed(3), "rw"), - "c": FieldMap({ - "d": Field(unsigned(4), "rw"), - }), + "a": Field(field.R, unsigned(1)), + "b": Field(field.RW, signed(3)), + "c": {"d": Field(field.RW, unsigned(4))}, }) self.assertEqual(list(field_map.flatten()), [ (("a",), field_map["a"]), (("b",), field_map["b"]), (("c", "d"), field_map["c"]["d"]), ]) + _silence_unused(*(v for k, v in field_map.flatten())) - def test_wrong_mapping(self): + def test_wrong_dict(self): with self.assertRaisesRegex(TypeError, - r"Fields must be provided as a non-empty mapping, not 'foo'"): + r"Fields must be provided as a non-empty dict, not 'foo'"): FieldMap("foo") def test_wrong_field_key(self): with self.assertRaisesRegex(TypeError, r"Field name must be a non-empty string, not 1"): - FieldMap({1: Field(unsigned(1), "rw")}) + FieldMap({1: Field(field.RW, unsigned(1))}) with self.assertRaisesRegex(TypeError, r"Field name must be a non-empty string, not ''"): - FieldMap({"": Field(unsigned(1), "rw")}) + FieldMap({"": Field(field.RW, unsigned(1))}) def test_wrong_field_value(self): with self.assertRaisesRegex(TypeError, - r"Field must be a Field or a FieldMap or a FieldArray, not unsigned\(1\)"): + r"unsigned\(1\) must be a Field object or a collection of Field objects"): FieldMap({"a": unsigned(1)}) def test_getitem_wrong_key(self): + field_map = FieldMap({"a": Field(field.RW, unsigned(1))}) with self.assertRaises(KeyError): - FieldMap({"a": Field(unsigned(1), "rw")})["b"] + field_map["b"] + _silence_unused(*(v for k, v in field_map.flatten())) + + def test_getitem_reserved(self): + field_map = FieldMap({"_reserved": Field(field.RW, unsigned(1))}) + field_rw_u1 = Field(field.RW, unsigned(1)).create() + self.assertTrue(_compatible_fields(field_map["_reserved"], field_rw_u1)) + _silence_unused(*(v for k, v in field_map.flatten())) + _silence_unused(field_rw_u1) + + def test_getattr_missing(self): + field_map = FieldMap({"a": Field(field.RW, unsigned(1)), + "b": Field(field.RW, unsigned(1))}) + with self.assertRaisesRegex(AttributeError, + r"Field map does not have a field 'c'; did you mean one of: 'a', 'b'?"): + field_map.c + _silence_unused(*(v for k, v in field_map.flatten())) + + def test_getattr_reserved(self): + field_map = FieldMap({"_reserved": Field(field.RW, unsigned(1))}) + with self.assertRaisesRegex(AttributeError, + r"Field map field '_reserved' has a reserved name and may only be accessed by " + r"indexing"): + field_map._reserved + _silence_unused(*(v for k, v in field_map.flatten())) class FieldArrayTestCase(unittest.TestCase): def test_simple(self): - field_array = FieldArray([Field(unsigned(2), "rw") for _ in range(8)]) + field_array = FieldArray([Field(field.RW, unsigned(2)) for _ in range(8)]) + field_rw_u2 = Field(field.RW, unsigned(2)).create() self.assertEqual(len(field_array), 8) for i in range(8): - self.assertTrue(_compatible_fields(field_array[i], Field(unsigned(2), "rw"))) + self.assertTrue(_compatible_fields(field_array[i], field_rw_u2)) + _silence_unused(*(v for k, v in field_array.flatten())) + _silence_unused(field_rw_u2) def test_dim_2(self): - field_array = FieldArray([FieldArray([Field(unsigned(1), "rw") for _ in range(4)]) + field_array = FieldArray([[Field(field.RW, unsigned(1)) for _ in range(4)] for _ in range(4)]) + field_rw_u1 = Field(field.RW, unsigned(1)).create() self.assertEqual(len(field_array), 4) for i in range(4): for j in range(4): - self.assertTrue(_compatible_fields(field_array[i][j], Field(1, "rw"))) + self.assertTrue(_compatible_fields(field_array[i][j], field_rw_u1)) + _silence_unused(*(v for k, v in field_array.flatten())) + _silence_unused(field_rw_u1) def test_nested(self): - field_array = FieldArray([ - FieldMap({ - "a": Field(unsigned(4), "rw"), - "b": FieldArray([Field(unsigned(1), "rw") for _ in range(4)]), - }) for _ in range(4)]) + field_array = FieldArray([{"a": Field(field.RW, unsigned(4)), + "b": [Field(field.RW, unsigned(1)) for _ in range(4)]} + for _ in range(4)]) + field_rw_u4 = Field(field.RW, unsigned(4)).create() + field_rw_u1 = Field(field.RW, unsigned(1)).create() self.assertEqual(len(field_array), 4) for i in range(4): - self.assertTrue(_compatible_fields(field_array[i]["a"], Field(unsigned(4), "rw"))) + self.assertTrue(_compatible_fields(field_array[i]["a"], field_rw_u4)) for j in range(4): - self.assertTrue(_compatible_fields(field_array[i]["b"][j], - Field(unsigned(1), "rw"))) + self.assertTrue(_compatible_fields(field_array[i]["b"][j], field_rw_u1)) + _silence_unused(*(v for k, v in field_array.flatten())) + _silence_unused(field_rw_u4, field_rw_u1) def test_iter(self): - field_array = FieldArray([Field(1, "rw") for _ in range(3)]) + field_array = FieldArray([Field(field.RW, 1) for _ in range(3)]) self.assertEqual(list(field_array), [ field_array[i] for i in range(3) ]) + _silence_unused(*(v for k, v in field_array.flatten())) def test_flatten(self): - field_array = FieldArray([ - FieldMap({ - "a": Field(4, "rw"), - "b": FieldArray([Field(1, "rw") for _ in range(2)]), - }) for _ in range(2)]) + field_array = FieldArray([{"a": Field(field.RW, 4), + "b": [Field(field.RW, 1) for _ in range(2)]} + for _ in range(2)]) self.assertEqual(list(field_array.flatten()), [ ((0, "a"), field_array[0]["a"]), ((0, "b", 0), field_array[0]["b"][0]), @@ -241,60 +299,124 @@ def test_flatten(self): ((1, "b", 0), field_array[1]["b"][0]), ((1, "b", 1), field_array[1]["b"][1]), ]) + _silence_unused(*(v for k, v in field_array.flatten())) + + def test_wrong_fields(self): + with self.assertRaisesRegex(TypeError, + r"Fields must be provided as a non-empty list, not 'foo'"): + FieldArray("foo") + with self.assertRaisesRegex(TypeError, + r"Fields must be provided as a non-empty list, not \[\]"): + FieldArray([]) def test_wrong_field(self): with self.assertRaisesRegex(TypeError, - r"Field must be a Field or a FieldMap or a FieldArray, not 'foo'"): - FieldArray([Field(1, "rw"), "foo"]) + r"'foo' must be a Field object or a collection of Field objects"): + FieldArray(["foo", Field(field.RW, 1)]) class RegisterTestCase(unittest.TestCase): def test_simple(self): - reg = Register("rw", FieldMap({ - "a": field.R(unsigned(1)), - "b": field.RW1C(unsigned(3)), - "c": FieldMap({"d": field.RW(signed(2))}), - "e": FieldArray([field.W(unsigned(1)) for _ in range(2)]) - })) - - self.assertTrue(_compatible_fields(reg.f.a, field.R(unsigned(1)))) - self.assertTrue(_compatible_fields(reg.f.b, field.RW1C(unsigned(3)))) - self.assertTrue(_compatible_fields(reg.f.c.d, field.RW(signed(2)))) - self.assertTrue(_compatible_fields(reg.f.e[0], field.W(unsigned(1)))) - self.assertTrue(_compatible_fields(reg.f.e[1], field.W(unsigned(1)))) + reg = Register("rw", { + "a": Field(field.R, unsigned(1)), + "b": Field(field.RW1C, unsigned(3)), + "c": {"d": Field(field.RW, signed(2))}, + "e": [Field(field.W, unsigned(1)) for _ in range(2)] + }) + field_r_u1 = Field(field.R, unsigned(1)).create() + field_rw1c_u3 = Field(field.RW1C, unsigned(3)).create() + field_rw_s2 = Field(field.RW, signed(2)).create() + field_w_u1 = Field(field.W, unsigned(1)).create() + + self.assertTrue(_compatible_fields(reg.f.a, field_r_u1)) + self.assertTrue(_compatible_fields(reg.f.b, field_rw1c_u3)) + self.assertTrue(_compatible_fields(reg.f.c.d, field_rw_s2)) + self.assertTrue(_compatible_fields(reg.f.e[0], field_w_u1)) + self.assertTrue(_compatible_fields(reg.f.e[1], field_w_u1)) + + self.assertEqual(reg.access, Element.Access.RW) self.assertEqual(reg.element.width, 8) self.assertEqual(reg.element.access.readable(), True) self.assertEqual(reg.element.access.writable(), True) + _silence_unused(reg, field_r_u1, field_rw1c_u3, field_rw_s2, field_w_u1) + def test_annotations(self): class MockRegister(Register): - a: field.R(unsigned(1)) - b: field.RW1C(unsigned(3)) - c: FieldMap({"d": field.RW(signed(2))}) - e: FieldArray([field.W(unsigned(1)) for _ in range(2)]) + a: Field(field.R, unsigned(1)) + b: {"c": Field(field.RW1C, unsigned(3)), + "d": [Field(field.W, unsigned(1)) for _ in range(2)]} + e: [{"f": Field(field.RW, signed(2))} for _ in range(2)] + [ + [Field(field.RW, signed(2))]] + g: {"x": "foo", "y": dict(), "z": "bar"} + h: ["foo", [], dict()] foo: unsigned(42) reg = MockRegister("rw") - self.assertTrue(_compatible_fields(reg.f.a, field.R(unsigned(1)))) - self.assertTrue(_compatible_fields(reg.f.b, field.RW1C(unsigned(3)))) - self.assertTrue(_compatible_fields(reg.f.c.d, field.RW(signed(2)))) - self.assertTrue(_compatible_fields(reg.f.e[0], field.W(unsigned(1)))) - self.assertTrue(_compatible_fields(reg.f.e[1], field.W(unsigned(1)))) + field_r_u1 = Field(field.R, unsigned(1)).create() + field_rw1c_u3 = Field(field.RW1C, unsigned(3)).create() + field_w_u1 = Field(field.W, unsigned(1)).create() + field_rw_s2 = Field(field.RW, signed(2)).create() - self.assertEqual(reg.element.width, 8) + self.assertTrue(_compatible_fields(reg.f.a, field_r_u1)) + self.assertTrue(_compatible_fields(reg.f.b.c, field_rw1c_u3)) + self.assertTrue(_compatible_fields(reg.f.b.d[0], field_w_u1)) + self.assertTrue(_compatible_fields(reg.f.b.d[1], field_w_u1)) + self.assertTrue(_compatible_fields(reg.f.e[0].f, field_rw_s2)) + self.assertTrue(_compatible_fields(reg.f.e[1].f, field_rw_s2)) + self.assertTrue(_compatible_fields(reg.f.e[2][0], field_rw_s2)) + + self.assertEqual(reg.element.width, 12) self.assertEqual(reg.element.access.readable(), True) self.assertEqual(reg.element.access.writable(), True) + _silence_unused(reg, field_r_u1, field_rw1c_u3, field_w_u1, field_rw_s2) + + def test_annotations_conflict(self): + class MockRegister(Register): + a: Field(field.R, unsigned(1)) + with self.assertRaisesRegex(ValueError, + r"Field collection \{'b': <.*>\} cannot be provided in addition to field " + r"annotations: a"): + MockRegister("rw", {"b": Field(field.W, unsigned(1))}) + + def test_annotations_other(self): + class MockRegister(Register): + foo: "bar" + reg = MockRegister("rw", {"a": Field(field.R, unsigned(1))}) + field_r_u1 = Field(field.R, unsigned(1)).create() + self.assertTrue(_compatible_fields(reg.f.a, field_r_u1)) + self.assertEqual(reg.element.width, 1) + _silence_unused(reg, field_r_u1) + + def test_wrong_access(self): + with self.assertRaisesRegex(TypeError, + r"Access mode must be one of \"r\", \"w\", or \"rw\", not 'foo'"): + Register(access="foo") + + def test_access_mismatch(self): + class _UnusedField(Field): + def create(self): + obj = super().create() + _silence_unused(obj) + return obj + with self.assertRaisesRegex(ValueError, + r"Field a__b is readable, but register access mode is \"): + Register("w", {"a": {"b": _UnusedField(field.RW, unsigned(1))}}) + with self.assertRaisesRegex(ValueError, + r"Field a__b is writable, but register access mode is \"): + Register("r", {"a": {"b": _UnusedField(field.RW, unsigned(1))}}) + def test_iter(self): - reg = Register("rw", FieldMap({ - "a": field.R(unsigned(1)), - "b": field.RW1C(unsigned(3)), - "c": FieldMap({"d": field.RW(signed(2))}), - "e": FieldArray([field.W(unsigned(1)) for _ in range(2)]) - })) + reg = Register("rw", { + "a": Field(field.R, unsigned(1)), + "b": Field(field.RW1C, unsigned(3)), + "c": {"d": Field(field.RW, signed(2))}, + "e": [Field(field.W, unsigned(1)) for _ in range(2)] + }) self.assertEqual(list(reg), [ (("a",), reg.f.a), (("b",), reg.f.b), @@ -302,15 +424,16 @@ def test_iter(self): (("e", 0), reg.f.e[0]), (("e", 1), reg.f.e[1]), ]) + _silence_unused(reg) def test_sim(self): - dut = Register("rw", FieldMap({ - "a": field.R(unsigned(1)), - "b": field.RW1C(unsigned(3), reset=0b111), - "c": FieldMap({"d": field.RW(signed(2), reset=-1)}), - "e": FieldArray([field.W(unsigned(1)) for _ in range(2)]), - "f": field.RW1S(unsigned(3)), - })) + dut = Register("rw", { + "a": Field(field.R, unsigned(1)), + "b": Field(field.RW1C, unsigned(3), reset=0b111), + "c": {"d": Field(field.RW, signed(2), reset=-1)}, + "e": [Field(field.W, unsigned(1)) for _ in range(2)], + "f": Field(field.RW1S, unsigned(3)), + }) def process(): # Check reset values: @@ -438,91 +561,97 @@ def process(): class RegisterMapTestCase(unittest.TestCase): def setUp(self): - self.dut = RegisterMap() + self.map = RegisterMap() def test_add_register(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) - self.assertIs(self.dut.add_register(reg_rw_a, name="reg_rw_a"), reg_rw_a) + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) + self.assertIs(self.map.add_register(reg_rw_a, name="reg_rw_a"), reg_rw_a) + _silence_unused(reg_rw_a) def test_add_register_frozen(self): - self.dut.freeze() - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + self.map.freeze() + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) with self.assertRaisesRegex(ValueError, r"Register map is frozen"): - self.dut.add_register(reg_rw_a, name="reg_rw_a") + self.map.add_register(reg_rw_a, name="reg_rw_a") + _silence_unused(reg_rw_a) def test_add_register_wrong_type(self): with self.assertRaisesRegex(TypeError, r"Register must be an instance of csr\.Register, not 'foo'"): - self.dut.add_register("foo", name="foo") + self.map.add_register("foo", name="foo") def test_add_register_wrong_name(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) with self.assertRaisesRegex(TypeError, r"Name must be a non-empty string, not None"): - self.dut.add_register(reg_rw_a, name=None) + self.map.add_register(reg_rw_a, name=None) + _silence_unused(reg_rw_a) def test_add_register_empty_name(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) with self.assertRaisesRegex(TypeError, r"Name must be a non-empty string, not ''"): - self.dut.add_register(reg_rw_a, name="") + self.map.add_register(reg_rw_a, name="") + _silence_unused(reg_rw_a) def test_add_cluster(self): cluster = RegisterMap() - self.assertIs(self.dut.add_cluster(cluster, name="cluster"), cluster) + self.assertIs(self.map.add_cluster(cluster, name="cluster"), cluster) def test_add_cluster_frozen(self): - self.dut.freeze() + self.map.freeze() cluster = RegisterMap() with self.assertRaisesRegex(ValueError, r"Register map is frozen"): - self.dut.add_cluster(cluster, name="cluster") + self.map.add_cluster(cluster, name="cluster") def test_add_cluster_wrong_type(self): with self.assertRaisesRegex(TypeError, r"Cluster must be an instance of csr\.RegisterMap, not 'foo'"): - self.dut.add_cluster("foo", name="foo") + self.map.add_cluster("foo", name="foo") def test_add_cluster_wrong_name(self): cluster = RegisterMap() with self.assertRaisesRegex(TypeError, r"Name must be a non-empty string, not None"): - self.dut.add_cluster(cluster, name=None) + self.map.add_cluster(cluster, name=None) def test_add_cluster_empty_name(self): cluster = RegisterMap() with self.assertRaisesRegex(TypeError, r"Name must be a non-empty string, not ''"): - self.dut.add_cluster(cluster, name="") + self.map.add_cluster(cluster, name="") def test_namespace_collision(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) - reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) + reg_rw_b = Register("rw", {"b": Field(field.RW, 1)}) cluster_0 = RegisterMap() cluster_1 = RegisterMap() - self.dut.add_register(reg_rw_a, name="reg_rw_a") - self.dut.add_cluster(cluster_0, name="cluster_0") + self.map.add_register(reg_rw_a, name="reg_rw_a") + self.map.add_cluster(cluster_0, name="cluster_0") with self.assertRaisesRegex(ValueError, # register/register r"Name 'reg_rw_a' is already used by *"): - self.dut.add_register(reg_rw_b, name="reg_rw_a") + self.map.add_register(reg_rw_b, name="reg_rw_a") with self.assertRaisesRegex(ValueError, # register/cluster r"Name 'reg_rw_a' is already used by *"): - self.dut.add_cluster(cluster_1, name="reg_rw_a") + self.map.add_cluster(cluster_1, name="reg_rw_a") with self.assertRaisesRegex(ValueError, # cluster/cluster r"Name 'cluster_0' is already used by *"): - self.dut.add_cluster(cluster_1, name="cluster_0") + self.map.add_cluster(cluster_1, name="cluster_0") with self.assertRaisesRegex(ValueError, # cluster/register r"Name 'cluster_0' is already used by *"): - self.dut.add_register(reg_rw_b, name="cluster_0") + self.map.add_register(reg_rw_b, name="cluster_0") + + _silence_unused(reg_rw_a, reg_rw_b) def test_iter_registers(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) - reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) - self.dut.add_register(reg_rw_a, name="reg_rw_a") - self.dut.add_register(reg_rw_b, name="reg_rw_b") + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) + reg_rw_b = Register("rw", {"b": Field(field.RW, 1)}) + self.map.add_register(reg_rw_a, name="reg_rw_a") + self.map.add_register(reg_rw_b, name="reg_rw_b") - registers = list(self.dut.registers()) + registers = list(self.map.registers()) self.assertEqual(len(registers), 2) self.assertIs(registers[0][0], reg_rw_a) @@ -530,13 +659,15 @@ def test_iter_registers(self): self.assertIs(registers[1][0], reg_rw_b) self.assertEqual(registers[1][1], "reg_rw_b") + _silence_unused(reg_rw_a, reg_rw_b) + def test_iter_clusters(self): cluster_0 = RegisterMap() cluster_1 = RegisterMap() - self.dut.add_cluster(cluster_0, name="cluster_0") - self.dut.add_cluster(cluster_1, name="cluster_1") + self.map.add_cluster(cluster_0, name="cluster_0") + self.map.add_cluster(cluster_1, name="cluster_1") - clusters = list(self.dut.clusters()) + clusters = list(self.map.clusters()) self.assertEqual(len(clusters), 2) self.assertIs(clusters[0][0], cluster_0) @@ -545,18 +676,18 @@ def test_iter_clusters(self): self.assertEqual(clusters[1][1], "cluster_1") def test_iter_flatten(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) - reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) + reg_rw_b = Register("rw", {"b": Field(field.RW, 1)}) cluster_0 = RegisterMap() cluster_1 = RegisterMap() cluster_0.add_register(reg_rw_a, name="reg_rw_a") cluster_1.add_register(reg_rw_b, name="reg_rw_b") - self.dut.add_cluster(cluster_0, name="cluster_0") - self.dut.add_cluster(cluster_1, name="cluster_1") + self.map.add_cluster(cluster_0, name="cluster_0") + self.map.add_cluster(cluster_1, name="cluster_1") - registers = list(self.dut.flatten()) + registers = list(self.map.flatten()) self.assertEqual(len(registers), 2) self.assertIs(registers[0][0], reg_rw_a) @@ -564,63 +695,80 @@ def test_iter_flatten(self): self.assertIs(registers[1][0], reg_rw_b) self.assertEqual(registers[1][1], ("cluster_1", "reg_rw_b")) + _silence_unused(reg_rw_a, reg_rw_b) + def test_get_path(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) - reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) - cluster_0 = RegisterMap() + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) + reg_rw_b = Register("rw", {"b": Field(field.RW, 1)}) + reg_rw_c = Register("rw", {"c": Field(field.RW, 1)}) + cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_a, name="reg_rw_a") - self.dut.add_cluster(cluster_0, name="cluster_0") - self.dut.add_register(reg_rw_b, name="reg_rw_b") - self.assertEqual(self.dut.get_path(reg_rw_a), ("cluster_0", "reg_rw_a")) - self.assertEqual(self.dut.get_path(reg_rw_b), ("reg_rw_b",)) + cluster_1 = RegisterMap() + cluster_1.add_register(reg_rw_c, name="reg_rw_c") + + self.map.add_cluster(cluster_0, name="cluster_0") + self.map.add_register(reg_rw_b, name="reg_rw_b") + self.map.add_cluster(cluster_1, name="cluster_1") + + self.assertEqual(self.map.get_path(reg_rw_a), ("cluster_0", "reg_rw_a")) + self.assertEqual(self.map.get_path(reg_rw_b), ("reg_rw_b",)) + self.assertEqual(self.map.get_path(reg_rw_c), ("cluster_1", "reg_rw_c")) + + _silence_unused(reg_rw_a, reg_rw_b, reg_rw_c) def test_get_path_wrong_register(self): with self.assertRaisesRegex(TypeError, r"Register must be an instance of csr\.Register, not 'foo'"): - self.dut.get_path("foo") + self.map.get_path("foo") def test_get_path_unknown_register(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) with self.assertRaises(KeyError): - self.dut.get_path(reg_rw_a) + self.map.get_path(reg_rw_a) + _silence_unused(reg_rw_a) def test_get_register(self): - reg_rw_a = Register("rw", FieldMap({"a": field.RW(1)})) - reg_rw_b = Register("rw", FieldMap({"b": field.RW(1)})) + reg_rw_a = Register("rw", {"a": Field(field.RW, 1)}) + reg_rw_b = Register("rw", {"b": Field(field.RW, 1)}) cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_a, name="reg_rw_a") - self.dut.add_cluster(cluster_0, name="cluster_0") - self.dut.add_register(reg_rw_b, name="reg_rw_b") + self.map.add_cluster(cluster_0, name="cluster_0") + self.map.add_register(reg_rw_b, name="reg_rw_b") + + self.assertIs(self.map.get_register(("cluster_0", "reg_rw_a")), reg_rw_a) + self.assertIs(self.map.get_register(("reg_rw_b",)), reg_rw_b) - self.assertIs(self.dut.get_register(("cluster_0", "reg_rw_a")), reg_rw_a) - self.assertIs(self.dut.get_register(("reg_rw_b",)), reg_rw_b) + _silence_unused(reg_rw_a, reg_rw_b) def test_get_register_empty_path(self): with self.assertRaisesRegex(ValueError, r"Path must be a non-empty iterable"): - self.dut.get_register(()) + self.map.get_register(()) def test_get_register_wrong_path(self): with self.assertRaisesRegex(TypeError, r"Path must contain non-empty strings, not 0"): - self.dut.get_register(("cluster_0", 0)) + self.map.get_register(("cluster_0", 0)) with self.assertRaisesRegex(TypeError, r"Path must contain non-empty strings, not ''"): - self.dut.get_register(("", "reg_rw_a")) + self.map.get_register(("", "reg_rw_a")) def test_get_register_unknown_path(self): + self.map.add_cluster(RegisterMap(), name="cluster_0") + with self.assertRaises(KeyError): + self.map.get_register(("reg_rw_a",)) with self.assertRaises(KeyError): - self.dut.get_register(("reg_rw_a",)) + self.map.get_register(("cluster_0", "reg_rw_a")) class BridgeTestCase(unittest.TestCase): def test_memory_map(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4)})) - reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8)})) - reg_rw_12 = Register("rw", FieldMap({"a": field.RW(12)})) - reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4)}) + reg_rw_8 = Register("rw", {"a": Field(field.RW, 8)}) + reg_rw_12 = Register("rw", {"a": Field(field.RW, 12)}) + reg_rw_16 = Register("rw", {"a": Field(field.RW, 16)}) cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_12, name="reg_rw_12") @@ -650,7 +798,7 @@ def test_memory_map(self): self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") self.assertEqual(registers[3][2], (4, 6)) - Fragment.get(dut, platform=None) # silence UnusedElaboratable + _silence_unused(dut) def test_wrong_register_map(self): with self.assertRaisesRegex(TypeError, @@ -658,10 +806,10 @@ def test_wrong_register_map(self): dut = Bridge("foo", addr_width=16, data_width=8) def test_register_addr(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4)})) - reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8)})) - reg_rw_12 = Register("rw", FieldMap({"a": field.RW(12)})) - reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4)}) + reg_rw_8 = Register("rw", {"a": Field(field.RW, 8)}) + reg_rw_12 = Register("rw", {"a": Field(field.RW, 12)}) + reg_rw_16 = Register("rw", {"a": Field(field.RW, 16)}) cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_12, name="reg_rw_12") @@ -697,13 +845,13 @@ def test_register_addr(self): self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") self.assertEqual(registers[3][2], (0x22, 0x24)) - Fragment.get(dut, platform=None) # silence UnusedElaboratable + _silence_unused(dut) def test_register_alignment(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4)})) - reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8)})) - reg_rw_12 = Register("rw", FieldMap({"a": field.RW(12)})) - reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4)}) + reg_rw_8 = Register("rw", {"a": Field(field.RW, 8)}) + reg_rw_12 = Register("rw", {"a": Field(field.RW, 12)}) + reg_rw_16 = Register("rw", {"a": Field(field.RW, 16)}) cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_12, name="reg_rw_12") @@ -739,57 +887,62 @@ def test_register_alignment(self): self.assertEqual(registers[3][1], "cluster_0__reg_rw_16") self.assertEqual(registers[3][2], (16, 18)) - Fragment.get(dut, platform=None) # silence UnusedElaboratable + _silence_unused(dut) def test_register_out_of_bounds(self): - reg_rw_24 = Register("rw", FieldMap({"a": field.RW(24)})) + reg_rw_24 = Register("rw", {"a": Field(field.RW, 24)}) register_map = RegisterMap() register_map.add_register(reg_rw_24, name="reg_rw_24") with self.assertRaisesRegex(ValueError, r"Address range 0x0\.\.0x3 out of bounds for memory map spanning " r"range 0x0\.\.0x2 \(1 address bits\)"): dut = Bridge(register_map, addr_width=1, data_width=8) + _silence_unused(reg_rw_24) def test_wrong_register_address(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4)}) register_map = RegisterMap() register_map.add_register(reg_rw_4, name="reg_rw_4") - with self.assertRaisesRegex(TypeError, r"Register address must be a mapping, not 'foo'"): + with self.assertRaisesRegex(TypeError, r"Register address must be a dict, not 'foo'"): dut = Bridge(register_map, addr_width=1, data_width=8, register_addr="foo") + _silence_unused(reg_rw_4) def test_wrong_cluster_address(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4)}) cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_4, name="reg_rw_4") register_map = RegisterMap() register_map.add_cluster(cluster_0, name="cluster_0") with self.assertRaisesRegex(TypeError, - r"Register address \('cluster_0',\) must be a mapping, not 'foo'"): + r"Register address \('cluster_0',\) must be a dict, not 'foo'"): dut = Bridge(register_map, addr_width=1, data_width=8, register_addr={"cluster_0": "foo"}) + _silence_unused(reg_rw_4) def test_wrong_register_alignment(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4)}) register_map = RegisterMap() register_map.add_register(reg_rw_4, name="reg_rw_4") - with self.assertRaisesRegex(TypeError, r"Register alignment must be a mapping, not 'foo'"): + with self.assertRaisesRegex(TypeError, r"Register alignment must be a dict, not 'foo'"): dut = Bridge(register_map, addr_width=1, data_width=8, register_alignment="foo") + _silence_unused(reg_rw_4) def test_wrong_cluster_alignment(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW(4)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4)}) cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_4, name="reg_rw_4") register_map = RegisterMap() register_map.add_cluster(cluster_0, name="cluster_0") with self.assertRaisesRegex(TypeError, - r"Register alignment \('cluster_0',\) must be a mapping, not 'foo'"): + r"Register alignment \('cluster_0',\) must be a dict, not 'foo'"): dut = Bridge(register_map, addr_width=1, data_width=8, register_alignment={"cluster_0": "foo"}) + _silence_unused(reg_rw_4) def test_sim(self): - reg_rw_4 = Register("rw", FieldMap({"a": field.RW( 4, reset=0x0)})) - reg_rw_8 = Register("rw", FieldMap({"a": field.RW( 8, reset=0x11)})) - reg_rw_16 = Register("rw", FieldMap({"a": field.RW(16, reset=0x3322)})) + reg_rw_4 = Register("rw", {"a": Field(field.RW, 4, reset=0x0)}) + reg_rw_8 = Register("rw", {"a": Field(field.RW, 8, reset=0x11)}) + reg_rw_16 = Register("rw", {"a": Field(field.RW, 16, reset=0x3322)}) cluster_0 = RegisterMap() cluster_0.add_register(reg_rw_16, name="reg_rw_16")