from functools import lru_cache
from itertools import combinations
from typing import (Sequence, AbstractSet, Iterable, Mapping,
- TYPE_CHECKING, Sequence, Sized, TypeVar, Generic)
+ TYPE_CHECKING, Sequence, TypeVar, Generic)
from nmutil.plain_data import plain_data
pass
-class PhysLoc(metaclass=ABCMeta):
- __slots__ = ()
-
-
-class RegLoc(PhysLoc):
+class RegLoc(metaclass=ABCMeta):
__slots__ = ()
@abstractmethod
# type: (RegLoc) -> bool
...
-
-class GPRRangeOrStackLoc(PhysLoc, Sized):
- __slots__ = ()
+ def get_subreg_at_offset(self, subreg_type, offset):
+ # type: (RegType, int) -> RegLoc
+ if self not in subreg_type.reg_class:
+ raise ValueError(f"register not a member of subreg_type: "
+ f"reg={self} subreg_type={subreg_type}")
+ if offset != 0:
+ raise ValueError(f"non-zero sub-register offset not supported "
+ f"for register: {self}")
+ return self
GPR_COUNT = 128
@plain_data(frozen=True, unsafe_hash=True)
@final
-class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]):
+class GPRRange(RegLoc, Sequence["GPRRange"]):
__slots__ = "start", "length"
def __init__(self, start, length=None):
return self.stop > other.start and other.stop > self.start
return False
+ def get_subreg_at_offset(self, subreg_type, offset):
+ # type: (RegType, int) -> GPRRange
+ if not isinstance(subreg_type, GPRRangeType):
+ raise ValueError(f"subreg_type is not a "
+ f"GPRRangeType: {subreg_type}")
+ if offset < 0 or offset + subreg_type.length > self.stop:
+ raise ValueError(f"sub-register offset is out of range: {offset}")
+ return GPRRange(self.start + offset, subreg_type.length)
+
SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
@final
class RegClass(AbstractSet[RegLoc]):
+ """ an ordered set of registers.
+ earlier registers are preferred by the register allocator.
+ """
def __init__(self, regs):
# type: (Iterable[RegLoc]) -> None
- self.__regs = frozenset(regs)
+
+ # use dict to maintain order
+ self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None]
def __len__(self):
return len(self.__regs)
def __hash__(self):
return super()._hash()
+ @lru_cache(maxsize=None, typed=True)
+ def max_conflicts_with(self, other):
+ # type: (RegClass | RegLoc) -> int
+ """the largest number of registers in `self` that a single register
+ from `other` can conflict with
+ """
+ if isinstance(other, RegClass):
+ return max(self.max_conflicts_with(i) for i in other)
+ else:
+ return sum(other.conflicts(i) for i in self)
+
@plain_data(frozen=True, unsafe_hash=True)
class RegType(metaclass=ABCMeta):
self.length = length
@staticmethod
- @lru_cache()
+ @lru_cache(maxsize=None)
def __get_reg_class(length):
# type: (int) -> RegClass
regs = []
return RegClass([GlobalMem.GlobalMem])
-@plain_data()
+@plain_data(frozen=True, unsafe_hash=True)
@final
-class StackSlot(GPRRangeOrStackLoc):
- """a stack slot. Use OpCopy to load from/store into this stack slot."""
- __slots__ = "offset", "length"
+class StackSlot(RegLoc):
+ __slots__ = "start_slot", "length_in_slots",
- def __init__(self, offset=None, length=1):
- # type: (int | None, int) -> None
- self.offset = offset
- if length < 1:
- raise ValueError("invalid length")
- self.length = length
+ def __init__(self, start_slot, length_in_slots):
+ # type: (int, int) -> None
+ self.start_slot = start_slot
+ if length_in_slots < 1:
+ raise ValueError("invalid length_in_slots")
+ self.length_in_slots = length_in_slots
- def __len__(self):
- return self.length
+ @property
+ def stop_slot(self):
+ return self.start_slot + self.length_in_slots
+
+ def conflicts(self, other):
+ # type: (RegLoc) -> bool
+ if isinstance(other, StackSlot):
+ return (self.stop_slot > other.start_slot
+ and other.stop_slot > self.start_slot)
+ return False
+
+ def get_subreg_at_offset(self, subreg_type, offset):
+ # type: (RegType, int) -> StackSlot
+ if not isinstance(subreg_type, StackSlotType):
+ raise ValueError(f"subreg_type is not a "
+ f"StackSlotType: {subreg_type}")
+ if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
+ raise ValueError(f"sub-register offset is out of range: {offset}")
+ return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
+
+
+STACK_SLOT_COUNT = 128
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class StackSlotType(RegType):
+ __slots__ = "length_in_slots",
+
+ def __init__(self, length_in_slots=1):
+ # type: (int) -> None
+ if length_in_slots < 1:
+ raise ValueError("invalid length_in_slots")
+ self.length_in_slots = length_in_slots
+
+ @staticmethod
+ @lru_cache(maxsize=None)
+ def __get_reg_class(length_in_slots):
+ # type: (int) -> RegClass
+ regs = []
+ for start in range(STACK_SLOT_COUNT - length_in_slots):
+ reg = StackSlot(start, length_in_slots)
+ regs.append(reg)
+ return RegClass(regs)
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return StackSlotType.__get_reg_class(self.length_in_slots)
+
+ @final
+ def __eq__(self, other):
+ if isinstance(other, StackSlotType):
+ return self.length_in_slots == other.length_in_slots
+ return False
+
+ @final
+ def __hash__(self):
+ return hash(self.length_in_slots)
_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
pass
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLoadFromStackSlot(Op):
+ __slots__ = "dest", "src"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"src": self.src}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"dest": self.dest}
+
+ def __init__(self, src):
+ # type: (SSAVal[GPRRangeType]) -> None
+ self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
+ self.src = src
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpStoreToStackSlot(Op):
+ __slots__ = "dest", "src"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"src": self.src}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"dest": self.dest}
+
+ def __init__(self, src):
+ # type: (SSAVal[StackSlotType]) -> None
+ self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
+ self.src = src
+
+
@plain_data(unsafe_hash=True, frozen=True)
@final
class OpCopy(Op, Generic[_RegT_co]):
self.__start = start # type: int
self.__stop = stop # type: int
self.__ty = ty # type: RegType
+ self.__hash = hash(frozenset(self.items()))
@staticmethod
def from_equality_constraint(constraint_sequence):
return len(self.__items)
def __hash__(self):
- return hash(frozenset(self.items()))
+ return self.__hash
def __repr__(self):
return f"MergedRegSet({list(self.__items.items())})"
@final
class IGNode(Generic[_RegT_co]):
""" interference graph node """
- __slots__ = "merged_reg_set", "edges"
+ __slots__ = "merged_reg_set", "edges", "reg"
- def __init__(self, merged_reg_set, edges=()):
- # type: (MergedRegSet[_RegT_co], Iterable[IGNode]) -> None
+ def __init__(self, merged_reg_set, edges=(), reg=None):
+ # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
self.merged_reg_set = merged_reg_set
self.edges = set(edges)
+ self.reg = reg
def add_edge(self, other):
# type: (IGNode) -> None
edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
return (f"IGNode(#{nodes[self]}, "
f"merged_reg_set={self.merged_reg_set}, "
- f"edges={edges})")
+ f"edges={edges}, "
+ f"reg={self.reg})")
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return self.merged_reg_set.ty.reg_class
+
+ def reg_conflicts_with_neighbors(self, reg):
+ # type: (RegLoc) -> bool
+ for neighbor in self.edges:
+ if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+ return True
+ return False
@final
@plain_data()
class AllocationFailed:
- __slots__ = "op_idx", "arg", "live_intervals"
+ __slots__ = "node", "live_intervals", "interference_graph"
- def __init__(self, op_idx, arg, live_intervals):
- # type: (int, SSAVal, LiveIntervals) -> None
- self.op_idx = op_idx
- self.arg = arg
+ def __init__(self, node, live_intervals, interference_graph):
+ # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+ self.node = node
self.live_intervals = live_intervals
+ self.interference_graph = interference_graph
def try_allocate_registers_without_spilling(ops):
- # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
+ # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
live_intervals = LiveIntervals(ops)
merged_reg_sets = live_intervals.merged_reg_sets
for op_idx, op in enumerate(ops):
reg_sets = live_intervals.reg_sets_live_after(op_idx)
for i, j in combinations(reg_sets, 2):
- interference_graph[i].add_edge(interference_graph[j])
+ if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+ interference_graph[i].add_edge(interference_graph[j])
for i, j in op.get_extra_interferences():
- interference_graph[merged_reg_sets[i]].add_edge(
- interference_graph[merged_reg_sets[j]])
+ i = merged_reg_sets[i]
+ j = merged_reg_sets[j]
+ if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+ interference_graph[i].add_edge(interference_graph[j])
+
+ nodes_remaining = set(interference_graph.values())
+
+ def local_colorability_score(node):
+ # type: (IGNode) -> int
+ """ returns a positive integer if node is locally colorable, returns
+ zero or a negative integer if node isn't known to be locally
+ colorable, the more negative the value, the less colorable
+ """
+ if node not in nodes_remaining:
+ raise ValueError()
+ retval = len(node.reg_class)
+ for neighbor in node.edges:
+ if neighbor in nodes_remaining:
+ retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
+ return retval
+
+ node_stack = [] # type: list[IGNode]
+ while True:
+ best_node = None # type: None | IGNode
+ best_score = 0
+ for node in nodes_remaining:
+ score = local_colorability_score(node)
+ if best_node is None or score > best_score:
+ best_node = node
+ best_score = score
+ if best_score > 0:
+ # it's locally colorable, no need to find a better one
+ break
+
+ if best_node is None:
+ break
+ node_stack.append(best_node)
+ nodes_remaining.remove(best_node)
+
+ retval = {} # type: dict[SSAVal, RegLoc]
+
+ while len(node_stack) > 0:
+ node = node_stack.pop()
+ if node.reg is not None:
+ if node.reg_conflicts_with_neighbors(node.reg):
+ return AllocationFailed(node=node,
+ live_intervals=live_intervals,
+ interference_graph=interference_graph)
+ else:
+ # pick the first non-conflicting register in node.reg_class, since
+ # register classes are ordered from most preferred to least
+ # preferred register.
+ for reg in node.reg_class:
+ if not node.reg_conflicts_with_neighbors(reg):
+ node.reg = reg
+ break
+ if node.reg is None:
+ return AllocationFailed(node=node,
+ live_intervals=live_intervals,
+ interference_graph=interference_graph)
+
+ for ssa_val, offset in node.merged_reg_set.items():
+ retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
- raise NotImplementedError
+ return retval
def allocate_registers(ops):