WIP divmod: finished writing out asm knuth's algorithm d, still buggy
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 13 Oct 2023 22:13:15 +0000 (15:13 -0700)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 1 Dec 2023 17:58:20 +0000 (17:58 +0000)
src/openpower/test/bigint/powmod.py

index ea4eb29de941e621aeebd7569fd223a79035c296..fd406f78fa20e58c9cffdd5a02c05bb8e4d1c1c9 100644 (file)
@@ -206,7 +206,7 @@ class _DivModRegsRegexLogger:
 
         for k, v in changes.items():
             if v is None:
-                del self.__tracked[k]
+                self.__tracked.pop(k, None)
             else:
                 if isinstance(v, (tuple, list)):
                     start_gpr, size = v
@@ -339,9 +339,9 @@ class DivModKnuthAlgorithmD:
                 "t_single": 8,
                 "s_scalar": 10,
                 "t_for_uv_shift": 0,
-                "n_for_unnorm": 32,
+                "n_for_unnorm": 16,
                 "t_for_unnorm": 3,
-                "s_for_unnorm": 34,
+                "s_for_unnorm": 18,
                 "qhat": 12,
                 "rhat_lo": 14,
                 "rhat_hi": 15,
@@ -350,6 +350,9 @@ class DivModKnuthAlgorithmD:
                 "j": 11,
                 "qhat_denom": 18,
                 "qhat_num_hi": 16,
+                "qhat_prod_lo": 15,
+                "qhat_prod_hi": 18,
+                "sub_len": 3,
             }
 
         self.num_size = num_size
@@ -521,20 +524,34 @@ class DivModKnuthAlgorithmD:
             do_log(locals(), qhat_num_hi=None, qhat_denom=None)
 
             while rhat_hi == 0:
-                if qhat * vn[n - 2] > (rhat_lo << self.word_size) + un[j + n - 2]:
-                    on_corner_case("qhat adjustment")
-                    qhat -= 1
-                    do_log(locals())
-                    carry = (rhat_lo + vn[n - 1]) >= 2 ** self.word_size
-                    rhat_lo += vn[n - 1]
-                    rhat_lo %= 2 ** self.word_size
-                    do_log(locals())
-                    rhat_hi = carry
-                    do_log(locals())
-                else:
+                index = n - 2
+                do_log(locals())
+                qhat_prod_lo = (qhat * vn[index]) % 2 ** self.word_size
+                do_log(locals(), qhat_prod_lo="qhat_prod_lo", rhat_hi=None)
+                qhat_prod_hi = (qhat * vn[index]) >> self.word_size
+                do_log(locals(), qhat_prod_hi="qhat_prod_hi")
+                if qhat_prod_hi < rhat_lo:
                     break
+                index = j + n - 2
+                do_log(locals())
+                if qhat_prod_hi == rhat_lo:
+                    if qhat_prod_lo <= un[index]:
+                        break
+                on_corner_case("qhat adjustment")
+                do_log(locals(), index=None,
+                       qhat_prod_lo=None, qhat_prod_hi=None)
+                qhat -= 1
+                do_log(locals(), index="index")
+                index = n - 1
+                do_log(locals())
+                carry = (rhat_lo + vn[index]) >= 2 ** self.word_size
+                rhat_lo = (rhat_lo + vn[index]) % 2 ** self.word_size
+                do_log(locals())
+                rhat_hi = carry
+                do_log(locals(), rhat_hi="rhat_hi")
 
-            do_log(locals(), rhat_lo=None, rhat_hi=None, index=None)
+            do_log(locals(), rhat_lo=None, rhat_hi=None, index=None,
+                   qhat_prod_lo=None, qhat_prod_hi=None)
 
             # Step D4: multiply and subtract
 
@@ -546,11 +563,18 @@ class DivModKnuthAlgorithmD:
                 product[i] = t % 2 ** self.word_size
                 t >>= self.word_size
                 do_log(locals())
-            product[n] = t
-            do_log(locals(), t=None)
+            index = n
+            do_log(locals(), index="index")
+            product[index] = t
+            do_log(locals(), t=None, index=None)
 
             t = 1
