prefix codes tests pass
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 30 Sep 2022 23:08:53 +0000 (16:08 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 30 Sep 2022 23:08:53 +0000 (16:08 -0700)
src/openpower/test/prefix_codes/prefix_codes_cases.py

index 8bdc0382028296f93f76011d2971dbfa1df33a91..cdbf9c6f9804dd812475671231618010c477426f 100644 (file)
@@ -4,6 +4,7 @@ from openpower.test.common import TestAccumulatorBase, skip_case
 from openpower.sv.trans.svp64 import SVP64Asm
 from openpower.test.state import ExpectedState
 from openpower.simulator.program import Program
+from typing import Iterable
 
 
 def tree_code(code):
@@ -27,23 +28,6 @@ def make_tree(*codes):
     return retval
 
 
-def reference_pcdec(supported_codes, input_bits, max_count):
-    # type: (set[str], str, int) -> tuple[list[str], bool]
-    assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
-    retval = []
-    current_code = ""
-    for bit in input_bits:
-        current_code += bit
-        if len(current_code) > 5:
-            break
-        if current_code in supported_codes:
-            retval.append(current_code)
-            current_code = ""
-            if len(retval) >= max_count:
-                break
-    return retval, current_code != ""
-
-
 CODE_2 = "0"
 CODE_7 = "11"
 CODE_19 = "1001"
@@ -57,24 +41,70 @@ def _cached_program(*instrs):
     return Program(list(SVP64Asm(list(instrs))), bigendian=False)
 
 
+def _code_sort_key(supported_code):
+    # type: (str) -> tuple[int, str]
+    return len(supported_code), supported_code
+
+
 class PrefixCodesCases(TestAccumulatorBase):
-    @skip_case("FIXME(programmerjake): update for new pcdec. pseudocode")
-    def check_pcdec(self, supported_codes, input_bits, once, src_loc_at=0):
-        # type: (set[str], str, bool, int) -> None
+    def check_pcdec(self, supported_codes, input_bits, mode, src_loc_at=0):
+        # type: (Iterable[str], str, int, int) -> None
+        supported_codes = sorted(supported_codes, key=_code_sort_key)
+        assert len(supported_codes) <= 32
         original_input_bits = input_bits
         input_bits = input_bits.replace("_", "")
         assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
         assert len(input_bits) < 128, "input_bits too long"
-        max_count = 1 if once else 8
-        decoded, expected_SO = reference_pcdec(
-            supported_codes, input_bits, max_count=max_count)
-        expected_GT = len(decoded) == 0
-        expected_EQ = len(decoded) < max_count
-        expected_RT = int.from_bytes(
-            [int("1" + code, 2) for code in decoded], 'little')
-        decoded_bits_len = len("".join(decoded))
+        found = False
+        hit_end = False
+        tree_index = 1
+        compressed_index = 0
+        used_bits = 0
+        for bit_len in range(1, 7):
+            cur_input_bits = input_bits[:bit_len]
+            if bit_len > len(input_bits):
+                hit_end = True
+                compressed_index = 0
+                for i, code in enumerate(supported_codes):
+                    if _code_sort_key(code) < _code_sort_key(cur_input_bits):
+                        compressed_index = i + 1
+                    else:
+                        break
+                break
+            tree_index *= 2
+            if cur_input_bits[-1] == "1":
+                tree_index += 1
+            if bit_len < 6:
+                try:
+                    compressed_index = supported_codes.index(cur_input_bits)
+                    found = True
+                    used_bits = bit_len
+                    break
+                except ValueError:
+                    pass
+            else:
+                compressed_index = tree_index - 64 + len(supported_codes)
+                used_bits = bit_len
+        if mode == 0:
+            expected_RT = tree_index
+            if not found:
+                used_bits = 0
+        elif mode == 1:
+            expected_RT = tree_index
+            if hit_end:
+                used_bits = 0
+        elif mode == 2:
+            expected_RT = compressed_index
+            if not found:
+                used_bits = 0
+                expected_RT = tree_index
+        else:
+            assert mode == 3
+            expected_RT = compressed_index
+            if hit_end:
+                used_bits = 0
         expected_ra_used = False
-        RB_val = make_tree(*supported_codes)
+        RB_val = make_tree(*supported_codes) | mode
         rev_input_bits = input_bits[::-1]
         RA_val = 0
         RA = 0
@@ -83,14 +113,14 @@ class PrefixCodesCases(TestAccumulatorBase):
             RA_val = int(rev_input_bits[:64], 2)
             RA = 7
             rev_input_bits = rev_input_bits[64:]
-            expected_ra_used = decoded_bits_len > len(rev_input_bits)
+            expected_ra_used = used_bits > len(rev_input_bits)
             if expected_ra_used:
-                expected_RS = (RA_val + 2 ** 64) >> (decoded_bits_len
+                expected_RS = (RA_val + 2 ** 64) >> (used_bits
                                                      - len(rev_input_bits))
         RC_val = int("1" + rev_input_bits, 2)
         if expected_RS is None:
-            expected_RS = RC_val >> decoded_bits_len
-        lst = [f"pcdec. 4,{RA},6,5,{int(once)}"]
+            expected_RS = RC_val >> used_bits
+        lst = [f"pcdec. 4,{RA},6,5"]
         gprs = [0] * 32
         gprs[6] = RB_val
         if RA:
@@ -99,96 +129,68 @@ class PrefixCodesCases(TestAccumulatorBase):
         e = ExpectedState(pc=4, int_regs=gprs)
         e.intregs[4] = expected_RT
         e.intregs[5] = expected_RS
-        e.crregs[0] = (expected_ra_used * 8 + expected_GT * 4
-                       + expected_EQ * 2 + expected_SO)
+        e.crregs[0] = (expected_ra_used * 8 + (tree_index >= 64) * 4
+                       + found * 2 + hit_end)
         with self.subTest(supported_codes=supported_codes,
-                          input_bits=original_input_bits, once=once):
+                          input_bits=original_input_bits, mode=mode):
             self.add_case(_cached_program(*lst), gprs, expected=e,
                           src_loc_at=src_loc_at + 1)
 
     def case_pcdec_empty(self):
-        self.check_pcdec({CODE_2}, "", False)
-
-    def case_pcdec_empty_once(self):
-        self.check_pcdec({CODE_2}, "", True)
+        for mode in range(4):
+            self.check_pcdec({CODE_2}, "", mode)
 
     def case_pcdec_only_one_code(self):
-        self.check_pcdec({CODE_37}, CODE_37, False)
-
-    def case_pcdec_only_one_code_once(self):
-        self.check_pcdec({CODE_37}, CODE_37, True)
+        for mode in range(4):
+            self.check_pcdec({CODE_37}, CODE_37, mode)
 
     def case_pcdec_short_seq(self):
-        self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), False)
-
-    def case_pcdec_short_seq_once(self):
-        self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), True)
+        for mode in range(4):
+            self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, CODE_35]), mode)
 
     def case_pcdec_medium_seq(self):
