def __floor__(self):
return self.bits >> self.frac_wid
+ def div(self, rhs, frac_wid, round_dir=RoundDir.ERROR_IF_INEXACT):
+ assert isinstance(frac_wid, int) and frac_wid >= 0
+ assert isinstance(round_dir, RoundDir)
+ rhs = FixedPoint.cast(rhs)
+ return FixedPoint.with_frac_wid(self.as_fraction()
+ / rhs.as_fraction(),
+ frac_wid, round_dir)
+
+ def sqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
+ assert isinstance(round_dir, RoundDir)
+ if self < 0:
+ raise ValueError("can't compute sqrt of negative number")
+ if self == 0:
+ return self
+ retval = FixedPoint(0, self.frac_wid)
+ int_part_wid = self.bits.bit_length() - self.frac_wid
+ first_bit_index = -(-int_part_wid // 2) # division rounds up
+ last_bit_index = -self.frac_wid
+ for bit_index in range(first_bit_index, last_bit_index - 1, -1):
+ trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
+ self.frac_wid)
+ if trial * trial <= self:
+ retval = trial
+ if round_dir == RoundDir.DOWN:
+ pass
+ elif round_dir == RoundDir.UP:
+ if retval * retval < self:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.NEAREST_TIES_UP:
+ half_way = retval + FixedPoint(1, self.frac_wid + 1)
+ if half_way * half_way <= self:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.ERROR_IF_INEXACT:
+ if retval * retval != self:
+ raise ValueError("inexact sqrt")
+ else:
+ assert False, "unimplemented round_dir"
+ return retval
+
+ def rsqrt(self, round_dir=RoundDir.ERROR_IF_INEXACT):
+ """compute the reciprocal-sqrt of `self`"""
+ assert isinstance(round_dir, RoundDir)
+ if self < 0:
+ raise ValueError("can't compute rsqrt of negative number")
+ if self == 0:
+ raise ZeroDivisionError("can't compute rsqrt of zero")
+ retval = FixedPoint(0, self.frac_wid)
+ first_bit_index = -(-self.frac_wid // 2) # division rounds up
+ last_bit_index = -self.frac_wid
+ for bit_index in range(first_bit_index, last_bit_index - 1, -1):
+ trial = retval + FixedPoint(1 << (bit_index + self.frac_wid),
+ self.frac_wid)
+ if trial * trial * self <= 1:
+ retval = trial
+ if round_dir == RoundDir.DOWN:
+ pass
+ elif round_dir == RoundDir.UP:
+ if retval * retval * self < 1:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.NEAREST_TIES_UP:
+ half_way = retval + FixedPoint(1, self.frac_wid + 1)
+ if half_way * half_way * self <= 1:
+ retval += FixedPoint(1, self.frac_wid)
+ elif round_dir == RoundDir.ERROR_IF_INEXACT:
+ if retval * retval * self != 1:
+ raise ValueError("inexact rsqrt")
+ else:
+ assert False, "unimplemented round_dir"
+ return retval
+
@dataclass
class GoldschmidtDivState:
assert state.remainder is not None
return state.quotient, state.remainder
+
+
+GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID = 2
+
+
+@lru_cache()
+def goldschmidt_sqrt_rsqrt_table(table_addr_bits, table_data_bits):
+ """Generate the look-up table needed for Goldschmidt's square-root and
+ reciprocal-square-root algorithm.
+
+ arguments:
+ table_addr_bits: int
+ the number of address bits for the look-up table.
+ table_data_bits: int
+ the number of data bits for the look-up table.
+ """
+ assert isinstance(table_addr_bits, int) and \
+ table_addr_bits >= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ assert isinstance(table_data_bits, int) and table_data_bits >= 1
+ table = []
+ table_len = 1 << table_addr_bits
+ for addr in range(table_len):
+ if addr == 0:
+ value = FixedPoint(0, table_data_bits)
+ elif (addr << 2) < table_len:
+ value = None # table entries should be unused
+ else:
+ table_addr_frac_wid = table_addr_bits
+ table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ max_input_value = FixedPoint(addr + 1, table_addr_bits - 2)
+ max_frac_wid = max(max_input_value.frac_wid, table_data_bits)
+ value = max_input_value.to_frac_wid(max_frac_wid)
+ value = value.rsqrt(RoundDir.DOWN)
+ value = value.to_frac_wid(table_data_bits, RoundDir.DOWN)
+ table.append(value)
+
+ # tuple for immutability
+ return tuple(table)
+
+
+def goldschmidt_sqrt_rsqrt(radicand, io_width, frac_wid, extra_precision,
+ table_addr_bits, table_data_bits, iter_count):
+ """Goldschmidt's square-root and reciprocal-square-root algorithm.
+
+ uses algorithm based on second method at:
+ https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Goldschmidt%E2%80%99s_algorithm
+
+ arguments:
+ radicand: FixedPoint(frac_wid=frac_wid)
+ the input value to take the square-root and reciprocal-square-root of.
+ io_width: int
+ the number of bits in the input (`radicand`) and output values.
+ frac_wid: int
+ the number of fraction bits in the input (`radicand`) and output
+ values.
+ extra_precision: int
+ the number of bits of internal extra precision.
+ table_addr_bits: int
+ the number of address bits for the look-up table.
+ table_data_bits: int
+ the number of data bits for the look-up table.
+
+ returns: tuple[FixedPoint, FixedPoint]
+ the square-root and reciprocal-square-root, rounded down to the
+ nearest representable value. If `radicand == 0`, then the
+ reciprocal-square-root value returned is zero.
+ """
+ assert (isinstance(radicand, FixedPoint)
+ and radicand.frac_wid == frac_wid
+ and 0 <= radicand.bits < (1 << io_width))
+ assert isinstance(io_width, int) and io_width >= 1
+ assert isinstance(frac_wid, int) and 0 <= frac_wid < io_width
+ assert isinstance(extra_precision, int) and extra_precision >= io_width
+ assert isinstance(table_addr_bits, int) and table_addr_bits >= 1
+ assert isinstance(table_data_bits, int) and table_data_bits >= 1
+ assert isinstance(iter_count, int) and iter_count >= 0
+ expanded_frac_wid = frac_wid + extra_precision
+ s = radicand.to_frac_wid(expanded_frac_wid)
+ sqrt_rshift = extra_precision
+ rsqrt_rshift = extra_precision
+ while s != 0 and s < 1:
+ s = (s * 4).to_frac_wid(expanded_frac_wid)
+ sqrt_rshift += 1
+ rsqrt_rshift -= 1
+ while s >= 4:
+ s = s.div(4, expanded_frac_wid)
+ sqrt_rshift -= 1
+ rsqrt_rshift += 1
+ table = goldschmidt_sqrt_rsqrt_table(table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits)
+ # core goldschmidt sqrt/rsqrt algorithm:
+ # initial setup:
+ table_addr_frac_wid = table_addr_bits
+ table_addr_frac_wid -= GOLDSCHMIDT_SQRT_RSQRT_TABLE_ADDR_INT_WID
+ addr = s.to_frac_wid(table_addr_frac_wid, RoundDir.DOWN)
+ assert 0 <= addr.bits < (1 << table_addr_bits), "table addr out of range"
+ f = table[addr.bits]
+ assert f is not None, "accessed invalid table entry"
+ # use with_frac_wid to fix IDE type deduction
+ f = FixedPoint.with_frac_wid(f, expanded_frac_wid, RoundDir.DOWN)
+ x = (s * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ h = (f * 0.5).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ for _ in range(iter_count):
+ # iteration step:
+ f = (1.5 - x * h).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ x = (x * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ h = (h * f).to_frac_wid(expanded_frac_wid, RoundDir.DOWN)
+ r = 2 * h
+ # now `x` is approximately `sqrt(s)` and `r` is approximately `rsqrt(s)`
+
+ sqrt = FixedPoint(x.bits >> sqrt_rshift, frac_wid)
+ rsqrt = FixedPoint(r.bits >> rsqrt_rshift, frac_wid)
+
+ next_sqrt = FixedPoint(sqrt.bits + 1, frac_wid)
+ if next_sqrt * next_sqrt <= radicand:
+ sqrt = next_sqrt
+
+ next_rsqrt = FixedPoint(rsqrt.bits + 1, frac_wid)
+ if next_rsqrt * next_rsqrt * radicand <= 1 and radicand != 0:
+ rsqrt = next_rsqrt
+ return sqrt, rsqrt
# Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
# of Horizon 2020 EU Programme 957073.
+import math
import unittest
from nmutil.formaltest import FHDLTestCase
from soc.fu.div.experiment.goldschmidt_div_sqrt import (
- GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div, FixedPoint)
+ GoldschmidtDivParams, ParamsNotAccurateEnough, goldschmidt_div,
+ FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
class TestFixedPoint(FHDLTestCase):
round_trip_value = FixedPoint.cast(str(value))
self.assertEqual(value, round_trip_value)
+ @staticmethod
+ def trap(f):
+ try:
+ return f(), None
+ except (ValueError, ZeroDivisionError) as e:
+ return None, e.__class__.__name__
+
+ def test_sqrt(self):
+ for frac_wid in range(8):
+ for bits in range(1 << 9):
+ for round_dir in RoundDir:
+ radicand = FixedPoint(bits, frac_wid)
+ expected_f = math.sqrt(float(radicand))
+ expected = self.trap(lambda: FixedPoint.with_frac_wid(
+ expected_f, frac_wid, round_dir))
+ with self.subTest(radicand=repr(radicand),
+ round_dir=str(round_dir),
+ expected=repr(expected)):
+ result = self.trap(lambda: radicand.sqrt(round_dir))
+ self.assertEqual(result, expected)
+
+ def test_rsqrt(self):
+ for frac_wid in range(8):
+ for bits in range(1, 1 << 9):
+ for round_dir in RoundDir:
+ radicand = FixedPoint(bits, frac_wid)
+ expected_f = 1 / math.sqrt(float(radicand))
+ expected = self.trap(lambda: FixedPoint.with_frac_wid(
+ expected_f, frac_wid, round_dir))
+ with self.subTest(radicand=repr(radicand),
+ round_dir=str(round_dir),
+ expected=repr(expected)):
+ result = self.trap(lambda: radicand.rsqrt(round_dir))
+ self.assertEqual(result, expected)
+
class TestGoldschmidtDiv(FHDLTestCase):
def test_case1(self):
self.tst_params(64)
+class TestGoldschmidtSqrtRSqrt(FHDLTestCase):
+ def tst(self, io_width, frac_wid, extra_precision,
+ table_addr_bits, table_data_bits, iter_count):
+ assert isinstance(io_width, int)
+ assert isinstance(frac_wid, int)
+ assert isinstance(extra_precision, int)
+ assert isinstance(table_addr_bits, int)
+ assert isinstance(table_data_bits, int)
+ assert isinstance(iter_count, int)
+ with self.subTest(io_width=io_width, frac_wid=frac_wid,
+ extra_precision=extra_precision,
+ table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits,
+ iter_count=iter_count):
+ for bits in range(1 << io_width):
+ radicand = FixedPoint(bits, frac_wid)
+ expected_sqrt = radicand.sqrt(RoundDir.DOWN)
+ expected_rsqrt = FixedPoint(0, frac_wid)
+ if radicand > 0:
+ expected_rsqrt = radicand.rsqrt(RoundDir.DOWN)
+ with self.subTest(radicand=repr(radicand),
+ expected_sqrt=repr(expected_sqrt),
+ expected_rsqrt=repr(expected_rsqrt)):
+ sqrt, rsqrt = goldschmidt_sqrt_rsqrt(
+ radicand=radicand, io_width=io_width,
+ frac_wid=frac_wid,
+ extra_precision=extra_precision,
+ table_addr_bits=table_addr_bits,
+ table_data_bits=table_data_bits,
+ iter_count=iter_count)
+ with self.subTest(sqrt=repr(sqrt), rsqrt=repr(rsqrt)):
+ self.assertEqual((sqrt, rsqrt),
+ (expected_sqrt, expected_rsqrt))
+
+ def test1(self):
+ self.tst(io_width=16, frac_wid=8, extra_precision=20,
+ table_addr_bits=4, table_data_bits=28, iter_count=4)
+
+
if __name__ == "__main__":
unittest.main()