Add test for prtyw pseudocode
[soc.git] / src / soc / decoder / selectable_int.py
index 2e663bc92215ec985d10446cf8a4cd2bf57c4d27..ce7c2ebbb9dc585e25036ba2a62e9fd46fe01fae 100644 (file)
@@ -1,10 +1,14 @@
 import unittest
 from copy import copy
 from soc.decoder.power_fields import BitRange
-from operator import (add, sub, mul, div, mod, or_, and_, xor, neg, inv)
+from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
 
 
 def check_extsign(a, b):
+    if isinstance(b, FieldSelectableInt):
+        b = b.get_range()
+    if isinstance(b, int):
+        return SelectableInt(b, a.bits)
     if b.bits != 256:
         return b
     return SelectableInt(b.value, a.bits)
@@ -22,9 +26,17 @@ class FieldSelectableInt:
             br = _br
         self.br = br # map of indices.
 
+    def eq(self, b):
+        if isinstance(b, SelectableInt):
+            for i in range(b.bits):
+                self[i] = b[i]
+        else:
+            self.si = copy(b.si)
+            self.br = copy(b.br)
+
     def _op(self, op, b):
         vi = self.get_range()
-        vi = op(vi + b)
+        vi = op(vi, b)
         return self.merge(vi)
 
     def _op1(self, op):
@@ -33,12 +45,27 @@ class FieldSelectableInt:
         return self.merge(vi)
 
     def __getitem__(self, key):
-        key = self.br[key]
-        return self.si[key]
+        print ("getitem", key, self.br)
+        if isinstance(key, SelectableInt):
+            key = key.value
+        if isinstance(key, int):
+            key = self.br[key] # don't do POWER 1.3.4 bit-inversion
+            return self.si[key]
+        if isinstance(key, slice):
+            key = self.br[key]
+            return selectconcat(*[self.si[x] for x in key])
 
-    def __setitem__(self, key, value)
-        key = self.br[key]
-        return self.si__setitem__(key, value)
+    def __setitem__(self, key, value):
+        if isinstance(key, SelectableInt):
+            key = key.value
+        key = self.br[key] # don't do POWER 1.3.4 bit-inversion
+        if isinstance(key, int):
+            return self.si.__setitem__(key, value)
+        else:
+            if not isinstance(value, SelectableInt):
+                value = SelectableInt(value, bits=len(key))
+            for i, k in enumerate(key):
+                self.si[k] = value[i]
 
     def __negate__(self):
         return self._op1(negate)
@@ -51,7 +78,9 @@ class FieldSelectableInt:
     def __mul__(self, b):
         return self._op(mul, b)
     def __div__(self, b):
-        return self._op(div, b)
+        return self._op(truediv, b)
+    def __mod__(self, b):
+        return self._op(mod, b)
     def __and__(self, b):
         return self._op(and_, b)
     def __or__(self, b):
@@ -91,6 +120,32 @@ class FieldSelectableIntTestCase(unittest.TestCase):
         print (c)
         #self.assertEqual(c.value, a.value + b.value)
 
+    def test_select(self):
+        a = SelectableInt(0b00001111, 8)
+        br = BitRange()
+        br[0] = 0
+        br[1] = 1
+        br[2] = 4
+        br[3] = 5
+        fs = FieldSelectableInt(a, br)
+
+        self.assertEqual(fs.get_range(), 0b0011)
+
+    def test_select_range(self):
+        a = SelectableInt(0b00001111, 8)
+        br = BitRange()
+        br[0] = 0
+        br[1] = 1
+        br[2] = 4
+        br[3] = 5
+        fs = FieldSelectableInt(a, br)
+
+        self.assertEqual(fs[2:4], 0b11)
+
+        fs[0:2] = 0b10
+        self.assertEqual(fs.get_range(), 0b1011)
+        
+
 
 class SelectableInt:
     def __init__(self, value, bits):
@@ -98,6 +153,10 @@ class SelectableInt:
         self.value = value & mask
         self.bits = bits
 
+    def eq(self, b):
+        self.value = b.value
+        self.bits = b.bits
+
     def __add__(self, b):
         if isinstance(b, int):
             b = SelectableInt(b, self.bits)
@@ -112,6 +171,20 @@ class SelectableInt:
         assert b.bits == self.bits
         return SelectableInt(self.value - b.value, self.bits)
 
+    def __rsub__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(b.value - self.value, self.bits)
+
+    def __radd__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(b.value + self.value, self.bits)
+
     def __mul__(self, b):
         b = check_extsign(self, b)
         assert b.bits == self.bits
@@ -133,6 +206,7 @@ class SelectableInt:
         return SelectableInt(self.value | b.value, self.bits)
 
     def __and__(self, b):
