import builtins
from collections import defaultdict
from enum import Enum, unique
-from typing import Iterable, Mapping, TYPE_CHECKING
+from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING
from nmutil.plain_data import plain_data
return LiveInterval(assignment=self.assignment, last_use=last_use)
-class LiveIntervals(Mapping[SSAVal, LiveInterval]):
+@final
+class EqualitySet(AbstractSet[SSAVal]):
+ def __init__(self, items):
+ # type: (Iterable[SSAVal]) -> None
+ self.__items = frozenset(items)
+
+ def __contains__(self, x):
+ # type: (object) -> bool
+ return x in self.__items
+
+ def __iter__(self):
+ return iter(self.__items)
+
+ def __len__(self):
+ return len(self.__items)
+
+
+@final
+class EqualitySets(Mapping[SSAVal, EqualitySet]):
+ def __init__(self, ops):
+ # type: (Iterable[Op]) -> None
+ indexes = {} # type: dict[SSAVal, int]
+ sets = [] # type: list[set[SSAVal]]
+ for op in ops:
+ for val in (*op.input_ssa_vals(), *op.output_ssa_vals()):
+ if val not in indexes:
+ indexes[val] = len(sets)
+ sets.append({val})
+ for e in op.get_equality_constraints():
+ lhs_index = indexes[e.lhs]
+ rhs_index = indexes[e.rhs]
+ sets[lhs_index] |= sets[rhs_index]
+ for val in sets[rhs_index]:
+ indexes[val] = lhs_index
+
+ equality_sets = [EqualitySet(i) for i in sets]
+ self.__map = {k: equality_sets[v] for k, v in indexes.items()}
+
+ def __getitem__(self, key):
+ # type: (SSAVal) -> EqualitySet
+ return self.__map[key]
+
+ def __iter__(self):
+ return iter(self.__map)
+
+
+@final
+class LiveIntervals(Mapping[EqualitySet, LiveInterval]):
def __init__(self, ops):
# type: (list[Op]) -> None
- live_intervals = {} # type: dict[SSAVal, LiveInterval]
+ self.__equality_sets = eqsets = EqualitySets(ops)
+ live_intervals = {} # type: dict[EqualitySet, LiveInterval]
for op_idx, op in enumerate(ops):
for val in op.input_ssa_vals():
- live_intervals[val] += op_idx
+ live_intervals[eqsets[val]] += op_idx
for val in op.output_ssa_vals():
- if val in live_intervals:
- raise ValueError(f"multiple instructions must not write "
- f"to the same SSA value: {val}")
- live_intervals[val] = LiveInterval(op_idx)
+ if eqsets[val] not in live_intervals:
+ live_intervals[eqsets[val]] = LiveInterval(op_idx)
+ else:
+ live_intervals[eqsets[val]] += op_idx
self.__live_intervals = live_intervals
+ @property
+ def equality_sets(self):
+ return self.__equality_sets
+
def __getitem__(self, key):
- # type: (SSAVal) -> LiveInterval
+ # type: (EqualitySet) -> LiveInterval
return self.__live_intervals[key]
def __iter__(self):
@plain_data()
class AllocationFailed:
- __slots__ = "op_idx", "arg", "live_intervals", "free_regs"
+ __slots__ = "op_idx", "arg", "live_intervals"
- def __init__(self, op_idx, arg, live_intervals, free_regs):
- # type: (int, SSAVal | VecArg, LiveIntervals, set[GPR | XERBit]) -> None
+ def __init__(self, op_idx, arg, live_intervals):
+ # type: (int, SSAVal | VecArg, LiveIntervals) -> None
self.op_idx = op_idx
self.arg = arg
self.live_intervals = live_intervals
- self.free_regs = free_regs
def try_allocate_registers_without_spilling(ops):