add unit test for part_mux
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 7 Feb 2020 14:38:08 +0000 (14:38 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 7 Feb 2020 14:38:08 +0000 (14:38 +0000)
src/ieee754/part/test/test_partsig.py
src/ieee754/part_mux/part_mux.py

index 0229794e4c8104e0d124bbf5d0fa682e4de459bd..920cd0d44a92d885121f9cb984b288ce53eac64c 100644 (file)
@@ -7,6 +7,7 @@ from nmigen.back.pysim import Simulator, Delay, Tick, Passive
 from nmigen.cli import verilog, rtlil
 
 from ieee754.part.partsig import PartitionedSignal
+from ieee754.part_mux.part_mux import PMux
 
 import unittest
 
@@ -34,6 +35,8 @@ class TestAddMod(Elaboratable):
         self.ne_output = Signal(len(partpoints)+1)
         self.lt_output = Signal(len(partpoints)+1)
         self.le_output = Signal(len(partpoints)+1)
+        self.mux_sel = Signal(len(partpoints)+1)
+        self.mux_out = Signal(width)
 
     def elaborate(self, platform):
         m = Module()
@@ -46,6 +49,7 @@ class TestAddMod(Elaboratable):
         m.d.comb += self.eq_output.eq(self.a == self.b)
         m.d.comb += self.ge_output.eq(self.a >= self.b)
         m.d.comb += self.add_output.eq(self.a + self.b)
+        m.d.comb += self.mux_out.eq(PMux(m, self.a, self.b, self.mux_sel))
 
         return m
 
@@ -155,6 +159,47 @@ class TestPartitionPoints(unittest.TestCase):
                 yield from test_binop("4-bit", test_fn, mod_attr,
                                       0b1000, 0b0100, 0b0010, 0b0001)
 
+            def test_muxop(msg_prefix, *maskbit_list):
+                for a, b, sel in [(0x0000, 0x0000, 0b0110),
+                             (0x1234, 0x1234, 0b1010),
+                             (0xABCD, 0xABCD, 0b1100),
+                             (0xFFFF, 0x0000, 0b0011),
+                             (0x0000, 0x0000, 0b1001),
+                             (0xFFFF, 0xFFFF, 0b1101),
+                             (0x0000, 0xFFFF, 0b1100)]:
+                    yield module.a.eq(a)
+                    yield module.b.eq(b)
+                    yield module.mux_sel.eq(sel)
+                    yield Delay(0.1e-6)
+                    # convert to mask_list
+                    mask_list = []
+                    for mb in maskbit_list:
+                        v = 0
+                        for i in range(4):
+                            if mb & (1<<i):
+                                v |= 0xf << (i*4)
+                        mask_list.append(v)
+                    y = 0
+                    # do the partitioned tests
+                    for i, mask in enumerate(mask_list):
+                        if (sel & mask):
+                            y |= (a & mask)
+                        else:
+                            y |= (b & mask)
+                    # check the result
+                    outval = (yield module.mux_out)
+                    msg = f"{msg_prefix}: mux 0x{a:X} == 0x{b:X}" + \
+                        f" => 0x{y:X} != 0x{outval:X}, masklist %s"
+                    #print ((msg % str(maskbit_list)).format(locals()))
+                    self.assertEqual(y, outval, msg % str(maskbit_list))
+
+            yield part_mask.eq(0)
+            yield from test_muxop("16-bit", 0b1111)
+            yield part_mask.eq(0b10)
+            yield from test_muxop("8-bit", 0b1100, 0b0011)
+            yield part_mask.eq(0b1111)
+            yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+
         sim.add_process(async_process)
         sim.run()
 
index 2d9d7eff39fa2cce96563272ff8949c69d5d8830..54f34812d4cd06bc3195ea1aca0b198d2707c43d 100644 (file)
@@ -16,6 +16,16 @@ See:
 from nmigen import Signal, Module, Elaboratable, Mux
 from ieee754.part_mul_add.partpoints import PartitionPoints
 
+modcount = 0 # global for now
+def PMux(m, sel, a, b):
+    modcount += 1
+    pm = PartitionedMux(a.shape()[0])
+    m.d.comb += pm.a.eq(a)
+    m.d.comb += pm.b.eq(b)
+    m.d.comb += pm.sel.eq(sel)
+    setattr(m.submodules, "pmux%d" % modcount, pm)
+    return pm.output
+
 class PartitionedMux(Elaboratable):
     """PartitionedMux: Partitioned "Mux"
 
@@ -33,8 +43,8 @@ class PartitionedMux(Elaboratable):
         self.b = Signal(width, reset_less=True)
         self.sel = Signal(self.mwidth, reset_less=True)
         self.output = Signal(width, reset_less=True)
-        assert self.partition_points.fits_in_width(width),
-                    ("partition_points doesn't fit in width")
+        assert (self.partition_points.fits_in_width(width),
+                    "partition_points doesn't fit in width")
 
     def elaborate(self, platform):
         m = Module()
@@ -54,3 +64,4 @@ class PartitionedMux(Elaboratable):
 
     def ports(self):
         return [self.a, self.b, self.sel, self.output]
+