From 7e0c8fc663f4e17808df8de5282671d1f6fbca57 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 1 Jul 2019 04:01:45 -0700 Subject: [PATCH] implement fixed_rsqrt --- src/ieee754/div_rem_sqrt_rsqrt/algorithm.py | 27 ++++++++- .../div_rem_sqrt_rsqrt/test_algorithm.py | 55 ++++++++++++++++++- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index b5cde64d..580895bf 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -490,8 +490,31 @@ class FixedSqrt: def fixed_rsqrt(radicand): - # FIXME: finish - raise NotImplementedError() + """ Compute the Reciprocal Square Root and Remainder. + + Solves the polynomial ``1 - x * x * radicand == 0`` + + :param radicand: the ``Fixed`` to take the reciprocal square root of. + :returns RootRemainder: + """ + # Written for correctness, not speed + if radicand <= 0: + return None + if not isinstance(radicand, Fixed): + raise TypeError() + + def is_remainder_non_negative(root): + return 1 >= root * root * radicand + + root = radicand.with_bits(0) + for i in reversed(range(root.bit_width)): + new_root = root.with_bits(root.bits | (1 << i)) + if new_root < 0: # skip sign bit + continue + if is_remainder_non_negative(new_root): + root = new_root + remainder = 1 - root * root * radicand + return RootRemainder(root, remainder) class FixedRSqrt: diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py index a9356263..91b7b5e7 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py @@ -728,4 +728,57 @@ class TestFixedSqrtFn(unittest.TestCase): self.assertEqual(str(fixed_sqrt(radicand)), expected) -# FIXME: add tests for FixedSqrt, fixed_rsqrt, and FixedRSqrt +# FIXME: add tests for FixedSqrt + + +class TestFixedRSqrtFn(unittest.TestCase): + def test2(self): + for bits in range(1, 1 << 5): + radicand = Fixed.from_bits(bits, 5, 12, False) + float_root = 1 / math.sqrt(float(radicand)) + root = radicand.with_value(float_root) + remainder = 1 - root * root * radicand + expected = RootRemainder(root, remainder) + with self.subTest(radicand=repr(radicand), + expected=repr(expected)): + self.assertEqual(repr(fixed_rsqrt(radicand)), + repr(expected)) + + def test(self): + for signed in False, True: + for bit_width in range(1, 10): + for fract_width in range(bit_width): + for bits in range(1 << bit_width): + radicand = Fixed.from_bits(bits, + fract_width, + bit_width, + signed) + if radicand <= 0: + continue + float_root = 1 / math.sqrt(float(radicand)) + max_value = radicand.with_bits( + (1 << (bit_width - signed)) - 1) + if float_root > float(max_value): + root = max_value + else: + root = radicand.with_value(float_root) + remainder = 1 - root * root * radicand + expected = RootRemainder(root, remainder) + with self.subTest(radicand=repr(radicand), + expected=repr(expected)): + self.assertEqual(repr(fixed_rsqrt(radicand)), + repr(expected)) + + def test_misc_cases(self): + test_cases = [ + # radicand, expected + (Fixed(0.5, 30, 32, False), + "RootRemainder(fixed:0x1.6a09e664, " + "fixed:0x0.0000000596d014780000000)") + ] + for radicand, expected in test_cases: + with self.subTest(radicand=str(radicand), expected=expected): + self.assertEqual(str(fixed_rsqrt(radicand)), expected) + + +# FIXME: add tests for FixedRSqrt -- 2.30.2