From 00daae74b1970e71e1377c6b0afd72fe7a9a83f5 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 5 Dec 2022 23:04:35 -0800 Subject: [PATCH] add dot graph output for interference graph --- .gitignore | 1 + .../_tests/test_register_allocator.py | 132 ++++++++++++++++++ .../register_allocator.py | 64 ++++++++- typings/nmutil/get_test_path.pyi | 20 +++ 4 files changed, 210 insertions(+), 7 deletions(-) create mode 100644 typings/nmutil/get_test_path.pyi diff --git a/.gitignore b/.gitignore index 4134655..ac96be5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ *.egg-info *.il /.vscode +dumped_graphs diff --git a/src/bigint_presentation_code/_tests/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py index d30ea12..7db0796 100644 --- a/src/bigint_presentation_code/_tests/test_register_allocator.py +++ b/src/bigint_presentation_code/_tests/test_register_allocator.py @@ -1,9 +1,22 @@ import sys import unittest +import shutil from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind, SSAVal) from bigint_presentation_code.register_allocator import allocate_registers +from nmutil.get_test_path import get_test_path + + +def dump_graphs(test_case, graphs): + # type: (unittest.TestCase, dict[str, str]) -> None + base_path = get_test_path(test_case, "dumped_graphs") + shutil.rmtree(base_path, ignore_errors=True) + base_path.mkdir(parents=True, exist_ok=True) + for name, dot in graphs.items(): + path = base_path / name + dot_path = path.with_suffix(".dot") + dot_path.write_text(dot) class TestRegisterAllocator(unittest.TestCase): @@ -179,6 +192,125 @@ class TestRegisterAllocator(unittest.TestCase): 'sv.std *14, 0(3)', ]) + def test_register_allocate_graphs(self): + fn, _arg = self.make_add_fn() + graphs = {} # type: dict[str, str] + + def dump_graph(name, dot): + # type: (str, str) -> None + self.assertNotIn(name, graphs, "duplicate graph name") + graphs[name] = dot + allocated = allocate_registers(fn, dump_graph=dump_graph) + self.assertEqual( + repr(allocated), + "{" + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=78, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1)" + "}" + ) + dump_graphs(self, graphs) + # FIXME: is_copy_related is not correct, it's missing a bunch of + # edges (which aren't interference edges) + self.assertEqual(graphs, { + 'initial': + 'graph {\n' + ' "0" [label = ">: 0"]\n' + ' "1" [label = ">: 0"]\n' + ' "2" [label = ">: 0"]\n' + ' "3" [label = ">: 0"]\n' + ' "4" [label = ">: 0"]\n' + ' "5" [label = ">: 0"]\n' + ' "6" [label = ">: 0"]\n' + ' "7" [label = ">: 0"]\n' + ' "8" [label = ">: 0"]\n' + ' "9" [label = ">: 0"]\n' + ' "10" [label = ">: 0"]\n' + ' "11" [label = ">: 0"]\n' + ' "12" [label = ">: 0"]\n' + ' "13" [label = ">: 0"]\n' + ' "14" [label = ">: 0"]\n' + ' "15" [label = ">: 0"]\n' + ' "16" [label = ">: 0"]\n' + ' "17" [label = ">: 0"]\n' + ' "18" [label = ">: 0\\n' + '>: 0"]\n' + ' "19" [label = ">: 0"]\n' + ' "20" [label = ">: 0"]\n' + ' "21" [label = ">: 0"]\n' + ' "22" [label = ">: 0"]\n' + ' "23" [label = ">: 0"]\n' + ' "24" [label = ">: 0"]\n' + ' "1" -- "3" [label = "IGEdge(is_copy_related=True)"]\n' + ' "1" -- "5" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "7" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "9" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "11" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "13" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "15" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "17" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "20" [label = "IGEdge(is_copy_related=False)"]\n' + ' "1" -- "22" [label = "IGEdge(is_copy_related=False)"]\n' + ' "3" -- "5" [label = "IGEdge(is_copy_related=False)"]\n' + ' "7" -- "9" [label = "IGEdge(is_copy_related=False)"]\n' + ' "7" -- "11" [label = "IGEdge(is_copy_related=False)"]\n' + ' "11" -- "13" [label = "IGEdge(is_copy_related=False)"]\n' + ' "13" -- "15" [label = "IGEdge(is_copy_related=False)"]\n' + ' "13" -- "17" [label = "IGEdge(is_copy_related=False)"]\n' + ' "15" -- "17" [label = "IGEdge(is_copy_related=False)"]\n' + ' "22" -- "23" [label = "IGEdge(is_copy_related=False)"]\n' + '}' + }) + def test_register_allocate_spread(self): fn = Fn() maxvl = 32 diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index be1522e..6e7d87c 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -7,7 +7,7 @@ this uses an algorithm based on: from functools import reduce from itertools import combinations -from typing import Iterable, Iterator, Mapping, TextIO +from typing import Callable, Iterable, Iterator, Mapping, TextIO from cached_property import cached_property from nmutil.plain_data import plain_data @@ -263,6 +263,18 @@ class MergedSSAVal(metaclass=InternedMeta): sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val]) return OFSet(v for s in sets for v in s) + def is_copy_related(self, other): + # type: (MergedSSAVal) -> bool + for lhs_ssa_val in self.ssa_vals: + for rhs_ssa_val in other.ssa_vals: + for lhs in lhs_ssa_val.ssa_val_sub_regs: + for rhs in rhs_ssa_val.ssa_val_sub_regs: + lhs = self.fn_analysis.copies.get(lhs, lhs) + rhs = self.fn_analysis.copies.get(rhs, rhs) + if lhs == rhs: + return True + return False + @final class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]): @@ -441,6 +453,41 @@ class InterferenceGraph: s = self.nodes.__repr__(repr_state) return f"InterferenceGraph(nodes={s}, <...>)" + def dump_to_dot(self): + # type: () -> str + + def quote(s): + # type: (object) -> str + s = str(s) + s = s.replace('\\', r'\\') + s = s.replace('"', r'\"') + s = s.replace('\n', r'\n') + return f'"{s}"' + + edges = {} # type: dict[tuple[IGNode, IGNode], IGEdge] + node_ids = {} # type: dict[IGNode, str] + for node in self.nodes.values(): + node_ids[node] = quote(len(node_ids)) + for neighbor, edge in node.edges.items(): + edge_key = (node, neighbor) + # ensure we only insert each edge once by checking for + # both directions + if edge_key not in edges and edge_key[::-1] not in edges: + edges[edge_key] = edge + lines = ["graph {"] + for node, node_id in node_ids.items(): + label_lines = [] # type: list[str] + for k, v in node.merged_ssa_val.ssa_val_offsets.items(): + label_lines.append(f"{k}: {v}") + label = quote("\n".join(label_lines)) + lines.append(f" {node_id} [label = {label}]") + for (node1, node2), edge in edges.items(): + label = quote(repr(edge)) + lines.append(f" {node_ids[node1]} -- {node_ids[node2]} " + f"[label = {label}]") + lines.append("}") + return "\n".join(lines) + @plain_data(repr=False) class IGNodeReprState: @@ -547,8 +594,12 @@ class AllocationFailedError(Exception): return self.__repr__() -def allocate_registers(fn, debug_out=None): - # type: (Fn, TextIO | None) -> dict[SSAVal, Loc] +def allocate_registers( + fn, # type: Fn + debug_out=None, # type: TextIO | None + dump_graph=None, # type: Callable[[str, str], None] | None +): + # type: (...) -> dict[SSAVal, Loc] # inserts enough copies that no manual spilling is necessary, all # spilling is done by the register allocator naturally allocating SSAVals @@ -577,12 +628,9 @@ def allocate_registers(fn, debug_out=None): # is_copy_related = not i.copy_related_ssa_vals.isdisjoint( # j.copy_related_ssa_vals) # since it is too coarse - - # TODO: fill in is_copy_related afterwards - # using fn_analysis.copies interference_graph.nodes[i].add_edge( interference_graph.nodes[j], - edge=IGEdge(is_copy_related=False)) + edge=IGEdge(is_copy_related=i.is_copy_related(j))) if debug_out is not None: print(f"processed {pp} out of {fn_analysis.all_program_points}", file=debug_out, flush=True) @@ -590,6 +638,8 @@ def allocate_registers(fn, debug_out=None): if debug_out is not None: print(f"After adding interference graph edges:\n" f"{interference_graph}", file=debug_out, flush=True) + if dump_graph is not None: + dump_graph("initial", interference_graph.dump_to_dot()) nodes_remaining = OSet(interference_graph.nodes.values()) diff --git a/typings/nmutil/get_test_path.pyi b/typings/nmutil/get_test_path.pyi new file mode 100644 index 0000000..df24d89 --- /dev/null +++ b/typings/nmutil/get_test_path.pyi @@ -0,0 +1,20 @@ +from os import PathLike +from pathlib import Path +from typing import Any +import unittest + + +class RunCounter: + def next(self, k: str) -> int: + """get a incrementing run counter for a `str` key `k`. returns an `int`.""" + ... + + @staticmethod + def get(obj: Any) -> RunCounter: ... + + +def get_test_path(test_case: unittest.TestCase, + base_path: str | PathLike[str]) -> Path: + """get the `Path` for a particular unittest.TestCase instance + (`test_case`). base_path is either a str or a path-like.""" + ... -- 2.30.2