icache.py fix several subtle bugs that were lines that I had missed from
[soc.git] / src / soc / experiment / mmu.py
index 1daf26fa871afaa381ee04e84647e3b4e8e442d2..a8c514f7d6e1816a132a804369ebd24596cfeee2 100644 (file)
@@ -1,3 +1,11 @@
+# MMU
+#
+# License for original copyright mmu.vhdl by microwatt authors: CC4
+# License for copyrighted modifications made in mmu.py: LGPLv3+
+#
+# This derivative work although includes CC4 licensed material is
+# covered by the LGPLv3+
+
 """MMU
 
 based on Anton Blanchard microwatt mmu.vhdl
@@ -9,17 +17,21 @@ from nmigen.cli import main
 from nmigen.cli import rtlil
 from nmutil.iocontrol import RecordObject
 from nmutil.byterev import byte_reverse
+from nmutil.mask import Mask, masked
+from nmutil.util import Display
+
+if True:
+    from nmigen.back.pysim import Simulator, Delay, Settle
+else:
+    from nmigen.sim.cxxsim import Simulator, Delay, Settle
+from nmutil.util import wrap
 
-from soc.experiment.mem_types import (LoadStore1ToMmuType,
-                                 MmuToLoadStore1Type,
-                                 MmuToDcacheType,
-                                 DcacheToMmuType,
-                                 MmuToIcacheType)
+from soc.experiment.mem_types import (LoadStore1ToMMUType,
+                                 MMUToLoadStore1Type,
+                                 MMUToDCacheType,
+                                 DCacheToMMUType,
+                                 MMUToICacheType)
 
-# -- Radix MMU
-# -- Supports 4-level trees as in arch 3.0B, but not the
-# -- two-step translation
-# -- for guests under a hypervisor (i.e. there is no gRA -> hRA translation).
 
 @unique
 class State(Enum):
@@ -75,31 +87,34 @@ class MMU(Elaboratable):
     (i.e. there is no gRA -> hRA translation).
     """
     def __init__(self):
-        self.l_in  = LoadStore1ToMmuType()
-        self.l_out = MmuToLoadStore1Type()
-        self.d_out = MmuToDcacheType()
-        self.d_in  = DcacheToMmuType()
-        self.i_out = MmuToIcacheType()
+        self.l_in  = LoadStore1ToMMUType()
+        self.l_out = MMUToLoadStore1Type()
+        self.d_out = MMUToDCacheType()
+        self.d_in  = DCacheToMMUType()
+        self.i_out = MMUToICacheType()
 
     def radix_tree_idle(self, m, l_in, r, v):
         comb = m.d.comb
+        sync = m.d.sync
+
         pt_valid = Signal()
         pgtbl = Signal(64)
+        rts = Signal(6)
+        mbits = Signal(6)
+
         with m.If(~l_in.addr[63]):
             comb += pgtbl.eq(r.pgtbl0)
             comb += pt_valid.eq(r.pt0_valid)
         with m.Else():
-            comb += pgtbl.eq(r.pt3_valid)
+            comb += pgtbl.eq(r.pgtbl3)
             comb += pt_valid.eq(r.pt3_valid)
 
         # rts == radix tree size, number of address bits
         # being translated
-        rts = Signal(6)
         comb += rts.eq(Cat(pgtbl[5:8], pgtbl[61:63]))
 
         # mbits == number of address bits to index top
         # level of tree
-        mbits = Signal(6)
         comb += mbits.eq(pgtbl[0:5])
 
         # set v.shift to rts so that we can use finalmask
@@ -112,6 +127,12 @@ class MMU(Elaboratable):
             comb += v.addr.eq(l_in.addr)
             comb += v.iside.eq(l_in.iside)
             comb += v.store.eq(~(l_in.load | l_in.iside))
