from collections import defaultdict
from enum import Enum, EnumMeta, unique
from functools import lru_cache
-from typing import (TYPE_CHECKING, Any, Generic, Iterable, Sequence, Type,
- TypeVar, cast)
+from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
from nmutil.plain_data import fields, plain_data
-from bigint_presentation_code.ordered_set import OFSet, OSet
-
-if TYPE_CHECKING:
- from typing_extensions import final
-else:
- def final(v):
- return v
+from bigint_presentation_code.util import OFSet, OSet, final
class ABCEnumMeta(EnumMeta, ABCMeta):
+++ /dev/null
-from typing import AbstractSet, Iterable, MutableSet, TypeVar
-
-_T_co = TypeVar("_T_co", covariant=True)
-_T = TypeVar("_T")
-
-
-class OFSet(AbstractSet[_T_co]):
- """ ordered frozen set """
-
- def __init__(self, items=()):
- # type: (Iterable[_T_co]) -> None
- self.__items = {v: None for v in items}
-
- def __contains__(self, x):
- return x in self.__items
-
- def __iter__(self):
- return iter(self.__items)
-
- def __len__(self):
- return len(self.__items)
-
- def __hash__(self):
- return self._hash()
-
- def __repr__(self):
- if len(self) == 0:
- return "OFSet()"
- return f"OFSet({list(self)})"
-
-
-class OSet(MutableSet[_T]):
- """ ordered mutable set """
-
- def __init__(self, items=()):
- # type: (Iterable[_T]) -> None
- self.__items = {v: None for v in items}
-
- def __contains__(self, x):
- return x in self.__items
-
- def __iter__(self):
- return iter(self.__items)
-
- def __len__(self):
- return len(self.__items)
-
- def add(self, value):
- # type: (_T) -> None
- self.__items[value] = None
-
- def discard(self, value):
- # type: (_T) -> None
- self.__items.pop(value, None)
-
- def __repr__(self):
- if len(self) == 0:
- return "OSet()"
- return f"OSet({list(self)})"
"""
from itertools import combinations
-from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar
+from typing import Generic, Iterable, Mapping, TypeVar
from nmutil.plain_data import plain_data
from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
RegLoc, RegType, SSAVal)
-from bigint_presentation_code.ordered_set import OFSet, OSet
-
-if TYPE_CHECKING:
- from typing_extensions import final
-else:
- def final(v):
- return v
-
+from bigint_presentation_code.util import OFSet, OSet, final
_RegType = TypeVar("_RegType", bound=RegType)
"""
-Toom-Cook algorithm generator for SVP64
-
-the register allocator uses an algorithm based on:
-[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+Toom-Cook multiplication algorithm generator for SVP64
"""
-from bigint_presentation_code.compiler_ir import Op
-from bigint_presentation_code.register_allocator import allocate_registers, AllocationFailed
+from abc import abstractmethod
+from enum import Enum
+from fractions import Fraction
+from typing import Any, Generic, Iterable, Sequence, TypeVar
+
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir import Fn, Op
+from bigint_presentation_code.util import Literal, OFSet, OSet, final
+
+
+@final
+class PointAtInfinity(Enum):
+ POINT_AT_INFINITY = "POINT_AT_INFINITY"
+
+ def __repr__(self):
+ return self.name
+
+
+POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY
+WORD_BITS = 64
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpPoly:
+ """polynomial"""
+ __slots__ = "coefficients",
+
+ def __init__(self, coefficients=()):
+ # type: (Iterable[Fraction | int] | EvalOpPoly | Fraction | int) -> None
+ if isinstance(coefficients, EvalOpPoly):
+ coefficients = coefficients.coefficients
+ elif isinstance(coefficients, (int, Fraction)):
+ coefficients = coefficients,
+ v = list(map(Fraction, coefficients))
+ while len(v) != 0 and v[-1] == 0:
+ v.pop()
+ self.coefficients = tuple(v) # type: tuple[Fraction, ...]
+
+ def __add__(self, rhs):
+ # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+ rhs = EvalOpPoly(rhs)
+ retval = list(self.coefficients)
+ extra = len(rhs.coefficients) - len(retval)
+ if extra > 0:
+ retval.extend([Fraction(0)] * extra)
+ for i, v in enumerate(rhs.coefficients):
+ retval[i] += v
+ return EvalOpPoly(retval)
+
+ __radd__ = __add__
+
+ def __neg__(self):
+ return EvalOpPoly(-v for v in self.coefficients)
+
+ def __sub__(self, rhs):
+ # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+ return self + -rhs
+
+ def __rsub__(self, lhs):
+ # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly
+ return lhs + -self
+
+ def __mul__(self, rhs):
+ # type: (int | Fraction) -> EvalOpPoly
+ return EvalOpPoly(v * rhs for v in self.coefficients)
+
+ __rmul__ = __mul__
+
+ def __truediv__(self, rhs):
+ # type: (int | Fraction) -> EvalOpPoly
+ if rhs == 0:
+ raise ZeroDivisionError()
+ return EvalOpPoly(v / rhs for v in self.coefficients)
+
+
+_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp")
+_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp")
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]):
+ __slots__ = "lhs", "rhs", "poly"
+
+ @property
+ def lhs_poly(self):
+ # type: () -> EvalOpPoly
+ if isinstance(self.lhs, int):
+ return EvalOpPoly(self.lhs)
+ return self.lhs.poly
+
+ @property
+ def rhs_poly(self):
+ # type: () -> EvalOpPoly
+ if isinstance(self.rhs, int):
+ return EvalOpPoly(self.rhs)
+ return self.rhs.poly
+
+ @abstractmethod
+ def _make_poly(self):
+ # type: () -> EvalOpPoly
+ ...
+
+ def __init__(self, lhs, rhs):
+ # type: (_EvalOpLHS, _EvalOpRHS) -> None
+ self.lhs = lhs
+ self.rhs = rhs
+ self.poly = self._make_poly()
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+ __slots__ = ()
+
+ def _make_poly(self):
+ # type: () -> EvalOpPoly
+ return self.lhs_poly + self.rhs_poly
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]):
+ __slots__ = ()
+
+ def _make_poly(self):
+ # type: () -> EvalOpPoly
+ return self.lhs_poly - self.rhs_poly
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpMul(EvalOp[_EvalOpLHS, int]):
+ __slots__ = ()
+
+ def _make_poly(self):
+ # type: () -> EvalOpPoly
+ return self.lhs_poly * self.rhs
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]):
+ __slots__ = ()
+
+ def _make_poly(self):
+ # type: () -> EvalOpPoly
+ return self.lhs_poly / self.rhs
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class EvalOpInput(EvalOp[int, Literal[0]]):
+ __slots__ = ()
+
+ def __init__(self, lhs, rhs=0):
+ # type: (...) -> None
+ if lhs < 0:
+ raise ValueError("Input split_index (lhs) must be >= 0")
+ if rhs != 0:
+ raise ValueError("Input rhs must be 0")
+ super().__init__(lhs, rhs)
+
+ @property
+ def split_index(self):
+ return self.lhs
+
+ def _make_poly(self):
+ # type: () -> EvalOpPoly
+ return EvalOpPoly([0] * self.split_index + [1])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ToomCookInstance:
+ __slots__ = ("lhs_split_count", "rhs_split_count", "eval_points",
+ "lhs_eval_ops", "rhs_eval_ops", "product_eval_ops")
+
+ def __init__(
+ self, lhs_split_count, # type: int
+ rhs_split_count, # type: int
+ eval_points, # type: Iterable[PointAtInfinity | int]
+ lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
+ rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]]
+ product_eval_ops, # type: Iterable[EvalOp[Any, Any]]
+ ):
+ # type: (...) -> None
+ self.lhs_split_count = lhs_split_count
+ if self.lhs_split_count < 2:
+ raise ValueError("lhs_split_count must be at least 2")
+ self.rhs_split_count = rhs_split_count
+ if self.rhs_split_count < 2:
+ raise ValueError("rhs_split_count must be at least 2")
+ eval_points = list(eval_points)
+ self.eval_points = OFSet(eval_points)
+ if len(self.eval_points) != len(eval_points):
+ raise ValueError("duplicate eval points")
+ self.lhs_eval_ops = tuple(lhs_eval_ops)
+ if len(self.lhs_eval_ops) != len(self.eval_points):
+ raise ValueError("wrong number of lhs_eval_ops")
+ self.rhs_eval_ops = tuple(rhs_eval_ops)
+ if len(self.rhs_eval_ops) != len(self.eval_points):
+ raise ValueError("wrong number of rhs_eval_ops")
+ if self.lhs_split_count < 2:
+ raise ValueError("lhs_split_count must be at least 2")
+ if self.rhs_split_count < 2:
+ raise ValueError("rhs_split_count must be at least 2")
+ if (self.lhs_split_count + self.rhs_split_count - 1
+ != len(self.eval_points)):
+ raise ValueError("wrong number of eval_points")
+ self.product_eval_ops = tuple(product_eval_ops)
+ if len(self.product_eval_ops) != len(self.eval_points):
+ raise ValueError("wrong number of product_eval_ops")
+ # TODO: compute and check matrix and all the *_eval_ops
+ raise NotImplementedError
+
+
+def toom_cook_mul(fn, word_count, instances):
+ # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op]
+ retval = OSet() # type: OSet[Op]
+ raise NotImplementedError
+ return retval
--- /dev/null
+from typing import (TYPE_CHECKING, AbstractSet, Iterable, Iterator, Mapping,
+ MutableSet, TypeVar, Union)
+
+if TYPE_CHECKING:
+ from typing_extensions import Literal, final
+else:
+ def final(v):
+ return v
+
+ class _Literal:
+ def __getitem__(self, v):
+ if isinstance(v, tuple):
+ return Union[tuple(type(i) for i in v)]
+ return type(v)
+
+ Literal = _Literal()
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+__all__ = ["final", "Literal", "OFSet", "OSet", "FMap"]
+
+
+class OFSet(AbstractSet[_T_co]):
+ """ ordered frozen set """
+ __slots__ = "__items",
+
+ def __init__(self, items=()):
+ # type: (Iterable[_T_co]) -> None
+ self.__items = {v: None for v in items}
+
+ def __contains__(self, x):
+ return x in self.__items
+
+ def __iter__(self):
+ return iter(self.__items)
+
+ def __len__(self):
+ return len(self.__items)
+
+ def __hash__(self):
+ return self._hash()
+
+ def __repr__(self):
+ if len(self) == 0:
+ return "OFSet()"
+ return f"OFSet({list(self)})"
+
+
+class OSet(MutableSet[_T]):
+ """ ordered mutable set """
+ __slots__ = "__items",
+
+ def __init__(self, items=()):
+ # type: (Iterable[_T]) -> None
+ self.__items = {v: None for v in items}
+
+ def __contains__(self, x):
+ return x in self.__items
+
+ def __iter__(self):
+ return iter(self.__items)
+
+ def __len__(self):
+ return len(self.__items)
+
+ def add(self, value):
+ # type: (_T) -> None
+ self.__items[value] = None
+
+ def discard(self, value):
+ # type: (_T) -> None
+ self.__items.pop(value, None)
+
+ def __repr__(self):
+ if len(self) == 0:
+ return "OSet()"
+ return f"OSet({list(self)})"
+
+
+class FMap(Mapping[_T, _T_co]):
+ """ordered frozen hashable mapping"""
+ __slots__ = "__items", "__hash"
+
+ def __init__(self, items=()):
+ # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
+ self.__items = dict(items) # type: dict[_T, _T_co]
+ self.__hash = None # type: None | int
+
+ def __getitem__(self, item):
+ # type: (_T) -> _T_co
+ return self.__items[item]
+
+ def __iter__(self):
+ # type: () -> Iterator[_T]
+ return iter(self.__items)
+
+ def __len__(self):
+ return len(self.__items)
+
+ def __eq__(self, other):
+ # type: (object) -> bool
+ if isinstance(other, FMap):
+ return self.__items == other.__items
+ return super().__eq__(other)
+
+ def __hash__(self):
+ if self.__hash is None:
+ self.__hash = hash(frozenset(self.items()))
+ return self.__hash
+
+ def __repr__(self):
+ return f"FMap({self.__items})"
--- /dev/null
+from typing import (AbstractSet, Iterable, Iterator, Mapping,
+ MutableSet, TypeVar, overload)
+from typing_extensions import final, Literal
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+__all__ = ["final", "Literal", "OFSet", "OSet", "FMap"]
+
+
+class OFSet(AbstractSet[_T_co]):
+ """ ordered frozen set """
+
+ def __init__(self, items: Iterable[_T_co] = ()):
+ ...
+
+ def __contains__(self, x: object) -> bool:
+ ...
+
+ def __iter__(self) -> Iterator[_T_co]:
+ ...
+
+ def __len__(self) -> int:
+ ...
+
+ def __hash__(self) -> int:
+ ...
+
+ def __repr__(self) -> str:
+ ...
+
+
+class OSet(MutableSet[_T]):
+ """ ordered mutable set """
+
+ def __init__(self, items: Iterable[_T] = ()):
+ ...
+
+ def __contains__(self, x: object) -> bool:
+ ...
+
+ def __iter__(self) -> Iterator[_T]:
+ ...
+
+ def __len__(self) -> int:
+ ...
+
+ def add(self, value: _T) -> None:
+ ...
+
+ def discard(self, value: _T) -> None:
+ ...
+
+ def __repr__(self) -> str:
+ ...
+
+
+class FMap(Mapping[_T, _T_co]):
+ """ordered frozen hashable mapping"""
+ @overload
+ def __init__(self, items: Mapping[_T, _T_co] = ...): ...
+ @overload
+ def __init__(self, items: Iterable[tuple[_T, _T_co]] = ...): ...
+
+ def __getitem__(self, item: _T) -> _T_co:
+ ...
+
+ def __iter__(self) -> Iterator[_T]:
+ ...
+
+ def __len__(self) -> int:
+ ...
+
+ def __eq__(self, other: object) -> bool:
+ ...
+
+ def __hash__(self) -> int:
+ ...
+
+ def __repr__(self) -> str:
+ ...