add selectconcat test
[soc.git] / src / soc / decoder / selectable_int.py
index 8b5715bbf21692186a08dab39ca667c5e9114e1e..fb064fcb383bd54a5bf35bc4d445c2e8fc6a0f10 100644 (file)
@@ -1,6 +1,11 @@
 import unittest
 from copy import copy
 
+def check_extsign(a, b):
+    if b.bits != 256:
+        return b
+    return SelectableInt(b.value, a.bits)
+
 
 class SelectableInt:
     def __init__(self, value, bits):
@@ -9,34 +14,46 @@ class SelectableInt:
         self.bits = bits
 
     def __add__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value + b.value, self.bits)
 
     def __sub__(self, b):
+        if isinstance(b, int):
+            b = SelectableInt(b, self.bits)
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value - b.value, self.bits)
 
     def __mul__(self, b):
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value * b.value, self.bits)
 
     def __div__(self, b):
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value / b.value, self.bits)
 
     def __mod__(self, b):
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value % b.value, self.bits)
 
     def __or__(self, b):
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value | b.value, self.bits)
 
     def __and__(self, b):
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value & b.value, self.bits)
 
     def __xor__(self, b):
+        b = check_extsign(self, b)
         assert b.bits == self.bits
         return SelectableInt(self.value ^ b.value, self.bits)
 
@@ -99,6 +116,7 @@ class SelectableInt:
 
     def __ge__(self, other):
         if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
@@ -107,6 +125,7 @@ class SelectableInt:
 
     def __le__(self, other):
         if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
@@ -115,6 +134,7 @@ class SelectableInt:
 
     def __gt__(self, other):
         if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
@@ -123,6 +143,7 @@ class SelectableInt:
 
     def __lt__(self, other):
         if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
@@ -131,6 +152,7 @@ class SelectableInt:
 
     def __eq__(self, other):
         if isinstance(other, SelectableInt):
+            other = check_extsign(self, other)
             assert other.bits == self.bits
             other = other.value
         if isinstance(other, int):
@@ -179,12 +201,20 @@ def selectassign(lhs, idx, rhs):
         lhs[t] = rhs[f]
 
 
-def selectconcat(*args):
+def selectconcat(*args, repeat=1):
+    if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
+        args = [SelectableInt(args[0], 1)]
+    if repeat != 1: # multiplies the incoming arguments
+        tmp = []
+        for i in range(repeat):
+            tmp += args
+        args = tmp
     res = copy(args[0])
     for i in args[1:]:
         assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
         res.bits += i.bits
         res.value = (res.value << i.bits) | i.value
+    print ("concat", repeat, res)
     return res
 
 
@@ -234,5 +264,16 @@ class SelectableIntTestCase(unittest.TestCase):
         a[0:4] = a[4:8]
         self.assertEqual(a, 0x199)
 
+    def test_concat(self):
+        a = SelectableInt(0x1, 1)
+        c = selectconcat(a, repeat=8)
+        self.assertEqual(c, 0xff)
+        self.assertEqual(c.bits, 8)
+        a = SelectableInt(0x0, 1)
+        c = selectconcat(a, repeat=8)
+        self.assertEqual(c, 0x00)
+        self.assertEqual(c.bits, 8)
+
+
 if __name__ == "__main__":
     unittest.main()