From: Jacob Lifshay Date: Wed, 12 Oct 2022 06:49:59 +0000 (-0700) Subject: work on switching algorithms X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=5cf39a0d7bd7d79bbb195e65d25cd133613bba16;p=bigint-presentation-code.git work on switching algorithms --- diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index dd9061b..02ddb4b 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -1,317 +1,300 @@ +""" +Toom-Cook algorithm generator for SVP64 + +the register allocator uses an algorithm based on: +[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) +""" + from abc import ABCMeta, abstractmethod from collections import defaultdict from enum import Enum, unique -from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING +from functools import lru_cache +from typing import (Sequence, AbstractSet, Iterable, Mapping, + TYPE_CHECKING, Sequence, TypeVar) from nmutil.plain_data import plain_data if TYPE_CHECKING: from typing_extensions import final, Self + from typing import Generic else: def final(v): return v + # make plain_data work with Generics + class Generic: + def __class_getitem__(cls, item): + return object + @plain_data(frozen=True, unsafe_hash=True) -class PhysLoc: - pass +class PhysLoc(metaclass=ABCMeta): + __slots__ = () @plain_data(frozen=True, unsafe_hash=True) -class GPROrStackLoc(PhysLoc): - pass +class RegLoc(PhysLoc): + __slots__ = () + + @abstractmethod + def conflicts(self, other): + # type: (RegLoc) -> bool + ... + +@plain_data(frozen=True, unsafe_hash=True) +class GPRRangeOrStackLoc(PhysLoc): + __slots__ = () + @abstractmethod + def __len__(self): + # type: () -> int + ... + + +GPR_COUNT = 128 + + +@plain_data(frozen=True, unsafe_hash=True) @final -class GPR(GPROrStackLoc, Enum): - def __init__(self, reg_num): - # type: (int) -> None - self.reg_num = reg_num - # fmt: off - R0 = 0; R1 = 1; R2 = 2; R3 = 3; R4 = 4; R5 = 5 - R6 = 6; R7 = 7; R8 = 8; R9 = 9; R10 = 10; R11 = 11 - R12 = 12; R13 = 13; R14 = 14; R15 = 15; R16 = 16; R17 = 17 - R18 = 18; R19 = 19; R20 = 20; R21 = 21; R22 = 22; R23 = 23 - R24 = 24; R25 = 25; R26 = 26; R27 = 27; R28 = 28; R29 = 29 - R30 = 30; R31 = 31; R32 = 32; R33 = 33; R34 = 34; R35 = 35 - R36 = 36; R37 = 37; R38 = 38; R39 = 39; R40 = 40; R41 = 41 - R42 = 42; R43 = 43; R44 = 44; R45 = 45; R46 = 46; R47 = 47 - R48 = 48; R49 = 49; R50 = 50; R51 = 51; R52 = 52; R53 = 53 - R54 = 54; R55 = 55; R56 = 56; R57 = 57; R58 = 58; R59 = 59 - R60 = 60; R61 = 61; R62 = 62; R63 = 63; R64 = 64; R65 = 65 - R66 = 66; R67 = 67; R68 = 68; R69 = 69; R70 = 70; R71 = 71 - R72 = 72; R73 = 73; R74 = 74; R75 = 75; R76 = 76; R77 = 77 - R78 = 78; R79 = 79; R80 = 80; R81 = 81; R82 = 82; R83 = 83 - R84 = 84; R85 = 85; R86 = 86; R87 = 87; R88 = 88; R89 = 89 - R90 = 90; R91 = 91; R92 = 92; R93 = 93; R94 = 94; R95 = 95 - R96 = 96; R97 = 97; R98 = 98; R99 = 99; R100 = 100; R101 = 101 - R102 = 102; R103 = 103; R104 = 104; R105 = 105; R106 = 106; R107 = 107 - R108 = 108; R109 = 109; R110 = 110; R111 = 111; R112 = 112; R113 = 113 - R114 = 114; R115 = 115; R116 = 116; R117 = 117; R118 = 118; R119 = 119 - R120 = 120; R121 = 121; R122 = 122; R123 = 123; R124 = 124; R125 = 125 - R126 = 126; R127 = 127 - # fmt: on - SP = 1 - TOC = 2 - - -SPECIAL_GPRS = GPR.R0, GPR.SP, GPR.TOC, GPR.R13 +class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]): + __slots__ = "start", "length" + + def __init__(self, start, length=None): + # type: (int | range, int | None) -> None + if isinstance(start, range): + if length is not None: + raise TypeError("can't specify length when input is a range") + if start.step != 1: + raise ValueError("range must have a step of 1") + length = len(start) + start = start.start + elif length is None: + length = 1 + if length <= 0 or start < 0 or start + length > GPR_COUNT: + raise ValueError("invalid GPRRange") + self.start = start + self.length = length + + @property + def stop(self): + return self.start + self.length + + @property + def step(self): + return 1 + + @property + def range(self): + return range(self.start, self.stop, self.step) + + def __len__(self): + return self.length + + def __getitem__(self, item): + # type: (int | slice) -> GPRRange + return GPRRange(self.range[item]) + + def __contains__(self, value): + # type: (GPRRange) -> bool + return value.start >= self.start and value.stop <= self.stop + + def index(self, sub, start=None, end=None): + # type: (GPRRange, int | None, int | None) -> int + r = self.range[start:end] + if sub.start < r.start or sub.stop > r.stop: + raise ValueError("GPR range not found") + return sub.start - self.start + + def count(self, sub, start=None, end=None): + # type: (GPRRange, int | None, int | None) -> int + r = self.range[start:end] + if len(r) == 0: + return 0 + return int(sub in GPRRange(r)) + + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, GPRRange): + return self.stop > other.start and other.stop > self.start + return False + + +SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13) @final @unique -class XERBit(Enum, PhysLoc): +class XERBit(Enum, RegLoc): CY = "CY" + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, XERBit): + return self == other + return False + @final @unique -class GlobalMem(Enum, PhysLoc): - """singleton representing all non-StackSlot memory""" +class GlobalMem(Enum, RegLoc): + """singleton representing all non-StackSlot memory -- treated as a single + physical register for register allocation purposes. + """ GlobalMem = "GlobalMem" - -ALLOCATABLE_REGS = frozenset((set(GPR) - set(SPECIAL_GPRS)) - | set(XERBit) | set(GlobalMem)) + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, GlobalMem): + return self == other + return False -@plain_data() @final -class StackSlot(GPROrStackLoc): - """a stack slot. Use OpCopy to load from/store into this stack slot.""" - __slots__ = "offset", +class RegClass(AbstractSet[RegLoc]): + def __init__(self, regs): + # type: (Iterable[RegLoc]) -> None + self.__regs = frozenset(regs) - def __init__(self, offset=None): - # type: (int | None) -> None - self.offset = offset + def __len__(self): + return len(self.__regs) + def __iter__(self): + return iter(self.__regs) -class SSAVal(metaclass=ABCMeta): - __slots__ = "op", "arg_name", "element_index" + def __contains__(self, v): + # type: (RegLoc) -> bool + return v in self.__regs - def __init__(self, op, arg_name, element_index): - # type: (Op, str, int) -> None - self.op = op - """the Op that writes this SSAVal""" + def __hash__(self): + return super()._hash() - self.arg_name = arg_name - self.element_index = element_index - @final - def __eq__(self, rhs): - if isinstance(rhs, SSAVal): - return (self.op is rhs.op - and self.arg_name == rhs.arg_name - and self.element_index == rhs.element_index) - return False +@plain_data(frozen=True, unsafe_hash=True) +class RegType(metaclass=ABCMeta): + __slots__ = () - @final - def __hash__(self): - return hash((id(self.op), self.arg_name, self.element_index)) + @property + @abstractmethod + def reg_class(self): + # type: () -> RegClass + return ... - def _get_phys_loc(self, phys_loc_in, value_assignments=None): - # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None - if phys_loc_in is not None: - return phys_loc_in - if value_assignments is not None: - return value_assignments.get(self) - return None - @abstractmethod - def get_phys_loc(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None - ... +@plain_data(frozen=True, eq=False) +class GPRRangeType(RegType): + __slots__ = "length", + + def __init__(self, length): + # type: (int) -> None + if length < 1 or length > GPR_COUNT: + raise ValueError("invalid length") + self.length = length + + @staticmethod + @lru_cache() + def __get_reg_class(length): + # type: (int) -> RegClass + regs = [] + for start in range(GPR_COUNT - length): + reg = GPRRange(start, length) + if any(i in reg for i in SPECIAL_GPRS): + continue + regs.append(reg) + return RegClass(regs) + + @property + def reg_class(self): + # type: () -> RegClass + return GPRRangeType.__get_reg_class(self.length) @final - def __repr__(self): - name = self.__class__.__name__ - op = object.__repr__(self.op) - phys_loc = self.get_phys_loc() - return (f"{name}(op={op}, arg_name={self.arg_name}, " - f"element_index={self.element_index}, phys_loc={phys_loc})") + def __eq__(self, other): + if isinstance(other, GPRRangeType): + return self.length == other.length + return False @final - def like(self, op, arg_name): - # type: (Op, str) -> Self - """create a new SSAVal based off of self's type. - has same signature as VecArg.like. - """ - return self.__class__(op=op, arg_name=arg_name, - element_index=0) + def __hash__(self): + return hash(self.length) +@plain_data(frozen=True, eq=False) @final -class SSAGPRVal(SSAVal): - __slots__ = "phys_loc", +class GPRType(GPRRangeType): + __slots__ = () - def __init__(self, op, arg_name, element_index, phys_loc=None): - # type: (Op, str, int, GPROrStackLoc | None) -> None - super().__init__(op, arg_name, element_index) - self.phys_loc = phys_loc + def __init__(self, length=1): + if length != 1: + raise ValueError("length must be 1") + super().__init__(length=1) - def __len__(self): - return 1 - def get_phys_loc(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None - loc = self._get_phys_loc(self.phys_loc, value_assignments) - if isinstance(loc, GPROrStackLoc): - return loc - return None - - def get_reg_num(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> int | None - reg = self.get_reg(value_assignments) - if reg is not None: - return reg.reg_num - return None - - def get_reg(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None - loc = self.get_phys_loc(value_assignments) - if isinstance(loc, GPR): - return loc - return None - - def get_stack_slot(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None - loc = self.get_phys_loc(value_assignments) - if isinstance(loc, StackSlot): - return loc - return None - - def possible_reg_assignments(self, value_assignments, - conflicting_regs=set()): - # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR] - if self.get_phys_loc(value_assignments) is not None: - raise ValueError("can't assign a already-assigned SSA value") - for reg in GPR: - if reg not in conflicting_regs: - yield reg +@plain_data(frozen=True, unsafe_hash=True) +@final +class CYType(RegType): + __slots__ = () + @property + def reg_class(self): + # type: () -> RegClass + return RegClass([XERBit.CY]) -@final -class SSAXERBitVal(SSAVal): - __slots__ = "phys_loc", - def __init__(self, op, arg_name, element_index, phys_loc=None): - # type: (Op, str, int, XERBit | None) -> None - super().__init__(op, arg_name, element_index) - self.phys_loc = phys_loc +@plain_data(frozen=True, unsafe_hash=True) +@final +class GlobalMemType(RegType): + __slots__ = () - def get_phys_loc(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None - loc = self._get_phys_loc(self.phys_loc, value_assignments) - if isinstance(loc, XERBit): - return loc - return None + @property + def reg_class(self): + # type: () -> RegClass + return RegClass([GlobalMem.GlobalMem]) +@plain_data() @final -class SSAMemory(SSAVal): - __slots__ = "phys_loc", +class StackSlot(GPRRangeOrStackLoc): + """a stack slot. Use OpCopy to load from/store into this stack slot.""" + __slots__ = "offset", "length" - def __init__(self, op, arg_name, element_index, - phys_loc=GlobalMem.GlobalMem): - # type: (Op, str, int, GlobalMem) -> None - super().__init__(op, arg_name, element_index) - self.phys_loc = phys_loc + def __init__(self, offset=None, length=1): + # type: (int | None, int) -> None + self.offset = offset + if length < 1: + raise ValueError("invalid length") + self.length = length - def get_phys_loc(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem - loc = self._get_phys_loc(self.phys_loc, value_assignments) - if isinstance(loc, GlobalMem): - return loc - return self.phys_loc + def __len__(self): + return self.length -@plain_data(unsafe_hash=True, frozen=True) +_RegType = TypeVar("_RegType", bound=RegType) + + +@plain_data(frozen=True, eq=False) @final -class VecArg: - __slots__ = "regs", +class SSAVal(Generic[_RegType]): + __slots__ = "op", "arg_name", "ty", "arg_index" - def __init__(self, regs): - # type: (Iterable[SSAGPRVal]) -> None - self.regs = tuple(regs) + def __init__(self, op, arg_name, ty): + # type: (Op, str, _RegType) -> None + self.op = op + """the Op that writes this SSAVal""" - def __len__(self): - return len(self.regs) - - def is_unassigned(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> bool - for val in self.regs: - if val.get_phys_loc(value_assignments) is not None: - return False - return True - - def try_get_range(self, value_assignments=None, allow_unassigned=False, - raise_if_invalid=False): - # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None - if len(self.regs) == 0: - return range(0) - - retval = None # type: range | None - for i, val in enumerate(self.regs): - if val.get_phys_loc(value_assignments) is None: - if not allow_unassigned: - if raise_if_invalid: - raise ValueError("not a valid register range: " - "unassigned SSA value encountered") - return None - continue - reg = val.get_reg_num(value_assignments) - if reg is None: - if raise_if_invalid: - raise ValueError("not a valid register range: " - "non-register encountered") - return None - expected_range = range(reg - i, reg - i + len(self.regs)) - if retval is None: - retval = expected_range - elif retval != expected_range: - if raise_if_invalid: - raise ValueError("not a valid register range: " - "register out of sequence") - return None - return retval - - def possible_reg_assignments( - self, - val, # type: SSAVal - value_assignments, # type: dict[SSAVal, PhysLoc] | None - conflicting_regs=set(), # type: set[GPR] - ): # type: (...) -> Iterable[GPR] - index = self.regs.index(val) - alignment = 1 - while alignment < len(self.regs): - alignment *= 2 - r = self.try_get_range(value_assignments) - if r is not None and r.start % alignment != 0: - raise ValueError("must be a ascending aligned range of GPRs") - if r is None: - for i in range(0, len(GPR), alignment): - r = range(i, i + len(self.regs)) - if any(GPR(reg) in conflicting_regs for reg in r): - continue - yield GPR(r[index]) - else: - yield GPR(r[index]) + self.arg_name = arg_name + """the name of the argument of self.op that writes this SSAVal""" - def like(self, op, arg_name): - # type: (Op, str) -> VecArg - """create a new VecArg based off of self's type. - has same signature as SSAVal.like. - """ - return VecArg( - SSAGPRVal(op, arg_name, i) for i in range(len(self.regs))) + self.ty = ty + def __eq__(self, rhs): + if isinstance(rhs, SSAVal): + return (self.op is rhs.op + and self.arg_name == rhs.arg_name) + return False -def vec_or_scalar_arg(element_count, op, arg_name): - # type: (int | None, Op, str) -> VecArg | SSAGPRVal - if element_count is None: - return SSAGPRVal(op, arg_name, 0) - else: - return VecArg(SSAGPRVal(op, arg_name, i) for i in range(element_count)) + def __hash__(self): + return hash((id(self.op), self.arg_name)) @final @@ -320,44 +303,25 @@ class EqualityConstraint: __slots__ = "lhs", "rhs" def __init__(self, lhs, rhs): - # type: (SSAVal, SSAVal) -> None + # type: (list[SSAVal], list[SSAVal]) -> None self.lhs = lhs self.rhs = rhs + if len(lhs) == 0 or len(rhs) == 0: + raise ValueError("can't constrain an empty list to be equal") @plain_data(unsafe_hash=True, frozen=True) class Op(metaclass=ABCMeta): __slots__ = () - def input_ssa_vals(self): - # type: () -> Iterable[SSAVal] - for arg in self.inputs().values(): - if isinstance(arg, VecArg): - yield from arg.regs - else: - yield arg - - def output_ssa_vals(self): - # type: () -> Iterable[SSAVal] - for arg in self.outputs().values(): - if isinstance(arg, VecArg): - yield from arg.regs - else: - yield arg - @abstractmethod def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] ... @abstractmethod def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] - ... - - @abstractmethod - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + # type: () -> dict[str, SSAVal] ... def get_equality_constraints(self): @@ -371,42 +335,79 @@ class Op(metaclass=ABCMeta): @plain_data(unsafe_hash=True, frozen=True) @final -class OpCopy(Op): +class OpCopy(Op, Generic[_RegType]): __slots__ = "dest", "src" def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"src": self.src} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"dest": self.dest} def __init__(self, src): - # type: (SSAGPRVal) -> None - self.dest = src.like(op=self, arg_name="dest") + # type: (SSAVal[_RegType]) -> None + self.dest = SSAVal(self, "dest", src.ty) self.src = src - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - if not isinstance(val, SSAGPRVal): - raise ValueError("invalid operand type") - return val.possible_reg_assignments(value_assignments) - - -def range_overlaps(range1, range2): - # type: (range, range) -> bool - if len(range1) == 0 or len(range2) == 0: - return False - range1_last = range1[-1] - range2_last = range2[-1] - return (range1.start in range2 or range1_last in range2 or - range2.start in range1 or range2_last in range1) + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpConcat(Op): + __slots__ = "dest", "sources" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {f"sources[{i}]": v for i, v in enumerate(self.sources)} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"dest": self.dest} + + def __init__(self, sources): + # type: (Iterable[SSAVal[GPRRangeType]]) -> None + sources = tuple(sources) + self.dest = SSAVal(self, "dest", GPRRangeType( + sum(i.ty.length for i in sources))) + self.sources = sources + + def get_equality_constraints(self): + # type: () -> Iterable[EqualityConstraint] + yield EqualityConstraint([self.dest], [*self.sources]) + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpSplit(Op): + __slots__ = "results", "src" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {i.arg_name: i for i in self.results} + + def __init__(self, src, split_indexes): + # type: (SSAVal[GPRRangeType], Iterable[int]) -> None + ranges = [] # type: list[GPRRangeType] + last = 0 + for i in split_indexes: + if not (0 < i < src.ty.length): + raise ValueError(f"invalid split index: {i}, must be in " + f"0 < i < {src.ty.length}") + ranges.append(GPRRangeType(i - last)) + last = i + ranges.append(GPRRangeType(src.ty.length - last)) + self.src = src + self.results = tuple( + SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges)) + + def get_equality_constraints(self): + # type: () -> Iterable[EqualityConstraint] + yield EqualityConstraint([*self.results], [self.src]) @plain_data(unsafe_hash=True, frozen=True) @@ -415,59 +416,25 @@ class OpAddSubE(Op): __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub" def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RT": self.RT, "CY_out": self.CY_out} def __init__(self, RA, RB, CY_in, is_sub): - # type: (VecArg, VecArg, SSAXERBitVal, bool) -> None - if len(RA.regs) != len(RB.regs): - raise TypeError(f"source lengths must match: " + # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None + if RA.ty != RB.ty: + raise TypeError(f"source types must match: " f"{RA} doesn't match {RB}") - self.RT = RA.like(op=self, arg_name="RT") + self.RT = SSAVal(self, "RT", RA.ty) self.RA = RA self.RB = RB self.CY_in = CY_in - self.CY_out = CY_in.like(op=self, arg_name="CY_out") + self.CY_out = SSAVal(self, "CY_out", CY_in.ty) self.is_sub = is_sub - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - if self.CY_in == val or self.CY_out == val: - yield XERBit.CY - elif val in self.RT.regs: - # since possible_reg_assignments only returns aligned - # vectors, all possible assignments either are the same as an - # input or don't overlap with an input and we avoid the incorrect - # results caused by partial overlaps overwriting input elements - # before they're read - yield from self.RT.possible_reg_assignments(val, value_assignments) - elif val in self.RA.regs: - yield from self.RA.possible_reg_assignments(val, value_assignments) - else: - yield from self.RB.possible_reg_assignments(val, value_assignments) - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint(self.CY_in, self.CY_out) - - -def to_reg_set(v): - # type: (None | GPR | range) -> set[GPR] - if v is None: - return set() - if isinstance(v, range): - return set(map(GPR, v)) - return {v} - @plain_data(unsafe_hash=True, frozen=True) @final @@ -475,64 +442,25 @@ class OpBigIntMulDiv(Op): __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div" def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RA": self.RA, "RB": self.RB, "RC": self.RC} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RT": self.RT, "RS": self.RS} def __init__(self, RA, RB, RC, is_div): - # type: (VecArg, SSAGPRVal, SSAGPRVal, bool) -> None - self.RT = RA.like(op=self, arg_name="RT") + # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None + self.RT = SSAVal(self, "RT", RA.ty) self.RA = RA self.RB = RB self.RC = RC - self.RS = RC.like(op=self, arg_name="RS") + self.RS = SSAVal(self, "RS", RC.ty) self.is_div = is_div - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - RT_range = self.RT.try_get_range(value_assignments, - allow_unassigned=True, - raise_if_invalid=True) - RA_range = self.RA.try_get_range(value_assignments, - allow_unassigned=True, - raise_if_invalid=True) - RC_RS_reg = self.RC.get_reg(value_assignments) - if RC_RS_reg is None: - RC_RS_reg = self.RS.get_reg(value_assignments) - - if self.RC == val or self.RS == val: - if RC_RS_reg is not None: - yield RC_RS_reg - else: - conflicting_regs = to_reg_set(RT_range) | to_reg_set(RA_range) - yield from self.RC.possible_reg_assignments(value_assignments, - conflicting_regs) - elif val in self.RT.regs: - # since possible_reg_assignments only returns aligned - # vectors, all possible assignments either are the same as - # RA or don't overlap with RA and we avoid the incorrect - # results caused by partial overlaps overwriting input elements - # before they're read - yield from self.RT.possible_reg_assignments( - val, value_assignments, - conflicting_regs=to_reg_set(RA_range) | to_reg_set(RC_RS_reg)) - else: - yield from self.RA.possible_reg_assignments( - val, value_assignments, - conflicting_regs=to_reg_set(RT_range) | to_reg_set(RC_RS_reg)) - def get_equality_constraints(self): # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint(self.RC, self.RS) + yield EqualityConstraint([self.RC], [self.RS]) @final @@ -549,54 +477,20 @@ class OpBigIntShift(Op): __slots__ = "RT", "inp", "sh", "kind" def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"inp": self.inp, "sh": self.sh} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RT": self.RT} def __init__(self, inp, sh, kind): - # type: (VecArg, SSAGPRVal, ShiftKind) -> None - self.RT = inp.like(op=self, arg_name="RT") + # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None + self.RT = SSAVal(self, "RT", inp.ty) self.inp = inp self.sh = sh self.kind = kind - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - RT_range = self.RT.try_get_range(value_assignments, - allow_unassigned=True, - raise_if_invalid=True) - inp_range = self.inp.try_get_range(value_assignments, - allow_unassigned=True, - raise_if_invalid=True) - sh_reg = self.sh.get_reg(value_assignments) - - if self.sh == val: - conflicting_regs = to_reg_set(RT_range) - yield from self.sh.possible_reg_assignments(value_assignments, - conflicting_regs) - elif val in self.RT.regs: - # since possible_reg_assignments only returns aligned - # vectors, all possible assignments either are the same as - # RA or don't overlap with RA and we avoid the incorrect - # results caused by partial overlaps overwriting input elements - # before they're read - yield from self.RT.possible_reg_assignments( - val, value_assignments, - conflicting_regs=to_reg_set(inp_range) | to_reg_set(sh_reg)) - else: - yield from self.inp.possible_reg_assignments( - val, value_assignments, - conflicting_regs=to_reg_set(RT_range)) - @plain_data(unsafe_hash=True, frozen=True) @final @@ -604,32 +498,18 @@ class OpLI(Op): __slots__ = "out", "value" def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self, value, element_count=None): - # type: (int, int | None) -> None - self.out = vec_or_scalar_arg(element_count, op=self, arg_name="out") + def __init__(self, value, length=1): + # type: (int, int) -> None + self.out = SSAVal(self, "out", GPRRangeType(length)) self.value = value - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - if isinstance(self.out, VecArg): - yield from self.out.possible_reg_assignments(val, - value_assignments) - else: - yield from self.out.possible_reg_assignments(value_assignments) - @plain_data(unsafe_hash=True, frozen=True) @final @@ -637,27 +517,16 @@ class OpClearCY(Op): __slots__ = "out", def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"out": self.out} def __init__(self): # type: () -> None - self.out = SSAXERBitVal(op=self, arg_name="out", element_index=0, - phys_loc=XERBit.CY) - - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - yield XERBit.CY + self.out = SSAVal(self, "out", CYType()) @plain_data(unsafe_hash=True, frozen=True) @@ -666,48 +535,20 @@ class OpLoad(Op): __slots__ = "RT", "RA", "offset", "mem" def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RA": self.RA, "mem": self.mem} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RT": self.RT} - def __init__(self, RA, offset, mem, element_count=None): - # type: (SSAGPRVal, int, SSAMemory, int | None) -> None - self.RT = vec_or_scalar_arg(element_count, op=self, arg_name="RT") + def __init__(self, RA, offset, mem, length=1): + # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None + self.RT = SSAVal(self, "RT", GPRRangeType(length)) self.RA = RA self.offset = offset self.mem = mem - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - RA_reg = self.RA.get_reg(value_assignments) - - if self.mem == val: - yield GlobalMem.GlobalMem - elif self.RA == val: - if isinstance(self.RT, VecArg): - conflicting_regs = to_reg_set(self.RT.try_get_range( - value_assignments, allow_unassigned=True, - raise_if_invalid=True)) - else: - conflicting_regs = set() - yield from self.RA.possible_reg_assignments(value_assignments, - conflicting_regs) - elif isinstance(self.RT, VecArg): - yield from self.RT.possible_reg_assignments( - val, value_assignments, - conflicting_regs=to_reg_set(RA_reg)) - else: - yield from self.RT.possible_reg_assignments(value_assignments) - @plain_data(unsafe_hash=True, frozen=True) @final @@ -715,41 +556,20 @@ class OpStore(Op): __slots__ = "RS", "RA", "offset", "mem_in", "mem_out" def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"mem_out": self.mem_out} def __init__(self, RS, RA, offset, mem_in): - # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None + # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None self.RS = RS self.RA = RA self.offset = offset self.mem_in = mem_in - self.mem_out = mem_in.like(op=self, arg_name="mem_out") - - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - if self.mem_in == val or self.mem_out == val: - yield GlobalMem.GlobalMem - elif self.RA == val: - yield from self.RA.possible_reg_assignments(value_assignments) - elif isinstance(self.RS, VecArg): - yield from self.RS.possible_reg_assignments(val, value_assignments) - else: - yield from self.RS.possible_reg_assignments(value_assignments) - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint(self.mem_in, self.mem_out) + self.mem_out = SSAVal(self, "mem_out", mem_in.ty) @plain_data(unsafe_hash=True, frozen=True) @@ -758,34 +578,16 @@ class OpFuncArg(Op): __slots__ = "out", def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self, phys_loc): - # type: (GPROrStackLoc | Iterable[GPROrStackLoc]) -> None - if isinstance(phys_loc, GPROrStackLoc): - self.out = SSAGPRVal(self, "out", 0, phys_loc) - else: - self.out = VecArg( - SSAGPRVal(self, "out", i, v) for i, v in enumerate(phys_loc)) - - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - if isinstance(self.out, VecArg): - yield from self.out.possible_reg_assignments(val, - value_assignments) - else: - yield from self.out.possible_reg_assignments(value_assignments) + def __init__(self, ty): + # type: (RegType) -> None + self.out = SSAVal(self, "out", ty) @plain_data(unsafe_hash=True, frozen=True) @@ -794,26 +596,16 @@ class OpInputMem(Op): __slots__ = "out", def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {} def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] + # type: () -> dict[str, SSAVal] return {"out": self.out} def __init__(self): # type: () -> None - self.out = SSAMemory(op=self, arg_name="out", element_index=0) - - def possible_reg_assignments(self, val, value_assignments): - # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] - if val not in self.input_ssa_vals() \ - and val not in self.output_ssa_vals(): - raise ValueError(f"{val} must be an operand of {self}") - if val.get_phys_loc(value_assignments) is not None: - raise ValueError(f"{val} already assigned a physical location") - - yield GlobalMem.GlobalMem + self.out = SSAVal(self, "out", GlobalMemType()) def op_set_to_list(ops): @@ -823,7 +615,7 @@ def op_set_to_list(ops): ops_to_pending_input_count_map = {} # type: dict[Op, int] for op in ops: input_count = 0 - for val in op.input_ssa_vals(): + for val in op.inputs().values(): input_count += 1 input_vals_to_ops_map[val].add(op) while len(worklists) <= input_count: @@ -835,7 +627,7 @@ def op_set_to_list(ops): while len(worklists[0]) != 0: writing_op = worklists[0].pop() retval.append(writing_op) - for val in writing_op.output_ssa_vals(): + for val in writing_op.outputs().values(): if val in ready_vals: raise ValueError(f"multiple instructions must not write " f"to the same SSA value: {val}") @@ -882,14 +674,101 @@ class LiveInterval: @final -class EqualitySet(AbstractSet[SSAVal]): - def __init__(self, items): - # type: (Iterable[SSAVal]) -> None - self.__items = frozenset(items) +class MergedRegSet(Mapping[SSAVal[_RegType], int]): + def __init__(self, reg_set): + # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None + self.__items = {} # type: dict[SSAVal[_RegType], int] + if isinstance(reg_set, SSAVal): + reg_set = [(reg_set, 0)] + for ssa_val, offset in reg_set: + if ssa_val in self.__items: + other = self.__items[ssa_val] + if offset != other: + raise ValueError( + f"can't merge register sets: conflicting offsets: " + f"for {ssa_val}: {offset} != {other}") + else: + self.__items[ssa_val] = offset + first_item = None + for i in self.__items.items(): + first_item = i + break + if first_item is None: + raise ValueError("can't have empty MergedRegs") + first_ssa_val, start = first_item + ty = first_ssa_val.ty + if isinstance(ty, GPRRangeType): + stop = start + ty.length + for ssa_val, offset in self.__items.items(): + if not isinstance(ssa_val.ty, GPRRangeType): + raise ValueError(f"can't merge incompatible types: " + f"{ssa_val.ty} and {ty}") + stop = max(stop, offset + ssa_val.ty.length) + start = min(start, offset) + ty = GPRRangeType(stop - start) + else: + stop = 1 + for ssa_val, offset in self.__items.items(): + if offset != 0: + raise ValueError(f"can't have non-zero offset " + f"for {ssa_val.ty}") + if ty != ssa_val.ty: + raise ValueError(f"can't merge incompatible types: " + f"{ssa_val.ty} and {ty}") + self.__start = start # type: int + self.__stop = stop # type: int + self.__ty = ty # type: RegType + + @staticmethod + def from_equality_constraint(constraint_sequence): + # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType] + if len(constraint_sequence) == 1: + # any type allowed with len = 1 + return MergedRegSet(constraint_sequence[0]) + offset = 0 + retval = [] + for val in constraint_sequence: + if not isinstance(val.ty, GPRRangeType): + raise ValueError("equality constraint sequences must only " + "have SSAVal type GPRRangeType") + retval.append((val, offset)) + offset += val.ty.length + return MergedRegSet(retval) - def __contains__(self, x): - # type: (object) -> bool - return x in self.__items + @property + def ty(self): + return self.__ty + + @property + def stop(self): + return self.__stop + + @property + def start(self): + return self.__start + + @property + def range(self): + return range(self.__start, self.__stop) + + def offset_by(self, amount): + # type: (int) -> MergedRegSet[_RegType] + return MergedRegSet((k, v + amount) for k, v in self.items()) + + def normalized(self): + # type: () -> MergedRegSet[_RegType] + return self.offset_by(-self.start) + + def with_offset_to_match(self, target): + # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType] + for ssa_val, offset in self.items(): + if ssa_val in target: + return self.offset_by(target[ssa_val] - offset) + raise ValueError("can't change offset to match unrelated MergedRegSet") + + def __getitem__(self, item): + # type: (SSAVal[_RegType]) -> int + return self.__items[item] def __iter__(self): return iter(self.__items) @@ -898,60 +777,66 @@ class EqualitySet(AbstractSet[SSAVal]): return len(self.__items) def __hash__(self): - return super()._hash() + return hash(frozenset(self.items())) + + def __repr__(self): + return f"MergedRegSet({list(self.__items.items())})" @final -class EqualitySets(Mapping[SSAVal, EqualitySet]): +class MergedRegSets(Mapping[SSAVal, MergedRegSet]): def __init__(self, ops): # type: (Iterable[Op]) -> None - indexes = {} # type: dict[SSAVal, int] - sets = [] # type: list[set[SSAVal]] + merged_sets = {} # type: dict[SSAVal, MergedRegSet] for op in ops: - for val in (*op.input_ssa_vals(), *op.output_ssa_vals()): - if val not in indexes: - indexes[val] = len(sets) - sets.append({val}) + for val in (*op.inputs().values(), *op.outputs().values()): + if val not in merged_sets: + merged_sets[val] = MergedRegSet(val) for e in op.get_equality_constraints(): - lhs_index = indexes[e.lhs] - rhs_index = indexes[e.rhs] - sets[lhs_index] |= sets[rhs_index] - for val in sets[rhs_index]: - indexes[val] = lhs_index + lhs_set = MergedRegSet.from_equality_constraint(e.lhs) + rhs_set = MergedRegSet.from_equality_constraint(e.rhs) + lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set) + rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set) + full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()]) + for val in full_set.keys(): + merged_sets[val] = full_set - equality_sets = [EqualitySet(i) for i in sets] - self.__map = {k: equality_sets[v] for k, v in indexes.items()} + self.__map = {k: v.normalized() for k, v in merged_sets.items()} def __getitem__(self, key): - # type: (SSAVal) -> EqualitySet + # type: (SSAVal) -> MergedRegSet return self.__map[key] def __iter__(self): return iter(self.__map) + def __len__(self): + return len(self.__map) + @final -class LiveIntervals(Mapping[EqualitySet, LiveInterval]): +class LiveIntervals(Mapping[MergedRegSet, LiveInterval]): def __init__(self, ops): # type: (list[Op]) -> None - self.__equality_sets = eqsets = EqualitySets(ops) - live_intervals = {} # type: dict[EqualitySet, LiveInterval] + self.__merges_reg_sets = MergedRegSets(ops) + live_intervals = {} # type: dict[MergedRegSet, LiveInterval] for op_idx, op in enumerate(ops): - for val in op.input_ssa_vals(): - live_intervals[eqsets[val]] += op_idx - for val in op.output_ssa_vals(): - if eqsets[val] not in live_intervals: - live_intervals[eqsets[val]] = LiveInterval(op_idx) + for val in op.inputs().values(): + live_intervals[self.__merges_reg_sets[val]] += op_idx + for val in op.outputs().values(): + reg_set = self.__merges_reg_sets[val] + if reg_set not in live_intervals: + live_intervals[reg_set] = LiveInterval(op_idx) else: - live_intervals[eqsets[val]] += op_idx + live_intervals[reg_set] += op_idx self.__live_intervals = live_intervals @property - def equality_sets(self): - return self.__equality_sets + def merges_reg_sets(self): + return self.__merges_reg_sets def __getitem__(self, key): - # type: (EqualitySet) -> LiveInterval + # type: (MergedRegSet) -> LiveInterval return self.__live_intervals[key] def __iter__(self): @@ -961,11 +846,11 @@ class LiveIntervals(Mapping[EqualitySet, LiveInterval]): @final class IGNode: """ interference graph node """ - __slots__ = "equality_set", "edges" + __slots__ = "merged_reg_set", "edges" - def __init__(self, equality_set, edges=()): - # type: (EqualitySet, Iterable[IGNode]) -> None - self.equality_set = equality_set + def __init__(self, merged_reg_set, edges=()): + # type: (MergedRegSet, Iterable[IGNode]) -> None + self.merged_reg_set = merged_reg_set self.edges = set(edges) def add_edge(self, other): @@ -976,11 +861,11 @@ class IGNode: def __eq__(self, other): # type: (object) -> bool if isinstance(other, IGNode): - return self.equality_set == other.equality_set + return self.merged_reg_set == other.merged_reg_set return NotImplemented def __hash__(self): - return self.equality_set.__hash__() + return hash(self.merged_reg_set) def __repr__(self, nodes=None): # type: (None | dict[IGNode, int]) -> str @@ -991,18 +876,18 @@ class IGNode: nodes[self] = len(nodes) edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}" return (f"IGNode(#{nodes[self]}, " - f"equality_set={self.equality_set}, " + f"merged_reg_set={self.merged_reg_set}, " f"edges={edges})") @final -class InterferenceGraph(Mapping[EqualitySet, IGNode]): - def __init__(self, equality_sets): - # type: (Iterable[EqualitySet]) -> None - self.__nodes = {i: IGNode(i) for i in equality_sets} +class InterferenceGraph(Mapping[MergedRegSet, IGNode]): + def __init__(self, merged_reg_sets): + # type: (Iterable[MergedRegSet]) -> None + self.__nodes = {i: IGNode(i) for i in merged_reg_sets} def __getitem__(self, key): - # type: (EqualitySet) -> IGNode + # type: (MergedRegSet) -> IGNode return self.__nodes[key] def __iter__(self): @@ -1014,7 +899,7 @@ class AllocationFailed: __slots__ = "op_idx", "arg", "live_intervals" def __init__(self, op_idx, arg, live_intervals): - # type: (int, SSAVal | VecArg, LiveIntervals) -> None + # type: (int, SSAVal, LiveIntervals) -> None self.op_idx = op_idx self.arg = arg self.live_intervals = live_intervals @@ -1022,11 +907,8 @@ class AllocationFailed: def try_allocate_registers_without_spilling(ops): # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed - live_intervals = LiveIntervals(ops) - def is_constrained(node): - # type: (EqualitySet) -> bool - raise NotImplementedError + live_intervals = LiveIntervals(ops) raise NotImplementedError