radixmmu: read prtable entry
[soc.git] / src / soc / decoder / isa / radixmmu.py
index d86a9e1e40e507225731f4efdcb864a69e1ba2bd..c49a93877808854aade8346e1ddaff82eb65576d 100644 (file)
@@ -19,9 +19,11 @@ from soc.decoder.selectable_int import (FieldSelectableInt, SelectableInt,
                                         selectconcat)
 from soc.decoder.helpers import exts, gtu, ltu, undefined
 from soc.decoder.isa.mem import Mem
+from soc.consts import MSRb  # big-endian (PowerISA versions)
 
 import math
 import sys
+import unittest
 
 # very quick, TODO move to SelectableInt utils later
 def genmask(shift, size):
@@ -43,6 +45,12 @@ def rpte_valid(r):
 def rpte_leaf(r):
     return bool(r[1])
 
+## Shift address bits 61--12 right by 0--47 bits and
+## supply the least significant 16 bits of the result.
+def addrshift(addr,shift):
+    x = addr.value >> shift.value
+    return SelectableInt(x,16)
+
 def NLB(x):
     """
     Next Level Base
@@ -190,16 +198,66 @@ def NLS(x):
 
 """
 
+testmem = {
+
+           0x10000:    # PARTITION_TABLE_2 (not implemented yet)
+                       # PATB_GR=1 PRTB=0x1000 PRTS=0xb
+           0x800000000100000b,
+
+           0x30000:     # RADIX_ROOT_PTE
+                        # V = 1 L = 0 NLB = 0x400 NLS = 9
+           0x8000000000040009,
+           0x40000:     # RADIX_SECOND_LEVEL
+                        #         V = 1 L = 1 SW = 0 RPN = 0
+                           # R = 1 C = 1 ATT = 0 EAA 0x7
+           0xc000000000000187,
+
+           0x1000000:   # PROCESS_TABLE_3
+                       # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
+           0x40000000000300ad,
+          }
+
+# this one has a 2nd level RADIX with a RPN of 0x5000
+testmem2 = {
+
+           0x10000:    # PARTITION_TABLE_2 (not implemented yet)
+                       # PATB_GR=1 PRTB=0x1000 PRTS=0xb
+           0x800000000100000b,
+
+           0x30000:     # RADIX_ROOT_PTE
+                        # V = 1 L = 0 NLB = 0x400 NLS = 9
+           0x8000000000040009,
+           0x40000:     # RADIX_SECOND_LEVEL
+                        #         V = 1 L = 1 SW = 0 RPN = 0x5000
+                           # R = 1 C = 1 ATT = 0 EAA 0x7
+           0xc000000005000187,
+
+           0x1000000:   # PROCESS_TABLE_3
+                       # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
+           0x40000000000300ad,
+          }
+
+
+testresult = """
+    prtbl = 1000000
+    DCACHE GET 1000000 PROCESS_TABLE_3
+    DCACHE GET 30000 RADIX_ROOT_PTE V = 1 L = 0
+    DCACHE GET 40000 RADIX_SECOND_LEVEL V = 1 L = 1
+    DCACHE GET 10000 PARTITION_TABLE_2
+translated done 1 err 0 badtree 0 addr 40000 pte 0
+"""
+
 # see qemu/target/ppc/mmu-radix64.c for reference
 class RADIX:
     def __init__(self, mem, caller):
         self.mem = mem
         self.caller = caller
-        #TODO move to lookup
-        self.dsisr = self.caller.spr["DSISR"]
-        self.dar   = self.caller.spr["DAR"]
-        self.pidr  = self.caller.spr["PIDR"]
-        self.prtbl = self.caller.spr["PRTBL"]
+        if caller is not None:
+            self.dsisr = self.caller.spr["DSISR"]
+            self.dar   = self.caller.spr["DAR"]
+            self.pidr  = self.caller.spr["PIDR"]
+            self.prtbl = self.caller.spr["PRTBL"]
+            self.msr   = self.caller.msr
 
         # cached page table stuff
         self.pgtbl0 = 0
