csr.bus: use proper enum instead of ad-hoc string enumeration.
authorwhitequark <whitequark@whitequark.org>
Fri, 25 Oct 2019 10:54:49 +0000 (10:54 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 25 Oct 2019 10:54:49 +0000 (10:54 +0000)
nmigen_soc/csr/bus.py
nmigen_soc/test/test_csr_bus.py

index df8ecabfb0c9a3eb21b11fd348b492046b696180..f65f7e8a82705a6b08dbbbce23d7e8184243cd22 100644 (file)
@@ -1,12 +1,27 @@
-from functools import reduce
+import enum
 from nmigen import *
-from nmigen import tracer
 
 
 __all__ = ["Element", "Interface", "Decoder"]
 
 
 class Element(Record):
+    class Access(enum.Enum):
+        """Register access mode.
+
+        Coarse access mode for the entire register. Individual fields can have more restrictive
+        access mode, e.g. R/O fields can be a part of an R/W register.
+        """
+        R  = "r"
+        W  = "w"
+        RW = "rw"
+
+        def readable(self):
+            return self == self.R or self == self.RW
+
+        def writable(self):
+            return self == self.W or self == self.RW
+
     """Peripheral-side CSR interface.
 
     A low-level interface to a single atomically readable and writable register in a peripheral.
@@ -17,6 +32,8 @@ class Element(Record):
     ----------
     width : int
         Width of the register.
+    access : :class:`Access`
+        Register access mode.
     name : str
         Name of the underlying record.
 
@@ -37,19 +54,19 @@ class Element(Record):
         if not isinstance(width, int) or width < 0:
             raise ValueError("Width must be a non-negative integer, not {!r}"
                              .format(width))
-        if access not in ("r", "w", "rw"):
+        if not isinstance(access, Element.Access) and access not in ("r", "w", "rw"):
             raise ValueError("Access mode must be one of \"r\", \"w\", or \"rw\", not {!r}"
                              .format(access))
         self.width  = width
-        self.access = access
+        self.access = Element.Access(access)
 
         layout = []
-        if "r" in self.access:
+        if self.access.readable():
             layout += [
                 ("r_data", width),
                 ("r_stb",  1),
             ]
-        if "w" in self.access:
+        if self.access.writable():
             layout += [
                 ("w_data", width),
                 ("w_stb",  1),
@@ -240,7 +257,7 @@ class Decoder(Elaboratable):
 
         for elem_addr, (elem, elem_size) in self._elements.items():
             shadow = Signal(elem.width, name="{}__shadow".format(elem.name))
-            if "w" in elem.access:
+            if elem.access.writable():
                 m.d.comb += elem.w_data.eq(shadow)
 
             # Enumerate every address used by the register explicitly, rather than using
@@ -252,7 +269,7 @@ class Decoder(Elaboratable):
                     chunk_slice = slice(chunk_offset * self.bus.data_width,
                                         (chunk_offset + 1) * self.bus.data_width)
                     with m.Case(elem_addr + chunk_offset):
-                        if "r" in elem.access:
+                        if elem.access.readable():
                             chunk_r_stb = Signal(self.bus.data_width,
                                 name="{}__r_stb_{}".format(elem.name, chunk_offset))
                             r_data_fanin |= Mux(chunk_r_stb, shadow[chunk_slice], 0)
@@ -263,7 +280,7 @@ class Decoder(Elaboratable):
                             # Delay by 1 cycle, allowing reads to be pipelined.
                             m.d.sync += chunk_r_stb.eq(self.bus.r_stb)
 
-                        if "w" in elem.access:
+                        if elem.access.writable():
                             if chunk_offset == elem_size - 1:
                                 # Delay by 1 cycle, avoiding combinatorial paths through
                                 # the CSR bus and into CSR registers.
index 59b9d14017b00bfa8b451a792beecd8a8bbe91f2..1426aef7817dd2dc9baffaacac8b133c3d3ad0f5 100644 (file)
@@ -10,7 +10,7 @@ class ElementTestCase(unittest.TestCase):
     def test_layout_1_ro(self):
         elem = Element(1, "r")
         self.assertEqual(elem.width, 1)
-        self.assertEqual(elem.access, "r")
+        self.assertEqual(elem.access, Element.Access.R)
         self.assertEqual(elem.layout, Layout.cast([
             ("r_data", 1),
             ("r_stb", 1),
@@ -19,7 +19,7 @@ class ElementTestCase(unittest.TestCase):
     def test_layout_8_rw(self):
         elem = Element(8, access="rw")
         self.assertEqual(elem.width, 8)
-        self.assertEqual(elem.access, "rw")
+        self.assertEqual(elem.access, Element.Access.RW)
         self.assertEqual(elem.layout, Layout.cast([
             ("r_data", 8),
             ("r_stb", 1),
@@ -30,16 +30,16 @@ class ElementTestCase(unittest.TestCase):
     def test_layout_10_wo(self):
         elem = Element(10, "w")
         self.assertEqual(elem.width, 10)
-        self.assertEqual(elem.access, "w")
+        self.assertEqual(elem.access, Element.Access.W)
         self.assertEqual(elem.layout, Layout.cast([
             ("w_data", 10),
             ("w_stb", 1),
         ]))
 
     def test_layout_0_rw(self): # degenerate but legal case
-        elem = Element(0, access="rw")
+        elem = Element(0, access=Element.Access.RW)
         self.assertEqual(elem.width, 0)
-        self.assertEqual(elem.access, "rw")
+        self.assertEqual(elem.access, Element.Access.RW)
         self.assertEqual(elem.layout, Layout.cast([
             ("r_data", 0),
             ("r_stb", 1),