make addrshift human readable
[soc.git] / src / soc / decoder / isa / radixmmu.py
index 4c74e9f644b842b018d2dd17b58807745c329182..10b95ff084e51e04229bac6d1c7d21c24394427b 100644 (file)
@@ -19,9 +19,11 @@ from soc.decoder.selectable_int import (FieldSelectableInt, SelectableInt,
                                         selectconcat)
 from soc.decoder.helpers import exts, gtu, ltu, undefined
 from soc.decoder.isa.mem import Mem
+from soc.consts import MSRb  # big-endian (PowerISA versions)
 
 import math
 import sys
+import unittest
 
 # very quick, TODO move to SelectableInt utils later
 def genmask(shift, size):
@@ -43,6 +45,12 @@ def rpte_valid(r):
 def rpte_leaf(r):
     return bool(r[1])
 
+## Shift address bits 61--12 right by 0--47 bits and
+## supply the least significant 16 bits of the result.
+def addrshift(addr,shift):
+    x = addr.value >> shift.value
+    return SelectableInt(x,16)
+
 def NLB(x):
     """
     Next Level Base
@@ -190,16 +198,40 @@ def NLS(x):
 
 """
 
+testaddr = 0x10000
+testmem = {
+
+           0x10000:    # PARTITION_TABLE_2 (not implemented yet)
+                       # PATB_GR=1 PRTB=0x1000 PRTS=0xb
+           0x800000000100000b,
+
+           0x30000:     # RADIX_ROOT_PTE
+                        # V = 1 L = 0 NLB = 0x400 NLS = 9
+           0x8000000000040009,
+########   0x4000000 #### wrong address calculated by _get_pgtable_addr
+           0x40000:     # RADIX_SECOND_LEVEL
+                        #         V = 1 L = 1 SW = 0 RPN = 0
+                           # R = 1 C = 1 ATT = 0 EAA 0x7
+           0xc000000000000187,
+
+           0x1000000:   # PROCESS_TABLE_3
+                       # RTS1 = 0x2 RPDB = 0x300 RTS2 = 0x5 RPDS = 13
+           0x40000000000300ad,
+          }
+          
+
+
 # see qemu/target/ppc/mmu-radix64.c for reference
 class RADIX:
     def __init__(self, mem, caller):
         self.mem = mem
         self.caller = caller
-        #TODO move to lookup
-        self.dsisr = self.caller.spr["DSISR"]
-        self.dar   = self.caller.spr["DAR"]
-        self.pidr  = self.caller.spr["PIDR"]
-        self.prtbl = self.caller.spr["PRTBL"]
+        if caller is not None:
+            self.dsisr = self.caller.spr["DSISR"]
+            self.dar   = self.caller.spr["DAR"]
+            self.pidr  = self.caller.spr["PIDR"]
+            self.prtbl = self.caller.spr["PRTBL"]
+            self.msr   = self.caller.msr
 
         # cached page table stuff
         self.pgtbl0 = 0
@@ -216,7 +248,7 @@ class RADIX:
                  instr_fetch=False):
         print("RADIX: ld from addr 0x%x width %d" % (address, width))
 
-        priv = 1 # XXX TODO: read MSR PR bit here priv = not ctrl.msr(MSR_PR);
+        priv = ~(self.msr(MSR_PR).value) # problem-state ==> privileged
         if instr_fetch:
             mode = 'EXECUTE'
         else:
@@ -235,7 +267,7 @@ class RADIX:
     def st(self, address, v, width=8, swap=True):
         print("RADIX: st to addr 0x%x width %d data %x" % (address, width, v))
 
-        priv = 1 # XXX TODO: read MSR PR bit here priv = not ctrl.msr(MSR_PR);
+        priv = ~(self.msr(MSR_PR).value) # problem-state ==> privileged
         mode = 'STORE'
         addr = SelectableInt(address, 64)
         (shift, mbits, pgbase) = self._decode_prte(addr)
@@ -252,10 +284,17 @@ class RADIX:
 
     def _next_level(self, addr, entry_width, swap, check_in_mem):
         # implement read access to mmu mem here
-        value = self.mem.ld(addr.value, entry_width, swap, check_in_mem)
-        print("addr", addr.value)
+
+        value = 0
+        if addr.value in testmem:
+            value = testmem[addr.value]
+        else:
+            print("not found")
+
+        ##value = self.mem.ld(addr.value, entry_width, swap, check_in_mem)
+        print("addr", hex(addr.value))
         data = SelectableInt(value, 64) # convert to SelectableInt
-        print("value", value)
+        print("value", hex(value))
         # index += 1
         return data;
 
@@ -339,11 +378,7 @@ class RADIX:
         # get address of root entry
         addr_next = self._get_prtable_addr(shift, prtbl, addr, pidr)
 
-        #test_input = [
-        #    SelectableInt(0x8000000000000007, 64), #valid
-        #    SelectableInt(0xc000000000000000, 64) #exit
-        #]
-        #index = 0
+        addr_next = SelectableInt(0x30000,64) # radix root for testing
 
         # walk tree starts on prtbl
         while True:
@@ -371,7 +406,18 @@ class RADIX:
                     return newlookup
                 shift, mask, pgbase = newlookup
                 print ("   next level", shift, mask, pgbase)