@@ -212,17 +270,23 @@ class RADIX:
         print("RADIX memread", addr, sz, val)
         return SelectableInt(val, sz*8)
 
-    def ld(self, address, width=8, swap=True, check_in_mem=False):
+    def ld(self, address, width=8, swap=True, check_in_mem=False,
+                 instr_fetch=False):
         print("RADIX: ld from addr 0x%x width %d" % (address, width))
 
-        mode = 'LOAD' # XXX TODO: executable load (icache)
+        priv = ~(self.msr(MSR_PR).value) # problem-state ==> privileged
+        if instr_fetch:
+            mode = 'EXECUTE'
+        else:
+            mode = 'LOAD'
         addr = SelectableInt(address, 64)
         (shift, mbits, pgbase) = self._decode_prte(addr)
         #shift = SelectableInt(0, 32)
 
-        pte = self._walk_tree(addr, pgbase, mode, mbits, shift)
-        # use pte to caclculate phys address
-        return self.mem.ld(address, width, swap, check_in_mem)
+        pte = self._walk_tree(addr, pgbase, mode, mbits, shift, priv)
+
+        # use pte to load from phys address
+        return self.mem.ld(pte.value, width, swap, check_in_mem)
 
         # XXX set SPRs on error
 
@@ -230,13 +294,14 @@ class RADIX:
     def st(self, address, v, width=8, swap=True):
         print("RADIX: st to addr 0x%x width %d data %x" % (address, width, v))
 
+        priv = ~(self.msr(MSR_PR).value) # problem-state ==> privileged
         mode = 'STORE'
         addr = SelectableInt(address, 64)
         (shift, mbits, pgbase) = self._decode_prte(addr)
-        pte = self._walk_tree(addr, pgbase, mode, mbits, shift)
+        pte = self._walk_tree(addr, pgbase, mode, mbits, shift, priv)
 
-        # use pte to caclculate phys address (addr)
-        return self.mem.st(addr.value, v, width, swap)
+        # use pte to store at phys address
+        return self.mem.st(pte.value, v, width, swap)
 
         # XXX set SPRs on error
 
@@ -244,14 +309,20 @@ class RADIX:
         print("memassign", addr, sz, val)
         self.st(addr.value, val.value, sz, swap=False)
 
-    def _next_level(self,r):
-        return rpte_valid(r), rpte_leaf(r)
-        ## DSISR_R_BADCONFIG
-        ## read_entry
-        ## DSISR_NOPTE
-        ## Prepare for next iteration
+    def _next_level(self, addr, entry_width, swap, check_in_mem):
+        # implement read access to mmu mem here
+
+        # DO NOT perform byte-swapping: load 8 bytes (that's the entry size)
+        value = self.mem.ld(addr.value, 8, False, check_in_mem)
+        assert(value is not None, "address lookup %x not found" % addr.value)
 
