add PartitionedSignal.__Mux__ unit test
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 28 Sep 2021 16:00:40 +0000 (17:00 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 28 Sep 2021 16:00:40 +0000 (17:00 +0100)
split out from TestPartitionedSignal to reduce time and make it clear
what is being tested

src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py
src/ieee754/part_mul_add/partpoints.py
src/ieee754/part_mux/part_mux.py

index f68cf8ccd9dd33fc3a215cbe3a233605afd9fe41..ba172aab906b3bc0a4c2358ec879bb2a75536a34 100644 (file)
@@ -21,7 +21,7 @@ from ieee754.part_cmp.eq_gt_ge import PartitionedEqGtGe
 from ieee754.part_bits.xor import PartitionedXOR
 from ieee754.part_shift.part_shift_dynamic import PartitionedDynamicShift
 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
-from ieee754.part_mul_add.partpoints import make_partition, PartitionPoints
+from ieee754.part_mul_add.partpoints import make_partition2, PartitionPoints
 from ieee754.part_mux.part_mux import PMux
 from operator import or_, xor, and_, not_
 
@@ -43,6 +43,13 @@ def applyop(op1, op2, op):
     result.m.d.comb += result.sig.eq(op(getsig(op1), getsig(op2)))
     return result
 
+global modnames
+modnames = {}
+# for sub-modules to be created on-demand. Mux is done slightly
+# differently (has its own global)
+for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
+    modnames[name] = 0
+
 
 class PartitionedSignal(UserValue):
     def __init__(self, mask, *args, src_loc_at=0, **kwargs):
@@ -53,12 +60,7 @@ class PartitionedSignal(UserValue):
         if isinstance(mask, PartitionPoints):
             self.partpoints = mask
         else:
-            self.partpoints = make_partition(mask, width)
-        self.modnames = {}
-        # for sub-modules to be created on-demand. Mux is done slightly
-        # differently
-        for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor']:
-            self.modnames[name] = 0
+            self.partpoints = make_partition2(mask, width)
 
     def lower(self):
         return self.sig
@@ -67,8 +69,8 @@ class PartitionedSignal(UserValue):
         self.m = m
 
     def get_modname(self, category):
-        self.modnames[category] += 1
-        return "%s_%d" % (category, self.modnames[category])
+        modnames[category] += 1
+        return "%s_%d" % (category, modnames[category])
 
     def eq(self, val):
         return self.sig.eq(getsig(val))
@@ -85,6 +87,7 @@ class PartitionedSignal(UserValue):
     # nmigen-redirected constructs (Mux, Cat, Switch, Assign)
 
     def __Mux__(self, val1, val2):
+        # print ("partsig mux", self, val1, val2)
         assert len(val1) == len(val2), \
             "PartitionedSignal width sources must be the same " \
             "val1 == %d, val2 == %d" % (len(val1), len(val2))
index 2bf0ef837a7a6b5be5d26a645590962732b47577..c8b4066a04263fff04561e1eb8e92597f70620b3 100644 (file)
@@ -2,7 +2,7 @@
 # SPDX-License-Identifier: LGPL-2.1-or-later
 # See Notices.txt for copyright information
 
-from nmigen import Signal, Module, Elaboratable
+from nmigen import Signal, Module, Elaboratable, Mux
 from nmigen.back.pysim import Simulator, Delay
 from nmigen.cli import rtlil
 
@@ -63,8 +63,10 @@ class TestAddMod2(Elaboratable):
         self.ne_output = Signal(len(partpoints)+1)
         self.lt_output = Signal(len(partpoints)+1)
         self.le_output = Signal(len(partpoints)+1)
-        self.mux_sel = Signal(len(partpoints)+1)
+        self.mux_sel2 = Signal(len(partpoints)+1)
+        self.mux_sel2 = PartitionedSignal(partpoints, len(partpoints))
         self.mux_out = Signal(width)
+        self.mux2_out = Signal(width)
         self.carry_in = Signal(len(partpoints)+1)
         self.add_carry_out = Signal(len(partpoints)+1)
         self.sub_carry_out = Signal(len(partpoints)+1)
@@ -76,6 +78,7 @@ class TestAddMod2(Elaboratable):
         sync = m.d.sync
         self.a.set_module(m)
         self.b.set_module(m)
+        self.mux_sel2.set_module(m)
         # compares
         sync += self.lt_output.eq(self.a < self.b)
         sync += self.ne_output.eq(self.a != self.b)
@@ -100,6 +103,7 @@ class TestAddMod2(Elaboratable):
         sync += self.rs_output.eq(self.a >> self.b)
         ppts = self.partpoints
         sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
+        sync += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
         # scalar left shift
         comb += self.bsig.eq(self.b.sig)
         sync += self.ls_scal_output.eq(self.a << self.bsig)
@@ -108,6 +112,31 @@ class TestAddMod2(Elaboratable):
         return m
 
 
+class TestMuxMod(Elaboratable):
+    def __init__(self, width, partpoints):
+        self.partpoints = partpoints
+        self.a = PartitionedSignal(partpoints, width)
+        self.b = PartitionedSignal(partpoints, width)
+        self.mux_sel = Signal(len(partpoints)+1)
+        self.mux_sel2 = PartitionedSignal(partpoints, len(partpoints)+1)
+        self.mux_out = Signal(width)
+        self.mux_out2 = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        sync = m.d.sync
+        self.a.set_module(m)
+        self.b.set_module(m)
+        self.mux_sel2.set_module(m)
+        ppts = self.partpoints
+
+        comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
+        comb += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
+
+        return m
+
+
 class TestAddMod(Elaboratable):
     def __init__(self, width, partpoints):
         self.partpoints = partpoints
@@ -126,8 +155,6 @@ class TestAddMod(Elaboratable):
         self.ne_output = Signal(len(partpoints)+1)
         self.lt_output = Signal(len(partpoints)+1)
         self.le_output = Signal(len(partpoints)+1)
-        self.mux_sel = Signal(len(partpoints)+1)
-        self.mux_out = Signal(width)
         self.carry_in = Signal(len(partpoints)+1)
         self.add_carry_out = Signal(len(partpoints)+1)
         self.sub_carry_out = Signal(len(partpoints)+1)
@@ -163,8 +190,6 @@ class TestAddMod(Elaboratable):
         # right shift
         comb += self.rs_output.eq(self.a >> self.b)
         ppts = self.partpoints
-        # mux
-        comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
         # scalar left shift
         comb += self.bsig.eq(self.b.sig)
         comb += self.ls_scal_output.eq(self.a << self.bsig)
@@ -174,7 +199,87 @@ class TestAddMod(Elaboratable):
         return m
 
 
-class TestPartitionPoints(unittest.TestCase):
+class TestMux(unittest.TestCase):
+    def test(self):
+        width = 16
+        part_mask = Signal(4)  # divide into 4-bits
+        module = TestMuxMod(width, part_mask)
+
+        test_name = "part_sig_mux"
+        traces = [part_mask,
+                  module.a.sig,
+                  module.b.sig,
+                  module.mux_out,
+                  module.mux_out2]
+        sim = create_simulator(module, traces, test_name)
+
+        def async_process():
+
+            def test_muxop(msg_prefix, *maskbit_list):
+                for a, b in [(0x0000, 0x0000),
+                             (0x1234, 0x1234),
+                             (0xABCD, 0xABCD),
+                             (0xFFFF, 0x0000),
+                             (0x0000, 0x0000),
+                             (0xFFFF, 0xFFFF),
+                             (0x0000, 0xFFFF)]:
+                    # convert to mask_list
+                    mask_list = []
+                    for mb in maskbit_list:
+                        v = 0
+                        for i in range(4):
+                            if mb & (1 << i):
+                                v |= 0xf << (i*4)
+                        mask_list.append(v)
+
+                    # TODO: sel needs to go through permutations of mask_list
+                    for p in perms(len(mask_list)):
+
+                        sel = 0
+                        selmask = 0
+                        for i, v in enumerate(p):
+                            if v == '1':
+                                sel |= maskbit_list[i]
+                                selmask |= mask_list[i]
+
+                        yield module.a.eq(a)
+                        yield module.b.eq(b)
+                        yield module.mux_sel.eq(sel)
+                        yield module.mux_sel2.sig.eq(sel)
+                        yield Delay(0.1e-6)
+                        y = 0
+                        # do the partitioned tests
+                        for i, mask in enumerate(mask_list):
+                            if (selmask & mask):
+                                y |= (a & mask)
+                            else:
+                                y |= (b & mask)
+                        # check the result
+                        outval = (yield module.mux_out)
+                        outval2 = (yield module.mux_out2)
+                        msg = f"{msg_prefix}: mux " + \
+                            f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
+                            f" => 0x{y:X} != 0x{outval:X}, masklist %s"
+                        # print ((msg % str(maskbit_list)).format(locals()))
+                        self.assertEqual(y, outval, msg % str(maskbit_list))
+                        self.assertEqual(y, outval2, msg % str(maskbit_list))
+
+            yield part_mask.eq(0)
+            yield from test_muxop("16-bit", 0b1111)
+            yield part_mask.eq(0b10)
+            yield from test_muxop("8-bit", 0b1100, 0b0011)
+            yield part_mask.eq(0b1111)
+            yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+
+        sim.add_process(async_process)
+        with sim.write_vcd(
+                vcd_file=open(test_name + ".vcd", "w"),
+                gtkw_file=open(test_name + ".gtkw", "w"),
+                traces=traces):
+            sim.run()
+
+
+class TestPartitionedSignal(unittest.TestCase):
     def test(self):
         width = 16
         part_mask = Signal(4)  # divide into 4-bits
@@ -416,59 +521,6 @@ class TestPartitionPoints(unittest.TestCase):
                 yield from test_binop("4-bit", test_fn, mod_attr,
                                       0b1000, 0b0100, 0b0010, 0b0001)
 
-            def test_muxop(msg_prefix, *maskbit_list):
-                for a, b in [(0x0000, 0x0000),
-                             (0x1234, 0x1234),
-                             (0xABCD, 0xABCD),
-                             (0xFFFF, 0x0000),
-                             (0x0000, 0x0000),
-                             (0xFFFF, 0xFFFF),
-                             (0x0000, 0xFFFF)]:
-                    # convert to mask_list
-                    mask_list = []
-                    for mb in maskbit_list:
-                        v = 0
-                        for i in range(4):
-                            if mb & (1 << i):
-                                v |= 0xf << (i*4)
-                        mask_list.append(v)
-
-                    # TODO: sel needs to go through permutations of mask_list
-                    for p in perms(len(mask_list)):
-
-                        sel = 0
-                        selmask = 0
-                        for i, v in enumerate(p):
-                            if v == '1':
-                                sel |= maskbit_list[i]
-                                selmask |= mask_list[i]
-
-                        yield module.a.eq(a)
-                        yield module.b.eq(b)
-                        yield module.mux_sel.eq(sel)
-                        yield Delay(0.1e-6)
-                        y = 0
-                        # do the partitioned tests
-                        for i, mask in enumerate(mask_list):
-                            if (selmask & mask):
-                                y |= (a & mask)
-                            else:
-                                y |= (b & mask)
-                        # check the result
-                        outval = (yield module.mux_out)
-                        msg = f"{msg_prefix}: mux " + \
-                            f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
-                            f" => 0x{y:X} != 0x{outval:X}, masklist %s"
-                        # print ((msg % str(maskbit_list)).format(locals()))
-                        self.assertEqual(y, outval, msg % str(maskbit_list))
-
-            yield part_mask.eq(0)
-            yield from test_muxop("16-bit", 0b1111)
-            yield part_mask.eq(0b10)
-            yield from test_muxop("8-bit", 0b1100, 0b0011)
-            yield part_mask.eq(0b1111)
-            yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
-
         sim.add_process(async_process)
         with sim.write_vcd(
                 vcd_file=open(test_name + ".vcd", "w"),
index 887372af62cef7bb7b213cd63e98dbb10aae85ce..5f3c38c51f58a17f679b27d9e2b2baadf324036a 100644 (file)
@@ -41,10 +41,15 @@ def make_partition2(mask, width):
     ppoints = {}
     ppos = jumpsize
     midx = 0
+    if isinstance(mask, dict): # convert dict/partpoints to sequential list
+        mask = list(mask.values())
+    print ("make_partition2", width, mask, mlen, jumpsize)
     while ppos < width and midx < mlen: # -1, ignore last bit
+        print ("    make_partition2", ppos, width, midx, mlen)
         ppoints[ppos] = mask[midx]
         ppos += jumpsize
         midx += 1
+    print ("    make_partition2", mask, width, ppoints)
     return ppoints
 
 
index 30d344c1569b513948bac414bd18a3375d49723a..e5d2c87f6d4a11f938411a996d20e8d4da39667e 100644 (file)
@@ -15,7 +15,7 @@ See:
 
 from nmigen import Signal, Module, Elaboratable, Mux
 from ieee754.part_mul_add.partpoints import PartitionPoints
-from ieee754.part_mul_add.partpoints import make_partition
+from ieee754.part_mul_add.partpoints import make_partition2
 
 
 modcount = 0 # global for now
@@ -23,7 +23,7 @@ def PMux(m, mask, sel, a, b):
     global modcount
     modcount += 1
     width = len(a.sig)  # get width
-    part_pts = make_partition(mask, width) # create partition points
+    part_pts = make_partition2(mask, width) # create partition points
     pm = PartitionedMux(width, part_pts)
     m.d.comb += pm.a.eq(a.sig)
     m.d.comb += pm.b.eq(b.sig)