# setup the inputs and outputs of the DUT as anyconst
a = Signal(width)
b = Signal(width)
+ bitrev = Signal()
out = Signal(width)
points = PartitionPoints()
gates = Signal(mwidth-1)
comb += [a.eq(AnyConst(width)),
b.eq(AnyConst(width)),
+ bitrev.eq(AnyConst(1)),
gates.eq(AnyConst(mwidth-1))]
m.submodules.dut = dut = PartitionedDynamicShift(width, points)
comb += [dut.a.eq(a),
dut.b.eq(b),
+ dut.bitrev.eq(bitrev),
out.eq(dut.output)]
- with m.Switch(points.as_sig()):
- with m.Case(0b000):
- comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
- with m.Case(0b001):
- comb += Assert(out_intervals[0] ==
- (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
- comb += Assert(Cat(out_intervals[1:4]) ==
- (Cat(a_intervals[1:4])
- << b_intervals[1][0:5]) & 0xffffff)
- with m.Case(0b010):
- comb += Assert(Cat(out_intervals[0:2]) ==
- (Cat(a_intervals[0:2])
- << (b_intervals[0] & 0xf)) & 0xffff)
- comb += Assert(Cat(out_intervals[2:4]) ==
- (Cat(a_intervals[2:4])
- << (b_intervals[2] & 0xf)) & 0xffff)
- with m.Case(0b011):
- comb += Assert(out_intervals[0] ==
- (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
- comb += Assert(out_intervals[1] ==
- (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
- comb += Assert(Cat(out_intervals[2:4]) ==
- (Cat(a_intervals[2:4])
- << b_intervals[2][0:4]) & 0xffff)
- with m.Case(0b100):
- comb += Assert(Cat(out_intervals[0:3]) ==
- (Cat(a_intervals[0:3])
- << b_intervals[0][0:5]) & 0xffffff)
- comb += Assert(out_intervals[3] ==
- (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
- with m.Case(0b101):
- comb += Assert(out_intervals[0] ==
- (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
- comb += Assert(Cat(out_intervals[1:3]) ==
- (Cat(a_intervals[1:3])
- << b_intervals[1][0:4]) & 0xffff)
- comb += Assert(out_intervals[3] ==
- (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
- with m.Case(0b110):
- comb += Assert(Cat(out_intervals[0:2]) ==
- (Cat(a_intervals[0:2])
- << b_intervals[0][0:4]) & 0xffff)
- comb += Assert(out_intervals[2] ==
- (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
- comb += Assert(out_intervals[3] ==
- (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
- with m.Case(0b111):
- for i, o in enumerate(out_intervals):
- comb += Assert(o ==
- (a_intervals[i] << b_intervals[i][0:3])
- & 0xff)
+ with m.If(bitrev == 0):
+ with m.Switch(points.as_sig()):
+ with m.Case(0b000):
+ comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
+ with m.Case(0b001):
+ comb += Assert(out_intervals[0] ==
+ (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
+ comb += Assert(Cat(out_intervals[1:4]) ==
+ (Cat(a_intervals[1:4])
+ << b_intervals[1][0:5]) & 0xffffff)
+ with m.Case(0b010):
+ comb += Assert(Cat(out_intervals[0:2]) ==
+ (Cat(a_intervals[0:2])
+ << (b_intervals[0] & 0xf)) & 0xffff)
+ comb += Assert(Cat(out_intervals[2:4]) ==
+ (Cat(a_intervals[2:4])
+ << (b_intervals[2] & 0xf)) & 0xffff)
+ with m.Case(0b011):
+ comb += Assert(out_intervals[0] ==
+ (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
+ comb += Assert(out_intervals[1] ==
+ (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
+ comb += Assert(Cat(out_intervals[2:4]) ==
+ (Cat(a_intervals[2:4])
+ << b_intervals[2][0:4]) & 0xffff)
+ with m.Case(0b100):
+ comb += Assert(Cat(out_intervals[0:3]) ==
+ (Cat(a_intervals[0:3])
+ << b_intervals[0][0:5]) & 0xffffff)
+ comb += Assert(out_intervals[3] ==
+ (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
+ with m.Case(0b101):
+ comb += Assert(out_intervals[0] ==
+ (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
+ comb += Assert(Cat(out_intervals[1:3]) ==
+ (Cat(a_intervals[1:3])
+ << b_intervals[1][0:4]) & 0xffff)
+ comb += Assert(out_intervals[3] ==
+ (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
+ with m.Case(0b110):
+ comb += Assert(Cat(out_intervals[0:2]) ==
+ (Cat(a_intervals[0:2])
+ << b_intervals[0][0:4]) & 0xffff)
+ comb += Assert(out_intervals[2] ==
+ (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
+ comb += Assert(out_intervals[3] ==
+ (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
+ with m.Case(0b111):
+ for i, o in enumerate(out_intervals):
+ comb += Assert(o ==
+ (a_intervals[i] << b_intervals[i][0:3])
+ & 0xff)
+ with m.Else():
+ with m.Switch(points.as_sig()):
+ with m.Case(0b000):
+ comb += Assert(out == (a>>b[0:5]) & 0xffffffff)
+ with m.Case(0b111):
+ for i, o in enumerate(out_intervals):
+ comb += Assert(o ==
+ (a_intervals[i] >> b_intervals[i][0:3])
+ & 0xff)
return m
"""
from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_shift.bitrev import GatedBitReverse
import math
class ShifterMask(Elaboratable):
self.a = Signal(width, reset_less=True)
self.b = Signal(width, reset_less=True)
+ self.bitrev = Signal(reset_less=True)
self.output = Signal(width, reset_less=True)
def elaborate(self, platform):
keys = list(self.partition_points.keys()) + [self.width]
start = 0
+ m.submodules.a_br = a_br = GatedBitReverse(self.a.width)
+ comb += a_br.data.eq(self.a)
+ comb += a_br.reverse_en.eq(self.bitrev)
+
+ m.submodules.out_br = out_br = GatedBitReverse(self.output.width)
+ comb += out_br.reverse_en.eq(self.bitrev)
+ comb += self.output.eq(out_br.output)
+
+ m.submodules.gate_br = gate_br = GatedBitReverse(pwid)
+ comb += gate_br.data.eq(gates)
+ comb += gate_br.reverse_en.eq(self.bitrev)
+
+
# break out both the input and output into partition-stratified blocks
a_intervals = []
b_intervals = []
for i in range(len(keys)):
end = keys[i]
widths.append(width - start)
- a_intervals.append(self.a[start:end])
+ a_intervals.append(a_br.output[start:end])
b_intervals.append(self.b[start:end])
intervals.append([start,end])
start = end
min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
- # shifts are normally done as (e.g. for 32 bit) result = a & (b&0b11111)
- # truncating the b input. however here of course the size of the
- # partition varies dynamically.
+ # shifts are normally done as (e.g. for 32 bit) result = a &
+ # (b&0b11111) truncating the b input. however here of course
+ # the size of the partition varies dynamically.
shifter_masks = []
for i in range(len(b_intervals)):
bwid = b_intervals[i].shape()[0]
sm = ShifterMask(bitwid, bwid, max_bits, min_bits)
setattr(m.submodules, "sm%d" % i, sm)
if bitwid != 0:
- comb += sm.gates.eq(gates[i:pwid])
+ comb += sm.gates.eq(gate_br.output[i:pwid])
shifter_masks.append(sm.mask)
print(shifter_masks)
masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
reset_less=True)
comb += pr.masked.eq(b_intervals[i] & shifter_masks[i])
- comb += pr.gate.eq(gates[i-1])
+ comb += pr.gate.eq(gate_br.output[i-1])
comb += pr.element.eq(element)
comb += pr.a_interval.eq(a_intervals[i])
partial_results.append(pr.partial)
for i in range(1, len(keys)):
start, end = (intervals[i][0], width)
reswid = width - start
- sel = Mux(gates[i-1], 0, result[intervals[0][1]:][:end-start])
+ sel = Mux(gate_br.output[i-1], 0,
+ result[intervals[0][1]:][:end-start])
print("select: [%d:%d]" % (start, end))
res = Signal(end-start+1, name="res%d" % i, reset_less=True)
comb += res.eq(partial_results[i] | sel)
s,e = intervals[0]
out.append(res[s:e])
- comb += self.output.eq(Cat(*out))
+ comb += out_br.data.eq(Cat(*out))
return m