--- /dev/null
+"""
+Compiler IR for Toom-Cook algorithm generator for SVP64
+"""
+
+from abc import ABCMeta, abstractmethod
+from collections import defaultdict
+from enum import Enum, EnumMeta, unique
+from functools import lru_cache
+from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
+ TypeVar)
+
+from nmutil.plain_data import plain_data
+
+if TYPE_CHECKING:
+ from typing_extensions import final
+else:
+ def final(v):
+ return v
+
+
+class ABCEnumMeta(EnumMeta, ABCMeta):
+ pass
+
+
+class RegLoc(metaclass=ABCMeta):
+ __slots__ = ()
+
+ @abstractmethod
+ def conflicts(self, other):
+ # type: (RegLoc) -> bool
+ ...
+
+ def get_subreg_at_offset(self, subreg_type, offset):
+ # type: (RegType, int) -> RegLoc
+ if self not in subreg_type.reg_class:
+ raise ValueError(f"register not a member of subreg_type: "
+ f"reg={self} subreg_type={subreg_type}")
+ if offset != 0:
+ raise ValueError(f"non-zero sub-register offset not supported "
+ f"for register: {self}")
+ return self
+
+
+GPR_COUNT = 128
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GPRRange(RegLoc, Sequence["GPRRange"]):
+ __slots__ = "start", "length"
+
+ def __init__(self, start, length=None):
+ # type: (int | range, int | None) -> None
+ if isinstance(start, range):
+ if length is not None:
+ raise TypeError("can't specify length when input is a range")
+ if start.step != 1:
+ raise ValueError("range must have a step of 1")
+ length = len(start)
+ start = start.start
+ elif length is None:
+ length = 1
+ if length <= 0 or start < 0 or start + length > GPR_COUNT:
+ raise ValueError("invalid GPRRange")
+ self.start = start
+ self.length = length
+
+ @property
+ def stop(self):
+ return self.start + self.length
+
+ @property
+ def step(self):
+ return 1
+
+ @property
+ def range(self):
+ return range(self.start, self.stop, self.step)
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, item):
+ # type: (int | slice) -> GPRRange
+ return GPRRange(self.range[item])
+
+ def __contains__(self, value):
+ # type: (GPRRange) -> bool
+ return value.start >= self.start and value.stop <= self.stop
+
+ def index(self, sub, start=None, end=None):
+ # type: (GPRRange, int | None, int | None) -> int
+ r = self.range[start:end]
+ if sub.start < r.start or sub.stop > r.stop:
+ raise ValueError("GPR range not found")
+ return sub.start - self.start
+
+ def count(self, sub, start=None, end=None):
+ # type: (GPRRange, int | None, int | None) -> int
+ r = self.range[start:end]
+ if len(r) == 0:
+ return 0
+ return int(sub in GPRRange(r))
+
+ def conflicts(self, other):
+ # type: (RegLoc) -> bool
+ if isinstance(other, GPRRange):
+ return self.stop > other.start and other.stop > self.start
+ return False
+
+ def get_subreg_at_offset(self, subreg_type, offset):
+ # type: (RegType, int) -> GPRRange
+ if not isinstance(subreg_type, GPRRangeType):
+ raise ValueError(f"subreg_type is not a "
+ f"GPRRangeType: {subreg_type}")
+ if offset < 0 or offset + subreg_type.length > self.stop:
+ raise ValueError(f"sub-register offset is out of range: {offset}")
+ return GPRRange(self.start + offset, subreg_type.length)
+
+
+SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
+
+
+@final
+@unique
+class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
+ CY = "CY"
+
+ def conflicts(self, other):
+ # type: (RegLoc) -> bool
+ if isinstance(other, XERBit):
+ return self == other
+ return False
+
+
+@final
+@unique
+class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
+ """singleton representing all non-StackSlot memory -- treated as a single
+ physical register for register allocation purposes.
+ """
+ GlobalMem = "GlobalMem"
+
+ def conflicts(self, other):
+ # type: (RegLoc) -> bool
+ if isinstance(other, GlobalMem):
+ return self == other
+ return False
+
+
+@final
+class RegClass(AbstractSet[RegLoc]):
+ """ an ordered set of registers.
+ earlier registers are preferred by the register allocator.
+ """
+
+ def __init__(self, regs):
+ # type: (Iterable[RegLoc]) -> None
+
+ # use dict to maintain order
+ self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None]
+
+ def __len__(self):
+ return len(self.__regs)
+
+ def __iter__(self):
+ return iter(self.__regs)
+
+ def __contains__(self, v):
+ # type: (RegLoc) -> bool
+ return v in self.__regs
+
+ def __hash__(self):
+ return super()._hash()
+
+ @lru_cache(maxsize=None, typed=True)
+ def max_conflicts_with(self, other):
+ # type: (RegClass | RegLoc) -> int
+ """the largest number of registers in `self` that a single register
+ from `other` can conflict with
+ """
+ if isinstance(other, RegClass):
+ return max(self.max_conflicts_with(i) for i in other)
+ else:
+ return sum(other.conflicts(i) for i in self)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class RegType(metaclass=ABCMeta):
+ __slots__ = ()
+
+ @property
+ @abstractmethod
+ def reg_class(self):
+ # type: () -> RegClass
+ return ...
+
+
+_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+
+
+@plain_data(frozen=True, eq=False)
+class GPRRangeType(RegType):
+ __slots__ = "length",
+
+ def __init__(self, length):
+ # type: (int) -> None
+ if length < 1 or length > GPR_COUNT:
+ raise ValueError("invalid length")
+ self.length = length
+
+ @staticmethod
+ @lru_cache(maxsize=None)
+ def __get_reg_class(length):
+ # type: (int) -> RegClass
+ regs = []
+ for start in range(GPR_COUNT - length):
+ reg = GPRRange(start, length)
+ if any(i in reg for i in SPECIAL_GPRS):
+ continue
+ regs.append(reg)
+ return RegClass(regs)
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return GPRRangeType.__get_reg_class(self.length)
+
+ @final
+ def __eq__(self, other):
+ if isinstance(other, GPRRangeType):
+ return self.length == other.length
+ return False
+
+ @final
+ def __hash__(self):
+ return hash(self.length)
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class GPRType(GPRRangeType):
+ __slots__ = ()
+
+ def __init__(self, length=1):
+ if length != 1:
+ raise ValueError("length must be 1")
+ super().__init__(length=1)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class CYType(RegType):
+ __slots__ = ()
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return RegClass([XERBit.CY])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GlobalMemType(RegType):
+ __slots__ = ()
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return RegClass([GlobalMem.GlobalMem])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class StackSlot(RegLoc):
+ __slots__ = "start_slot", "length_in_slots",
+
+ def __init__(self, start_slot, length_in_slots):
+ # type: (int, int) -> None
+ self.start_slot = start_slot
+ if length_in_slots < 1:
+ raise ValueError("invalid length_in_slots")
+ self.length_in_slots = length_in_slots
+
+ @property
+ def stop_slot(self):
+ return self.start_slot + self.length_in_slots
+
+ def conflicts(self, other):
+ # type: (RegLoc) -> bool
+ if isinstance(other, StackSlot):
+ return (self.stop_slot > other.start_slot
+ and other.stop_slot > self.start_slot)
+ return False
+
+ def get_subreg_at_offset(self, subreg_type, offset):
+ # type: (RegType, int) -> StackSlot
+ if not isinstance(subreg_type, StackSlotType):
+ raise ValueError(f"subreg_type is not a "
+ f"StackSlotType: {subreg_type}")
+ if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
+ raise ValueError(f"sub-register offset is out of range: {offset}")
+ return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
+
+
+STACK_SLOT_COUNT = 128
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class StackSlotType(RegType):
+ __slots__ = "length_in_slots",
+
+ def __init__(self, length_in_slots=1):
+ # type: (int) -> None
+ if length_in_slots < 1:
+ raise ValueError("invalid length_in_slots")
+ self.length_in_slots = length_in_slots
+
+ @staticmethod
+ @lru_cache(maxsize=None)
+ def __get_reg_class(length_in_slots):
+ # type: (int) -> RegClass
+ regs = []
+ for start in range(STACK_SLOT_COUNT - length_in_slots):
+ reg = StackSlot(start, length_in_slots)
+ regs.append(reg)
+ return RegClass(regs)
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return StackSlotType.__get_reg_class(self.length_in_slots)
+
+ @final
+ def __eq__(self, other):
+ if isinstance(other, StackSlotType):
+ return self.length_in_slots == other.length_in_slots
+ return False
+
+ @final
+ def __hash__(self):
+ return hash(self.length_in_slots)
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class SSAVal(Generic[_RegT_co]):
+ __slots__ = "op", "arg_name", "ty", "arg_index"
+
+ def __init__(self, op, arg_name, ty):
+ # type: (Op, str, _RegT_co) -> None
+ self.op = op
+ """the Op that writes this SSAVal"""
+
+ self.arg_name = arg_name
+ """the name of the argument of self.op that writes this SSAVal"""
+
+ self.ty = ty
+
+ def __eq__(self, rhs):
+ if isinstance(rhs, SSAVal):
+ return (self.op is rhs.op
+ and self.arg_name == rhs.arg_name)
+ return False
+
+ def __hash__(self):
+ return hash((id(self.op), self.arg_name))
+
+
+@final
+@plain_data(unsafe_hash=True, frozen=True)
+class EqualityConstraint:
+ __slots__ = "lhs", "rhs"
+
+ def __init__(self, lhs, rhs):
+ # type: (list[SSAVal], list[SSAVal]) -> None
+ self.lhs = lhs
+ self.rhs = rhs
+ if len(lhs) == 0 or len(rhs) == 0:
+ raise ValueError("can't constrain an empty list to be equal")
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+class Op(metaclass=ABCMeta):
+ __slots__ = ()
+
+ @abstractmethod
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ ...
+
+ @abstractmethod
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ ...
+
+ def get_equality_constraints(self):
+ # type: () -> Iterable[EqualityConstraint]
+ if False:
+ yield ...
+
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ if False:
+ yield ...
+
+ def __init__(self):
+ pass
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLoadFromStackSlot(Op):
+ __slots__ = "dest", "src"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"src": self.src}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"dest": self.dest}
+
+ def __init__(self, src):
+ # type: (SSAVal[GPRRangeType]) -> None
+ self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
+ self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpStoreToStackSlot(Op):
+ __slots__ = "dest", "src"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"src": self.src}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"dest": self.dest}
+
+ def __init__(self, src):
+ # type: (SSAVal[StackSlotType]) -> None
+ self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
+ self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpCopy(Op, Generic[_RegT_co]):
+ __slots__ = "dest", "src"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"src": self.src}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"dest": self.dest}
+
+ def __init__(self, src):
+ # type: (SSAVal[_RegT_co]) -> None
+ self.dest = SSAVal(self, "dest", src.ty)
+ self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpConcat(Op):
+ __slots__ = "dest", "sources"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"dest": self.dest}
+
+ def __init__(self, sources):
+ # type: (Iterable[SSAVal[GPRRangeType]]) -> None
+ sources = tuple(sources)
+ self.dest = SSAVal(self, "dest", GPRRangeType(
+ sum(i.ty.length for i in sources)))
+ self.sources = sources
+
+ def get_equality_constraints(self):
+ # type: () -> Iterable[EqualityConstraint]
+ yield EqualityConstraint([self.dest], [*self.sources])
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpSplit(Op):
+ __slots__ = "results", "src"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"src": self.src}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {i.arg_name: i for i in self.results}
+
+ def __init__(self, src, split_indexes):
+ # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
+ ranges = [] # type: list[GPRRangeType]
+ last = 0
+ for i in split_indexes:
+ if not (0 < i < src.ty.length):
+ raise ValueError(f"invalid split index: {i}, must be in "
+ f"0 < i < {src.ty.length}")
+ ranges.append(GPRRangeType(i - last))
+ last = i
+ ranges.append(GPRRangeType(src.ty.length - last))
+ self.src = src
+ self.results = tuple(
+ SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
+
+ def get_equality_constraints(self):
+ # type: () -> Iterable[EqualityConstraint]
+ yield EqualityConstraint([*self.results], [self.src])
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpAddSubE(Op):
+ __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RT": self.RT, "CY_out": self.CY_out}
+
+ def __init__(self, RA, RB, CY_in, is_sub):
+ # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+ if RA.ty != RB.ty:
+ raise TypeError(f"source types must match: "
+ f"{RA} doesn't match {RB}")
+ self.RT = SSAVal(self, "RT", RA.ty)
+ self.RA = RA
+ self.RB = RB
+ self.CY_in = CY_in
+ self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
+ self.is_sub = is_sub
+
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ yield self.RT, self.RA
+ yield self.RT, self.RB
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpBigIntMulDiv(Op):
+ __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RT": self.RT, "RS": self.RS}
+
+ def __init__(self, RA, RB, RC, is_div):
+ # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+ self.RT = SSAVal(self, "RT", RA.ty)
+ self.RA = RA
+ self.RB = RB
+ self.RC = RC
+ self.RS = SSAVal(self, "RS", RC.ty)
+ self.is_div = is_div
+
+ def get_equality_constraints(self):
+ # type: () -> Iterable[EqualityConstraint]
+ yield EqualityConstraint([self.RC], [self.RS])
+
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ yield self.RT, self.RA
+ yield self.RT, self.RB
+ yield self.RT, self.RC
+ yield self.RT, self.RS
+ yield self.RS, self.RA
+ yield self.RS, self.RB
+
+
+@final
+@unique
+class ShiftKind(Enum):
+ Sl = "sl"
+ Sr = "sr"
+ Sra = "sra"
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpBigIntShift(Op):
+ __slots__ = "RT", "inp", "sh", "kind"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"inp": self.inp, "sh": self.sh}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RT": self.RT}
+
+ def __init__(self, inp, sh, kind):
+ # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+ self.RT = SSAVal(self, "RT", inp.ty)
+ self.inp = inp
+ self.sh = sh
+ self.kind = kind
+
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ yield self.RT, self.inp
+ yield self.RT, self.sh
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLI(Op):
+ __slots__ = "out", "value"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"out": self.out}
+
+ def __init__(self, value, length=1):
+ # type: (int, int) -> None
+ self.out = SSAVal(self, "out", GPRRangeType(length))
+ self.value = value
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpClearCY(Op):
+ __slots__ = "out",
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"out": self.out}
+
+ def __init__(self):
+ # type: () -> None
+ self.out = SSAVal(self, "out", CYType())
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLoad(Op):
+ __slots__ = "RT", "RA", "offset", "mem"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RA": self.RA, "mem": self.mem}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RT": self.RT}
+
+ def __init__(self, RA, offset, mem, length=1):
+ # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+ self.RT = SSAVal(self, "RT", GPRRangeType(length))
+ self.RA = RA
+ self.offset = offset
+ self.mem = mem
+
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ if self.RT.ty.length > 1:
+ yield self.RT, self.RA
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpStore(Op):
+ __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"mem_out": self.mem_out}
+
+ def __init__(self, RS, RA, offset, mem_in):
+ # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+ self.RS = RS
+ self.RA = RA
+ self.offset = offset
+ self.mem_in = mem_in
+ self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpFuncArg(Op):
+ __slots__ = "out",
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"out": self.out}
+
+ def __init__(self, ty):
+ # type: (RegType) -> None
+ self.out = SSAVal(self, "out", ty)
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpInputMem(Op):
+ __slots__ = "out",
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"out": self.out}
+
+ def __init__(self):
+ # type: () -> None
+ self.out = SSAVal(self, "out", GlobalMemType())
+
+
+def op_set_to_list(ops):
+ # type: (Iterable[Op]) -> list[Op]
+ worklists = [set()] # type: list[set[Op]]
+ input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]]
+ ops_to_pending_input_count_map = {} # type: dict[Op, int]
+ for op in ops:
+ input_count = 0
+ for val in op.inputs().values():
+ input_count += 1
+ input_vals_to_ops_map[val].add(op)
+ while len(worklists) <= input_count:
+ worklists.append(set())
+ ops_to_pending_input_count_map[op] = input_count
+ worklists[input_count].add(op)
+ retval = [] # type: list[Op]
+ ready_vals = set() # type: set[SSAVal]
+ while len(worklists[0]) != 0:
+ writing_op = worklists[0].pop()
+ retval.append(writing_op)
+ for val in writing_op.outputs().values():
+ if val in ready_vals:
+ raise ValueError(f"multiple instructions must not write "
+ f"to the same SSA value: {val}")
+ ready_vals.add(val)
+ for reading_op in input_vals_to_ops_map[val]:
+ pending = ops_to_pending_input_count_map[reading_op]
+ worklists[pending].remove(reading_op)
+ pending -= 1
+ worklists[pending].add(reading_op)
+ ops_to_pending_input_count_map[reading_op] = pending
+ for worklist in worklists:
+ for op in worklist:
+ raise ValueError(f"instruction is part of a dependency loop or "
+ f"its inputs are never written: {op}")
+ return retval
--- /dev/null
+"""
+Register Allocator for Toom-Cook algorithm generator for SVP64
+
+this uses an algorithm based on:
+[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+"""
+
+from itertools import combinations
+from typing import TYPE_CHECKING, Generic, Iterable, Mapping, TypeVar
+
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
+ RegLoc, RegType, SSAVal)
+
+if TYPE_CHECKING:
+ from typing_extensions import Self, final
+else:
+ def final(v):
+ return v
+
+
+_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+
+
+@plain_data(unsafe_hash=True, order=True, frozen=True)
+class LiveInterval:
+ __slots__ = "first_write", "last_use"
+
+ def __init__(self, first_write, last_use=None):
+ # type: (int, int | None) -> None
+ if last_use is None:
+ last_use = first_write
+ if last_use < first_write:
+ raise ValueError("uses must be after first_write")
+ if first_write < 0 or last_use < 0:
+ raise ValueError("indexes must be nonnegative")
+ self.first_write = first_write
+ self.last_use = last_use
+
+ def overlaps(self, other):
+ # type: (LiveInterval) -> bool
+ if self.first_write == other.first_write:
+ return True
+ return self.last_use > other.first_write \
+ and other.last_use > self.first_write
+
+ def __add__(self, use):
+ # type: (int) -> LiveInterval
+ last_use = max(self.last_use, use)
+ return LiveInterval(first_write=self.first_write, last_use=last_use)
+
+ @property
+ def live_after_op_range(self):
+ """the range of op indexes where self is live immediately after the
+ Op at each index
+ """
+ return range(self.first_write, self.last_use)
+
+
+@final
+class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
+ def __init__(self, reg_set):
+ # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
+ self.__items = {} # type: dict[SSAVal[_RegT_co], int]
+ if isinstance(reg_set, SSAVal):
+ reg_set = [(reg_set, 0)]
+ for ssa_val, offset in reg_set:
+ if ssa_val in self.__items:
+ other = self.__items[ssa_val]
+ if offset != other:
+ raise ValueError(
+ f"can't merge register sets: conflicting offsets: "
+ f"for {ssa_val}: {offset} != {other}")
+ else:
+ self.__items[ssa_val] = offset
+ first_item = None
+ for i in self.__items.items():
+ first_item = i
+ break
+ if first_item is None:
+ raise ValueError("can't have empty MergedRegs")
+ first_ssa_val, start = first_item
+ ty = first_ssa_val.ty
+ if isinstance(ty, GPRRangeType):
+ stop = start + ty.length
+ for ssa_val, offset in self.__items.items():
+ if not isinstance(ssa_val.ty, GPRRangeType):
+ raise ValueError(f"can't merge incompatible types: "
+ f"{ssa_val.ty} and {ty}")
+ stop = max(stop, offset + ssa_val.ty.length)
+ start = min(start, offset)
+ ty = GPRRangeType(stop - start)
+ else:
+ stop = 1
+ for ssa_val, offset in self.__items.items():
+ if offset != 0:
+ raise ValueError(f"can't have non-zero offset "
+ f"for {ssa_val.ty}")
+ if ty != ssa_val.ty:
+ raise ValueError(f"can't merge incompatible types: "
+ f"{ssa_val.ty} and {ty}")
+ self.__start = start # type: int
+ self.__stop = stop # type: int
+ self.__ty = ty # type: RegType
+ self.__hash = hash(frozenset(self.items()))
+
+ @staticmethod
+ def from_equality_constraint(constraint_sequence):
+ # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
+ if len(constraint_sequence) == 1:
+ # any type allowed with len = 1
+ return MergedRegSet(constraint_sequence[0])
+ offset = 0
+ retval = []
+ for val in constraint_sequence:
+ if not isinstance(val.ty, GPRRangeType):
+ raise ValueError("equality constraint sequences must only "
+ "have SSAVal type GPRRangeType")
+ retval.append((val, offset))
+ offset += val.ty.length
+ return MergedRegSet(retval)
+
+ @property
+ def ty(self):
+ return self.__ty
+
+ @property
+ def stop(self):
+ return self.__stop
+
+ @property
+ def start(self):
+ return self.__start
+
+ @property
+ def range(self):
+ return range(self.__start, self.__stop)
+
+ def offset_by(self, amount):
+ # type: (int) -> MergedRegSet[_RegT_co]
+ return MergedRegSet((k, v + amount) for k, v in self.items())
+
+ def normalized(self):
+ # type: () -> MergedRegSet[_RegT_co]
+ return self.offset_by(-self.start)
+
+ def with_offset_to_match(self, target):
+ # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
+ for ssa_val, offset in self.items():
+ if ssa_val in target:
+ return self.offset_by(target[ssa_val] - offset)
+ raise ValueError("can't change offset to match unrelated MergedRegSet")
+
+ def __getitem__(self, item):
+ # type: (SSAVal[_RegT_co]) -> int
+ return self.__items[item]
+
+ def __iter__(self):
+ return iter(self.__items)
+
+ def __len__(self):
+ return len(self.__items)
+
+ def __hash__(self):
+ return self.__hash
+
+ def __repr__(self):
+ return f"MergedRegSet({list(self.__items.items())})"
+
+
+@final
+class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
+ def __init__(self, ops):
+ # type: (Iterable[Op]) -> None
+ merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]]
+ for op in ops:
+ for val in (*op.inputs().values(), *op.outputs().values()):
+ if val not in merged_sets:
+ merged_sets[val] = MergedRegSet(val)
+ for e in op.get_equality_constraints():
+ lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
+ rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
+ lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
+ rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
+ full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
+ for val in full_set.keys():
+ merged_sets[val] = full_set
+
+ self.__map = {k: v.normalized() for k, v in merged_sets.items()}
+
+ def __getitem__(self, key):
+ # type: (SSAVal) -> MergedRegSet
+ return self.__map[key]
+
+ def __iter__(self):
+ return iter(self.__map)
+
+ def __len__(self):
+ return len(self.__map)
+
+ def __repr__(self):
+ return f"MergedRegSets(data={self.__map})"
+
+
+@final
+class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
+ def __init__(self, ops):
+ # type: (list[Op]) -> None
+ self.__merged_reg_sets = MergedRegSets(ops)
+ live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval]
+ for op_idx, op in enumerate(ops):
+ for val in op.inputs().values():
+ live_intervals[self.__merged_reg_sets[val]] += op_idx
+ for val in op.outputs().values():
+ reg_set = self.__merged_reg_sets[val]
+ if reg_set not in live_intervals:
+ live_intervals[reg_set] = LiveInterval(op_idx)
+ else:
+ live_intervals[reg_set] += op_idx
+ self.__live_intervals = live_intervals
+ live_after = [] # type: list[set[MergedRegSet[_RegT_co]]]
+ live_after += (set() for _ in ops)
+ for reg_set, live_interval in self.__live_intervals.items():
+ for i in live_interval.live_after_op_range:
+ live_after[i].add(reg_set)
+ self.__live_after = [frozenset(i) for i in live_after]
+
+ @property
+ def merged_reg_sets(self):
+ return self.__merged_reg_sets
+
+ def __getitem__(self, key):
+ # type: (MergedRegSet[_RegT_co]) -> LiveInterval
+ return self.__live_intervals[key]
+
+ def __iter__(self):
+ return iter(self.__live_intervals)
+
+ def reg_sets_live_after(self, op_index):
+ # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
+ return self.__live_after[op_index]
+
+ def __repr__(self):
+ reg_sets_live_after = dict(enumerate(self.__live_after))
+ return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
+ f"merged_reg_sets={self.merged_reg_sets}, "
+ f"reg_sets_live_after={reg_sets_live_after})")
+
+
+@final
+class IGNode(Generic[_RegT_co]):
+ """ interference graph node """
+ __slots__ = "merged_reg_set", "edges", "reg"
+
+ def __init__(self, merged_reg_set, edges=(), reg=None):
+ # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
+ self.merged_reg_set = merged_reg_set
+ self.edges = set(edges)
+ self.reg = reg
+
+ def add_edge(self, other):
+ # type: (IGNode) -> None
+ self.edges.add(other)
+ other.edges.add(self)
+
+ def __eq__(self, other):
+ # type: (object) -> bool
+ if isinstance(other, IGNode):
+ return self.merged_reg_set == other.merged_reg_set
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.merged_reg_set)
+
+ def __repr__(self, nodes=None):
+ # type: (None | dict[IGNode, int]) -> str
+ if nodes is None:
+ nodes = {}
+ if self in nodes:
+ return f"<IGNode #{nodes[self]}>"
+ nodes[self] = len(nodes)
+ edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
+ return (f"IGNode(#{nodes[self]}, "
+ f"merged_reg_set={self.merged_reg_set}, "
+ f"edges={edges}, "
+ f"reg={self.reg})")
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return self.merged_reg_set.ty.reg_class
+
+ def reg_conflicts_with_neighbors(self, reg):
+ # type: (RegLoc) -> bool
+ for neighbor in self.edges:
+ if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+ return True
+ return False
+
+
+@final
+class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
+ def __init__(self, merged_reg_sets):
+ # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
+ self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
+
+ def __getitem__(self, key):
+ # type: (MergedRegSet[_RegT_co]) -> IGNode
+ return self.__nodes[key]
+
+ def __iter__(self):
+ return iter(self.__nodes)
+
+ def __repr__(self):
+ nodes = {}
+ nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
+ nodes_text = ", ".join(nodes_text)
+ return f"InterferenceGraph(nodes={{{nodes_text}}})"
+
+
+@plain_data()
+class AllocationFailed:
+ __slots__ = "node", "live_intervals", "interference_graph"
+
+ def __init__(self, node, live_intervals, interference_graph):
+ # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+ self.node = node
+ self.live_intervals = live_intervals
+ self.interference_graph = interference_graph
+
+
+def try_allocate_registers_without_spilling(ops):
+ # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
+
+ live_intervals = LiveIntervals(ops)
+ merged_reg_sets = live_intervals.merged_reg_sets
+ interference_graph = InterferenceGraph(merged_reg_sets.values())
+ for op_idx, op in enumerate(ops):
+ reg_sets = live_intervals.reg_sets_live_after(op_idx)
+ for i, j in combinations(reg_sets, 2):
+ if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+ interference_graph[i].add_edge(interference_graph[j])
+ for i, j in op.get_extra_interferences():
+ i = merged_reg_sets[i]
+ j = merged_reg_sets[j]
+ if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+ interference_graph[i].add_edge(interference_graph[j])
+
+ nodes_remaining = set(interference_graph.values())
+
+ def local_colorability_score(node):
+ # type: (IGNode) -> int
+ """ returns a positive integer if node is locally colorable, returns
+ zero or a negative integer if node isn't known to be locally
+ colorable, the more negative the value, the less colorable
+ """
+ if node not in nodes_remaining:
+ raise ValueError()
+ retval = len(node.reg_class)
+ for neighbor in node.edges:
+ if neighbor in nodes_remaining:
+ retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
+ return retval
+
+ node_stack = [] # type: list[IGNode]
+ while True:
+ best_node = None # type: None | IGNode
+ best_score = 0
+ for node in nodes_remaining:
+ score = local_colorability_score(node)
+ if best_node is None or score > best_score:
+ best_node = node
+ best_score = score
+ if best_score > 0:
+ # it's locally colorable, no need to find a better one
+ break
+
+ if best_node is None:
+ break
+ node_stack.append(best_node)
+ nodes_remaining.remove(best_node)
+
+ retval = {} # type: dict[SSAVal, RegLoc]
+
+ while len(node_stack) > 0:
+ node = node_stack.pop()
+ if node.reg is not None:
+ if node.reg_conflicts_with_neighbors(node.reg):
+ return AllocationFailed(node=node,
+ live_intervals=live_intervals,
+ interference_graph=interference_graph)
+ else:
+ # pick the first non-conflicting register in node.reg_class, since
+ # register classes are ordered from most preferred to least
+ # preferred register.
+ for reg in node.reg_class:
+ if not node.reg_conflicts_with_neighbors(reg):
+ node.reg = reg
+ break
+ if node.reg is None:
+ return AllocationFailed(node=node,
+ live_intervals=live_intervals,
+ interference_graph=interference_graph)
+
+ for ssa_val, offset in node.merged_reg_set.items():
+ retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
+
+ return retval
+
+
+def allocate_registers(ops):
+ # type: (list[Op]) -> None
+ raise NotImplementedError
--- /dev/null
+import unittest
+
+from bigint_presentation_code.compiler_ir import Op, op_set_to_list
+
+
+class TestCompilerIR(unittest.TestCase):
+ pass # no tests yet, just testing importing
+
+
+if __name__ == "__main__":
+ unittest.main()
--- /dev/null
+import unittest
+
+from bigint_presentation_code.compiler_ir import Op
+from bigint_presentation_code.register_allocator import (
+ AllocationFailed, allocate_registers,
+ try_allocate_registers_without_spilling)
+
+
+class TestCompilerIR(unittest.TestCase):
+ pass # no tests yet, just testing importing
+
+
+if __name__ == "__main__":
+ unittest.main()
import unittest
-from bigint_presentation_code.toom_cook import Op
+
+import bigint_presentation_code.toom_cook
class TestToomCook(unittest.TestCase):
the register allocator uses an algorithm based on:
[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
"""
-
-from abc import ABCMeta, abstractmethod
-from collections import defaultdict
-from enum import Enum, unique, EnumMeta
-from functools import lru_cache
-from itertools import combinations
-from typing import (Sequence, AbstractSet, Iterable, Mapping,
- TYPE_CHECKING, Sequence, TypeVar, Generic)
-
-from nmutil.plain_data import plain_data
-
-if TYPE_CHECKING:
- from typing_extensions import final, Self
-else:
- def final(v):
- return v
-
-
-class ABCEnumMeta(EnumMeta, ABCMeta):
- pass
-
-
-class RegLoc(metaclass=ABCMeta):
- __slots__ = ()
-
- @abstractmethod
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- ...
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> RegLoc
- if self not in subreg_type.reg_class:
- raise ValueError(f"register not a member of subreg_type: "
- f"reg={self} subreg_type={subreg_type}")
- if offset != 0:
- raise ValueError(f"non-zero sub-register offset not supported "
- f"for register: {self}")
- return self
-
-
-GPR_COUNT = 128
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GPRRange(RegLoc, Sequence["GPRRange"]):
- __slots__ = "start", "length"
-
- def __init__(self, start, length=None):
- # type: (int | range, int | None) -> None
- if isinstance(start, range):
- if length is not None:
- raise TypeError("can't specify length when input is a range")
- if start.step != 1:
- raise ValueError("range must have a step of 1")
- length = len(start)
- start = start.start
- elif length is None:
- length = 1
- if length <= 0 or start < 0 or start + length > GPR_COUNT:
- raise ValueError("invalid GPRRange")
- self.start = start
- self.length = length
-
- @property
- def stop(self):
- return self.start + self.length
-
- @property
- def step(self):
- return 1
-
- @property
- def range(self):
- return range(self.start, self.stop, self.step)
-
- def __len__(self):
- return self.length
-
- def __getitem__(self, item):
- # type: (int | slice) -> GPRRange
- return GPRRange(self.range[item])
-
- def __contains__(self, value):
- # type: (GPRRange) -> bool
- return value.start >= self.start and value.stop <= self.stop
-
- def index(self, sub, start=None, end=None):
- # type: (GPRRange, int | None, int | None) -> int
- r = self.range[start:end]
- if sub.start < r.start or sub.stop > r.stop:
- raise ValueError("GPR range not found")
- return sub.start - self.start
-
- def count(self, sub, start=None, end=None):
- # type: (GPRRange, int | None, int | None) -> int
- r = self.range[start:end]
- if len(r) == 0:
- return 0
- return int(sub in GPRRange(r))
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, GPRRange):
- return self.stop > other.start and other.stop > self.start
- return False
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> GPRRange
- if not isinstance(subreg_type, GPRRangeType):
- raise ValueError(f"subreg_type is not a "
- f"GPRRangeType: {subreg_type}")
- if offset < 0 or offset + subreg_type.length > self.stop:
- raise ValueError(f"sub-register offset is out of range: {offset}")
- return GPRRange(self.start + offset, subreg_type.length)
-
-
-SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
-
-
-@final
-@unique
-class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
- CY = "CY"
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, XERBit):
- return self == other
- return False
-
-
-@final
-@unique
-class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
- """singleton representing all non-StackSlot memory -- treated as a single
- physical register for register allocation purposes.
- """
- GlobalMem = "GlobalMem"
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, GlobalMem):
- return self == other
- return False
-
-
-@final
-class RegClass(AbstractSet[RegLoc]):
- """ an ordered set of registers.
- earlier registers are preferred by the register allocator.
- """
- def __init__(self, regs):
- # type: (Iterable[RegLoc]) -> None
-
- # use dict to maintain order
- self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None]
-
- def __len__(self):
- return len(self.__regs)
-
- def __iter__(self):
- return iter(self.__regs)
-
- def __contains__(self, v):
- # type: (RegLoc) -> bool
- return v in self.__regs
-
- def __hash__(self):
- return super()._hash()
-
- @lru_cache(maxsize=None, typed=True)
- def max_conflicts_with(self, other):
- # type: (RegClass | RegLoc) -> int
- """the largest number of registers in `self` that a single register
- from `other` can conflict with
- """
- if isinstance(other, RegClass):
- return max(self.max_conflicts_with(i) for i in other)
- else:
- return sum(other.conflicts(i) for i in self)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-class RegType(metaclass=ABCMeta):
- __slots__ = ()
-
- @property
- @abstractmethod
- def reg_class(self):
- # type: () -> RegClass
- return ...
-
-
-@plain_data(frozen=True, eq=False)
-class GPRRangeType(RegType):
- __slots__ = "length",
-
- def __init__(self, length):
- # type: (int) -> None
- if length < 1 or length > GPR_COUNT:
- raise ValueError("invalid length")
- self.length = length
-
- @staticmethod
- @lru_cache(maxsize=None)
- def __get_reg_class(length):
- # type: (int) -> RegClass
- regs = []
- for start in range(GPR_COUNT - length):
- reg = GPRRange(start, length)
- if any(i in reg for i in SPECIAL_GPRS):
- continue
- regs.append(reg)
- return RegClass(regs)
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return GPRRangeType.__get_reg_class(self.length)
-
- @final
- def __eq__(self, other):
- if isinstance(other, GPRRangeType):
- return self.length == other.length
- return False
-
- @final
- def __hash__(self):
- return hash(self.length)
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class GPRType(GPRRangeType):
- __slots__ = ()
-
- def __init__(self, length=1):
- if length != 1:
- raise ValueError("length must be 1")
- super().__init__(length=1)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class CYType(RegType):
- __slots__ = ()
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([XERBit.CY])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GlobalMemType(RegType):
- __slots__ = ()
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([GlobalMem.GlobalMem])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class StackSlot(RegLoc):
- __slots__ = "start_slot", "length_in_slots",
-
- def __init__(self, start_slot, length_in_slots):
- # type: (int, int) -> None
- self.start_slot = start_slot
- if length_in_slots < 1:
- raise ValueError("invalid length_in_slots")
- self.length_in_slots = length_in_slots
-
- @property
- def stop_slot(self):
- return self.start_slot + self.length_in_slots
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, StackSlot):
- return (self.stop_slot > other.start_slot
- and other.stop_slot > self.start_slot)
- return False
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> StackSlot
- if not isinstance(subreg_type, StackSlotType):
- raise ValueError(f"subreg_type is not a "
- f"StackSlotType: {subreg_type}")
- if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
- raise ValueError(f"sub-register offset is out of range: {offset}")
- return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
-
-
-STACK_SLOT_COUNT = 128
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class StackSlotType(RegType):
- __slots__ = "length_in_slots",
-
- def __init__(self, length_in_slots=1):
- # type: (int) -> None
- if length_in_slots < 1:
- raise ValueError("invalid length_in_slots")
- self.length_in_slots = length_in_slots
-
- @staticmethod
- @lru_cache(maxsize=None)
- def __get_reg_class(length_in_slots):
- # type: (int) -> RegClass
- regs = []
- for start in range(STACK_SLOT_COUNT - length_in_slots):
- reg = StackSlot(start, length_in_slots)
- regs.append(reg)
- return RegClass(regs)
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return StackSlotType.__get_reg_class(self.length_in_slots)
-
- @final
- def __eq__(self, other):
- if isinstance(other, StackSlotType):
- return self.length_in_slots == other.length_in_slots
- return False
-
- @final
- def __hash__(self):
- return hash(self.length_in_slots)
-
-
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class SSAVal(Generic[_RegT_co]):
- __slots__ = "op", "arg_name", "ty", "arg_index"
-
- def __init__(self, op, arg_name, ty):
- # type: (Op, str, _RegT_co) -> None
- self.op = op
- """the Op that writes this SSAVal"""
-
- self.arg_name = arg_name
- """the name of the argument of self.op that writes this SSAVal"""
-
- self.ty = ty
-
- def __eq__(self, rhs):
- if isinstance(rhs, SSAVal):
- return (self.op is rhs.op
- and self.arg_name == rhs.arg_name)
- return False
-
- def __hash__(self):
- return hash((id(self.op), self.arg_name))
-
-
-@final
-@plain_data(unsafe_hash=True, frozen=True)
-class EqualityConstraint:
- __slots__ = "lhs", "rhs"
-
- def __init__(self, lhs, rhs):
- # type: (list[SSAVal], list[SSAVal]) -> None
- self.lhs = lhs
- self.rhs = rhs
- if len(lhs) == 0 or len(rhs) == 0:
- raise ValueError("can't constrain an empty list to be equal")
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-class Op(metaclass=ABCMeta):
- __slots__ = ()
-
- @abstractmethod
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- ...
-
- @abstractmethod
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- ...
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- if False:
- yield ...
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- if False:
- yield ...
-
- def __init__(self):
- pass
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLoadFromStackSlot(Op):
- __slots__ = "dest", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, src):
- # type: (SSAVal[GPRRangeType]) -> None
- self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
- self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpStoreToStackSlot(Op):
- __slots__ = "dest", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, src):
- # type: (SSAVal[StackSlotType]) -> None
- self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
- self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpCopy(Op, Generic[_RegT_co]):
- __slots__ = "dest", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, src):
- # type: (SSAVal[_RegT_co]) -> None
- self.dest = SSAVal(self, "dest", src.ty)
- self.src = src
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpConcat(Op):
- __slots__ = "dest", "sources"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, sources):
- # type: (Iterable[SSAVal[GPRRangeType]]) -> None
- sources = tuple(sources)
- self.dest = SSAVal(self, "dest", GPRRangeType(
- sum(i.ty.length for i in sources)))
- self.sources = sources
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([self.dest], [*self.sources])
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpSplit(Op):
- __slots__ = "results", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {i.arg_name: i for i in self.results}
-
- def __init__(self, src, split_indexes):
- # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
- ranges = [] # type: list[GPRRangeType]
- last = 0
- for i in split_indexes:
- if not (0 < i < src.ty.length):
- raise ValueError(f"invalid split index: {i}, must be in "
- f"0 < i < {src.ty.length}")
- ranges.append(GPRRangeType(i - last))
- last = i
- ranges.append(GPRRangeType(src.ty.length - last))
- self.src = src
- self.results = tuple(
- SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([*self.results], [self.src])
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpAddSubE(Op):
- __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT, "CY_out": self.CY_out}
-
- def __init__(self, RA, RB, CY_in, is_sub):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
- if RA.ty != RB.ty:
- raise TypeError(f"source types must match: "
- f"{RA} doesn't match {RB}")
- self.RT = SSAVal(self, "RT", RA.ty)
- self.RA = RA
- self.RB = RB
- self.CY_in = CY_in
- self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
- self.is_sub = is_sub
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.RA
- yield self.RT, self.RB
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpBigIntMulDiv(Op):
- __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT, "RS": self.RS}
-
- def __init__(self, RA, RB, RC, is_div):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
- self.RT = SSAVal(self, "RT", RA.ty)
- self.RA = RA
- self.RB = RB
- self.RC = RC
- self.RS = SSAVal(self, "RS", RC.ty)
- self.is_div = is_div
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([self.RC], [self.RS])
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.RA
- yield self.RT, self.RB
- yield self.RT, self.RC
- yield self.RT, self.RS
- yield self.RS, self.RA
- yield self.RS, self.RB
-
-
-@final
-@unique
-class ShiftKind(Enum):
- Sl = "sl"
- Sr = "sr"
- Sra = "sra"
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpBigIntShift(Op):
- __slots__ = "RT", "inp", "sh", "kind"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"inp": self.inp, "sh": self.sh}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT}
-
- def __init__(self, inp, sh, kind):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
- self.RT = SSAVal(self, "RT", inp.ty)
- self.inp = inp
- self.sh = sh
- self.kind = kind
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.inp
- yield self.RT, self.sh
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLI(Op):
- __slots__ = "out", "value"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, value, length=1):
- # type: (int, int) -> None
- self.out = SSAVal(self, "out", GPRRangeType(length))
- self.value = value
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpClearCY(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self):
- # type: () -> None
- self.out = SSAVal(self, "out", CYType())
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpLoad(Op):
- __slots__ = "RT", "RA", "offset", "mem"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "mem": self.mem}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT}
-
- def __init__(self, RA, offset, mem, length=1):
- # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
- self.RT = SSAVal(self, "RT", GPRRangeType(length))
- self.RA = RA
- self.offset = offset
- self.mem = mem
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- if self.RT.ty.length > 1:
- yield self.RT, self.RA
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpStore(Op):
- __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"mem_out": self.mem_out}
-
- def __init__(self, RS, RA, offset, mem_in):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
- self.RS = RS
- self.RA = RA
- self.offset = offset
- self.mem_in = mem_in
- self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpFuncArg(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, ty):
- # type: (RegType) -> None
- self.out = SSAVal(self, "out", ty)
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpInputMem(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self):
- # type: () -> None
- self.out = SSAVal(self, "out", GlobalMemType())
-
-
-def op_set_to_list(ops):
- # type: (Iterable[Op]) -> list[Op]
- worklists = [set()] # type: list[set[Op]]
- input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]]
- ops_to_pending_input_count_map = {} # type: dict[Op, int]
- for op in ops:
- input_count = 0
- for val in op.inputs().values():
- input_count += 1
- input_vals_to_ops_map[val].add(op)
- while len(worklists) <= input_count:
- worklists.append(set())
- ops_to_pending_input_count_map[op] = input_count
- worklists[input_count].add(op)
- retval = [] # type: list[Op]
- ready_vals = set() # type: set[SSAVal]
- while len(worklists[0]) != 0:
- writing_op = worklists[0].pop()
- retval.append(writing_op)
- for val in writing_op.outputs().values():
- if val in ready_vals:
- raise ValueError(f"multiple instructions must not write "
- f"to the same SSA value: {val}")
- ready_vals.add(val)
- for reading_op in input_vals_to_ops_map[val]:
- pending = ops_to_pending_input_count_map[reading_op]
- worklists[pending].remove(reading_op)
- pending -= 1
- worklists[pending].add(reading_op)
- ops_to_pending_input_count_map[reading_op] = pending
- for worklist in worklists:
- for op in worklist:
- raise ValueError(f"instruction is part of a dependency loop or "
- f"its inputs are never written: {op}")
- return retval
-
-
-@plain_data(unsafe_hash=True, order=True, frozen=True)
-class LiveInterval:
- __slots__ = "first_write", "last_use"
-
- def __init__(self, first_write, last_use=None):
- # type: (int, int | None) -> None
- if last_use is None:
- last_use = first_write
- if last_use < first_write:
- raise ValueError("uses must be after first_write")
- if first_write < 0 or last_use < 0:
- raise ValueError("indexes must be nonnegative")
- self.first_write = first_write
- self.last_use = last_use
-
- def overlaps(self, other):
- # type: (LiveInterval) -> bool
- if self.first_write == other.first_write:
- return True
- return self.last_use > other.first_write \
- and other.last_use > self.first_write
-
- def __add__(self, use):
- # type: (int) -> LiveInterval
- last_use = max(self.last_use, use)
- return LiveInterval(first_write=self.first_write, last_use=last_use)
-
- @property
- def live_after_op_range(self):
- """the range of op indexes where self is live immediately after the
- Op at each index
- """
- return range(self.first_write, self.last_use)
-
-
-@final
-class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
- def __init__(self, reg_set):
- # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
- self.__items = {} # type: dict[SSAVal[_RegT_co], int]
- if isinstance(reg_set, SSAVal):
- reg_set = [(reg_set, 0)]
- for ssa_val, offset in reg_set:
- if ssa_val in self.__items:
- other = self.__items[ssa_val]
- if offset != other:
- raise ValueError(
- f"can't merge register sets: conflicting offsets: "
- f"for {ssa_val}: {offset} != {other}")
- else:
- self.__items[ssa_val] = offset
- first_item = None
- for i in self.__items.items():
- first_item = i
- break
- if first_item is None:
- raise ValueError("can't have empty MergedRegs")
- first_ssa_val, start = first_item
- ty = first_ssa_val.ty
- if isinstance(ty, GPRRangeType):
- stop = start + ty.length
- for ssa_val, offset in self.__items.items():
- if not isinstance(ssa_val.ty, GPRRangeType):
- raise ValueError(f"can't merge incompatible types: "
- f"{ssa_val.ty} and {ty}")
- stop = max(stop, offset + ssa_val.ty.length)
- start = min(start, offset)
- ty = GPRRangeType(stop - start)
- else:
- stop = 1
- for ssa_val, offset in self.__items.items():
- if offset != 0:
- raise ValueError(f"can't have non-zero offset "
- f"for {ssa_val.ty}")
- if ty != ssa_val.ty:
- raise ValueError(f"can't merge incompatible types: "
- f"{ssa_val.ty} and {ty}")
- self.__start = start # type: int
- self.__stop = stop # type: int
- self.__ty = ty # type: RegType
- self.__hash = hash(frozenset(self.items()))
-
- @staticmethod
- def from_equality_constraint(constraint_sequence):
- # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
- if len(constraint_sequence) == 1:
- # any type allowed with len = 1
- return MergedRegSet(constraint_sequence[0])
- offset = 0
- retval = []
- for val in constraint_sequence:
- if not isinstance(val.ty, GPRRangeType):
- raise ValueError("equality constraint sequences must only "
- "have SSAVal type GPRRangeType")
- retval.append((val, offset))
- offset += val.ty.length
- return MergedRegSet(retval)
-
- @property
- def ty(self):
- return self.__ty
-
- @property
- def stop(self):
- return self.__stop
-
- @property
- def start(self):
- return self.__start
-
- @property
- def range(self):
- return range(self.__start, self.__stop)
-
- def offset_by(self, amount):
- # type: (int) -> MergedRegSet[_RegT_co]
- return MergedRegSet((k, v + amount) for k, v in self.items())
-
- def normalized(self):
- # type: () -> MergedRegSet[_RegT_co]
- return self.offset_by(-self.start)
-
- def with_offset_to_match(self, target):
- # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
- for ssa_val, offset in self.items():
- if ssa_val in target:
- return self.offset_by(target[ssa_val] - offset)
- raise ValueError("can't change offset to match unrelated MergedRegSet")
-
- def __getitem__(self, item):
- # type: (SSAVal[_RegT_co]) -> int
- return self.__items[item]
-
- def __iter__(self):
- return iter(self.__items)
-
- def __len__(self):
- return len(self.__items)
-
- def __hash__(self):
- return self.__hash
-
- def __repr__(self):
- return f"MergedRegSet({list(self.__items.items())})"
-
-
-@final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
- def __init__(self, ops):
- # type: (Iterable[Op]) -> None
- merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]]
- for op in ops:
- for val in (*op.inputs().values(), *op.outputs().values()):
- if val not in merged_sets:
- merged_sets[val] = MergedRegSet(val)
- for e in op.get_equality_constraints():
- lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
- rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
- lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
- rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
- full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
- for val in full_set.keys():
- merged_sets[val] = full_set
-
- self.__map = {k: v.normalized() for k, v in merged_sets.items()}
-
- def __getitem__(self, key):
- # type: (SSAVal) -> MergedRegSet
- return self.__map[key]
-
- def __iter__(self):
- return iter(self.__map)
-
- def __len__(self):
- return len(self.__map)
-
- def __repr__(self):
- return f"MergedRegSets(data={self.__map})"
-
-
-@final
-class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
- def __init__(self, ops):
- # type: (list[Op]) -> None
- self.__merged_reg_sets = MergedRegSets(ops)
- live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval]
- for op_idx, op in enumerate(ops):
- for val in op.inputs().values():
- live_intervals[self.__merged_reg_sets[val]] += op_idx
- for val in op.outputs().values():
- reg_set = self.__merged_reg_sets[val]
- if reg_set not in live_intervals:
- live_intervals[reg_set] = LiveInterval(op_idx)
- else:
- live_intervals[reg_set] += op_idx
- self.__live_intervals = live_intervals
- live_after = [] # type: list[set[MergedRegSet[_RegT_co]]]
- live_after += (set() for _ in ops)
- for reg_set, live_interval in self.__live_intervals.items():
- for i in live_interval.live_after_op_range:
- live_after[i].add(reg_set)
- self.__live_after = [frozenset(i) for i in live_after]
-
- @property
- def merged_reg_sets(self):
- return self.__merged_reg_sets
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegT_co]) -> LiveInterval
- return self.__live_intervals[key]
-
- def __iter__(self):
- return iter(self.__live_intervals)
-
- def reg_sets_live_after(self, op_index):
- # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
- return self.__live_after[op_index]
-
- def __repr__(self):
- reg_sets_live_after = dict(enumerate(self.__live_after))
- return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
- f"merged_reg_sets={self.merged_reg_sets}, "
- f"reg_sets_live_after={reg_sets_live_after})")
-
-
-@final
-class IGNode(Generic[_RegT_co]):
- """ interference graph node """
- __slots__ = "merged_reg_set", "edges", "reg"
-
- def __init__(self, merged_reg_set, edges=(), reg=None):
- # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
- self.merged_reg_set = merged_reg_set
- self.edges = set(edges)
- self.reg = reg
-
- def add_edge(self, other):
- # type: (IGNode) -> None
- self.edges.add(other)
- other.edges.add(self)
-
- def __eq__(self, other):
- # type: (object) -> bool
- if isinstance(other, IGNode):
- return self.merged_reg_set == other.merged_reg_set
- return NotImplemented
-
- def __hash__(self):
- return hash(self.merged_reg_set)
-
- def __repr__(self, nodes=None):
- # type: (None | dict[IGNode, int]) -> str
- if nodes is None:
- nodes = {}
- if self in nodes:
- return f"<IGNode #{nodes[self]}>"
- nodes[self] = len(nodes)
- edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
- return (f"IGNode(#{nodes[self]}, "
- f"merged_reg_set={self.merged_reg_set}, "
- f"edges={edges}, "
- f"reg={self.reg})")
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return self.merged_reg_set.ty.reg_class
-
- def reg_conflicts_with_neighbors(self, reg):
- # type: (RegLoc) -> bool
- for neighbor in self.edges:
- if neighbor.reg is not None and neighbor.reg.conflicts(reg):
- return True
- return False
-
-
-@final
-class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
- def __init__(self, merged_reg_sets):
- # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
- self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegT_co]) -> IGNode
- return self.__nodes[key]
-
- def __iter__(self):
- return iter(self.__nodes)
-
- def __repr__(self):
- nodes = {}
- nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
- nodes_text = ", ".join(nodes_text)
- return f"InterferenceGraph(nodes={{{nodes_text}}})"
-
-
-@plain_data()
-class AllocationFailed:
- __slots__ = "node", "live_intervals", "interference_graph"
-
- def __init__(self, node, live_intervals, interference_graph):
- # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
- self.node = node
- self.live_intervals = live_intervals
- self.interference_graph = interference_graph
-
-
-def try_allocate_registers_without_spilling(ops):
- # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
-
- live_intervals = LiveIntervals(ops)
- merged_reg_sets = live_intervals.merged_reg_sets
- interference_graph = InterferenceGraph(merged_reg_sets.values())
- for op_idx, op in enumerate(ops):
- reg_sets = live_intervals.reg_sets_live_after(op_idx)
- for i, j in combinations(reg_sets, 2):
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
- interference_graph[i].add_edge(interference_graph[j])
- for i, j in op.get_extra_interferences():
- i = merged_reg_sets[i]
- j = merged_reg_sets[j]
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
- interference_graph[i].add_edge(interference_graph[j])
-
- nodes_remaining = set(interference_graph.values())
-
- def local_colorability_score(node):
- # type: (IGNode) -> int
- """ returns a positive integer if node is locally colorable, returns
- zero or a negative integer if node isn't known to be locally
- colorable, the more negative the value, the less colorable
- """
- if node not in nodes_remaining:
- raise ValueError()
- retval = len(node.reg_class)
- for neighbor in node.edges:
- if neighbor in nodes_remaining:
- retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
- return retval
-
- node_stack = [] # type: list[IGNode]
- while True:
- best_node = None # type: None | IGNode
- best_score = 0
- for node in nodes_remaining:
- score = local_colorability_score(node)
- if best_node is None or score > best_score:
- best_node = node
- best_score = score
- if best_score > 0:
- # it's locally colorable, no need to find a better one
- break
-
- if best_node is None:
- break
- node_stack.append(best_node)
- nodes_remaining.remove(best_node)
-
- retval = {} # type: dict[SSAVal, RegLoc]
-
- while len(node_stack) > 0:
- node = node_stack.pop()
- if node.reg is not None:
- if node.reg_conflicts_with_neighbors(node.reg):
- return AllocationFailed(node=node,
- live_intervals=live_intervals,
- interference_graph=interference_graph)
- else:
- # pick the first non-conflicting register in node.reg_class, since
- # register classes are ordered from most preferred to least
- # preferred register.
- for reg in node.reg_class:
- if not node.reg_conflicts_with_neighbors(reg):
- node.reg = reg
- break
- if node.reg is None:
- return AllocationFailed(node=node,
- live_intervals=live_intervals,
- interference_graph=interference_graph)
-
- for ssa_val, offset in node.merged_reg_set.items():
- retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
-
- return retval
-
-
-def allocate_registers(ops):
- # type: (list[Op]) -> None
- raise NotImplementedError
+from bigint_presentation_code.compiler_ir import Op
+from bigint_presentation_code.register_allocator import allocate_registers, AllocationFailed