add general TOOM-2 test
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Nov 2022 06:58:14 +0000 (22:58 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 16 Nov 2022 06:58:14 +0000 (22:58 -0800)
src/bigint_presentation_code/_tests/test_toom_cook.py

index 76b9a5e37f2a3bb8e65debc412d47205e23be67c..41a4e2236b159ad89d8481fc00a2740113dfb5a3 100644 (file)
@@ -1,6 +1,6 @@
 from contextlib import contextmanager
 import unittest
-from typing import Any, Callable, ContextManager, Iterator, Tuple
+from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
 
 from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
                                                   GPR_SIZE_IN_BYTES,
@@ -1847,6 +1847,102 @@ class TestToomCook(unittest.TestCase):
             'sv.std *4, 0(3)'
         ])
 
+    def tst_toom_mul_sim(
+        self, code,  # type: Mul
+        lhs_signed,  # type: bool
+        rhs_signed,  # type: bool
+        get_state_factory,  # type: Callable[[Mul], _StateFactory]
+        test_cases,  # type: Iterable[tuple[int, int]]
+    ):
+        print(code.retval[1])
+        print(code.fn.ops_to_str())
+        state_factory = get_state_factory(code)
+        ptr_in = 0x100
+        dest_ptr = ptr_in + code.dest_offset
+        lhs_ptr = ptr_in + code.lhs_offset
+        rhs_ptr = ptr_in + code.rhs_offset
+        lhs_size_in_bits = code.lhs_size_in_words * GPR_SIZE_IN_BITS
+        rhs_size_in_bits = code.rhs_size_in_words * GPR_SIZE_IN_BITS
+        for lhs_value, rhs_value in test_cases:
+            lhs_value %= 1 << lhs_size_in_bits
+            rhs_value %= 1 << rhs_size_in_bits
+            if lhs_signed and lhs_value >> (lhs_size_in_bits - 1):
+                lhs_value -= 1 << lhs_size_in_bits
+            if rhs_signed and rhs_value >> (rhs_size_in_bits - 1):
+                rhs_value -= 1 << rhs_size_in_bits
+            prod_value = lhs_value * rhs_value
+            lhs_value %= 1 << lhs_size_in_bits
+            rhs_value %= 1 << rhs_size_in_bits
+            prod_value %= 1 << (lhs_size_in_bits + rhs_size_in_bits)
+            with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
+                              lhs_value=hex(lhs_value),
+                              rhs_value=hex(rhs_value),
+                              prod_value=hex(prod_value)):
+                with state_factory() as state:
+                    state[code.ptr_in] = ptr_in,
+                    for i in range(code.lhs_size_in_words):
+                        v = lhs_value >> GPR_SIZE_IN_BITS * i
+                        v &= GPR_VALUE_MASK
+                        state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
+                    for i in range(code.rhs_size_in_words):
+                        v = rhs_value >> GPR_SIZE_IN_BITS * i
+                        v &= GPR_VALUE_MASK
+                        state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
+                    code.fn.sim(state)
+                    prod = 0
+                    for i in range(code.dest_size_in_words):
+                        v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
+                        prod += v << (GPR_SIZE_IN_BITS * i)
+                    self.assertEqual(hex(prod), hex(prod_value),
+                                     f"failed: state={state}")
+
+    def tst_toom_mul_all_sizes_pre_ra_sim(self, instances):
+        # type: (tuple[ToomCookInstance, ...]) -> None
+        for lhs_signed in False, True:
+            for rhs_signed in False, True:
+                def mul(fn, lhs, rhs):
+                    # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
+                    v = ToomCookMul(
+                        fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
+                        rhs_signed=rhs_signed, instances=instances)
+                    return v.retval, v
+                for lhs_size_in_words in range(1, 32):
+                    for rhs_size_in_words in range(1, 32):
+                        lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
+                        rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
+                        with self.subTest(lhs_size_in_words=lhs_size_in_words,
+                                          rhs_size_in_words=rhs_size_in_words,
+                                          lhs_signed=lhs_signed,
+                                          rhs_signed=rhs_signed):
+                            test_cases = []  # type: list[tuple[int, int]]
+                            test_cases.append((-1, -1))
+                            test_cases.append(((0x80 << 2048) // 0xFF,
+                                               (0x80 << 2048) // 0xFF))
+                            test_cases.append(((0x40 << 2048) // 0xFF,
+                                               (0x80 << 2048) // 0xFF))
+                            test_cases.append(((0x80 << 2048) // 0xFF,
+                                               (0x40 << 2048) // 0xFF))
+                            test_cases.append(((0x40 << 2048) // 0xFF,
+                                               (0x40 << 2048) // 0xFF))
+                            test_cases.append((1 << (lhs_size_in_bits - 1),
+                                               1 << (rhs_size_in_bits - 1)))
+                            test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
+                            test_cases.append((1 << (lhs_size_in_bits - 1), 1))
+                            test_cases.append((1, 1))
+                            self.tst_toom_mul_sim(
+                                code=Mul(mul=mul,
+                                         lhs_size_in_words=lhs_size_in_words,
+                                         rhs_size_in_words=rhs_size_in_words),
+                                lhs_signed=lhs_signed, rhs_signed=rhs_signed,
+                                get_state_factory=get_pre_ra_state_factory,
+                                test_cases=test_cases)
+
+    def test_toom_2_mul_all_sizes_pre_ra_sim(self):
+        self.skipTest("broken")  # FIXME: fix
+        TOOM_2 = ToomCookInstance.make_toom_2()
+        self.tst_toom_mul_all_sizes_pre_ra_sim(
+            (TOOM_2, TOOM_2, TOOM_2, TOOM_2))
+
 
 if __name__ == "__main__":
     unittest.main()