-            for i in range(n + 1):
+            do_log(locals())
+            sub_len = n + 1
+            do_log(locals(), sub_len="sub_len")
+            VL = sub_len
+            do_log(locals(), sub_len=None)
+            for i in range(VL):
                 # subfe
                 not_product = ~product[i] % 2 ** self.word_size
                 t += not_product + un[j + i]
@@ -581,12 +605,16 @@ class DivModKnuthAlgorithmD:
                 do_log(locals())
 
             q[j] = qhat
-            do_log(locals())
+            do_log(locals(), index=None)
 
         # Step D8: un-normalize
-        do_log(locals(), s="s_for_unnorm", vn=None, m=None, j=None)
+
+        # move s and n
+        do_log(locals(), s="s_for_unnorm", n="n_for_unnorm",
+               vn=None, m=None, j=None)
+
         r = [0] * self.r_size  # remainder
-        do_log(locals(), r=("r", self.r_size), n="n_for_unnorm")
+        do_log(locals(), r=("r", self.r_size))
         # r = un >> s
         t = 0
         do_log(locals(), t="t_for_unnorm")
@@ -628,6 +656,9 @@ class DivModKnuthAlgorithmD:
         j = self.regs["j"]
         qhat_num_hi = self.regs["qhat_num_hi"]
         qhat_denom = self.regs["qhat_denom"]
+        qhat_prod_lo = self.regs["qhat_prod_lo"]
+        qhat_prod_hi = self.regs["qhat_prod_hi"]
+        sub_len = self.regs["sub_len"]
         num_size = self.num_size
         denom_size = self.denom_size
         q_size = self.q_size
@@ -670,7 +701,7 @@ class DivModKnuthAlgorithmD:
         yield f"ori 0, 0, {svshape_low}"
         yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
         yield f"svremap 0o01, 0, 0, 0, 0, 0, 0"  # enable SVSHAPE0 for RA
-        yield f"sv.cmpi/ff=ne *0, 1, *{u}, 0"
+        yield f"sv.cmpli/ff=ne *0, 1, *{u}, 0"
         yield f"setvl {m}, 0, 1, 0, 0, 0 # getvl {m}"  # m = VL
         yield f"subfic {m}, {m}, {num_size}"  # m = num_size - m
 
@@ -686,12 +717,12 @@ class DivModKnuthAlgorithmD:
         yield f"ori 0, 0, {svshape_low}"
         yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
         yield f"svremap 0o01, 0, 0, 0, 0, 0, 0"  # enable SVSHAPE0 for RA
-        yield f"sv.cmpi/ff=ne *0, 1, *{v}, 0"
+        yield f"sv.cmpli/ff=ne *0, 1, *{v}, 0"
         yield f"setvl {n_scalar}, 0, 1, 0, 0, 0 # getvl {n_scalar}"  # n = VL
         # n = denom_size - n
         yield f"subfic {n_scalar}, {n_scalar}, {denom_size}"
 
-        yield f"cmpi 0, 1, {n_scalar}, 1  # cmpdi {n_scalar}, 1"
+        yield f"cmpli 0, 1, {n_scalar}, 1  # cmpldi {n_scalar}, 1"
         yield "bc 4, 2, divmod_skip_sw_divisor # bne divmod_skip_sw_divisor"
 
         # Knuth's algorithm D requires the divisor to have length >= 2
@@ -710,7 +741,7 @@ class DivModKnuthAlgorithmD:
         yield "b divmod_return"
 
         yield "divmod_skip_sw_divisor:"
-        yield f"cmp 0, 1, {m}, {n_scalar}  # cmpd {m}, {n_scalar}"
+        yield f"cmpl 0, 1, {m}, {n_scalar}  # cmpld {m}, {n_scalar}"
         yield "bc 4, 0, divmod_skip_copy_r # bge divmod_skip_copy_r"
         # if m < n:
 
