BaseSimState, Fn,
GenAsmState, OpKind,
PostRASimState,
- PreRASimState)
+ PreRASimState, SSAVal)
from bigint_presentation_code.register_allocator import allocate_registers
-from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
+from bigint_presentation_code.toom_cook import (ToomCookInstance, simple_mul,
+ toom_cook_mul)
-class SimpleMul192x192:
- def __init__(self):
+def simple_umul(fn, lhs, rhs):
+ # type: (Fn, SSAVal, SSAVal) -> SSAVal
+ return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
+ rhs_signed=False, name="simple_umul")
+
+
+class Mul:
+ def __init__(self, mul, lhs_size_in_words, rhs_size_in_words):
+ # type: (Callable[[Fn, SSAVal, SSAVal], SSAVal], int, int) -> None
super().__init__()
self.fn = fn = Fn()
self.dest_offset = 0
- self.lhs_offset = 48 + self.dest_offset
- self.rhs_offset = 24 + self.lhs_offset
+ self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words
+ self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES
+ self.lhs_size_in_words = lhs_size_in_words
+ self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES
+ self.rhs_size_in_words = rhs_size_in_words
+ self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES
+ self.lhs_offset = self.dest_size_in_bytes + self.dest_offset
+ self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset
self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3,
name="ptr_in").outputs[0]
- setvl3 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[3],
- maxvl=3, name="setvl3")
+ lhs_setvl = fn.append_new_op(
+ kind=OpKind.SetVLI, immediates=[lhs_size_in_words],
+ maxvl=lhs_size_in_words, name="lhs_setvl")
load_lhs = fn.append_new_op(
kind=OpKind.SvLd, immediates=[self.lhs_offset],
- input_vals=[self.ptr_in, setvl3.outputs[0]],
- name="load_lhs", maxvl=3)
+ input_vals=[self.ptr_in, lhs_setvl.outputs[0]],
+ name="load_lhs", maxvl=lhs_size_in_words)
+ rhs_setvl = fn.append_new_op(
+ kind=OpKind.SetVLI, immediates=[rhs_size_in_words],
+ maxvl=rhs_size_in_words, name="rhs_setvl")
load_rhs = fn.append_new_op(
kind=OpKind.SvLd, immediates=[self.rhs_offset],
- input_vals=[self.ptr_in, setvl3.outputs[0]],
+ input_vals=[self.ptr_in, rhs_setvl.outputs[0]],
name="load_rhs", maxvl=3)
- retval = simple_mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
- setvl6 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[6],
- maxvl=6, name="setvl6")
+ retval = mul(fn, load_lhs.outputs[0], load_rhs.outputs[0])
+ dest_setvl = fn.append_new_op(
+ kind=OpKind.SetVLI, immediates=[self.dest_size_in_words],
+ maxvl=self.dest_size_in_words, name="dest_setvl")
fn.append_new_op(
kind=OpKind.SvStd,
- input_vals=[retval, self.ptr_in, setvl6.outputs[0]],
- immediates=[self.dest_offset], maxvl=6, name="store_dest")
+ input_vals=[retval, self.ptr_in, dest_setvl.outputs[0]],
+ immediates=[self.dest_offset], maxvl=self.dest_size_in_words,
+ name="store_dest")
class TestToomCook(unittest.TestCase):
)
def test_simple_mul_192x192_pre_ra_sim(self):
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+
def create_sim_state(code):
- # type: (SimpleMul192x192) -> BaseSimState
+ # type: (Mul) -> BaseSimState
return PreRASimState(ssa_vals={}, memory={})
self.tst_simple_mul_192x192_sim(create_sim_state)
def test_simple_mul_192x192_post_ra_sim(self):
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+
def create_sim_state(code):
- # type: (SimpleMul192x192) -> BaseSimState
+ # type: (Mul) -> BaseSimState
ssa_val_to_loc_map = allocate_registers(code.fn)
return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map,
memory={}, loc_values={})
self.tst_simple_mul_192x192_sim(create_sim_state)
def tst_simple_mul_192x192_sim(self, create_sim_state):
- # type: (Callable[[SimpleMul192x192], BaseSimState]) -> None
+ # type: (Callable[[Mul], BaseSimState]) -> None
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
# test multiplying:
# 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
# * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
# "_3931783239312079_7261727469627261", base=0)
# == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
# 'little')
- code = SimpleMul192x192()
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
state = create_sim_state(code)
ptr_in = 0x100
dest_ptr = ptr_in + code.dest_offset
self.assertEqual(out_bytes, expected_bytes)
def test_simple_mul_192x192_ops(self):
- code = SimpleMul192x192()
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
self.assertEqual([repr(v) for v in fn.ops], [
"Op(kind=OpKind.FuncArgR3, "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[3], "
- "outputs=(<setvl3.outputs[0]: <VL_MAXVL>>,), "
- "name='setvl3')",
+ "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='lhs_setvl')",
"Op(kind=OpKind.SvLd, "
"input_vals=[<ptr_in.outputs[0]: <I64>>, "
- "<setvl3.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<load_lhs.input_uses[0]: <I64>>, "
"<load_lhs.input_uses[1]: <VL_MAXVL>>), immediates=[48], "
"outputs=(<load_lhs.outputs[0]: <I64*3>>,), "
"name='load_lhs')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), immediates=[3], "
+ "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='rhs_setvl')",
"Op(kind=OpKind.SvLd, "
"input_vals=[<ptr_in.outputs[0]: <I64>>, "
- "<setvl3.outputs[0]: <VL_MAXVL>>], "
+ "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<load_rhs.input_uses[0]: <I64>>, "
"<load_rhs.input_uses[1]: <VL_MAXVL>>), immediates=[72], "
"outputs=(<load_rhs.outputs[0]: <I64*3>>,), "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[3], "
- "outputs=(<rhs_setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='rhs_setvl')",
+ "outputs=(<rhs_setvl2.outputs[0]: <VL_MAXVL>>,), "
+ "name='rhs_setvl2')",
"Op(kind=OpKind.Spread, "
"input_vals=[<load_rhs.outputs[0]: <I64*3>>, "
- "<rhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<rhs_setvl2.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<rhs_spread.input_uses[0]: <I64*3>>, "
"<rhs_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<rhs_spread.outputs[0]: <I64>>, "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[3], "
- "outputs=(<lhs_setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='lhs_setvl')",
+ "outputs=(<lhs_setvl3.outputs[0]: <VL_MAXVL>>,), "
+ "name='lhs_setvl3')",
"Op(kind=OpKind.LI, "
"input_vals=[], "
"input_uses=(), immediates=[0], "
"input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
"<rhs_spread.outputs[0]: <I64>>, "
"<zero.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul0.input_uses[0]: <I64*3>>, "
"<mul0.input_uses[1]: <I64>>, "
"<mul0.input_uses[2]: <I64>>, "
"name='mul0')",
"Op(kind=OpKind.Spread, "
"input_vals=[<mul0.outputs[0]: <I64*3>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul0_rt_spread.input_uses[0]: <I64*3>>, "
"<mul0_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<mul0_rt_spread.outputs[0]: <I64>>, "
"input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
"<rhs_spread.outputs[1]: <I64>>, "
"<zero.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul1.input_uses[0]: <I64*3>>, "
"<mul1.input_uses[1]: <I64>>, "
"<mul1.input_uses[2]: <I64>>, "
"input_vals=[<mul0_rt_spread.outputs[1]: <I64>>, "
"<mul0_rt_spread.outputs[2]: <I64>>, "
"<mul0.outputs[1]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add1_rb_concat.input_uses[0]: <I64>>, "
"<add1_rb_concat.input_uses[1]: <I64>>, "
"<add1_rb_concat.input_uses[2]: <I64>>, "
"input_vals=[<mul1.outputs[0]: <I64*3>>, "
"<add1_rb_concat.outputs[0]: <I64*3>>, "
"<clear_ca1.outputs[0]: <CA>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add1.input_uses[0]: <I64*3>>, "
"<add1.input_uses[1]: <I64*3>>, "
"<add1.input_uses[2]: <CA>>, "
"name='add1')",
"Op(kind=OpKind.Spread, "
"input_vals=[<add1.outputs[0]: <I64*3>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add1_rt_spread.input_uses[0]: <I64*3>>, "
"<add1_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<add1_rt_spread.outputs[0]: <I64>>, "
"input_vals=[<load_lhs.outputs[0]: <I64*3>>, "
"<rhs_spread.outputs[2]: <I64>>, "
"<zero.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<mul2.input_uses[0]: <I64*3>>, "
"<mul2.input_uses[1]: <I64>>, "
"<mul2.input_uses[2]: <I64>>, "
"input_vals=[<add1_rt_spread.outputs[1]: <I64>>, "
"<add1_rt_spread.outputs[2]: <I64>>, "
"<add_hi1.outputs[0]: <I64>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add2_rb_concat.input_uses[0]: <I64>>, "
"<add2_rb_concat.input_uses[1]: <I64>>, "
"<add2_rb_concat.input_uses[2]: <I64>>, "
"input_vals=[<mul2.outputs[0]: <I64*3>>, "
"<add2_rb_concat.outputs[0]: <I64*3>>, "
"<clear_ca2.outputs[0]: <CA>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add2.input_uses[0]: <I64*3>>, "
"<add2.input_uses[1]: <I64*3>>, "
"<add2.input_uses[2]: <CA>>, "
"name='add2')",
"Op(kind=OpKind.Spread, "
"input_vals=[<add2.outputs[0]: <I64*3>>, "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>], "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add2_rt_spread.input_uses[0]: <I64*3>>, "
"<add2_rt_spread.input_uses[1]: <VL_MAXVL>>), immediates=[], "
"outputs=(<add2_rt_spread.outputs[0]: <I64>>, "
"Op(kind=OpKind.SetVLI, "
"input_vals=[], "
"input_uses=(), immediates=[6], "
- "outputs=(<setvl6.outputs[0]: <VL_MAXVL>>,), "
- "name='setvl6')",
+ "outputs=(<dest_setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='dest_setvl')",
"Op(kind=OpKind.SvStd, "
"input_vals=[<concat_retval.outputs[0]: <I64*6>>, "
"<ptr_in.outputs[0]: <I64>>, "
- "<setvl6.outputs[0]: <VL_MAXVL>>], "
+ "<dest_setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<store_dest.input_uses[0]: <I64*6>>, "
"<store_dest.input_uses[1]: <I64>>, "
"<store_dest.input_uses[2]: <VL_MAXVL>>), immediates=[0], "
])
def test_simple_mul_192x192_reg_alloc(self):
- code = SimpleMul192x192()
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
assigned_registers = allocate_registers(fn)
self.assertEqual(
"Loc(kind=LocKind.GPR, start=4, reg_len=6), "
"<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<setvl6.outputs[0]: <VL_MAXVL>>: "
+ "<dest_setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<concat_retval.out0.copy.outputs[0]: <I64*6>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=6), "
"Loc(kind=LocKind.GPR, start=18, reg_len=1), "
"<zero.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "<lhs_setvl3.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<rhs_spread.out2.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=19, reg_len=1), "
"Loc(kind=LocKind.GPR, start=3, reg_len=3), "
"<rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "<rhs_setvl2.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=3), "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_rhs.inp0.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
"Loc(kind=LocKind.GPR, start=20, reg_len=3), "
"<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<load_lhs.inp0.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=6, reg_len=1), "
- "<setvl3.outputs[0]: <VL_MAXVL>>: "
+ "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<ptr_in.out0.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=23, reg_len=1), "
"}")
def test_simple_mul_192x192_asm(self):
- code = SimpleMul192x192()
+ self.skipTest("WIP") # FIXME: finish fixing simple_mul
+ code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
assigned_registers = allocate_registers(fn)
gen_asm_state = GenAsmState(assigned_registers)
'sv.ld *3, 48(6)',
'setvl 0, 0, 3, 0, 1, 1',
'sv.or *20, *3, *3',
+ 'setvl 0, 0, 3, 0, 1, 1',
'or 6, 23, 23',
'setvl 0, 0, 3, 0, 1, 1',
'sv.ld *3, 72(6)',
'sv.std *4, 0(3)'
])
+ def test_toom_2_mul_256x256_asm(self):
+ self.skipTest("WIP") # FIXME: finish
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ instances = TOOM_2, TOOM_2
+
+ def mul(fn, lhs, rhs):
+ # type: (Fn, SSAVal, SSAVal) -> SSAVal
+ return toom_cook_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs,
+ rhs_signed=False, instances=instances)
+ code = Mul(mul=mul, lhs_size_in_words=3, rhs_size_in_words=3)
+ fn = code.fn
+ assigned_registers = allocate_registers(fn)
+ gen_asm_state = GenAsmState(assigned_registers)
+ fn.gen_asm(gen_asm_state)
+ self.assertEqual(gen_asm_state.output, [
+ ])
+
if __name__ == "__main__":
unittest.main()
"""
Toom-Cook multiplication algorithm generator for SVP64
"""
-from abc import ABCMeta, abstractmethod
+import math
+from abc import abstractmethod
from enum import Enum
from fractions import Fraction
-from typing import Any, Generic, Iterable, Mapping, TypeVar, Union
+from typing import Iterable, Mapping, Union
+from cached_property import cached_property
from nmutil.plain_data import plain_data
-from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, Fn, OpKind,
- SSAVal)
+from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, BaseTy, Fn,
+ OpKind, SSAVal, Ty)
from bigint_presentation_code.matrix import Matrix
from bigint_presentation_code.type_util import Literal, final
+from bigint_presentation_code.util import InternedMeta
@final
POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
-WORD_BITS = GPR_SIZE_IN_BITS
_EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]",
"EvalOpPoly", Fraction, int, None]
@plain_data(frozen=True, unsafe_hash=True)
+@final
class EvalOpValueRange:
- __slots__ = ("eval_op", "inputs_words", "min_value", "max_value",
- "is_signed")
+ __slots__ = ("eval_op", "inputs", "min_value", "max_value",
+ "is_signed", "output_size")
- def __init__(self, eval_op, inputs_words):
- # type: (EvalOp[Any, Any], Iterable[int]) -> None
+ def __init__(self, eval_op, inputs):
+ # type: (EvalOp | int, tuple[EvalOpGenIrInput, ...]) -> None
super().__init__()
self.eval_op = eval_op
- self.inputs_words = tuple(inputs_words)
- for words in self.inputs_words:
- if words <= 0:
- raise ValueError(f"invalid word count: {words}")
- min_value = max_value = eval_op.poly.const_coeff
- for var, coeff in enumerate(eval_op.poly.var_coeffs):
+ self.inputs = inputs
+ min_value = max_value = self.poly.const_coeff
+ for var, coeff in enumerate(self.poly.var_coeffs):
if coeff == 0:
continue
- var_min = 0
- var_max = (1 << self.inputs_words[var] * WORD_BITS) - 1
- term_min = var_min * coeff
- term_max = var_max * coeff
+ term_min = self.inputs[var].min_value * coeff
+ term_max = self.inputs[var].max_value * coeff
if term_min > term_max:
term_min, term_max = term_max, term_min
min_value += term_min
max_value += term_max
+ # output values are always integers, so eliminate any fractional part
+ # as impossible.
+ self.min_value = math.ceil(min_value) # exclude fractional part
+ self.max_value = math.floor(max_value) # exclude fractional part
+ self.is_signed = min_value < 0
+ output_size = 1
+ if self.is_signed:
+ min_v = -1 << (GPR_SIZE_IN_BITS - 1)
+ max_v = (1 << (GPR_SIZE_IN_BITS - 1)) - 1
+ else:
+ min_v = 0
+ max_v = (1 << GPR_SIZE_IN_BITS) - 1
+ while not (min_v <= self.min_value and self.max_value <= max_v):
+ output_size += 1
+ min_v <<= GPR_SIZE_IN_BITS
+ max_v <<= GPR_SIZE_IN_BITS
+ self.output_size = output_size
+
+ @cached_property
+ def poly(self):
+ if isinstance(self.eval_op, int):
+ return EvalOpPoly(const_coeff=self.eval_op)
+ return self.eval_op.poly
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpGenIrOutput:
+ __slots__ = "output", "value_range"
+
+ def __init__(self, output, value_range):
+ # type: (SSAVal, EvalOpValueRange) -> None
+ super().__init__()
+ if output.ty.reg_len != value_range.output_size:
+ raise ValueError("wrong output size")
+ self.output = output
+ self.value_range = value_range
+
+ @property
+ def eval_op(self):
+ # type: () -> EvalOp | int
+ return self.value_range.eval_op
+
+ @property
+ def inputs(self):
+ # type: () -> tuple[EvalOpGenIrInput, ...]
+ return self.value_range.inputs
+
+ @property
+ def min_value(self):
+ # type: () -> int
+ return self.value_range.min_value
+
+ @property
+ def max_value(self):
+ # type: () -> int
+ return self.value_range.max_value
+
+ @property
+ def is_signed(self):
+ # type: () -> bool
+ return self.value_range.is_signed
+
+ @property
+ def output_size(self):
+ # type: () -> int
+ return self.value_range.output_size
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpGenIrInput:
+ __slots__ = "ssa_val", "is_signed", "min_value", "max_value"
+
+ def __init__(self, ssa_val, is_signed, min_value=None, max_value=None):
+ # type: (SSAVal, bool | None, int | None, int | None) -> None
+ super().__init__()
+ self.ssa_val = ssa_val
+ if ssa_val.base_ty != BaseTy.I64:
+ raise ValueError("input must have a base_ty of BaseTy.I64")
+ if is_signed is None:
+ if min_value is None or max_value is None:
+ raise ValueError("must specify either is_signed or both "
+ "min_value and max_value")
+ is_signed = min_value < 0
+ self.is_signed = is_signed
+ if is_signed:
+ if min_value is None:
+ min_value = -1 << (ssa_val.ty.reg_len * GPR_SIZE_IN_BITS - 1)
+ if max_value is None:
+ max_value = (1 << (
+ ssa_val.ty.reg_len * GPR_SIZE_IN_BITS - 1)) - 1
+ else:
+ if min_value is None:
+ min_value = 0
+ if max_value is None:
+ max_value = (1 << (ssa_val.ty.reg_len * GPR_SIZE_IN_BITS)) - 1
self.min_value = min_value
self.max_value = max_value
- self.is_signed = min_value < 0
+ if self.min_value > self.max_value:
+ raise ValueError("invalid value range")
-_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]")
-_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]")
+@plain_data(frozen=True)
+@final
+class EvalOpGenIrState:
+ __slots__ = "fn", "inputs", "outputs_map"
+
+ def __init__(self, fn, inputs):
+ # type: (Fn, Iterable[EvalOpGenIrInput]) -> None
+ super().__init__()
+ self.fn = fn
+ self.inputs = tuple(inputs)
+ self.outputs_map = {} # type: dict[EvalOp | int, EvalOpGenIrOutput]
+
+ def get_output(self, eval_op):
+ # type: (EvalOp | int) -> EvalOpGenIrOutput
+ retval = self.outputs_map.get(eval_op, None)
+ if retval is not None:
+ return retval
+ value_range = EvalOpValueRange(eval_op=eval_op, inputs=self.inputs)
+ if isinstance(eval_op, int):
+ li = self.fn.append_new_op(OpKind.LI, immediates=[eval_op],
+ name=f"li_{eval_op}")
+ output = cast_to_size(
+ fn=self.fn, ssa_val=li.outputs[0],
+ dest_size=value_range.output_size,
+ src_signed=value_range.is_signed, name=f"cast_{eval_op}")
+ retval = EvalOpGenIrOutput(output=output, value_range=value_range)
+ else:
+ retval = eval_op.make_output(state=self,
+ output_value_range=value_range)
+ if retval.value_range != value_range:
+ raise ValueError("wrong value_range")
+ return self.outputs_map.setdefault(eval_op, retval)
@plain_data(frozen=True, unsafe_hash=True)
-class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS], metaclass=ABCMeta):
+class EvalOp(metaclass=InternedMeta):
__slots__ = "lhs", "rhs", "poly"
@property
# type: () -> EvalOpPoly
...
+ @abstractmethod
+ def make_output(self, state, output_value_range):
+ # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
+ ...
+
def __init__(self, lhs, rhs):
- # type: (_EvalOpLHS, _EvalOpRHS) -> None
+ # type: (EvalOp | int, EvalOp | int) -> None
super().__init__()
self.lhs = lhs
self.rhs = rhs
@plain_data(frozen=True, unsafe_hash=True)
@final
-class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+class EvalOpAdd(EvalOp):
__slots__ = ()
def _make_poly(self):
# type: () -> EvalOpPoly
return self.lhs_poly + self.rhs_poly
+ def make_output(self, state, output_value_range):
+ # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
+ lhs = state.get_output(self.lhs)
+ lhs_output = cast_to_size(
+ fn=state.fn, ssa_val=lhs.output,
+ dest_size=output_value_range.output_size, src_signed=lhs.is_signed,
+ name="add_lhs_cast")
+ rhs = state.get_output(self.rhs)
+ rhs_output = cast_to_size(
+ fn=state.fn, ssa_val=rhs.output,
+ dest_size=output_value_range.output_size, src_signed=rhs.is_signed,
+ name="add_rhs_cast")
+
+ raise NotImplementedError # FIXME: finish
+
@plain_data(frozen=True, unsafe_hash=True)
@final
-class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+class EvalOpSub(EvalOp):
__slots__ = ()
def _make_poly(self):
# type: () -> EvalOpPoly
return self.lhs_poly - self.rhs_poly
+ def make_output(self, state, output_value_range):
+ # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
+ raise NotImplementedError # FIXME: finish
+
@plain_data(frozen=True, unsafe_hash=True)
@final
-class EvalOpMul(EvalOp[_EvalOpLHS, int]):
+class EvalOpMul(EvalOp):
__slots__ = ()
+ rhs: int
def _make_poly(self):
# type: () -> EvalOpPoly
+ if not isinstance(self.rhs, int): # type: ignore
+ raise TypeError("invalid rhs type")
return self.lhs_poly * self.rhs
+ def make_output(self, state, output_value_range):
+ # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
+ raise NotImplementedError # FIXME: finish
+
@plain_data(frozen=True, unsafe_hash=True)
@final
-class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]):
+class EvalOpExactDiv(EvalOp):
__slots__ = ()
+ rhs: int
def _make_poly(self):
# type: () -> EvalOpPoly
+ if not isinstance(self.rhs, int): # type: ignore
+ raise TypeError("invalid rhs type")
return self.lhs_poly / self.rhs
+ def make_output(self, state, output_value_range):
+ # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
+ raise NotImplementedError # FIXME: finish
+
@plain_data(frozen=True, unsafe_hash=True)
@final
-class EvalOpInput(EvalOp[int, Literal[0]]):
+class EvalOpInput(EvalOp):
__slots__ = ()
+ lhs: int
+ rhs: Literal[0]
def __init__(self, lhs, rhs=0):
# type: (int, int) -> None
# type: () -> EvalOpPoly
return EvalOpPoly({self.part_index: 1})
+ def make_output(self, state, output_value_range):
+ # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput
+ inp = state.inputs[self.part_index]
+ output = cast_to_size(
+ fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed,
+ dest_size=output_value_range.output_size,
+ name="input_{self.part_index}_cast")
+ return EvalOpGenIrOutput(output=output, value_range=output_value_range)
+
@plain_data(frozen=True, unsafe_hash=True)
@final
self, lhs_part_count, # type: int
rhs_part_count, # type: int
eval_points, # type: Iterable[PointAtInfinity | int]
- lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
- rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
- prod_eval_ops, # type: Iterable[EvalOp[Any, Any]]
+ lhs_eval_ops, # type: Iterable[EvalOp]
+ rhs_eval_ops, # type: Iterable[EvalOp]
+ prod_eval_ops, # type: Iterable[EvalOp]
):
# type: (...) -> None
self.lhs_part_count = lhs_part_count
# TODO: add make_toom_3
-def simple_mul(fn, lhs, rhs):
- # type: (Fn, SSAVal, SSAVal) -> SSAVal
- """ simple O(n^2) big-int unsigned multiply """
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class PartialProduct:
+ __slots__ = "ssa_val_spread", "shift_in_words", "is_signed"
+
+ def __init__(self, ssa_val_spread, shift_in_words, is_signed):
+ # type: (Iterable[SSAVal], int, bool) -> None
+ if shift_in_words < 0:
+ raise ValueError("invalid shift_in_words")
+ self.ssa_val_spread = tuple(ssa_val_spread)
+ for ssa_val in ssa_val_spread:
+ if ssa_val.ty != Ty(base_ty=BaseTy.I64, reg_len=1):
+ raise ValueError("invalid ssa_val.ty")
+ self.shift_in_words = shift_in_words
+ self.is_signed = is_signed
+
+
+def sum_partial_products(fn, partial_products, name):
+ # type: (Fn, Iterable[PartialProduct], str) -> SSAVal
+ retval_spread = [] # type: list[SSAVal]
+ retval_signed = False
+ zero = fn.append_new_op(OpKind.LI, immediates=[0],
+ name=f"{name}_zero").outputs[0]
+ has_carry_word = False
+ for idx, partial_product in enumerate(partial_products):
+ shift_in_words = partial_product.shift_in_words
+ spread = list(partial_product.ssa_val_spread)
+ if not retval_signed and shift_in_words >= len(retval_spread):
+ retval_spread.extend(
+ [zero] * (shift_in_words - len(retval_spread)))
+ retval_spread.extend(spread)
+ retval_signed = partial_product.is_signed
+ has_carry_word = False
+ continue
+ assert len(retval_spread) != 0, "logic error"
+ maxvl = max(len(retval_spread) - shift_in_words, len(spread))
+ if not has_carry_word:
+ maxvl += 1
+ has_carry_word = True
+ retval_spread = cast_to_size_spread(
+ fn=fn, ssa_vals=retval_spread, src_signed=retval_signed,
+ dest_size=maxvl + shift_in_words, name=f"{name}_{idx}_cast_retval")
+ spread = cast_to_size_spread(
+ fn=fn, ssa_vals=spread, src_signed=partial_product.is_signed,
+ dest_size=maxvl, name=f"{name}_{idx}_cast_pp")
+ setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[maxvl],
+ maxvl=maxvl, name=f"{name}_{idx}_setvl")
+ retval_concat = fn.append_new_op(
+ kind=OpKind.Concat,
+ input_vals=[*retval_spread[shift_in_words:], setvl.outputs[0]],
+ name=f"{name}_{idx}_retval_concat", maxvl=maxvl)
+ pp_concat = fn.append_new_op(
+ kind=OpKind.Concat,
+ input_vals=[*retval_spread[shift_in_words:], setvl.outputs[0]],
+ name=f"{name}_{idx}_pp_concat", maxvl=maxvl)
+ clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
+ name=f"{name}_{idx}_clear_ca")
+ add = fn.append_new_op(
+ kind=OpKind.SvAddE, input_vals=[
+ retval_concat.outputs[0], pp_concat.outputs[0],
+ clear_ca.outputs[0], setvl.outputs[0]],
+ maxvl=maxvl, name=f"{name}_{idx}_add")
+ retval_spread[shift_in_words:] = fn.append_new_op(
+ kind=OpKind.Spread,
+ input_vals=[add.outputs[0], setvl.outputs[0]],
+ name=f"{name}_{idx}_sum_spread", maxvl=maxvl).outputs
+ retval_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[len(retval_spread)],
+ maxvl=len(retval_spread), name=f"{name}_setvl")
+ retval_concat = fn.append_new_op(
+ kind=OpKind.Concat,
+ input_vals=[*retval_spread, retval_setvl.outputs[0]],
+ name=f"{name}_concat", maxvl=len(retval_spread))
+ return retval_concat.outputs[0]
+
+
+def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name):
+ # type: (Fn, SSAVal, bool, SSAVal, bool, str) -> SSAVal
+ """ simple O(n^2) big-int multiply """
if lhs.ty.reg_len < rhs.ty.reg_len:
lhs, rhs = rhs, lhs
+ lhs_signed, rhs_signed = rhs_signed, lhs_signed
# split rhs into elements
rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
immediates=[rhs.ty.reg_len], name="rhs_setvl")
kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]],
maxvl=rhs.ty.reg_len, name="rhs_spread")
rhs_words = rhs_spread.outputs
- spread_retval = None # type: tuple[SSAVal, ...] | None
+ zero = fn.append_new_op(
+ kind=OpKind.LI, immediates=[0], name=f"{name}_zero").outputs[0]
maxvl = lhs.ty.reg_len
- lhs_setvl = fn.append_new_op(kind=OpKind.SetVLI,
- immediates=[lhs.ty.reg_len], name="lhs_setvl")
+ lhs_setvl = fn.append_new_op(
+ kind=OpKind.SetVLI, immediates=[maxvl], name="lhs_setvl", maxvl=maxvl)
vl = lhs_setvl.outputs[0]
- zero_op = fn.append_new_op(kind=OpKind.LI, immediates=[0], name="zero")
- zero = zero_op.outputs[0]
- for shift, rhs_word in enumerate(rhs_words):
- mul = fn.append_new_op(kind=OpKind.SvMAddEDU,
- input_vals=[lhs, rhs_word, zero, vl],
- maxvl=maxvl, name=f"mul{shift}")
- if spread_retval is None:
+ if lhs_signed or rhs_signed:
+ raise NotImplementedError # FIXME: implement signed multiply
+
+ def partial_products():
+ # type: () -> Iterable[PartialProduct]
+ for shift_in_words, rhs_word in enumerate(rhs_words):
+ mul = fn.append_new_op(
+ kind=OpKind.SvMAddEDU, input_vals=[lhs, rhs_word, zero, vl],
+ maxvl=maxvl, name=f"{name}_{shift_in_words}_mul")
mul_rt_spread = fn.append_new_op(
kind=OpKind.Spread, input_vals=[mul.outputs[0], vl],
- name=f"mul{shift}_rt_spread", maxvl=maxvl)
- spread_retval = (*mul_rt_spread.outputs, mul.outputs[1])
+ name=f"{name}_{shift_in_words}_mul_rt_spread", maxvl=maxvl)
+ yield PartialProduct(
+ ssa_val_spread=[*mul_rt_spread.outputs, mul.outputs[1]],
+ shift_in_words=shift_in_words,
+ is_signed=False)
+ return sum_partial_products(fn=fn, partial_products=partial_products(),
+ name=name)
+
+
+def cast_to_size(fn, ssa_val, src_signed, dest_size, name):
+ # type: (Fn, SSAVal, bool, int, str) -> SSAVal
+ if dest_size <= 0:
+ raise ValueError("invalid dest_size -- must be a positive integer")
+ if ssa_val.ty.reg_len == dest_size:
+ return ssa_val
+ in_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[ssa_val.ty.reg_len],
+ maxvl=ssa_val.ty.reg_len, name=f"{name}_in_setvl")
+ spread = fn.append_new_op(
+ OpKind.Spread, input_vals=[ssa_val, in_setvl.outputs[0]],
+ name=f"{name}_spread", maxvl=ssa_val.ty.reg_len)
+ spread_values = cast_to_size_spread(
+ fn=fn, ssa_vals=spread.outputs, src_signed=src_signed,
+ dest_size=dest_size, name=name)
+ out_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[dest_size], maxvl=dest_size,
+ name=f"{name}_out_setvl")
+ concat = fn.append_new_op(
+ OpKind.Concat, input_vals=[*spread_values, out_setvl.outputs[0]],
+ name=f"{name}_concat", maxvl=dest_size)
+ return concat.outputs[0]
+
+
+def cast_to_size_spread(fn, ssa_vals, src_signed, dest_size, name):
+ # type: (Fn, Iterable[SSAVal], bool, int, str) -> list[SSAVal]
+ if dest_size <= 0:
+ raise ValueError("invalid dest_size -- must be a positive integer")
+ spread_values = list(ssa_vals)
+ for ssa_val in ssa_vals:
+ if ssa_val.ty != Ty(base_ty=BaseTy.I64, reg_len=1):
+ raise ValueError("invalid ssa_val.ty")
+ if len(spread_values) == dest_size:
+ return spread_values
+ if len(spread_values) > dest_size:
+ spread_values[dest_size:] = []
+ elif src_signed:
+ sign = fn.append_new_op(
+ OpKind.SRADI, input_vals=[spread_values[-1]],
+ immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_sign")
+ spread_values += [sign.outputs[0]] * (dest_size - len(spread_values))
+ else:
+ zero = fn.append_new_op(
+ OpKind.LI, immediates=[0], name=f"{name}_zero")
+ spread_values += [zero.outputs[0]] * (dest_size - len(spread_values))
+ return spread_values
+
+
+def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name):
+ # type: (Fn, SSAVal, int, int, str) -> list[SSAVal]
+ """split ssa_val into part_count parts, where all but the last part have
+ `part.ty.reg_len == part_size`.
+ """
+ if part_size <= 0:
+ raise ValueError("invalid part size, must be positive")
+ if part_count <= 0:
+ raise ValueError("invalid part count, must be positive")
+ if part_count == 1:
+ return [ssa_val]
+ too_short_reg_len = (part_count - 1) * part_size
+ if ssa_val.ty.reg_len <= too_short_reg_len:
+ raise ValueError(f"ssa_val is too short to split, must have "
+ f"reg_len > {too_short_reg_len}: {ssa_val}")
+ maxvl = ssa_val.ty.reg_len
+ setvl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+ maxvl=maxvl, name=f"{name}_setvl")
+ spread = fn.append_new_op(
+ OpKind.Spread, input_vals=[ssa_val, setvl.outputs[0]],
+ name=f"{name}_spread", maxvl=maxvl)
+ retval = [] # type: list[SSAVal]
+ for part in range(part_count):
+ start = part * part_size
+ stop = min(maxvl, start + part_size)
+ part_maxvl = stop - start
+ part_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[part_size], maxvl=part_size,
+ name=f"{name}_{part}_setvl")
+ concat = fn.append_new_op(
+ OpKind.Concat,
+ input_vals=[*spread.outputs[start:stop], part_setvl.outputs[0]],
+ name=f"{name}_{part}_concat", maxvl=part_maxvl)
+ retval.append(concat.outputs[0])
+ return retval
+
+
+def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances,
+ start_instance_index=0):
+ # type: (Fn, SSAVal, bool, SSAVal, bool, tuple[ToomCookInstance, ...], int) -> SSAVal
+ if start_instance_index < 0:
+ raise ValueError("start_instance_index must be non-negative")
+ instance = None
+ part_size = 0
+ while start_instance_index < len(instances):
+ instance = instances[start_instance_index]
+ part_size = max(lhs.ty.reg_len // instance.lhs_part_count,
+ rhs.ty.reg_len // instance.rhs_part_count)
+ if part_size <= 0:
+ instance = None
+ start_instance_index += 1
else:
- first_part = spread_retval[:shift] # type: tuple[SSAVal, ...]
- last_part = spread_retval[shift:]
-
- add_rb_concat = fn.append_new_op(
- kind=OpKind.Concat, input_vals=[*last_part, vl],
- name=f"add{shift}_rb_concat", maxvl=maxvl)
+ break
+ if instance is None:
+ return simple_mul(fn=fn,
+ lhs=lhs, lhs_signed=lhs_signed,
+ rhs=rhs, rhs_signed=rhs_signed,
+ name="toom_cook_base_case")
+ lhs_parts = split_into_exact_sized_parts(
+ fn=fn, ssa_val=lhs, part_count=instance.lhs_part_count,
+ part_size=part_size, name="lhs")
+ lhs_inputs = [] # type: list[EvalOpGenIrInput]
+ for part, ssa_val in enumerate(lhs_parts):
+ lhs_inputs.append(EvalOpGenIrInput(
+ ssa_val=ssa_val,
+ is_signed=lhs_signed and part == len(lhs_parts) - 1))
+ lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=lhs_inputs)
+ lhs_outputs = [lhs_eval_state.get_output(i) for i in instance.lhs_eval_ops]
+ rhs_parts = split_into_exact_sized_parts(
+ fn=fn, ssa_val=rhs, part_count=instance.rhs_part_count,
+ part_size=part_size, name="rhs")
+ rhs_inputs = [] # type: list[EvalOpGenIrInput]
+ for part, ssa_val in enumerate(rhs_parts):
+ rhs_inputs.append(EvalOpGenIrInput(
+ ssa_val=ssa_val,
+ is_signed=rhs_signed and part == len(rhs_parts) - 1))
+ rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=rhs_inputs)
+ rhs_outputs = [rhs_eval_state.get_output(i) for i in instance.rhs_eval_ops]
+ prod_inputs = [] # type: list[EvalOpGenIrInput]
+ for lhs_output, rhs_output in zip(lhs_outputs, rhs_outputs):
+ ssa_val = toom_cook_mul(
+ fn=fn,
+ lhs=lhs_output.output, lhs_signed=lhs_output.is_signed,
+ rhs=rhs_output.output, rhs_signed=rhs_output.is_signed,
+ instances=instances, start_instance_index=start_instance_index + 1)
+ products = (lhs_output.min_value * rhs_output.min_value,
+ lhs_output.min_value * rhs_output.max_value,
+ lhs_output.max_value * rhs_output.min_value,
+ lhs_output.max_value * rhs_output.max_value)
+ prod_inputs.append(EvalOpGenIrInput(
+ ssa_val=ssa_val,
+ is_signed=None,
+ min_value=min(products),
+ max_value=max(products)))
+ prod_eval_state = EvalOpGenIrState(fn=fn, inputs=prod_inputs)
+ prod_parts = [
+ prod_eval_state.get_output(i) for i in instance.prod_eval_ops]
+ retval_size = lhs.ty.reg_len + rhs.ty.reg_len
+ spread_retval = [] # type: list[SSAVal]
+ retval_signed = False # type: bool
+ # FIXME: replace loop with call to sum_partial_products
+ for part, prod_part in enumerate(prod_parts):
+ shift = part * part_size
+ maxvl = 1 + max(len(spread_retval) - shift,
+ prod_part.output.ty.reg_len)
+ if part == 0:
+ part_maxvl = prod_part.output.ty.reg_len
+ part_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[part_maxvl],
+ name=f"prod_{part}_setvl", maxvl=part_maxvl)
+ spread_part = fn.append_new_op(
+ OpKind.Spread,
+ input_vals=[prod_part.output, part_setvl.outputs[0]],
+ name=f"prod_{part}_spread", maxvl=part_maxvl)
+ spread_retval[:] = spread_part.outputs
+ else:
+ cast_retval_spread = cast_to_size_spread(
+ fn=fn, ssa_vals=spread_retval[shift:],
+ src_signed=retval_signed, dest_size=maxvl,
+ name=f"prod_{part}_retval_cast")
+ cast_prod = cast_to_size(
+ fn=fn, ssa_val=prod_part.output,
+ src_signed=prod_part.is_signed, dest_size=maxvl,
+ name=f"prod_{part}_cast")
+ part_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[maxvl],
+ name=f"prod_{part}_setvl", maxvl=maxvl)
+ cast_retval = fn.append_new_op(
+ kind=OpKind.Concat,
+ input_vals=[*cast_retval_spread, part_setvl.outputs[0]],
+ name=f"prod_{part}_concat", maxvl=maxvl)
clear_ca = fn.append_new_op(kind=OpKind.ClearCA,
- name=f"clear_ca{shift}")
+ name=f"prod_{part}_clear_ca")
add = fn.append_new_op(
kind=OpKind.SvAddE, input_vals=[
- mul.outputs[0], add_rb_concat.outputs[0],
- clear_ca.outputs[0], vl],
- maxvl=maxvl, name=f"add{shift}")
- add_rt_spread = fn.append_new_op(
- kind=OpKind.Spread, input_vals=[add.outputs[0], vl],
- name=f"add{shift}_rt_spread", maxvl=maxvl)
- add_hi = fn.append_new_op(
- kind=OpKind.AddZE, input_vals=[mul.outputs[1], add.outputs[1]],
- name=f"add_hi{shift}")
- spread_retval = (
- *first_part, *add_rt_spread.outputs, add_hi.outputs[0])
- assert spread_retval is not None
- lhs_setvl = fn.append_new_op(
- kind=OpKind.SetVLI, immediates=[len(spread_retval)],
- name="retval_setvl")
- concat_retval = fn.append_new_op(
- kind=OpKind.Concat, input_vals=[*spread_retval, lhs_setvl.outputs[0]],
- name="concat_retval", maxvl=len(spread_retval))
- return concat_retval.outputs[0]
-
-
-def toom_cook_mul(fn, lhs, rhs, instances):
- # type: (Fn, SSAVal, SSAVal, list[ToomCookInstance]) -> SSAVal
- raise NotImplementedError
+ cast_prod, cast_retval.outputs[0],
+ clear_ca.outputs[0], part_setvl.outputs[0]],
+ maxvl=maxvl, name=f"prod_{part}_add")
+ spread = fn.append_new_op(
+ kind=OpKind.Spread,
+ input_vals=[add.outputs[0], part_setvl.outputs[0]],
+ name=f"prod_{part}_spread", maxvl=maxvl)
+ spread_retval[shift:] = spread.outputs
+ retval_signed |= prod_part.is_signed
+ while len(spread_retval) > retval_size:
+ spread_retval.pop()
+ assert len(spread_retval) == retval_size, "logic error"
+ retval_setvl = fn.append_new_op(
+ OpKind.SetVLI, immediates=[retval_size], name=f"prod_setvl",
+ maxvl=retval_size)
+ retval_concat = fn.append_new_op(
+ OpKind.Concat, input_vals=[*spread_retval, retval_setvl.outputs[0]],
+ name="prod_concat", maxvl=retval_size)
+ return retval_concat.outputs[0]