pysvp64db: fix traversal
[openpower-isa.git] / src / openpower / test / bigint / powmod.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
3 # Copyright 2023 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7 #
8 # * https://bugs.libre-soc.org/show_bug.cgi?id=1044
9
10 """ modular exponentiation (`pow(x, y, z)`)
11
12 related bugs:
13
14 * https://bugs.libre-soc.org/show_bug.cgi?id=1044
15 """
16
17 from openpower.test.common import TestAccumulatorBase, skip_case
18 from openpower.test.state import ExpectedState
19 from openpower.test.util import assemble
20 from nmutil.sim_util import hash_256
21 from openpower.util import log
22
23
24 MUL_256_X_256_TO_512_ASM = (
25 "mul_256_to_512:",
26 # a is in r4-7, b is in r8-11
27 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
28 "sv.or *32, *4, *4", # move args to r32-39
29 # a is now in r32-35, b is in r36-39, y is in r4-11, t is in r40-44
30 "sv.addi *4, 0, 0", # clear output
31 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
32 "sv.maddedu *4, *32, 36, 8", # first partial-product a * b[0]
33 "sv.addi 44, 0, 0",
34 "sv.maddedu *40, *32, 37, 44", # second partial-product a * b[1]
35 "sv.addc 5, 5, 40",
36 "sv.adde *6, *6, *41",
37 "sv.addi 44, 0, 0",
38 "sv.maddedu *40, *32, 38, 44", # third partial-product a * b[2]
39 "sv.addc 6, 6, 40",
40 "sv.adde *7, *7, *41",
41 "sv.addi 44, 0, 0",
42 "sv.maddedu *40, *32, 39, 44", # final partial-product a * b[3]
43 "sv.addc 7, 7, 40",
44 "sv.adde *8, *8, *41",
45 "bclr 20, 0, 0 # blr",
46 )
47
48 # TODO: these really need to go into a common util file, see
49 # openpower/decoder/isa/poly1305-donna.py:def _DSRD(lo, hi, sh)
50 # okok they are modulo 100 but you get the general idea
51
52
53 def maddedu(a, b, c):
54 y = a * b + c
55 return y % 100, y // 100
56
57
58 def adde(a, b, c):
59 y = a + b + c
60 return y % 100, y // 100
61
62
63 def addc(a, b):
64 y = a + b
65 return y % 100, y // 100
66
67
68 def python_mul_algorithm(a, b):
69 # version of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
70 # than 2^64, since that's easier to read.
71 # run this file in a debugger to see all the intermediate values.
72 y = [0] * 8
73 t = [0] * 5
74 for i in range(4):
75 y[i], y[4] = maddedu(a[0], b[i], y[4])
76 t[4] = 0
77 for i in range(4):
78 t[i], t[4] = maddedu(a[1], b[i], t[4])
79 y[1], ca = addc(y[1], t[0])
80 for i in range(4):
81 y[2 + i], ca = adde(y[2 + i], t[1 + i], ca)
82 t[4] = 0
83 for i in range(4):
84 t[i], t[4] = maddedu(a[2], b[i], t[4])
85 y[2], ca = addc(y[2], t[0])
86 for i in range(4):
87 y[3 + i], ca = adde(y[3 + i], t[1 + i], ca)
88 t[4] = 0
89 for i in range(4):
90 t[i], t[4] = maddedu(a[3], b[i], t[4])
91 y[3], ca = addc(y[3], t[0])
92 for i in range(4):
93 y[4 + i], ca = adde(y[4 + i], t[1 + i], ca)
94 return y
95
96
97 def python_mul_algorithm2(a, b):
98 # version 2 of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
99 # than 2^64, since that's easier to read.
100 # the idea here is that it will "morph" into something more akin to
101 # using REMAP bigmul (first using REMAP Indexed)
102
103 # create a schedule for use below. the "end of inner loop" marker is 0b01
104 iyl = []
105 il = []
106 for iy in range(4):
107 for i in range(4):
108 iyl.append((iy+i, i==3))
109 il.append(i)
110 for i in range(5):
111 iyl.append((iy+i, i==4))
112 il.append(i)
113
114 y = [0] * 8 # result y and temp t of same size
115 t = [0] * 8 # no need after this to set t[4] to zero
116 for iy in range(4):
117 for i in range(4): # use t[iy+4] as a 64-bit carry
118 t[iy+i], t[iy+4] = maddedu(a[iy], b[i], t[iy+4])
119 ca = 0
120 for i in range(5): # add vec t to y with 1-bit carry
121 idx = iy + i
122 y[idx], ca = adde(y[idx], t[idx], ca)
123 return y
124
125
126 DIVMOD_512x256_TO_256x256_ASM = (
127 # extremely slow and simplistic shift and subtract algorithm.
128 # a future task is to rewrite to use Knuth's Algorithm D,
129 # which is generally an order of magnitude faster
130 "divmod_512_by_256:",
131 # n is in r4-11, d is in r32-35
132 "addi 3, 0, 256 # li 3, 256",
133 "mtspr 9, 3 # mtctr 3", # set CTR to 256
134 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
135 # r is in r40-47
136 "sv.or *40, *4, *4", # assign n to r, in r40-47
137 # shifted_d is in r32-39
138 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
139 "addi 3, 0, 1 # li 3, 1", # shift amount
140 "addi 0, 0, 0 # li 0, 0", # dsrd carry
141 "sv.dsrd/mrr *36, *32, 3, 0", # shifted_d = d << (256 - 1)
142 "sv.addi *32, 0, 0", # clear lsb half
143 "sv.or 35, 0, 0", # move carry to correct location
144 # q is in r4-7
145 "sv.addi *4, 0, 0", # clear q
146 "divmod_loop:",
147 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
148 "subfc 0, 0, 0", # set CA
149 # diff is in r48-55
150 "sv.subfe *48, *32, *40", # diff = r - shifted_d
151 # not borrowed is in CA
152 "mcrxrx 0", # move CA to CR0.eq
153 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
154 "addi 0, 0, 0 # li 0, 0", # dsld carry
155 "sv.dsld *4, *4, 3, 0", # q <<= 1 (1 is in r3)
156 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
157 "bc 4, 2, divmod_else # bne divmod_else", # if borrowed goto divmod_else
158 "ori 4, 4, 1", # q |= 1
159 "sv.or *40, *48, *48", # r = diff
160 "divmod_else:",
161 "addi 0, 0, 0 # li 0, 0", # dsld carry
162 "sv.dsld *40, *40, 3, 0", # r <<= 1 (1 is in r3)
163 "bc 16, 0, divmod_loop # bdnz divmod_loop",
164 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
165 # r is in r40-47
166 "sv.or *8, *44, *44", # r >>= 256
167 # q is in r4-7, r is in r8-11
168 "bclr 20, 0, 0 # blr",
169 )
170
171
172 class _DivModRegsRegexLogger:
173 """ logger that logs a regex that matches the expected register dump for
174 the currently tracked `locals` -- quite useful for debugging
175 """
176
177 def __init__(self, enabled=True):
178 self.__tracked = {}
179 self.enabled = enabled
180
181 def log(self, locals_, **changes):
182 """ use like so:
183 ```
184 # create a variable `a`:
185 a = ...
186
187 # we invoke `locals()` each time since python doesn't guarantee
188 # it's up-to-date otherwise
189 logger.log(locals(), a=(4, 6)) # `a` starts at r4 and uses 6 registers
190
191 a += 3
192
193 logger.log(locals()) # keeps using `a`
194
195 b = a + 5
196
197 logger.log(locals(), a=None, b=(4, 6)) # remove `a` and add `b`
198 ```
199 """
200
201 for k, v in changes.items():
202 if v is None:
203 del self.__tracked[k]
204 else:
205 self.__tracked[k] = v
206
207 gprs = [None] * 128
208 for name, (start_gpr, size) in self.__tracked.items():
209 value = locals_[name]
210 for i in range(size):
211 assert gprs[start_gpr + i] is None, "overlapping values"
212 gprs[start_gpr + i] = (value >> 64 * i) % 2 ** 64
213
214 if not self.enabled:
215 # after building `gprs` so we catch any missing/invalid locals
216 return
217
218 segments = []
219
220 for i in range(0, 128, 8):
221 segments.append(f"reg +{i}")
222 for value in gprs[i:i + 8]:
223 if value is None:
224 segments.append(" +[0-9a-f]+")
225 else:
226 segments.append(f" +{value:08x}")
227 segments.append("\\n")
228 log("DIVMOD REGEX:", "".join(segments))
229
230
231 def python_divmod_algorithm(n, d, width=256, log_regex=False):
232 assert n >= 0 and d > 0 and width > 0 and n < (d << width), "invalid input"
233 do_log = _DivModRegsRegexLogger(enabled=log_regex).log
234
235 do_log(locals(), n=(4, 8), d=(32, 4))
236
237 r = n
238 do_log(locals(), n=None, r=(40, 8))
239
240 shifted_d = d << (width - 1)
241 do_log(locals(), d=None, shifted_d=(32, 8))
242
243 q = 0
244 do_log(locals(), q=(4, 4))
245
246 for _ in range(width):
247 diff = r - shifted_d
248 borrowed = diff < 0
249 do_log(locals(), diff=(48, 8))
250
251 q <<= 1
252 do_log(locals())
253
254 if not borrowed:
255 q |= 1
256 do_log(locals())
257
258 r = diff
259 do_log(locals())
260
261 r <<= 1
262 do_log(locals())
263
264 r >>= width
265 do_log(locals(), r=(8, 4))
266
267 return q, r
268
269
270 class PowModCases(TestAccumulatorBase):
271 def call_case(self, instructions, expected, initial_regs, src_loc_at=0):
272 stop_at_pc = 0x10000000
273 sprs = {8: stop_at_pc}
274 expected.intregs[1] = initial_regs[1] = 0x1000000 # set stack pointer
275 expected.pc = stop_at_pc
276 expected.sprs['LR'] = None
277 self.add_case(assemble(instructions),
278 initial_regs, initial_sprs=sprs,
279 stop_at_pc=stop_at_pc, expected=expected,
280 src_loc_at=src_loc_at + 1)
281
282 def case_mul_256_x_256_to_512(self):
283 for i in range(10):
284 a = hash_256(f"mul256 input a {i}")
285 b = hash_256(f"mul256 input b {i}")
286 if i == 0:
287 # use known values:
288 a = b = 2**256 - 1
289 elif i == 1:
290 # use known values:
291 a = b = (2**256 - 1) // 0xFF
292 y = a * b
293 with self.subTest(a=f"{a:#_x}", b=f"{b:#_x}", y=f"{y:#_x}"):
294 # registers start filled with junk
295 initial_regs = [0xABCDEF] * 128
296 for i in range(4):
297 # write a in LE order to regs 4-7
298 initial_regs[4 + i] = (a >> (64 * i)) % 2**64
299 # write b in LE order to regs 8-11
300 initial_regs[8 + i] = (b >> (64 * i)) % 2**64
301 # only check regs up to r11 since that's where the output is
302 e = ExpectedState(int_regs=initial_regs[:12])
303 for i in range(8):
304 # write y in LE order to regs 4-11
305 e.intregs[4 + i] = (y >> (64 * i)) % 2**64
306
307 self.call_case(MUL_256_X_256_TO_512_ASM, e, initial_regs)
308
309 @staticmethod
310 def divmod_512x256_to_256x256_test_inputs():
311 for i in range(10):
312 n = hash_256(f"divmod256 input n msb {i}")
313 n <<= 256
314 n |= hash_256(f"divmod256 input n lsb {i}")
315 d = hash_256(f"divmod256 input d {i}")
316 if i == 0:
317 # use known values:
318 n = 2 ** (256 - 1)
319 d = 1
320 elif i == 1:
321 # use known values:
322 n = 2 ** (512 - 1) - 1
323 d = 2 ** 256 - 1
324 if d == 0:
325 d = 1
326 if n >= d << 256:
327 n -= d << 256
328 yield (n, d)
329
330 def case_divmod_512x256_to_256x256(self):
331 for n, d in self.divmod_512x256_to_256x256_test_inputs():
332 q, r = divmod(n, d)
333 with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",
334 q=f"{q:#_x}", r=f"{r:#_x}"):
335 # registers start filled with junk
336 initial_regs = [0xABCDEF] * 128
337 for i in range(8):
338 # write n in LE order to regs 4-11
339 initial_regs[4 + i] = (n >> (64 * i)) % 2**64
340 for i in range(4):
341 # write d in LE order to regs 32-35
342 initial_regs[32 + i] = (d >> (64 * i)) % 2**64
343 # only check regs up to r11 since that's where the output is.
344 # don't check CR
345 e = ExpectedState(int_regs=initial_regs[:12], crregs=0)
346 e.intregs[0] = 0 # leftovers -- ignore
347 e.intregs[3] = 1 # leftovers -- ignore
348 e.ca = None # ignored
349 for i in range(4):
350 # write q in LE order to regs 4-7
351 e.intregs[4 + i] = (q >> (64 * i)) % 2**64
352 # write r in LE order to regs 8-11
353 e.intregs[8 + i] = (r >> (64 * i)) % 2**64
354
355 self.call_case(DIVMOD_512x256_TO_256x256_ASM, e, initial_regs)
356
357 # TODO: add 256-bit modular exponentiation
358
359
360 # for running "quick" simple investigations
361 if __name__ == "__main__":
362 # first check if python_mul_algorithm works
363 a = b = (99, 99, 99, 99)
364 expected = [1, 0, 0, 0, 98, 99, 99, 99]
365 assert python_mul_algorithm(a, b) == expected
366
367 # now test python_mul_algorithm2 *against* python_mul_algorithm
368 import random
369 random.seed(0) # reproducible values
370 for i in range(10000):
371 a = []
372 b = []
373 for j in range(4):
374 a.append(random.randint(0, 99))
375 b.append(random.randint(0, 99))
376 expected = python_mul_algorithm(a, b)
377 testing = python_mul_algorithm2(a, b)
378 report = "%+17s * %-17s = %s\n" % (repr(a), repr(b), repr(expected))
379 report += " (%s)" % repr(testing)
380 print(report)
381 assert expected == testing