copy-merging works afaict! -- some tests still broken: out-of-date
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 9 Dec 2022 08:06:41 +0000 (00:06 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 9 Dec 2022 08:06:41 +0000 (00:06 -0800)
tests not yet all updated for new register allocator results

src/bigint_presentation_code/_tests/test_register_allocator.py
src/bigint_presentation_code/_tests/test_toom_cook.py
src/bigint_presentation_code/compiler_ir.py
src/bigint_presentation_code/register_allocator.py
src/bigint_presentation_code/register_allocator_test_util.py [new file with mode: 0644]

index d0d9068404d9512c1ec595167e9e2f2b0122c4fe..be76d2dcf0a461f6ebbc2b7c7b64b34463845e2b 100644 (file)
@@ -1,22 +1,10 @@
 import sys
 import unittest
-import shutil
 
 from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind,
                                                   SSAVal)
 from bigint_presentation_code.register_allocator import allocate_registers
-from nmutil.get_test_path import get_test_path
-
-
-def dump_graphs(test_case, graphs):
-    # type: (unittest.TestCase, dict[str, str]) -> None
-    base_path = get_test_path(test_case, "dumped_graphs")
-    shutil.rmtree(base_path, ignore_errors=True)
-    base_path.mkdir(parents=True, exist_ok=True)
-    for name, dot in graphs.items():
-        path = base_path / name
-        dot_path = path.with_suffix(".dot")
-        dot_path.write_text(dot)
+from bigint_presentation_code.register_allocator_test_util import GraphDumper
 
 
 class TestRegisterAllocator(unittest.TestCase):
@@ -48,7 +36,8 @@ class TestRegisterAllocator(unittest.TestCase):
 
     def test_register_allocate(self):
         fn, _arg = self.make_add_fn()
-        reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+        reg_assignments = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
 
         self.assertEqual(
             repr(reg_assignments),
@@ -108,7 +97,8 @@ class TestRegisterAllocator(unittest.TestCase):
 
     def test_gen_asm(self):
         fn, _arg = self.make_add_fn()
-        reg_assignments = allocate_registers(fn)
+        reg_assignments = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
 
         self.assertEqual(
             repr(reg_assignments),
@@ -131,14 +121,14 @@ class TestRegisterAllocator(unittest.TestCase):
             "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<ld.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
-            "<ld.inp0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<arg.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<st.inp1.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
-            "<arg.out0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
             "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
@@ -170,9 +160,7 @@ class TestRegisterAllocator(unittest.TestCase):
         state = GenAsmState(reg_assignments)
         fn.gen_asm(state)
         self.assertEqual(state.output, [
-            'or 4, 3, 3',
             'setvl 0, 0, 32, 0, 1, 1',
-            'or 3, 4, 4',
             'setvl 0, 0, 32, 0, 1, 1',
             'sv.ld *14, 0(3)',
             'setvl 0, 0, 32, 0, 1, 1',
@@ -187,22 +175,23 @@ class TestRegisterAllocator(unittest.TestCase):
             'sv.adde *14, *78, *46',
             'setvl 0, 0, 32, 0, 1, 1',
             'setvl 0, 0, 32, 0, 1, 1',
-            'or 3, 4, 4',
             'setvl 0, 0, 32, 0, 1, 1',
-            'sv.std *14, 0(3)',
+            'sv.std *14, 0(3)'
         ])
 
     def test_register_allocate_graphs(self):
         fn, _arg = self.make_add_fn()
         graphs = {}  # type: dict[str, str]
 
+        graph_dumper = GraphDumper(self)
+
         def dump_graph(name, dot):
             # type: (str, str) -> None
             self.assertNotIn(name, graphs, "duplicate graph name")
             graphs[name] = dot
-        allocated = allocate_registers(fn, dump_graph=dump_graph,
-                                       debug_out=sys.stdout)
-        dump_graphs(self, graphs)
+            graph_dumper(name, dot)
+        allocated = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=dump_graph)
         self.assertEqual(
             repr(allocated),
             "{"
@@ -224,14 +213,14 @@ class TestRegisterAllocator(unittest.TestCase):
             "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
             "<ld.outputs[0]: <I64*32>>: "
             "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
-            "<ld.inp0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<arg.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<arg.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<ld.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
             "<st.inp1.copy.outputs[0]: <I64>>: "
             "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
-            "<arg.out0.copy.outputs[0]: <I64>>: "
-            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
             "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
             "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
             "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
@@ -338,7 +327,8 @@ class TestRegisterAllocator(unittest.TestCase):
         _concat = fn.append_new_op(
             OpKind.Concat, input_vals=[*spread[::-1], vl],
             name="concat", maxvl=maxvl)
-        reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+        reg_assignments = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
 
         self.assertEqual(
             repr(reg_assignments),
index eadc0809afa4edb60705021033ef54affbae2489..dc3bb8d8985cad70b95a7acf419eb7266f9bb912 100644 (file)
@@ -1,4 +1,5 @@
 from contextlib import contextmanager
+import sys
 import unittest
 from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
 
@@ -9,6 +10,7 @@ from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS,
                                                   PostRASimState,
                                                   PreRASimState, SSAVal)
 from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.register_allocator_test_util import GraphDumper
 from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
                                                 simple_mul)
 from bigint_presentation_code.util import OSet
@@ -77,21 +79,21 @@ class Mul:
             name="store_dest")
 
 
