svp64_utf_8_validation.py works!
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 29 Aug 2022 07:32:55 +0000 (00:32 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 29 Aug 2022 07:32:55 +0000 (00:32 -0700)
src/openpower/decoder/isa/test_caller_svp64_utf_8_validation.py
src/openpower/test/algorithms/svp64_utf_8_validation.py

index a1352e4879b5214cf55fad888ac2608694b6ff11..7234a57794cb56719ffc350c8615e7f32e4bd8dd 100644 (file)
@@ -5,20 +5,57 @@ import unittest
 from openpower.test.algorithms.svp64_utf_8_validation import \
     SVP64UTF8ValidationTestCase
 from openpower.test.runner import TestRunnerBase
+from functools import lru_cache
 
 # writing the test_caller invocation this way makes it work with pytest
 
 
-@unittest.skip("not yet working")
-class TestSVP64UTF8Validation(TestRunnerBase):
+@lru_cache
+def make_cases():
+    # cache globally, so we only have to create test_data once per process
+    return SVP64UTF8ValidationTestCase().test_data
+
+
+class TestSVP64UTF8ValidationBase(TestRunnerBase):
+    __test__ = False
+
+    # split up test cases into SPLIT_COUNT tests, so we get some parallelism
+    SPLIT_COUNT = 64
+    SPLIT_INDEX = -1
+
     def __init__(self, test):
-        assert test == 'test'
-        super().__init__(SVP64UTF8ValidationTestCase().test_data)
+        assert test == 'test', f"test={test!r}"
+        cases = make_cases()
+        assert self.SPLIT_INDEX != -1, "must be overridden"
+        # split cases evenly over tests
+        start = (len(cases) * self.SPLIT_INDEX) // self.SPLIT_COUNT
+        end = (len(cases) * (self.SPLIT_INDEX + 1)) // self.SPLIT_COUNT
+        # if we have less cases than tests, move them all to the beginning,
+        # making finding failures faster
+        if len(cases) < self.SPLIT_COUNT:
+            start = 0
+            end = 0
+            if self.SPLIT_INDEX < len(cases):
+                start = self.SPLIT_INDEX
+                end = start + 1
+        # can't do raise SkipTest if `start == end`, it makes unittest break
+        super().__init__(cases[start:end])
+
+    @classmethod
+    def make_split_classes(cls):
+        for i in range(cls.SPLIT_COUNT):
+            exec(f"""
+class TestSVP64UTF8Validation{i}(TestSVP64UTF8ValidationBase):
+    __test__ = True
+    SPLIT_INDEX = {i}
 
     def test(self):
         # dummy function to make unittest try to test this class
         pass
+            """, globals())
+
 
+TestSVP64UTF8ValidationBase.make_split_classes()
 
 if __name__ == "__main__":
     unittest.main()
index 082c6d9afdac4465c61491e88b3737e21b7a695b..f040f1e6f114927bfcae0714d9d76bb737e0c42c 100644 (file)
@@ -145,6 +145,25 @@ def svp64_utf8_validation_asm():
     # nibbles to look up in r64-r71 -- u64x8
     temp_vec2 = cur_bytes + vec_sz * 2
     temp_vec2_end = temp_vec2 + vec_sz
+
+    def sv_set_0x80_if_ge(out_v, inp_v, temp_s, compare_rhs):
+        # type: (int, int, int, int) -> list[str]
+        """ generate values with bit 0x80 set if the input vector is
+        unsigned `>= compare_rhs`, this assumes `0x80 <= compare_rhs <= 0xFF`
+        and the input vector elements are in `0 <= v <= 0xFF`.
+
+        can't use CRs for this, since vectors of CRs used as masks currently
+        max out at 4 in the simulator.
+        """
+        assert 0x80 <= compare_rhs <= 0xFF, \
+            "the algorithm only works if compare_rhs is in range"
+        max_arg = compare_rhs - 1
+        add_arg = 0x80 - compare_rhs
+        return [
+            f"addi {temp_s}, 0, {max_arg}",
+            f"sv.maxu *{out_v}, *{inp_v}, {temp_s}",
+            f"sv.addi *{out_v}, *{out_v}, {add_arg}"
+        ]
     return [
         # input addr in r3, input length in r4
         f"setvl 0, 0, {prev_bytes_sz}, 0, 1, 1",  # set VL to prev_bytes_sz
@@ -167,7 +186,9 @@ def svp64_utf8_validation_asm():
         f"setvl. 5, 4, {vec_sz}, 0, 1, 1",  # set VL to min(vec_sz, r4)
         # if no bytes left to load, run final check
         f"bc 12, 2, final_check # beq final_check",
-        f"sv.lbz/els *{cur_bytes}, 0({inp_addr})",  # load bytes
+        # sv.lbz/els is buggy, use sv.lbzx instead:
+        f"sv.addi *{cur_bytes + 1}, *{cur_bytes}, 1",  # create indexes
+        f"sv.lbzx *{cur_bytes}, {inp_addr}, *{cur_bytes}",  # load bytes
         f"setvl 0, 0, {vec_sz}, 0, 1, 1",  # set VL to vec_sz
         # now we can operate on vec_sz byte chunks, branch to `fail` if they
         # don't pass validation.
@@ -184,7 +205,7 @@ def svp64_utf8_validation_asm():
         f"addi 9, 0, {0xF}",
         f"sv.and *{temp_vec2}, *{cur_bytes - 1}, 9",
         # look-up nibbles in table
-        f"sv.lbzx *{temp_vec2}, 6, *{temp_vec2}",
+        f"sv.lbzx *{temp_vec2}, 7, *{temp_vec2}",
         # bitwise and into error flags
         f"sv.and *{temp_vec1}, *{temp_vec1}, *{temp_vec2}",
 
@@ -192,11 +213,12 @@ def svp64_utf8_validation_asm():
         # srdi *{temp_vec2}, *{cur_bytes}, 4
         f"sv.rldicl *{temp_vec2}, *{cur_bytes}, {64 - 4}, 4",
         # look-up nibbles in table
-        f"sv.lbzx *{temp_vec2}, 6, *{temp_vec2}",
+        f"sv.lbzx *{temp_vec2}, 8, *{temp_vec2}",
         # bitwise and into error flags
         f"sv.and *{temp_vec1}, *{temp_vec1}, *{temp_vec2}",
 
         # or-reduce error flags into temp_vec2_end
+        f"sv.addi {temp_vec2_end}, 0, 0",
         f"sv.ori *{temp_vec2}, *{temp_vec1}, 0",
         f"sv.or *{temp_vec2 + 1}, *{temp_vec2}, *{temp_vec2 + 1}",
         # check for any actual error flags set
@@ -207,21 +229,27 @@ def svp64_utf8_validation_asm():
         f"bc 4, 2, fail # bne fail",
 
         # check for the correct number of continuation bytes for 3/4-byte cases
+
         # set bit 0x80 (TwoContinuations) if input is >= 0xE0
-        f"sv.cmpli *0, 1, *{cur_bytes - 2}, {0xE0}",
+        *sv_set_0x80_if_ge(out_v=temp_vec2, inp_v=cur_bytes - 2,
+                           temp_s=9, compare_rhs=0xE0),
         # xor into error flags
-        f"sv.xori/m=ge *{temp_vec1}, *{temp_vec1}, {0x80}",
+        f"sv.xor *{temp_vec1}, *{temp_vec1}, *{temp_vec2}",
         # set bit 0x80 (TwoContinuations) if input is >= 0xF0
-        f"sv.cmpli *0, 1, *{cur_bytes - 2}, {0xF0}",
+        *sv_set_0x80_if_ge(out_v=temp_vec2, inp_v=cur_bytes - 3,
+                           temp_s=9, compare_rhs=0xF0),
         # xor into error flags
-        f"sv.xori/m=ge *{temp_vec1}, *{temp_vec1}, {0x80}",
+        f"sv.xor *{temp_vec1}, *{temp_vec1}, *{temp_vec2}",
         # now bit 0x80 is set in temp_vec1 if there's an error
         # or-reduce into temp_vec2
+        f"sv.addi {temp_vec2}, 0, 0",
         f"sv.or *{temp_vec1 + 1}, *{temp_vec1}, *{temp_vec1 + 1}",
         # adjust count/pointer
         f"add 3, 3, 5",  # increment pointer
-        f"sub 4, 4, 5",  # decrement count
-        f"sv.andi. {temp_vec2}, {temp_vec2}, {0x80}",  # check if any errors
+        f"subf 4, 5, 4",  # decrement count
+        # sv.andi. is buggy, so move to r9 first
+        f"sv.ori 9, {temp_vec2}, 0",
+        f"andi. 9, 9, {0x80}",  # check if any errors
         f"bc 12, 2, loop # beq loop",  # if no errors loop, else fail
         f"fail:",
         f"addi 3, 0, 0",
@@ -272,12 +300,19 @@ def assemble(instructions, start_pc=0):
 
 
 class SVP64UTF8ValidationTestCase(TestAccumulatorBase):
+    def __init__(self):
+        self.__seen_cases = set()
+        super().__init__()
+
     @cached_property
     def program(self):
         return assemble(svp64_utf8_validation_asm())
 
     def run_case(self, data, src_loc_at=0):
         # type: (bytes, int) -> None
+        if data in self.__seen_cases:
+            return
+        self.__seen_cases.add(data)
         expected = 1
         try:
             data.decode("utf-8")
@@ -290,16 +325,23 @@ class SVP64UTF8ValidationTestCase(TestAccumulatorBase):
         initial_mem = {}
         for i, v in enumerate(data):
             initial_mem[i + initial_regs[3]] = v, 1
+        for i, v in enumerate(FIRST_BYTE_LOW_NIBBLE_LUT):
+            initial_mem[i + FIRST_BYTE_LOW_NIBBLE_LUT_ADDR] = int(v), 1
+        for i, v in enumerate(FIRST_BYTE_HIGH_NIBBLE_LUT):
+            initial_mem[i + FIRST_BYTE_HIGH_NIBBLE_LUT_ADDR] = int(v), 1
+        for i, v in enumerate(SECOND_BYTE_HIGH_NIBBLE_LUT):
+            initial_mem[i + SECOND_BYTE_HIGH_NIBBLE_LUT_ADDR] = int(v), 1
         stop_at_pc = 0x10000000
         initial_sprs = {8: SelectableInt(stop_at_pc, 64)}
-        e = ExpectedState(pc=stop_at_pc, int_regs=4, crregs=0, fp_regs=0)
+        e = ExpectedState(pc=stop_at_pc, int_regs=4, crregs=0, fp_regs=0,
+                          so=None, ov=None, ca=None)
         e.intregs[:3] = initial_regs[:3]
         e.intregs[3] = expected
         with self.subTest(data=data, expected=expected):
             self.add_case(self.program, initial_regs, initial_mem=initial_mem,
-                        initial_sprs=initial_sprs, stop_at_pc=stop_at_pc,
-                        expected=e,
-                        src_loc_at=src_loc_at + 1)
+                          initial_sprs=initial_sprs, stop_at_pc=stop_at_pc,
+                          expected=e,
+                          src_loc_at=src_loc_at + 1)
 
     def run_cases(self, data):
         # type: (bytes | str) -> None
@@ -307,164 +349,128 @@ class SVP64UTF8ValidationTestCase(TestAccumulatorBase):
             data = data.encode("utf-8")
         data = b' ' * 8 + data + b' ' * 8
         for i in range(len(data)):
-            part = data[i:]
-            for j in range(len(part)):
-                self.run_case(part[:j], src_loc_at=1)
+            self.run_case(data[i:], src_loc_at=1)
+            self.run_case(data[:i], src_loc_at=1)
 
     def case_empty(self):
         self.run_case(b"")
 
+    def case_x6_sp_nul(self):
+        self.run_case(b' ' * 6 + b'\x00')
+
     def case_nul(self):
         self.run_cases("\u0000")  # min 1-byte
 
-    @skip_case
     def case_a(self):
         self.run_cases("a")
 
-    @skip_case
     def case_7f(self):
         self.run_cases("\u007F")  # max 1-byte
 
-    @skip_case
     def case_c0_80(self):
         self.run_cases(b"\xC0\x80")  # min 2-byte overlong encoding
 
-    @skip_case
     def case_c1_bf(self):
         self.run_cases(b"\xC1\xBF")  # max 2-byte overlong encoding
 
-    @skip_case
     def case_u0080(self):
         self.run_cases("\u0080")  # min 2-byte
 
-    @skip_case
     def case_u07ff(self):
         self.run_cases("\u07FF")  # max 2-byte
 
-    @skip_case
     def case_e0_80_80(self):
         self.run_cases(b"\xE0\x80\x80")  # min 3-byte overlong encoding
 
-    @skip_case
     def case_e0_9f_bf(self):
         self.run_cases(b"\xE0\x9F\xBF")  # max 3-byte overlong encoding
 
-    @skip_case
     def case_u0800(self):
         self.run_cases("\u0800")  # min 3-byte
 
-    @skip_case
     def case_u0fff(self):
         self.run_cases("\u0FFF")
 
-    @skip_case
     def case_u1000(self):
         self.run_cases("\u1000")
 
-    @skip_case
     def case_ucfff(self):
         self.run_cases("\uCFFF")
 
-    @skip_case
     def case_ud000(self):
         self.run_cases("\uD000")
 
-    @skip_case
     def case_ud7ff(self):
         self.run_cases("\uD7FF")
 
-    @skip_case
     def case_ed_a0_80(self):
         self.run_cases(b"\xED\xA0\x80")  # first high surrogate
 
-    @skip_case
     def case_ed_af_bf(self):
         self.run_cases(b"\xED\xAF\xBF")  # last high surrogate
 
-    @skip_case
     def case_ed_b0_80(self):
         self.run_cases(b"\xED\xB0\x80")  # first low surrogate
 
-    @skip_case
     def case_ed_bf_bf(self):
         self.run_cases(b"\xED\xBF\xBF")  # last low surrogate
 
-    @skip_case
     def case_ue000(self):
         self.run_cases("\uE000")
 
-    @skip_case
     def case_uffff(self):
         self.run_cases("\uFFFF")  # max 3-byte
 
-    @skip_case
     def case_f0_80_80_80(self):
         self.run_cases(b"\xF0\x80\x80\x80")  # min 4-byte overlong encoding
 
-    @skip_case
     def case_f0_bf_bf_bf(self):
         self.run_cases(b"\xF0\x8F\xBF\xBF")  # max 4-byte overlong encoding
 
-    @skip_case
     def case_u00010000(self):
         self.run_cases("\U00010000")  # min 4-byte
 
-    @skip_case
     def case_u0003ffff(self):
         self.run_cases("\U0003FFFF")
 
-    @skip_case
     def case_u00040000(self):
         self.run_cases("\U00040000")
 
-    @skip_case
     def case_u000fffff(self):
         self.run_cases("\U000FFFFF")
 
-    @skip_case
     def case_u00100000(self):
         self.run_cases("\U00100000")
 
-    @skip_case
     def case_u0010ffff(self):
         self.run_cases("\U0010FFFF")  # max 4-byte
 
-    @skip_case
     def case_f4_90_80_80(self):
         self.run_cases(b"\xF4\x90\x80\x80")  # first too-big encoding
 
-    @skip_case
     def case_f7_bf_bf_bf(self):
         self.run_cases(b"\xF7\xBF\xBF\xBF")  # max too-big 4-byte encoding
 
-    @skip_case
     def case_f8_x4_80(self):
         self.run_cases(b"\xF8" + b"\x80" * 4)  # min too-big 5-byte encoding
 
-    @skip_case
     def case_fb_x4_bf(self):
         self.run_cases(b"\xFB" + b"\xBF" * 4)  # max too-big 5-byte encoding
 
-    @skip_case
     def case_fc_x5_80(self):
         self.run_cases(b"\xFC" + b"\x80" * 5)  # min too-big 6-byte encoding
 
-    @skip_case
     def case_fd_x5_bf(self):
         self.run_cases(b"\xFD" + b"\xBF" * 5)  # max too-big 6-byte encoding
 
-    @skip_case
     def case_fe_x6_80(self):
         self.run_cases(b"\xFE" + b"\x80" * 6)  # min too-big 7-byte encoding
 
-    @skip_case
     def case_fe_x6_bf(self):
         self.run_cases(b"\xFE" + b"\xBF" * 6)  # max too-big 7-byte encoding
 
-    @skip_case
     def case_ff_x7_80(self):
         self.run_cases(b"\xFF" + b"\x80" * 7)  # min too-big 8-byte encoding
 
-    @skip_case
     def case_ff_x7_bf(self):
         self.run_cases(b"\xFF" + b"\xBF" * 7)  # max too-big 8-byte encoding