csr.bus: rewrite using the MemoryMap abstraction.
authorwhitequark <whitequark@whitequark.org>
Fri, 25 Oct 2019 18:41:40 +0000 (18:41 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 25 Oct 2019 18:42:56 +0000 (18:42 +0000)
nmigen_soc/csr/bus.py
nmigen_soc/test/test_csr_bus.py

index f65f7e8a82705a6b08dbbbce23d7e8184243cd22..8450acc9f385e0e8d54add23c4c54198f52a809c 100644 (file)
@@ -1,6 +1,8 @@
 import enum
 from nmigen import *
 
+from ..memory import MemoryMap
+
 
 __all__ = ["Element", "Interface", "Decoder"]
 
@@ -73,6 +75,9 @@ class Element(Record):
             ]
         super().__init__(layout, name=name, src_loc_at=1)
 
+    # FIXME: get rid of this
+    __hash__ = object.__hash__
+
 
 class Interface(Record):
     """CPU-side CSR interface.
@@ -189,61 +194,27 @@ class Decoder(Elaboratable):
         CSR bus providing access to registers.
     """
     def __init__(self, *, addr_width, data_width, alignment=0):
-        self.bus = Interface(addr_width=addr_width, data_width=data_width)
+        self.bus  = Interface(addr_width=addr_width, data_width=data_width)
+        self._map = MemoryMap(addr_width=addr_width, data_width=data_width, alignment=alignment)
 
-        if not isinstance(alignment, int) or alignment < 0:
-            raise ValueError("Alignment must be a non-negative integer, not {!r}"
-                             .format(alignment))
-        self.alignment = alignment
+    def align_to(self, alignment):
+        """Align the implicit address of the next register.
 
-        self._next_addr = 0
-        self._elements  = dict()
+        See :meth:`MemoryMap.align_to` for details.
+        """
+        return self._map.align_to(alignment)
 
-    def add(self, element):
+    def add(self, element, *, addr=None, alignment=None):
         """Add a register.
 
-        Arguments
-        ---------
-        element : :class:`Element`
-            Interface of the register.
-
-        Return value
-        ------------
-        An ``(addr, size)`` tuple, where ``addr`` is the address assigned to the first chunk of
-        the register, and ``size`` is the amount of chunks it takes, which may be greater than
-        ``element.size // self.data_width`` due to alignment.
+        See :meth:`MemoryMap.add_resource` for details.
         """
         if not isinstance(element, Element):
             raise TypeError("Element must be an instance of csr.Element, not {!r}"
                             .format(element))
 
-        addr = self.align_to(self.alignment)
-        self._next_addr += (element.width + self.bus.data_width - 1) // self.bus.data_width
-        size = self.align_to(self.alignment) - addr
-        self._elements[addr] = element, size
-        return addr, size
-
-    def align_to(self, alignment):
-        """Align the next register explicitly.
-
-        Arguments
-        ---------
-        alignment : int
-            Register alignment. The address assigned to the next register will be a multiple of
-            ``2 ** alignment`` or ``2 ** self.alignment``, whichever is greater.
-
-        Return value
-        ------------
-        Address of the next register.
-        """
-        if not isinstance(alignment, int) or alignment < 0:
-            raise ValueError("Alignment must be a non-negative integer, not {!r}"
-                             .format(alignment))
-
-        align_chunks = 1 << alignment
-        if self._next_addr % align_chunks != 0:
-            self._next_addr += align_chunks - (self._next_addr % align_chunks)
-        return self._next_addr
+        size = (element.width + self.bus.data_width - 1) // self.bus.data_width
+        return self._map.add_resource(element, size=size, addr=addr, alignment=alignment)
 
     def elaborate(self, platform):
         m = Module()
@@ -255,7 +226,7 @@ class Decoder(Elaboratable):
         # 2-AND or 2-OR gates.
         r_data_fanin = 0
 
-        for elem_addr, (elem, elem_size) in self._elements.items():
+        for elem, (elem_start, elem_end) in self._map.resources():
             shadow = Signal(elem.width, name="{}__shadow".format(elem.name))
             if elem.access.writable():
                 m.d.comb += elem.w_data.eq(shadow)
@@ -265,15 +236,16 @@ class Decoder(Elaboratable):
             # carry chains for comparisons, even with a constant. (Register sizes don't have
             # to be powers of 2.)
             with m.Switch(self.bus.addr):
