with m.Switch(points.as_sig()):
with m.Case(0b000):
comb += Assert(out == (a>>b[0:5]) & 0xffffffff)
+ with m.Case(0b001):
+ comb += Assert(out_intervals[0] ==
+ (a_intervals[0] >> b_intervals[0][0:3]) & 0xff)
+ comb += Assert(Cat(out_intervals[1:4]) ==
+ (Cat(a_intervals[1:4])
+ >> b_intervals[1][0:5]) & 0xffffff)
+ with m.Case(0b010):
+ comb += Assert(Cat(out_intervals[0:2]) ==
+ (Cat(a_intervals[0:2])
+ >> (b_intervals[0] & 0xf)) & 0xffff)
+ comb += Assert(Cat(out_intervals[2:4]) ==
+ (Cat(a_intervals[2:4])
+ >> (b_intervals[2] & 0xf)) & 0xffff)
+ with m.Case(0b011):
+ comb += Assert(out_intervals[0] ==
+ (a_intervals[0] >> b_intervals[0][0:3]) & 0xff)
+ comb += Assert(out_intervals[1] ==
+ (a_intervals[1] >> b_intervals[1][0:3]) & 0xff)
+ comb += Assert(Cat(out_intervals[2:4]) ==
+ (Cat(a_intervals[2:4])
+ >> b_intervals[2][0:4]) & 0xffff)
+ with m.Case(0b100):
+ comb += Assert(Cat(out_intervals[0:3]) ==
+ (Cat(a_intervals[0:3])
+ >> b_intervals[0][0:5]) & 0xffffff)
+ comb += Assert(out_intervals[3] ==
+ (a_intervals[3] >> b_intervals[3][0:3]) & 0xff)
+ with m.Case(0b101):
+ comb += Assert(out_intervals[0] ==
+ (a_intervals[0] >> b_intervals[0][0:3]) & 0xff)
+ comb += Assert(Cat(out_intervals[1:3]) ==
+ (Cat(a_intervals[1:3])
+ >> b_intervals[1][0:4]) & 0xffff)
+ comb += Assert(out_intervals[3] ==
+ (a_intervals[3] >> b_intervals[3][0:3]) & 0xff)
+ with m.Case(0b110):
+ comb += Assert(Cat(out_intervals[0:2]) ==
+ (Cat(a_intervals[0:2])
+ >> b_intervals[0][0:4]) & 0xffff)
+ comb += Assert(out_intervals[2] ==
+ (a_intervals[2] >> b_intervals[2][0:3]) & 0xff)
+ comb += Assert(out_intervals[3] ==
+ (a_intervals[3] >> b_intervals[3][0:3]) & 0xff)
with m.Case(0b111):
for i, o in enumerate(out_intervals):
comb += Assert(o ==
sm = ShifterMask(bitwid, bwid, max_bits, min_bits)
setattr(m.submodules, "sm%d" % i, sm)
if bitwid != 0:
- comb += sm.gates.eq(gate_br.output[i:pwid])
+ comb += sm.gates.eq(gates[i:pwid])
shifter_masks.append(sm.mask)
print(shifter_masks)
reset_less=True)
comb += masked.eq(b_intervals[i] & shifter_masks[i])
masked_b.append(masked)
-
+ b_shl_amount = []
element = Signal(b_intervals[0].shape(), reset_less=True)
comb += element.eq(masked_b[0])
+ b_shl_amount.append(element)
+ for i in range(1, len(keys)):
+ element = Mux(gates[i-1], masked_b[i], element)
+ b_shl_amount.append(element)
+ b_shr_amount = list(reversed(b_shl_amount))
+
+ shift_amounts = []
+ for i in range(len(b_shl_amount)):
+ shift_amount = Signal(masked_b[i].width, name="shift_amount%d" % i)
+ comb += shift_amount.eq(
+ Mux(self.bitrev, b_shr_amount[i], b_shl_amount[i]))
+ shift_amounts.append(shift_amount)
+
partial_results = []
partial = Signal(width, name="partial0", reset_less=True)
- comb += partial.eq(a_intervals[0] << element)
+ comb += partial.eq(a_intervals[0] << shift_amounts[0])
+
partial_results.append(partial)
for i in range(1, len(keys)):
- element = Mux(gate_br.output[i-1], masked_b[i], element)
reswid = width - intervals[i][0]
shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
print ("partial", reswid, width, intervals[i], shiftbits)
pr = PartialResult(pwid, b_intervals[i].shape()[0], reswid)
setattr(m.submodules, "pr%d" % i, pr)
comb += pr.gate.eq(gate_br.output[i-1])
- comb += pr.b.eq(element)
+ comb += pr.b.eq(shift_amounts[i])
comb += pr.a_interval.eq(a_intervals[i])
partial_results.append(pr.partial)
comb += out_br.data.eq(Cat(*out))
return m
-