From c1bf1966ecbb6a84ddfa1718d966f9af9df9a04c Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Mon, 31 Oct 2022 23:26:53 -0700 Subject: [PATCH] working on refactoring register allocator to use new ir --- .../_tests/test_compiler_ir2.py | 114 +++-- src/bigint_presentation_code/compiler_ir2.py | 115 ++++- .../register_allocator2.py | 471 ++++++++++++++++++ 3 files changed, 632 insertions(+), 68 deletions(-) create mode 100644 src/bigint_presentation_code/register_allocator2.py diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir2.py b/src/bigint_presentation_code/_tests/test_compiler_ir2.py index 74c38e9..4aa8e3e 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir2.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir2.py @@ -39,30 +39,32 @@ class TestCompilerIR(unittest.TestCase): "Op(kind=OpKind.FuncArgR3, " "inputs=[], " "immediates=[], " - "outputs=(>,), name='arg')", + "outputs=(>,), name='arg')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(>,), name='vl')", + "outputs=(>,), name='vl')", "Op(kind=OpKind.SvLd, " - "inputs=[>, >], " + "inputs=[>, >], " "immediates=[0], " - "outputs=(>,), name='ld')", + "outputs=(>,), name='ld')", "Op(kind=OpKind.SvLI, " - "inputs=[>], " + "inputs=[>], " "immediates=[0], " - "outputs=(>,), name='li')", + "outputs=(>,), name='li')", "Op(kind=OpKind.SetCA, " "inputs=[], " "immediates=[], " - "outputs=(>,), name='ca')", + "outputs=(>,), name='ca')", "Op(kind=OpKind.SvAddE, " - "inputs=[>, >, >, " - ">], " + "inputs=[>, >, " + ">, >], " "immediates=[], " - "outputs=(>, >), name='add')", + "outputs=(>, >), " + "name='add')", "Op(kind=OpKind.SvStd, " - "inputs=[>, >, >], " + "inputs=[>, >, " + ">], " "immediates=[0], " "outputs=(), name='st')", ]) @@ -150,90 +152,93 @@ class TestCompilerIR(unittest.TestCase): "Op(kind=OpKind.FuncArgR3, " "inputs=[], " "immediates=[], " - "outputs=(>,), name='arg')", + "outputs=(>,), name='arg')", "Op(kind=OpKind.CopyFromReg, " - "inputs=[>], " + "inputs=[>], " "immediates=[], " - "outputs=(<2#0: >,), name='2')", + "outputs=(<2.outputs[0]: >,), name='2')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(>,), name='vl')", + "outputs=(>,), name='vl')", "Op(kind=OpKind.CopyToReg, " - "inputs=[<2#0: >], " + "inputs=[<2.outputs[0]: >], " "immediates=[], " - "outputs=(<3#0: >,), name='3')", + "outputs=(<3.outputs[0]: >,), name='3')", "Op(kind=OpKind.SvLd, " - "inputs=[<3#0: >, >], " + "inputs=[<3.outputs[0]: >, >], " "immediates=[0], " - "outputs=(>,), name='ld')", + "outputs=(>,), name='ld')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(<4#0: >,), name='4')", + "outputs=(<4.outputs[0]: >,), name='4')", "Op(kind=OpKind.VecCopyFromReg, " - "inputs=[>, <4#0: >], " + "inputs=[>, <4.outputs[0]: >], " "immediates=[], " - "outputs=(<5#0: >,), name='5')", + "outputs=(<5.outputs[0]: >,), name='5')", "Op(kind=OpKind.SvLI, " - "inputs=[>], " + "inputs=[>], " "immediates=[0], " - "outputs=(>,), name='li')", + "outputs=(>,), name='li')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(<6#0: >,), name='6')", + "outputs=(<6.outputs[0]: >,), name='6')", "Op(kind=OpKind.VecCopyFromReg, " - "inputs=[>, <6#0: >], " + "inputs=[>, <6.outputs[0]: >], " "immediates=[], " - "outputs=(<7#0: >,), name='7')", + "outputs=(<7.outputs[0]: >,), name='7')", "Op(kind=OpKind.SetCA, " "inputs=[], " "immediates=[], " - "outputs=(>,), name='ca')", + "outputs=(>,), name='ca')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(<8#0: >,), name='8')", + "outputs=(<8.outputs[0]: >,), name='8')", "Op(kind=OpKind.VecCopyToReg, " - "inputs=[<5#0: >, <8#0: >], " + "inputs=[<5.outputs[0]: >, <8.outputs[0]: >], " "immediates=[], " - "outputs=(<9#0: >,), name='9')", + "outputs=(<9.outputs[0]: >,), name='9')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(<10#0: >,), name='10')", + "outputs=(<10.outputs[0]: >,), name='10')", "Op(kind=OpKind.VecCopyToReg, " - "inputs=[<7#0: >, <10#0: >], " + "inputs=[<7.outputs[0]: >, <10.outputs[0]: >], " "immediates=[], " - "outputs=(<11#0: >,), name='11')", + "outputs=(<11.outputs[0]: >,), name='11')", "Op(kind=OpKind.SvAddE, " - "inputs=[<9#0: >, <11#0: >, >, " - ">], " + "inputs=[<9.outputs[0]: >, <11.outputs[0]: >, " + ">, >], " "immediates=[], " - "outputs=(>, >), name='add')", + "outputs=(>, >), " + "name='add')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(<12#0: >,), name='12')", + "outputs=(<12.outputs[0]: >,), name='12')", "Op(kind=OpKind.VecCopyFromReg, " - "inputs=[>, <12#0: >], " + "inputs=[>, " + "<12.outputs[0]: >], " "immediates=[], " - "outputs=(<13#0: >,), name='13')", + "outputs=(<13.outputs[0]: >,), name='13')", "Op(kind=OpKind.SetVLI, " "inputs=[], " "immediates=[32], " - "outputs=(<14#0: >,), name='14')", + "outputs=(<14.outputs[0]: >,), name='14')", "Op(kind=OpKind.VecCopyToReg, " - "inputs=[<13#0: >, <14#0: >], " + "inputs=[<13.outputs[0]: >, <14.outputs[0]: >], " "immediates=[], " - "outputs=(<15#0: >,), name='15')", + "outputs=(<15.outputs[0]: >,), name='15')", "Op(kind=OpKind.CopyToReg, " - "inputs=[<2#0: >], " + "inputs=[<2.outputs[0]: >], " "immediates=[], " - "outputs=(<16#0: >,), name='16')", + "outputs=(<16.outputs[0]: >,), name='16')", "Op(kind=OpKind.SvStd, " - "inputs=[<15#0: >, <16#0: >, >], " + "inputs=[<15.outputs[0]: >, <16.outputs[0]: >, " + ">], " "immediates=[0], " "outputs=(), name='st')", ]) @@ -471,16 +476,17 @@ class TestCompilerIR(unittest.TestCase): size_in_bytes=GPR_SIZE_IN_BYTES) self.assertEqual( repr(state), - "PreRASimState(ssa_vals={>: (0x100,)}, memory={\n" + "PreRASimState(ssa_vals={>: (0x100,)}, " + "memory={\n" "0x00100: <0xffffffffffffffff>,\n" "0x00108: <0xabcdef0123456789>})") fn.pre_ra_sim(state) self.assertEqual( repr(state), "PreRASimState(ssa_vals={\n" - ">: (0x100,),\n" - ">: (0x20,),\n" - ">: (\n" + ">: (0x100,),\n" + ">: (0x20,),\n" + ">: (\n" " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" @@ -489,7 +495,7 @@ class TestCompilerIR(unittest.TestCase): " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0),\n" - ">: (\n" + ">: (\n" " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" @@ -498,8 +504,8 @@ class TestCompilerIR(unittest.TestCase): " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0),\n" - ">: (0x1,),\n" - ">: (\n" + ">: (0x1,),\n" + ">: (\n" " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" @@ -508,7 +514,7 @@ class TestCompilerIR(unittest.TestCase): " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0,\n" " 0x0, 0x0, 0x0, 0x0),\n" - ">: (0x0,),\n" + ">: (0x0,),\n" "}, memory={\n" "0x00100: <0x0000000000000000>,\n" "0x00108: <0xabcdef012345678a>,\n" diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index 7109e77..8e36509 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -1,5 +1,5 @@ import enum -from abc import abstractmethod +from abc import ABCMeta, abstractmethod from enum import Enum, unique from functools import lru_cache from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator, @@ -10,7 +10,7 @@ from cached_property import cached_property from nmutil.plain_data import fields, plain_data from bigint_presentation_code.type_util import Self, assert_never, final -from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet +from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet @final @@ -103,6 +103,33 @@ class Fn: assert_never(out.ty.base_ty) +@plain_data(frozen=True, eq=False) +@final +class FnWithUses: + __slots__ = "fn", "uses" + + def __init__(self, fn): + # type: (Fn) -> None + self.fn = fn + retval = {} # type: dict[SSAVal, OSet[SSAUse]] + for op in fn.ops: + for idx, inp in enumerate(op.inputs): + retval[inp].add(SSAUse(op, idx)) + for out in op.outputs: + retval[out] = OSet() + self.uses = FMap((k, OFSet(v)) for k, v in retval.items()) + + def __eq__(self, other): + # type: (FnWithUses | Any) -> bool + if isinstance(other, FnWithUses): + return self.fn == other.fn + return NotImplemented + + def __hash__(self): + # type: () -> int + return hash(self.fn) + + @unique @final class BaseTy(Enum): @@ -1074,41 +1101,93 @@ class OpKind(Enum): _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim +@plain_data(frozen=True, unsafe_hash=True, repr=False) +class SSAValOrUse(metaclass=ABCMeta): + __slots__ = "op", + + def __init__(self, op): + # type: (Op) -> None + self.op = op + + @abstractmethod + def __repr__(self): + # type: () -> str + ... + + @property + @abstractmethod + def defining_descriptor(self): + # type: () -> OperandDesc + ... + + @cached_property + def ty(self): + # type: () -> Ty + return self.defining_descriptor.ty + + @cached_property + def ty_before_spread(self): + # type: () -> Ty + return self.defining_descriptor.ty_before_spread + + @property + def base_ty(self): + # type: () -> BaseTy + return self.ty_before_spread.base_ty + + @plain_data(frozen=True, unsafe_hash=True, repr=False) @final -class SSAVal: - __slots__ = "op", "output_idx" +class SSAVal(SSAValOrUse): + __slots__ = "output_idx", def __init__(self, op, output_idx): # type: (Op, int) -> None - self.op = op + super().__init__(op) if output_idx < 0 or output_idx >= len(op.properties.outputs): raise ValueError("invalid output_idx") self.output_idx = output_idx def __repr__(self): # type: () -> str - return f"<{self.op.name}#{self.output_idx}: {self.ty}>" + return f"<{self.op.name}.outputs[{self.output_idx}]: {self.ty}>" + + @cached_property + def def_loc_set_before_spread(self): + # type: () -> LocSet + return self.defining_descriptor.loc_set_before_spread @cached_property def defining_descriptor(self): # type: () -> OperandDesc return self.op.properties.outputs[self.output_idx] + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +class SSAUse(SSAValOrUse): + __slots__ = "input_idx", + + def __init__(self, op, input_idx): + # type: (Op, int) -> None + super().__init__(op) + self.input_idx = input_idx + if input_idx < 0 or input_idx >= len(op.inputs): + raise ValueError("input_idx out of range") + @cached_property - def loc_set_before_spread(self): + def use_loc_set_before_spread(self): # type: () -> LocSet return self.defining_descriptor.loc_set_before_spread @cached_property - def ty(self): - # type: () -> Ty - return self.defining_descriptor.ty + def defining_descriptor(self): + # type: () -> OperandDesc + return self.op.properties.inputs[self.input_idx] - @cached_property - def ty_before_spread(self): - # type: () -> Ty - return self.defining_descriptor.ty_before_spread + def __repr__(self): + # type: () -> str + return f"<{self.op.name}.inputs[{self.input_idx}]: {self.ty}>" _T = TypeVar("_T") @@ -1135,6 +1214,10 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]): self._verify_write_with_desc(idx, item, desc) return idx + def _on_set(self, idx, new_item, old_item): + # type: (int, _T, _T | None) -> None + pass + @abstractmethod def _get_descriptors(self): # type: () -> tuple[_Desc, ...] @@ -1212,6 +1295,10 @@ class OpInputs(OpInputSeq[SSAVal, OperandDesc]): raise ValueError(f"assigned item's type {item.ty!r} doesn't match " f"corresponding input's type {desc.ty!r}") + def _on_set(self, idx, new_item, old_item): + # type: (int, SSAVal, SSAVal | None) -> None + SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore + def __init__(self, items, op): # type: (Iterable[SSAVal], Op) -> None if hasattr(op, "inputs"): diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py new file mode 100644 index 0000000..962a021 --- /dev/null +++ b/src/bigint_presentation_code/register_allocator2.py @@ -0,0 +1,471 @@ +""" +Register Allocator for Toom-Cook algorithm generator for SVP64 + +this uses an algorithm based on: +[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) +""" + +from itertools import combinations +from functools import reduce +from typing import Generic, Iterable, Mapping +from cached_property import cached_property +import operator + +from nmutil.plain_data import plain_data + +from bigint_presentation_code.compiler_ir2 import ( + Op, LocSet, Ty, SSAVal, BaseTy, Loc, FnWithUses) +from bigint_presentation_code.type_util import final, Self +from bigint_presentation_code.util import OFSet, OSet, FMap + + +@plain_data(unsafe_hash=True, order=True, frozen=True) +class LiveInterval: + __slots__ = "first_write", "last_use" + + def __init__(self, first_write, last_use=None): + # type: (int, int | None) -> None + if last_use is None: + last_use = first_write + if last_use < first_write: + raise ValueError("uses must be after first_write") + if first_write < 0 or last_use < 0: + raise ValueError("indexes must be nonnegative") + self.first_write = first_write + self.last_use = last_use + + def overlaps(self, other): + # type: (LiveInterval) -> bool + if self.first_write == other.first_write: + return True + return self.last_use > other.first_write \ + and other.last_use > self.first_write + + def __add__(self, use): + # type: (int) -> LiveInterval + last_use = max(self.last_use, use) + return LiveInterval(first_write=self.first_write, last_use=last_use) + + @property + def live_after_op_range(self): + """the range of op indexes where self is live immediately after the + Op at each index + """ + return range(self.first_write, self.last_use) + + +class BadMergedSSAVal(ValueError): + pass + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class MergedSSAVal: + """a set of `SSAVal`s along with their offsets, all register allocated as + a single unit. + + Definition of the term `offset` for this class: + + Let `locs[x]` be the `Loc` that `x` is assigned to after register + allocation and let `msv` be a `MergedSSAVal` instance, then the offset + for each `SSAVal` `ssa_val` in `msv` is defined as: + + ``` + msv.ssa_val_offsets[ssa_val] = (msv.offset + + locs[ssa_val].start - locs[msv].start) + ``` + + Example: + ``` + v1.ty == + v2.ty == + v3.ty == + msv = MergedSSAVal({v1: 0, v2: 4, v3: 1}) + msv.ty == + ``` + if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then + * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)` + * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)` + * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)` + """ + __slots__ = "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set" + + def __init__(self, fn_with_uses, ssa_val_offsets): + # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None + self.fn_with_uses = fn_with_uses + if isinstance(ssa_val_offsets, SSAVal): + ssa_val_offsets = {ssa_val_offsets: 0} + self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int] + base_ty = None + for ssa_val in self.ssa_val_offsets.keys(): + base_ty = ssa_val.base_ty + break + if base_ty is None: + raise BadMergedSSAVal("MergedSSAVal can't be empty") + self.base_ty = base_ty # type: BaseTy + # self.ty checks for mismatched base_ty + reg_len = self.ty.reg_len + loc_set = None # type: None | LocSet + for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items(): + def_spread_idx = ssa_val.defining_descriptor.spread_index or 0 + + def locs(): + # type: () -> Iterable[Loc] + for loc in ssa_val.def_loc_set_before_spread: + disallowed_by_use = False + for use in fn_with_uses.uses[ssa_val]: + use_spread_idx = \ + use.defining_descriptor.spread_index or 0 + # calculate the start for the use's Loc before spread + # e.g. if the def's Loc before spread starts at r6 + # and the def's spread_index is 5 + # and the use's spread_index is 3 + # then the use's Loc before spread starts at r8 + # because 8 == 6 + 5 - 3 + start = loc.start + def_spread_idx - use_spread_idx + use_loc = Loc.try_make( + loc.kind, start=start, + reg_len=use.ty_before_spread.reg_len) + if (use_loc is None or + use_loc not in use.use_loc_set_before_spread): + disallowed_by_use = True + break + if disallowed_by_use: + continue + # FIXME: add spread consistency check + start = loc.start - cur_offset + self.offset + loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len) + if loc is not None and (loc_set is None or loc in loc_set): + yield loc + loc_set = LocSet(locs()) + assert loc_set is not None, "already checked that self isn't empty" + if loc_set.ty is None: + raise BadMergedSSAVal("there are no valid Locs left") + assert loc_set.ty == self.ty, "logic error somewhere" + self.loc_set = loc_set # type: LocSet + + @cached_property + def offset(self): + # type: () -> int + return min(self.ssa_val_offsets_before_spread.values()) + + @cached_property + def ty(self): + # type: () -> Ty + reg_len = 0 + for ssa_val, offset in self.ssa_val_offsets_before_spread.items(): + cur_ty = ssa_val.ty_before_spread + if self.base_ty != cur_ty.base_ty: + raise BadMergedSSAVal( + f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}") + reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset) + return Ty(base_ty=self.base_ty, reg_len=reg_len) + + @cached_property + def ssa_val_offsets_before_spread(self): + # type: () -> FMap[SSAVal, int] + retval = {} # type: dict[SSAVal, int] + for ssa_val, offset in self.ssa_val_offsets.items(): + offset_before_spread = offset + spread_index = ssa_val.defining_descriptor.spread_index + if spread_index is not None: + assert ssa_val.ty.reg_len == 1, ( + "this function assumes spreading always converts a vector " + "to a contiguous sequence of scalars, if that's changed " + "in the future, then this function needs to be adjusted") + offset_before_spread -= spread_index + retval[ssa_val] = offset_before_spread + return FMap(retval) + + def offset_by(self, amount): + # type: (int) -> MergedSSAVal + v = {k: v + amount for k, v in self.ssa_val_offsets.items()} + return MergedSSAVal(fn_with_uses=self.fn_with_uses, ssa_val_offsets=v) + + def normalized(self): + # type: () -> MergedSSAVal + return self.offset_by(-self.offset) + + def with_offset_to_match(self, target): + # type: (MergedSSAVal) -> MergedSSAVal + for ssa_val, offset in self.ssa_val_offsets.items(): + if ssa_val in target.ssa_val_offsets: + return self.offset_by(target.ssa_val_offsets[ssa_val] - offset) + raise ValueError("can't change offset to match unrelated MergedSSAVal") + + +@final +class MergedSSAVals(OFSet[MergedSSAVal]): + def __init__(self, merged_ssa_vals=()): + # type: (Iterable[MergedSSAVal]) -> None + super().__init__(merged_ssa_vals) + merge_map = {} # type: dict[SSAVal, MergedSSAVal] + for merged_ssa_val in self: + for ssa_val in merged_ssa_val.ssa_val_offsets.keys(): + if ssa_val in merge_map: + raise ValueError( + f"overlapping `MergedSSAVal`s: {ssa_val} is in both " + f"{merged_ssa_val} and {merge_map[ssa_val]}") + merge_map[ssa_val] = merged_ssa_val + self.__merge_map = FMap(merge_map) + + @cached_property + def merge_map(self): + # type: () -> FMap[SSAVal, MergedSSAVal] + return self.__merge_map + +# FIXME: work on code from here + + @staticmethod + def minimally_merged(fn_with_uses): + # type: (FnWithUses) -> MergedSSAVals + merge_map = {} # type: dict[SSAVal, MergedSSAVal] + for op in fn_with_uses.fn.ops: + for fn + for val in (*op.inputs().values(), *op.outputs().values()): + if val not in merged_sets: + merged_sets[val] = MergedRegSet(val) + for e in op.get_equality_constraints(): + lhs_set = MergedRegSet.from_equality_constraint(e.lhs) + rhs_set = MergedRegSet.from_equality_constraint(e.rhs) + items = [] # type: list[tuple[SSAVal, int]] + for i in e.lhs: + s = merged_sets[i].with_offset_to_match(lhs_set) + items.extend(s.items()) + for i in e.rhs: + s = merged_sets[i].with_offset_to_match(rhs_set) + items.extend(s.items()) + full_set = MergedRegSet(items) + for val in full_set.keys(): + merged_sets[val] = full_set + + self.__map = {k: v.normalized() for k, v in merged_sets.items()} + + +@final +class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]): + def __init__(self, ops): + # type: (list[Op]) -> None + self.__merged_reg_sets = MergedRegSets(ops) + live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval] + for op_idx, op in enumerate(ops): + for val in op.inputs().values(): + live_intervals[self.__merged_reg_sets[val]] += op_idx + for val in op.outputs().values(): + reg_set = self.__merged_reg_sets[val] + if reg_set not in live_intervals: + live_intervals[reg_set] = LiveInterval(op_idx) + else: + live_intervals[reg_set] += op_idx + self.__live_intervals = live_intervals + live_after = [] # type: list[OSet[MergedRegSet[_RegType]]] + live_after += (OSet() for _ in ops) + for reg_set, live_interval in self.__live_intervals.items(): + for i in live_interval.live_after_op_range: + live_after[i].add(reg_set) + self.__live_after = [OFSet(i) for i in live_after] + + @property + def merged_reg_sets(self): + return self.__merged_reg_sets + + def __getitem__(self, key): + # type: (MergedRegSet[_RegType]) -> LiveInterval + return self.__live_intervals[key] + + def __iter__(self): + return iter(self.__live_intervals) + + def __len__(self): + return len(self.__live_intervals) + + def reg_sets_live_after(self, op_index): + # type: (int) -> OFSet[MergedRegSet[_RegType]] + return self.__live_after[op_index] + + def __repr__(self): + reg_sets_live_after = dict(enumerate(self.__live_after)) + return (f"LiveIntervals(live_intervals={self.__live_intervals}, " + f"merged_reg_sets={self.merged_reg_sets}, " + f"reg_sets_live_after={reg_sets_live_after})") + + +@final +class IGNode(Generic[_RegType]): + """ interference graph node """ + __slots__ = "merged_reg_set", "edges", "reg" + + def __init__(self, merged_reg_set, edges=(), reg=None): + # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None + self.merged_reg_set = merged_reg_set + self.edges = OSet(edges) + self.reg = reg + + def add_edge(self, other): + # type: (IGNode) -> None + self.edges.add(other) + other.edges.add(self) + + def __eq__(self, other): + # type: (object) -> bool + if isinstance(other, IGNode): + return self.merged_reg_set == other.merged_reg_set + return NotImplemented + + def __hash__(self): + return hash(self.merged_reg_set) + + def __repr__(self, nodes=None): + # type: (None | dict[IGNode, int]) -> str + if nodes is None: + nodes = {} + if self in nodes: + return f"" + nodes[self] = len(nodes) + edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}" + return (f"IGNode(#{nodes[self]}, " + f"merged_reg_set={self.merged_reg_set}, " + f"edges={edges}, " + f"reg={self.reg})") + + @property + def reg_class(self): + # type: () -> RegClass + return self.merged_reg_set.ty.reg_class + + def reg_conflicts_with_neighbors(self, reg): + # type: (RegLoc) -> bool + for neighbor in self.edges: + if neighbor.reg is not None and neighbor.reg.conflicts(reg): + return True + return False + + +@final +class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]): + def __init__(self, merged_reg_sets): + # type: (Iterable[MergedRegSet[_RegType]]) -> None + self.__nodes = {i: IGNode(i) for i in merged_reg_sets} + + def __getitem__(self, key): + # type: (MergedRegSet[_RegType]) -> IGNode + return self.__nodes[key] + + def __iter__(self): + return iter(self.__nodes) + + def __len__(self): + return len(self.__nodes) + + def __repr__(self): + nodes = {} + nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()] + nodes_text = ", ".join(nodes_text) + return f"InterferenceGraph(nodes={{{nodes_text}}})" + + +@plain_data() +class AllocationFailed: + __slots__ = "node", "live_intervals", "interference_graph" + + def __init__(self, node, live_intervals, interference_graph): + # type: (IGNode, LiveIntervals, InterferenceGraph) -> None + self.node = node + self.live_intervals = live_intervals + self.interference_graph = interference_graph + + +class AllocationFailedError(Exception): + def __init__(self, msg, allocation_failed): + # type: (str, AllocationFailed) -> None + super().__init__(msg, allocation_failed) + self.allocation_failed = allocation_failed + + +def try_allocate_registers_without_spilling(ops): + # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed + + live_intervals = LiveIntervals(ops) + merged_reg_sets = live_intervals.merged_reg_sets + interference_graph = InterferenceGraph(merged_reg_sets.values()) + for op_idx, op in enumerate(ops): + reg_sets = live_intervals.reg_sets_live_after(op_idx) + for i, j in combinations(reg_sets, 2): + if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: + interference_graph[i].add_edge(interference_graph[j]) + for i, j in op.get_extra_interferences(): + i = merged_reg_sets[i] + j = merged_reg_sets[j] + if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: + interference_graph[i].add_edge(interference_graph[j]) + + nodes_remaining = OSet(interference_graph.values()) + + def local_colorability_score(node): + # type: (IGNode) -> int + """ returns a positive integer if node is locally colorable, returns + zero or a negative integer if node isn't known to be locally + colorable, the more negative the value, the less colorable + """ + if node not in nodes_remaining: + raise ValueError() + retval = len(node.reg_class) + for neighbor in node.edges: + if neighbor in nodes_remaining: + retval -= node.reg_class.max_conflicts_with(neighbor.reg_class) + return retval + + node_stack = [] # type: list[IGNode] + while True: + best_node = None # type: None | IGNode + best_score = 0 + for node in nodes_remaining: + score = local_colorability_score(node) + if best_node is None or score > best_score: + best_node = node + best_score = score + if best_score > 0: + # it's locally colorable, no need to find a better one + break + + if best_node is None: + break + node_stack.append(best_node) + nodes_remaining.remove(best_node) + + retval = {} # type: dict[SSAVal, RegLoc] + + while len(node_stack) > 0: + node = node_stack.pop() + if node.reg is not None: + if node.reg_conflicts_with_neighbors(node.reg): + return AllocationFailed(node=node, + live_intervals=live_intervals, + interference_graph=interference_graph) + else: + # pick the first non-conflicting register in node.reg_class, since + # register classes are ordered from most preferred to least + # preferred register. + for reg in node.reg_class: + if not node.reg_conflicts_with_neighbors(reg): + node.reg = reg + break + if node.reg is None: + return AllocationFailed(node=node, + live_intervals=live_intervals, + interference_graph=interference_graph) + + for ssa_val, offset in node.merged_reg_set.items(): + retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset) + + return retval + + +def allocate_registers(ops): + # type: (list[Op]) -> dict[SSAVal, RegLoc] + retval = try_allocate_registers_without_spilling(ops) + if isinstance(retval, AllocationFailed): + # TODO: implement spilling + raise AllocationFailedError( + "spilling required but not yet implemented", retval) + return retval -- 2.30.2