From a5b8891ddced10b189006459c5e67a7f588c74bd Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 1 Dec 2022 23:19:59 -0800 Subject: [PATCH] add is_copy_related to interference graph edges --- src/bigint_presentation_code/compiler_ir.py | 32 ++++++++ .../register_allocator.py | 82 ++++++++++++++----- 2 files changed, 94 insertions(+), 20 deletions(-) diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 12a32f1..0553a40 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -1,3 +1,4 @@ +from collections import defaultdict from contextlib import contextmanager import enum from abc import ABCMeta, abstractmethod @@ -413,6 +414,8 @@ class FnAnalysis: """ map from SSAValSubRegs to the original SSAValSubRegs that they are a copy of, looking through all layers of copies. The map excludes all SSAValSubRegs that aren't copies of other SSAValSubRegs. + This ignores inputs of copy Ops that aren't actually being copied + (e.g. the VL input of VecCopyToReg). """ retval = {} # type: dict[SSAValSubReg, SSAValSubReg] for op in self.op_indexes.keys(): @@ -434,6 +437,35 @@ class FnAnalysis: retval[out] = inp return FMap(retval) + @cached_property + def copy_related_ssa_vals(self): + # type: () -> FMap[SSAVal, OFSet[SSAVal]] + """ map from SSAVals to the full set of SSAVals that are related by + being sources/destinations of copies, transitively looking through all + copies. + This ignores inputs of copy Ops that aren't actually being copied + (e.g. the VL input of VecCopyToReg). + """ + sets_map = {i: OSet([i]) for i in self.uses.keys()} + for k, v in self.copies.items(): + k_set = sets_map[k.ssa_val] + v_set = sets_map[v.ssa_val] + # merge k_set and v_set + if k_set is v_set: + continue + k_set |= v_set + for i in k_set: + sets_map[i] = k_set + # this way we construct each OFSet only once rather than + # for each SSAVal + sets_set = {id(i): i for i in sets_map.values()} + retval = {} # type: dict[SSAVal, OFSet[SSAVal]] + for v in sets_set.values(): + v = OFSet(v) + for k in v: + retval[k] = v + return FMap(retval) + @cached_property def const_ssa_vals(self): # type: () -> FMap[SSAVal, tuple[int, ...]] diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index 73381e4..be1522e 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -5,6 +5,7 @@ this uses an algorithm based on: [Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) """ +from functools import reduce from itertools import combinations from typing import Iterable, Iterator, Mapping, TextIO @@ -253,6 +254,15 @@ class MergedSSAVal(metaclass=InternedMeta): f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, " f"live_interval={self.live_interval})") + @cached_property + def copy_related_ssa_vals(self): + # type: () -> OFSet[SSAVal] + sets = OSet() # type: OSet[OFSet[SSAVal]] + # avoid merging the same sets multiple times + for ssa_val in self.ssa_vals: + sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val]) + return OFSet(v for s in sets for v in s) + @final class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]): @@ -322,7 +332,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): f"{self.__merged_ssa_val_map[ssa_val]}") self.__merged_ssa_val_map[ssa_val] = merged_ssa_val added += 1 - retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None) + retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None) self.__map[merged_ssa_val] = retval added = None return retval @@ -337,7 +347,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): def merge_into_one_node(self, final_merged_ssa_val): # type: (MergedSSAVal) -> IGNode source_nodes = OSet() # type: OSet[IGNode] - edges = OSet() # type: OSet[IGNode] + edges = {} # type: dict[IGNode, IGEdge] loc = None # type: Loc | None for ssa_val in final_merged_ssa_val.ssa_vals: merged_ssa_val = self.__merged_ssa_val_map[ssa_val] @@ -354,16 +364,21 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): elif source_node.loc is not None and loc != source_node.loc: raise ValueError(f"can't merge IGNodes with mismatched `loc` " f"values: {loc} != {source_node.loc}") - edges |= source_node.edges + for n, edge in source_node.edges.items(): + if n in edges: + edge = edge.merged(edges[n]) + edges[n] = edge if len(source_nodes) == 1: return source_nodes.pop() # merging a single node is a no-op # we're finished checking validity, now we can modify stuff - edges -= source_nodes + for n in source_nodes: + edges.pop(n, None) retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges, loc=loc) for node in edges: - node.edges -= source_nodes - node.edges.add(retval) + edge = reduce(IGEdge.merged, + (node.edges.pop(n) for n in source_nodes)) + node.edges[retval] = edge for node in source_nodes: del self.__map[node.merged_ssa_val] self.__map[final_merged_ssa_val] = retval @@ -437,21 +452,37 @@ class IGNodeReprState: self.did_full_repr = OSet() # type: OSet[IGNode] +@plain_data(frozen=True, unsafe_hash=True) +@final +class IGEdge: + """ interference graph edge """ + __slots__ = "is_copy_related", + + def __init__(self, is_copy_related): + # type: (bool) -> None + self.is_copy_related = is_copy_related + + def merged(self, other): + # type: (IGEdge) -> IGEdge + is_copy_related = self.is_copy_related | other.is_copy_related + return IGEdge(is_copy_related=is_copy_related) + + @final class IGNode: """ interference graph node """ __slots__ = "merged_ssa_val", "edges", "loc" def __init__(self, merged_ssa_val, edges, loc): - # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None + # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None) -> None self.merged_ssa_val = merged_ssa_val - self.edges = OSet(edges) + self.edges = edges self.loc = loc - def add_edge(self, other): - # type: (IGNode) -> None - self.edges.add(other) - other.edges.add(self) + def add_edge(self, other, edge): + # type: (IGNode, IGEdge) -> None + self.edges[other] = edge + other.edges[self] = edge def __eq__(self, other): # type: (object) -> bool @@ -465,15 +496,18 @@ class IGNode: def __repr__(self, repr_state=None, short=False): # type: (None | IGNodeReprState, bool) -> str - if repr_state is None: - repr_state = IGNodeReprState() - node_id = repr_state.node_ids.get(self, None) + rs = repr_state + del repr_state + if rs is None: + rs = IGNodeReprState() + node_id = rs.node_ids.get(self, None) if node_id is None: - repr_state.node_ids[self] = node_id = len(repr_state.node_ids) - if short or self in repr_state.did_full_repr: + rs.node_ids[self] = node_id = len(rs.node_ids) + if short or self in rs.did_full_repr: return f"" - repr_state.did_full_repr.add(self) - edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges) + rs.did_full_repr.add(self) + edges = ", ".join( + f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items()) return (f"IGNode(#{node_id}, " f"merged_ssa_val={self.merged_ssa_val}, " f"edges={{{edges}}}, " @@ -539,8 +573,16 @@ def allocate_registers(fn, debug_out=None): interference_graph.merged_ssa_val_map[ssa_val]) for i, j in combinations(live_merged_ssa_vals, 2): if i.loc_set.max_conflicts_with(j.loc_set) != 0: + # can't use: + # is_copy_related = not i.copy_related_ssa_vals.isdisjoint( + # j.copy_related_ssa_vals) + # since it is too coarse + + # TODO: fill in is_copy_related afterwards + # using fn_analysis.copies interference_graph.nodes[i].add_edge( - interference_graph.nodes[j]) + interference_graph.nodes[j], + edge=IGEdge(is_copy_related=False)) if debug_out is not None: print(f"processed {pp} out of {fn_analysis.all_program_points}", file=debug_out, flush=True) -- 2.30.2