import sys
import unittest
-import shutil
from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind,
SSAVal)
from bigint_presentation_code.register_allocator import allocate_registers
-from nmutil.get_test_path import get_test_path
-
-
-def dump_graphs(test_case, graphs):
- # type: (unittest.TestCase, dict[str, str]) -> None
- base_path = get_test_path(test_case, "dumped_graphs")
- shutil.rmtree(base_path, ignore_errors=True)
- base_path.mkdir(parents=True, exist_ok=True)
- for name, dot in graphs.items():
- path = base_path / name
- dot_path = path.with_suffix(".dot")
- dot_path.write_text(dot)
+from bigint_presentation_code.register_allocator_test_util import GraphDumper
class TestRegisterAllocator(unittest.TestCase):
def test_register_allocate(self):
fn, _arg = self.make_add_fn()
- reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+ reg_assignments = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(reg_assignments),
def test_gen_asm(self):
fn, _arg = self.make_add_fn()
- reg_assignments = allocate_registers(fn)
+ reg_assignments = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(reg_assignments),
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
"<ld.outputs[0]: <I64*32>>: "
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<ld.inp0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<arg.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<st.inp1.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<arg.out0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
"<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
state = GenAsmState(reg_assignments)
fn.gen_asm(state)
self.assertEqual(state.output, [
- 'or 4, 3, 3',
'setvl 0, 0, 32, 0, 1, 1',
- 'or 3, 4, 4',
'setvl 0, 0, 32, 0, 1, 1',
'sv.ld *14, 0(3)',
'setvl 0, 0, 32, 0, 1, 1',
'sv.adde *14, *78, *46',
'setvl 0, 0, 32, 0, 1, 1',
'setvl 0, 0, 32, 0, 1, 1',
- 'or 3, 4, 4',
'setvl 0, 0, 32, 0, 1, 1',
- 'sv.std *14, 0(3)',
+ 'sv.std *14, 0(3)'
])
def test_register_allocate_graphs(self):
fn, _arg = self.make_add_fn()
graphs = {} # type: dict[str, str]
+ graph_dumper = GraphDumper(self)
+
def dump_graph(name, dot):
# type: (str, str) -> None
self.assertNotIn(name, graphs, "duplicate graph name")
graphs[name] = dot
- allocated = allocate_registers(fn, dump_graph=dump_graph,
- debug_out=sys.stdout)
- dump_graphs(self, graphs)
+ graph_dumper(name, dot)
+ allocated = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=dump_graph)
self.assertEqual(
repr(allocated),
"{"
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
"<ld.outputs[0]: <I64*32>>: "
"Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<ld.inp0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<arg.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
"<st.inp1.copy.outputs[0]: <I64>>: "
"Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<arg.out0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
"<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
"Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
"<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
_concat = fn.append_new_op(
OpKind.Concat, input_vals=[*spread[::-1], vl],
name="concat", maxvl=maxvl)
- reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+ reg_assignments = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(reg_assignments),
from contextlib import contextmanager
+import sys
import unittest
from typing import Any, Callable, ContextManager, Iterator, Tuple, Iterable
PostRASimState,
PreRASimState, SSAVal)
from bigint_presentation_code.register_allocator import allocate_registers
+from bigint_presentation_code.register_allocator_test_util import GraphDumper
from bigint_presentation_code.toom_cook import (ToomCookInstance, ToomCookMul,
simple_mul)
from bigint_presentation_code.util import OSet
name="store_dest")
-def get_post_ra_state_factory(code):
- # type: (Mul) -> _StateFactory
- ssa_val_to_loc_map = allocate_registers(code.fn)
-
- @contextmanager
- def state_factory():
- yield PostRASimState(
- ssa_val_to_loc_map=ssa_val_to_loc_map,
- memory={}, loc_values={})
- return state_factory
-
-
class TestToomCook(unittest.TestCase):
maxDiff = None
+ def get_post_ra_state_factory(self, code):
+ # type: (Mul) -> _StateFactory
+ ssa_val_to_loc_map = allocate_registers(
+ code.fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
+
+ @contextmanager
+ def state_factory():
+ yield PostRASimState(
+ ssa_val_to_loc_map=ssa_val_to_loc_map,
+ memory={}, loc_values={})
+ return state_factory
+
def test_toom_2_repr(self):
TOOM_2 = ToomCookInstance.make_toom_2()
# print(repr(repr(TOOM_2)))
for rhs_signed in False, True:
self.tst_simple_mul_192x192_sim(
lhs_signed=lhs_signed, rhs_signed=rhs_signed,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def tst_simple_mul_192x192_sim(
self, lhs_signed, # type: bool
def test_simple_mul_192x192_reg_alloc(self):
code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
- assigned_registers = allocate_registers(fn)
+ assigned_registers = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
self.assertEqual(
repr(assigned_registers), "{"
"<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
def test_simple_mul_192x192_asm(self):
code = Mul(mul=simple_umul, lhs_size_in_words=3, rhs_size_in_words=3)
fn = code.fn
- assigned_registers = allocate_registers(fn)
+ assigned_registers = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
gen_asm_state = GenAsmState(assigned_registers)
fn.gen_asm(gen_asm_state)
self.assertEqual(gen_asm_state.output, [
- 'or 27, 3, 3',
+ 'or 9, 3, 3',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 6, 27, 27',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.ld *3, 48(6)',
+ 'sv.ld *20, 48(9)',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *24, *3, *3',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 6, 27, 27',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.ld *3, 72(6)',
+ 'sv.ld *10, 72(9)',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or/mrr *5, *3, *3',
- 'or 4, 5, 5',
- 'or 14, 6, 6',
- 'or 23, 7, 7',
- 'addi 3, 0, 0',
- 'or 22, 3, 3',
+ 'addi 6, 0, 0',
+ 'or 8, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'addi 3, 0, 0',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *8, *24, *24',
- 'or 7, 4, 4',
- 'or 6, 22, 22',
+ 'or 6, 8, 8',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.maddedu *3, *8, 7, 6',
+ 'sv.maddedu *14, *20, 10, 6',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 19, 6, 6',
+ 'or 17, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 21, 3, 3',
- 'or 12, 4, 4',
- 'or 11, 5, 5',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *8, *24, *24',
- 'or 7, 14, 14',
- 'or 6, 22, 22',
+ 'or 6, 8, 8',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.maddedu *3, *8, 7, 6',
+ 'sv.maddedu *3, *20, 11, 6',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 18, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 17, 3, 3',
- 'or 16, 4, 4',
- 'or 15, 5, 5',
- 'addi 3, 0, 0',
- 'or 8, 3, 3',
- 'addi 3, 0, 0',
- 'or 14, 3, 3',
+ 'addi 18, 0, 0',
+ 'addi 7, 0, 0',
'setvl 0, 0, 5, 0, 1, 1',
- 'or 3, 12, 12',
- 'or 4, 11, 11',
- 'or 5, 19, 19',
- 'or 6, 8, 8',
- 'or 7, 8, 8',
+ 'or 19, 18, 18',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.or *8, *3, *3',
- 'or 3, 17, 17',
- 'or 4, 16, 16',
- 'or 5, 15, 15',
- 'or 6, 18, 18',
- 'or 7, 14, 14',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
'addic 0, 0, 0',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.or *14, *8, *8',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.or *8, *3, *3',
'setvl 0, 0, 5, 0, 1, 1',
- 'sv.adde *3, *14, *8',
+ 'sv.adde *15, *15, *3',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
'setvl 0, 0, 5, 0, 1, 1',
- 'or 20, 3, 3',
- 'or 19, 4, 4',
- 'or 18, 5, 5',
- 'or 17, 6, 6',
- 'or 16, 7, 7',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.or *8, *24, *24',
- 'or 7, 23, 23',
- 'or 6, 22, 22',
+ 'or 6, 8, 8',
'setvl 0, 0, 3, 0, 1, 1',
- 'sv.maddedu *3, *8, 7, 6',
+ 'sv.maddedu *3, *20, 12, 6',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 15, 6, 6',
'setvl 0, 0, 3, 0, 1, 1',
'setvl 0, 0, 3, 0, 1, 1',
- 'or 14, 3, 3',
- 'or 12, 4, 4',
- 'or 11, 5, 5',
'setvl 0, 0, 4, 0, 1, 1',
- 'or 3, 19, 19',
- 'or 4, 18, 18',
- 'or 5, 17, 17',
- 'or 6, 16, 16',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.or *7, *3, *3',
- 'or 3, 14, 14',
- 'or 4, 12, 12',
- 'or 5, 11, 11',
- 'or 6, 15, 15',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
'addic 0, 0, 0',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.or *14, *7, *7',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.or *7, *3, *3',
'setvl 0, 0, 4, 0, 1, 1',
- 'sv.adde *3, *14, *7',
+ 'sv.adde *16, *16, *3',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
'setvl 0, 0, 4, 0, 1, 1',
- 'or 12, 3, 3',
- 'or 11, 4, 4',
- 'or 10, 5, 5',
- 'or 9, 6, 6',
'setvl 0, 0, 6, 0, 1, 1',
- 'or 3, 21, 21',
- 'or 4, 20, 20',
- 'or 5, 12, 12',
- 'or 6, 11, 11',
- 'or 7, 10, 10',
- 'or 8, 9, 9',
'setvl 0, 0, 6, 0, 1, 1',
'setvl 0, 0, 6, 0, 1, 1',
'setvl 0, 0, 6, 0, 1, 1',
'setvl 0, 0, 6, 0, 1, 1',
- 'sv.or/mrr *4, *3, *3',
- 'or 3, 27, 27',
'setvl 0, 0, 6, 0, 1, 1',
- 'sv.std *4, 0(3)'
+ 'sv.std *14, 0(9)',
])
def toom_2_mul_256x256(self, lhs_signed, rhs_signed):
def test_toom_2_mul_256x256_uu_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=False, rhs_signed=False,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_su_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=True, rhs_signed=False,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_us_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=False, rhs_signed=True,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_ss_post_ra_sim(self):
self.tst_toom_2_mul_256x256_sim(
lhs_signed=True, rhs_signed=True,
- get_state_factory=get_post_ra_state_factory)
+ get_state_factory=self.get_post_ra_state_factory)
def test_toom_2_mul_256x256_asm(self):
code = self.toom_2_mul_256x256(lhs_signed=False, rhs_signed=False)
fn = code.fn
- assigned_registers = allocate_registers(fn)
+ assigned_registers = allocate_registers(
+ fn, debug_out=sys.stdout, dump_graph=GraphDumper(self))
gen_asm_state = GenAsmState(assigned_registers)
fn.gen_asm(gen_asm_state)
self.assertEqual(gen_asm_state.output, [
return Loc(kind=self.kind,
start=self.start + offset, reg_len=subloc_ty.reg_len)
+ def get_superloc_with_self_at_offset(self, superloc_ty, offset):
+ # type: (Ty, int) -> Loc
+ """get the Loc containing `self` such that:
+ `retval.get_subloc_at_offset(self.ty, offset) == self`
+ and `retval.ty == superloc_ty`
+ """
+ if superloc_ty.base_ty != self.kind.base_ty:
+ raise ValueError("BaseTy mismatch")
+ if offset < 0 or offset + self.reg_len > superloc_ty.reg_len:
+ raise ValueError("invalid sub-Loc: offset and/or "
+ "superloc_ty.reg_len out of range")
+ return Loc(kind=self.kind,
+ start=self.start - offset, reg_len=superloc_ty.reg_len)
+
SPECIAL_GPRS = (
Loc(kind=LocKind.GPR, start=0, reg_len=1),
def __repr__(self):
return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
+ @cached_property
+ def only_loc(self):
+ # type: () -> Loc | None
+ """if len(self) == 1 then return the Loc in self, otherwise None"""
+ only_loc = None
+ for i in self:
+ if only_loc is None:
+ only_loc = i
+ else:
+ return None # len(self) > 1
+ return only_loc
+
@plain_data(frozen=True, unsafe_hash=True)
@final
"""
from functools import reduce
-from itertools import chain, combinations, count
+from itertools import combinations, count
from typing import Callable, Container, Iterable, Iterator, Mapping, TextIO, Tuple
from cached_property import cached_property
-from nmutil.plain_data import plain_data
+from nmutil.plain_data import plain_data, replace
from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
LocSet, Op, ProgramRange,
__slots__ = ("fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set",
"first_loc")
- def __init__(self, fn_analysis, ssa_val_offsets):
- # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
+ def __init__(self, fn_analysis, ssa_val_offsets, loc_set=None):
+ # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal, LocSet | None) -> None
self.fn_analysis = fn_analysis
if isinstance(ssa_val_offsets, SSAVal):
ssa_val_offsets = {ssa_val_offsets: 0}
self.first_ssa_val = first_ssa_val # type: SSAVal
# self.ty checks for mismatched base_ty
reg_len = self.ty.reg_len
- loc_set = None # type: None | LocSet
+ if loc_set is not None and loc_set.ty != self.ty:
+ raise ValueError(
+ f"invalid loc_set, type doesn't match: "
+ f"{loc_set.ty} != {self.ty}")
for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
def locs():
# type: () -> Iterable[Loc]
@cached_property
def __hash(self):
# type: () -> int
- return hash((self.fn_analysis, self.ssa_val_offsets))
+ return hash((self.fn_analysis, self.ssa_val_offsets, self.loc_set))
def __hash__(self):
# type: () -> int
return self.__hash
+ @property
+ def only_loc(self):
+ # type: () -> Loc | None
+ return self.loc_set.only_loc
+
@cached_property
def offset(self):
# type: () -> int
ssa_val_offsets[ssa_val] + additional_offset - offset)
raise ValueError("can't change offset to match unrelated MergedSSAVal")
+ def with_loc(self, loc):
+ # type: (Loc) -> MergedSSAVal
+ if loc not in self.loc_set:
+ raise ValueError(
+ f"Loc is not allowed -- not a member of `self.loc_set`: "
+ f"{loc} not in {self.loc_set}")
+ return MergedSSAVal(fn_analysis=self.fn_analysis,
+ ssa_val_offsets=self.ssa_val_offsets,
+ loc_set=LocSet([loc]))
+
def merged(self, *others):
# type: (*MergedSSAVal) -> MergedSSAVal
retval = dict(self.ssa_val_offsets)
return lhs, rhs
return None
+ def copy_merged(self, lhs_loc, rhs, rhs_loc, copy_relation):
+ # type: (Loc | None, MergedSSAVal, Loc | None, _CopyRelation) -> MergedSSAVal
+ cr_lhs, cr_rhs = copy_relation
+ if cr_lhs.ssa_val not in self.ssa_vals:
+ cr_lhs, cr_rhs = cr_rhs, cr_lhs
+ lhs_merged = self.with_offset_to_match(
+ cr_lhs.ssa_val, additional_offset=-cr_lhs.reg_idx)
+ if lhs_loc is not None:
+ lhs_merged = lhs_merged.with_loc(lhs_loc)
+ rhs_merged = rhs.with_offset_to_match(
+ cr_rhs.ssa_val, additional_offset=-cr_rhs.reg_idx)
+ if rhs_loc is not None:
+ rhs_merged = rhs_merged.with_loc(rhs_loc)
+ return lhs_merged.merged(rhs_merged).normalized()
+
@final
class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
added += 1
retval = IGNode(
node_id=self.__next_node_id, merged_ssa_val=merged_ssa_val,
- edges={}, loc=None, ignored=False)
+ edges={}, loc=merged_ssa_val.only_loc, ignored=False)
self.__map[merged_ssa_val] = retval
self.__next_node_id += 1
added = None
# type: (MergedSSAVal) -> IGNode
source_nodes = OSet() # type: OSet[IGNode]
edges = {} # type: dict[IGNode, IGEdge]
- loc = None # type: Loc | None
for ssa_val in final_merged_ssa_val.ssa_vals:
merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
source_node = self.__map[merged_ssa_val]
f"but not in merged IGNode's merged_ssa_val: "
f"source_node={source_node} "
f"final_merged_ssa_val={final_merged_ssa_val}")
- if loc is None:
- loc = source_node.loc
- elif source_node.loc is not None and loc != source_node.loc:
- raise ValueError(f"can't merge IGNodes with mismatched `loc` "
- f"values: {loc} != {source_node.loc}")
+ if source_node.loc != source_node.merged_ssa_val.only_loc:
+ raise ValueError(
+ f"can't merge IGNodes: loc != merged_ssa_val.only_loc: "
+ f"{source_node.loc} != "
+ f"{source_node.merged_ssa_val.only_loc}")
for n, edge in source_node.edges.items():
if n in edges:
edge = edge.merged(edges[n])
# we're finished checking validity, now we can modify stuff
for n in source_nodes:
edges.pop(n, None)
+ loc = final_merged_ssa_val.only_loc
+ for n, edge in edges.items():
+ if edge.copy_relation is None or not edge.interferes:
+ continue
+ try:
+ # if merging works, then the edge can't interfere
+ _ = final_merged_ssa_val.copy_merged(
+ lhs_loc=loc, rhs=n.merged_ssa_val, rhs_loc=n.loc,
+ copy_relation=edge.copy_relation)
+ except BadMergedSSAVal:
+ continue
+ edges[n] = replace(edge, interferes=False)
retval = IGNode(
node_id=self.__next_node_id, merged_ssa_val=final_merged_ssa_val,
edges=edges, loc=loc, ignored=False)
merged2 = self.merged_ssa_val_map[ssa_val2]
merged = merged1.with_offset_to_match(ssa_val1)
return merged.merged(merged2.with_offset_to_match(
- ssa_val2, additional_offset=additional_offset))
+ ssa_val2, additional_offset=additional_offset)).normalized()
def merge(self, ssa_val1, ssa_val2, additional_offset=0):
# type: (SSAVal, SSAVal, int) -> IGNode
ssa_val1=ssa_val1, ssa_val2=ssa_val2,
additional_offset=additional_offset))
- def copy_merge_preview(self, node1, node2):
- # type: (IGNode, IGNode) -> MergedSSAVal
- try:
- copy_relation = node1.edges[node2].copy_relation
- except KeyError:
- raise ValueError("nodes aren't copy related")
- if copy_relation is None:
- raise ValueError("nodes aren't copy related")
- lhs, rhs = copy_relation
- if lhs.ssa_val not in node1.merged_ssa_val.ssa_vals:
- lhs, rhs = rhs, lhs
- lhs_merged = node1.merged_ssa_val.with_offset_to_match(
- lhs.ssa_val, additional_offset=-lhs.reg_idx)
- rhs_merged = node2.merged_ssa_val.with_offset_to_match(
- rhs.ssa_val, additional_offset=-rhs.reg_idx)
- return lhs_merged.merged(rhs_merged)
-
def copy_merge(self, node1, node2):
# type: (IGNode, IGNode) -> IGNode
- return self.nodes.merge_into_one_node(self.copy_merge_preview(
- node1=node1, node2=node2))
+ return self.nodes.merge_into_one_node(node1.copy_merge_preview(node2))
def local_colorability_score(self, node, merged_in_copy=None):
# type: (IGNode, None | IGNode) -> int
loc_set = node.loc_set
edges = node.edges
if merged_in_copy is not None:
- loc_set = self.copy_merge_preview(node, merged_in_copy).loc_set
+ if merged_in_copy.ignored:
+ raise ValueError(
+ "can't get local_colorability_score of ignored node")
+ loc_set = node.copy_merge_preview(merged_in_copy).loc_set
edges = edges.copy()
for neighbor, edge in merged_in_copy.edges.items():
edges[neighbor] = edge.merged(edges.get(neighbor))
def __eq__(self, other):
# type: (object) -> bool
if isinstance(other, IGNode):
- return self.merged_ssa_val == other.merged_ssa_val
+ return self.node_id == other.node_id
return NotImplemented
def __hash__(self):
# type: () -> int
- return hash(self.merged_ssa_val)
+ return hash(self.node_id)
def __repr__(self, repr_state=None, short=False):
# type: (None | IGNodeReprState, bool) -> str
return True
return False
+ def copy_merge_preview(self, rhs_node):
+ # type: (IGNode) -> MergedSSAVal
+ try:
+ copy_relation = self.edges[rhs_node].copy_relation
+ except KeyError:
+ raise ValueError("nodes aren't copy related")
+ if copy_relation is None:
+ raise ValueError("nodes aren't copy related")
+ return self.merged_ssa_val.copy_merged(
+ lhs_loc=self.loc,
+ rhs=rhs_node.merged_ssa_val, rhs_loc=rhs_node.loc,
+ copy_relation=copy_relation)
+
class AllocationFailedError(Exception):
def __init__(self, msg, node, interference_graph):
print(f"After InterferenceGraph.minimally_merged():\n"
f"{interference_graph}", file=debug_out, flush=True)
+ for i, j in combinations(interference_graph.nodes.values(), 2):
+ copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
+ i.merge_edge(j, IGEdge(copy_relation=copy_relation))
+
for pp, ssa_vals in fn_analysis.live_at.items():
live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
for ssa_val in ssa_vals:
live_merged_ssa_vals.add(
interference_graph.merged_ssa_val_map[ssa_val])
for i, j in combinations(live_merged_ssa_vals, 2):
- if i.loc_set.max_conflicts_with(j.loc_set) != 0:
- interference_graph.nodes[i].merge_edge(
- interference_graph.nodes[j],
- edge=IGEdge(interferes=True))
+ if i.loc_set.max_conflicts_with(j.loc_set) == 0:
+ continue
+ node_i = interference_graph.nodes[i]
+ node_j = interference_graph.nodes[j]
+ if node_j in node_i.edges:
+ if node_i.edges[node_j].copy_relation is not None:
+ try:
+ _ = node_i.copy_merge_preview(node_j)
+ continue # doesn't interfere if copy merging succeeds
+ except BadMergedSSAVal:
+ pass
+ node_i.merge_edge(node_j, edge=IGEdge(interferes=True))
if debug_out is not None:
print(f"processed {pp} out of {fn_analysis.all_program_points}",
file=debug_out, flush=True)
- for i, j in combinations(interference_graph.nodes.values(), 2):
- copy_relation = i.merged_ssa_val.get_copy_relation(j.merged_ssa_val)
- i.merge_edge(j, IGEdge(copy_relation=copy_relation))
-
if debug_out is not None:
print(f"After adding interference graph edges:\n"
f"{interference_graph}", file=debug_out, flush=True)
"IGNode is pre-allocated to a conflicting Loc",
node=node, interference_graph=interference_graph)
else:
- # pick the first non-conflicting register in node.reg_class, since
- # register classes are ordered from most preferred to least
- # preferred register.
+ # Locs to try allocating, ordered from most preferred to least
+ # preferred
+ locs = OSet()
+ # prefer eliminating copies
+ for neighbor, edge in node.edges.items():
+ if neighbor.loc is None or edge.copy_relation is None:
+ continue
+ try:
+ merged = node.copy_merge_preview(neighbor)
+ except BadMergedSSAVal:
+ continue
+ # get merged_loc if merged.loc_set has a single Loc
+ merged_loc = merged.only_loc
+ if merged_loc is None:
+ continue
+ ssa_val = node.merged_ssa_val.first_ssa_val
+ ssa_val_loc = merged_loc.get_subloc_at_offset(
+ subloc_ty=ssa_val.ty,
+ offset=merged.ssa_val_offsets[ssa_val])
+ node_loc = ssa_val_loc.get_superloc_with_self_at_offset(
+ superloc_ty=node.merged_ssa_val.ty,
+ offset=node.merged_ssa_val.ssa_val_offsets[ssa_val])
+ assert node_loc in node.merged_ssa_val.loc_set, "logic error"
+ locs.add(node_loc)
+ # add node's allowed Locs as fallback
for loc in node.loc_set:
+ # TODO: add in order of preference
+ locs.add(loc)
+ # pick the first non-conflicting register in locs, since locs is
+ # ordered from most preferred to least preferred register.
+ for loc in locs:
if not node.loc_conflicts_with_neighbors(loc):
node.loc = loc
break
--- /dev/null
+import unittest
+import shutil
+
+from nmutil.get_test_path import get_test_path
+from bigint_presentation_code.type_util import final
+
+
+@final
+class GraphDumper:
+ def __init__(self, test_case):
+ # type: (unittest.TestCase) -> None
+ self.test_case = test_case
+ self.base_path = get_test_path(test_case, "dumped_graphs")
+ shutil.rmtree(self.base_path, ignore_errors=True)
+ self.base_path.mkdir(parents=True, exist_ok=True)
+
+ def __call__(self, name, dot):
+ # type: (str, str) -> None
+ path = self.base_path / name
+ dot_path = path.with_suffix(".dot")
+ dot_path.write_text(dot)