ea4eb29de941e621aeebd7569fd223a79035c296
[openpower-isa.git] / src / openpower / test / bigint / powmod.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
3 # Copyright 2023 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7 #
8 # * https://bugs.libre-soc.org/show_bug.cgi?id=1044
9
10 """ modular exponentiation (`pow(x, y, z)`)
11
12 related bugs:
13
14 * https://bugs.libre-soc.org/show_bug.cgi?id=1044
15 """
16
17 from openpower.test.common import TestAccumulatorBase, skip_case
18 from openpower.test.state import ExpectedState
19 from openpower.test.util import assemble
20 from nmutil.sim_util import hash_256
21 from openpower.util import log
22 from nmutil.plain_data import plain_data
23 from cached_property import cached_property
24 from openpower.decoder.isa.svshape import SVSHAPE
25 from openpower.decoder.power_enums import SPRfull
26 from openpower.decoder.selectable_int import SelectableInt
27
28
29 MUL_256_X_256_TO_512_ASM = (
30 "mul_256_to_512:",
31 # a is in r4-7, b is in r8-11
32 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
33 "sv.or *32, *4, *4", # move args to r32-39
34 # a is now in r32-35, b is in r36-39, y is in r4-11, t is in r40-44
35 "sv.addi *4, 0, 0", # clear output
36 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
37 "sv.maddedu *4, *32, 36, 8", # first partial-product a * b[0]
38 "sv.addi 44, 0, 0",
39 "sv.maddedu *40, *32, 37, 44", # second partial-product a * b[1]
40 "sv.addc 5, 5, 40",
41 "sv.adde *6, *6, *41",
42 "sv.addi 44, 0, 0",
43 "sv.maddedu *40, *32, 38, 44", # third partial-product a * b[2]
44 "sv.addc 6, 6, 40",
45 "sv.adde *7, *7, *41",
46 "sv.addi 44, 0, 0",
47 "sv.maddedu *40, *32, 39, 44", # final partial-product a * b[3]
48 "sv.addc 7, 7, 40",
49 "sv.adde *8, *8, *41",
50 "bclr 20, 0, 0 # blr",
51 )
52
53 # TODO: these really need to go into a common util file, see
54 # openpower/decoder/isa/poly1305-donna.py:def _DSRD(lo, hi, sh)
55 # okok they are modulo 100 but you get the general idea
56
57
58 def maddedu(a, b, c):
59 y = a * b + c
60 return y % 100, y // 100
61
62
63 def adde(a, b, c):
64 y = a + b + c
65 return y % 100, y // 100
66
67
68 def addc(a, b):
69 y = a + b
70 return y % 100, y // 100
71
72
73 def python_mul_algorithm(a, b):
74 # version of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
75 # than 2^64, since that's easier to read.
76 # run this file in a debugger to see all the intermediate values.
77 y = [0] * 8
78 t = [0] * 5
79 for i in range(4):
80 y[i], y[4] = maddedu(a[0], b[i], y[4])
81 t[4] = 0
82 for i in range(4):
83 t[i], t[4] = maddedu(a[1], b[i], t[4])
84 y[1], ca = addc(y[1], t[0])
85 for i in range(4):
86 y[2 + i], ca = adde(y[2 + i], t[1 + i], ca)
87 t[4] = 0
88 for i in range(4):
89 t[i], t[4] = maddedu(a[2], b[i], t[4])
90 y[2], ca = addc(y[2], t[0])
91 for i in range(4):
92 y[3 + i], ca = adde(y[3 + i], t[1 + i], ca)
93 t[4] = 0
94 for i in range(4):
95 t[i], t[4] = maddedu(a[3], b[i], t[4])
96 y[3], ca = addc(y[3], t[0])
97 for i in range(4):
98 y[4 + i], ca = adde(y[4 + i], t[1 + i], ca)
99 return y
100
101
102 def python_mul_algorithm2(a, b):
103 # version 2 of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
104 # than 2^64, since that's easier to read.
105 # the idea here is that it will "morph" into something more akin to
106 # using REMAP bigmul (first using REMAP Indexed)
107
108 # create a schedule for use below. the "end of inner loop" marker is 0b01
109 iyl = []
110 il = []
111 for iy in range(4):
112 for i in range(4):
113 iyl.append((iy+i, i == 3))
114 il.append(i)
115 for i in range(5):
116 iyl.append((iy+i, i == 4))
117 il.append(i)
118
119 y = [0] * 8 # result y and temp t of same size
120 t = [0] * 8 # no need after this to set t[4] to zero
121 for iy in range(4):
122 for i in range(4): # use t[iy+4] as a 64-bit carry
123 t[iy+i], t[iy+4] = maddedu(a[iy], b[i], t[iy+4])
124 ca = 0
125 for i in range(5): # add vec t to y with 1-bit carry
126 idx = iy + i
127 y[idx], ca = adde(y[idx], t[idx], ca)
128 return y
129
130
131 DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM = (
132 # extremely slow and simplistic shift and subtract algorithm.
133 # a future task is to rewrite to use Knuth's Algorithm D,
134 # which is generally an order of magnitude faster
135 "divmod_512_by_256:",
136 # n is in r4-11, d is in r32-35
137 "addi 3, 0, 256 # li 3, 256",
138 "mtspr 9, 3 # mtctr 3", # set CTR to 256
139 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
140 # r is in r40-47
141 "sv.or *40, *4, *4", # assign n to r, in r40-47
142 # shifted_d is in r32-39
143 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
144 "addi 3, 0, 1 # li 3, 1", # shift amount
145 "addi 0, 0, 0 # li 0, 0", # dsrd carry
146 "sv.dsrd/mrr *36, *32, 3, 0", # shifted_d = d << (256 - 1)
147 "sv.addi *32, 0, 0", # clear lsb half
148 "sv.or 35, 0, 0", # move carry to correct location
149 # q is in r4-7
150 "sv.addi *4, 0, 0", # clear q
151 "divmod_loop:",
152 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
153 "subfc 0, 0, 0", # set CA
154 # diff is in r48-55
155 "sv.subfe *48, *32, *40", # diff = r - shifted_d
156 # not borrowed is in CA
157 "mcrxrx 0", # move CA to CR0.eq
158 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
159 "addi 0, 0, 0 # li 0, 0", # dsld carry
160 "sv.dsld *4, *4, 3, 0", # q <<= 1 (1 is in r3)
161 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
162 "bc 4, 2, divmod_else # bne divmod_else", # if borrowed goto divmod_else
163 "ori 4, 4, 1", # q |= 1
164 "sv.or *40, *48, *48", # r = diff
165 "divmod_else:",
166 "addi 0, 0, 0 # li 0, 0", # dsld carry
167 "sv.dsld *40, *40, 3, 0", # r <<= 1 (1 is in r3)
168 "bc 16, 0, divmod_loop # bdnz divmod_loop",
169 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
170 # r is in r40-47
171 "sv.or *8, *44, *44", # r >>= 256
172 # q is in r4-7, r is in r8-11
173 "bclr 20, 0, 0 # blr",
174 )
175
176
177 class _DivModRegsRegexLogger:
178 """ logger that logs a regex that matches the expected register dump for
179 the currently tracked `locals` -- quite useful for debugging
180 """
181
182 def __init__(self, enabled=True, regs=None):
183 self.__tracked = {}
184 self.__regs = regs if regs is not None else {}
185 self.enabled = enabled
186
187 def log(self, locals_, **changes):
188 """ use like so:
189 ```
190 # create a variable `a`:
191 a = ...
192
193 # we invoke `locals()` each time since python doesn't guarantee
194 # it's up-to-date otherwise
195 logger.log(locals(), a=(4, 6)) # `a` starts at r4 and uses 6 registers
196
197 a += 3
198
199 logger.log(locals()) # keeps using `a`
200
201 b = a + 5
202
203 logger.log(locals(), a=None, b=(4, 6)) # remove `a` and add `b`
204 ```
205 """
206
207 for k, v in changes.items():
208 if v is None:
209 del self.__tracked[k]
210 else:
211 if isinstance(v, (tuple, list)):
212 start_gpr, size = v
213 else:
214 start_gpr = v
215 size = 1
216 if not isinstance(start_gpr, int):
217 start_gpr = self.__regs[start_gpr]
218 self.__tracked[k] = start_gpr, size
219
220 gprs = [None] * 128
221 for name, (start_gpr, size) in self.__tracked.items():
222 value = locals_[name]
223 if value is None:
224 continue
225 elif not isinstance(value, (list, tuple)):
226 value = [(value >> 64 * i) % 2 ** 64 for i in range(size)]
227 else:
228 assert len(value) == size, "value has wrong len"
229 for i in range(size):
230 if value[i] is None:
231 continue
232 reg = start_gpr + i
233 if gprs[reg] is not None:
234 other_value, other_name, other_i = gprs[reg]
235 raise AssertionError(f"overlapping values at r{reg}: "
236 f"{name}[{i}] overlaps with "
237 f"{other_name}[{other_i}]")
238 gprs[reg] = value[i], name, i
239
240 if not self.enabled:
241 # after building `gprs` so we catch any missing/invalid locals
242 return
243
244 segments = []
245
246 for i in range(0, 128, 8):
247 segments.append(f"reg +{i}")
248 for value in gprs[i:i + 8]:
249 if value is None:
250 segments.append(" +[0-9a-f]+")
251 else:
252 value, name, i = value
253 segments.append(f" +{value:08x}")
254 segments.append("\\n")
255 log("DIVMOD REGEX:", "".join(segments))
256
257
258 def python_divmod_shift_sub_algorithm(n, d, width=256, log_regex=False):
259 assert n >= 0 and d > 0 and width > 0 and n < (d << width), "invalid input"
260 do_log = _DivModRegsRegexLogger(enabled=log_regex).log
261
262 do_log(locals(), n=(4, 8), d=(32, 4))
263
264 r = n
265 do_log(locals(), n=None, r=(40, 8))
266
267 shifted_d = d << (width - 1)
268 do_log(locals(), d=None, shifted_d=(32, 8))
269
270 q = 0
271 do_log(locals(), q=(4, 4))
272
273 for _ in range(width):
274 diff = r - shifted_d
275 borrowed = diff < 0
276 do_log(locals(), diff=(48, 8))
277
278 q <<= 1
279 do_log(locals())
280
281 if not borrowed:
282 q |= 1
283 do_log(locals())
284
285 r = diff
286 do_log(locals())
287
288 r <<= 1
289 do_log(locals())
290
291 r >>= width
292 do_log(locals(), r=(8, 4))
293
294 return q, r
295
296
297 def divmod2du(RA, RB, RC):
298 # type: (int, int, int) -> tuple[int, int, bool]
299 if RC < RB and RB != 0:
300 RT, RS = divmod(RC << 64 | RA, RB)
301 overflow = False
302 else:
303 overflow = True
304 RT = (1 << 64) - 1
305 RS = 0
306 return RT, RS, overflow
307
308
309 @plain_data()
310 class DivModKnuthAlgorithmD:
311 __slots__ = "num_size", "denom_size", "q_size", "word_size", "regs"
312
313 def __init__(self, num_size=8, denom_size=4, q_size=4,
314 word_size=64, regs=None):
315 # type: (int, int, int | None, int, None | dict[str, int]) -> None
316 assert num_size >= denom_size, \
317 "the dividend's length must be >= the divisor's length"
318 assert word_size > 0
319
320 if q_size is None:
321 # quotient length from original algorithm is m - n + 1,
322 # but that assumes v[-1] != 0 -- since we support smaller divisors
323 # the quotient must be larger.
324 q_size = num_size
325
326 if regs is None:
327 regs = {
328 "n_0": 4,
329 "d_0": 32,
330 "u": 36,
331 "m": 9,
332 "v": 32,
333 "n_scalar": 8,
334 "q": 4,
335 "vn": 32,
336 "un": 36,
337 "product": 46,
338 "r": 8,
339 "t_single": 8,
340 "s_scalar": 10,
341 "t_for_uv_shift": 0,
342 "n_for_unnorm": 32,
343 "t_for_unnorm": 3,
344 "s_for_unnorm": 34,
345 "qhat": 12,
346 "rhat_lo": 14,
347 "rhat_hi": 15,
348 "t_for_prod": 18,
349 "index": 3,
350 "j": 11,
351 "qhat_denom": 18,
352 "qhat_num_hi": 16,
353 }
354
355 self.num_size = num_size
356 self.denom_size = denom_size
357 self.q_size = q_size
358 self.word_size = word_size
359 self.regs = regs
360
361 @property
362 def r_size(self):
363 return self.denom_size
364
365 @property
366 def un_size(self):
367 return self.num_size + 1
368
369 @property
370 def vn_size(self):
371 return self.denom_size
372
373 @property
374 def product_size(self):
375 return self.num_size + 1
376
377 def python(self, n, d, log_regex=False, on_corner_case=lambda desc: None):
378 do_log = _DivModRegsRegexLogger(enabled=log_regex, regs=self.regs).log
379
380 do_log(locals(), n=("n_0", self.num_size), d=("d_0", self.denom_size))
381
382 # switch to names used by Knuth's algorithm D
383 u = list(n) # dividend
384 assert len(u) == self.num_size, "numerator has wrong size"
385 do_log(locals(), n=None, u=("u", self.num_size))
386 m = len(u) # length of dividend
387 do_log(locals(), m="m")
388 v = list(d) # divisor
389 assert len(v) == self.denom_size, "denominator has wrong size"
390 del d # less confusing to debug
391 do_log(locals(), d=None, v=("v", self.denom_size))
392 n = len(v) # length of divisor
393 do_log(locals(), n="n_scalar")
394
395 # allocate outputs/temporaries -- before any normalization so
396 # the outputs/temporaries can be fixed-length in the assembly version.
397
398 q = [0] * self.q_size # quotient
399 do_log(locals(), q=("q", self.q_size))
400 vn = [None] * self.vn_size # normalized divisor
401 do_log(locals(), vn=("vn", self.vn_size))
402 un = [None] * self.un_size # normalized dividend
403 do_log(locals(), un=("un", self.un_size))
404 product = [None] * self.product_size
405 do_log(locals(), product=("product", self.product_size))
406
407 # get non-zero length of dividend
408 while m > 0 and u[m - 1] == 0:
409 m -= 1
410
411 do_log(locals())
412
413 # get non-zero length of divisor
414 while n > 0 and v[n - 1] == 0:
415 n -= 1
416
417 do_log(locals())
418
419 if n == 0:
420 raise ZeroDivisionError
421
422 if n == 1:
423 on_corner_case("single-word divisor")
424 # Knuth's algorithm D requires the divisor to have length >= 2
425 # handle single-word divisors separately
426 t = 0
427 if m > self.q_size:
428 t = u[self.q_size]
429 m = self.q_size
430 do_log(locals(), t="t_single", n=None)
431 do_log(locals(), m=None) # VL = m, so we don't need it in a GPR
432 for i in reversed(range(m)):
433 q[i], t, _ = divmod2du(u[i], v[0], t)
434 do_log(locals())
435 r = [0] * self.r_size # remainder
436 r[0] = t
437 do_log(locals(), t=None, r=("r", self.r_size))
438 return q, r
439
440 if m < n:
441 r = [None] * self.r_size # remainder
442 do_log(locals(), r=("r", self.r_size), m=None, n=None)
443 # dividend < divisor
444 for i in range(self.r_size):
445 r[i] = u[i]
446 do_log(locals())
447 return q, r
448
449 # Knuth's algorithm D starts here:
450
451 # Step D1: normalize
452
453 # calculate amount to shift by -- count leading zeros
454 s = 0
455 index = n - 1
456 do_log(locals(), index="index")
457 while (v[index] << s) >> (self.word_size - 1) == 0:
458 s += 1
459
460 do_log(locals(), s="s_scalar", index=None)
461
462 if s != 0:
463 on_corner_case("non-zero shift")
464
465 # vn = v << s
466 t = 0
467 do_log(locals(), t="t_for_uv_shift")
468 for i in range(n):
469 # dsld
470 t |= v[i] << s
471 v[i] = None # mark reg as unused
472 vn[i] = t % 2 ** self.word_size
473 t >>= self.word_size
474 do_log(locals())
475
476 # un = u << s
477 t = 0
478 do_log(locals(), v=None)
479 for i in range(m):
480 # dsld
481 t |= u[i] << s
482 u[i] = None # mark reg as unused
483 un[i] = t % 2 ** self.word_size
484 t >>= self.word_size
485 do_log(locals())
486 index = m
487 do_log(locals(), index="index")
488 un[index] = t
489
490 do_log(locals(), u=None, t=None, index=None)
491
492 # Step D2 and Step D7: loop
493 for j in range(min(m - n, self.q_size - 1), -1, -1):
494 do_log(locals(), j="j")
495 # Step D3: calculate q̂
496
497 index = j + n
498 do_log(locals(), index="index")
499 qhat_num_hi = un[index]
500 do_log(locals(), qhat_num_hi="qhat_num_hi")
501 index = n - 1
502 do_log(locals())
503 qhat_denom = vn[index]
504 do_log(locals(), qhat_denom="qhat_denom")
505 index = j + n - 1
506 do_log(locals())
507 qhat, rhat_lo, ov = divmod2du(un[index], qhat_denom, qhat_num_hi)
508 rhat_hi = 0
509 do_log(locals(), qhat="qhat", rhat_lo="rhat_lo", rhat_hi="rhat_hi")
510 if ov:
511 # division overflows word
512 on_corner_case("qhat overflows word")
513 assert qhat_num_hi == qhat_denom
514 rhat_lo = (qhat * qhat_denom) % 2 ** self.word_size
515 rhat_hi = (qhat * qhat_denom) >> self.word_size
516 do_log(locals())
517 borrow = un[index] < rhat_lo
518 rhat_lo = (un[index] - rhat_lo) % 2 ** self.word_size
519 do_log(locals())
520 rhat_hi = qhat_num_hi - rhat_hi - borrow
521 do_log(locals(), qhat_num_hi=None, qhat_denom=None)
522
523 while rhat_hi == 0:
524 if qhat * vn[n - 2] > (rhat_lo << self.word_size) + un[j + n - 2]:
525 on_corner_case("qhat adjustment")
526 qhat -= 1
527 do_log(locals())
528 carry = (rhat_lo + vn[n - 1]) >= 2 ** self.word_size
529 rhat_lo += vn[n - 1]
530 rhat_lo %= 2 ** self.word_size
531 do_log(locals())
532 rhat_hi = carry
533 do_log(locals())
534 else:
535 break
536
537 do_log(locals(), rhat_lo=None, rhat_hi=None, index=None)
538
539 # Step D4: multiply and subtract
540
541 t = 0
542 do_log(locals(), t="t_for_prod")
543 for i in range(n):
544 # maddedu
545 t += vn[i] * qhat
546 product[i] = t % 2 ** self.word_size
547 t >>= self.word_size
548 do_log(locals())
549 product[n] = t
550 do_log(locals(), t=None)
551
552 t = 1
553 for i in range(n + 1):
554 # subfe
555 not_product = ~product[i] % 2 ** self.word_size
556 t += not_product + un[j + i]
557 un[j + i] = t % 2 ** self.word_size
558 t = int(t >= 2 ** self.word_size)
559 do_log(locals())
560 need_fixup = not t
561
562 # Step D5: test remainder
563
564 if need_fixup:
565
566 # Step D6: add back
567
568 on_corner_case("add back")
569
570 qhat -= 1
571 do_log(locals())
572
573 t = 0
574 for i in range(n):
575 # adde
576 t += un[j + i] + vn[i]
577 un[j + i] = t % 2 ** self.word_size
578 t = int(t >= 2 ** self.word_size)
579 do_log(locals())
580 un[j + n] += t
581 do_log(locals())
582
583 q[j] = qhat
584 do_log(locals())
585
586 # Step D8: un-normalize
587 do_log(locals(), s="s_for_unnorm", vn=None, m=None, j=None)
588 r = [0] * self.r_size # remainder
589 do_log(locals(), r=("r", self.r_size), n="n_for_unnorm")
590 # r = un >> s
591 t = 0
592 do_log(locals(), t="t_for_unnorm")
593 for i in reversed(range(n)):
594 # dsrd
595 t <<= self.word_size
596 t |= (un[i] << self.word_size) >> s
597 r[i] = t >> self.word_size
598 t %= 2 ** self.word_size
599 do_log(locals())
600
601 return q, r
602
603 def __asm_iter(self):
604 if self.word_size != 64:
605 raise NotImplementedError("only word_size == 64 is implemented")
606 n_0 = self.regs["n_0"]
607 d_0 = self.regs["d_0"]
608 u = self.regs["u"]
609 m = self.regs["m"]
610 v = self.regs["v"]
611 n_scalar = self.regs["n_scalar"]
612 q = self.regs["q"]
613 vn = self.regs["vn"]
614 un = self.regs["un"]
615 product = self.regs["product"]
616 r = self.regs["r"]
617 t_single = self.regs["t_single"]
618 s_scalar = self.regs["s_scalar"]
619 t_for_uv_shift = self.regs["t_for_uv_shift"]
620 n_for_unnorm = self.regs["n_for_unnorm"]
621 t_for_unnorm = self.regs["t_for_unnorm"]
622 s_for_unnorm = self.regs["s_for_unnorm"]
623 qhat = self.regs["qhat"]
624 rhat_lo = self.regs["rhat_lo"]
625 rhat_hi = self.regs["rhat_hi"]
626 t_for_prod = self.regs["t_for_prod"]
627 index = self.regs["index"]
628 j = self.regs["j"]
629 qhat_num_hi = self.regs["qhat_num_hi"]
630 qhat_denom = self.regs["qhat_denom"]
631 num_size = self.num_size
632 denom_size = self.denom_size
633 q_size = self.q_size
634 r_size = self.r_size
635 un_size = self.un_size
636 vn_size = self.vn_size
637 product_size = self.product_size
638
639 yield "divmod_512_by_256:"
640 # n in n_0 size num_size
641 # d in d_0 size denom_size
642
643 yield "mfspr 0, 8 # mflr 0"
644 yield "std 0, 16(1)" # save return address
645 yield "setvl 0, 0, 18, 0, 1, 1" # set VL to 18
646 yield "sv.std *14, -144(1)" # save all callee-save registers
647 yield "stdu 1, -176(1)" # create stack frame as required by ABI
648
649 # switch to names used by Knuth's algorithm D
650 yield f"setvl 0, 0, {num_size}, 0, 1, 1" # set VL to num_size
651 yield f"sv.or *{u}, *{n_0}, *{n_0}" # u = n
652 yield f"addi {m}, 0, {num_size}" # m = len(u)
653 assert v == d_0, "v and d_0 must be in the same regs" # v = d
654 yield f"addi {n_scalar}, 0, {denom_size}" # n = len(v)
655
656 # allocate outputs/temporaries
657 yield f"setvl 0, 0, {q_size}, 0, 1, 1" # set VL to q_size
658 yield f"sv.addi *{q}, 0, 0" # q = [0] * q_size
659
660 # get non-zero length of dividend
661 yield f"setvl 0, 0, {num_size}, 0, 1, 1" # set VL to num_size
662 # create SVSHAPE that reverses order
663 svshape = SVSHAPE(0)
664 svshape.zdimsz = num_size
665 svshape.invxyz = SelectableInt(0b1, 3) # invert Z
666 svshape_low = int(svshape) % 2 ** 16
667 svshape_high = int(svshape) >> 16
668 SVSHAPE0 = SPRfull.SVSHAPE0.value
669 yield f"addis 0, 0, {svshape_high}"
670 yield f"ori 0, 0, {svshape_low}"
671 yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
672 yield f"svremap 0o01, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RA
673 yield f"sv.cmpi/ff=ne *0, 1, *{u}, 0"
674 yield f"setvl {m}, 0, 1, 0, 0, 0 # getvl {m}" # m = VL
675 yield f"subfic {m}, {m}, {num_size}" # m = num_size - m
676
677 # get non-zero length of divisor
678 yield f"setvl 0, 0, {denom_size}, 0, 1, 1" # set VL to denom_size
679 # create SVSHAPE that reverses order
680 svshape = SVSHAPE(0)
681 svshape.zdimsz = denom_size
682 svshape.invxyz = SelectableInt(0b1, 3) # invert Z
683 svshape_low = int(svshape) % 2 ** 16
684 svshape_high = int(svshape) >> 16
685 yield f"addis 0, 0, {svshape_high}"
686 yield f"ori 0, 0, {svshape_low}"
687 yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
688 yield f"svremap 0o01, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RA
689 yield f"sv.cmpi/ff=ne *0, 1, *{v}, 0"
690 yield f"setvl {n_scalar}, 0, 1, 0, 0, 0 # getvl {n_scalar}" # n = VL
691 # n = denom_size - n
692 yield f"subfic {n_scalar}, {n_scalar}, {denom_size}"
693
694 yield f"cmpi 0, 1, {n_scalar}, 1 # cmpdi {n_scalar}, 1"
695 yield "bc 4, 2, divmod_skip_sw_divisor # bne divmod_skip_sw_divisor"
696
697 # Knuth's algorithm D requires the divisor to have length >= 2
698 # handle single-word divisors separately
699 yield f"addi {t_single}, 0, 0"
700 yield f"setvl. {m}, {m}, {q_size}, 0, 1, 1" # m = VL = min(m, q_size)
701 # if CR0.SO: t = u[q_size]
702 yield f"sv.isel {t_single}, {u + q_size}, {t_single}, 3"
703 # div loop
704 yield f"sv.divmod2du/mrr *{q}, *{u}, {v}, {t_single}"
705 # r[0] = t
706 assert r == t_single, "r[0] and t_single must be in the same regs"
707 yield f"setvl 0, 0, {r_size - 1}, 0, 1, 1" # set VL to r_size - 1
708 yield f"sv.addi *{r + 1}, 0, 0" # r[1:] = [0] * (r_size - 1)
709
710 yield "b divmod_return"
711
712 yield "divmod_skip_sw_divisor:"
713 yield f"cmp 0, 1, {m}, {n_scalar} # cmpd {m}, {n_scalar}"
714 yield "bc 4, 0, divmod_skip_copy_r # bge divmod_skip_copy_r"
715 # if m < n:
716
717 yield f"setvl 0, 0, {r_size}, 0, 1, 1" # set VL to r_size
718 yield f"sv.or *{r}, *{u}, *{u}" # r[...] = u[...]
719 yield "b divmod_return"
720
721 yield "divmod_skip_copy_r:"
722
723 # Knuth's algorithm D starts here:
724
725 # Step D1: normalize
726
727 # calculate amount to shift by -- count leading zeros
728 yield f"addi {index}, {n_scalar}, -1" # index = n - 1
729 assert index == 3, "index must be r3"
730 yield f"setvl. 0, 0, {denom_size}, 0, 1, 1" # VL = denom_size
731 yield f"sv.cntlzd/m=1<<r3 {s_scalar}, *{v}" # s = clz64(v[index])
732
733 yield f"addi {t_for_uv_shift}, 0, 0" # t = 0
734 yield f"setvl. 0, {n_scalar}, {denom_size}, 0, 1, 1" # VL = n
735 # vn = v << s
736 yield f"sv.dsld *{vn}, *{v}, {s_scalar}, {t_for_uv_shift}"
737
738 yield f"addi {t_for_uv_shift}, 0, 0" # t = 0
739 yield f"setvl. 0, {m}, {num_size}, 0, 1, 1" # VL = m
740 # un = u << s
741 yield f"sv.dsld *{un}, *{u}, {s_scalar}, {t_for_uv_shift}"
742 yield f"setvl. 0, 0, {un_size}, 0, 1, 1" # VL = un_size
743 yield f"or {index}, {m}, {m}" # index = m
744 assert index == 3, "index must be r3"
745 # un[index] = t
746 yield f"sv.or/m=1<<r3 *{un}, {t_for_uv_shift}, {t_for_uv_shift}"
747
748 # Step D2 and Step D7: loop
749 # j = m - n
750 yield f"subf {j}, {n_scalar}, {m}"
751 # j = min(j, q_size - 1)
752 yield f"addi 0, 0, {q_size - 1}"
753 yield f"minmax {j}, {j}, 0, 0 # maxd {j}, {j}, 0"
754 yield f"divmod_loop:"
755
756 # Step D3: calculate q̂
757 yield f"setvl. 0, 0, {un_size}, 0, 1, 1" # VL = un_size
758 # FIXME: finish
759
760 # Step D2 and Step D7: loop
761 yield f"addic. {j}, {j}, -1" # j -= 1
762 yield f"bc 4, 0, divmod_loop # bge divmod_loop"
763
764 # FIXME: finish
765
766 yield "divmod_return:"
767 yield "addi 1, 1, 176" # teardown stack frame
768 yield "ld 0, 16(1)"
769 yield "mtspr 8, 0 # mtlr 0" # restore return address
770 yield "setvl 0, 0, 18, 0, 1, 1" # set VL to 18
771 yield "sv.ld *14, -144(1)" # restore all callee-save registers
772 yield "bclr 20, 0, 0 # blr"
773
774 @cached_property
775 def asm(self):
776 return tuple(self.__asm_iter())
777
778
779 POWMOD_256_ASM = (
780 # base is in r4-7, exp is in r8-11, mod is in r32-35
781 "powmod_256:",
782 "mfspr 0, 8 # mflr 0",
783 "std 0, 16(1)", # save return address
784 "setvl 0, 0, 18, 0, 1, 1", # set VL to 18
785 "sv.std *14, -144(1)", # save all callee-save registers
786 "stdu 1, -176(1)", # create stack frame as required by ABI
787
788 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
789 "sv.or *16, *4, *4", # move base to r16-19
790 "sv.or *20, *8, *8", # move exp to r20-23
791 "sv.or *24, *32, *32", # move mod to r24-27
792 "sv.addi *28, 0, 0", # retval in r28-31
793 "addi 28, 0, 1", # retval = 1
794
795 "addi 14, 0, 256", # ctr in r14
796
797 "powmod_256_loop:",
798 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
799 "addi 3, 0, 1 # li 3, 1", # shift amount
800 "addi 0, 0, 0 # li 0, 0", # dsrd carry
801 "sv.dsrd/mrr *20, *20, 3, 0", # exp >>= 1; shifted out bit in r0
802 "cmpli 0, 1, 0, 0 # cmpldi 0, 0",
803 "bc 12, 2, powmod_256_else # beq powmod_256_else", # if lsb:
804
805 "sv.or *4, *28, *28", # copy retval to r4-7
806 "sv.or *8, *16, *16", # copy base to r8-11
807 "bl mul_256_to_512", # prod = retval * base
808 # prod in r4-11
809
810 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
811 "sv.or *32, *24, *24", # copy mod to r32-35
812
813 "bl divmod_512_by_256", # prod % mod
814 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
815 "sv.or *28, *8, *8", # retval = prod % mod
816
817 "powmod_256_else:",
818
819 "sv.or *4, *16, *16", # copy base to r4-7
820 "sv.or *8, *16, *16", # copy base to r8-11
821 "bl mul_256_to_512", # prod = base * base
822 # prod in r4-11
823
824 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
825 "sv.or *32, *24, *24", # copy mod to r32-35
826
827 "bl divmod_512_by_256", # prod % mod
828 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
829 "sv.or *16, *8, *8", # base = prod % mod
830
831 "addic. 14, 14, -1", # decrement ctr and compare against zero
832 "bc 4, 2, powmod_256_loop # bne powmod_256_loop",
833
834 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
835 "sv.or *4, *28, *28", # move retval to r4-7
836
837 "addi 1, 1, 176", # teardown stack frame
838 "ld 0, 16(1)",
839 "mtspr 8, 0 # mtlr 0", # restore return address
840 "setvl 0, 0, 18, 0, 1, 1", # set VL to 18
841 "sv.ld *14, -144(1)", # restore all callee-save registers
842 "bclr 20, 0, 0 # blr",
843 *MUL_256_X_256_TO_512_ASM,
844 *DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM,
845 )
846
847
848 def python_powmod_256_algorithm(base, exp, mod):
849 retval = 1
850 for _ in range(256):
851 lsb = bool(exp & 1) # rshift and retrieve lsb
852 exp >>= 1
853 if lsb:
854 prod = retval * base
855 retval = prod % mod
856 prod = base * base
857 base = prod % mod
858 return retval
859
860
861 class PowModCases(TestAccumulatorBase):
862 def call_case(self, instructions, expected, initial_regs, src_loc_at=0):
863 stop_at_pc = 0x10000000
864 sprs = {8: stop_at_pc}
865 expected.intregs[1] = initial_regs[1] = 0x1000000 # set stack pointer
866 expected.pc = stop_at_pc
867 expected.sprs['LR'] = None
868 self.add_case(assemble(instructions),
869 initial_regs, initial_sprs=sprs,
870 stop_at_pc=stop_at_pc, expected=expected,
871 src_loc_at=src_loc_at + 1)
872
873 def case_mul_256_x_256_to_512(self):
874 for i in range(10):
875 a = hash_256(f"mul256 input a {i}")
876 b = hash_256(f"mul256 input b {i}")
877 if i == 0:
878 # use known values:
879 a = b = 2**256 - 1
880 elif i == 1:
881 # use known values:
882 a = b = (2**256 - 1) // 0xFF
883 y = a * b
884 with self.subTest(a=f"{a:#_x}", b=f"{b:#_x}", y=f"{y:#_x}"):
885 # registers start filled with junk
886 initial_regs = [0xABCDEF] * 128
887 for i in range(4):
888 # write a in LE order to regs 4-7
889 initial_regs[4 + i] = (a >> (64 * i)) % 2**64
890 # write b in LE order to regs 8-11
891 initial_regs[8 + i] = (b >> (64 * i)) % 2**64
892 # only check regs up to r11 since that's where the output is
893 e = ExpectedState(int_regs=initial_regs[:12])
894 for i in range(8):
895 # write y in LE order to regs 4-11
896 e.intregs[4 + i] = (y >> (64 * i)) % 2**64
897
898 self.call_case(MUL_256_X_256_TO_512_ASM, e, initial_regs)
899
900 @staticmethod
901 def divmod_512x256_to_256x256_test_inputs():
902 yield (2 ** (256 - 1), 1)
903 yield (2 ** (512 - 1) - 1, 2 ** 256 - 1)
904
905 # test division by single word
906 yield (((1 << 256) - 1) << 32, 1 << 32)
907 yield (((1 << 192) - 1) << 32, 1 << 32)
908 yield (((1 << 64) - 1) << 32, 1 << 32)
909 yield (1 << 32, 1 << 32)
910
911 # test qhat overflow
912 yield (0x8000 << 128 | 0xFFFE << 64, 0x8000 << 64 | 0xFFFF)
913
914 # tests where add back is required
915 yield (8 << (192 - 4) | 3, 2 << (192 - 4) | 1)
916 yield (0x8000 << 128 | 3, 0x2000 << 128 | 1)
917 yield (0x7FFF << 192 | 0x8000 << 128, 0x8000 << 128 | 1)
918
919 for i in range(20):
920 n = hash_256(f"divmod256 input n msb {i}")
921 n <<= 256
922 n |= hash_256(f"divmod256 input n lsb {i}")
923 n_shift = hash_256(f"divmod256 input n shift {i}") % 512
924 n >>= n_shift
925 d = hash_256(f"divmod256 input d {i}")
926 d_shift = hash_256(f"divmod256 input d shift {i}") % 256
927 d >>= d_shift
928 if d == 0:
929 d = 1
930 n %= d << 256
931 yield (n, d)
932
933 def case_divmod_shift_sub_512x256_to_256x256(self):
934 cases = list(self.divmod_512x256_to_256x256_test_inputs())
935 del cases[2:-1] # speed up tests by removing most test cases
936 for n, d in cases:
937 q, r = divmod(n, d)
938 with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",
939 q=f"{q:#_x}", r=f"{r:#_x}"):
940 # registers start filled with junk
941 initial_regs = [0xABCDEF] * 128
942 for i in range(8):
943 # write n in LE order to regs 4-11
944 initial_regs[4 + i] = (n >> (64 * i)) % 2**64
945 for i in range(4):
946 # write d in LE order to regs 32-35
947 initial_regs[32 + i] = (d >> (64 * i)) % 2**64
948 # only check regs up to r11 since that's where the output is.
949 # don't check CR
950 e = ExpectedState(int_regs=initial_regs[:12], crregs=0)
951 e.intregs[0] = 0 # leftovers -- ignore
952 e.intregs[3] = 1 # leftovers -- ignore
953 e.ca = None # ignored
954 for i in range(4):
955 # write q in LE order to regs 4-7
956 e.intregs[4 + i] = (q >> (64 * i)) % 2**64
957 # write r in LE order to regs 8-11
958 e.intregs[8 + i] = (r >> (64 * i)) % 2**64
959
960 self.call_case(
961 DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM, e, initial_regs)
962
963 def case_divmod_knuth_algorithm_d_512x256_to_256x256(self):
964 cases = list(self.divmod_512x256_to_256x256_test_inputs())
965 asm = DivModKnuthAlgorithmD().asm
966 for n, d in cases:
967 skip = d >= 2 ** 64
968 if n << 64 < n:
969 skip = False
970 if skip:
971 # FIXME: only part of the algorithm is implemented,
972 # so we skip the cases that we expect to fail
973 continue
974 q, r = divmod(n, d)
975 with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",
976 q=f"{q:#_x}", r=f"{r:#_x}"):
977 # registers start filled with junk
978 initial_regs = [0xABCDEF] * 128
979 for i in range(8):
980 # write n in LE order to regs 4-11
981 initial_regs[4 + i] = (n >> (64 * i)) % 2**64
982 for i in range(4):
983 # write d in LE order to regs 32-35
984 initial_regs[32 + i] = (d >> (64 * i)) % 2**64
985 # only check regs up to r11 since that's where the output is.
986 # don't check CR
987 e = ExpectedState(int_regs=initial_regs[:12], crregs=0)
988 e.intregs[0] = None # ignored
989 e.intregs[3] = None # ignored
990 e.ca = None # ignored
991 e.sprs['SVSHAPE0'] = None
992 for i in range(4):
993 # write q in LE order to regs 4-7
994 e.intregs[4 + i] = (q >> (64 * i)) % 2**64
995 # write r in LE order to regs 8-11
996 e.intregs[8 + i] = (r >> (64 * i)) % 2**64
997
998 self.call_case(asm, e, initial_regs)
999
1000 @staticmethod
1001 def powmod_256_test_inputs():
1002 for i in range(3):
1003 base = hash_256(f"powmod256 input base {i}")
1004 exp = hash_256(f"powmod256 input exp {i}")
1005 mod = hash_256(f"powmod256 input mod {i}")
1006 if i == 0:
1007 base = 2
1008 exp = 2 ** 256 - 1
1009 mod = 2 ** 256 - 189 # largest prime less than 2 ** 256
1010 if mod == 0:
1011 mod = 1
1012 base %= mod
1013 yield (base, exp, mod)
1014
1015 @skip_case("FIXME: divmod is too slow to test powmod")
1016 def case_powmod_256(self):
1017 for base, exp, mod in PowModCases.powmod_256_test_inputs():
1018 expected = pow(base, exp, mod)
1019 with self.subTest(base=f"{base:#_x}", exp=f"{exp:#_x}",
1020 mod=f"{mod:#_x}", expected=f"{expected:#_x}"):
1021 # registers start filled with junk
1022 initial_regs = [0xABCDEF] * 128
1023 for i in range(4):
1024 # write n in LE order to regs 4-7
1025 initial_regs[4 + i] = (base >> (64 * i)) % 2**64
1026 for i in range(4):
1027 # write n in LE order to regs 8-11
1028 initial_regs[8 + i] = (exp >> (64 * i)) % 2**64
1029 for i in range(4):
1030 # write d in LE order to regs 32-35
1031 initial_regs[32 + i] = (mod >> (64 * i)) % 2**64
1032 # only check regs up to r7 since that's where the output is.
1033 # don't check CR
1034 e = ExpectedState(int_regs=initial_regs[:8], crregs=0)
1035 e.ca = None # ignored
1036 for i in range(4):
1037 # write output in LE order to regs 4-7
1038 e.intregs[4 + i] = (expected >> (64 * i)) % 2**64
1039
1040 self.call_case(POWMOD_256_ASM, e, initial_regs)
1041
1042
1043 # for running "quick" simple investigations
1044 if __name__ == "__main__":
1045 # first check if python_mul_algorithm works
1046 a = b = (99, 99, 99, 99)
1047 expected = [1, 0, 0, 0, 98, 99, 99, 99]
1048 assert python_mul_algorithm(a, b) == expected
1049
1050 # now test python_mul_algorithm2 *against* python_mul_algorithm
1051 import random
1052 random.seed(0) # reproducible values
1053 for i in range(10000):
1054 a = []
1055 b = []
1056 for j in range(4):
1057 a.append(random.randint(0, 99))
1058 b.append(random.randint(0, 99))
1059 expected = python_mul_algorithm(a, b)
1060 testing = python_mul_algorithm2(a, b)
1061 report = "%+17s * %-17s = %s\n" % (repr(a), repr(b), repr(expected))
1062 report += " (%s)" % repr(testing)
1063 print(report)
1064 assert expected == testing