denominator = value.denominator
else:
value = FixedPoint.cast(value)
- # compute number of bits that should be removed from value
- del_bits = value.frac_wid - frac_wid
- if del_bits == 0:
- return value
- if del_bits < 0: # add bits
- return FixedPoint(value.bits << -del_bits,
- frac_wid)
numerator = value.bits
denominator = 1 << value.frac_wid
if denominator < 0:
iter_count: int
"""the total number of iterations of the division algorithm's loop"""
- # tuple to be immutable
- table: "tuple[FixedPoint, ...]" = field(init=False)
+ # tuple to be immutable, default so repr() works for debugging even when
+ # __post_init__ hasn't finished running yet
+ table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented)
"""the lookup-table"""
- ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False)
+ ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False,
+ default=NotImplemented)
"""the operations needed to perform the goldschmidt division algorithm."""
+ def _shrink_bound(self, bound, round_dir):
+ """prevent fractions from having huge numerators/denominators by
+ rounding to a `FixedPoint` and converting back to a `Fraction`.
+
+ This is intended only for values used to compute bounds, and not for
+ values that end up in the hardware.
+ """
+ assert isinstance(bound, (Fraction, int))
+ assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \
+ "you shouldn't use that round_dir on bounds"
+ frac_wid = self.io_width * 4 + 100 # should be enough precision
+ fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir)
+ return fixed.as_fraction()
+
+ def _shrink_min(self, min_bound):
+ """prevent fractions used as minimum bounds from having huge
+ numerators/denominators by rounding down to a `FixedPoint` and
+ converting back to a `Fraction`.
+
+ This is intended only for values used to compute bounds, and not for
+ values that end up in the hardware.
+ """
+ return self._shrink_bound(min_bound, RoundDir.DOWN)
+
+ def _shrink_max(self, max_bound):
+ """prevent fractions used as maximum bounds from having huge
+ numerators/denominators by rounding up to a `FixedPoint` and
+ converting back to a `Fraction`.
+
+ This is intended only for values used to compute bounds, and not for
+ values that end up in the hardware.
+ """
+ return self._shrink_bound(max_bound, RoundDir.UP)
+
@property
def table_addr_count(self):
"""number of distinct addresses in the lookup-table."""
assert isinstance(addr, int)
assert 0 <= addr < self.table_addr_count
_assert_accuracy(self.io_width >= self.table_addr_bits)
- min_numerator = (1 << self.table_addr_bits) + addr
- denominator = 1 << self.table_addr_bits
- values_per_table_entry = 1 << (self.io_width - self.table_addr_bits)
- max_numerator = min_numerator + values_per_table_entry
+ addr_shift = self.io_width - self.table_addr_bits
+ min_numerator = (1 << self.io_width) + (addr << addr_shift)
+ denominator = 1 << self.io_width
+ values_per_table_entry = 1 << addr_shift
+ max_numerator = min_numerator + values_per_table_entry - 1
min_input = Fraction(min_numerator, denominator)
max_input = Fraction(max_numerator, denominator)
+ min_input = self._shrink_min(min_input)
+ max_input = self._shrink_max(max_input)
+ assert 1 <= min_input <= max_input < 2
return min_input, max_input
def table_value_exact_range(self, addr):
"""return the range of values as `Fraction`s used for the table entry
with address `addr`."""
- min_value, max_value = self.table_input_exact_range(addr)
+ min_input, max_input = self.table_input_exact_range(addr)
# division swaps min/max
- return 1 / max_value, 1 / min_value
+ min_value = 1 / max_input
+ max_value = 1 / min_input
+ min_value = self._shrink_min(min_value)
+ max_value = self._shrink_max(max_value)
+ assert 0.5 < min_value <= max_value <= 1
+ return min_value, max_value
def table_exact_value(self, index):
min_value, max_value = self.table_value_exact_range(index)
with `params.io_width == io_width`.
"""
assert isinstance(io_width, int) and io_width >= 1
+ last_params = None
+ last_error = None
for extra_precision in range(io_width * 2 + 4):
for table_addr_bits in range(1, 7 + 1):
table_data_bits = io_width + extra_precision
table_addr_bits=table_addr_bits,
table_data_bits=table_data_bits,
iter_count=iter_count)
- except ParamsNotAccurateEnough:
- pass
+ except ParamsNotAccurateEnough as e:
+ last_params = (f"GoldschmidtDivParams("
+ f"io_width={io_width!r}, "
+ f"extra_precision={extra_precision!r}, "
+ f"table_addr_bits={table_addr_bits!r}, "
+ f"table_data_bits={table_data_bits!r}, "
+ f"iter_count={iter_count!r})")
+ last_error = e
raise ValueError(f"can't find working parameters for a goldschmidt "
- f"division algorithm with io_width={io_width}")
+ f"division algorithm: last params: {last_params}"
+ ) from last_error
@property
def expanded_width(self):
cur_max_e0 = 1 - min_product
min_e0 = min(min_e0, cur_min_e0)
max_e0 = max(max_e0, cur_max_e0)
+ min_e0 = self._shrink_min(min_e0)
+ max_e0 = self._shrink_max(max_e0)
return min_e0, max_e0
@cached_property
"only one quadrant of interval division implemented"
retval = self.max_neps(i) / min_mpd
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ return self._shrink_max(retval)
@cache_on_self
def max_d(self, i):
"only one quadrant of interval division implemented"
retval = self.max_deps(i) / (1 - self.max_delta(i - 1))
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ return self._shrink_max(retval)
@cache_on_self
def max_f(self, i):
# `f[i] <= max_feps[i]`
retval = self.max_feps(i)
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ return self._shrink_max(retval)
@cache_on_self
def max_delta(self, i):
assert prev_max_delta >= 0
retval = prev_max_delta ** 2 + self.max_f(i - 1)
- # we need Fraction to avoid using float by accident
- # -- it also hints to the IDE to give the correct type
- return Fraction(retval)
+ # `delta[i]` has to be smaller than one otherwise errors would go off
+ # to infinity
+ _assert_accuracy(retval < 1)
+
+ return self._shrink_max(retval)
@cache_on_self
def max_pi(self, i):
# `pi[i] = 1 - (1 - n[i]) * prod`
# where `prod` is the product of,
# for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])`
- min_prod = Fraction(0)
+ min_prod = Fraction(1)
for j in range(i):
max_n_j = self.max_n(j)
max_d_j = self.max_d(j)
max_n_i = self.max_n(i)
assert max_n_i <= 1 and min_prod >= 0, \
"only one quadrant of interval multiplication implemented"
- return 1 - (1 - max_n_i) * min_prod
+ retval = 1 - (1 - max_n_i) * min_prod
+ return self._shrink_max(retval)
@cached_property
def max_n_shift(self):
# `p(N_prime[i]) <= (2 * i) * n_hat \`
# ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)`
i = params.iter_count - 1 # last used `i`
- max_rel_error = (2 * i) * n_hat + \
- (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i)
+ # compute power manually to prevent huge intermediate values
+ power = params._shrink_max(params.max_abs_e0 + 3 * n_hat / 2)
+ for _ in range(i):
+ power = params._shrink_max(power * power)
+
+ max_rel_error = (2 * i) * n_hat + power
min_a_over_b = Fraction(1, 2)
max_a_over_b = Fraction(2)
max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift)
max_allowed_rel_error = max_allowed_abs_error / min_a_over_b
- _assert_accuracy(max_rel_error < max_allowed_rel_error)
+ _assert_accuracy(max_rel_error < max_allowed_rel_error,
+ f"not accurate enough: max_rel_error={max_rel_error} "
+ f"max_allowed_rel_error={max_allowed_rel_error}")
yield GoldschmidtDivOp.CalcResult
import unittest
from nmutil.formaltest import FHDLTestCase
-from soc.fu.div.experiment.goldschmidt_div_sqrt import (GoldschmidtDivParams, goldschmidt_div,
- FixedPoint)
+from soc.fu.div.experiment.goldschmidt_div_sqrt import (
+ GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint)
class TestFixedPoint(FHDLTestCase):
class TestGoldschmidtDiv(FHDLTestCase):
- @unittest.skip("goldschmidt_div isn't finished yet")
+ def test_case1(self):
+ with self.assertRaises(ParamsNotAccurateEnough):
+ GoldschmidtDivParams(io_width=3, extra_precision=2,
+ table_addr_bits=3, table_data_bits=5,
+ iter_count=2)
+
+ def test_case2(self):
+ with self.assertRaises(ParamsNotAccurateEnough):
+ GoldschmidtDivParams(io_width=4, extra_precision=1,
+ table_addr_bits=1, table_data_bits=5,
+ iter_count=1)
+
def tst(self, io_width):
assert isinstance(io_width, int)
params = GoldschmidtDivParams.get(io_width)
with self.subTest(q=hex(q), r=hex(r)):
self.assertEqual((q, r), (expected_q, expected_r))
- def test_1_through_5(self):
- for io_width in range(1, 5 + 1):
+ def test_1_through_4(self):
+ for io_width in range(1, 4 + 1):
with self.subTest(io_width=io_width):
self.tst(io_width)
+ def test_5(self):
+ self.tst(5)
+
def test_6(self):
self.tst(6)
+ def tst_params(self, io_width):
+ assert isinstance(io_width, int)
+ params = GoldschmidtDivParams.get(io_width)
+ print()
+ print(params)
+
+ def test_params_1(self):
+ self.tst_params(1)
+
+ def test_params_2(self):
+ self.tst_params(2)
+
+ def test_params_3(self):
+ self.tst_params(3)
+
+ def test_params_4(self):
+ self.tst_params(4)
+
+ def test_params_5(self):
+ self.tst_params(5)
+
+ def test_params_6(self):
+ self.tst_params(6)
+
+ def test_params_7(self):
+ self.tst_params(7)
+
+ def test_params_8(self):
+ self.tst_params(8)
+
+ def test_params_9(self):
+ self.tst_params(9)
+
+ def test_params_10(self):
+ self.tst_params(10)
+
+ def test_params_11(self):
+ self.tst_params(11)
+
+ def test_params_12(self):
+ self.tst_params(12)
+
+ def test_params_13(self):
+ self.tst_params(13)
+
+ def test_params_14(self):
+ self.tst_params(14)
+
+ def test_params_15(self):
+ self.tst_params(15)
+
+ def test_params_16(self):
+ self.tst_params(16)
+
+ def test_params_17(self):
+ self.tst_params(17)
+
+ def test_params_18(self):
+ self.tst_params(18)
+
+ def test_params_19(self):
+ self.tst_params(19)
+
+ def test_params_20(self):
+ self.tst_params(20)
+
+ def test_params_21(self):
+ self.tst_params(21)
+
+ def test_params_22(self):
+ self.tst_params(22)
+
+ def test_params_23(self):
+ self.tst_params(23)
+
+ def test_params_24(self):
+ self.tst_params(24)
+
+ def test_params_25(self):
+ self.tst_params(25)
+
+ def test_params_26(self):
+ self.tst_params(26)
+
+ def test_params_27(self):
+ self.tst_params(27)
+
+ def test_params_28(self):
+ self.tst_params(28)
+
+ def test_params_29(self):
+ self.tst_params(29)
+
+ def test_params_30(self):
+ self.tst_params(30)
+
+ def test_params_31(self):
+ self.tst_params(31)
+
+ def test_params_32(self):
+ self.tst_params(32)
+
+ def test_params_33(self):
+ self.tst_params(33)
+
+ def test_params_34(self):
+ self.tst_params(34)
+
+ def test_params_35(self):
+ self.tst_params(35)
+
+ def test_params_36(self):
+ self.tst_params(36)
+
+ def test_params_37(self):
+ self.tst_params(37)
+
+ def test_params_38(self):
+ self.tst_params(38)
+
+ def test_params_39(self):
+ self.tst_params(39)
+
+ def test_params_40(self):
+ self.tst_params(40)
+
+ def test_params_41(self):
+ self.tst_params(41)
+
+ def test_params_42(self):
+ self.tst_params(42)
+
+ def test_params_43(self):
+ self.tst_params(43)
+
+ def test_params_44(self):
+ self.tst_params(44)
+
+ def test_params_45(self):
+ self.tst_params(45)
+
+ def test_params_46(self):
+ self.tst_params(46)
+
+ def test_params_47(self):
+ self.tst_params(47)
+
+ def test_params_48(self):
+ self.tst_params(48)
+
+ def test_params_49(self):
+ self.tst_params(49)
+
+ def test_params_50(self):
+ self.tst_params(50)
+
+ def test_params_51(self):
+ self.tst_params(51)
+
+ def test_params_52(self):
+ self.tst_params(52)
+
+ def test_params_53(self):
+ self.tst_params(53)
+
+ def test_params_54(self):
+ self.tst_params(54)
+
+ def test_params_55(self):
+ self.tst_params(55)
+
+ def test_params_56(self):
+ self.tst_params(56)
+
+ def test_params_57(self):
+ self.tst_params(57)
+
+ def test_params_58(self):
+ self.tst_params(58)
+
+ def test_params_59(self):
+ self.tst_params(59)
+
+ def test_params_60(self):
+ self.tst_params(60)
+
+ def test_params_61(self):
+ self.tst_params(61)
+
+ def test_params_62(self):
+ self.tst_params(62)
+
+ def test_params_63(self):
+ self.tst_params(63)
+
+ def test_params_64(self):
+ self.tst_params(64)
+
if __name__ == "__main__":
unittest.main()