reenable tests
[openpower-isa.git] / src / openpower / test / bigint / powmod.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
3 # Copyright 2023 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
7 #
8 # * https://bugs.libre-soc.org/show_bug.cgi?id=1044
9
10 """ modular exponentiation (`pow(x, y, z)`)
11
12 related bugs:
13
14 * https://bugs.libre-soc.org/show_bug.cgi?id=1044
15 """
16
17 from openpower.test.common import TestAccumulatorBase, skip_case
18 from openpower.test.state import ExpectedState
19 from openpower.test.util import assemble
20 from nmutil.sim_util import hash_256
21 from openpower.util import log, LogType
22 from nmutil.plain_data import plain_data
23 from cached_property import cached_property
24 from openpower.decoder.isa.svshape import SVSHAPE
25 from openpower.decoder.power_enums import SPRfull
26 from openpower.decoder.selectable_int import SelectableInt
27
28
29 MUL_256_X_256_TO_512_ASM = (
30 "mul_256_to_512:",
31 # a is in r4-7, b is in r8-11
32 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
33 "sv.or *32, *4, *4", # move args to r32-39
34 # a is now in r32-35, b is in r36-39, y is in r4-11, t is in r40-44
35 "sv.addi *4, 0, 0", # clear output
36 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
37 "sv.maddedu *4, *32, 36, 8", # first partial-product a * b[0]
38 "sv.addi 44, 0, 0",
39 "sv.maddedu *40, *32, 37, 44", # second partial-product a * b[1]
40 "sv.addc 5, 5, 40",
41 "sv.adde *6, *6, *41",
42 "sv.addi 44, 0, 0",
43 "sv.maddedu *40, *32, 38, 44", # third partial-product a * b[2]
44 "sv.addc 6, 6, 40",
45 "sv.adde *7, *7, *41",
46 "sv.addi 44, 0, 0",
47 "sv.maddedu *40, *32, 39, 44", # final partial-product a * b[3]
48 "sv.addc 7, 7, 40",
49 "sv.adde *8, *8, *41",
50 "bclr 20, 0, 0 # blr",
51 )
52
53 # TODO: these really need to go into a common util file, see
54 # openpower/decoder/isa/poly1305-donna.py:def _DSRD(lo, hi, sh)
55 # okok they are modulo 100 but you get the general idea
56
57
58 def maddedu(a, b, c):
59 y = a * b + c
60 return y % 100, y // 100
61
62
63 def adde(a, b, c):
64 y = a + b + c
65 return y % 100, y // 100
66
67
68 def addc(a, b):
69 y = a + b
70 return y % 100, y // 100
71
72
73 def python_mul_algorithm(a, b):
74 # version of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
75 # than 2^64, since that's easier to read.
76 # run this file in a debugger to see all the intermediate values.
77 y = [0] * 8
78 t = [0] * 5
79 for i in range(4):
80 y[i], y[4] = maddedu(a[0], b[i], y[4])
81 t[4] = 0
82 for i in range(4):
83 t[i], t[4] = maddedu(a[1], b[i], t[4])
84 y[1], ca = addc(y[1], t[0])
85 for i in range(4):
86 y[2 + i], ca = adde(y[2 + i], t[1 + i], ca)
87 t[4] = 0
88 for i in range(4):
89 t[i], t[4] = maddedu(a[2], b[i], t[4])
90 y[2], ca = addc(y[2], t[0])
91 for i in range(4):
92 y[3 + i], ca = adde(y[3 + i], t[1 + i], ca)
93 t[4] = 0
94 for i in range(4):
95 t[i], t[4] = maddedu(a[3], b[i], t[4])
96 y[3], ca = addc(y[3], t[0])
97 for i in range(4):
98 y[4 + i], ca = adde(y[4 + i], t[1 + i], ca)
99 return y
100
101
102 def python_mul_algorithm2(a, b):
103 # version 2 of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
104 # than 2^64, since that's easier to read.
105 # the idea here is that it will "morph" into something more akin to
106 # using REMAP bigmul (first using REMAP Indexed)
107
108 # create a schedule for use below. the "end of inner loop" marker is 0b01
109 iyl = []
110 il = []
111 for iy in range(4):
112 for i in range(4):
113 iyl.append((iy+i, i == 3))
114 il.append(i)
115 for i in range(5):
116 iyl.append((iy+i, i == 4))
117 il.append(i)
118
119 y = [0] * 8 # result y and temp t of same size
120 t = [0] * 8 # no need after this to set t[4] to zero
121 for iy in range(4):
122 for i in range(4): # use t[iy+4] as a 64-bit carry
123 t[iy+i], t[iy+4] = maddedu(a[iy], b[i], t[iy+4])
124 ca = 0
125 for i in range(5): # add vec t to y with 1-bit carry
126 idx = iy + i
127 y[idx], ca = adde(y[idx], t[idx], ca)
128 return y
129
130
131 DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM = (
132 # extremely slow and simplistic shift and subtract algorithm.
133 # a future task is to rewrite to use Knuth's Algorithm D,
134 # which is generally an order of magnitude faster
135 "divmod_512_by_256:",
136 # n is in r4-11, d is in r32-35
137 "addi 3, 0, 256 # li 3, 256",
138 "mtspr 9, 3 # mtctr 3", # set CTR to 256
139 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
140 # r is in r40-47
141 "sv.or *40, *4, *4", # assign n to r, in r40-47
142 # shifted_d is in r32-39
143 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
144 "addi 3, 0, 1 # li 3, 1", # shift amount
145 "addi 0, 0, 0 # li 0, 0", # dsrd carry
146 "sv.dsrd/mrr *36, *32, 3, 0", # shifted_d = d << (256 - 1)
147 "sv.addi *32, 0, 0", # clear lsb half
148 "sv.or 35, 0, 0", # move carry to correct location
149 # q is in r4-7
150 "sv.addi *4, 0, 0", # clear q
151 "divmod_loop:",
152 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
153 "subfc 0, 0, 0", # set CA
154 # diff is in r48-55
155 "sv.subfe *48, *32, *40", # diff = r - shifted_d
156 # not borrowed is in CA
157 "mcrxrx 0", # move CA to CR0.eq
158 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
159 "addi 0, 0, 0 # li 0, 0", # dsld carry
160 "sv.dsld *4, *4, 3, 0", # q <<= 1 (1 is in r3)
161 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
162 "bc 4, 2, divmod_else # bne divmod_else", # if borrowed goto divmod_else
163 "ori 4, 4, 1", # q |= 1
164 "sv.or *40, *48, *48", # r = diff
165 "divmod_else:",
166 "addi 0, 0, 0 # li 0, 0", # dsld carry
167 "sv.dsld *40, *40, 3, 0", # r <<= 1 (1 is in r3)
168 "bc 16, 0, divmod_loop # bdnz divmod_loop",
169 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
170 # r is in r40-47
171 "sv.or *8, *44, *44", # r >>= 256
172 # q is in r4-7, r is in r8-11
173 "bclr 20, 0, 0 # blr",
174 )
175
176
177 class _DivModRegsRegexLogger:
178 """ logger that logs a regex that matches the expected register dump for
179 the currently tracked `locals` -- quite useful for debugging
180 """
181
182 def __init__(self, enabled=True, regs=None):
183 self.__tracked = {}
184 self.__regs = regs if regs is not None else {}
185 self.enabled = enabled
186
187 def log(self, locals_, label=None, **changes):
188 """ use like so:
189 ```
190 # create a variable `a`:
191 a = ...
192
193 # we invoke `locals()` each time since python doesn't guarantee
194 # it's up-to-date otherwise
195 logger.log(locals(), a=(4, 6)) # `a` starts at r4 and uses 6 registers
196
197 a += 3
198
199 logger.log(locals()) # keeps using `a`
200
201 b = a + 5
202
203 logger.log(locals(), a=None, b=(4, 6)) # remove `a` and add `b`
204 ```
205 """
206
207 for k, v in changes.items():
208 if v is None:
209 self.__tracked.pop(k, None)
210 else:
211 if isinstance(v, (tuple, list)):
212 start_gpr, size = v
213 else:
214 start_gpr = v
215 size = 1
216 if not isinstance(start_gpr, int):
217 start_gpr = self.__regs[start_gpr]
218 self.__tracked[k] = start_gpr, size
219
220 gprs = [None] * 128
221 for name, (start_gpr, size) in self.__tracked.items():
222 value = locals_[name]
223 if value is None:
224 continue
225 elif not isinstance(value, (list, tuple)):
226 value = [(value >> 64 * i) % 2 ** 64 for i in range(size)]
227 else:
228 assert len(value) == size, "value has wrong len"
229 for i in range(size):
230 if value[i] is None:
231 continue
232 reg = start_gpr + i
233 if gprs[reg] is not None:
234 other_value, other_name, other_i = gprs[reg]
235 raise AssertionError(f"overlapping values at r{reg}: "
236 f"{name}[{i}] overlaps with "
237 f"{other_name}[{other_i}]")
238 gprs[reg] = value[i], name, i
239
240 if not self.enabled:
241 # after building `gprs` so we catch any missing/invalid locals
242 return
243
244 segments = []
245
246 for i in range(0, 128, 8):
247 segments.append(f"reg +{i}")
248 for value in gprs[i:i + 8]:
249 if value is None:
250 segments.append(" +[0-9a-f]+")
251 else:
252 value, name, i = value
253 segments.append(f" +{value:08x}")
254 segments.append("\\n")
255 prefix = "" if label is None else f"At: {label}\n"
256 log(prefix + "DIVMOD REGEX:", "".join(segments),
257 kind=LogType.OutputMatching)
258
259
260 def python_divmod_shift_sub_algorithm(n, d, width=256, log_regex=False):
261 assert n >= 0 and d > 0 and width > 0 and n < (d << width), "invalid input"
262 do_log = _DivModRegsRegexLogger(enabled=log_regex).log
263
264 do_log(locals(), n=(4, 8), d=(32, 4))
265
266 r = n
267 do_log(locals(), n=None, r=(40, 8))
268
269 shifted_d = d << (width - 1)
270 do_log(locals(), d=None, shifted_d=(32, 8))
271
272 q = 0
273 do_log(locals(), q=(4, 4))
274
275 for _ in range(width):
276 diff = r - shifted_d
277 borrowed = diff < 0
278 do_log(locals(), diff=(48, 8))
279
280 q <<= 1
281 do_log(locals())
282
283 if not borrowed:
284 q |= 1
285 do_log(locals())
286
287 r = diff
288 do_log(locals())
289
290 r <<= 1
291 do_log(locals())
292
293 r >>= width
294 do_log(locals(), r=(8, 4))
295
296 return q, r
297
298
299 def divmod2du(RA, RB, RC):
300 # type: (int, int, int) -> tuple[int, int, bool]
301 if RC < RB and RB != 0:
302 RT, RS = divmod(RC << 64 | RA, RB)
303 overflow = False
304 else:
305 overflow = True
306 RT = (1 << 64) - 1
307 RS = 0
308 return RT, RS, overflow
309
310
311 @plain_data()
312 class DivModKnuthAlgorithmD:
313 __slots__ = "num_size", "denom_size", "q_size", "word_size", "regs"
314
315 def __init__(self, num_size=8, denom_size=4, q_size=4,
316 word_size=64, regs=None):
317 # type: (int, int, int | None, int, None | dict[str, int]) -> None
318 assert num_size >= denom_size, \
319 "the dividend's length must be >= the divisor's length"
320 assert word_size > 0
321
322 if q_size is None:
323 # quotient length from original algorithm is m - n + 1,
324 # but that assumes v[-1] != 0 -- since we support smaller divisors
325 # the quotient must be larger.
326 q_size = num_size
327
328 if regs is None:
329 regs = {
330 "n_0": 4,
331 "d_0": 32,
332 "u": 36,
333 "m": 9,
334 "v": 32,
335 "n_scalar": 8,
336 "q": 4,
337 "vn": 32,
338 "un": 36,
339 "product": 46,
340 "r": 8,
341 "t_single": 8,
342 "s_scalar": 10,
343 "t_for_uv_shift": 0,
344 "n_for_unnorm": 16,
345 "t_for_unnorm": 3,
346 "s_for_unnorm": 18,
347 "qhat": 12,
348 "rhat_lo": 14,
349 "rhat_hi": 15,
350 "t_for_prod": 18,
351 "index": 3,
352 "j": 11,
353 "qhat_denom": 18,
354 "qhat_num_hi": 16,
355 "qhat_prod_lo": 15,
356 "qhat_prod_hi": 18,
357 "sub_len": 3,
358 }
359
360 self.num_size = num_size
361 self.denom_size = denom_size
362 self.q_size = q_size
363 self.word_size = word_size
364 self.regs = regs
365
366 @property
367 def r_size(self):
368 return self.denom_size
369
370 @property
371 def un_size(self):
372 return self.num_size + 1
373
374 @property
375 def vn_size(self):
376 return self.denom_size
377
378 @property
379 def product_size(self):
380 return self.num_size + 1
381
382 def python(self, n, d, log_regex=False, on_corner_case=lambda desc: None):
383 # IMPORTANT: do_log calls match up with the expected register values
384 # in the assembly version at that point in the algorithm, please don't
385 # "simplify" all the seemingly-redundant local variable assignments,
386 # they match what actually happens in the assembly version.
387 do_log = _DivModRegsRegexLogger(enabled=log_regex, regs=self.regs).log
388
389 do_log(locals(), "start",
390 n=("n_0", self.num_size), d=("d_0", self.denom_size))
391
392 # switch to names used by Knuth's algorithm D
393 u = list(n) # dividend
394 assert len(u) == self.num_size, "numerator has wrong size"
395 do_log(locals(), "u = n", n=None, u=("u", self.num_size))
396 m = len(u) # length of dividend
397 do_log(locals(), "m = len(u)", m="m")
398 v = list(d) # divisor
399 assert len(v) == self.denom_size, "denominator has wrong size"
400 del d # less confusing to debug
401 do_log(locals(), "v = d", d=None, v=("v", self.denom_size))
402 n = len(v) # length of divisor
403 do_log(locals(), "n = len(v)", n="n_scalar")
404
405 # allocate outputs/temporaries -- before any normalization so
406 # the outputs/temporaries can be fixed-length in the assembly version.
407
408 q = [0] * self.q_size # quotient
409 do_log(locals(), "q = [0...]", q=("q", self.q_size))
410 vn = [None] * self.vn_size # normalized divisor
411 do_log(locals(), "vn = [None...]", vn=("vn", self.vn_size))
412 un = [None] * self.un_size # normalized dividend
413 do_log(locals(), "un = [None...]", un=("un", self.un_size))
414 product = [None] * self.product_size
415 do_log(locals(), "product = [None...]",
416 product=("product", self.product_size))
417
418 # get non-zero length of dividend
419 while m > 0 and u[m - 1] == 0:
420 m -= 1
421
422 do_log(locals(), "get non-zero dividend len")
423
424 # get non-zero length of divisor
425 while n > 0 and v[n - 1] == 0:
426 n -= 1
427
428 do_log(locals(), "get non-zero divisor len")
429
430 if n == 0:
431 raise ZeroDivisionError
432
433 if n == 1:
434 on_corner_case("single-word divisor")
435 # Knuth's algorithm D requires the divisor to have length >= 2
436 # handle single-word divisors separately
437 t = 0
438 if m > self.q_size:
439 t = u[self.q_size]
440 m = self.q_size
441 do_log(locals(), "t = u[q_size]", t="t_single", n=None)
442 # VL = m, so we don't need it in a GPR
443 do_log(locals(), "VL = m", m=None)
444 for i in reversed(range(m)):
445 q[i], t, _ = divmod2du(u[i], v[0], t)
446 do_log(locals(), "divide step")
447 r = [0] * self.r_size # remainder
448 r[0] = t
449 do_log(locals(), "finished single-word divisor",
450 t=None, r=("r", self.r_size))
451 return q, r
452
453 if m < n:
454 r = [None] * self.r_size # remainder
455 do_log(locals(), "m < n", r=("r", self.r_size), m=None, n=None)
456 # dividend < divisor
457 for i in range(self.r_size):
458 r[i] = u[i]
459 do_log(locals(), "finished m < n")
460 return q, r
461
462 # Knuth's algorithm D starts here:
463
464 # Step D1: normalize
465
466 # calculate amount to shift by -- count leading zeros
467 s = 0
468 index = n - 1
469 do_log(locals(), "index = n - 1", index="index")
470 while (v[index] << s) >> (self.word_size - 1) == 0:
471 s += 1
472
473 do_log(locals(), "s = clz64", s="s_scalar", index=None)
474
475 if s != 0:
476 on_corner_case("non-zero shift")
477
478 # vn = v << s
479 t = 0
480 do_log(locals(), "vn = v << s: t = 0", t="t_for_uv_shift")
481 for i in range(n):
482 # dsld
483 t |= v[i] << s
484 v[i] = None # mark reg as unused
485 vn[i] = t % 2 ** self.word_size
486 t >>= self.word_size
487 do_log(locals(), "vn = v << s: step")
488
489 # un = u << s
490 t = 0
491 do_log(locals(), "un = u << s: t = 0", v=None)
492 for i in range(m):
493 # dsld
494 t |= u[i] << s
495 u[i] = None # mark reg as unused
496 un[i] = t % 2 ** self.word_size
497 t >>= self.word_size
498 do_log(locals(), "un = u << s: step")
499 index = m
500 do_log(locals(), "un = u << s: index = m", index="index")
501 un[index] = t
502
503 do_log(locals(), "un = u << s: un[index] = t",
504 u=None, t=None, index=None)
505
506 # Step D2 and Step D7: loop
507 for j in range(min(m - n, self.q_size - 1), -1, -1):
508 do_log(locals(), "start of j loop", j="j")
509 # Step D3: calculate q̂
510
511 index = j + n
512 do_log(locals(), "qhat: index = j + n", index="index")
513 qhat_num_hi = un[index]
514 do_log(locals(), "qhat_num_hi = un[index]",
515 qhat_num_hi="qhat_num_hi")
516 index = n - 1
517 do_log(locals(), "qhat: index = n - 1")
518 qhat_denom = vn[index]
519 do_log(locals(), "qhat_denom = vn[index]",
520 qhat_denom="qhat_denom")
521 index = j + n - 1
522 do_log(locals(), "qhat: index = j + n - 1")
523 qhat, rhat_lo, ov = divmod2du(un[index], qhat_denom, qhat_num_hi)
524 rhat_hi = 0
525 do_log(locals(), "qhat: initial divmod2du",
526 qhat="qhat", rhat_lo="rhat_lo", rhat_hi="rhat_hi")
527 if ov:
528 # division overflows word
529 on_corner_case("qhat overflows word")
530 assert qhat_num_hi == qhat_denom
531 rhat_lo = (qhat * qhat_denom) % 2 ** self.word_size
532 rhat_hi = (qhat * qhat_denom) >> self.word_size
533 do_log(locals(), "qhat ov: rhat = qhat * qhat_denom")
534 borrow = un[index] < rhat_lo
535 rhat_lo = (un[index] - rhat_lo) % 2 ** self.word_size
536 do_log(locals(), "qhat ov: un[index] - rhat_lo")
537 rhat_hi = qhat_num_hi - rhat_hi - borrow
538 do_log(locals(), "qhat: after overflow check",
539 qhat_num_hi=None, qhat_denom=None)
540
541 while rhat_hi == 0:
542 index = n - 2
543 do_log(locals(), "qhat adj loop: index = n - 2")
544 qhat_prod_lo = (qhat * vn[index]) % 2 ** self.word_size
545 do_log(locals(), "qhat adj loop: prod_lo",
546 qhat_prod_lo="qhat_prod_lo", rhat_hi=None)
547 qhat_prod_hi = (qhat * vn[index]) >> self.word_size
548 do_log(locals(), "qhat adj loop: prod_hi",
549 qhat_prod_hi="qhat_prod_hi")
550 if qhat_prod_hi < rhat_lo:
551 break
552 index = j + n - 2
553 do_log(locals(), "qhat adj loop: index = j + n - 2")
554 if qhat_prod_hi == rhat_lo:
555 if qhat_prod_lo <= un[index]:
556 break
557 on_corner_case("qhat adjustment")
558 do_log(locals(), "qhat adj loop: adj needed", index=None,
559 qhat_prod_lo=None, qhat_prod_hi=None)
560 qhat -= 1
561 do_log(locals(), "qhat adj loop: qhat -= 1", index="index")
562 index = n - 1
563 do_log(locals(), "qhat adj loop: index = n - 1")
564 carry = (rhat_lo + vn[index]) >= 2 ** self.word_size
565 rhat_lo = (rhat_lo + vn[index]) % 2 ** self.word_size
566 do_log(locals(), "qhat adj loop: rhat_lo += vn[index]")
567 rhat_hi = carry
568 do_log(locals(), "qhat adj loop: rhat_hi = CA",
569 rhat_hi="rhat_hi")
570
571 do_log(locals(), "computed qhat", rhat_lo=None, rhat_hi=None,
572 index=None, qhat_prod_lo=None, qhat_prod_hi=None)
573
574 # Step D4: multiply and subtract
575
576 t = 0
577 do_log(locals(), "product: t = 0", t="t_for_prod")
578 for i in range(n):
579 # maddedu
580 t += vn[i] * qhat
581 product[i] = t % 2 ** self.word_size
582 t >>= self.word_size
583 do_log(locals(), "product: step")
584 index = n
585 do_log(locals(), "product: index = n", index="index")
586 product[index] = t
587 do_log(locals(), "product[index] = t", t=None, index=None)
588
589 t = 1
590 do_log(locals(), "subtract: t = 1")
591 sub_len = n + 1
592 do_log(locals(), "subtract: sub_len = n + 1", sub_len="sub_len")
593 VL = sub_len
594 do_log(locals(), "VL = sub_len", sub_len=None)
595 for i in range(VL):
596 # subfe
597 not_product = ~product[i] % 2 ** self.word_size
598 t += not_product + un[j + i]
599 un[j + i] = t % 2 ** self.word_size
600 t = int(t >= 2 ** self.word_size)
601 do_log(locals(), "subtract: step")
602 need_fixup = not t
603
604 # Step D5: test remainder
605
606 if need_fixup:
607
608 # Step D6: add back
609
610 on_corner_case("add back")
611
612 qhat -= 1
613 do_log(locals(), "add back: qhat -= 1")
614
615 t = 0
616 for i in range(n):
617 # adde
618 t += un[j + i] + vn[i]
619 un[j + i] = t % 2 ** self.word_size
620 t = int(t >= 2 ** self.word_size)
621 do_log(locals(), "add back: step")
622 index = j + n
623 do_log(locals(), "add back: index = j + n", index="index")
624 un[index] += t
625 do_log(locals(), "add back: un[index] += t", index=None)
626
627 index = j
628 do_log(locals(), "assign q: index = j", index="index")
629 q[index] = qhat
630 do_log(locals(), "q[index] = qhat", index=None)
631
632 # Step D8: un-normalize
633
634 # move s and n
635 do_log(locals(), "un-normalize", s="s_for_unnorm", n="n_for_unnorm",
636 vn=None, m=None, j=None)
637
638 r = [0] * self.r_size # remainder
639 do_log(locals(), "un-normalize: r = [0...]", r=("r", self.r_size))
640 # r = un >> s
641 t = 0
642 do_log(locals(), "un-normalize: t = 0", t="t_for_unnorm")
643 for i in reversed(range(n)):
644 # dsrd
645 t <<= self.word_size
646 t |= (un[i] << self.word_size) >> s
647 r[i] = t >> self.word_size
648 t %= 2 ** self.word_size
649 do_log(locals(), "un-normalize: step")
650
651 do_log(locals(), "finished un-normalize")
652
653 return q, r
654
655 def __asm_iter(self):
656 # IMPORTANT: the assembly matches up with the python version, if you
657 # make any changes, change the python version to match.
658 if self.word_size != 64:
659 raise NotImplementedError("only word_size == 64 is implemented")
660 n_0 = self.regs["n_0"]
661 d_0 = self.regs["d_0"]
662 u = self.regs["u"]
663 m = self.regs["m"]
664 v = self.regs["v"]
665 n_scalar = self.regs["n_scalar"]
666 q = self.regs["q"]
667 vn = self.regs["vn"]
668 un = self.regs["un"]
669 product = self.regs["product"]
670 r = self.regs["r"]
671 t_single = self.regs["t_single"]
672 s_scalar = self.regs["s_scalar"]
673 t_for_uv_shift = self.regs["t_for_uv_shift"]
674 n_for_unnorm = self.regs["n_for_unnorm"]
675 t_for_unnorm = self.regs["t_for_unnorm"]
676 s_for_unnorm = self.regs["s_for_unnorm"]
677 qhat = self.regs["qhat"]
678 rhat_lo = self.regs["rhat_lo"]
679 rhat_hi = self.regs["rhat_hi"]
680 t_for_prod = self.regs["t_for_prod"]
681 index = self.regs["index"]
682 j = self.regs["j"]
683 qhat_num_hi = self.regs["qhat_num_hi"]
684 qhat_denom = self.regs["qhat_denom"]
685 qhat_prod_lo = self.regs["qhat_prod_lo"]
686 qhat_prod_hi = self.regs["qhat_prod_hi"]
687 sub_len = self.regs["sub_len"]
688 num_size = self.num_size
689 denom_size = self.denom_size
690 q_size = self.q_size
691 r_size = self.r_size
692 un_size = self.un_size
693 vn_size = self.vn_size
694 product_size = self.product_size
695
696 yield "divmod_512_by_256:"
697 # n in n_0 size num_size
698 # d in d_0 size denom_size
699
700 yield "mfspr 0, 8 # mflr 0"
701 yield "std 0, 16(1)" # save return address
702 yield "setvl 0, 0, 18, 0, 1, 1" # set VL to 18
703 yield "sv.std *14, -144(1)" # save all callee-save registers
704 yield "stdu 1, -176(1)" # create stack frame as required by ABI
705
706 # switch to names used by Knuth's algorithm D
707 yield f"setvl 0, 0, {num_size}, 0, 1, 1" # set VL to num_size
708 yield f"sv.or *{u}, *{n_0}, *{n_0}" # u = n
709 yield f"addi {m}, 0, {num_size}" # m = len(u)
710 assert v == d_0, "v and d_0 must be in the same regs" # v = d
711 yield f"addi {n_scalar}, 0, {denom_size}" # n = len(v)
712
713 # allocate outputs/temporaries
714 yield f"setvl 0, 0, {q_size}, 0, 1, 1" # set VL to q_size
715 yield f"sv.addi *{q}, 0, 0" # q = [0] * q_size
716
717 # get non-zero length of dividend
718 yield f"setvl 0, 0, {num_size}, 0, 1, 1" # set VL to num_size
719 # create SVSHAPE that reverses order
720 svshape = SVSHAPE(0)
721 svshape.zdimsz = num_size
722 svshape.invxyz = SelectableInt(0b1, 3) # invert Z
723 svshape_low = int(svshape) % 2 ** 16
724 svshape_high = int(svshape) >> 16
725 SVSHAPE0 = SPRfull.SVSHAPE0.value
726 yield f"addis 0, 0, {svshape_high}"
727 yield f"ori 0, 0, {svshape_low}"
728 yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
729 yield f"svremap 0o01, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RA
730 yield f"sv.cmpli/ff=ne *0, 1, *{u}, 0"
731 yield f"setvl {m}, 0, 1, 0, 0, 0 # getvl {m}" # m = VL
732 yield f"subfic {m}, {m}, {num_size}" # m = num_size - m
733
734 # get non-zero length of divisor
735 yield f"setvl 0, 0, {denom_size}, 0, 1, 1" # set VL to denom_size
736 # create SVSHAPE that reverses order
737 svshape = SVSHAPE(0)
738 svshape.zdimsz = denom_size
739 svshape.invxyz = SelectableInt(0b1, 3) # invert Z
740 svshape_low = int(svshape) % 2 ** 16
741 svshape_high = int(svshape) >> 16
742 yield f"addis 0, 0, {svshape_high}"
743 yield f"ori 0, 0, {svshape_low}"
744 yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
745 yield f"svremap 0o01, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RA
746 yield f"sv.cmpli/ff=ne *0, 1, *{v}, 0"
747 yield f"setvl {n_scalar}, 0, 1, 0, 0, 0 # getvl {n_scalar}" # n = VL
748 # n = denom_size - n
749 yield f"subfic {n_scalar}, {n_scalar}, {denom_size}"
750
751 yield f"cmpli 0, 1, {n_scalar}, 1 # cmpldi {n_scalar}, 1"
752 yield "bc 4, 2, divmod_skip_sw_divisor # bne divmod_skip_sw_divisor"
753
754 # Knuth's algorithm D requires the divisor to have length >= 2
755 # handle single-word divisors separately
756 yield f"addi {t_single}, 0, 0"
757 yield f"setvl. {m}, {m}, {q_size}, 0, 1, 1" # m = VL = min(m, q_size)
758 # if CR0.SO: t = u[q_size]
759 yield f"sv.isel {t_single}, {u + q_size}, {t_single}, 3"
760 # div loop
761 yield f"sv.divmod2du/mrr *{q}, *{u}, {v}, {t_single}"
762 # r[0] = t
763 assert r == t_single, "r[0] and t_single must be in the same regs"
764 yield f"setvl 0, 0, {r_size - 1}, 0, 1, 1" # set VL to r_size - 1
765 yield f"sv.addi *{r + 1}, 0, 0" # r[1:] = [0] * (r_size - 1)
766
767 yield "b divmod_return"
768
769 yield "divmod_skip_sw_divisor:"
770 yield f"cmpl 0, 1, {m}, {n_scalar} # cmpld {m}, {n_scalar}"
771 yield "bc 4, 0, divmod_skip_copy_r # bge divmod_skip_copy_r"
772 # if m < n:
773
774 yield f"setvl 0, 0, {r_size}, 0, 1, 1" # set VL to r_size
775 yield f"sv.or *{r}, *{u}, *{u}" # r[...] = u[...]
776 yield "b divmod_return"
777
778 yield "divmod_skip_copy_r:"
779
780 # Knuth's algorithm D starts here:
781
782 # Step D1: normalize
783
784 # calculate amount to shift by -- count leading zeros
785 yield f"addi {index}, {n_scalar}, -1" # index = n - 1
786 assert index == 3, "index must be r3"
787 yield f"setvl 0, 0, {denom_size}, 0, 1, 1" # VL = denom_size
788 yield f"sv.cntlzd/m=1<<r3 {s_scalar}, *{v}" # s = clz64(v[index])
789
790 yield f"addi {t_for_uv_shift}, 0, 0" # t = 0
791 yield f"setvl 0, {n_scalar}, {denom_size}, 0, 1, 1" # VL = n
792 # vn = v << s
793 yield f"sv.dsld *{vn}, *{v}, {s_scalar}, {t_for_uv_shift}"
794
795 yield f"addi {t_for_uv_shift}, 0, 0" # t = 0
796 yield f"setvl 0, {m}, {num_size}, 0, 1, 1" # VL = m
797 # un = u << s
798 yield f"sv.dsld *{un}, *{u}, {s_scalar}, {t_for_uv_shift}"
799 yield f"setvl 0, 0, {un_size}, 0, 1, 1" # VL = un_size
800 yield f"or {index}, {m}, {m}" # index = m
801 assert index == 3, "index must be r3"
802 # un[index] = t
803 yield f"sv.or/m=1<<r3 *{un}, {t_for_uv_shift}, {t_for_uv_shift}"
804
805 # Step D2 and Step D7: loop
806 # j = m - n
807 yield f"subf {j}, {n_scalar}, {m}"
808 # j = min(j, q_size - 1)
809 yield f"addi 0, 0, {q_size - 1}"
810 yield f"minmax {j}, {j}, 0, 0 # maxd {j}, {j}, 0"
811 yield f"divmod_loop:"
812
813 # Step D3: calculate q̂
814 yield f"setvl 0, 0, {un_size}, 0, 1, 1" # VL = un_size
815 yield f"add {index}, {j}, {n_scalar}" # index = j + n
816 # qhat_num_hi = un[index]
817 assert index == 3, "index must be r3"
818 yield f"sv.or/m=1<<r3 {qhat_num_hi}, *{un}, *{un}"
819 yield f"addi {index}, {n_scalar}, -1" # index = n - 1
820 # qhat_denom = vn[index]
821 yield f"setvl 0, 0, {vn_size}, 0, 1, 1" # VL = vn_size
822 assert index == 3, "index must be r3"
823 yield f"sv.or/m=1<<r3 {qhat_denom}, *{vn}, *{vn}"
824 yield f"add {index}, {index}, {j}" # index = j + n - 1
825 # qhat, rhat_lo, ov = divmod2du(un[index], qhat_denom, qhat_num_hi)
826 yield f"or {rhat_lo}, {qhat_num_hi}, {qhat_num_hi}"
827 yield f"setvl 0, 0, {un_size}, 0, 1, 1" # VL = un_size
828 assert index == 3, "index must be r3"
829 yield f"sv.divmod2du/m=1<<r3 {qhat}, *{un}, {qhat_denom}, {rhat_lo}"
830 yield f"addi {rhat_hi}, 0, 0" # rhat_hi = 0
831 yield f"mcrxrx 0" # move OV to CR0.lt
832 yield "bc 4, 0, divmod_skip_qhat_overflow # bge divmod_..."
833 # if ov:
834 # division overflows word
835 # rhat_lo = (qhat * qhat_denom) % 2 ** self.word_size
836 yield f"mulld {rhat_lo}, {qhat}, {qhat_denom}"
837 # rhat_hi = (qhat * qhat_denom) >> self.word_size
838 yield f"mulhdu {rhat_hi}, {qhat}, {qhat_denom}"
839 # borrow = un[index] < rhat_lo
840 # rhat_lo = (un[index] - rhat_lo) % 2 ** self.word_size
841 assert index == 3, "index must be r3"
842 yield f"sv.subfc/m=1<<r3 {rhat_lo}, {rhat_lo}, *{un}"
843 # rhat_hi = qhat_num_hi - rhat_hi - borrow
844 yield f"subfe {rhat_hi}, {rhat_hi}, {qhat_num_hi}"
845 yield "divmod_skip_qhat_overflow:"
846
847 # while rhat_hi == 0:
848 yield "divmod_qhat_adj_loop:"
849 yield f"cmpli 0, 1, {rhat_hi}, 0 # cmpldi {rhat_hi}, 0"
850 yield "bc 12, 2, divmod_qhat_adj_loop_break # beq divmod_qhat_adj..."
851
852 yield f"setvl 0, 0, {vn_size}, 0, 1, 1" # VL = vn_size
853 yield f"addi {index}, {n_scalar}, -2" # index = n - 2
854 # qhat_prod_lo = (qhat * vn[index]) % 2 ** self.word_size
855 assert index == 3, "index must be r3"
856 yield f"sv.mulld/m=1<<r3 {qhat_prod_lo}, {qhat}, *{vn}"
857 # qhat_prod_hi = (qhat * vn[index]) >> self.word_size
858 yield f"sv.mulhdu/m=1<<r3 {qhat_prod_hi}, {qhat}, *{vn}"
859
860 # if qhat_prod_hi < rhat_lo:
861 # break
862 yield f"cmpl 0, 1, {qhat_prod_hi}, {rhat_lo} # cmpld cr0, ..."
863 yield "bc 12, 0, divmod_qhat_adj_loop_break # blt divmod_qhat_adj..."
864 # if qhat_prod_hi == rhat_lo:
865 yield "bc 4, 2, divmod_qhat_do_adj # bne divmod_qhat_do_adj"
866
867 yield f"add {index}, {index}, {j}" # index = j + n - 2
868 # if qhat_prod_lo <= un[index]:
869 # break
870 yield f"setvl 0, 0, {un_size}, 0, 1, 1" # VL = un_size
871 assert index == 3, "index must be r3"
872 yield f"sv.cmp/m=1<<r3 1, 1, {qhat_prod_lo}, *{un} # cmpld cr1, ..."
873 yield "bc 4, 1, divmod_qhat_adj_loop_break # ble divmod_qhat_adj..."
874 yield "divmod_qhat_do_adj:"
875
876 yield f"addi {qhat}, {qhat}, -1" # qhat -= 1
877
878 yield f"addi {index}, {n_scalar}, -1" # index = n - 1
879 # carry = (rhat_lo + vn[index]) >= 2 ** self.word_size
880 # rhat_lo = (rhat_lo + vn[index]) % 2 ** self.word_size
881 yield f"setvl 0, 0, {vn_size}, 0, 1, 1" # VL = vn_size
882 assert index == 3, "index must be r3"
883 yield f"sv.addc/m=1<<r3 {rhat_lo}, {rhat_lo}, *{vn}"
884 # rhat_hi = carry
885 yield f"addi 0, 0, 0"
886 yield f"addze. {rhat_hi}, 0"
887
888 # while rhat_hi == 0:
889 yield "bc 4, 2, divmod_qhat_adj_loop # bne divmod_qhat_adj_loop"
890 yield "divmod_qhat_adj_loop_break:"
891
892 # Step D4: multiply and subtract
893
894 yield f"setvl 0, {n_scalar}, {vn_size}, 0, 1, 1" # VL = n
895 yield f"addi {t_for_prod}, 0, 0" # t = 0
896 # product[:n] = vn[:n] * qhat
897 yield f"sv.maddedu *{product}, *{vn}, {qhat}, {t_for_prod}"
898 yield f"or {index}, {n_scalar}, {n_scalar}" # index = n
899 yield f"setvl 0, 0, {product_size}, 0, 1, 1" # VL = product_size
900 # product[index] = t
901 assert index == 3, "index must be r3"
902 yield f"sv.or/m=1<<r3 *{product}, {t_for_prod}, {t_for_prod}"
903
904 yield "subfc 0, 0, 0" # t = 1 (t is CA)
905 yield f"addi {sub_len}, {n_scalar}, 1" # sub_len = n + 1
906 yield f"setvl 0, {sub_len}, {product_size}, 0, 1, 1" # VL = sub_len
907 # create svshape that offsets by `j`
908 svshape = SVSHAPE(0)
909 svshape.zdimsz = q_size + 1
910 svshape_low = int(svshape) % 2 ** 16
911 svshape_high = int(svshape) >> 16
912 offset_field = svshape.fsi['offset']
913 assert 2 ** (len(offset_field) - 1) >= q_size, \
914 "max needed offset won't fit in SVSHAPE"
915 mask_start_le = len(svshape) - offset_field.br[0] - 1
916 mask_start = 64 - mask_start_le - 1
917 last = len(offset_field) - 1
918 shift_amount = len(svshape) - offset_field.br[last] - 1
919 # insert j in offset field
920 yield f"rldic 0, {j}, {shift_amount}, {mask_start}"
921 # or in all the other bits
922 if svshape_high != 0:
923 yield f"oris 0, 0, {svshape_high}"
924 if svshape_low != 0:
925 yield f"ori 0, 0, {svshape_low}"
926 yield f"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
927 yield f"svremap 0o12, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RB & RT
928 # un[j:] -= product
929 yield f"sv.subfe *{un}, *{product}, *{un}"
930 # need_fixup = not CA
931
932 # Step D5: test remainder
933
934 yield f"mcrxrx 0" # move CA to CR0.eq
935 # if need_fixup:
936 yield "bc 12, 2, divmod_skip_fixup # beq divmod_skip_fixup"
937
938 # Step D6: add back
939
940 yield f"addi {qhat}, {qhat}, -1" # qhat -= 1
941 yield "addic 0, 0, 0" # t = 0 (t is CA)
942 yield f"setvl 0, {n_scalar}, {vn_size}, 0, 1, 1" # VL = n
943 yield f"svremap 0o11, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RA & RT
944 # un[j:] += vn
945 yield f"sv.adde *{un}, *{un}, *{vn}"
946 yield f"add {index}, {j}, {n_scalar}" # index = j + n
947 # un[index] += t
948 yield f"setvl 0, 0, {un_size}, 0, 1, 1" # VL = un_size
949 assert index == 3, "index must be r3"
950 yield f"sv.addze/m=1<<r3 *{un}, *{un}"
951
952 yield "divmod_skip_fixup:"
953
954 yield f"or {index}, {j}, {j}" # index = j
955 # q[j] = qhat
956 yield f"setvl 0, 0, {q_size}, 0, 1, 1" # VL = q_size
957 yield f"sv.or/m=1<<r3 *{q}, {qhat}, {qhat}"
958
959 # Step D2 and Step D7: loop
960 yield f"addic. {j}, {j}, -1" # j -= 1
961 yield f"bc 4, 0, divmod_loop # bge divmod_loop"
962
963 # Step D8: un-normalize
964
965 # move s and n
966 yield f"or {s_for_unnorm}, {s_scalar}, {s_scalar}"
967 yield f"or {n_for_unnorm}, {n_scalar}, {n_scalar}"
968
969 # r = [0] * self.r_size # remainder
970 yield f"setvl 0, 0, {r_size}, 0, 1, 1" # VL = r_size
971 yield f"sv.addi *{r}, 0, 0"
972
973 # r = un >> s
974 yield f"addi {t_for_unnorm}, 0, 0" # t = 0
975 yield f"setvl 0, {n_for_unnorm}, {r_size}, 0, 1, 1" # VL = n
976 yield f"sv.dsrd/mrr *{r}, *{un}, {s_for_unnorm}, {t_for_unnorm}"
977
978 yield "divmod_return:"
979 yield "addi 1, 1, 176" # teardown stack frame
980 yield "ld 0, 16(1)"
981 yield "mtspr 8, 0 # mtlr 0" # restore return address
982 yield "setvl 0, 0, 18, 0, 1, 1" # set VL to 18
983 yield "sv.ld *14, -144(1)" # restore all callee-save registers
984 yield "bclr 20, 0, 0 # blr"
985
986 @cached_property
987 def asm(self):
988 return tuple(self.__asm_iter())
989
990
991 POWMOD_256_ASM = (
992 # base is in r4-7, exp is in r8-11, mod is in r32-35
993 "powmod_256:",
994 "mfspr 0, 8 # mflr 0",
995 "std 0, 16(1)", # save return address
996 "setvl 0, 0, 18, 0, 1, 1", # set VL to 18
997 "sv.std *14, -144(1)", # save all callee-save registers
998 "stdu 1, -176(1)", # create stack frame as required by ABI
999
1000 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
1001 "sv.or *16, *4, *4", # move base to r16-19
1002 "sv.or *20, *8, *8", # move exp to r20-23
1003 "sv.or *24, *32, *32", # move mod to r24-27
1004 "sv.addi *28, 0, 0", # retval in r28-31
1005 "addi 28, 0, 1", # retval = 1
1006
1007 "addi 14, 0, 256", # ctr in r14
1008
1009 "powmod_256_loop:",
1010 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
1011 "addi 3, 0, 1 # li 3, 1", # shift amount
1012 "addi 0, 0, 0 # li 0, 0", # dsrd carry
1013 "sv.dsrd/mrr *20, *20, 3, 0", # exp >>= 1; shifted out bit in r0
1014 "cmpli 0, 1, 0, 0 # cmpldi 0, 0",
1015 "bc 12, 2, powmod_256_else # beq powmod_256_else", # if lsb:
1016
1017 "sv.or *4, *28, *28", # copy retval to r4-7
1018 "sv.or *8, *16, *16", # copy base to r8-11
1019 "bl mul_256_to_512", # prod = retval * base
1020 # prod in r4-11
1021
1022 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
1023 "sv.or *32, *24, *24", # copy mod to r32-35
1024
1025 "bl divmod_512_by_256", # prod % mod
1026 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
1027 "sv.or *28, *8, *8", # retval = prod % mod
1028
1029 "powmod_256_else:",
1030
1031 "sv.or *4, *16, *16", # copy base to r4-7
1032 "sv.or *8, *16, *16", # copy base to r8-11
1033 "bl mul_256_to_512", # prod = base * base
1034 # prod in r4-11
1035
1036 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
1037 "sv.or *32, *24, *24", # copy mod to r32-35
1038
1039 "bl divmod_512_by_256", # prod % mod
1040 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
1041 "sv.or *16, *8, *8", # base = prod % mod
1042
1043 "addic. 14, 14, -1", # decrement ctr and compare against zero
1044 "bc 4, 2, powmod_256_loop # bne powmod_256_loop",
1045
1046 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
1047 "sv.or *4, *28, *28", # move retval to r4-7
1048
1049 "addi 1, 1, 176", # teardown stack frame
1050 "ld 0, 16(1)",
1051 "mtspr 8, 0 # mtlr 0", # restore return address
1052 "setvl 0, 0, 18, 0, 1, 1", # set VL to 18
1053 "sv.ld *14, -144(1)", # restore all callee-save registers
1054 "bclr 20, 0, 0 # blr",
1055 *MUL_256_X_256_TO_512_ASM,
1056 *DivModKnuthAlgorithmD().asm,
1057 )
1058
1059
1060 def python_powmod_256_algorithm(base, exp, mod):
1061 retval = 1
1062 for _ in range(256):
1063 lsb = bool(exp & 1) # rshift and retrieve lsb
1064 exp >>= 1
1065 if lsb:
1066 prod = retval * base
1067 retval = prod % mod
1068 prod = base * base
1069 base = prod % mod
1070 return retval
1071
1072
1073 class PowModCases(TestAccumulatorBase):
1074 def call_case(self, instructions, expected, initial_regs, src_loc_at=0):
1075 stop_at_pc = 0x10000000
1076 sprs = {8: stop_at_pc}
1077 expected.intregs[1] = initial_regs[1] = 0x1000000 # set stack pointer
1078 expected.pc = stop_at_pc
1079 expected.sprs['LR'] = None
1080 self.add_case(assemble(instructions),
1081 initial_regs, initial_sprs=sprs,
1082 stop_at_pc=stop_at_pc, expected=expected,
1083 src_loc_at=src_loc_at + 1)
1084
1085 def case_mul_256_x_256_to_512(self):
1086 for i in range(10):
1087 a = hash_256(f"mul256 input a {i}")
1088 b = hash_256(f"mul256 input b {i}")
1089 if i == 0:
1090 # use known values:
1091 a = b = 2**256 - 1
1092 elif i == 1:
1093 # use known values:
1094 a = b = (2**256 - 1) // 0xFF
1095 y = a * b
1096 with self.subTest(a=f"{a:#_x}", b=f"{b:#_x}", y=f"{y:#_x}"):
1097 # registers start filled with junk
1098 initial_regs = [0xABCDEF] * 128
1099 for i in range(4):
1100 # write a in LE order to regs 4-7
1101 initial_regs[4 + i] = (a >> (64 * i)) % 2**64
1102 # write b in LE order to regs 8-11
1103 initial_regs[8 + i] = (b >> (64 * i)) % 2**64
1104 # only check regs up to r11 since that's where the output is
1105 e = ExpectedState(int_regs=initial_regs[:12])
1106 for i in range(8):
1107 # write y in LE order to regs 4-11
1108 e.intregs[4 + i] = (y >> (64 * i)) % 2**64
1109
1110 self.call_case(MUL_256_X_256_TO_512_ASM, e, initial_regs)
1111
1112 @staticmethod
1113 def divmod_512x256_to_256x256_test_inputs():
1114 yield (2 ** (256 - 1), 1)
1115 yield (2 ** (512 - 1) - 1, 2 ** 256 - 1)
1116
1117 # test division by single word
1118 yield (((1 << 256) - 1) << 32, 1 << 32)
1119 yield (((1 << 192) - 1) << 32, 1 << 32)
1120 yield (((1 << 64) - 1) << 32, 1 << 32)
1121 yield (1 << 32, 1 << 32)
1122
1123 # test qhat overflow
1124 yield (0x8000 << 128 | 0xFFFE << 64, 0x8000 << 64 | 0xFFFF)
1125
1126 # tests where add back is required
1127 yield (8 << (192 - 4) | 3, 2 << (192 - 4) | 1)
1128 yield (0x8000 << 128 | 3, 0x2000 << 128 | 1)
1129 yield (0x7FFF << 192 | 0x8000 << 128, 0x8000 << 128 | 1)
1130
1131 for i in range(20):
1132 n = hash_256(f"divmod256 input n msb {i}")
1133 n <<= 256
1134 n |= hash_256(f"divmod256 input n lsb {i}")
1135 n_shift = hash_256(f"divmod256 input n shift {i}") % 512
1136 n >>= n_shift
1137 d = hash_256(f"divmod256 input d {i}")
1138 d_shift = hash_256(f"divmod256 input d shift {i}") % 256
1139 d >>= d_shift
1140 if d == 0:
1141 d = 1
1142 n %= d << 256
1143 yield (n, d)
1144
1145 def case_divmod_shift_sub_512x256_to_256x256(self):
1146 cases = list(self.divmod_512x256_to_256x256_test_inputs())
1147 del cases[2:-1] # speed up tests by removing most test cases
1148 for n, d in cases:
1149 q, r = divmod(n, d)
1150 with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",
1151 q=f"{q:#_x}", r=f"{r:#_x}"):
1152 # registers start filled with junk
1153 initial_regs = [0xABCDEF] * 128
1154 for i in range(8):
1155 # write n in LE order to regs 4-11
1156 initial_regs[4 + i] = (n >> (64 * i)) % 2**64
1157 for i in range(4):
1158 # write d in LE order to regs 32-35
1159 initial_regs[32 + i] = (d >> (64 * i)) % 2**64
1160 # only check regs up to r11 since that's where the output is.
1161 # don't check CR
1162 e = ExpectedState(int_regs=initial_regs[:12], crregs=0)
1163 e.intregs[0] = 0 # leftovers -- ignore
1164 e.intregs[3] = 1 # leftovers -- ignore
1165 e.ca = None # ignored
1166 for i in range(4):
1167 # write q in LE order to regs 4-7
1168 e.intregs[4 + i] = (q >> (64 * i)) % 2**64
1169 # write r in LE order to regs 8-11
1170 e.intregs[8 + i] = (r >> (64 * i)) % 2**64
1171
1172 self.call_case(
1173 DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM, e, initial_regs)
1174
1175 def case_divmod_knuth_algorithm_d_512x256_to_256x256(self):
1176 cases = list(self.divmod_512x256_to_256x256_test_inputs())
1177 asm = DivModKnuthAlgorithmD().asm
1178 for n, d in cases:
1179 q, r = divmod(n, d)
1180 with self.subTest(n=f"{n:#_x}", d=f"{d:#_x}",
1181 q=f"{q:#_x}", r=f"{r:#_x}"):
1182 # registers start filled with junk
1183 initial_regs = [0xABCDEF] * 128
1184 for i in range(8):
1185 # write n in LE order to regs 4-11
1186 initial_regs[4 + i] = (n >> (64 * i)) % 2**64
1187 for i in range(4):
1188 # write d in LE order to regs 32-35
1189 initial_regs[32 + i] = (d >> (64 * i)) % 2**64
1190 # only check regs up to r11 since that's where the output is.
1191 # don't check CR
1192 e = ExpectedState(int_regs=initial_regs[:12], crregs=0)
1193 e.intregs[0] = None # ignored
1194 e.intregs[3] = None # ignored
1195 e.ca = None # ignored
1196 e.sprs['SVSHAPE0'] = None
1197 for i in range(4):
1198 # write q in LE order to regs 4-7
1199 e.intregs[4 + i] = (q >> (64 * i)) % 2**64
1200 # write r in LE order to regs 8-11
1201 e.intregs[8 + i] = (r >> (64 * i)) % 2**64
1202
1203 self.call_case(asm, e, initial_regs)
1204
1205 @staticmethod
1206 def powmod_256_test_inputs():
1207 for i in range(3):
1208 base = hash_256(f"powmod256 input base {i}")
1209 exp = hash_256(f"powmod256 input exp {i}")
1210 mod = hash_256(f"powmod256 input mod {i}")
1211 if i == 0:
1212 base = 2
1213 exp = 2 ** 256 - 1
1214 mod = 2 ** 256 - 189 # largest prime less than 2 ** 256
1215 if mod == 0:
1216 mod = 1
1217 base %= mod
1218 yield (base, exp, mod)
1219
1220 def case_powmod_256(self):
1221 for base, exp, mod in PowModCases.powmod_256_test_inputs():
1222 expected = pow(base, exp, mod)
1223 with self.subTest(base=f"{base:#_x}", exp=f"{exp:#_x}",
1224 mod=f"{mod:#_x}", expected=f"{expected:#_x}"):
1225 # registers start filled with junk
1226 initial_regs = [0xABCDEF] * 128
1227 for i in range(4):
1228 # write n in LE order to regs 4-7
1229 initial_regs[4 + i] = (base >> (64 * i)) % 2**64
1230 for i in range(4):
1231 # write n in LE order to regs 8-11
1232 initial_regs[8 + i] = (exp >> (64 * i)) % 2**64
1233 for i in range(4):
1234 # write d in LE order to regs 32-35
1235 initial_regs[32 + i] = (mod >> (64 * i)) % 2**64
1236 # don't check CR
1237 e = ExpectedState(int_regs=initial_regs, crregs=0)
1238 for i in range(128):
1239 nonvolatile = 14 <= i <= 31
1240 if nonvolatile or i in (1, 2, 13):
1241 continue
1242 e.intregs[i] = None
1243 e.ca = None # ignored
1244 e.sprs['SVSHAPE0'] = None
1245 for i in range(4):
1246 # write output in LE order to regs 4-7
1247 e.intregs[4 + i] = (expected >> (64 * i)) % 2**64
1248
1249 self.call_case(POWMOD_256_ASM, e, initial_regs)
1250
1251
1252 # for running "quick" simple investigations
1253 if __name__ == "__main__":
1254 # first check if python_mul_algorithm works
1255 a = b = (99, 99, 99, 99)
1256 expected = [1, 0, 0, 0, 98, 99, 99, 99]
1257 assert python_mul_algorithm(a, b) == expected
1258
1259 # now test python_mul_algorithm2 *against* python_mul_algorithm
1260 import random
1261 random.seed(0) # reproducible values
1262 for i in range(10000):
1263 a = []
1264 b = []
1265 for j in range(4):
1266 a.append(random.randint(0, 99))
1267 b.append(random.randint(0, 99))
1268 expected = python_mul_algorithm(a, b)
1269 testing = python_mul_algorithm2(a, b)
1270 report = "%+17s * %-17s = %s\n" % (repr(a), repr(b), repr(expected))
1271 report += " (%s)" % repr(testing)
1272 print(report)
1273 assert expected == testing