import sys
 import unittest
+import shutil
 
 from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind,
                                                   SSAVal)
 from bigint_presentation_code.register_allocator import allocate_registers
+from nmutil.get_test_path import get_test_path
+
+
+def dump_graphs(test_case, graphs):
+    # type: (unittest.TestCase, dict[str, str]) -> None
+    base_path = get_test_path(test_case, "dumped_graphs")
+    shutil.rmtree(base_path, ignore_errors=True)
+    base_path.mkdir(parents=True, exist_ok=True)
+    for name, dot in graphs.items():
+        path = base_path / name
+        dot_path = path.with_suffix(".dot")
+        dot_path.write_text(dot)
 
 
 class TestRegisterAllocator(unittest.TestCase):
             'sv.std *14, 0(3)',
         ])
 
+    def test_register_allocate_graphs(self):
+        fn, _arg = self.make_add_fn()
+        graphs = {}  # type: dict[str, str]
+
+        def dump_graph(name, dot):
+            # type: (str, str) -> None
+            self.assertNotIn(name, graphs, "duplicate graph name")
+            graphs[name] = dot
+        allocated = allocate_registers(fn, dump_graph=dump_graph)
+        self.assertEqual(
+            repr(allocated),
+            "{"
+            "<add.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.inp1.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<add.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+            "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<st.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<st.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ca.outputs[0]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add.outputs[1]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+            "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<vl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<arg.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
+            "}"
+        )
+        dump_graphs(self, graphs)
+        # FIXME: is_copy_related is not correct, it's missing a bunch of
+        # edges (which aren't interference edges)
+        self.assertEqual(graphs, {
+            'initial':
+            'graph {\n'
+            '    "0" [label = "<arg.outputs[0]: <I64>>: 0"]\n'
+            '    "1" [label = "<arg.out0.copy.outputs[0]: <I64>>: 0"]\n'
+            '    "2" [label = "<vl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "3" [label = "<ld.inp0.copy.outputs[0]: <I64>>: 0"]\n'
+            '    "4" [label = "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "5" [label = "<ld.outputs[0]: <I64*32>>: 0"]\n'
+            '    "6" [label = "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "7" [label = "<ld.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
+            '    "8" [label = "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "9" [label = "<li.outputs[0]: <I64*32>>: 0"]\n'
+            '    "10" [label = "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "11" [label = "<li.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
+            '    "12" [label = "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "13" [label = "<add.inp0.copy.outputs[0]: <I64*32>>: 0"]\n'
+            '    "14" [label = "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "15" [label = "<add.inp1.copy.outputs[0]: <I64*32>>: 0"]\n'
+            '    "16" [label = "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "17" [label = "<add.outputs[0]: <I64*32>>: 0"]\n'
+            '    "18" [label = "<ca.outputs[0]: <CA>>: 0\\n'
+            '<add.outputs[1]: <CA>>: 0"]\n'
+            '    "19" [label = "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "20" [label = "<add.out0.copy.outputs[0]: <I64*32>>: 0"]\n'
+            '    "21" [label = "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "22" [label = "<st.inp0.copy.outputs[0]: <I64*32>>: 0"]\n'
+            '    "23" [label = "<st.inp1.copy.outputs[0]: <I64>>: 0"]\n'
+            '    "24" [label = "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: 0"]\n'
+            '    "1" -- "3" [label = "IGEdge(is_copy_related=True)"]\n'
+            '    "1" -- "5" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "7" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "9" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "11" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "13" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "15" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "20" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "1" -- "22" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "3" -- "5" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "7" -- "9" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "7" -- "11" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "11" -- "13" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "13" -- "15" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "13" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "15" -- "17" [label = "IGEdge(is_copy_related=False)"]\n'
+            '    "22" -- "23" [label = "IGEdge(is_copy_related=False)"]\n'
+            '}'
+        })
+
     def test_register_allocate_spread(self):
         fn = Fn()
         maxvl = 32
 
 
 from functools import reduce
 from itertools import combinations
