increase pcdec. output compression by skipping impossible codes
[openpower-isa.git] / src / openpower / test / prefix_codes / prefix_codes_cases.py
1 import functools
2 import itertools
3 from openpower.test.common import TestAccumulatorBase, skip_case
4 from openpower.sv.trans.svp64 import SVP64Asm
5 from openpower.test.state import ExpectedState
6 from openpower.simulator.program import Program
7 from typing import Iterable
8
9
10 def tree_code(code):
11 # type: (str) -> int
12 retval = 1
13 for bit in code:
14 assert bit in "01"
15 retval = retval * 2 + int(bit)
16 assert retval < 64, "code too long"
17 return retval
18
19
20 def make_tree(*codes):
21 # type: (*str) -> int
22 retval = 0
23 for code in sorted(codes, key=len):
24 retval |= 1 << tree_code(code)
25 return retval
26
27
28 def from_tree(tree):
29 # type: (int) -> list[str]
30 retval = [] # type: list[str]
31 for bit_len in range(1, 6):
32 for bits in itertools.product("01", repeat=bit_len):
33 bits = "".join(bits)
34 tree_index = int("0b1" + bits, 0)
35 if tree & (1 << tree_index):
36 retval.append(bits)
37 return retval
38
39
40 CODE_2 = "0"
41 CODE_7 = "11"
42 CODE_19 = "1001"
43 CODE_35 = "10101"
44 CODE_37 = "10111"
45 CODES = {CODE_2, CODE_7, CODE_19, CODE_35, CODE_37}
46
47
48 @functools.lru_cache()
49 def _cached_program(*instrs):
50 return Program(list(SVP64Asm(list(instrs))), bigendian=False)
51
52
53 def _code_sort_key(supported_code):
54 # type: (str) -> tuple[int, str]
55 return len(supported_code), supported_code
56
57
58 class PrefixCodesCases(TestAccumulatorBase):
59 def check_pcdec(self, supported_codes, input_bits, mode, src_loc_at=0):
60 # type: (Iterable[str], str, int, int) -> None
61 supported_codes = sorted(supported_codes, key=_code_sort_key)
62 compressed_codes = [] # type: list[str]
63 codes = supported_codes.copy()
64 codes += map(lambda v: "".join(v), itertools.product("01", repeat=6))
65 for code in codes:
66 possible = True
67 for i in range(1, len(code)):
68 if code[:i] in compressed_codes:
69 possible = False
70 break
71 if possible:
72 compressed_codes.append(code)
73 original_input_bits = input_bits
74 input_bits = input_bits.replace("_", "")
75 assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
76 assert len(input_bits) < 128, "input_bits too long"
77 found = False
78 hit_end = False
79 tree_index = 1
80 compressed_index = 0
81 used_bits = 0
82 for bit_len in range(1, 7):
83 cur_input_bits = input_bits[:bit_len]
84 if bit_len > len(input_bits):
85 hit_end = True
86 compressed_index = 0
87 for i, code in enumerate(compressed_codes):
88 if _code_sort_key(code) < _code_sort_key(cur_input_bits):
89 compressed_index = i + 1
90 else:
91 break
92 break
93 tree_index *= 2
94 if cur_input_bits[-1] == "1":
95 tree_index += 1
96 try:
97 compressed_index = compressed_codes.index(cur_input_bits)
98 used_bits = bit_len
99 if bit_len < 6:
100 found = True
101 break
102 except ValueError:
103 pass
104 if mode == 0:
105 expected_RT = tree_index
106 if not found:
107 used_bits = 0
108 elif mode == 1:
109 expected_RT = tree_index
110 if hit_end:
111 used_bits = 0
112 elif mode == 2:
113 expected_RT = compressed_index
114 if not found:
115 used_bits = 0
116 expected_RT = tree_index
117 else:
118 assert mode == 3
119 expected_RT = compressed_index
120 if hit_end:
121 used_bits = 0
122 expected_ra_used = False
123 RB_val = make_tree(*supported_codes) | mode
124 rev_input_bits = input_bits[::-1]
125 RA_val = 0
126 RA = 0
127 expected_RS = None
128 if len(input_bits) >= 64:
129 RA_val = int(rev_input_bits[:64], 2)
130 RA = 7
131 rev_input_bits = rev_input_bits[64:]
132 expected_ra_used = used_bits > len(rev_input_bits)
133 if expected_ra_used:
134 expected_RS = (RA_val + 2 ** 64) >> (used_bits
135 - len(rev_input_bits))
136 RC_val = int("1" + rev_input_bits, 2)
137 if expected_RS is None:
138 expected_RS = RC_val >> used_bits
139 lst = [f"pcdec. 4,{RA},6,5"]
140 gprs = [0] * 32
141 gprs[6] = RB_val
142 if RA:
143 gprs[RA] = RA_val
144 gprs[5] = RC_val
145 e = ExpectedState(pc=4, int_regs=gprs)
146 e.intregs[4] = expected_RT
147 e.intregs[5] = expected_RS
148 e.crregs[0] = (expected_ra_used * 8 + (tree_index >= 64) * 4
149 + found * 2 + hit_end)
150 with self.subTest(supported_codes=supported_codes,
151 input_bits=original_input_bits, mode=mode):
152 self.add_case(_cached_program(*lst), gprs, expected=e,
153 src_loc_at=src_loc_at + 1)
154
155 def case_pcdec_empty(self):
156 for mode in range(4):
157 self.check_pcdec({CODE_2}, "", mode)
158
159 def case_pcdec_only_one_code(self):
160 for mode in range(4):
161 self.check_pcdec({CODE_37}, CODE_37, mode)
162
163 def case_pcdec_short_seq(self):
164 for mode in range(4):
165 self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), mode)
166
167 def case_pcdec_medium_seq(self):
168 for mode in range(4):
169 self.check_pcdec(
170 CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", mode)
171
172 def case_pcdec_long_seq(self):
173 for mode in range(4):
174 self.check_pcdec(CODES,
175 "0_11_1001_10101_10111_10111_10101_1001_11_0"
176 + CODE_37 * 6, mode)
177
178 def case_pcdec_invalid_code_at_start(self):
179 for mode in range(4):
180 self.check_pcdec(CODES, "_".join(["1000", CODE_35]), mode)
181
182 def case_pcdec_invalid_code_after_3(self):
183 for mode in range(4):
184 self.check_pcdec(CODES, "_".join(
185 [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), mode)
186
187 def case_pcdec_invalid_code_after_8(self):
188 for mode in range(4):
189 self.check_pcdec(CODES, "_".join(
190 [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), mode)
191
192 def case_pcdec_invalid_code_in_rb(self):
193 for mode in range(4):
194 self.check_pcdec(CODES, "_".join(
195 [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), mode)
196
197 def case_pcdec_overlong_code(self):
198 for mode in range(4):
199 self.check_pcdec(CODES, "_".join(
200 [CODE_2, CODE_19, "10000000"]), mode)
201
202 def case_pcdec_incomplete_code(self):
203 for mode in range(4):
204 self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), mode)
205
206 def case_rest(self):
207 for mode in range(4):
208 for repeat in range(8):
209 for bits in itertools.product("01", repeat=repeat):
210 self.check_pcdec(CODES, "".join(bits), mode)
211 # 60 so we cover both less and more than 64 bits
212 self.check_pcdec(CODES, "".join(bits) + "0" * 60, mode)
213
214 def case_arbitrary_tree(self):
215 codes = from_tree(0x123456789ABCDEF0)
216 for mode in range(4):
217 for repeat in range(8):
218 for bits in itertools.product("01", repeat=repeat):
219 self.check_pcdec(codes, "".join(bits), mode)
220 # 60 so we cover both less and more than 64 bits
221 self.check_pcdec(codes, "".join(bits) + "0" * 60, mode)