From 39194ac3b107bfde366b7f65f20e5f3134025469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jean-Fran=C3=A7ois=20Nguyen?= Date: Mon, 24 Jul 2023 11:34:15 +0200 Subject: [PATCH] csr.bus: redesign Multiplexer shadow registers. Before this commit, csr.Multiplexer had separate shadows for every element in its memory map. The same shadow was shared for read and write accesses to an element; a combined read/write transaction was impossible despite being allowed by the CSR interface. After this commit, csr.Multiplexer has separate shadows for read and write accesses, but both shadows are shared by every element using them. For multiplexers with many elements, this approach also results in significant resource savings. --- amaranth_soc/csr/bus.py | 279 ++++++++++++++++++++++++++++++++++------ tests/test_csr_bus.py | 243 ++++++++++++++++++---------------- 2 files changed, 378 insertions(+), 144 deletions(-) diff --git a/amaranth_soc/csr/bus.py b/amaranth_soc/csr/bus.py index b075d34..24de525 100644 --- a/amaranth_soc/csr/bus.py +++ b/amaranth_soc/csr/bus.py @@ -1,6 +1,7 @@ +from collections import defaultdict +from math import ceil, log2 import enum from amaranth import * -from amaranth.utils import log2_int from ..memory import MemoryMap @@ -171,10 +172,183 @@ def memory_map(self, memory_map): class Multiplexer(Elaboratable): + class _Shadow: + class Chunk: + """The interface between of a CSR multiplexer and a shadow register chunk.""" + def __init__(self, shadow, offset, elements): + self.name = f"{shadow.name}__{offset}" + self.data = Signal(shadow.granularity, name=f"{self.name}__data") + self.r_en = Signal(name=f"{self.name}__r_en") + self.w_en = Signal(name=f"{self.name}__w_en") + self._elements = tuple(elements) + + def elements(self): + """Iterate the address ranges of CSR elements using this chunk.""" + yield from self._elements + + """CSR multiplexer shadow register. + + Attributes + ---------- + name : :class:`str` + Name of the shadow register. + granularity : :class:`int` + Amount of bits stored in a chunk of the shadow register. + overlaps : :class:`int` + Maximum amount of CSR elements that can share a chunk of the shadow register. Optional. + If ``None``, it is implicitly set by :meth:`Multiplexer._Shadow.prepare`. + """ + def __init__(self, granularity, overlaps, *, name): + assert isinstance(name, str) + assert isinstance(granularity, int) and granularity >= 0 + assert overlaps is None or isinstance(overlaps, int) and overlaps >= 0 + self.name = name + self.granularity = granularity + self.overlaps = overlaps + self._ranges = set() + self._size = 1 + self._chunks = None + + @property + def size(self): + """Size of the shadow register. + + Returns + ------- + :class:`int` + The amount of :class:`Multiplexer._Shadow.Chunk`s of the shadow. It can increase + by calling :meth:`Multiplexer._Shadow.add` or :meth:`Multiplexer._Shadow.prepare`. + """ + return self._size + + def add(self, elem_range): + """Add a CSR element to the shadow. + + Arguments + --------- + elem_range : :class:`range` + Address range of a CSR :class:`Element`. It uses ``2 ** ceil(log2(elem_range.stop - + elem_range.start))`` chunks of the shadow register. If this amount is greater than + :attr:`~Multiplexer._Shadow.size`, it replaces the latter. + """ + assert isinstance(elem_range, range) + self._ranges.add(elem_range) + elem_size = 2 ** ceil(log2(elem_range.stop - elem_range.start)) + self._size = max(self._size, elem_size) + + def decode_address(self, addr, elem_range): + """Decode a bus address into a shadow register offset. + + Returns + ------- + :class:`int` + The shadow register offset corresponding to the :class:`Multiplexer._Shadow.Chunk` + used by ``addr``. + + The address decoding scheme is illustrated by the following example: + * ``addr`` is ``0x1c``; + * ``elem_range`` is ``range(0x1b, 0x1f)``; + * the :attr:`~Multiplexer._Shadow.size` of the shadow is ``16``. + + The lower bits of the offset would be ``0b00``, extracted from ``addr``: + + .. code-block:: + + +----+--+--+ + |0001|11|00| + +----+--+--+ + │ └─ 0 + └──── ceil(log2(elem_range.stop - elem_range.start)) + + The upper bits of the offset would be ``0b10``, extracted from ``elem_range.start``: + + .. code-block:: + + +----+--+--+ + |0001|10|11| + +----+--+--+ + │ │ + │ └──── ceil(log2(elem_range.stop - elem_range.start)) + └─────── log2(self.size) + + + The decoded offset would therefore be ``0xc`` (i.e. ``0b1100``). + """ + assert elem_range in self._ranges and addr in elem_range + elem_size = 2 ** ceil(log2(elem_range.stop - elem_range.start)) + self_mask = self.size - 1 + elem_mask = elem_size - 1 + return elem_range.start & self_mask & ~elem_mask | addr & elem_mask + + def encode_offset(self, offset, elem_range): + """Encode a shadow register offset into a bus address. + + Returns + ------- + :class:`int` + The bus address in ``elem_range`` using the :class:`Multiplexer._Shadow.Chunk` + located at ``offset``. See :meth:`~Multiplexer._Shadow.decode_address` for details. + """ + assert elem_range in self._ranges and isinstance(offset, int) + elem_size = 2 ** ceil(log2(elem_range.stop - elem_range.start)) + return elem_range.start + ((offset - elem_range.start) % elem_size) + + def prepare(self): + """Balance out and instantiate the shadow register chunks. + + The scheme used by :meth:`~Multiplexer._Shadow.decode_address` allows multiple bus + addresses to be decoded to the same shadow register offset. Depending on the platform + and its toolchain, this may create nets with high fan-in (if the chunk is read from + the bus) or fan-out (if written), which may impact timing closure or resource usage. + + If any shadow register offset is aliased to more bus addresses than permitted by the + :attr:`~Multiplexer._Shadow.overlaps` constraint, the :attr:`~Multiplexer._Shadow.size` + of the shadow is doubled. This increases the number of address bits used for decoding, + which effectively balances chunk usage across the shadow register. + + This method is recursive until the overlap constraint is satisfied. + """ + if isinstance(self._ranges, frozenset): + return + if self.overlaps is None: + self.overlaps = len(self._ranges) + + elements = defaultdict(list) + balanced = True + + for elem_range in self._ranges: + for chunk_addr in elem_range: + chunk_offset = self.decode_address(chunk_addr, elem_range) + if len(elements[chunk_offset]) > self.overlaps: + balanced = False + break + elements[chunk_offset].append(elem_range) + + if balanced: + self._ranges = frozenset(self._ranges) + self._chunks = dict() + for chunk_offset, chunk_elements in elements.items(): + chunk = Multiplexer._Shadow.Chunk(self, chunk_offset, chunk_elements) + self._chunks[chunk_offset] = chunk + else: + self._size *= 2 + self.prepare() + + def chunks(self): + """Iterate shadow register chunks used by at least one CSR element.""" + for chunk_offset, chunk in self._chunks.items(): + yield chunk_offset, chunk + """CSR register multiplexer. An address-based multiplexer for CSR registers implementing atomic updates. + This implementation assumes the following from the CSR bus: + * an initiator must have exclusive ownership over the multiplexer for the full duration of + a register transaction; + * an initiator must access a register in ascending order of addresses, but it may abort a + transaction after any bus cycle. + Latency ------- @@ -214,16 +388,22 @@ class Multiplexer(Elaboratable): Register alignment. See :class:`..memory.MemoryMap`. name : str Window name. Optional. + shadow_overlaps : int + Maximum number of CSR registers that can share a chunk of a shadow register. + Optional. If ``None``, any number of CSR registers can share a shadow chunk. + See :class:`Multiplexer._Shadow` for details. Attributes ---------- bus : :class:`Interface` CSR bus providing access to registers. """ - def __init__(self, *, addr_width, data_width, alignment=0, name=None): + def __init__(self, *, addr_width, data_width, alignment=0, name=None, shadow_overlaps=None): self._map = MemoryMap(addr_width=addr_width, data_width=data_width, alignment=alignment, name=name) self._bus = None + self._r_shadow = Multiplexer._Shadow(data_width, shadow_overlaps, name="r_shadow") + self._w_shadow = Multiplexer._Shadow(data_width, shadow_overlaps, name="w_shadow") @property def bus(self): @@ -258,50 +438,77 @@ def add(self, element, *, addr=None, alignment=None, extend=False): def elaborate(self, platform): m = Module() - # Instead of a straightforward multiplexer for reads, use a per-element address comparator, - # AND the shadow register chunk with the comparator output, and OR all of those together. - # If the toolchain doesn't already synthesize multiplexer trees this way, this trick can - # save a significant amount of logic, since e.g. one 4-LUT can pack one 2-MUX, but two - # 2-AND or 2-OR gates. - r_data_fanin = 0 - for elem, _, (elem_start, elem_end) in self._map.resources(): - shadow = Signal(elem.width, name="{}__shadow".format(elem.name)) + elem_range = range(elem_start, elem_end) if elem.access.readable(): - shadow_en = Signal(elem_end - elem_start, name="{}__shadow_en".format(elem.name)) - m.d.sync += shadow_en.eq(0) + self._r_shadow.add(elem_range) if elem.access.writable(): - m.d.comb += elem.w_data.eq(shadow) - m.d.sync += elem.w_stb.eq(0) + self._w_shadow.add(elem_range) + + self._r_shadow.prepare() + self._w_shadow.prepare() + + # Instead of a straightforward multiplexer for reads, use an address comparator for each + # shadow register chunk, AND the comparator output with the chunk contents, and OR all of + # those together. If the toolchain doesn't already synthesize multiplexer trees this way, + # this trick can save a significant amount of logic, since e.g. one 4-LUT can pack one + # 2-MUX, but two 2-AND or 2-OR gates. + r_data_fanin = 0 + + for chunk_offset, r_chunk in self._r_shadow.chunks(): + # Use the same trick to select which element is read into a shadow register chunk. + r_chunk_w_en_fanin = 0 + r_chunk_data_fanin = 0 + + m.d.sync += r_chunk.r_en.eq(0) - # Enumerate every address used by the register explicitly, rather than using - # arithmetic comparisons, since some toolchains (e.g. Yosys) are too eager to infer - # carry chains for comparisons, even with a constant. (Register sizes don't have - # to be powers of 2.) with m.Switch(self.bus.addr): - for chunk_offset, chunk_addr in enumerate(range(elem_start, elem_end)): - shadow_slice = shadow.word_select(chunk_offset, self.bus.data_width) + for elem_range in r_chunk.elements(): + chunk_addr = self._r_shadow.encode_offset(chunk_offset, elem_range) + elem = self._map.decode_address(elem_range.start) + elem_offset = chunk_addr - elem_range.start + elem_slice = elem.r_data.word_select(elem_offset, self.bus.data_width) with m.Case(chunk_addr): - if elem.access.readable(): - r_data_fanin |= Mux(shadow_en[chunk_offset], shadow_slice, 0) - if chunk_addr == elem_start: - m.d.comb += elem.r_stb.eq(self.bus.r_stb) - with m.If(self.bus.r_stb): - m.d.sync += shadow.eq(elem.r_data) - # Delay by 1 cycle, allowing reads to be pipelined. - m.d.sync += shadow_en.eq(self.bus.r_stb << chunk_offset) - - if elem.access.writable(): - if chunk_addr == elem_end - 1: - # Delay by 1 cycle, avoiding combinatorial paths through - # the CSR bus and into CSR registers. - m.d.sync += elem.w_stb.eq(self.bus.w_stb) - with m.If(self.bus.w_stb): - m.d.sync += shadow_slice.eq(self.bus.w_data) + if chunk_addr == elem_range.start: + m.d.comb += elem.r_stb.eq(self.bus.r_stb) + # Delay by 1 cycle, allowing reads to be pipelined. + m.d.sync += r_chunk.r_en.eq(self.bus.r_stb) + + r_chunk_w_en_fanin |= elem.r_stb + r_chunk_data_fanin |= Mux(elem.r_stb, elem_slice, 0) + + m.d.comb += r_chunk.w_en.eq(r_chunk_w_en_fanin) + with m.If(r_chunk.w_en): + m.d.sync += r_chunk.data.eq(r_chunk_data_fanin) + + r_data_fanin |= Mux(r_chunk.r_en, r_chunk.data, 0) m.d.comb += self.bus.r_data.eq(r_data_fanin) + for chunk_offset, w_chunk in self._w_shadow.chunks(): + with m.Switch(self.bus.addr): + for elem_range in w_chunk.elements(): + chunk_addr = self._w_shadow.encode_offset(chunk_offset, elem_range) + elem = self._map.decode_address(elem_range.start) + elem_offset = chunk_addr - elem_range.start + elem_slice = elem.w_data.word_select(elem_offset, self.bus.data_width) + + if chunk_addr == elem_range.stop - 1: + m.d.sync += elem.w_stb.eq(0) + + with m.Case(chunk_addr): + if chunk_addr == elem_range.stop - 1: + # Delay by 1 cycle, avoiding combinatorial paths through + # the CSR bus and into CSR registers. + m.d.sync += elem.w_stb.eq(self.bus.w_stb) + m.d.comb += w_chunk.w_en.eq(self.bus.w_stb) + + m.d.comb += elem_slice.eq(w_chunk.data) + + with m.If(w_chunk.w_en): + m.d.sync += w_chunk.data.eq(self.bus.w_data) + return m diff --git a/tests/test_csr_bus.py b/tests/test_csr_bus.py index 0d56d3e..7b7ff47 100644 --- a/tests/test_csr_bus.py +++ b/tests/test_csr_bus.py @@ -172,81 +172,103 @@ def test_add_wrong_out_of_bounds(self): self.dut.add(elem, addr=0x10000) def test_sim(self): - elem_4_r = Element(4, "r") - self.dut.add(elem_4_r) - elem_8_w = Element(8, "w") - self.dut.add(elem_8_w) - elem_16_rw = Element(16, "rw") - self.dut.add(elem_16_rw) - - bus = self.dut.bus - - def sim_test(): - yield elem_4_r.r_data.eq(0xa) - yield elem_16_rw.r_data.eq(0x5aa5) - - yield bus.addr.eq(0) - yield bus.r_stb.eq(1) - yield - yield bus.r_stb.eq(0) - self.assertEqual((yield elem_4_r.r_stb), 1) - self.assertEqual((yield elem_16_rw.r_stb), 0) - yield - self.assertEqual((yield bus.r_data), 0xa) - - yield bus.addr.eq(2) - yield bus.r_stb.eq(1) - yield - yield bus.r_stb.eq(0) - self.assertEqual((yield elem_4_r.r_stb), 0) - self.assertEqual((yield elem_16_rw.r_stb), 1) - yield - yield bus.addr.eq(3) # pipeline a read - self.assertEqual((yield bus.r_data), 0xa5) - - yield bus.r_stb.eq(1) - yield - yield bus.r_stb.eq(0) - self.assertEqual((yield elem_4_r.r_stb), 0) - self.assertEqual((yield elem_16_rw.r_stb), 0) - yield - self.assertEqual((yield bus.r_data), 0x5a) - - yield bus.addr.eq(1) - yield bus.w_data.eq(0x3d) - yield bus.w_stb.eq(1) - yield - yield bus.w_stb.eq(0) - yield bus.addr.eq(2) # change address - yield - self.assertEqual((yield elem_8_w.w_stb), 1) - self.assertEqual((yield elem_8_w.w_data), 0x3d) - self.assertEqual((yield elem_16_rw.w_stb), 0) - yield - self.assertEqual((yield elem_8_w.w_stb), 0) - - yield bus.addr.eq(2) - yield bus.w_data.eq(0x55) - yield bus.w_stb.eq(1) - yield - self.assertEqual((yield elem_8_w.w_stb), 0) - self.assertEqual((yield elem_16_rw.w_stb), 0) - yield bus.addr.eq(3) # pipeline a write - yield bus.w_data.eq(0xaa) - yield - self.assertEqual((yield elem_8_w.w_stb), 0) - self.assertEqual((yield elem_16_rw.w_stb), 0) - yield bus.w_stb.eq(0) - yield - self.assertEqual((yield elem_8_w.w_stb), 0) - self.assertEqual((yield elem_16_rw.w_stb), 1) - self.assertEqual((yield elem_16_rw.w_data), 0xaa55) - - sim = Simulator(self.dut) - sim.add_clock(1e-6) - sim.add_sync_process(sim_test) - with sim.write_vcd(vcd_file=open("test.vcd", "w")): - sim.run() + for shadow_overlaps in [None, 0, 1]: + with self.subTest(shadow_overlaps=shadow_overlaps): + dut = Multiplexer(addr_width=16, data_width=8, shadow_overlaps=shadow_overlaps) + + elem_4_r = Element(4, "r") + dut.add(elem_4_r) + elem_8_w = Element(8, "w") + dut.add(elem_8_w) + elem_16_rw = Element(16, "rw") + dut.add(elem_16_rw) + + bus = dut.bus + + def sim_test(): + yield elem_4_r.r_data.eq(0xa) + yield elem_16_rw.r_data.eq(0x5aa5) + + yield bus.addr.eq(0) + yield bus.r_stb.eq(1) + yield + yield bus.r_stb.eq(0) + self.assertEqual((yield elem_4_r.r_stb), 1) + self.assertEqual((yield elem_16_rw.r_stb), 0) + yield + self.assertEqual((yield bus.r_data), 0xa) + + yield bus.addr.eq(2) + yield bus.r_stb.eq(1) + yield + yield bus.r_stb.eq(0) + self.assertEqual((yield elem_4_r.r_stb), 0) + self.assertEqual((yield elem_16_rw.r_stb), 1) + yield + yield bus.addr.eq(3) # pipeline a read + self.assertEqual((yield bus.r_data), 0xa5) + + yield bus.r_stb.eq(1) + yield + yield bus.r_stb.eq(0) + self.assertEqual((yield elem_4_r.r_stb), 0) + self.assertEqual((yield elem_16_rw.r_stb), 0) + yield + self.assertEqual((yield bus.r_data), 0x5a) + + yield bus.addr.eq(1) + yield bus.w_data.eq(0x3d) + yield bus.w_stb.eq(1) + yield + yield bus.w_stb.eq(0) + yield bus.addr.eq(2) # change address + yield + self.assertEqual((yield elem_8_w.w_stb), 1) + self.assertEqual((yield elem_8_w.w_data), 0x3d) + self.assertEqual((yield elem_16_rw.w_stb), 0) + yield + self.assertEqual((yield elem_8_w.w_stb), 0) + + yield bus.addr.eq(2) + yield bus.w_data.eq(0x55) + yield bus.w_stb.eq(1) + yield + self.assertEqual((yield elem_8_w.w_stb), 0) + self.assertEqual((yield elem_16_rw.w_stb), 0) + yield bus.addr.eq(3) # pipeline a write + yield bus.w_data.eq(0xaa) + yield + self.assertEqual((yield elem_8_w.w_stb), 0) + self.assertEqual((yield elem_16_rw.w_stb), 0) + yield bus.w_stb.eq(0) + yield + self.assertEqual((yield elem_8_w.w_stb), 0) + self.assertEqual((yield elem_16_rw.w_stb), 1) + self.assertEqual((yield elem_16_rw.w_data), 0xaa55) + + yield bus.addr.eq(2) + yield bus.r_stb.eq(1) + yield bus.w_data.eq(0x66) + yield bus.w_stb.eq(1) + yield + self.assertEqual((yield elem_16_rw.r_stb), 1) + self.assertEqual((yield elem_16_rw.w_stb), 0) + yield + yield bus.addr.eq(3) # pipeline a read and a write + yield bus.w_data.eq(0xbb) + self.assertEqual((yield bus.r_data), 0xa5) + yield + yield Delay() + self.assertEqual((yield bus.r_data), 0x5a) + self.assertEqual((yield elem_16_rw.r_stb), 0) + self.assertEqual((yield elem_16_rw.w_stb), 1) + self.assertEqual((yield elem_16_rw.w_data), 0xbb66) + + sim = Simulator(dut) + sim.add_clock(1e-6) + sim.add_sync_process(sim_test) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() class MultiplexerAlignedTestCase(unittest.TestCase): @@ -274,39 +296,44 @@ def test_under_align_to(self): self.assertEqual(self.dut.add(elem_1), (4, 8)) def test_sim(self): - elem_20_rw = Element(20, "rw") - self.dut.add(elem_20_rw) - - bus = self.dut.bus - - def sim_test(): - yield bus.w_stb.eq(1) - yield bus.addr.eq(0) - yield bus.w_data.eq(0x55) - yield - self.assertEqual((yield elem_20_rw.w_stb), 0) - yield bus.addr.eq(1) - yield bus.w_data.eq(0xaa) - yield - self.assertEqual((yield elem_20_rw.w_stb), 0) - yield bus.addr.eq(2) - yield bus.w_data.eq(0x33) - yield - self.assertEqual((yield elem_20_rw.w_stb), 0) - yield bus.addr.eq(3) - yield bus.w_data.eq(0xdd) - yield - self.assertEqual((yield elem_20_rw.w_stb), 0) - yield bus.w_stb.eq(0) - yield - self.assertEqual((yield elem_20_rw.w_stb), 1) - self.assertEqual((yield elem_20_rw.w_data), 0x3aa55) - - sim = Simulator(self.dut) - sim.add_clock(1e-6) - sim.add_sync_process(sim_test) - with sim.write_vcd(vcd_file=open("test.vcd", "w")): - sim.run() + for shadow_overlaps in [None, 0, 1]: + with self.subTest(shadow_overlaps=shadow_overlaps): + dut = Multiplexer(addr_width=16, data_width=8, alignment=2, + shadow_overlaps=shadow_overlaps) + + elem_20_rw = Element(20, "rw") + dut.add(elem_20_rw) + + bus = dut.bus + + def sim_test(): + yield bus.w_stb.eq(1) + yield bus.addr.eq(0) + yield bus.w_data.eq(0x55) + yield + self.assertEqual((yield elem_20_rw.w_stb), 0) + yield bus.addr.eq(1) + yield bus.w_data.eq(0xaa) + yield + self.assertEqual((yield elem_20_rw.w_stb), 0) + yield bus.addr.eq(2) + yield bus.w_data.eq(0x33) + yield + self.assertEqual((yield elem_20_rw.w_stb), 0) + yield bus.addr.eq(3) + yield bus.w_data.eq(0xdd) + yield + self.assertEqual((yield elem_20_rw.w_stb), 0) + yield bus.w_stb.eq(0) + yield + self.assertEqual((yield elem_20_rw.w_stb), 1) + self.assertEqual((yield elem_20_rw.w_data), 0x3aa55) + + sim = Simulator(dut) + sim.add_clock(1e-6) + sim.add_sync_process(sim_test) + with sim.write_vcd(vcd_file=open("test.vcd", "w")): + sim.run() class DecoderTestCase(unittest.TestCase):