Minor refactoring of part_shift_dynamic
authorMichael Nolan <mtnolan2640@gmail.com>
Fri, 14 Feb 2020 16:23:32 +0000 (11:23 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Fri, 14 Feb 2020 16:23:32 +0000 (11:23 -0500)
src/ieee754/part_shift/part_shift_dynamic.py

index ec4893b48557306b4791868350b941326c36730d..edcb7162ac68f3014acda5b1bf6dcac7d4b41267 100644 (file)
@@ -75,20 +75,20 @@ class PartitionedDynamicShift(Elaboratable):
         # those partial results to calculate a0, a1, a2, and a3
         partial_results = []
         partial_results.append(a_intervals[0] << b_intervals[0])
+        element = b_intervals[0]
         for i in range(1, len(out_intervals)):
             s, e = intervals[i]
+            element = Mux(gates[i-1], b_intervals[i], element)
 
             # This calculates which partition of b to select the
             # shifter from. According to the table above, the
             # partition to select is given by the highest set bit in
             # the partition mask, this calculates that with a mux
             # chain
-            element = b_intervals[0]
-            for index in range(i):
-                element = Mux(gates[index], b_intervals[index+1], element)
+
 
             # This computes the partial results table
-            shifter = Signal(8)
+            shifter = Signal(8, name="shifter%d" % i)
             comb += shifter.eq(element)
             partial = Signal(width, name="partial%d" % i)
             comb += partial.eq(a_intervals[i] << shifter)
@@ -99,17 +99,20 @@ class PartitionedDynamicShift(Elaboratable):
 
         # This calculates the outputs o0-o3 from the partial results
         # table above.
-        for i in range(len(out_intervals)):
-            result = 0
-            for j in range(i):
-                s,e = intervals[i-j]
-                result = Mux(gates[j], 0, result | partial_results[j][s:e])
-            result = partial_results[i] | result
+        s,e = intervals[0]
+        result = partial_results[0]
+        out.append(result[s:e])
+        for i in range(1, len(out_intervals)):
+            start, end = (intervals[i][0], width)
+            result = partial_results[i] | \
+                Mux(gates[i-1], 0, result[intervals[0][1]:])[:end-start]
+            print("select: [%d:%d]" % (start, end))
+            res = Signal(width, name="res%d" % i)
+            comb += res.eq(result)
             s,e = intervals[0]
-            out.append(result[s:e])
+            out.append(res[s:e])
 
         comb += self.output.eq(Cat(*out))
 
-
         return m