WIP getting asm version of knuth's algorithm d working
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 12 Oct 2023 03:29:51 +0000 (20:29 -0700)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 1 Dec 2023 17:58:20 +0000 (17:58 +0000)
src/openpower/test/bigint/powmod.py

index 6127b6042073a2ddc23a6be5a40ca9c438b5aa2c..ea4eb29de941e621aeebd7569fd223a79035c296 100644 (file)
@@ -294,6 +294,18 @@ def python_divmod_shift_sub_algorithm(n, d, width=256, log_regex=False):
     return q, r
 
 
+def divmod2du(RA, RB, RC):
+    # type: (int, int, int) -> tuple[int, int, bool]
+    if RC < RB and RB != 0:
+        RT, RS = divmod(RC << 64 | RA, RB)
+        overflow = False
+    else:
+        overflow = True
+        RT = (1 << 64) - 1
+        RS = 0
+    return RT, RS, overflow
+
+
 @plain_data()
 class DivModKnuthAlgorithmD:
     __slots__ = "num_size", "denom_size", "q_size", "word_size", "regs"
@@ -316,7 +328,7 @@ class DivModKnuthAlgorithmD:
                 "n_0": 4,
                 "d_0": 32,
                 "u": 36,
-                "m": 3,
+                "m": 9,
                 "v": 32,
                 "n_scalar": 8,
                 "q": 4,
@@ -331,8 +343,13 @@ class DivModKnuthAlgorithmD:
                 "t_for_unnorm": 3,
                 "s_for_unnorm": 34,
                 "qhat": 12,
-                "rhat": 0,
-                "t_for_prod": 0,
+                "rhat_lo": 14,
+                "rhat_hi": 15,
+                "t_for_prod": 18,
+                "index": 3,
+                "j": 11,
+                "qhat_denom": 18,
+                "qhat_num_hi": 16,
             }
 
         self.num_size = num_size
@@ -411,12 +428,9 @@ class DivModKnuthAlgorithmD:
                 t = u[self.q_size]
                 m = self.q_size
             do_log(locals(), t="t_single", n=None)
+            do_log(locals(), m=None)  # VL = m, so we don't need it in a GPR
             for i in reversed(range(m)):
-                # divmod2du
-                t <<= self.word_size
-                t += u[i]
-                q[i] = t // v[0]
-                t %= v[0]
+                q[i], t, _ = divmod2du(u[i], v[0], t)
                 do_log(locals())
             r = [0] * self.r_size  # remainder
             r[0] = t
@@ -425,7 +439,7 @@ class DivModKnuthAlgorithmD:
 
         if m < n:
             r = [None] * self.r_size  # remainder
-            do_log(locals(), r=("r", self.r_size), n=None)
+            do_log(locals(), r=("r", self.r_size), m=None, n=None)
             # dividend < divisor
             for i in range(self.r_size):
                 r[i] = u[i]
@@ -438,10 +452,12 @@ class DivModKnuthAlgorithmD:
 
         # calculate amount to shift by -- count leading zeros
         s = 0
-        while (v[n - 1] << s) >> (self.word_size - 1) == 0:
+        index = n - 1
+        do_log(locals(), index="index")
+        while (v[index] << s) >> (self.word_size - 1) == 0:
             s += 1
 
-        do_log(locals(), s="s_scalar")
+        do_log(locals(), s="s_scalar", index=None)
 
         if s != 0:
             on_corner_case("non-zero shift")
@@ -467,39 +483,58 @@ class DivModKnuthAlgorithmD:
             un[i] = t % 2 ** self.word_size
             t >>= self.word_size
             do_log(locals())
-        un[m] = t
+        index = m
+        do_log(locals(), index="index")
+        un[index] = t
 
-        do_log(locals(), u=None, t=None)
+        do_log(locals(), u=None, t=None, index=None)
 
         # Step D2 and Step D7: loop
         for j in range(min(m - n, self.q_size - 1), -1, -1):
+            do_log(locals(), j="j")
             # Step D3: calculate q̂
 
-            t = un[j + n]
-            t <<= self.word_size
-            t += un[j + n - 1]
-            if un[j + n] >= vn[n - 1]:
+            index = j + n
+            do_log(locals(), index="index")
+            qhat_num_hi = un[index]
+            do_log(locals(), qhat_num_hi="qhat_num_hi")
+            index = n - 1
+            do_log(locals())
+            qhat_denom = vn[index]
+            do_log(locals(), qhat_denom="qhat_denom")
+            index = j + n - 1
+            do_log(locals())
+            qhat, rhat_lo, ov = divmod2du(un[index], qhat_denom, qhat_num_hi)
+            rhat_hi = 0
+            do_log(locals(), qhat="qhat", rhat_lo="rhat_lo", rhat_hi="rhat_hi")
+            if ov:
                 # division overflows word
                 on_corner_case("qhat overflows word")
