working on goldschmidt_div_sqrt.py
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 25 Apr 2022 08:44:31 +0000 (01:44 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 25 Apr 2022 08:44:31 +0000 (01:44 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index f4aee9daf00e329f5335bf911fcf54015dc98865..dc363c5a4bebe56df804193c36bf71fa11db76d7 100644 (file)
@@ -8,6 +8,46 @@ from dataclasses import dataclass, field
 import math
 import enum
 from fractions import Fraction
+from types import FunctionType
+
+try:
+    from functools import cached_property
+except ImportError:
+    from cached_property import cached_property
+
+# fix broken IDE type detection for cached_property
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+    from functools import cached_property
+
+
+_NOT_FOUND = object()
+
+
+def cache_on_self(func):
+    """like `functools.cached_property`, except for methods. unlike
+    `lru_cache` the cache is per-class instance rather than a global cache
+    per-method."""
+
+    assert isinstance(func, FunctionType), \
+        "non-plain methods are not supported"
+
+    cache_name = func.__name__ + "__cache"
+
+    def wrapper(self, *args, **kwargs):
+        # specifically access through `__dict__` to bypass frozen=True
+        cache = self.__dict__.get(cache_name, _NOT_FOUND)
+        if cache is _NOT_FOUND:
+            self.__dict__[cache_name] = cache = {}
+        key = (args, *kwargs.items())
+        retval = cache.get(key, _NOT_FOUND)
+        if retval is _NOT_FOUND:
+            retval = func(self, *args, **kwargs)
+            cache[key] = retval
+        return retval
+
+    wrapper.__doc__ = func.__doc__
+    return wrapper
 
 
 @enum.unique
@@ -212,14 +252,27 @@ class FixedPoint:
 
 @dataclass
 class GoldschmidtDivState:
+    orig_n: int
+    """original numerator"""
+
+    orig_d: int
+    """original denominator"""
+
     n: FixedPoint
     """numerator -- N_prime[i] in the paper's algorithm 2"""
+
     d: FixedPoint
     """denominator -- D_prime[i] in the paper's algorithm 2"""
+
     f: "FixedPoint | None" = None
     """current factor -- F_prime[i] in the paper's algorithm 2"""
-    result: "int | None" = None
-    """final result"""
+
+    quotient: "int | None" = None
+    """final quotient"""
+
+    remainder: "int | None" = None
+    """final remainder"""
+
     n_shift: "int | None" = None
     """amount the numerator needs to be left-shifted at the end of the
     algorithm.
@@ -242,19 +295,28 @@ class GoldschmidtDivParams:
     """parameters for a Goldschmidt division algorithm.
     Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
     """
+
     io_width: int
     """bit-width of the input divisor and the result.
     the input numerator is `2 * io_width`-bits wide.
     """
+
     extra_precision: int
     """number of bits of additional precision used inside the algorithm."""
+
     table_addr_bits: int
     """the number of address bits used in the lookup-table."""
+
     table_data_bits: int
     """the number of data bits used in the lookup-table."""
+
+    iter_count: int
+    """the total number of iterations of the division algorithm's loop"""
+
     # tuple to be immutable
     table: "tuple[FixedPoint, ...]" = field(init=False)
     """the lookup-table"""
+
     ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False)
     """the operations needed to perform the goldschmidt division algorithm."""
 
@@ -269,7 +331,7 @@ class GoldschmidtDivParams:
         with address `addr`."""
         assert isinstance(addr, int)
         assert 0 <= addr < self.table_addr_count
-        assert self.io_width >= self.table_addr_bits
+        _assert_accuracy(self.io_width >= self.table_addr_bits)
         min_numerator = (1 << self.table_addr_bits) + addr
         denominator = 1 << self.table_addr_bits
         values_per_table_entry = 1 << (self.io_width - self.table_addr_bits)
@@ -296,6 +358,7 @@ class GoldschmidtDivParams:
         assert self.extra_precision >= 0
         assert self.table_addr_bits >= 1
         assert self.table_data_bits >= 1
+        assert self.iter_count >= 1
         table = []
         for addr in range(1 << self.table_addr_bits):
             table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
@@ -311,17 +374,19 @@ class GoldschmidtDivParams:
         with `params.io_width == io_width`.
         """
         assert isinstance(io_width, int) and io_width >= 1
-        for extra_precision in range(io_width * 2):
-            for table_addr_bits in range(3, 7 + 1):
+        for extra_precision in range(io_width * 2 + 4):
+            for table_addr_bits in range(1, 7 + 1):
                 table_data_bits = io_width + extra_precision
-                try:
-                    return GoldschmidtDivParams(
-                        io_width=io_width,
-                        extra_precision=extra_precision,
-                        table_addr_bits=table_addr_bits,
-                        table_data_bits=table_data_bits)
-                except ParamsNotAccurateEnough:
-                    pass
+                for iter_count in range(1, 2 * io_width.bit_length()):
+                    try:
+                        return GoldschmidtDivParams(
+                            io_width=io_width,
+                            extra_precision=extra_precision,
+                            table_addr_bits=table_addr_bits,
+                            table_data_bits=table_data_bits,
+                            iter_count=iter_count)
+                    except ParamsNotAccurateEnough:
+                        pass
         raise ValueError(f"can't find working parameters for a goldschmidt "
                          f"division algorithm with io_width={io_width}")
 
@@ -330,6 +395,227 @@ class GoldschmidtDivParams:
         """the total number of bits of precision used inside the algorithm."""
         return self.io_width + self.extra_precision
 
+    @cache_on_self
+    def max_neps(self, i):
+        """maximum value of `neps[i]`.
+        `neps[i]` is defined to be `n[i] * N_prime[i - 1] * F_prime[i - 1]`.
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        return Fraction(1, 1 << self.expanded_width)
+
+    @cache_on_self
+    def max_deps(self, i):
+        """maximum value of `deps[i]`.
+        `deps[i]` is defined to be `d[i] * D_prime[i - 1] * F_prime[i - 1]`.
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        return Fraction(1, 1 << self.expanded_width)
+
+    @cache_on_self
+    def max_feps(self, i):
+        """maximum value of `feps[i]`.
+        `feps[i]` is defined to be `f[i] * (2 - D_prime[i - 1])`.
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        # zero, because the computation of `F_prime[i]` in
+        # `GoldschmidtDivOp.MulDByF.run(...)` is exact.
+        return Fraction(0)
+
+    @cached_property
+    def e0_range(self):
+        """minimum and maximum values of `e[0]`
+        (the relative error in `F_prime[-1]`)
+        """
+        min_e0 = Fraction(0)
+        max_e0 = Fraction(0)
+        for addr in range(self.table_addr_count):
+            # `F_prime[-1] = (1 - e[0]) / B`
+            # => `e[0] = 1 - B * F_prime[-1]`
+            min_b, max_b = self.table_input_exact_range(addr)
+            f_prime_m1 = self.table[addr].as_fraction()
+            assert min_b >= 0 and f_prime_m1 >= 0, \
+                "only positive quadrant of interval multiplication implemented"
+            min_product = min_b * f_prime_m1
+            max_product = max_b * f_prime_m1
+            # negation swaps min/max
+            cur_min_e0 = 1 - max_product
+            cur_max_e0 = 1 - min_product
+            min_e0 = min(min_e0, cur_min_e0)
+            max_e0 = max(max_e0, cur_max_e0)
+        return min_e0, max_e0
+
+    @cached_property
+    def min_e0(self):
+        """minimum value of `e[0]` (the relative error in `F_prime[-1]`)
+        """
+        min_e0, max_e0 = self.e0_range
+        return min_e0
+
+    @cached_property
+    def max_e0(self):
+        """maximum value of `e[0]` (the relative error in `F_prime[-1]`)
+        """
+        min_e0, max_e0 = self.e0_range
+        return max_e0
+
+    @cached_property
+    def max_abs_e0(self):
+        """maximum value of `abs(e[0])`."""
+        return max(abs(self.min_e0), abs(self.max_e0))
+
+    @cached_property
+    def min_abs_e0(self):
+        """minimum value of `abs(e[0])`."""
+        return Fraction(0)
+
+    @cache_on_self
+    def max_n(self, i):
+        """maximum value of `n[i]` (the relative error in `N_prime[i]`
+        relative to the previous iteration)
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        if i == 0:
+            # from Claim 10
+            # `n[0] = neps[0] / ((1 - e[0]) * (A / B))`
+            # `n[0] <= 2 * neps[0] / (1 - e[0])`
+
+            assert self.max_e0 < 1 and self.max_neps(0) >= 0, \
+                "only one quadrant of interval division implemented"
+            retval = 2 * self.max_neps(0) / (1 - self.max_e0)
+        elif i == 1:
+            # from Claim 10
+            # `n[1] <= neps[1] / ((1 - f[0]) * (1 - pi[0] - delta[0]))`
+            min_mpd = 1 - self.max_pi(0) - self.max_delta(0)
+            assert self.max_f(0) <= 1 and min_mpd >= 0, \
+                "only one quadrant of interval multiplication implemented"
+            prod = (1 - self.max_f(0)) * min_mpd
+            assert self.max_neps(1) >= 0 and prod > 0, \
+                "only one quadrant of interval division implemented"
+            retval = self.max_neps(1) / prod
+        else:
+            # from Claim 6
+            # `0 <= n[i] <= 2 * max_neps[i] / (1 - pi[i - 1] - delta[i - 1])`
+            min_mpd = 1 - self.max_pi(i - 1) - self.max_delta(i - 1)
+            assert self.max_neps(i) >= 0 and min_mpd > 0, \
+                "only one quadrant of interval division implemented"
+            retval = self.max_neps(i) / min_mpd
+
+        # we need Fraction to avoid using float by accident
+        # -- it also hints to the IDE to give the correct type
+        return Fraction(retval)
+
+    @cache_on_self
+    def max_d(self, i):
+        """maximum value of `d[i]` (the relative error in `D_prime[i]`
+        relative to the previous iteration)
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        if i == 0:
+            # from Claim 10
+            # `d[0] = deps[0] / (1 - e[0])`
+
+            assert self.max_e0 < 1 and self.max_deps(0) >= 0, \
+                "only one quadrant of interval division implemented"
+            retval = self.max_deps(0) / (1 - self.max_e0)
+        elif i == 1:
+            # from Claim 10
+            # `d[1] <= deps[1] / ((1 - f[0]) * (1 - delta[0] ** 2))`
+            assert self.max_f(0) <= 1 and self.max_delta(0) <= 1, \
+                "only one quadrant of interval multiplication implemented"
+            divisor = (1 - self.max_f(0)) * (1 - self.max_delta(0) ** 2)
+            assert self.max_deps(1) >= 0 and divisor > 0, \
+                "only one quadrant of interval division implemented"
+            retval = self.max_deps(1) / divisor
+        else:
+            # from Claim 6
+            # `0 <= d[i] <= max_deps[i] / (1 - delta[i - 1])`
+            assert self.max_deps(i) >= 0 and self.max_delta(i - 1) < 1, \
+                "only one quadrant of interval division implemented"
+            retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
+
+        # we need Fraction to avoid using float by accident
+        # -- it also hints to the IDE to give the correct type
+        return Fraction(retval)
+
+    @cache_on_self
+    def max_f(self, i):
+        """maximum value of `f[i]` (the relative error in `F_prime[i]`
+        relative to the previous iteration)
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        if i == 0:
+            # from Claim 10
+            # `f[0] = feps[0] / (1 - delta[0])`
+
+            assert self.max_delta(0) < 1 and self.max_feps(0) >= 0, \
+                "only one quadrant of interval division implemented"
+            retval = self.max_feps(0) / (1 - self.max_delta(0))
+        elif i == 1:
+            # from Claim 10
+            # `f[1] = feps[1]`
+            retval = self.max_feps(1)
+        else:
+            # from Claim 6
+            # `f[i] <= max_feps[i]`
+            retval = self.max_feps(i)
+
+        # we need Fraction to avoid using float by accident
+        # -- it also hints to the IDE to give the correct type
+        return Fraction(retval)
+
+    @cache_on_self
+    def max_delta(self, i):
+        """ maximum value of `delta[i]`.
+        `delta[i]` is defined in Definition 4 of paper.
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        if i == 0:
+            # `delta[0] = abs(e[0]) + 3 * d[0] / 2`
+            retval = self.max_abs_e0 + Fraction(3, 2) * self.max_d(0)
+        else:
+            # `delta[i] = delta[i - 1] ** 2 + f[i - 1]`
+            prev_max_delta = self.max_delta(i - 1)
+            assert prev_max_delta >= 0
+            retval = prev_max_delta ** 2 + self.max_f(i - 1)
+
+        # we need Fraction to avoid using float by accident
+        # -- it also hints to the IDE to give the correct type
+        return Fraction(retval)
+
+    @cache_on_self
+    def max_pi(self, i):
+        """ maximum value of `pi[i]`.
+        `pi[i]` is defined right below Theorem 5 of paper.
+        """
+        assert isinstance(i, int) and 0 <= i < self.iter_count
+        # `pi[i] = 1 - (1 - n[i]) * prod`
+        # where `prod` is the product of,
+        # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
+        min_prod = Fraction(0)
+        for j in range(i):
+            max_n_j = self.max_n(j)
+            max_d_j = self.max_d(j)
+            assert max_n_j <= 1 and max_d_j > -1, \
+                "only one quadrant of interval division implemented"
+            min_prod *= (1 - max_n_j) / (1 + max_d_j)
+        max_n_i = self.max_n(i)
+        assert max_n_i <= 1 and min_prod >= 0, \
+            "only one quadrant of interval multiplication implemented"
+        return 1 - (1 - max_n_i) * min_prod
+
+    @cached_property
+    def max_n_shift(self):
+        """ maximum value of `state.n_shift`.
+        """
+        # input numerator is `2*io_width`-bits
+        max_n = (1 << (self.io_width * 2)) - 1
+        max_n_shift = 0
+        # normalize so 1 <= n < 2
+        while max_n >= 2:
+            max_n >>= 1
+            max_n_shift += 1
+        return max_n_shift
+
 
 @enum.unique
 class GoldschmidtDivOp(enum.Enum):
@@ -378,9 +664,11 @@ class GoldschmidtDivOp(enum.Enum):
             # scale to correct value
             n = state.n * (1 << state.n_shift)
 
-            # avoid incorrectly rounding down
-            n = n.to_frac_wid(params.io_width, round_dir=RoundDir.UP)
-            state.result = math.floor(n)
+            state.quotient = math.floor(n)
+            state.remainder = state.orig_n - state.quotient * state.orig_d
+            if state.remainder >= state.orig_d:
+                state.quotient += 1
+                state.remainder -= state.orig_d
         else:
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
@@ -412,37 +700,8 @@ def _goldschmidt_div_ops(params):
     _assert_accuracy(params.expanded_width > 4)
 
     # 3. require `abs(e[0]) + 3 * d[0] / 2 + f[0] < 1 / 2`.
-
-    # maximum `abs(e[0])`
-    max_abs_e0 = 0
-    # maximum `d[0]`
-    max_d0 = 0
-    # `f[i] = 0` for all `i`
-    fi = 0
-    for addr in range(params.table_addr_count):
-        # `F_prime[-1] = (1 - e[0]) / B`
-        # => `e[0] = 1 - B * F_prime[-1]`
-        min_b, max_b = params.table_input_exact_range(addr)
-        f_prime_m1 = params.table[addr].as_fraction()
-        assert min_b >= 0 and f_prime_m1 >= 0, \
-            "only positive quadrant of interval multiplication implemented"
-        min_product = min_b * f_prime_m1
-        max_product = max_b * f_prime_m1
-        # negation swaps min/max
-        min_e0 = 1 - max_product
-        max_e0 = 1 - min_product
-        max_abs_e0 = max(max_abs_e0, abs(min_e0), abs(max_e0))
-
-        # `D_prime[0] = (1 + d[0]) * B * F_prime[-1]`
-        # `D_prime[0] = abs_round_err + B * F_prime[-1]`
-        # => `d[0] = abs_round_err / (B * F_prime[-1])`
-        max_abs_round_err = Fraction(1, 1 << params.expanded_width)
-        assert min_product > 0 and max_abs_round_err >= 0, \
-            "only positive quadrant of interval division implemented"
-        # division swaps divisor's min/max
-        max_d0 = max(max_d0, max_abs_round_err / min_product)
-
-    _assert_accuracy(max_abs_e0 + 3 * max_d0 / 2 + fi < Fraction(1, 2))
+    _assert_accuracy(params.max_abs_e0 + 3 * params.max_d(0) / 2
+                     + params.max_f(0) < Fraction(1, 2))
 
     # 4. the initial approximation F'[-1] of 1/B is in [1/2, 1].
     # (B is the denominator)
@@ -453,16 +712,32 @@ def _goldschmidt_div_ops(params):
 
     yield GoldschmidtDivOp.FEqTableLookup
 
-    # we use Setting I (section 4.1 of the paper)
-
-    min_bits_of_precision = 1
-    # FIXME: calculate error and check if it's small enough
-    while min_bits_of_precision < params.io_width * 2:
+    # we use Setting I (section 4.1 of the paper):
+    # Require `n[i] <= n_hat` and `d[i] <= n_hat` and `f[i] = 0`
+    n_hat = Fraction(0)
+    for i in range(params.iter_count):
+        _assert_accuracy(params.max_f(i) == 0)
+        n_hat = max(n_hat, params.max_n(i), params.max_d(i))
         yield GoldschmidtDivOp.MulNByF
-        yield GoldschmidtDivOp.MulDByF
-        yield GoldschmidtDivOp.FEq2MinusD
-
-        min_bits_of_precision *= 2
+        if i != params.iter_count - 1:
+            yield GoldschmidtDivOp.MulDByF
+            yield GoldschmidtDivOp.FEq2MinusD
+
+    # relative approximation error `p(N_prime[i])`:
+    # `p(N_prime[i]) = (A / B - N_prime[i]) / (A / B)`
+    # `0 <= p(N_prime[i])`
+    # `p(N_prime[i]) <= (2 * i) * n_hat \`
+    # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
+    i = params.iter_count - 1  # last used `i`
+    max_rel_error = (2 * i) * n_hat + \
+        (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i)
+
+    min_a_over_b = Fraction(1, 2)
+    max_a_over_b = Fraction(2)
+    max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift)
+    max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
+
+    _assert_accuracy(max_rel_error < max_allowed_rel_error)
 
     yield GoldschmidtDivOp.CalcResult
 
@@ -485,8 +760,9 @@ def goldschmidt_div(n, d, params):
         width: int
             the bit-width of the inputs/outputs. must be a positive integer.
 
-        returns: int
-            the quotient. a `width`-bit unsigned integer.
+        returns: tuple[int, int]
+            the quotient and remainder. a tuple of two `width`-bit unsigned
+            integers.
     """
     assert isinstance(params, GoldschmidtDivParams)
     assert isinstance(d, int) and 0 < d < (1 << params.io_width)
@@ -496,6 +772,8 @@ def goldschmidt_div(n, d, params):
     # have `width` fractional bits
 
     state = GoldschmidtDivState(
+        orig_n=n,
+        orig_d=d,
         n=FixedPoint(n, params.io_width),
         d=FixedPoint(d, params.io_width),
     )
@@ -503,6 +781,7 @@ def goldschmidt_div(n, d, params):
     for op in params.ops:
         op.run(params, state)
 
-    assert state.result is not None
+    assert state.quotient is not None
+    assert state.remainder is not None
 
-    return state.result
+    return state.quotient, state.remainder
index b4c9da7fa492524ac5a360fac82fb0ced93ee5c9..fd07d6151a3e2fb187e36b17d41905a64f09c6e9 100644 (file)
@@ -25,18 +25,21 @@ class TestGoldschmidtDiv(FHDLTestCase):
     def tst(self, io_width):
         assert isinstance(io_width, int)
         params = GoldschmidtDivParams.get(io_width)
-        print(params)
-        for d in range(1, 1 << io_width):
-            for n in range(d << io_width):
-                expected = n // d
-                with self.subTest(io_width=io_width, n=hex(n), d=hex(d),
-                                  expected=hex(expected)):
-                    result = goldschmidt_div(n, d, params)
-                    self.assertEqual(result, expected, f"result={hex(result)}")
+        with self.subTest(params=str(params)):
+            for d in range(1, 1 << io_width):
+                for n in range(d << io_width):
+                    expected_q, expected_r = divmod(n, d)
+                    with self.subTest(n=hex(n), d=hex(d),
+                                      expected_q=hex(expected_q),
+                                      expected_r=hex(expected_r)):
+                        q, r = goldschmidt_div(n, d, params)
+                        with self.subTest(q=hex(q), r=hex(r)):
+                            self.assertEqual((q, r), (expected_q, expected_r))
 
     def test_1_through_5(self):
-        for width in range(1, 5 + 1):
-            self.tst(width)
+        for io_width in range(1, 5 + 1):
+            with self.subTest(io_width=io_width):
+                self.tst(io_width)
 
     def test_6(self):
         self.tst(6)