-    def _walk_tree(self, addr, pgbase, mode, mbits, shift):
+        print("addr", hex(addr.value))
+        data = SelectableInt(value, 64) # convert to SelectableInt
+        print("value", hex(value))
+        # index += 1
+        return data;
+
+    def _walk_tree(self, addr, pgbase, mode, mbits, shift, priv=1):
         """walk tree
 
         // vaddr                    64 Bit
@@ -327,54 +398,101 @@ class RADIX:
         p = addr[55:63]
         print("last 8 bits ----------")
         print
-        
 
         # get address of root entry
-        prtable_addr = self._get_prtable_addr(shift, prtbl, addr, pidr)
-        print("prtable_addr",prtable_addr)
+        shift = selectconcat(SelectableInt(0,1), prtbl[58:63]) # TODO verify
+        addr_next = self._get_prtable_addr(shift, prtbl, addr, pidr)
+        print("starting with prtable, addr_next",addr_next)
+
+        assert(addr_next.bits == 64)
+        assert(addr_next.value == 0x1000000) #TODO
 
-        # read root entry - imcomplete
+        # read an entry from prtable
         swap = False
         check_in_mem = False
         entry_width = 8
-        value = self.mem.ld(prtable_addr.value, entry_width, swap, check_in_mem)
-        print("value",value)
+        data = self._next_level(addr_next, entry_width, swap, check_in_mem)
+        print("pr_table",data)
+
+        # rts = shift = unsigned('0' & data(62 downto 61) & data(7 downto 5));
+        shift = selectconcat(SelectableInt(0,1), data[1:3], data[55:58])
+        assert(shift.bits==6) # variable rts : unsigned(5 downto 0);
+        print("shift",shift)
+
+        # mbits := unsigned('0' & data(4 downto 0));
+        mbits = selectconcat(SelectableInt(0,1), data[58:63])
+        assert(mbits.bits==6) #variable mbits : unsigned(5 downto 0);
+        print("mbits",mbits)
+
+        new_shift = self._segment_check(addr, mbits, shift)
+        print("new_shift",new_shift)
 
-        test_input = [
-            SelectableInt(0x8000000000000007, 64), #valid
-            SelectableInt(0xc000000000000000, 64) #exit
-        ]
-        index = 0
+        addr_next = SelectableInt(0x30000,64) # radix root for testing
+        # this needs to be calculated using the code above
 
         # walk tree starts on prtbl
         while True:
             print("nextlevel----------------------------")
-            l = test_input[index]
-            index += 1
-            valid,leaf = self._next_level(l)
-            if not leaf:
-                mbits = l[59:64]
-                print("mbits=")
-                print(mbits)
-                if mbits < 5 or mbits > 16:
-                    print("badtree")
-                    return None
-                """
-                mbits := unsigned('0' & data(4 downto 0));
-                if mbits < 5 or mbits > 16 or mbits > r.shift then
-                    v.state := RADIX_FINISH;
-                    v.badtree := '1'; -- throw error
-                else
-                    v.shift := v.shift - mbits;
-                    v.mask_size := mbits(4 downto 0);
-                    v.pgbase := data(55 downto 8) & x"00"; NLB?
-                    v.state := RADIX_LOOKUP; --> next level
-                end if;
-                """
-            print(valid)
-            print(leaf)
-            if not valid: return None
-            if leaf: return None
+            # read an entry
+            swap = False
+            check_in_mem = False
+            entry_width = 8
+
+            data = self._next_level(addr_next, entry_width, swap, check_in_mem)
+            valid = rpte_valid(data)
+            leaf = rpte_leaf(data)
+
+            print("    valid, leaf", valid, leaf)
+            if not valid:
+                return "invalid" # TODO: return error
+            if leaf:
+                print ("is leaf, checking perms")
+                ok = self._check_perms(data, priv, mode)
+                if ok == True: # data was ok, found phys address, return it?
+                    paddr = self._get_pte(addrsh, addr, data)
+                    print ("    phys addr", hex(paddr.value))
+                    return paddr
+                return ok # return the error code
+            else:
+                newlookup = self._new_lookup(data, mbits, shift)
+                if newlookup == 'badtree':
+                    return newlookup
+                shift, mask, pgbase = newlookup
+                print ("   next level", shift, mask, pgbase)
+                shift = SelectableInt(shift.value,16) #THIS is wrong !!!
+                print("calling _get_pgtable_addr")
+                print(mask)    #SelectableInt(value=0x9, bits=4)
+                print(pgbase)  #SelectableInt(value=0x40000, bits=56)
+                print(shift)   #SelectableInt(value=0x4, bits=16) #FIXME
+                pgbase = SelectableInt(pgbase.value, 64)
+                addrsh = addrshift(addr,shift)
+                addr_next = self._get_pgtable_addr(mask, pgbase, addrsh)
+                print("addr_next",addr_next)
+                print("addrsh",addrsh)
+
+    def _new_lookup(self, data, mbits, shift):
+        """
+        mbits := unsigned('0' & data(4 downto 0));
+        if mbits < 5 or mbits > 16 or mbits > r.shift then
+            v.state := RADIX_FINISH;
+            v.badtree := '1'; -- throw error
+        else
+            v.shift := v.shift - mbits;
+            v.mask_size := mbits(4 downto 0);
+            v.pgbase := data(55 downto 8) & x"00"; NLB?
+            v.state := RADIX_LOOKUP; --> next level
+        end if;
+        """
+        mbits = data[59:64]
+        print("mbits=", mbits)
+        if mbits < 5 or mbits > 16: #fixme compare with r.shift
+            print("badtree")
+            return "badtree"
+        # reduce shift (has to be done at same bitwidth)
+        shift = shift - selectconcat(SelectableInt(0, 1), mbits)
+        mask_size = mbits[1:5] # get 4 LSBs
+        pgbase = selectconcat(data[8:56], SelectableInt(0, 8)) # shift up 8
+        return shift, mask_size, pgbase
 
     def _decode_prte(self, data):
         """PRTE0 Layout
