move clmul files into nmigen-gf.git
[nmutil.git] / src / nmutil / test / test_grev.py
index 18e1917d2a1b93f246ed9b6a9c07bbedbd4a1f2f..780239d8a13b2954a7953d5d2e312dd517a80347 100644 (file)
@@ -1,5 +1,8 @@
 # SPDX-License-Identifier: LGPL-3-or-later
-# See Notices.txt for copyright information
+# Copyright 2021 Jacob Lifshay
+
+# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
+# of Horizon 2020 EU Programme 957073.
 
 import unittest
 from nmigen.hdl.ast import AnyConst, Assert
@@ -11,63 +14,73 @@ from nmutil.sim_util import do_sim, hash_256
 
 
 class TestGrev(FHDLTestCase):
-    def test(self):
-        log2_width = 6
+    def tst(self, msb_first, log2_width=6):
         width = 2 ** log2_width
-        dut = GRev(log2_width)
+        dut = GRev(log2_width, msb_first)
         self.assertEqual(width, dut.width)
-        self.assertEqual(len(dut._steps), log2_width + 1)
+        self.assertEqual(len(dut._intermediates), log2_width + 1)
 
-        def case(input, chunk_sizes):
-            expected = grev(input, chunk_sizes, log2_width)
-            with self.subTest(input=hex(input), chunk_sizes=bin(chunk_sizes),
+        def case(inval, chunk_sizes):
+            expected = grev(inval, chunk_sizes, log2_width)
+            with self.subTest(inval=hex(inval), chunk_sizes=bin(chunk_sizes),
                               expected=hex(expected)):
-                yield dut.input.eq(input)
+                yield dut.input.eq(inval)
                 yield dut.chunk_sizes.eq(chunk_sizes)
                 yield Delay(1e-6)
                 output = yield dut.output
                 with self.subTest(output=hex(output)):
                     self.assertEqual(expected, output)
-                for i, step in enumerate(dut._steps):
-                    cur_chunk_sizes = chunk_sizes & (2 ** i - 1)
-                    step_expected = grev(input, cur_chunk_sizes, log2_width)
-                    step = yield step
-                    with self.subTest(i=i, step=hex(step),
-                                      cur_chunk_sizes=bin(cur_chunk_sizes),
-                                      step_expected=hex(step_expected)):
-                        self.assertEqual(step, step_expected)
+                for sig, expected in dut._sigs_and_expected(inval,
+                                                            chunk_sizes):
+                    value = yield sig
+                    with self.subTest(sig=sig.name, value=hex(value),
+                                      expected=hex(expected)):
+                        self.assertEqual(value, expected)
 
         def process():
             for count in range(width + 1):
-                input = (1 << count) - 1
+                inval = (1 << count) - 1
                 for chunk_sizes in range(2 ** log2_width):
-                    yield from case(input, chunk_sizes)
+                    yield from case(inval, chunk_sizes)
             for i in range(100):
-                input = hash_256(f"grev input {i}")
-                input &= 2 ** width - 1
+                inval = hash_256(f"grev input {i}")
+                inval &= 2 ** width - 1
                 chunk_sizes = hash_256(f"grev 2 {i}")
                 chunk_sizes &= 2 ** log2_width - 1
-                yield from case(input, chunk_sizes)
+                yield from case(inval, chunk_sizes)
         with do_sim(self, dut, [dut.input, dut.chunk_sizes,
-                                *dut._steps, dut.output]) as sim:
+                                dut.output]) as sim:
             sim.add_process(process)
             sim.run()
 
-    def test_formal(self):
+    def test(self):
+        self.tst(msb_first=False)
+
+    def test_msb_first(self):
+        self.tst(msb_first=True)
+
+    def test_small(self):
+        self.tst(msb_first=False, log2_width=3)
+
+    def test_small_msb_first(self):
+        self.tst(msb_first=True, log2_width=3)
+
+    def tst_formal(self, msb_first):
         log2_width = 4
-        dut = GRev(log2_width)
+        dut = GRev(log2_width, msb_first)
         m = Module()
         m.submodules.dut = dut
         m.d.comb += dut.input.eq(AnyConst(2 ** log2_width))
         m.d.comb += dut.chunk_sizes.eq(AnyConst(log2_width))
-        m.d.comb += Assert(dut.output == grev(dut.input,
-                                              dut.chunk_sizes, log2_width))
-        for i, step in enumerate(dut._steps):
-            cur_chunk_sizes = dut.chunk_sizes & (2 ** i - 1)
-            step_expected = grev(dut.input, cur_chunk_sizes, log2_width)
-            m.d.comb += Assert(step == step_expected)
+        # actual formal correctness proof is inside the module itself, now
         self.assertFormal(m)
 
+    def test_formal(self):
+        self.tst_formal(msb_first=False)
+
+    def test_formal_msb_first(self):
+        self.tst_formal(msb_first=True)
+
 
 if __name__ == "__main__":
     unittest.main()