From 0291c0a744c7b0ee25725f5d527922dece8f49c2 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 8 Dec 2022 00:45:05 -0800 Subject: [PATCH] WIP: copy merging -- currently broken currently not merging everything it should --- .../_tests/test_register_allocator.py | 101 ++++----- .../register_allocator.py | 204 ++++++++++++++---- 2 files changed, 218 insertions(+), 87 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py index 5e9d25a..d0d9068 100644 --- a/src/bigint_presentation_code/_tests/test_register_allocator.py +++ b/src/bigint_presentation_code/_tests/test_register_allocator.py @@ -48,7 +48,7 @@ class TestRegisterAllocator(unittest.TestCase): def test_register_allocate(self): fn, _arg = self.make_add_fn() - reg_assignments = allocate_registers(fn) + reg_assignments = allocate_registers(fn, debug_out=sys.stdout) self.assertEqual( repr(reg_assignments), @@ -112,22 +112,37 @@ class TestRegisterAllocator(unittest.TestCase): self.assertEqual( repr(reg_assignments), - "{>: " + "{" + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " ">: " "Loc(kind=LocKind.GPR, start=78, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " @@ -140,30 +155,17 @@ class TestRegisterAllocator(unittest.TestCase): "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=46, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1)}" + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)" + "}" ) state = GenAsmState(reg_assignments) fn.gen_asm(state) @@ -174,15 +176,13 @@ class TestRegisterAllocator(unittest.TestCase): 'setvl 0, 0, 32, 0, 1, 1', 'sv.ld *14, 0(3)', 'setvl 0, 0, 32, 0, 1, 1', - 'sv.or *46, *14, *14', 'setvl 0, 0, 32, 0, 1, 1', - 'sv.addi *14, 0, 0', + 'sv.addi *46, 0, 0', 'setvl 0, 0, 32, 0, 1, 1', 'subfc 0, 0, 0', 'setvl 0, 0, 32, 0, 1, 1', - 'sv.or *78, *46, *46', + 'sv.or *78, *14, *14', 'setvl 0, 0, 32, 0, 1, 1', - 'sv.or *46, *14, *14', 'setvl 0, 0, 32, 0, 1, 1', 'sv.adde *14, *78, *46', 'setvl 0, 0, 32, 0, 1, 1', @@ -200,26 +200,42 @@ class TestRegisterAllocator(unittest.TestCase): # type: (str, str) -> None self.assertNotIn(name, graphs, "duplicate graph name") graphs[name] = dot - allocated = allocate_registers(fn, dump_graph=dump_graph) + allocated = allocate_registers(fn, dump_graph=dump_graph, + debug_out=sys.stdout) + dump_graphs(self, graphs) self.assertEqual( repr(allocated), "{" ">: " "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " ">: " "Loc(kind=LocKind.GPR, start=78, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " @@ -232,33 +248,18 @@ class TestRegisterAllocator(unittest.TestCase): "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=46, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1)" + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)" "}" ) - dump_graphs(self, graphs) # FIXME: is_copy_related is not correct, it's missing a bunch of # edges (which aren't interference edges) self.assertEqual(graphs, { diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index e4e3ab6..540af0a 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -6,8 +6,8 @@ this uses an algorithm based on: """ from functools import reduce -from itertools import chain, combinations -from typing import Callable, Iterable, Iterator, Mapping, TextIO, Tuple +from itertools import chain, combinations, count +from typing import Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple from cached_property import cached_property from nmutil.plain_data import plain_data @@ -275,7 +275,7 @@ class MergedSSAVal(metaclass=InternedMeta): lhs_src = self.fn_analysis.copies.get(lhs, lhs) rhs_src = self.fn_analysis.copies.get(rhs, rhs) if lhs_src == rhs_src: - return lhs_src, rhs_src + return lhs, rhs return None @@ -319,6 +319,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): # type: (...) -> None self.__merged_ssa_val_map = _private_merged_ssa_val_map self.__map = {} # type: dict[MergedSSAVal, IGNode] + self.__next_node_id = 0 def __getitem__(self, __key): # type: (MergedSSAVal) -> IGNode @@ -347,9 +348,11 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): f"{self.__merged_ssa_val_map[ssa_val]}") self.__merged_ssa_val_map[ssa_val] = merged_ssa_val added += 1 - retval = IGNode(merged_ssa_val=merged_ssa_val, edges={}, loc=None, - ignored=False) + retval = IGNode( + node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val, + edges={}, loc=None, ignored=False) self.__map[merged_ssa_val] = retval + self.__next_node_id += 1 added = None return retval finally: @@ -391,12 +394,18 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): # we're finished checking validity, now we can modify stuff for n in source_nodes: edges.pop(n, None) - retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges, - loc=loc, ignored=False) + retval = IGNode( + node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val, + edges=edges, loc=loc, ignored=False) + self.__next_node_id += 1 + empty_e = IGEdge() for node in edges: edge = reduce(IGEdge.merged, - (node.edges.pop(n) for n in source_nodes)) - node.edges[retval] = edge + (node.edges.pop(n, empty_e) for n in source_nodes)) + if edge == empty_e: + node.edges.pop(retval, None) + else: + node.edges[retval] = edge for node in source_nodes: del self.__map[node.merged_ssa_val] self.__map[final_merged_ssa_val] = retval @@ -452,7 +461,7 @@ class InterferenceGraph: lhs, rhs = rhs, lhs lhs_merged = node1.merged_ssa_val.with_offset_to_match( lhs.ssa_val, additional_offset=-lhs.reg_idx) - rhs_merged = node1.merged_ssa_val.with_offset_to_match( + rhs_merged = node2.merged_ssa_val.with_offset_to_match( rhs.ssa_val, additional_offset=-rhs.reg_idx) return lhs_merged.merged(rhs_merged) @@ -474,12 +483,14 @@ class InterferenceGraph: raise ValueError( "can't get local_colorability_score of ignored node") loc_set = node.loc_set - edges = node.edges.items() + edges = node.edges if merged_in_copy is not None: loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set - edges = chain(edges, merged_in_copy.edges.items()) + edges = edges.copy() + for neighbor, edge in merged_in_copy.edges.items(): + edges[neighbor] = edge.merged(edges.get(neighbor)) retval = len(loc_set) - for neighbor, edge in edges: + for neighbor, edge in edges.items(): if neighbor.ignored or not edge.interferes: continue if neighbor == merged_in_copy or neighbor == node: @@ -512,8 +523,12 @@ class InterferenceGraph: s = self.nodes.__repr__(repr_state) return f"InterferenceGraph(nodes={s}, <...>)" - def dump_to_dot(self): - # type: () -> str + def dump_to_dot( + self, highlighted_nodes=(), # type: Container[IGNode] + node_scores=None, # type: None | dict[IGNode, int] + edge_scores=None, # type: None | dict[tuple[IGNode, IGNode], int] + ): + # type: (...) -> str def quote(s): # type: (object) -> str @@ -523,10 +538,15 @@ class InterferenceGraph: s = s.replace('\n', r'\n') return f'"{s}"' + if node_scores is None: + node_scores = {} + if edge_scores is None: + edge_scores = {} + edges = {} # type: dict[tuple[IGNode, IGNode], IGEdge] node_ids = {} # type: dict[IGNode, str] for node in self.nodes.values(): - node_ids[node] = quote(len(node_ids)) + node_ids[node] = quote(node.node_id) for neighbor, edge in node.edges.items(): edge_key = (node, neighbor) # ensure we only insert each edge once by checking for @@ -539,10 +559,23 @@ class InterferenceGraph: ] for node, node_id in node_ids.items(): label_lines = [] # type: list[str] + score = node_scores.get(node) + if score is not None: + label_lines.append(f"score={score}") for k, v in node.merged_ssa_val.ssa_val_offsets.items(): label_lines.append(f"{k}: {v}") label = quote("\n".join(label_lines)) - lines.append(f" {node_id} [label = {label}]") + style = "dotted" if node.ignored else "solid" + color = "black" + if node in highlighted_nodes: + style = "bold" + color = "green" + style = quote(style) + color = quote(color) + lines.append(f" {node_id} [" + f"label = {label}, " + f"style = {style}, " + f"color = {color}]") def append_edge(node1, node2, label, color, style): # type: (IGNode, IGNode, str, str, str) -> None @@ -555,11 +588,17 @@ class InterferenceGraph: f"style = {style}, " f"decorate = true]") for (node1, node2), edge in edges.items(): + score = edge_scores.get((node1, node2)) + if score is None: + score = edge_scores.get((node2, node1)) + label_prefix = "" + if score is not None: + label_prefix = f"score={score}\n" if edge.interferes: - append_edge(node1, node2, label="interferes", + append_edge(node1, node2, label=label_prefix + "interferes", color="darkred", style="bold") if edge.copy_relation is not None: - append_edge(node1, node2, label="copy related", + append_edge(node1, node2, label=label_prefix + "copy related", color="blue", style="dashed") lines.append("}") return "\n".join(lines) @@ -567,11 +606,10 @@ class InterferenceGraph: @plain_data(repr=False) class IGNodeReprState: - __slots__ = "node_ids", "did_full_repr" + __slots__ = "did_full_repr", def __init__(self): super().__init__() - self.node_ids = {} # type: dict[IGNode, int] self.did_full_repr = OSet() # type: OSet[IGNode] @@ -600,10 +638,11 @@ class IGEdge: @final class IGNode: """ interference graph node """ - __slots__ = "merged_ssa_val", "edges", "loc", "ignored" + __slots__ = "node_id", "merged_ssa_val", "edges", "loc", "ignored" - def __init__(self, merged_ssa_val, edges, loc, ignored): - # type: (MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None + def __init__(self, node_id, merged_ssa_val, edges, loc, ignored): + # type: (int, MergedSSAVal, dict[IGNode, IGEdge], Loc | None, bool) -> None + self.node_id = node_id self.merged_ssa_val = merged_ssa_val self.edges = edges self.loc = loc @@ -611,6 +650,8 @@ class IGNode: def merge_edge(self, other, edge): # type: (IGNode, IGEdge) -> None + if self == other: + raise ValueError("can't have self-loops") old_edge = self.edges.get(other, None) assert old_edge is other.edges.get(self, None), "inconsistent edges" edge = edge.merged(old_edge) @@ -637,15 +678,12 @@ class IGNode: del repr_state if rs is None: rs = IGNodeReprState() - node_id = rs.node_ids.get(self, None) - if node_id is None: - rs.node_ids[self] = node_id = len(rs.node_ids) if short or self in rs.did_full_repr: - return f"" + return f"" rs.did_full_repr.add(self) edges = ", ".join( f"{k.__repr__(rs, True)}: {v}" for k, v in self.edges.items()) - return (f"IGNode(#{node_id}, " + return (f"IGNode(#{self.node_id}, " f"merged_ssa_val={self.merged_ssa_val}, " f"edges={{{edges}}}, " f"loc={self.loc}, " @@ -734,28 +772,120 @@ def allocate_registers( if dump_graph is not None: dump_graph("initial", interference_graph.dump_to_dot()) - # TODO: implement copy-merging - node_stack = [] # type: list[IGNode] - while True: + + debug_node_scores = {} # type: dict[IGNode, int] + debug_edge_scores = {} # type: dict[tuple[IGNode, IGNode], int] + + def find_best_node(has_copy_relation): + # type: (bool) -> None | IGNode best_node = None # type: None | IGNode best_score = 0 for node in interference_graph.nodes.values(): if node.ignored: continue + node_has_copy_relation = False + for neighbor, edge in node.edges.items(): + if neighbor.ignored: + continue + if edge.copy_relation is not None: + node_has_copy_relation = True + break + if node_has_copy_relation != has_copy_relation: + continue score = interference_graph.local_colorability_score(node) + debug_node_scores[node] = score if best_node is None or score > best_score: best_node = node best_score = score if best_score > 0: # it's locally colorable, no need to find a better one break + if debug_out is not None: + print(f"find_best_node(has_copy_relation={has_copy_relation}):\n" + f"{best_node}", file=debug_out, flush=True) + return best_node + # copy-merging algorithm based on Iterated Register Coalescing, section 5: + # https://dl.acm.org/doi/pdf/10.1145/229542.229546 + # Build step is above. + for step in count(): + debug_node_scores.clear() + debug_edge_scores.clear() + # Simplify: + best_node = find_best_node(has_copy_relation=False) + if best_node is not None: + if dump_graph is not None: + dump_graph( + f"step_{step}_simplify", interference_graph.dump_to_dot( + highlighted_nodes=[best_node], + node_scores=debug_node_scores, + edge_scores=debug_edge_scores)) + node_stack.append(best_node) + best_node.ignored = True + continue + # Coalesce (aka. do copy-merges): + did_any_copy_merges = False + for node in interference_graph.nodes.values(): + if node.ignored: + continue + for neighbor, edge in node.edges.items(): + if neighbor.ignored: + continue + if edge.copy_relation is None: + continue + try: + score = interference_graph.local_colorability_score( + node, merged_in_copy=neighbor) + except BadMergedSSAVal: + continue + if (neighbor, node) in debug_edge_scores: + debug_edge_scores[(neighbor, node)] = score + else: + debug_edge_scores[(node, neighbor)] = score + if score > 0: # merged node is locally colorable + if dump_graph is not None: + dump_graph( + f"step_{step}_copy_merge", + interference_graph.dump_to_dot( + highlighted_nodes=[node, neighbor], + node_scores=debug_node_scores, + edge_scores=debug_edge_scores)) + if debug_out is not None: + print(f"\nCopy-merging:\n{node}\nwith:\n{neighbor}", + file=debug_out, flush=True) + merged_node = interference_graph.copy_merge(node, neighbor) + if dump_graph is not None: + dump_graph( + f"step_{step}_copy_merge_result", + interference_graph.dump_to_dot( + highlighted_nodes=[merged_node])) + if debug_out is not None: + print(f"merged_node:\n" + f"{merged_node}", file=debug_out, flush=True) + did_any_copy_merges = True + break + if did_any_copy_merges: + break + if did_any_copy_merges: + continue + # Freeze: + best_node = find_best_node(has_copy_relation=True) + if best_node is not None: + if dump_graph is not None: + dump_graph(f"step_{step}_freeze", + interference_graph.dump_to_dot( + highlighted_nodes=[best_node], + node_scores=debug_node_scores, + edge_scores=debug_edge_scores)) + # no need to clear copy relations since best_node won't be + # considered since it's now ignored. + node_stack.append(best_node) + best_node.ignored = True + continue + break - if best_node is None: - break - node_stack.append(best_node) - best_node.ignored = True - + if dump_graph is not None: + dump_graph("final", interference_graph.dump_to_dot()) if debug_out is not None: print(f"After deciding node allocation order:\n" f"{node_stack}", file=debug_out, flush=True) -- 2.30.2