+            comb += v.priv.eq(l_in.priv)
+
+            comb += Display("state %d l_in.valid addr %x iside %d store %d "
+                            "rts %x mbits %x pt_valid %d",
+                            v.state, v.addr, v.iside, v.store,
+                            rts, mbits, pt_valid)
 
             with m.If(l_in.tlbie):
                 # Invalidate all iTLB/dTLB entries for
@@ -141,7 +162,7 @@ class MMU(Elaboratable):
                     comb += v.shift.eq(r.prtbl[0:5])
                     comb += v.state.eq(State.PROC_TBL_READ)
 
-                with m.If(~mbits):
+                with m.Elif(mbits == 0):
                     # Use RPDS = 0 to disable radix tree walks
                     comb += v.state.eq(State.RADIX_FINISH)
                     comb += v.invalid.eq(1)
@@ -171,66 +192,96 @@ class MMU(Elaboratable):
         with m.Else():
             comb += v.pgtbl0.eq(data)
             comb += v.pt0_valid.eq(1)
-        # rts == radix tree size, # address bits being translated
+
         rts = Signal(6)
+        mbits = Signal(6)
+
+        # rts == radix tree size, # address bits being translated
         comb += rts.eq(Cat(data[5:8], data[61:63]))
 
         # mbits == # address bits to index top level of tree
-        mbits = Signal(6)
         comb += mbits.eq(data[0:5])
-        # set v.shift to rts so that we can use
-        # finalmask for the segment check
+
+        # set v.shift to rts so that we can use finalmask for the segment check
         comb += v.shift.eq(rts)
         comb += v.mask_size.eq(mbits[0:5])
         comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
 
-        with m.If(~mbits):
+        with m.If(mbits):
+            comb += v.state.eq(State.SEGMENT_CHECK)
+        with m.Else():
             comb += v.state.eq(State.RADIX_FINISH)
             comb += v.invalid.eq(1)
-            comb += v.state.eq(State.SEGMENT_CHECK)
 
     def radix_read_wait(self, m, v, r, d_in, data):
         comb = m.d.comb
+        sync = m.d.sync
+
+        perm_ok = Signal()
+        rc_ok = Signal()
+        mbits = Signal(6)
+        valid = Signal()
+        leaf = Signal()
+        badtree = Signal()
+
+        comb += Display("RDW %016x done %d "
+                        "perm %d rc %d mbits %d shf %d "
+                        "valid %d leaf %d bad %d",
+                        data, d_in.done, perm_ok, rc_ok,
+                        mbits, r.shift, valid, leaf, badtree)
+
+        # set pde
         comb += v.pde.eq(data)
+
         # test valid bit
-        with m.If(data[63]):
-            with m.If(data[62]):
+        comb += valid.eq(data[63]) # valid=data[63]
+        comb += leaf.eq(data[62]) # valid=data[63]
+
+        comb += v.pde.eq(data)
+        # valid & leaf
+        with m.If(valid):
+            with m.If(leaf):
                 # check permissions and RC bits
-                perm_ok = Signal()
-                comb += perm_ok.eq(0)
                 with m.If(r.priv | ~data[3]):
                     with m.If(~r.iside):
-                        comb += perm_ok.eq((data[1] | data[2]) & (~r.store))
+                        comb += perm_ok.eq(data[1] | (data[2] & ~r.store))
                     with m.Else():
-                        # no IAMR, so no KUEP support
-                        # for now deny execute
-                        # permission if cache inhibited
+                        # no IAMR, so no KUEP support for now
+                        # deny execute permission if cache inhibited
                         comb += perm_ok.eq(data[0] & ~data[5])
 
-                rc_ok = Signal()
-                comb += rc_ok.eq(data[8] & (data[7] | (~r.store)))
+                comb += rc_ok.eq(data[8] & (data[7] | ~r.store))
                 with m.If(perm_ok & rc_ok):
                     comb += v.state.eq(State.RADIX_LOAD_TLB)
                 with m.Else():
                     comb += v.state.eq(State.RADIX_FINISH)
                     comb += v.perm_err.eq(~perm_ok)
