From cc516cd03faca03319c6500b2ca4101e23df41b7 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 13 Oct 2022 22:52:15 -0700 Subject: [PATCH] try_allocate_registers_without_spilling is completed, but untested --- src/bigint_presentation_code/toom_cook.py | 276 +++++++++++++++++++--- 1 file changed, 237 insertions(+), 39 deletions(-) diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index ce26d68..88e8f7e 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -11,7 +11,7 @@ from enum import Enum, unique, EnumMeta from functools import lru_cache from itertools import combinations from typing import (Sequence, AbstractSet, Iterable, Mapping, - TYPE_CHECKING, Sequence, Sized, TypeVar, Generic) + TYPE_CHECKING, Sequence, TypeVar, Generic) from nmutil.plain_data import plain_data @@ -26,11 +26,7 @@ class ABCEnumMeta(EnumMeta, ABCMeta): pass -class PhysLoc(metaclass=ABCMeta): - __slots__ = () - - -class RegLoc(PhysLoc): +class RegLoc(metaclass=ABCMeta): __slots__ = () @abstractmethod @@ -38,9 +34,15 @@ class RegLoc(PhysLoc): # type: (RegLoc) -> bool ... - -class GPRRangeOrStackLoc(PhysLoc, Sized): - __slots__ = () + def get_subreg_at_offset(self, subreg_type, offset): + # type: (RegType, int) -> RegLoc + if self not in subreg_type.reg_class: + raise ValueError(f"register not a member of subreg_type: " + f"reg={self} subreg_type={subreg_type}") + if offset != 0: + raise ValueError(f"non-zero sub-register offset not supported " + f"for register: {self}") + return self GPR_COUNT = 128 @@ -48,7 +50,7 @@ GPR_COUNT = 128 @plain_data(frozen=True, unsafe_hash=True) @final -class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]): +class GPRRange(RegLoc, Sequence["GPRRange"]): __slots__ = "start", "length" def __init__(self, start, length=None): @@ -110,6 +112,15 @@ class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]): return self.stop > other.start and other.stop > self.start return False + def get_subreg_at_offset(self, subreg_type, offset): + # type: (RegType, int) -> GPRRange + if not isinstance(subreg_type, GPRRangeType): + raise ValueError(f"subreg_type is not a " + f"GPRRangeType: {subreg_type}") + if offset < 0 or offset + subreg_type.length > self.stop: + raise ValueError(f"sub-register offset is out of range: {offset}") + return GPRRange(self.start + offset, subreg_type.length) + SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13) @@ -143,9 +154,14 @@ class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta): @final class RegClass(AbstractSet[RegLoc]): + """ an ordered set of registers. + earlier registers are preferred by the register allocator. + """ def __init__(self, regs): # type: (Iterable[RegLoc]) -> None - self.__regs = frozenset(regs) + + # use dict to maintain order + self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None] def __len__(self): return len(self.__regs) @@ -160,6 +176,17 @@ class RegClass(AbstractSet[RegLoc]): def __hash__(self): return super()._hash() + @lru_cache(maxsize=None, typed=True) + def max_conflicts_with(self, other): + # type: (RegClass | RegLoc) -> int + """the largest number of registers in `self` that a single register + from `other` can conflict with + """ + if isinstance(other, RegClass): + return max(self.max_conflicts_with(i) for i in other) + else: + return sum(other.conflicts(i) for i in self) + @plain_data(frozen=True, unsafe_hash=True) class RegType(metaclass=ABCMeta): @@ -183,7 +210,7 @@ class GPRRangeType(RegType): self.length = length @staticmethod - @lru_cache() + @lru_cache(maxsize=None) def __get_reg_class(length): # type: (int) -> RegClass regs = [] @@ -243,21 +270,77 @@ class GlobalMemType(RegType): return RegClass([GlobalMem.GlobalMem]) -@plain_data() +@plain_data(frozen=True, unsafe_hash=True) @final -class StackSlot(GPRRangeOrStackLoc): - """a stack slot. Use OpCopy to load from/store into this stack slot.""" - __slots__ = "offset", "length" +class StackSlot(RegLoc): + __slots__ = "start_slot", "length_in_slots", - def __init__(self, offset=None, length=1): - # type: (int | None, int) -> None - self.offset = offset - if length < 1: - raise ValueError("invalid length") - self.length = length + def __init__(self, start_slot, length_in_slots): + # type: (int, int) -> None + self.start_slot = start_slot + if length_in_slots < 1: + raise ValueError("invalid length_in_slots") + self.length_in_slots = length_in_slots - def __len__(self): - return self.length + @property + def stop_slot(self): + return self.start_slot + self.length_in_slots + + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, StackSlot): + return (self.stop_slot > other.start_slot + and other.stop_slot > self.start_slot) + return False + + def get_subreg_at_offset(self, subreg_type, offset): + # type: (RegType, int) -> StackSlot + if not isinstance(subreg_type, StackSlotType): + raise ValueError(f"subreg_type is not a " + f"StackSlotType: {subreg_type}") + if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot: + raise ValueError(f"sub-register offset is out of range: {offset}") + return StackSlot(self.start_slot + offset, subreg_type.length_in_slots) + + +STACK_SLOT_COUNT = 128 + + +@plain_data(frozen=True, eq=False) +@final +class StackSlotType(RegType): + __slots__ = "length_in_slots", + + def __init__(self, length_in_slots=1): + # type: (int) -> None + if length_in_slots < 1: + raise ValueError("invalid length_in_slots") + self.length_in_slots = length_in_slots + + @staticmethod + @lru_cache(maxsize=None) + def __get_reg_class(length_in_slots): + # type: (int) -> RegClass + regs = [] + for start in range(STACK_SLOT_COUNT - length_in_slots): + reg = StackSlot(start, length_in_slots) + regs.append(reg) + return RegClass(regs) + + @property + def reg_class(self): + # type: () -> RegClass + return StackSlotType.__get_reg_class(self.length_in_slots) + + @final + def __eq__(self, other): + if isinstance(other, StackSlotType): + return self.length_in_slots == other.length_in_slots + return False + + @final + def __hash__(self): + return hash(self.length_in_slots) _RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True) @@ -329,6 +412,44 @@ class Op(metaclass=ABCMeta): pass +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpLoadFromStackSlot(Op): + __slots__ = "dest", "src" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"dest": self.dest} + + def __init__(self, src): + # type: (SSAVal[GPRRangeType]) -> None + self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) + self.src = src + + +@plain_data(unsafe_hash=True, frozen=True) +@final +class OpStoreToStackSlot(Op): + __slots__ = "dest", "src" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"src": self.src} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"dest": self.dest} + + def __init__(self, src): + # type: (SSAVal[StackSlotType]) -> None + self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) + self.src = src + + @plain_data(unsafe_hash=True, frozen=True) @final class OpCopy(Op, Generic[_RegT_co]): @@ -745,6 +866,7 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): self.__start = start # type: int self.__stop = stop # type: int self.__ty = ty # type: RegType + self.__hash = hash(frozenset(self.items())) @staticmethod def from_equality_constraint(constraint_sequence): @@ -804,7 +926,7 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): return len(self.__items) def __hash__(self): - return hash(frozenset(self.items())) + return self.__hash def __repr__(self): return f"MergedRegSet({list(self.__items.items())})" @@ -892,12 +1014,13 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): @final class IGNode(Generic[_RegT_co]): """ interference graph node """ - __slots__ = "merged_reg_set", "edges" + __slots__ = "merged_reg_set", "edges", "reg" - def __init__(self, merged_reg_set, edges=()): - # type: (MergedRegSet[_RegT_co], Iterable[IGNode]) -> None + def __init__(self, merged_reg_set, edges=(), reg=None): + # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None self.merged_reg_set = merged_reg_set self.edges = set(edges) + self.reg = reg def add_edge(self, other): # type: (IGNode) -> None @@ -923,7 +1046,20 @@ class IGNode(Generic[_RegT_co]): edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}" return (f"IGNode(#{nodes[self]}, " f"merged_reg_set={self.merged_reg_set}, " - f"edges={edges})") + f"edges={edges}, " + f"reg={self.reg})") + + @property + def reg_class(self): + # type: () -> RegClass + return self.merged_reg_set.ty.reg_class + + def reg_conflicts_with_neighbors(self, reg): + # type: (RegLoc) -> bool + for neighbor in self.edges: + if neighbor.reg is not None and neighbor.reg.conflicts(reg): + return True + return False @final @@ -948,17 +1084,17 @@ class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]): @plain_data() class AllocationFailed: - __slots__ = "op_idx", "arg", "live_intervals" + __slots__ = "node", "live_intervals", "interference_graph" - def __init__(self, op_idx, arg, live_intervals): - # type: (int, SSAVal, LiveIntervals) -> None - self.op_idx = op_idx - self.arg = arg + def __init__(self, node, live_intervals, interference_graph): + # type: (IGNode, LiveIntervals, InterferenceGraph) -> None + self.node = node self.live_intervals = live_intervals + self.interference_graph = interference_graph def try_allocate_registers_without_spilling(ops): - # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed + # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed live_intervals = LiveIntervals(ops) merged_reg_sets = live_intervals.merged_reg_sets @@ -966,12 +1102,74 @@ def try_allocate_registers_without_spilling(ops): for op_idx, op in enumerate(ops): reg_sets = live_intervals.reg_sets_live_after(op_idx) for i, j in combinations(reg_sets, 2): - interference_graph[i].add_edge(interference_graph[j]) + if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: + interference_graph[i].add_edge(interference_graph[j]) for i, j in op.get_extra_interferences(): - interference_graph[merged_reg_sets[i]].add_edge( - interference_graph[merged_reg_sets[j]]) + i = merged_reg_sets[i] + j = merged_reg_sets[j] + if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: + interference_graph[i].add_edge(interference_graph[j]) + + nodes_remaining = set(interference_graph.values()) + + def local_colorability_score(node): + # type: (IGNode) -> int + """ returns a positive integer if node is locally colorable, returns + zero or a negative integer if node isn't known to be locally + colorable, the more negative the value, the less colorable + """ + if node not in nodes_remaining: + raise ValueError() + retval = len(node.reg_class) + for neighbor in node.edges: + if neighbor in nodes_remaining: + retval -= node.reg_class.max_conflicts_with(neighbor.reg_class) + return retval + + node_stack = [] # type: list[IGNode] + while True: + best_node = None # type: None | IGNode + best_score = 0 + for node in nodes_remaining: + score = local_colorability_score(node) + if best_node is None or score > best_score: + best_node = node + best_score = score + if best_score > 0: + # it's locally colorable, no need to find a better one + break + + if best_node is None: + break + node_stack.append(best_node) + nodes_remaining.remove(best_node) + + retval = {} # type: dict[SSAVal, RegLoc] + + while len(node_stack) > 0: + node = node_stack.pop() + if node.reg is not None: + if node.reg_conflicts_with_neighbors(node.reg): + return AllocationFailed(node=node, + live_intervals=live_intervals, + interference_graph=interference_graph) + else: + # pick the first non-conflicting register in node.reg_class, since + # register classes are ordered from most preferred to least + # preferred register. + for reg in node.reg_class: + if not node.reg_conflicts_with_neighbors(reg): + node.reg = reg + break + if node.reg is None: + return AllocationFailed(node=node, + live_intervals=live_intervals, + interference_graph=interference_graph) + + for ssa_val, offset in node.merged_reg_set.items(): + retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset) - raise NotImplementedError + return retval def allocate_registers(ops): -- 2.30.2