-                qhat = 2 ** self.word_size - 1
-                rhat = t - qhat * vn[n - 1]
-            else:
-                # divmod2du
-                qhat = t // vn[n - 1]
-                rhat = t % vn[n - 1]
-
-            do_log(locals(), qhat="qhat", rhat="rhat")
+                assert qhat_num_hi == qhat_denom
+                rhat_lo = (qhat * qhat_denom) % 2 ** self.word_size
+                rhat_hi = (qhat * qhat_denom) >> self.word_size
+                do_log(locals())
+                borrow = un[index] < rhat_lo
+                rhat_lo = (un[index] - rhat_lo) % 2 ** self.word_size
+                do_log(locals())
+                rhat_hi = qhat_num_hi - rhat_hi - borrow
+            do_log(locals(), qhat_num_hi=None, qhat_denom=None)
 
-            while rhat < 2 ** self.word_size:
-                if qhat * vn[n - 2] > (rhat << self.word_size) + un[j + n - 2]:
+            while rhat_hi == 0:
+                if qhat * vn[n - 2] > (rhat_lo << self.word_size) + un[j + n - 2]:
                     on_corner_case("qhat adjustment")
                     qhat -= 1
-                    rhat += vn[n - 1]
+                    do_log(locals())
+                    carry = (rhat_lo + vn[n - 1]) >= 2 ** self.word_size
+                    rhat_lo += vn[n - 1]
+                    rhat_lo %= 2 ** self.word_size
+                    do_log(locals())
+                    rhat_hi = carry
                     do_log(locals())
                 else:
                     break
 
-            do_log(locals(), rhat=None)
+            do_log(locals(), rhat_lo=None, rhat_hi=None, index=None)
 
             # Step D4: multiply and subtract
 
@@ -549,12 +584,12 @@ class DivModKnuthAlgorithmD:
             do_log(locals())
 
         # Step D8: un-normalize
-        do_log(locals(), s="s_for_unnorm", vn=None)
+        do_log(locals(), s="s_for_unnorm", vn=None, m=None, j=None)
         r = [0] * self.r_size  # remainder
         do_log(locals(), r=("r", self.r_size), n="n_for_unnorm")
         # r = un >> s
         t = 0
-        do_log(locals(), t="t_for_unnorm", m=None)
+        do_log(locals(), t="t_for_unnorm")
         for i in reversed(range(n)):
             # dsrd
             t <<= self.word_size
@@ -586,8 +621,13 @@ class DivModKnuthAlgorithmD:
         t_for_unnorm = self.regs["t_for_unnorm"]
         s_for_unnorm = self.regs["s_for_unnorm"]
         qhat = self.regs["qhat"]
-        rhat = self.regs["rhat"]
+        rhat_lo = self.regs["rhat_lo"]
+        rhat_hi = self.regs["rhat_hi"]
         t_for_prod = self.regs["t_for_prod"]
+        index = self.regs["index"]
+        j = self.regs["j"]
+        qhat_num_hi = self.regs["qhat_num_hi"]
+        qhat_denom = self.regs["qhat_denom"]
         num_size = self.num_size
         denom_size = self.denom_size
         q_size = self.q_size
@@ -600,6 +640,12 @@ class DivModKnuthAlgorithmD:
         # n in n_0 size num_size
         # d in d_0 size denom_size
 
+        yield "mfspr 0, 8 # mflr 0"
+        yield "std 0, 16(1)"  # save return address
+        yield "setvl 0, 0, 18, 0, 1, 1"  # set VL to 18
+        yield "sv.std *14, -144(1)"  # save all callee-save registers
+        yield "stdu 1, -176(1)"  # create stack frame as required by ABI
+
         # switch to names used by Knuth's algorithm D
         yield f"setvl 0, 0, {num_size}, 0, 1, 1"  # set VL to num_size
         yield f"sv.or *{u}, *{n_0}, *{n_0}"  # u = n
@@ -645,7 +691,7 @@ class DivModKnuthAlgorithmD:
         # n = denom_size - n
         yield f"subfic {n_scalar}, {n_scalar}, {denom_size}"
 
