self.bit_width = bit_width
self.signed = signed
+ def with_bits(self, bits):
+ """ Create a new Fixed with the specified bits.
+
+ :param bits: the new bits.
+ :returns Fixed: the new Fixed.
+ """
+ return self.from_bits(bits,
+ self.fract_width,
+ self.bit_width,
+ self.signed)
+
+ def with_value(self, value):
+ """ Create a new Fixed with the specified value.
+
+ :param value: the new value.
+ :returns Fixed: the new Fixed.
+ """
+ return Fixed(value,
+ self.fract_width,
+ self.bit_width,
+ self.signed)
+
def __repr__(self):
""" Get representation."""
retval = f"Fixed.from_bits({self.bits}, {self.fract_width}, "
def __float__(self):
""" Convert to float."""
- return self.bits * 2 ** -self.fract_width
+ return self.bits * 2.0 ** -self.fract_width
def __floor__(self):
""" Floor to integer."""
return retval
-def fixed_sqrt():
- # FIXME: finish
- raise NotImplementedError()
+class RootRemainder:
+ """ A polynomial root and remainder.
+
+ :attribute root: the polynomial root.
+ :attribute remainder: the remainder.
+ """
+
+ def __init__(self, root, remainder):
+ """ Create a new RootRemainder.
+
+ :param root: the polynomial root.
+ :param remainder: the remainder.
+ """
+ self.root = root
+ self.remainder = remainder
+
+ def __repr__(self):
+ """ Get the representation as a string. """
+ return f"RootRemainder({repr(self.root)}, {repr(self.remainder)})"
+
+ def __str__(self):
+ """ Convert to a string. """
+ return f"RootRemainder({str(self.root)}, {str(self.remainder)})"
+
+
+def fixed_sqrt(radicand):
+ """ Compute the Square Root and Remainder.
+
+ Solves the polynomial ``radicand - x * x == 0``
+
+ :param radicand: the ``Fixed`` to take the square root of.
+ :returns RootRemainder:
+ """
+ # Written for correctness, not speed
+ if radicand < 0:
+ return None
+ is_int = isinstance(radicand, int)
+ if is_int:
+ radicand = Fixed(radicand, 0, radicand.bit_length() + 1, True)
+ elif not isinstance(radicand, Fixed):
+ raise TypeError()
+
+ def is_remainder_non_negative(root):
+ return radicand >= root * root
+
+ root = radicand.with_bits(0)
+ for i in reversed(range(root.bit_width)):
+ new_root = root.with_bits(root.bits | (1 << i))
+ if new_root < 0: # skip sign bit
+ continue
+ if is_remainder_non_negative(new_root):
+ root = new_root
+ remainder = radicand - root * root
+ if is_int:
+ root = int(root)
+ remainder = int(remainder)
+ return RootRemainder(root, remainder)
class FixedSqrt:
pass
-def fixed_rsqrt():
+def fixed_rsqrt(radicand):
# FIXME: finish
raise NotImplementedError()
from nmigen.hdl.ast import Const
from .algorithm import (div_rem, UnsignedDivRem, DivRem,
- Fixed, fixed_sqrt, FixedSqrt, fixed_rsqrt, FixedRSqrt)
+ Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
+ fixed_rsqrt, FixedRSqrt)
import unittest
import math
"fixed:0x1.23450")
-# FIXME: add tests for fract_sqrt, FractSqrt, fract_rsqrt, and FractRSqrt
+class TestFixedSqrtFn(unittest.TestCase):
+ def test_on_ints(self):
+ for radicand in range(-1, 32):
+ if radicand < 0:
+ expected = None
+ else:
+ root = math.floor(math.sqrt(radicand))
+ remainder = radicand - root * root
+ expected = RootRemainder(root, remainder)
+ with self.subTest(radicand=radicand, expected=expected):
+ self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
+ radicand = 2 << 64
+ root = 0x16A09E667
+ remainder = radicand - root * root
+ expected = RootRemainder(root, remainder)
+ with self.subTest(radicand=radicand, expected=expected):
+ self.assertEqual(repr(fixed_sqrt(radicand)), repr(expected))
+
+ def test_on_fixed(self):
+ for signed in False, True:
+ for bit_width in range(1, 10):
+ for fract_width in range(bit_width):
+ for bits in range(1 << bit_width):
+ radicand = Fixed.from_bits(bits,
+ fract_width,
+ bit_width,
+ signed)
+ if radicand < 0:
+ continue
+ root = radicand.with_value(math.sqrt(float(radicand)))
+ remainder = radicand - root * root
+ expected = RootRemainder(root, remainder)
+ with self.subTest(radicand=repr(radicand),
+ expected=repr(expected)):
+ self.assertEqual(repr(fixed_sqrt(radicand)),
+ repr(expected))
+
+ def test_misc_cases(self):
+ test_cases = [
+ # radicand, expected
+ (2 << 64, str(RootRemainder(0x16A09E667, 0x2B164C28F))),
+ (Fixed(2, 30, 32, False),
+ "RootRemainder(fixed:0x1.6a09e664, fixed:0x0.0000000b2da028f)")
+ ]
+ for radicand, expected in test_cases:
+ with self.subTest(radicand=str(radicand), expected=expected):
+ self.assertEqual(str(fixed_sqrt(radicand)), expected)
+
+
+# FIXME: add tests for FixedSqrt, fixed_rsqrt, and FixedRSqrt