truncate partial result intermediate to width of partition
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
index 41bf16e3500eb47c65451aa5164ec657db89aa61..75e3fe4951c7c21549f80fe94096e64f89dce954 100644 (file)
@@ -23,103 +23,133 @@ class PartitionedDynamicShift(Elaboratable):
         self.width = width
         self.partition_points = PartitionPoints(partition_points)
 
-        self.a = Signal(width)
-        self.b = Signal(width)
-        self.output = Signal(width)
+        self.a = Signal(width, reset_less=True)
+        self.b = Signal(width, reset_less=True)
+        self.output = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
         width = self.width
-        gates = Signal(self.partition_points.get_max_partition_count(width)-1)
+        pwid = self.partition_points.get_max_partition_count(width)-1
+        gates = Signal(pwid, reset_less=True)
         comb += gates.eq(self.partition_points.as_sig())
 
         matrix = []
         keys = list(self.partition_points.keys()) + [self.width]
         start = 0
 
-        # create a matrix of partial shift-results (similar to PartitionedMul
-        # matrices).  These however have to be of length suitable to contain
-        # the full shifted "contribution".  i.e. B from the LSB *could* contain
-        # a number great enough to shift the entirety of A LSB right up to
-        # the MSB of the output, however B from the *MSB* is *only* going
-        # to contribute to the *MSB* of the output.
-        for i in range(len(keys)):
-            row = []
-            start = 0
-            for j in range(len(keys)):
-                end = keys[j]
-                row.append(Signal(width - start,
-                           name="matrix[%d][%d]" % (i, j)))
-                start = end
-            matrix.append(row)
-
         # break out both the input and output into partition-stratified blocks
         a_intervals = []
         b_intervals = []
-        out_intervals = []
         intervals = []
+        widths = []
         start = 0
         for i in range(len(keys)):
             end = keys[i]
+            widths.append(width - start)
             a_intervals.append(self.a[start:end])
             b_intervals.append(self.b[start:end])
-            out_intervals.append(self.output[start:end])
             intervals.append([start,end])
             start = end
 
