# SPDX-License-Identifier: LGPL-2.1-or-later
# See Notices.txt for copyright information
-from nmigen import Signal, Module, Elaboratable, Mux
+from nmigen import Signal, Module, Elaboratable, Mux, Cat
from nmigen.back.pysim import Simulator, Delay
from nmigen.cli import rtlil
return m
+class TestCatMod(Elaboratable):
+ def __init__(self, width, partpoints):
+ self.partpoints = partpoints
+ self.a = PartitionedSignal(partpoints, width)
+ self.b = PartitionedSignal(partpoints, width*2)
+ self.cat_sel = Signal(len(partpoints)+1)
+ self.cat_out = Signal(width*3)
+
+ def elaborate(self, platform):
+ m = Module()
+ comb = m.d.comb
+ self.a.set_module(m)
+ self.b.set_module(m)
+ #self.cat_sel.set_module(m)
+
+ comb += self.cat_out.eq(Cat(self.a, self.b))
+
+ return m
+
+
class TestAddMod(Elaboratable):
def __init__(self, width, partpoints):
self.partpoints = partpoints
sim.run()
+class TestCat(unittest.TestCase):
+ def test(self):
+ width = 16
+ part_mask = Signal(3) # divide into 4-bits
+ module = TestCatMod(width, part_mask)
+
+ test_name = "part_sig_mux"
+ traces = [part_mask,
+ module.a.sig,
+ module.b.sig,
+ module.cat_out]
+ sim = create_simulator(module, traces, test_name)
+
+ # annoying recursive import issue
+ from ieee754.part_cat.cat import get_runlengths
+
+ def async_process():
+
+ def test_catop(msg_prefix, *maskbit_list):
+ # define lengths of a/b test input
+ alen, blen = 16, 32
+ # pairs of test values a, b
+ for a, b in [(0x0000, 0x00000000),
+ (0xDCBA, 0x12345678),
+ (0xABCD, 0x01234567),
+ (0xFFFF, 0x0000),
+ (0x0000, 0x0000),
+ (0x1F1F, 0xF1F1F1F1),
+ (0x0000, 0xFFFFFFFF)]:
+ # convert to mask_list
+ mask_list = []
+ for mb in maskbit_list:
+ v = 0
+ for i in range(4):
+ if mb & (1 << i):
+ v |= 0xf << (i*4)
+ mask_list.append(v)
+
+ # convert a and b to partitions
+ apart, bpart = [], []
+ ajump, bjump = alen // 4, blen // 4
+ for i in range(4):
+ apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
+ bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
+
+ print ("apart bpart", hex(a), hex(b),
+ list(map(hex, apart)), list(map(hex, bpart)))
+
+ yield module.a.eq(a)
+ yield module.b.eq(b)
+ yield Delay(0.1e-6)
+
+ y = 0
+ # work out the runlengths for this mask.
+ # 0b011 returns [1,1,2] (for a mask of length 3)
+ mval = yield part_mask
+ runlengths = get_runlengths(mval, 3)
+ j = 0
+ ai = 0
+ bi = 0
+ for i in runlengths:
+ # a first
+ for _ in range(i):
+ print ("runlength", i,
+ "ai", ai,
+ "apart", hex(apart[ai]),
+ "j", j)
+ y |= apart[ai] << j
+ print (" y", hex(y))
+ j += ajump
+ ai += 1
+ # now b
+ for _ in range(i):
+ print ("runlength", i,
+ "bi", bi,
+ "bpart", hex(bpart[bi]),
+ "j", j)
+ y |= bpart[bi] << j
+ print (" y", hex(y))
+ j += bjump
+ bi += 1
+
+ # check the result
+ outval = (yield module.cat_out)
+ msg = f"{msg_prefix}: cat " + \
+ f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
+ f" => 0x{y:X} != 0x{outval:X}, masklist %s"
+ # print ((msg % str(maskbit_list)).format(locals()))
+ self.assertEqual(y, outval, msg % str(maskbit_list))
+
+ yield part_mask.eq(0)
+ yield from test_catop("16-bit", 0b1111)
+ yield part_mask.eq(0b10)
+ yield from test_catop("8-bit", 0b1100, 0b0011)
+ yield part_mask.eq(0b1111)
+ yield from test_catop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+
+ sim.add_process(async_process)
+ with sim.write_vcd(
+ vcd_file=open(test_name + ".vcd", "w"),
+ gtkw_file=open(test_name + ".gtkw", "w"),
+ traces=traces):
+ sim.run()
+
+
class TestPartitionedSignal(unittest.TestCase):
def test(self):
width = 16
"""
# work out the length (total of all PartitionedSignals)
self.catlist = catlist
+ if isinstance(mask, dict):
+ mask = list(mask.values())
self.mask = mask
width = 0
for p in catlist:
keys = list(self.partition_points.keys())
print ("keys", keys, "values", self.partition_points.values())
- with m.Switch(self.mask[:-1]):
+ print ("mask", self.mask)
+ with m.Switch(Cat(self.mask)):
# for each partition possibility, create a Cat sequence
for pbit in range(1<<len(keys)):
# set up some indices pointing to where things have got
if __name__ == "__main__":
m = Module()
- mask = Signal(4)
+ mask = Signal(3)
a = PartitionedSignal(mask, 32)
b = PartitionedSignal(mask, 16)
catlist = [a, b]
def make_partition2(mask, width):
""" from a mask and a bitwidth, create partition points.
- note that the assumption is that the mask indicates the
- breakpoints in regular intervals, and that the last bit (MSB)
- of the mask is therefore *ignored*.
- mask len = 4, width == 16 will return:
+ note that the mask represents the actual partition points
+ and therefore must be ONE LESS than the number of required
+ partitions
+
+ mask len = 3, width == 16 will return:
{4: mask[0], 8: mask[1], 12: mask[2]}
- mask len = 8, width == 64 will return:
+ mask len = 7, width == 64 will return:
{8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
"""
- mlen = len(mask)
+ mlen = len(mask) + 1 # ONE MORE partitions than break-points
jumpsize = width // mlen # amount to jump by (size of each partition)
ppoints = {}
ppos = jumpsize