test_caller_svp64_powmod: rename to test_aaa_caller_svp64_powmod so pytest tries...
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
index 22fb8311cef15eed2b76f88f2c6b09893552fd68..1eb0d2a7bd22c53be962853731607a0ec680a80e 100644 (file)
@@ -2,22 +2,34 @@ import unittest
 import struct
 from copy import copy
 import functools
-from openpower.decoder.power_fields import BitRange
+from collections import OrderedDict
 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
                       neg, inv, lshift, rshift, lt, eq)
 from openpower.util import log
 
+EFFECTIVELY_UNLIMITED = 1024
 
 def check_extsign(a, b):
     if isinstance(b, FieldSelectableInt):
         b = b.get_range()
     if isinstance(b, int):
         return SelectableInt(b, a.bits)
-    if b.bits != 256:
+    if b.bits != EFFECTIVELY_UNLIMITED:
         return b
     return SelectableInt(b.value, a.bits)
 
 
+class BitRange(OrderedDict):
+    """BitRange: remaps from straight indices (0,1,2..) to bit numbers
+    """
+
+    def __getitem__(self, subscript):
+        if isinstance(subscript, slice):
+            return list(self.values())[subscript]
+        else:
+            return OrderedDict.__getitem__(self, subscript)
+
+
 @functools.total_ordering
 class FieldSelectableInt:
     """FieldSelectableInt: allows bit-range selection onto another target
@@ -47,18 +59,10 @@ class FieldSelectableInt:
         self.br = br  # map of indices
 
     def eq(self, b):
-        if isinstance(b, int):
-            # convert integer to same SelectableInt of same bitlength as range
-            blen = len(self.br)
-            b = SelectableInt(b, blen)
-            for i in range(b.bits):
-                self[i] = b[i]
-        elif isinstance(b, SelectableInt):
-            for i in range(b.bits):
-                self[i] = b[i]
-        else:
-            self.si = copy(b.si)
-            self.br = copy(b.br)
+        if not isinstance(b, SelectableInt):
+            b = SelectableInt(b, len(self.br))
+        for i in range(b.bits):
+            self[i] = b[i]
 
     def _op(self, op, b):
         vi = self.get_range()
@@ -74,7 +78,7 @@ class FieldSelectableInt:
         return len(self.br)
 
     def __getitem__(self, key):
-        log("getitem", key, self.br)
+        #log("getitem", key, self.br)
         if isinstance(key, SelectableInt):
             key = key.value
 
@@ -86,6 +90,8 @@ class FieldSelectableInt:
             return selectconcat(*[self.si[x] for x in key])
         elif isinstance(key, (tuple, list, range)):
             return FieldSelectableInt(si=self, br=key)
+        else:
+            raise ValueError(key)
 
     def __setitem__(self, key, value):
         if isinstance(key, SelectableInt):
@@ -100,7 +106,7 @@ class FieldSelectableInt:
                 self.si[k] = value[i]
 
     def __negate__(self):
-        return self._op1(negate)
+        return self._op1(neg)
 
     def __invert__(self):
         return self._op1(inv)
@@ -159,11 +165,14 @@ class FieldSelectableInt:
                 return True
         return False
 
+    def __int__(self):
+        return self.asint(msb0=True)
+
     def asint(self, msb0=False):
         res = 0
         brlen = len(self.br)
         for i, key in self.br.items():
-            log("asint", i, key, self.si[key])
+            #log("asint", i, key, self.si[key])
             bit = self.si[key].value
             #log("asint", i, key, bit)
             res |= bit << ((brlen-i-1) if msb0 else i)
@@ -209,6 +218,7 @@ class FieldSelectableIntTestCase(unittest.TestCase):
         self.assertEqual(fs.get_range(), 0b1011)
 
 
+@functools.total_ordering
 class SelectableInt:
     """SelectableInt - a class that behaves exactly like python int
 
@@ -219,9 +229,21 @@ class SelectableInt:
     FieldSelectableInt can then operate on partial bits, and because there
     is a bit width associated with SelectableInt, slices operate correctly
     including negative start/end points.
