goldschmidt division works! still needs better parameter selection tho...
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 26 Apr 2022 03:56:30 +0000 (20:56 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 26 Apr 2022 03:56:30 +0000 (20:56 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index dc363c5a4bebe56df804193c36bf71fa11db76d7..62156ee5ec418d9490e31edb89a6f401d587a06f 100644 (file)
@@ -125,13 +125,6 @@ class FixedPoint:
             denominator = value.denominator
         else:
             value = FixedPoint.cast(value)
-            # compute number of bits that should be removed from value
-            del_bits = value.frac_wid - frac_wid
-            if del_bits == 0:
-                return value
-            if del_bits < 0:  # add bits
-                return FixedPoint(value.bits << -del_bits,
-                                  frac_wid)
             numerator = value.bits
             denominator = 1 << value.frac_wid
         if denominator < 0:
@@ -313,13 +306,49 @@ class GoldschmidtDivParams:
     iter_count: int
     """the total number of iterations of the division algorithm's loop"""
 
-    # tuple to be immutable
-    table: "tuple[FixedPoint, ...]" = field(init=False)
+    # tuple to be immutable, default so repr() works for debugging even when
+    # __post_init__ hasn't finished running yet
+    table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented)
     """the lookup-table"""
 
-    ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False)
+    ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False,
+                                                default=NotImplemented)
     """the operations needed to perform the goldschmidt division algorithm."""
 
+    def _shrink_bound(self, bound, round_dir):
+        """prevent fractions from having huge numerators/denominators by
+        rounding to a `FixedPoint` and converting back to a `Fraction`.
+
+        This is intended only for values used to compute bounds, and not for
+        values that end up in the hardware.
+        """
+        assert isinstance(bound, (Fraction, int))
+        assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
+            "you shouldn't use that round_dir on bounds"
+        frac_wid = self.io_width * 4 + 100  # should be enough precision
+        fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
+        return fixed.as_fraction()
+
+    def _shrink_min(self, min_bound):
+        """prevent fractions used as minimum bounds from having huge
+        numerators/denominators by rounding down to a `FixedPoint` and
+        converting back to a `Fraction`.
+
+        This is intended only for values used to compute bounds, and not for
+        values that end up in the hardware.
+        """
+        return self._shrink_bound(min_bound, RoundDir.DOWN)
+
+    def _shrink_max(self, max_bound):
+        """prevent fractions used as maximum bounds from having huge
+        numerators/denominators by rounding up to a `FixedPoint` and
+        converting back to a `Fraction`.
+
+        This is intended only for values used to compute bounds, and not for
+        values that end up in the hardware.
+        """
+        return self._shrink_bound(max_bound, RoundDir.UP)
+
     @property
     def table_addr_count(self):
         """number of distinct addresses in the lookup-table."""
@@ -332,20 +361,29 @@ class GoldschmidtDivParams:
         assert isinstance(addr, int)
         assert 0 <= addr < self.table_addr_count
         _assert_accuracy(self.io_width >= self.table_addr_bits)
-        min_numerator = (1 << self.table_addr_bits) + addr
-        denominator = 1 << self.table_addr_bits
-        values_per_table_entry = 1 << (self.io_width - self.table_addr_bits)
-        max_numerator = min_numerator + values_per_table_entry
+        addr_shift = self.io_width - self.table_addr_bits
+        min_numerator = (1 << self.io_width) + (addr << addr_shift)
+        denominator = 1 << self.io_width
+        values_per_table_entry = 1 << addr_shift
+        max_numerator = min_numerator + values_per_table_entry - 1
         min_input = Fraction(min_numerator, denominator)
         max_input = Fraction(max_numerator, denominator)
+        min_input = self._shrink_min(min_input)
+        max_input = self._shrink_max(max_input)
+        assert 1 <= min_input <= max_input < 2
         return min_input, max_input
 
     def table_value_exact_range(self, addr):
         """return the range of values as `Fraction`s used for the table entry
         with address `addr`."""