-        yield f"cmpi 0, 1, {n_scalar}, 1  # cmpi {n_scalar}, 1"
+        yield f"cmpi 0, 1, {n_scalar}, 1  # cmpdi {n_scalar}, 1"
         yield "bc 4, 2, divmod_skip_sw_divisor # bne divmod_skip_sw_divisor"
 
         # Knuth's algorithm D requires the divisor to have length >= 2
@@ -661,11 +707,68 @@ class DivModKnuthAlgorithmD:
         yield f"setvl 0, 0, {r_size - 1}, 0, 1, 1"  # set VL to r_size - 1
         yield f"sv.addi *{r + 1}, 0, 0"  # r[1:] = [0] * (r_size - 1)
 
-        yield "bclr 20, 0, 0 # blr"
+        yield "b divmod_return"
 
         yield "divmod_skip_sw_divisor:"
+        yield f"cmp 0, 1, {m}, {n_scalar}  # cmpd {m}, {n_scalar}"
+        yield "bc 4, 0, divmod_skip_copy_r # bge divmod_skip_copy_r"
+        # if m < n:
+
+        yield f"setvl 0, 0, {r_size}, 0, 1, 1"  # set VL to r_size
+        yield f"sv.or *{r}, *{u}, *{u}"  # r[...] = u[...]
+        yield "b divmod_return"
 
+        yield "divmod_skip_copy_r:"
+
+        # Knuth's algorithm D starts here:
+
+        # Step D1: normalize
+
+        # calculate amount to shift by -- count leading zeros
+        yield f"addi {index}, {n_scalar}, -1"  # index = n - 1
+        assert index == 3, "index must be r3"
+        yield f"setvl. 0, 0, {denom_size}, 0, 1, 1"  # VL = denom_size
+        yield f"sv.cntlzd/m=1<<r3 {s_scalar}, *{v}"  # s = clz64(v[index])
+
+        yield f"addi {t_for_uv_shift}, 0, 0"  # t = 0
+        yield f"setvl. 0, {n_scalar}, {denom_size}, 0, 1, 1"  # VL = n
+        # vn = v << s
+        yield f"sv.dsld *{vn}, *{v}, {s_scalar}, {t_for_uv_shift}"
+
+        yield f"addi {t_for_uv_shift}, 0, 0"  # t = 0
+        yield f"setvl. 0, {m}, {num_size}, 0, 1, 1"  # VL = m
+        # un = u << s
+        yield f"sv.dsld *{un}, *{u}, {s_scalar}, {t_for_uv_shift}"
+        yield f"setvl. 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
+        yield f"or {index}, {m}, {m}"  # index = m
+        assert index == 3, "index must be r3"
+        # un[index] = t
+        yield f"sv.or/m=1<<r3 *{un}, {t_for_uv_shift}, {t_for_uv_shift}"
+
+        # Step D2 and Step D7: loop
+        # j = m - n
+        yield f"subf {j}, {n_scalar}, {m}"
+        # j = min(j, q_size - 1)
+        yield f"addi 0, 0, {q_size - 1}"
+        yield f"minmax {j}, {j}, 0, 0  # maxd {j}, {j}, 0"
+        yield f"divmod_loop:"
+
+        # Step D3: calculate q̂
+        yield f"setvl. 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
         # FIXME: finish
+
+        # Step D2 and Step D7: loop
+        yield f"addic. {j}, {j}, -1"  # j -= 1
+        yield f"bc 4, 0, divmod_loop # bge divmod_loop"
+
+        # FIXME: finish
+
+        yield "divmod_return:"
+        yield "addi 1, 1, 176"  # teardown stack frame
+        yield "ld 0, 16(1)"
+        yield "mtspr 8, 0 # mtlr 0"  # restore return address
+        yield "setvl 0, 0, 18, 0, 1, 1"  # set VL to 18
+        yield "sv.ld *14, -144(1)"  # restore all callee-save registers
         yield "bclr 20, 0, 0 # blr"
 
     @cached_property
@@ -861,9 +964,12 @@ class PowModCases(TestAccumulatorBase):
         cases = list(self.divmod_512x256_to_256x256_test_inputs())
         asm = DivModKnuthAlgorithmD().asm
         for n, d in cases:
-            if d >= 2 ** 64:
-                # FIXME: only single-word part of algorithm implemented,
-                # so we skip the ones that we expect to fail
+            skip = d >= 2 ** 64
+            if n << 64 < n:
+                skip = False
+            if skip:
+                # FIXME: only part of the algorithm is implemented,
+                # so we skip the cases that we expect to fail
                 continue
             q, r = divmod(n, d)
             with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",