-                for chunk_offset in range(elem_size):
-                    chunk_slice = slice(chunk_offset * self.bus.data_width,
-                                        (chunk_offset + 1) * self.bus.data_width)
-                    with m.Case(elem_addr + chunk_offset):
+                for chunk_offset, chunk_addr in enumerate(range(elem_start, elem_end)):
+                    with m.Case(chunk_addr):
+                        shadow_slice = shadow[chunk_offset * self.bus.data_width:
+                                              (chunk_offset + 1) * self.bus.data_width]
+
                         if elem.access.readable():
                             chunk_r_stb = Signal(self.bus.data_width,
                                 name="{}__r_stb_{}".format(elem.name, chunk_offset))
-                            r_data_fanin |= Mux(chunk_r_stb, shadow[chunk_slice], 0)
-                            if chunk_offset == 0:
+                            r_data_fanin |= Mux(chunk_r_stb, shadow_slice, 0)
+                            if chunk_addr == elem_start:
                                 m.d.comb += elem.r_stb.eq(self.bus.r_stb)
                                 with m.If(self.bus.r_stb):
                                     m.d.sync += shadow.eq(elem.r_data)
@@ -281,12 +253,12 @@ class Decoder(Elaboratable):
                             m.d.sync += chunk_r_stb.eq(self.bus.r_stb)
 
                         if elem.access.writable():
-                            if chunk_offset == elem_size - 1:
+                            if chunk_addr == elem_end - 1:
                                 # Delay by 1 cycle, avoiding combinatorial paths through
                                 # the CSR bus and into CSR registers.
                                 m.d.sync += elem.w_stb.eq(self.bus.w_stb)
                             with m.If(self.bus.w_stb):
-                                m.d.sync += shadow[chunk_slice].eq(self.bus.w_data)
+                                m.d.sync += shadow_slice.eq(self.bus.w_data)
 
                 with m.Default():
                     m.d.sync += shadow.eq(0)
index 1426aef7817dd2dc9baffaacac8b133c3d3ad0f5..0c9f8d4917dea8ecbf378f2fac806abccfc244f1 100644 (file)
@@ -86,14 +86,6 @@ class DecoderTestCase(unittest.TestCase):
     def setUp(self):
         self.dut = Decoder(addr_width=16, data_width=8)
 
-    def test_alignment_wrong(self):
-        with self.assertRaisesRegex(ValueError,
-                r"Alignment must be a non-negative integer, not -1"):
-            Decoder(addr_width=16, data_width=8, alignment=-1)
-
-    def test_attrs(self):
-        self.assertEqual(self.dut.alignment, 0)
-
     def test_add_4b(self):
         self.assertEqual(self.dut.add(Element(4, "rw")),
                          (0, 1))
@@ -114,7 +106,7 @@ class DecoderTestCase(unittest.TestCase):
         self.assertEqual(self.dut.add(Element(16, "rw")),
                          (0, 2))
         self.assertEqual(self.dut.add(Element(8, "rw")),
-                         (2, 1))
+                         (2, 3))
 
     def test_add_wrong(self):
         with self.assertRaisesRegex(ValueError,
@@ -126,7 +118,7 @@ class DecoderTestCase(unittest.TestCase):
                          (0, 1))
         self.assertEqual(self.dut.align_to(2), 4)
         self.assertEqual(self.dut.add(Element(8, "rw")),
-                         (4, 1))
+                         (4, 5))
 
     def test_sim(self):
         bus = self.dut.bus
@@ -206,28 +198,25 @@ class DecoderAlignedTestCase(unittest.TestCase):
     def setUp(self):
         self.dut = Decoder(addr_width=16, data_width=8, alignment=2)
 
-    def test_attrs(self):
-        self.assertEqual(self.dut.alignment, 2)
-
     def test_add_two(self):
         self.assertEqual(self.dut.add(Element(8, "rw")),
                          (0, 4))
         self.assertEqual(self.dut.add(Element(16, "rw")),
-                         (4, 4))
+                         (4, 8))
 
     def test_over_align_to(self):
         self.assertEqual(self.dut.add(Element(8, "rw")),
                          (0, 4))
         self.assertEqual(self.dut.align_to(3), 8)
         self.assertEqual(self.dut.add(Element(8, "rw")),
-                         (8, 4))
+                         (8, 12))
 
     def test_under_align_to(self):
         self.assertEqual(self.dut.add(Element(8, "rw")),
                          (0, 4))
         self.assertEqual(self.dut.align_to(1), 4)
         self.assertEqual(self.dut.add(Element(8, "rw")),
-                         (4, 4))
+                         (4, 8))
 
     def test_sim(self):
         bus = self.dut.bus