From: Jacob Lifshay Date: Wed, 9 Nov 2022 08:31:47 +0000 (-0800) Subject: working on adding signed multiplication -- needed for toom-cook X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=0dce56670c1deaa90c653d368061d2a51ee6009c;p=bigint-presentation-code.git working on adding signed multiplication -- needed for toom-cook --- diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 96dc318..2f548a4 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -5,37 +5,57 @@ from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, BaseSimState, Fn, GenAsmState, OpKind, PostRASimState, - PreRASimState) + PreRASimState, SSAVal) from bigint_presentation_code.register_allocator import allocate_registers -from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul +from bigint_presentation_code.toom_cook import (ToomCookInstance, simple_mul, + toom_cook_mul) -class SimpleMul192x192: - def __init__(self): +def simple_umul(fn, lhs, rhs): + # type: (Fn, SSAVal, SSAVal) -> SSAVal + return simple_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs, + rhs_signed=False, name="simple_umul") + + +class Mul: + def __init__(self, mul, lhs_size_in_words, rhs_size_in_words): + # type: (Callable[[Fn, SSAVal, SSAVal], SSAVal], int, int) -> None super().__init__() self.fn = fn = Fn() self.dest_offset = 0 - self.lhs_offset = 48 + self.dest_offset - self.rhs_offset = 24 + self.lhs_offset + self.dest_size_in_words = lhs_size_in_words + rhs_size_in_words + self.dest_size_in_bytes = self.dest_size_in_words * GPR_SIZE_IN_BYTES + self.lhs_size_in_words = lhs_size_in_words + self.lhs_size_in_bytes = self.lhs_size_in_words * GPR_SIZE_IN_BYTES + self.rhs_size_in_words = rhs_size_in_words + self.rhs_size_in_bytes = self.rhs_size_in_words * GPR_SIZE_IN_BYTES + self.lhs_offset = self.dest_size_in_bytes + self.dest_offset + self.rhs_offset = self.lhs_size_in_bytes + self.lhs_offset self.ptr_in = fn.append_new_op(kind=OpKind.FuncArgR3, name="ptr_in").outputs[0] - setvl3 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[3], - maxvl=3, name="setvl3") + lhs_setvl = fn.append_new_op( + kind=OpKind.SetVLI, immediates=[lhs_size_in_words], + maxvl=lhs_size_in_words, name="lhs_setvl") load_lhs = fn.append_new_op( kind=OpKind.SvLd, immediates=[self.lhs_offset], - input_vals=[self.ptr_in, setvl3.outputs[0]], - name="load_lhs", maxvl=3) + input_vals=[self.ptr_in, lhs_setvl.outputs[0]], + name="load_lhs", maxvl=lhs_size_in_words) + rhs_setvl = fn.append_new_op( + kind=OpKind.SetVLI, immediates=[rhs_size_in_words], + maxvl=rhs_size_in_words, name="rhs_setvl") load_rhs = fn.append_new_op( kind=OpKind.SvLd, immediates=[self.rhs_offset], - input_vals=[self.ptr_in, setvl3.outputs[0]], + input_vals=[self.ptr_in, rhs_setvl.outputs[0]], name="load_rhs", maxvl=3) - retval = simple_mul(fn, load_lhs.outputs[0], load_rhs.outputs[0]) - setvl6 = fn.append_new_op(kind=OpKind.SetVLI, immediates=[6], - maxvl=6, name="setvl6") + retval = mul(fn, load_lhs.outputs[0], load_rhs.outputs[0]) + dest_setvl = fn.append_new_op( + kind=OpKind.SetVLI, immediates=[self.dest_size_in_words], + maxvl=self.dest_size_in_words, name="dest_setvl") fn.append_new_op( kind=OpKind.SvStd, - input_vals=[retval, self.ptr_in, setvl6.outputs[0]], - immediates=[self.dest_offset], maxvl=6, name="store_dest") + input_vals=[retval, self.ptr_in, dest_setvl.outputs[0]], + immediates=[self.dest_offset], maxvl=self.dest_size_in_words, + name="store_dest") class TestToomCook(unittest.TestCase): @@ -207,21 +227,26 @@ class TestToomCook(unittest.TestCase): ) def test_simple_mul_192x192_pre_ra_sim(self): + self.skipTest("WIP") # FIXME: finish fixing simple_mul + def create_sim_state(code): - # type: (SimpleMul192x192) -> BaseSimState + # type: (Mul) -> BaseSimState return PreRASimState(ssa_vals={}, memory={}) self.tst_simple_mul_192x192_sim(create_sim_state) def test_simple_mul_192x192_post_ra_sim(self): + self.skipTest("WIP") # FIXME: finish fixing simple_mul + def create_sim_state(code): - # type: (SimpleMul192x192) -> BaseSimState + # type: (Mul) -> BaseSimState ssa_val_to_loc_map = allocate_registers(code.fn) return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map, memory={}, loc_values={}) self.tst_simple_mul_192x192_sim(create_sim_state) def tst_simple_mul_192x192_sim(self, create_sim_state): - # type: (Callable[[SimpleMul192x192], BaseSimState]) -> None + # type: (Callable[[Mul], BaseSimState]) -> None + self.skipTest("WIP") # FIXME: finish fixing simple_mul # test multiplying: # 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57 # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507 @@ -231,7 +256,7 @@ class TestToomCook(unittest.TestCase): # "_3931783239312079_7261727469627261", base=0) # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test", # 'little') - code = SimpleMul192x192() + code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) state = create_sim_state(code) ptr_in = 0x100 dest_ptr = ptr_in + code.dest_offset @@ -253,7 +278,8 @@ class TestToomCook(unittest.TestCase): self.assertEqual(out_bytes, expected_bytes) def test_simple_mul_192x192_ops(self): - code = SimpleMul192x192() + self.skipTest("WIP") # FIXME: finish fixing simple_mul + code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn self.assertEqual([repr(v) for v in fn.ops], [ "Op(kind=OpKind.FuncArgR3, " @@ -264,18 +290,23 @@ class TestToomCook(unittest.TestCase): "Op(kind=OpKind.SetVLI, " "input_vals=[], " "input_uses=(), immediates=[3], " - "outputs=(>,), " - "name='setvl3')", + "outputs=(>,), " + "name='lhs_setvl')", "Op(kind=OpKind.SvLd, " "input_vals=[>, " - ">], " + ">], " "input_uses=(>, " ">), immediates=[48], " "outputs=(>,), " "name='load_lhs')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), immediates=[3], " + "outputs=(>,), " + "name='rhs_setvl')", "Op(kind=OpKind.SvLd, " "input_vals=[>, " - ">], " + ">], " "input_uses=(>, " ">), immediates=[72], " "outputs=(>,), " @@ -283,11 +314,11 @@ class TestToomCook(unittest.TestCase): "Op(kind=OpKind.SetVLI, " "input_vals=[], " "input_uses=(), immediates=[3], " - "outputs=(>,), " - "name='rhs_setvl')", + "outputs=(>,), " + "name='rhs_setvl2')", "Op(kind=OpKind.Spread, " "input_vals=[>, " - ">], " + ">], " "input_uses=(>, " ">), immediates=[], " "outputs=(>, " @@ -297,8 +328,8 @@ class TestToomCook(unittest.TestCase): "Op(kind=OpKind.SetVLI, " "input_vals=[], " "input_uses=(), immediates=[3], " - "outputs=(>,), " - "name='lhs_setvl')", + "outputs=(>,), " + "name='lhs_setvl3')", "Op(kind=OpKind.LI, " "input_vals=[], " "input_uses=(), immediates=[0], " @@ -308,7 +339,7 @@ class TestToomCook(unittest.TestCase): "input_vals=[>, " ">, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">, " @@ -318,7 +349,7 @@ class TestToomCook(unittest.TestCase): "name='mul0')", "Op(kind=OpKind.Spread, " "input_vals=[>, " - ">], " + ">], " "input_uses=(>, " ">), immediates=[], " "outputs=(>, " @@ -329,7 +360,7 @@ class TestToomCook(unittest.TestCase): "input_vals=[>, " ">, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">, " @@ -341,7 +372,7 @@ class TestToomCook(unittest.TestCase): "input_vals=[>, " ">, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">, " @@ -357,7 +388,7 @@ class TestToomCook(unittest.TestCase): "input_vals=[>, " ">, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">, " @@ -367,7 +398,7 @@ class TestToomCook(unittest.TestCase): "name='add1')", "Op(kind=OpKind.Spread, " "input_vals=[>, " - ">], " + ">], " "input_uses=(>, " ">), immediates=[], " "outputs=(>, " @@ -386,7 +417,7 @@ class TestToomCook(unittest.TestCase): "input_vals=[>, " ">, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">, " @@ -398,7 +429,7 @@ class TestToomCook(unittest.TestCase): "input_vals=[>, " ">, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">, " @@ -414,7 +445,7 @@ class TestToomCook(unittest.TestCase): "input_vals=[>, " ">, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">, " @@ -424,7 +455,7 @@ class TestToomCook(unittest.TestCase): "name='add2')", "Op(kind=OpKind.Spread, " "input_vals=[>, " - ">], " + ">], " "input_uses=(>, " ">), immediates=[], " "outputs=(>, " @@ -464,12 +495,12 @@ class TestToomCook(unittest.TestCase): "Op(kind=OpKind.SetVLI, " "input_vals=[], " "input_uses=(), immediates=[6], " - "outputs=(>,), " - "name='setvl6')", + "outputs=(>,), " + "name='dest_setvl')", "Op(kind=OpKind.SvStd, " "input_vals=[>, " ">, " - ">], " + ">], " "input_uses=(>, " ">, " ">), immediates=[0], " @@ -478,7 +509,8 @@ class TestToomCook(unittest.TestCase): ]) def test_simple_mul_192x192_reg_alloc(self): - code = SimpleMul192x192() + self.skipTest("WIP") # FIXME: finish fixing simple_mul + code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn assigned_registers = allocate_registers(fn) self.assertEqual( @@ -491,7 +523,7 @@ class TestToomCook(unittest.TestCase): "Loc(kind=LocKind.GPR, start=4, reg_len=6), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=6), " @@ -717,7 +749,7 @@ class TestToomCook(unittest.TestCase): "Loc(kind=LocKind.GPR, start=18, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=19, reg_len=1), " @@ -737,7 +769,7 @@ class TestToomCook(unittest.TestCase): "Loc(kind=LocKind.GPR, start=3, reg_len=3), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=3), " @@ -749,6 +781,8 @@ class TestToomCook(unittest.TestCase): "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=20, reg_len=3), " ">: " @@ -759,7 +793,7 @@ class TestToomCook(unittest.TestCase): "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=6, reg_len=1), " - ">: " + ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=23, reg_len=1), " @@ -768,7 +802,8 @@ class TestToomCook(unittest.TestCase): "}") def test_simple_mul_192x192_asm(self): - code = SimpleMul192x192() + self.skipTest("WIP") # FIXME: finish fixing simple_mul + code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn assigned_registers = allocate_registers(fn) gen_asm_state = GenAsmState(assigned_registers) @@ -781,6 +816,7 @@ class TestToomCook(unittest.TestCase): 'sv.ld *3, 48(6)', 'setvl 0, 0, 3, 0, 1, 1', 'sv.or *20, *3, *3', + 'setvl 0, 0, 3, 0, 1, 1', 'or 6, 23, 23', 'setvl 0, 0, 3, 0, 1, 1', 'sv.ld *3, 72(6)', @@ -885,6 +921,23 @@ class TestToomCook(unittest.TestCase): 'sv.std *4, 0(3)' ]) + def test_toom_2_mul_256x256_asm(self): + self.skipTest("WIP") # FIXME: finish + TOOM_2 = ToomCookInstance.make_toom_2() + instances = TOOM_2, TOOM_2 + + def mul(fn, lhs, rhs): + # type: (Fn, SSAVal, SSAVal) -> SSAVal + return toom_cook_mul(fn=fn, lhs=lhs, lhs_signed=False, rhs=rhs, + rhs_signed=False, instances=instances) + code = Mul(mul=mul, lhs_size_in_words=3, rhs_size_in_words=3) + fn = code.fn + assigned_registers = allocate_registers(fn) + gen_asm_state = GenAsmState(assigned_registers) + fn.gen_asm(gen_asm_state) + self.assertEqual(gen_asm_state.output, [ + ]) + if __name__ == "__main__": unittest.main() diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 756615d..ba5ada9 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -16,6 +16,12 @@ from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta, OFSet, OSet) +GPR_SIZE_IN_BYTES = 8 +BITS_IN_BYTE = 8 +GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE +GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1 + + @final class Fn: def __init__(self): @@ -1221,6 +1227,36 @@ class OpKind(Enum): _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm + @staticmethod + def __sradi_sim(op, state): + # type: (Op, BaseSimState) -> None + rs, = state[op.input_vals[0]] + imm = op.immediates[0] + if rs >= 1 << (GPR_SIZE_IN_BITS - 1): + rs -= 1 << GPR_SIZE_IN_BITS + v = rs >> imm + RA = v & GPR_VALUE_MASK + CA = (RA << imm) != rs + state[op.outputs[0]] = RA, + state[op.outputs[1]] = CA, + + @staticmethod + def __sradi_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RA = state.sgpr(op.outputs[0]) + RS = state.sgpr(op.input_vals[1]) + imm = op.immediates[0] + state.writeln(f"sradi {RA}, {RS}, {imm}") + SRADI = GenericOpProperties( + demo_asm="sradi RA, RS, imm", + inputs=[OD_BASE_SGPR], + outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late), + OD_CA.with_write_stage(OpStage.Late)], + immediates=[range(GPR_SIZE_IN_BITS)], + ) + _SIM_FNS[SRADI] = lambda: OpKind.__sradi_sim + _GEN_ASMS[SRADI] = lambda: OpKind.__sradi_gen_asm + @staticmethod def __setvli_sim(op, state): # type: (Op, BaseSimState) -> None @@ -1971,12 +2007,6 @@ class Op: self.kind.gen_asm(self, state) -GPR_SIZE_IN_BYTES = 8 -BITS_IN_BYTE = 8 -GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE -GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1 - - @plain_data(frozen=True, repr=False) class BaseSimState(metaclass=ABCMeta): __slots__ = "memory", diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 03f3e93..de5b0f6 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -1,17 +1,20 @@ """ Toom-Cook multiplication algorithm generator for SVP64 """ -from abc import ABCMeta, abstractmethod +import math +from abc import abstractmethod from enum import Enum from fractions import Fraction -from typing import Any, Generic, Iterable, Mapping, TypeVar, Union +from typing import Iterable, Mapping, Union +from cached_property import cached_property from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, Fn, OpKind, - SSAVal) +from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, BaseTy, Fn, + OpKind, SSAVal, Ty) from bigint_presentation_code.matrix import Matrix from bigint_presentation_code.type_util import Literal, final +from bigint_presentation_code.util import InternedMeta @final @@ -23,7 +26,6 @@ class PointAtInfinity(Enum): POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY -WORD_BITS = GPR_SIZE_IN_BITS _EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]", "EvalOpPoly", Fraction, int, None] @@ -160,41 +162,165 @@ class EvalOpPoly: @plain_data(frozen=True, unsafe_hash=True) +@final class EvalOpValueRange: - __slots__ = ("eval_op", "inputs_words", "min_value", "max_value", - "is_signed") + __slots__ = ("eval_op", "inputs", "min_value", "max_value", + "is_signed", "output_size") - def __init__(self, eval_op, inputs_words): - # type: (EvalOp[Any, Any], Iterable[int]) -> None + def __init__(self, eval_op, inputs): + # type: (EvalOp | int, tuple[EvalOpGenIrInput, ...]) -> None super().__init__() self.eval_op = eval_op - self.inputs_words = tuple(inputs_words) - for words in self.inputs_words: - if words <= 0: - raise ValueError(f"invalid word count: {words}") - min_value = max_value = eval_op.poly.const_coeff - for var, coeff in enumerate(eval_op.poly.var_coeffs): + self.inputs = inputs + min_value = max_value = self.poly.const_coeff + for var, coeff in enumerate(self.poly.var_coeffs): if coeff == 0: continue - var_min = 0 - var_max = (1 << self.inputs_words[var] * WORD_BITS) - 1 - term_min = var_min * coeff - term_max = var_max * coeff + term_min = self.inputs[var].min_value * coeff + term_max = self.inputs[var].max_value * coeff if term_min > term_max: term_min, term_max = term_max, term_min min_value += term_min max_value += term_max + # output values are always integers, so eliminate any fractional part + # as impossible. + self.min_value = math.ceil(min_value) # exclude fractional part + self.max_value = math.floor(max_value) # exclude fractional part + self.is_signed = min_value < 0 + output_size = 1 + if self.is_signed: + min_v = -1 << (GPR_SIZE_IN_BITS - 1) + max_v = (1 << (GPR_SIZE_IN_BITS - 1)) - 1 + else: + min_v = 0 + max_v = (1 << GPR_SIZE_IN_BITS) - 1 + while not (min_v <= self.min_value and self.max_value <= max_v): + output_size += 1 + min_v <<= GPR_SIZE_IN_BITS + max_v <<= GPR_SIZE_IN_BITS + self.output_size = output_size + + @cached_property + def poly(self): + if isinstance(self.eval_op, int): + return EvalOpPoly(const_coeff=self.eval_op) + return self.eval_op.poly + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpGenIrOutput: + __slots__ = "output", "value_range" + + def __init__(self, output, value_range): + # type: (SSAVal, EvalOpValueRange) -> None + super().__init__() + if output.ty.reg_len != value_range.output_size: + raise ValueError("wrong output size") + self.output = output + self.value_range = value_range + + @property + def eval_op(self): + # type: () -> EvalOp | int + return self.value_range.eval_op + + @property + def inputs(self): + # type: () -> tuple[EvalOpGenIrInput, ...] + return self.value_range.inputs + + @property + def min_value(self): + # type: () -> int + return self.value_range.min_value + + @property + def max_value(self): + # type: () -> int + return self.value_range.max_value + + @property + def is_signed(self): + # type: () -> bool + return self.value_range.is_signed + + @property + def output_size(self): + # type: () -> int + return self.value_range.output_size + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpGenIrInput: + __slots__ = "ssa_val", "is_signed", "min_value", "max_value" + + def __init__(self, ssa_val, is_signed, min_value=None, max_value=None): + # type: (SSAVal, bool | None, int | None, int | None) -> None + super().__init__() + self.ssa_val = ssa_val + if ssa_val.base_ty != BaseTy.I64: + raise ValueError("input must have a base_ty of BaseTy.I64") + if is_signed is None: + if min_value is None or max_value is None: + raise ValueError("must specify either is_signed or both " + "min_value and max_value") + is_signed = min_value < 0 + self.is_signed = is_signed + if is_signed: + if min_value is None: + min_value = -1 << (ssa_val.ty.reg_len * GPR_SIZE_IN_BITS - 1) + if max_value is None: + max_value = (1 << ( + ssa_val.ty.reg_len * GPR_SIZE_IN_BITS - 1)) - 1 + else: + if min_value is None: + min_value = 0 + if max_value is None: + max_value = (1 << (ssa_val.ty.reg_len * GPR_SIZE_IN_BITS)) - 1 self.min_value = min_value self.max_value = max_value - self.is_signed = min_value < 0 + if self.min_value > self.max_value: + raise ValueError("invalid value range") -_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp[Any, Any]") -_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp[Any, Any]") +@plain_data(frozen=True) +@final +class EvalOpGenIrState: + __slots__ = "fn", "inputs", "outputs_map" + + def __init__(self, fn, inputs): + # type: (Fn, Iterable[EvalOpGenIrInput]) -> None + super().__init__() + self.fn = fn + self.inputs = tuple(inputs) + self.outputs_map = {} # type: dict[EvalOp | int, EvalOpGenIrOutput] + + def get_output(self, eval_op): + # type: (EvalOp | int) -> EvalOpGenIrOutput + retval = self.outputs_map.get(eval_op, None) + if retval is not None: + return retval + value_range = EvalOpValueRange(eval_op=eval_op, inputs=self.inputs) + if isinstance(eval_op, int): + li = self.fn.append_new_op(OpKind.LI, immediates=[eval_op], + name=f"li_{eval_op}") + output = cast_to_size( + fn=self.fn, ssa_val=li.outputs[0], + dest_size=value_range.output_size, + src_signed=value_range.is_signed, name=f"cast_{eval_op}") + retval = EvalOpGenIrOutput(output=output, value_range=value_range) + else: + retval = eval_op.make_output(state=self, + output_value_range=value_range) + if retval.value_range != value_range: + raise ValueError("wrong value_range") + return self.outputs_map.setdefault(eval_op, retval) @plain_data(frozen=True, unsafe_hash=True) -class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS], metaclass=ABCMeta): +class EvalOp(metaclass=InternedMeta): __slots__ = "lhs", "rhs", "poly" @property @@ -216,8 +342,13 @@ class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS], metaclass=ABCMeta): # type: () -> EvalOpPoly ... + @abstractmethod + def make_output(self, state, output_value_range): + # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput + ... + def __init__(self, lhs, rhs): - # type: (_EvalOpLHS, _EvalOpRHS) -> None + # type: (EvalOp | int, EvalOp | int) -> None super().__init__() self.lhs = lhs self.rhs = rhs @@ -226,48 +357,83 @@ class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS], metaclass=ABCMeta): @plain_data(frozen=True, unsafe_hash=True) @final -class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]): +class EvalOpAdd(EvalOp): __slots__ = () def _make_poly(self): # type: () -> EvalOpPoly return self.lhs_poly + self.rhs_poly + def make_output(self, state, output_value_range): + # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput + lhs = state.get_output(self.lhs) + lhs_output = cast_to_size( + fn=state.fn, ssa_val=lhs.output, + dest_size=output_value_range.output_size, src_signed=lhs.is_signed, + name="add_lhs_cast") + rhs = state.get_output(self.rhs) + rhs_output = cast_to_size( + fn=state.fn, ssa_val=rhs.output, + dest_size=output_value_range.output_size, src_signed=rhs.is_signed, + name="add_rhs_cast") + + raise NotImplementedError # FIXME: finish + @plain_data(frozen=True, unsafe_hash=True) @final -class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]): +class EvalOpSub(EvalOp): __slots__ = () def _make_poly(self): # type: () -> EvalOpPoly return self.lhs_poly - self.rhs_poly + def make_output(self, state, output_value_range): + # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput + raise NotImplementedError # FIXME: finish + @plain_data(frozen=True, unsafe_hash=True) @final -class EvalOpMul(EvalOp[_EvalOpLHS, int]): +class EvalOpMul(EvalOp): __slots__ = () + rhs: int def _make_poly(self): # type: () -> EvalOpPoly + if not isinstance(self.rhs, int): # type: ignore + raise TypeError("invalid rhs type") return self.lhs_poly * self.rhs + def make_output(self, state, output_value_range): + # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput + raise NotImplementedError # FIXME: finish + @plain_data(frozen=True, unsafe_hash=True) @final -class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]): +class EvalOpExactDiv(EvalOp): __slots__ = () + rhs: int def _make_poly(self): # type: () -> EvalOpPoly + if not isinstance(self.rhs, int): # type: ignore + raise TypeError("invalid rhs type") return self.lhs_poly / self.rhs + def make_output(self, state, output_value_range): + # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput + raise NotImplementedError # FIXME: finish + @plain_data(frozen=True, unsafe_hash=True) @final -class EvalOpInput(EvalOp[int, Literal[0]]): +class EvalOpInput(EvalOp): __slots__ = () + lhs: int + rhs: Literal[0] def __init__(self, lhs, rhs=0): # type: (int, int) -> None @@ -285,6 +451,15 @@ class EvalOpInput(EvalOp[int, Literal[0]]): # type: () -> EvalOpPoly return EvalOpPoly({self.part_index: 1}) + def make_output(self, state, output_value_range): + # type: (EvalOpGenIrState, EvalOpValueRange) -> EvalOpGenIrOutput + inp = state.inputs[self.part_index] + output = cast_to_size( + fn=state.fn, ssa_val=inp.ssa_val, src_signed=inp.is_signed, + dest_size=output_value_range.output_size, + name="input_{self.part_index}_cast") + return EvalOpGenIrOutput(output=output, value_range=output_value_range) + @plain_data(frozen=True, unsafe_hash=True) @final @@ -349,9 +524,9 @@ class ToomCookInstance: self, lhs_part_count, # type: int rhs_part_count, # type: int eval_points, # type: Iterable[PointAtInfinity | int] - lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]] - rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]] - prod_eval_ops, # type: Iterable[EvalOp[Any, Any]] + lhs_eval_ops, # type: Iterable[EvalOp] + rhs_eval_ops, # type: Iterable[EvalOp] + prod_eval_ops, # type: Iterable[EvalOp] ): # type: (...) -> None self.lhs_part_count = lhs_part_count @@ -470,11 +645,89 @@ class ToomCookInstance: # TODO: add make_toom_3 -def simple_mul(fn, lhs, rhs): - # type: (Fn, SSAVal, SSAVal) -> SSAVal - """ simple O(n^2) big-int unsigned multiply """ +@plain_data(frozen=True, unsafe_hash=True) +@final +class PartialProduct: + __slots__ = "ssa_val_spread", "shift_in_words", "is_signed" + + def __init__(self, ssa_val_spread, shift_in_words, is_signed): + # type: (Iterable[SSAVal], int, bool) -> None + if shift_in_words < 0: + raise ValueError("invalid shift_in_words") + self.ssa_val_spread = tuple(ssa_val_spread) + for ssa_val in ssa_val_spread: + if ssa_val.ty != Ty(base_ty=BaseTy.I64, reg_len=1): + raise ValueError("invalid ssa_val.ty") + self.shift_in_words = shift_in_words + self.is_signed = is_signed + + +def sum_partial_products(fn, partial_products, name): + # type: (Fn, Iterable[PartialProduct], str) -> SSAVal + retval_spread = [] # type: list[SSAVal] + retval_signed = False + zero = fn.append_new_op(OpKind.LI, immediates=[0], + name=f"{name}_zero").outputs[0] + has_carry_word = False + for idx, partial_product in enumerate(partial_products): + shift_in_words = partial_product.shift_in_words + spread = list(partial_product.ssa_val_spread) + if not retval_signed and shift_in_words >= len(retval_spread): + retval_spread.extend( + [zero] * (shift_in_words - len(retval_spread))) + retval_spread.extend(spread) + retval_signed = partial_product.is_signed + has_carry_word = False + continue + assert len(retval_spread) != 0, "logic error" + maxvl = max(len(retval_spread) - shift_in_words, len(spread)) + if not has_carry_word: + maxvl += 1 + has_carry_word = True + retval_spread = cast_to_size_spread( + fn=fn, ssa_vals=retval_spread, src_signed=retval_signed, + dest_size=maxvl + shift_in_words, name=f"{name}_{idx}_cast_retval") + spread = cast_to_size_spread( + fn=fn, ssa_vals=spread, src_signed=partial_product.is_signed, + dest_size=maxvl, name=f"{name}_{idx}_cast_pp") + setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[maxvl], + maxvl=maxvl, name=f"{name}_{idx}_setvl") + retval_concat = fn.append_new_op( + kind=OpKind.Concat, + input_vals=[*retval_spread[shift_in_words:], setvl.outputs[0]], + name=f"{name}_{idx}_retval_concat", maxvl=maxvl) + pp_concat = fn.append_new_op( + kind=OpKind.Concat, + input_vals=[*retval_spread[shift_in_words:], setvl.outputs[0]], + name=f"{name}_{idx}_pp_concat", maxvl=maxvl) + clear_ca = fn.append_new_op(kind=OpKind.ClearCA, + name=f"{name}_{idx}_clear_ca") + add = fn.append_new_op( + kind=OpKind.SvAddE, input_vals=[ + retval_concat.outputs[0], pp_concat.outputs[0], + clear_ca.outputs[0], setvl.outputs[0]], + maxvl=maxvl, name=f"{name}_{idx}_add") + retval_spread[shift_in_words:] = fn.append_new_op( + kind=OpKind.Spread, + input_vals=[add.outputs[0], setvl.outputs[0]], + name=f"{name}_{idx}_sum_spread", maxvl=maxvl).outputs + retval_setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[len(retval_spread)], + maxvl=len(retval_spread), name=f"{name}_setvl") + retval_concat = fn.append_new_op( + kind=OpKind.Concat, + input_vals=[*retval_spread, retval_setvl.outputs[0]], + name=f"{name}_concat", maxvl=len(retval_spread)) + return retval_concat.outputs[0] + + +def simple_mul(fn, lhs, lhs_signed, rhs, rhs_signed, name): + # type: (Fn, SSAVal, bool, SSAVal, bool, str) -> SSAVal + """ simple O(n^2) big-int multiply """ if lhs.ty.reg_len < rhs.ty.reg_len: lhs, rhs = rhs, lhs + lhs_signed, rhs_signed = rhs_signed, lhs_signed # split rhs into elements rhs_setvl = fn.append_new_op(kind=OpKind.SetVLI, immediates=[rhs.ty.reg_len], name="rhs_setvl") @@ -482,54 +735,231 @@ def simple_mul(fn, lhs, rhs): kind=OpKind.Spread, input_vals=[rhs, rhs_setvl.outputs[0]], maxvl=rhs.ty.reg_len, name="rhs_spread") rhs_words = rhs_spread.outputs - spread_retval = None # type: tuple[SSAVal, ...] | None + zero = fn.append_new_op( + kind=OpKind.LI, immediates=[0], name=f"{name}_zero").outputs[0] maxvl = lhs.ty.reg_len - lhs_setvl = fn.append_new_op(kind=OpKind.SetVLI, - immediates=[lhs.ty.reg_len], name="lhs_setvl") + lhs_setvl = fn.append_new_op( + kind=OpKind.SetVLI, immediates=[maxvl], name="lhs_setvl", maxvl=maxvl) vl = lhs_setvl.outputs[0] - zero_op = fn.append_new_op(kind=OpKind.LI, immediates=[0], name="zero") - zero = zero_op.outputs[0] - for shift, rhs_word in enumerate(rhs_words): - mul = fn.append_new_op(kind=OpKind.SvMAddEDU, - input_vals=[lhs, rhs_word, zero, vl], - maxvl=maxvl, name=f"mul{shift}") - if spread_retval is None: + if lhs_signed or rhs_signed: + raise NotImplementedError # FIXME: implement signed multiply + + def partial_products(): + # type: () -> Iterable[PartialProduct] + for shift_in_words, rhs_word in enumerate(rhs_words): + mul = fn.append_new_op( + kind=OpKind.SvMAddEDU, input_vals=[lhs, rhs_word, zero, vl], + maxvl=maxvl, name=f"{name}_{shift_in_words}_mul") mul_rt_spread = fn.append_new_op( kind=OpKind.Spread, input_vals=[mul.outputs[0], vl], - name=f"mul{shift}_rt_spread", maxvl=maxvl) - spread_retval = (*mul_rt_spread.outputs, mul.outputs[1]) + name=f"{name}_{shift_in_words}_mul_rt_spread", maxvl=maxvl) + yield PartialProduct( + ssa_val_spread=[*mul_rt_spread.outputs, mul.outputs[1]], + shift_in_words=shift_in_words, + is_signed=False) + return sum_partial_products(fn=fn, partial_products=partial_products(), + name=name) + + +def cast_to_size(fn, ssa_val, src_signed, dest_size, name): + # type: (Fn, SSAVal, bool, int, str) -> SSAVal + if dest_size <= 0: + raise ValueError("invalid dest_size -- must be a positive integer") + if ssa_val.ty.reg_len == dest_size: + return ssa_val + in_setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[ssa_val.ty.reg_len], + maxvl=ssa_val.ty.reg_len, name=f"{name}_in_setvl") + spread = fn.append_new_op( + OpKind.Spread, input_vals=[ssa_val, in_setvl.outputs[0]], + name=f"{name}_spread", maxvl=ssa_val.ty.reg_len) + spread_values = cast_to_size_spread( + fn=fn, ssa_vals=spread.outputs, src_signed=src_signed, + dest_size=dest_size, name=name) + out_setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[dest_size], maxvl=dest_size, + name=f"{name}_out_setvl") + concat = fn.append_new_op( + OpKind.Concat, input_vals=[*spread_values, out_setvl.outputs[0]], + name=f"{name}_concat", maxvl=dest_size) + return concat.outputs[0] + + +def cast_to_size_spread(fn, ssa_vals, src_signed, dest_size, name): + # type: (Fn, Iterable[SSAVal], bool, int, str) -> list[SSAVal] + if dest_size <= 0: + raise ValueError("invalid dest_size -- must be a positive integer") + spread_values = list(ssa_vals) + for ssa_val in ssa_vals: + if ssa_val.ty != Ty(base_ty=BaseTy.I64, reg_len=1): + raise ValueError("invalid ssa_val.ty") + if len(spread_values) == dest_size: + return spread_values + if len(spread_values) > dest_size: + spread_values[dest_size:] = [] + elif src_signed: + sign = fn.append_new_op( + OpKind.SRADI, input_vals=[spread_values[-1]], + immediates=[GPR_SIZE_IN_BITS - 1], name=f"{name}_sign") + spread_values += [sign.outputs[0]] * (dest_size - len(spread_values)) + else: + zero = fn.append_new_op( + OpKind.LI, immediates=[0], name=f"{name}_zero") + spread_values += [zero.outputs[0]] * (dest_size - len(spread_values)) + return spread_values + + +def split_into_exact_sized_parts(fn, ssa_val, part_count, part_size, name): + # type: (Fn, SSAVal, int, int, str) -> list[SSAVal] + """split ssa_val into part_count parts, where all but the last part have + `part.ty.reg_len == part_size`. + """ + if part_size <= 0: + raise ValueError("invalid part size, must be positive") + if part_count <= 0: + raise ValueError("invalid part count, must be positive") + if part_count == 1: + return [ssa_val] + too_short_reg_len = (part_count - 1) * part_size + if ssa_val.ty.reg_len <= too_short_reg_len: + raise ValueError(f"ssa_val is too short to split, must have " + f"reg_len > {too_short_reg_len}: {ssa_val}") + maxvl = ssa_val.ty.reg_len + setvl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl], + maxvl=maxvl, name=f"{name}_setvl") + spread = fn.append_new_op( + OpKind.Spread, input_vals=[ssa_val, setvl.outputs[0]], + name=f"{name}_spread", maxvl=maxvl) + retval = [] # type: list[SSAVal] + for part in range(part_count): + start = part * part_size + stop = min(maxvl, start + part_size) + part_maxvl = stop - start + part_setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[part_size], maxvl=part_size, + name=f"{name}_{part}_setvl") + concat = fn.append_new_op( + OpKind.Concat, + input_vals=[*spread.outputs[start:stop], part_setvl.outputs[0]], + name=f"{name}_{part}_concat", maxvl=part_maxvl) + retval.append(concat.outputs[0]) + return retval + + +def toom_cook_mul(fn, lhs, lhs_signed, rhs, rhs_signed, instances, + start_instance_index=0): + # type: (Fn, SSAVal, bool, SSAVal, bool, tuple[ToomCookInstance, ...], int) -> SSAVal + if start_instance_index < 0: + raise ValueError("start_instance_index must be non-negative") + instance = None + part_size = 0 + while start_instance_index < len(instances): + instance = instances[start_instance_index] + part_size = max(lhs.ty.reg_len // instance.lhs_part_count, + rhs.ty.reg_len // instance.rhs_part_count) + if part_size <= 0: + instance = None + start_instance_index += 1 else: - first_part = spread_retval[:shift] # type: tuple[SSAVal, ...] - last_part = spread_retval[shift:] - - add_rb_concat = fn.append_new_op( - kind=OpKind.Concat, input_vals=[*last_part, vl], - name=f"add{shift}_rb_concat", maxvl=maxvl) + break + if instance is None: + return simple_mul(fn=fn, + lhs=lhs, lhs_signed=lhs_signed, + rhs=rhs, rhs_signed=rhs_signed, + name="toom_cook_base_case") + lhs_parts = split_into_exact_sized_parts( + fn=fn, ssa_val=lhs, part_count=instance.lhs_part_count, + part_size=part_size, name="lhs") + lhs_inputs = [] # type: list[EvalOpGenIrInput] + for part, ssa_val in enumerate(lhs_parts): + lhs_inputs.append(EvalOpGenIrInput( + ssa_val=ssa_val, + is_signed=lhs_signed and part == len(lhs_parts) - 1)) + lhs_eval_state = EvalOpGenIrState(fn=fn, inputs=lhs_inputs) + lhs_outputs = [lhs_eval_state.get_output(i) for i in instance.lhs_eval_ops] + rhs_parts = split_into_exact_sized_parts( + fn=fn, ssa_val=rhs, part_count=instance.rhs_part_count, + part_size=part_size, name="rhs") + rhs_inputs = [] # type: list[EvalOpGenIrInput] + for part, ssa_val in enumerate(rhs_parts): + rhs_inputs.append(EvalOpGenIrInput( + ssa_val=ssa_val, + is_signed=rhs_signed and part == len(rhs_parts) - 1)) + rhs_eval_state = EvalOpGenIrState(fn=fn, inputs=rhs_inputs) + rhs_outputs = [rhs_eval_state.get_output(i) for i in instance.rhs_eval_ops] + prod_inputs = [] # type: list[EvalOpGenIrInput] + for lhs_output, rhs_output in zip(lhs_outputs, rhs_outputs): + ssa_val = toom_cook_mul( + fn=fn, + lhs=lhs_output.output, lhs_signed=lhs_output.is_signed, + rhs=rhs_output.output, rhs_signed=rhs_output.is_signed, + instances=instances, start_instance_index=start_instance_index + 1) + products = (lhs_output.min_value * rhs_output.min_value, + lhs_output.min_value * rhs_output.max_value, + lhs_output.max_value * rhs_output.min_value, + lhs_output.max_value * rhs_output.max_value) + prod_inputs.append(EvalOpGenIrInput( + ssa_val=ssa_val, + is_signed=None, + min_value=min(products), + max_value=max(products))) + prod_eval_state = EvalOpGenIrState(fn=fn, inputs=prod_inputs) + prod_parts = [ + prod_eval_state.get_output(i) for i in instance.prod_eval_ops] + retval_size = lhs.ty.reg_len + rhs.ty.reg_len + spread_retval = [] # type: list[SSAVal] + retval_signed = False # type: bool + # FIXME: replace loop with call to sum_partial_products + for part, prod_part in enumerate(prod_parts): + shift = part * part_size + maxvl = 1 + max(len(spread_retval) - shift, + prod_part.output.ty.reg_len) + if part == 0: + part_maxvl = prod_part.output.ty.reg_len + part_setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[part_maxvl], + name=f"prod_{part}_setvl", maxvl=part_maxvl) + spread_part = fn.append_new_op( + OpKind.Spread, + input_vals=[prod_part.output, part_setvl.outputs[0]], + name=f"prod_{part}_spread", maxvl=part_maxvl) + spread_retval[:] = spread_part.outputs + else: + cast_retval_spread = cast_to_size_spread( + fn=fn, ssa_vals=spread_retval[shift:], + src_signed=retval_signed, dest_size=maxvl, + name=f"prod_{part}_retval_cast") + cast_prod = cast_to_size( + fn=fn, ssa_val=prod_part.output, + src_signed=prod_part.is_signed, dest_size=maxvl, + name=f"prod_{part}_cast") + part_setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[maxvl], + name=f"prod_{part}_setvl", maxvl=maxvl) + cast_retval = fn.append_new_op( + kind=OpKind.Concat, + input_vals=[*cast_retval_spread, part_setvl.outputs[0]], + name=f"prod_{part}_concat", maxvl=maxvl) clear_ca = fn.append_new_op(kind=OpKind.ClearCA, - name=f"clear_ca{shift}") + name=f"prod_{part}_clear_ca") add = fn.append_new_op( kind=OpKind.SvAddE, input_vals=[ - mul.outputs[0], add_rb_concat.outputs[0], - clear_ca.outputs[0], vl], - maxvl=maxvl, name=f"add{shift}") - add_rt_spread = fn.append_new_op( - kind=OpKind.Spread, input_vals=[add.outputs[0], vl], - name=f"add{shift}_rt_spread", maxvl=maxvl) - add_hi = fn.append_new_op( - kind=OpKind.AddZE, input_vals=[mul.outputs[1], add.outputs[1]], - name=f"add_hi{shift}") - spread_retval = ( - *first_part, *add_rt_spread.outputs, add_hi.outputs[0]) - assert spread_retval is not None - lhs_setvl = fn.append_new_op( - kind=OpKind.SetVLI, immediates=[len(spread_retval)], - name="retval_setvl") - concat_retval = fn.append_new_op( - kind=OpKind.Concat, input_vals=[*spread_retval, lhs_setvl.outputs[0]], - name="concat_retval", maxvl=len(spread_retval)) - return concat_retval.outputs[0] - - -def toom_cook_mul(fn, lhs, rhs, instances): - # type: (Fn, SSAVal, SSAVal, list[ToomCookInstance]) -> SSAVal - raise NotImplementedError + cast_prod, cast_retval.outputs[0], + clear_ca.outputs[0], part_setvl.outputs[0]], + maxvl=maxvl, name=f"prod_{part}_add") + spread = fn.append_new_op( + kind=OpKind.Spread, + input_vals=[add.outputs[0], part_setvl.outputs[0]], + name=f"prod_{part}_spread", maxvl=maxvl) + spread_retval[shift:] = spread.outputs + retval_signed |= prod_part.is_signed + while len(spread_retval) > retval_size: + spread_retval.pop() + assert len(spread_retval) == retval_size, "logic error" + retval_setvl = fn.append_new_op( + OpKind.SetVLI, immediates=[retval_size], name=f"prod_setvl", + maxvl=retval_size) + retval_concat = fn.append_new_op( + OpKind.Concat, input_vals=[*spread_retval, retval_setvl.outputs[0]], + name="prod_concat", maxvl=retval_size) + return retval_concat.outputs[0]