implemented FixedUDivRemSqrtRSqrt
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 3 Jul 2019 08:46:42 +0000 (01:46 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 3 Jul 2019 08:46:42 +0000 (01:46 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index 7fa051f2fdcfed335e04580d713a61199cc63ead..6ba28311401d0165d0eaa890531b827b506deea3 100644 (file)
@@ -8,6 +8,7 @@ code for simulating/testing the various algorithms
 
 from nmigen.hdl.ast import Const
 import math
+import enum
 
 
 def div_rem(dividend, divisor, bit_width, signed):
@@ -676,3 +677,156 @@ class FixedRSqrt:
         while not self.calculate_stage():
             pass
         return self
+
+
+class Operation(enum.Enum):
+    """ Operation for ``FixedUDivRemSqrtRSqrt``. """
+
+    UDivRem = "unsigned-divide/remainder"
+    SqrtRem = "square-root/remainder"
+    RSqrtRem = "reciprocal-square-root/remainder"
+
+
+class FixedUDivRemSqrtRSqrt:
+    """ Combined class for computing fixed-point unsigned div/rem/sqrt/rsqrt.
+
+    Algorithm based on ``UnsignedDivRem``, ``FixedSqrt``, and ``FixedRSqrt``.
+
+    Formulas solved are:
+    * div/rem:
+        ``dividend == quotient_root * divisor_radicand``
+    * sqrt/rem:
+        ``divisor_radicand == quotient_root * quotient_root``
+    * rsqrt/rem:
+        ``1 == quotient_root * quotient_root * divisor_radicand``
+
+    The remainder is the left-hand-side of the comparison minus the
+    right-hand-side of the comparison in the above formulas.
+
+    Important: not all variables have the same bit-width or fract-width. For
+        instance, ``dividend`` has a bit-width of ``bit_width + fract_width``
+        and a fract-width of ``2 * fract_width`` bits.
+
+    :attribute dividend: dividend for div/rem. Variable with a bit-width of
+        ``bit_width + fract_width`` and a fract-width of ``fract_width * 2``
+        bits.
+    :attribute divisor_radicand: divisor for div/rem and radicand for
+        sqrt/rsqrt. Variable with a bit-width of ``bit_width`` and a
+        fract-width of ``fract_width`` bits.
+    :attribute operation: the ``Operation`` to be computed.
+    :attribute quotient_root: the quotient or root part of the result of the
+        operation. Variable with a bit-width of ``bit_width`` and a fract-width
+        of ``fract_width`` bits.
+    :attribute remainder: the remainder part of the result of the operation.
+        Variable with a bit-width of ``bit_width * 3`` and a fract-width
+        of ``fract_width * 3`` bits.
+    :attribute root_times_radicand: ``quotient_root * divisor_radicand``.
+        Variable with a bit-width of ``bit_width * 2`` and a fract-width of
+        ``fract_width * 2`` bits.
+    :attribute compare_lhs: The left-hand-side of the comparison in the
+        equation to be solved. Variable with a bit-width of ``bit_width * 3``
+        and a fract-width of ``fract_width * 3`` bits.
+    :attribute compare_rhs: The right-hand-side of the comparison in the
+        equation to be solved. Variable with a bit-width of ``bit_width * 3``
+        and a fract-width of ``fract_width * 3`` bits.
+    :attribute bit_width: base bit-width. Constant int.
+    :attribute fract_width: base fract-width. Specifies location of base-2
+        radix point. Constant int.
+    :attribute log2_radix: number of bits of ``quotient_root`` that should be
+        computed per pipeline stage (invocation of ``calculate_stage``).
+        Constant int.
+    :attribute current_shift: the current bit index. Variable int.
+    """
+
+    def __init__(self,
+                 dividend,
+                 divisor_radicand,
+                 operation,
+                 bit_width,
+                 fract_width,
+                 log2_radix):
+        """ Create a new ``FixedUDivRemSqrtRSqrt``.
+
+        :param dividend: ``dividend`` attribute's initializer.
+        :param divisor_radicand: ``divisor_radicand`` attribute's initializer.
+        :param operation: ``operation`` attribute's initializer.
+        :param bit_width: ``bit_width`` attribute's initializer.
+        :param fract_width: ``fract_width`` attribute's initializer.
+        :param log2_radix: ``log2_radix`` attribute's initializer.
+        """
+        assert bit_width > 0
+        assert fract_width >= 0
+        assert fract_width <= bit_width
+        assert log2_radix > 0
+        self.dividend = Const.normalize(dividend,
+                                        (bit_width + fract_width, False))
+        self.divisor_radicand = Const.normalize(divisor_radicand,
+                                                (bit_width, False))
+        self.quotient_root = 0
+        self.root_times_radicand = 0
+        if operation is Operation.UDivRem:
+            self.compare_lhs = self.dividend << fract_width
+        elif operation is Operation.SqrtRem:
+            self.compare_lhs = self.divisor_radicand << (fract_width * 2)
+        else:
+            assert operation is Operation.RSqrtRem
+            self.compare_lhs = 1 << (fract_width * 3)
+        self.compare_rhs = 0
+        self.remainder = self.compare_lhs
+        self.operation = operation
+        self.bit_width = bit_width
+        self.fract_width = fract_width
+        self.log2_radix = log2_radix
+        self.current_shift = bit_width
+
+    def calculate_stage(self):
+        """ Calculate the next pipeline stage of the operation.
+
+        :returns bool: True if this is the last pipeline stage.
+        """
+        if self.current_shift == 0:
+            return True
+        log2_radix = min(self.log2_radix, self.current_shift)
+        assert log2_radix > 0
+        self.current_shift -= log2_radix
+        radix = 1 << log2_radix
+        trial_compare_rhs_values = []
+        for trial_bits in range(radix):
+            shifted_trial_bits = trial_bits << self.current_shift
+            shifted_trial_bits_sqrd = shifted_trial_bits * shifted_trial_bits
+            v = self.compare_rhs
+            if self.operation is Operation.UDivRem:
+                factor1 = self.divisor_radicand * shifted_trial_bits
+                v += factor1 << self.fract_width
+            elif self.operation is Operation.SqrtRem:
+                factor1 = self.quotient_root * (shifted_trial_bits << 1)
+                v += factor1 << self.fract_width
+                factor2 = shifted_trial_bits_sqrd
+                v += factor2 << self.fract_width
+            else:
+                assert self.operation is Operation.RSqrtRem
+                factor1 = self.root_times_radicand * (shifted_trial_bits << 1)
+                v += factor1
+                factor2 = self.divisor_radicand * shifted_trial_bits_sqrd
+                v += factor2
+            trial_compare_rhs_values.append(v)
+        shifted_next_bits = 0
+        next_compare_rhs = trial_compare_rhs_values[0]
+        for trial_bits in range(radix):
+            if self.compare_lhs >= trial_compare_rhs_values[trial_bits]:
+                shifted_next_bits = trial_bits << self.current_shift
+                next_compare_rhs = trial_compare_rhs_values[trial_bits]
+        self.root_times_radicand += self.divisor_radicand * shifted_next_bits
+        self.compare_rhs = next_compare_rhs
+        self.quotient_root |= shifted_next_bits
+        self.remainder = self.compare_lhs - self.compare_rhs
+        return self.current_shift == 0
+
+    def calculate(self):
+        """ Calculate the results of the operation.
+
+        :returns: self
+        """
+        while not self.calculate_stage():
+            pass
+        return self
index 081bda97bddc41218a89e345fb0e0e4cfa1be3ec..c5c3e7b3dba11b5741ed06f38be69aa8586df06d 100644 (file)
@@ -4,7 +4,8 @@
 from nmigen.hdl.ast import Const
 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
                         Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
-                        fixed_rsqrt, FixedRSqrt)
+                        fixed_rsqrt, FixedRSqrt, Operation,
+                        FixedUDivRemSqrtRSqrt)
 import unittest
 import math
 
