add segment_check function, plus quick test.
[soc.git] / src / soc / decoder / isa / caller.py
index ecc2a3beeca903b73e85ea879e59c00cf2bbddcc..a90ebed7646c60127c58f998e5abd7d8bc4df25e 100644 (file)
@@ -75,6 +75,14 @@ def create_args(reglist, extra=None):
     return retval
 
 
+# very quick, TODO move to SelectableInt utils later
+def genmask(shift, size):
+    res = SelectableInt(0, size)
+    for i in range(size):
+        if i < shift:
+            res[size-1-i] = SelectableInt(1, 1)
+    return res
+
 """
     Get Root Page
 
@@ -321,16 +329,18 @@ class RADIX:
             0 1  2 3 4                55 56  58 59      63
         """
         zero = SelectableInt(0, 1)
-        rts = selectconcat(data[5:8],      # [56-58] - RTS2
-                           data[61:63],    # [1-2]   - RTS1
-                           zero)
-        masksize = data[0:5]               # [59-63] - RPDS
-        mbits = selectconcat(masksize, zero)
-        pgbase = selectconcat(SelectableInt(0, 16),
-                              data[8:56])  # [8-55] - part of RPDB
+        rts = selectconcat(zero,
+                           data[56:59],      # RTS2
+                           data[1:3],        # RTS1
+                           )
+        masksize = data[59:64]               # RPDS
+        mbits = selectconcat(zero, masksize)
+        pgbase = selectconcat(data[8:56],  # part of RPDB
+                             SelectableInt(0, 16),)
+
         return (rts, mbits, pgbase)
 
-    def _segment_check(self):
+    def _segment_check(self, addr, mbits, shift):
         """checks segment valid
                     mbits := '0' & r.mask_size;
             v.shift := r.shift + (31 - 12) - mbits;
@@ -344,6 +354,17 @@ class RADIX:
             else
                 v.state := RADIX_LOOKUP;
         """
+        mask = genmask(shift, 43)
+        nonzero = addr[1:32] & mask[12:43]
+        print ("RADIX _segment_check nonzero", bin(nonzero.value))
+        print ("RADIX _segment_check addr[0-1]", addr[0].value, addr[1].value)
+        if addr[0] != addr[1] or nonzero == 1:
+            return "segerror"
+        limit = shift + (31 - 12)
+        if mbits < 5 or mbits > 16 or mbits > limit:
+            return "badtree"
+        new_shift = shift + (31 - 12) - mbits
+        return new_shift
 
     def _check_perms(self):
         """check page permissions
@@ -1418,3 +1439,29 @@ def inject():
         return decorator
 
     return variable_injector
+
+
+# very quick test of maskgen function (TODO, move to util later)
+if __name__ == '__main__':
+    shift = SelectableInt(5, 6)
+    mask = genmask(shift, 43)
+    print ("    mask", bin(mask.value))
+
+    mem = Mem(row_bytes=8)
+    mem = RADIX(mem, None)
+    # -----------------------------------------------
+    # |/|RTS1|/|     RPDB          | RTS2 |  RPDS   |
+    # -----------------------------------------------
+    # |0|1  2|3|4                55|56  58|59     63|
+    data = SelectableInt(0, 64)
+    data[1:3] = 0b01
+    data[56:59] = 0b11
+    data[59:64] = 0b01101 # mask
+    data[55] = 1
+    (rts, mbits, pgbase) = mem._decode_prte(data)
+    print ("    rts", bin(rts.value), rts.bits)
+    print ("    mbits", bin(mbits.value), mbits.bits)
+    print ("    pgbase", hex(pgbase.value), pgbase.bits)
+    addr = SelectableInt(0x1000, 64)
+    check = mem._segment_check(addr, mbits, shift)
+    print ("    segment check", check)