Remove element mux calculation from PartialResult
authorMichael Nolan <mtnolan2640@gmail.com>
Wed, 26 Feb 2020 16:09:20 +0000 (11:09 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Wed, 26 Feb 2020 16:09:20 +0000 (11:09 -0500)
src/ieee754/part_shift/part_shift_dynamic.py

index 230455092914d4a7ce82aca275463cb0caf94a5e..788b400a465bd31ee23a089aef60998e9bd2fe47 100644 (file)
@@ -63,10 +63,8 @@ class PartialResult(Elaboratable):
         self.pwid = pwid
         self.bwid = bwid
         self.reswid = reswid
-        self.element = Signal(bwid, reset_less=True)
-        self.elmux = Signal(bwid, reset_less=True)
+        self.b = Signal(bwid, reset_less=True)
         self.a_interval = Signal(bwid, reset_less=True)
-        self.masked = Signal(bwid, reset_less=True)
         self.gate = Signal(reset_less=True)
         self.partial = Signal(reswid, reset_less=True)
 
@@ -76,9 +74,7 @@ class PartialResult(Elaboratable):
 
         shiftbits = math.ceil(math.log2(self.reswid+1))+1 # hmmm...
         print ("partial", self.reswid, self.pwid, shiftbits)
-        element = Mux(self.gate, self.masked, self.element)
-        comb += self.elmux.eq(element)
-        element = self.elmux
+        element = self.b
 
         # This calculates which partition of b to select the
         # shifter from. According to the table above, the
@@ -189,27 +185,32 @@ class PartitionedDynamicShift(Elaboratable):
         # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
         # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
         # those partial results to calculate a0, a1, a2, and a3
+
+        masked_b = []
+        for i in range(0, len(keys)):
+            masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
+                          reset_less=True)
+            comb += masked.eq(b_intervals[i] & shifter_masks[i])
+            masked_b.append(masked)
+
         element = Signal(b_intervals[0].shape(), reset_less=True)
-        comb += element.eq(b_intervals[0] & shifter_masks[0])
+        comb += element.eq(masked_b[0])
         partial_results = []
         partial = Signal(width, name="partial0", reset_less=True)
         comb += partial.eq(a_intervals[0] << element)
         partial_results.append(partial)
         for i in range(1, len(keys)):
+            element = Mux(gate_br.output[i-1], masked_b[i], element)
             reswid = width - intervals[i][0]
             shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
             print ("partial", reswid, width, intervals[i], shiftbits)
             s, e = intervals[i]
             pr = PartialResult(pwid, b_intervals[i].shape()[0], reswid)
             setattr(m.submodules, "pr%d" % i, pr)
-            masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
-                          reset_less=True)
-            comb += pr.masked.eq(b_intervals[i] & shifter_masks[i])
             comb += pr.gate.eq(gate_br.output[i-1])
-            comb += pr.element.eq(element)
+            comb += pr.b.eq(element)
             comb += pr.a_interval.eq(a_intervals[i])
             partial_results.append(pr.partial)
-            element = pr.elmux
 
         out = []