-        min_value, max_value = self.table_input_exact_range(addr)
+        min_input, max_input = self.table_input_exact_range(addr)
         # division swaps min/max
-        return 1 / max_value, 1 / min_value
+        min_value = 1 / max_input
+        max_value = 1 / min_input
+        min_value = self._shrink_min(min_value)
+        max_value = self._shrink_max(max_value)
+        assert 0.5 < min_value <= max_value <= 1
+        return min_value, max_value
 
     def table_exact_value(self, index):
         min_value, max_value = self.table_value_exact_range(index)
@@ -374,6 +412,8 @@ class GoldschmidtDivParams:
         with `params.io_width == io_width`.
         """
         assert isinstance(io_width, int) and io_width >= 1
+        last_params = None
+        last_error = None
         for extra_precision in range(io_width * 2 + 4):
             for table_addr_bits in range(1, 7 + 1):
                 table_data_bits = io_width + extra_precision
@@ -385,10 +425,17 @@ class GoldschmidtDivParams:
                             table_addr_bits=table_addr_bits,
                             table_data_bits=table_data_bits,
                             iter_count=iter_count)
-                    except ParamsNotAccurateEnough:
-                        pass
+                    except ParamsNotAccurateEnough as e:
+                        last_params = (f"GoldschmidtDivParams("
+                                       f"io_width={io_width!r}, "
+                                       f"extra_precision={extra_precision!r}, "
+                                       f"table_addr_bits={table_addr_bits!r}, "
+                                       f"table_data_bits={table_data_bits!r}, "
+                                       f"iter_count={iter_count!r})")
+                        last_error = e
         raise ValueError(f"can't find working parameters for a goldschmidt "
-                         f"division algorithm with io_width={io_width}")
+                         f"division algorithm: last params: {last_params}"
+                         ) from last_error
 
     @property
     def expanded_width(self):
@@ -442,6 +489,8 @@ class GoldschmidtDivParams:
             cur_max_e0 = 1 - min_product
             min_e0 = min(min_e0, cur_min_e0)
             max_e0 = max(max_e0, cur_max_e0)
+        min_e0 = self._shrink_min(min_e0)
+        max_e0 = self._shrink_max(max_e0)
         return min_e0, max_e0
 
     @cached_property
@@ -500,9 +549,7 @@ class GoldschmidtDivParams:
                 "only one quadrant of interval division implemented"
             retval = self.max_neps(i) / min_mpd
 
-        # we need Fraction to avoid using float by accident
-        # -- it also hints to the IDE to give the correct type
-        return Fraction(retval)
+        return self._shrink_max(retval)
 
     @cache_on_self
     def max_d(self, i):
@@ -533,9 +580,7 @@ class GoldschmidtDivParams:
                 "only one quadrant of interval division implemented"
             retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
 
-        # we need Fraction to avoid using float by accident
-        # -- it also hints to the IDE to give the correct type
-        return Fraction(retval)
+        return self._shrink_max(retval)
 
     @cache_on_self
     def max_f(self, i):
@@ -559,9 +604,7 @@ class GoldschmidtDivParams:
             # `f[i] <= max_feps[i]`
             retval = self.max_feps(i)
 
-        # we need Fraction to avoid using float by accident
-        # -- it also hints to the IDE to give the correct type
-        return Fraction(retval)
+        return self._shrink_max(retval)
 
     @cache_on_self
     def max_delta(self, i):
@@ -578,9 +621,11 @@ class GoldschmidtDivParams:
             assert prev_max_delta >= 0
             retval = prev_max_delta ** 2 + self.max_f(i - 1)
 
-        # we need Fraction to avoid using float by accident
-        # -- it also hints to the IDE to give the correct type
-        return Fraction(retval)
+        # `delta[i]` has to be smaller than one otherwise errors would go off
+        # to infinity
+        _assert_accuracy(retval < 1)
+
+        return self._shrink_max(retval)
 
     @cache_on_self
     def max_pi(self, i):
@@ -591,7 +636,7 @@ class GoldschmidtDivParams:
         # `pi[i] = 1 - (1 - n[i]) * prod`
         # where `prod` is the product of,
         # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
-        min_prod = Fraction(0)
+        min_prod = Fraction(1)
         for j in range(i):
             max_n_j = self.max_n(j)
             max_d_j = self.max_d(j)
@@ -601,7 +646,8 @@ class GoldschmidtDivParams:
         max_n_i = self.max_n(i)
         assert max_n_i <= 1 and min_prod >= 0, \
             "only one quadrant of interval multiplication implemented"
-        return 1 - (1 - max_n_i) * min_prod
+        retval = 1 - (1 - max_n_i) * min_prod
+        return self._shrink_max(retval)
 
     @cached_property
     def max_n_shift(self):
@@ -729,15 +775,21 @@ def _goldschmidt_div_ops(params):
     # `p(N_prime[i]) <= (2 * i) * n_hat \`
     # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
     i = params.iter_count - 1  # last used `i`
-    max_rel_error = (2 * i) * n_hat + \
-        (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i)
+    # compute power manually to prevent huge intermediate values
+    power = params._shrink_max(params.max_abs_e0 + 3 * n_hat / 2)
+    for _ in range(i):
+        power = params._shrink_max(power * power)
+
+    max_rel_error = (2 * i) * n_hat + power
 
     min_a_over_b = Fraction(1, 2)
     max_a_over_b = Fraction(2)
     max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift)
     max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
 
-    _assert_accuracy(max_rel_error < max_allowed_rel_error)
+    _assert_accuracy(max_rel_error < max_allowed_rel_error,
+                     f"not accurate enough: max_rel_error={max_rel_error} "
+                     f"max_allowed_rel_error={max_allowed_rel_error}")
 
     yield GoldschmidtDivOp.CalcResult
 
index fd07d6151a3e2fb187e36b17d41905a64f09c6e9..9e2763410f65e20aa25888b8094bbcab3b80104b 100644 (file)
@@ -6,8 +6,8 @@
 
 import unittest
 from nmutil.formaltest import FHDLTestCase
-from soc.fu.div.experiment.goldschmidt_div_sqrt import (GoldschmidtDivParams, goldschmidt_div,
-                                                        FixedPoint)
+from soc.fu.div.experiment.goldschmidt_div_sqrt import (
+    GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint)
 
 
 class TestFixedPoint(FHDLTestCase):
@@ -21,7 +21,18 @@ class TestFixedPoint(FHDLTestCase):
 
 
 class TestGoldschmidtDiv(FHDLTestCase):
-    @unittest.skip("goldschmidt_div isn't finished yet")
+    def test_case1(self):
+        with self.assertRaises(ParamsNotAccurateEnough):
+            GoldschmidtDivParams(io_width=3, extra_precision=2,
+                                 table_addr_bits=3, table_data_bits=5,
+                                 iter_count=2)
+
+    def test_case2(self):
+        with self.assertRaises(ParamsNotAccurateEnough):
+            GoldschmidtDivParams(io_width=4, extra_precision=1,
+                                 table_addr_bits=1, table_data_bits=5,
+                                 iter_count=1)
+
     def tst(self, io_width):
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
@@ -36,14 +47,215 @@ class TestGoldschmidtDiv(FHDLTestCase):
                         with self.subTest(q=hex(q), r=hex(r)):
                             self.assertEqual((q, r), (expected_q, expected_r))
 
-    def test_1_through_5(self):
-        for io_width in range(1, 5 + 1):
+    def test_1_through_4(self):
+        for io_width in range(1, 4 + 1):
             with self.subTest(io_width=io_width):
                 self.tst(io_width)
 
+    def test_5(self):
+        self.tst(5)
+
     def test_6(self):
         self.tst(6)
 
+    def tst_params(self, io_width):
+        assert isinstance(io_width, int)
+        params = GoldschmidtDivParams.get(io_width)
+        print()
+        print(params)
+
+    def test_params_1(self):
+        self.tst_params(1)
+
+    def test_params_2(self):
+        self.tst_params(2)
+
+    def test_params_3(self):
+        self.tst_params(3)
+
+    def test_params_4(self):
+        self.tst_params(4)
+
+    def test_params_5(self):
+        self.tst_params(5)
+
+    def test_params_6(self):
+        self.tst_params(6)
+
+    def test_params_7(self):
+        self.tst_params(7)
+
+    def test_params_8(self):
+        self.tst_params(8)
+
+    def test_params_9(self):
+        self.tst_params(9)
+
+    def test_params_10(self):
+        self.tst_params(10)
+
+    def test_params_11(self):
+        self.tst_params(11)
+
+    def test_params_12(self):
+        self.tst_params(12)
+
+    def test_params_13(self):
+        self.tst_params(13)
+
+    def test_params_14(self):
+        self.tst_params(14)
+
+    def test_params_15(self):
+        self.tst_params(15)
+
+    def test_params_16(self):
+        self.tst_params(16)
+
+    def test_params_17(self):
+        self.tst_params(17)
+
+    def test_params_18(self):
+        self.tst_params(18)
+
+    def test_params_19(self):
+        self.tst_params(19)
+
+    def test_params_20(self):
+        self.tst_params(20)
+
+    def test_params_21(self):
+        self.tst_params(21)
+
+    def test_params_22(self):
+        self.tst_params(22)
+
+    def test_params_23(self):
+        self.tst_params(23)
+
+    def test_params_24(self):
+        self.tst_params(24)
+
+    def test_params_25(self):
+        self.tst_params(25)
+
+    def test_params_26(self):
+        self.tst_params(26)
+
+    def test_params_27(self):
+        self.tst_params(27)
+
+    def test_params_28(self):
+        self.tst_params(28)
+
+    def test_params_29(self):
+        self.tst_params(29)
+
+    def test_params_30(self):
+        self.tst_params(30)
+
+    def test_params_31(self):
+        self.tst_params(31)
+
+    def test_params_32(self):
+        self.tst_params(32)
+
+    def test_params_33(self):
+        self.tst_params(33)
+
+    def test_params_34(self):
+        self.tst_params(34)
+
+    def test_params_35(self):
+        self.tst_params(35)
+
+    def test_params_36(self):
+        self.tst_params(36)
+
+    def test_params_37(self):
+        self.tst_params(37)
+
+    def test_params_38(self):
+        self.tst_params(38)
+
+    def test_params_39(self):
+        self.tst_params(39)
+
+    def test_params_40(self):
+        self.tst_params(40)
+
+    def test_params_41(self):
+        self.tst_params(41)
+
+    def test_params_42(self):
+        self.tst_params(42)
+
+    def test_params_43(self):
+        self.tst_params(43)
+
+    def test_params_44(self):
+        self.tst_params(44)
+
+    def test_params_45(self):
+        self.tst_params(45)
+
+    def test_params_46(self):
+        self.tst_params(46)
+
+    def test_params_47(self):
+        self.tst_params(47)
+
+    def test_params_48(self):
+        self.tst_params(48)
+
+    def test_params_49(self):
+        self.tst_params(49)
+
+    def test_params_50(self):
+        self.tst_params(50)
+
+    def test_params_51(self):
+        self.tst_params(51)
+
+    def test_params_52(self):
+        self.tst_params(52)
+
+    def test_params_53(self):
+        self.tst_params(53)
+
+    def test_params_54(self):
+        self.tst_params(54)
+
+    def test_params_55(self):
+        self.tst_params(55)
+
+    def test_params_56(self):
+        self.tst_params(56)
+
+    def test_params_57(self):
+        self.tst_params(57)
+
+    def test_params_58(self):
+        self.tst_params(58)
+
+    def test_params_59(self):
+        self.tst_params(59)
+
+    def test_params_60(self):
+        self.tst_params(60)
+
+    def test_params_61(self):
+        self.tst_params(61)
+
+    def test_params_62(self):
+        self.tst_params(62)
+
+    def test_params_63(self):
+        self.tst_params(63)
+
+    def test_params_64(self):
+        self.tst_params(64)
+
 
 if __name__ == "__main__":
     unittest.main()