import unittest
-import bigint_presentation_code.toom_cook
+from bigint_presentation_code.toom_cook import ToomCookInstance
class TestToomCook(unittest.TestCase):
- pass # no tests yet, just testing importing
+ def test_toom_2(self):
+ TOOM_2 = ToomCookInstance.make_toom_2()
+ print(repr(repr(TOOM_2)))
+ self.assertEqual(
+ repr(TOOM_2),
+ "ToomCookInstance(lhs_part_count=2, rhs_part_count=2, "
+ "eval_points=(0, 1, POINT_AT_INFINITY), "
+ "lhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+ " rhs_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpAdd(lhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(1, 1), 1: Fraction(1, 1)})), "
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)}))),"
+ " prod_eval_ops=("
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "EvalOpSub(lhs="
+ "EvalOpSub(lhs="
+ "EvalOpInput(lhs=1, rhs=0, poly=EvalOpPoly({1: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=0, rhs=0, poly=EvalOpPoly({0: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({0: Fraction(-1, 1), 1: Fraction(1, 1)})), "
+ "rhs="
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)})), "
+ "poly=EvalOpPoly({"
+ "0: Fraction(-1, 1), 1: Fraction(1, 1), 2: Fraction(-1, 1)})), "
+ "EvalOpInput(lhs=2, rhs=0, poly=EvalOpPoly({2: Fraction(1, 1)}))))"
+ )
if __name__ == "__main__":
from abc import abstractmethod
from enum import Enum
from fractions import Fraction
-from typing import Any, Generic, Iterable, Sequence, TypeVar
+from typing import Any, Generic, Iterable, Mapping, Sequence, TypeVar, Union
from nmutil.plain_data import plain_data
from bigint_presentation_code.compiler_ir import Fn, Op
-from bigint_presentation_code.util import Literal, OFSet, OSet, final
+from bigint_presentation_code.matrix import Matrix
+from bigint_presentation_code.util import Literal, OSet, final
@final
POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
WORD_BITS = 64
+_EvalOpPolyCoefficients = Union["Mapping[int | None, Fraction | int]",
+ "EvalOpPoly", Fraction, int, None]
-@plain_data(frozen=True, unsafe_hash=True)
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
class EvalOpPoly:
"""polynomial"""
- __slots__ = "coefficients",
-
- def __init__(self, coefficients=()):
- # type: (Iterable[Fraction | int] | EvalOpPoly | Fraction | int) -> None
- if isinstance(coefficients, EvalOpPoly):
- coefficients = coefficients.coefficients
- elif isinstance(coefficients, (int, Fraction)):
- coefficients = coefficients,
- v = list(map(Fraction, coefficients))
- while len(v) != 0 and v[-1] == 0:
- v.pop()
- self.coefficients = tuple(v) # type: tuple[Fraction, ...]
+ __slots__ = "const_coeff", "var_coeffs"
+
+ def __init__(
+ self, coeffs=None, # type: _EvalOpPolyCoefficients
+ const_coeff=None, # type: Fraction | int | None
+ var_coeffs=(), # type: Iterable[Fraction | int] | None
+ ):
+ if coeffs is not None:
+ if const_coeff is not None or var_coeffs != ():
+ raise ValueError(
+ "can't specify const_coeff or "
+ "var_coeffs along with coeffs")
+ if isinstance(coeffs, EvalOpPoly):
+ self.const_coeff = coeffs.const_coeff
+ self.var_coeffs = coeffs.var_coeffs
+ return
+ if isinstance(coeffs, (int, Fraction)):
+ const_coeff = Fraction(coeffs)
+ final_var_coeffs = [] # type: list[Fraction]
+ else:
+ const_coeff = 0
+ final_var_coeffs = []
+ for var, coeff in coeffs.items():
+ if coeff == 0:
+ continue
+ coeff = Fraction(coeff)
+ if var is None:
+ const_coeff = coeff
+ continue
+ if var < 0:
+ raise ValueError("invalid variable index")
+ if var >= len(final_var_coeffs):
+ additional = var - len(final_var_coeffs)
+ final_var_coeffs.extend((Fraction(),) * additional)
+ final_var_coeffs.append(coeff)
+ else:
+ final_var_coeffs[var] = coeff
+ else:
+ if var_coeffs is None:
+ final_var_coeffs = []
+ else:
+ final_var_coeffs = [Fraction(v) for v in var_coeffs]
+ while len(final_var_coeffs) > 0 and final_var_coeffs[-1] == 0:
+ final_var_coeffs.pop()
+ if const_coeff is None:
+ const_coeff = 0
+ self.const_coeff = Fraction(const_coeff)
+ self.var_coeffs = tuple(final_var_coeffs)
def __add__(self, rhs):
# type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
rhs = EvalOpPoly(rhs)
- retval = list(self.coefficients)
- extra = len(rhs.coefficients) - len(retval)
- if extra > 0:
- retval.extend([Fraction(0)] * extra)
- for i, v in enumerate(rhs.coefficients):
- retval[i] += v
- return EvalOpPoly(retval)
+ const_coeff = self.const_coeff + rhs.const_coeff
+ var_coeffs = list(self.var_coeffs)
+ if len(rhs.var_coeffs) > len(var_coeffs):
+ var_coeffs.extend(rhs.var_coeffs[len(var_coeffs):])
+ for var in range(min(len(self.var_coeffs), len(rhs.var_coeffs))):
+ var_coeffs[var] += rhs.var_coeffs[var]
+ return EvalOpPoly(const_coeff=const_coeff, var_coeffs=var_coeffs)
+
+ @property
+ def coefficients(self):
+ # type: () -> dict[int | None, Fraction]
+ retval = {} # type: dict[int | None, Fraction]
+ if self.const_coeff != 0:
+ retval[None] = self.const_coeff
+ for var, coeff in enumerate(self.var_coeffs):
+ if coeff != 0:
+ retval[var] = coeff
+ return retval
+
+ @property
+ def is_const(self):
+ # type: () -> bool
+ return self.var_coeffs == ()
+
+ def coeff(self, var):
+ # type: (int | None) -> Fraction
+ if var is None:
+ return self.const_coeff
+ if var < 0:
+ raise ValueError("invalid variable index")
+ if var < len(self.var_coeffs):
+ return self.var_coeffs[var]
+ return Fraction()
__radd__ = __add__
def __neg__(self):
- return EvalOpPoly(-v for v in self.coefficients)
+ return EvalOpPoly(const_coeff=-self.const_coeff,
+ var_coeffs=(-v for v in self.var_coeffs))
def __sub__(self, rhs):
# type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
return lhs + -self
def __mul__(self, rhs):
- # type: (int | Fraction) -> EvalOpPoly
- return EvalOpPoly(v * rhs for v in self.coefficients)
+ # type: (int | Fraction | EvalOpPoly) -> EvalOpPoly
+ if isinstance(rhs, EvalOpPoly):
+ if self.is_const:
+ self, rhs = rhs, self
+ if not rhs.is_const:
+ raise ValueError("can't represent exponents larger than one")
+ rhs = rhs.const_coeff
+ if rhs == 0:
+ return EvalOpPoly()
+ return EvalOpPoly(const_coeff=self.const_coeff * rhs,
+ var_coeffs=(i * rhs for i in self.var_coeffs))
__rmul__ = __mul__
# type: (int | Fraction) -> EvalOpPoly
if rhs == 0:
raise ZeroDivisionError()
- return EvalOpPoly(v / rhs for v in self.coefficients)
+ return EvalOpPoly(const_coeff=self.const_coeff / rhs,
+ var_coeffs=(i / rhs for i in self.var_coeffs))
+
+ def __repr__(self):
+ return f"EvalOpPoly({self.coefficients})"
_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
def __init__(self, lhs, rhs=0):
# type: (...) -> None
if lhs < 0:
- raise ValueError("Input split_index (lhs) must be >= 0")
+ raise ValueError("Input part_index (lhs) must be >= 0")
if rhs != 0:
raise ValueError("Input rhs must be 0")
super().__init__(lhs, rhs)
@property
- def split_index(self):
+ def part_index(self):
return self.lhs
def _make_poly(self):
# type: () -> EvalOpPoly
- return EvalOpPoly([0] * self.split_index + [1])
+ return EvalOpPoly({self.part_index: 1})
@plain_data(frozen=True, unsafe_hash=True)
@final
class ToomCookInstance:
- __slots__ = ("lhs_split_count", "rhs_split_count", "eval_points",
- "lhs_eval_ops", "rhs_eval_ops", "product_eval_ops")
+ __slots__ = ("lhs_part_count", "rhs_part_count", "eval_points",
+ "lhs_eval_ops", "rhs_eval_ops", "prod_eval_ops")
+
+ @property
+ def prod_part_count(self):
+ return self.lhs_part_count + self.rhs_part_count - 1
+
+ @staticmethod
+ def make_eval_matrix(width, eval_points):
+ # type: (int, tuple[PointAtInfinity | int, ...]) -> Matrix[Fraction]
+ retval = Matrix(height=len(eval_points), width=width)
+ for row, col in retval.indexes():
+ eval_point = eval_points[row]
+ if eval_point is POINT_AT_INFINITY:
+ retval[row, col] = int(col == width - 1)
+ else:
+ retval[row, col] = eval_point ** col
+ return retval
+
+ def get_lhs_eval_matrix(self):
+ # type: () -> Matrix[Fraction]
+ return self.make_eval_matrix(self.lhs_part_count, self.eval_points)
+
+ @staticmethod
+ def make_input_poly_vector(height):
+ # type: (int) -> Matrix[EvalOpPoly]
+ return Matrix(height=height, width=1, element_type=EvalOpPoly,
+ data=(EvalOpPoly({i: 1}) for i in range(height)))
+
+ def get_lhs_eval_polys(self):
+ # type: () -> list[EvalOpPoly]
+ return list(self.get_lhs_eval_matrix().cast(EvalOpPoly)
+ @ self.make_input_poly_vector(self.lhs_part_count))
+
+ def get_rhs_eval_matrix(self):
+ # type: () -> Matrix[Fraction]
+ return self.make_eval_matrix(self.rhs_part_count, self.eval_points)
+
+ def get_rhs_eval_polys(self):
+ # type: () -> list[EvalOpPoly]
+ return list(self.get_rhs_eval_matrix().cast(EvalOpPoly)
+ @ self.make_input_poly_vector(self.rhs_part_count))
+
+ def get_prod_inverse_eval_matrix(self):
+ # type: () -> Matrix[Fraction]
+ return self.make_eval_matrix(self.prod_part_count, self.eval_points)
+
+ def get_prod_eval_matrix(self):
+ # type: () -> Matrix[Fraction]
+ return self.get_prod_inverse_eval_matrix().inverse()
+
+ def get_prod_eval_polys(self):
+ # type: () -> list[EvalOpPoly]
+ return list(self.get_prod_eval_matrix().cast(EvalOpPoly)
+ @ self.make_input_poly_vector(self.prod_part_count))
def __init__(
- self, lhs_split_count, # type: int
- rhs_split_count, # type: int
+ self, lhs_part_count, # type: int
+ rhs_part_count, # type: int
eval_points, # type: Iterable[PointAtInfinity | int]
lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
- product_eval_ops, # type: Iterable[EvalOp[Any, Any]]
+ prod_eval_ops, # type: Iterable[EvalOp[Any, Any]]
):
# type: (...) -> None
- self.lhs_split_count = lhs_split_count
- if self.lhs_split_count < 2:
- raise ValueError("lhs_split_count must be at least 2")
- self.rhs_split_count = rhs_split_count
- if self.rhs_split_count < 2:
- raise ValueError("rhs_split_count must be at least 2")
+ self.lhs_part_count = lhs_part_count
+ if self.lhs_part_count < 2:
+ raise ValueError("lhs_part_count must be at least 2")
+ self.rhs_part_count = rhs_part_count
+ if self.rhs_part_count < 2:
+ raise ValueError("rhs_part_count must be at least 2")
eval_points = list(eval_points)
- self.eval_points = OFSet(eval_points)
- if len(self.eval_points) != len(eval_points):
+ self.eval_points = tuple(eval_points)
+ if len(self.eval_points) != len(set(self.eval_points)):
raise ValueError("duplicate eval points")
self.lhs_eval_ops = tuple(lhs_eval_ops)
- if len(self.lhs_eval_ops) != len(self.eval_points):
+ if len(self.lhs_eval_ops) != self.prod_part_count:
raise ValueError("wrong number of lhs_eval_ops")
self.rhs_eval_ops = tuple(rhs_eval_ops)
- if len(self.rhs_eval_ops) != len(self.eval_points):
+ if len(self.rhs_eval_ops) != self.prod_part_count:
raise ValueError("wrong number of rhs_eval_ops")
- if self.lhs_split_count < 2:
- raise ValueError("lhs_split_count must be at least 2")
- if self.rhs_split_count < 2:
- raise ValueError("rhs_split_count must be at least 2")
- if (self.lhs_split_count + self.rhs_split_count - 1
- != len(self.eval_points)):
+ if len(self.eval_points) != self.prod_part_count:
raise ValueError("wrong number of eval_points")
- self.product_eval_ops = tuple(product_eval_ops)
- if len(self.product_eval_ops) != len(self.eval_points):
- raise ValueError("wrong number of product_eval_ops")
- # TODO: compute and check matrix and all the *_eval_ops
- raise NotImplementedError
+ self.prod_eval_ops = tuple(prod_eval_ops)
+ if len(self.prod_eval_ops) != self.prod_part_count:
+ raise ValueError("wrong number of prod_eval_ops")
+
+ lhs_eval_polys = self.get_lhs_eval_polys()
+ for i, eval_op in enumerate(self.lhs_eval_ops):
+ if lhs_eval_polys[i] != eval_op.poly:
+ raise ValueError(
+ f"lhs_eval_ops[{i}] is incorrect: expected polynomial: "
+ f"{lhs_eval_polys[i]} found polynomial: {eval_op.poly}")
+
+ rhs_eval_polys = self.get_rhs_eval_polys()
+ for i, eval_op in enumerate(self.rhs_eval_ops):
+ if rhs_eval_polys[i] != eval_op.poly:
+ raise ValueError(
+ f"rhs_eval_ops[{i}] is incorrect: expected polynomial: "
+ f"{rhs_eval_polys[i]} found polynomial: {eval_op.poly}")
+
+ prod_eval_polys = self.get_prod_eval_polys() # also checks matrix
+ for i, eval_op in enumerate(self.prod_eval_ops):
+ if prod_eval_polys[i] != eval_op.poly:
+ raise ValueError(
+ f"prod_eval_ops[{i}] is incorrect: expected polynomial: "
+ f"{prod_eval_polys[i]} found polynomial: {eval_op.poly}")
+
+ @staticmethod
+ def make_toom_2():
+ # type: () -> ToomCookInstance
+ return ToomCookInstance(
+ lhs_part_count=2,
+ rhs_part_count=2,
+ eval_points=[0, 1, POINT_AT_INFINITY],
+ lhs_eval_ops=[
+ EvalOpInput(0),
+ EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
+ EvalOpInput(1),
+ ],
+ rhs_eval_ops=[
+ EvalOpInput(0),
+ EvalOpAdd(EvalOpInput(0), EvalOpInput(1)),
+ EvalOpInput(1),
+ ],
+ prod_eval_ops=[
+ EvalOpInput(0),
+ EvalOpSub(EvalOpSub(EvalOpInput(1), EvalOpInput(0)),
+ EvalOpInput(2)),
+ EvalOpInput(2),
+ ],
+ )
def toom_cook_mul(fn, word_count, instances):