# type: (*str) -> int
retval = 0
for code in sorted(codes, key=len):
- for i in range(len(code)):
- assert retval & (1 << tree_code(code[:i])) == 0, \
- f"conflicting code: {code} conflicts with {code[:i]}"
retval |= 1 << tree_code(code)
return retval
+def from_tree(tree):
+ # type: (int) -> list[str]
+ retval = [] # type: list[str]
+ for bit_len in range(1, 6):
+ for bits in itertools.product("01", repeat=bit_len):
+ bits = "".join(bits)
+ tree_index = int("0b1" + bits, 0)
+ if tree & (1 << tree_index):
+ retval.append(bits)
+ return retval
+
+
CODE_2 = "0"
CODE_7 = "11"
CODE_19 = "1001"
def check_pcdec(self, supported_codes, input_bits, mode, src_loc_at=0):
# type: (Iterable[str], str, int, int) -> None
supported_codes = sorted(supported_codes, key=_code_sort_key)
- assert len(supported_codes) <= 32
+ compressed_codes = [] # type: list[str]
+ codes = supported_codes.copy()
+ codes += map(lambda v: "".join(v), itertools.product("01", repeat=6))
+ for code in codes:
+ possible = True
+ for i in range(1, len(code)):
+ if code[:i] in compressed_codes:
+ possible = False
+ break
+ if possible:
+ compressed_codes.append(code)
original_input_bits = input_bits
input_bits = input_bits.replace("_", "")
assert input_bits.lstrip("01") == "", "input_bits must be binary bits"
if bit_len > len(input_bits):
hit_end = True
compressed_index = 0
- for i, code in enumerate(supported_codes):
+ for i, code in enumerate(compressed_codes):
if _code_sort_key(code) < _code_sort_key(cur_input_bits):
compressed_index = i + 1
else:
tree_index *= 2
if cur_input_bits[-1] == "1":
tree_index += 1
- if bit_len < 6:
- try:
- compressed_index = supported_codes.index(cur_input_bits)
- found = True
- used_bits = bit_len
- break
- except ValueError:
- pass
- else:
- compressed_index = tree_index - 64 + len(supported_codes)
+ try:
+ compressed_index = compressed_codes.index(cur_input_bits)
used_bits = bit_len
+ if bit_len < 6:
+ found = True
+ break
+ except ValueError:
+ pass
if mode == 0:
expected_RT = tree_index
if not found:
self.check_pcdec(CODES, "".join(bits), mode)
# 60 so we cover both less and more than 64 bits
self.check_pcdec(CODES, "".join(bits) + "0" * 60, mode)
+
+ def case_arbitrary_tree(self):
+ codes = from_tree(0x123456789ABCDEF0)
+ for mode in range(4):
+ for repeat in range(8):
+ for bits in itertools.product("01", repeat=repeat):
+ self.check_pcdec(codes, "".join(bits), mode)
+ # 60 so we cover both less and more than 64 bits
+ self.check_pcdec(codes, "".join(bits) + "0" * 60, mode)