working on fixing DivPipeCore's test cases
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_algorithm.py
old mode 100644 (file)
new mode 100755 (executable)
index c826412..9200b3b
@@ -1,10 +1,12 @@
+#!/usr/bin/env python3
 # SPDX-License-Identifier: LGPL-2.1-or-later
 # See Notices.txt for copyright information
 
 from nmigen.hdl.ast import Const
 from .algorithm import (div_rem, UnsignedDivRem, DivRem,
                         Fixed, RootRemainder, fixed_sqrt, FixedSqrt,
-                        fixed_rsqrt, FixedRSqrt)
+                        fixed_rsqrt, FixedRSqrt, Operation,
+                        FixedUDivRemSqrtRSqrt)
 import unittest
 import math
 
@@ -295,14 +297,22 @@ class TestUnsignedDivRem(unittest.TestCase):
                 with self.subTest(n=n, d=d, q=q, r=r):
                     udr = UnsignedDivRem(n, d, bit_width, log2_radix)
                     for _ in range(250 * bit_width):
-                        self.assertEqual(n, udr.quotient * udr.divisor
-                                         + udr.remainder)
+                        self.assertEqual(udr.dividend, n)
+                        self.assertEqual(udr.divisor, d)
+                        self.assertEqual(udr.quotient_times_divisor,
+                                         udr.quotient * udr.divisor)
+                        self.assertGreaterEqual(udr.dividend,
+                                                udr.quotient_times_divisor)
                         if udr.calculate_stage():
                             break
                     else:
                         self.fail("infinite loop")
-                    self.assertEqual(n, udr.quotient * udr.divisor
-                                     + udr.remainder)
+                    self.assertEqual(udr.dividend, n)
+                    self.assertEqual(udr.divisor, d)
+                    self.assertEqual(udr.quotient_times_divisor,
+                                     udr.quotient * udr.divisor)
+                    self.assertGreaterEqual(udr.dividend,
+                                            udr.quotient_times_divisor)
                     self.assertEqual(udr.quotient, q)
                     self.assertEqual(udr.remainder, r)
 
@@ -389,7 +399,7 @@ class TestFixed(unittest.TestCase):
         self.assertEqual(value.bit_width, 8)
         self.assertEqual(value.signed, True)
 
