add tests for checking if the simulator and assembler agree on SVP64 encodings
[openpower-isa.git] / src / openpower / test / svp64 / encodings.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2023 Jacob Lifshay programmerjake@gmail.com
3
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
6
7 from openpower.test.common import TestAccumulatorBase, skip_case
8 from openpower.test.state import ExpectedState
9 from openpower.test.util import assemble
10 from openpower.decoder.isa.caller import SVP64State
11 from nmutil.plain_data import plain_data
12 from enum import Enum
13 import itertools
14
15 _SCALAR_EXTRA2 = range(64)
16 _VECTOR_EXTRA2 = range(0, 128, 2)
17 _SCALAR_EXTRA3 = range(128)
18 _VECTOR_EXTRA3 = range(128)
19
20
21 @plain_data(unsafe_hash=True, frozen=True)
22 class _Arg:
23 __slots__ = "alloc_pass",
24
25 def __init__(self, alloc_pass):
26 # type: (_AllocPass)
27 self.alloc_pass = alloc_pass
28
29 def gen(self, regs):
30 raise NotImplementedError
31 yield ""
32
33 @staticmethod
34 def __gen_all_args(cur_out, output, args, regs, step):
35 if step >= len(args):
36 output.append((cur_out.copy(), regs.copy()))
37 return
38
39 idx, arg = args[step]
40 for out in arg.gen(regs):
41 cur_out[idx] = out
42 _Arg.__gen_all_args(cur_out, output, args, regs, step=step + 1)
43
44 @staticmethod
45 def gen_all_args(args):
46 # sort by pass
47 args = sorted(enumerate(args), key=lambda arg: arg[1].alloc_pass)
48 output = []
49 cur_out = [""] * len(args)
50 _Arg.__gen_all_args(cur_out, output, args, regs={}, step=0)
51 return output
52
53
54 _DEFAULT_TEST_VALUES = 0x123456789ABCDEF0,
55
56
57 @plain_data(unsafe_hash=True, frozen=True)
58 class _ArgReg(_Arg):
59 __slots__ = "vec", "regs", "values", "all_regs"
60
61 def __init__(self, vec, regs, values, all_regs):
62 # type: (bool, range, list[int] | tuple[int, ...], bool) -> None
63 self.vec = vec
64 self.regs = regs
65 self.values = tuple(values)
66 self.all_regs = all_regs
67 super().__init__(0 if all_regs else 1)
68
69 @staticmethod
70 def const(value, vec=False, regs=range(4, 32, 2)):
71 # type: (int, bool, range) -> _ArgReg
72 return _ArgReg(vec, regs, values=(value,), all_regs=False)
73
74 @staticmethod
75 def reg_range(vec, regs, values=_DEFAULT_TEST_VALUES, skip_r0=False):
76 # type: (bool, range, list[int] | tuple[int, ...], bool) -> _ArgReg
77 if skip_r0:
78 assert regs.start == 0
79 regs = range(regs.start + regs.step,
80 regs.stop, regs.step)
81 return _ArgReg(vec, regs, values, all_regs=True)
82
83 @staticmethod
84 def s_extra2(values=_DEFAULT_TEST_VALUES, skip_r0=False):
85 # type: (list[int] | tuple[int, ...], bool) -> _ArgReg
86 return _ArgReg.reg_range(vec=False, regs=_SCALAR_EXTRA2,
87 values=values, skip_r0=skip_r0)
88
89 @staticmethod
90 def v_extra2(values=_DEFAULT_TEST_VALUES, skip_r0=False):
91 # type: (list[int] | tuple[int, ...], bool) -> _ArgReg
92 return _ArgReg.reg_range(vec=True, regs=_VECTOR_EXTRA2,
93 values=values, skip_r0=skip_r0)
94
95 @staticmethod
96 def s_extra3(values=_DEFAULT_TEST_VALUES, skip_r0=False):
97 # type: (list[int] | tuple[int, ...], bool) -> _ArgReg
98 return _ArgReg.reg_range(vec=False, regs=_SCALAR_EXTRA3,
99 values=values, skip_r0=skip_r0)
100
101 @staticmethod
102 def v_extra3(values=_DEFAULT_TEST_VALUES, skip_r0=False):
103 # type: (list[int] | tuple[int, ...], bool) -> _ArgReg
104 return _ArgReg.reg_range(vec=True, regs=_VECTOR_EXTRA3,
105 values=values, skip_r0=skip_r0)
106
107 def gen(self, regs):
108 for reg in self.regs:
109 if reg in regs:
110 continue
111 regs[reg] = self.values
112 s = str(reg)
113 if self.vec:
114 s = "*" + s
115 yield (reg, s)
116 del regs[reg]
117 if not self.all_regs:
118 break
119
120
121 @plain_data(unsafe_hash=True, frozen=True)
122 class _ArgLiteral(_Arg):
123 __slots__ = "text",
124
125 def __init__(self, text):
126 # type: (str) -> None
127 self.text = text
128 super().__init__(0)
129
130 def gen(self, regs):
131 yield (None, self.text)
132
133
134 class SVP64EncodingsCases(TestAccumulatorBase):
135 def do_check(self, insn, args, gen_expected, src_loc_at=0):
136 UNINIT = int.from_bytes(b"uninit..", "little")
137 all_args = _Arg.gen_all_args(args)
138 for cur_args, cur_regs in all_args:
139 asm = insn + " " + ", ".join(map(lambda v: v[1], cur_args))
140 with self.subTest(asm=asm):
141 prog = assemble([asm])
142 for values in itertools.product(*cur_regs.values()):
143 gprs = [UNINIT] * 128
144 for reg, v in zip(cur_regs.keys(), values):
145 gprs[reg] = v
146 svstate = SVP64State()
147 svstate.vl = 1
148 svstate.maxvl = 1
149 e = gen_expected(cur_args, gprs)
150 expected_gprs = []
151 input_gprs = []
152 for reg in sorted(cur_regs.keys()):
153 iv = gprs[reg]
154 ev = e.intregs[reg]
155 input_gprs.append(f"r{reg} = 0x{iv:X}")
156 expected_gprs.append(f"r{reg} = 0x{ev:X}")
157 expected_gprs = "\n".join(expected_gprs)
158 input_gprs = "\n".join(input_gprs)
159 with self.subTest(
160 expected_gprs=expected_gprs, input_gprs=input_gprs,
161 ):
162 self.add_case(prog, gprs, expected=e,
163 initial_svstate=svstate,
164 src_loc_at=src_loc_at + 1)
165
166 # test RM-1P-2S1D
167
168 @staticmethod
169 def __sv_add_gen_expected(cur_args, gprs):
170 e = ExpectedState(pc=8, int_regs=gprs)
171 RT_reg = cur_args[0][0]
172 RA_reg = cur_args[1][0]
173 RB_reg = cur_args[2][0]
174 RA = gprs[RA_reg]
175 RB = gprs[RB_reg]
176 e.intregs[RT_reg] = (RA + RB) % 2 ** 64
177 return e
178
179 def case_sv_add_vvs_rt(self):
180 self.do_check("sv.add", [
181 _ArgReg.v_extra3(),
182 _ArgReg.const(1, vec=True),
183 _ArgReg.const(1)], self.__sv_add_gen_expected)
184
185 def case_sv_add_vvs_ra(self):
186 self.do_check("sv.add", [
187 _ArgReg.const(0, vec=True),
188 _ArgReg.v_extra3(),
189 _ArgReg.const(1)], self.__sv_add_gen_expected)
190
191 def case_sv_add_vvs_rb(self):
192 self.do_check("sv.add", [
193 _ArgReg.const(0, vec=True),
194 _ArgReg.const(1, vec=True),
195 _ArgReg.s_extra3()], self.__sv_add_gen_expected)
196
197 # test RM-1P-3S1D
198
199 @staticmethod
200 def __sv_maddedu_gen_expected(cur_args, gprs):
201 e = ExpectedState(pc=8, int_regs=gprs)
202 RT_reg = cur_args[0][0]
203 RA_reg = cur_args[1][0]
204 RB_reg = cur_args[2][0]
205 RC_reg = cur_args[3][0]
206 RA = gprs[RA_reg]
207 RB = gprs[RB_reg]
208 RC = gprs[RC_reg]
209 v = (RA * RB) + RC
210 RT = v % 2 ** 64
211 RS = v >> 64 # can't overflow, so no need for wrapping
212 e.intregs[RT_reg] = RT
213 e.intregs[RC_reg] = RS
214 return e
215
216 def case_sv_maddedu_vvss_rt(self):
217 self.do_check("sv.maddedu", [
218 _ArgReg.v_extra2(),
219 _ArgReg.const(1, vec=True),
220 _ArgReg.const(1),
221 _ArgReg.const(0)], self.__sv_maddedu_gen_expected)
222
223 def case_sv_maddedu_vvss_ra(self):
224 self.do_check("sv.maddedu", [
225 _ArgReg.const(0, vec=True),
226 _ArgReg.v_extra2(),
227 _ArgReg.const(1),
228 _ArgReg.const(0)], self.__sv_maddedu_gen_expected)
229
230 def case_sv_maddedu_vvss_rb(self):
231 self.do_check("sv.maddedu", [
232 _ArgReg.const(0, vec=True),
233 _ArgReg.const(1, vec=True),
234 _ArgReg.s_extra2(),
235 _ArgReg.const(0)], self.__sv_maddedu_gen_expected)
236
237 def case_sv_maddedu_vvss_rc(self):
238 self.do_check("sv.maddedu", [
239 _ArgReg.const(0, vec=True),
240 _ArgReg.const(0, vec=True),
241 _ArgReg.const(0),
242 _ArgReg.s_extra2()], self.__sv_maddedu_gen_expected)