big reorg, shuffle code to functions, makes the FSM clearer
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 16 Aug 2020 10:11:48 +0000 (11:11 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 16 Aug 2020 10:11:48 +0000 (11:11 +0100)
src/soc/experiment/mmu.py

index 28e2675d95d3049e777a26ee5afefa2878ef27eb..46d9867ce952d0019ab238496880c8ebff4bc8e2 100644 (file)
@@ -6,6 +6,7 @@ based on Anton Blanchard microwatt mmu.vhdl
 from enum import Enum, unique
 from nmigen import (C, Module, Signal, Elaboratable, Mux, Cat, Repl, Signal)
 from nmigen.cli import main
+from nmigen.cli import rtlil
 from nmutil.iocontrol import RecordObject
 from nmutil.byterev import byte_reverse
 
@@ -160,6 +161,171 @@ class MMU(Elaboratable):
         self.d_in  = DcacheToMmuType()
         self.i_out = MmuToIcacheType()
 
+    def radix_tree_idle(self, m, l_in, v):
+        comb = m.d.comb
+        pt_valid = Signal()
+        pgtbl = Signal(64)
+        with m.If(~l_in.addr[63]):
+            comb += pgtbl.eq(r.pgtbl0)
+            comb += pt_valid.eq(r.pt0_valid)
+        with m.Else():
+            comb += pgtbl.eq(r.pt3_valid)
+            comb += pt_valid.eq(r.pt3_valid)
+
+        # rts == radix tree size, number of address bits
+        # being translated
+        rts = Signal(6)
+        comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
+
+        # mbits == number of address bits to index top
+        # level of tree
+        mbits = Signal(6)
+        comb += mbits.eq(pgtbl[0:5])
+
+        # set v.shift to rts so that we can use finalmask
+        # for the segment check
+        comb += v.shift.eq(rts)
+        comb += v.mask_size.eq(mbits[0:5])
+        comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
+
+        with m.If(l_in.valid):
+            comb += v.addr.eq(l_in.addr)
+            comb += v.iside.eq(l_in.iside)
+            comb += v.store.eq(~(l_in.load | l_in.iside))
+
+            with m.If(l_in.tlbie):
+                # Invalidate all iTLB/dTLB entries for
+                # tlbie with RB[IS] != 0 or RB[AP] != 0,
+                # or for slbia
+                comb += v.inval_all.eq(l_in.slbia
+                                       | l_in.addr[11]
+                                       | l_in.addr[10]
+                                       | l_in.addr[7]
+                                       | l_in.addr[6]
+                                       | l_in.addr[5]
+                                      )
+                # The RIC field of the tlbie instruction
+                # comes across on the sprn bus as bits 2--3.
+                # RIC=2 flushes process table caches.
+                with m.If(l_in.sprn[3]):
+                    comb += v.pt0_valid.eq(0)
+                    comb += v.pt3_valid.eq(0)
+                comb += v.state.eq(State.DO_TLBIE)
+            with m.Else():
+                comb += v.valid.eq(1)
+                with m.If(~pt_valid):
+                    # need to fetch process table entry
+                    # set v.shift so we can use finalmask
+                    # for generating the process table
+                    # entry address
+                    comb += v.shift.eq(r.prtble[0:5])
+                    comb += v.state.eq(State.PROC_TBL_READ)
+
+                with m.If(~mbits):
+                    # Use RPDS = 0 to disable radix tree walks
+                    comb += v.state.eq(State.RADIX_FINISH)
+                    comb += v.invalid.eq(1)
+                with m.Else():
+                    comb += v.state.eq(State.SEGMENT_CHECK)
+
+        with m.If(l_in.mtspr):
+            # Move to PID needs to invalidate L1 TLBs
+            # and cached pgtbl0 value.  Move to PRTBL
+            # does that plus invalidating the cached
+            # pgtbl3 value as well.
+            with m.If(~l_in.sprn[9]):
+                comb += v.pid.eq(l_in.rs[0:32])
+            with m.Else():
+                comb += v.prtbl.eq(l_in.rs)
+                comb += v.pt3_valid.eq(0)
+
+            comb += v.pt0_valid.eq(0)
+            comb += v.inval_all.eq(1)
+            comb += v.state.eq(State.DO_TLBIE)
+
+    def proc_tbl_wait(self, m, v, r, data):
+        comb = m.d.comb
+        with m.If(r.addr[63]):
+            comb += v.pgtbl3.eq(data)
+            comb += v.pt3_valid.eq(1)
+        with m.Else():
+            comb += v.pgtbl0.eq(data)
+            comb += v.pt0_valid.eq(1)
+        # rts == radix tree size, # address bits being translated
+        rts = Signal(6)
+        comb += rts.eq(Cat(data[5:8], data[61:63]))
+
+        # mbits == # address bits to index top level of tree
+        mbits = Signal(6)
+        comb += mbits.eq(data[0:5])
+        # set v.shift to rts so that we can use
+        # finalmask for the segment check
+        comb += v.shift.eq(rts)
+        comb += v.mask_size.eq(mbits[0:5])
+        comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
+
+        with m.If(~mbits):
+            comb += v.state.eq(State.RADIX_FINISH)
+            comb += v.invalid.eq(1)
+            comb += v.state.eq(State.SEGMENT_CHECK)
+
+    def radix_read_wait(self, m, v, r, d_in, data):
+        comb = m.d.comb
+        comb += v.pde.eq(data)
+        # test valid bit
+        with m.If(data[63]):
+            with m.If(data[62]):
+                # check permissions and RC bits
+                perm_ok = Signal()
+                comb += perm_ok.eq(0)
+                with m.If(r.priv | ~data[3]):
+                    with m.If(~r.iside):
+                        comb += perm_ok.eq((data[1] | data[2]) & (~r.store))
+                    with m.Else():
+                        # no IAMR, so no KUEP support
+                        # for now deny execute
+                        # permission if cache inhibited
+                        comb += perm_ok.eq(data[0] & ~data[5])
+
+                rc_ok = Signal()
+                comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
+                with m.If(perm_ok & rc_ok):
+                    comb += v.state.eq(State.RADIX_LOAD_TLB)
+                with m.Else():
+                    comb += v.state.eq(State.RADIX_ERROR)
+                    comb += v.perm_err.eq(~perm_ok)
+                    # permission error takes precedence
+                    # over RC error
+                    comb += v.rc_error.eq(perm_ok)
+            with m.Else():
+                mbits = Signal(6)
+                comb += mbits.eq(data[0:5])
+                with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
+                    comb += v.state.eq(State.RADIX_FINISH)
+                    comb += v.badtree.eq(1)
+                with m.Else():
+                    comb += v.shift.eq(v.shif - mbits)
+                    comb += v.mask_size.eq(mbits[0:5])
+                    comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
+                    comb += v.state.eq(State.RADIX_LOOKUP)
+
+    def segment_check(self, m, v, r, data):
+        comb = m.d.comb
+        mbits = Signal(6)
+        nonzero = Signal()
+        comb += mbits.eq(r.mask_size)
+        comb += v.shift.eq(r.shift + (31 - 12) - mbits)
+        comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
+        with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
+            comb += v.state.eq(State.RADIX_FINISH)
+            comb += v.segerror.eq(1)
+        with m.Elif((mbits < 5) | (mbits > 16) |
+                    (mbits > (r.shift + (31-12)))):
+            comb += v.state.eq(State.RADIX_FINISH)
+            comb += v.badtree.eq(1)
+        with m.Else():
+            comb += v.state.eq(State.RADIX_LOOKUP)
+
     def elaborate(self, platform):
         m = Module()
 
@@ -212,22 +378,16 @@ class MMU(Elaboratable):
         sync += r.eq(rin)
 
         v = RegStage()
-        dcrq = Signal()
+        dcreq = Signal()
         tlb_load = Signal()
         itlb_load = Signal()
         tlbie_req = Signal()
         prtbl_rd = Signal()
-        pt_valid = Signal()
         effpid = Signal(32)
         prtable_addr = Signal(64)
-        rts = Signal(6)
-        mbits = Signal(6)
         pgtable_addr = Signal(64)
         pte = Signal(64)
         tlb_data = Signal(64)
-        nonzero = Signal()
-        pgtbl = Signal(64)
-        rc_ok = Signal()
         addr = Signal(64)
 
         comb += v.eq(r)
@@ -252,80 +412,7 @@ class MMU(Elaboratable):
 
         with m.Switch(r.state):
             with m.Case(State.IDLE):
-                with m.If(~l_in.addr[63]):
-                    comb += pgtbl.eq(r.pgtbl0)
-                    comb += pt_valid.eq(r.pt0_valid)
-                with m.Else():
-                    comb += pgtbl.eq(r.pt3_valid)
-                    comb += pt_valid.eq(r.pt3_valid)
-
-                # rts == radix tree size, number of address bits
-                # being translated
-                comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
-
-                # mbits == number of address bits to index top
-                # level of tree
-                comb += mbits.eq(pgtbl[0:5])
-
-                # set v.shift to rts so that we can use finalmask
-                # for the segment check
-                comb += v.shift.eq(rts)
-                comb += v.mask_size.eq(mbits[0:5])
-                comb += v.pgbase.eq(Cat(C(0, 8), pgtbl[8:56]))
-
-                with m.If(l_in.valid):
-                    comb += v.addr.eq(l_in.addr)
-                    comb += v.iside.eq(l_in.iside)
-                    comb += v.store.eq(~(l_in.load | l_in.iside))
-                    with m.If(l_in.tlbie):
-                        # Invalidate all iTLB/dTLB entries for
-                        # tlbie with RB[IS] != 0 or RB[AP] != 0,
-                        # or for slbia
-                        comb += v.inval_all.eq(l_in.slbia
-                                               | l_in.addr[11]
-                                               | l_in.addr[10]
-                                               | l_in.addr[7]
-                                               | l_in.addr[6]
-                                               | l_in.addr[5]
-                                              )
-                        # The RIC field of the tlbie instruction
-                        # comes across on the sprn bus as bits 2--3.
-                        # RIC=2 flushes process table caches.
-                        with m.If(l_in.sprn[3]):
-                            comb += v.pt0_valid.eq(0)
-                            comb += v.pt3_valid.eq(0)
-                        comb += v.state.eq(State.DO_TLBIE)
-                    with m.Else():
-                        comb += v.valid.eq(1)
-                        with m.If(~pt_valid):
-                            # need to fetch process table entry
-                            # set v.shift so we can use finalmask
-                            # for generating the process table
-                            # entry address
-                            comb += v.shift.eq(r.prtble[0:5])
-                            comb += v.state.eq(State.PROC_TBL_READ)
-
-                        with m.If(~mbits):
-                            # Use RPDS = 0 to disable radix tree walks
-                            comb += v.state.eq(State.RADIX_FINISH)
-                            comb += v.invalid.eq(1)
-                        with m.Else():
-                            comb += v.state.eq(State.SEGMENT_CHECK)
-
-                with m.If(l_in.mtspr):
-                    # Move to PID needs to invalidate L1 TLBs
-                    # and cached pgtbl0 value.  Move to PRTBL
-                    # does that plus invalidating the cached
-                    # pgtbl3 value as well.
-                    with m.If(~l_in.sprn[9]):
-                        comb += v.pid.eq(l_in.rs[0:32])
-                    with m.Else():
-                        comb += v.prtbl.eq(l_in.rs)
-                        comb += v.pt3_valid.eq(0)
-
-                    comb += v.pt0_valid.eq(0)
-                    comb += v.inval_all.eq(1)
-                    comb += v.state.eq(State.DO_TLBIE)
+                self.radix_tree_idle(m, l_in, v)
 
             with m.Case(State.DO_TLBIE):
                 comb += dcreq.eq(1)
@@ -343,45 +430,14 @@ class MMU(Elaboratable):
 
             with m.Case(State.PROC_TBL_WAIT):
                 with m.If(d_in.done):
-                    with m.If(r.addr[63]):
-                        comb += v.pgtbl3.eq(data)
-                        comb += v.pt3_valid.eq(1)
-                    with m.Else():
-                        comb += v.pgtbl0.eq(data)
-                        comb += v.pt0_valid.eq(1)
-                    # rts == radix tree size, # address bits being translated
-                    comb += rts.eq(Cat(data[5:8], data[61:63]))
-
-                    # mbits == # address bits to index top level of tree
-                    comb += mbits.eq(data[0:5])
-                    # set v.shift to rts so that we can use
-                    # finalmask for the segment check
-                    comb += v.shift.eq(rts)
-                    comb += v.mask_size.eq(mbits[0:5])
-                    comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
-
-                    with m.If(~mbits):
-                        comb += v.state.eq(State.RADIX_FINISH)
-                        comb += v.invalid.eq(1)
-                        comb += v.state.eq(State.SEGMENT_CHECK)
+                    self.proc_tbl_wait(m, v, r, data)
 
                 with m.If(d_in.err):
                     comb += v.state.eq(State.RADIX_FINISH)
                     comb += v.badtree.eq(1)
 
             with m.Case(State.SEGMENT_CHECK):
-                comb += mbits.eq(r.mask_size)
-                comb += v.shift.eq(r.shift + (31 - 12) - mbits)
-                comb += nonzero.eq((r.addr[31:62] & ~finalmask[0:31]).bool())
-                with m.If((r.addr[63] ^ r.addr[62]) | nonzero):
-                    comb += v.state.eq(State.RADIX_FINISH)
-                    comb += v.segerror.eq(1)
-                with m.Elif((mbits < 5) | (mbits > 16)
-                          | (mbits > (r.shift + (31-12)))):
-                    comb += v.state.eq(State.RADIX_FINISH)
-                    comb += v.badtree.eq(1)
-                with m.Else():
-                    comb += v.state.eq(State.RADIX_LOOKUP)
+                self.segment_check(m, v, r, data)
 
             with m.Case(State.RADIX_LOOKUP):
                 comb += dcreq.eq(1)
@@ -389,47 +445,11 @@ class MMU(Elaboratable):
 
             with m.Case(State.RADIX_READ_WAIT):
                 with m.If(d_in.done):
-                    comb += v.pde.eq(data)
-                    # test valid bit
-                    with m.If(data[63]):
-                        with m.If(data[62]):
-                            # check permissions and RC bits
-                            perm_ok = Signal()
-                            comb += perm_ok.eq(0)
-                            with m.If(r.priv | ~data[3]):
-                                with m.If(~r.iside):
-                                    comb += perm_ok.eq((data[1] | data[2]) &
-                                                       (~r.store))
-                                with m.Else():
-                                    # no IAMR, so no KUEP support
-                                    # for now deny execute
-                                    # permission if cache inhibited
-                                    comb += perm_ok.eq(data[0] & ~data[5])
-
-                            comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
-                            with m.If(perm_ok & rc_ok):
-                                comb += v.state.eq(State.RADIX_LOAD_TLB)
-                            with m.Else():
-                                comb += v.state.eq(State.RADIX_ERROR)
-                                comb += v.perm_err.eq(~perm_ok)
-                                # permission error takes precedence
-                                # over RC error
-                                comb += v.rc_error.eq(perm_ok)
-                        with m.Else():
-                            comb += mbits.eq(data[0:5])
-                            with m.If((mbits < 5) | (mbits > 16) |
-                                      (mbits > r.shift)):
-                                comb += v.state.eq(State.RADIX_FINISH)
-                                comb += v.badtree.eq(1)
-                            with m.Else():
-                                comb += v.shift.eq(v.shif - mbits)
-                                comb += v.mask_size.eq(mbits[0:5])
-                                comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
-                                comb += v.state.eq(State.RADIX_LOOKUP)
-                    with m.Else():
-                        # non-present PTE, generate a DSI
-                        comb += v.state.eq(State.RADIX_FINISH)
-                        comb += v.invalid.eq(1)
+                    self.radix_read_wait(m, v, r, d_in, data)
+                with m.Else():
+                    # non-present PTE, generate a DSI
+                    comb += v.state.eq(State.RADIX_FINISH)
+                    comb += v.invalid.eq(1)
 
                 with m.If(d_in.err):
                     comb += v.state.eq(State.RADIX_FINISH)
@@ -447,8 +467,8 @@ class MMU(Elaboratable):
             with m.Case(State.RADIX_FINISH):
                 comb += v.state.eq(State.IDLE)
 
-        with m.If((v.state == State.RADIX_FINISH)
-                  | ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
+        with m.If((v.state == State.RADIX_FINISH) |
+                 ((v.state == State.RADIX_LOAD_TLB) & r.iside)):
             comb += v.err.eq(v.invalid | v.badtree | v.segerror
                              | v.perm_err | v.rc_error)
             comb += v.done.eq(~v.err)
@@ -502,7 +522,7 @@ class MMU(Elaboratable):
         comb += d_out.valid.eq(dcreq)
         comb += d_out.tlbie.eq(tlbie_req)
         comb += d_out.doall.eq(r.inval_all)
-        comb += d_out.tlbld.eeq(tlb_load)
+        comb += d_out.tlbld.eq(tlb_load)
         comb += d_out.addr.eq(addr)
         comb += d_out.pte.eq(tlb_data)
 
@@ -549,9 +569,7 @@ def mmu_sim():
 
 def test_mmu():
     dut = MMU()
-    rp = dut.read_port()
-    wp = dut.write_port()
-    vl = rtlil.convert(dut, ports=dut.ports())
+    vl = rtlil.convert(dut, ports=[])#dut.ports())
     with open("test_mmu.il", "w") as f:
         f.write(vl)