remove PartitionedSignal.eq, expectation is to use PartitionedSignal.__Assign__
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
index 2bf0ef837a7a6b5be5d26a645590962732b47577..a862412e7824bb24f5ad285625167d324197eeb1 100644 (file)
@@ -2,7 +2,7 @@
 # SPDX-License-Identifier: LGPL-2.1-or-later
 # See Notices.txt for copyright information
 
-from nmigen import Signal, Module, Elaboratable
+from nmigen import Signal, Module, Elaboratable, Mux, Cat
 from nmigen.back.pysim import Simulator, Delay
 from nmigen.cli import rtlil
 
@@ -14,6 +14,7 @@ import unittest
 import itertools
 import math
 
+
 def first_zero(x):
     res = 0
     for i in range(16):
@@ -63,8 +64,10 @@ class TestAddMod2(Elaboratable):
         self.ne_output = Signal(len(partpoints)+1)
         self.lt_output = Signal(len(partpoints)+1)
         self.le_output = Signal(len(partpoints)+1)
-        self.mux_sel = Signal(len(partpoints)+1)
+        self.mux_sel2 = Signal(len(partpoints)+1)
+        self.mux_sel2 = PartitionedSignal(partpoints, len(partpoints))
         self.mux_out = Signal(width)
+        self.mux2_out = Signal(width)
         self.carry_in = Signal(len(partpoints)+1)
         self.add_carry_out = Signal(len(partpoints)+1)
         self.sub_carry_out = Signal(len(partpoints)+1)
@@ -76,6 +79,7 @@ class TestAddMod2(Elaboratable):
         sync = m.d.sync
         self.a.set_module(m)
         self.b.set_module(m)
+        self.mux_sel2.set_module(m)
         # compares
         sync += self.lt_output.eq(self.a < self.b)
         sync += self.ne_output.eq(self.a != self.b)
@@ -100,14 +104,60 @@ class TestAddMod2(Elaboratable):
         sync += self.rs_output.eq(self.a >> self.b)
         ppts = self.partpoints
         sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
+        sync += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
         # scalar left shift
-        comb += self.bsig.eq(self.b.sig)
+        comb += self.bsig.eq(self.b.lower())
         sync += self.ls_scal_output.eq(self.a << self.bsig)
         sync += self.rs_scal_output.eq(self.a >> self.bsig)
 
         return m
 
 
+class TestMuxMod(Elaboratable):
+    def __init__(self, width, partpoints):
+        self.partpoints = partpoints
+        self.a = PartitionedSignal(partpoints, width)
+        self.b = PartitionedSignal(partpoints, width)
+        self.mux_sel = Signal(len(partpoints)+1)
+        self.mux_sel2 = PartitionedSignal(partpoints, len(partpoints)+1)
+        self.mux_out = Signal(width)
+        self.mux_out2 = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        sync = m.d.sync
+        self.a.set_module(m)
+        self.b.set_module(m)
+        self.mux_sel2.set_module(m)
+        ppts = self.partpoints
+
+        comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
+        comb += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
+
+        return m
+
+
+class TestCatMod(Elaboratable):
+    def __init__(self, width, partpoints):
+        self.partpoints = partpoints
+        self.a = PartitionedSignal(partpoints, width)
+        self.b = PartitionedSignal(partpoints, width*2)
+        self.cat_sel = Signal(len(partpoints)+1)
+        self.cat_out = Signal(width*3)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        self.a.set_module(m)
+        self.b.set_module(m)
+        #self.cat_sel.set_module(m)
+
+        comb += self.cat_out.eq(Cat(self.a, self.b))
+
+        return m
+
+
 class TestAddMod(Elaboratable):
     def __init__(self, width, partpoints):
         self.partpoints = partpoints
@@ -126,8 +176,6 @@ class TestAddMod(Elaboratable):
         self.ne_output = Signal(len(partpoints)+1)
         self.lt_output = Signal(len(partpoints)+1)
         self.le_output = Signal(len(partpoints)+1)
