import sys
import unittest
+import shutil
from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind,
SSAVal)
from bigint_presentation_code.register_allocator import allocate_registers
+from nmutil.get_test_path import get_test_path
+
+
+def dump_graphs(test_case, graphs):
+ # type: (unittest.TestCase, dict[str, str]) -> None
+ base_path = get_test_path(test_case, "dumped_graphs")
+ shutil.rmtree(base_path, ignore_errors=True)
+ base_path.mkdir(parents=True, exist_ok=True)
+ for name, dot in graphs.items():
+ path = base_path / name
+ dot_path = path.with_suffix(".dot")
+ dot_path.write_text(dot)
class TestRegisterAllocator(unittest.TestCase):
'sv.std *14, 0(3)',
])
+ def test_register_allocate_graphs(self):
+ fn, _arg = self.make_add_fn()
+ graphs = {} # type: dict[str, str]
+
+ def dump_graph(name, dot):
+ # type: (str, str) -> None
+ self.assertNotIn(name, graphs, "duplicate graph name")
+ graphs[name] = dot
+ allocated = allocate_registers(fn, dump_graph=dump_graph)
+ self.assertEqual(
+ repr(allocated),
+ "{"
+ "<add.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.inp1.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<add.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+ "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<st.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<st.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ca.outputs[0]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<vl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<arg.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
+ "}"
+ )
+ dump_graphs(self, graphs)
+ # FIXME: is_copy_related is not correct, it's missing a bunch of
+ # edges (which aren't interference edges)
+ self.assertEqual(graphs, {
+ 'initial':
+ 'graph {\n'
+ ' "0" [label = "<arg.outputs[0]: <I64>>: 0"]\n'
+ ' "1" [label = "<arg.out0.copy.outputs[0]: <I64>>: 0"]\n'
+ ' "2" [label = "<vl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "3" [label = "<ld.inp0.copy.outputs[0]: <I64>>: 0"]\n'
+ ' "4" [label = "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "5" [label = "<ld.outputs[0]: <I64*32>>: 0"]\n'
+ ' "6" [label = "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "7" [label = "<ld.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
+ ' "8" [label = "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "9" [label = "<li.outputs[0]: <I64*32>>: 0"]\n'
+ ' "10" [label = "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "11" [label = "<li.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
+ ' "12" [label = "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "13" [label = "<add.inp0.copy.outputs[0]: <I64*32>>: 0"]\n'
+ ' "14" [label = "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "15" [label = "<add.inp1.copy.outputs[0]: <I64*32>>: 0"]\n'
+ ' "16" [label = "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "17" [label = "<add.outputs[0]: <I64*32>>: 0"]\n'
+ ' "18" [label = "<ca.outputs[0]: <CA>>: 0\\n'
+ '<add.outputs[1]: <CA>>: 0"]\n'
+ ' "19" [label = "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "20" [label = "<add.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
+ ' "21" [label = "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "22" [label = "<st.inp0.copy.outputs[0]: <I64*32>>: 0"]\n'
+ ' "23" [label = "<st.inp1.copy.outputs[0]: <I64>>: 0"]\n'
+ ' "24" [label = "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+ ' "1" -- "3" [label = "IGEdge(is_copy_related=True)"]\n'
+ ' "1" -- "5" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "7" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "9" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "11" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "13" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "15" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "20" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "1" -- "22" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "3" -- "5" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "7" -- "9" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "7" -- "11" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "11" -- "13" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "13" -- "15" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "13" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "15" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
+ ' "22" -- "23" [label = "IGEdge(is_copy_related=False)"]\n'
+ '}'
+ })
+
def test_register_allocate_spread(self):
fn = Fn()
maxvl = 32
from functools import reduce
from itertools import combinations
-from typing import Iterable, Iterator, Mapping, TextIO
+from typing import Callable, Iterable, Iterator, Mapping, TextIO
from cached_property import cached_property
from nmutil.plain_data import plain_data
sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
return OFSet(v for s in sets for v in s)
+ def is_copy_related(self, other):
+ # type: (MergedSSAVal) -> bool
+ for lhs_ssa_val in self.ssa_vals:
+ for rhs_ssa_val in other.ssa_vals:
+ for lhs in lhs_ssa_val.ssa_val_sub_regs:
+ for rhs in rhs_ssa_val.ssa_val_sub_regs:
+ lhs = self.fn_analysis.copies.get(lhs, lhs)
+ rhs = self.fn_analysis.copies.get(rhs, rhs)
+ if lhs == rhs:
+ return True
+ return False
+
@final
class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
s = self.nodes.__repr__(repr_state)
return f"InterferenceGraph(nodes={s}, <...>)"
+ def dump_to_dot(self):
+ # type: () -> str
+
+ def quote(s):
+ # type: (object) -> str
+ s = str(s)
+ s = s.replace('\\', r'\\')
+ s = s.replace('"', r'\"')
+ s = s.replace('\n', r'\n')
+ return f'"{s}"'
+
+ edges = {} # type: dict[tuple[IGNode, IGNode], IGEdge]
+ node_ids = {} # type: dict[IGNode, str]
+ for node in self.nodes.values():
+ node_ids[node] = quote(len(node_ids))
+ for neighbor, edge in node.edges.items():
+ edge_key = (node, neighbor)
+ # ensure we only insert each edge once by checking for
+ # both directions
+ if edge_key not in edges and edge_key[::-1] not in edges:
+ edges[edge_key] = edge
+ lines = ["graph {"]
+ for node, node_id in node_ids.items():
+ label_lines = [] # type: list[str]
+ for k, v in node.merged_ssa_val.ssa_val_offsets.items():
+ label_lines.append(f"{k}: {v}")
+ label = quote("\n".join(label_lines))
+ lines.append(f" {node_id} [label = {label}]")
+ for (node1, node2), edge in edges.items():
+ label = quote(repr(edge))
+ lines.append(f" {node_ids[node1]} -- {node_ids[node2]} "
+ f"[label = {label}]")
+ lines.append("}")
+ return "\n".join(lines)
+
@plain_data(repr=False)
class IGNodeReprState:
return self.__repr__()
-def allocate_registers(fn, debug_out=None):
- # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
+def allocate_registers(
+ fn, # type: Fn
+ debug_out=None, # type: TextIO | None
+ dump_graph=None, # type: Callable[[str, str], None] | None
+):
+ # type: (...) -> dict[SSAVal, Loc]
# inserts enough copies that no manual spilling is necessary, all
# spilling is done by the register allocator naturally allocating SSAVals
# is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
# j.copy_related_ssa_vals)
# since it is too coarse
-
- # TODO: fill in is_copy_related afterwards
- # using fn_analysis.copies
interference_graph.nodes[i].add_edge(
interference_graph.nodes[j],
- edge=IGEdge(is_copy_related=False))
+ edge=IGEdge(is_copy_related=i.is_copy_related(j)))
if debug_out is not None:
print(f"processed {pp} out of {fn_analysis.all_program_points}",
file=debug_out, flush=True)
if debug_out is not None:
print(f"After adding interference graph edges:\n"
f"{interference_graph}", file=debug_out, flush=True)
+ if dump_graph is not None:
+ dump_graph("initial", interference_graph.dump_to_dot())
nodes_remaining = OSet(interference_graph.nodes.values())