fix so HDL works for 5, 8, 16, 32, and 64-bits.
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 29 Apr 2022 06:04:05 +0000 (23:04 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 29 Apr 2022 06:04:05 +0000 (23:04 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index a86fa78d111e3957b31e45a645b0ba8ac5e85174..1f3f81a4df74891647355d4eefc58964ccc69c90 100644 (file)
@@ -1097,7 +1097,12 @@ class GoldschmidtDivOp(enum.Enum):
             assert state.f is not None
             assert state.f.width == params.n_d_f_total_wid, "invalid f width"
             d = Signal.like(state.d)
-            m.d.comb += d.eq((state.d * state.f) >> params.expanded_width)
+            d_times_f = Signal.like(state.d * state.f)
+            m.d.comb += [
+                d_times_f.eq(state.d * state.f),
+                d.eq((d_times_f >> params.expanded_width)
+                     + (d_times_f[:params.expanded_width] != 0)),
+            ]
             state.d = d
         elif self == GoldschmidtDivOp.FEq2MinusD:
             assert state.d.width == params.n_d_f_total_wid, "invalid d width"
index b4d4fb85cd232b12b8eb28f40226e0679b23855a..66345fe20c77f9505b3dfade84d45bff15d40d0e 100644 (file)
@@ -8,7 +8,7 @@ from dataclasses import fields, replace
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
-from nmutil.sim_util import do_sim
+from nmutil.sim_util import do_sim, hash_256
 from nmigen.sim import Tick, Delay
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
@@ -76,19 +76,42 @@ class TestGoldschmidtDiv(FHDLTestCase):
                                  table_addr_bits=1, table_data_bits=5,
                                  iter_count=1)
 
-    def tst(self, io_width):
+    @staticmethod
+    def cases(io_width, cases=None):
+        assert isinstance(io_width, int) and io_width >= 1
+        if cases is not None:
+            for n, d in cases:
+                assert isinstance(d, int) \
+                    and 0 < d < (1 << io_width), "invalid case"
+                assert isinstance(n, int) \
+                    and 0 <= n < (d << io_width), "invalid case"
+                yield (n, d)
+        elif io_width > 6:
+            assert io_width * 2 <= 256, \
+                "can't generate big enough numbers for test cases"
+            for i in range(10000):
+                d = hash_256(f'd {i}') % (1 << io_width)
+                if d == 0:
+                    d = 1
+                n = hash_256(f'n {i}') % (d << io_width)
+                yield (n, d)
+        else:
+            for d in range(1, 1 << io_width):
+                for n in range(d << io_width):
+                    yield (n, d)
+
+    def tst(self, io_width, cases=None):
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
         with self.subTest(params=str(params)):
-            for d in range(1, 1 << io_width):
-                for n in range(d << io_width):
-                    expected_q, expected_r = divmod(n, d)
-                    with self.subTest(n=hex(n), d=hex(d),
-                                      expected_q=hex(expected_q),
-                                      expected_r=hex(expected_r)):
-                        q, r = goldschmidt_div(n, d, params)
-                        with self.subTest(q=hex(q), r=hex(r)):
-                            self.assertEqual((q, r), (expected_q, expected_r))
+            for n, d in self.cases(io_width, cases):
+                expected_q, expected_r = divmod(n, d)
+                with self.subTest(n=hex(n), d=hex(d),
+                                  expected_q=hex(expected_q),
+                                  expected_r=hex(expected_r)):
+                    q, r = goldschmidt_div(n, d, params)
+                    with self.subTest(q=hex(q), r=hex(r)):
+                        self.assertEqual((q, r), (expected_q, expected_r))
 
     def tst_sim(self, io_width, cases=None, pipe_reg_indexes=(),
                 sync_rom=False):
@@ -101,22 +124,9 @@ class TestGoldschmidtDiv(FHDLTestCase):
         # make sync domain get added
         m.d.sync += Signal().eq(0)
 
-        def iter_cases():
-            if cases is not None:
-                for n, d in cases:
-                    assert isinstance(d, int) \
-                        and 0 < d < (1 << params.io_width), "invalid case"
-                    assert isinstance(n, int) \
-                        and 0 <= n < (d << params.io_width), "invalid case"
-                    yield (n, d)
-                return
-            for d in range(1, 1 << io_width):
-                for n in range(d << io_width):
-                    yield (n, d)
-
         def inputs_proc():
             yield Tick()
-            for n, d in iter_cases():
+            for n, d in self.cases(io_width, cases):
                 yield dut.n.eq(n)
                 yield dut.d.eq(d)
                 yield Tick()
@@ -164,7 +174,7 @@ class TestGoldschmidtDiv(FHDLTestCase):
             yield Tick()
             for _ in range(dut.total_pipeline_registers):
                 yield Tick()
-            for n, d in iter_cases():
+            for n, d in self.cases(io_width, cases):
                 yield Delay(0.1e-6)
                 expected_q, expected_r = divmod(n, d)
                 with self.subTest(n=hex(n), d=hex(d),
@@ -196,9 +206,33 @@ class TestGoldschmidtDiv(FHDLTestCase):
     def test_6(self):
         self.tst(6)
 
+    def test_8(self):
+        self.tst(8)
+
+    def test_16(self):
+        self.tst(16)
+
+    def test_32(self):
+        self.tst(32)
+
+    def test_64(self):
+        self.tst(64)
+
     def test_sim_5(self):
         self.tst_sim(5)
 
+    def test_sim_8(self):
+        self.tst_sim(8)
+
+    def test_sim_16(self):
+        self.tst_sim(16)
+
+    def test_sim_32(self):
+        self.tst_sim(32)
+
+    def test_sim_64(self):
+        self.tst_sim(64)
+
     def tst_params(self, io_width):
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)