From 397136bfba5b2888b050aac4448f3aa22b0d8709 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 2 Jul 2019 21:44:56 -0700 Subject: [PATCH] implement FixedRSqrt --- src/ieee754/div_rem_sqrt_rsqrt/algorithm.py | 86 ++++++++++++++++++- .../div_rem_sqrt_rsqrt/test_algorithm.py | 49 ++++++++++- 2 files changed, 132 insertions(+), 3 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index c073dd6f..7fa051f2 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -592,5 +592,87 @@ def fixed_rsqrt(radicand): class FixedRSqrt: - # FIXME: finish - pass + """ Fixed-point Reciprocal-Square-Root/Remainder. + + :attribute radicand: the radicand + :attribute root: the reciprocal square root + :attribute radicand_root: ``radicand * root`` + :attribute radicand_root_squared: ``radicand * root * root`` + :attribute remainder: the remainder + :attribute log2_radix: the base-2 log of the operation radix. The number of + bits of root that are calculated per pipeline stage. + :attribute current_shift: the current bit index + """ + + def __init__(self, radicand, log2_radix=3): + """ Create an FixedRSqrt. + + :param radicand: the radicand. + :param log2_radix: the base-2 log of the operation radix. The number of + bits of root that are calculated per pipeline stage. + """ + assert isinstance(radicand, Fixed) + assert radicand.signed is False + self.radicand = radicand + self.root = radicand.with_bits(0) + self.radicand_root = radicand.with_bits(0) * self.root + self.radicand_root_squared = self.radicand_root * self.root + self.remainder = radicand.with_bits(0) - self.radicand_root_squared + self.log2_radix = log2_radix + self.current_shift = self.root.bit_width + + def calculate_stage(self): + """ Calculate the next pipeline stage of the operation. + + :returns bool: True if this is the last pipeline stage. + """ + if self.current_shift == 0: + return True + log2_radix = min(self.log2_radix, self.current_shift) + assert log2_radix > 0 + self.current_shift -= log2_radix + radix = 1 << log2_radix + trial_values = [] + for i in range(radix): + v = self.radicand_root_squared + factor1 = Fixed.from_bits(i << (self.current_shift + 1), + self.root.fract_width, + self.root.bit_width + 1 + log2_radix, + False) + v += self.radicand_root * factor1 + factor2 = Fixed.from_bits(i << self.current_shift, + self.root.fract_width, + self.root.bit_width + log2_radix, + False) + v += self.radicand * factor2 * factor2 + trial_values.append(self.radicand_root_squared.with_value(v)) + root_bits = 0 + new_radicand_root_squared = self.radicand_root_squared + for i in range(radix): + if 1 >= trial_values[i]: + root_bits = i + new_radicand_root_squared = trial_values[i] + v = self.radicand_root + v += self.radicand * Fixed.from_bits(root_bits << self.current_shift, + self.root.fract_width, + self.root.bit_width + log2_radix, + False) + self.radicand_root = self.radicand_root.with_value(v) + self.root |= Fixed.from_bits(root_bits << self.current_shift, + self.root.fract_width, + self.root.bit_width + log2_radix, + False) + self.radicand_root_squared = new_radicand_root_squared + if self.current_shift == 0: + self.remainder = 1 - self.radicand_root_squared + return True + return False + + def calculate(self): + """ Calculate the results of the reciprocal square root. + + :returns: self + """ + while not self.calculate_stage(): + pass + return self diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py index f633f1a4..c8264124 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py @@ -824,4 +824,51 @@ class TestFixedRSqrtFn(unittest.TestCase): self.assertEqual(str(fixed_rsqrt(radicand)), expected) -# FIXME: add tests for FixedRSqrt +class TestFixedRSqrt(unittest.TestCase): + def helper(self, log2_radix): + for bit_width in range(1, 8): + for fract_width in range(bit_width): + for radicand_bits in range(1, 1 << bit_width): + radicand = Fixed.from_bits(radicand_bits, + fract_width, + bit_width, + False) + root_remainder = fixed_rsqrt(radicand) + with self.subTest(radicand=repr(radicand), + root_remainder=repr(root_remainder), + log2_radix=log2_radix): + obj = FixedRSqrt(radicand, log2_radix) + for _ in range(250 * bit_width): + self.assertEqual(obj.radicand * obj.root, + obj.radicand_root) + self.assertEqual(obj.radicand_root * obj.root, + obj.radicand_root_squared) + self.assertGreaterEqual(1, + obj.radicand_root_squared) + if obj.calculate_stage(): + break + else: + self.fail("infinite loop") + self.assertEqual(obj.radicand * obj.root, + obj.radicand_root) + self.assertEqual(obj.radicand_root * obj.root, + obj.radicand_root_squared) + self.assertGreaterEqual(1, + obj.radicand_root_squared) + self.assertEqual(obj.remainder, + 1 - obj.radicand_root_squared) + self.assertEqual(obj.root, root_remainder.root) + self.assertEqual(obj.remainder, + root_remainder.remainder) + + def test_radix_2(self): + self.helper(1) + + def test_radix_4(self): + self.helper(2) + + def test_radix_8(self): + self.helper(3) + + def test_radix_16(self): + self.helper(4) -- 2.30.2