From 77d159f36e5c24453c409f670a70e75cd4e273b0 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 6 Oct 2022 21:08:08 -0700 Subject: [PATCH] started working on toom_cook.py --- src/bigint_presentation_code/toom_cook.py | 707 ++++++++++++++++++++++ 1 file changed, 707 insertions(+) create mode 100644 src/bigint_presentation_code/toom_cook.py diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py new file mode 100644 index 0000000..bd8169b --- /dev/null +++ b/src/bigint_presentation_code/toom_cook.py @@ -0,0 +1,707 @@ +from abc import ABCMeta, abstractmethod +import builtins +from collections import defaultdict +from enum import Enum, unique +from typing import Iterable, Mapping, TYPE_CHECKING + +from nmutil.plain_data import plain_data + +if TYPE_CHECKING: + from typing_extensions import final +else: + def final(v): + return v + + +@plain_data(frozen=True, unsafe_hash=True) +class PhysLoc: + pass + + +@plain_data(frozen=True, unsafe_hash=True) +class GPROrStackLoc(PhysLoc): + pass + + +@final +class GPR(GPROrStackLoc, Enum): + def __init__(self, reg_num): + # type: (int) -> None + self.reg_num = reg_num + # fmt: off + R0 = 0; R1 = 1; R2 = 2; R3 = 3; R4 = 4; R5 = 5 + R6 = 6; R7 = 7; R8 = 8; R9 = 9; R10 = 10; R11 = 11 + R12 = 12; R13 = 13; R14 = 14; R15 = 15; R16 = 16; R17 = 17 + R18 = 18; R19 = 19; R20 = 20; R21 = 21; R22 = 22; R23 = 23 + R24 = 24; R25 = 25; R26 = 26; R27 = 27; R28 = 28; R29 = 29 + R30 = 30; R31 = 31; R32 = 32; R33 = 33; R34 = 34; R35 = 35 + R36 = 36; R37 = 37; R38 = 38; R39 = 39; R40 = 40; R41 = 41 + R42 = 42; R43 = 43; R44 = 44; R45 = 45; R46 = 46; R47 = 47 + R48 = 48; R49 = 49; R50 = 50; R51 = 51; R52 = 52; R53 = 53 + R54 = 54; R55 = 55; R56 = 56; R57 = 57; R58 = 58; R59 = 59 + R60 = 60; R61 = 61; R62 = 62; R63 = 63; R64 = 64; R65 = 65 + R66 = 66; R67 = 67; R68 = 68; R69 = 69; R70 = 70; R71 = 71 + R72 = 72; R73 = 73; R74 = 74; R75 = 75; R76 = 76; R77 = 77 + R78 = 78; R79 = 79; R80 = 80; R81 = 81; R82 = 82; R83 = 83 + R84 = 84; R85 = 85; R86 = 86; R87 = 87; R88 = 88; R89 = 89 + R90 = 90; R91 = 91; R92 = 92; R93 = 93; R94 = 94; R95 = 95 + R96 = 96; R97 = 97; R98 = 98; R99 = 99; R100 = 100; R101 = 101 + R102 = 102; R103 = 103; R104 = 104; R105 = 105; R106 = 106; R107 = 107 + R108 = 108; R109 = 109; R110 = 110; R111 = 111; R112 = 112; R113 = 113 + R114 = 114; R115 = 115; R116 = 116; R117 = 117; R118 = 118; R119 = 119 + R120 = 120; R121 = 121; R122 = 122; R123 = 123; R124 = 124; R125 = 125 + R126 = 126; R127 = 127 + # fmt: on + SP = 1 + TOC = 2 + + +SPECIAL_GPRS = GPR.R0, GPR.SP, GPR.TOC, GPR.R13 + + +@final +@unique +class XERBit(Enum, PhysLoc): + CY = "CY" + + +@final +@unique +class GlobalMem(Enum, PhysLoc): + """singleton representing all non-StackSlot memory""" + GlobalMem = "GlobalMem" + + +@plain_data() +@final +class StackSlot(GPROrStackLoc): + """a stack slot. Use OpCopy to load from/store into this stack slot.""" + __slots__ = "offset", + + def __init__(self, offset=None): + # type: (int | None) -> None + self.offset = offset + + +@plain_data(eq=False) +class SSAVal: + __slots__ = "id", + + def __init__(self, id=None): + # type: (int | None) -> None + if id is None: + id = builtins.id(self) + self.id = id + + def __eq__(self, rhs): + if isinstance(rhs, SSAVal): + return self.id == rhs.id + return False + + def __hash__(self): + return hash(self.id) + + +@plain_data(eq=False) +@final +class SSAGPRVal(SSAVal): + __slots__ = "phys_loc", + + def __init__(self, phys_loc=None): + # type: (GPROrStackLoc | None) -> None + self.phys_loc = phys_loc + super().__init__() + + def __len__(self): + return 1 + + def get_reg_num(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> int | None + phys_loc = None + if value_assignments is not None: + phys_loc = value_assignments.get(self) + if self.phys_loc is not None: + phys_loc = self.phys_loc + if isinstance(phys_loc, GPR): + return phys_loc.reg_num + return None + + +@plain_data(eq=False) +@final +class SSAXERBitVal(SSAVal): + __slots__ = "phys_loc", + + def __init__(self, phys_loc=None): + # type: (XERBit | None) -> None + self.phys_loc = phys_loc + + +@plain_data(eq=False) +@final +class SSAMemory(SSAVal): + __slots__ = "phys_loc", + + def __init__(self, phys_loc=GlobalMem.GlobalMem): + # type: (GlobalMem) -> None + self.phys_loc = phys_loc + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class VecArg: + __slots__ = "regs", + + def __init__(self, regs): + # type: (Iterable[SSAGPRVal]) -> None + self.regs = tuple(regs) + + def __len__(self): + return len(self.regs) + + def try_get_range(self, value_assignments=None, prefer_descending=False): + # type: (dict[SSAVal, PhysLoc] | None, int) -> range | None + if len(self.regs) == 0: + return range(0) + + start = self.regs[0].get_reg_num(value_assignments) + if start is None: + return None + if len(self.regs) == 1: + if prefer_descending: + return range(start, start - 1, -1) + return range(start, start + 1, 1) + reg = self.regs[1].get_reg_num(value_assignments) + if reg is None: + return None + step = reg - start + if abs(step) != 1: + return None + for i, reg in enumerate(self.regs): + reg = self.regs[1].get_reg_num(value_assignments) + if reg is None: + return None + if start + i * step != reg: + return None + return range(start, start + len(self.regs) * step, step) + + def try_get_ascending_range(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> range | None + r = self.try_get_range(value_assignments=value_assignments) + if r is not None and r.step == 1: + return r + return None + + def try_get_descending_range(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> range | None + r = self.try_get_range(value_assignments=value_assignments, + prefer_descending=True) + if r is not None and r.step == -1: + return r + return None + + +@plain_data(unsafe_hash=True, frozen=True) +class Op(metaclass=ABCMeta): + __slots__ = () + + def input_ssa_vals(self): + # type: () -> Iterable[SSAVal] + for arg in self.inputs().values(): + if isinstance(arg, VecArg): + yield from arg.regs + else: + yield arg + + def output_ssa_vals(self): + # type: () -> Iterable[SSAVal] + for arg in self.outputs().values(): + if isinstance(arg, VecArg): + yield from arg.regs + else: + yield arg + + @abstractmethod + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + ... + + @abstractmethod + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + ... + + @abstractmethod + def meets_constraints(self, value_assignments): + # type: (dict[SSAVal, PhysLoc]) -> bool + return True + + def __init__(self): + pass + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpCopy(Op): + __slots__ = "dest", "src" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"dest": self.dest} + + def __init__(self, dest, src): + # type: (VecArg | SSAVal, VecArg | SSAVal) -> None + if isinstance(dest, VecArg) and isinstance(src, VecArg): + if len(src.regs) != len(dest.regs): + raise TypeError(f"source length must match dest " + f"length: {src} doesn't match {dest}") + elif type(dest) != type(src): + raise TypeError(f"source argument type must match dest " + f"argument type: {src} doesn't match {dest}") + self.dest = dest + self.src = src + + def meets_constraints(self, value_assignments): + # type: (dict[SSAVal, PhysLoc]) -> bool + return True + + +def range_overlaps(range1, range2): + # type: (range, range) -> bool + if len(range1) == 0 or len(range2) == 0: + return False + range1_last = next(reversed(range1)) + range2_last = next(reversed(range2)) + return (range1.start in range2 or range1_last in range2 or + range2.start in range1 or range2_last in range1) + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpAddE(Op): + __slots__ = "RT", "RA", "RB", "CY_in", "CY_out" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RT": self.RT, "CY_out": self.CY_out} + + def __init__(self, RT, RA, RB, CY_in, CY_out): + # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None + if len(RA.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {RA} doesn't match {RT}") + if len(RB.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {RB} doesn't match {RT}") + self.RT = RT + self.RA = RA + self.RB = RB + self.CY_in = CY_in + self.CY_out = CY_out + + def meets_constraints(self, value_assignments): + # type: (dict[SSAVal, PhysLoc]) -> bool + RT_range = self.RT.try_get_ascending_range(value_assignments) + RA_range = self.RA.try_get_ascending_range(value_assignments) + RB_range = self.RB.try_get_ascending_range(value_assignments) + if RA_range is None or RB_range is None or RT_range is None: + return False + if RA_range != RT_range and range_overlaps(RA_range, RT_range): + return False + if RB_range != RT_range and range_overlaps(RB_range, RT_range): + return False + return True + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpSubE(Op): + __slots__ = "RT", "lhs", "rhs", "CY_in", "CY_out" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"lhs": self.lhs, "rhs": self.rhs, "CY_in": self.CY_in} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RT": self.RT, "CY_out": self.CY_out} + + def __init__(self, RT, lhs, rhs, CY_in, CY_out): + # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None + if len(lhs.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {lhs} doesn't match {RT}") + if len(rhs.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {rhs} doesn't match {RT}") + self.RT = RT + self.lhs = lhs + self.rhs = rhs + self.CY_in = CY_in + self.CY_out = CY_out + + def meets_constraints(self, value_assignments): + # type: (dict[SSAVal, PhysLoc]) -> bool + RT_range = self.RT.try_get_ascending_range(value_assignments) + lhs_range = self.lhs.try_get_ascending_range(value_assignments) + rhs_range = self.rhs.try_get_ascending_range(value_assignments) + if lhs_range is None or rhs_range is None or RT_range is None: + return False + if lhs_range != RT_range and range_overlaps(lhs_range, RT_range): + return False + if rhs_range != RT_range and range_overlaps(rhs_range, RT_range): + return False + return True + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpMAddEU(Op): + __slots__ = "RT", "RA", "RB", "RC", "RS" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RA": self.RA, "RB": self.RB, "RC": self.RC} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RT": self.RT, "RS": self.RS} + + def __init__(self, RT, RA, RB, RC, RS): + # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None + if len(RA.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {RA} doesn't match {RT}") + self.RT = RT + self.RA = RA + self.RB = RB + self.RC = RC + self.RS = RS + + def meets_constraints(self, value_assignments): + # type: (dict[SSAVal, PhysLoc]) -> bool + RT_range = self.RT.try_get_ascending_range(value_assignments) + RA_range = self.RA.try_get_ascending_range(value_assignments) + RB_reg = self.RB.get_reg_num(value_assignments) + RC_reg = self.RC.get_reg_num(value_assignments) + RS_reg = self.RS.get_reg_num(value_assignments) + if RT_range is None or RA_range is None or RB_reg is None \ + or RC_reg is None or RS_reg is None: + return False + if RA_range != RT_range and range_overlaps(RA_range, RT_range): + return False + if RB_reg in RT_range or RC_reg in RT_range or RS_reg in RT_range \ + or RS_reg in RA_range: + return False + if RS_reg != RC_reg: + return False + return True + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpDivMod2DU(Op): + __slots__ = "RT", "RA", "RB", "RC", "RS" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RA": self.RA, "RB": self.RB, "RC": self.RC} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RT": self.RT, "RS": self.RS} + + def __init__(self, RT, RA, RB, RC, RS): + # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None + if len(RA.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {RA} doesn't match {RT}") + self.RT = RT + self.RA = RA + self.RB = RB + self.RC = RC + self.RS = RS + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpBigIntShl(Op): + __slots__ = "RT", "inp", "sh" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"inp": self.inp, "sh": self.sh} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RT": self.RT} + + def __init__(self, RT, inp, sh): + # type: (VecArg, VecArg, SSAGPRVal) -> None + if len(inp.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {inp} doesn't match {RT}") + self.RT = RT + self.inp = inp + self.sh = sh + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpBigIntShr(Op): + __slots__ = "RT", "inp", "sh" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"inp": self.inp, "sh": self.sh} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RT": self.RT} + + def __init__(self, RT, inp, sh): + # type: (VecArg, VecArg, SSAGPRVal) -> None + if len(inp.regs) != len(RT.regs): + raise TypeError(f"source length must match dest " + f"length: {inp} doesn't match {RT}") + self.RT = RT + self.inp = inp + self.sh = sh + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpLI(Op): + __slots__ = "out", "value" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"out": self.out} + + def __init__(self, out, value): + # type: (VecArg | SSAGPRVal, int) -> None + self.out = out + self.value = value + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpClearCY(Op): + __slots__ = "out", + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"out": self.out} + + def __init__(self, out): + # type: (SSAXERBitVal) -> None + self.out = out + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpLoad(Op): + __slots__ = "RT", "RA", "offset", "mem" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RA": self.RA, "mem": self.mem} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RT": self.RT} + + def __init__(self, RT, RA, offset, mem): + # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None + self.RT = RT + self.RA = RA + self.offset = offset + self.mem = mem + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpStore(Op): + __slots__ = "RS", "RA", "offset", "mem_in", "mem_out" + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"mem_out": self.mem_out} + + def __init__(self, RS, RA, offset, mem_in, mem_out): + # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory, SSAMemory) -> None + self.RS = RS + self.RA = RA + self.offset = offset + self.mem_in = mem_in + self.mem_out = mem_out + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpFuncArg(Op): + __slots__ = "out", + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"out": self.out} + + def __init__(self, out): + # type: (VecArg | SSAGPRVal) -> None + self.out = out + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpInputMem(Op): + __slots__ = "out", + + def inputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, VecArg | SSAVal] + return {"out": self.out} + + def __init__(self, out): + # type: (SSAMemory) -> None + self.out = out + + +def op_set_to_list(ops): + # type: (Iterable[Op]) -> list[Op] + worklists = [set()] # type: list[set[Op]] + input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]] + ops_to_pending_input_count_map = {} # type: dict[Op, int] + for op in ops: + input_count = 0 + for val in op.input_ssa_vals(): + input_count += 1 + input_vals_to_ops_map[val].add(op) + while len(worklists) <= input_count: + worklists.append(set()) + ops_to_pending_input_count_map[op] = input_count + worklists[input_count].add(op) + retval = [] # type: list[Op] + ready_vals = set() # type: set[SSAVal] + while len(worklists[0]) != 0: + writing_op = worklists[0].pop() + retval.append(writing_op) + for val in writing_op.output_ssa_vals(): + if val in ready_vals: + raise ValueError(f"multiple instructions must not write " + f"to the same SSA value: {val}") + ready_vals.add(val) + for reading_op in input_vals_to_ops_map[val]: + pending = ops_to_pending_input_count_map[reading_op] + worklists[pending].remove(reading_op) + pending -= 1 + worklists[pending].add(reading_op) + ops_to_pending_input_count_map[reading_op] = pending + for worklist in worklists: + for op in worklist: + raise ValueError(f"instruction is part of a dependency loop or " + f"its inputs are never written: {op}") + return retval + + +@plain_data(unsafe_hash=True, order=True, frozen=True) +class LiveInterval: + __slots__ = "assignment", "last_use" + + def __init__(self, assignment, last_use=None): + # type: (int, int | None) -> None + if last_use is None: + last_use = assignment + if last_use < assignment: + raise ValueError("uses must be after assignment") + if assignment < 0 or last_use < 0: + raise ValueError("indexes must be nonnegative") + self.assignment = assignment + self.last_use = last_use + + def overlaps(self, other): + # type: (LiveInterval) -> bool + if self.assignment == other.assignment: + return True + return self.last_use > other.assignment \ + and other.last_use > self.assignment + + def __add__(self, use): + # type: (int) -> LiveInterval + last_use = max(self.last_use, use) + return LiveInterval(assignment=self.assignment, last_use=last_use) + + +class LiveIntervals(Mapping[SSAVal, LiveInterval]): + def __init__(self, ops): + # type: (list[Op]) -> None + live_intervals = {} # type: dict[SSAVal, LiveInterval] + for op_idx, op in enumerate(ops): + for val in op.input_ssa_vals(): + live_intervals[val] += op_idx + for val in op.output_ssa_vals(): + if val in live_intervals: + raise ValueError(f"multiple instructions must not write " + f"to the same SSA value: {val}") + live_intervals[val] = LiveInterval(op_idx) + self.__live_intervals = live_intervals + + def __getitem__(self, key): + # type: (SSAVal) -> LiveInterval + return self.__live_intervals[key] + + def __iter__(self): + return iter(self.__live_intervals) + + +@plain_data() +class AllocationFailed: + __slots__ = "op_idx", "arg", "live_intervals", "free_regs" + + def __init__(self, op_idx, arg, live_intervals, free_regs): + # type: (int, SSAVal | VecArg, LiveIntervals, set[GPR | XERBit]) -> None + self.op_idx = op_idx + self.arg = arg + self.live_intervals = live_intervals + self.free_regs = free_regs + + +def try_allocate_registers_without_spilling(ops): + # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed + live_intervals = LiveIntervals(ops) + free_regs = set() # type: set[GPR | XERBit] + free_regs.update(GPR) + free_regs.difference_update(SPECIAL_GPRS) + free_regs.update(XERBit) + raise NotImplementedError + + +def allocate_registers(ops): + # type: (list[Op]) -> None + raise NotImplementedError -- 2.30.2