add more tests and fix missing corner case
[openpower-isa.git] / src / openpower / test / prefix_codes / prefix_codes_cases.py
1 import functools
2 import itertools
3 from openpower.test.common import TestAccumulatorBase
4 from openpower.sv.trans.svp64 import SVP64Asm
5 from openpower.test.state import ExpectedState
6 from openpower.simulator.program import Program
7
8
9 def tree_code(code):
10 # type: (str) -> int
11 retval = 1
12 for bit in code:
13 assert bit in "01"
14 retval = retval * 2 + int(bit)
15 assert retval < 64, "code too long"
16 return retval
17
18
19 def make_tree(*codes):
20 # type: (*str) -> int
21 retval = 0
22 for code in sorted(codes, key=len):
23 for i in range(len(code)):
24 assert retval & (1 << tree_code(code[:i])) == 0, \
25 f"conflicting code: {code} conflicts with {code[:i]}"
26 retval |= 1 << tree_code(code)
27 return retval
28
29
30 def reference_pcdec(supported_codes, input_bits, max_count):
31 # type: (set[str], str, int) -> tuple[list[str], bool]
32 assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
33 retval = []
34 current_code = ""
35 for bit in input_bits:
36 current_code += bit
37 if len(current_code) > 5:
38 break
39 if current_code in supported_codes:
40 retval.append(current_code)
41 current_code = ""
42 if len(retval) >= max_count:
43 break
44 return retval, current_code != ""
45
46
47 CODE_2 = "0"
48 CODE_7 = "11"
49 CODE_19 = "1001"
50 CODE_35 = "10101"
51 CODE_37 = "10111"
52 CODES = {CODE_2, CODE_7, CODE_19, CODE_35, CODE_37}
53
54
55 @functools.lru_cache()
56 def _cached_program(*instrs):
57 return Program(list(SVP64Asm(list(instrs))), bigendian=False)
58
59
60 class PrefixCodesCases(TestAccumulatorBase):
61 def check_pcdec(self, supported_codes, input_bits, once, src_loc_at=0):
62 # type: (set[str], str, bool, int) -> None
63 original_input_bits = input_bits
64 input_bits = input_bits.replace("_", "")
65 assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
66 assert len(input_bits) < 128, "input_bits too long"
67 max_count = 1 if once else 8
68 decoded, expected_SO = reference_pcdec(
69 supported_codes, input_bits, max_count=max_count)
70 expected_GT = len(decoded) == 0
71 expected_EQ = len(decoded) < max_count
72 expected_RT = int.from_bytes(
73 [int("1" + code, 2) for code in decoded], 'little')
74 decoded_bits_len = len("".join(decoded))
75 expected_ra_used = False
76 RB_val = make_tree(*supported_codes)
77 rev_input_bits = input_bits[::-1]
78 RA_val = 0
79 RA = 0
80 expected_RS = None
81 if len(input_bits) >= 64:
82 RA_val = int(rev_input_bits[:64], 2)
83 RA = 7
84 rev_input_bits = rev_input_bits[64:]
85 expected_ra_used = decoded_bits_len > len(rev_input_bits)
86 if expected_ra_used:
87 expected_RS = (RA_val + 2 ** 64) >> (decoded_bits_len
88 - len(rev_input_bits))
89 RC_val = int("1" + rev_input_bits, 2)
90 if expected_RS is None:
91 expected_RS = RC_val >> decoded_bits_len
92 lst = [f"pcdec. 4,{RA},6,5,{int(once)}"]
93 gprs = [0] * 32
94 gprs[6] = RB_val
95 if RA:
96 gprs[RA] = RA_val
97 gprs[5] = RC_val
98 e = ExpectedState(pc=4, int_regs=gprs)
99 e.intregs[4] = expected_RT
100 e.intregs[5] = expected_RS
101 e.crregs[0] = (expected_ra_used * 8 + expected_GT * 4
102 + expected_EQ * 2 + expected_SO)
103 with self.subTest(supported_codes=supported_codes,
104 input_bits=original_input_bits, once=once):
105 self.add_case(_cached_program(*lst), gprs, expected=e,
106 src_loc_at=src_loc_at + 1)
107
108 def case_pcdec_empty(self):
109 self.check_pcdec({CODE_2}, "", False)
110
111 def case_pcdec_empty_once(self):
112 self.check_pcdec({CODE_2}, "", True)
113
114 def case_pcdec_only_one_code(self):
115 self.check_pcdec({CODE_37}, CODE_37, False)
116
117 def case_pcdec_only_one_code_once(self):
118 self.check_pcdec({CODE_37}, CODE_37, True)
119
120 def case_pcdec_short_seq(self):
121 self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), False)
122
123 def case_pcdec_short_seq_once(self):
124 self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), True)
125
126 def case_pcdec_medium_seq(self):
127 self.check_pcdec(
128 CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", False)
129
130 def case_pcdec_medium_seq_once(self):
131 self.check_pcdec(
132 CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", True)
133
134 def case_pcdec_long_seq(self):
135 self.check_pcdec(CODES,
136 "0_11_1001_10101_10111_10111_10101_1001_11_0"
137 + CODE_37 * 6, False)
138
139 def case_pcdec_long_seq_once(self):
140 self.check_pcdec(CODES,
141 "0_11_1001_10101_10111_10111_10101_1001_11_0"
142 + CODE_37 * 6, True)
143
144 def case_pcdec_invalid_code_at_start(self):
145 self.check_pcdec(CODES, "_".join(["1000", CODE_35]), False)
146
147 def case_pcdec_invalid_code_at_start_once(self):
148 self.check_pcdec(CODES, "_".join(["1000", CODE_35]), True)
149
150 def case_pcdec_invalid_code_after_3(self):
151 self.check_pcdec(CODES, "_".join(
152 [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), False)
153
154 def case_pcdec_invalid_code_after_3_once(self):
155 self.check_pcdec(CODES, "_".join(
156 [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), True)
157
158 def case_pcdec_invalid_code_after_8(self):
159 self.check_pcdec(CODES, "_".join(
160 [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), False)
161
162 def case_pcdec_invalid_code_after_8_once(self):
163 self.check_pcdec(CODES, "_".join(
164 [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), True)
165
166 def case_pcdec_invalid_code_in_rb(self):
167 self.check_pcdec(CODES, "_".join(
168 [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), False)
169
170 def case_pcdec_invalid_code_in_rb_once(self):
171 self.check_pcdec(CODES, "_".join(
172 [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), True)
173
174 def case_pcdec_overlong_code(self):
175 self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, "10000000"]), False)
176
177 def case_pcdec_overlong_code_once(self):
178 self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, "10000000"]), True)
179
180 def case_pcdec_incomplete_code(self):
181 self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), False)
182
183 def case_pcdec_incomplete_code_once(self):
184 self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), True)
185
186 def case_rest(self):
187 for repeat in range(8):
188 for bits in itertools.product("01", repeat=repeat):
189 self.check_pcdec(CODES, "".join(bits), False)
190 self.check_pcdec(CODES, "".join(bits), True)
191 # 60 so we cover both less and more than 64 bits
192 self.check_pcdec(CODES, "".join(bits) + "0" * 60, False)
193 self.check_pcdec(CODES, "".join(bits) + "0" * 60, True)