-                addr_next = self._get_pgtable_addr(mask, pgbase, shift)
+                shift = SelectableInt(shift.value,16) #THIS is wrong !!!
+                print("calling _get_pgtable_addr")
+                print(mask)    #SelectableInt(value=0x9, bits=4)
+                print(pgbase)  #SelectableInt(value=0x40000, bits=56)
+                print(shift)   #SelectableInt(value=0x4, bits=16) #FIXME
+                pgbase = SelectableInt(pgbase.value,64)
+                addrsh = addrshift(addr,shift)
+                addr_next = self._get_pgtable_addr(mask, pgbase, addrsh)
+                print("addr_next",addr_next)
+                print("addrsh",addrsh)
+                assert(addr_next == 0x40000)
+                return "TODO verify next level"
 
     def _new_lookup(self, data, mbits, shift):
         """
@@ -555,7 +601,7 @@ class RADIX:
         zero3 = SelectableInt(0, 3)
         res = selectconcat(zero8,
                            pgbase[8:45],              #
-                           (prtbl[45:61] & ~mask16) | #
+                           (pgbase[45:61] & ~mask16) | #
                            (addrsh       & mask16),   #
                            zero3
                            )
@@ -578,46 +624,75 @@ class RADIX:
         return res
 
 
-# very quick test of maskgen function (TODO, move to util later)
+class TestRadixMMU(unittest.TestCase):
+
+    def test_genmask(self):
+        shift = SelectableInt(5, 6)
+        mask = genmask(shift, 43)
+        print ("    mask", bin(mask.value))
+
+        self.assertEqual(sum([1, 2, 3]), 6, "Should be 6")
+
+    def test_get_pgtable_addr(self):
+
+        mem = None
+        caller = None
+        dut = RADIX(mem, caller)
+
+        mask_size=4
+        pgbase = SelectableInt(0,64)
+        addrsh = SelectableInt(0,16)
+        ret = dut._get_pgtable_addr(mask_size, pgbase, addrsh)
+        print("ret=",ret)
+        assert(ret==0)
+
+    def test_walk_tree(self):
+        # set up dummy minimal ISACaller
+        spr = {'DSISR': SelectableInt(0, 64),
+               'DAR': SelectableInt(0, 64),
+               'PIDR': SelectableInt(0, 64),
+               'PRTBL': SelectableInt(0, 64)
+        }
+        # set problem state == 0 (other unit tests, set to 1)
+        msr = SelectableInt(0, 64)
+        msr[MSRb.PR] = 0
+        class ISACaller: pass
+        caller = ISACaller()
+        caller.spr = spr
+        caller.msr = msr
+
+        shift = SelectableInt(5, 6)
+        mask = genmask(shift, 43)
+        print ("    mask", bin(mask.value))
+
+        mem = Mem(row_bytes=8)
+        mem = RADIX(mem, caller)
+        # -----------------------------------------------
+        # |/|RTS1|/|     RPDB          | RTS2 |  RPDS   |
+        # -----------------------------------------------
+        # |0|1  2|3|4                55|56  58|59     63|
+        data = SelectableInt(0, 64)
+        data[1:3] = 0b01
+        data[56:59] = 0b11
+        data[59:64] = 0b01101 # mask
+        data[55] = 1
+        (rts, mbits, pgbase) = mem._decode_prte(data)
+        print ("    rts", bin(rts.value), rts.bits)
+        print ("    mbits", bin(mbits.value), mbits.bits)
+        print ("    pgbase", hex(pgbase.value), pgbase.bits)
+        addr = SelectableInt(0x1000, 64)
+        check = mem._segment_check(addr, mbits, shift)
+        print ("    segment check", check)
+
+        print("walking tree")
+        addr = SelectableInt(testaddr,64)
+        # pgbase = None
+        mode = None
+        #mbits = None
+        shift = rts
+        result = mem._walk_tree(addr, pgbase, mode, mbits, shift)
+        print("     walking tree result", result)
+
+
 if __name__ == '__main__':
-    # set up dummy minimal ISACaller
-    spr = {'DSISR': SelectableInt(0, 64),
-           'DAR': SelectableInt(0, 64),
-           'PIDR': SelectableInt(0, 64),
-           'PRTBL': SelectableInt(0, 64)
-    }
-    class ISACaller: pass
-    caller = ISACaller()
-    caller.spr = spr
-
-    shift = SelectableInt(5, 6)
-    mask = genmask(shift, 43)
-    print ("    mask", bin(mask.value))
-
-    mem = Mem(row_bytes=8)
-    mem = RADIX(mem, caller)
-    # -----------------------------------------------
-    # |/|RTS1|/|     RPDB          | RTS2 |  RPDS   |
-    # -----------------------------------------------
-    # |0|1  2|3|4                55|56  58|59     63|
-    data = SelectableInt(0, 64)
-    data[1:3] = 0b01
-    data[56:59] = 0b11
-    data[59:64] = 0b01101 # mask
-    data[55] = 1
-    (rts, mbits, pgbase) = mem._decode_prte(data)
-    print ("    rts", bin(rts.value), rts.bits)
-    print ("    mbits", bin(mbits.value), mbits.bits)
-    print ("    pgbase", hex(pgbase.value), pgbase.bits)
-    addr = SelectableInt(0x1000, 64)
-    check = mem._segment_check(addr, mbits, shift)
-    print ("    segment check", check)
-
-    print("walking tree")
-    # addr = unchanged
-    # pgbase = None
-    mode = None
-    #mbits = None
-    shift = rts
-    result = mem._walk_tree(addr, pgbase, mode, mbits, shift)
-    print("     walking tree result", result)
+    unittest.main()