import math
class ShifterMask(Elaboratable):
+
def __init__(self, pwid, bwid, max_bits, min_bits):
self.max_bits = max_bits
self.min_bits = min_bits
comb += self.mask.eq(minm)
return m
+ # create bit-cascade
bits = Signal(self.pwid, reset_less=True)
bl = []
for j in range(self.pwid):
else:
comb += bit.eq(~self.gates[j])
bl.append(bit)
+
# XXX ARGH, really annoying: simulation bug, can't use Cat(*bl).
for j in range(bits.shape()[0]):
comb += bits[j].eq(bl[j])
class PartitionedDynamicShift(Elaboratable):
+
def __init__(self, width, partition_points):
self.width = width
self.partition_points = PartitionPoints(partition_points)
def elaborate(self, platform):
m = Module()
+
+ # temporaries
comb = m.d.comb
width = self.width
pwid = self.partition_points.get_max_partition_count(width)-1
keys = list(self.partition_points.keys()) + [self.width]
start = 0
+ # create gated-reversed versions of a, b and the output
+ # left-shift is non-reversed, right-shift is reversed
m.submodules.a_br = a_br = GatedBitReverse(self.a.width)
comb += a_br.data.eq(self.a)
comb += a_br.reverse_en.eq(self.shift_right)
comb += gate_br.data.eq(gates)
comb += gate_br.reverse_en.eq(self.shift_right)
-
# break out both the input and output into partition-stratified blocks
a_intervals = []
b_intervals = []
b_shl_amount.append(element)
for i in range(1, len(keys)):
element = Mux(gates[i-1], masked_b[i], element)
- b_shl_amount.append(element)
+ b_shl_amount.append(element) # FIXME: creates an O(N^2) cascade
b_shr_amount = list(reversed(b_shl_amount))
+ # select shift-amount (b) for partition based on op being left or right
shift_amounts = []
for i in range(len(b_shl_amount)):
shift_amount = Signal(masked_b[i].width, name="shift_amount%d" % i)
- comb += shift_amount.eq(
- Mux(self.shift_right, b_shr_amount[i], b_shl_amount[i]))
+ sel = Mux(self.shift_right, b_shr_amount[i], b_shl_amount[i])
+ comb += shift_amount.eq(sel)
shift_amounts.append(shift_amount)
+ # now calculate partial results
+
+ # first item (simple)
partial_results = []
partial = Signal(width, name="partial0", reset_less=True)
comb += partial.eq(a_intervals[0] << shift_amounts[0])
-
partial_results.append(partial)
+
+ # rest of list
for i in range(1, len(keys)):
reswid = width - intervals[i][0]
shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
comb += pr.a_interval.eq(a_intervals[i])
partial_results.append(pr.partial)
- out = []
-
# This calculates the outputs o0-o3 from the partial results
# table above. Note: only relevant bits of the partial result equal
# to the width of the output column are accumulated in a Mux-cascade.
+ out = []
s,e = intervals[0]
result = partial_results[0]
out.append(result[s:e])