add selectconcat test
[soc.git] / src / soc / decoder / selectable_int.py
index ad3bfdc18d9ca3cb80d09e99af3f1ee2a1e7632c..fb064fcb383bd54a5bf35bc4d445c2e8fc6a0f10 100644 (file)
@@ -1,18 +1,71 @@
 import unittest
+from copy import copy
+
+def check_extsign(a, b):
+    if b.bits != 256:
+        return b
+    return SelectableInt(b.value, a.bits)
 
 
 class SelectableInt:
     def __init__(self, value, bits):
-        self.value = value
+        mask = (1 << bits) - 1
+        self.value = value & mask
         self.bits = bits
 
     def __add__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value + b.value, self.bits)
 
+    def __sub__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value - b.value, self.bits)
+
+    def __mul__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value * b.value, self.bits)
+
+    def __div__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value / b.value, self.bits)
+
+    def __mod__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value % b.value, self.bits)
+
+    def __or__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value | b.value, self.bits)
+
+    def __and__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value & b.value, self.bits)
+
+    def __xor__(self, b):
+        b = check_extsign(self, b)
+        assert b.bits == self.bits
+        return SelectableInt(self.value ^ b.value, self.bits)
+
+    def __invert__(self):
+        return SelectableInt(~self.value, self.bits)
+
+    def __neg__(self):
+        return SelectableInt(~self.value + 1, self.bits)
+
     def __getitem__(self, key):
         if isinstance(key, int):
-            assert key < self.bits
+            assert key < self.bits, "key %d accessing %d" % (key, self.bits)
             assert key >= 0
             key = self.bits - (key + 1)
 
@@ -27,7 +80,7 @@ class SelectableInt:
             stop = self.bits - key.start
             start = self.bits - key.stop
 
-            bits = stop - start
+            bits = stop - start + 1
             mask = (1 << bits) - 1
             value = (self.value >> start) & mask
             return SelectableInt(value, bits)
@@ -53,40 +106,152 @@ class SelectableInt:
             stop = self.bits - key.start
             start = self.bits - key.stop
 
-            bits = stop - start
+            bits = stop - start + 1
             if isinstance(value, SelectableInt):
-                assert value.bits == bits
+                assert value.bits == bits, "%d into %d" % (value.bits, bits)
                 value = value.value
             mask = ((1 << bits) - 1) << start
             value = value << start
             self.value = (self.value & ~mask) | (value & mask)
 
+    def __ge__(self, other):
+        if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
+            assert other.bits == self.bits
+            other = other.value
+        if isinstance(other, int):
+            return other >= self.value
+        assert False
+
+    def __le__(self, other):
+        if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
+            assert other.bits == self.bits
+            other = other.value
+        if isinstance(other, int):
+            return onebit(other <= self.value)
+        assert False
+
+    def __gt__(self, other):
+        if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
+            assert other.bits == self.bits
+            other = other.value
+        if isinstance(other, int):
+            return onebit(other > self.value)
+        assert False
+
+    def __lt__(self, other):
+        if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
+            assert other.bits == self.bits
+            other = other.value
+        if isinstance(other, int):
+            return onebit(other < self.value)
+        assert False
+
     def __eq__(self, other):
         if isinstance(other, SelectableInt):
-            return other.value == self.value and other.bits == self.bits
+            other = check_extsign(self, other)
+            assert other.bits == self.bits
+            other = other.value
         if isinstance(other, int):
-            return other == self.value
+            return onebit(other == self.value)
         assert False
 
+    def __bool__(self):
+        return self.value != 0
+
     def __repr__(self):
         return "SelectableInt(value={:x}, bits={})".format(self.value,
                                                            self.bits)
 
+def onebit(bit):
+    return SelectableInt(1 if bit else 0, 1)
+
+def selectltu(lhs, rhs):
+    """ less-than (unsigned)
+    """
+    if isinstance(rhs, SelectableInt):
+        rhs = rhs.value
+    return onebit(lhs.value < rhs)
+
+def selectgtu(lhs, rhs):
+    """ greater-than (unsigned)
+    """
+    if isinstance(rhs, SelectableInt):
+        rhs = rhs.value
+    return onebit(lhs.value > rhs)
+
+
+# XXX this probably isn't needed...
+def selectassign(lhs, idx, rhs):
+    if isinstance(idx, tuple):
+        if len(idx) == 2:
+            lower, upper = idx
+            step = None
+        else:
+            lower, upper, step = idx
+        toidx = range(lower, upper, step)
+        fromidx = range(0, upper-lower, step) # XXX eurgh...
+    else:
+        toidx = [idx]
+        fromidx = [0]
+    for t, f in zip(toidx, fromidx):
+        lhs[t] = rhs[f]
+
+
+def selectconcat(*args, repeat=1):
+    if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
+        args = [SelectableInt(args[0], 1)]
+    if repeat != 1: # multiplies the incoming arguments
+        tmp = []
+        for i in range(repeat):
+            tmp += args
+        args = tmp
+    res = copy(args[0])
+    for i in args[1:]:
+        assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
+        res.bits += i.bits
+        res.value = (res.value << i.bits) | i.value
+    print ("concat", repeat, res)
+    return res
+
 
 class SelectableIntTestCase(unittest.TestCase):
-    def test_add(self):
+    def test_arith(self):
         a = SelectableInt(5, 8)
         b = SelectableInt(9, 8)
         c = a + b
-        assert c.value == a.value + b.value
-        assert c.bits == a.bits
+        d = a - b
+        e = a * b
+        f = -a
+        self.assertEqual(c.value, a.value + b.value)
+        self.assertEqual(d.value, (a.value - b.value) & 0xFF)
+        self.assertEqual(e.value, (a.value * b.value) & 0xFF)
+        self.assertEqual(f.value, (-a.value) & 0xFF)
+        self.assertEqual(c.bits, a.bits)
+        self.assertEqual(d.bits, a.bits)
+        self.assertEqual(e.bits, a.bits)
+
+    def test_logic(self):
+        a = SelectableInt(0x0F, 8)
+        b = SelectableInt(0xA5, 8)
+        c = a & b
+        d = a | b
+        e = a ^ b
+        f = ~a
+        self.assertEqual(c.value, a.value & b.value)
+        self.assertEqual(d.value, a.value | b.value)
+        self.assertEqual(e.value, a.value ^ b.value)
+        self.assertEqual(f.value, 0xF0)
 
     def test_get(self):
         a = SelectableInt(0xa2, 8)
         # These should be big endian
-        assert a[7] == 0
-        assert a[0:4] == 10
-        assert a[4:8] == 2
+        self.assertEqual(a[7], 0)
+        self.assertEqual(a[0:4], 10)
+        self.assertEqual(a[4:8], 2)
 
     def test_set(self):
         a = SelectableInt(0x5, 8)
@@ -97,7 +262,17 @@ class SelectableIntTestCase(unittest.TestCase):
         a[0:4] = 3
         self.assertEqual(a, 0x39)
         a[0:4] = a[4:8]
-        self.assertEqual(a, 0x99)
+        self.assertEqual(a, 0x199)
+
+    def test_concat(self):
+        a = SelectableInt(0x1, 1)
+        c = selectconcat(a, repeat=8)
+        self.assertEqual(c, 0xff)
+        self.assertEqual(c.bits, 8)
+        a = SelectableInt(0x0, 1)
+        c = selectconcat(a, repeat=8)
+        self.assertEqual(c, 0x00)
+        self.assertEqual(c.bits, 8)
 
 
 if __name__ == "__main__":