convert PartitionedRepl over to new "PartType" format
[ieee754fpu.git] / src / ieee754 / part / partsig.py
index 5ed299b9d155d5d2363347f84ad6a583bd2d62dd..d02dac4a1691b8320eb826ec7539d4648edd0902 100644 (file)
@@ -27,10 +27,11 @@ from ieee754.part_mul_add.partpoints import make_partition2, PartitionPoints
 from ieee754.part_mux.part_mux import PMux
 from ieee754.part_ass.passign import PAssign
 from ieee754.part_cat.pcat import PCat
+from ieee754.part_repl.prepl import PRepl
 from operator import or_, xor, and_, not_
 
-from nmigen import (Signal, Const)
-from nmigen.hdl.ast import UserValue
+from nmigen import (Signal, Const, Cat)
+from nmigen.hdl.ast import UserValue, Shape
 
 
 def getsig(op1):
@@ -55,7 +56,25 @@ for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor', 'bool', 'all']:
     modnames[name] = 0
 
 
+
+class PartType: # TODO decide name
+    def __init__(self, psig):
+        self.psig = psig
+    def get_mask(self):
+        return list(self.psig.partpoints.values())
+    def get_switch(self):
+        return Cat(self.get_mask())
+    def get_cases(self):
+        return range(1<<len(self.get_mask()))
+    @property
+    def blanklanes(self):
+        return 0
+
+
 class PartitionedSignal(UserValue):
+    # XXX ################################################### XXX
+    # XXX Keep these functions in the same order as ast.Value XXX
+    # XXX ################################################### XXX
     def __init__(self, mask, *args, src_loc_at=0, **kwargs):
         super().__init__(src_loc_at=src_loc_at)
         self.sig = Signal(*args, **kwargs)
@@ -65,7 +84,7 @@ class PartitionedSignal(UserValue):
             self.partpoints = mask
         else:
             self.partpoints = make_partition2(mask, width)
-
+        self.ptype = PartType(self)
 
     def set_module(self, m):
         self.m = m
@@ -83,18 +102,25 @@ class PartitionedSignal(UserValue):
         result.m = other.m
         return result
 
-    def __len__(self):
-        return len(self.sig)
-    def shape(self):
-        return self.sig.shape()
     def lower(self):
         return self.sig
-    # now using __Assign__
-    #def eq(self, val):
-    #    return self.sig.eq(getsig(val))
 
     # nmigen-redirected constructs (Mux, Cat, Switch, Assign)
 
+    # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=458
+    #def __Part__(self, offset, width, stride=1, *, src_loc_at=0):
+
+    def __Repl__(self, count, *, src_loc_at=0):
+        return PRepl(self.m, self, count, self.ptype)
+
+    def __Cat__(self, *args, src_loc_at=0):
+        args = [self] + list(args)
+        for sig in args:
+            assert isinstance(sig, PartitionedSignal), \
+                "All PartitionedSignal.__Cat__ arguments must be " \
+                "a PartitionedSignal. %s is not." % repr(sig)
+        return PCat(self.m, args, self.partpoints)
+
     def __Mux__(self, val1, val2):
         # print ("partsig mux", self, val1, val2)
         assert len(val1) == len(val2), \
@@ -106,13 +132,12 @@ class PartitionedSignal(UserValue):
         # print ("partsig ass", self, val)
         return PAssign(self.m, self, val, self.partpoints)
 
-    def __Cat__(self, *args, src_loc_at=0):
-        args = [self] + list(args)
-        for sig in args:
-            assert isinstance(sig, PartitionedSignal), \
-                "All PartitionedSignal.__Cat__ arguments must be " \
-                "a PartitionedSignal. %s is not." % repr(sig)
-        return PCat(self.m, args, self.partpoints)
+    # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=458
+    #def __Switch__(self, cases, *, src_loc=None, src_loc_at=0,
+    #                               case_src_locs={}):
+
+    # no override needed, Value.__bool__ sufficient
+    # def __bool__(self):
 
     # unary ops that do not require partitioning
 
@@ -128,72 +153,8 @@ class PartitionedSignal(UserValue):
         result, _ = self.sub_op(z, self)
         return result
 
