1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
3 # Copyright 2023 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
5 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
6 # of Horizon 2020 EU Programme 957073.
8 # * https://bugs.libre-soc.org/show_bug.cgi?id=1044
10 """ modular exponentiation (`pow(x, y, z)`)
14 * https://bugs.libre-soc.org/show_bug.cgi?id=1044
17 from openpower
.test
.common
import TestAccumulatorBase
, skip_case
18 from openpower
.test
.state
import ExpectedState
19 from openpower
.test
.util
import assemble
20 from nmutil
.sim_util
import hash_256
21 from openpower
.util
import log
24 MUL_256_X_256_TO_512_ASM
= (
26 # a is in r4-7, b is in r8-11
27 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
28 "sv.or *32, *4, *4", # move args to r32-39
29 # a is now in r32-35, b is in r36-39, y is in r4-11, t is in r40-44
30 "sv.addi *4, 0, 0", # clear output
31 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
32 "sv.maddedu *4, *32, 36, 8", # first partial-product a * b[0]
34 "sv.maddedu *40, *32, 37, 44", # second partial-product a * b[1]
36 "sv.adde *6, *6, *41",
38 "sv.maddedu *40, *32, 38, 44", # third partial-product a * b[2]
40 "sv.adde *7, *7, *41",
42 "sv.maddedu *40, *32, 39, 44", # final partial-product a * b[3]
44 "sv.adde *8, *8, *41",
45 "bclr 20, 0, 0 # blr",
48 # TODO: these really need to go into a common util file, see
49 # openpower/decoder/isa/poly1305-donna.py:def _DSRD(lo, hi, sh)
50 # okok they are modulo 100 but you get the general idea
55 return y
% 100, y
// 100
60 return y
% 100, y
// 100
65 return y
% 100, y
// 100
68 def python_mul_algorithm(a
, b
):
69 # version of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
70 # than 2^64, since that's easier to read.
71 # run this file in a debugger to see all the intermediate values.
75 y
[i
], y
[4] = maddedu(a
[0], b
[i
], y
[4])
78 t
[i
], t
[4] = maddedu(a
[1], b
[i
], t
[4])
79 y
[1], ca
= addc(y
[1], t
[0])
81 y
[2 + i
], ca
= adde(y
[2 + i
], t
[1 + i
], ca
)
84 t
[i
], t
[4] = maddedu(a
[2], b
[i
], t
[4])
85 y
[2], ca
= addc(y
[2], t
[0])
87 y
[3 + i
], ca
= adde(y
[3 + i
], t
[1 + i
], ca
)
90 t
[i
], t
[4] = maddedu(a
[3], b
[i
], t
[4])
91 y
[3], ca
= addc(y
[3], t
[0])
93 y
[4 + i
], ca
= adde(y
[4 + i
], t
[1 + i
], ca
)
97 def python_mul_algorithm2(a
, b
):
98 # version 2 of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
99 # than 2^64, since that's easier to read.
100 # the idea here is that it will "morph" into something more akin to
101 # using REMAP bigmul (first using REMAP Indexed)
103 # create a schedule for use below. the "end of inner loop" marker is 0b01
108 iyl
.append((iy
+i
, i
==3))
111 iyl
.append((iy
+i
, i
==4))
114 y
= [0] * 8 # result y and temp t of same size
115 t
= [0] * 8 # no need after this to set t[4] to zero
117 for i
in range(4): # use t[iy+4] as a 64-bit carry
118 t
[iy
+i
], t
[iy
+4] = maddedu(a
[iy
], b
[i
], t
[iy
+4])
120 for i
in range(5): # add vec t to y with 1-bit carry
122 y
[idx
], ca
= adde(y
[idx
], t
[idx
], ca
)
126 DIVMOD_512x256_TO_256x256_ASM
= (
127 # extremely slow and simplistic shift and subtract algorithm.
128 # a future task is to rewrite to use Knuth's Algorithm D,
129 # which is generally an order of magnitude faster
130 "divmod_512_by_256:",
131 # n is in r4-11, d is in r32-35
132 "addi 3, 0, 256 # li 3, 256",
133 "mtspr 9, 3 # mtctr 3", # set CTR to 256
134 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
136 "sv.or *40, *4, *4", # assign n to r, in r40-47
137 # shifted_d is in r32-39
138 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
139 "addi 3, 0, 1 # li 3, 1", # shift amount
140 "addi 0, 0, 0 # li 0, 0", # dsrd carry
141 "sv.dsrd/mrr *36, *32, 3, 0", # shifted_d = d << (256 - 1)
142 "sv.addi *32, 0, 0", # clear lsb half
143 "sv.or 35, 0, 0", # move carry to correct location
145 "sv.addi *4, 0, 0", # clear q
147 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
148 "subfc 0, 0, 0", # set CA
150 "sv.subfe *48, *32, *40", # diff = r - shifted_d
151 # not borrowed is in CA
152 "mcrxrx 0", # move CA to CR0.eq
153 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
154 "addi 0, 0, 0 # li 0, 0", # dsld carry
155 "sv.dsld *4, *4, 3, 0", # q <<= 1 (1 is in r3)
156 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
157 "bc 4, 2, divmod_else # bne divmod_else", # if borrowed goto divmod_else
158 "ori 4, 4, 1", # q |= 1
159 "sv.or *40, *48, *48", # r = diff
161 "addi 0, 0, 0 # li 0, 0", # dsld carry
162 "sv.dsld *40, *40, 3, 0", # r <<= 1 (1 is in r3)
163 "bc 16, 0, divmod_loop # bdnz divmod_loop",
164 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
166 "sv.or *8, *44, *44", # r >>= 256
167 # q is in r4-7, r is in r8-11
168 "bclr 20, 0, 0 # blr",
172 class _DivModRegsRegexLogger
:
173 """ logger that logs a regex that matches the expected register dump for
174 the currently tracked `locals` -- quite useful for debugging
177 def __init__(self
, enabled
=True):
179 self
.enabled
= enabled
181 def log(self
, locals_
, **changes
):
184 # create a variable `a`:
187 # we invoke `locals()` each time since python doesn't guarantee
188 # it's up-to-date otherwise
189 logger.log(locals(), a=(4, 6)) # `a` starts at r4 and uses 6 registers
193 logger.log(locals()) # keeps using `a`
197 logger.log(locals(), a=None, b=(4, 6)) # remove `a` and add `b`
201 for k
, v
in changes
.items():
203 del self
.__tracked
[k
]
205 self
.__tracked
[k
] = v
208 for name
, (start_gpr
, size
) in self
.__tracked
.items():
209 value
= locals_
[name
]
210 for i
in range(size
):
211 assert gprs
[start_gpr
+ i
] is None, "overlapping values"
212 gprs
[start_gpr
+ i
] = (value
>> 64 * i
) % 2 ** 64
215 # after building `gprs` so we catch any missing/invalid locals
220 for i
in range(0, 128, 8):
221 segments
.append(f
"reg +{i}")
222 for value
in gprs
[i
:i
+ 8]:
224 segments
.append(" +[0-9a-f]+")
226 segments
.append(f
" +{value:08x}")
227 segments
.append("\\n")
228 log("DIVMOD REGEX:", "".join(segments
))
231 def python_divmod_algorithm(n
, d
, width
=256, log_regex
=False):
232 assert n
>= 0 and d
> 0 and width
> 0 and n
< (d
<< width
), "invalid input"
233 do_log
= _DivModRegsRegexLogger(enabled
=log_regex
).log
235 do_log(locals(), n
=(4, 8), d
=(32, 4))
238 do_log(locals(), n
=None, r
=(40, 8))
240 shifted_d
= d
<< (width
- 1)
241 do_log(locals(), d
=None, shifted_d
=(32, 8))
244 do_log(locals(), q
=(4, 4))
246 for _
in range(width
):
249 do_log(locals(), diff
=(48, 8))
265 do_log(locals(), r
=(8, 4))
270 class PowModCases(TestAccumulatorBase
):
271 def call_case(self
, instructions
, expected
, initial_regs
, src_loc_at
=0):
272 stop_at_pc
= 0x10000000
273 sprs
= {8: stop_at_pc
}
274 expected
.intregs
[1] = initial_regs
[1] = 0x1000000 # set stack pointer
275 expected
.pc
= stop_at_pc
276 expected
.sprs
['LR'] = None
277 self
.add_case(assemble(instructions
),
278 initial_regs
, initial_sprs
=sprs
,
279 stop_at_pc
=stop_at_pc
, expected
=expected
,
280 src_loc_at
=src_loc_at
+ 1)
282 def case_mul_256_x_256_to_512(self
):
284 a
= hash_256(f
"mul256 input a {i}")
285 b
= hash_256(f
"mul256 input b {i}")
291 a
= b
= (2**256 - 1) // 0xFF
293 with self
.subTest(a
=f
"{a:#_x}", b
=f
"{b:#_x}", y
=f
"{y:#_x}"):
294 # registers start filled with junk
295 initial_regs
= [0xABCDEF] * 128
297 # write a in LE order to regs 4-7
298 initial_regs
[4 + i
] = (a
>> (64 * i
)) % 2**64
299 # write b in LE order to regs 8-11
300 initial_regs
[8 + i
] = (b
>> (64 * i
)) % 2**64
301 # only check regs up to r11 since that's where the output is
302 e
= ExpectedState(int_regs
=initial_regs
[:12])
304 # write y in LE order to regs 4-11
305 e
.intregs
[4 + i
] = (y
>> (64 * i
)) % 2**64
307 self
.call_case(MUL_256_X_256_TO_512_ASM
, e
, initial_regs
)
310 def divmod_512x256_to_256x256_test_inputs():
312 n
= hash_256(f
"divmod256 input n msb {i}")
314 n |
= hash_256(f
"divmod256 input n lsb {i}")
315 d
= hash_256(f
"divmod256 input d {i}")
322 n
= 2 ** (512 - 1) - 1
330 def case_divmod_512x256_to_256x256(self
):
331 for n
, d
in self
.divmod_512x256_to_256x256_test_inputs():
333 with self
.subTest(n
=f
"{n:#_x}", d
=f
"{d:#_x}",
334 q
=f
"{q:#_x}", r
=f
"{r:#_x}"):
335 # registers start filled with junk
336 initial_regs
= [0xABCDEF] * 128
338 # write n in LE order to regs 4-11
339 initial_regs
[4 + i
] = (n
>> (64 * i
)) % 2**64
341 # write d in LE order to regs 32-35
342 initial_regs
[32 + i
] = (d
>> (64 * i
)) % 2**64
343 # only check regs up to r11 since that's where the output is.
345 e
= ExpectedState(int_regs
=initial_regs
[:12], crregs
=0)
346 e
.intregs
[0] = 0 # leftovers -- ignore
347 e
.intregs
[3] = 1 # leftovers -- ignore
348 e
.ca
= None # ignored
350 # write q in LE order to regs 4-7
351 e
.intregs
[4 + i
] = (q
>> (64 * i
)) % 2**64
352 # write r in LE order to regs 8-11
353 e
.intregs
[8 + i
] = (r
>> (64 * i
)) % 2**64
355 self
.call_case(DIVMOD_512x256_TO_256x256_ASM
, e
, initial_regs
)
357 # TODO: add 256-bit modular exponentiation
360 # for running "quick" simple investigations
361 if __name__
== "__main__":
362 # first check if python_mul_algorithm works
363 a
= b
= (99, 99, 99, 99)
364 expected
= [1, 0, 0, 0, 98, 99, 99, 99]
365 assert python_mul_algorithm(a
, b
) == expected
367 # now test python_mul_algorithm2 *against* python_mul_algorithm
369 random
.seed(0) # reproducible values
370 for i
in range(10000):
374 a
.append(random
.randint(0, 99))
375 b
.append(random
.randint(0, 99))
376 expected
= python_mul_algorithm(a
, b
)
377 testing
= python_mul_algorithm2(a
, b
)
378 report
= "%+17s * %-17s = %s\n" % (repr(a
), repr(b
), repr(expected
))
379 report
+= " (%s)" % repr(testing
)
381 assert expected
== testing