Refactor part_shift_dynamic.py
authorMichael Nolan <mtnolan2640@gmail.com>
Fri, 14 Feb 2020 15:28:29 +0000 (10:28 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Fri, 14 Feb 2020 15:28:29 +0000 (10:28 -0500)
This removes the matrix of partial results by instead using the
partition bits to calculate the shifter (b) for each partition, and
computing a short table of partial results from that

src/ieee754/part_shift/part_shift_dynamic.py
src/ieee754/part_shift/test/test_shift_dynamic.py

index 41bf16e3500eb47c65451aa5164ec657db89aa61..ec4893b48557306b4791868350b941326c36730d 100644 (file)
@@ -38,90 +38,78 @@ class PartitionedDynamicShift(Elaboratable):
         keys = list(self.partition_points.keys()) + [self.width]
         start = 0
 
-        # create a matrix of partial shift-results (similar to PartitionedMul
-        # matrices).  These however have to be of length suitable to contain
-        # the full shifted "contribution".  i.e. B from the LSB *could* contain
-        # a number great enough to shift the entirety of A LSB right up to
-        # the MSB of the output, however B from the *MSB* is *only* going
-        # to contribute to the *MSB* of the output.
-        for i in range(len(keys)):
-            row = []
-            start = 0
-            for j in range(len(keys)):
-                end = keys[j]
-                row.append(Signal(width - start,
-                           name="matrix[%d][%d]" % (i, j)))
-                start = end
-            matrix.append(row)
 
         # break out both the input and output into partition-stratified blocks
         a_intervals = []
         b_intervals = []
         out_intervals = []
         intervals = []
+        widths = []
         start = 0
         for i in range(len(keys)):
             end = keys[i]
+            widths.append(width - start)
             a_intervals.append(self.a[start:end])
             b_intervals.append(self.b[start:end])
             out_intervals.append(self.output[start:end])
             intervals.append([start,end])
             start = end
 
-        # actually calculate the shift-partials here
-        for i, b in enumerate(b_intervals):
-            start = 0
-            for j in range(i, len(a_intervals)):
-                a = a_intervals[j]
-                end = keys[i]
-                result_width = matrix[i][j].width
-                rw = math.ceil(math.log2(result_width + 1))
-                # XXX!
-                bw = math.ceil(math.log2(self.output.width + 1))
-                tshift = Signal(bw, name="ts%d_%d" % (i, j), reset_less=True)
-                ow = math.ceil(math.log2(width-start))
-                maxshift = (1<<(ow))
-                print ("part", i, b, j, a, rw, bw, ow, maxshift)
-                with m.If(b[:bw] < maxshift):
-                    comb += tshift.eq(b[:bw])
-                with m.Else():
-                    comb += tshift.eq(maxshift)
-                comb += matrix[i][j].eq(a << tshift)
-                start = end
-
-        # now create a switch statement which sums the relevant partial results
-        # in each output-partition
-
-        out = []
-        intermed = matrix[0][0]
-        s, e = intervals[0]
-        out.append(intermed[s:e])
+        # Instead of generating the matrix described in the wiki, I
+        # instead calculate the shift amounts for each partition, then
+        # calculate the partial results of each partition << shift
+        # amount. On the wiki, the following table is given for output #3:
+        # p2p1p0 | o3
+        # 0 0 0  | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
+        # 0 0 1  | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
+        # 0 1 0  | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
+        # 0 1 1  | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
+        # 1 0 0  | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
+        # 1 0 1  | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
+        # 1 1 0  | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
+        # 1 1 1  | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
+
+        # Each output for o3 is given by a3bx and the partial results
+        # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
+        # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
+        # those partial results to calculate a0, a1, a2, and a3
+        partial_results = []
+        partial_results.append(a_intervals[0] << b_intervals[0])
         for i in range(1, len(out_intervals)):
             s, e = intervals[i]
-            index = gates[:i]  # selects the 'i' least significant bits
-                               # of gates
-            element = matrix[0][i]
+
+            # This calculates which partition of b to select the
+            # shifter from. According to the table above, the
+            # partition to select is given by the highest set bit in
+            # the partition mask, this calculates that with a mux
+            # chain
+            element = b_intervals[0]
             for index in range(i):
-                element = Mux(gates[index], matrix[index+1][i], element)
-            print(keys[i-1])
-            temp = Signal(matrix[0][i].width, name="intermed%d" % i)
-            print(intermed[keys[0]:])
-            # XXX bit of a mess here, but rather than select
-            # element or (element | intermed), select between 0 or intermed
-            # then unconditionally "|" element on top (once copied into
-            # a named Signal)
-            # XXX TODO: hmmm rather than pass down the actual intermed
-            # here, why not accumulate a cascade of "do we need to include
-            # this partial result" things, *then* OR them together?
-            # this is where it sort-of becomes like the gt_combiner
-            intermed = Mux(gates[i-1], 0, intermed[keys[0]:])
-            intermed2 = Signal(intermed.shape())
-            comb += intermed2.eq(intermed | element)
-            intermed = intermed2
-            comb += temp.eq(intermed)
-            out.append(temp[:e-s])
+                element = Mux(gates[index], b_intervals[index+1], element)
+
+            # This computes the partial results table
+            shifter = Signal(8)
+            comb += shifter.eq(element)
+            partial = Signal(width, name="partial%d" % i)
+            comb += partial.eq(a_intervals[i] << shifter)
+
+            partial_results.append(partial)
+
+        out = []
+
+        # This calculates the outputs o0-o3 from the partial results
+        # table above.
+        for i in range(len(out_intervals)):
+            result = 0
+            for j in range(i):
+                s,e = intervals[i-j]
+                result = Mux(gates[j], 0, result | partial_results[j][s:e])
+            result = partial_results[i] | result
+            s,e = intervals[0]
+            out.append(result[s:e])
 
         comb += self.output.eq(Cat(*out))
 
+
         return m
 
index 685799141a1c2498b649510c201a6625d8c25179..86db976cac0bb8977fecbc8bac8ad52bf4b516b9 100644 (file)
@@ -45,6 +45,11 @@ class DynamicShiftTestCase(FHDLTestCase):
         def process():
             yield a.eq(0x01010101)
             yield b.eq(0x04030201)
+            for i in range(1<<(mwidth-1)):
+                yield gates.eq(i)
+                yield Delay(1e-6)
+                yield Settle()
+            yield b.eq(0x0c0b0a09)
             for i in range(1<<(mwidth-1)):
                 yield gates.eq(i)
                 yield Delay(1e-6)