-def get_post_ra_state_factory(code):
-    # type: (Mul) -> _StateFactory
-    ssa_val_to_loc_map = allocate_registers(code.fn)
-
-    @contextmanager
-    def state_factory():
-        yield PostRASimState(
-            ssa_val_to_loc_map=ssa_val_to_loc_map,
-            memory={}, loc_values={})
-    return state_factory
-
-
 class TestToomCook(unittest.TestCase):
     maxDiff = None
 
+    def get_post_ra_state_factory(self, code):
+        # type: (Mul) -> _StateFactory
+        ssa_val_to_loc_map = allocate_registers(
+            code.fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
+
+        @contextmanager
+        def state_factory():
+            yield PostRASimState(
+                ssa_val_to_loc_map=ssa_val_to_loc_map,
+                memory={}, loc_values={})
+        return state_factory
+
     def test_toom_2_repr(self):
         TOOM_2 = ToomCookInstance.make_toom_2()
         # print(repr(repr(TOOM_2)))
@@ -269,7 +271,7 @@ class TestToomCook(unittest.TestCase):
             for rhs_signed in False, True:
                 self.tst_simple_mul_192x192_sim(
                     lhs_signed=lhs_signed, rhs_signed=rhs_signed,
-                    get_state_factory=get_post_ra_state_factory)
+                    get_state_factory=self.get_post_ra_state_factory)
 
     def tst_simple_mul_192x192_sim(
         self, lhs_signed,  # type: bool
@@ -485,7 +487,8 @@ class TestToomCook(unittest.TestCase):
     def test_simple_mul_192x192_reg_alloc(self):
         code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         fn = code.fn
-        assigned_registers = allocate_registers(fn)
+        assigned_registers = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
         self.assertEqual(
             repr(assigned_registers), "{"
             "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
@@ -865,150 +868,85 @@ class TestToomCook(unittest.TestCase):
     def test_simple_mul_192x192_asm(self):
         code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
         fn = code.fn
-        assigned_registers = allocate_registers(fn)
+        assigned_registers = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
         gen_asm_state = GenAsmState(assigned_registers)
         fn.gen_asm(gen_asm_state)
         self.assertEqual(gen_asm_state.output, [
-            'or 27, 3, 3',
+            'or 9, 3, 3',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 6, 27, 27',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.ld *3, 48(6)',
+            'sv.ld *20, 48(9)',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *24, *3, *3',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 6, 27, 27',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.ld *3, 72(6)',
+            'sv.ld *10, 72(9)',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or/mrr *5, *3, *3',
-            'or 4, 5, 5',
-            'or 14, 6, 6',
-            'or 23, 7, 7',
-            'addi 3, 0, 0',
-            'or 22, 3, 3',
+            'addi 6, 0, 0',
+            'or 8, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'addi 3, 0, 0',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *8, *24, *24',
-            'or 7, 4, 4',
-            'or 6, 22, 22',
+            'or 6, 8, 8',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.maddedu *3, *8, 7, 6',
+            'sv.maddedu *14, *20, 10, 6',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 19, 6, 6',
+            'or 17, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 21, 3, 3',
-            'or 12, 4, 4',
-            'or 11, 5, 5',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *8, *24, *24',
-            'or 7, 14, 14',
-            'or 6, 22, 22',
+            'or 6, 8, 8',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.maddedu *3, *8, 7, 6',
+            'sv.maddedu *3, *20, 11, 6',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 18, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 17, 3, 3',
-            'or 16, 4, 4',
-            'or 15, 5, 5',
-            'addi 3, 0, 0',
-            'or 8, 3, 3',
-            'addi 3, 0, 0',
-            'or 14, 3, 3',
+            'addi 18, 0, 0',
+            'addi 7, 0, 0',
             'setvl 0, 0, 5, 0, 1, 1',
-            'or 3, 12, 12',
-            'or 4, 11, 11',
-            'or 5, 19, 19',
-            'or 6, 8, 8',
-            'or 7, 8, 8',
+            'or 19, 18, 18',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.or *8, *3, *3',
-            'or 3, 17, 17',
-            'or 4, 16, 16',
-            'or 5, 15, 15',
-            'or 6, 18, 18',
-            'or 7, 14, 14',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
             'addic 0, 0, 0',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.or *14, *8, *8',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.or *8, *3, *3',
             'setvl 0, 0, 5, 0, 1, 1',
-            'sv.adde *3, *14, *8',
+            'sv.adde *15, *15, *3',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
             'setvl 0, 0, 5, 0, 1, 1',
-            'or 20, 3, 3',
-            'or 19, 4, 4',
-            'or 18, 5, 5',
-            'or 17, 6, 6',
-            'or 16, 7, 7',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.or *8, *24, *24',
-            'or 7, 23, 23',
-            'or 6, 22, 22',
+            'or 6, 8, 8',
             'setvl 0, 0, 3, 0, 1, 1',
-            'sv.maddedu *3, *8, 7, 6',
+            'sv.maddedu *3, *20, 12, 6',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 15, 6, 6',
             'setvl 0, 0, 3, 0, 1, 1',
             'setvl 0, 0, 3, 0, 1, 1',
-            'or 14, 3, 3',
-            'or 12, 4, 4',
-            'or 11, 5, 5',
             'setvl 0, 0, 4, 0, 1, 1',
-            'or 3, 19, 19',
-            'or 4, 18, 18',
-            'or 5, 17, 17',
-            'or 6, 16, 16',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.or *7, *3, *3',
-            'or 3, 14, 14',
-            'or 4, 12, 12',
-            'or 5, 11, 11',
-            'or 6, 15, 15',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
             'addic 0, 0, 0',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.or *14, *7, *7',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.or *7, *3, *3',
             'setvl 0, 0, 4, 0, 1, 1',
-            'sv.adde *3, *14, *7',
+            'sv.adde *16, *16, *3',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
             'setvl 0, 0, 4, 0, 1, 1',
-            'or 12, 3, 3',
-            'or 11, 4, 4',
-            'or 10, 5, 5',
-            'or 9, 6, 6',
             'setvl 0, 0, 6, 0, 1, 1',
-            'or 3, 21, 21',
-            'or 4, 20, 20',
-            'or 5, 12, 12',
-            'or 6, 11, 11',
-            'or 7, 10, 10',
-            'or 8, 9, 9',
             'setvl 0, 0, 6, 0, 1, 1',
             'setvl 0, 0, 6, 0, 1, 1',
             'setvl 0, 0, 6, 0, 1, 1',
             'setvl 0, 0, 6, 0, 1, 1',
-            'sv.or/mrr *4, *3, *3',
-            'or 3, 27, 27',
             'setvl 0, 0, 6, 0, 1, 1',
-            'sv.std *4, 0(3)'
+            'sv.std *14, 0(9)',
         ])
 
     def toom_2_mul_256x256(self, lhs_signed, rhs_signed):
@@ -1122,27 +1060,28 @@ class TestToomCook(unittest.TestCase):
     def test_toom_2_mul_256x256_uu_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=False, rhs_signed=False,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_su_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=True, rhs_signed=False,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_us_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=False, rhs_signed=True,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_ss_post_ra_sim(self):
         self.tst_toom_2_mul_256x256_sim(
             lhs_signed=True, rhs_signed=True,
-            get_state_factory=get_post_ra_state_factory)
+            get_state_factory=self.get_post_ra_state_factory)
 
     def test_toom_2_mul_256x256_asm(self):
         code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False)
         fn = code.fn