-        self.mux_sel = Signal(len(partpoints)+1)
-        self.mux_out = Signal(width)
         self.carry_in = Signal(len(partpoints)+1)
         self.add_carry_out = Signal(len(partpoints)+1)
         self.sub_carry_out = Signal(len(partpoints)+1)
@@ -163,10 +211,8 @@ class TestAddMod(Elaboratable):
         # right shift
         comb += self.rs_output.eq(self.a >> self.b)
         ppts = self.partpoints
-        # mux
-        comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
         # scalar left shift
-        comb += self.bsig.eq(self.b.sig)
+        comb += self.bsig.eq(self.b.lower())
         comb += self.ls_scal_output.eq(self.a << self.bsig)
         # scalar right shift
         comb += self.rs_scal_output.eq(self.a >> self.bsig)
@@ -174,10 +220,195 @@ class TestAddMod(Elaboratable):
         return m
 
 
-class TestPartitionPoints(unittest.TestCase):
+class TestMux(unittest.TestCase):
     def test(self):
         width = 16
-        part_mask = Signal(4)  # divide into 4-bits
+        part_mask = Signal(3)  # divide into 4-bits
+        module = TestMuxMod(width, part_mask)
+
+        test_name = "part_sig_mux"
+        traces = [part_mask,
+                  module.a.sig,
+                  module.b.sig,
+                  module.mux_out,
+                  module.mux_out2]
+        sim = create_simulator(module, traces, test_name)
+
+        def async_process():
+
+            def test_muxop(msg_prefix, *maskbit_list):
+                for a, b in [(0x0000, 0x0000),
+                             (0x1234, 0x1234),
+                             (0xABCD, 0xABCD),
+                             (0xFFFF, 0x0000),
+                             (0x0000, 0x0000),
+                             (0xFFFF, 0xFFFF),
+                             (0x0000, 0xFFFF)]:
+                    # convert to mask_list
+                    mask_list = []
+                    for mb in maskbit_list:
+                        v = 0
+                        for i in range(4):
+                            if mb & (1 << i):
+                                v |= 0xf << (i*4)
+                        mask_list.append(v)
+
+                    # TODO: sel needs to go through permutations of mask_list
+                    for p in perms(len(mask_list)):
+
+                        sel = 0
+                        selmask = 0
+                        for i, v in enumerate(p):
+                            if v == '1':
+                                sel |= maskbit_list[i]
+                                selmask |= mask_list[i]
+
+                        yield module.a.lower().eq(a)
+                        yield module.b.lower().eq(b)
+                        yield module.mux_sel.eq(sel)
+                        yield module.mux_sel2.lower().eq(sel)
+                        yield Delay(0.1e-6)
+                        y = 0
+                        # do the partitioned tests
+                        for i, mask in enumerate(mask_list):
+                            if (selmask & mask):
+                                y |= (a & mask)
+                            else:
+                                y |= (b & mask)
+                        # check the result
+                        outval = (yield module.mux_out)
+                        outval2 = (yield module.mux_out2)
+                        msg = f"{msg_prefix}: mux " + \
+                            f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
+                            f" => 0x{y:X} != 0x{outval:X}, masklist %s"
+                        # print ((msg % str(maskbit_list)).format(locals()))
+                        self.assertEqual(y, outval, msg % str(maskbit_list))
+                        self.assertEqual(y, outval2, msg % str(maskbit_list))
+
+            yield part_mask.eq(0)
+            yield from test_muxop("16-bit", 0b1111)
+            yield part_mask.eq(0b10)
+            yield from test_muxop("8-bit", 0b1100, 0b0011)
+            yield part_mask.eq(0b1111)
+            yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+
+        sim.add_process(async_process)
+        with sim.write_vcd(
+                vcd_file=open(test_name + ".vcd", "w"),
+                gtkw_file=open(test_name + ".gtkw", "w"),
+                traces=traces):
+            sim.run()
+
+
+class TestCat(unittest.TestCase):
+    def test(self):
+        width = 16
+        part_mask = Signal(3)  # divide into 4-bits
+        module = TestCatMod(width, part_mask)
+
+        test_name = "part_sig_mux"
+        traces = [part_mask,
+                  module.a.sig,
+                  module.b.sig,
+                  module.cat_out]
+        sim = create_simulator(module, traces, test_name)
+
+        # annoying recursive import issue
+        from ieee754.part_cat.cat import get_runlengths
+
+        def async_process():
+
+            def test_catop(msg_prefix, *maskbit_list):
+                # define lengths of a/b test input
+                alen, blen = 16, 32
+                # pairs of test values a, b
+                for a, b in [(0x0000, 0x00000000),
+                             (0xDCBA, 0x12345678),
+                             (0xABCD, 0x01234567),
+                             (0xFFFF, 0x0000),
+                             (0x0000, 0x0000),
+                             (0x1F1F, 0xF1F1F1F1),
+                             (0x0000, 0xFFFFFFFF)]:
+                    # convert to mask_list
+                    mask_list = []
+                    for mb in maskbit_list:
+                        v = 0
+                        for i in range(4):
+                            if mb & (1 << i):
+                                v |= 0xf << (i*4)
+                        mask_list.append(v)
+
+                    # convert a and b to partitions
+                    apart, bpart = [], []
+                    ajump, bjump = alen // 4, blen // 4
+                    for i in range(4):
+                        apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
+                        bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
+
+                    print ("apart bpart", hex(a), hex(b),
+                            list(map(hex, apart)), list(map(hex, bpart)))
+
+                    yield module.a.lower().eq(a)
+                    yield module.b.lower().eq(b)
+                    yield Delay(0.1e-6)
+
+                    y = 0
+                    # work out the runlengths for this mask.
+                    # 0b011 returns [1,1,2] (for a mask of length 3)
+                    mval = yield part_mask
+                    runlengths = get_runlengths(mval, 3)
+                    j = 0
+                    ai = 0
+                    bi = 0
+                    for i in runlengths:
+                        # a first
+                        for _ in range(i):
+                            print ("runlength", i,
+                                   "ai", ai,
+                                   "apart", hex(apart[ai]),
+                                   "j", j)
+                            y |= apart[ai] << j
+                            print ("    y", hex(y))
+                            j += ajump
+                            ai += 1
+                        # now b
+                        for _ in range(i):
+                            print ("runlength", i,
+                                   "bi", bi,
+                                   "bpart", hex(bpart[bi]),
+                                   "j", j)
+                            y |= bpart[bi] << j
+                            print ("    y", hex(y))
+                            j += bjump
+                            bi += 1
+
+                    # check the result
+                    outval = (yield module.cat_out)
+                    msg = f"{msg_prefix}: cat " + \
+                        f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
+                        f" => 0x{y:X} != 0x{outval:X}, masklist %s"
+                    # print ((msg % str(maskbit_list)).format(locals()))
+                    self.assertEqual(y, outval, msg % str(maskbit_list))
+
+            yield part_mask.eq(0)
+            yield from test_catop("16-bit", 0b1111)
+            yield part_mask.eq(0b10)
+            yield from test_catop("8-bit", 0b1100, 0b0011)
+            yield part_mask.eq(0b1111)
+            yield from test_catop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+
+        sim.add_process(async_process)
+        with sim.write_vcd(
+                vcd_file=open(test_name + ".vcd", "w"),
+                gtkw_file=open(test_name + ".gtkw", "w"),
+                traces=traces):
+            sim.run()
+
+
+class TestPartitionedSignal(unittest.TestCase):
+    def test(self):
+        width = 16
+        part_mask = Signal(3)  # divide into 4-bits
         module = TestAddMod(width, part_mask)
 
         test_name = "part_sig_add"