-        self.check_pcdec(
-            CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", False)
-
-    def case_pcdec_medium_seq_once(self):
-        self.check_pcdec(
-            CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", True)
+        for mode in range(4):
+            self.check_pcdec(
+                CODES, "0_11_1001_10101_10111_10111_10101_1001_11_0", mode)
 
     def case_pcdec_long_seq(self):
-        self.check_pcdec(CODES,
-                         "0_11_1001_10101_10111_10111_10101_1001_11_0"
-                         + CODE_37 * 6, False)
-
-    def case_pcdec_long_seq_once(self):
-        self.check_pcdec(CODES,
-                         "0_11_1001_10101_10111_10111_10101_1001_11_0"
-                         + CODE_37 * 6, True)
+        for mode in range(4):
+            self.check_pcdec(CODES,
+                             "0_11_1001_10101_10111_10111_10101_1001_11_0"
+                             + CODE_37 * 6, mode)
 
     def case_pcdec_invalid_code_at_start(self):
-        self.check_pcdec(CODES, "_".join(["1000", CODE_35]), False)
-
-    def case_pcdec_invalid_code_at_start_once(self):
-        self.check_pcdec(CODES, "_".join(["1000", CODE_35]), True)
+        for mode in range(4):
+            self.check_pcdec(CODES, "_".join(["1000", CODE_35]), mode)
 
     def case_pcdec_invalid_code_after_3(self):
-        self.check_pcdec(CODES, "_".join(
-            [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), False)
-
-    def case_pcdec_invalid_code_after_3_once(self):
-        self.check_pcdec(CODES, "_".join(
-            [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), True)
+        for mode in range(4):
+            self.check_pcdec(CODES, "_".join(
+                [CODE_2, CODE_19, CODE_35, "1000", CODE_35]), mode)
 
     def case_pcdec_invalid_code_after_8(self):
-        self.check_pcdec(CODES, "_".join(
-            [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), False)
-
-    def case_pcdec_invalid_code_after_8_once(self):
-        self.check_pcdec(CODES, "_".join(
-            [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), True)
+        for mode in range(4):
+            self.check_pcdec(CODES, "_".join(
+                [CODE_2, CODE_19, *([CODE_35] * 6), "1000", CODE_35]), mode)
 
     def case_pcdec_invalid_code_in_rb(self):
-        self.check_pcdec(CODES, "_".join(
-            [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), False)
-
-    def case_pcdec_invalid_code_in_rb_once(self):
-        self.check_pcdec(CODES, "_".join(
-            [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), True)
+        for mode in range(4):
+            self.check_pcdec(CODES, "_".join(
+                [CODE_2, CODE_19, "1000", *([CODE_19] * 15)]), mode)
 
     def case_pcdec_overlong_code(self):
-        self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, "10000000"]), False)
-
-    def case_pcdec_overlong_code_once(self):
-        self.check_pcdec(CODES, "_".join([CODE_2, CODE_19, "10000000"]), True)
+        for mode in range(4):
+            self.check_pcdec(CODES, "_".join(
+                [CODE_2, CODE_19, "10000000"]), mode)
 
     def case_pcdec_incomplete_code(self):
-        self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), False)
-
-    def case_pcdec_incomplete_code_once(self):
-        self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), True)
+        for mode in range(4):
+            self.check_pcdec(CODES, "_".join([CODE_19[:-1]]), mode)
 
     def case_rest(self):
-        for repeat in range(8):
-            for bits in itertools.product("01", repeat=repeat):
-                self.check_pcdec(CODES, "".join(bits), False)
-                self.check_pcdec(CODES, "".join(bits), True)
-                # 60 so we cover both less and more than 64 bits
-                self.check_pcdec(CODES, "".join(bits) + "0" * 60, False)
-                self.check_pcdec(CODES, "".join(bits) + "0" * 60, True)
+        for mode in range(4):
+            for repeat in range(8):
+                for bits in itertools.product("01", repeat=repeat):
+                    self.check_pcdec(CODES, "".join(bits), mode)
+                    # 60 so we cover both less and more than 64 bits
+                    self.check_pcdec(CODES, "".join(bits) + "0" * 60, mode)