-    def helper_test_from_bits(self, bit_width, fract_width):
+    def helper_tst_from_bits(self, bit_width, fract_width):
         signed = False
         for bits in range(1 << bit_width):
             with self.subTest(bit_width=bit_width,
@@ -416,7 +426,7 @@ class TestFixed(unittest.TestCase):
     def test_from_bits(self):
         for bit_width in range(1, 5):
             for fract_width in range(bit_width):
-                self.helper_test_from_bits(bit_width, fract_width)
+                self.helper_tst_from_bits(bit_width, fract_width)
 
     def test_repr(self):
         self.assertEqual(repr(Fixed.from_bits(1, 2, 3, False)),
@@ -872,3 +882,255 @@ class TestFixedRSqrt(unittest.TestCase):
 
     def test_radix_16(self):
         self.helper(4)
+
+
+class TestFixedUDivRemSqrtRSqrt(unittest.TestCase):
+    @staticmethod
+    def show_fixed(bits, fract_width, bit_width):
+        fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
+        return f"{str(fixed)}:{repr(fixed)}"
+
+    def check_invariants(self,
+                         dividend,
+                         divisor_radicand,
+                         operation,
+                         bit_width,
+                         fract_width,
+                         log2_radix,
+                         obj):
+        self.assertEqual(obj.dividend, dividend)
+        self.assertEqual(obj.divisor_radicand, divisor_radicand)
+        self.assertEqual(obj.operation, operation)
+        self.assertEqual(obj.bit_width, bit_width)
+        self.assertEqual(obj.fract_width, fract_width)
+        self.assertEqual(obj.log2_radix, log2_radix)
+        self.assertEqual(obj.root_times_radicand,
+                         obj.quotient_root * obj.divisor_radicand)
+        self.assertGreaterEqual(obj.compare_lhs, obj.compare_rhs)
+        self.assertEqual(obj.remainder, obj.compare_lhs - obj.compare_rhs)
+        if operation is Operation.UDivRem:
+            self.assertEqual(obj.compare_lhs, obj.dividend << fract_width)
+            self.assertEqual(obj.compare_rhs,
+                             (obj.quotient_root * obj.divisor_radicand)
+                             << fract_width)
+        elif operation is Operation.SqrtRem:
+            self.assertEqual(obj.compare_lhs,
+                             obj.divisor_radicand << (fract_width * 2))
+            self.assertEqual(obj.compare_rhs,
+                             (obj.quotient_root * obj.quotient_root)
+                             << fract_width)
+        else:
+            assert operation is Operation.RSqrtRem
+            self.assertEqual(obj.compare_lhs,
+                             1 << (fract_width * 3))
+            self.assertEqual(obj.compare_rhs,
+                             obj.quotient_root * obj.quotient_root
+                             * obj.divisor_radicand)
+
+    def handle_case(self,
+                    dividend,
+                    divisor_radicand,
+                    operation,
+                    bit_width,
+                    fract_width,
+                    log2_radix):
+        dividend_str = self.show_fixed(dividend,
+                                       fract_width * 2,
+                                       bit_width + fract_width)
+        divisor_radicand_str = self.show_fixed(divisor_radicand,
+                                               fract_width,
+                                               bit_width)
+        with self.subTest(dividend=dividend_str,
+                          divisor_radicand=divisor_radicand_str,
+                          operation=operation.name,
+                          bit_width=bit_width,
+                          fract_width=fract_width,
+                          log2_radix=log2_radix):
+            if operation is Operation.UDivRem:
+                if divisor_radicand == 0:
+                    return
+                quotient_root, remainder = div_rem(dividend,
+                                                   divisor_radicand,
+                                                   bit_width * 3,
+                                                   False)
+                remainder <<= fract_width
+            elif operation is Operation.SqrtRem:
+                root_remainder = fixed_sqrt(Fixed.from_bits(divisor_radicand,
+                                                            fract_width,
+                                                            bit_width,
+                                                            False))
+                self.assertEqual(root_remainder.root.bit_width,
+                                 bit_width)
+                self.assertEqual(root_remainder.root.fract_width,
+                                 fract_width)
+                self.assertEqual(root_remainder.remainder.bit_width,
+                                 bit_width * 2)
+                self.assertEqual(root_remainder.remainder.fract_width,
+                                 fract_width * 2)
+                quotient_root = root_remainder.root.bits
+                remainder = root_remainder.remainder.bits << fract_width
+            else:
+                assert operation is Operation.RSqrtRem
+                if divisor_radicand == 0:
+                    return
+                root_remainder = fixed_rsqrt(Fixed.from_bits(divisor_radicand,
+                                                             fract_width,
+                                                             bit_width,
+                                                             False))
+                self.assertEqual(root_remainder.root.bit_width,
+                                 bit_width)
+                self.assertEqual(root_remainder.root.fract_width,
+                                 fract_width)
+                self.assertEqual(root_remainder.remainder.bit_width,
+                                 bit_width * 3)
+                self.assertEqual(root_remainder.remainder.fract_width,
+                                 fract_width * 3)
+                quotient_root = root_remainder.root.bits
+                remainder = root_remainder.remainder.bits
+            if quotient_root >= (1 << bit_width):
+                return
+            quotient_root_str = self.show_fixed(quotient_root,
+                                                fract_width,
+                                                bit_width)
+            remainder_str = self.show_fixed(remainder,
+                                            fract_width * 3,
+                                            bit_width * 3)
+            with self.subTest(quotient_root=quotient_root_str,
+                              remainder=remainder_str):
+                obj = FixedUDivRemSqrtRSqrt(dividend,
+                                            divisor_radicand,
+                                            operation,
+                                            bit_width,
+                                            fract_width,
+                                            log2_radix)
+                for _ in range(250 * bit_width):
+                    self.check_invariants(dividend,
+                                          divisor_radicand,
+                                          operation,
+                                          bit_width,
+                                          fract_width,
+                                          log2_radix,
+                                          obj)
+                    if obj.calculate_stage():
+                        break
+                else:
+                    self.fail("infinite loop")
+                self.check_invariants(dividend,
+                                      divisor_radicand,
+                                      operation,
+                                      bit_width,
+                                      fract_width,
+                                      log2_radix,
+                                      obj)
+                self.assertEqual(obj.quotient_root, quotient_root)
+                self.assertEqual(obj.remainder, remainder)
+
+    def helper(self, log2_radix, operation):
+        bit_width_range = range(1, 8)
+        if operation is Operation.UDivRem:
+            bit_width_range = range(1, 6)
+        for bit_width in bit_width_range:
+            for fract_width in range(bit_width):
+                for divisor_radicand in range(1 << bit_width):
+                    dividend_range = range(1)
+                    if operation is Operation.UDivRem:
+                        dividend_range = range(1 << (bit_width + fract_width))
+                    for dividend in dividend_range:
+                        self.handle_case(dividend,
+                                         divisor_radicand,
+                                         operation,
+                                         bit_width,
+                                         fract_width,
+                                         log2_radix)
+
+    def test_radix_2_UDiv(self):
+        self.helper(1, Operation.UDivRem)
+
+    def test_radix_4_UDiv(self):
+        self.helper(2, Operation.UDivRem)
+
+    def test_radix_8_UDiv(self):
+        self.helper(3, Operation.UDivRem)
+
+    def test_radix_16_UDiv(self):
+        self.helper(4, Operation.UDivRem)
+
+    def test_radix_2_Sqrt(self):
+        self.helper(1, Operation.SqrtRem)
+
+    def test_radix_4_Sqrt(self):
+        self.helper(2, Operation.SqrtRem)
+
+    def test_radix_8_Sqrt(self):
+        self.helper(3, Operation.SqrtRem)
+
+    def test_radix_16_Sqrt(self):
+        self.helper(4, Operation.SqrtRem)
+
+    def test_radix_2_RSqrt(self):
+        self.helper(1, Operation.RSqrtRem)
+
+    def test_radix_4_RSqrt(self):
+        self.helper(2, Operation.RSqrtRem)
+
+    def test_radix_8_RSqrt(self):
+        self.helper(3, Operation.RSqrtRem)
+
+    def test_radix_16_RSqrt(self):
+        self.helper(4, Operation.RSqrtRem)
+
+    def test_int_div(self):
+        bit_width = 8
+        fract_width = 4
+        log2_radix = 3
+        for dividend in range(1 << bit_width):
+            for divisor in range(1, 1 << bit_width):
+                obj = FixedUDivRemSqrtRSqrt(dividend,
+                                            divisor,
+                                            Operation.UDivRem,
+                                            bit_width,
+                                            fract_width,
+                                            log2_radix)
+                obj.calculate()
+                quotient, remainder = div_rem(dividend,
+                                              divisor,
+                                              bit_width,
+                                              False)
+                shifted_remainder = remainder << fract_width
+                with self.subTest(dividend=dividend,
+                                  divisor=divisor,
+                                  quotient=quotient,
+                                  remainder=remainder,
+                                  shifted_remainder=shifted_remainder):
+                    self.assertEqual(obj.quotient_root, quotient)
+                    self.assertEqual(obj.remainder, shifted_remainder)
+
+    def test_fract_div(self):
+        bit_width = 8
+        fract_width = 4
+        log2_radix = 3
+        for dividend in range(1 << bit_width):
+            for divisor in range(1, 1 << bit_width):
+                obj = FixedUDivRemSqrtRSqrt(dividend << fract_width,
+                                            divisor,
+                                            Operation.UDivRem,
+                                            bit_width,
+                                            fract_width,
+                                            log2_radix)
+                obj.calculate()
+                quotient = (dividend << fract_width) // divisor
+                if quotient >= (1 << bit_width):
+                    continue
+                remainder = (dividend << fract_width) % divisor
+                shifted_remainder = remainder << fract_width
+                with self.subTest(dividend=dividend,
+                                  divisor=divisor,
+                                  quotient=quotient,
+                                  remainder=remainder,
+                                  shifted_remainder=shifted_remainder):
+                    self.assertEqual(obj.quotient_root, quotient)
+                    self.assertEqual(obj.remainder, shifted_remainder)
+
+
+if __name__ == '__main__':
+    unittest.main()