# Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
from nmigen import Module, Signal, Elaboratable, Mux, Cat
-from nmigen.asserts import Assert, AnyConst, Assume
+from nmigen.asserts import Assert, AnyConst
from nmigen.test.utils import FHDLTestCase
from nmigen.cli import rtlil
with m.Switch(points.as_sig()):
with m.Case(0b000):
- comb += Assume(b <= 32)
- comb += Assert(out == (a<<b[0:6]) & 0xffffffff)
+ comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
with m.Case(0b001):
- comb += Assume(b_intervals[0] <= 8)
comb += Assert(out_intervals[0] ==
- (a_intervals[0] << b_intervals[0]) & 0xff)
- comb += Assume(b_intervals[1] <= 24)
+ (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
comb += Assert(Cat(out_intervals[1:4]) ==
(Cat(a_intervals[1:4])
- << b_intervals[1]) & 0xffffff)
+ << b_intervals[1][0:5]) & 0xffffff)
with m.Case(0b010):
- comb += Assume(b_intervals[0] <= 16)
comb += Assert(Cat(out_intervals[0:2]) ==
(Cat(a_intervals[0:2])
- << b_intervals[0]) & 0xffff)
- comb += Assume(b_intervals[2] <= 16)
+ << (b_intervals[0] & 0xf)) & 0xffff)
comb += Assert(Cat(out_intervals[2:4]) ==
(Cat(a_intervals[2:4])
- << b_intervals[2]) & 0xffff)
+ << (b_intervals[2] & 0xf)) & 0xffff)
with m.Case(0b011):
- comb += Assume(b_intervals[0] <= 8)
comb += Assert(out_intervals[0] ==
- (a_intervals[0] << b_intervals[0]) & 0xff)
- comb += Assume(b_intervals[1] <= 8)
+ (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
comb += Assert(out_intervals[1] ==
- (a_intervals[1] << b_intervals[1]) & 0xff)
- comb += Assume(b_intervals[2] <= 16)
+ (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
comb += Assert(Cat(out_intervals[2:4]) ==
(Cat(a_intervals[2:4])
- << b_intervals[2]) & 0xffff)
+ << b_intervals[2][0:4]) & 0xffff)
with m.Case(0b100):
- comb += Assume(b_intervals[0] <= 24)
comb += Assert(Cat(out_intervals[0:3]) ==
(Cat(a_intervals[0:3])
- << b_intervals[0]) & 0xffffff)
- comb += Assume(b_intervals[3] <= 8)
+ << b_intervals[0][0:5]) & 0xffffff)
comb += Assert(out_intervals[3] ==
- (a_intervals[3] << b_intervals[3]) & 0xff)
+ (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
with m.Case(0b101):
- comb += Assume(b_intervals[0] <= 8)
comb += Assert(out_intervals[0] ==
- (a_intervals[0] << b_intervals[0]) & 0xff)
- comb += Assume(b_intervals[1] <= 16)
+ (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
comb += Assert(Cat(out_intervals[1:3]) ==
(Cat(a_intervals[1:3])
- << b_intervals[1]) & 0xffff)
- comb += Assume(b_intervals[3] <= 8)
+ << b_intervals[1][0:4]) & 0xffff)
comb += Assert(out_intervals[3] ==
- (a_intervals[3] << b_intervals[3]) & 0xff)
+ (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
with m.Case(0b110):
- comb += Assume(b_intervals[0] <= 16)
comb += Assert(Cat(out_intervals[0:2]) ==
(Cat(a_intervals[0:2])
- << b_intervals[0]) & 0xffff)
- comb += Assume(b_intervals[2] <= 8)
+ << b_intervals[0][0:4]) & 0xffff)
comb += Assert(out_intervals[2] ==
- (a_intervals[2] << b_intervals[2]) & 0xff)
- comb += Assume(b_intervals[3] <= 8)
+ (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
comb += Assert(out_intervals[3] ==
- (a_intervals[3] << b_intervals[3]) & 0xff)
+ (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
with m.Case(0b111):
for i, o in enumerate(out_intervals):
- comb += Assume(b_intervals[i] <= 8)
comb += Assert(o ==
- (a_intervals[i] << b_intervals[i]) & 0xff)
+ (a_intervals[i] << b_intervals[i][0:3])
+ & 0xff)
return m
def test_shift(self):
module = ShifterDriver()
self.assertFormal(module, mode="bmc", depth=4)
+
def test_ilang(self):
width = 64
mwidth = 8
if __name__ == "__main__":
unittest.main()
-
intervals.append([start,end])
start = end
+ min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
+ max_bits = math.ceil(math.log2(width))
+
+ shifter_masks = []
+ for i in range(len(b_intervals)):
+ mask = Signal(b_intervals[i].shape(), name="shift_mask%d" % i)
+ bits = []
+ for j in range(i, gates.width):
+ if bits:
+ bits.append(~gates[j] & bits[-1])
+ else:
+ bits.append(~gates[j])
+ comb += mask.eq(Cat((1 << min_bits)-1, bits)
+ & ((1 << max_bits)-1))
+ shifter_masks.append(mask)
+
+ print(shifter_masks)
+
+
# Instead of generating the matrix described in the wiki, I
# instead calculate the shift amounts for each partition, then
# calculate the partial results of each partition << shift
# for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
# partial results [a0b0, a1bx, a2bx, a3bx], I can use just
# those partial results to calculate a0, a1, a2, and a3
+ shiftbits = math.ceil(math.log2(width))
+ element = b_intervals[0] & shifter_masks[0]
partial_results = []
- partial_results.append(a_intervals[0] << b_intervals[0])
- element = b_intervals[0]
+ partial_results.append(a_intervals[0] << element)
for i in range(1, len(out_intervals)):
s, e = intervals[i]
- element = Mux(gates[i-1], b_intervals[i], element)
+ masked = Signal(b_intervals[i].shape(), name="masked%d" % i)
+ comb += masked.eq(b_intervals[i] & shifter_masks[i])
+ element = Mux(gates[i-1], masked,
+ element)
# This calculates which partition of b to select the
# shifter from. According to the table above, the
# the partition mask, this calculates that with a mux
# chain
-
# This computes the partial results table
- shifter = Signal(8, name="shifter%d" % i)
+ shifter = Signal(shiftbits, name="shifter%d" % i)
comb += shifter.eq(element)
partial = Signal(width, name="partial%d" % i)
comb += partial.eq(a_intervals[i] << shifter)