1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
3 # Copyright 2023 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
8 # * https://bugs.libre-soc.org/show_bug.cgi?id=1044
10 """ modular exponentiation (`pow(x, y, z)`)
14 * https://bugs.libre-soc.org/show_bug.cgi?id=1044
17 from openpower
.test
.common
import TestAccumulatorBase
, skip_case
18 from openpower
.test
.state
import ExpectedState
19 from openpower
.test
.util
import assemble
20 from nmutil
.sim_util
import hash_256
21 from openpower
.util
import log
22 from nmutil
.plain_data
import plain_data
23 from cached_property
import cached_property
24 from openpower
.decoder
.isa
.svshape
import SVSHAPE
25 from openpower
.decoder
.power_enums
import SPRfull
26 from openpower
.decoder
.selectable_int
import SelectableInt
29 MUL_256_X_256_TO_512_ASM
= (
31 # a is in r4-7, b is in r8-11
32 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
33 "sv.or *32, *4, *4", # move args to r32-39
34 # a is now in r32-35, b is in r36-39, y is in r4-11, t is in r40-44
35 "sv.addi *4, 0, 0", # clear output
36 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
37 "sv.maddedu *4, *32, 36, 8", # first partial-product a * b[0]
39 "sv.maddedu *40, *32, 37, 44", # second partial-product a * b[1]
41 "sv.adde *6, *6, *41",
43 "sv.maddedu *40, *32, 38, 44", # third partial-product a * b[2]
45 "sv.adde *7, *7, *41",
47 "sv.maddedu *40, *32, 39, 44", # final partial-product a * b[3]
49 "sv.adde *8, *8, *41",
50 "bclr 20, 0, 0 # blr",
53 # TODO: these really need to go into a common util file, see
54 # openpower/decoder/isa/poly1305-donna.py:def _DSRD(lo, hi, sh)
55 # okok they are modulo 100 but you get the general idea
60 return y
% 100, y
// 100
65 return y
% 100, y
// 100
70 return y
% 100, y
// 100
73 def python_mul_algorithm(a
, b
):
74 # version of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
75 # than 2^64, since that's easier to read.
76 # run this file in a debugger to see all the intermediate values.
80 y
[i
], y
[4] = maddedu(a
[0], b
[i
], y
[4])
83 t
[i
], t
[4] = maddedu(a
[1], b
[i
], t
[4])
84 y
[1], ca
= addc(y
[1], t
[0])
86 y
[2 + i
], ca
= adde(y
[2 + i
], t
[1 + i
], ca
)
89 t
[i
], t
[4] = maddedu(a
[2], b
[i
], t
[4])
90 y
[2], ca
= addc(y
[2], t
[0])
92 y
[3 + i
], ca
= adde(y
[3 + i
], t
[1 + i
], ca
)
95 t
[i
], t
[4] = maddedu(a
[3], b
[i
], t
[4])
96 y
[3], ca
= addc(y
[3], t
[0])
98 y
[4 + i
], ca
= adde(y
[4 + i
], t
[1 + i
], ca
)
102 def python_mul_algorithm2(a
, b
):
103 # version 2 of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
104 # than 2^64, since that's easier to read.
105 # the idea here is that it will "morph" into something more akin to
106 # using REMAP bigmul (first using REMAP Indexed)
108 # create a schedule for use below. the "end of inner loop" marker is 0b01
113 iyl
.append((iy
+i
, i
== 3))
116 iyl
.append((iy
+i
, i
== 4))
119 y
= [0] * 8 # result y and temp t of same size
120 t
= [0] * 8 # no need after this to set t[4] to zero
122 for i
in range(4): # use t[iy+4] as a 64-bit carry
123 t
[iy
+i
], t
[iy
+4] = maddedu(a
[iy
], b
[i
], t
[iy
+4])
125 for i
in range(5): # add vec t to y with 1-bit carry
127 y
[idx
], ca
= adde(y
[idx
], t
[idx
], ca
)
131 DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM
= (
132 # extremely slow and simplistic shift and subtract algorithm.
133 # a future task is to rewrite to use Knuth's Algorithm D,
134 # which is generally an order of magnitude faster
135 "divmod_512_by_256:",
136 # n is in r4-11, d is in r32-35
137 "addi 3, 0, 256 # li 3, 256",
138 "mtspr 9, 3 # mtctr 3", # set CTR to 256
139 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
141 "sv.or *40, *4, *4", # assign n to r, in r40-47
142 # shifted_d is in r32-39
143 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
144 "addi 3, 0, 1 # li 3, 1", # shift amount
145 "addi 0, 0, 0 # li 0, 0", # dsrd carry
146 "sv.dsrd/mrr *36, *32, 3, 0", # shifted_d = d << (256 - 1)
147 "sv.addi *32, 0, 0", # clear lsb half
148 "sv.or 35, 0, 0", # move carry to correct location
150 "sv.addi *4, 0, 0", # clear q
152 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
153 "subfc 0, 0, 0", # set CA
155 "sv.subfe *48, *32, *40", # diff = r - shifted_d
156 # not borrowed is in CA
157 "mcrxrx 0", # move CA to CR0.eq
158 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
159 "addi 0, 0, 0 # li 0, 0", # dsld carry
160 "sv.dsld *4, *4, 3, 0", # q <<= 1 (1 is in r3)
161 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
162 "bc 4, 2, divmod_else # bne divmod_else", # if borrowed goto divmod_else
163 "ori 4, 4, 1", # q |= 1
164 "sv.or *40, *48, *48", # r = diff
166 "addi 0, 0, 0 # li 0, 0", # dsld carry
167 "sv.dsld *40, *40, 3, 0", # r <<= 1 (1 is in r3)
168 "bc 16, 0, divmod_loop # bdnz divmod_loop",
169 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
171 "sv.or *8, *44, *44", # r >>= 256
172 # q is in r4-7, r is in r8-11
173 "bclr 20, 0, 0 # blr",
177 class _DivModRegsRegexLogger
:
178 """ logger that logs a regex that matches the expected register dump for
179 the currently tracked `locals` -- quite useful for debugging
182 def __init__(self
, enabled
=True, regs
=None):
184 self
.__regs
= regs
if regs
is not None else {}
185 self
.enabled
= enabled
187 def log(self
, locals_
, **changes
):
190 # create a variable `a`:
193 # we invoke `locals()` each time since python doesn't guarantee
194 # it's up-to-date otherwise
195 logger.log(locals(), a=(4, 6)) # `a` starts at r4 and uses 6 registers
199 logger.log(locals()) # keeps using `a`
203 logger.log(locals(), a=None, b=(4, 6)) # remove `a` and add `b`
207 for k
, v
in changes
.items():
209 del self
.__tracked
[k
]
211 if isinstance(v
, (tuple, list)):
216 if not isinstance(start_gpr
, int):
217 start_gpr
= self
.__regs
[start_gpr
]
218 self
.__tracked
[k
] = start_gpr
, size
221 for name
, (start_gpr
, size
) in self
.__tracked
.items():
222 value
= locals_
[name
]
225 elif not isinstance(value
, (list, tuple)):
226 value
= [(value
>> 64 * i
) % 2 ** 64 for i
in range(size
)]
228 assert len(value
) == size
, "value has wrong len"
229 for i
in range(size
):
233 if gprs
[reg
] is not None:
234 other_value
, other_name
, other_i
= gprs
[reg
]
235 raise AssertionError(f
"overlapping values at r{reg}: "
236 f
"{name}[{i}] overlaps with "
237 f
"{other_name}[{other_i}]")
238 gprs
[reg
] = value
[i
], name
, i
241 # after building `gprs` so we catch any missing/invalid locals
246 for i
in range(0, 128, 8):
247 segments
.append(f
"reg +{i}")
248 for value
in gprs
[i
:i
+ 8]:
250 segments
.append(" +[0-9a-f]+")
252 value
, name
, i
= value
253 segments
.append(f
" +{value:08x}")
254 segments
.append("\\n")
255 log("DIVMOD REGEX:", "".join(segments
))
258 def python_divmod_shift_sub_algorithm(n
, d
, width
=256, log_regex
=False):
259 assert n
>= 0 and d
> 0 and width
> 0 and n
< (d
<< width
), "invalid input"
260 do_log
= _DivModRegsRegexLogger(enabled
=log_regex
).log
262 do_log(locals(), n
=(4, 8), d
=(32, 4))
265 do_log(locals(), n
=None, r
=(40, 8))
267 shifted_d
= d
<< (width
- 1)
268 do_log(locals(), d
=None, shifted_d
=(32, 8))
271 do_log(locals(), q
=(4, 4))
273 for _
in range(width
):
276 do_log(locals(), diff
=(48, 8))
292 do_log(locals(), r
=(8, 4))
297 def divmod2du(RA
, RB
, RC
):
298 # type: (int, int, int) -> tuple[int, int, bool]
299 if RC
< RB
and RB
!= 0:
300 RT
, RS
= divmod(RC
<< 64 | RA
, RB
)
306 return RT
, RS
, overflow
310 class DivModKnuthAlgorithmD
:
311 __slots__
= "num_size", "denom_size", "q_size", "word_size", "regs"
313 def __init__(self
, num_size
=8, denom_size
=4, q_size
=4,
314 word_size
=64, regs
=None):
315 # type: (int, int, int | None, int, None | dict[str, int]) -> None
316 assert num_size
>= denom_size
, \
317 "the dividend's length must be >= the divisor's length"
321 # quotient length from original algorithm is m - n + 1,
322 # but that assumes v[-1] != 0 -- since we support smaller divisors
323 # the quotient must be larger.
355 self
.num_size
= num_size
356 self
.denom_size
= denom_size
358 self
.word_size
= word_size
363 return self
.denom_size
367 return self
.num_size
+ 1
371 return self
.denom_size
374 def product_size(self
):
375 return self
.num_size
+ 1
377 def python(self
, n
, d
, log_regex
=False, on_corner_case
=lambda desc
: None):
378 do_log
= _DivModRegsRegexLogger(enabled
=log_regex
, regs
=self
.regs
).log
380 do_log(locals(), n
=("n_0", self
.num_size
), d
=("d_0", self
.denom_size
))
382 # switch to names used by Knuth's algorithm D
383 u
= list(n
) # dividend
384 assert len(u
) == self
.num_size
, "numerator has wrong size"
385 do_log(locals(), n
=None, u
=("u", self
.num_size
))
386 m
= len(u
) # length of dividend
387 do_log(locals(), m
="m")
388 v
= list(d
) # divisor
389 assert len(v
) == self
.denom_size
, "denominator has wrong size"
390 del d
# less confusing to debug
391 do_log(locals(), d
=None, v
=("v", self
.denom_size
))
392 n
= len(v
) # length of divisor
393 do_log(locals(), n
="n_scalar")
395 # allocate outputs/temporaries -- before any normalization so
396 # the outputs/temporaries can be fixed-length in the assembly version.
398 q
= [0] * self
.q_size
# quotient
399 do_log(locals(), q
=("q", self
.q_size
))
400 vn
= [None] * self
.vn_size
# normalized divisor
401 do_log(locals(), vn
=("vn", self
.vn_size
))
402 un
= [None] * self
.un_size
# normalized dividend
403 do_log(locals(), un
=("un", self
.un_size
))
404 product
= [None] * self
.product_size
405 do_log(locals(), product
=("product", self
.product_size
))
407 # get non-zero length of dividend
408 while m
> 0 and u
[m
- 1] == 0:
413 # get non-zero length of divisor
414 while n
> 0 and v
[n
- 1] == 0:
420 raise ZeroDivisionError
423 on_corner_case("single-word divisor")
424 # Knuth's algorithm D requires the divisor to have length >= 2
425 # handle single-word divisors separately
430 do_log(locals(), t
="t_single", n
=None)
431 do_log(locals(), m
=None) # VL = m, so we don't need it in a GPR
432 for i
in reversed(range(m
)):
433 q
[i
], t
, _
= divmod2du(u
[i
], v
[0], t
)
435 r
= [0] * self
.r_size
# remainder
437 do_log(locals(), t
=None, r
=("r", self
.r_size
))
441 r
= [None] * self
.r_size
# remainder
442 do_log(locals(), r
=("r", self
.r_size
), m
=None, n
=None)
444 for i
in range(self
.r_size
):
449 # Knuth's algorithm D starts here:
453 # calculate amount to shift by -- count leading zeros
456 do_log(locals(), index
="index")
457 while (v
[index
] << s
) >> (self
.word_size
- 1) == 0:
460 do_log(locals(), s
="s_scalar", index
=None)
463 on_corner_case("non-zero shift")
467 do_log(locals(), t
="t_for_uv_shift")
471 v
[i
] = None # mark reg as unused
472 vn
[i
] = t
% 2 ** self
.word_size
478 do_log(locals(), v
=None)
482 u
[i
] = None # mark reg as unused
483 un
[i
] = t
% 2 ** self
.word_size
487 do_log(locals(), index
="index")
490 do_log(locals(), u
=None, t
=None, index
=None)
492 # Step D2 and Step D7: loop
493 for j
in range(min(m
- n
, self
.q_size
- 1), -1, -1):
494 do_log(locals(), j
="j")
495 # Step D3: calculate q̂
498 do_log(locals(), index
="index")
499 qhat_num_hi
= un
[index
]
500 do_log(locals(), qhat_num_hi
="qhat_num_hi")
503 qhat_denom
= vn
[index
]
504 do_log(locals(), qhat_denom
="qhat_denom")
507 qhat
, rhat_lo
, ov
= divmod2du(un
[index
], qhat_denom
, qhat_num_hi
)
509 do_log(locals(), qhat
="qhat", rhat_lo
="rhat_lo", rhat_hi
="rhat_hi")
511 # division overflows word
512 on_corner_case("qhat overflows word")
513 assert qhat_num_hi
== qhat_denom
514 rhat_lo
= (qhat
* qhat_denom
) % 2 ** self
.word_size
515 rhat_hi
= (qhat
* qhat_denom
) >> self
.word_size
517 borrow
= un
[index
] < rhat_lo
518 rhat_lo
= (un
[index
] - rhat_lo
) % 2 ** self
.word_size
520 rhat_hi
= qhat_num_hi
- rhat_hi
- borrow
521 do_log(locals(), qhat_num_hi
=None, qhat_denom
=None)
524 if qhat
* vn
[n
- 2] > (rhat_lo
<< self
.word_size
) + un
[j
+ n
- 2]:
525 on_corner_case("qhat adjustment")
528 carry
= (rhat_lo
+ vn
[n
- 1]) >= 2 ** self
.word_size
530 rhat_lo
%= 2 ** self
.word_size
537 do_log(locals(), rhat_lo
=None, rhat_hi
=None, index
=None)
539 # Step D4: multiply and subtract
542 do_log(locals(), t
="t_for_prod")
546 product
[i
] = t
% 2 ** self
.word_size
550 do_log(locals(), t
=None)
553 for i
in range(n
+ 1):
555 not_product
= ~product
[i
] % 2 ** self
.word_size
556 t
+= not_product
+ un
[j
+ i
]
557 un
[j
+ i
] = t
% 2 ** self
.word_size
558 t
= int(t
>= 2 ** self
.word_size
)
562 # Step D5: test remainder
568 on_corner_case("add back")
576 t
+= un
[j
+ i
] + vn
[i
]
577 un
[j
+ i
] = t
% 2 ** self
.word_size
578 t
= int(t
>= 2 ** self
.word_size
)
586 # Step D8: un-normalize
587 do_log(locals(), s
="s_for_unnorm", vn
=None, m
=None, j
=None)
588 r
= [0] * self
.r_size
# remainder
589 do_log(locals(), r
=("r", self
.r_size
), n
="n_for_unnorm")
592 do_log(locals(), t
="t_for_unnorm")
593 for i
in reversed(range(n
)):
596 t |
= (un
[i
] << self
.word_size
) >> s
597 r
[i
] = t
>> self
.word_size
598 t
%= 2 ** self
.word_size
603 def __asm_iter(self
):
604 if self
.word_size
!= 64:
605 raise NotImplementedError("only word_size == 64 is implemented")
606 n_0
= self
.regs
["n_0"]
607 d_0
= self
.regs
["d_0"]
611 n_scalar
= self
.regs
["n_scalar"]
615 product
= self
.regs
["product"]
617 t_single
= self
.regs
["t_single"]
618 s_scalar
= self
.regs
["s_scalar"]
619 t_for_uv_shift
= self
.regs
["t_for_uv_shift"]
620 n_for_unnorm
= self
.regs
["n_for_unnorm"]
621 t_for_unnorm
= self
.regs
["t_for_unnorm"]
622 s_for_unnorm
= self
.regs
["s_for_unnorm"]
623 qhat
= self
.regs
["qhat"]
624 rhat_lo
= self
.regs
["rhat_lo"]
625 rhat_hi
= self
.regs
["rhat_hi"]
626 t_for_prod
= self
.regs
["t_for_prod"]
627 index
= self
.regs
["index"]
629 qhat_num_hi
= self
.regs
["qhat_num_hi"]
630 qhat_denom
= self
.regs
["qhat_denom"]
631 num_size
= self
.num_size
632 denom_size
= self
.denom_size
635 un_size
= self
.un_size
636 vn_size
= self
.vn_size
637 product_size
= self
.product_size
639 yield "divmod_512_by_256:"
640 # n in n_0 size num_size
641 # d in d_0 size denom_size
643 yield "mfspr 0, 8 # mflr 0"
644 yield "std 0, 16(1)" # save return address
645 yield "setvl 0, 0, 18, 0, 1, 1" # set VL to 18
646 yield "sv.std *14, -144(1)" # save all callee-save registers
647 yield "stdu 1, -176(1)" # create stack frame as required by ABI
649 # switch to names used by Knuth's algorithm D
650 yield f
"setvl 0, 0, {num_size}, 0, 1, 1" # set VL to num_size
651 yield f
"sv.or *{u}, *{n_0}, *{n_0}" # u = n
652 yield f
"addi {m}, 0, {num_size}" # m = len(u)
653 assert v
== d_0
, "v and d_0 must be in the same regs" # v = d
654 yield f
"addi {n_scalar}, 0, {denom_size}" # n = len(v)
656 # allocate outputs/temporaries
657 yield f
"setvl 0, 0, {q_size}, 0, 1, 1" # set VL to q_size
658 yield f
"sv.addi *{q}, 0, 0" # q = [0] * q_size
660 # get non-zero length of dividend
661 yield f
"setvl 0, 0, {num_size}, 0, 1, 1" # set VL to num_size
662 # create SVSHAPE that reverses order
664 svshape
.zdimsz
= num_size
665 svshape
.invxyz
= SelectableInt(0b1, 3) # invert Z
666 svshape_low
= int(svshape
) % 2 ** 16
667 svshape_high
= int(svshape
) >> 16
668 SVSHAPE0
= SPRfull
.SVSHAPE0
.value
669 yield f
"addis 0, 0, {svshape_high}"
670 yield f
"ori 0, 0, {svshape_low}"
671 yield f
"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
672 yield f
"svremap 0o01, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RA
673 yield f
"sv.cmpi/ff=ne *0, 1, *{u}, 0"
674 yield f
"setvl {m}, 0, 1, 0, 0, 0 # getvl {m}" # m = VL
675 yield f
"subfic {m}, {m}, {num_size}" # m = num_size - m
677 # get non-zero length of divisor
678 yield f
"setvl 0, 0, {denom_size}, 0, 1, 1" # set VL to denom_size
679 # create SVSHAPE that reverses order
681 svshape
.zdimsz
= denom_size
682 svshape
.invxyz
= SelectableInt(0b1, 3) # invert Z
683 svshape_low
= int(svshape
) % 2 ** 16
684 svshape_high
= int(svshape
) >> 16
685 yield f
"addis 0, 0, {svshape_high}"
686 yield f
"ori 0, 0, {svshape_low}"
687 yield f
"mtspr {SVSHAPE0}, 0 # mtspr SVSHAPE0, 0"
688 yield f
"svremap 0o01, 0, 0, 0, 0, 0, 0" # enable SVSHAPE0 for RA
689 yield f
"sv.cmpi/ff=ne *0, 1, *{v}, 0"
690 yield f
"setvl {n_scalar}, 0, 1, 0, 0, 0 # getvl {n_scalar}" # n = VL
692 yield f
"subfic {n_scalar}, {n_scalar}, {denom_size}"
694 yield f
"cmpi 0, 1, {n_scalar}, 1 # cmpdi {n_scalar}, 1"
695 yield "bc 4, 2, divmod_skip_sw_divisor # bne divmod_skip_sw_divisor"
697 # Knuth's algorithm D requires the divisor to have length >= 2
698 # handle single-word divisors separately
699 yield f
"addi {t_single}, 0, 0"
700 yield f
"setvl. {m}, {m}, {q_size}, 0, 1, 1" # m = VL = min(m, q_size)
701 # if CR0.SO: t = u[q_size]
702 yield f
"sv.isel {t_single}, {u + q_size}, {t_single}, 3"
704 yield f
"sv.divmod2du/mrr *{q}, *{u}, {v}, {t_single}"
706 assert r
== t_single
, "r[0] and t_single must be in the same regs"
707 yield f
"setvl 0, 0, {r_size - 1}, 0, 1, 1" # set VL to r_size - 1
708 yield f
"sv.addi *{r + 1}, 0, 0" # r[1:] = [0] * (r_size - 1)
710 yield "b divmod_return"
712 yield "divmod_skip_sw_divisor:"
713 yield f
"cmp 0, 1, {m}, {n_scalar} # cmpd {m}, {n_scalar}"
714 yield "bc 4, 0, divmod_skip_copy_r # bge divmod_skip_copy_r"
717 yield f
"setvl 0, 0, {r_size}, 0, 1, 1" # set VL to r_size
718 yield f
"sv.or *{r}, *{u}, *{u}" # r[...] = u[...]
719 yield "b divmod_return"
721 yield "divmod_skip_copy_r:"
723 # Knuth's algorithm D starts here:
727 # calculate amount to shift by -- count leading zeros
728 yield f
"addi {index}, {n_scalar}, -1" # index = n - 1
729 assert index
== 3, "index must be r3"
730 yield f
"setvl. 0, 0, {denom_size}, 0, 1, 1" # VL = denom_size
731 yield f
"sv.cntlzd/m=1<<r3 {s_scalar}, *{v}" # s = clz64(v[index])
733 yield f
"addi {t_for_uv_shift}, 0, 0" # t = 0
734 yield f
"setvl. 0, {n_scalar}, {denom_size}, 0, 1, 1" # VL = n
736 yield f
"sv.dsld *{vn}, *{v}, {s_scalar}, {t_for_uv_shift}"
738 yield f
"addi {t_for_uv_shift}, 0, 0" # t = 0
739 yield f
"setvl. 0, {m}, {num_size}, 0, 1, 1" # VL = m
741 yield f
"sv.dsld *{un}, *{u}, {s_scalar}, {t_for_uv_shift}"
742 yield f
"setvl. 0, 0, {un_size}, 0, 1, 1" # VL = un_size
743 yield f
"or {index}, {m}, {m}" # index = m
744 assert index
== 3, "index must be r3"
746 yield f
"sv.or/m=1<<r3 *{un}, {t_for_uv_shift}, {t_for_uv_shift}"
748 # Step D2 and Step D7: loop
750 yield f
"subf {j}, {n_scalar}, {m}"
751 # j = min(j, q_size - 1)
752 yield f
"addi 0, 0, {q_size - 1}"
753 yield f
"minmax {j}, {j}, 0, 0 # maxd {j}, {j}, 0"
754 yield f
"divmod_loop:"
756 # Step D3: calculate q̂
757 yield f
"setvl. 0, 0, {un_size}, 0, 1, 1" # VL = un_size
760 # Step D2 and Step D7: loop
761 yield f
"addic. {j}, {j}, -1" # j -= 1
762 yield f
"bc 4, 0, divmod_loop # bge divmod_loop"
766 yield "divmod_return:"
767 yield "addi 1, 1, 176" # teardown stack frame
769 yield "mtspr 8, 0 # mtlr 0" # restore return address
770 yield "setvl 0, 0, 18, 0, 1, 1" # set VL to 18
771 yield "sv.ld *14, -144(1)" # restore all callee-save registers
772 yield "bclr 20, 0, 0 # blr"
776 return tuple(self
.__asm
_iter
())
780 # base is in r4-7, exp is in r8-11, mod is in r32-35
782 "mfspr 0, 8 # mflr 0",
783 "std 0, 16(1)", # save return address
784 "setvl 0, 0, 18, 0, 1, 1", # set VL to 18
785 "sv.std *14, -144(1)", # save all callee-save registers
786 "stdu 1, -176(1)", # create stack frame as required by ABI
788 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
789 "sv.or *16, *4, *4", # move base to r16-19
790 "sv.or *20, *8, *8", # move exp to r20-23
791 "sv.or *24, *32, *32", # move mod to r24-27
792 "sv.addi *28, 0, 0", # retval in r28-31
793 "addi 28, 0, 1", # retval = 1
795 "addi 14, 0, 256", # ctr in r14
798 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
799 "addi 3, 0, 1 # li 3, 1", # shift amount
800 "addi 0, 0, 0 # li 0, 0", # dsrd carry
801 "sv.dsrd/mrr *20, *20, 3, 0", # exp >>= 1; shifted out bit in r0
802 "cmpli 0, 1, 0, 0 # cmpldi 0, 0",
803 "bc 12, 2, powmod_256_else # beq powmod_256_else", # if lsb:
805 "sv.or *4, *28, *28", # copy retval to r4-7
806 "sv.or *8, *16, *16", # copy base to r8-11
807 "bl mul_256_to_512", # prod = retval * base
810 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
811 "sv.or *32, *24, *24", # copy mod to r32-35
813 "bl divmod_512_by_256", # prod % mod
814 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
815 "sv.or *28, *8, *8", # retval = prod % mod
819 "sv.or *4, *16, *16", # copy base to r4-7
820 "sv.or *8, *16, *16", # copy base to r8-11
821 "bl mul_256_to_512", # prod = base * base
824 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
825 "sv.or *32, *24, *24", # copy mod to r32-35
827 "bl divmod_512_by_256", # prod % mod
828 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
829 "sv.or *16, *8, *8", # base = prod % mod
831 "addic. 14, 14, -1", # decrement ctr and compare against zero
832 "bc 4, 2, powmod_256_loop # bne powmod_256_loop",
834 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
835 "sv.or *4, *28, *28", # move retval to r4-7
837 "addi 1, 1, 176", # teardown stack frame
839 "mtspr 8, 0 # mtlr 0", # restore return address
840 "setvl 0, 0, 18, 0, 1, 1", # set VL to 18
841 "sv.ld *14, -144(1)", # restore all callee-save registers
842 "bclr 20, 0, 0 # blr",
843 *MUL_256_X_256_TO_512_ASM
,
844 *DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM
,
848 def python_powmod_256_algorithm(base
, exp
, mod
):
851 lsb
= bool(exp
& 1) # rshift and retrieve lsb
861 class PowModCases(TestAccumulatorBase
):
862 def call_case(self
, instructions
, expected
, initial_regs
, src_loc_at
=0):
863 stop_at_pc
= 0x10000000
864 sprs
= {8: stop_at_pc
}
865 expected
.intregs
[1] = initial_regs
[1] = 0x1000000 # set stack pointer
866 expected
.pc
= stop_at_pc
867 expected
.sprs
['LR'] = None
868 self
.add_case(assemble(instructions
),
869 initial_regs
, initial_sprs
=sprs
,
870 stop_at_pc
=stop_at_pc
, expected
=expected
,
871 src_loc_at
=src_loc_at
+ 1)
873 def case_mul_256_x_256_to_512(self
):
875 a
= hash_256(f
"mul256 input a {i}")
876 b
= hash_256(f
"mul256 input b {i}")
882 a
= b
= (2**256 - 1) // 0xFF
884 with self
.subTest(a
=f
"{a:#_x}", b
=f
"{b:#_x}", y
=f
"{y:#_x}"):
885 # registers start filled with junk
886 initial_regs
= [0xABCDEF] * 128
888 # write a in LE order to regs 4-7
889 initial_regs
[4 + i
] = (a
>> (64 * i
)) % 2**64
890 # write b in LE order to regs 8-11
891 initial_regs
[8 + i
] = (b
>> (64 * i
)) % 2**64
892 # only check regs up to r11 since that's where the output is
893 e
= ExpectedState(int_regs
=initial_regs
[:12])
895 # write y in LE order to regs 4-11
896 e
.intregs
[4 + i
] = (y
>> (64 * i
)) % 2**64
898 self
.call_case(MUL_256_X_256_TO_512_ASM
, e
, initial_regs
)
901 def divmod_512x256_to_256x256_test_inputs():
902 yield (2 ** (256 - 1), 1)
903 yield (2 ** (512 - 1) - 1, 2 ** 256 - 1)
905 # test division by single word
906 yield (((1 << 256) - 1) << 32, 1 << 32)
907 yield (((1 << 192) - 1) << 32, 1 << 32)
908 yield (((1 << 64) - 1) << 32, 1 << 32)
909 yield (1 << 32, 1 << 32)
912 yield (0x8000 << 128 |
0xFFFE << 64, 0x8000 << 64 |
0xFFFF)
914 # tests where add back is required
915 yield (8 << (192 - 4) |
3, 2 << (192 - 4) |
1)
916 yield (0x8000 << 128 |
3, 0x2000 << 128 |
1)
917 yield (0x7FFF << 192 |
0x8000 << 128, 0x8000 << 128 |
1)
920 n
= hash_256(f
"divmod256 input n msb {i}")
922 n |
= hash_256(f
"divmod256 input n lsb {i}")
923 n_shift
= hash_256(f
"divmod256 input n shift {i}") % 512
925 d
= hash_256(f
"divmod256 input d {i}")
926 d_shift
= hash_256(f
"divmod256 input d shift {i}") % 256
933 def case_divmod_shift_sub_512x256_to_256x256(self
):
934 cases
= list(self
.divmod_512x256_to_256x256_test_inputs())
935 del cases
[2:-1] # speed up tests by removing most test cases
938 with self
.subTest(n
=f
"{n:#_x}", d
=f
"{d:#_x}",
939 q
=f
"{q:#_x}", r
=f
"{r:#_x}"):
940 # registers start filled with junk
941 initial_regs
= [0xABCDEF] * 128
943 # write n in LE order to regs 4-11
944 initial_regs
[4 + i
] = (n
>> (64 * i
)) % 2**64
946 # write d in LE order to regs 32-35
947 initial_regs
[32 + i
] = (d
>> (64 * i
)) % 2**64
948 # only check regs up to r11 since that's where the output is.
950 e
= ExpectedState(int_regs
=initial_regs
[:12], crregs
=0)
951 e
.intregs
[0] = 0 # leftovers -- ignore
952 e
.intregs
[3] = 1 # leftovers -- ignore
953 e
.ca
= None # ignored
955 # write q in LE order to regs 4-7
956 e
.intregs
[4 + i
] = (q
>> (64 * i
)) % 2**64
957 # write r in LE order to regs 8-11
958 e
.intregs
[8 + i
] = (r
>> (64 * i
)) % 2**64
961 DIVMOD_SHIFT_SUB_512x256_TO_256x256_ASM
, e
, initial_regs
)
963 def case_divmod_knuth_algorithm_d_512x256_to_256x256(self
):
964 cases
= list(self
.divmod_512x256_to_256x256_test_inputs())
965 asm
= DivModKnuthAlgorithmD().asm
971 # FIXME: only part of the algorithm is implemented,
972 # so we skip the cases that we expect to fail
975 with self
.subTest(n
=f
"{n:#_x}", d
=f
"{d:#_x}",
976 q
=f
"{q:#_x}", r
=f
"{r:#_x}"):
977 # registers start filled with junk
978 initial_regs
= [0xABCDEF] * 128
980 # write n in LE order to regs 4-11
981 initial_regs
[4 + i
] = (n
>> (64 * i
)) % 2**64
983 # write d in LE order to regs 32-35
984 initial_regs
[32 + i
] = (d
>> (64 * i
)) % 2**64
985 # only check regs up to r11 since that's where the output is.
987 e
= ExpectedState(int_regs
=initial_regs
[:12], crregs
=0)
988 e
.intregs
[0] = None # ignored
989 e
.intregs
[3] = None # ignored
990 e
.ca
= None # ignored
991 e
.sprs
['SVSHAPE0'] = None
993 # write q in LE order to regs 4-7
994 e
.intregs
[4 + i
] = (q
>> (64 * i
)) % 2**64
995 # write r in LE order to regs 8-11
996 e
.intregs
[8 + i
] = (r
>> (64 * i
)) % 2**64
998 self
.call_case(asm
, e
, initial_regs
)
1001 def powmod_256_test_inputs():
1003 base
= hash_256(f
"powmod256 input base {i}")
1004 exp
= hash_256(f
"powmod256 input exp {i}")
1005 mod
= hash_256(f
"powmod256 input mod {i}")
1009 mod
= 2 ** 256 - 189 # largest prime less than 2 ** 256
1013 yield (base
, exp
, mod
)
1015 @skip_case("FIXME: divmod is too slow to test powmod")
1016 def case_powmod_256(self
):
1017 for base
, exp
, mod
in PowModCases
.powmod_256_test_inputs():
1018 expected
= pow(base
, exp
, mod
)
1019 with self
.subTest(base
=f
"{base:#_x}", exp
=f
"{exp:#_x}",
1020 mod
=f
"{mod:#_x}", expected
=f
"{expected:#_x}"):
1021 # registers start filled with junk
1022 initial_regs
= [0xABCDEF] * 128
1024 # write n in LE order to regs 4-7
1025 initial_regs
[4 + i
] = (base
>> (64 * i
)) % 2**64
1027 # write n in LE order to regs 8-11
1028 initial_regs
[8 + i
] = (exp
>> (64 * i
)) % 2**64
1030 # write d in LE order to regs 32-35
1031 initial_regs
[32 + i
] = (mod
>> (64 * i
)) % 2**64
1032 # only check regs up to r7 since that's where the output is.
1034 e
= ExpectedState(int_regs
=initial_regs
[:8], crregs
=0)
1035 e
.ca
= None # ignored
1037 # write output in LE order to regs 4-7
1038 e
.intregs
[4 + i
] = (expected
>> (64 * i
)) % 2**64
1040 self
.call_case(POWMOD_256_ASM
, e
, initial_regs
)
1043 # for running "quick" simple investigations
1044 if __name__
== "__main__":
1045 # first check if python_mul_algorithm works
1046 a
= b
= (99, 99, 99, 99)
1047 expected
= [1, 0, 0, 0, 98, 99, 99, 99]
1048 assert python_mul_algorithm(a
, b
) == expected
1050 # now test python_mul_algorithm2 *against* python_mul_algorithm
1052 random
.seed(0) # reproducible values
1053 for i
in range(10000):
1057 a
.append(random
.randint(0, 99))
1058 b
.append(random
.randint(0, 99))
1059 expected
= python_mul_algorithm(a
, b
)
1060 testing
= python_mul_algorithm2(a
, b
)
1061 report
= "%+17s * %-17s = %s\n" % (repr(a
), repr(b
), repr(expected
))
1062 report
+= " (%s)" % repr(testing
)
1064 assert expected
== testing