@@ -416,18 +534,18 @@ class RADIX:
         # below *directly* match the spec, unlike microwatt which
         # has to turn them around (to LE)
         mask = genmask(shift, 44)
-        nonzero = addr[1:32] & mask[13:44] # mask 31 LSBs (BE numbered 13:44)
+        nonzero = addr[2:33] & mask[13:44] # mask 31 LSBs (BE numbered 13:44)
         print ("RADIX _segment_check nonzero", bin(nonzero.value))
         print ("RADIX _segment_check addr[0-1]", addr[0].value, addr[1].value)
-        if addr[0] != addr[1] or nonzero == 1:
+        if addr[0] != addr[1] or nonzero != 0:
             return "segerror"
         limit = shift + (31 - 12)
         if mbits < 5 or mbits > 16 or mbits > limit:
-            return "badtree"
+            return "badtree mbits="+str(mbits)+" limit="+str(limit)
         new_shift = shift + (31 - 12) - mbits
         return new_shift
 
-    def _check_perms(self, data, priv, iside, store):
+    def _check_perms(self, data, priv, mode):
         """check page permissions
         // Leaf PDE                                           |
         // |------------------------------|           |----------------|
@@ -472,10 +590,17 @@ class RADIX:
                             v.rc_error := perm_ok;
                         end if;
         """
+        # decode mode into something that matches microwatt equivalent code
+        instr_fetch, store = 0, 0
+        if mode == 'STORE':
+            store = 1
+        if mode == 'EXECUTE':
+            inst_fetch = 1
+
         # check permissions and RC bits
         perm_ok = 0
         if priv == 1 or data[60] == 0:
-            if iside == 0:
+            if instr_fetch == 0:
                 perm_ok = data[62] | (data[61] & (store == 0))
             # no IAMR, so no KUEP support for now
             # deny execute permission if cache inhibited
@@ -483,6 +608,7 @@ class RADIX:
         rc_ok = data[55] & (data[56] | (store == 0))
         if perm_ok == 1 and rc_ok == 1:
             return True
+
         return "perm_err" if perm_ok == 0 else "rc_err"
 
     def _get_prtable_addr(self, shift, prtbl, addr, pid):
@@ -497,16 +623,16 @@ class RADIX:
                 (effpid(31 downto 8) and finalmask(23 downto 0))) &
                 effpid(7 downto 0) & "0000";
         """
-        print ("_get_prtable_addr_", shift, prtbl, addr, pid)
+        print ("_get_prtable_addr", shift, prtbl, addr, pid)
         finalmask = genmask(shift, 44)
         finalmask24 = finalmask[20:44]
         if addr[0].value == 1:
             effpid = SelectableInt(0, 32)
         else:
             effpid = pid #self.pid # TODO, check on this
-        zero16 = SelectableInt(0, 16)
+        zero8 = SelectableInt(0, 8)
         zero4 = SelectableInt(0, 4)
-        res = selectconcat(zero16,
+        res = selectconcat(zero8,
                            prtbl[8:28],                        #
                            (prtbl[28:52] & ~finalmask24) |     #
                            (effpid[0:24] & finalmask24),       #
@@ -526,7 +652,7 @@ class RADIX:
         zero3 = SelectableInt(0, 3)
         res = selectconcat(zero8,
                            pgbase[8:45],              #
-                           (prtbl[45:61] & ~mask16) | #
+                           (pgbase[45:61] & ~mask16) | #
                            (addrsh       & mask16),   #
                            zero3
                            )
@@ -539,56 +665,164 @@ class RADIX:
          (r.addr(55 downto 12) and finalmask))
         & r.pde(11 downto 0);
         """
