+from collections import defaultdict
from contextlib import contextmanager
import enum
from abc import ABCMeta, abstractmethod
""" map from SSAValSubRegs to the original SSAValSubRegs that they are
a copy of, looking through all layers of copies. The map excludes all
SSAValSubRegs that aren't copies of other SSAValSubRegs.
+ This ignores inputs of copy Ops that aren't actually being copied
+ (e.g. the VL input of VecCopyToReg).
"""
retval = {} # type: dict[SSAValSubReg, SSAValSubReg]
for op in self.op_indexes.keys():
retval[out] = inp
return FMap(retval)
+ @cached_property
+ def copy_related_ssa_vals(self):
+ # type: () -> FMap[SSAVal, OFSet[SSAVal]]
+ """ map from SSAVals to the full set of SSAVals that are related by
+ being sources/destinations of copies, transitively looking through all
+ copies.
+ This ignores inputs of copy Ops that aren't actually being copied
+ (e.g. the VL input of VecCopyToReg).
+ """
+ sets_map = {i: OSet([i]) for i in self.uses.keys()}
+ for k, v in self.copies.items():
+ k_set = sets_map[k.ssa_val]
+ v_set = sets_map[v.ssa_val]
+ # merge k_set and v_set
+ if k_set is v_set:
+ continue
+ k_set |= v_set
+ for i in k_set:
+ sets_map[i] = k_set
+ # this way we construct each OFSet only once rather than
+ # for each SSAVal
+ sets_set = {id(i): i for i in sets_map.values()}
+ retval = {} # type: dict[SSAVal, OFSet[SSAVal]]
+ for v in sets_set.values():
+ v = OFSet(v)
+ for k in v:
+ retval[k] = v
+ return FMap(retval)
+
@cached_property
def const_ssa_vals(self):
# type: () -> FMap[SSAVal, tuple[int, ...]]
[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
"""
+from functools import reduce
from itertools import combinations
from typing import Iterable, Iterator, Mapping, TextIO
f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
f"live_interval={self.live_interval})")
+ @cached_property
+ def copy_related_ssa_vals(self):
+ # type: () -> OFSet[SSAVal]
+ sets = OSet() # type: OSet[OFSet[SSAVal]]
+ # avoid merging the same sets multiple times
+ for ssa_val in self.ssa_vals:
+ sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
+ return OFSet(v for s in sets for v in s)
+
@final
class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
f"{self.__merged_ssa_val_map[ssa_val]}")
self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
added += 1
- retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
+ retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None)
self.__map[merged_ssa_val] = retval
added = None
return retval
def merge_into_one_node(self, final_merged_ssa_val):
# type: (MergedSSAVal) -> IGNode
source_nodes = OSet() # type: OSet[IGNode]
- edges = OSet() # type: OSet[IGNode]
+ edges = {} # type: dict[IGNode, IGEdge]
loc = None # type: Loc | None
for ssa_val in final_merged_ssa_val.ssa_vals:
merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
elif source_node.loc is not None and loc != source_node.loc:
raise ValueError(f"can't merge IGNodes with mismatched `loc` "
f"values: {loc} != {source_node.loc}")
- edges |= source_node.edges
+ for n, edge in source_node.edges.items():
+ if n in edges:
+ edge = edge.merged(edges[n])
+ edges[n] = edge
if len(source_nodes) == 1:
return source_nodes.pop() # merging a single node is a no-op
# we're finished checking validity, now we can modify stuff
- edges -= source_nodes
+ for n in source_nodes:
+ edges.pop(n, None)
retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
loc=loc)
for node in edges:
- node.edges -= source_nodes
- node.edges.add(retval)
+ edge = reduce(IGEdge.merged,
+ (node.edges.pop(n) for n in source_nodes))
+ node.edges[retval] = edge
for node in source_nodes:
del self.__map[node.merged_ssa_val]
self.__map[final_merged_ssa_val] = retval
self.did_full_repr = OSet() # type: OSet[IGNode]
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class IGEdge:
+ """ interference graph edge """
+ __slots__ = "is_copy_related",
+
+ def __init__(self, is_copy_related):
+ # type: (bool) -> None
+ self.is_copy_related = is_copy_related
+
+ def merged(self, other):
+ # type: (IGEdge) -> IGEdge
+ is_copy_related = self.is_copy_related | other.is_copy_related
+ return IGEdge(is_copy_related=is_copy_related)
+
+
@final
class IGNode:
""" interference graph node """
__slots__ = "merged_ssa_val", "edges", "loc"
def __init__(self, merged_ssa_val, edges, loc):
- # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
+ # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None
self.merged_ssa_val = merged_ssa_val
- self.edges = OSet(edges)
+ self.edges = edges
self.loc = loc
- def add_edge(self, other):
- # type: (IGNode) -> None
- self.edges.add(other)
- other.edges.add(self)
+ def add_edge(self, other, edge):
+ # type: (IGNode, IGEdge) -> None
+ self.edges[other] = edge
+ other.edges[self] = edge
def __eq__(self, other):
# type: (object) -> bool
def __repr__(self, repr_state=None, short=False):
# type: (None | IGNodeReprState, bool) -> str
- if repr_state is None:
- repr_state = IGNodeReprState()
- node_id = repr_state.node_ids.get(self, None)
+ rs = repr_state
+ del repr_state
+ if rs is None:
+ rs = IGNodeReprState()
+ node_id = rs.node_ids.get(self, None)
if node_id is None:
- repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
- if short or self in repr_state.did_full_repr:
+ rs.node_ids[self] = node_id = len(rs.node_ids)
+ if short or self in rs.did_full_repr:
return f"<IGNode #{node_id}>"
- repr_state.did_full_repr.add(self)
- edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
+ rs.did_full_repr.add(self)
+ edges = ", ".join(
+ f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items())
return (f"IGNode(#{node_id}, "
f"merged_ssa_val={self.merged_ssa_val}, "
f"edges={{{edges}}}, "
interference_graph.merged_ssa_val_map[ssa_val])
for i, j in combinations(live_merged_ssa_vals, 2):
if i.loc_set.max_conflicts_with(j.loc_set) != 0:
+ # can't use:
+ # is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
+ # j.copy_related_ssa_vals)
+ # since it is too coarse
+
+ # TODO: fill in is_copy_related afterwards
+ # using fn_analysis.copies
interference_graph.nodes[i].add_edge(
- interference_graph.nodes[j])
+ interference_graph.nodes[j],
+ edge=IGEdge(is_copy_related=False))
if debug_out is not None:
print(f"processed {pp} out of {fn_analysis.all_program_points}",
file=debug_out, flush=True)