From 8baefc384ebc08fb5fcfb7e9ab7e80007d32fa6c Mon Sep 17 00:00:00 2001 From: Michael Nolan Date: Fri, 14 Feb 2020 10:28:29 -0500 Subject: [PATCH] Refactor part_shift_dynamic.py This removes the matrix of partial results by instead using the partition bits to calculate the shifter (b) for each partition, and computing a short table of partial results from that --- src/ieee754/part_shift/part_shift_dynamic.py | 116 ++++++++---------- .../part_shift/test/test_shift_dynamic.py | 5 + 2 files changed, 57 insertions(+), 64 deletions(-) diff --git a/src/ieee754/part_shift/part_shift_dynamic.py b/src/ieee754/part_shift/part_shift_dynamic.py index 41bf16e3..ec4893b4 100644 --- a/src/ieee754/part_shift/part_shift_dynamic.py +++ b/src/ieee754/part_shift/part_shift_dynamic.py @@ -38,90 +38,78 @@ class PartitionedDynamicShift(Elaboratable): keys = list(self.partition_points.keys()) + [self.width] start = 0 - # create a matrix of partial shift-results (similar to PartitionedMul - # matrices). These however have to be of length suitable to contain - # the full shifted "contribution". i.e. B from the LSB *could* contain - # a number great enough to shift the entirety of A LSB right up to - # the MSB of the output, however B from the *MSB* is *only* going - # to contribute to the *MSB* of the output. - for i in range(len(keys)): - row = [] - start = 0 - for j in range(len(keys)): - end = keys[j] - row.append(Signal(width - start, - name="matrix[%d][%d]" % (i, j))) - start = end - matrix.append(row) # break out both the input and output into partition-stratified blocks a_intervals = [] b_intervals = [] out_intervals = [] intervals = [] + widths = [] start = 0 for i in range(len(keys)): end = keys[i] + widths.append(width - start) a_intervals.append(self.a[start:end]) b_intervals.append(self.b[start:end]) out_intervals.append(self.output[start:end]) intervals.append([start,end]) start = end - # actually calculate the shift-partials here - for i, b in enumerate(b_intervals): - start = 0 - for j in range(i, len(a_intervals)): - a = a_intervals[j] - end = keys[i] - result_width = matrix[i][j].width - rw = math.ceil(math.log2(result_width + 1)) - # XXX! - bw = math.ceil(math.log2(self.output.width + 1)) - tshift = Signal(bw, name="ts%d_%d" % (i, j), reset_less=True) - ow = math.ceil(math.log2(width-start)) - maxshift = (1<<(ow)) - print ("part", i, b, j, a, rw, bw, ow, maxshift) - with m.If(b[:bw] < maxshift): - comb += tshift.eq(b[:bw]) - with m.Else(): - comb += tshift.eq(maxshift) - comb += matrix[i][j].eq(a << tshift) - start = end - - # now create a switch statement which sums the relevant partial results - # in each output-partition - - out = [] - intermed = matrix[0][0] - s, e = intervals[0] - out.append(intermed[s:e]) + # Instead of generating the matrix described in the wiki, I + # instead calculate the shift amounts for each partition, then + # calculate the partial results of each partition << shift + # amount. On the wiki, the following table is given for output #3: + # p2p1p0 | o3 + # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0] + # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0] + # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0] + # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0] + # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0] + # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0] + # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0] + # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0] + + # Each output for o3 is given by a3bx and the partial results + # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the + # partial results [a0b0, a1bx, a2bx, a3bx], I can use just + # those partial results to calculate a0, a1, a2, and a3 + partial_results = [] + partial_results.append(a_intervals[0] << b_intervals[0]) for i in range(1, len(out_intervals)): s, e = intervals[i] - index = gates[:i] # selects the 'i' least significant bits - # of gates - element = matrix[0][i] + + # This calculates which partition of b to select the + # shifter from. According to the table above, the + # partition to select is given by the highest set bit in + # the partition mask, this calculates that with a mux + # chain + element = b_intervals[0] for index in range(i): - element = Mux(gates[index], matrix[index+1][i], element) - print(keys[i-1]) - temp = Signal(matrix[0][i].width, name="intermed%d" % i) - print(intermed[keys[0]:]) - # XXX bit of a mess here, but rather than select - # element or (element | intermed), select between 0 or intermed - # then unconditionally "|" element on top (once copied into - # a named Signal) - # XXX TODO: hmmm rather than pass down the actual intermed - # here, why not accumulate a cascade of "do we need to include - # this partial result" things, *then* OR them together? - # this is where it sort-of becomes like the gt_combiner - intermed = Mux(gates[i-1], 0, intermed[keys[0]:]) - intermed2 = Signal(intermed.shape()) - comb += intermed2.eq(intermed | element) - intermed = intermed2 - comb += temp.eq(intermed) - out.append(temp[:e-s]) + element = Mux(gates[index], b_intervals[index+1], element) + + # This computes the partial results table + shifter = Signal(8) + comb += shifter.eq(element) + partial = Signal(width, name="partial%d" % i) + comb += partial.eq(a_intervals[i] << shifter) + + partial_results.append(partial) + + out = [] + + # This calculates the outputs o0-o3 from the partial results + # table above. + for i in range(len(out_intervals)): + result = 0 + for j in range(i): + s,e = intervals[i-j] + result = Mux(gates[j], 0, result | partial_results[j][s:e]) + result = partial_results[i] | result + s,e = intervals[0] + out.append(result[s:e]) comb += self.output.eq(Cat(*out)) + return m diff --git a/src/ieee754/part_shift/test/test_shift_dynamic.py b/src/ieee754/part_shift/test/test_shift_dynamic.py index 68579914..86db976c 100644 --- a/src/ieee754/part_shift/test/test_shift_dynamic.py +++ b/src/ieee754/part_shift/test/test_shift_dynamic.py @@ -45,6 +45,11 @@ class DynamicShiftTestCase(FHDLTestCase): def process(): yield a.eq(0x01010101) yield b.eq(0x04030201) + for i in range(1<<(mwidth-1)): + yield gates.eq(i) + yield Delay(1e-6) + yield Settle() + yield b.eq(0x0c0b0a09) for i in range(1<<(mwidth-1)): yield gates.eq(i) yield Delay(1e-6) -- 2.30.2