-                    # permission error takes precedence
-                    # over RC error
+                    # permission error takes precedence over RC error
                     comb += v.rc_error.eq(perm_ok)
+
+            # valid & !leaf
             with m.Else():
-                mbits = Signal(6)
                 comb += mbits.eq(data[0:5])
-                with m.If((mbits < 5) | (mbits > 16) | (mbits > r.shift)):
+                comb += badtree.eq((mbits < 5) |
+                                   (mbits > 16) |
+                                   (mbits > r.shift))
+                with m.If(badtree):
                     comb += v.state.eq(State.RADIX_FINISH)
                     comb += v.badtree.eq(1)
                 with m.Else():
-                    comb += v.shift.eq(v.shift - mbits)
+                    comb += v.shift.eq(r.shift - mbits)
                     comb += v.mask_size.eq(mbits[0:5])
                     comb += v.pgbase.eq(Cat(C(0, 8), data[8:56]))
                     comb += v.state.eq(State.RADIX_LOOKUP)
 
+        with m.Else():
+            # non-present PTE, generate a DSI
+            comb += v.state.eq(State.RADIX_FINISH)
+            comb += v.invalid.eq(1)
+
     def segment_check(self, m, v, r, data, finalmask):
         comb = m.d.comb
+
         mbits = Signal(6)
         nonzero = Signal()
         comb += mbits.eq(r.mask_size)
@@ -246,25 +297,10 @@ class MMU(Elaboratable):
         with m.Else():
             comb += v.state.eq(State.RADIX_LOOKUP)
 
-    def elaborate(self, platform):
-        m = Module()
-
+    def mmu_0(self, m, r, rin, l_in, l_out, d_out, addrsh, mask):
         comb = m.d.comb
         sync = m.d.sync
 
-        addrsh = Signal(16)
-        mask = Signal(16)
-        finalmask = Signal(44)
-
-        r = RegStage("r")
-        rin = RegStage("r_in")
-
-        l_in  = self.l_in
-        l_out = self.l_out
-        d_out = self.d_out
-        d_in  = self.d_in
-        i_out = self.i_out
-
         # Multiplex internal SPR values back to loadstore1,
         # selected by l_in.sprn.
         with m.If(l_in.sprn[9]):
@@ -273,30 +309,45 @@ class MMU(Elaboratable):
             comb += l_out.sprval.eq(r.pid)
 
         with m.If(rin.valid):
-            pass
-            #sync += Display(f"MMU got tlb miss for {rin.addr}")
+            sync += Display("MMU got tlb miss for %x", rin.addr)
 
         with m.If(l_out.done):
-            pass
-            # sync += Display("MMU completing op without error")
+            sync += Display("MMU completing op without error")
 
         with m.If(l_out.err):
-            pass
-            # sync += Display(f"MMU completing op with err invalid"
-            #                 "{l_out.invalid} badtree={l_out.badtree}")
+            sync += Display("MMU completing op with err invalid"
+                            "%d badtree=%d", l_out.invalid, l_out.badtree)
 
         with m.If(rin.state == State.RADIX_LOOKUP):
-            pass
-            # sync += Display (f"radix lookup shift={rin.shift}"
-            #          "msize={rin.mask_size}")
+            sync += Display ("radix lookup shift=%d msize=%d",
+                            rin.shift, rin.mask_size)
 
         with m.If(r.state == State.RADIX_LOOKUP):
-            pass
-            # sync += Display(f"send load addr={d_out.addr}"
-            #           "addrsh={addrsh} mask={mask}")
-
+            sync += Display(f"send load addr=%x addrsh=%d mask=%x",
+                            d_out.addr, addrsh, mask)
         sync += r.eq(rin)
 
