from openpower.sv.trans.svp64 import SVP64Asm
from openpower.test.state import ExpectedState
from openpower.simulator.program import Program
+from typing import Iterable
def tree_code(code):
return retval
-def reference_pcdec(supported_codes, input_bits, max_count):
- # type: (set[str], str, int) -> tuple[list[str], bool]
- assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
- retval = []
- current_code = ""
- for bit in input_bits:
- current_code += bit
- if len(current_code) > 5:
- break
- if current_code in supported_codes:
- retval.append(current_code)
- current_code = ""
- if len(retval) >= max_count:
- break
- return retval, current_code != ""
-
-
CODE_2 = "0"
CODE_7 = "11"
CODE_19 = "1001"
return Program(list(SVP64Asm(list(instrs))), bigendian=False)
+def _code_sort_key(supported_code):
+ # type: (str) -> tuple[int, str]
+ return len(supported_code), supported_code
+
+
class PrefixCodesCases(TestAccumulatorBase):
- @skip_case("FIXME(programmerjake): update for new pcdec. pseudocode")
- def check_pcdec(self, supported_codes, input_bits, once, src_loc_at=0):
- # type: (set[str], str, bool, int) -> None
+ def check_pcdec(self, supported_codes, input_bits, mode, src_loc_at=0):
+ # type: (Iterable[str], str, int, int) -> None
+ supported_codes = sorted(supported_codes, key=_code_sort_key)
+ assert len(supported_codes) <= 32
original_input_bits = input_bits
input_bits = input_bits.replace("_", "")
assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
assert len(input_bits) < 128, "input_bits too long"
- max_count = 1 if once else 8
- decoded, expected_SO = reference_pcdec(
- supported_codes, input_bits, max_count=max_count)
- expected_GT = len(decoded) == 0
- expected_EQ = len(decoded) < max_count
- expected_RT = int.from_bytes(
- [int("1" + code, 2) for code in decoded], 'little')
- decoded_bits_len = len("".join(decoded))
+ found = False
+ hit_end = False
+ tree_index = 1
+ compressed_index = 0
+ used_bits = 0
+ for bit_len in range(1, 7):
+ cur_input_bits = input_bits[:bit_len]
+ if bit_len > len(input_bits):
+ hit_end = True
+ compressed_index = 0
+ for i, code in enumerate(supported_codes):
+ if _code_sort_key(code) < _code_sort_key(cur_input_bits):
+ compressed_index = i + 1
+ else:
+ break
+ break
+ tree_index *= 2
+ if cur_input_bits[-1] == "1":
+ tree_index += 1
+ if bit_len < 6:
+ try:
+ compressed_index = supported_codes.index(cur_input_bits)
+ found = True
+ used_bits = bit_len
+ break
+ except ValueError:
+ pass
+ else:
+ compressed_index = tree_index - 64 + len(supported_codes)
+ used_bits = bit_len
+ if mode == 0:
+ expected_RT = tree_index
+ if not found:
+ used_bits = 0
+ elif mode == 1:
+ expected_RT = tree_index
+ if hit_end:
+ used_bits = 0
+ elif mode == 2:
+ expected_RT = compressed_index
+ if not found:
+ used_bits = 0
+ expected_RT = tree_index
+ else:
+ assert mode == 3
+ expected_RT = compressed_index
+ if hit_end:
+ used_bits = 0
expected_ra_used = False
- RB_val = make_tree(*supported_codes)
+ RB_val = make_tree(*supported_codes) | mode
rev_input_bits = input_bits[::-1]
RA_val = 0
RA = 0
RA_val = int(rev_input_bits[:64], 2)
RA = 7
rev_input_bits = rev_input_bits[64:]
- expected_ra_used = decoded_bits_len > len(rev_input_bits)
+ expected_ra_used = used_bits > len(rev_input_bits)
if expected_ra_used:
- expected_RS = (RA_val + 2 ** 64) >> (decoded_bits_len
+ expected_RS = (RA_val + 2 ** 64) >> (used_bits
- len(rev_input_bits))
RC_val = int("1" + rev_input_bits, 2)
if expected_RS is None:
- expected_RS = RC_val >> decoded_bits_len
- lst = [f"pcdec. 4,{RA},6,5,{int(once)}"]
+ expected_RS = RC_val >> used_bits
+ lst = [f"pcdec. 4,{RA},6,5"]
gprs = [0] * 32
gprs[6] = RB_val
if RA:
e = ExpectedState(pc=4, int_regs=gprs)
e.intregs[4] = expected_RT
e.intregs[5] = expected_RS
- e.crregs[0] = (expected_ra_used * 8 + expected_GT * 4
- + expected_EQ * 2 + expected_SO)
+ e.crregs[0] = (expected_ra_used * 8 + (tree_index >= 64) * 4
+ + found * 2 + hit_end)
with self.subTest(supported_codes=supported_codes,
- input_bits=original_input_bits, once=once):
+ input_bits=original_input_bits, mode=mode):
self.add_case(_cached_program(*lst), gprs, expected=e,
src_loc_at=src_loc_at + 1)
def case_pcdec_empty(self):
- self.check_pcdec({CODE_2}, "", False)
-
- def case_pcdec_empty_once(self):
- self.check_pcdec({CODE_2}, "", True)
+ for mode in range(4):
+ self.check_pcdec({CODE_2}, "", mode)
def case_pcdec_only_one_code(self):
- self.check_pcdec({CODE_37}, CODE_37, False)
-
- def case_pcdec_only_one_code_once(self):
- self.check_pcdec({CODE_37}, CODE_37, True)
+ for mode in range(4):
+ self.check_pcdec({CODE_37}, CODE_37, mode)
def case_pcdec_short_seq(self):
- self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), False)
-
- def case_pcdec_short_seq_once(self):
- self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), True)
+ for mode in range(4):
+ self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), mode)
def case_pcdec_medium_seq(self):
- self.check_pcdec(
- CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", False)
-
- def case_pcdec_medium_seq_once(self):
- self.check_pcdec(
- CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", True)
+ for mode in range(4):
+ self.check_pcdec(
+ CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", mode)
def case_pcdec_long_seq(self):
- self.check_pcdec(CODES,
- "0_11_1001_10101_10111_10111_10101_1001_11_0"
- + CODE_37 * 6, False)
-
- def case_pcdec_long_seq_once(self):
- self.check_pcdec(CODES,
- "0_11_1001_10101_10111_10111_10101_1001_11_0"
- + CODE_37 * 6, True)
+ for mode in range(4):
+ self.check_pcdec(CODES,
+ "0_11_1001_10101_10111_10111_10101_1001_11_0"
+ + CODE_37 * 6, mode)
def case_pcdec_invalid_code_at_start(self):
- self.check_pcdec(CODES, "_".join(["1000", CODE_35]), False)
-
- def case_pcdec_invalid_code_at_start_once(self):
- self.check_pcdec(CODES, "_".join(["1000", CODE_35]), True)
+ for mode in range(4):
+ self.check_pcdec(CODES, "_".join(["1000", CODE_35]), mode)
def case_pcdec_invalid_code_after_3(self):
- self.check_pcdec(CODES, "_".join(
- [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), False)
-
- def case_pcdec_invalid_code_after_3_once(self):
- self.check_pcdec(CODES, "_".join(
- [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), True)
+ for mode in range(4):
+ self.check_pcdec(CODES, "_".join(
+ [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), mode)
def case_pcdec_invalid_code_after_8(self):
- self.check_pcdec(CODES, "_".join(
- [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), False)
-
- def case_pcdec_invalid_code_after_8_once(self):
- self.check_pcdec(CODES, "_".join(
- [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), True)
+ for mode in range(4):
+ self.check_pcdec(CODES, "_".join(
+ [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), mode)
def case_pcdec_invalid_code_in_rb(self):
- self.check_pcdec(CODES, "_".join(
- [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), False)
-
- def case_pcdec_invalid_code_in_rb_once(self):
- self.check_pcdec(CODES, "_".join(
- [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), True)
+ for mode in range(4):
+ self.check_pcdec(CODES, "_".join(
+ [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), mode)
def case_pcdec_overlong_code(self):
- self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, "10000000"]), False)
-
- def case_pcdec_overlong_code_once(self):
- self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, "10000000"]), True)
+ for mode in range(4):
+ self.check_pcdec(CODES, "_".join(
+ [CODE_2, CODE_19, "10000000"]), mode)
def case_pcdec_incomplete_code(self):
- self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), False)
-
- def case_pcdec_incomplete_code_once(self):
- self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), True)
+ for mode in range(4):
+ self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), mode)
def case_rest(self):
- for repeat in range(8):
- for bits in itertools.product("01", repeat=repeat):
- self.check_pcdec(CODES, "".join(bits), False)
- self.check_pcdec(CODES, "".join(bits), True)
- # 60 so we cover both less and more than 64 bits
- self.check_pcdec(CODES, "".join(bits) + "0" * 60, False)
- self.check_pcdec(CODES, "".join(bits) + "0" * 60, True)
+ for mode in range(4):
+ for repeat in range(8):
+ for bits in itertools.product("01", repeat=repeat):
+ self.check_pcdec(CODES, "".join(bits), mode)
+ # 60 so we cover both less and more than 64 bits
+ self.check_pcdec(CODES, "".join(bits) + "0" * 60, mode)