-    # binary ops that don't require partitioning
-
-    def __and__(self, other):
-        return applyop(self, other, and_)
-
-    def __rand__(self, other):
-        return applyop(other, self, and_)
-
-    def __or__(self, other):
-        return applyop(self, other, or_)
-
-    def __ror__(self, other):
-        return applyop(other, self, or_)
-
-    def __xor__(self, other):
-        return applyop(self, other, xor)
-
-    def __rxor__(self, other):
-        return applyop(other, self, xor)
-
     # binary ops that need partitioning
 
-    # TODO: detect if the 2nd operand is a Const, a Signal or a
-    # PartitionedSignal.  if it's a Const or a Signal, a global shift
-    # can occur.  if it's a PartitionedSignal, that's much more interesting.
-    def ls_op(self, op1, op2, carry, shr_flag=0):
-        op1 = getsig(op1)
-        if isinstance(op2, Const) or isinstance(op2, Signal):
-            scalar = True
-            pa = PartitionedScalarShift(len(op1), self.partpoints)
-        else:
-            scalar = False
-            op2 = getsig(op2)
-            pa = PartitionedDynamicShift(len(op1), self.partpoints)
-        setattr(self.m.submodules, self.get_modname('ls'), pa)
-        comb = self.m.d.comb
-        if scalar:
-            comb += pa.data.eq(op1)
-            comb += pa.shifter.eq(op2)
-            comb += pa.shift_right.eq(shr_flag)
-        else:
-            comb += pa.a.eq(op1)
-            comb += pa.b.eq(op2)
-            comb += pa.shift_right.eq(shr_flag)
-        # XXX TODO: carry-in, carry-out
-        #comb += pa.carry_in.eq(carry)
-        return (pa.output, 0)
-
-    def __lshift__(self, other):
-        z = Const(0, len(self.partpoints)+1)
-        result, _ = self.ls_op(self, other, carry=z) # TODO, carry
-        return result
-
-    def __rlshift__(self, other):
-        raise NotImplementedError
-        return Operator("<<", [other, self])
-
-    def __rshift__(self, other):
-        z = Const(0, len(self.partpoints)+1)
-        result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
-        return result
-
-    def __rrshift__(self, other):
-        raise NotImplementedError
-        return Operator(">>", [other, self])
-
     def add_op(self, op1, op2, carry):
         op1 = getsig(op1)
         op2 = getsig(op2)
@@ -225,6 +186,7 @@ class PartitionedSignal(UserValue):
         return result
 
     def __radd__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
         result, _ = self.add_op(other, self)
         return result
 
@@ -233,24 +195,20 @@ class PartitionedSignal(UserValue):
         return result
 
     def __rsub__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
         result, _ = self.sub_op(other, self)
         return result
 
     def __mul__(self, other):
+        raise NotImplementedError # too complicated at the moment
         return Operator("*", [self, other])
 
     def __rmul__(self, other):
+        raise NotImplementedError # too complicated at the moment
         return Operator("*", [other, self])
 
-    def __check_divisor(self):
-        width, signed = self.shape()
-        if signed:
-            # Python's division semantics and Verilog's division semantics
-            # differ for negative divisors (Python uses div/mod, Verilog
-            # uses quo/rem); for now, avoid the issue
-            # completely by prohibiting such division operations.
-            raise NotImplementedError(
-                    "Division by a signed value is not supported")
+    # not needed: same as Value.__check_divisor
+    #def __check_divisor(self):
 
     def __mod__(self, other):
         raise NotImplementedError
@@ -274,6 +232,79 @@ class PartitionedSignal(UserValue):
         self.__check_divisor()
         return Operator("//", [other, self])
 
