switch to better CLDivRem algorithm
[nmigen-gf.git] / src / nmigen_gf / hdl / test / test_cldivrem.py
index de7cf34025e5e7c13e173d3c30a3440abb63c2f1..27820705176c7528d2298a499292f72b03377b9b 100644 (file)
@@ -144,9 +144,12 @@ class TestCLDivRemShifting(FHDLTestCase):
 class TestCLDivRemComb(FHDLTestCase):
     def tst(self, shape, full):
         assert isinstance(shape, CLDivRemShape)
+        width = shape.width
         m = Module()
-        n_in = Signal(shape.n_width)
-        d_in = Signal(shape.width)
+        n_in = Signal(width)
+        d_in = Signal(width)
+        q_out = Signal(width)
+        r_out = Signal(width)
         states: "list[CLDivRemState]" = []
         for i in shape.step_range:
             states.append(CLDivRemState(shape, name=f"state_{i}"))
@@ -154,13 +157,14 @@ class TestCLDivRemComb(FHDLTestCase):
                 states[i].set_to_initial(m, n=n_in, d=d_in)
             else:
                 states[i].set_to_next(m, states[i - 1])
+        q, r = states[-1].get_output()
+        m.d.comb += [q_out.eq(q), r_out.eq(r)]
 
         def case(n, d):
             assert isinstance(n, int)
             assert isinstance(d, int)
-            max_width = max(shape.width, shape.n_width)
             if d != 0:
-                expected_q, expected_r = cldivrem(n, d, width=max_width)
+                expected_q, expected_r = cldivrem_shifting(n, d, width)
             else:
                 expected_q = expected_r = 0
             with self.subTest(n=hex(n), d=hex(d),
@@ -175,35 +179,38 @@ class TestCLDivRemComb(FHDLTestCase):
                         step = yield states[i].step
                         self.assertEqual(done, i >= shape.done_step)
                         self.assertEqual(step, i)
-                q = yield states[-1].q
-                r = yield states[-1].r
+                q = yield q_out
+                r = yield r_out
                 with self.subTest(q=hex(q), r=hex(r)):
                     # only check results when inputs are valid
-                    if d != 0 and (expected_q >> shape.width) == 0:
+                    if d != 0:
                         self.assertEqual(q, expected_q)
                         self.assertEqual(r, expected_r)
 
         def process():
             if full:
-                for n in range(1 << shape.n_width):
-                    for d in range(1 << shape.width):
+                for n in range(1 << width):
+                    for d in range(1 << width):
                         yield from case(n, d)
             else:
                 for i in range(100):
                     n = hash_256(f"cldivrem comb n {i}")
-                    n = Const.normalize(n, unsigned(shape.n_width))
+                    n = Const.normalize(n, unsigned(width))
                     d = hash_256(f"cldivrem comb d {i}")
-                    d = Const.normalize(d, unsigned(shape.width))
+                    d = Const.normalize(d, unsigned(width))
                     yield from case(n, d)
-        with do_sim(self, m, [n_in, d_in, states[-1].q, states[-1].r]) as sim:
+        with do_sim(self, m, [n_in, d_in, q_out, r_out]) as sim:
             sim.add_process(process)
             sim.run()
 
     def test_4(self):
-        self.tst(CLDivRemShape(width=4, n_width=4), full=True)
+        self.tst(CLDivRemShape(width=4), full=True)
+
+    def test_6(self):
+        self.tst(CLDivRemShape(width=6), full=True)
 
-    def test_8_by_4(self):
-        self.tst(CLDivRemShape(width=4, n_width=8), full=True)
+    def test_8(self):
+        self.tst(CLDivRemShape(width=8), full=False)
 
 
 class TestCLDivRemFSM(FHDLTestCase):
@@ -214,7 +221,7 @@ class TestCLDivRemFSM(FHDLTestCase):
         dut = CLDivRemFSMStage(pspec, shape, steps_per_clock=steps_per_clock)
         i_data: CLDivRemInputData = dut.p.i_data
         o_data: CLDivRemOutputData = dut.n.o_data
-        self.assertEqual(i_data.n.shape(), unsigned(shape.n_width))
+        self.assertEqual(i_data.n.shape(), unsigned(shape.width))
         self.assertEqual(i_data.d.shape(), unsigned(shape.width))
         self.assertEqual(o_data.q.shape(), unsigned(shape.width))
         self.assertEqual(o_data.r.shape(), unsigned(shape.width))
@@ -222,9 +229,8 @@ class TestCLDivRemFSM(FHDLTestCase):
         def case(n, d):
             assert isinstance(n, int)
             assert isinstance(d, int)
-            max_width = max(shape.width, shape.n_width)
             if d != 0:
-                expected_q, expected_r = cldivrem(n, d, width=max_width)
+                expected_q, expected_r = cldivrem(n, d, width=shape.width)
             else:
                 expected_q = expected_r = 0
             with self.subTest(n=hex(n), d=hex(d),
@@ -245,8 +251,7 @@ class TestCLDivRemFSM(FHDLTestCase):
                 yield i_data.n.eq(-1)
                 yield i_data.d.eq(-1)
                 yield dut.p.i_valid.eq(0)
-                for i in range(steps_per_clock * 2, shape.done_step,
-                               steps_per_clock):
+                for step in range(0, shape.done_step, steps_per_clock):
                     yield Delay(0.1e-6)
                     valid = yield dut.n.o_valid
                     ready = yield dut.p.o_ready
@@ -279,13 +284,13 @@ class TestCLDivRemFSM(FHDLTestCase):
 
         def process():
             if full:
-                for n in range(1 << shape.n_width):
+                for n in range(1 << shape.width):
                     for d in range(1 << shape.width):
                         yield from case(n, d)
             else:
                 for i in range(100):
                     n = hash_256(f"cldivrem fsm n {i}")
-                    n = Const.normalize(n, unsigned(shape.n_width))
+                    n = Const.normalize(n, unsigned(shape.width))
                     d = hash_256(f"cldivrem fsm d {i}")
                     d = Const.normalize(d, unsigned(shape.width))
                     yield from case(n, d)
@@ -296,20 +301,40 @@ class TestCLDivRemFSM(FHDLTestCase):
             sim.run()
 
     def test_4_step_1(self):
-        self.tst(CLDivRemShape(width=4, n_width=4),
+        self.tst(CLDivRemShape(width=4),
                  full=True,
                  steps_per_clock=1)
 
     def test_4_step_2(self):
-        self.tst(CLDivRemShape(width=4, n_width=4),
+        self.tst(CLDivRemShape(width=4),
                  full=True,
                  steps_per_clock=2)
 
     def test_4_step_3(self):
-        self.tst(CLDivRemShape(width=4, n_width=4),
+        self.tst(CLDivRemShape(width=4),
                  full=True,
                  steps_per_clock=3)
 
+    def test_4_step_4(self):
+        self.tst(CLDivRemShape(width=4),
+                 full=True,
+                 steps_per_clock=4)
+
+    def test_8_step_4(self):
+        self.tst(CLDivRemShape(width=8),
+                 full=False,
+                 steps_per_clock=4)
+
+    def test_64_step_4(self):
+        self.tst(CLDivRemShape(width=64),
+                 full=False,
+                 steps_per_clock=4)
+
+    def test_64_step_8(self):
+        self.tst(CLDivRemShape(width=64),
+                 full=False,
+                 steps_per_clock=8)
+
 
 if __name__ == "__main__":
     unittest.main()