finish writing python_divmod_knuth_algorithm_d
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 9 Oct 2023 04:46:57 +0000 (21:46 -0700)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 1 Dec 2023 17:58:20 +0000 (17:58 +0000)
src/openpower/decoder/isa/test_caller_svp64_powmod.py
src/openpower/test/bigint/powmod.py

index 1ea02a23d2743f6b17ba82000a577076882aee1c..fb57efec3c2ff6979c8a31338a56c96619ad5bee 100644 (file)
@@ -16,7 +16,7 @@ from functools import lru_cache
 import os
 from openpower.test.bigint.powmod import (
     PowModCases, python_divmod_shift_sub_algorithm,
-    python_powmod_256_algorithm)
+    python_divmod_knuth_algorithm_d, python_powmod_256_algorithm)
 from openpower.test.runner import TestRunnerBase
 
 
@@ -33,6 +33,36 @@ class TestPythonAlgorithms(unittest.TestCase):
                     self.assertEqual(out_q, q)
                     self.assertEqual(out_r, r)
 
+    def test_python_divmod_knuth_algorithm_d(self):
+        seen_corner_cases = set()
+        for n, d in PowModCases.divmod_512x256_to_256x256_test_inputs():
+            log_regex = n == 2 ** 511 - 1 and d == 2 ** 256 - 1
+            q, r = divmod(n, d)
+            n = [(n >> 64 * i) % 2 ** 64 for i in range(8)]
+            d = [(d >> 64 * i) % 2 ** 64 for i in range(4)]
+            q = [(q >> 64 * i) % 2 ** 64 for i in range(4)]
+            r = [(r >> 64 * i) % 2 ** 64 for i in range(4)]
+            with self.subTest(n=[f"{i:#_x}" for i in n],
+                              d=[f"{i:#_x}" for i in d],
+                              q=[f"{i:#_x}" for i in q],
+                              r=[f"{i:#_x}" for i in r]):
+                out_q, out_r = python_divmod_knuth_algorithm_d(
+                    n, d, log_regex=log_regex,
+                    on_corner_case=seen_corner_cases.add)
+                with self.subTest(out_q=[f"{i:#_x}" for i in out_q],
+                                  out_r=[f"{i:#_x}" for i in out_r]):
+                    self.assertEqual(out_q, q + [0] * 4)
+                    self.assertEqual(out_r, r)
+
+        # ensure our testing actually covers all the corner cases
+        self.assertEqual(seen_corner_cases, {
+            "single-word divisor",
+            "non-zero shift",
+            "qhat overflows word",
+            "qhat adjustment",
+            "add back",
+        })
+
     def test_python_powmod_algorithm(self):
         for base, exp, mod in PowModCases.powmod_256_test_inputs():
             expected = pow(base, exp, mod)
index a4d9bc50623e0e5167ef4fa99f57dfd0742d114c..cb595645774ba8e8db4ba592ba83ea64384ba98a 100644 (file)
@@ -267,7 +267,8 @@ def python_divmod_shift_sub_algorithm(n, d, width=256, log_regex=False):
     return q, r
 
 
-def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False):
+def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False,
+                                    on_corner_case=lambda desc: None):
     do_log = _DivModRegsRegexLogger(enabled=log_regex).log
 
     # switch to names used by Knuth's algorithm D
@@ -304,6 +305,7 @@ def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False):
         raise ZeroDivisionError
 
     if n == 1:
+        on_corner_case("single-word divisor")
         # Knuth's algorithm D requires the divisor to have length >= 2
         # handle single-word divisors separately
         t = 0
@@ -316,6 +318,12 @@ def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False):
         r[0] = t
         return q, r
 
+    if m < n:
+        # dividend < divisor
+        for i in range(m):
+            r[i] = u[i]
+        return q, r
+
     # Knuth's algorithm D starts here:
 
     # Step D1: normalize
@@ -325,6 +333,9 @@ def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False):
     while (v[n - 1] << s) >> (word_size - 1) == 0:
         s += 1
 
+    if s != 0:
+        on_corner_case("non-zero shift")
+
     # vn = v << s
     t = 0
     for i in range(n):
@@ -351,6 +362,7 @@ def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False):
         t += un[j + n - 1]
         if un[j + n] >= vn[n - 1]:
             # division overflows word
+            on_corner_case("qhat overflows word")
             qhat = 2 ** word_size - 1
             rhat = t - qhat * vn[n - 1]
         else:
@@ -360,6 +372,7 @@ def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False):
 
         while rhat < 2 ** word_size:
             if qhat * vn[n - 2] > (rhat << word_size) + un[j + n - 2]:
+                on_corner_case("qhat adjustment")
                 qhat -= 1
                 rhat += vn[n - 1]
             else:
@@ -378,13 +391,43 @@ def python_divmod_knuth_algorithm_d(n, d, word_size=64, log_regex=False):
         t = 1
         for i in range(n + 1):
             # subfe
-            t += ~product[i] + un[j + i]
+            not_product = ~product[i] % 2 ** word_size
+            t += not_product + un[j + i]
             un[j + i] = t % 2 ** word_size
             t = int(t >= 2 ** word_size)
         need_fixup = not t
 
-        # FIXME(jacob): finish
+        # Step D5: test remainder
+
+        q[j] = qhat
+        if need_fixup:
+
+            # Step D6: add back
+
+            on_corner_case("add back")
+
+            q[j] -= 1
+
+            t = 0
+            for i in range(n):
+                # adde
+                t += un[j + i] + vn[i]
+                un[j + i] = t % 2 ** word_size
+                t = int(t >= 2 ** word_size)
+            un[j + n] += t
+
+    # Step D8: un-normalize
+
+    # r = un >> s
+    t = 0
+    for i in reversed(range(n)):
+        # dsrd
+        t <<= word_size
+        t |= (un[i] << word_size) >> s
+        r[i] = t >> word_size
+        t %= 2 ** word_size
 
+    return q, r
 
 POWMOD_256_ASM = (
     # base is in r4-7, exp is in r8-11, mod is in r32-35
@@ -509,19 +552,26 @@ class PowModCases(TestAccumulatorBase):
 
     @staticmethod
     def divmod_512x256_to_256x256_test_inputs():
-        for i in range(10):
+        yield (2 ** (256 - 1), 1)
+        yield (2 ** (512 - 1) - 1, 2 ** 256 - 1)
+
+        # test qhat overflow
+        yield (0x8000 << 128 | 0xFFFE << 64, 0x8000 << 64 | 0xFFFF)
+
+        # tests where add back is required
+        yield (8 << (192 - 4) | 3, 2 << (192 - 4) | 1)
+        yield (0x8000 << 128 | 3, 0x2000 << 128 | 1)
+        yield (0x7FFF << 192 | 0x8000 << 128, 0x8000 << 128 | 1)
+
+        for i in range(20):
             n = hash_256(f"divmod256 input n msb {i}")
             n <<= 256
             n |= hash_256(f"divmod256 input n lsb {i}")
+            n_shift = hash_256(f"divmod256 input n shift {i}") % 512
+            n >>= n_shift
             d = hash_256(f"divmod256 input d {i}")
-            if i == 0:
-                # use known values:
-                n = 2 ** (256 - 1)
-                d = 1
-            elif i == 1:
-                # use known values:
-                n = 2 ** (512 - 1) - 1
-                d = 2 ** 256 - 1
+            d_shift = hash_256(f"divmod256 input d shift {i}") % 256
+            d >>= d_shift
             if d == 0:
                 d = 1
             n %= d << 256