+        print ("__and__", self, b)
         b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value & b.value, self.bits)
@@ -142,12 +216,25 @@ class SelectableInt:
         assert b.bits == self.bits
         return SelectableInt(self.value ^ b.value, self.bits)
 
+    def __rxor__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value ^ b.value, self.bits)
+
     def __invert__(self):
         return SelectableInt(~self.value, self.bits)
 
     def __neg__(self):
         return SelectableInt(~self.value + 1, self.bits)
 
+    def __lshift__(self, b):
+        b = check_extsign(self, b)
+        return SelectableInt(self.value << b.value, self.bits)
+
+    def __rshift__(self, b):
+        b = check_extsign(self, b)
+        return SelectableInt(self.value >> b.value, self.bits)
+
     def __getitem__(self, key):
         if isinstance(key, int):
             assert key < self.bits, "key %d accessing %d" % (key, self.bits)
@@ -167,7 +254,8 @@ class SelectableInt:
             stop = self.bits - key.start
             start = self.bits - key.stop
 
-            bits = stop - start + 1
+            bits = stop - start
+            #print ("__getitem__ slice num bits", bits)
             mask = (1 << bits) - 1
             value = (self.value >> start) & mask
             return SelectableInt(value, bits)
@@ -193,7 +281,8 @@ class SelectableInt:
             stop = self.bits - key.start
             start = self.bits - key.stop
 
-            bits = stop - start + 1
+            bits = stop - start
+            #print ("__setitem__ slice num bits", bits)
             if isinstance(value, SelectableInt):
                 assert value.bits == bits, "%d into %d" % (value.bits, bits)
                 value = value.value
@@ -202,42 +291,53 @@ class SelectableInt:
             self.value = (self.value & ~mask) | (value & mask)
 
     def __ge__(self, other):
+        if isinstance(other, FieldSelectableInt):
+            other = other.get_range()
         if isinstance(other, SelectableInt):
             other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
-            return other >= self.value
+            return onebit(self.value >= other.value)
         assert False
 
     def __le__(self, other):
+        if isinstance(other, FieldSelectableInt):
+            other = other.get_range()
         if isinstance(other, SelectableInt):
             other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
-            return onebit(other <= self.value)
+            return onebit(self.value <= other)
         assert False
 
     def __gt__(self, other):
+        if isinstance(other, FieldSelectableInt):
+            other = other.get_range()
         if isinstance(other, SelectableInt):
             other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
-            return onebit(other > self.value)
+            return onebit(self.value > other)
         assert False
 
     def __lt__(self, other):
+        if isinstance(other, FieldSelectableInt):
+            other = other.get_range()
         if isinstance(other, SelectableInt):
             other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
-            return onebit(other < self.value)
+            return onebit(self.value < other)
         assert False
 
     def __eq__(self, other):
+        print ("__eq__", self, other)
+        if isinstance(other, FieldSelectableInt):
+            other = other.get_range()
         if isinstance(other, SelectableInt):
             other = check_extsign(self, other)
             assert other.bits == self.bits
@@ -257,6 +357,9 @@ class SelectableInt:
         return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
                                                            self.bits)
 
+    def __len__(self):
+        return self.bits
+
 def onebit(bit):
     return SelectableInt(1 if bit else 0, 1)
 
@@ -302,6 +405,8 @@ def selectconcat(*args, repeat=1):
         args = tmp
     res = copy(args[0])
     for i in args[1:]:
+        if isinstance(i, FieldSelectableInt):
+            i = i.si
         assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
         res.bits += i.bits
         res.value = (res.value << i.bits) | i.value
@@ -353,7 +458,7 @@ class SelectableIntTestCase(unittest.TestCase):
         a[0:4] = 3
         self.assertEqual(a, 0x39)
         a[0:4] = a[4:8]
-        self.assertEqual(a, 0x199)
+        self.assertEqual(a, 0x99)
 
     def test_concat(self):
         a = SelectableInt(0x1, 1)
@@ -371,5 +476,21 @@ class SelectableIntTestCase(unittest.TestCase):
             b = eval(repr(a))
             self.assertEqual(a, b)
 
+    def test_cmp(self):
+        a = SelectableInt(10, bits=8)
+        b = SelectableInt(5, bits=8)
+        self.assertTrue(a > b)
+        self.assertFalse(a < b)
+        self.assertTrue(a != b)
+        self.assertFalse(a == b)
+
+    def test_unsigned(self):
+        a = SelectableInt(0x80, bits=8)
+        b = SelectableInt(0x7f, bits=8)
+        self.assertTrue(a > b)
+        self.assertFalse(a < b)
+        self.assertTrue(a != b)
+        self.assertFalse(a == b)
+
 if __name__ == "__main__":
     unittest.main()