From aa722dc113042a01204d4796b341baf64d63b608 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 27 Apr 2022 22:39:08 -0700 Subject: [PATCH] add the goldschmidt sqrt/rsqrt algorithm, still need code to calculate good parameters --- .../fu/div/experiment/goldschmidt_div_sqrt.py | 191 ++++++++++++++++++ .../test/test_goldschmidt_div_sqrt.py | 78 ++++++- 2 files changed, 268 insertions(+), 1 deletion(-) diff --git a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py index 3af5320d..055ff7c1 100644 --- a/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/goldschmidt_div_sqrt.py @@ -244,6 +244,76 @@ class FixedPoint: def __floor__(self): return self.bits >> self.frac_wid + def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT): + assert isinstance(frac_wid, int) and frac_wid >= 0 + assert isinstance(round_dir, RoundDir) + rhs = FixedPoint.cast(rhs) + return FixedPoint.with_frac_wid(self.as_fraction() + / rhs.as_fraction(), + frac_wid, round_dir) + + def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT): + assert isinstance(round_dir, RoundDir) + if self < 0: + raise ValueError("can't compute sqrt of negative number") + if self == 0: + return self + retval = FixedPoint(0, self.frac_wid) + int_part_wid = self.bits.bit_length() - self.frac_wid + first_bit_index = -(-int_part_wid // 2) # division rounds up + last_bit_index = -self.frac_wid + for bit_index in range(first_bit_index, last_bit_index - 1, -1): + trial = retval + FixedPoint(1 << (bit_index + self.frac_wid), + self.frac_wid) + if trial * trial <= self: + retval = trial + if round_dir == RoundDir.DOWN: + pass + elif round_dir == RoundDir.UP: + if retval * retval < self: + retval += FixedPoint(1, self.frac_wid) + elif round_dir == RoundDir.NEAREST_TIES_UP: + half_way = retval + FixedPoint(1, self.frac_wid + 1) + if half_way * half_way <= self: + retval += FixedPoint(1, self.frac_wid) + elif round_dir == RoundDir.ERROR_IF_INEXACT: + if retval * retval != self: + raise ValueError("inexact sqrt") + else: + assert False, "unimplemented round_dir" + return retval + + def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT): + """compute the reciprocal-sqrt of `self`""" + assert isinstance(round_dir, RoundDir) + if self < 0: + raise ValueError("can't compute rsqrt of negative number") + if self == 0: + raise ZeroDivisionError("can't compute rsqrt of zero") + retval = FixedPoint(0, self.frac_wid) + first_bit_index = -(-self.frac_wid // 2) # division rounds up + last_bit_index = -self.frac_wid + for bit_index in range(first_bit_index, last_bit_index - 1, -1): + trial = retval + FixedPoint(1 << (bit_index + self.frac_wid), + self.frac_wid) + if trial * trial * self <= 1: + retval = trial + if round_dir == RoundDir.DOWN: + pass + elif round_dir == RoundDir.UP: + if retval * retval * self < 1: + retval += FixedPoint(1, self.frac_wid) + elif round_dir == RoundDir.NEAREST_TIES_UP: + half_way = retval + FixedPoint(1, self.frac_wid + 1) + if half_way * half_way * self <= 1: + retval += FixedPoint(1, self.frac_wid) + elif round_dir == RoundDir.ERROR_IF_INEXACT: + if retval * retval * self != 1: + raise ValueError("inexact rsqrt") + else: + assert False, "unimplemented round_dir" + return retval + @dataclass class GoldschmidtDivState: @@ -978,3 +1048,124 @@ def goldschmidt_div(n, d, params): assert state.remainder is not None return state.quotient, state.remainder + + +GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2 + + +@lru_cache() +def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits): + """Generate the look-up table needed for Goldschmidt's square-root and + reciprocal-square-root algorithm. + + arguments: + table_addr_bits: int + the number of address bits for the look-up table. + table_data_bits: int + the number of data bits for the look-up table. + """ + assert isinstance(table_addr_bits, int) and \ + table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID + assert isinstance(table_data_bits, int) and table_data_bits >= 1 + table = [] + table_len = 1 << table_addr_bits + for addr in range(table_len): + if addr == 0: + value = FixedPoint(0, table_data_bits) + elif (addr << 2) < table_len: + value = None # table entries should be unused + else: + table_addr_frac_wid = table_addr_bits + table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID + max_input_value = FixedPoint(addr + 1, table_addr_bits - 2) + max_frac_wid = max(max_input_value.frac_wid, table_data_bits) + value = max_input_value.to_frac_wid(max_frac_wid) + value = value.rsqrt(RoundDir.DOWN) + value = value.to_frac_wid(table_data_bits, RoundDir.DOWN) + table.append(value) + + # tuple for immutability + return tuple(table) + + +def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision, + table_addr_bits, table_data_bits, iter_count): + """Goldschmidt's square-root and reciprocal-square-root algorithm. + + uses algorithm based on second method at: + https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm + + arguments: + radicand: FixedPoint(frac_wid=frac_wid) + the input value to take the square-root and reciprocal-square-root of. + io_width: int + the number of bits in the input (`radicand`) and output values. + frac_wid: int + the number of fraction bits in the input (`radicand`) and output + values. + extra_precision: int + the number of bits of internal extra precision. + table_addr_bits: int + the number of address bits for the look-up table. + table_data_bits: int + the number of data bits for the look-up table. + + returns: tuple[FixedPoint, FixedPoint] + the square-root and reciprocal-square-root, rounded down to the + nearest representable value. If `radicand == 0`, then the + reciprocal-square-root value returned is zero. + """ + assert (isinstance(radicand, FixedPoint) + and radicand.frac_wid == frac_wid + and 0 <= radicand.bits < (1 << io_width)) + assert isinstance(io_width, int) and io_width >= 1 + assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width + assert isinstance(extra_precision, int) and extra_precision >= io_width + assert isinstance(table_addr_bits, int) and table_addr_bits >= 1 + assert isinstance(table_data_bits, int) and table_data_bits >= 1 + assert isinstance(iter_count, int) and iter_count >= 0 + expanded_frac_wid = frac_wid + extra_precision + s = radicand.to_frac_wid(expanded_frac_wid) + sqrt_rshift = extra_precision + rsqrt_rshift = extra_precision + while s != 0 and s < 1: + s = (s * 4).to_frac_wid(expanded_frac_wid) + sqrt_rshift += 1 + rsqrt_rshift -= 1 + while s >= 4: + s = s.div(4, expanded_frac_wid) + sqrt_rshift -= 1 + rsqrt_rshift += 1 + table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits, + table_data_bits=table_data_bits) + # core goldschmidt sqrt/rsqrt algorithm: + # initial setup: + table_addr_frac_wid = table_addr_bits + table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID + addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN) + assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range" + f = table[addr.bits] + assert f is not None, "accessed invalid table entry" + # use with_frac_wid to fix IDE type deduction + f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN) + x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN) + h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN) + for _ in range(iter_count): + # iteration step: + f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN) + x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN) + h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN) + r = 2 * h + # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)` + + sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid) + rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid) + + next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid) + if next_sqrt * next_sqrt <= radicand: + sqrt = next_sqrt + + next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid) + if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0: + rsqrt = next_rsqrt + return sqrt, rsqrt diff --git a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py index 9e276341..e2984dc1 100644 --- a/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py +++ b/src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py @@ -4,10 +4,12 @@ # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part # of Horizon 2020 EU Programme 957073. +import math import unittest from nmutil.formaltest import FHDLTestCase from soc.fu.div.experiment.goldschmidt_div_sqrt import ( - GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint) + GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, + FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt) class TestFixedPoint(FHDLTestCase): @@ -19,6 +21,41 @@ class TestFixedPoint(FHDLTestCase): round_trip_value = FixedPoint.cast(str(value)) self.assertEqual(value, round_trip_value) + @staticmethod + def trap(f): + try: + return f(), None + except (ValueError, ZeroDivisionError) as e: + return None, e.__class__.__name__ + + def test_sqrt(self): + for frac_wid in range(8): + for bits in range(1 << 9): + for round_dir in RoundDir: + radicand = FixedPoint(bits, frac_wid) + expected_f = math.sqrt(float(radicand)) + expected = self.trap(lambda: FixedPoint.with_frac_wid( + expected_f, frac_wid, round_dir)) + with self.subTest(radicand=repr(radicand), + round_dir=str(round_dir), + expected=repr(expected)): + result = self.trap(lambda: radicand.sqrt(round_dir)) + self.assertEqual(result, expected) + + def test_rsqrt(self): + for frac_wid in range(8): + for bits in range(1, 1 << 9): + for round_dir in RoundDir: + radicand = FixedPoint(bits, frac_wid) + expected_f = 1 / math.sqrt(float(radicand)) + expected = self.trap(lambda: FixedPoint.with_frac_wid( + expected_f, frac_wid, round_dir)) + with self.subTest(radicand=repr(radicand), + round_dir=str(round_dir), + expected=repr(expected)): + result = self.trap(lambda: radicand.rsqrt(round_dir)) + self.assertEqual(result, expected) + class TestGoldschmidtDiv(FHDLTestCase): def test_case1(self): @@ -257,5 +294,44 @@ class TestGoldschmidtDiv(FHDLTestCase): self.tst_params(64) +class TestGoldschmidtSqrtRSqrt(FHDLTestCase): + def tst(self, io_width, frac_wid, extra_precision, + table_addr_bits, table_data_bits, iter_count): + assert isinstance(io_width, int) + assert isinstance(frac_wid, int) + assert isinstance(extra_precision, int) + assert isinstance(table_addr_bits, int) + assert isinstance(table_data_bits, int) + assert isinstance(iter_count, int) + with self.subTest(io_width=io_width, frac_wid=frac_wid, + extra_precision=extra_precision, + table_addr_bits=table_addr_bits, + table_data_bits=table_data_bits, + iter_count=iter_count): + for bits in range(1 << io_width): + radicand = FixedPoint(bits, frac_wid) + expected_sqrt = radicand.sqrt(RoundDir.DOWN) + expected_rsqrt = FixedPoint(0, frac_wid) + if radicand > 0: + expected_rsqrt = radicand.rsqrt(RoundDir.DOWN) + with self.subTest(radicand=repr(radicand), + expected_sqrt=repr(expected_sqrt), + expected_rsqrt=repr(expected_rsqrt)): + sqrt, rsqrt = goldschmidt_sqrt_rsqrt( + radicand=radicand, io_width=io_width, + frac_wid=frac_wid, + extra_precision=extra_precision, + table_addr_bits=table_addr_bits, + table_data_bits=table_data_bits, + iter_count=iter_count) + with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)): + self.assertEqual((sqrt, rsqrt), + (expected_sqrt, expected_rsqrt)) + + def test1(self): + self.tst(io_width=16, frac_wid=8, extra_precision=20, + table_addr_bits=4, table_data_bits=28, iter_count=4) + + if __name__ == "__main__": unittest.main() -- 2.30.2