From 25d0f2f7a9d5387f15319e95be5b72f860c871f9 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 1 Jul 2019 03:21:45 -0700 Subject: [PATCH] implement fixed_sqrt --- src/ieee754/div_rem_sqrt_rsqrt/algorithm.py | 86 +++++++++++++++++-- .../div_rem_sqrt_rsqrt/test_algorithm.py | 54 +++++++++++- 2 files changed, 133 insertions(+), 7 deletions(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index 199450ed..b5cde64d 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -200,6 +200,28 @@ class Fixed: self.bit_width = bit_width self.signed = signed + def with_bits(self, bits): + """ Create a new Fixed with the specified bits. + + :param bits: the new bits. + :returns Fixed: the new Fixed. + """ + return self.from_bits(bits, + self.fract_width, + self.bit_width, + self.signed) + + def with_value(self, value): + """ Create a new Fixed with the specified value. + + :param value: the new value. + :returns Fixed: the new Fixed. + """ + return Fixed(value, + self.fract_width, + self.bit_width, + self.signed) + def __repr__(self): """ Get representation.""" retval = f"Fixed.from_bits({self.bits}, {self.fract_width}, " @@ -217,7 +239,7 @@ class Fixed: def __float__(self): """ Convert to float.""" - return self.bits * 2 ** -self.fract_width + return self.bits * 2.0 ** -self.fract_width def __floor__(self): """ Floor to integer.""" @@ -403,9 +425,63 @@ class Fixed: return retval -def fixed_sqrt(): - # FIXME: finish - raise NotImplementedError() +class RootRemainder: + """ A polynomial root and remainder. + + :attribute root: the polynomial root. + :attribute remainder: the remainder. + """ + + def __init__(self, root, remainder): + """ Create a new RootRemainder. + + :param root: the polynomial root. + :param remainder: the remainder. + """ + self.root = root + self.remainder = remainder + + def __repr__(self): + """ Get the representation as a string. """ + return f"RootRemainder({repr(self.root)}, {repr(self.remainder)})" + + def __str__(self): + """ Convert to a string. """ + return f"RootRemainder({str(self.root)}, {str(self.remainder)})" + + +def fixed_sqrt(radicand): + """ Compute the Square Root and Remainder. + + Solves the polynomial ``radicand - x * x == 0`` + + :param radicand: the ``Fixed`` to take the square root of. + :returns RootRemainder: + """ + # Written for correctness, not speed + if radicand < 0: + return None + is_int = isinstance(radicand, int) + if is_int: + radicand = Fixed(radicand, 0, radicand.bit_length() + 1, True) + elif not isinstance(radicand, Fixed): + raise TypeError() + + def is_remainder_non_negative(root): + return radicand >= root * root + + root = radicand.with_bits(0) + for i in reversed(range(root.bit_width)): + new_root = root.with_bits(root.bits | (1 << i)) + if new_root < 0: # skip sign bit + continue + if is_remainder_non_negative(new_root): + root = new_root + remainder = radicand - root * root + if is_int: + root = int(root) + remainder = int(remainder) + return RootRemainder(root, remainder) class FixedSqrt: @@ -413,7 +489,7 @@ class FixedSqrt: pass -def fixed_rsqrt(): +def fixed_rsqrt(radicand): # FIXME: finish raise NotImplementedError() diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py index a72f9243..a9356263 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py @@ -3,7 +3,8 @@ from nmigen.hdl.ast import Const from .algorithm import (div_rem, UnsignedDivRem, DivRem, - Fixed, fixed_sqrt, FixedSqrt, fixed_rsqrt, FixedRSqrt) + Fixed, RootRemainder, fixed_sqrt, FixedSqrt, + fixed_rsqrt, FixedRSqrt) import unittest import math @@ -678,4 +679,53 @@ class TestFixed(unittest.TestCase): "fixed:0x1.23450") -# FIXME: add tests for fract_sqrt, FractSqrt, fract_rsqrt, and FractRSqrt +class TestFixedSqrtFn(unittest.TestCase): + def test_on_ints(self): + for radicand in range(-1, 32): + if radicand < 0: + expected = None + else: + root = math.floor(math.sqrt(radicand)) + remainder = radicand - root * root + expected = RootRemainder(root, remainder) + with self.subTest(radicand=radicand, expected=expected): + self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected)) + radicand = 2 << 64 + root = 0x16A09E667 + remainder = radicand - root * root + expected = RootRemainder(root, remainder) + with self.subTest(radicand=radicand, expected=expected): + self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected)) + + def test_on_fixed(self): + for signed in False, True: + for bit_width in range(1, 10): + for fract_width in range(bit_width): + for bits in range(1 << bit_width): + radicand = Fixed.from_bits(bits, + fract_width, + bit_width, + signed) + if radicand < 0: + continue + root = radicand.with_value(math.sqrt(float(radicand))) + remainder = radicand - root * root + expected = RootRemainder(root, remainder) + with self.subTest(radicand=repr(radicand), + expected=repr(expected)): + self.assertEqual(repr(fixed_sqrt(radicand)), + repr(expected)) + + def test_misc_cases(self): + test_cases = [ + # radicand, expected + (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))), + (Fixed(2, 30, 32, False), + "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)") + ] + for radicand, expected in test_cases: + with self.subTest(radicand=str(radicand), expected=expected): + self.assertEqual(str(fixed_sqrt(radicand)), expected) + + +# FIXME: add tests for FixedSqrt, fixed_rsqrt, and FixedRSqrt -- 2.30.2