disable mul and rmul in PartitionedSignal for now
[ieee754fpu.git] / src / ieee754 / part / partsig.py
index 35487330ba342a3567d23a306b09525c00b5dc23..8e5154add74e50e82c423e388a0a64f559064465 100644 (file)
@@ -19,10 +19,13 @@ nmigen.Case, or other constructs: only Mux and other logic.
 from ieee754.part_mul_add.adder import PartitionedAdder
 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
 from ieee754.part_bits.xor import PartitionedXOR
+from ieee754.part_bits.bool import PartitionedBool
+from ieee754.part_bits.all import PartitionedAll
 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
 from ieee754.part_mul_add.partpoints import make_partition2, PartitionPoints
 from ieee754.part_mux.part_mux import PMux
+from ieee754.part_ass.passign import PAssign
 from ieee754.part_cat.pcat import PCat
 from operator import or_, xor, and_, not_
 
@@ -48,7 +51,7 @@ global modnames
 modnames = {}
 # for sub-modules to be created on-demand. Mux is done slightly
 # differently (has its own global)
-for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
+for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor', 'bool', 'all']:
     modnames[name] = 0
 
 
@@ -63,8 +66,6 @@ class PartitionedSignal(UserValue):
         else:
             self.partpoints = make_partition2(mask, width)
 
-    def lower(self):
-        return self.sig
 
     def set_module(self, m):
         self.m = m
@@ -73,9 +74,6 @@ class PartitionedSignal(UserValue):
         modnames[category] += 1
         return "%s_%d" % (category, modnames[category])
 
-    def eq(self, val):
-        return self.sig.eq(getsig(val))
-
     @staticmethod
     def like(other, *args, **kwargs):
         """Builds a new PartitionedSignal with the same PartitionPoints and
@@ -87,6 +85,13 @@ class PartitionedSignal(UserValue):
 
     def __len__(self):
         return len(self.sig)
+    def shape(self):
+        return self.sig.shape()
+    def lower(self):
+        return self.sig
+    # now using __Assign__
+    #def eq(self, val):
+    #    return self.sig.eq(getsig(val))
 
     # nmigen-redirected constructs (Mux, Cat, Switch, Assign)
 
@@ -97,6 +102,10 @@ class PartitionedSignal(UserValue):
             "val1 == %d, val2 == %d" % (len(val1), len(val2))
         return PMux(self.m, self.partpoints, self, val1, val2)
 
+    def __Assign__(self, val, *, src_loc_at=0):
+        # print ("partsig ass", self, val)
+        return PAssign(self.m, self, val, self.partpoints)
+
     def __Cat__(self, *args, src_loc_at=0):
         args = [self] + list(args)
         for sig in args:
@@ -153,6 +162,10 @@ class PartitionedSignal(UserValue):
             scalar = False
             op2 = getsig(op2)
             pa = PartitionedDynamicShift(len(op1), self.partpoints)
+        # else:
+        #   TODO: case where the *shifter* is a PartitionedSignal but
+        #   the thing *being* Shifted is a scalar (Signal, expression)
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
         setattr(self.m.submodules, self.get_modname('ls'), pa)
         comb = self.m.d.comb
         if scalar:
@@ -163,7 +176,7 @@ class PartitionedSignal(UserValue):
             comb += pa.a.eq(op1)
             comb += pa.b.eq(op2)
             comb += pa.shift_right.eq(shr_flag)
-        # XXX TODO: carry-in, carry-out
+        # XXX TODO: carry-in, carry-out (for arithmetic shift)
         #comb += pa.carry_in.eq(carry)
         return (pa.output, 0)
 
@@ -173,6 +186,7 @@ class PartitionedSignal(UserValue):
         return result
 
     def __rlshift__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
         raise NotImplementedError
         return Operator("<<", [other, self])
 
@@ -182,6 +196,7 @@ class PartitionedSignal(UserValue):
         return result
 
     def __rrshift__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
         raise NotImplementedError
         return Operator(">>", [other, self])
 
@@ -216,6 +231,7 @@ class PartitionedSignal(UserValue):
         return result
 
     def __radd__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
         result, _ = self.add_op(other, self)
         return result
 
@@ -224,13 +240,16 @@ class PartitionedSignal(UserValue):
         return result
 
     def __rsub__(self, other):
+        #   https://bugs.libre-soc.org/show_bug.cgi?id=718
         result, _ = self.sub_op(other, self)
         return result
 
     def __mul__(self, other):
+        raise NotImplementedError # too complicated at the moment
         return Operator("*", [self, other])
 
     def __rmul__(self, other):
+        raise NotImplementedError # too complicated at the moment
         return Operator("*", [other, self])
 
     def __check_divisor(self):
@@ -322,8 +341,11 @@ class PartitionedSignal(UserValue):
         Value, out
             ``1`` if any bits are set, ``0`` otherwise.
         """
-        return self.any() # have to see how this goes
-        #return Operator("b", [self])
+        width = len(self.sig)
+        pa = PartitionedBool(width, self.partpoints)
+        setattr(self.m.submodules, self.get_modname("bool"), pa)
+        self.m.d.comb += pa.a.eq(self.sig)
+        return pa.output
 
     def any(self):
         """Check if any bits are ``1``.
@@ -344,6 +366,13 @@ class PartitionedSignal(UserValue):
         Value, out
             ``1`` if all bits are set, ``0`` otherwise.
         """
+        # something wrong with PartitionedAll, but self == Const(-1)"
+        # XXX https://bugs.libre-soc.org/show_bug.cgi?id=176#c17
+        #width = len(self.sig)
+        #pa = PartitionedAll(width, self.partpoints)
+        #setattr(self.m.submodules, self.get_modname("all"), pa)
+        #self.m.d.comb += pa.a.eq(self.sig)
+        #return pa.output
         return self == Const(-1) # leverage the __eq__ operator here
 
     def xor(self):