+        shift.value = 12
         finalmask = genmask(shift, 44)
         zero8 = SelectableInt(0, 8)
+        rpn = pde[8:52]       # RPN = Real Page Number
+        abits = addr[8:52] # non-masked address bits
+        print("     get_pte RPN", hex(rpn.value))
+        print("             abits", hex(abits.value))
+        print("             shift", shift.value)
+        print("             finalmask", bin(finalmask.value))
         res = selectconcat(zero8,
-                           (pde[8:52]  & ~finalmask) | #
-                           (addr[8:52] & finalmask),   #
-                           pde[52:64],
+                           (rpn  & ~finalmask) | #
+                           (abits & finalmask),   #
+                           addr[52:64],
                            )
         return res
 
+class TestRadixMMU(unittest.TestCase):
+
+    def test_genmask(self):
+        shift = SelectableInt(5, 6)
+        mask = genmask(shift, 43)
+        print ("    mask", bin(mask.value))
+
+        self.assertEqual(mask.value, 0b11111, "mask should be 5 1s")
+
+    def test_get_pgtable_addr(self):
+
+        mem = None
+        caller = None
+        dut = RADIX(mem, caller)
+
+        mask_size=4
+        pgbase = SelectableInt(0,64)
+        addrsh = SelectableInt(0,16)
+        ret = dut._get_pgtable_addr(mask_size, pgbase, addrsh)
+        print("ret=", ret)
+        self.assertEqual(ret, 0, "pgtbl_addr should be 0")
+
+    def test_walk_tree_1(self):
+
+        # test address as in
+        # https://github.com/power-gem5/gem5/blob/gem5-experimental/src/arch/power/radix_walk_example.txt#L65
+        testaddr = 0x1000
+        expected = 0x1000
+
+        # starting prtbl
+        prtbl = 0x1000000
+
+        # set up dummy minimal ISACaller
+        spr = {'DSISR': SelectableInt(0, 64),
+               'DAR': SelectableInt(0, 64),
+               'PIDR': SelectableInt(0, 64),
+               'PRTBL': SelectableInt(prtbl, 64)
+        }
+        # set problem state == 0 (other unit tests, set to 1)
+        msr = SelectableInt(0, 64)
+        msr[MSRb.PR] = 0
+        class ISACaller: pass
+        caller = ISACaller()
+        caller.spr = spr
+        caller.msr = msr
+
+        shift = SelectableInt(5, 6)
+        mask = genmask(shift, 43)
+        print ("    mask", bin(mask.value))
+
+        mem = Mem(row_bytes=8, initial_mem=testmem)
+        mem = RADIX(mem, caller)
+        # -----------------------------------------------
+        # |/|RTS1|/|     RPDB          | RTS2 |  RPDS   |
+        # -----------------------------------------------
+        # |0|1  2|3|4                55|56  58|59     63|
+        data = SelectableInt(0, 64)
+        data[1:3] = 0b01
+        data[56:59] = 0b11
+        data[59:64] = 0b01101 # mask
+        data[55] = 1
+        (rts, mbits, pgbase) = mem._decode_prte(data)
+        print ("    rts", bin(rts.value), rts.bits)
+        print ("    mbits", bin(mbits.value), mbits.bits)
+        print ("    pgbase", hex(pgbase.value), pgbase.bits)
+        addr = SelectableInt(0x1000, 64)
+        check = mem._segment_check(addr, mbits, shift)
+        print ("    segment check", check)
+
+        print("walking tree")
+        addr = SelectableInt(testaddr,64)
+        # pgbase = None
+        mode = None
+        #mbits = None
+        shift = rts
+        result = mem._walk_tree(addr, pgbase, mode, mbits, shift)
+        print("     walking tree result", result)
+        print("should be", testresult)
+        self.assertEqual(result.value, expected,
+                             "expected 0x%x got 0x%x" % (expected,
+                                                    result.value))
+
+
+    def test_walk_tree_2(self):
+
+        # test address slightly different
+        testaddr = 0x1101
+        expected = 0x5001101
+
+        # starting prtbl
+        prtbl = 0x1000000
+
+        # set up dummy minimal ISACaller
+        spr = {'DSISR': SelectableInt(0, 64),
+               'DAR': SelectableInt(0, 64),
+               'PIDR': SelectableInt(0, 64),
+               'PRTBL': SelectableInt(prtbl, 64)
+        }
+        # set problem state == 0 (other unit tests, set to 1)
+        msr = SelectableInt(0, 64)
+        msr[MSRb.PR] = 0
+        class ISACaller: pass
+        caller = ISACaller()
+        caller.spr = spr
+        caller.msr = msr
+
+        shift = SelectableInt(5, 6)
+        mask = genmask(shift, 43)
+        print ("    mask", bin(mask.value))
+
+        mem = Mem(row_bytes=8, initial_mem=testmem2)
+        mem = RADIX(mem, caller)
+        # -----------------------------------------------
+        # |/|RTS1|/|     RPDB          | RTS2 |  RPDS   |
+        # -----------------------------------------------
+        # |0|1  2|3|4                55|56  58|59     63|
+        data = SelectableInt(0, 64)
+        data[1:3] = 0b01
+        data[56:59] = 0b11
+        data[59:64] = 0b01101 # mask
+        data[55] = 1
+        (rts, mbits, pgbase) = mem._decode_prte(data)
+        print ("    rts", bin(rts.value), rts.bits)
+        print ("    mbits", bin(mbits.value), mbits.bits)
+        print ("    pgbase", hex(pgbase.value), pgbase.bits)
+        addr = SelectableInt(0x1000, 64)
+        check = mem._segment_check(addr, mbits, shift)
+        print ("    segment check", check)
+
+        print("walking tree")
+        addr = SelectableInt(testaddr,64)
+        # pgbase = None
+        mode = None
+        #mbits = None
+        shift = rts
+        result = mem._walk_tree(addr, pgbase, mode, mbits, shift)
+        print("     walking tree result", result)
+        print("should be", testresult)
+        self.assertEqual(result.value, expected,
+                             "expected 0x%x got 0x%x" % (expected,
+                                                    result.value))
+
 