@@ -727,19 +758,19 @@ class DivModKnuthAlgorithmD:
         # calculate amount to shift by -- count leading zeros
         yield f"addi {index}, {n_scalar}, -1"  # index = n - 1
         assert index == 3, "index must be r3"
-        yield f"setvl. 0, 0, {denom_size}, 0, 1, 1"  # VL = denom_size
+        yield f"setvl 0, 0, {denom_size}, 0, 1, 1"  # VL = denom_size
         yield f"sv.cntlzd/m=1<<r3 {s_scalar}, *{v}"  # s = clz64(v[index])
 
         yield f"addi {t_for_uv_shift}, 0, 0"  # t = 0
-        yield f"setvl. 0, {n_scalar}, {denom_size}, 0, 1, 1"  # VL = n
+        yield f"setvl 0, {n_scalar}, {denom_size}, 0, 1, 1"  # VL = n
         # vn = v << s
         yield f"sv.dsld *{vn}, *{v}, {s_scalar}, {t_for_uv_shift}"
 
         yield f"addi {t_for_uv_shift}, 0, 0"  # t = 0
-        yield f"setvl. 0, {m}, {num_size}, 0, 1, 1"  # VL = m
+        yield f"setvl 0, {m}, {num_size}, 0, 1, 1"  # VL = m
         # un = u << s
         yield f"sv.dsld *{un}, *{u}, {s_scalar}, {t_for_uv_shift}"
-        yield f"setvl. 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
+        yield f"setvl 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
         yield f"or {index}, {m}, {m}"  # index = m
         assert index == 3, "index must be r3"
         # un[index] = t
@@ -754,14 +785,164 @@ class DivModKnuthAlgorithmD:
         yield f"divmod_loop:"
 
         # Step D3: calculate q̂
