class FixedSqrt:
- # FIXME: finish
- pass
+ """ Fixed-point Square-Root/Remainder.
+
+ :attribute radicand: the radicand
+ :attribute root: the square root
+ :attribute root_squared: the square of ``root``
+ :attribute remainder: the remainder
+ :attribute log2_radix: the base-2 log of the division radix. The number of
+ bits of quotient that are calculated per pipeline stage.
+ :attribute current_shift: the current bit index
+ """
+
+ def __init__(self, radicand, log2_radix=3):
+ """ Create an FixedSqrt.
+
+ :param radicand: the radicand.
+ :param log2_radix: the base-2 log of the division radix. The number of
+ bits of result that are calculated per pipeline stage.
+ """
+ assert isinstance(radicand, Fixed)
+ assert radicand.signed is False
+ self.radicand = radicand
+ self.root = radicand.with_bits(0)
+ self.root_squared = self.root * self.root
+ self.remainder = radicand.with_bits(0) - self.root_squared
+ self.log2_radix = log2_radix
+ self.current_shift = self.root.bit_width
+
+ def calculate_stage(self):
+ """ Calculate the next pipeline stage of the division.
+
+ :returns bool: True if this is the last pipeline stage.
+ """
+ if self.current_shift == 0:
+ return True
+ log2_radix = min(self.log2_radix, self.current_shift)
+ assert log2_radix > 0
+ self.current_shift -= log2_radix
+ radix = 1 << log2_radix
+ trial_squares = []
+ for i in range(radix):
+ v = self.root_squared
+ factor1 = Fixed.from_bits(i << (self.current_shift + 1),
+ self.root.fract_width,
+ self.root.bit_width + 1 + log2_radix,
+ False)
+ v += self.root * factor1
+ factor2 = Fixed.from_bits(i << self.current_shift,
+ self.root.fract_width,
+ self.root.bit_width + log2_radix,
+ False)
+ v += factor2 * factor2
+ trial_squares.append(self.root_squared.with_value(v))
+ root_bits = 0
+ new_root_squared = self.root_squared
+ for i in range(radix):
+ if self.radicand >= trial_squares[i]:
+ root_bits = i
+ new_root_squared = trial_squares[i]
+ self.root |= Fixed.from_bits(root_bits << self.current_shift,
+ self.root.fract_width,
+ self.root.bit_width + log2_radix,
+ False)
+ self.root_squared = new_root_squared
+ if self.current_shift == 0:
+ self.remainder = self.radicand - self.root_squared
+ return True
+ return False
+
+ def calculate(self):
+ """ Calculate the results of the square root.
+
+ :returns: self
+ """
+ while not self.calculate_stage():
+ pass
+ return self
def fixed_rsqrt(radicand):
self.assertEqual(str(fixed_sqrt(radicand)), expected)
-# FIXME: add tests for FixedSqrt
+class TestFixedSqrt(unittest.TestCase):
+ def helper(self, log2_radix):
+ for bit_width in range(1, 8):
+ for fract_width in range(bit_width):
+ for radicand_bits in range(1 << bit_width):
+ radicand = Fixed.from_bits(radicand_bits,
+ fract_width,
+ bit_width,
+ False)
+ root_remainder = fixed_sqrt(radicand)
+ with self.subTest(radicand=repr(radicand),
+ root_remainder=repr(root_remainder),
+ log2_radix=log2_radix):
+ obj = FixedSqrt(radicand, log2_radix)
+ for _ in range(250 * bit_width):
+ self.assertEqual(obj.root * obj.root,
+ obj.root_squared)
+ self.assertGreaterEqual(obj.radicand,
+ obj.root_squared)
+ if obj.calculate_stage():
+ break
+ else:
+ self.fail("infinite loop")
+ self.assertEqual(obj.root * obj.root,
+ obj.root_squared)
+ self.assertGreaterEqual(obj.radicand,
+ obj.root_squared)
+ self.assertEqual(obj.remainder,
+ obj.radicand - obj.root_squared)
+ self.assertEqual(obj.root, root_remainder.root)
+ self.assertEqual(obj.remainder,
+ root_remainder.remainder)
+
+ def test_radix_2(self):
+ self.helper(1)
+
+ def test_radix_4(self):
+ self.helper(2)
+
+ def test_radix_8(self):
+ self.helper(3)
+
+ def test_radix_16(self):
+ self.helper(4)
class TestFixedRSqrtFn(unittest.TestCase):