keys = list(self.partition_points.keys()) + [self.width]
start = 0
- # create a matrix of partial shift-results (similar to PartitionedMul
- # matrices). These however have to be of length suitable to contain
- # the full shifted "contribution". i.e. B from the LSB *could* contain
- # a number great enough to shift the entirety of A LSB right up to
- # the MSB of the output, however B from the *MSB* is *only* going
- # to contribute to the *MSB* of the output.
- for i in range(len(keys)):
- row = []
- start = 0
- for j in range(len(keys)):
- end = keys[j]
- row.append(Signal(width - start,
- name="matrix[%d][%d]" % (i, j)))
- start = end
- matrix.append(row)
# break out both the input and output into partition-stratified blocks
a_intervals = []
b_intervals = []
out_intervals = []
intervals = []
+ widths = []
start = 0
for i in range(len(keys)):
end = keys[i]
+ widths.append(width - start)
a_intervals.append(self.a[start:end])
b_intervals.append(self.b[start:end])
out_intervals.append(self.output[start:end])
intervals.append([start,end])
start = end
- # actually calculate the shift-partials here
- for i, b in enumerate(b_intervals):
- start = 0
- for j in range(i, len(a_intervals)):
- a = a_intervals[j]
- end = keys[i]
- result_width = matrix[i][j].width
- rw = math.ceil(math.log2(result_width + 1))
- # XXX!
- bw = math.ceil(math.log2(self.output.width + 1))
- tshift = Signal(bw, name="ts%d_%d" % (i, j), reset_less=True)
- ow = math.ceil(math.log2(width-start))
- maxshift = (1<<(ow))
- print ("part", i, b, j, a, rw, bw, ow, maxshift)
- with m.If(b[:bw] < maxshift):
- comb += tshift.eq(b[:bw])
- with m.Else():
- comb += tshift.eq(maxshift)
- comb += matrix[i][j].eq(a << tshift)
- start = end
-
- # now create a switch statement which sums the relevant partial results
- # in each output-partition
-
- out = []
- intermed = matrix[0][0]
- s, e = intervals[0]
- out.append(intermed[s:e])
+ # Instead of generating the matrix described in the wiki, I
+ # instead calculate the shift amounts for each partition, then
+ # calculate the partial results of each partition << shift
+ # amount. On the wiki, the following table is given for output #3:
+ # p2p1p0 | o3
+ # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
+ # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
+ # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
+ # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
+ # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
+ # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
+ # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
+ # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
+
+ # Each output for o3 is given by a3bx and the partial results
+ # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
+ # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
+ # those partial results to calculate a0, a1, a2, and a3
+ partial_results = []
+ partial_results.append(a_intervals[0] << b_intervals[0])
for i in range(1, len(out_intervals)):
s, e = intervals[i]
- index = gates[:i] # selects the 'i' least significant bits
- # of gates
- element = matrix[0][i]
+
+ # This calculates which partition of b to select the
+ # shifter from. According to the table above, the
+ # partition to select is given by the highest set bit in
+ # the partition mask, this calculates that with a mux
+ # chain
+ element = b_intervals[0]
for index in range(i):
- element = Mux(gates[index], matrix[index+1][i], element)
- print(keys[i-1])
- temp = Signal(matrix[0][i].width, name="intermed%d" % i)
- print(intermed[keys[0]:])
- # XXX bit of a mess here, but rather than select
- # element or (element | intermed), select between 0 or intermed
- # then unconditionally "|" element on top (once copied into
- # a named Signal)
- # XXX TODO: hmmm rather than pass down the actual intermed
- # here, why not accumulate a cascade of "do we need to include
- # this partial result" things, *then* OR them together?
- # this is where it sort-of becomes like the gt_combiner
- intermed = Mux(gates[i-1], 0, intermed[keys[0]:])
- intermed2 = Signal(intermed.shape())
- comb += intermed2.eq(intermed | element)
- intermed = intermed2
- comb += temp.eq(intermed)
- out.append(temp[:e-s])
+ element = Mux(gates[index], b_intervals[index+1], element)
+
+ # This computes the partial results table
+ shifter = Signal(8)
+ comb += shifter.eq(element)
+ partial = Signal(width, name="partial%d" % i)
+ comb += partial.eq(a_intervals[i] << shifter)
+
+ partial_results.append(partial)
+
+ out = []
+
+ # This calculates the outputs o0-o3 from the partial results
+ # table above.
+ for i in range(len(out_intervals)):
+ result = 0
+ for j in range(i):
+ s,e = intervals[i-j]
+ result = Mux(gates[j], 0, result | partial_results[j][s:e])
+ result = partial_results[i] | result
+ s,e = intervals[0]
+ out.append(result[s:e])
comb += self.output.eq(Cat(*out))
+
return m