From 6af1a415a6885a2645fc1bc3adc8a3a3cc3aaaf7 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 25 Apr 2022 20:56:30 -0700 Subject: [PATCH] goldschmidt division works! still needs better parameter selection tho... --- .../fu/div/experiment/goldschmidt_div_sqrt.py | 124 +++++++--- .../test/test_goldschmidt_div_sqrt.py | 222 +++++++++++++++++- 2 files changed, 305 insertions(+), 41 deletions(-) diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index dc363c5a..62156ee5 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -125,13 +125,6 @@ class FixedPoint: denominator = value.denominator else: value = FixedPoint.cast(value) - # compute number of bits that should be removed from value - del_bits = value.frac_wid - frac_wid - if del_bits == 0: - return value - if del_bits < 0: # add bits - return FixedPoint(value.bits << -del_bits, - frac_wid) numerator = value.bits denominator = 1 << value.frac_wid if denominator < 0: @@ -313,13 +306,49 @@ class GoldschmidtDivParams: iter_count: int """the total number of iterations of the division algorithm's loop""" - # tuple to be immutable - table: "tuple[FixedPoint, ...]" = field(init=False) + # tuple to be immutable, default so repr() works for debugging even when + # __post_init__ hasn't finished running yet + table: "tuple[FixedPoint, ...]" = field(init=False, default=NotImplemented) """the lookup-table""" - ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False) + ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, + default=NotImplemented) """the operations needed to perform the goldschmidt division algorithm.""" + def _shrink_bound(self, bound, round_dir): + """prevent fractions from having huge numerators/denominators by + rounding to a `FixedPoint` and converting back to a `Fraction`. + + This is intended only for values used to compute bounds, and not for + values that end up in the hardware. + """ + assert isinstance(bound, (Fraction, int)) + assert round_dir is RoundDir.DOWN or round_dir is RoundDir.UP, \ + "you shouldn't use that round_dir on bounds" + frac_wid = self.io_width * 4 + 100 # should be enough precision + fixed = FixedPoint.with_frac_wid(bound, frac_wid, round_dir) + return fixed.as_fraction() + + def _shrink_min(self, min_bound): + """prevent fractions used as minimum bounds from having huge + numerators/denominators by rounding down to a `FixedPoint` and + converting back to a `Fraction`. + + This is intended only for values used to compute bounds, and not for + values that end up in the hardware. + """ + return self._shrink_bound(min_bound, RoundDir.DOWN) + + def _shrink_max(self, max_bound): + """prevent fractions used as maximum bounds from having huge + numerators/denominators by rounding up to a `FixedPoint` and + converting back to a `Fraction`. + + This is intended only for values used to compute bounds, and not for + values that end up in the hardware. + """ + return self._shrink_bound(max_bound, RoundDir.UP) + @property def table_addr_count(self): """number of distinct addresses in the lookup-table.""" @@ -332,20 +361,29 @@ class GoldschmidtDivParams: assert isinstance(addr, int) assert 0 <= addr < self.table_addr_count _assert_accuracy(self.io_width >= self.table_addr_bits) - min_numerator = (1 << self.table_addr_bits) + addr - denominator = 1 << self.table_addr_bits - values_per_table_entry = 1 << (self.io_width - self.table_addr_bits) - max_numerator = min_numerator + values_per_table_entry + addr_shift = self.io_width - self.table_addr_bits + min_numerator = (1 << self.io_width) + (addr << addr_shift) + denominator = 1 << self.io_width + values_per_table_entry = 1 << addr_shift + max_numerator = min_numerator + values_per_table_entry - 1 min_input = Fraction(min_numerator, denominator) max_input = Fraction(max_numerator, denominator) + min_input = self._shrink_min(min_input) + max_input = self._shrink_max(max_input) + assert 1 <= min_input <= max_input < 2 return min_input, max_input def table_value_exact_range(self, addr): """return the range of values as `Fraction`s used for the table entry with address `addr`.""" - min_value, max_value = self.table_input_exact_range(addr) + min_input, max_input = self.table_input_exact_range(addr) # division swaps min/max - return 1 / max_value, 1 / min_value + min_value = 1 / max_input + max_value = 1 / min_input + min_value = self._shrink_min(min_value) + max_value = self._shrink_max(max_value) + assert 0.5 < min_value <= max_value <= 1 + return min_value, max_value def table_exact_value(self, index): min_value, max_value = self.table_value_exact_range(index) @@ -374,6 +412,8 @@ class GoldschmidtDivParams: with `params.io_width == io_width`. """ assert isinstance(io_width, int) and io_width >= 1 + last_params = None + last_error = None for extra_precision in range(io_width * 2 + 4): for table_addr_bits in range(1, 7 + 1): table_data_bits = io_width + extra_precision @@ -385,10 +425,17 @@ class GoldschmidtDivParams: table_addr_bits=table_addr_bits, table_data_bits=table_data_bits, iter_count=iter_count) - except ParamsNotAccurateEnough: - pass + except ParamsNotAccurateEnough as e: + last_params = (f"GoldschmidtDivParams(" + f"io_width={io_width!r}, " + f"extra_precision={extra_precision!r}, " + f"table_addr_bits={table_addr_bits!r}, " + f"table_data_bits={table_data_bits!r}, " + f"iter_count={iter_count!r})") + last_error = e raise ValueError(f"can't find working parameters for a goldschmidt " - f"division algorithm with io_width={io_width}") + f"division algorithm: last params: {last_params}" + ) from last_error @property def expanded_width(self): @@ -442,6 +489,8 @@ class GoldschmidtDivParams: cur_max_e0 = 1 - min_product min_e0 = min(min_e0, cur_min_e0) max_e0 = max(max_e0, cur_max_e0) + min_e0 = self._shrink_min(min_e0) + max_e0 = self._shrink_max(max_e0) return min_e0, max_e0 @cached_property @@ -500,9 +549,7 @@ class GoldschmidtDivParams: "only one quadrant of interval division implemented" retval = self.max_neps(i) / min_mpd - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + return self._shrink_max(retval) @cache_on_self def max_d(self, i): @@ -533,9 +580,7 @@ class GoldschmidtDivParams: "only one quadrant of interval division implemented" retval = self.max_deps(i) / (1 - self.max_delta(i - 1)) - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + return self._shrink_max(retval) @cache_on_self def max_f(self, i): @@ -559,9 +604,7 @@ class GoldschmidtDivParams: # `f[i] <= max_feps[i]` retval = self.max_feps(i) - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + return self._shrink_max(retval) @cache_on_self def max_delta(self, i): @@ -578,9 +621,11 @@ class GoldschmidtDivParams: assert prev_max_delta >= 0 retval = prev_max_delta ** 2 + self.max_f(i - 1) - # we need Fraction to avoid using float by accident - # -- it also hints to the IDE to give the correct type - return Fraction(retval) + # `delta[i]` has to be smaller than one otherwise errors would go off + # to infinity + _assert_accuracy(retval < 1) + + return self._shrink_max(retval) @cache_on_self def max_pi(self, i): @@ -591,7 +636,7 @@ class GoldschmidtDivParams: # `pi[i] = 1 - (1 - n[i]) * prod` # where `prod` is the product of, # for `j` in `0 <= j < i`, `(1 - n[j]) / (1 + d[j])` - min_prod = Fraction(0) + min_prod = Fraction(1) for j in range(i): max_n_j = self.max_n(j) max_d_j = self.max_d(j) @@ -601,7 +646,8 @@ class GoldschmidtDivParams: max_n_i = self.max_n(i) assert max_n_i <= 1 and min_prod >= 0, \ "only one quadrant of interval multiplication implemented" - return 1 - (1 - max_n_i) * min_prod + retval = 1 - (1 - max_n_i) * min_prod + return self._shrink_max(retval) @cached_property def max_n_shift(self): @@ -729,15 +775,21 @@ def _goldschmidt_div_ops(params): # `p(N_prime[i]) <= (2 * i) * n_hat \` # ` + (abs(e[0]) + 3 * n_hat / 2) ** (2 ** i)` i = params.iter_count - 1 # last used `i` - max_rel_error = (2 * i) * n_hat + \ - (params.max_abs_e0 + 3 * n_hat / 2) ** (2 ** i) + # compute power manually to prevent huge intermediate values + power = params._shrink_max(params.max_abs_e0 + 3 * n_hat / 2) + for _ in range(i): + power = params._shrink_max(power * power) + + max_rel_error = (2 * i) * n_hat + power min_a_over_b = Fraction(1, 2) max_a_over_b = Fraction(2) max_allowed_abs_error = max_a_over_b / (1 << params.max_n_shift) max_allowed_rel_error = max_allowed_abs_error / min_a_over_b - _assert_accuracy(max_rel_error < max_allowed_rel_error) + _assert_accuracy(max_rel_error < max_allowed_rel_error, + f"not accurate enough: max_rel_error={max_rel_error} " + f"max_allowed_rel_error={max_allowed_rel_error}") yield GoldschmidtDivOp.CalcResult diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py index fd07d615..9e276341 100644 --- a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -6,8 +6,8 @@ import unittest from nmutil.formaltest import FHDLTestCase -from soc.fu.div.experiment.goldschmidt_div_sqrt import (GoldschmidtDivParams, goldschmidt_div, - FixedPoint) +from soc.fu.div.experiment.goldschmidt_div_sqrt import ( + GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint) class TestFixedPoint(FHDLTestCase): @@ -21,7 +21,18 @@ class TestFixedPoint(FHDLTestCase): class TestGoldschmidtDiv(FHDLTestCase): - @unittest.skip("goldschmidt_div isn't finished yet") + def test_case1(self): + with self.assertRaises(ParamsNotAccurateEnough): + GoldschmidtDivParams(io_width=3, extra_precision=2, + table_addr_bits=3, table_data_bits=5, + iter_count=2) + + def test_case2(self): + with self.assertRaises(ParamsNotAccurateEnough): + GoldschmidtDivParams(io_width=4, extra_precision=1, + table_addr_bits=1, table_data_bits=5, + iter_count=1) + def tst(self, io_width): assert isinstance(io_width, int) params = GoldschmidtDivParams.get(io_width) @@ -36,14 +47,215 @@ class TestGoldschmidtDiv(FHDLTestCase): with self.subTest(q=hex(q), r=hex(r)): self.assertEqual((q, r), (expected_q, expected_r)) - def test_1_through_5(self): - for io_width in range(1, 5 + 1): + def test_1_through_4(self): + for io_width in range(1, 4 + 1): with self.subTest(io_width=io_width): self.tst(io_width) + def test_5(self): + self.tst(5) + def test_6(self): self.tst(6) + def tst_params(self, io_width): + assert isinstance(io_width, int) + params = GoldschmidtDivParams.get(io_width) + print() + print(params) + + def test_params_1(self): + self.tst_params(1) + + def test_params_2(self): + self.tst_params(2) + + def test_params_3(self): + self.tst_params(3) + + def test_params_4(self): + self.tst_params(4) + + def test_params_5(self): + self.tst_params(5) + + def test_params_6(self): + self.tst_params(6) + + def test_params_7(self): + self.tst_params(7) + + def test_params_8(self): + self.tst_params(8) + + def test_params_9(self): + self.tst_params(9) + + def test_params_10(self): + self.tst_params(10) + + def test_params_11(self): + self.tst_params(11) + + def test_params_12(self): + self.tst_params(12) + + def test_params_13(self): + self.tst_params(13) + + def test_params_14(self): + self.tst_params(14) + + def test_params_15(self): + self.tst_params(15) + + def test_params_16(self): + self.tst_params(16) + + def test_params_17(self): + self.tst_params(17) + + def test_params_18(self): + self.tst_params(18) + + def test_params_19(self): + self.tst_params(19) + + def test_params_20(self): + self.tst_params(20) + + def test_params_21(self): + self.tst_params(21) + + def test_params_22(self): + self.tst_params(22) + + def test_params_23(self): + self.tst_params(23) + + def test_params_24(self): + self.tst_params(24) + + def test_params_25(self): + self.tst_params(25) + + def test_params_26(self): + self.tst_params(26) + + def test_params_27(self): + self.tst_params(27) + + def test_params_28(self): + self.tst_params(28) + + def test_params_29(self): + self.tst_params(29) + + def test_params_30(self): + self.tst_params(30) + + def test_params_31(self): + self.tst_params(31) + + def test_params_32(self): + self.tst_params(32) + + def test_params_33(self): + self.tst_params(33) + + def test_params_34(self): + self.tst_params(34) + + def test_params_35(self): + self.tst_params(35) + + def test_params_36(self): + self.tst_params(36) + + def test_params_37(self): + self.tst_params(37) + + def test_params_38(self): + self.tst_params(38) + + def test_params_39(self): + self.tst_params(39) + + def test_params_40(self): + self.tst_params(40) + + def test_params_41(self): + self.tst_params(41) + + def test_params_42(self): + self.tst_params(42) + + def test_params_43(self): + self.tst_params(43) + + def test_params_44(self): + self.tst_params(44) + + def test_params_45(self): + self.tst_params(45) + + def test_params_46(self): + self.tst_params(46) + + def test_params_47(self): + self.tst_params(47) + + def test_params_48(self): + self.tst_params(48) + + def test_params_49(self): + self.tst_params(49) + + def test_params_50(self): + self.tst_params(50) + + def test_params_51(self): + self.tst_params(51) + + def test_params_52(self): + self.tst_params(52) + + def test_params_53(self): + self.tst_params(53) + + def test_params_54(self): + self.tst_params(54) + + def test_params_55(self): + self.tst_params(55) + + def test_params_56(self): + self.tst_params(56) + + def test_params_57(self): + self.tst_params(57) + + def test_params_58(self): + self.tst_params(58) + + def test_params_59(self): + self.tst_params(59) + + def test_params_60(self): + self.tst_params(60) + + def test_params_61(self): + self.tst_params(61) + + def test_params_62(self): + self.tst_params(62) + + def test_params_63(self): + self.tst_params(63) + + def test_params_64(self): + self.tst_params(64) + if __name__ == "__main__": unittest.main() -- 2.30.2