pysvp64db: fix traversal
[openpower-isa.git] / src / openpower / decoder / isa / test_caller_svp64_chacha20.py
1 """Implementation of chacha20 core in SVP64
2 Copyright (C) 2022,2023 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
3 Licensed under the LGPLv3+
4 Funded by NLnet NGI-ASSURE under EU grant agreement No 957073.
5 * https://nlnet.nl/project/LibreSOC-GigabitRouter/
6 * https://bugs.libre-soc.org/show_bug.cgi?id=965
7 * https://libre-soc.org/openpower/sv/cookbook/chacha20/
8 """
9
10 import unittest
11 from copy import deepcopy
12
13 from nmutil.formaltest import FHDLTestCase
14 from openpower.decoder.isa.caller import SVP64State, set_masked_reg
15 from openpower.decoder.isa.test_caller import run_tst
16 from openpower.decoder.selectable_int import SelectableInt
17 from openpower.simulator.program import Program
18 from openpower.insndb.asm import SVP64Asm
19
20
21 # originally from https://github.com/pts/chacha20
22 # the functtion is turned into a "schedule" of the
23 # operations to be applied, where the add32 xor32 rotl32
24 # are actually carried out by the sthth_round
25 # higher-order-function. this "split-out" (code-morph)
26 # of the original code by pts@fazekas.hu allows us to
27 # share the "schedule" between the pure-python chacha20
28 # and the SVP64 implementation. the schedule is static:
29 # it can be printed out and loaded as "magic constants"
30 # into registers. more details:
31 # https://libre-soc.org/openpower/sv/cookbook/chacha20/
32 def quarter_round_schedule(x, a, b, c, d):
33 """collate list of reg-offsets for use with svindex/svremap
34 """
35 #x[a] = (x[a] + x[b]) & 0xffffffff - add32
36 #x[d] = x[d] ^ x[a] - xor32
37 #x[d] = rotate(x[d], 16) - rotl32
38 x.append((a, b, d, 16))
39
40 #x[c] = (x[c] + x[d]) & 0xffffffff - add32
41 #x[b] = x[b] ^ x[c] - xor32
42 #x[b] = rotate(x[b], 12) - rotl32
43 x.append((c, d, b, 12))
44
45 #x[a] = (x[a] + x[b]) & 0xffffffff - add32
46 #x[d] = x[d] ^ x[a] - xor32
47 #x[d] = rotate(x[d], 8) - rotl32
48 x.append((a, b, d, 8))
49
50 #x[c] = (x[c] + x[d]) & 0xffffffff - add32
51 #x[b] = x[b] ^ x[c] - xor32
52 #x[b] = rotate(x[b], 7) - rotl32
53 x.append((c, d, b, 7))
54
55
56 def rotl32(v, c):
57 c = c & 0x1f
58 res = ((v << c) & 0xffffffff) | v >> (32 - c)
59 print("op rotl32", hex(res), hex(v), hex(c))
60 return res
61
62
63 def add32(a, b):
64 res = (a + b) & 0xffffffff
65 print("op add32", hex(res), hex(a), hex(b))
66 return res
67
68
69 def xor32(a, b):
70 res = a ^ b
71 print("op xor32", hex(res), hex(a), hex(b))
72 return res
73
74
75 # originally in pts's code there were 4 of these, explicitly loop-unrolled.
76 # the common constants were extracted (a,b,c,d,rot) and this is what is left
77 def sthth_round(x, a, b, d, rot):
78 x[a] = add32 (x[a], x[b])
79 x[d] = xor32 (x[d], x[a])
80 x[d] = rotl32(x[d], rot)
81
82 # pts's version of quarter_round has the add/xor/rot explicitly
83 # loop-unrolled four times. instead we call the 16th-round function
84 # with the appropriate offsets/rot-magic-constants.
85 def quarter_round(x, a, b, c, d):
86 """collate list of reg-offsets for use with svindex/svremap
87 """
88 sthth_round(x, a, b, d, 16)
89 sthth_round(x, c, d, b, 12)
90 sthth_round(x, a, b, d, 8)
91 sthth_round(x, c, d, b, 7)
92
93
94 # again in pts's version, this is what was originally
95 # the loop around quarter_round. we can either pass in
96 # a function that simply collates the indices *or*
97 # actually do the same job as pts's original code,
98 # just by passing in a different fn.
99 def chacha_idx_schedule(x, fn=quarter_round_schedule):
100 fn(x, 0, 4, 8, 12)
101 fn(x, 1, 5, 9, 13)
102 fn(x, 2, 6, 10, 14)
103 fn(x, 3, 7, 11, 15)
104 fn(x, 0, 5, 10, 15)
105 fn(x, 1, 6, 11, 12)
106 fn(x, 2, 7, 8, 13)
107 fn(x, 3, 4, 9, 14)
108
109
110 class SVSTATETestCase(FHDLTestCase):
111
112 def _check_regs(self, sim, expected):
113 print("GPR")
114 sim.gpr.dump()
115 for i in range(32):
116 self.assertEqual(sim.gpr(i), SelectableInt(expected[i], 64),
117 "GPR %d %x expected %x" % \
118 (i, sim.gpr(i).value, expected[i]))
119
120 def test_1_sv_chacha20_main_rounds(self):
121 """chacha20 main rounds
122
123 RA, RB, RS and RT are set up via Indexing to perform the *individual*
124 add/xor/rotl32 operations (with elwidth=32)
125
126 the inner loop uses "svstep." which detects src/dst-step reaching
127 the end of the loop, setting CR0.eq=1. no need for an additional
128 counter-register-with-a-decrement. this has the side-effect of
129 freeing up CTR for use as a straight decrement-counter.
130
131 both loops are 100% deterministic meaning that there should be
132 *ZERO* branch-prediction misses, obviating a need for loop-unrolling.
133 """
134
135 nrounds = 2 # should be 10 for full algorithm
136
137 block = 24 # register for block of 16
138 vl = 22 # copy of VL placed in here
139 SHAPE0 = 8
140 SHAPE1 = 12
141 SHAPE2 = 16
142 shifts = 20 # registers for 4 32-bit shift amounts
143 ctr = 7 # register for CTR
144
145 isa = SVP64Asm([
146 # set up VL=32 vertical-first, and SVSHAPEs 0-2
147 # vertical-first, set MAXVL (and r17)
148 'setvl 0, 0, 32, 1, 1, 1', # vertical-first, set VL
149 'svindex %d, 0, 1, 3, 0, 1, 0' % (SHAPE0//2), # SVSHAPE0, a
150 'svindex %d, 1, 1, 3, 0, 1, 0' % (SHAPE1//2), # SVSHAPE1, b
151 'svindex %d, 2, 1, 3, 0, 1, 0' % (SHAPE2//2), # SVSHAPE2, c
152 'svshape2 0, 0, 3, 4, 0, 1', # SVSHAPE3, shift amount, mod 4
153 # establish CTR for outer round count
154 'addi %d, 0, %d' % (ctr, nrounds), # set number of rounds
155 'mtspr 9, %d' % ctr, # set CTR to number of rounds
156 # outer loop begins here (standard CTR loop)
157 'setvl 0, 0, 32, 1, 1, 1', # vertical-first, set VL
158 # inner loop begins here. add-xor-rotl32 with remap, step, branch
159 'svremap 31, 1, 0, 0, 0, 0, 0', # RA=1, RB=0, RT=0 (0b01011)
160 'sv.add/w=32 *%d, *%d, *%d' % (block, block, block),
161 'svremap 31, 2, 0, 2, 2, 0, 0', # RA=2, RB=0, RS=2 (0b00111)
162 'sv.xor/w=32 *%d, *%d, *%d' % (block, block, block),
163 'svremap 31, 0, 3, 2, 2, 0, 0', # RA=2, RB=3, RS=2 (0b01110)
164 'sv.rldcl/w=32 *%d, *%d, *%d, 0' % (block, block, shifts),
165 'svstep. %d, 0, 1, 0' % ctr, # step to next in-regs element
166 'bc 6, 3, -0x28', # svstep. Rc=1 loop-end-condition?
167 # inner-loop done: outer loop standard CTR-decrement to setvl again
168 'bc 16, 0, -0x30',
169 ])
170 lst = list(isa)
171 print("listing", lst)
172
173 schedule = []
174 chacha_idx_schedule(schedule, fn=quarter_round_schedule)
175
176 # initial values in GPR regfile
177 initial_regs = [0] * 128
178
179 # offsets for a b c
180 for i, (a, b, c, d) in enumerate(schedule):
181 print ("chacha20 schedule", i, hex(a), hex(b), hex(c), hex(d))
182 set_masked_reg(initial_regs, SHAPE0, i, ew_bits=8, value=a)
183 set_masked_reg(initial_regs, SHAPE1, i, ew_bits=8, value=b)
184 set_masked_reg(initial_regs, SHAPE2, i, ew_bits=8, value=c)
185
186 # offsets for d (modulo 4 shift amount)
187 shiftvals = [16, 12, 8, 7] # chacha20 shifts
188 for i in range(4):
189 set_masked_reg(initial_regs, shifts, i, ew_bits=32,
190 value=shiftvals[i])
191
192 # set up input test vector then pack it into regs
193 x = [0] * 16
194 x[0] = 0x61707865
195 x[1] = 0x3320646e
196 x[2] = 0x79622d32
197 x[3] = 0x6b206574
198 x[4] = 0x6d8bc55e
199 x[5] = 0xa5e04f51
200 x[6] = 0xea0d1e6f
201 x[7] = 0x5a09dc7b
202 x[8] = 0x18b6f510
203 x[9] = 0x26f2b6bd
204 x[10] = 0x7b59cc2f
205 x[11] = 0xefb330b2
206 x[12] = 0xcff545a3
207 x[13] = 0x7c512380
208 x[14] = 0x75f0fcc0
209 x[15] = 0x5f868c74
210
211 # use packing function which emulates element-width overrides @ 32-bit
212 for i in range(16):
213 set_masked_reg(initial_regs, block, i, ew_bits=32, value=x[i])
214
215 # SVSTATE vl=32
216 svstate = SVP64State()
217 #svstate.vl = 32 # VL
218 #svstate.maxvl = 32 # MAXVL
219 print("SVSTATE", bin(svstate.asint()))
220
221 # copy before running, compute expected results
222 expected_regs = deepcopy(initial_regs)
223 expected_regs[ctr] = 0 # reaches zero
224 #expected_regs[vl] = 32 # gets set to MAXVL
225 expected = deepcopy(x)
226 # use the pts-derived quarter_round function to
227 # compute a pure-python version of chacha20
228 for i in range(nrounds):
229 chacha_idx_schedule(expected, fn=quarter_round)
230 for i in range(16):
231 set_masked_reg(expected_regs, block, i, ew_bits=32,
232 value=expected[i])
233
234 with Program(lst, bigendian=False) as program:
235 sim = self.run_tst_program(program, initial_regs, svstate=svstate)
236
237 # print out expected: 16 values @ 32-bit ea -> QTY8 64-bit regs
238 for i in range(8):
239 RS = sim.gpr(i+block).value
240 print("expected", i+block, hex(RS), hex(expected_regs[i+block]))
241
242 print(sim.spr)
243 SVSHAPE0 = sim.spr['SVSHAPE0']
244 SVSHAPE1 = sim.spr['SVSHAPE1']
245 print("SVSTATE after", bin(sim.svstate.asint()))
246 print(" vl", bin(sim.svstate.vl))
247 print(" mvl", bin(sim.svstate.maxvl))
248 print(" srcstep", bin(sim.svstate.srcstep))
249 print(" dststep", bin(sim.svstate.dststep))
250 print(" RMpst", bin(sim.svstate.RMpst))
251 print(" SVme", bin(sim.svstate.SVme))
252 print(" mo0", bin(sim.svstate.mo0))
253 print(" mo1", bin(sim.svstate.mo1))
254 print(" mi0", bin(sim.svstate.mi0))
255 print(" mi1", bin(sim.svstate.mi1))
256 print(" mi2", bin(sim.svstate.mi2))
257 print("STATE0svgpr", hex(SVSHAPE0.svgpr))
258 print("STATE0 xdim", SVSHAPE0.xdimsz)
259 print("STATE0 ydim", SVSHAPE0.ydimsz)
260 print("STATE0 skip", bin(SVSHAPE0.skip))
261 print("STATE0 inv", SVSHAPE0.invxyz)
262 print("STATE0order", SVSHAPE0.order)
263 print(sim.gpr.dump())
264 self._check_regs(sim, expected_regs)
265 self.assertEqual(sim.svstate.RMpst, 0)
266 self.assertEqual(sim.svstate.SVme, 0b11111)
267 self.assertEqual(sim.svstate.mi0, 0)
268 self.assertEqual(sim.svstate.mi1, 3)
269 self.assertEqual(sim.svstate.mi2, 2)
270 self.assertEqual(sim.svstate.mo0, 2)
271 self.assertEqual(sim.svstate.mo1, 0)
272 #self.assertEqual(SVSHAPE0.svgpr, 22)
273 #self.assertEqual(SVSHAPE1.svgpr, 30)
274
275 def run_tst_program(self, prog, initial_regs=None,
276 svstate=None):
277 if initial_regs is None:
278 initial_regs = [0] * 32
279 simulator = run_tst(prog, initial_regs, svstate=svstate)
280 simulator.gpr.dump()
281 return simulator
282
283
284 if __name__ == "__main__":
285 unittest.main()