# those partial results to calculate a0, a1, a2, and a3
element = b_intervals[0] & shifter_masks[0]
partial_results = []
- partial_results.append(a_intervals[0] << element)
+ partial = Signal(width, name="partial0", reset_less=True)
+ comb += partial.eq(a_intervals[0] << element)
+ partial_results.append(partial)
for i in range(1, len(keys)):
reswid = width - intervals[i][0]
shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
# the partition mask, this calculates that with a mux
# chain
- # This computes the partial results table
+ # This computes the partial results table. note that
+ # the shift amount is truncated because there's no point
+ # trying to shift data by 64 bits if the result width
+ # is only 8.
shifter = Signal(shiftbits, name="shifter%d" % i,
reset_less=True)
- #with m.If(element > shiftbits):
- # comb += shifter.eq(shiftbits)
- #with m.Else():
- # comb += shifter.eq(element)
+ with m.If(element > shiftbits):
+ comb += shifter.eq(shiftbits)
+ with m.Else():
+ comb += shifter.eq(element)
comb += shifter.eq(element)
- partial = Signal(width, name="partial%d" % i, reset_less=True)
+ partial = Signal(reswid, name="partial%d" % i, reset_less=True)
comb += partial.eq(a_intervals[i] << shifter)
partial_results.append(partial)
out = []
# This calculates the outputs o0-o3 from the partial results
- # table above.
+ # table above. Note: only relevant bits of the partial result equal
+ # to the width of the output column are accumulated in a Mux-cascade.
s,e = intervals[0]
result = partial_results[0]
out.append(result[s:e])
for i in range(1, len(keys)):
start, end = (intervals[i][0], width)
- result = partial_results[i] | \
- Mux(gates[i-1], 0, result[intervals[0][1]:])[:end-start]
+ reswid = width - start
+ sel = Mux(gates[i-1], 0, result[intervals[0][1]:][:end-start])
print("select: [%d:%d]" % (start, end))
- res = Signal(width, name="res%d" % i, reset_less=True)
- comb += res.eq(result)
+ res = Signal(end-start+1, name="res%d" % i, reset_less=True)
+ comb += res.eq(partial_results[i] | sel)
+ result = res
s,e = intervals[0]
out.append(res[s:e])