increase pcdec. output compression by skipping impossible codes
[openpower-isa.git] / src / openpower / test / prefix_codes / prefix_codes_cases.py
index cdbf9c6f9804dd812475671231618010c477426f..c19fa350df1669b782ae722413b00ba59498eed9 100644 (file)
@@ -21,13 +21,22 @@ def make_tree(*codes):
     # type: (*str) -> int
     retval = 0
     for code in sorted(codes, key=len):
-        for i in range(len(code)):
-            assert retval & (1 << tree_code(code[:i])) == 0, \
-                f"conflicting code: {code} conflicts with {code[:i]}"
         retval |= 1 << tree_code(code)
     return retval
 
 
+def from_tree(tree):
+    # type: (int) -> list[str]
+    retval = []  # type: list[str]
+    for bit_len in range(1, 6):
+        for bits in itertools.product("01", repeat=bit_len):
+            bits = "".join(bits)
+            tree_index = int("0b1" + bits, 0)
+            if tree & (1 << tree_index):
+                retval.append(bits)
+    return retval
+
+
 CODE_2 = "0"
 CODE_7 = "11"
 CODE_19 = "1001"
@@ -50,7 +59,17 @@ class PrefixCodesCases(TestAccumulatorBase):
     def check_pcdec(self, supported_codes, input_bits, mode, src_loc_at=0):
         # type: (Iterable[str], str, int, int) -> None
         supported_codes = sorted(supported_codes, key=_code_sort_key)
-        assert len(supported_codes) <= 32
+        compressed_codes = []  # type: list[str]
+        codes = supported_codes.copy()
+        codes += map(lambda v: "".join(v), itertools.product("01", repeat=6))
+        for code in codes:
+            possible = True
+            for i in range(1, len(code)):
+                if code[:i] in compressed_codes:
+                    possible = False
+                    break
+            if possible:
+                compressed_codes.append(code)
         original_input_bits = input_bits
         input_bits = input_bits.replace("_", "")
         assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
@@ -65,7 +84,7 @@ class PrefixCodesCases(TestAccumulatorBase):
             if bit_len > len(input_bits):
                 hit_end = True
                 compressed_index = 0
-                for i, code in enumerate(supported_codes):
+                for i, code in enumerate(compressed_codes):
                     if _code_sort_key(code) < _code_sort_key(cur_input_bits):
                         compressed_index = i + 1
                     else:
@@ -74,17 +93,14 @@ class PrefixCodesCases(TestAccumulatorBase):
             tree_index *= 2
             if cur_input_bits[-1] == "1":
                 tree_index += 1
-            if bit_len < 6:
-                try:
-                    compressed_index = supported_codes.index(cur_input_bits)
-                    found = True
-                    used_bits = bit_len
-                    break
-                except ValueError:
-                    pass
-            else:
-                compressed_index = tree_index - 64 + len(supported_codes)
+            try:
+                compressed_index = compressed_codes.index(cur_input_bits)
                 used_bits = bit_len
+                if bit_len < 6:
+                    found = True
+                break
+            except ValueError:
+                pass
         if mode == 0:
             expected_RT = tree_index
             if not found:
@@ -194,3 +210,12 @@ class PrefixCodesCases(TestAccumulatorBase):
                     self.check_pcdec(CODES, "".join(bits), mode)
                     # 60 so we cover both less and more than 64 bits
                     self.check_pcdec(CODES, "".join(bits) + "0" * 60, mode)
+
+    def case_arbitrary_tree(self):
+        codes = from_tree(0x123456789ABCDEF0)
+        for mode in range(4):
+            for repeat in range(8):
+                for bits in itertools.product("01", repeat=repeat):
+                    self.check_pcdec(codes, "".join(bits), mode)
+                    # 60 so we cover both less and more than 64 bits
+                    self.check_pcdec(codes, "".join(bits) + "0" * 60, mode)