-from typing import Iterable, Iterator, Mapping, TextIO
+from typing import Callable, Iterable, Iterator, Mapping, TextIO
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
             sets.add(self.fn_analysis.copy_related_ssa_vals[ssa_val])
         return OFSet(v for s in sets for v in s)
 
+    def is_copy_related(self, other):
+        # type: (MergedSSAVal) -> bool
+        for lhs_ssa_val in self.ssa_vals:
+            for rhs_ssa_val in other.ssa_vals:
+                for lhs in lhs_ssa_val.ssa_val_sub_regs:
+                    for rhs in rhs_ssa_val.ssa_val_sub_regs:
+                        lhs = self.fn_analysis.copies.get(lhs, lhs)
+                        rhs = self.fn_analysis.copies.get(rhs, rhs)
+                        if lhs == rhs:
+                            return True
+        return False
+
 
 @final
 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
         s = self.nodes.__repr__(repr_state)
         return f"InterferenceGraph(nodes={s}, <...>)"
 
+    def dump_to_dot(self):
+        # type: () -> str
+
+        def quote(s):
+            # type: (object) -> str
+            s = str(s)
+            s = s.replace('\\', r'\\')
+            s = s.replace('"', r'\"')
+            s = s.replace('\n', r'\n')
+            return f'"{s}"'
+
+        edges = {}  # type: dict[tuple[IGNode, IGNode], IGEdge]
+        node_ids = {}  # type: dict[IGNode, str]
+        for node in self.nodes.values():
+            node_ids[node] = quote(len(node_ids))
+            for neighbor, edge in node.edges.items():
+                edge_key = (node, neighbor)
+                # ensure we only insert each edge once by checking for
+                # both directions
+                if edge_key not in edges and edge_key[::-1] not in edges:
+                    edges[edge_key] = edge
+        lines = ["graph {"]
+        for node, node_id in node_ids.items():
+            label_lines = []  # type: list[str]
+            for k, v in node.merged_ssa_val.ssa_val_offsets.items():
+                label_lines.append(f"{k}: {v}")
+            label = quote("\n".join(label_lines))
+            lines.append(f"    {node_id} [label = {label}]")
+        for (node1, node2), edge in edges.items():
+            label = quote(repr(edge))
+            lines.append(f"    {node_ids[node1]} -- {node_ids[node2]} "
+                         f"[label = {label}]")
+        lines.append("}")
+        return "\n".join(lines)
+
 
 @plain_data(repr=False)
 class IGNodeReprState:
         return self.__repr__()
 
 
-def allocate_registers(fn, debug_out=None):
-    # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
+def allocate_registers(
+    fn,  # type: Fn
+    debug_out=None,  # type: TextIO | None
+    dump_graph=None,  # type: Callable[[str, str], None] | None
+):
+    # type: (...) -> dict[SSAVal, Loc]
 
     # inserts enough copies that no manual spilling is necessary, all
     # spilling is done by the register allocator naturally allocating SSAVals
                 # is_copy_related = not i.copy_related_ssa_vals.isdisjoint(
                 #     j.copy_related_ssa_vals)
                 # since it is too coarse
-
-                # TODO: fill in is_copy_related afterwards
-                # using fn_analysis.copies
                 interference_graph.nodes[i].add_edge(
                     interference_graph.nodes[j],
-                    edge=IGEdge(is_copy_related=False))
+                    edge=IGEdge(is_copy_related=i.is_copy_related(j)))
         if debug_out is not None:
             print(f"processed {pp} out of {fn_analysis.all_program_points}",
                   file=debug_out, flush=True)
     if debug_out is not None:
         print(f"After adding interference graph edges:\n"
               f"{interference_graph}", file=debug_out, flush=True)
+    if dump_graph is not None:
+        dump_graph("initial", interference_graph.dump_to_dot())
 
     nodes_remaining = OSet(interference_graph.nodes.values())