-# very quick test of maskgen function (TODO, move to util later)
 if __name__ == '__main__':
-    # set up dummy minimal ISACaller
-    spr = {'DSISR': SelectableInt(0, 64),
-           'DAR': SelectableInt(0, 64),
-           'PIDR': SelectableInt(0, 64),
-           'PRTBL': SelectableInt(0, 64)
-    }
-    class ISACaller: pass
-    caller = ISACaller()
-    caller.spr = spr
-
-    shift = SelectableInt(5, 6)
-    mask = genmask(shift, 43)
-    print ("    mask", bin(mask.value))
-
-    mem = Mem(row_bytes=8)
-    mem = RADIX(mem, caller)
-    # -----------------------------------------------
-    # |/|RTS1|/|     RPDB          | RTS2 |  RPDS   |
-    # -----------------------------------------------
-    # |0|1  2|3|4                55|56  58|59     63|
-    data = SelectableInt(0, 64)
-    data[1:3] = 0b01
-    data[56:59] = 0b11
-    data[59:64] = 0b01101 # mask
-    data[55] = 1
-    (rts, mbits, pgbase) = mem._decode_prte(data)
-    print ("    rts", bin(rts.value), rts.bits)
-    print ("    mbits", bin(mbits.value), mbits.bits)
-    print ("    pgbase", hex(pgbase.value), pgbase.bits)
-    addr = SelectableInt(0x1000, 64)
-    check = mem._segment_check(addr, mbits, shift)
-    print ("    segment check", check)
-
-    print("walking tree")
-    # addr = unchanged
-    # pgbase = None 
-    mode = None
-    #mbits = None
-    shift = rts
-    result = mem._walk_tree(addr, pgbase, mode, mbits, shift)
-    print(result)
+    unittest.main()