@@ -297,8 +528,8 @@ class TestPartitionPoints(unittest.TestCase):
                              (0x0000, 0x0000),
                              (0xFFFF, 0xFFFF),
                              (0x0000, 0xFFFF)] + rand_data:
-                    yield module.a.eq(a)
-                    yield module.b.eq(b)
+                    yield module.a.lower().eq(a)
+                    yield module.b.lower().eq(b)
                     carry_sig = 0xf if carry else 0
                     yield module.carry_in.eq(carry_sig)
                     yield Delay(0.1e-6)
@@ -324,6 +555,11 @@ class TestPartitionPoints(unittest.TestCase):
                             f" => 0x{carry_result:X} != 0x{c_outval:X}"
                         self.assertEqual(carry_result, c_outval, msg)
 
+            # run through series of operations with corresponding
+            # "helper" routines to reproduce the result (test_fn).  the same
+            # a/b input is passed to *all* outputs, where the name of the
+            # output attribute (mod_attr) will contain the result to be
+            # compared against the expected output from test_fn
             for (test_fn, mod_attr) in (
                                         (test_ls_scal_fn, "ls_scal"),
                                         (test_ls_fn, "ls"),
@@ -376,8 +612,8 @@ class TestPartitionPoints(unittest.TestCase):
                              (0xABCD, 0xABCE),
                              (0x8000, 0x0000),
                              (0xBEEF, 0xFEED)]:
