From 74a864a3ef7f009e322a965353be511ad9f031b0 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 9 Dec 2022 00:06:41 -0800 Subject: [PATCH] copy-merging works afaict! -- some tests still broken: out-of-date tests not yet all updated for new register allocator results --- .../_tests/test_register_allocator.py | 52 +++--- .../_tests/test_toom_cook.py | 147 +++++---------- src/bigint_presentation_code/compiler_ir.py | 26 +++ .../register_allocator.py | 171 +++++++++++++----- .../register_allocator_test_util.py | 21 +++ 5 files changed, 235 insertions(+), 182 deletions(-) create mode 100644 src/bigint_presentation_code/register_allocator_test_util.py diff --git a/src/bigint_presentation_code/_tests/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py index d0d9068..be76d2d 100644 --- a/src/bigint_presentation_code/_tests/test_register_allocator.py +++ b/src/bigint_presentation_code/_tests/test_register_allocator.py @@ -1,22 +1,10 @@ import sys import unittest -import shutil from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind, SSAVal) from bigint_presentation_code.register_allocator import allocate_registers -from nmutil.get_test_path import get_test_path - - -def dump_graphs(test_case, graphs): - # type: (unittest.TestCase, dict[str, str]) -> None - base_path = get_test_path(test_case, "dumped_graphs") - shutil.rmtree(base_path, ignore_errors=True) - base_path.mkdir(parents=True, exist_ok=True) - for name, dot in graphs.items(): - path = base_path / name - dot_path = path.with_suffix(".dot") - dot_path.write_text(dot) +from bigint_presentation_code.register_allocator_test_util import GraphDumper class TestRegisterAllocator(unittest.TestCase): @@ -48,7 +36,8 @@ class TestRegisterAllocator(unittest.TestCase): def test_register_allocate(self): fn, _arg = self.make_add_fn() - reg_assignments = allocate_registers(fn, debug_out=sys.stdout) + reg_assignments = allocate_registers( + fn, debug_out=sys.stdout, dump_graph=GraphDumper(self)) self.assertEqual( repr(reg_assignments), @@ -108,7 +97,8 @@ class TestRegisterAllocator(unittest.TestCase): def test_gen_asm(self): fn, _arg = self.make_add_fn() - reg_assignments = allocate_registers(fn) + reg_assignments = allocate_registers( + fn, debug_out=sys.stdout, dump_graph=GraphDumper(self)) self.assertEqual( repr(reg_assignments), @@ -131,14 +121,14 @@ class TestRegisterAllocator(unittest.TestCase): "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " @@ -170,9 +160,7 @@ class TestRegisterAllocator(unittest.TestCase): state = GenAsmState(reg_assignments) fn.gen_asm(state) self.assertEqual(state.output, [ - 'or 4, 3, 3', 'setvl 0, 0, 32, 0, 1, 1', - 'or 3, 4, 4', 'setvl 0, 0, 32, 0, 1, 1', 'sv.ld *14, 0(3)', 'setvl 0, 0, 32, 0, 1, 1', @@ -187,22 +175,23 @@ class TestRegisterAllocator(unittest.TestCase): 'sv.adde *14, *78, *46', 'setvl 0, 0, 32, 0, 1, 1', 'setvl 0, 0, 32, 0, 1, 1', - 'or 3, 4, 4', 'setvl 0, 0, 32, 0, 1, 1', - 'sv.std *14, 0(3)', + 'sv.std *14, 0(3)' ]) def test_register_allocate_graphs(self): fn, _arg = self.make_add_fn() graphs = {} # type: dict[str, str] + graph_dumper = GraphDumper(self) + def dump_graph(name, dot): # type: (str, str) -> None self.assertNotIn(name, graphs, "duplicate graph name") graphs[name] = dot - allocated = allocate_registers(fn, dump_graph=dump_graph, - debug_out=sys.stdout) - dump_graphs(self, graphs) + graph_dumper(name, dot) + allocated = allocate_registers( + fn, debug_out=sys.stdout, dump_graph=dump_graph) self.assertEqual( repr(allocated), "{" @@ -224,14 +213,14 @@ class TestRegisterAllocator(unittest.TestCase): "Loc(kind=LocKind.GPR, start=14, reg_len=32), " ">: " "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " ">: " "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " ">: " "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " ">: " @@ -338,7 +327,8 @@ class TestRegisterAllocator(unittest.TestCase): _concat = fn.append_new_op( OpKind.Concat, input_vals=[*spread[::-1], vl], name="concat", maxvl=maxvl) - reg_assignments = allocate_registers(fn, debug_out=sys.stdout) + reg_assignments = allocate_registers( + fn, debug_out=sys.stdout, dump_graph=GraphDumper(self)) self.assertEqual( repr(reg_assignments), diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index eadc080..dc3bb8d 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -1,4 +1,5 @@ from contextlib import contextmanager +import sys import unittest from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable @@ -9,6 +10,7 @@ from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BITS, PostRASimState, PreRASimState, SSAVal) from bigint_presentation_code.register_allocator import allocate_registers +from bigint_presentation_code.register_allocator_test_util import GraphDumper from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul, simple_mul) from bigint_presentation_code.util import OSet @@ -77,21 +79,21 @@ class Mul: name="store_dest") -def get_post_ra_state_factory(code): - # type: (Mul) -> _StateFactory - ssa_val_to_loc_map = allocate_registers(code.fn) - - @contextmanager - def state_factory(): - yield PostRASimState( - ssa_val_to_loc_map=ssa_val_to_loc_map, - memory={}, loc_values={}) - return state_factory - - class TestToomCook(unittest.TestCase): maxDiff = None + def get_post_ra_state_factory(self, code): + # type: (Mul) -> _StateFactory + ssa_val_to_loc_map = allocate_registers( + code.fn, debug_out=sys.stdout, dump_graph=GraphDumper(self)) + + @contextmanager + def state_factory(): + yield PostRASimState( + ssa_val_to_loc_map=ssa_val_to_loc_map, + memory={}, loc_values={}) + return state_factory + def test_toom_2_repr(self): TOOM_2 = ToomCookInstance.make_toom_2() # print(repr(repr(TOOM_2))) @@ -269,7 +271,7 @@ class TestToomCook(unittest.TestCase): for rhs_signed in False, True: self.tst_simple_mul_192x192_sim( lhs_signed=lhs_signed, rhs_signed=rhs_signed, - get_state_factory=get_post_ra_state_factory) + get_state_factory=self.get_post_ra_state_factory) def tst_simple_mul_192x192_sim( self, lhs_signed, # type: bool @@ -485,7 +487,8 @@ class TestToomCook(unittest.TestCase): def test_simple_mul_192x192_reg_alloc(self): code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn - assigned_registers = allocate_registers(fn) + assigned_registers = allocate_registers( + fn, debug_out=sys.stdout, dump_graph=GraphDumper(self)) self.assertEqual( repr(assigned_registers), "{" ">: " @@ -865,150 +868,85 @@ class TestToomCook(unittest.TestCase): def test_simple_mul_192x192_asm(self): code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3) fn = code.fn - assigned_registers = allocate_registers(fn) + assigned_registers = allocate_registers( + fn, debug_out=sys.stdout, dump_graph=GraphDumper(self)) gen_asm_state = GenAsmState(assigned_registers) fn.gen_asm(gen_asm_state) self.assertEqual(gen_asm_state.output, [ - 'or 27, 3, 3', + 'or 9, 3, 3', 'setvl 0, 0, 3, 0, 1, 1', - 'or 6, 27, 27', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.ld *3, 48(6)', + 'sv.ld *20, 48(9)', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *24, *3, *3', 'setvl 0, 0, 3, 0, 1, 1', - 'or 6, 27, 27', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.ld *3, 72(6)', + 'sv.ld *10, 72(9)', 'setvl 0, 0, 3, 0, 1, 1', 'setvl 0, 0, 3, 0, 1, 1', 'setvl 0, 0, 3, 0, 1, 1', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or/mrr *5, *3, *3', - 'or 4, 5, 5', - 'or 14, 6, 6', - 'or 23, 7, 7', - 'addi 3, 0, 0', - 'or 22, 3, 3', + 'addi 6, 0, 0', + 'or 8, 6, 6', 'setvl 0, 0, 3, 0, 1, 1', 'addi 3, 0, 0', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *8, *24, *24', - 'or 7, 4, 4', - 'or 6, 22, 22', + 'or 6, 8, 8', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.maddedu *3, *8, 7, 6', + 'sv.maddedu *14, *20, 10, 6', 'setvl 0, 0, 3, 0, 1, 1', - 'or 19, 6, 6', + 'or 17, 6, 6', 'setvl 0, 0, 3, 0, 1, 1', 'setvl 0, 0, 3, 0, 1, 1', - 'or 21, 3, 3', - 'or 12, 4, 4', - 'or 11, 5, 5', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *8, *24, *24', - 'or 7, 14, 14', - 'or 6, 22, 22', + 'or 6, 8, 8', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.maddedu *3, *8, 7, 6', + 'sv.maddedu *3, *20, 11, 6', 'setvl 0, 0, 3, 0, 1, 1', - 'or 18, 6, 6', 'setvl 0, 0, 3, 0, 1, 1', 'setvl 0, 0, 3, 0, 1, 1', - 'or 17, 3, 3', - 'or 16, 4, 4', - 'or 15, 5, 5', - 'addi 3, 0, 0', - 'or 8, 3, 3', - 'addi 3, 0, 0', - 'or 14, 3, 3', + 'addi 18, 0, 0', + 'addi 7, 0, 0', 'setvl 0, 0, 5, 0, 1, 1', - 'or 3, 12, 12', - 'or 4, 11, 11', - 'or 5, 19, 19', - 'or 6, 8, 8', - 'or 7, 8, 8', + 'or 19, 18, 18', 'setvl 0, 0, 5, 0, 1, 1', 'setvl 0, 0, 5, 0, 1, 1', - 'sv.or *8, *3, *3', - 'or 3, 17, 17', - 'or 4, 16, 16', - 'or 5, 15, 15', - 'or 6, 18, 18', - 'or 7, 14, 14', 'setvl 0, 0, 5, 0, 1, 1', 'setvl 0, 0, 5, 0, 1, 1', 'addic 0, 0, 0', 'setvl 0, 0, 5, 0, 1, 1', - 'sv.or *14, *8, *8', 'setvl 0, 0, 5, 0, 1, 1', - 'sv.or *8, *3, *3', 'setvl 0, 0, 5, 0, 1, 1', - 'sv.adde *3, *14, *8', + 'sv.adde *15, *15, *3', 'setvl 0, 0, 5, 0, 1, 1', 'setvl 0, 0, 5, 0, 1, 1', 'setvl 0, 0, 5, 0, 1, 1', - 'or 20, 3, 3', - 'or 19, 4, 4', - 'or 18, 5, 5', - 'or 17, 6, 6', - 'or 16, 7, 7', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.or *8, *24, *24', - 'or 7, 23, 23', - 'or 6, 22, 22', + 'or 6, 8, 8', 'setvl 0, 0, 3, 0, 1, 1', - 'sv.maddedu *3, *8, 7, 6', + 'sv.maddedu *3, *20, 12, 6', 'setvl 0, 0, 3, 0, 1, 1', - 'or 15, 6, 6', 'setvl 0, 0, 3, 0, 1, 1', 'setvl 0, 0, 3, 0, 1, 1', - 'or 14, 3, 3', - 'or 12, 4, 4', - 'or 11, 5, 5', 'setvl 0, 0, 4, 0, 1, 1', - 'or 3, 19, 19', - 'or 4, 18, 18', - 'or 5, 17, 17', - 'or 6, 16, 16', 'setvl 0, 0, 4, 0, 1, 1', 'setvl 0, 0, 4, 0, 1, 1', - 'sv.or *7, *3, *3', - 'or 3, 14, 14', - 'or 4, 12, 12', - 'or 5, 11, 11', - 'or 6, 15, 15', 'setvl 0, 0, 4, 0, 1, 1', 'setvl 0, 0, 4, 0, 1, 1', 'addic 0, 0, 0', 'setvl 0, 0, 4, 0, 1, 1', - 'sv.or *14, *7, *7', 'setvl 0, 0, 4, 0, 1, 1', - 'sv.or *7, *3, *3', 'setvl 0, 0, 4, 0, 1, 1', - 'sv.adde *3, *14, *7', + 'sv.adde *16, *16, *3', 'setvl 0, 0, 4, 0, 1, 1', 'setvl 0, 0, 4, 0, 1, 1', 'setvl 0, 0, 4, 0, 1, 1', - 'or 12, 3, 3', - 'or 11, 4, 4', - 'or 10, 5, 5', - 'or 9, 6, 6', 'setvl 0, 0, 6, 0, 1, 1', - 'or 3, 21, 21', - 'or 4, 20, 20', - 'or 5, 12, 12', - 'or 6, 11, 11', - 'or 7, 10, 10', - 'or 8, 9, 9', 'setvl 0, 0, 6, 0, 1, 1', 'setvl 0, 0, 6, 0, 1, 1', 'setvl 0, 0, 6, 0, 1, 1', 'setvl 0, 0, 6, 0, 1, 1', - 'sv.or/mrr *4, *3, *3', - 'or 3, 27, 27', 'setvl 0, 0, 6, 0, 1, 1', - 'sv.std *4, 0(3)' + 'sv.std *14, 0(9)', ]) def toom_2_mul_256x256(self, lhs_signed, rhs_signed): @@ -1122,27 +1060,28 @@ class TestToomCook(unittest.TestCase): def test_toom_2_mul_256x256_uu_post_ra_sim(self): self.tst_toom_2_mul_256x256_sim( lhs_signed=False, rhs_signed=False, - get_state_factory=get_post_ra_state_factory) + get_state_factory=self.get_post_ra_state_factory) def test_toom_2_mul_256x256_su_post_ra_sim(self): self.tst_toom_2_mul_256x256_sim( lhs_signed=True, rhs_signed=False, - get_state_factory=get_post_ra_state_factory) + get_state_factory=self.get_post_ra_state_factory) def test_toom_2_mul_256x256_us_post_ra_sim(self): self.tst_toom_2_mul_256x256_sim( lhs_signed=False, rhs_signed=True, - get_state_factory=get_post_ra_state_factory) + get_state_factory=self.get_post_ra_state_factory) def test_toom_2_mul_256x256_ss_post_ra_sim(self): self.tst_toom_2_mul_256x256_sim( lhs_signed=True, rhs_signed=True, - get_state_factory=get_post_ra_state_factory) + get_state_factory=self.get_post_ra_state_factory) def test_toom_2_mul_256x256_asm(self): code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False) fn = code.fn - assigned_registers = allocate_registers(fn) + assigned_registers = allocate_registers( + fn, debug_out=sys.stdout, dump_graph=GraphDumper(self)) gen_asm_state = GenAsmState(assigned_registers) fn.gen_asm(gen_asm_state) self.assertEqual(gen_asm_state.output, [ diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 699102b..4f8fb79 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -787,6 +787,20 @@ class Loc(metaclass=InternedMeta): return Loc(kind=self.kind, start=self.start + offset, reg_len=subloc_ty.reg_len) + def get_superloc_with_self_at_offset(self, superloc_ty, offset): + # type: (Ty, int) -> Loc + """get the Loc containing `self` such that: + `retval.get_subloc_at_offset(self.ty, offset) == self` + and `retval.ty == superloc_ty` + """ + if superloc_ty.base_ty != self.kind.base_ty: + raise ValueError("BaseTy mismatch") + if offset < 0 or offset + self.reg_len > superloc_ty.reg_len: + raise ValueError("invalid sub-Loc: offset and/or " + "superloc_ty.reg_len out of range") + return Loc(kind=self.kind, + start=self.start - offset, reg_len=superloc_ty.reg_len) + SPECIAL_GPRS = ( Loc(kind=LocKind.GPR, start=0, reg_len=1), @@ -900,6 +914,18 @@ class LocSet(OFSet[Loc], metaclass=InternedMeta): def __repr__(self): return f"LocSet(starts={self.starts!r}, ty={self.ty!r})" + @cached_property + def only_loc(self): + # type: () -> Loc | None + """if len(self) == 1 then return the Loc in self, otherwise None""" + only_loc = None + for i in self: + if only_loc is None: + only_loc = i + else: + return None # len(self) > 1 + return only_loc + @plain_data(frozen=True, unsafe_hash=True) @final diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index 540af0a..76fc966 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -6,11 +6,11 @@ this uses an algorithm based on: """ from functools import reduce -from itertools import chain, combinations, count +from itertools import combinations, count from typing import Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple from cached_property import cached_property -from nmutil.plain_data import plain_data +from nmutil.plain_data import plain_data, replace from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc, LocSet, Op, ProgramRange, @@ -59,8 +59,8 @@ class MergedSSAVal(metaclass=InternedMeta): __slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set", "first_loc") - def __init__(self, fn_analysis, ssa_val_offsets): - # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None + def __init__(self, fn_analysis, ssa_val_offsets, loc_set=None): + # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None self.fn_analysis = fn_analysis if isinstance(ssa_val_offsets, SSAVal): ssa_val_offsets = {ssa_val_offsets: 0} @@ -74,7 +74,10 @@ class MergedSSAVal(metaclass=InternedMeta): self.first_ssa_val = first_ssa_val # type: SSAVal # self.ty checks for mismatched base_ty reg_len = self.ty.reg_len - loc_set = None # type: None | LocSet + if loc_set is not None and loc_set.ty != self.ty: + raise ValueError( + f"invalid loc_set, type doesn't match: " + f"{loc_set.ty} != {self.ty}") for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items(): def locs(): # type: () -> Iterable[Loc] @@ -163,12 +166,17 @@ class MergedSSAVal(metaclass=InternedMeta): @cached_property def __hash(self): # type: () -> int - return hash((self.fn_analysis, self.ssa_val_offsets)) + return hash((self.fn_analysis, self.ssa_val_offsets, self.loc_set)) def __hash__(self): # type: () -> int return self.__hash + @property + def only_loc(self): + # type: () -> Loc | None + return self.loc_set.only_loc + @cached_property def offset(self): # type: () -> int @@ -226,6 +234,16 @@ class MergedSSAVal(metaclass=InternedMeta): ssa_val_offsets[ssa_val] + additional_offset - offset) raise ValueError("can't change offset to match unrelated MergedSSAVal") + def with_loc(self, loc): + # type: (Loc) -> MergedSSAVal + if loc not in self.loc_set: + raise ValueError( + f"Loc is not allowed -- not a member of `self.loc_set`: " + f"{loc} not in {self.loc_set}") + return MergedSSAVal(fn_analysis=self.fn_analysis, + ssa_val_offsets=self.ssa_val_offsets, + loc_set=LocSet([loc])) + def merged(self, *others): # type: (*MergedSSAVal) -> MergedSSAVal retval = dict(self.ssa_val_offsets) @@ -278,6 +296,21 @@ class MergedSSAVal(metaclass=InternedMeta): return lhs, rhs return None + def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation): + # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal + cr_lhs, cr_rhs = copy_relation + if cr_lhs.ssa_val not in self.ssa_vals: + cr_lhs, cr_rhs = cr_rhs, cr_lhs + lhs_merged = self.with_offset_to_match( + cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx) + if lhs_loc is not None: + lhs_merged = lhs_merged.with_loc(lhs_loc) + rhs_merged = rhs.with_offset_to_match( + cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx) + if rhs_loc is not None: + rhs_merged = rhs_merged.with_loc(rhs_loc) + return lhs_merged.merged(rhs_merged).normalized() + @final class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]): @@ -350,7 +383,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): added += 1 retval = IGNode( node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val, - edges={}, loc=None, ignored=False) + edges={}, loc=merged_ssa_val.only_loc, ignored=False) self.__map[merged_ssa_val] = retval self.__next_node_id += 1 added = None @@ -367,7 +400,6 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): # type: (MergedSSAVal) -> IGNode source_nodes = OSet() # type: OSet[IGNode] edges = {} # type: dict[IGNode, IGEdge] - loc = None # type: Loc | None for ssa_val in final_merged_ssa_val.ssa_vals: merged_ssa_val = self.__merged_ssa_val_map[ssa_val] source_node = self.__map[merged_ssa_val] @@ -380,11 +412,11 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): f"but not in merged IGNode's merged_ssa_val: " f"source_node={source_node} " f"final_merged_ssa_val={final_merged_ssa_val}") - if loc is None: - loc = source_node.loc - elif source_node.loc is not None and loc != source_node.loc: - raise ValueError(f"can't merge IGNodes with mismatched `loc` " - f"values: {loc} != {source_node.loc}") + if source_node.loc != source_node.merged_ssa_val.only_loc: + raise ValueError( + f"can't merge IGNodes: loc != merged_ssa_val.only_loc: " + f"{source_node.loc} != " + f"{source_node.merged_ssa_val.only_loc}") for n, edge in source_node.edges.items(): if n in edges: edge = edge.merged(edges[n]) @@ -394,6 +426,18 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): # we're finished checking validity, now we can modify stuff for n in source_nodes: edges.pop(n, None) + loc = final_merged_ssa_val.only_loc + for n, edge in edges.items(): + if edge.copy_relation is None or not edge.interferes: + continue + try: + # if merging works, then the edge can't interfere + _ = final_merged_ssa_val.copy_merged( + lhs_loc=loc, rhs=n.merged_ssa_val, rhs_loc=n.loc, + copy_relation=edge.copy_relation) + except BadMergedSSAVal: + continue + edges[n] = replace(edge, interferes=False) retval = IGNode( node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val, edges=edges, loc=loc, ignored=False) @@ -440,7 +484,7 @@ class InterferenceGraph: merged2 = self.merged_ssa_val_map[ssa_val2] merged = merged1.with_offset_to_match(ssa_val1) return merged.merged(merged2.with_offset_to_match( - ssa_val2, additional_offset=additional_offset)) + ssa_val2, additional_offset=additional_offset)).normalized() def merge(self, ssa_val1, ssa_val2, additional_offset=0): # type: (SSAVal, SSAVal, int) -> IGNode @@ -448,27 +492,9 @@ class InterferenceGraph: ssa_val1=ssa_val1, ssa_val2=ssa_val2, additional_offset=additional_offset)) - def copy_merge_preview(self, node1, node2): - # type: (IGNode, IGNode) -> MergedSSAVal - try: - copy_relation = node1.edges[node2].copy_relation - except KeyError: - raise ValueError("nodes aren't copy related") - if copy_relation is None: - raise ValueError("nodes aren't copy related") - lhs, rhs = copy_relation - if lhs.ssa_val not in node1.merged_ssa_val.ssa_vals: - lhs, rhs = rhs, lhs - lhs_merged = node1.merged_ssa_val.with_offset_to_match( - lhs.ssa_val, additional_offset=-lhs.reg_idx) - rhs_merged = node2.merged_ssa_val.with_offset_to_match( - rhs.ssa_val, additional_offset=-rhs.reg_idx) - return lhs_merged.merged(rhs_merged) - def copy_merge(self, node1, node2): # type: (IGNode, IGNode) -> IGNode - return self.nodes.merge_into_one_node(self.copy_merge_preview( - node1=node1, node2=node2)) + return self.nodes.merge_into_one_node(node1.copy_merge_preview(node2)) def local_colorability_score(self, node, merged_in_copy=None): # type: (IGNode, None | IGNode) -> int @@ -485,7 +511,10 @@ class InterferenceGraph: loc_set = node.loc_set edges = node.edges if merged_in_copy is not None: - loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set + if merged_in_copy.ignored: + raise ValueError( + "can't get local_colorability_score of ignored node") + loc_set = node.copy_merge_preview(merged_in_copy).loc_set edges = edges.copy() for neighbor, edge in merged_in_copy.edges.items(): edges[neighbor] = edge.merged(edges.get(neighbor)) @@ -665,12 +694,12 @@ class IGNode: def __eq__(self, other): # type: (object) -> bool if isinstance(other, IGNode): - return self.merged_ssa_val == other.merged_ssa_val + return self.node_id == other.node_id return NotImplemented def __hash__(self): # type: () -> int - return hash(self.merged_ssa_val) + return hash(self.node_id) def __repr__(self, repr_state=None, short=False): # type: (None | IGNodeReprState, bool) -> str @@ -703,6 +732,19 @@ class IGNode: return True return False + def copy_merge_preview(self, rhs_node): + # type: (IGNode) -> MergedSSAVal + try: + copy_relation = self.edges[rhs_node].copy_relation + except KeyError: + raise ValueError("nodes aren't copy related") + if copy_relation is None: + raise ValueError("nodes aren't copy related") + return self.merged_ssa_val.copy_merged( + lhs_loc=self.loc, + rhs=rhs_node.merged_ssa_val, rhs_loc=rhs_node.loc, + copy_relation=copy_relation) + class AllocationFailedError(Exception): def __init__(self, msg, node, interference_graph): @@ -748,24 +790,32 @@ def allocate_registers( print(f"After InterferenceGraph.minimally_merged():\n" f"{interference_graph}", file=debug_out, flush=True) + for i, j in combinations(interference_graph.nodes.values(), 2): + copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val) + i.merge_edge(j, IGEdge(copy_relation=copy_relation)) + for pp, ssa_vals in fn_analysis.live_at.items(): live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal] for ssa_val in ssa_vals: live_merged_ssa_vals.add( interference_graph.merged_ssa_val_map[ssa_val]) for i, j in combinations(live_merged_ssa_vals, 2): - if i.loc_set.max_conflicts_with(j.loc_set) != 0: - interference_graph.nodes[i].merge_edge( - interference_graph.nodes[j], - edge=IGEdge(interferes=True)) + if i.loc_set.max_conflicts_with(j.loc_set) == 0: + continue + node_i = interference_graph.nodes[i] + node_j = interference_graph.nodes[j] + if node_j in node_i.edges: + if node_i.edges[node_j].copy_relation is not None: + try: + _ = node_i.copy_merge_preview(node_j) + continue # doesn't interfere if copy merging succeeds + except BadMergedSSAVal: + pass + node_i.merge_edge(node_j, edge=IGEdge(interferes=True)) if debug_out is not None: print(f"processed {pp} out of {fn_analysis.all_program_points}", file=debug_out, flush=True) - for i, j in combinations(interference_graph.nodes.values(), 2): - copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val) - i.merge_edge(j, IGEdge(copy_relation=copy_relation)) - if debug_out is not None: print(f"After adding interference graph edges:\n" f"{interference_graph}", file=debug_out, flush=True) @@ -900,10 +950,37 @@ def allocate_registers( "IGNode is pre-allocated to a conflicting Loc", node=node, interference_graph=interference_graph) else: - # pick the first non-conflicting register in node.reg_class, since - # register classes are ordered from most preferred to least - # preferred register. + # Locs to try allocating, ordered from most preferred to least + # preferred + locs = OSet() + # prefer eliminating copies + for neighbor, edge in node.edges.items(): + if neighbor.loc is None or edge.copy_relation is None: + continue + try: + merged = node.copy_merge_preview(neighbor) + except BadMergedSSAVal: + continue + # get merged_loc if merged.loc_set has a single Loc + merged_loc = merged.only_loc + if merged_loc is None: + continue + ssa_val = node.merged_ssa_val.first_ssa_val + ssa_val_loc = merged_loc.get_subloc_at_offset( + subloc_ty=ssa_val.ty, + offset=merged.ssa_val_offsets[ssa_val]) + node_loc = ssa_val_loc.get_superloc_with_self_at_offset( + superloc_ty=node.merged_ssa_val.ty, + offset=node.merged_ssa_val.ssa_val_offsets[ssa_val]) + assert node_loc in node.merged_ssa_val.loc_set, "logic error" + locs.add(node_loc) + # add node's allowed Locs as fallback for loc in node.loc_set: + # TODO: add in order of preference + locs.add(loc) + # pick the first non-conflicting register in locs, since locs is + # ordered from most preferred to least preferred register. + for loc in locs: if not node.loc_conflicts_with_neighbors(loc): node.loc = loc break diff --git a/src/bigint_presentation_code/register_allocator_test_util.py b/src/bigint_presentation_code/register_allocator_test_util.py new file mode 100644 index 0000000..e0b475b --- /dev/null +++ b/src/bigint_presentation_code/register_allocator_test_util.py @@ -0,0 +1,21 @@ +import unittest +import shutil + +from nmutil.get_test_path import get_test_path +from bigint_presentation_code.type_util import final + + +@final +class GraphDumper: + def __init__(self, test_case): + # type: (unittest.TestCase) -> None + self.test_case = test_case + self.base_path = get_test_path(test_case, "dumped_graphs") + shutil.rmtree(self.base_path, ignore_errors=True) + self.base_path.mkdir(parents=True, exist_ok=True) + + def __call__(self, name, dot): + # type: (str, str) -> None + path = self.base_path / name + dot_path = path.with_suffix(".dot") + dot_path.write_text(dot) -- 2.30.2