add carry (not zeros, a Const of correct width)
[ieee754fpu.git] / src / ieee754 / part / partsig.py
index 0ab603def3c8103062c8d33109013eadf2443ff4..ea1fa3a6572ad64089fd2fdec77fcf29a90d571c 100644 (file)
@@ -19,10 +19,11 @@ nmigen.Case, or other constructs: only Mux and other logic.
 from ieee754.part_mul_add.adder import PartitionedAdder
 from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
+from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
 from ieee754.part_mul_add.partpoints import make_partition
 from operator import or_, xor, and_, not_
 
-from nmigen import (Signal)
+from nmigen import (Signal, Const)
 
 
 def getsig(op1):
@@ -50,7 +51,7 @@ class PartitionedSignal:
 
     def get_modname(self, category):
         self.modnames[category] += 1
-        return "%s%d" % (category, self.modnames[category])
+        return "%s_%d" % (category, self.modnames[category])
 
     def eq(self, val):
         return self.sig.eq(getsig(val))
@@ -63,7 +64,8 @@ class PartitionedSignal:
     # unary ops that require partitioning
 
     def __neg__(self):
-        result, _ = self.add_op(self, ~0, carry=0)  # TODO, subop
+        z = Const(0, self.sig.shape())
+        result, _ = self.add_op(self, ~0, carry=z)  # TODO, subop
         return result
 
     # binary ops that don't require partitioning
@@ -93,19 +95,30 @@ class PartitionedSignal:
     # can occur.  if it's a PartitionedSignal, that's much more interesting.
     def ls_op(self, op1, op2, carry):
         op1 = getsig(op1)
-        op2 = getsig(op2)
-        shape = op1.shape()
-        pa = PartitionedDynamicShift(shape[0], self.partpoints)
+        if isinstance(op2, Const) or isinstance(op2, Signal):
+            scalar = True
+            shape = op1.shape()
+            pa = PartitionedScalarShift(shape[0], self.partpoints)
+        else:
+            scalar = False
+            op2 = getsig(op2)
+            shape = op1.shape()
+            pa = PartitionedDynamicShift(shape[0], self.partpoints)
         setattr(self.m.submodules, self.get_modname('ls'), pa)
         comb = self.m.d.comb
-        comb += pa.a.eq(op1)
-        comb += pa.b.eq(op2)
+        if scalar:
+            comb += pa.data.eq(op1)
+            comb += pa.shifter.eq(op2)
+        else:
+            comb += pa.a.eq(op1)
+            comb += pa.b.eq(op2)
         # XXX TODO: carry-in, carry-out
         #comb += pa.carry_in.eq(carry)
         return (pa.output, 0)
 
     def __lshift__(self, other):
-        result, _ = self.ls_op(self, other, carry=0)
+        z = Const(0, self.sig.shape())
+        result, _ = self.ls_op(self, other, carry=z) # TODO, carry
         return result
 
     def __rlshift__(self, other):