+
+    value: int
+        the bits contained by `self`
+    bits: int
+        the number of bits contained by `self`.
+    ok: bool
+        a flag to detect if outputs have been written by pseudo-code
+
+        instruction inputs have `ok` set to `False`, all changed or new
+        SelectableInt instances set `ok` to `True`.
     """
 
-    def __init__(self, value, bits=None):
+    def __init__(self, value, bits=None, *, ok=True):
+        if isinstance(value, FieldSelectableInt):
+            value = value.get_range()
         if isinstance(value, SelectableInt):
             if bits is not None:
                 # check if the bitlength is different. TODO, allow override?
@@ -229,11 +251,7 @@ class SelectableInt:
                     raise ValueError(value)
             bits = value.bits
             value = value.value
-        elif isinstance(value, FieldSelectableInt):
-            if bits is not None:
-                raise ValueError(value)
-            bits = len(value.br)
-            value = value.si.value
+            # intentionally don't copy ok
         else:
             if not isinstance(value, int):
                 raise ValueError(value)
@@ -243,10 +261,12 @@ class SelectableInt:
         self.value = value & mask
         self.bits = bits
         self.overflow = (value & ~mask) != 0
+        self.ok = ok
 
     def eq(self, b):
         self.value = b.value
         self.bits = b.bits
+        self.ok = True
 
     def to_signed_int(self):
         log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
@@ -306,7 +326,7 @@ class SelectableInt:
     def __rsub__(self, b):
         log("rsub", b, self.value)
         if isinstance(b, int):
-            b = SelectableInt(b, 256) # max extent
+            b = SelectableInt(b, EFFECTIVELY_UNLIMITED) # max extent
         #b = check_extsign(self, b)
         #assert b.bits == self.bits
         return SelectableInt(b.value - self.value, b.bits)
@@ -340,7 +360,7 @@ class SelectableInt:
         return SelectableInt(self.value >> b.value, self.bits)
 
     def __getitem__(self, key):
-        log ("SelectableInt.__getitem__", self, key, type(key))
+        #log ("SelectableInt.__getitem__", self, key, type(key))
         if isinstance(key, SelectableInt):
             key = key.value
         if isinstance(key, int):
@@ -351,32 +371,51 @@ class SelectableInt:
             key = self.bits - (key + 1)
 
             value = (self.value >> key) & 1
-            log("getitem", key, self.bits, hex(self.value), value)
+            #log("getitem", key, self.bits, hex(self.value), value)
             return SelectableInt(value, 1)
         elif isinstance(key, slice):
-            assert key.step is None or key.step == 1
-            assert key.start < key.stop
-            assert key.start >= 0
-            assert key.stop <= self.bits
-
-            stop = self.bits - key.start
-            start = self.bits - key.stop
-
+            start = key.start
+            if isinstance(start, SelectableInt):
+                start = start.value
+            stop = key.stop
+            if isinstance(stop, SelectableInt):
+                stop = stop.value
+            step = key.step
+            if isinstance(step, SelectableInt):
+                step = step.value
+
+            assert step is None or step == 1
+            assert start < stop
+            assert start >= 0
+            assert stop <= self.bits
+
+            (start, stop) = (
+                (self.bits - stop),
+                (self.bits - start),
+            )
             bits = stop - start
-            log ("__getitem__ slice num bits", start, stop, bits)
+            #log ("__getitem__ slice num bits", start, stop, bits)
             mask = (1 << bits) - 1
             value = (self.value >> start) & mask
-            log("getitem", stop, start, self.bits, hex(self.value), value)
+            #log("getitem", stop, start, self.bits, hex(self.value), value)
             return SelectableInt(value, bits)
+        else:
+            bits = []
+            for bit in key:
+                if not isinstance(bit, (int, SelectableInt)):
+                    raise ValueError(key)
+                bits.append(self[bit])
+            return selectconcat(*bits)
 
     def __setitem__(self, key, value):
+        self.ok = True
         if isinstance(key, SelectableInt):
             key = key.value
         if isinstance(key, int):
             if isinstance(value, SelectableInt):
                 assert value.bits == 1
                 value = value.value
-            log("setitem", key, self.bits, hex(self.value), hex(value))
+            #log("setitem", key, self.bits, hex(self.value), hex(value))
 
             assert key < self.bits
             assert key >= 0
@@ -390,7 +429,7 @@ class SelectableInt:
             if isinstance(kstart, SelectableInt): kstart = kstart.asint()
             if isinstance(kstop, SelectableInt): kstop = kstop.asint()
             if isinstance(kstep, SelectableInt): kstep = kstep.asint()
-            log ("__setitem__ slice ", kstart, kstop, kstep)
+            #log ("__setitem__ slice ", kstart, kstop, kstep)
             assert kstep is None or kstep == 1
             assert kstart < kstop
             assert kstart >= 0
@@ -405,46 +444,29 @@ class SelectableInt:
             if isinstance(value, SelectableInt):
                 assert value.bits == bits, "%d into %d" % (value.bits, bits)
                 value = value.value
-            log("setitem", key, self.bits, hex(self.value), hex(value))
+            #log("setitem", key, self.bits, hex(self.value), hex(value))
             mask = ((1 << bits) - 1) << start
             value = value << start
             self.value = (self.value & ~mask) | (value & mask)
+        else:
+            bits = []
+            for bit in key:
+                if not isinstance(bit, (int, SelectableInt)):
+                    raise ValueError(key)
+                bits.append(bit)
+
+            if isinstance(value, int):
+                if value.bit_length() > len(bits):
+                    raise ValueError(value)
+                value = SelectableInt(value=value, bits=len(bits))
+            if not isinstance(value, SelectableInt):
+                raise ValueError(value)
 
-    def __ge__(self, other):
-        if isinstance(other, FieldSelectableInt):
-            other = other.get_range()
-        if isinstance(other, SelectableInt):
-            other = check_extsign(self, other)
-            assert other.bits == self.bits
-            other = other.to_signed_int()
-        if isinstance(other, int):
-            return onebit(self.to_signed_int() >= other)
-        assert False
-
-    def __le__(self, other):
-        if isinstance(other, FieldSelectableInt):
-            other = other.get_range()
-        if isinstance(other, SelectableInt):
-            other = check_extsign(self, other)
-            assert other.bits == self.bits
-            other = other.to_signed_int()
-        if isinstance(other, int):
-            return onebit(self.to_signed_int() <= other)
-        assert False
-
-    def __gt__(self, other):
-        if isinstance(other, FieldSelectableInt):
-            other = other.get_range()
-        if isinstance(other, SelectableInt):
-            other = check_extsign(self, other)
-            assert other.bits == self.bits
-            other = other.to_signed_int()
-        if isinstance(other, int):
-            return onebit(self.to_signed_int() > other)
-        assert False
+            for (src, dst) in enumerate(bits):
+                self[dst] = value[src]
 
     def __lt__(self, other):
-        log ("SelectableInt lt", self, other)
+        log ("SelectableInt __lt__", self, other)
         if isinstance(other, FieldSelectableInt):
             other = other.get_range()
         if isinstance(other, SelectableInt):
@@ -459,7 +481,7 @@ class SelectableInt:
         assert False
 
     def __eq__(self, other):
-        log("__eq__", self, other)
+        log("SelectableInt __eq__", self, other)
         if isinstance(other, FieldSelectableInt):
             other = other.get_range()
         if isinstance(other, SelectableInt):
@@ -479,8 +501,10 @@ class SelectableInt:
         return self.value != 0
 
     def __repr__(self):
-        value = f"value={hex(self.value)}, bits={self.bits}"
-        return f"{self.__class__.__name__}({value})"
+        value = "value=%#x, bits=%d" % (self.value, self.bits)
+        if not self.ok:
+            value += ", ok=False"
+        return "%s(%s)" % (self.__class__.__name__, value)
 
     def __len__(self):
         return self.bits
@@ -495,111 +519,14 @@ class SelectableInt:
         """convert to double-precision float.  TODO, properly convert
         rather than a hack-job: must actually support Power IEEE754 FP
         """
+        if self.bits == 32:
+            data = self.value.to_bytes(4, byteorder='little')
+            return struct.unpack('<f', data)[0]
         assert self.bits == 64 # must be 64-bit
         data = self.value.to_bytes(8, byteorder='little')
         return struct.unpack('<d', data)[0]
 
 
-class SelectableIntMappingMeta(type):
-    class Field(tuple):
-        def __call__(self, si):
-            return FieldSelectableInt(si=si, br=self)
-
-    class FieldMapping(dict):
-        def __init__(self, items):
-            if isinstance(items, dict):
-                items = items.items()
-
-            length = 0
-            mapping = {}
-            Field = SelectableIntMappingMeta.Field
-            for (key, value) in items:
-                field = Field(value)
-                mapping[key] = field
-                length = max(length, len(field))
-
-            self.__length = length
-
-            return super().__init__(mapping)
-
-        def __iter__(self):
-            yield from self.items()
-
-        def __len__(self):
-            return self.__length
-
-        def __call__(self, si):
-            return {key:value(si=si) for (key, value) in self}
-
-    def __new__(metacls, name, bases, attrs):
-        mapping = {}
-        valid = False
-        for base in reversed(bases):
-            if issubclass(base.__class__, metacls):
-                mapping.update(base)
-            if not valid and issubclass(base, SelectableInt):
-                valid = True
-        if not valid:
-            raise ValueError(bases)
-
-        for (key, value) in tuple(attrs.items()):
-            if key.startswith("_"):
-                continue
-            if isinstance(value, dict):
-                value = metacls.FieldMapping(value)
-            elif isinstance(value, (list, tuple, range)):
-                value = metacls.Field(value)
-            else:
-                continue
-            mapping[key] = value
-            attrs[key] = value
-
-        length = 0
-        for (key, value) in mapping.items():
-            length = max(length, len(value))
-
-        cls = super().__new__(metacls, name, bases, attrs)
-        cls.__length = length
-        cls.__mapping = mapping
-
-        return cls
-
-    def __len__(cls):
-        return cls.__length
-
-    def __contains__(cls, key):
-        return cls.__mapping.__contains__(key)
-
-    def __getitem__(cls, key):
-        return cls.__mapping.__getitem__(key)
-
-    def __iter__(cls):
-        yield from cls.__mapping.items()
-        if type(cls) is not SelectableIntMappingMeta:
-            yield from super().__iter__()
-
-
-class SelectableIntMapping(SelectableInt, metaclass=SelectableIntMappingMeta):
-    def __init__(self, value=0, bits=None):
-        if isinstance(value, SelectableInt):
-            value = value.value
-        if bits is None:
-            bits = len(self.__class__)
-        if bits != len(self.__class__):
-            raise ValueError(bits)
-
-        return super().__init__(value=value, bits=bits)
-
-    def __iter__(self):
-        for (name, _) in self.__class__:
-            yield (name, getattr(self, name))
-
-    def __getattribute__(self, attr):
-        if (attr != "__class__") and (attr in self.__class__):
-            return self.__class__[attr](si=self)
-        return super().__getattribute__(attr)
-
-
 def onebit(bit):
     return SelectableInt(1 if bit else 0, 1)
 
@@ -607,17 +534,21 @@ def onebit(bit):
 def selectltu(lhs, rhs):
     """ less-than (unsigned)
     """
+    if isinstance(lhs, SelectableInt):
+        lhs = lhs.value
     if isinstance(rhs, SelectableInt):
         rhs = rhs.value
-    return onebit(lhs.value < rhs)
+    return onebit(lhs < rhs)
 
 
 def selectgtu(lhs, rhs):
     """ greater-than (unsigned)
     """
+    if isinstance(lhs, SelectableInt):
+        lhs = lhs.value
     if isinstance(rhs, SelectableInt):
         rhs = rhs.value
-    return onebit(lhs.value > rhs)
+    return onebit(lhs > rhs)
 
 
 # XXX this probably isn't needed...
@@ -638,17 +569,23 @@ def selectassign(lhs, idx, rhs):
 
 
 def selectconcat(*args, repeat=1):
-    if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
+    if isinstance(repeat, SelectableInt):
+        repeat = repeat.value
+    if len(args) == 1 and isinstance(args[0], int) and args[0] in (0, 1):
         args = [SelectableInt(args[0], 1)]
     if repeat != 1:  # multiplies the incoming arguments
         tmp = []
         for i in range(repeat):
             tmp += args
         args = tmp
-    res = copy(args[0])
+    if isinstance(args[0], FieldSelectableInt):
+        res = args[0].get_range()
+    else:
+        assert isinstance(args[0], SelectableInt), "can only concat SIs, sorry"
+        res = SelectableInt(args[0])
     for i in args[1:]:
         if isinstance(i, FieldSelectableInt):
-            i = i.si
+            i = i.get_range()
         assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
         res.bits += i.bits
         res.value = (res.value << i.bits) | i.value