From 918f3eadf7118a6ecd0e2eb6caaaed9da6936299 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 29 Aug 2022 00:32:55 -0700 Subject: [PATCH] svp64_utf_8_validation.py works! --- .../isa/test_caller_svp64_utf_8_validation.py | 45 ++++++- .../test/algorithms/svp64_utf_8_validation.py | 114 +++++++++--------- 2 files changed, 101 insertions(+), 58 deletions(-) diff --git a/src/openpower/decoder/isa/test_caller_svp64_utf_8_validation.py b/src/openpower/decoder/isa/test_caller_svp64_utf_8_validation.py index a1352e48..7234a577 100644 --- a/src/openpower/decoder/isa/test_caller_svp64_utf_8_validation.py +++ b/src/openpower/decoder/isa/test_caller_svp64_utf_8_validation.py @@ -5,20 +5,57 @@ import unittest from openpower.test.algorithms.svp64_utf_8_validation import \ SVP64UTF8ValidationTestCase from openpower.test.runner import TestRunnerBase +from functools import lru_cache # writing the test_caller invocation this way makes it work with pytest -@unittest.skip("not yet working") -class TestSVP64UTF8Validation(TestRunnerBase): +@lru_cache +def make_cases(): + # cache globally, so we only have to create test_data once per process + return SVP64UTF8ValidationTestCase().test_data + + +class TestSVP64UTF8ValidationBase(TestRunnerBase): + __test__ = False + + # split up test cases into SPLIT_COUNT tests, so we get some parallelism + SPLIT_COUNT = 64 + SPLIT_INDEX = -1 + def __init__(self, test): - assert test == 'test' - super().__init__(SVP64UTF8ValidationTestCase().test_data) + assert test == 'test', f"test={test!r}" + cases = make_cases() + assert self.SPLIT_INDEX != -1, "must be overridden" + # split cases evenly over tests + start = (len(cases) * self.SPLIT_INDEX) // self.SPLIT_COUNT + end = (len(cases) * (self.SPLIT_INDEX + 1)) // self.SPLIT_COUNT + # if we have less cases than tests, move them all to the beginning, + # making finding failures faster + if len(cases) < self.SPLIT_COUNT: + start = 0 + end = 0 + if self.SPLIT_INDEX < len(cases): + start = self.SPLIT_INDEX + end = start + 1 + # can't do raise SkipTest if `start == end`, it makes unittest break + super().__init__(cases[start:end]) + + @classmethod + def make_split_classes(cls): + for i in range(cls.SPLIT_COUNT): + exec(f""" +class TestSVP64UTF8Validation{i}(TestSVP64UTF8ValidationBase): + __test__ = True + SPLIT_INDEX = {i} def test(self): # dummy function to make unittest try to test this class pass + """, globals()) + +TestSVP64UTF8ValidationBase.make_split_classes() if __name__ == "__main__": unittest.main() diff --git a/src/openpower/test/algorithms/svp64_utf_8_validation.py b/src/openpower/test/algorithms/svp64_utf_8_validation.py index 082c6d9a..f040f1e6 100644 --- a/src/openpower/test/algorithms/svp64_utf_8_validation.py +++ b/src/openpower/test/algorithms/svp64_utf_8_validation.py @@ -145,6 +145,25 @@ def svp64_utf8_validation_asm(): # nibbles to look up in r64-r71 -- u64x8 temp_vec2 = cur_bytes + vec_sz * 2 temp_vec2_end = temp_vec2 + vec_sz + + def sv_set_0x80_if_ge(out_v, inp_v, temp_s, compare_rhs): + # type: (int, int, int, int) -> list[str] + """ generate values with bit 0x80 set if the input vector is + unsigned `>= compare_rhs`, this assumes `0x80 <= compare_rhs <= 0xFF` + and the input vector elements are in `0 <= v <= 0xFF`. + + can't use CRs for this, since vectors of CRs used as masks currently + max out at 4 in the simulator. + """ + assert 0x80 <= compare_rhs <= 0xFF, \ + "the algorithm only works if compare_rhs is in range" + max_arg = compare_rhs - 1 + add_arg = 0x80 - compare_rhs + return [ + f"addi {temp_s}, 0, {max_arg}", + f"sv.maxu *{out_v}, *{inp_v}, {temp_s}", + f"sv.addi *{out_v}, *{out_v}, {add_arg}" + ] return [ # input addr in r3, input length in r4 f"setvl 0, 0, {prev_bytes_sz}, 0, 1, 1", # set VL to prev_bytes_sz @@ -167,7 +186,9 @@ def svp64_utf8_validation_asm(): f"setvl. 5, 4, {vec_sz}, 0, 1, 1", # set VL to min(vec_sz, r4) # if no bytes left to load, run final check f"bc 12, 2, final_check # beq final_check", - f"sv.lbz/els *{cur_bytes}, 0({inp_addr})", # load bytes + # sv.lbz/els is buggy, use sv.lbzx instead: + f"sv.addi *{cur_bytes + 1}, *{cur_bytes}, 1", # create indexes + f"sv.lbzx *{cur_bytes}, {inp_addr}, *{cur_bytes}", # load bytes f"setvl 0, 0, {vec_sz}, 0, 1, 1", # set VL to vec_sz # now we can operate on vec_sz byte chunks, branch to `fail` if they # don't pass validation. @@ -184,7 +205,7 @@ def svp64_utf8_validation_asm(): f"addi 9, 0, {0xF}", f"sv.and *{temp_vec2}, *{cur_bytes - 1}, 9", # look-up nibbles in table - f"sv.lbzx *{temp_vec2}, 6, *{temp_vec2}", + f"sv.lbzx *{temp_vec2}, 7, *{temp_vec2}", # bitwise and into error flags f"sv.and *{temp_vec1}, *{temp_vec1}, *{temp_vec2}", @@ -192,11 +213,12 @@ def svp64_utf8_validation_asm(): # srdi *{temp_vec2}, *{cur_bytes}, 4 f"sv.rldicl *{temp_vec2}, *{cur_bytes}, {64 - 4}, 4", # look-up nibbles in table - f"sv.lbzx *{temp_vec2}, 6, *{temp_vec2}", + f"sv.lbzx *{temp_vec2}, 8, *{temp_vec2}", # bitwise and into error flags f"sv.and *{temp_vec1}, *{temp_vec1}, *{temp_vec2}", # or-reduce error flags into temp_vec2_end + f"sv.addi {temp_vec2_end}, 0, 0", f"sv.ori *{temp_vec2}, *{temp_vec1}, 0", f"sv.or *{temp_vec2 + 1}, *{temp_vec2}, *{temp_vec2 + 1}", # check for any actual error flags set @@ -207,21 +229,27 @@ def svp64_utf8_validation_asm(): f"bc 4, 2, fail # bne fail", # check for the correct number of continuation bytes for 3/4-byte cases + # set bit 0x80 (TwoContinuations) if input is >= 0xE0 - f"sv.cmpli *0, 1, *{cur_bytes - 2}, {0xE0}", + *sv_set_0x80_if_ge(out_v=temp_vec2, inp_v=cur_bytes - 2, + temp_s=9, compare_rhs=0xE0), # xor into error flags - f"sv.xori/m=ge *{temp_vec1}, *{temp_vec1}, {0x80}", + f"sv.xor *{temp_vec1}, *{temp_vec1}, *{temp_vec2}", # set bit 0x80 (TwoContinuations) if input is >= 0xF0 - f"sv.cmpli *0, 1, *{cur_bytes - 2}, {0xF0}", + *sv_set_0x80_if_ge(out_v=temp_vec2, inp_v=cur_bytes - 3, + temp_s=9, compare_rhs=0xF0), # xor into error flags - f"sv.xori/m=ge *{temp_vec1}, *{temp_vec1}, {0x80}", + f"sv.xor *{temp_vec1}, *{temp_vec1}, *{temp_vec2}", # now bit 0x80 is set in temp_vec1 if there's an error # or-reduce into temp_vec2 + f"sv.addi {temp_vec2}, 0, 0", f"sv.or *{temp_vec1 + 1}, *{temp_vec1}, *{temp_vec1 + 1}", # adjust count/pointer f"add 3, 3, 5", # increment pointer - f"sub 4, 4, 5", # decrement count - f"sv.andi. {temp_vec2}, {temp_vec2}, {0x80}", # check if any errors + f"subf 4, 5, 4", # decrement count + # sv.andi. is buggy, so move to r9 first + f"sv.ori 9, {temp_vec2}, 0", + f"andi. 9, 9, {0x80}", # check if any errors f"bc 12, 2, loop # beq loop", # if no errors loop, else fail f"fail:", f"addi 3, 0, 0", @@ -272,12 +300,19 @@ def assemble(instructions, start_pc=0): class SVP64UTF8ValidationTestCase(TestAccumulatorBase): + def __init__(self): + self.__seen_cases = set() + super().__init__() + @cached_property def program(self): return assemble(svp64_utf8_validation_asm()) def run_case(self, data, src_loc_at=0): # type: (bytes, int) -> None + if data in self.__seen_cases: + return + self.__seen_cases.add(data) expected = 1 try: data.decode("utf-8") @@ -290,16 +325,23 @@ class SVP64UTF8ValidationTestCase(TestAccumulatorBase): initial_mem = {} for i, v in enumerate(data): initial_mem[i + initial_regs[3]] = v, 1 + for i, v in enumerate(FIRST_BYTE_LOW_NIBBLE_LUT): + initial_mem[i + FIRST_BYTE_LOW_NIBBLE_LUT_ADDR] = int(v), 1 + for i, v in enumerate(FIRST_BYTE_HIGH_NIBBLE_LUT): + initial_mem[i + FIRST_BYTE_HIGH_NIBBLE_LUT_ADDR] = int(v), 1 + for i, v in enumerate(SECOND_BYTE_HIGH_NIBBLE_LUT): + initial_mem[i + SECOND_BYTE_HIGH_NIBBLE_LUT_ADDR] = int(v), 1 stop_at_pc = 0x10000000 initial_sprs = {8: SelectableInt(stop_at_pc, 64)} - e = ExpectedState(pc=stop_at_pc, int_regs=4, crregs=0, fp_regs=0) + e = ExpectedState(pc=stop_at_pc, int_regs=4, crregs=0, fp_regs=0, + so=None, ov=None, ca=None) e.intregs[:3] = initial_regs[:3] e.intregs[3] = expected with self.subTest(data=data, expected=expected): self.add_case(self.program, initial_regs, initial_mem=initial_mem, - initial_sprs=initial_sprs, stop_at_pc=stop_at_pc, - expected=e, - src_loc_at=src_loc_at + 1) + initial_sprs=initial_sprs, stop_at_pc=stop_at_pc, + expected=e, + src_loc_at=src_loc_at + 1) def run_cases(self, data): # type: (bytes | str) -> None @@ -307,164 +349,128 @@ class SVP64UTF8ValidationTestCase(TestAccumulatorBase): data = data.encode("utf-8") data = b' ' * 8 + data + b' ' * 8 for i in range(len(data)): - part = data[i:] - for j in range(len(part)): - self.run_case(part[:j], src_loc_at=1) + self.run_case(data[i:], src_loc_at=1) + self.run_case(data[:i], src_loc_at=1) def case_empty(self): self.run_case(b"") + def case_x6_sp_nul(self): + self.run_case(b' ' * 6 + b'\x00') + def case_nul(self): self.run_cases("\u0000") # min 1-byte - @skip_case def case_a(self): self.run_cases("a") - @skip_case def case_7f(self): self.run_cases("\u007F") # max 1-byte - @skip_case def case_c0_80(self): self.run_cases(b"\xC0\x80") # min 2-byte overlong encoding - @skip_case def case_c1_bf(self): self.run_cases(b"\xC1\xBF") # max 2-byte overlong encoding - @skip_case def case_u0080(self): self.run_cases("\u0080") # min 2-byte - @skip_case def case_u07ff(self): self.run_cases("\u07FF") # max 2-byte - @skip_case def case_e0_80_80(self): self.run_cases(b"\xE0\x80\x80") # min 3-byte overlong encoding - @skip_case def case_e0_9f_bf(self): self.run_cases(b"\xE0\x9F\xBF") # max 3-byte overlong encoding - @skip_case def case_u0800(self): self.run_cases("\u0800") # min 3-byte - @skip_case def case_u0fff(self): self.run_cases("\u0FFF") - @skip_case def case_u1000(self): self.run_cases("\u1000") - @skip_case def case_ucfff(self): self.run_cases("\uCFFF") - @skip_case def case_ud000(self): self.run_cases("\uD000") - @skip_case def case_ud7ff(self): self.run_cases("\uD7FF") - @skip_case def case_ed_a0_80(self): self.run_cases(b"\xED\xA0\x80") # first high surrogate - @skip_case def case_ed_af_bf(self): self.run_cases(b"\xED\xAF\xBF") # last high surrogate - @skip_case def case_ed_b0_80(self): self.run_cases(b"\xED\xB0\x80") # first low surrogate - @skip_case def case_ed_bf_bf(self): self.run_cases(b"\xED\xBF\xBF") # last low surrogate - @skip_case def case_ue000(self): self.run_cases("\uE000") - @skip_case def case_uffff(self): self.run_cases("\uFFFF") # max 3-byte - @skip_case def case_f0_80_80_80(self): self.run_cases(b"\xF0\x80\x80\x80") # min 4-byte overlong encoding - @skip_case def case_f0_bf_bf_bf(self): self.run_cases(b"\xF0\x8F\xBF\xBF") # max 4-byte overlong encoding - @skip_case def case_u00010000(self): self.run_cases("\U00010000") # min 4-byte - @skip_case def case_u0003ffff(self): self.run_cases("\U0003FFFF") - @skip_case def case_u00040000(self): self.run_cases("\U00040000") - @skip_case def case_u000fffff(self): self.run_cases("\U000FFFFF") - @skip_case def case_u00100000(self): self.run_cases("\U00100000") - @skip_case def case_u0010ffff(self): self.run_cases("\U0010FFFF") # max 4-byte - @skip_case def case_f4_90_80_80(self): self.run_cases(b"\xF4\x90\x80\x80") # first too-big encoding - @skip_case def case_f7_bf_bf_bf(self): self.run_cases(b"\xF7\xBF\xBF\xBF") # max too-big 4-byte encoding - @skip_case def case_f8_x4_80(self): self.run_cases(b"\xF8" + b"\x80" * 4) # min too-big 5-byte encoding - @skip_case def case_fb_x4_bf(self): self.run_cases(b"\xFB" + b"\xBF" * 4) # max too-big 5-byte encoding - @skip_case def case_fc_x5_80(self): self.run_cases(b"\xFC" + b"\x80" * 5) # min too-big 6-byte encoding - @skip_case def case_fd_x5_bf(self): self.run_cases(b"\xFD" + b"\xBF" * 5) # max too-big 6-byte encoding - @skip_case def case_fe_x6_80(self): self.run_cases(b"\xFE" + b"\x80" * 6) # min too-big 7-byte encoding - @skip_case def case_fe_x6_bf(self): self.run_cases(b"\xFE" + b"\xBF" * 6) # max too-big 7-byte encoding - @skip_case def case_ff_x7_80(self): self.run_cases(b"\xFF" + b"\x80" * 7) # min too-big 8-byte encoding - @skip_case def case_ff_x7_bf(self): self.run_cases(b"\xFF" + b"\xBF" * 7) # max too-big 8-byte encoding -- 2.30.2