-        yield f"setvl. 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
-        # FIXME: finish
+        yield f"setvl 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
+        yield f"add {index}, {j}, {n_scalar}"  # index = j + n
+        # qhat_num_hi = un[index]
+        assert index == 3, "index must be r3"
+        yield f"sv.or/m=1<<r3 {qhat_num_hi}, *{un}, *{un}"
+        yield f"addi {index}, {n_scalar}, -1"  # index = n - 1
+        # qhat_denom = vn[index]
+        yield f"setvl 0, 0, {vn_size}, 0, 1, 1"  # VL = vn_size
+        assert index == 3, "index must be r3"
+        yield f"sv.or/m=1<<r3 {qhat_denom}, *{vn}, *{vn}"
+        yield f"add {index}, {index}, {j}"  # index = j + n - 1
+        # qhat, rhat_lo, ov = divmod2du(un[index], qhat_denom, qhat_num_hi)
+        yield f"or {rhat_lo}, {qhat_num_hi}, {qhat_num_hi}"
+        yield f"setvl 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
+        assert index == 3, "index must be r3"
+        yield f"sv.divmod2du/m=1<<r3 {qhat}, *{un}, {qhat_denom}, {rhat_lo}"
+        yield f"mcrxrx 0"  # move OV to CR0.lt
+        yield "bc 4, 0, divmod_skip_qhat_overflow # bge divmod_..."
+        # if ov:
+        # division overflows word
+        # rhat_lo = (qhat * qhat_denom) % 2 ** self.word_size
+        yield f"mulld {rhat_lo}, {qhat}, {qhat_denom}"
+        # rhat_hi = (qhat * qhat_denom) >> self.word_size
+        yield f"mulhdu {rhat_hi}, {qhat}, {qhat_denom}"
+        # borrow = un[index] < rhat_lo
+        # rhat_lo = (un[index] - rhat_lo) % 2 ** self.word_size
+        assert index == 3, "index must be r3"
+        yield f"sv.subfc/m=1<<r3 {rhat_lo}, {rhat_lo}, *{un}"
+        # rhat_hi = qhat_num_hi - rhat_hi - borrow
+        yield f"subfe {rhat_hi}, {rhat_hi}, {qhat_num_hi}"
+        yield "divmod_skip_qhat_overflow:"
+
+        # while rhat_hi == 0:
+        yield "divmod_qhat_adj_loop:"
+        yield f"cmpli 0, 1, {rhat_hi}, 0  # cmpldi {rhat_hi}, 0"
+        yield "bc 12, 2, divmod_qhat_adj_loop_break # beq divmod_qhat_adj..."
+
+        yield f"setvl 0, 0, {vn_size}, 0, 1, 1"  # VL = vn_size
+        yield f"addi {index}, {n_scalar}, -2"  # index = n - 2
+        # qhat_prod_lo = (qhat * vn[index]) % 2 ** self.word_size
+        assert index == 3, "index must be r3"
+        yield f"sv.mulld/m=1<<r3 {qhat_prod_lo}, {qhat}, *{vn}"
+        # qhat_prod_hi = (qhat * vn[index]) >> self.word_size
+        yield f"sv.mulhdu/m=1<<r3 {qhat_prod_hi}, {qhat}, *{vn}"
+
+        # if qhat_prod_hi < rhat_lo:
+        #     break
+        yield f"cmpl 0, 1, {qhat_prod_hi}, {rhat_lo}  # cmpld cr0, ..."
+        yield "bc 12, 0, divmod_qhat_adj_loop_break # blt divmod_qhat_adj..."
+        # if qhat_prod_hi == rhat_lo:
+        yield "bc 4, 2, divmod_qhat_do_adj # bne divmod_qhat_do_adj"
+
+        yield f"add {index}, {index}, {j}"  # index = j + n - 2
+        # if qhat_prod_lo <= un[index]:
+        #     break
+        yield f"setvl 0, 0, {un_size}, 0, 1, 1"  # VL = un_size
+        assert index == 3, "index must be r3"
+        yield f"sv.cmp/m=1<<r3 1, 1, {qhat_prod_lo}, *{un}  # cmpld cr1, ..."
+        yield "bc 4, 1, divmod_qhat_adj_loop_break # ble divmod_qhat_adj..."
+        yield "divmod_qhat_do_adj:"
+
+        yield f"addi {qhat}, {qhat}, -1"  # qhat -= 1
+
+        yield f"addi {index}, {n_scalar}, -1"  # index = n - 1
+        # carry = (rhat_lo + vn[index]) >= 2 ** self.word_size
+        # rhat_lo = (rhat_lo + vn[index]) % 2 ** self.word_size
+        yield f"setvl 0, 0, {vn_size}, 0, 1, 1"  # VL = vn_size
+        assert index == 3, "index must be r3"
+        yield f"sv.addc/m=1<<r3 {rhat_lo}, {rhat_lo}, *{vn}"
+        # rhat_hi = carry
+        yield f"addi 0, 0, 0"
+        yield f"addze. {rhat_hi}, 0"
+
+        # while rhat_hi == 0:
+        yield "bc 4, 2, divmod_qhat_adj_loop # bne divmod_qhat_adj_loop"
+        yield "divmod_qhat_adj_loop_break:"
+
+        # Step D4: multiply and subtract
+
+        yield f"setvl 0, {n_scalar}, {vn_size}, 0, 1, 1"  # VL = n
+        yield f"addi {t_for_prod}, 0, 0"  # t = 0
+        # product[:n] = vn[:n] * qhat
+        yield f"sv.maddedu *{product}, *{vn}, {qhat}, {t_for_prod}"
+        yield f"or {index}, {n_scalar}, {n_scalar}"  # index = n
+        yield f"setvl 0, 0, {vn_size}, 0, 1, 1"  # VL = vn_size
+        # product[index] = t
+        assert index == 3, "index must be r3"
+        yield f"sv.or/m=1<<r3 *{product}, {t_for_prod}, {t_for_prod}"
+
+        yield "subfc 0, 0, 0"  # t = 1 (t is CA)
+        yield f"addi {sub_len}, {n_scalar}, 1"  # sub_len = n + 1
+        yield f"setvl 0, {sub_len}, {product_size}, 0, 1, 1"  # VL = sub_len
+        # create svshape that offsets by `j`
+        svshape = SVSHAPE(0)
+        svshape.zdimsz = q_size
+        svshape_low = int(svshape) % 2 ** 16
+        svshape_high = int(svshape) >> 16
+        offset_field = svshape.fsi['offset']
+        assert 2 ** (len(offset_field) - 1) >= q_size, \
+            "max needed offset won't fit in SVSHAPE"
+        mask_start_le = len(svshape) - offset_field.br[0] - 1
+        mask_start = 64 - mask_start_le - 1
+        last = len(offset_field) - 1
+        shift_amount = len(svshape) - offset_field.br[last] - 1
+        # insert j in offset field
+        yield f"rldic 0, {j}, {shift_amount}, {mask_start}"
+        # or in all the other bits
+        if svshape_high != 0:
+            yield f"oris 0, 0, {svshape_high}"
+        yield f"ori 0, 0, {svshape_low}"
+        yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
+        yield f"svremap 0o12, 0, 0, 0, 0, 0, 0"  # enable SVSHAPE0 for RB & RT
+        # un[j:] -= product
+        yield f"sv.subfe *{un}, *{product}, *{un}"
+        # need_fixup = not CA
+
+        # Step D5: test remainder
+
+        yield f"mcrxrx 0"  # move CA to CR0.eq
+        # if need_fixup:
+        yield "bc 4, 2, divmod_skip_fixup # bne divmod_skip_fixup"
+
+        # Step D6: add back
+
+        yield f"addi {qhat}, {qhat}, -1"  # qhat -= 1
+        yield "addic 0, 0, 0"  # t = 0 (t is CA)
+        yield f"setvl 0, {n_scalar}, {vn_size}, 0, 1, 1"  # VL = n
+        yield f"svremap 0o11, 0, 0, 0, 0, 0, 0"  # enable SVSHAPE0 for RA & RT
+        # un[j:] += vn
+        yield f"sv.adde *{un}, *{un}, *{vn}"
+        yield f"svremap 0o11, 0, 0, 0, 0, 0, 0"  # enable SVSHAPE0 for RA & RT
+        # un[j + n] += t
+        yield f"sv.addze *{un}, *{un}"
+
+        yield "divmod_skip_fixup:"
+        yield f"setvl 0, 0, {q_size}, 0, 1, 1"  # VL = q_size
+        yield f"svremap 0o10, 0, 0, 0, 0, 0, 0"  # enable SVSHAPE0 for RT
+        # q[j] = qhat
+        yield f"sv.or {q}, {qhat}, {qhat}"
 
         # Step D2 and Step D7: loop
         yield f"addic. {j}, {j}, -1"  # j -= 1
         yield f"bc 4, 0, divmod_loop # bge divmod_loop"
 