-        assigned_registers = allocate_registers(fn)
+        assigned_registers = allocate_registers(
+            fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
         gen_asm_state = GenAsmState(assigned_registers)
         fn.gen_asm(gen_asm_state)
         self.assertEqual(gen_asm_state.output, [
index 699102b655fe86d19dd67283e21620861c012985..4f8fb79bbbf15788663ca0875fd6150db7679e51 100644 (file)
@@ -787,6 +787,20 @@ class Loc(metaclass=InternedMeta):
         return Loc(kind=self.kind,
                    start=self.start + offset, reg_len=subloc_ty.reg_len)
 
+    def get_superloc_with_self_at_offset(self, superloc_ty, offset):
+        # type: (Ty, int) -> Loc
+        """get the Loc containing `self` such that:
+        `retval.get_subloc_at_offset(self.ty, offset) == self`
+        and `retval.ty == superloc_ty`
+        """
+        if superloc_ty.base_ty != self.kind.base_ty:
+            raise ValueError("BaseTy mismatch")
+        if offset < 0 or offset + self.reg_len > superloc_ty.reg_len:
+            raise ValueError("invalid sub-Loc: offset and/or "
+                             "superloc_ty.reg_len out of range")
+        return Loc(kind=self.kind,
+                   start=self.start - offset, reg_len=superloc_ty.reg_len)
+
 
 SPECIAL_GPRS = (
     Loc(kind=LocKind.GPR, start=0, reg_len=1),
@@ -900,6 +914,18 @@ class LocSet(OFSet[Loc], metaclass=InternedMeta):
     def __repr__(self):
         return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
 
+    @cached_property
+    def only_loc(self):
+        # type: () -> Loc | None
+        """if len(self) == 1 then return the Loc in self, otherwise None"""
+        only_loc = None
+        for i in self:
+            if only_loc is None:
+                only_loc = i
+            else:
+                return None  # len(self) > 1
+        return only_loc
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
index 540af0a3e31316203cfefff1851ca2be3e1d816d..76fc966ccfbf721a85132f8ff2c2a1a0b880e881 100644 (file)
@@ -6,11 +6,11 @@ this uses an algorithm based on:
 """
 
 from functools import reduce
-from itertools import chain, combinations, count
+from itertools import combinations, count
 from typing import Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple
 
 from cached_property import cached_property
-from nmutil.plain_data import plain_data
+from nmutil.plain_data import plain_data, replace
 
 from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
                                                   LocSet, Op, ProgramRange,
@@ -59,8 +59,8 @@ class MergedSSAVal(metaclass=InternedMeta):
     __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
                  "first_loc")
 
-    def __init__(self, fn_analysis, ssa_val_offsets):
-        # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
+    def __init__(self, fn_analysis, ssa_val_offsets, loc_set=None):
+        # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
         self.fn_analysis = fn_analysis
         if isinstance(ssa_val_offsets, SSAVal):
             ssa_val_offsets = {ssa_val_offsets: 0}
@@ -74,7 +74,10 @@ class MergedSSAVal(metaclass=InternedMeta):
         self.first_ssa_val = first_ssa_val  # type: SSAVal
         # self.ty checks for mismatched base_ty
         reg_len = self.ty.reg_len
-        loc_set = None  # type: None | LocSet
+        if loc_set is not None and loc_set.ty != self.ty:
+            raise ValueError(
+                f"invalid loc_set, type doesn't match: "
+                f"{loc_set.ty} != {self.ty}")
         for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
             def locs():
                 # type: () -> Iterable[Loc]
@@ -163,12 +166,17 @@ class MergedSSAVal(metaclass=InternedMeta):
     @cached_property
     def __hash(self):
         # type: () -> int
-        return hash((self.fn_analysis, self.ssa_val_offsets))
+        return hash((self.fn_analysis, self.ssa_val_offsets, self.loc_set))
 
     def __hash__(self):
         # type: () -> int
         return self.__hash
 
+    @property
+    def only_loc(self):
+        # type: () -> Loc | None
+        return self.loc_set.only_loc
+
     @cached_property
     def offset(self):
         # type: () -> int
@@ -226,6 +234,16 @@ class MergedSSAVal(metaclass=InternedMeta):
                     ssa_val_offsets[ssa_val] + additional_offset - offset)
         raise ValueError("can't change offset to match unrelated MergedSSAVal")
 
+    def with_loc(self, loc):
+        # type: (Loc) -> MergedSSAVal
+        if loc not in self.loc_set:
+            raise ValueError(
+                f"Loc is not allowed -- not a member of `self.loc_set`: "
+                f"{loc} not in {self.loc_set}")
+        return MergedSSAVal(fn_analysis=self.fn_analysis,
+                            ssa_val_offsets=self.ssa_val_offsets,
+                            loc_set=LocSet([loc]))
+
     def merged(self, *others):
         # type: (*MergedSSAVal) -> MergedSSAVal
         retval = dict(self.ssa_val_offsets)
@@ -278,6 +296,21 @@ class MergedSSAVal(metaclass=InternedMeta):
                             return lhs, rhs
         return None
 
+    def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation):
+        # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
+        cr_lhs, cr_rhs = copy_relation
+        if cr_lhs.ssa_val not in self.ssa_vals:
+            cr_lhs, cr_rhs = cr_rhs, cr_lhs
+        lhs_merged = self.with_offset_to_match(
+            cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
+        if lhs_loc is not None:
+            lhs_merged = lhs_merged.with_loc(lhs_loc)
+        rhs_merged = rhs.with_offset_to_match(
+            cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
+        if rhs_loc is not None:
+            rhs_merged = rhs_merged.with_loc(rhs_loc)
+        return lhs_merged.merged(rhs_merged).normalized()
+
 
 @final
 class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
@@ -350,7 +383,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
                 added += 1
             retval = IGNode(
                 node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val,
-                edges={}, loc=None, ignored=False)
+                edges={}, loc=merged_ssa_val.only_loc, ignored=False)
             self.__map[merged_ssa_val] = retval
             self.__next_node_id += 1
             added = None
@@ -367,7 +400,6 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         # type: (MergedSSAVal) -> IGNode
         source_nodes = OSet()  # type: OSet[IGNode]
         edges = {}  # type: dict[IGNode, IGEdge]
-        loc = None  # type: Loc | None
         for ssa_val in final_merged_ssa_val.ssa_vals:
             merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
             source_node = self.__map[merged_ssa_val]
@@ -380,11 +412,11 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
                     f"but not in merged IGNode's merged_ssa_val: "
                     f"source_node={source_node} "
                     f"final_merged_ssa_val={final_merged_ssa_val}")
-            if loc is None:
-                loc = source_node.loc
-            elif source_node.loc is not None and loc != source_node.loc:
-                raise ValueError(f"can't merge IGNodes with mismatched `loc` "
-                                 f"values: {loc} != {source_node.loc}")
+            if source_node.loc != source_node.merged_ssa_val.only_loc:
+                raise ValueError(
+                    f"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
+                    f"{source_node.loc} != "
+                    f"{source_node.merged_ssa_val.only_loc}")
             for n, edge in source_node.edges.items():
                 if n in edges:
                     edge = edge.merged(edges[n])
@@ -394,6 +426,18 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
         # we're finished checking validity, now we can modify stuff
         for n in source_nodes:
             edges.pop(n, None)
+        loc = final_merged_ssa_val.only_loc
+        for n, edge in edges.items():
+            if edge.copy_relation is None or not edge.interferes:
+                continue
+            try:
+                # if merging works, then the edge can't interfere
+                _ = final_merged_ssa_val.copy_merged(
+                    lhs_loc=loc, rhs=n.merged_ssa_val, rhs_loc=n.loc,
+                    copy_relation=edge.copy_relation)
+            except BadMergedSSAVal:
+                continue
+            edges[n] = replace(edge, interferes=False)
         retval = IGNode(
             node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val,
             edges=edges, loc=loc, ignored=False)
@@ -440,7 +484,7 @@ class InterferenceGraph:
         merged2 = self.merged_ssa_val_map[ssa_val2]
         merged = merged1.with_offset_to_match(ssa_val1)
         return merged.merged(merged2.with_offset_to_match(
-            ssa_val2, additional_offset=additional_offset))
+            ssa_val2, additional_offset=additional_offset)).normalized()
 
     def merge(self, ssa_val1, ssa_val2, additional_offset=0):
         # type: (SSAVal, SSAVal, int) -> IGNode
@@ -448,27 +492,9 @@ class InterferenceGraph:
             ssa_val1=ssa_val1, ssa_val2=ssa_val2,
             additional_offset=additional_offset))
 
-    def copy_merge_preview(self, node1, node2):
-        # type: (IGNode, IGNode) -> MergedSSAVal
-        try:
-            copy_relation = node1.edges[node2].copy_relation
-        except KeyError:
-            raise ValueError("nodes aren't copy related")
-        if copy_relation is None:
-            raise ValueError("nodes aren't copy related")
-        lhs, rhs = copy_relation
-        if lhs.ssa_val not in node1.merged_ssa_val.ssa_vals:
-            lhs, rhs = rhs, lhs
-        lhs_merged = node1.merged_ssa_val.with_offset_to_match(
-            lhs.ssa_val, additional_offset=-lhs.reg_idx)
-        rhs_merged = node2.merged_ssa_val.with_offset_to_match(
-            rhs.ssa_val, additional_offset=-rhs.reg_idx)
-        return lhs_merged.merged(rhs_merged)
-
     def copy_merge(self, node1, node2):
         # type: (IGNode, IGNode) -> IGNode
-        return self.nodes.merge_into_one_node(self.copy_merge_preview(
-            node1=node1, node2=node2))
+        return self.nodes.merge_into_one_node(node1.copy_merge_preview(node2))
 
     def local_colorability_score(self, node, merged_in_copy=None):
         # type: (IGNode, None | IGNode) -> int
@@ -485,7 +511,10 @@ class InterferenceGraph:
         loc_set = node.loc_set
         edges = node.edges
         if merged_in_copy is not None:
-            loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set
+            if merged_in_copy.ignored:
+                raise ValueError(
+                    "can't get local_colorability_score of ignored node")
+            loc_set = node.copy_merge_preview(merged_in_copy).loc_set
             edges = edges.copy()
             for neighbor, edge in merged_in_copy.edges.items():
                 edges[neighbor] = edge.merged(edges.get(neighbor))
@@ -665,12 +694,12 @@ class IGNode:
     def __eq__(self, other):
         # type: (object) -> bool
         if isinstance(other, IGNode):
-            return self.merged_ssa_val == other.merged_ssa_val
+            return self.node_id == other.node_id
         return NotImplemented
 
     def __hash__(self):
         # type: () -> int
-        return hash(self.merged_ssa_val)
+        return hash(self.node_id)
 
     def __repr__(self, repr_state=None, short=False):
         # type: (None | IGNodeReprState, bool) -> str
@@ -703,6 +732,19 @@ class IGNode:
                 return True
         return False
 
+    def copy_merge_preview(self, rhs_node):
+        # type: (IGNode) -> MergedSSAVal
+        try:
+            copy_relation = self.edges[rhs_node].copy_relation
+        except KeyError:
+            raise ValueError("nodes aren't copy related")
+        if copy_relation is None:
+            raise ValueError("nodes aren't copy related")
+        return self.merged_ssa_val.copy_merged(
+            lhs_loc=self.loc,
+            rhs=rhs_node.merged_ssa_val, rhs_loc=rhs_node.loc,
+            copy_relation=copy_relation)
+
 
 class AllocationFailedError(Exception):
     def __init__(self, msg, node, interference_graph):
@@ -748,24 +790,32 @@ def allocate_registers(
         print(f"After InterferenceGraph.minimally_merged():\n"
               f"{interference_graph}", file=debug_out, flush=True)
 
+    for i, j in combinations(interference_graph.nodes.values(), 2):
+        copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
+        i.merge_edge(j, IGEdge(copy_relation=copy_relation))
+
     for pp, ssa_vals in fn_analysis.live_at.items():
         live_merged_ssa_vals = OSet()  # type: OSet[MergedSSAVal]
         for ssa_val in ssa_vals:
             live_merged_ssa_vals.add(
                 interference_graph.merged_ssa_val_map[ssa_val])
         for i, j in combinations(live_merged_ssa_vals, 2):
-            if i.loc_set.max_conflicts_with(j.loc_set) != 0:
-                interference_graph.nodes[i].merge_edge(
-                    interference_graph.nodes[j],
-                    edge=IGEdge(interferes=True))
+            if i.loc_set.max_conflicts_with(j.loc_set) == 0:
+                continue
+            node_i = interference_graph.nodes[i]
+            node_j = interference_graph.nodes[j]
+            if node_j in node_i.edges:
+                if node_i.edges[node_j].copy_relation is not None:
+                    try:
+                        _ = node_i.copy_merge_preview(node_j)
+                        continue  # doesn't interfere if copy merging succeeds
+                    except BadMergedSSAVal:
+                        pass
+            node_i.merge_edge(node_j, edge=IGEdge(interferes=True))
         if debug_out is not None:
             print(f"processed {pp} out of {fn_analysis.all_program_points}",
                   file=debug_out, flush=True)
 
-    for i, j in combinations(interference_graph.nodes.values(), 2):
-        copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
-        i.merge_edge(j, IGEdge(copy_relation=copy_relation))
-
     if debug_out is not None:
         print(f"After adding interference graph edges:\n"
               f"{interference_graph}", file=debug_out, flush=True)
@@ -900,10 +950,37 @@ def allocate_registers(
                     "IGNode is pre-allocated to a conflicting Loc",
                     node=node, interference_graph=interference_graph)
         else:
-            # pick the first non-conflicting register in node.reg_class, since
-            # register classes are ordered from most preferred to least
-            # preferred register.
+            # Locs to try allocating, ordered from most preferred to least
+            # preferred
+            locs = OSet()
+            # prefer eliminating copies
+            for neighbor, edge in node.edges.items():
+                if neighbor.loc is None or edge.copy_relation is None:
+                    continue
+                try:
+                    merged = node.copy_merge_preview(neighbor)
+                except BadMergedSSAVal:
+                    continue
+                # get merged_loc if merged.loc_set has a single Loc
+                merged_loc = merged.only_loc
+                if merged_loc is None:
+                    continue
+                ssa_val = node.merged_ssa_val.first_ssa_val
+                ssa_val_loc = merged_loc.get_subloc_at_offset(
+                    subloc_ty=ssa_val.ty,
+                    offset=merged.ssa_val_offsets[ssa_val])
+                node_loc = ssa_val_loc.get_superloc_with_self_at_offset(
+                    superloc_ty=node.merged_ssa_val.ty,
+                    offset=node.merged_ssa_val.ssa_val_offsets[ssa_val])
+                assert node_loc in node.merged_ssa_val.loc_set, "logic error"
+                locs.add(node_loc)
+            # add node's allowed Locs as fallback
             for loc in node.loc_set:
+                # TODO: add in order of preference
+                locs.add(loc)
+            # pick the first non-conflicting register in locs, since locs is
+            # ordered from most preferred to least preferred register.
+            for loc in locs:
                 if not node.loc_conflicts_with_neighbors(loc):
                     node.loc = loc
                     break
diff --git a/src/bigint_presentation_code/register_allocator_test_util.py b/src/bigint_presentation_code/register_allocator_test_util.py
new file mode 100644 (file)
index 0000000..e0b475b
--- /dev/null
@@ -0,0 +1,21 @@
+import unittest
+import shutil
+
+from nmutil.get_test_path import get_test_path
+from bigint_presentation_code.type_util import final
+
+
+@final
+class GraphDumper:
+    def __init__(self, test_case):
+        # type: (unittest.TestCase) -> None
+        self.test_case = test_case
+        self.base_path = get_test_path(test_case, "dumped_graphs")
+        shutil.rmtree(self.base_path, ignore_errors=True)
+        self.base_path.mkdir(parents=True, exist_ok=True)
+
+    def __call__(self, name, dot):
+        # type: (str, str) -> None
+        path = self.base_path / name
+        dot_path = path.with_suffix(".dot")
+        dot_path.write_text(dot)