From: Jacob Lifshay Date: Wed, 19 Oct 2022 09:01:16 +0000 (-0700) Subject: working on toom-cook multiplication X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=82a992e04903f76dbf7e42bdd6f20c7b21378571;p=bigint-presentation-code.git working on toom-cook multiplication --- diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 3ebb7cf..517e542 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -8,18 +8,11 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict from enum import Enum, EnumMeta, unique from functools import lru_cache -from typing import (TYPE_CHECKING, Any, Generic, Iterable, Sequence, Type, - TypeVar, cast) +from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast from nmutil.plain_data import fields, plain_data -from bigint_presentation_code.ordered_set import OFSet, OSet - -if TYPE_CHECKING: - from typing_extensions import final -else: - def final(v): - return v +from bigint_presentation_code.util import OFSet, OSet, final class ABCEnumMeta(EnumMeta, ABCMeta): diff --git a/src/bigint_presentation_code/ordered_set.py b/src/bigint_presentation_code/ordered_set.py deleted file mode 100644 index 018f97b..0000000 --- a/src/bigint_presentation_code/ordered_set.py +++ /dev/null @@ -1,59 +0,0 @@ -from typing import AbstractSet, Iterable, MutableSet, TypeVar - -_T_co = TypeVar("_T_co", covariant=True) -_T = TypeVar("_T") - - -class OFSet(AbstractSet[_T_co]): - """ ordered frozen set """ - - def __init__(self, items=()): - # type: (Iterable[_T_co]) -> None - self.__items = {v: None for v in items} - - def __contains__(self, x): - return x in self.__items - - def __iter__(self): - return iter(self.__items) - - def __len__(self): - return len(self.__items) - - def __hash__(self): - return self._hash() - - def __repr__(self): - if len(self) == 0: - return "OFSet()" - return f"OFSet({list(self)})" - - -class OSet(MutableSet[_T]): - """ ordered mutable set """ - - def __init__(self, items=()): - # type: (Iterable[_T]) -> None - self.__items = {v: None for v in items} - - def __contains__(self, x): - return x in self.__items - - def __iter__(self): - return iter(self.__items) - - def __len__(self): - return len(self.__items) - - def add(self, value): - # type: (_T) -> None - self.__items[value] = None - - def discard(self, value): - # type: (_T) -> None - self.__items.pop(value, None) - - def __repr__(self): - if len(self) == 0: - return "OSet()" - return f"OSet({list(self)})" diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index a22299f..b8269e4 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -6,20 +6,13 @@ this uses an algorithm based on: """ from itertools import combinations -from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar +from typing import Generic, Iterable, Mapping, TypeVar from nmutil.plain_data import plain_data from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass, RegLoc, RegType, SSAVal) -from bigint_presentation_code.ordered_set import OFSet, OSet - -if TYPE_CHECKING: - from typing_extensions import final -else: - def final(v): - return v - +from bigint_presentation_code.util import OFSet, OSet, final _RegType = TypeVar("_RegType", bound=RegType) diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index c014d09..865015e 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -1,8 +1,226 @@ """ -Toom-Cook algorithm generator for SVP64 - -the register allocator uses an algorithm based on: -[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) +Toom-Cook multiplication algorithm generator for SVP64 """ -from bigint_presentation_code.compiler_ir import Op -from bigint_presentation_code.register_allocator import allocate_registers, AllocationFailed +from abc import abstractmethod +from enum import Enum +from fractions import Fraction +from typing import Any, Generic, Iterable, Sequence, TypeVar + +from nmutil.plain_data import plain_data + +from bigint_presentation_code.compiler_ir import Fn, Op +from bigint_presentation_code.util import Literal, OFSet, OSet, final + + +@final +class PointAtInfinity(Enum): + POINT_AT_INFINITY = "POINT_AT_INFINITY" + + def __repr__(self): + return self.name + + +POINT_AT_INFINITY = PointAtInfinity.POINT_AT_INFINITY +WORD_BITS = 64 + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpPoly: + """polynomial""" + __slots__ = "coefficients", + + def __init__(self, coefficients=()): + # type: (Iterable[Fraction | int] | EvalOpPoly | Fraction | int) -> None + if isinstance(coefficients, EvalOpPoly): + coefficients = coefficients.coefficients + elif isinstance(coefficients, (int, Fraction)): + coefficients = coefficients, + v = list(map(Fraction, coefficients)) + while len(v) != 0 and v[-1] == 0: + v.pop() + self.coefficients = tuple(v) # type: tuple[Fraction, ...] + + def __add__(self, rhs): + # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly + rhs = EvalOpPoly(rhs) + retval = list(self.coefficients) + extra = len(rhs.coefficients) - len(retval) + if extra > 0: + retval.extend([Fraction(0)] * extra) + for i, v in enumerate(rhs.coefficients): + retval[i] += v + return EvalOpPoly(retval) + + __radd__ = __add__ + + def __neg__(self): + return EvalOpPoly(-v for v in self.coefficients) + + def __sub__(self, rhs): + # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly + return self + -rhs + + def __rsub__(self, lhs): + # type: (EvalOpPoly | int | Fraction) -> EvalOpPoly + return lhs + -self + + def __mul__(self, rhs): + # type: (int | Fraction) -> EvalOpPoly + return EvalOpPoly(v * rhs for v in self.coefficients) + + __rmul__ = __mul__ + + def __truediv__(self, rhs): + # type: (int | Fraction) -> EvalOpPoly + if rhs == 0: + raise ZeroDivisionError() + return EvalOpPoly(v / rhs for v in self.coefficients) + + +_EvalOpLHS = TypeVar("_EvalOpLHS", int, "EvalOp") +_EvalOpRHS = TypeVar("_EvalOpRHS", int, "EvalOp") + + +@plain_data(frozen=True, unsafe_hash=True) +class EvalOp(Generic[_EvalOpLHS, _EvalOpRHS]): + __slots__ = "lhs", "rhs", "poly" + + @property + def lhs_poly(self): + # type: () -> EvalOpPoly + if isinstance(self.lhs, int): + return EvalOpPoly(self.lhs) + return self.lhs.poly + + @property + def rhs_poly(self): + # type: () -> EvalOpPoly + if isinstance(self.rhs, int): + return EvalOpPoly(self.rhs) + return self.rhs.poly + + @abstractmethod + def _make_poly(self): + # type: () -> EvalOpPoly + ... + + def __init__(self, lhs, rhs): + # type: (_EvalOpLHS, _EvalOpRHS) -> None + self.lhs = lhs + self.rhs = rhs + self.poly = self._make_poly() + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpAdd(EvalOp[_EvalOpLHS, _EvalOpRHS]): + __slots__ = () + + def _make_poly(self): + # type: () -> EvalOpPoly + return self.lhs_poly + self.rhs_poly + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpSub(EvalOp[_EvalOpLHS, _EvalOpRHS]): + __slots__ = () + + def _make_poly(self): + # type: () -> EvalOpPoly + return self.lhs_poly - self.rhs_poly + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpMul(EvalOp[_EvalOpLHS, int]): + __slots__ = () + + def _make_poly(self): + # type: () -> EvalOpPoly + return self.lhs_poly * self.rhs + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpExactDiv(EvalOp[_EvalOpLHS, int]): + __slots__ = () + + def _make_poly(self): + # type: () -> EvalOpPoly + return self.lhs_poly / self.rhs + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class EvalOpInput(EvalOp[int, Literal[0]]): + __slots__ = () + + def __init__(self, lhs, rhs=0): + # type: (...) -> None + if lhs < 0: + raise ValueError("Input split_index (lhs) must be >= 0") + if rhs != 0: + raise ValueError("Input rhs must be 0") + super().__init__(lhs, rhs) + + @property + def split_index(self): + return self.lhs + + def _make_poly(self): + # type: () -> EvalOpPoly + return EvalOpPoly([0] * self.split_index + [1]) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class ToomCookInstance: + __slots__ = ("lhs_split_count", "rhs_split_count", "eval_points", + "lhs_eval_ops", "rhs_eval_ops", "product_eval_ops") + + def __init__( + self, lhs_split_count, # type: int + rhs_split_count, # type: int + eval_points, # type: Iterable[PointAtInfinity | int] + lhs_eval_ops, # type: Iterable[EvalOp[Any, Any]] + rhs_eval_ops, # type: Iterable[EvalOp[Any, Any]] + product_eval_ops, # type: Iterable[EvalOp[Any, Any]] + ): + # type: (...) -> None + self.lhs_split_count = lhs_split_count + if self.lhs_split_count < 2: + raise ValueError("lhs_split_count must be at least 2") + self.rhs_split_count = rhs_split_count + if self.rhs_split_count < 2: + raise ValueError("rhs_split_count must be at least 2") + eval_points = list(eval_points) + self.eval_points = OFSet(eval_points) + if len(self.eval_points) != len(eval_points): + raise ValueError("duplicate eval points") + self.lhs_eval_ops = tuple(lhs_eval_ops) + if len(self.lhs_eval_ops) != len(self.eval_points): + raise ValueError("wrong number of lhs_eval_ops") + self.rhs_eval_ops = tuple(rhs_eval_ops) + if len(self.rhs_eval_ops) != len(self.eval_points): + raise ValueError("wrong number of rhs_eval_ops") + if self.lhs_split_count < 2: + raise ValueError("lhs_split_count must be at least 2") + if self.rhs_split_count < 2: + raise ValueError("rhs_split_count must be at least 2") + if (self.lhs_split_count + self.rhs_split_count - 1 + != len(self.eval_points)): + raise ValueError("wrong number of eval_points") + self.product_eval_ops = tuple(product_eval_ops) + if len(self.product_eval_ops) != len(self.eval_points): + raise ValueError("wrong number of product_eval_ops") + # TODO: compute and check matrix and all the *_eval_ops + raise NotImplementedError + + +def toom_cook_mul(fn, word_count, instances): + # type: (Fn, int, Sequence[ToomCookInstance]) -> OSet[Op] + retval = OSet() # type: OSet[Op] + raise NotImplementedError + return retval diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py new file mode 100644 index 0000000..b8b2934 --- /dev/null +++ b/src/bigint_presentation_code/util.py @@ -0,0 +1,113 @@ +from typing import (TYPE_CHECKING, AbstractSet, Iterable, Iterator, Mapping, + MutableSet, TypeVar, Union) + +if TYPE_CHECKING: + from typing_extensions import Literal, final +else: + def final(v): + return v + + class _Literal: + def __getitem__(self, v): + if isinstance(v, tuple): + return Union[tuple(type(i) for i in v)] + return type(v) + + Literal = _Literal() + +_T_co = TypeVar("_T_co", covariant=True) +_T = TypeVar("_T") + +__all__ = ["final", "Literal", "OFSet", "OSet", "FMap"] + + +class OFSet(AbstractSet[_T_co]): + """ ordered frozen set """ + __slots__ = "__items", + + def __init__(self, items=()): + # type: (Iterable[_T_co]) -> None + self.__items = {v: None for v in items} + + def __contains__(self, x): + return x in self.__items + + def __iter__(self): + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + def __hash__(self): + return self._hash() + + def __repr__(self): + if len(self) == 0: + return "OFSet()" + return f"OFSet({list(self)})" + + +class OSet(MutableSet[_T]): + """ ordered mutable set """ + __slots__ = "__items", + + def __init__(self, items=()): + # type: (Iterable[_T]) -> None + self.__items = {v: None for v in items} + + def __contains__(self, x): + return x in self.__items + + def __iter__(self): + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + def add(self, value): + # type: (_T) -> None + self.__items[value] = None + + def discard(self, value): + # type: (_T) -> None + self.__items.pop(value, None) + + def __repr__(self): + if len(self) == 0: + return "OSet()" + return f"OSet({list(self)})" + + +class FMap(Mapping[_T, _T_co]): + """ordered frozen hashable mapping""" + __slots__ = "__items", "__hash" + + def __init__(self, items=()): + # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None + self.__items = dict(items) # type: dict[_T, _T_co] + self.__hash = None # type: None | int + + def __getitem__(self, item): + # type: (_T) -> _T_co + return self.__items[item] + + def __iter__(self): + # type: () -> Iterator[_T] + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + def __eq__(self, other): + # type: (object) -> bool + if isinstance(other, FMap): + return self.__items == other.__items + return super().__eq__(other) + + def __hash__(self): + if self.__hash is None: + self.__hash = hash(frozenset(self.items())) + return self.__hash + + def __repr__(self): + return f"FMap({self.__items})" diff --git a/src/bigint_presentation_code/util.pyi b/src/bigint_presentation_code/util.pyi new file mode 100644 index 0000000..48445b1 --- /dev/null +++ b/src/bigint_presentation_code/util.pyi @@ -0,0 +1,81 @@ +from typing import (AbstractSet, Iterable, Iterator, Mapping, + MutableSet, TypeVar, overload) +from typing_extensions import final, Literal + +_T_co = TypeVar("_T_co", covariant=True) +_T = TypeVar("_T") + +__all__ = ["final", "Literal", "OFSet", "OSet", "FMap"] + + +class OFSet(AbstractSet[_T_co]): + """ ordered frozen set """ + + def __init__(self, items: Iterable[_T_co] = ()): + ... + + def __contains__(self, x: object) -> bool: + ... + + def __iter__(self) -> Iterator[_T_co]: + ... + + def __len__(self) -> int: + ... + + def __hash__(self) -> int: + ... + + def __repr__(self) -> str: + ... + + +class OSet(MutableSet[_T]): + """ ordered mutable set """ + + def __init__(self, items: Iterable[_T] = ()): + ... + + def __contains__(self, x: object) -> bool: + ... + + def __iter__(self) -> Iterator[_T]: + ... + + def __len__(self) -> int: + ... + + def add(self, value: _T) -> None: + ... + + def discard(self, value: _T) -> None: + ... + + def __repr__(self) -> str: + ... + + +class FMap(Mapping[_T, _T_co]): + """ordered frozen hashable mapping""" + @overload + def __init__(self, items: Mapping[_T, _T_co] = ...): ... + @overload + def __init__(self, items: Iterable[tuple[_T, _T_co]] = ...): ... + + def __getitem__(self, item: _T) -> _T_co: + ... + + def __iter__(self) -> Iterator[_T]: + ... + + def __len__(self) -> int: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __hash__(self) -> int: + ... + + def __repr__(self) -> str: + ...