code-cleanup in radixmmu
[soc.git] / src / soc / decoder / isa / radixmmu.py
index 1524e949655539f5cdb95fe2b349c5bad7a3ff76..6fafc63afeeb4d0cf5ec175d9e1cc902316a8cb4 100644 (file)
@@ -73,7 +73,7 @@ def NLB(x):
 
 def NLS(x):
     """
-    Next Level Size
+    Next Level Size (PATS and RPDS in same bits btw)
     NLS >= 5
     """
     return x[59:64] # python numbering end+1
@@ -301,10 +301,7 @@ class RADIX:
         else:
             mode = 'LOAD'
         addr = SelectableInt(address, 64)
-        (shift, mbits, pgbase) = self._decode_prte(addr)
-        #shift = SelectableInt(0, 32)
-
-        pte = self._walk_tree(addr, pgbase, mode, mbits, shift, priv)
+        pte = self._walk_tree(addr, mode, priv)
 
         if type(pte)==str:
             print("error on load",pte)
@@ -322,8 +319,7 @@ class RADIX:
         priv = ~(self.msr[MSRb.PR].value) # problem-state ==> privileged
         mode = 'STORE'
         addr = SelectableInt(address, 64)
-        (shift, mbits, pgbase) = self._decode_prte(addr)
-        pte = self._walk_tree(addr, pgbase, mode, mbits, shift, priv)
+        pte = self._walk_tree(addr, mode, priv)
 
         # use pte to store at phys address
         return self.mem.st(pte.value, v, width, swap)
@@ -349,7 +345,7 @@ class RADIX:
         # index += 1
         return data;
 
-    def _walk_tree(self, addr, pgbase, mode, mbits, shift, priv=1):
+    def _walk_tree(self, addr, mode, priv=1):
         """walk tree
 
         // vaddr                    64 Bit
@@ -427,21 +423,24 @@ class RADIX:
         print
 
         # get address of root entry
-        shift = selectconcat(SelectableInt(0,1), prtbl[58:63]) # TODO verify
+        # need to fetch process table entry
+        # v.shift := unsigned('0' & r.prtbl(4 downto 0));
+        shift = selectconcat(SelectableInt(0, 1), NLS(prtbl))
         addr_next = self._get_prtable_addr(shift, prtbl, addr, pidr)
-        print("starting with prtable, addr_next",addr_next)
+        print("starting with prtable, addr_next", addr_next)
 
         assert(addr_next.bits == 64)
         #only for first unit tests assert(addr_next.value == 0x1000000)
 
-        # read an entry from prtable
+        # read an entry from prtable, decode PTRE
         swap = False
         check_in_mem = False
         entry_width = 8
         data = self._next_level(addr_next, entry_width, swap, check_in_mem)
         print("pr_table", data)
         pgtbl = data # this is cached in microwatt (as v.pgtbl3 / v.pgtbl0)
-        shift, mbits = self._get_rts_nls(pgtbl)
+        (rts, mbits, pgbase) = self._decode_prte(pgtbl)
+        print("pgbase", pgbase)
 
         # WIP
         if mbits == 0:
@@ -460,24 +459,16 @@ class RADIX:
         shift = self._segment_check(addr, mbits, shift)
         print("shift", shift)
 
-        # v.pgbase := pgtbl(55 downto 8) & x"00";
-        # see test_RPDB for reference
-        zero8 = SelectableInt(0, 8)
-
-        pgbase = selectconcat(zero8, RPDB(pgtbl), zero8)
-        print("pgbase",pgbase)
-        #assert(pgbase.value==0x30000)
-
-        if type(addr) == str:
+        if isinstance(addr, str):
             return addr
-        if type(shift) == str:
+        if isinstance(shift, str):
             return shift
 
-        addrsh = addrshift(addr,shift)
+        addrsh = addrshift(addr, shift)
         print("addrsh",addrsh)
 
         addr_next = self._get_pgtable_addr(mask_size, pgbase, addrsh)
-        print("DONE addr_next",addr_next)
+        print("DONE addr_next", addr_next)
 
         # walk tree
         while True:
@@ -519,6 +510,13 @@ class RADIX:
                 print("addr_next",addr_next)
                 print("addrsh",addrsh)
 
+    def _get_pgbase(self, data):
+        """
+            v.pgbase := data(55 downto 8) & x"00"; NLB?
+        """
+        zero8 = SelectableInt(0, 8)
+        return selectconcat(zero8, data[8:56], zero8) # shift up 8
+
     def _new_lookup(self, data, shift):
         """
         mbits := unsigned('0' & data(4 downto 0));
@@ -532,15 +530,16 @@ class RADIX:
             v.state := RADIX_LOOKUP; --> next level
         end if;
         """
-        mbits = data[59:64]
+        mbits = selectconcat(SelectableInt(0, 1), NLS(data))
         print("mbits=", mbits)
         if mbits < 5 or mbits > 16: #fixme compare with r.shift
             print("badtree")
             return "badtree"
         # reduce shift (has to be done at same bitwidth)
-        shift = shift - selectconcat(SelectableInt(0, 1), mbits)
-        mask_size = mbits[1:5] # get 4 LSBs
-        pgbase = selectconcat(data[8:56], SelectableInt(0, 8)) # shift up 8
+        shift = shift - mbits
+        assert mbits.bits == 6
+        mask_size = mbits[2:6] # get 4 LSBs from 6-bit (using MSB0 numbering)
+        pgbase = self._get_pgbase(data)
         return shift, mask_size, pgbase
 
     def _decode_prte(self, data):
@@ -553,12 +552,8 @@ class RADIX:
         # note that SelectableInt does big-endian!  so the indices
         # below *directly* match the spec, unlike microwatt which
         # has to turn them around (to LE)
-        zero = SelectableInt(0, 1)
-        rts = RTS(data)
-        masksize = data[59:64]               # RPDS
-        mbits = selectconcat(zero, masksize)
-        pgbase = selectconcat(data[8:56],  # part of RPDB
-                             SelectableInt(0, 16),)
+        rts, mbits = self._get_rts_nls(data)
+        pgbase = self._get_pgbase(data)
 
         return (rts, mbits, pgbase)
 
@@ -741,6 +736,7 @@ class RADIX:
                            )
         return res
 
+
 class TestRadixMMU(unittest.TestCase):
 
     def test_genmask(self):
@@ -760,7 +756,6 @@ class TestRadixMMU(unittest.TestCase):
         result = selectconcat(rtdb,SelectableInt(0,8))
         print("result",result)
 
-
     def test_get_pgtable_addr(self):
 
         mem = None
@@ -827,14 +822,13 @@ class TestRadixMMU(unittest.TestCase):
         mode = None
         #mbits = None
         shift = rts
-        result = mem._walk_tree(addr, pgbase, mode, mbits, shift)
+        result = mem._walk_tree(addr, mode)
         print("     walking tree result", result)
         print("should be", testresult)
         self.assertEqual(result.value, expected,
                              "expected 0x%x got 0x%x" % (expected,
                                                     result.value))
 
-
     def test_walk_tree_2(self):
 
         # test address slightly different
@@ -887,7 +881,7 @@ class TestRadixMMU(unittest.TestCase):
         mode = None
         #mbits = None
         shift = rts
-        result = mem._walk_tree(addr, pgbase, mode, mbits, shift)
+        result = mem._walk_tree(addr, mode)
         print("     walking tree result", result)
         print("should be", testresult)
         self.assertEqual(result.value, expected,