+    def elaborate(self, platform):
+        m = Module()
+
+        comb = m.d.comb
+        sync = m.d.sync
+
+        addrsh = Signal(16)
+        mask = Signal(16)
+        finalmask = Signal(44)
+
+        self.rin = rin = RegStage("r_in")
+        r = RegStage("r")
+
+        l_in  = self.l_in
+        l_out = self.l_out
+        d_out = self.d_out
+        d_in  = self.d_in
+        i_out = self.i_out
+
+        self.mmu_0(m, r, rin, l_in, l_out, d_out, addrsh, mask)
+
         v = RegStage()
         dcreq = Signal()
         tlb_load = Signal()
@@ -304,8 +355,8 @@ class MMU(Elaboratable):
         tlbie_req = Signal()
         prtbl_rd = Signal()
         effpid = Signal(32)
-        prtable_addr = Signal(64)
-        pgtable_addr = Signal(64)
+        prtb_adr = Signal(64)
+        pgtb_adr = Signal(64)
         pte = Signal(64)
         tlb_data = Signal(64)
         addr = Signal(64)
@@ -331,11 +382,18 @@ class MMU(Elaboratable):
         data = byte_reverse(m, "data", d_in.data, 8)
 
         # generate mask for extracting address fields for PTE addr generation
-        comb += mask.eq(Cat(C(0x1f,5), ((1<<r.mask_size)-1)))
+        m.submodules.pte_mask = pte_mask = Mask(16-5)
+        comb += pte_mask.shift.eq(r.mask_size - 5)
+        comb += mask.eq(Cat(C(0x1f, 5), pte_mask.mask))
 
         # generate mask for extracting address bits to go in
         # TLB entry in order to support pages > 4kB
-        comb += finalmask.eq(((1<<r.shift)-1))
+        m.submodules.tlb_mask = tlb_mask = Mask(44)
+        comb += tlb_mask.shift.eq(r.shift)
+        comb += finalmask.eq(tlb_mask.mask)
+
+        with m.If(r.state != State.IDLE):
+            sync += Display("MMU state %d %016x", r.state, data)
 
         with m.Switch(r.state):
             with m.Case(State.IDLE):
@@ -351,6 +409,7 @@ class MMU(Elaboratable):
                     comb += v.state.eq(State.RADIX_FINISH)
 
             with m.Case(State.PROC_TBL_READ):
+                sync += Display("   TBL_READ %016x", prtb_adr)
                 comb += dcreq.eq(1)
                 comb += prtbl_rd.eq(1)
                 comb += v.state.eq(State.PROC_TBL_WAIT)
@@ -367,17 +426,14 @@ class MMU(Elaboratable):
                 self.segment_check(m, v, r, data, finalmask)
 
             with m.Case(State.RADIX_LOOKUP):
+                sync += Display("   RADIX_LOOKUP")
                 comb += dcreq.eq(1)
                 comb += v.state.eq(State.RADIX_READ_WAIT)
 
             with m.Case(State.RADIX_READ_WAIT):
+                sync += Display("   READ_WAIT")
                 with m.If(d_in.done):
                     self.radix_read_wait(m, v, r, d_in, data)
-                with m.Else():
-                    # non-present PTE, generate a DSI
-                    comb += v.state.eq(State.RADIX_FINISH)
-                    comb += v.invalid.eq(1)
-
                 with m.If(d_in.err):
                     comb += v.state.eq(State.RADIX_FINISH)
                     comb += v.badtree.eq(1)
@@ -392,6 +448,7 @@ class MMU(Elaboratable):
                     comb += v.state.eq(State.IDLE)
 
             with m.Case(State.RADIX_FINISH):
