started working on toom_cook.py
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 7 Oct 2022 04:08:08 +0000 (21:08 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 7 Oct 2022 04:08:08 +0000 (21:08 -0700)
src/bigint_presentation_code/toom_cook.py [new file with mode: 0644]

diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py
new file mode 100644 (file)
index 0000000..bd8169b
--- /dev/null
@@ -0,0 +1,707 @@
+from abc import ABCMeta, abstractmethod
+import builtins
+from collections import defaultdict
+from enum import Enum, unique
+from typing import Iterable, Mapping, TYPE_CHECKING
+
+from nmutil.plain_data import plain_data
+
+if TYPE_CHECKING:
+    from typing_extensions import final
+else:
+    def final(v):
+        return v
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class PhysLoc:
+    pass
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class GPROrStackLoc(PhysLoc):
+    pass
+
+
+@final
+class GPR(GPROrStackLoc, Enum):
+    def __init__(self, reg_num):
+        # type: (int) -> None
+        self.reg_num = reg_num
+    # fmt: off
+    R0 = 0;     R1 = 1;     R2 = 2;     R3 = 3;     R4 = 4;     R5 = 5
+    R6 = 6;     R7 = 7;     R8 = 8;     R9 = 9;     R10 = 10;   R11 = 11
+    R12 = 12;   R13 = 13;   R14 = 14;   R15 = 15;   R16 = 16;   R17 = 17
+    R18 = 18;   R19 = 19;   R20 = 20;   R21 = 21;   R22 = 22;   R23 = 23
+    R24 = 24;   R25 = 25;   R26 = 26;   R27 = 27;   R28 = 28;   R29 = 29
+    R30 = 30;   R31 = 31;   R32 = 32;   R33 = 33;   R34 = 34;   R35 = 35
+    R36 = 36;   R37 = 37;   R38 = 38;   R39 = 39;   R40 = 40;   R41 = 41
+    R42 = 42;   R43 = 43;   R44 = 44;   R45 = 45;   R46 = 46;   R47 = 47
+    R48 = 48;   R49 = 49;   R50 = 50;   R51 = 51;   R52 = 52;   R53 = 53
+    R54 = 54;   R55 = 55;   R56 = 56;   R57 = 57;   R58 = 58;   R59 = 59
+    R60 = 60;   R61 = 61;   R62 = 62;   R63 = 63;   R64 = 64;   R65 = 65
+    R66 = 66;   R67 = 67;   R68 = 68;   R69 = 69;   R70 = 70;   R71 = 71
+    R72 = 72;   R73 = 73;   R74 = 74;   R75 = 75;   R76 = 76;   R77 = 77
+    R78 = 78;   R79 = 79;   R80 = 80;   R81 = 81;   R82 = 82;   R83 = 83
+    R84 = 84;   R85 = 85;   R86 = 86;   R87 = 87;   R88 = 88;   R89 = 89
+    R90 = 90;   R91 = 91;   R92 = 92;   R93 = 93;   R94 = 94;   R95 = 95
+    R96 = 96;   R97 = 97;   R98 = 98;   R99 = 99;   R100 = 100; R101 = 101
+    R102 = 102; R103 = 103; R104 = 104; R105 = 105; R106 = 106; R107 = 107
+    R108 = 108; R109 = 109; R110 = 110; R111 = 111; R112 = 112; R113 = 113
+    R114 = 114; R115 = 115; R116 = 116; R117 = 117; R118 = 118; R119 = 119
+    R120 = 120; R121 = 121; R122 = 122; R123 = 123; R124 = 124; R125 = 125
+    R126 = 126; R127 = 127
+    # fmt: on
+    SP = 1
+    TOC = 2
+
+
+SPECIAL_GPRS = GPR.R0, GPR.SP, GPR.TOC, GPR.R13
+
+
+@final
+@unique
+class XERBit(Enum, PhysLoc):
+    CY = "CY"
+
+
+@final
+@unique
+class GlobalMem(Enum, PhysLoc):
+    """singleton representing all non-StackSlot memory"""
+    GlobalMem = "GlobalMem"
+
+
+@plain_data()
+@final
+class StackSlot(GPROrStackLoc):
+    """a stack slot. Use OpCopy to load from/store into this stack slot."""
+    __slots__ = "offset",
+
+    def __init__(self, offset=None):
+        # type: (int | None) -> None
+        self.offset = offset
+
+
+@plain_data(eq=False)
+class SSAVal:
+    __slots__ = "id",
+
+    def __init__(self, id=None):
+        # type: (int | None) -> None
+        if id is None:
+            id = builtins.id(self)
+        self.id = id
+
+    def __eq__(self, rhs):
+        if isinstance(rhs, SSAVal):
+            return self.id == rhs.id
+        return False
+
+    def __hash__(self):
+        return hash(self.id)
+
+
+@plain_data(eq=False)
+@final
+class SSAGPRVal(SSAVal):
+    __slots__ = "phys_loc",
+
+    def __init__(self, phys_loc=None):
+        # type: (GPROrStackLoc | None) -> None
+        self.phys_loc = phys_loc
+        super().__init__()
+
+    def __len__(self):
+        return 1
+
+    def get_reg_num(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> int | None
+        phys_loc = None
+        if value_assignments is not None:
+            phys_loc = value_assignments.get(self)
+        if self.phys_loc is not None:
+            phys_loc = self.phys_loc
+        if isinstance(phys_loc, GPR):
+            return phys_loc.reg_num
+        return None
+
+
+@plain_data(eq=False)
+@final
+class SSAXERBitVal(SSAVal):
+    __slots__ = "phys_loc",
+
+    def __init__(self, phys_loc=None):
+        # type: (XERBit | None) -> None
+        self.phys_loc = phys_loc
+
+
+@plain_data(eq=False)
+@final
+class SSAMemory(SSAVal):
+    __slots__ = "phys_loc",
+
+    def __init__(self, phys_loc=GlobalMem.GlobalMem):
+        # type: (GlobalMem) -> None
+        self.phys_loc = phys_loc
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class VecArg:
+    __slots__ = "regs",
+
+    def __init__(self, regs):
+        # type: (Iterable[SSAGPRVal]) -> None
+        self.regs = tuple(regs)
+
+    def __len__(self):
+        return len(self.regs)
+
+    def try_get_range(self, value_assignments=None, prefer_descending=False):
+        # type: (dict[SSAVal, PhysLoc] | None, int) -> range | None
+        if len(self.regs) == 0:
+            return range(0)
+
+        start = self.regs[0].get_reg_num(value_assignments)
+        if start is None:
+            return None
+        if len(self.regs) == 1:
+            if prefer_descending:
+                return range(start, start - 1, -1)
+            return range(start, start + 1, 1)
+        reg = self.regs[1].get_reg_num(value_assignments)
+        if reg is None:
+            return None
+        step = reg - start
+        if abs(step) != 1:
+            return None
+        for i, reg in enumerate(self.regs):
+            reg = self.regs[1].get_reg_num(value_assignments)
+            if reg is None:
+                return None
+            if start + i * step != reg:
+                return None
+        return range(start, start + len(self.regs) * step, step)
+
+    def try_get_ascending_range(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> range | None
+        r = self.try_get_range(value_assignments=value_assignments)
+        if r is not None and r.step == 1:
+            return r
+        return None
+
+    def try_get_descending_range(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> range | None
+        r = self.try_get_range(value_assignments=value_assignments,
+                               prefer_descending=True)
+        if r is not None and r.step == -1:
+            return r
+        return None
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+class Op(metaclass=ABCMeta):
+    __slots__ = ()
+
+    def input_ssa_vals(self):
+        # type: () -> Iterable[SSAVal]
+        for arg in self.inputs().values():
+            if isinstance(arg, VecArg):
+                yield from arg.regs
+            else:
+                yield arg
+
+    def output_ssa_vals(self):
+        # type: () -> Iterable[SSAVal]
+        for arg in self.outputs().values():
+            if isinstance(arg, VecArg):
+                yield from arg.regs
+            else:
+                yield arg
+
+    @abstractmethod
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        ...
+
+    @abstractmethod
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        ...
+
+    @abstractmethod
+    def meets_constraints(self, value_assignments):
+        # type: (dict[SSAVal, PhysLoc]) -> bool
+        return True
+
+    def __init__(self):
+        pass
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpCopy(Op):
+    __slots__ = "dest", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, dest, src):
+        # type: (VecArg | SSAVal, VecArg | SSAVal) -> None
+        if isinstance(dest, VecArg) and isinstance(src, VecArg):
+            if len(src.regs) != len(dest.regs):
+                raise TypeError(f"source length must match dest "
+                                f"length: {src} doesn't match {dest}")
+        elif type(dest) != type(src):
+            raise TypeError(f"source argument type must match dest "
+                            f"argument type: {src} doesn't match {dest}")
+        self.dest = dest
+        self.src = src
+
+    def meets_constraints(self, value_assignments):
+        # type: (dict[SSAVal, PhysLoc]) -> bool
+        return True
+
+
+def range_overlaps(range1, range2):
+    # type: (range, range) -> bool
+    if len(range1) == 0 or len(range2) == 0:
+        return False
+    range1_last = next(reversed(range1))
+    range2_last = next(reversed(range2))
+    return (range1.start in range2 or range1_last in range2 or
+            range2.start in range1 or range2_last in range1)
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpAddE(Op):
+    __slots__ = "RT", "RA", "RB", "CY_in", "CY_out"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RT": self.RT, "CY_out": self.CY_out}
+
+    def __init__(self, RT, RA, RB, CY_in, CY_out):
+        # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None
+        if len(RA.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {RA} doesn't match {RT}")
+        if len(RB.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {RB} doesn't match {RT}")
+        self.RT = RT
+        self.RA = RA
+        self.RB = RB
+        self.CY_in = CY_in
+        self.CY_out = CY_out
+
+    def meets_constraints(self, value_assignments):
+        # type: (dict[SSAVal, PhysLoc]) -> bool
+        RT_range = self.RT.try_get_ascending_range(value_assignments)
+        RA_range = self.RA.try_get_ascending_range(value_assignments)
+        RB_range = self.RB.try_get_ascending_range(value_assignments)
+        if RA_range is None or RB_range is None or RT_range is None:
+            return False
+        if RA_range != RT_range and range_overlaps(RA_range, RT_range):
+            return False
+        if RB_range != RT_range and range_overlaps(RB_range, RT_range):
+            return False
+        return True
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpSubE(Op):
+    __slots__ = "RT", "lhs", "rhs", "CY_in", "CY_out"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"lhs": self.lhs, "rhs": self.rhs, "CY_in": self.CY_in}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RT": self.RT, "CY_out": self.CY_out}
+
+    def __init__(self, RT, lhs, rhs, CY_in, CY_out):
+        # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None
+        if len(lhs.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {lhs} doesn't match {RT}")
+        if len(rhs.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {rhs} doesn't match {RT}")
+        self.RT = RT
+        self.lhs = lhs
+        self.rhs = rhs
+        self.CY_in = CY_in
+        self.CY_out = CY_out
+
+    def meets_constraints(self, value_assignments):
+        # type: (dict[SSAVal, PhysLoc]) -> bool
+        RT_range = self.RT.try_get_ascending_range(value_assignments)
+        lhs_range = self.lhs.try_get_ascending_range(value_assignments)
+        rhs_range = self.rhs.try_get_ascending_range(value_assignments)
+        if lhs_range is None or rhs_range is None or RT_range is None:
+            return False
+        if lhs_range != RT_range and range_overlaps(lhs_range, RT_range):
+            return False
+        if rhs_range != RT_range and range_overlaps(rhs_range, RT_range):
+            return False
+        return True
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpMAddEU(Op):
+    __slots__ = "RT", "RA", "RB", "RC", "RS"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RT": self.RT, "RS": self.RS}
+
+    def __init__(self, RT, RA, RB, RC, RS):
+        # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None
+        if len(RA.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {RA} doesn't match {RT}")
+        self.RT = RT
+        self.RA = RA
+        self.RB = RB
+        self.RC = RC
+        self.RS = RS
+
+    def meets_constraints(self, value_assignments):
+        # type: (dict[SSAVal, PhysLoc]) -> bool
+        RT_range = self.RT.try_get_ascending_range(value_assignments)
+        RA_range = self.RA.try_get_ascending_range(value_assignments)
+        RB_reg = self.RB.get_reg_num(value_assignments)
+        RC_reg = self.RC.get_reg_num(value_assignments)
+        RS_reg = self.RS.get_reg_num(value_assignments)
+        if RT_range is None or RA_range is None or RB_reg is None \
+                or RC_reg is None or RS_reg is None:
+            return False
+        if RA_range != RT_range and range_overlaps(RA_range, RT_range):
+            return False
+        if RB_reg in RT_range or RC_reg in RT_range or RS_reg in RT_range \
+                or RS_reg in RA_range:
+            return False
+        if RS_reg != RC_reg:
+            return False
+        return True
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpDivMod2DU(Op):
+    __slots__ = "RT", "RA", "RB", "RC", "RS"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RT": self.RT, "RS": self.RS}
+
+    def __init__(self, RT, RA, RB, RC, RS):
+        # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None
+        if len(RA.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {RA} doesn't match {RT}")
+        self.RT = RT
+        self.RA = RA
+        self.RB = RB
+        self.RC = RC
+        self.RS = RS
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpBigIntShl(Op):
+    __slots__ = "RT", "inp", "sh"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"inp": self.inp, "sh": self.sh}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RT": self.RT}
+
+    def __init__(self, RT, inp, sh):
+        # type: (VecArg, VecArg, SSAGPRVal) -> None
+        if len(inp.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {inp} doesn't match {RT}")
+        self.RT = RT
+        self.inp = inp
+        self.sh = sh
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpBigIntShr(Op):
+    __slots__ = "RT", "inp", "sh"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"inp": self.inp, "sh": self.sh}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RT": self.RT}
+
+    def __init__(self, RT, inp, sh):
+        # type: (VecArg, VecArg, SSAGPRVal) -> None
+        if len(inp.regs) != len(RT.regs):
+            raise TypeError(f"source length must match dest "
+                            f"length: {inp} doesn't match {RT}")
+        self.RT = RT
+        self.inp = inp
+        self.sh = sh
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLI(Op):
+    __slots__ = "out", "value"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"out": self.out}
+
+    def __init__(self, out, value):
+        # type: (VecArg | SSAGPRVal, int) -> None
+        self.out = out
+        self.value = value
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpClearCY(Op):
+    __slots__ = "out",
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"out": self.out}
+
+    def __init__(self, out):
+        # type: (SSAXERBitVal) -> None
+        self.out = out
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpLoad(Op):
+    __slots__ = "RT", "RA", "offset", "mem"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RA": self.RA, "mem": self.mem}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RT": self.RT}
+
+    def __init__(self, RT, RA, offset, mem):
+        # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
+        self.RT = RT
+        self.RA = RA
+        self.offset = offset
+        self.mem = mem
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpStore(Op):
+    __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"mem_out": self.mem_out}
+
+    def __init__(self, RS, RA, offset, mem_in, mem_out):
+        # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory, SSAMemory) -> None
+        self.RS = RS
+        self.RA = RA
+        self.offset = offset
+        self.mem_in = mem_in
+        self.mem_out = mem_out
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpFuncArg(Op):
+    __slots__ = "out",
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"out": self.out}
+
+    def __init__(self, out):
+        # type: (VecArg | SSAGPRVal) -> None
+        self.out = out
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpInputMem(Op):
+    __slots__ = "out",
+
+    def inputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {}
+
+    def outputs(self):
+        # type: () -> dict[str, VecArg | SSAVal]
+        return {"out": self.out}
+
+    def __init__(self, out):
+        # type: (SSAMemory) -> None
+        self.out = out
+
+
+def op_set_to_list(ops):
+    # type: (Iterable[Op]) -> list[Op]
+    worklists = [set()]  # type: list[set[Op]]
+    input_vals_to_ops_map = defaultdict(set)  # type: dict[SSAVal, set[Op]]
+    ops_to_pending_input_count_map = {}  # type: dict[Op, int]
+    for op in ops:
+        input_count = 0
+        for val in op.input_ssa_vals():
+            input_count += 1
+            input_vals_to_ops_map[val].add(op)
+        while len(worklists) <= input_count:
+            worklists.append(set())
+        ops_to_pending_input_count_map[op] = input_count
+        worklists[input_count].add(op)
+    retval = []  # type: list[Op]
+    ready_vals = set()  # type: set[SSAVal]
+    while len(worklists[0]) != 0:
+        writing_op = worklists[0].pop()
+        retval.append(writing_op)
+        for val in writing_op.output_ssa_vals():
+            if val in ready_vals:
+                raise ValueError(f"multiple instructions must not write "
+                                 f"to the same SSA value: {val}")
+            ready_vals.add(val)
+            for reading_op in input_vals_to_ops_map[val]:
+                pending = ops_to_pending_input_count_map[reading_op]
+                worklists[pending].remove(reading_op)
+                pending -= 1
+                worklists[pending].add(reading_op)
+                ops_to_pending_input_count_map[reading_op] = pending
+    for worklist in worklists:
+        for op in worklist:
+            raise ValueError(f"instruction is part of a dependency loop or "
+                             f"its inputs are never written: {op}")
+    return retval
+
+
+@plain_data(unsafe_hash=True, order=True, frozen=True)
+class LiveInterval:
+    __slots__ = "assignment", "last_use"
+
+    def __init__(self, assignment, last_use=None):
+        # type: (int, int | None) -> None
+        if last_use is None:
+            last_use = assignment
+        if last_use < assignment:
+            raise ValueError("uses must be after assignment")
+        if assignment < 0 or last_use < 0:
+            raise ValueError("indexes must be nonnegative")
+        self.assignment = assignment
+        self.last_use = last_use
+
+    def overlaps(self, other):
+        # type: (LiveInterval) -> bool
+        if self.assignment == other.assignment:
+            return True
+        return self.last_use > other.assignment \
+            and other.last_use > self.assignment
+
+    def __add__(self, use):
+        # type: (int) -> LiveInterval
+        last_use = max(self.last_use, use)
+        return LiveInterval(assignment=self.assignment, last_use=last_use)
+
+
+class LiveIntervals(Mapping[SSAVal, LiveInterval]):
+    def __init__(self, ops):
+        # type: (list[Op]) -> None
+        live_intervals = {}  # type: dict[SSAVal, LiveInterval]
+        for op_idx, op in enumerate(ops):
+            for val in op.input_ssa_vals():
+                live_intervals[val] += op_idx
+            for val in op.output_ssa_vals():
+                if val in live_intervals:
+                    raise ValueError(f"multiple instructions must not write "
+                                     f"to the same SSA value: {val}")
+                live_intervals[val] = LiveInterval(op_idx)
+        self.__live_intervals = live_intervals
+
+    def __getitem__(self, key):
+        # type: (SSAVal) -> LiveInterval
+        return self.__live_intervals[key]
+
+    def __iter__(self):
+        return iter(self.__live_intervals)
+
+
+@plain_data()
+class AllocationFailed:
+    __slots__ = "op_idx", "arg", "live_intervals", "free_regs"
+
+    def __init__(self, op_idx, arg, live_intervals, free_regs):
+        # type: (int, SSAVal | VecArg, LiveIntervals, set[GPR | XERBit]) -> None
+        self.op_idx = op_idx
+        self.arg = arg
+        self.live_intervals = live_intervals
+        self.free_regs = free_regs
+
+
+def try_allocate_registers_without_spilling(ops):
+    # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
+    live_intervals = LiveIntervals(ops)
+    free_regs = set()  # type: set[GPR | XERBit]
+    free_regs.update(GPR)
+    free_regs.difference_update(SPECIAL_GPRS)
+    free_regs.update(XERBit)
+    raise NotImplementedError
+
+
+def allocate_registers(ops):
+    # type: (list[Op]) -> None
+    raise NotImplementedError