csr.bus: add CSRElement and CSRMultiplexer.
authorwhitequark <whitequark@whitequark.org>
Mon, 21 Oct 2019 15:05:24 +0000 (15:05 +0000)
committerwhitequark <whitequark@whitequark.org>
Mon, 21 Oct 2019 15:05:31 +0000 (15:05 +0000)
.gitignore
nmigen_soc/csr/__init__.py [new file with mode: 0644]
nmigen_soc/csr/bus.py [new file with mode: 0644]
nmigen_soc/test/__init__.py [new file with mode: 0644]
nmigen_soc/test/test_csr_bus.py [new file with mode: 0644]

index 610c3991361f00fca236f12bcde2ab8f14cd2d8b..11fffae198762340e79072c1efffba4bef1e6df2 100644 (file)
@@ -2,3 +2,7 @@
 *.pyc
 /*.egg-info
 /.eggs
+
+# tests
+*.vcd
+*.gtkw
diff --git a/nmigen_soc/csr/__init__.py b/nmigen_soc/csr/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/nmigen_soc/csr/bus.py b/nmigen_soc/csr/bus.py
new file mode 100644 (file)
index 0000000..54129cc
--- /dev/null
@@ -0,0 +1,255 @@
+from functools import reduce
+from nmigen import *
+from nmigen import tracer
+
+
+__all__ = ["CSRElement", "CSRMultiplexer"]
+
+
+class CSRElement(Record):
+    """Peripheral-side CSR interface.
+
+    A low-level interface to a single atomically readable and writable register in a peripheral.
+    This interface supports any register width and semantics, provided that both reads and writes
+    always succeed and complete in one cycle.
+
+    Parameters
+    ----------
+    width : int
+        Width of the register.
+    name : str
+        Name of the underlying record.
+
+    Attributes
+    ----------
+    r_data : Signal(width)
+        Read data. Must be always valid, and is sampled when ``r_stb`` is asserted.
+    r_stb : Signal()
+        Read strobe. Registers with read side effects should perform the read side effect when this
+        strobe is asserted.
+    w_data : Signal(width)
+        Write data. Valid only when ``w_stb`` is asserted.
+    w_stb : Signal()
+        Write strobe. Registers should update their value or perform the write side effect when
+        this strobe is asserted.
+    """
+    def __init__(self, width, access, *, name=None, src_loc_at=0):
+        if not isinstance(width, int) or width < 0:
+            raise ValueError("Width must be a non-negative integer, not {!r}"
+                             .format(width))
+        if access not in ("r", "w", "rw"):
+            raise ValueError("Access mode must be one of \"r\", \"w\", or \"rw\", not {!r}"
+                             .format(access))
+
+        self.width  = int(width)
+        self.access = access
+
+        layout = []
+        if "r" in self.access:
+            layout += [
+                ("r_data", width),
+                ("r_stb",  1),
+            ]
+        if "w" in self.access:
+            layout += [
+                ("w_data", width),
+                ("w_stb",  1),
+            ]
+        super().__init__(layout, name=name, src_loc_at=1)
+
+
+class CSRMultiplexer(Elaboratable):
+    """CPU-side CSR interface.
+
+    A low-level interface to a set of peripheral CSR registers that implements address-based
+    multiplexing and atomic updates of wide registers.
+
+    Operation
+    ---------
+
+    The CSR multiplexer splits each CSR register into chunks according to its data width. Each
+    chunk is assigned an address, and the first chunk of each register always has the provided
+    minimum alignment. This allows accessing CSRs of any size using any datapath width.
+
+    When the first chunk of a register is read, the value of a register is captured, and reads
+    from subsequent chunks of the same register return the captured values. When any chunk except
+    the last chunk of a register is written, the written value is captured; a write to the last
+    chunk writes the captured value to the register. This allows atomically accessing CSRs larger
+    than datapath width.
+
+    Reads to padding bytes return zeroes, and writes to padding bytes are ignored.
+
+    Writes are registered, and add 1 cycle of latency.
+
+    Wide registers
+    --------------
+
+    Because the CSR bus conserves logic and routing resources, it is common to e.g. access
+    a CSR bus with an *n*-bit data path from a CPU with a *k*-bit datapath in cases where CSR
+    access latency is less important than resource usage. In this case, two strategies are
+    possible for connecting the CSR bus to the CPU:
+        * The CPU could access the CSR bus directly (with no intervening logic other than simple
+          translation of control signals). In this case, the register alignment should be set
+          to 1, and each *w*-bit register would occupy *ceil(w/n)* addresses from the CPU
+          perspective, requiring the same amount of memory instructions to access.
+        * The CPU could also access the CSR bus through a width down-converter, which would issue
+          *k/n* CSR accesses for each CPU access. In this case, the register alignment should be
+          set to *k/n*, and each *w*-bit register would occupy *ceil(w/k)* addresses from the CPU
+          perspective, requiring the same amount of memory instructions to access.
+
+    If alignment is greater than 1, it affects which CSR bus write is considered a write to
+    the last register chunk. For example, if a 24-bit register is used with a 8-bit CSR bus and
+    a CPU with a 32-bit datapath, a write to this register requires 4 CSR bus writes to complete
+    and the 4th write is the one that actually writes the value to the register. This allows
+    determining write latency solely from the amount of addresses the register occupies in
+    the CPU address space, and the width of the CSR bus.
+
+    Parameters
+    ----------
+    addr_width : int
+        Address width. At most ``(2 ** addr_width) * data_width`` register bits will be available.
+    data_width : int
+        Data width. Registers are accessed in ``data_width`` sized chunks.
+    alignment : int
+        Register alignment. The address assigned to each register will be a multiple of
+        ``2 ** alignment``.
+
+    Attributes
+    ----------
+    addr : Signal(addr_width)
+        Address for reads and writes.
+    r_data : Signal(data_width)
+        Read data. Valid on the next cycle after ``r_stb`` is asserted.
+    r_stb : Signal()
+        Read strobe. If ``addr`` points to the first chunk of a register, captures register value
+        and causes read side effects to be performed (if any). If ``addr`` points to any chunk
+        of a register, latches the captured value to ``r_data``. Otherwise, latches zero
+        to ``r_data``.
+    w_data : Signal(data_width)
+        Write data. Must be valid when ``w_stb`` is asserted.
+    w_stb : Signal()
+        Write strobe. If ``addr`` points to the last chunk of a register, writes captured value
+        to the register and causes write side effects to be performed (if any). If ``addr`` points
+        to any chunk of a register, latches ``w_data`` to the captured value. Otherwise, does
+        nothing.
+    """
+    def __init__(self, *, addr_width, data_width, alignment=0):
+        if not isinstance(addr_width, int) or addr_width <= 0:
+            raise ValueError("Address width must be a positive integer, not {!r}"
+                             .format(addr_width))
+        if not isinstance(data_width, int) or data_width <= 0:
+            raise ValueError("Data width must be a positive integer, not {!r}"
+                             .format(data_width))
+        if not isinstance(alignment, int) or alignment < 0:
+            raise ValueError("Alignment must be a non-negative integer, not {!r}"
+                             .format(alignment))
+
+        self.addr_width = int(addr_width)
+        self.data_width = int(data_width)
+        self.alignment  = alignment
+
+        self._next_addr = 0
+        self._elements  = dict()
+
+        self.addr   = Signal(addr_width)
+        self.r_data = Signal(data_width)
+        self.r_stb  = Signal()
+        self.w_data = Signal(data_width)
+        self.w_stb  = Signal()
+
+    def add(self, element):
+        """Add a register.
+
+        Arguments
+        ---------
+        element : CSRElement
+            Interface of the register.
+
+        Return value
+        ------------
+        An ``(addr, size)`` tuple, where ``addr`` is the address assigned to the first chunk of
+        the register, and ``size`` is the amount of chunks it takes, which may be greater than
+        ``element.size // self.data_width`` due to alignment.
+        """
+        if not isinstance(element, CSRElement):
+            raise TypeError("Element must be an instance of CSRElement, not {!r}"
+                            .format(element))
+
+        addr = self.align_to(self.alignment)
+        self._next_addr += (element.width + self.data_width - 1) // self.data_width
+        size = self.align_to(self.alignment) - addr
+        self._elements[addr] = element, size
+        return addr, size
+
+    def align_to(self, alignment):
+        """Align the next register explicitly.
+
+        Arguments
+        ---------
+        alignment : int
+            Register alignment. The address assigned to the next register will be a multiple of
+            ``2 ** alignment`` or ``2 ** self.alignment``, whichever is greater.
+
+        Return value
+        ------------
+        Address of the next register.
+        """
+        if not isinstance(alignment, int) or alignment < 0:
+            raise ValueError("Alignment must be a non-negative integer, not {!r}"
+                             .format(alignment))
+
+        align_chunks = 1 << alignment
+        if self._next_addr % align_chunks != 0:
+            self._next_addr += align_chunks - (self._next_addr % align_chunks)
+        return self._next_addr
+
+    def elaborate(self, platform):
+        m = Module()
+
+        # Instead of a straightforward multiplexer for reads, use a per-element address comparator,
+        # clear the shadow register when it does not match, and OR every selected shadow register
+        # part to form the output. This can save a significant amount of logic; the size of
+        # a complete k-OR or k-MUX gate tree for n inputs is `s = ceil((n - 1) / (k - 1))`,
+        # and its logic depth is `ceil(log_k(s))`, but a 4-LUT can implement either a 4-OR or
+        # a 2-MUX gate.
+        r_data_fanin = 0
+
+        for elem_addr, (elem, elem_size) in self._elements.items():
+            shadow = Signal(elem.width, name="{}__shadow".format(elem.name))
+            if "w" in elem.access:
+                m.d.comb += elem.w_data.eq(shadow)
+
+            # Enumerate every address used by the register explicitly, rather than using
+            # arithmetic comparisons, since some toolchains (e.g. Yosys) are too eager to infer
+            # carry chains for comparisons, even with a constant. (Register sizes don't have
+            # to be powers of 2.)
+            with m.Switch(self.addr):
+                for chunk_offset in range(elem_size):
+                    chunk_slice = slice(chunk_offset * self.data_width,
+                                        (chunk_offset + 1) * self.data_width)
+                    with m.Case(elem_addr + chunk_offset):
+                        if "r" in elem.access:
+                            chunk_r_stb = Signal(self.data_width,
+                                name="{}__r_stb_{}".format(elem.name, chunk_offset))
+                            r_data_fanin |= Mux(chunk_r_stb, shadow[chunk_slice], 0)
+                            if chunk_offset == 0:
+                                m.d.comb += elem.r_stb.eq(self.r_stb)
+                                with m.If(self.r_stb):
+                                    m.d.sync += shadow.eq(elem.r_data)
+                            # Delay by 1 cycle, allowing reads to be pipelined.
+                            m.d.sync += chunk_r_stb.eq(self.r_stb)
+
+                        if "w" in elem.access:
+                            if chunk_offset == elem_size - 1:
+                                # Delay by 1 cycle, avoiding combinatorial paths through
+                                # the CSR bus and into CSR registers.
+                                m.d.sync += elem.w_stb.eq(self.w_stb)
+                            with m.If(self.w_stb):
+                                m.d.sync += shadow[chunk_slice].eq(self.w_data)
+
+                with m.Default():
+                    m.d.sync += shadow.eq(0)
+
+        m.d.comb += self.r_data.eq(r_data_fanin)
+
+        return m
diff --git a/nmigen_soc/test/__init__.py b/nmigen_soc/test/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/nmigen_soc/test/test_csr_bus.py b/nmigen_soc/test/test_csr_bus.py
new file mode 100644 (file)
index 0000000..ac427b9
--- /dev/null
@@ -0,0 +1,248 @@
+import unittest
+from nmigen import *
+from nmigen.hdl.rec import Layout
+from nmigen.back.pysim import *
+
+from ..csr.bus import *
+
+
+class CSRElementTestCase(unittest.TestCase):
+    def test_1_ro(self):
+        elem = CSRElement(1, "r")
+        self.assertEqual(elem.width, 1)
+        self.assertEqual(elem.access, "r")
+        self.assertEqual(elem.layout, Layout.cast([
+            ("r_data", 1),
+            ("r_stb", 1),
+        ]))
+
+    def test_8_rw(self):
+        elem = CSRElement(8, access="rw")
+        self.assertEqual(elem.width, 8)
+        self.assertEqual(elem.access, "rw")
+        self.assertEqual(elem.layout, Layout.cast([
+            ("r_data", 8),
+            ("r_stb", 1),
+            ("w_data", 8),
+            ("w_stb", 1),
+        ]))
+
+    def test_10_wo(self):
+        elem = CSRElement(10, "w")
+        self.assertEqual(elem.width, 10)
+        self.assertEqual(elem.access, "w")
+        self.assertEqual(elem.layout, Layout.cast([
+            ("w_data", 10),
+            ("w_stb", 1),
+        ]))
+
+    def test_0_rw(self): # degenerate but legal case
+        elem = CSRElement(0, access="rw")
+        self.assertEqual(elem.width, 0)
+        self.assertEqual(elem.access, "rw")
+        self.assertEqual(elem.layout, Layout.cast([
+            ("r_data", 0),
+            ("r_stb", 1),
+            ("w_data", 0),
+            ("w_stb", 1),
+        ]))
+
+    def test_width_wrong(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Width must be a non-negative integer, not -1"):
+            CSRElement(-1, "rw")
+
+    def test_access_wrong(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Access mode must be one of \"r\", \"w\", or \"rw\", not 'wo'"):
+            CSRElement(1, "wo")
+
+
+class CSRMultiplexerTestCase(unittest.TestCase):
+    def setUp(self):
+        self.dut = CSRMultiplexer(addr_width=16, data_width=8)
+
+    def test_addr_width_wrong(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Address width must be a positive integer, not -1"):
+            CSRMultiplexer(addr_width=-1, data_width=8)
+
+    def test_data_width_wrong(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Data width must be a positive integer, not -1"):
+            CSRMultiplexer(addr_width=16, data_width=-1)
+
+    def test_alignment_wrong(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Alignment must be a non-negative integer, not -1"):
+            CSRMultiplexer(addr_width=16, data_width=8, alignment=-1)
+
+    def test_attrs(self):
+        self.assertEqual(self.dut.addr_width, 16)
+        self.assertEqual(self.dut.data_width, 8)
+        self.assertEqual(self.dut.alignment, 0)
+
+    def test_add_4b(self):
+        self.assertEqual(self.dut.add(CSRElement(4, "rw")),
+                         (0, 1))
+
+    def test_add_8b(self):
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (0, 1))
+
+    def test_add_12b(self):
+        self.assertEqual(self.dut.add(CSRElement(12, "rw")),
+                         (0, 2))
+
+    def test_add_16b(self):
+        self.assertEqual(self.dut.add(CSRElement(16, "rw")),
+                         (0, 2))
+
+    def test_add_two(self):
+        self.assertEqual(self.dut.add(CSRElement(16, "rw")),
+                         (0, 2))
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (2, 1))
+
+    def test_add_wrong(self):
+        with self.assertRaisesRegex(ValueError,
+                r"Width must be a non-negative integer, not -1"):
+            CSRElement(-1, "rw")
+
+    def test_align_to(self):
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (0, 1))
+        self.assertEqual(self.dut.align_to(2), 4)
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (4, 1))
+
+    def test_sim(self):
+        elem_4_r = CSRElement(4, "r")
+        self.dut.add(elem_4_r)
+        elem_8_w = CSRElement(8, "w")
+        self.dut.add(elem_8_w)
+        elem_16_rw = CSRElement(16, "rw")
+        self.dut.add(elem_16_rw)
+
+        def sim_test():
+            yield elem_4_r.r_data.eq(0xa)
+            yield elem_16_rw.r_data.eq(0x5aa5)
+
+            yield self.dut.addr.eq(0)
+            yield self.dut.r_stb.eq(1)
+            yield
+            yield self.dut.r_stb.eq(0)
+            self.assertEqual((yield elem_4_r.r_stb), 1)
+            self.assertEqual((yield elem_16_rw.r_stb), 0)
+            yield
+            self.assertEqual((yield self.dut.r_data), 0xa)
+
+            yield self.dut.addr.eq(2)
+            yield self.dut.r_stb.eq(1)
+            yield
+            yield self.dut.r_stb.eq(0)
+            self.assertEqual((yield elem_4_r.r_stb), 0)
+            self.assertEqual((yield elem_16_rw.r_stb), 1)
+            yield
+            yield self.dut.addr.eq(3) # pipeline a read
+            self.assertEqual((yield self.dut.r_data), 0xa5)
+
+            yield self.dut.r_stb.eq(1)
+            yield
+            yield self.dut.r_stb.eq(0)
+            self.assertEqual((yield elem_4_r.r_stb), 0)
+            self.assertEqual((yield elem_16_rw.r_stb), 0)
+            yield
+            self.assertEqual((yield self.dut.r_data), 0x5a)
+
+            yield self.dut.addr.eq(1)
+            yield self.dut.w_data.eq(0x3d)
+            yield self.dut.w_stb.eq(1)
+            yield
+            yield self.dut.w_stb.eq(0)
+            yield
+            self.assertEqual((yield elem_8_w.w_stb), 1)
+            self.assertEqual((yield elem_8_w.w_data), 0x3d)
+            self.assertEqual((yield elem_16_rw.w_stb), 0)
+
+            yield self.dut.addr.eq(2)
+            yield self.dut.w_data.eq(0x55)
+            yield self.dut.w_stb.eq(1)
+            yield
+            self.assertEqual((yield elem_8_w.w_stb), 0)
+            self.assertEqual((yield elem_16_rw.w_stb), 0)
+            yield self.dut.addr.eq(3) # pipeline a write
+            yield self.dut.w_data.eq(0xaa)
+            yield
+            self.assertEqual((yield elem_8_w.w_stb), 0)
+            self.assertEqual((yield elem_16_rw.w_stb), 0)
+            yield self.dut.w_stb.eq(0)
+            yield
+            self.assertEqual((yield elem_8_w.w_stb), 0)
+            self.assertEqual((yield elem_16_rw.w_stb), 1)
+            self.assertEqual((yield elem_16_rw.w_data), 0xaa55)
+
+        with Simulator(self.dut, vcd_file=open("test.vcd", "w")) as sim:
+            sim.add_clock(1e-6)
+            sim.add_sync_process(sim_test())
+            sim.run()
+
+
+class CSRAlignedMultiplexerTestCase(unittest.TestCase):
+    def setUp(self):
+        self.dut = CSRMultiplexer(addr_width=16, data_width=8, alignment=2)
+
+    def test_attrs(self):
+        self.assertEqual(self.dut.alignment, 2)
+
+    def test_add_two(self):
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (0, 4))
+        self.assertEqual(self.dut.add(CSRElement(16, "rw")),
+                         (4, 4))
+
+    def test_over_align_to(self):
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (0, 4))
+        self.assertEqual(self.dut.align_to(3), 8)
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (8, 4))
+
+    def test_under_align_to(self):
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (0, 4))
+        self.assertEqual(self.dut.align_to(1), 4)
+        self.assertEqual(self.dut.add(CSRElement(8, "rw")),
+                         (4, 4))
+
+    def test_sim(self):
+        elem_20_rw = CSRElement(20, "rw")
+        self.dut.add(elem_20_rw)
+
+        def sim_test():
+            yield self.dut.w_stb.eq(1)
+            yield self.dut.addr.eq(0)
+            yield self.dut.w_data.eq(0x55)
+            yield
+            self.assertEqual((yield elem_20_rw.w_stb), 0)
+            yield self.dut.addr.eq(1)
+            yield self.dut.w_data.eq(0xaa)
+            yield
+            self.assertEqual((yield elem_20_rw.w_stb), 0)
+            yield self.dut.addr.eq(2)
+            yield self.dut.w_data.eq(0x33)
+            yield
+            self.assertEqual((yield elem_20_rw.w_stb), 0)
+            yield self.dut.addr.eq(3)
+            yield self.dut.w_data.eq(0xdd)
+            yield
+            self.assertEqual((yield elem_20_rw.w_stb), 0)
+            yield self.dut.w_stb.eq(0)
+            yield
+            self.assertEqual((yield elem_20_rw.w_stb), 1)
+            self.assertEqual((yield elem_20_rw.w_data), 0x3aa55)
+
+        with Simulator(self.dut, vcd_file=open("test.vcd", "w")) as sim:
+            sim.add_clock(1e-6)
+            sim.add_sync_process(sim_test())
+            sim.run()