From 7a3b4c87f7c088127f3d15b587c65f9ac6b4a453 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 5 Oct 2023 19:57:29 -0700 Subject: [PATCH] add WIP powmod_256 -- asm test is currently disabled since divmod is too slow --- .../decoder/isa/test_caller_svp64_powmod.py | 12 +- src/openpower/test/bigint/powmod.py | 123 +++++++++++++++++- 2 files changed, 133 insertions(+), 2 deletions(-) diff --git a/src/openpower/decoder/isa/test_caller_svp64_powmod.py b/src/openpower/decoder/isa/test_caller_svp64_powmod.py index c9675ee2..055b40d4 100644 --- a/src/openpower/decoder/isa/test_caller_svp64_powmod.py +++ b/src/openpower/decoder/isa/test_caller_svp64_powmod.py @@ -14,7 +14,8 @@ related bugs: import unittest from functools import lru_cache import os -from openpower.test.bigint.powmod import PowModCases, python_divmod_algorithm +from openpower.test.bigint.powmod import ( + PowModCases, python_divmod_algorithm, python_powmod_256_algorithm) from openpower.test.runner import TestRunnerBase @@ -31,6 +32,15 @@ class TestPythonAlgorithms(unittest.TestCase): self.assertEqual(out_q, q) self.assertEqual(out_r, r) + def test_python_powmod_algorithm(self): + for base, exp, mod in PowModCases.powmod_256_test_inputs(): + expected = pow(base, exp, mod) + with self.subTest(base=f"{base:#_x}", exp=f"{exp:#_x}", + mod=f"{mod:#_x}", expected=f"{expected:#_x}"): + out = python_powmod_256_algorithm(base, exp, mod) + with self.subTest(out=f"{out:#_x}"): + self.assertEqual(expected, out) + # writing the test_caller invocation this way makes it work with pytest diff --git a/src/openpower/test/bigint/powmod.py b/src/openpower/test/bigint/powmod.py index 8cdeeabd..4f2f2d93 100644 --- a/src/openpower/test/bigint/powmod.py +++ b/src/openpower/test/bigint/powmod.py @@ -267,6 +267,88 @@ def python_divmod_algorithm(n, d, width=256, log_regex=False): return q, r +POWMOD_256_ASM = ( + # base is in r4-7, exp is in r8-11, mod is in r32-35 + "powmod_256:", + "mfspr 0, 8 # mflr 0", + "std 0, 16(1)", # save return address + "setvl 0, 0, 18, 0, 1, 1", # set VL to 18 + "sv.std *14, -144(1)", # save all callee-save registers + "stdu 1, -176(1)", # create stack frame as required by ABI + + "setvl 0, 0, 4, 0, 1, 1", # set VL to 4 + "sv.or *16, *4, *4", # move base to r16-19 + "sv.or *20, *8, *8", # move exp to r20-23 + "sv.or *24, *32, *32", # move mod to r24-27 + "sv.addi *28, 0, 0", # retval in r28-31 + "addi 28, 0, 1", # retval = 1 + + "addi 14, 0, 256", # ctr in r14 + + "powmod_256_loop:", + "setvl 0, 0, 4, 0, 1, 1", # set VL to 4 + "addi 3, 0, 1 # li 3, 1", # shift amount + "addi 0, 0, 0 # li 0, 0", # dsrd carry + "sv.dsrd/mrr *20, *20, 3, 0", # exp >>= 1; shifted out bit in r0 + "cmpli 0, 1, 0, 0 # cmpldi 0, 0", + "bc 12, 2, powmod_256_else # beq powmod_256_else", # if lsb: + + "sv.or *4, *28, *28", # copy retval to r4-7 + "sv.or *8, *16, *16", # copy base to r8-11 + "bl mul_256_to_512", # prod = retval * base + # prod in r4-11 + + "setvl 0, 0, 4, 0, 1, 1", # set VL to 4 + "sv.or *32, *24, *24", # copy mod to r32-35 + + "bl divmod_512_by_256", # prod % mod + "setvl 0, 0, 4, 0, 1, 1", # set VL to 4 + "sv.or *28, *8, *8", # retval = prod % mod + + "powmod_256_else:", + + "sv.or *4, *16, *16", # copy base to r4-7 + "sv.or *8, *16, *16", # copy base to r8-11 + "bl mul_256_to_512", # prod = base * base + # prod in r4-11 + + "setvl 0, 0, 4, 0, 1, 1", # set VL to 4 + "sv.or *32, *24, *24", # copy mod to r32-35 + + "bl divmod_512_by_256", # prod % mod + "setvl 0, 0, 4, 0, 1, 1", # set VL to 4 + "sv.or *16, *8, *8", # base = prod % mod + + "addic. 14, 14, -1", # decrement ctr and compare against zero + "bc 4, 2, powmod_256_loop # bne powmod_256_loop", + + "setvl 0, 0, 4, 0, 1, 1", # set VL to 4 + "sv.or *4, *28, *28", # move retval to r4-7 + + "addi 1, 1, 176", # teardown stack frame + "ld 0, 16(1)", + "mtspr 8, 0 # mtlr 0", # restore return address + "setvl 0, 0, 18, 0, 1, 1", # set VL to 18 + "sv.ld *14, -144(1)", # restore all callee-save registers + "bclr 20, 0, 0 # blr", + *MUL_256_X_256_TO_512_ASM, + *DIVMOD_512x256_TO_256x256_ASM, +) + + +def python_powmod_256_algorithm(base, exp, mod): + retval = 1 + for _ in range(256): + lsb = bool(exp & 1) # rshift and retrieve lsb + exp >>= 1 + if lsb: + prod = retval * base + retval = prod % mod + prod = base * base + base = prod % mod + return retval + + class PowModCases(TestAccumulatorBase): def call_case(self, instructions, expected, initial_regs, src_loc_at=0): stop_at_pc = 0x10000000 @@ -354,8 +436,47 @@ class PowModCases(TestAccumulatorBase): self.call_case(DIVMOD_512x256_TO_256x256_ASM, e, initial_regs) - # TODO: add 256-bit modular exponentiation + @staticmethod + def powmod_256_test_inputs(): + for i in range(3): + base = hash_256(f"powmod256 input base {i}") + exp = hash_256(f"powmod256 input exp {i}") + mod = hash_256(f"powmod256 input mod {i}") + if i == 0: + base = 2 + exp = 2 ** 256 - 1 + mod = 2 ** 256 - 189 # largest prime less than 2 ** 256 + if mod == 0: + mod = 1 + base %= mod + yield (base, exp, mod) + + @skip_case("FIXME: divmod is too slow to test powmod") + def case_powmod_256(self): + for base, exp, mod in PowModCases.powmod_256_test_inputs(): + expected = pow(base, exp, mod) + with self.subTest(base=f"{base:#_x}", exp=f"{exp:#_x}", + mod=f"{mod:#_x}", expected=f"{expected:#_x}"): + # registers start filled with junk + initial_regs = [0xABCDEF] * 128 + for i in range(4): + # write n in LE order to regs 4-7 + initial_regs[4 + i] = (base >> (64 * i)) % 2**64 + for i in range(4): + # write n in LE order to regs 8-11 + initial_regs[8 + i] = (exp >> (64 * i)) % 2**64 + for i in range(4): + # write d in LE order to regs 32-35 + initial_regs[32 + i] = (mod >> (64 * i)) % 2**64 + # only check regs up to r7 since that's where the output is. + # don't check CR + e = ExpectedState(int_regs=initial_regs[:8], crregs=0) + e.ca = None # ignored + for i in range(4): + # write output in LE order to regs 4-7 + e.intregs[4 + i] = (expected >> (64 * i)) % 2**64 + self.call_case(POWMOD_256_ASM, e, initial_regs) # for running "quick" simple investigations if __name__ == "__main__": -- 2.30.2