clip shift amount
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
index 9e214d0f41b974533f994e1008fdc7bf79d04e36..ccfc5d04d8f44307cf0e562d403d073d2bbdd15d 100644 (file)
@@ -97,7 +97,9 @@ class PartitionedDynamicShift(Elaboratable):
         # those partial results to calculate a0, a1, a2, and a3
         element = b_intervals[0] & shifter_masks[0]
         partial_results = []
-        partial_results.append(a_intervals[0] << element)
+        partial = Signal(width, name="partial0", reset_less=True)
+        comb += partial.eq(a_intervals[0] << element)
+        partial_results.append(partial)
         for i in range(1, len(keys)):
             reswid = width - intervals[i][0]
             shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
@@ -118,15 +120,18 @@ class PartitionedDynamicShift(Elaboratable):
             # the partition mask, this calculates that with a mux
             # chain
 
-            # This computes the partial results table
+            # This computes the partial results table.  note that
+            # the shift amount is truncated because there's no point
+            # trying to shift data by 64 bits if the result width
+            # is only 8.
             shifter = Signal(shiftbits, name="shifter%d" % i,
                           reset_less=True)
-            #with m.If(element > shiftbits):
-            #    comb += shifter.eq(shiftbits)
-            #with m.Else():
-            #    comb += shifter.eq(element)
+            with m.If(element > shiftbits):
+                comb += shifter.eq(shiftbits)
+            with m.Else():
+                comb += shifter.eq(element)
             comb += shifter.eq(element)
-            partial = Signal(width, name="partial%d" % i, reset_less=True)
+            partial = Signal(reswid, name="partial%d" % i, reset_less=True)
             comb += partial.eq(a_intervals[i] << shifter)
 
             partial_results.append(partial)
@@ -134,17 +139,19 @@ class PartitionedDynamicShift(Elaboratable):
         out = []
 
         # This calculates the outputs o0-o3 from the partial results
-        # table above.
+        # table above.  Note: only relevant bits of the partial result equal
+        # to the width of the output column are accumulated in a Mux-cascade.
         s,e = intervals[0]
         result = partial_results[0]
         out.append(result[s:e])
         for i in range(1, len(keys)):
             start, end = (intervals[i][0], width)
-            result = partial_results[i] | \
-                Mux(gates[i-1], 0, result[intervals[0][1]:])[:end-start]
+            reswid = width - start
+            sel = Mux(gates[i-1], 0, result[intervals[0][1]:][:end-start])
             print("select: [%d:%d]" % (start, end))
-            res = Signal(width, name="res%d" % i, reset_less=True)
-            comb += res.eq(result)
+            res = Signal(end-start+1, name="res%d" % i, reset_less=True)
+            comb += res.eq(partial_results[i] | sel)
+            result = res
             s,e = intervals[0]
             out.append(res[s:e])