-                    yield module.a.eq(a)
-                    yield module.b.eq(b)
+                    yield module.a.lower().eq(a)
+                    yield module.b.lower().eq(b)
                     yield Delay(0.1e-6)
                     # convert to mask_list
                     mask_list = []
@@ -416,59 +652,6 @@ class TestPartitionPoints(unittest.TestCase):
                 yield from test_binop("4-bit", test_fn, mod_attr,
                                       0b1000, 0b0100, 0b0010, 0b0001)
 
-            def test_muxop(msg_prefix, *maskbit_list):
-                for a, b in [(0x0000, 0x0000),
-                             (0x1234, 0x1234),
-                             (0xABCD, 0xABCD),
-                             (0xFFFF, 0x0000),
-                             (0x0000, 0x0000),
-                             (0xFFFF, 0xFFFF),
-                             (0x0000, 0xFFFF)]:
-                    # convert to mask_list
-                    mask_list = []
-                    for mb in maskbit_list:
-                        v = 0
-                        for i in range(4):
-                            if mb & (1 << i):
-                                v |= 0xf << (i*4)
-                        mask_list.append(v)
-
-                    # TODO: sel needs to go through permutations of mask_list
-                    for p in perms(len(mask_list)):
-
-                        sel = 0
-                        selmask = 0
-                        for i, v in enumerate(p):
-                            if v == '1':
-                                sel |= maskbit_list[i]
-                                selmask |= mask_list[i]
-
-                        yield module.a.eq(a)
-                        yield module.b.eq(b)
-                        yield module.mux_sel.eq(sel)
-                        yield Delay(0.1e-6)
-                        y = 0
-                        # do the partitioned tests
-                        for i, mask in enumerate(mask_list):
-                            if (selmask & mask):
-                                y |= (a & mask)
-                            else:
-                                y |= (b & mask)
-                        # check the result
-                        outval = (yield module.mux_out)
-                        msg = f"{msg_prefix}: mux " + \
-                            f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
-                            f" => 0x{y:X} != 0x{outval:X}, masklist %s"
-                        # print ((msg % str(maskbit_list)).format(locals()))
-                        self.assertEqual(y, outval, msg % str(maskbit_list))
-
-            yield part_mask.eq(0)
-            yield from test_muxop("16-bit", 0b1111)
-            yield part_mask.eq(0b10)
-            yield from test_muxop("8-bit", 0b1100, 0b0011)
-            yield part_mask.eq(0b1111)
-            yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
-
         sim.add_process(async_process)
         with sim.write_vcd(
                 vcd_file=open(test_name + ".vcd", "w"),