"Op(kind=OpKind.FuncArgR3, "
"inputs=[], "
"immediates=[], "
- "outputs=(<arg#0: <I64>>,), name='arg')",
+ "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<vl#0: <VL_MAXVL>>,), name='vl')",
+ "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
"Op(kind=OpKind.SvLd, "
- "inputs=[<arg#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+ "inputs=[<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[0], "
- "outputs=(<ld#0: <I64*32>>,), name='ld')",
+ "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
"Op(kind=OpKind.SvLI, "
- "inputs=[<vl#0: <VL_MAXVL>>], "
+ "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[0], "
- "outputs=(<li#0: <I64*32>>,), name='li')",
+ "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
"Op(kind=OpKind.SetCA, "
"inputs=[], "
"immediates=[], "
- "outputs=(<ca#0: <CA>>,), name='ca')",
+ "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
"Op(kind=OpKind.SvAddE, "
- "inputs=[<ld#0: <I64*32>>, <li#0: <I64*32>>, <ca#0: <CA>>, "
- "<vl#0: <VL_MAXVL>>], "
+ "inputs=[<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<add#0: <I64*32>>, <add#1: <CA>>), name='add')",
+ "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+ "name='add')",
"Op(kind=OpKind.SvStd, "
- "inputs=[<add#0: <I64*32>>, <arg#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+ "inputs=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[0], "
"outputs=(), name='st')",
])
"Op(kind=OpKind.FuncArgR3, "
"inputs=[], "
"immediates=[], "
- "outputs=(<arg#0: <I64>>,), name='arg')",
+ "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
"Op(kind=OpKind.CopyFromReg, "
- "inputs=[<arg#0: <I64>>], "
+ "inputs=[<arg.outputs[0]: <I64>>], "
"immediates=[], "
- "outputs=(<2#0: <I64>>,), name='2')",
+ "outputs=(<2.outputs[0]: <I64>>,), name='2')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<vl#0: <VL_MAXVL>>,), name='vl')",
+ "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
"Op(kind=OpKind.CopyToReg, "
- "inputs=[<2#0: <I64>>], "
+ "inputs=[<2.outputs[0]: <I64>>], "
"immediates=[], "
- "outputs=(<3#0: <I64>>,), name='3')",
+ "outputs=(<3.outputs[0]: <I64>>,), name='3')",
"Op(kind=OpKind.SvLd, "
- "inputs=[<3#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+ "inputs=[<3.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[0], "
- "outputs=(<ld#0: <I64*32>>,), name='ld')",
+ "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<4#0: <VL_MAXVL>>,), name='4')",
+ "outputs=(<4.outputs[0]: <VL_MAXVL>>,), name='4')",
"Op(kind=OpKind.VecCopyFromReg, "
- "inputs=[<ld#0: <I64*32>>, <4#0: <VL_MAXVL>>], "
+ "inputs=[<ld.outputs[0]: <I64*32>>, <4.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<5#0: <I64*32>>,), name='5')",
+ "outputs=(<5.outputs[0]: <I64*32>>,), name='5')",
"Op(kind=OpKind.SvLI, "
- "inputs=[<vl#0: <VL_MAXVL>>], "
+ "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[0], "
- "outputs=(<li#0: <I64*32>>,), name='li')",
+ "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<6#0: <VL_MAXVL>>,), name='6')",
+ "outputs=(<6.outputs[0]: <VL_MAXVL>>,), name='6')",
"Op(kind=OpKind.VecCopyFromReg, "
- "inputs=[<li#0: <I64*32>>, <6#0: <VL_MAXVL>>], "
+ "inputs=[<li.outputs[0]: <I64*32>>, <6.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<7#0: <I64*32>>,), name='7')",
+ "outputs=(<7.outputs[0]: <I64*32>>,), name='7')",
"Op(kind=OpKind.SetCA, "
"inputs=[], "
"immediates=[], "
- "outputs=(<ca#0: <CA>>,), name='ca')",
+ "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<8#0: <VL_MAXVL>>,), name='8')",
+ "outputs=(<8.outputs[0]: <VL_MAXVL>>,), name='8')",
"Op(kind=OpKind.VecCopyToReg, "
- "inputs=[<5#0: <I64*32>>, <8#0: <VL_MAXVL>>], "
+ "inputs=[<5.outputs[0]: <I64*32>>, <8.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<9#0: <I64*32>>,), name='9')",
+ "outputs=(<9.outputs[0]: <I64*32>>,), name='9')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<10#0: <VL_MAXVL>>,), name='10')",
+ "outputs=(<10.outputs[0]: <VL_MAXVL>>,), name='10')",
"Op(kind=OpKind.VecCopyToReg, "
- "inputs=[<7#0: <I64*32>>, <10#0: <VL_MAXVL>>], "
+ "inputs=[<7.outputs[0]: <I64*32>>, <10.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<11#0: <I64*32>>,), name='11')",
+ "outputs=(<11.outputs[0]: <I64*32>>,), name='11')",
"Op(kind=OpKind.SvAddE, "
- "inputs=[<9#0: <I64*32>>, <11#0: <I64*32>>, <ca#0: <CA>>, "
- "<vl#0: <VL_MAXVL>>], "
+ "inputs=[<9.outputs[0]: <I64*32>>, <11.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<add#0: <I64*32>>, <add#1: <CA>>), name='add')",
+ "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+ "name='add')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<12#0: <VL_MAXVL>>,), name='12')",
+ "outputs=(<12.outputs[0]: <VL_MAXVL>>,), name='12')",
"Op(kind=OpKind.VecCopyFromReg, "
- "inputs=[<add#0: <I64*32>>, <12#0: <VL_MAXVL>>], "
+ "inputs=[<add.outputs[0]: <I64*32>>, "
+ "<12.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<13#0: <I64*32>>,), name='13')",
+ "outputs=(<13.outputs[0]: <I64*32>>,), name='13')",
"Op(kind=OpKind.SetVLI, "
"inputs=[], "
"immediates=[32], "
- "outputs=(<14#0: <VL_MAXVL>>,), name='14')",
+ "outputs=(<14.outputs[0]: <VL_MAXVL>>,), name='14')",
"Op(kind=OpKind.VecCopyToReg, "
- "inputs=[<13#0: <I64*32>>, <14#0: <VL_MAXVL>>], "
+ "inputs=[<13.outputs[0]: <I64*32>>, <14.outputs[0]: <VL_MAXVL>>], "
"immediates=[], "
- "outputs=(<15#0: <I64*32>>,), name='15')",
+ "outputs=(<15.outputs[0]: <I64*32>>,), name='15')",
"Op(kind=OpKind.CopyToReg, "
- "inputs=[<2#0: <I64>>], "
+ "inputs=[<2.outputs[0]: <I64>>], "
"immediates=[], "
- "outputs=(<16#0: <I64>>,), name='16')",
+ "outputs=(<16.outputs[0]: <I64>>,), name='16')",
"Op(kind=OpKind.SvStd, "
- "inputs=[<15#0: <I64*32>>, <16#0: <I64>>, <vl#0: <VL_MAXVL>>], "
+ "inputs=[<15.outputs[0]: <I64*32>>, <16.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
"immediates=[0], "
"outputs=(), name='st')",
])
size_in_bytes=GPR_SIZE_IN_BYTES)
self.assertEqual(
repr(state),
- "PreRASimState(ssa_vals={<arg#0: <I64>>: (0x100,)}, memory={\n"
+ "PreRASimState(ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)}, "
+ "memory={\n"
"0x00100: <0xffffffffffffffff>,\n"
"0x00108: <0xabcdef0123456789>})")
fn.pre_ra_sim(state)
self.assertEqual(
repr(state),
"PreRASimState(ssa_vals={\n"
- "<arg#0: <I64>>: (0x100,),\n"
- "<vl#0: <VL_MAXVL>>: (0x20,),\n"
- "<ld#0: <I64*32>>: (\n"
+ "<arg.outputs[0]: <I64>>: (0x100,),\n"
+ "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
+ "<ld.outputs[0]: <I64*32>>: (\n"
" 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0),\n"
- "<li#0: <I64*32>>: (\n"
+ "<li.outputs[0]: <I64*32>>: (\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0),\n"
- "<ca#0: <CA>>: (0x1,),\n"
- "<add#0: <I64*32>>: (\n"
+ "<ca.outputs[0]: <CA>>: (0x1,),\n"
+ "<add.outputs[0]: <I64*32>>: (\n"
" 0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0,\n"
" 0x0, 0x0, 0x0, 0x0),\n"
- "<add#1: <CA>>: (0x0,),\n"
+ "<add.outputs[1]: <CA>>: (0x0,),\n"
"}, memory={\n"
"0x00100: <0x0000000000000000>,\n"
"0x00108: <0xabcdef012345678a>,\n"
import enum
-from abc import abstractmethod
+from abc import ABCMeta, abstractmethod
from enum import Enum, unique
from functools import lru_cache
from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
from nmutil.plain_data import fields, plain_data
from bigint_presentation_code.type_util import Self, assert_never, final
-from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet
+from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
@final
assert_never(out.ty.base_ty)
+@plain_data(frozen=True, eq=False)
+@final
+class FnWithUses:
+ __slots__ = "fn", "uses"
+
+ def __init__(self, fn):
+ # type: (Fn) -> None
+ self.fn = fn
+ retval = {} # type: dict[SSAVal, OSet[SSAUse]]
+ for op in fn.ops:
+ for idx, inp in enumerate(op.inputs):
+ retval[inp].add(SSAUse(op, idx))
+ for out in op.outputs:
+ retval[out] = OSet()
+ self.uses = FMap((k, OFSet(v)) for k, v in retval.items())
+
+ def __eq__(self, other):
+ # type: (FnWithUses | Any) -> bool
+ if isinstance(other, FnWithUses):
+ return self.fn == other.fn
+ return NotImplemented
+
+ def __hash__(self):
+ # type: () -> int
+ return hash(self.fn)
+
+
@unique
@final
class BaseTy(Enum):
_PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+class SSAValOrUse(metaclass=ABCMeta):
+ __slots__ = "op",
+
+ def __init__(self, op):
+ # type: (Op) -> None
+ self.op = op
+
+ @abstractmethod
+ def __repr__(self):
+ # type: () -> str
+ ...
+
+ @property
+ @abstractmethod
+ def defining_descriptor(self):
+ # type: () -> OperandDesc
+ ...
+
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ return self.defining_descriptor.ty
+
+ @cached_property
+ def ty_before_spread(self):
+ # type: () -> Ty
+ return self.defining_descriptor.ty_before_spread
+
+ @property
+ def base_ty(self):
+ # type: () -> BaseTy
+ return self.ty_before_spread.base_ty
+
+
@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
-class SSAVal:
- __slots__ = "op", "output_idx"
+class SSAVal(SSAValOrUse):
+ __slots__ = "output_idx",
def __init__(self, op, output_idx):
# type: (Op, int) -> None
- self.op = op
+ super().__init__(op)
if output_idx < 0 or output_idx >= len(op.properties.outputs):
raise ValueError("invalid output_idx")
self.output_idx = output_idx
def __repr__(self):
# type: () -> str
- return f"<{self.op.name}#{self.output_idx}: {self.ty}>"
+ return f"<{self.op.name}.outputs[{self.output_idx}]: {self.ty}>"
+
+ @cached_property
+ def def_loc_set_before_spread(self):
+ # type: () -> LocSet
+ return self.defining_descriptor.loc_set_before_spread
@cached_property
def defining_descriptor(self):
# type: () -> OperandDesc
return self.op.properties.outputs[self.output_idx]
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class SSAUse(SSAValOrUse):
+ __slots__ = "input_idx",
+
+ def __init__(self, op, input_idx):
+ # type: (Op, int) -> None
+ super().__init__(op)
+ self.input_idx = input_idx
+ if input_idx < 0 or input_idx >= len(op.inputs):
+ raise ValueError("input_idx out of range")
+
@cached_property
- def loc_set_before_spread(self):
+ def use_loc_set_before_spread(self):
# type: () -> LocSet
return self.defining_descriptor.loc_set_before_spread
@cached_property
- def ty(self):
- # type: () -> Ty
- return self.defining_descriptor.ty
+ def defining_descriptor(self):
+ # type: () -> OperandDesc
+ return self.op.properties.inputs[self.input_idx]
- @cached_property
- def ty_before_spread(self):
- # type: () -> Ty
- return self.defining_descriptor.ty_before_spread
+ def __repr__(self):
+ # type: () -> str
+ return f"<{self.op.name}.inputs[{self.input_idx}]: {self.ty}>"
_T = TypeVar("_T")
self._verify_write_with_desc(idx, item, desc)
return idx
+ def _on_set(self, idx, new_item, old_item):
+ # type: (int, _T, _T | None) -> None
+ pass
+
@abstractmethod
def _get_descriptors(self):
# type: () -> tuple[_Desc, ...]
raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
f"corresponding input's type {desc.ty!r}")
+ def _on_set(self, idx, new_item, old_item):
+ # type: (int, SSAVal, SSAVal | None) -> None
+ SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
+
def __init__(self, items, op):
# type: (Iterable[SSAVal], Op) -> None
if hasattr(op, "inputs"):
--- /dev/null
+"""
+Register Allocator for Toom-Cook algorithm generator for SVP64
+
+this uses an algorithm based on:
+[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+"""
+
+from itertools import combinations
+from functools import reduce
+from typing import Generic, Iterable, Mapping
+from cached_property import cached_property
+import operator
+
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir2 import (
+ Op, LocSet, Ty, SSAVal, BaseTy, Loc, FnWithUses)
+from bigint_presentation_code.type_util import final, Self
+from bigint_presentation_code.util import OFSet, OSet, FMap
+
+
+@plain_data(unsafe_hash=True, order=True, frozen=True)
+class LiveInterval:
+ __slots__ = "first_write", "last_use"
+
+ def __init__(self, first_write, last_use=None):
+ # type: (int, int | None) -> None
+ if last_use is None:
+ last_use = first_write
+ if last_use < first_write:
+ raise ValueError("uses must be after first_write")
+ if first_write < 0 or last_use < 0:
+ raise ValueError("indexes must be nonnegative")
+ self.first_write = first_write
+ self.last_use = last_use
+
+ def overlaps(self, other):
+ # type: (LiveInterval) -> bool
+ if self.first_write == other.first_write:
+ return True
+ return self.last_use > other.first_write \
+ and other.last_use > self.first_write
+
+ def __add__(self, use):
+ # type: (int) -> LiveInterval
+ last_use = max(self.last_use, use)
+ return LiveInterval(first_write=self.first_write, last_use=last_use)
+
+ @property
+ def live_after_op_range(self):
+ """the range of op indexes where self is live immediately after the
+ Op at each index
+ """
+ return range(self.first_write, self.last_use)
+
+
+class BadMergedSSAVal(ValueError):
+ pass
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class MergedSSAVal:
+ """a set of `SSAVal`s along with their offsets, all register allocated as
+ a single unit.
+
+ Definition of the term `offset` for this class:
+
+ Let `locs[x]` be the `Loc` that `x` is assigned to after register
+ allocation and let `msv` be a `MergedSSAVal` instance, then the offset
+ for each `SSAVal` `ssa_val` in `msv` is defined as:
+
+ ```
+ msv.ssa_val_offsets[ssa_val] = (msv.offset
+ + locs[ssa_val].start - locs[msv].start)
+ ```
+
+ Example:
+ ```
+ v1.ty == <I64*4>
+ v2.ty == <I64*2>
+ v3.ty == <I64>
+ msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
+ msv.ty == <I64*6>
+ ```
+ if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
+ * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
+ * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
+ * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
+ """
+ __slots__ = "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set"
+
+ def __init__(self, fn_with_uses, ssa_val_offsets):
+ # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
+ self.fn_with_uses = fn_with_uses
+ if isinstance(ssa_val_offsets, SSAVal):
+ ssa_val_offsets = {ssa_val_offsets: 0}
+ self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
+ base_ty = None
+ for ssa_val in self.ssa_val_offsets.keys():
+ base_ty = ssa_val.base_ty
+ break
+ if base_ty is None:
+ raise BadMergedSSAVal("MergedSSAVal can't be empty")
+ self.base_ty = base_ty # type: BaseTy
+ # self.ty checks for mismatched base_ty
+ reg_len = self.ty.reg_len
+ loc_set = None # type: None | LocSet
+ for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
+ def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ for loc in ssa_val.def_loc_set_before_spread:
+ disallowed_by_use = False
+ for use in fn_with_uses.uses[ssa_val]:
+ use_spread_idx = \
+ use.defining_descriptor.spread_index or 0
+ # calculate the start for the use's Loc before spread
+ # e.g. if the def's Loc before spread starts at r6
+ # and the def's spread_index is 5
+ # and the use's spread_index is 3
+ # then the use's Loc before spread starts at r8
+ # because 8 == 6 + 5 - 3
+ start = loc.start + def_spread_idx - use_spread_idx
+ use_loc = Loc.try_make(
+ loc.kind, start=start,
+ reg_len=use.ty_before_spread.reg_len)
+ if (use_loc is None or
+ use_loc not in use.use_loc_set_before_spread):
+ disallowed_by_use = True
+ break
+ if disallowed_by_use:
+ continue
+ # FIXME: add spread consistency check
+ start = loc.start - cur_offset + self.offset
+ loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
+ if loc is not None and (loc_set is None or loc in loc_set):
+ yield loc
+ loc_set = LocSet(locs())
+ assert loc_set is not None, "already checked that self isn't empty"
+ if loc_set.ty is None:
+ raise BadMergedSSAVal("there are no valid Locs left")
+ assert loc_set.ty == self.ty, "logic error somewhere"
+ self.loc_set = loc_set # type: LocSet
+
+ @cached_property
+ def offset(self):
+ # type: () -> int
+ return min(self.ssa_val_offsets_before_spread.values())
+
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ reg_len = 0
+ for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
+ cur_ty = ssa_val.ty_before_spread
+ if self.base_ty != cur_ty.base_ty:
+ raise BadMergedSSAVal(
+ f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
+ reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
+ return Ty(base_ty=self.base_ty, reg_len=reg_len)
+
+ @cached_property
+ def ssa_val_offsets_before_spread(self):
+ # type: () -> FMap[SSAVal, int]
+ retval = {} # type: dict[SSAVal, int]
+ for ssa_val, offset in self.ssa_val_offsets.items():
+ offset_before_spread = offset
+ spread_index = ssa_val.defining_descriptor.spread_index
+ if spread_index is not None:
+ assert ssa_val.ty.reg_len == 1, (
+ "this function assumes spreading always converts a vector "
+ "to a contiguous sequence of scalars, if that's changed "
+ "in the future, then this function needs to be adjusted")
+ offset_before_spread -= spread_index
+ retval[ssa_val] = offset_before_spread
+ return FMap(retval)
+
+ def offset_by(self, amount):
+ # type: (int) -> MergedSSAVal
+ v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
+ return MergedSSAVal(fn_with_uses=self.fn_with_uses, ssa_val_offsets=v)
+
+ def normalized(self):
+ # type: () -> MergedSSAVal
+ return self.offset_by(-self.offset)
+
+ def with_offset_to_match(self, target):
+ # type: (MergedSSAVal) -> MergedSSAVal
+ for ssa_val, offset in self.ssa_val_offsets.items():
+ if ssa_val in target.ssa_val_offsets:
+ return self.offset_by(target.ssa_val_offsets[ssa_val] - offset)
+ raise ValueError("can't change offset to match unrelated MergedSSAVal")
+
+
+@final
+class MergedSSAVals(OFSet[MergedSSAVal]):
+ def __init__(self, merged_ssa_vals=()):
+ # type: (Iterable[MergedSSAVal]) -> None
+ super().__init__(merged_ssa_vals)
+ merge_map = {} # type: dict[SSAVal, MergedSSAVal]
+ for merged_ssa_val in self:
+ for ssa_val in merged_ssa_val.ssa_val_offsets.keys():
+ if ssa_val in merge_map:
+ raise ValueError(
+ f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
+ f"{merged_ssa_val} and {merge_map[ssa_val]}")
+ merge_map[ssa_val] = merged_ssa_val
+ self.__merge_map = FMap(merge_map)
+
+ @cached_property
+ def merge_map(self):
+ # type: () -> FMap[SSAVal, MergedSSAVal]
+ return self.__merge_map
+
+# FIXME: work on code from here
+
+ @staticmethod
+ def minimally_merged(fn_with_uses):
+ # type: (FnWithUses) -> MergedSSAVals
+ merge_map = {} # type: dict[SSAVal, MergedSSAVal]
+ for op in fn_with_uses.fn.ops:
+ for fn
+ for val in (*op.inputs().values(), *op.outputs().values()):
+ if val not in merged_sets:
+ merged_sets[val] = MergedRegSet(val)
+ for e in op.get_equality_constraints():
+ lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
+ rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
+ items = [] # type: list[tuple[SSAVal, int]]
+ for i in e.lhs:
+ s = merged_sets[i].with_offset_to_match(lhs_set)
+ items.extend(s.items())
+ for i in e.rhs:
+ s = merged_sets[i].with_offset_to_match(rhs_set)
+ items.extend(s.items())
+ full_set = MergedRegSet(items)
+ for val in full_set.keys():
+ merged_sets[val] = full_set
+
+ self.__map = {k: v.normalized() for k, v in merged_sets.items()}
+
+
+@final
+class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
+ def __init__(self, ops):
+ # type: (list[Op]) -> None
+ self.__merged_reg_sets = MergedRegSets(ops)
+ live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
+ for op_idx, op in enumerate(ops):
+ for val in op.inputs().values():
+ live_intervals[self.__merged_reg_sets[val]] += op_idx
+ for val in op.outputs().values():
+ reg_set = self.__merged_reg_sets[val]
+ if reg_set not in live_intervals:
+ live_intervals[reg_set] = LiveInterval(op_idx)
+ else:
+ live_intervals[reg_set] += op_idx
+ self.__live_intervals = live_intervals
+ live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
+ live_after += (OSet() for _ in ops)
+ for reg_set, live_interval in self.__live_intervals.items():
+ for i in live_interval.live_after_op_range:
+ live_after[i].add(reg_set)
+ self.__live_after = [OFSet(i) for i in live_after]
+
+ @property
+ def merged_reg_sets(self):
+ return self.__merged_reg_sets
+
+ def __getitem__(self, key):
+ # type: (MergedRegSet[_RegType]) -> LiveInterval
+ return self.__live_intervals[key]
+
+ def __iter__(self):
+ return iter(self.__live_intervals)
+
+ def __len__(self):
+ return len(self.__live_intervals)
+
+ def reg_sets_live_after(self, op_index):
+ # type: (int) -> OFSet[MergedRegSet[_RegType]]
+ return self.__live_after[op_index]
+
+ def __repr__(self):
+ reg_sets_live_after = dict(enumerate(self.__live_after))
+ return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
+ f"merged_reg_sets={self.merged_reg_sets}, "
+ f"reg_sets_live_after={reg_sets_live_after})")
+
+
+@final
+class IGNode(Generic[_RegType]):
+ """ interference graph node """
+ __slots__ = "merged_reg_set", "edges", "reg"
+
+ def __init__(self, merged_reg_set, edges=(), reg=None):
+ # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
+ self.merged_reg_set = merged_reg_set
+ self.edges = OSet(edges)
+ self.reg = reg
+
+ def add_edge(self, other):
+ # type: (IGNode) -> None
+ self.edges.add(other)
+ other.edges.add(self)
+
+ def __eq__(self, other):
+ # type: (object) -> bool
+ if isinstance(other, IGNode):
+ return self.merged_reg_set == other.merged_reg_set
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.merged_reg_set)
+
+ def __repr__(self, nodes=None):
+ # type: (None | dict[IGNode, int]) -> str
+ if nodes is None:
+ nodes = {}
+ if self in nodes:
+ return f"<IGNode #{nodes[self]}>"
+ nodes[self] = len(nodes)
+ edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
+ return (f"IGNode(#{nodes[self]}, "
+ f"merged_reg_set={self.merged_reg_set}, "
+ f"edges={edges}, "
+ f"reg={self.reg})")
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return self.merged_reg_set.ty.reg_class
+
+ def reg_conflicts_with_neighbors(self, reg):
+ # type: (RegLoc) -> bool
+ for neighbor in self.edges:
+ if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+ return True
+ return False
+
+
+@final
+class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
+ def __init__(self, merged_reg_sets):
+ # type: (Iterable[MergedRegSet[_RegType]]) -> None
+ self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
+
+ def __getitem__(self, key):
+ # type: (MergedRegSet[_RegType]) -> IGNode
+ return self.__nodes[key]
+
+ def __iter__(self):
+ return iter(self.__nodes)
+
+ def __len__(self):
+ return len(self.__nodes)
+
+ def __repr__(self):
+ nodes = {}
+ nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
+ nodes_text = ", ".join(nodes_text)
+ return f"InterferenceGraph(nodes={{{nodes_text}}})"
+
+
+@plain_data()
+class AllocationFailed:
+ __slots__ = "node", "live_intervals", "interference_graph"
+
+ def __init__(self, node, live_intervals, interference_graph):
+ # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+ self.node = node
+ self.live_intervals = live_intervals
+ self.interference_graph = interference_graph
+
+
+class AllocationFailedError(Exception):
+ def __init__(self, msg, allocation_failed):
+ # type: (str, AllocationFailed) -> None
+ super().__init__(msg, allocation_failed)
+ self.allocation_failed = allocation_failed
+
+
+def try_allocate_registers_without_spilling(ops):
+ # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
+
+ live_intervals = LiveIntervals(ops)
+ merged_reg_sets = live_intervals.merged_reg_sets
+ interference_graph = InterferenceGraph(merged_reg_sets.values())
+ for op_idx, op in enumerate(ops):
+ reg_sets = live_intervals.reg_sets_live_after(op_idx)
+ for i, j in combinations(reg_sets, 2):
+ if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+ interference_graph[i].add_edge(interference_graph[j])
+ for i, j in op.get_extra_interferences():
+ i = merged_reg_sets[i]
+ j = merged_reg_sets[j]
+ if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+ interference_graph[i].add_edge(interference_graph[j])
+
+ nodes_remaining = OSet(interference_graph.values())
+
+ def local_colorability_score(node):
+ # type: (IGNode) -> int
+ """ returns a positive integer if node is locally colorable, returns
+ zero or a negative integer if node isn't known to be locally
+ colorable, the more negative the value, the less colorable
+ """
+ if node not in nodes_remaining:
+ raise ValueError()
+ retval = len(node.reg_class)
+ for neighbor in node.edges:
+ if neighbor in nodes_remaining:
+ retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
+ return retval
+
+ node_stack = [] # type: list[IGNode]
+ while True:
+ best_node = None # type: None | IGNode
+ best_score = 0
+ for node in nodes_remaining:
+ score = local_colorability_score(node)
+ if best_node is None or score > best_score:
+ best_node = node
+ best_score = score
+ if best_score > 0:
+ # it's locally colorable, no need to find a better one
+ break
+
+ if best_node is None:
+ break
+ node_stack.append(best_node)
+ nodes_remaining.remove(best_node)
+
+ retval = {} # type: dict[SSAVal, RegLoc]
+
+ while len(node_stack) > 0:
+ node = node_stack.pop()
+ if node.reg is not None:
+ if node.reg_conflicts_with_neighbors(node.reg):
+ return AllocationFailed(node=node,
+ live_intervals=live_intervals,
+ interference_graph=interference_graph)
+ else:
+ # pick the first non-conflicting register in node.reg_class, since
+ # register classes are ordered from most preferred to least
+ # preferred register.
+ for reg in node.reg_class:
+ if not node.reg_conflicts_with_neighbors(reg):
+ node.reg = reg
+ break
+ if node.reg is None:
+ return AllocationFailed(node=node,
+ live_intervals=live_intervals,
+ interference_graph=interference_graph)
+
+ for ssa_val, offset in node.merged_reg_set.items():
+ retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
+
+ return retval
+
+
+def allocate_registers(ops):
+ # type: (list[Op]) -> dict[SSAVal, RegLoc]
+ retval = try_allocate_registers_without_spilling(ops)
+ if isinstance(retval, AllocationFailed):
+ # TODO: implement spilling
+ raise AllocationFailedError(
+ "spilling required but not yet implemented", retval)
+ return retval