from bigint_presentation_code.register_allocator import allocate_registers
from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
simple_mul)
+from bigint_presentation_code.util import OSet
_StateFactory = Callable[[], ContextManager[BaseSimState]]
self.assertEqual(hex(prod), hex(prod_value),
f"failed: state={state}")
- def tst_toom_mul_all_sizes_pre_ra_sim(self, instances):
- # type: (tuple[ToomCookInstance, ...]) -> None
- for lhs_signed in False, True:
- for rhs_signed in False, True:
- def mul(fn, lhs, rhs):
- # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
- v = ToomCookMul(
- fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
- rhs_signed=rhs_signed, instances=instances)
- return v.retval, v
- for lhs_size_in_words in range(1, 32):
- for rhs_size_in_words in range(1, 32):
- lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
- rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
- with self.subTest(lhs_size_in_words=lhs_size_in_words,
- rhs_size_in_words=rhs_size_in_words,
- lhs_signed=lhs_signed,
- rhs_signed=rhs_signed):
- test_cases = [] # type: list[tuple[int, int]]
- test_cases.append((-1, -1))
- test_cases.append(((0x80 << 2048) // 0xFF,
- (0x80 << 2048) // 0xFF))
- test_cases.append(((0x40 << 2048) // 0xFF,
- (0x80 << 2048) // 0xFF))
- test_cases.append(((0x80 << 2048) // 0xFF,
- (0x40 << 2048) // 0xFF))
- test_cases.append(((0x40 << 2048) // 0xFF,
- (0x40 << 2048) // 0xFF))
- test_cases.append((1 << (lhs_size_in_bits - 1),
- 1 << (rhs_size_in_bits - 1)))
- test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
- test_cases.append((1 << (lhs_size_in_bits - 1), 1))
- test_cases.append((1, 1))
- self.tst_toom_mul_sim(
- code=Mul(mul=mul,
- lhs_size_in_words=lhs_size_in_words,
- rhs_size_in_words=rhs_size_in_words),
- lhs_signed=lhs_signed, rhs_signed=rhs_signed,
- get_state_factory=get_pre_ra_state_factory,
- test_cases=test_cases)
+ def tst_toom_mul_all_sizes_pre_ra_sim(self, instances, lhs_signed, rhs_signed):
+ # type: (tuple[ToomCookInstance, ...], bool, bool) -> None
+ def mul(fn, lhs, rhs):
+ # type: (Fn, SSAVal, SSAVal) -> tuple[SSAVal, ToomCookMul]
+ v = ToomCookMul(
+ fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs,
+ rhs_signed=rhs_signed, instances=instances)
+ return v.retval, v
+ sizes_in_words = OSet() # type: OSet[int]
+ for i in range(6):
+ sizes_in_words.add(1 << i)
+ sizes_in_words.add(3 << i)
+ sizes_in_words = OSet(
+ i for i in sorted(sizes_in_words) if 1 <= i <= 16)
+ for lhs_size_in_words in sizes_in_words:
+ for rhs_size_in_words in sizes_in_words:
+ lhs_size_in_bits = GPR_SIZE_IN_BITS * lhs_size_in_words
+ rhs_size_in_bits = GPR_SIZE_IN_BITS * rhs_size_in_words
+ with self.subTest(lhs_size_in_words=lhs_size_in_words,
+ rhs_size_in_words=rhs_size_in_words,
+ lhs_signed=lhs_signed,
+ rhs_signed=rhs_signed):
+ test_cases = [] # type: list[tuple[int, int]]
+ test_cases.append((-1, -1))
+ test_cases.append(((0x80 << 2048) // 0xFF,
+ (0x80 << 2048) // 0xFF))
+ test_cases.append(((0x40 << 2048) // 0xFF,
+ (0x80 << 2048) // 0xFF))
+ test_cases.append(((0x80 << 2048) // 0xFF,
+ (0x40 << 2048) // 0xFF))
+ test_cases.append(((0x40 << 2048) // 0xFF,
+ (0x40 << 2048) // 0xFF))
+ test_cases.append((1 << (lhs_size_in_bits - 1),
+ 1 << (rhs_size_in_bits - 1)))
+ test_cases.append((1, 1 << (rhs_size_in_bits - 1)))
+ test_cases.append((1 << (lhs_size_in_bits - 1), 1))
+ test_cases.append((1, 1))
+ self.tst_toom_mul_sim(
+ code=Mul(mul=mul,
+ lhs_size_in_words=lhs_size_in_words,
+ rhs_size_in_words=rhs_size_in_words),
+ lhs_signed=lhs_signed, rhs_signed=rhs_signed,
+ get_state_factory=get_pre_ra_state_factory,
+ test_cases=test_cases)
+
+ def test_toom_2_once_mul_uu_all_sizes_pre_ra_sim(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ (TOOM_2,), lhs_signed=False, rhs_signed=False)
+
+ def test_toom_2_once_mul_us_all_sizes_pre_ra_sim(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ (TOOM_2,), lhs_signed=False, rhs_signed=True)
+
+ def test_toom_2_once_mul_su_all_sizes_pre_ra_sim(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ (TOOM_2,), lhs_signed=True, rhs_signed=False)
+
+ def test_toom_2_once_mul_ss_all_sizes_pre_ra_sim(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ (TOOM_2,), lhs_signed=True, rhs_signed=True)
+
+ def test_toom_2_mul_uu_all_sizes_pre_ra_sim(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ instances, lhs_signed=False, rhs_signed=False)
+
+ def test_toom_2_mul_us_all_sizes_pre_ra_sim(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ instances, lhs_signed=False, rhs_signed=True)
+
+ def test_toom_2_mul_su_all_sizes_pre_ra_sim(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
+ self.tst_toom_mul_all_sizes_pre_ra_sim(
+ instances, lhs_signed=True, rhs_signed=False)
- def test_toom_2_mul_all_sizes_pre_ra_sim(self):
- self.skipTest("broken") # FIXME: fix
+ def test_toom_2_mul_ss_all_sizes_pre_ra_sim(self):
TOOM_2 = ToomCookInstance.make_toom_2()
+ instances = TOOM_2, TOOM_2, TOOM_2, TOOM_2
self.tst_toom_mul_all_sizes_pre_ra_sim(
- (TOOM_2, TOOM_2, TOOM_2, TOOM_2))
+ instances, lhs_signed=True, rhs_signed=True)
if __name__ == "__main__":
@final
class EvalOpValueRange:
__slots__ = ("eval_op", "inputs", "min_value", "max_value",
- "is_signed", "output_size")
+ "is_signed", "output_size", "name_part")
def __init__(self, eval_op, inputs):
# type: (EvalOp | int, tuple[EvalOpGenIrInput, ...]) -> None
min_v <<= GPR_SIZE_IN_BITS
max_v <<= GPR_SIZE_IN_BITS
self.output_size = output_size
+ if isinstance(eval_op, int):
+ self.name_part = f"const_{eval_op}"
+ else:
+ self.name_part = eval_op.name_part
@cached_property
def poly(self):
@plain_data(frozen=True)
@final
class EvalOpGenIrState:
- __slots__ = "fn", "inputs", "outputs_map"
+ __slots__ = "fn", "inputs", "outputs_map", "name"
- def __init__(self, fn, inputs):
- # type: (Fn, Iterable[EvalOpGenIrInput]) -> None
+ def __init__(self, fn, inputs, name):
+ # type: (Fn, Iterable[EvalOpGenIrInput], str) -> None
super().__init__()
self.fn = fn
self.inputs = tuple(inputs)
+ self.name = name
self.outputs_map = {} # type: dict[EvalOp | int, EvalOpGenIrOutput]
def get_output(self, eval_op):
return retval
value_range = EvalOpValueRange(eval_op=eval_op, inputs=self.inputs)
if isinstance(eval_op, int):
+ name = f"{self.name}_{EvalOp.get_name_part(eval_op)}"
li = self.fn.append_new_op(OpKind.LI, immediates=[eval_op],
- name=f"li_{eval_op}")
+ name=f"{name}_li")
output = cast_to_size(
fn=self.fn, ssa_val=li.outputs[0],
dest_size=value_range.output_size,
- src_signed=value_range.is_signed, name=f"cast_{eval_op}")
+ src_signed=value_range.is_signed, name=f"{name}_case")
retval = EvalOpGenIrOutput(output=output, value_range=value_range)
else:
retval = eval_op.make_output(state=self,
# type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
...
+ @cached_property
+ @abstractmethod
+ def name_part(self):
+ # type: () -> str
+ ...
+
+ @staticmethod
+ def get_name_part(eval_op):
+ # type: (EvalOp | int) -> str
+ if isinstance(eval_op, int):
+ return f"const_{eval_op}"
+ return eval_op.name_part
+
+ @property
+ @final
+ def lhs_name_part(self):
+ # type: () -> str
+ return EvalOp.get_name_part(self.lhs)
+
+ @property
+ @final
+ def rhs_name_part(self):
+ # type: () -> str
+ return EvalOp.get_name_part(self.rhs)
+
def __init__(self, lhs, rhs):
# type: (EvalOp | int, EvalOp | int) -> None
super().__init__()
# type: () -> EvalOpPoly
return self.lhs_poly + self.rhs_poly
+ @cached_property
+ def name_part(self):
+ # type: () -> str
+ return f"({self.lhs_name_part}+{self.rhs_name_part})"
+
def make_output(self, state, output_value_range):
# type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
lhs = state.get_output(self.lhs)
lhs_output = cast_to_size(
fn=state.fn, ssa_val=lhs.output,
dest_size=output_value_range.output_size, src_signed=lhs.is_signed,
- name="add_lhs_cast")
+ name=f"{state.name}_{self.name_part}_lhs_cast")
rhs = state.get_output(self.rhs)
rhs_output = cast_to_size(
fn=state.fn, ssa_val=rhs.output,
dest_size=output_value_range.output_size, src_signed=rhs.is_signed,
- name="add_rhs_cast")
+ name=f"{state.name}_{self.name_part}_rhs_cast")
setvl = state.fn.append_new_op(
OpKind.SetVLI, immediates=[output_value_range.output_size],
- name="setvl", maxvl=output_value_range.output_size)
- clear_ca = state.fn.append_new_op(OpKind.ClearCA, name="clear_ca")
+ name=f"{state.name}_{self.name_part}_setvl",
+ maxvl=output_value_range.output_size)
+ clear_ca = state.fn.append_new_op(
+ OpKind.ClearCA, name=f"{state.name}_{self.name_part}_clear_ca")
add = state.fn.append_new_op(
OpKind.SvAddE, input_vals=[
lhs_output, rhs_output, clear_ca.outputs[0], setvl.outputs[0]],
- maxvl=output_value_range.output_size, name="add")
+ maxvl=output_value_range.output_size,
+ name=f"{state.name}_{self.name_part}_add")
return EvalOpGenIrOutput(
output=add.outputs[0], value_range=output_value_range)
# type: () -> EvalOpPoly
return self.lhs_poly - self.rhs_poly
+ @cached_property
+ def name_part(self):
+ # type: () -> str
+ return f"({self.lhs_name_part}-{self.rhs_name_part})"
+
def make_output(self, state, output_value_range):
# type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
lhs = state.get_output(self.lhs)
lhs_output = cast_to_size(
fn=state.fn, ssa_val=lhs.output,
dest_size=output_value_range.output_size, src_signed=lhs.is_signed,
- name="add_lhs_cast")
+ name=f"{state.name}_{self.name_part}_lhs_cast")
rhs = state.get_output(self.rhs)
rhs_output = cast_to_size(
fn=state.fn, ssa_val=rhs.output,
dest_size=output_value_range.output_size, src_signed=rhs.is_signed,
- name="add_rhs_cast")
+ name=f"{state.name}_{self.name_part}_rhs_cast")
setvl = state.fn.append_new_op(
OpKind.SetVLI, immediates=[output_value_range.output_size],
- name="setvl", maxvl=output_value_range.output_size)
- set_ca = state.fn.append_new_op(OpKind.SetCA, name="set_ca")
+ name=f"{state.name}_{self.name_part}_setvl",
+ maxvl=output_value_range.output_size)
+ set_ca = state.fn.append_new_op(
+ OpKind.SetCA, name=f"{state.name}_{self.name_part}_set_ca")
sub = state.fn.append_new_op(
OpKind.SvSubFE, input_vals=[
rhs_output, lhs_output, set_ca.outputs[0], setvl.outputs[0]],
- maxvl=output_value_range.output_size, name="sub")
+ maxvl=output_value_range.output_size,
+ name=f"{state.name}_{self.name_part}_sub")
return EvalOpGenIrOutput(
output=sub.outputs[0], value_range=output_value_range)
raise TypeError("invalid rhs type")
return self.lhs_poly * self.rhs
+ @cached_property
+ def name_part(self):
+ # type: () -> str
+ return f"({self.lhs_name_part}*{self.rhs})"
+
def make_output(self, state, output_value_range):
# type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
raise NotImplementedError # FIXME: finish
raise TypeError("invalid rhs type")
return self.lhs_poly / self.rhs
+ @cached_property
+ def name_part(self):
+ # type: () -> str
+ return f"({self.lhs_name_part}/{self.rhs})"
+
def make_output(self, state, output_value_range):
# type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
raise NotImplementedError # FIXME: finish
def part_index(self):
return self.lhs
+ @cached_property
+ def name_part(self):
+ # type: () -> str
+ return f"part[{self.part_index}]"
+
def _make_poly(self):
# type: () -> EvalOpPoly
return EvalOpPoly({self.part_index: 1})
output = cast_to_size(
fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed,
dest_size=output_value_range.output_size,
- name=f"input_{self.part_index}_cast")
+ name=f"{state.name}_{self.name_part}_cast")
return EvalOpGenIrOutput(output=output, value_range=output_value_range)
for part in range(part_count):
start = part * part_size
stop = min(maxvl, start + part_size)
+ if part == part_count - 1:
+ stop = maxvl
part_maxvl = stop - start
part_setvl = fn.append_new_op(
OpKind.SetVLI, immediates=[part_size], maxvl=part_size,
class ToomCookMul:
__slots__ = (
"fn", "lhs", "lhs_signed", "rhs", "rhs_signed", "instances",
- "retval_size", "start_instance_index", "instance", "part_size",
+ "retval_size", "start_instance_index", "name", "instance", "part_size",
"lhs_parts", "lhs_inputs", "lhs_eval_state", "lhs_outputs",
"rhs_parts", "rhs_inputs", "rhs_eval_state", "rhs_outputs",
"prod_inputs", "prod_eval_state", "prod_parts",
"partial_products", "retval",
)
- def __init__(self, fn, lhs, lhs_signed, rhs, rhs_signed, instances,
- retval_size=None, start_instance_index=0):
- # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> None
+ def __init__(
+ self, fn, # type: Fn
+ lhs, # type: SSAVal
+ lhs_signed, # type: bool
+ rhs, # type: SSAVal
+ rhs_signed, # type: bool
+ instances, # type: _TCIs
+ retval_size=None, # type: None | int
+ name=None, # type: None | str
+ start_instance_index=0, # type: int
+ ):
+ # type: (...) -> None
self.fn = fn
self.lhs = lhs
self.lhs_signed = lhs_signed
if retval_size is None:
retval_size = lhs.ty.reg_len + rhs.ty.reg_len
self.retval_size = retval_size
+ if name is None:
+ name = "mul"
+ self.name = name
if start_instance_index < 0:
raise ValueError("start_instance_index must be non-negative")
self.start_instance_index = start_instance_index
self.part_size = 0 # type: int
while start_instance_index < len(instances):
self.instance = instances[start_instance_index]
- self.part_size = max(
- lhs.ty.reg_len // self.instance.lhs_part_count,
- rhs.ty.reg_len // self.instance.rhs_part_count)
+ self.part_size = 0
+ # FIXME: this loop is some kind of integer division,
+ # figure out the correct formula
+ for shift in reversed(range(6)):
+ next_part_size = self.part_size + (1 << shift)
+ if (lhs.ty.reg_len > (
+ self.instance.lhs_part_count - 1) * next_part_size
+ and rhs.ty.reg_len > (
+ self.instance.rhs_part_count - 1) * next_part_size):
+ self.part_size = next_part_size
if self.part_size <= 0:
self.instance = None
start_instance_index += 1
self.retval = simple_mul(fn=fn,
lhs=lhs, lhs_signed=lhs_signed,
rhs=rhs, rhs_signed=rhs_signed,
- name="toom_cook_base_case")
+ name=f"{name}_base_case")
return
self.lhs_parts = split_into_exact_sized_parts(
fn=fn, ssa_val=lhs, part_count=self.instance.lhs_part_count,
- part_size=self.part_size, name="lhs")
+ part_size=self.part_size, name=f"{name}_lhs")
self.lhs_inputs = [] # type: list[EvalOpGenIrInput]
for part, ssa_val in enumerate(self.lhs_parts):
self.lhs_inputs.append(EvalOpGenIrInput(
ssa_val=ssa_val,
is_signed=lhs_signed and part == len(self.lhs_parts) - 1))
- self.lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.lhs_inputs)
+ self.lhs_eval_state = EvalOpGenIrState(
+ fn=fn, inputs=self.lhs_inputs, name=f"{name}_lhs_eval")
lhs_eval_ops = self.instance.lhs_eval_ops
self.lhs_outputs = [
self.lhs_eval_state.get_output(i) for i in lhs_eval_ops]
self.rhs_parts = split_into_exact_sized_parts(
fn=fn, ssa_val=rhs, part_count=self.instance.rhs_part_count,
- part_size=self.part_size, name="rhs")
+ part_size=self.part_size, name=f"{name}_rhs")
self.rhs_inputs = [] # type: list[EvalOpGenIrInput]
for part, ssa_val in enumerate(self.rhs_parts):
self.rhs_inputs.append(EvalOpGenIrInput(
ssa_val=ssa_val,
is_signed=rhs_signed and part == len(self.rhs_parts) - 1))
- self.rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=self.rhs_inputs)
+ self.rhs_eval_state = EvalOpGenIrState(
+ fn=fn, inputs=self.rhs_inputs, name=f"{name}_rhs_eval")
rhs_eval_ops = self.instance.rhs_eval_ops
self.rhs_outputs = [
self.rhs_eval_state.get_output(i) for i in rhs_eval_ops]
self.prod_inputs = [] # type: list[EvalOpGenIrInput]
- for lhs_output, rhs_output in zip(self.lhs_outputs, self.rhs_outputs):
- ssa_val = toom_cook_mul(
+ for point_index, (lhs_output, rhs_output) in enumerate(
+ zip(self.lhs_outputs, self.rhs_outputs)):
+ ssa_val = ToomCookMul(
fn=fn,
lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
instances=instances,
- start_instance_index=start_instance_index + 1)
+ start_instance_index=start_instance_index + 1,
+ retval_size=None,
+ name=f"{name}_pt{point_index}").retval
products = (lhs_output.min_value * rhs_output.min_value,
lhs_output.min_value * rhs_output.max_value,
lhs_output.max_value * rhs_output.min_value,
is_signed=None,
min_value=min(products),
max_value=max(products)))
- self.prod_eval_state = EvalOpGenIrState(fn=fn, inputs=self.prod_inputs)
+ self.prod_eval_state = EvalOpGenIrState(
+ fn=fn, inputs=self.prod_inputs, name=f"{name}_prod_eval")
prod_eval_ops = self.instance.prod_eval_ops
self.prod_parts = [
self.prod_eval_state.get_output(i) for i in prod_eval_ops]
part_maxvl = prod_part.output.ty.reg_len
part_setvl = fn.append_new_op(
OpKind.SetVLI, immediates=[part_maxvl],
- name=f"prod_{part}_setvl", maxvl=part_maxvl)
+ name=f"{name}_prod_{part}_setvl", maxvl=part_maxvl)
spread_part = fn.append_new_op(
OpKind.Spread,
input_vals=[prod_part.output, part_setvl.outputs[0]],
- name=f"prod_{part}_spread", maxvl=part_maxvl)
+ name=f"{name}_prod_{part}_spread", maxvl=part_maxvl)
yield PartialProduct(
spread_part.outputs, shift_in_words=part * self.part_size,
is_signed=prod_part.is_signed, subtract=False)
self.partial_products = tuple(partial_products())
self.retval = sum_partial_products(
fn=fn, partial_products=self.partial_products,
- retval_size=retval_size, name="prod")
-
-
-def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances,
- retval_size=None, start_instance_index=0):
- # type: (Fn, SSAVal, bool, SSAVal, bool, _TCIs, None | int, int) -> SSAVal
+ retval_size=retval_size, name=f"{name}_sum_p_prods")
+
+
+def toom_cook_mul(
+ fn, # type: Fn
+ lhs, # type: SSAVal
+ lhs_signed, # type: bool
+ rhs, # type: SSAVal
+ rhs_signed, # type: bool
+ instances, # type: _TCIs
+ retval_size=None, # type: None | int
+ name=None, # type: None | str
+):
+ # type: (...) -> SSAVal
return ToomCookMul(
fn=fn, lhs=lhs, lhs_signed=lhs_signed, rhs=rhs, rhs_signed=rhs_signed,
- instances=instances, retval_size=retval_size,
- start_instance_index=start_instance_index).retval
+ instances=instances, retval_size=retval_size, name=name).retval