From 2f0d5e605815a0acf57caf3c7451569f14fc22a9 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 19 Oct 2022 23:46:02 -0700 Subject: [PATCH] ToomCookInstance works! --- src/bigint_presentation_code/matrix.py | 10 +- .../test_toom_cook.py | 40 ++- src/bigint_presentation_code/toom_cook.py | 282 ++++++++++++++---- 3 files changed, 273 insertions(+), 59 deletions(-) diff --git a/src/bigint_presentation_code/matrix.py b/src/bigint_presentation_code/matrix.py index 49acddf..89c3ea2 100644 --- a/src/bigint_presentation_code/matrix.py +++ b/src/bigint_presentation_code/matrix.py @@ -2,11 +2,12 @@ import operator from enum import Enum, unique from fractions import Fraction from numbers import Rational -from typing import Callable, Generic, Iterable, Iterator, Type, TypeVar +from typing import Any, Callable, Generic, Iterable, Iterator, Type, TypeVar from bigint_presentation_code.util import final _T = TypeVar("_T") +_T2 = TypeVar("_T2") @final @@ -37,7 +38,7 @@ class Matrix(Generic[_T]): def __init__(self, height, width, data=SpecialMatrix.Zero, element_type=Fraction): - # type: (int, int, Iterable[_T | int] | SpecialMatrix, Type[_T]) -> None + # type: (int, int, Iterable[_T | int | Any] | SpecialMatrix, Type[_T]) -> None if width < 0 or height < 0: raise ValueError("matrix size must be non-negative") self.__height = height @@ -55,6 +56,11 @@ class Matrix(Generic[_T]): if len(self.__data) != height * width: raise ValueError("data has wrong length") + def cast(self, element_type): + # type: (Type[_T2]) -> Matrix[_T2] + data = self # type: Iterable[Any] + return Matrix(self.height, self.width, data, element_type=element_type) + def __idx(self, row, col): # type: (int, int) -> int if 0 <= col < self.width and 0 <= row < self.height: diff --git a/src/bigint_presentation_code/test_toom_cook.py b/src/bigint_presentation_code/test_toom_cook.py index 8fe6cea..f880a81 100644 --- a/src/bigint_presentation_code/test_toom_cook.py +++ b/src/bigint_presentation_code/test_toom_cook.py @@ -1,10 +1,46 @@ import unittest -import bigint_presentation_code.toom_cook +from bigint_presentation_code.toom_cook import ToomCookInstance class TestToomCook(unittest.TestCase): - pass # no tests yet, just testing importing + def test_toom_2(self): + TOOM_2 = ToomCookInstance.make_toom_2() + print(repr(repr(TOOM_2))) + self.assertEqual( + repr(TOOM_2), + "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, " + "eval_points=(0, 1, POINT_AT_INFINITY), " + "lhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," + " rhs_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpAdd(lhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), " + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})))," + " prod_eval_ops=(" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "EvalOpSub(lhs=" + "EvalOpSub(lhs=" + "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), " + "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), " + "rhs=" + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), " + "poly=EvalOpPoly({" + "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), " + "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))" + ) if __name__ == "__main__": diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index 865015e..76a8a99 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -4,12 +4,13 @@ Toom-Cook multiplication algorithm generator for SVP64 from abc import abstractmethod from enum import Enum from fractions import Fraction -from typing import Any, Generic, Iterable, Sequence, TypeVar +from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union from nmutil.plain_data import plain_data from bigint_presentation_code.compiler_ir import Fn, Op -from bigint_presentation_code.util import Literal, OFSet, OSet, final +from bigint_presentation_code.matrix import Matrix +from bigint_presentation_code.util import Literal, OSet, final @final @@ -23,39 +24,105 @@ class PointAtInfinity(Enum): POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY WORD_BITS = 64 +_EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]", + "EvalOpPoly", Fraction, int, None] -@plain_data(frozen=True, unsafe_hash=True) + +@plain_data(frozen=True, unsafe_hash=True, repr=False) @final class EvalOpPoly: """polynomial""" - __slots__ = "coefficients", - - def __init__(self, coefficients=()): - # type: (Iterable[Fraction | int] | EvalOpPoly | Fraction | int) -> None - if isinstance(coefficients, EvalOpPoly): - coefficients = coefficients.coefficients - elif isinstance(coefficients, (int, Fraction)): - coefficients = coefficients, - v = list(map(Fraction, coefficients)) - while len(v) != 0 and v[-1] == 0: - v.pop() - self.coefficients = tuple(v) # type: tuple[Fraction, ...] + __slots__ = "const_coeff", "var_coeffs" + + def __init__( + self, coeffs=None, # type: _EvalOpPolyCoefficients + const_coeff=None, # type: Fraction | int | None + var_coeffs=(), # type: Iterable[Fraction | int] | None + ): + if coeffs is not None: + if const_coeff is not None or var_coeffs != (): + raise ValueError( + "can't specify const_coeff or " + "var_coeffs along with coeffs") + if isinstance(coeffs, EvalOpPoly): + self.const_coeff = coeffs.const_coeff + self.var_coeffs = coeffs.var_coeffs + return + if isinstance(coeffs, (int, Fraction)): + const_coeff = Fraction(coeffs) + final_var_coeffs = [] # type: list[Fraction] + else: + const_coeff = 0 + final_var_coeffs = [] + for var, coeff in coeffs.items(): + if coeff == 0: + continue + coeff = Fraction(coeff) + if var is None: + const_coeff = coeff + continue + if var < 0: + raise ValueError("invalid variable index") + if var >= len(final_var_coeffs): + additional = var - len(final_var_coeffs) + final_var_coeffs.extend((Fraction(),) * additional) + final_var_coeffs.append(coeff) + else: + final_var_coeffs[var] = coeff + else: + if var_coeffs is None: + final_var_coeffs = [] + else: + final_var_coeffs = [Fraction(v) for v in var_coeffs] + while len(final_var_coeffs) > 0 and final_var_coeffs[-1] == 0: + final_var_coeffs.pop() + if const_coeff is None: + const_coeff = 0 + self.const_coeff = Fraction(const_coeff) + self.var_coeffs = tuple(final_var_coeffs) def __add__(self, rhs): # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly rhs = EvalOpPoly(rhs) - retval = list(self.coefficients) - extra = len(rhs.coefficients) - len(retval) - if extra > 0: - retval.extend([Fraction(0)] * extra) - for i, v in enumerate(rhs.coefficients): - retval[i] += v - return EvalOpPoly(retval) + const_coeff = self.const_coeff + rhs.const_coeff + var_coeffs = list(self.var_coeffs) + if len(rhs.var_coeffs) > len(var_coeffs): + var_coeffs.extend(rhs.var_coeffs[len(var_coeffs):]) + for var in range(min(len(self.var_coeffs), len(rhs.var_coeffs))): + var_coeffs[var] += rhs.var_coeffs[var] + return EvalOpPoly(const_coeff=const_coeff, var_coeffs=var_coeffs) + + @property + def coefficients(self): + # type: () -> dict[int | None, Fraction] + retval = {} # type: dict[int | None, Fraction] + if self.const_coeff != 0: + retval[None] = self.const_coeff + for var, coeff in enumerate(self.var_coeffs): + if coeff != 0: + retval[var] = coeff + return retval + + @property + def is_const(self): + # type: () -> bool + return self.var_coeffs == () + + def coeff(self, var): + # type: (int | None) -> Fraction + if var is None: + return self.const_coeff + if var < 0: + raise ValueError("invalid variable index") + if var < len(self.var_coeffs): + return self.var_coeffs[var] + return Fraction() __radd__ = __add__ def __neg__(self): - return EvalOpPoly(-v for v in self.coefficients) + return EvalOpPoly(const_coeff=-self.const_coeff, + var_coeffs=(-v for v in self.var_coeffs)) def __sub__(self, rhs): # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly @@ -66,8 +133,17 @@ class EvalOpPoly: return lhs + -self def __mul__(self, rhs): - # type: (int | Fraction) -> EvalOpPoly - return EvalOpPoly(v * rhs for v in self.coefficients) + # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly + if isinstance(rhs, EvalOpPoly): + if self.is_const: + self, rhs = rhs, self + if not rhs.is_const: + raise ValueError("can't represent exponents larger than one") + rhs = rhs.const_coeff + if rhs == 0: + return EvalOpPoly() + return EvalOpPoly(const_coeff=self.const_coeff * rhs, + var_coeffs=(i * rhs for i in self.var_coeffs)) __rmul__ = __mul__ @@ -75,7 +151,11 @@ class EvalOpPoly: # type: (int | Fraction) -> EvalOpPoly if rhs == 0: raise ZeroDivisionError() - return EvalOpPoly(v / rhs for v in self.coefficients) + return EvalOpPoly(const_coeff=self.const_coeff / rhs, + var_coeffs=(i / rhs for i in self.var_coeffs)) + + def __repr__(self): + return f"EvalOpPoly({self.coefficients})" _EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp") @@ -160,63 +240,155 @@ class EvalOpInput(EvalOp[int, Literal[0]]): def __init__(self, lhs, rhs=0): # type: (...) -> None if lhs < 0: - raise ValueError("Input split_index (lhs) must be >= 0") + raise ValueError("Input part_index (lhs) must be >= 0") if rhs != 0: raise ValueError("Input rhs must be 0") super().__init__(lhs, rhs) @property - def split_index(self): + def part_index(self): return self.lhs def _make_poly(self): # type: () -> EvalOpPoly - return EvalOpPoly([0] * self.split_index + [1]) + return EvalOpPoly({self.part_index: 1}) @plain_data(frozen=True, unsafe_hash=True) @final class ToomCookInstance: - __slots__ = ("lhs_split_count", "rhs_split_count", "eval_points", - "lhs_eval_ops", "rhs_eval_ops", "product_eval_ops") + __slots__ = ("lhs_part_count", "rhs_part_count", "eval_points", + "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops") + + @property + def prod_part_count(self): + return self.lhs_part_count + self.rhs_part_count - 1 + + @staticmethod + def make_eval_matrix(width, eval_points): + # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction] + retval = Matrix(height=len(eval_points), width=width) + for row, col in retval.indexes(): + eval_point = eval_points[row] + if eval_point is POINT_AT_INFINITY: + retval[row, col] = int(col == width - 1) + else: + retval[row, col] = eval_point ** col + return retval + + def get_lhs_eval_matrix(self): + # type: () -> Matrix[Fraction] + return self.make_eval_matrix(self.lhs_part_count, self.eval_points) + + @staticmethod + def make_input_poly_vector(height): + # type: (int) -> Matrix[EvalOpPoly] + return Matrix(height=height, width=1, element_type=EvalOpPoly, + data=(EvalOpPoly({i: 1}) for i in range(height))) + + def get_lhs_eval_polys(self): + # type: () -> list[EvalOpPoly] + return list(self.get_lhs_eval_matrix().cast(EvalOpPoly) + @ self.make_input_poly_vector(self.lhs_part_count)) + + def get_rhs_eval_matrix(self): + # type: () -> Matrix[Fraction] + return self.make_eval_matrix(self.rhs_part_count, self.eval_points) + + def get_rhs_eval_polys(self): + # type: () -> list[EvalOpPoly] + return list(self.get_rhs_eval_matrix().cast(EvalOpPoly) + @ self.make_input_poly_vector(self.rhs_part_count)) + + def get_prod_inverse_eval_matrix(self): + # type: () -> Matrix[Fraction] + return self.make_eval_matrix(self.prod_part_count, self.eval_points) + + def get_prod_eval_matrix(self): + # type: () -> Matrix[Fraction] + return self.get_prod_inverse_eval_matrix().inverse() + + def get_prod_eval_polys(self): + # type: () -> list[EvalOpPoly] + return list(self.get_prod_eval_matrix().cast(EvalOpPoly) + @ self.make_input_poly_vector(self.prod_part_count)) def __init__( - self, lhs_split_count, # type: int - rhs_split_count, # type: int + self, lhs_part_count, # type: int + rhs_part_count, # type: int eval_points, # type: Iterable[PointAtInfinity | int] lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]] rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]] - product_eval_ops, # type: Iterable[EvalOp[Any, Any]] + prod_eval_ops, # type: Iterable[EvalOp[Any, Any]] ): # type: (...) -> None - self.lhs_split_count = lhs_split_count - if self.lhs_split_count < 2: - raise ValueError("lhs_split_count must be at least 2") - self.rhs_split_count = rhs_split_count - if self.rhs_split_count < 2: - raise ValueError("rhs_split_count must be at least 2") + self.lhs_part_count = lhs_part_count + if self.lhs_part_count < 2: + raise ValueError("lhs_part_count must be at least 2") + self.rhs_part_count = rhs_part_count + if self.rhs_part_count < 2: + raise ValueError("rhs_part_count must be at least 2") eval_points = list(eval_points) - self.eval_points = OFSet(eval_points) - if len(self.eval_points) != len(eval_points): + self.eval_points = tuple(eval_points) + if len(self.eval_points) != len(set(self.eval_points)): raise ValueError("duplicate eval points") self.lhs_eval_ops = tuple(lhs_eval_ops) - if len(self.lhs_eval_ops) != len(self.eval_points): + if len(self.lhs_eval_ops) != self.prod_part_count: raise ValueError("wrong number of lhs_eval_ops") self.rhs_eval_ops = tuple(rhs_eval_ops) - if len(self.rhs_eval_ops) != len(self.eval_points): + if len(self.rhs_eval_ops) != self.prod_part_count: raise ValueError("wrong number of rhs_eval_ops") - if self.lhs_split_count < 2: - raise ValueError("lhs_split_count must be at least 2") - if self.rhs_split_count < 2: - raise ValueError("rhs_split_count must be at least 2") - if (self.lhs_split_count + self.rhs_split_count - 1 - != len(self.eval_points)): + if len(self.eval_points) != self.prod_part_count: raise ValueError("wrong number of eval_points") - self.product_eval_ops = tuple(product_eval_ops) - if len(self.product_eval_ops) != len(self.eval_points): - raise ValueError("wrong number of product_eval_ops") - # TODO: compute and check matrix and all the *_eval_ops - raise NotImplementedError + self.prod_eval_ops = tuple(prod_eval_ops) + if len(self.prod_eval_ops) != self.prod_part_count: + raise ValueError("wrong number of prod_eval_ops") + + lhs_eval_polys = self.get_lhs_eval_polys() + for i, eval_op in enumerate(self.lhs_eval_ops): + if lhs_eval_polys[i] != eval_op.poly: + raise ValueError( + f"lhs_eval_ops[{i}] is incorrect: expected polynomial: " + f"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}") + + rhs_eval_polys = self.get_rhs_eval_polys() + for i, eval_op in enumerate(self.rhs_eval_ops): + if rhs_eval_polys[i] != eval_op.poly: + raise ValueError( + f"rhs_eval_ops[{i}] is incorrect: expected polynomial: " + f"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}") + + prod_eval_polys = self.get_prod_eval_polys() # also checks matrix + for i, eval_op in enumerate(self.prod_eval_ops): + if prod_eval_polys[i] != eval_op.poly: + raise ValueError( + f"prod_eval_ops[{i}] is incorrect: expected polynomial: " + f"{prod_eval_polys[i]} found polynomial: {eval_op.poly}") + + @staticmethod + def make_toom_2(): + # type: () -> ToomCookInstance + return ToomCookInstance( + lhs_part_count=2, + rhs_part_count=2, + eval_points=[0, 1, POINT_AT_INFINITY], + lhs_eval_ops=[ + EvalOpInput(0), + EvalOpAdd(EvalOpInput(0), EvalOpInput(1)), + EvalOpInput(1), + ], + rhs_eval_ops=[ + EvalOpInput(0), + EvalOpAdd(EvalOpInput(0), EvalOpInput(1)), + EvalOpInput(1), + ], + prod_eval_ops=[ + EvalOpInput(0), + EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)), + EvalOpInput(2)), + EvalOpInput(2), + ], + ) def toom_cook_mul(fn, word_count, instances): -- 2.30.2