from contextlib import contextmanager
import unittest
-from typing import Any, Callable, ContextManager, Iterator, Tuple
+from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
GPR_SIZE_IN_BYTES,
'sv.std *4, 0(3)'
])
+ def tst_toom_mul_sim(
+ self, code, # type: Mul
+ lhs_signed, # type: bool
+ rhs_signed, # type: bool
+ get_state_factory, # type: Callable[[Mul], _StateFactory]
+ test_cases, # type: Iterable[tuple[int, int]]
+ ):
+ print(code.retval[1])
+ print(code.fn.ops_to_str())
+ state_factory = get_state_factory(code)
+ ptr_in = 0x100
+ dest_ptr = ptr_in + code.dest_offset
+ lhs_ptr = ptr_in + code.lhs_offset
+ rhs_ptr = ptr_in + code.rhs_offset
+ lhs_size_in_bits = code.lhs_size_in_words * GPR_SIZE_IN_BITS
+ rhs_size_in_bits = code.rhs_size_in_words * GPR_SIZE_IN_BITS
+ for lhs_value, rhs_value in test_cases:
+ lhs_value %= 1 << lhs_size_in_bits
+ rhs_value %= 1 << rhs_size_in_bits
+ if lhs_signed and lhs_value >> (lhs_size_in_bits - 1):
+ lhs_value -= 1 << lhs_size_in_bits
+ if rhs_signed and rhs_value >> (rhs_size_in_bits - 1):
+ rhs_value -= 1 << rhs_size_in_bits
+ prod_value = lhs_value * rhs_value
+ lhs_value %= 1 << lhs_size_in_bits
+ rhs_value %= 1 << rhs_size_in_bits
+ prod_value %= 1 << (lhs_size_in_bits + rhs_size_in_bits)
+ with self.subTest(lhs_signed=lhs_signed, rhs_signed=rhs_signed,
+ lhs_value=hex(lhs_value),
+ rhs_value=hex(rhs_value),
+ prod_value=hex(prod_value)):
+ with state_factory() as state:
+ state[code.ptr_in] = ptr_in,
+ for i in range(code.lhs_size_in_words):
+ v = lhs_value >> GPR_SIZE_IN_BITS * i
+ v &= GPR_VALUE_MASK
+ state.store(lhs_ptr + i * GPR_SIZE_IN_BYTES, v)
+ for i in range(code.rhs_size_in_words):
+ v = rhs_value >> GPR_SIZE_IN_BITS * i
+ v &= GPR_VALUE_MASK
+ state.store(rhs_ptr + i * GPR_SIZE_IN_BYTES, v)
+ code.fn.sim(state)
+ prod = 0
+ for i in range(code.dest_size_in_words):
+ v = state.load(dest_ptr + GPR_SIZE_IN_BYTES * i)
+ prod += v << (GPR_SIZE_IN_BITS * i)
+ self.assertEqual(hex(prod), hex(prod_value),
+ f"failed: state={state}")
+
+ def tst_toom_mul_all_sizes_pre_ra_sim(self, instances):
+ # type: (tuple[ToomCookInstance, ...]) -> None
+ for lhs_signed in False, True:
+ for rhs_signed in False, True:
+ def mul(fn, lhs, rhs):
+ # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
+ v = ToomCookMul(
+ fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
+ rhs_signed=rhs_signed, instances=instances)
+ return v.retval, v
+ for lhs_size_in_words in range(1, 32):
+ for rhs_size_in_words in range(1, 32):
+ lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
+ rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
+ with self.subTest(lhs_size_in_words=lhs_size_in_words,
+ rhs_size_in_words=rhs_size_in_words,
+ lhs_signed=lhs_signed,
+ rhs_signed=rhs_signed):
+ test_cases = [] # type: list[tuple[int, int]]
+ test_cases.append((-1, -1))
+ test_cases.append(((0x80 << 2048) // 0xFF,
+ (0x80 << 2048) // 0xFF))
+ test_cases.append(((0x40 << 2048) // 0xFF,
+ (0x80 << 2048) // 0xFF))
+ test_cases.append(((0x80 << 2048) // 0xFF,
+ (0x40 << 2048) // 0xFF))
+ test_cases.append(((0x40 << 2048) // 0xFF,
+ (0x40 << 2048) // 0xFF))
+ test_cases.append((1 << (lhs_size_in_bits - 1),
+ 1 << (rhs_size_in_bits - 1)))
+ test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
+ test_cases.append((1 << (lhs_size_in_bits - 1), 1))
+ test_cases.append((1, 1))
+ self.tst_toom_mul_sim(
+ code=Mul(mul=mul,
+ lhs_size_in_words=lhs_size_in_words,
+ rhs_size_in_words=rhs_size_in_words),
+ lhs_signed=lhs_signed, rhs_signed=rhs_signed,
+ get_state_factory=get_pre_ra_state_factory,
+ test_cases=test_cases)
+
+ def test_toom_2_mul_all_sizes_pre_ra_sim(self):
+ self.skipTest("broken") # FIXME: fix
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ (TOOM_2, TOOM_2, TOOM_2, TOOM_2))
+
if __name__ == "__main__":
unittest.main()