From e415ed46c70a124ef8a4cbf32b0e097788ecff16 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 2 Jul 2019 21:21:05 -0700 Subject: [PATCH] implemented FixedSqrt --- src/ieee754/div_rem_sqrt_rsqrt/algorithm.py | 78 ++++++++++++++++++- .../div_rem_sqrt_rsqrt/test_algorithm.py | 45 ++++++++++- 2 files changed, 120 insertions(+), 3 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index 580895bf..6b9c2b19 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -485,8 +485,82 @@ def fixed_sqrt(radicand): class FixedSqrt: - # FIXME: finish - pass + """ Fixed-point Square-Root/Remainder. + + :attribute radicand: the radicand + :attribute root: the square root + :attribute root_squared: the square of ``root`` + :attribute remainder: the remainder + :attribute log2_radix: the base-2 log of the division radix. The number of + bits of quotient that are calculated per pipeline stage. + :attribute current_shift: the current bit index + """ + + def __init__(self, radicand, log2_radix=3): + """ Create an FixedSqrt. + + :param radicand: the radicand. + :param log2_radix: the base-2 log of the division radix. The number of + bits of result that are calculated per pipeline stage. + """ + assert isinstance(radicand, Fixed) + assert radicand.signed is False + self.radicand = radicand + self.root = radicand.with_bits(0) + self.root_squared = self.root * self.root + self.remainder = radicand.with_bits(0) - self.root_squared + self.log2_radix = log2_radix + self.current_shift = self.root.bit_width + + def calculate_stage(self): + """ Calculate the next pipeline stage of the division. + + :returns bool: True if this is the last pipeline stage. + """ + if self.current_shift == 0: + return True + log2_radix = min(self.log2_radix, self.current_shift) + assert log2_radix > 0 + self.current_shift -= log2_radix + radix = 1 << log2_radix + trial_squares = [] + for i in range(radix): + v = self.root_squared + factor1 = Fixed.from_bits(i << (self.current_shift + 1), + self.root.fract_width, + self.root.bit_width + 1 + log2_radix, + False) + v += self.root * factor1 + factor2 = Fixed.from_bits(i << self.current_shift, + self.root.fract_width, + self.root.bit_width + log2_radix, + False) + v += factor2 * factor2 + trial_squares.append(self.root_squared.with_value(v)) + root_bits = 0 + new_root_squared = self.root_squared + for i in range(radix): + if self.radicand >= trial_squares[i]: + root_bits = i + new_root_squared = trial_squares[i] + self.root |= Fixed.from_bits(root_bits << self.current_shift, + self.root.fract_width, + self.root.bit_width + log2_radix, + False) + self.root_squared = new_root_squared + if self.current_shift == 0: + self.remainder = self.radicand - self.root_squared + return True + return False + + def calculate(self): + """ Calculate the results of the square root. + + :returns: self + """ + while not self.calculate_stage(): + pass + return self def fixed_rsqrt(radicand): diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py index 91b7b5e7..f633f1a4 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py @@ -728,7 +728,50 @@ class TestFixedSqrtFn(unittest.TestCase): self.assertEqual(str(fixed_sqrt(radicand)), expected) -# FIXME: add tests for FixedSqrt +class TestFixedSqrt(unittest.TestCase): + def helper(self, log2_radix): + for bit_width in range(1, 8): + for fract_width in range(bit_width): + for radicand_bits in range(1 << bit_width): + radicand = Fixed.from_bits(radicand_bits, + fract_width, + bit_width, + False) + root_remainder = fixed_sqrt(radicand) + with self.subTest(radicand=repr(radicand), + root_remainder=repr(root_remainder), + log2_radix=log2_radix): + obj = FixedSqrt(radicand, log2_radix) + for _ in range(250 * bit_width): + self.assertEqual(obj.root * obj.root, + obj.root_squared) + self.assertGreaterEqual(obj.radicand, + obj.root_squared) + if obj.calculate_stage(): + break + else: + self.fail("infinite loop") + self.assertEqual(obj.root * obj.root, + obj.root_squared) + self.assertGreaterEqual(obj.radicand, + obj.root_squared) + self.assertEqual(obj.remainder, + obj.radicand - obj.root_squared) + self.assertEqual(obj.root, root_remainder.root) + self.assertEqual(obj.remainder, + root_remainder.remainder) + + def test_radix_2(self): + self.helper(1) + + def test_radix_4(self): + self.helper(2) + + def test_radix_8(self): + self.helper(3) + + def test_radix_16(self): + self.helper(4) class TestFixedRSqrtFn(unittest.TestCase): -- 2.30.2