-        # actually calculate the shift-partials here
-        for i, b in enumerate(b_intervals):
-            start = 0
-            for j in range(i, len(a_intervals)):
-                a = a_intervals[j]
-                end = keys[i]
-                result_width = matrix[i][j].width
-                rw = math.ceil(math.log2(result_width + 1))
-                # XXX!
-                bw = math.ceil(math.log2(self.output.width + 1))
-                tshift = Signal(bw, name="ts%d_%d" % (i, j), reset_less=True)
-                ow = math.ceil(math.log2(width-start))
-                maxshift = (1<<(ow))
-                print ("part", i, b, j, a, rw, bw, ow, maxshift)
-                with m.If(b[:bw] < maxshift):
-                    comb += tshift.eq(b[:bw])
-                with m.Else():
-                    comb += tshift.eq(maxshift)
-                comb += matrix[i][j].eq(a << tshift)
-                start = end
-
-        # now create a switch statement which sums the relevant partial results
-        # in each output-partition
+        min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
+        max_bits = math.ceil(math.log2(width))
+
+        # shifts are normally done as (e.g. for 32 bit) result = a & (b&0b11111)
+        # truncating the b input.  however here of course the size of the
+        # partition varies dynamically.
+        shifter_masks = []
+        for i in range(len(b_intervals)):
+            mask = Signal(b_intervals[i].shape(), name="shift_mask%d" % i,
+                          reset_less=True)
+            bits = Signal(gates.width-i+1, name="bits%d" % i, reset_less=True)
+            bl = []
+            for j in range(i, gates.width):
+                if bl:
+                    bl.append(~gates[j] & bits[j-i-1])
+                else:
+                    bl.append(~gates[j])
+            comb += bits.eq(Cat(*bl))
+            comb += mask.eq(Cat((1 << min_bits)-1, bits)
+                            & ((1 << max_bits)-1))
+            shifter_masks.append(mask)
+
+        print(shifter_masks)
+
+        # Instead of generating the matrix described in the wiki, I
+        # instead calculate the shift amounts for each partition, then
+        # calculate the partial results of each partition << shift
+        # amount. On the wiki, the following table is given for output #3:
+        # p2p1p0 | o3
+        # 0 0 0  | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
+        # 0 0 1  | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
+        # 0 1 0  | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
+        # 0 1 1  | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
+        # 1 0 0  | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
+        # 1 0 1  | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
+        # 1 1 0  | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
+        # 1 1 1  | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
+
+        # Each output for o3 is given by a3bx and the partial results
+        # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
+        # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
+        # those partial results to calculate a0, a1, a2, and a3
+        element = b_intervals[0] & shifter_masks[0]
+        partial_results = []
+        partial = Signal(width, name="partial0", reset_less=True)
+        comb += partial.eq(a_intervals[0] << element)
+        partial_results.append(partial)
+        for i in range(1, len(keys)):
+            reswid = width - intervals[i][0]
+            shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
+            print ("partial", reswid, width, intervals[i], shiftbits)
+            s, e = intervals[i]
+            masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
+                          reset_less=True)
+            comb += masked.eq(b_intervals[i] & shifter_masks[i])
+            element = Mux(gates[i-1], masked, element)
+            elmux = Signal(b_intervals[i].shape(), name="elmux%d" % i,
+                          reset_less=True)
+            comb += elmux.eq(element)
+            element = elmux
+
+            # This calculates which partition of b to select the
+            # shifter from. According to the table above, the
+            # partition to select is given by the highest set bit in
+            # the partition mask, this calculates that with a mux
+            # chain
+
+            # This computes the partial results table
+            shifter = Signal(shiftbits, name="shifter%d" % i,
+                          reset_less=True)
+            #with m.If(element > shiftbits):
+            #    comb += shifter.eq(shiftbits)
+            #with m.Else():
+            #    comb += shifter.eq(element)
+            comb += shifter.eq(element)
+            partial = Signal(reswid, name="partial%d" % i, reset_less=True)
+            comb += partial.eq(a_intervals[i] << shifter)
+
+            partial_results.append(partial)
 
         out = []
-        intermed = matrix[0][0]
-        s, e = intervals[0]
-        out.append(intermed[s:e])
-        for i in range(1, len(out_intervals)):
-            s, e = intervals[i]
-            index = gates[:i]  # selects the 'i' least significant bits
-                               # of gates
-            element = matrix[0][i]
-            for index in range(i):
-                element = Mux(gates[index], matrix[index+1][i], element)
-            print(keys[i-1])
-            temp = Signal(matrix[0][i].width, name="intermed%d" % i)
-            print(intermed[keys[0]:])
-            # XXX bit of a mess here, but rather than select
-            # element or (element | intermed), select between 0 or intermed
-            # then unconditionally "|" element on top (once copied into
-            # a named Signal)
-            # XXX TODO: hmmm rather than pass down the actual intermed
-            # here, why not accumulate a cascade of "do we need to include
-            # this partial result" things, *then* OR them together?
-            # this is where it sort-of becomes like the gt_combiner
-            intermed = Mux(gates[i-1], 0, intermed[keys[0]:])
-            intermed2 = Signal(intermed.shape())
-            comb += intermed2.eq(intermed | element)
-            intermed = intermed2
-            comb += temp.eq(intermed)
-            out.append(temp[:e-s])
+
+        # This calculates the outputs o0-o3 from the partial results
+        # table above.
+        s,e = intervals[0]
+        result = partial_results[0]
+        out.append(result[s:e])
+        for i in range(1, len(keys)):
+            start, end = (intervals[i][0], width)
+            reswid = width - start
+            sel = Mux(gates[i-1], 0, result[intervals[0][1]:][:end-start])
+            print("select: [%d:%d]" % (start, end))
+            res = Signal(end-start+1, name="res%d" % i, reset_less=True)
+            comb += res.eq(partial_results[i] | sel)
+            result = res
+            s,e = intervals[0]
+            out.append(res[s:e])
 
         comb += self.output.eq(Cat(*out))