From: Jacob Lifshay Date: Mon, 7 Nov 2022 02:07:12 +0000 (-0800) Subject: working on code -- register_allocator2.py should work... still needs copy merging... X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=9ff4c2196738457c74e83201a7c087cf04098e5a;p=bigint-presentation-code.git working on code -- register_allocator2.py should work... still needs copy merging though --- diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index 7bdd40b..9f91196 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -2,14 +2,16 @@ import enum from abc import ABCMeta, abstractmethod from enum import Enum, unique from functools import lru_cache, total_ordering +from io import StringIO from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator, - Sequence, TypeVar, overload) + Mapping, Sequence, TypeVar, overload) from weakref import WeakValueDictionary as _WeakVDict from cached_property import cached_property from nmutil.plain_data import fields, plain_data -from bigint_presentation_code.type_util import Self, assert_never, final, Literal +from bigint_presentation_code.type_util import (Literal, Self, assert_never, + final) from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet @@ -600,6 +602,16 @@ class Loc: reg_len += other.reg_len return Loc(kind=self.kind, start=self.start, reg_len=reg_len) + def get_subloc_at_offset(self, subloc_ty, offset): + # type: (Ty, int) -> Loc + if subloc_ty.base_ty != self.kind.base_ty: + raise ValueError("BaseTy mismatch") + start = self.start + offset + if offset < 0 or start + subloc_ty.reg_len > self.reg_len: + raise ValueError("invalid sub-Loc: offset and/or " + "subloc_ty.reg_len out of range") + return Loc(kind=self.kind, start=start, reg_len=subloc_ty.reg_len) + SPECIAL_GPRS = ( Loc(kind=LocKind.GPR, start=0, reg_len=1), @@ -1014,6 +1026,9 @@ IMM_S16 = range(-1 << 15, 1 << 15) _PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None] _PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN] _PRE_RA_SIMS = {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2] +_GEN_ASM_FN = Callable[["Op", "GenAsmState"], None] +_GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN] +_GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2] @unique @@ -1042,27 +1057,44 @@ class OpKind(Enum): # type: () -> _PRE_RA_SIM_FN return _PRE_RA_SIMS[self.properties]() + @cached_property + def gen_asm(self): + # type: () -> _GEN_ASM_FN + return _GEN_ASMS[self.properties]() + @staticmethod def __clearca_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = False, + + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") ClearCA = GenericOpProperties( demo_asm="addic 0, 0, 0", inputs=[], outputs=[OD_CA.with_write_stage(OpStage.Late)], ) _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim + _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm @staticmethod def __setca_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = True, + + @staticmethod + def __setca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("subfc 0, 0, 0") SetCA = GenericOpProperties( demo_asm="subfc 0, 0, 0", inputs=[], outputs=[OD_CA.with_write_stage(OpStage.Late)], ) _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim + _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm @staticmethod def __svadde_pre_ra_sim(op, state): @@ -1078,12 +1110,21 @@ class OpKind(Enum): carry = (v >> GPR_SIZE_IN_BITS) != 0 state.ssa_vals[op.outputs[0]] = tuple(RT) state.ssa_vals[op.outputs[1]] = carry, + + @staticmethod + def __svadde_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + RB = state.vgpr(op.input_vals[1]) + state.writeln(f"sv.adde {RT}, {RA}, {RB}") SvAddE = GenericOpProperties( demo_asm="sv.adde *RT, *RA, *RB", inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], outputs=[OD_EXTRA3_VGPR, OD_CA], ) _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim + _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm @staticmethod def __svsubfe_pre_ra_sim(op, state): @@ -1099,12 +1140,21 @@ class OpKind(Enum): carry = (v >> GPR_SIZE_IN_BITS) != 0 state.ssa_vals[op.outputs[0]] = tuple(RT) state.ssa_vals[op.outputs[1]] = carry, + + @staticmethod + def __svsubfe_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + RB = state.vgpr(op.input_vals[1]) + state.writeln(f"sv.subfe {RT}, {RA}, {RB}") SvSubFE = GenericOpProperties( demo_asm="sv.subfe *RT, *RA, *RB", inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], outputs=[OD_EXTRA3_VGPR, OD_CA], ) _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim + _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm @staticmethod def __svmaddedu_pre_ra_sim(op, state): @@ -1120,17 +1170,33 @@ class OpKind(Enum): carry = v >> GPR_SIZE_IN_BITS state.ssa_vals[op.outputs[0]] = tuple(RT) state.ssa_vals[op.outputs[1]] = carry, + + @staticmethod + def __svmaddedu_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + RB = state.sgpr(op.input_vals[1]) + RC = state.sgpr(op.input_vals[2]) + state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}") SvMAddEDU = GenericOpProperties( demo_asm="sv.maddedu *RT, *RA, RB, RC", inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL], outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)], ) _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim + _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm @staticmethod def __setvli_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = op.immediates[0], + + @staticmethod + def __setvli_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + imm = op.immediates[0] + state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1") SetVLI = GenericOpProperties( demo_asm="setvl 0, 0, imm, 0, 1, 1", inputs=(), @@ -1139,6 +1205,7 @@ class OpKind(Enum): is_load_immediate=True, ) _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim + _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm @staticmethod def __svli_pre_ra_sim(op, state): @@ -1146,6 +1213,13 @@ class OpKind(Enum): VL, = state.ssa_vals[op.input_vals[0]] imm = op.immediates[0] & GPR_VALUE_MASK state.ssa_vals[op.outputs[0]] = (imm,) * VL + + @staticmethod + def __svli_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"sv.addi {RT}, 0, {imm}") SvLI = GenericOpProperties( demo_asm="sv.addi *RT, 0, imm", inputs=[OD_VL], @@ -1154,12 +1228,20 @@ class OpKind(Enum): is_load_immediate=True, ) _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim + _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm @staticmethod def __li_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None imm = op.immediates[0] & GPR_VALUE_MASK state.ssa_vals[op.outputs[0]] = imm, + + @staticmethod + def __li_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.sgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"addi {RT}, 0, {imm}") LI = GenericOpProperties( demo_asm="addi RT, 0, imm", inputs=(), @@ -1168,11 +1250,35 @@ class OpKind(Enum): is_load_immediate=True, ) _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim + _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm @staticmethod def __veccopytoreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + + @staticmethod + def __veccopytoreg_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + src_loc = state.loc(op.input_vals[0], (LocKind.GPR, LocKind.StackI64)) + dest_loc = state.loc(op.outputs[0], LocKind.GPR) + RT = state.vgpr(dest_loc) + if src_loc == dest_loc: + return # no-op + assert src_loc.kind in (LocKind.GPR, LocKind.StackI64), \ + "checked by loc()" + if src_loc.kind is LocKind.StackI64: + src = state.stack(src_loc) + state.writeln(f"sv.ld {RT}, {src}") + return + elif src_loc.kind is not LocKind.GPR: + assert_never(src_loc.kind) + rev = "" + if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start: + rev = "/mrr" + src = state.vgpr(src_loc) + state.writeln(f"sv.or{rev} {RT}, {src}, {src}") + VecCopyToReg = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[GenericOperandDesc( @@ -1183,11 +1289,18 @@ class OpKind(Enum): is_copy=True, ) _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim + _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm @staticmethod def __veccopyfromreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") VecCopyFromReg = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[OD_EXTRA3_VGPR, OD_VL], @@ -1199,11 +1312,18 @@ class OpKind(Enum): is_copy=True, ) _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim + _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm @staticmethod def __copytoreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") CopyToReg = GenericOpProperties( demo_asm="mv dest, src", inputs=[GenericOperandDesc( @@ -1219,11 +1339,18 @@ class OpKind(Enum): is_copy=True, ) _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim + _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm @staticmethod def __copyfromreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") CopyFromReg = GenericOpProperties( demo_asm="mv dest, src", inputs=[GenericOperandDesc( @@ -1239,12 +1366,19 @@ class OpKind(Enum): is_copy=True, ) _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim + _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm @staticmethod def __concat_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = tuple( state.ssa_vals[i][0] for i in op.input_vals[:-1]) + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") Concat = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[GenericOperandDesc( @@ -1256,12 +1390,19 @@ class OpKind(Enum): is_copy=True, ) _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim + _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm @staticmethod def __spread_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]): state.ssa_vals[op.outputs[idx]] = inp, + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") Spread = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[OD_EXTRA3_VGPR, OD_VL], @@ -1274,6 +1415,7 @@ class OpKind(Enum): is_copy=True, ) _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim + _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm @staticmethod def __svld_pre_ra_sim(op, state): @@ -1286,6 +1428,12 @@ class OpKind(Enum): v = state.load(addr + GPR_SIZE_IN_BYTES * i) RT.append(v & GPR_VALUE_MASK) state.ssa_vals[op.outputs[0]] = tuple(RT) + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") SvLd = GenericOpProperties( demo_asm="sv.ld *RT, imm(RA)", inputs=[OD_EXTRA3_SGPR, OD_VL], @@ -1293,6 +1441,7 @@ class OpKind(Enum): immediates=[IMM_S16], ) _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim + _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm @staticmethod def __ld_pre_ra_sim(op, state): @@ -1301,6 +1450,12 @@ class OpKind(Enum): addr = RA + op.immediates[0] v = state.load(addr) state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK, + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") Ld = GenericOpProperties( demo_asm="ld RT, imm(RA)", inputs=[OD_BASE_SGPR], @@ -1308,6 +1463,7 @@ class OpKind(Enum): immediates=[IMM_S16], ) _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim + _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm @staticmethod def __svstd_pre_ra_sim(op, state): @@ -1318,6 +1474,12 @@ class OpKind(Enum): addr = RA + op.immediates[0] for i in range(VL): state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i]) + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") SvStd = GenericOpProperties( demo_asm="sv.std *RS, imm(RA)", inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL], @@ -1326,6 +1488,7 @@ class OpKind(Enum): has_side_effects=True, ) _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim + _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm @staticmethod def __std_pre_ra_sim(op, state): @@ -1334,6 +1497,12 @@ class OpKind(Enum): RA, = state.ssa_vals[op.input_vals[1]] addr = RA + op.immediates[0] state.store(addr, value=RS) + + # FIXME: change to correct __*_gen_asm function + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") Std = GenericOpProperties( demo_asm="std RT, imm(RA)", inputs=[OD_BASE_SGPR, OD_BASE_SGPR], @@ -1342,11 +1511,17 @@ class OpKind(Enum): has_side_effects=True, ) _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim + _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm @staticmethod def __funcargr3_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None pass # return value set before simulation + + @staticmethod + def __funcargr3_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + pass # no instructions needed FuncArgR3 = GenericOpProperties( demo_asm="", inputs=[], @@ -1354,6 +1529,7 @@ class OpKind(Enum): Loc(kind=LocKind.GPR, start=3, reg_len=1))], ) _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim + _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm @plain_data(frozen=True, unsafe_hash=True, repr=False) @@ -1812,3 +1988,62 @@ class PreRASimState: field_vals.append(f"{name}={value!r}") field_vals_str = ", ".join(field_vals) return f"PreRASimState({field_vals_str})" + + +@plain_data(frozen=True) +class GenAsmState: + __slots__ = "allocated_locs", "output" + + def __init__(self, allocated_locs, output): + # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None + super().__init__() + self.allocated_locs = FMap(allocated_locs) + for ssa_val, loc in self.allocated_locs.items(): + if ssa_val.ty != loc.ty: + raise ValueError( + f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}") + if output is None: + output = [] + self.output = output + + def loc(self, ssa_val_or_loc, expected_kinds): + # type: (SSAVal | Loc, LocKind | tuple[LocKind, ...]) -> Loc + if isinstance(ssa_val_or_loc, SSAVal): + retval = self.allocated_locs[ssa_val_or_loc] + else: + retval = ssa_val_or_loc + if isinstance(expected_kinds, LocKind): + expected_kinds = expected_kinds, + if retval.kind not in expected_kinds: + if len(expected_kinds) == 1: + expected_kinds = expected_kinds[0] + raise ValueError(f"LocKind mismatch: {ssa_val_or_loc}: found " + f"{retval.kind} expected {expected_kinds}") + return retval + + def gpr(self, ssa_val_or_loc, is_vec): + # type: (SSAVal | Loc, bool) -> str + loc = self.loc(ssa_val_or_loc, LocKind.GPR) + vec_str = "*" if is_vec else "" + return vec_str + str(loc.start) + + def sgpr(self, ssa_val_or_loc): + # type: (SSAVal | Loc) -> str + return self.gpr(ssa_val_or_loc, is_vec=False) + + def vgpr(self, ssa_val_or_loc): + # type: (SSAVal | Loc) -> str + return self.gpr(ssa_val_or_loc, is_vec=True) + + def stack(self, ssa_val_or_loc): + # type: (SSAVal | Loc) -> str + loc = self.loc(ssa_val_or_loc, LocKind.StackI64) + return f"{loc.start}(1)" + + def writeln(self, *line_segments): + # type: (*str) -> None + line = " ".join(line_segments) + if isinstance(self.output, list): + self.output.append(line) + else: + self.output.write(line + "\n") diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py index 20ca534..d3ca398 100644 --- a/src/bigint_presentation_code/register_allocator2.py +++ b/src/bigint_presentation_code/register_allocator2.py @@ -6,13 +6,13 @@ this uses an algorithm based on: """ from itertools import combinations -from typing import Any, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Dict +from typing import Iterable, Iterator, Mapping from cached_property import cached_property from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir2 import (BaseTy, FnAnalysis, Loc, - LocSet, Op, ProgramRange, +from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc, + LocSet, ProgramRange, SSAVal, Ty) from bigint_presentation_code.type_util import final from bigint_presentation_code.util import FMap, OFSet, OSet @@ -269,7 +269,7 @@ class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]): @final -class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]): +class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): def __init__( self, *, _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal] @@ -305,7 +305,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]): f"{self.__merged_ssa_val_map[ssa_val]}") self.__merged_ssa_val_map[ssa_val] = merged_ssa_val added += 1 - retval = IGNode(merged_ssa_val) + retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None) self.__map[merged_ssa_val] = retval added = None return retval @@ -319,21 +319,40 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]): def merge_into_one_node(self, final_merged_ssa_val): # type: (MergedSSAVal) -> IGNode - source_nodes = {} # type: dict[MergedSSAVal, IGNode] + source_nodes = OSet() # type: OSet[IGNode] + edges = OSet() # type: OSet[IGNode] + loc = None # type: Loc | None for ssa_val in final_merged_ssa_val.ssa_vals: merged_ssa_val = self.__merged_ssa_val_map[ssa_val] - source_nodes[merged_ssa_val] = self.__map[merged_ssa_val] + source_node = self.__map[merged_ssa_val] + source_nodes.add(source_node) for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals: raise ValueError( f"SSAVal {i} appears in source IGNode's merged_ssa_val " f"but not in merged IGNode's merged_ssa_val: " - f"source_node={self.__map[merged_ssa_val]} " + f"source_node={source_node} " f"final_merged_ssa_val={final_merged_ssa_val}") - # FIXME: work on function from here - raise NotImplementedError - self.__values_set.discard(value) - for ssa_val in value.ssa_val_offsets.keys(): - del self.__merge_map[ssa_val] + if loc is None: + loc = source_node.loc + elif source_node.loc is not None and loc != source_node.loc: + raise ValueError(f"can't merge IGNodes with mismatched `loc` " + f"values: {loc} != {source_node.loc}") + edges |= source_node.edges + if len(source_nodes) == 1: + return source_nodes.pop() # merging a single node is a no-op + # we're finished checking validity, now we can modify stuff + edges -= source_nodes + retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges, + loc=loc) + for node in edges: + node.edges -= source_nodes + node.edges.add(retval) + for node in source_nodes: + del self.__map[node.merged_ssa_val] + self.__map[final_merged_ssa_val] = retval + for ssa_val in final_merged_ssa_val.ssa_vals: + self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val + return retval def __repr__(self): # type: () -> str @@ -387,7 +406,7 @@ class IGNode: """ interference graph node """ __slots__ = "merged_ssa_val", "edges", "loc" - def __init__(self, merged_ssa_val, edges=(), loc=None): + def __init__(self, merged_ssa_val, edges, loc): # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None self.merged_ssa_val = merged_ssa_val self.edges = OSet(edges) @@ -405,6 +424,7 @@ class IGNode: return NotImplemented def __hash__(self): + # type: () -> int return hash(self.merged_ssa_val) def __repr__(self, nodes=None): @@ -433,42 +453,36 @@ class IGNode: return False -@plain_data() -class AllocationFailed: - __slots__ = "node", "merged_ssa_vals", "interference_graph" - - def __init__(self, node, merged_ssa_vals, interference_graph): - # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None - super().__init__() +class AllocationFailedError(Exception): + def __init__(self, msg, node, interference_graph): + # type: (str, IGNode, InterferenceGraph) -> None + super().__init__(msg, node, interference_graph) self.node = node - self.merged_ssa_vals = merged_ssa_vals self.interference_graph = interference_graph -class AllocationFailedError(Exception): - def __init__(self, msg, allocation_failed): - # type: (str, AllocationFailed) -> None - super().__init__(msg, allocation_failed) - self.allocation_failed = allocation_failed +def allocate_registers(fn): + # type: (Fn) -> dict[SSAVal, Loc] + # inserts enough copies that no manual spilling is necessary, all + # spilling is done by the register allocator naturally allocating SSAVals + # to stack slots + fn.pre_ra_insert_copies() -def try_allocate_registers_without_spilling(merged_ssa_vals): - # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed + fn_analysis = FnAnalysis(fn) + interference_graph = InterferenceGraph.minimally_merged(fn_analysis) - interference_graph = { - i: IGNode(i) for i in merged_ssa_vals.merged_ssa_vals} - fn_analysis = merged_ssa_vals.fn_analysis for ssa_vals in fn_analysis.live_at.values(): live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal] for ssa_val in ssa_vals: - live_merged_ssa_vals.add(merged_ssa_vals.merge_map[ssa_val]) + live_merged_ssa_vals.add( + interference_graph.merged_ssa_val_map[ssa_val]) for i, j in combinations(live_merged_ssa_vals, 2): if i.loc_set.max_conflicts_with(j.loc_set) != 0: - interference_graph[i].add_edge(interference_graph[j]) - - nodes_remaining = OSet(interference_graph.values()) + interference_graph.nodes[i].add_edge( + interference_graph.nodes[j]) -# FIXME: work on code from here + nodes_remaining = OSet(interference_graph.nodes.values()) def local_colorability_score(node): # type: (IGNode) -> int @@ -481,9 +495,11 @@ def try_allocate_registers_without_spilling(merged_ssa_vals): retval = len(node.loc_set) for neighbor in node.edges: if neighbor in nodes_remaining: - retval -= node.reg_class.max_conflicts_with(neighbor.reg_class) + retval -= node.loc_set.max_conflicts_with(neighbor.loc_set) return retval + # TODO: implement copy-merging + node_stack = [] # type: list[IGNode] while True: best_node = None # type: None | IGNode @@ -502,39 +518,29 @@ def try_allocate_registers_without_spilling(merged_ssa_vals): node_stack.append(best_node) nodes_remaining.remove(best_node) - retval = {} # type: dict[SSAVal, RegLoc] + retval = {} # type: dict[SSAVal, Loc] while len(node_stack) > 0: node = node_stack.pop() - if node.reg is not None: - if node.reg_conflicts_with_neighbors(node.reg): - return AllocationFailed(node=node, - live_intervals=live_intervals, - interference_graph=interference_graph) + if node.loc is not None: + if node.loc_conflicts_with_neighbors(node.loc): + raise AllocationFailedError( + "IGNode is pre-allocated to a conflicting Loc", + node=node, interference_graph=interference_graph) else: # pick the first non-conflicting register in node.reg_class, since # register classes are ordered from most preferred to least # preferred register. - for reg in node.reg_class: - if not node.reg_conflicts_with_neighbors(reg): - node.reg = reg + for loc in node.loc_set: + if not node.loc_conflicts_with_neighbors(loc): + node.loc = loc break - if node.reg is None: - return AllocationFailed(node=node, - live_intervals=live_intervals, - interference_graph=interference_graph) - - for ssa_val, offset in node.merged_reg_set.items(): - retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset) - - return retval + if node.loc is None: + raise AllocationFailedError( + "failed to allocate Loc for IGNode", + node=node, interference_graph=interference_graph) + for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items(): + retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset) -def allocate_registers(ops): - # type: (list[Op]) -> dict[SSAVal, RegLoc] - retval = try_allocate_registers_without_spilling(ops) - if isinstance(retval, AllocationFailed): - # TODO: implement spilling - raise AllocationFailedError( - "spilling required but not yet implemented", retval) return retval