+    # not needed: same as Value.__check_shamt
+    #def __check_shamt(self):
+
+    # TODO: detect if the 2nd operand is a Const, a Signal or a
+    # PartitionedSignal.  if it's a Const or a Signal, a global shift
+    # can occur.  if it's a PartitionedSignal, that's much more interesting.
+    def ls_op(self, op1, op2, carry, shr_flag=0):
+        op1 = getsig(op1)
+        if isinstance(op2, Const) or isinstance(op2, Signal):
+            scalar = True
+            pa = PartitionedScalarShift(len(op1), self.partpoints)
+        else:
+            scalar = False
+            op2 = getsig(op2)
+            pa = PartitionedDynamicShift(len(op1), self.partpoints)
+        # else:
+        #   TODO: case where the *shifter* is a PartitionedSignal but
+        #   the thing *being* Shifted is a scalar (Signal, expression)
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
+        setattr(self.m.submodules, self.get_modname('ls'), pa)
+        comb = self.m.d.comb
+        if scalar:
+            comb += pa.data.eq(op1)
+            comb += pa.shifter.eq(op2)
+            comb += pa.shift_right.eq(shr_flag)
+        else:
+            comb += pa.a.eq(op1)
+            comb += pa.b.eq(op2)
+            comb += pa.shift_right.eq(shr_flag)
+        # XXX TODO: carry-in, carry-out (for arithmetic shift)
+        #comb += pa.carry_in.eq(carry)
+        return (pa.output, 0)
+
+    def __lshift__(self, other):
+        z = Const(0, len(self.partpoints)+1)
+        result, _ = self.ls_op(self, other, carry=z) # TODO, carry
+        return result
+
+    def __rlshift__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
+        raise NotImplementedError
+        return Operator("<<", [other, self])
+
+    def __rshift__(self, other):
+        z = Const(0, len(self.partpoints)+1)
+        result, _ = self.ls_op(self, other, carry=z, shr_flag=1) # TODO, carry
+        return result
+
+    def __rrshift__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
+        raise NotImplementedError
+        return Operator(">>", [other, self])
+
+    # binary ops that don't require partitioning
+
+    def __and__(self, other):
+        return applyop(self, other, and_)
+
+    def __rand__(self, other):
+        return applyop(other, self, and_)
+
+    def __or__(self, other):
+        return applyop(self, other, or_)
+
+    def __ror__(self, other):
+        return applyop(other, self, or_)
+
+    def __xor__(self, other):
+        return applyop(self, other, xor)
+
+    def __rxor__(self, other):
+        return applyop(other, self, xor)
+
     # binary comparison ops that need partitioning
 
     def _compare(self, width, op1, op2, opname, optype):
@@ -303,24 +334,45 @@ class PartitionedSignal(UserValue):
         self.m.d.comb += ne.eq(~eq)
         return ne
 
-    def __gt__(self, other):
-        width = len(self.sig)
-        return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT)
-
     def __lt__(self, other):
         width = len(self.sig)
         # swap operands, use gt to do lt
         return self._compare(width, other, self, "gt", PartitionedEqGtGe.GT)
 
-    def __ge__(self, other):
-        width = len(self.sig)
-        return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE)
-
     def __le__(self, other):
         width = len(self.sig)
         # swap operands, use ge to do le
         return self._compare(width, other, self, "ge", PartitionedEqGtGe.GE)
 
+    def __gt__(self, other):
+        width = len(self.sig)
+        return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT)
+
+    def __ge__(self, other):
+        width = len(self.sig)
+        return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE)
+
+    # no override needed: Value.__abs__ is general enough it does the job
+    # def __abs__(self):
+
+    def __len__(self):
+        return len(self.sig)
+
+    # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=716
+    # def __getitem__(self, key):
+
+    def __new_sign(self, signed):
+        shape = Shape(len(self), signed=signed)
+        result = PartitionedSignal.like(self, shape=shape)
+        self.m.d.comb += result.sig.eq(self.sig)
+        return result
+
+    # http://bugs.libre-riscv.org/show_bug.cgi?id=719
+    def as_unsigned(self):
+        return self.__new_sign(False)
+    def as_signed(self):
+        return self.__new_sign(True)
+
     # useful operators
 
     def bool(self):
@@ -380,14 +432,16 @@ class PartitionedSignal(UserValue):
         self.m.d.comb += pa.a.eq(self.sig)
         return pa.output
 
-    def implies(premise, conclusion):
-        """Implication.
+    # not needed: Value.implies does the job
+    # def implies(premise, conclusion):
 
-        Returns
-        -------
-        Value, out
-            ``0`` if ``premise`` is true and ``conclusion`` is not,
-            ``1`` otherwise.
-        """
-        # amazingly, this should actually work.
-        return ~premise | conclusion
+    # TODO. contains a Value.cast which means an override is needed (on both)
+    # def bit_select(self, offset, width):
+    # def word_select(self, offset, width):
+
+    # not needed: Value.matches, amazingly, should do the job
+    # def matches(self, *patterns):
+
+    # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=713
+    def shape(self):
+        return self.sig.shape()