@@ -872,3 +873,199 @@ class TestFixedRSqrt(unittest.TestCase):
 
     def test_radix_16(self):
         self.helper(4)
+
+
+class TestFixedUDivRemSqrtRSqrt(unittest.TestCase):
+    @staticmethod
+    def show_fixed(bits, fract_width, bit_width):
+        fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
+        return f"{str(fixed)}:{repr(fixed)}"
+
+    def check_invariants(self,
+                         dividend,
+                         divisor_radicand,
+                         operation,
+                         bit_width,
+                         fract_width,
+                         log2_radix,
+                         obj):
+        self.assertEqual(obj.dividend, dividend)
+        self.assertEqual(obj.divisor_radicand, divisor_radicand)
+        self.assertEqual(obj.operation, operation)
+        self.assertEqual(obj.bit_width, bit_width)
+        self.assertEqual(obj.fract_width, fract_width)
+        self.assertEqual(obj.log2_radix, log2_radix)
+        self.assertEqual(obj.root_times_radicand,
+                         obj.quotient_root * obj.divisor_radicand)
+        self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs)
+        self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs)
+        if operation is Operation.UDivRem:
+            self.assertEqual(obj.compare_lhs, obj.dividend << fract_width)
+            self.assertEqual(obj.compare_rhs,
+                             (obj.quotient_root * obj.divisor_radicand)
+                             << fract_width)
+        elif operation is Operation.SqrtRem:
+            self.assertEqual(obj.compare_lhs,
+                             obj.divisor_radicand << (fract_width * 2))
+            self.assertEqual(obj.compare_rhs,
+                             (obj.quotient_root * obj.quotient_root)
+                             << fract_width)
+        else:
+            assert operation is Operation.RSqrtRem
+            self.assertEqual(obj.compare_lhs,
+                             1 << (fract_width * 3))
+            self.assertEqual(obj.compare_rhs,
+                             obj.quotient_root * obj.quotient_root
+                             * obj.divisor_radicand)
+
+    def handle_case(self,
+                    dividend,
+                    divisor_radicand,
+                    operation,
+                    bit_width,
+                    fract_width,
+                    log2_radix):
+        dividend_str = self.show_fixed(dividend,
+                                       fract_width * 2,
+                                       bit_width + fract_width)
+        divisor_radicand_str = self.show_fixed(divisor_radicand,
+                                               fract_width,
+                                               bit_width)
+        with self.subTest(dividend=dividend_str,
+                          divisor_radicand=divisor_radicand_str,
+                          operation=operation.name,
+                          bit_width=bit_width,
+                          fract_width=fract_width,
+                          log2_radix=log2_radix):
+            if operation is Operation.UDivRem:
+                if divisor_radicand == 0:
+                    return
+                quotient_root, remainder = div_rem(dividend,
+                                                   divisor_radicand,
+                                                   bit_width * 3,
+                                                   False)
+                remainder <<= fract_width
+            elif operation is Operation.SqrtRem:
+                root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
+                                                            fract_width,
+                                                            bit_width,
+                                                            False))
+                self.assertEqual(root_remainder.root.bit_width,
+                                 bit_width)
+                self.assertEqual(root_remainder.root.fract_width,
+                                 fract_width)
+                self.assertEqual(root_remainder.remainder.bit_width,
+                                 bit_width * 2)
+                self.assertEqual(root_remainder.remainder.fract_width,
+                                 fract_width * 2)
+                quotient_root = root_remainder.root.bits
+                remainder = root_remainder.remainder.bits << fract_width
+            else:
+                assert operation is Operation.RSqrtRem
+                if divisor_radicand == 0:
+                    return
+                root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
+                                                             fract_width,
+                                                             bit_width,
+                                                             False))
+                self.assertEqual(root_remainder.root.bit_width,
+                                 bit_width)
+                self.assertEqual(root_remainder.root.fract_width,
+                                 fract_width)
+                self.assertEqual(root_remainder.remainder.bit_width,
+                                 bit_width * 3)
+                self.assertEqual(root_remainder.remainder.fract_width,
+                                 fract_width * 3)
+                quotient_root = root_remainder.root.bits
+                remainder = root_remainder.remainder.bits
+            if quotient_root >= (1 << bit_width):
+                return
+            quotient_root_str = self.show_fixed(quotient_root,
+                                                fract_width,
+                                                bit_width)
+            remainder_str = self.show_fixed(remainder,
+                                            fract_width * 3,
+                                            bit_width * 3)
+            with self.subTest(quotient_root=quotient_root_str,
+                              remainder=remainder_str):
+                obj = FixedUDivRemSqrtRSqrt(dividend,
+                                            divisor_radicand,
+                                            operation,
+                                            bit_width,
+                                            fract_width,
+                                            log2_radix)
+                for _ in range(250 * bit_width):
+                    self.check_invariants(dividend,
+                                          divisor_radicand,
+                                          operation,
+                                          bit_width,
+                                          fract_width,
+                                          log2_radix,
+                                          obj)
+                    if obj.calculate_stage():
+                        break
+                else:
+                    self.fail("infinite loop")
+                self.check_invariants(dividend,
+                                      divisor_radicand,
+                                      operation,
+                                      bit_width,
+                                      fract_width,
+                                      log2_radix,
+                                      obj)
+                self.assertEqual(obj.quotient_root, quotient_root)
+                self.assertEqual(obj.remainder, remainder)
+
+    def helper(self, log2_radix, operation):
+        bit_width_range = range(1, 8)
+        if operation is Operation.UDivRem:
+            bit_width_range = range(1, 6)
+        for bit_width in bit_width_range:
+            for fract_width in range(bit_width):
+                for divisor_radicand in range(1 << bit_width):
+                    dividend_range = range(1)
+                    if operation is Operation.UDivRem:
+                        dividend_range = range(1 << (bit_width + fract_width))
+                    for dividend in dividend_range:
+                        self.handle_case(dividend,
+                                         divisor_radicand,
+                                         operation,
+                                         bit_width,
+                                         fract_width,
+                                         log2_radix)
+
+    def test_radix_2_UDiv(self):
+        self.helper(1, Operation.UDivRem)
+
+    def test_radix_4_UDiv(self):
+        self.helper(2, Operation.UDivRem)
+
+    def test_radix_8_UDiv(self):
+        self.helper(3, Operation.UDivRem)
+
+    def test_radix_16_UDiv(self):
+        self.helper(4, Operation.UDivRem)
+
+    def test_radix_2_Sqrt(self):
+        self.helper(1, Operation.SqrtRem)
+
+    def test_radix_4_Sqrt(self):
+        self.helper(2, Operation.SqrtRem)
+
+    def test_radix_8_Sqrt(self):
+        self.helper(3, Operation.SqrtRem)
+
+    def test_radix_16_Sqrt(self):
+        self.helper(4, Operation.SqrtRem)
+
+    def test_radix_2_RSqrt(self):
+        self.helper(1, Operation.RSqrtRem)
+
+    def test_radix_4_RSqrt(self):
+        self.helper(2, Operation.RSqrtRem)
+
+    def test_radix_8_RSqrt(self):
+        self.helper(3, Operation.RSqrtRem)
+
+    def test_radix_16_RSqrt(self):
+        self.helper(4, Operation.RSqrtRem)