+                sync += Display("   RADIX_FINISH")
                 comb += v.state.eq(State.IDLE)
 
         with m.If((v.state == State.RADIX_FINISH) |
@@ -403,29 +460,20 @@ class MMU(Elaboratable):
         with m.If(~r.addr[63]):
             comb += effpid.eq(r.pid)
 
-        comb += prtable_addr.eq(Cat(
-                                 C(0b0000, 4),
-                                 effpid[0:8],
-                                 (r.prtbl[12:36] & ~finalmask[0:24]) |
-                                 (effpid[8:32]   &  finalmask[0:24]),
-                                 r.prtbl[36:56]
-                                ))
-
-        comb += pgtable_addr.eq(Cat(
-                                 C(0b000, 3),
-                                 (r.pgbase[3:19] & ~mask) |
-                                 (addrsh         &  mask),
-                                 r.pgbase[19:56]
-                                ))
-
-        comb += pte.eq(Cat(
-                         r.pde[0:12],
-                          (r.pde[12:56]    & ~finalmask) |
-                          (r.addr[12:56] &  finalmask),
-                        ))
+        pr24 = Signal(24, reset_less=True)
+        comb += pr24.eq(masked(r.prtbl[12:36], effpid[8:32], finalmask))
+        comb += prtb_adr.eq(Cat(C(0, 4), effpid[0:8], pr24, r.prtbl[36:56]))
+
+        pg16 = Signal(16, reset_less=True)
+        comb += pg16.eq(masked(r.pgbase[3:19], addrsh, mask))
+        comb += pgtb_adr.eq(Cat(C(0, 3), pg16, r.pgbase[19:56]))
+
+        pd44 = Signal(44, reset_less=True)
+        comb += pd44.eq(masked(r.pde[12:56], r.addr[12:56], finalmask))
+        comb += pte.eq(Cat(r.pde[0:12], pd44))
 
         # update registers
-        rin.eq(v)
+        comb += rin.eq(v)
 
         # drive outputs
         with m.If(tlbie_req):
@@ -434,9 +482,9 @@ class MMU(Elaboratable):
             comb += addr.eq(Cat(C(0, 12), r.addr[12:64]))
             comb += tlb_data.eq(pte)
         with m.Elif(prtbl_rd):
-            comb += addr.eq(prtable_addr)
+            comb += addr.eq(prtb_adr)
         with m.Else():
-            comb += addr.eq(pgtable_addr)
+            comb += addr.eq(pgtb_adr)
 
         comb += l_out.done.eq(r.done)
         comb += l_out.err.eq(r.err)
@@ -461,40 +509,121 @@ class MMU(Elaboratable):
 
         return m
 
+stop = False
+
+def dcache_get(dut):
+    """simulator process for getting memory load requests
+    """
 
-def mmu_sim():
-    yield wp.waddr.eq(1)
-    yield wp.data_i.eq(2)
-    yield wp.wen.eq(1)
+    global stop
+
+    def b(x):
+        return int.from_bytes(x.to_bytes(8, byteorder='little'),
+                              byteorder='big', signed=False)
+
+    mem = {0x0: 0x000000, # to get mtspr prtbl working
+
+           0x10000:    # PARTITION_TABLE_2
+                       # PATB_GR=1 PRTB=0x1000 PRTS=0xb
+           b(0x800000000100000b),
+
+           0x30000:     # RADIX_ROOT_PTE
+                        # V = 1 L = 0 NLB = 0x400 NLS = 9
+           b(0x8000000000040009),
+
+           0x40000:     # RADIX_SECOND_LEVEL
+                        #         V = 1 L = 1 SW = 0 RPN = 0
+                           # R = 1 C = 1 ATT = 0 EAA 0x7
+           b(0xc000000000000187),
+
+          0x1000000:   # PROCESS_TABLE_3
+                       # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
+           b(0x40000000000300ad),
+          }
+
+    while not stop:
+        while True: # wait for dc_valid
+            if stop:
+                return
+            dc_valid = yield (dut.d_out.valid)
+            if dc_valid:
+                break
+            yield
+        addr = yield dut.d_out.addr
+        if addr not in mem:
+            print ("    DCACHE LOOKUP FAIL %x" % (addr))
+            stop = True
+            return
+
+        yield
+        data = mem[addr]
+        yield dut.d_in.data.eq(data)
+        print ("    DCACHE GET %x data %x" % (addr, data))
+        yield dut.d_in.done.eq(1)
+        yield
+        yield dut.d_in.done.eq(0)
+
+def mmu_wait(dut):
+    global stop
+    while not stop: # wait for dc_valid / err
+        l_done = yield (dut.l_out.done)
+        l_err = yield (dut.l_out.err)
+        l_badtree = yield (dut.l_out.badtree)
+        l_permerr = yield (dut.l_out.perm_error)
+        l_rc_err = yield (dut.l_out.rc_error)
+        l_segerr = yield (dut.l_out.segerr)
+        l_invalid = yield (dut.l_out.invalid)
+        if (l_done or l_err or l_badtree or
+            l_permerr or l_rc_err or l_segerr or l_invalid):
+            break
+        yield
+        yield dut.l_in.valid.eq(0) # data already in MMU by now
+        yield dut.l_in.mtspr.eq(0) # captured by RegStage(s)
+        yield dut.l_in.load.eq(0)  # can reset everything safely
+
+def mmu_sim(dut):
+    global stop
+
+    # MMU MTSPR set prtbl
+    yield dut.l_in.mtspr.eq(1)
+    yield dut.l_in.sprn[9].eq(1) # totally fake way to set SPR=prtbl
+    yield dut.l_in.rs.eq(0x1000000) # set process table
+    yield dut.l_in.valid.eq(1)
+    yield from mmu_wait(dut)
     yield
-    yield wp.wen.eq(0)
-    yield rp.ren.eq(1)
-    yield rp.raddr.eq(1)
-    yield Settle()
-    data = yield rp.data_o
-    print(data)
-    assert data == 2
+    yield dut.l_in.sprn.eq(0)
+    yield dut.l_in.rs.eq(0)
     yield
 
-    yield wp.waddr.eq(5)
-    yield rp.raddr.eq(5)
-    yield rp.ren.eq(1)
-    yield wp.wen.eq(1)
-    yield wp.data_i.eq(6)
-    yield Settle()
-    data = yield rp.data_o
-    print(data)
-    assert data == 6
-    yield
-    yield wp.wen.eq(0)
-    yield rp.ren.eq(0)
-    yield Settle()
-    data = yield rp.data_o
-    print(data)
-    assert data == 0
+    prtbl = yield (dut.rin.prtbl)
+    print ("prtbl after MTSPR %x" % prtbl)
+    assert prtbl == 0x1000000
+
+    #yield dut.rin.prtbl.eq(0x1000000) # manually set process table
+    #yield
+
+
+    # MMU PTE request
+    yield dut.l_in.load.eq(1)
+    yield dut.l_in.priv.eq(1)
+    yield dut.l_in.addr.eq(0x10000)
+    yield dut.l_in.valid.eq(1)
+    yield from mmu_wait(dut)
+
+    addr = yield dut.d_out.addr
+    pte = yield dut.d_out.pte
+    l_done = yield (dut.l_out.done)
+    l_err = yield (dut.l_out.err)
+    l_badtree = yield (dut.l_out.badtree)
+    print ("translated done %d err %d badtree %d addr %x pte %x" % \
+               (l_done, l_err, l_badtree, addr, pte))
     yield
-    data = yield rp.data_o
-    print(data)
+    yield dut.l_in.priv.eq(0)
+    yield dut.l_in.addr.eq(0)
+
+
+    stop = True
+
 
 def test_mmu():
     dut = MMU()
@@ -502,7 +631,17 @@ def test_mmu():
     with open("test_mmu.il", "w") as f:
         f.write(vl)
 
-    run_simulation(dut, mmu_sim(), vcd_name='test_mmu.vcd')
+    m = Module()
+    m.submodules.mmu = dut
+
+    # nmigen Simulation
+    sim = Simulator(m)
+    sim.add_clock(1e-6)
+
+    sim.add_sync_process(wrap(mmu_sim(dut)))
+    sim.add_sync_process(wrap(dcache_get(dut)))
+    with sim.write_vcd('test_mmu.vcd'):
+        sim.run()
 
 if __name__ == '__main__':
     test_mmu()