unit test for PartitionedSignal.__Cat__
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 29 Sep 2021 15:43:50 +0000 (16:43 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 29 Sep 2021 15:43:50 +0000 (16:43 +0100)
decided also to change the mask to be the number of partition points
(not waste one bit)

src/ieee754/part/test/test_partsig.py
src/ieee754/part_cat/cat.py
src/ieee754/part_mul_add/partpoints.py

index c8b4066a04263fff04561e1eb8e92597f70620b3..387a2e9c3912e47533c8b84454135a6106fd97ad 100644 (file)
@@ -2,7 +2,7 @@
 # SPDX-License-Identifier: LGPL-2.1-or-later
 # See Notices.txt for copyright information
 
-from nmigen import Signal, Module, Elaboratable, Mux
+from nmigen import Signal, Module, Elaboratable, Mux, Cat
 from nmigen.back.pysim import Simulator, Delay
 from nmigen.cli import rtlil
 
@@ -137,6 +137,26 @@ class TestMuxMod(Elaboratable):
         return m
 
 
+class TestCatMod(Elaboratable):
+    def __init__(self, width, partpoints):
+        self.partpoints = partpoints
+        self.a = PartitionedSignal(partpoints, width)
+        self.b = PartitionedSignal(partpoints, width*2)
+        self.cat_sel = Signal(len(partpoints)+1)
+        self.cat_out = Signal(width*3)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        self.a.set_module(m)
+        self.b.set_module(m)
+        #self.cat_sel.set_module(m)
+
+        comb += self.cat_out.eq(Cat(self.a, self.b))
+
+        return m
+
+
 class TestAddMod(Elaboratable):
     def __init__(self, width, partpoints):
         self.partpoints = partpoints
@@ -279,6 +299,111 @@ class TestMux(unittest.TestCase):
             sim.run()
 
 
+class TestCat(unittest.TestCase):
+    def test(self):
+        width = 16
+        part_mask = Signal(3)  # divide into 4-bits
+        module = TestCatMod(width, part_mask)
+
+        test_name = "part_sig_mux"
+        traces = [part_mask,
+                  module.a.sig,
+                  module.b.sig,
+                  module.cat_out]
+        sim = create_simulator(module, traces, test_name)
+
+        # annoying recursive import issue
+        from ieee754.part_cat.cat import get_runlengths
+
+        def async_process():
+
+            def test_catop(msg_prefix, *maskbit_list):
+                # define lengths of a/b test input
+                alen, blen = 16, 32
+                # pairs of test values a, b
+                for a, b in [(0x0000, 0x00000000),
+                             (0xDCBA, 0x12345678),
+                             (0xABCD, 0x01234567),
+                             (0xFFFF, 0x0000),
+                             (0x0000, 0x0000),
+                             (0x1F1F, 0xF1F1F1F1),
+                             (0x0000, 0xFFFFFFFF)]:
+                    # convert to mask_list
+                    mask_list = []
+                    for mb in maskbit_list:
+                        v = 0
+                        for i in range(4):
+                            if mb & (1 << i):
+                                v |= 0xf << (i*4)
+                        mask_list.append(v)
+
+                    # convert a and b to partitions
+                    apart, bpart = [], []
+                    ajump, bjump = alen // 4, blen // 4
+                    for i in range(4):
+                        apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
+                        bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
+
+                    print ("apart bpart", hex(a), hex(b),
+                            list(map(hex, apart)), list(map(hex, bpart)))
+
+                    yield module.a.eq(a)
+                    yield module.b.eq(b)
+                    yield Delay(0.1e-6)
+
+                    y = 0
+                    # work out the runlengths for this mask.
+                    # 0b011 returns [1,1,2] (for a mask of length 3)
+                    mval = yield part_mask
+                    runlengths = get_runlengths(mval, 3)
+                    j = 0
+                    ai = 0
+                    bi = 0
+                    for i in runlengths:
+                        # a first
+                        for _ in range(i):
+                            print ("runlength", i,
+                                   "ai", ai,
+                                   "apart", hex(apart[ai]),
+                                   "j", j)
+                            y |= apart[ai] << j
+                            print ("    y", hex(y))
+                            j += ajump
+                            ai += 1
+                        # now b
+                        for _ in range(i):
+                            print ("runlength", i,
+                                   "bi", bi,
+                                   "bpart", hex(bpart[bi]),
+                                   "j", j)
+                            y |= bpart[bi] << j
+                            print ("    y", hex(y))
+                            j += bjump
+                            bi += 1
+
+                    # check the result
+                    outval = (yield module.cat_out)
+                    msg = f"{msg_prefix}: cat " + \
+                        f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
+                        f" => 0x{y:X} != 0x{outval:X}, masklist %s"
+                    # print ((msg % str(maskbit_list)).format(locals()))
+                    self.assertEqual(y, outval, msg % str(maskbit_list))
+
+            yield part_mask.eq(0)
+            yield from test_catop("16-bit", 0b1111)
+            yield part_mask.eq(0b10)
+            yield from test_catop("8-bit", 0b1100, 0b0011)
+            yield part_mask.eq(0b1111)
+            yield from test_catop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+
+        sim.add_process(async_process)
+        with sim.write_vcd(
+                vcd_file=open(test_name + ".vcd", "w"),
+                gtkw_file=open(test_name + ".gtkw", "w"),
+                traces=traces):
+            sim.run()
+
+
 class TestPartitionedSignal(unittest.TestCase):
     def test(self):
         width = 16
index a9541080be0e6f1a5cc64135d4ed645933b26fa2..6fdfb66b1b1c32f8700c2511f978862d85f34d6f 100644 (file)
@@ -73,6 +73,8 @@ class PartitionedCat(Elaboratable):
         """
         # work out the length (total of all PartitionedSignals)
         self.catlist = catlist
+        if isinstance(mask, dict):
+            mask = list(mask.values())
         self.mask = mask
         width = 0
         for p in catlist:
@@ -101,7 +103,8 @@ class PartitionedCat(Elaboratable):
 
         keys = list(self.partition_points.keys())
         print ("keys", keys, "values", self.partition_points.values())
-        with m.Switch(self.mask[:-1]):
+        print ("mask", self.mask)
+        with m.Switch(Cat(self.mask)):
             # for each partition possibility, create a Cat sequence
             for pbit in range(1<<len(keys)):
                 # set up some indices pointing to where things have got
@@ -131,7 +134,7 @@ class PartitionedCat(Elaboratable):
 
 if __name__ == "__main__":
     m = Module()
-    mask = Signal(4)
+    mask = Signal(3)
     a = PartitionedSignal(mask, 32)
     b = PartitionedSignal(mask, 16)
     catlist = [a, b]
index 5f3c38c51f58a17f679b27d9e2b2baadf324036a..8ae238103e32b3d7b7f176d6c2a838a517a372f4 100644 (file)
@@ -28,15 +28,16 @@ def make_partition(mask, width):
 
 def make_partition2(mask, width):
     """ from a mask and a bitwidth, create partition points.
-        note that the assumption is that the mask indicates the
-        breakpoints in regular intervals, and that the last bit (MSB)
-        of the mask is therefore *ignored*.
-        mask len = 4, width == 16 will return:
+        note that the mask represents the actual partition points
+        and therefore must be ONE LESS than the number of required
+        partitions
+
+        mask len = 3, width == 16 will return:
             {4: mask[0], 8: mask[1], 12: mask[2]}
-        mask len = 8, width == 64 will return:
+        mask len = 7, width == 64 will return:
             {8: mask[0], 16: mask[1], 24: mask[2], .... 56: mask[6]}
     """
-    mlen = len(mask)
+    mlen = len(mask) + 1     # ONE MORE partitions than break-points
     jumpsize = width // mlen # amount to jump by (size of each partition)
     ppoints = {}
     ppos = jumpsize