Add test for cmpeqb
[soc.git] / src / soc / decoder / selectable_int.py
index b6414f90d9483d5a4ecdbe02b965aae3ca9ffb5b..e09d85957e345f2fddd74b9460ba4598e53fd76a 100644 (file)
@@ -48,14 +48,24 @@ class FieldSelectableInt:
         print ("getitem", key, self.br)
         if isinstance(key, SelectableInt):
             key = key.value
-        key = self.br[key] # don't do POWER 1.3.4 bit-inversion
-        return self.si[key]
+        if isinstance(key, int):
+            key = self.br[key] # don't do POWER 1.3.4 bit-inversion
+            return self.si[key]
+        if isinstance(key, slice):
+            key = self.br[key]
+            return selectconcat(*[self.si[x] for x in key])
 
     def __setitem__(self, key, value):
         if isinstance(key, SelectableInt):
             key = key.value
         key = self.br[key] # don't do POWER 1.3.4 bit-inversion
-        return self.si.__setitem__(key, value)
+        if isinstance(key, int):
+            return self.si.__setitem__(key, value)
+        else:
+            if not isinstance(value, SelectableInt):
+                value = SelectableInt(value, bits=len(key))
+            for i, k in enumerate(key):
+                self.si[k] = value[i]
 
     def __negate__(self):
         return self._op1(negate)
@@ -110,9 +120,37 @@ class FieldSelectableIntTestCase(unittest.TestCase):
         print (c)
         #self.assertEqual(c.value, a.value + b.value)
 
+    def test_select(self):
+        a = SelectableInt(0b00001111, 8)
+        br = BitRange()
+        br[0] = 0
+        br[1] = 1
+        br[2] = 4
+        br[3] = 5
+        fs = FieldSelectableInt(a, br)
+
+        self.assertEqual(fs.get_range(), 0b0011)
+
+    def test_select_range(self):
+        a = SelectableInt(0b00001111, 8)
+        br = BitRange()
+        br[0] = 0
+        br[1] = 1
+        br[2] = 4
+        br[3] = 5
+        fs = FieldSelectableInt(a, br)
+
+        self.assertEqual(fs[2:4], 0b11)
+
+        fs[0:2] = 0b10
+        self.assertEqual(fs.get_range(), 0b1011)
+        
+
 
 class SelectableInt:
     def __init__(self, value, bits):
+        if isinstance(value, SelectableInt):
+            value = value.value
         mask = (1 << bits) - 1
         self.value = value & mask
         self.bits = bits
@@ -135,6 +173,20 @@ class SelectableInt:
         assert b.bits == self.bits
         return SelectableInt(self.value - b.value, self.bits)
 
+    def __rsub__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(b.value - self.value, self.bits)
+
+    def __radd__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(b.value + self.value, self.bits)
+
     def __mul__(self, b):
         b = check_extsign(self, b)
         assert b.bits == self.bits
@@ -166,12 +218,25 @@ class SelectableInt:
         assert b.bits == self.bits
         return SelectableInt(self.value ^ b.value, self.bits)
 
+    def __rxor__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value ^ b.value, self.bits)
+
     def __invert__(self):
         return SelectableInt(~self.value, self.bits)
 
     def __neg__(self):
         return SelectableInt(~self.value + 1, self.bits)
 
+    def __lshift__(self, b):
+        b = check_extsign(self, b)
+        return SelectableInt(self.value << b.value, self.bits)
+
+    def __rshift__(self, b):
+        b = check_extsign(self, b)
+        return SelectableInt(self.value >> b.value, self.bits)
+
     def __getitem__(self, key):
         if isinstance(key, int):
             assert key < self.bits, "key %d accessing %d" % (key, self.bits)