from collections import defaultdict
from enum import Enum, unique, EnumMeta
from functools import lru_cache
+from itertools import combinations
from typing import (Sequence, AbstractSet, Iterable, Mapping,
TYPE_CHECKING, Sequence, Sized, TypeVar, Generic)
return self.length
-_RegType = TypeVar("_RegType", bound=RegType)
+_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
@plain_data(frozen=True, eq=False)
@final
-class SSAVal(Generic[_RegType]):
+class SSAVal(Generic[_RegT_co]):
__slots__ = "op", "arg_name", "ty", "arg_index"
def __init__(self, op, arg_name, ty):
- # type: (Op, str, _RegType) -> None
+ # type: (Op, str, _RegT_co) -> None
self.op = op
"""the Op that writes this SSAVal"""
if False:
yield ...
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ if False:
+ yield ...
+
def __init__(self):
pass
@plain_data(unsafe_hash=True, frozen=True)
@final
-class OpCopy(Op, Generic[_RegType]):
+class OpCopy(Op, Generic[_RegT_co]):
__slots__ = "dest", "src"
def inputs(self):
return {"dest": self.dest}
def __init__(self, src):
- # type: (SSAVal[_RegType]) -> None
+ # type: (SSAVal[_RegT_co]) -> None
self.dest = SSAVal(self, "dest", src.ty)
self.src = src
self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
self.is_sub = is_sub
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ yield self.RT, self.RA
+ yield self.RT, self.RB
+
@plain_data(unsafe_hash=True, frozen=True)
@final
# type: () -> Iterable[EqualityConstraint]
yield EqualityConstraint([self.RC], [self.RS])
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ yield self.RT, self.RA
+ yield self.RT, self.RB
+ yield self.RT, self.RC
+ yield self.RT, self.RS
+ yield self.RS, self.RA
+ yield self.RS, self.RB
+
@final
@unique
self.sh = sh
self.kind = kind
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ yield self.RT, self.inp
+ yield self.RT, self.sh
+
@plain_data(unsafe_hash=True, frozen=True)
@final
self.offset = offset
self.mem = mem
+ def get_extra_interferences(self):
+ # type: () -> Iterable[tuple[SSAVal, SSAVal]]
+ if self.RT.ty.length > 1:
+ yield self.RT, self.RA
+
@plain_data(unsafe_hash=True, frozen=True)
@final
last_use = max(self.last_use, use)
return LiveInterval(first_write=self.first_write, last_use=last_use)
+ @property
+ def live_after_op_range(self):
+ """the range of op indexes where self is live immediately after the
+ Op at each index
+ """
+ return range(self.first_write, self.last_use)
+
@final
-class MergedRegSet(Mapping[SSAVal[_RegType], int]):
+class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
def __init__(self, reg_set):
- # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
- self.__items = {} # type: dict[SSAVal[_RegType], int]
+ # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
+ self.__items = {} # type: dict[SSAVal[_RegT_co], int]
if isinstance(reg_set, SSAVal):
reg_set = [(reg_set, 0)]
for ssa_val, offset in reg_set:
@staticmethod
def from_equality_constraint(constraint_sequence):
- # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
+ # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
if len(constraint_sequence) == 1:
# any type allowed with len = 1
return MergedRegSet(constraint_sequence[0])
return range(self.__start, self.__stop)
def offset_by(self, amount):
- # type: (int) -> MergedRegSet[_RegType]
+ # type: (int) -> MergedRegSet[_RegT_co]
return MergedRegSet((k, v + amount) for k, v in self.items())
def normalized(self):
- # type: () -> MergedRegSet[_RegType]
+ # type: () -> MergedRegSet[_RegT_co]
return self.offset_by(-self.start)
def with_offset_to_match(self, target):
- # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
+ # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
for ssa_val, offset in self.items():
if ssa_val in target:
return self.offset_by(target[ssa_val] - offset)
raise ValueError("can't change offset to match unrelated MergedRegSet")
def __getitem__(self, item):
- # type: (SSAVal[_RegType]) -> int
+ # type: (SSAVal[_RegT_co]) -> int
return self.__items[item]
def __iter__(self):
@final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet]):
+class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
def __init__(self, ops):
# type: (Iterable[Op]) -> None
- merged_sets = {} # type: dict[SSAVal, MergedRegSet]
+ merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]]
for op in ops:
for val in (*op.inputs().values(), *op.outputs().values()):
if val not in merged_sets:
def __len__(self):
return len(self.__map)
+ def __repr__(self):
+ return f"MergedRegSets(data={self.__map})"
+
@final
-class LiveIntervals(Mapping[MergedRegSet, LiveInterval]):
+class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
def __init__(self, ops):
# type: (list[Op]) -> None
self.__merged_reg_sets = MergedRegSets(ops)
- live_intervals = {} # type: dict[MergedRegSet, LiveInterval]
+ live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval]
for op_idx, op in enumerate(ops):
for val in op.inputs().values():
live_intervals[self.__merged_reg_sets[val]] += op_idx
else:
live_intervals[reg_set] += op_idx
self.__live_intervals = live_intervals
+ live_after = [] # type: list[set[MergedRegSet[_RegT_co]]]
+ live_after += (set() for _ in ops)
+ for reg_set, live_interval in self.__live_intervals.items():
+ for i in live_interval.live_after_op_range:
+ live_after[i].add(reg_set)
+ self.__live_after = [frozenset(i) for i in live_after]
@property
def merged_reg_sets(self):
return self.__merged_reg_sets
def __getitem__(self, key):
- # type: (MergedRegSet) -> LiveInterval
+ # type: (MergedRegSet[_RegT_co]) -> LiveInterval
return self.__live_intervals[key]
def __iter__(self):
return iter(self.__live_intervals)
+ def reg_sets_live_after(self, op_index):
+ # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
+ return self.__live_after[op_index]
+
+ def __repr__(self):
+ reg_sets_live_after = dict(enumerate(self.__live_after))
+ return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
+ f"merged_reg_sets={self.merged_reg_sets}, "
+ f"reg_sets_live_after={reg_sets_live_after})")
+
@final
-class IGNode:
+class IGNode(Generic[_RegT_co]):
""" interference graph node """
__slots__ = "merged_reg_set", "edges"
def __init__(self, merged_reg_set, edges=()):
- # type: (MergedRegSet, Iterable[IGNode]) -> None
+ # type: (MergedRegSet[_RegT_co], Iterable[IGNode]) -> None
self.merged_reg_set = merged_reg_set
self.edges = set(edges)
@final
-class InterferenceGraph(Mapping[MergedRegSet, IGNode]):
+class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
def __init__(self, merged_reg_sets):
- # type: (Iterable[MergedRegSet]) -> None
+ # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
def __getitem__(self, key):
- # type: (MergedRegSet) -> IGNode
+ # type: (MergedRegSet[_RegT_co]) -> IGNode
return self.__nodes[key]
def __iter__(self):
return iter(self.__nodes)
+ def __repr__(self):
+ nodes = {}
+ nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
+ nodes_text = ", ".join(nodes_text)
+ return f"InterferenceGraph(nodes={{{nodes_text}}})"
+
@plain_data()
class AllocationFailed:
# type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
live_intervals = LiveIntervals(ops)
+ merged_reg_sets = live_intervals.merged_reg_sets
+ interference_graph = InterferenceGraph(merged_reg_sets.values())
+ for op_idx, op in enumerate(ops):
+ reg_sets = live_intervals.reg_sets_live_after(op_idx)
+ for i, j in combinations(reg_sets, 2):
+ interference_graph[i].add_edge(interference_graph[j])
+ for i, j in op.get_extra_interferences():
+ interference_graph[merged_reg_sets[i]].add_edge(
+ interference_graph[merged_reg_sets[j]])
raise NotImplementedError