-        # FIXME: finish
+        # Step D8: un-normalize
+
+        # move s and n
+        yield f"or {s_for_unnorm}, {s_scalar}, {s_scalar}"
+        yield f"or {n_for_unnorm}, {n_scalar}, {n_scalar}"
+
+        # r = [0] * self.r_size  # remainder
+        yield f"setvl 0, 0, {r_size}, 0, 1, 1"  # VL = r_size
+        yield f"sv.addi *{r}, 0, 0"
+
+        # r = un >> s
+        yield f"addi {t_for_unnorm}, 0, 0"  # t = 0
+        yield f"setvl 0, {n_for_unnorm}, {r_size}, 0, 1, 1"  # VL = n
+        yield f"sv.dsrd/mrr *{r}, *{un}, {s_for_unnorm}, {t_for_unnorm}"
 
         yield "divmod_return:"
         yield "addi 1, 1, 176"  # teardown stack frame
@@ -968,8 +1149,8 @@ class PowModCases(TestAccumulatorBase):
             if n << 64 < n:
                 skip = False
             if skip:
-                # FIXME: only part of the algorithm is implemented,
-                # so we skip the cases that we expect to fail
+                # FIXME: only part of the algorithm works,
+                # so we skip the cases that we know fail
                 continue
             q, r = divmod(n, d)
             with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",