1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 """ modular exponentiation (`pow(x, y, z)`)
11 * https://bugs.libre-soc.org/show_bug.cgi?id=1044
14 from openpower
.test
.common
import TestAccumulatorBase
, skip_case
15 from openpower
.test
.state
import ExpectedState
16 from openpower
.test
.util
import assemble
17 from nmutil
.sim_util
import hash_256
20 MUL_256_X_256_TO_512_ASM
= (
22 # a is in r4-7, b is in r8-11
23 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
24 "sv.or *32, *4, *4", # move args to r32-39
25 # a is now in r32-35, b is in r36-39, y is in r4-11, t is in r40-44
26 "sv.addi *4, 0, 0", # clear output
27 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
28 "sv.maddedu *4, *32, 36, 8", # first partial-product a * b[0]
30 "sv.maddedu *40, *32, 37, 44", # second partial-product a * b[1]
32 "sv.adde *6, *6, *41",
34 "sv.maddedu *40, *32, 38, 44", # third partial-product a * b[2]
36 "sv.adde *7, *7, *41",
38 "sv.maddedu *40, *32, 39, 44", # final partial-product a * b[3]
40 "sv.adde *8, *8, *41",
41 "bclr 20, 0, 0 # blr",
45 def _python_mul_algorithm(a
, b
):
46 # version of the MUL_256_X_256_TO_512_ASM algorithm using base 100 rather
47 # than 2^64, since that's easier to read.
48 # run this file in a debugger to see all the intermediate values.
51 return y
% 100, y
// 100
55 return y
% 100, y
// 100
59 return y
% 100, y
// 100
64 y
[i
], y
[4] = maddedu(a
[0], b
[i
], y
[4])
67 t
[i
], t
[4] = maddedu(a
[1], b
[i
], t
[4])
68 y
[1], ca
= addc(y
[1], t
[0])
70 y
[2 + i
], ca
= adde(y
[2 + i
], t
[1 + i
], ca
)
73 t
[i
], t
[4] = maddedu(a
[2], b
[i
], t
[4])
74 y
[2], ca
= addc(y
[2], t
[0])
76 y
[3 + i
], ca
= adde(y
[3 + i
], t
[1 + i
], ca
)
79 t
[i
], t
[4] = maddedu(a
[3], b
[i
], t
[4])
80 y
[3], ca
= addc(y
[3], t
[0])
82 y
[4 + i
], ca
= adde(y
[4 + i
], t
[1 + i
], ca
)
86 DIVMOD_512x256_TO_256x256_ASM
= (
87 # extremely slow and simplistic shift and subtract algorithm.
88 # a future task is to rewrite to use Knuth's Algorithm D,
89 # which is generally an order of magnitude faster
91 # n is in r4-11, d is in r32-35
92 "addi 3, 0, 256 # li 3, 256",
93 "mtspr 9, 3 # mtctr 3", # set CTR to 256
94 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
96 "sv.or *40, *4, *4", # assign n to r, in r40-47
97 # shifted_d is in r32-39
98 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
99 "addi 3, 0, 1 # li 3, 1", # shift amount
100 "addi 0, 0, 0 # li 0, 0", # dsrd carry
101 "sv.dsrd/mrr *36, *32, 3, 0", # shifted_d = d << (256 - 1)
102 "sv.addi *32, 0, 0", # clear lsb half
103 "sv.or 35, 0, 0", # move carry to correct location
105 "sv.addi *4, 0, 0", # clear q
107 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
108 "subfc 0, 0, 0", # set CA
110 "sv.subfe *48, *32, *40", # diff = r - shifted_d
111 # not borrowed is in CA
112 "mcrxrx 0", # move CA to CR0.eq
113 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
114 "addi 0, 0, 0 # li 0, 0", # dsld carry
115 "sv.dsld *4, *4, 3, 0", # q <<= 1 (1 is in r3)
116 "setvl 0, 0, 8, 0, 1, 1", # set VL to 8
117 "bc 4, 2, divmod_else # bne divmod_else", # if borrowed goto divmod_else
118 "ori 4, 4, 1", # q |= 1
119 "sv.or *40, *48, *48", # r = diff
121 "addi 0, 0, 0 # li 0, 0", # dsld carry
122 "sv.dsld *40, *40, 3, 0", # r <<= 1 (1 is in r3)
123 "bc 16, 0, divmod_loop # bdnz divmod_loop",
124 "setvl 0, 0, 4, 0, 1, 1", # set VL to 4
126 "sv.or *8, *44, 44", # r >>= 256
127 # q is in r4-7, r is in r8-11
128 "bclr 20, 0, 0 # blr",
132 def python_divmod_algorithm(n
, d
, width
=256):
133 assert n
>= 0 and d
> 0 and width
> 0 and n
< (d
<< width
), "invalid input"
135 shifted_d
= d
<< (width
- 1)
137 for _
in range(width
):
148 class PowModCases(TestAccumulatorBase
):
149 def call_case(self
, instructions
, expected
, initial_regs
, src_loc_at
=0):
150 stop_at_pc
= 0x10000000
151 sprs
= {8: stop_at_pc
}
152 expected
.intregs
[1] = initial_regs
[1] = 0x1000000 # set stack pointer
153 expected
.pc
= stop_at_pc
154 expected
.sprs
['LR'] = None
155 self
.add_case(assemble(instructions
),
156 initial_regs
, initial_sprs
=sprs
,
157 stop_at_pc
=stop_at_pc
, expected
=expected
,
158 src_loc_at
=src_loc_at
+ 1)
160 def case_mul_256_x_256_to_512(self
):
162 a
= hash_256(f
"mul256 input a {i}")
163 b
= hash_256(f
"mul256 input b {i}")
169 a
= b
= (2**256 - 1) // 0xFF
171 with self
.subTest(a
=f
"{a:#_x}", b
=f
"{b:#_x}", y
=f
"{y:#_x}"):
172 # registers start filled with junk
173 initial_regs
= [0xABCDEF] * 128
175 # write a in LE order to regs 4-7
176 initial_regs
[4 + i
] = (a
>> (64 * i
)) % 2**64
177 # write b in LE order to regs 8-11
178 initial_regs
[8 + i
] = (b
>> (64 * i
)) % 2**64
179 # only check regs up to r11 since that's where the output is
180 e
= ExpectedState(int_regs
=initial_regs
[:12])
182 # write y in LE order to regs 4-11
183 e
.intregs
[4 + i
] = (y
>> (64 * i
)) % 2**64
185 self
.call_case(MUL_256_X_256_TO_512_ASM
, e
, initial_regs
)
187 @skip_case("FIXME: wip -- currently broken")
188 def case_divmod_512x256_to_256x256(self
):
190 n
= hash_256(f
"divmod256 input n msb {i}")
192 n |
= hash_256(f
"divmod256 input n lsb {i}")
193 d
= hash_256(f
"divmod256 input d {i}")
200 n
= 2 ** (512 - 1) - 1
207 with self
.subTest(n
=f
"{n:#_x}", d
=f
"{d:#_x}",
208 q
=f
"{q:#_x}", r
=f
"{r:#_x}"):
209 # registers start filled with junk
210 initial_regs
= [0xABCDEF] * 128
212 # write n in LE order to regs 4-11
213 initial_regs
[4 + i
] = (n
>> (64 * i
)) % 2**64
215 # write d in LE order to regs 32-35
216 initial_regs
[32 + i
] = (d
>> (64 * i
)) % 2**64
217 # only check regs up to r11 since that's where the output is.
219 e
= ExpectedState(int_regs
=initial_regs
[:12], crregs
=0)
220 e
.intregs
[0] = 0 # leftovers -- ignore
221 e
.intregs
[3] = 1 # leftovers -- ignore
223 # write q in LE order to regs 4-7
224 e
.intregs
[4 + i
] = (q
>> (64 * i
)) % 2**64
225 # write r in LE order to regs 8-11
226 e
.intregs
[8 + i
] = (r
>> (64 * i
)) % 2**64
228 self
.call_case(DIVMOD_512x256_TO_256x256_ASM
, e
, initial_regs
)
230 # TODO: add 256-bit modular exponentiation
233 if __name__
== "__main__":
234 a
= b
= 99, 99, 99, 99
235 assert _python_mul_algorithm(a
, b
) == [1, 0, 0, 0, 98, 99, 99, 99]