split out PartialResults to separate module
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 15 Feb 2020 15:39:02 +0000 (15:39 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 15 Feb 2020 15:39:02 +0000 (15:39 +0000)
src/ieee754/part_shift/part_shift_dynamic.py

index ee24dc7db89c0999b99757de0122c081bbcd6e1d..f2394d5a279c810f8ac41cc3ea84d78e50787a83 100644 (file)
@@ -45,6 +45,49 @@ class ShifterMask(Elaboratable):
         return m
 
 
+class PartialResult(Elaboratable):
+    def __init__(self, pwid, bwid, reswid):
+        self.pwid = pwid
+        self.bwid = bwid
+        self.reswid = reswid
+        self.element = Signal(bwid, reset_less=True)
+        self.elmux = Signal(bwid, reset_less=True)
+        self.a_interval = Signal(bwid, reset_less=True)
+        self.masked = Signal(bwid, reset_less=True)
+        self.gate = Signal(reset_less=True)
+        self.partial = Signal(reswid, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+
+        shiftbits = math.ceil(math.log2(self.reswid+1))+1 # hmmm...
+        print ("partial", self.reswid, self.pwid, shiftbits)
+        element = Mux(self.gate, self.masked, self.element)
+        comb += self.elmux.eq(element)
+        element = self.elmux
+
+        # This calculates which partition of b to select the
+        # shifter from. According to the table above, the
+        # partition to select is given by the highest set bit in
+        # the partition mask, this calculates that with a mux
+        # chain
+
+        # This computes the partial results table.  note that
+        # the shift amount is truncated because there's no point
+        # trying to shift data by 64 bits if the result width
+        # is only 8.
+        shifter = Signal(shiftbits, reset_less=True)
+        with m.If(element > shiftbits):
+            comb += shifter.eq(shiftbits)
+        with m.Else():
+            comb += shifter.eq(element)
+        comb += shifter.eq(element)
+        comb += self.partial.eq(self.a_interval << shifter)
+
+        return m
+
+
 class PartitionedDynamicShift(Elaboratable):
     def __init__(self, width, partition_points):
         self.width = width
@@ -124,36 +167,16 @@ class PartitionedDynamicShift(Elaboratable):
             shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
             print ("partial", reswid, width, intervals[i], shiftbits)
             s, e = intervals[i]
+            pr = PartialResult(pwid, b_intervals[i].shape()[0], reswid)
+            setattr(m.submodules, "pr%d" % i, pr)
             masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
                           reset_less=True)
-            comb += masked.eq(b_intervals[i] & shifter_masks[i])
-            element = Mux(gates[i-1], masked, element)
-            elmux = Signal(b_intervals[i].shape(), name="elmux%d" % i,
-                          reset_less=True)
-            comb += elmux.eq(element)
-            element = elmux
-
-            # This calculates which partition of b to select the
-            # shifter from. According to the table above, the
-            # partition to select is given by the highest set bit in
-            # the partition mask, this calculates that with a mux
-            # chain
-
-            # This computes the partial results table.  note that
-            # the shift amount is truncated because there's no point
-            # trying to shift data by 64 bits if the result width
-            # is only 8.
-            shifter = Signal(shiftbits, name="shifter%d" % i,
-                          reset_less=True)
-            with m.If(element > shiftbits):
-                comb += shifter.eq(shiftbits)
-            with m.Else():
-                comb += shifter.eq(element)
-            comb += shifter.eq(element)
-            partial = Signal(reswid, name="partial%d" % i, reset_less=True)
-            comb += partial.eq(a_intervals[i] << shifter)
-
-            partial_results.append(partial)
+            comb += pr.masked.eq(b_intervals[i] & shifter_masks[i])
+            comb += pr.gate.eq(gates[i-1])
+            comb += pr.element.eq(element)
+            comb += pr.a_interval.eq(a_intervals[i])
+            partial_results.append(pr.partial)
+            element = pr.elmux
 
         out = []