2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 """ modular exponentiation (`pow(x, y, z)`)
9 related bugs:
11 * https://bugs.libre-soc.org/show_bug.cgi?id=1044
12 """
14 from openpower.test.common import TestAccumulatorBase, skip_case
15 from openpower.test.state import ExpectedState
16 from openpower.test.util import assemble
17 from nmutil.sim_util import hash_256
20 MUL_256_X_256_TO_512_ASM = (
21 "mul_256_to_512:",
22 # a is in r4-7, b is in r8-11
23 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
24 "sv.or *32, *4, *4", # move args to r32-39
25 # a is now in r32-35, b is in r36-39, y is in r4-11, t is in r40-44
26 "sv.addi *4, 0, 0", # clear output
27 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
28 "sv.maddedu *4, *32, 36, 8", # first partial-product a * b[0]
30 "sv.maddedu *40, *32, 37, 44", # second partial-product a * b[1]
34 "sv.maddedu *40, *32, 38, 44", # third partial-product a * b[2]
38 "sv.maddedu *40, *32, 39, 44", # final partial-product a * b[3]
41 "bclr 20, 0, 0 # blr",
42 )
45 def _python_mul_algorithm(a, b):
46 # version of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
47 # than 2^64, since that's easier to read.
48 # run this file in a debugger to see all the intermediate values.
50 y = a * b + c
51 return y % 100, y // 100
54 y = a + b + c
55 return y % 100, y // 100
58 y = a + b
59 return y % 100, y // 100
61 y = [0] * 8
62 t = [0] * 5
63 for i in range(4):
64 y[i], y[4] = maddedu(a[0], b[i], y[4])
65 t[4] = 0
66 for i in range(4):
67 t[i], t[4] = maddedu(a[1], b[i], t[4])
68 y[1], ca = addc(y[1], t[0])
69 for i in range(4):
70 y[2 + i], ca = adde(y[2 + i], t[1 + i], ca)
71 t[4] = 0
72 for i in range(4):
73 t[i], t[4] = maddedu(a[2], b[i], t[4])
74 y[2], ca = addc(y[2], t[0])
75 for i in range(4):
76 y[3 + i], ca = adde(y[3 + i], t[1 + i], ca)
77 t[4] = 0
78 for i in range(4):
79 t[i], t[4] = maddedu(a[3], b[i], t[4])
80 y[3], ca = addc(y[3], t[0])
81 for i in range(4):
82 y[4 + i], ca = adde(y[4 + i], t[1 + i], ca)
83 return y
86 DIVMOD_512x256_TO_256x256_ASM = (
87 # extremely slow and simplistic shift and subtract algorithm.
88 # a future task is to rewrite to use Knuth's Algorithm D,
89 # which is generally an order of magnitude faster
90 "divmod_512_by_256:",
91 # n is in r4-11, d is in r32-35
92 "addi 3, 0, 256 # li 3, 256",
93 "mtspr 9, 3 # mtctr 3", # set CTR to 256
94 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
95 # r is in r40-47
96 "sv.or *40, *4, *4", # assign n to r, in r40-47
97 # shifted_d is in r32-39
98 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
99 "addi 3, 0, 1 # li 3, 1", # shift amount
100 "addi 0, 0, 0 # li 0, 0", # dsrd carry
101 "sv.dsrd/mrr *36, *32, 3, 0", # shifted_d = d << (256 - 1)
102 "sv.addi *32, 0, 0", # clear lsb half
103 "sv.or 35, 0, 0", # move carry to correct location
104 # q is in r4-7
105 "sv.addi *4, 0, 0", # clear q
106 "divmod_loop:",
107 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
108 "subfc 0, 0, 0", # set CA
109 # diff is in r48-55
110 "sv.subfe *48, *32, *40", # diff = r - shifted_d
111 # not borrowed is in CA
112 "mcrxrx 0", # move CA to CR0.eq
113 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
114 "addi 0, 0, 0 # li 0, 0", # dsld carry
115 "sv.dsld *4, *4, 3, 0", # q <<= 1 (1 is in r3)
116 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
117 "bc 4, 2, divmod_else # bne divmod_else", # if borrowed goto divmod_else
118 "ori 4, 4, 1", # q |= 1
119 "sv.or *40, *48, *48", # r = diff
120 "divmod_else:",
121 "addi 0, 0, 0 # li 0, 0", # dsld carry
122 "sv.dsld *40, *40, 3, 0", # r <<= 1 (1 is in r3)
123 "bc 16, 0, divmod_loop # bdnz divmod_loop",
124 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
125 # r is in r40-47
126 "sv.or *8, *44, 44", # r >>= 256
127 # q is in r4-7, r is in r8-11
128 "bclr 20, 0, 0 # blr",
129 )
132 def python_divmod_algorithm(n, d, width=256):
133 assert n >= 0 and d > 0 and width > 0 and n < (d << width), "invalid input"
134 r = n
135 shifted_d = d << (width - 1)
136 q = 0
137 for _ in range(width):
138 diff = r - shifted_d
139 borrowed = diff < 0
140 q <<= 1
141 if not borrowed:
142 q |= 1
143 r = diff
144 r <<= 1
145 return q, r >> width
148 class PowModCases(TestAccumulatorBase):
149 def call_case(self, instructions, expected, initial_regs, src_loc_at=0):
150 stop_at_pc = 0x10000000
151 sprs = {8: stop_at_pc}
152 expected.intregs[1] = initial_regs[1] = 0x1000000 # set stack pointer
153 expected.pc = stop_at_pc
154 expected.sprs['LR'] = None
156 initial_regs, initial_sprs=sprs,
157 stop_at_pc=stop_at_pc, expected=expected,
158 src_loc_at=src_loc_at + 1)
160 def case_mul_256_x_256_to_512(self):
161 for i in range(10):
162 a = hash_256(f"mul256 input a {i}")
163 b = hash_256(f"mul256 input b {i}")
164 if i == 0:
165 # use known values:
166 a = b = 2**256 - 1
167 elif i == 1:
168 # use known values:
169 a = b = (2**256 - 1) // 0xFF
170 y = a * b
171 with self.subTest(a=f"{a:#_x}", b=f"{b:#_x}", y=f"{y:#_x}"):
172 # registers start filled with junk
173 initial_regs = [0xABCDEF] * 128
174 for i in range(4):
175 # write a in LE order to regs 4-7
176 initial_regs[4 + i] = (a >> (64 * i)) % 2**64
177 # write b in LE order to regs 8-11
178 initial_regs[8 + i] = (b >> (64 * i)) % 2**64
179 # only check regs up to r11 since that's where the output is
180 e = ExpectedState(int_regs=initial_regs[:12])
181 for i in range(8):
182 # write y in LE order to regs 4-11
183 e.intregs[4 + i] = (y >> (64 * i)) % 2**64
185 self.call_case(MUL_256_X_256_TO_512_ASM, e, initial_regs)
187 @skip_case("FIXME: wip -- currently broken")
188 def case_divmod_512x256_to_256x256(self):
189 for i in range(10):
190 n = hash_256(f"divmod256 input n msb {i}")
191 n <<= 256
192 n |= hash_256(f"divmod256 input n lsb {i}")
193 d = hash_256(f"divmod256 input d {i}")
194 if i == 0:
195 # use known values:
196 n = 2 ** (256 - 1)
197 d = 1
198 elif i == 1:
199 # use known values:
200 n = 2 ** (512 - 1) - 1
201 d = 2 ** 256 - 1
202 if d == 0:
203 d = 1
204 if n >= d << 256:
205 n -= d << 256
206 q, r = divmod(n, d)
207 with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",
208 q=f"{q:#_x}", r=f"{r:#_x}"):
209 # registers start filled with junk
210 initial_regs = [0xABCDEF] * 128
211 for i in range(8):
212 # write n in LE order to regs 4-11
213 initial_regs[4 + i] = (n >> (64 * i)) % 2**64
214 for i in range(4):
215 # write d in LE order to regs 32-35
216 initial_regs[32 + i] = (d >> (64 * i)) % 2**64
217 # only check regs up to r11 since that's where the output is.
218 # don't check CR
219 e = ExpectedState(int_regs=initial_regs[:12], crregs=0)
220 e.intregs[0] = 0 # leftovers -- ignore
221 e.intregs[3] = 1 # leftovers -- ignore
222 for i in range(4):
223 # write q in LE order to regs 4-7
224 e.intregs[4 + i] = (q >> (64 * i)) % 2**64
225 # write r in LE order to regs 8-11
226 e.intregs[8 + i] = (r >> (64 * i)) % 2**64
228 self.call_case(DIVMOD_512x256_TO_256x256_ASM, e, initial_regs)
230 # TODO: add 256-bit modular exponentiation
233 if __name__ == "__main__":
234 a = b = 99, 99, 99, 99
235 assert _python_mul_algorithm(a, b) == [1, 0, 0, 0, 98, 99, 99, 99]