working on code -- register_allocator2.py should work... still needs copy merging...
authorJacob Lifshay <programmerjake@gmail.com>
Mon, 7 Nov 2022 02:07:12 +0000 (18:07 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Mon, 7 Nov 2022 02:07:12 +0000 (18:07 -0800)
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/register_allocator2.py

index 7bdd40b0c8931a82d9ed55a5be8a70b0e4fa4171..9f911969c7f96235d472f7fb7b15cac83c9491a9 100644 (file)
@@ -2,14 +2,16 @@ import enum
 from abc import ABCMeta, abstractmethod
 from enum import Enum, unique
 from functools import lru_cache, total_ordering
+from io import StringIO
 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
-                    Sequence, TypeVar, overload)
+                    Mapping, Sequence, TypeVar, overload)
 from weakref import WeakValueDictionary as _WeakVDict
 
 from cached_property import cached_property
 from nmutil.plain_data import fields, plain_data
 
-from bigint_presentation_code.type_util import Self, assert_never, final, Literal
+from bigint_presentation_code.type_util import (Literal, Self, assert_never,
+                                                final)
 from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
 
 
@@ -600,6 +602,16 @@ class Loc:
             reg_len += other.reg_len
         return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
 
+    def get_subloc_at_offset(self, subloc_ty, offset):
+        # type: (Ty, int) -> Loc
+        if subloc_ty.base_ty != self.kind.base_ty:
+            raise ValueError("BaseTy mismatch")
+        start = self.start + offset
+        if offset < 0 or start + subloc_ty.reg_len > self.reg_len:
+            raise ValueError("invalid sub-Loc: offset and/or "
+                             "subloc_ty.reg_len out of range")
+        return Loc(kind=self.kind, start=start, reg_len=subloc_ty.reg_len)
+
 
 SPECIAL_GPRS = (
     Loc(kind=LocKind.GPR, start=0, reg_len=1),
@@ -1014,6 +1026,9 @@ IMM_S16 = range(-1 << 15, 1 << 15)
 _PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None]
 _PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN]
 _PRE_RA_SIMS = {}  # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
+_GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
+_GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
+_GEN_ASMS = {}  # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
 
 
 @unique
@@ -1042,27 +1057,44 @@ class OpKind(Enum):
         # type: () -> _PRE_RA_SIM_FN
         return _PRE_RA_SIMS[self.properties]()
 
+    @cached_property
+    def gen_asm(self):
+        # type: () -> _GEN_ASM_FN
+        return _GEN_ASMS[self.properties]()
+
     @staticmethod
     def __clearca_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = False,
+
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     ClearCA = GenericOpProperties(
         demo_asm="addic 0, 0, 0",
         inputs=[],
         outputs=[OD_CA.with_write_stage(OpStage.Late)],
     )
     _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
+    _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
 
     @staticmethod
     def __setca_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = True,
+
+    @staticmethod
+    def __setca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("subfc 0, 0, 0")
     SetCA = GenericOpProperties(
         demo_asm="subfc 0, 0, 0",
         inputs=[],
         outputs=[OD_CA.with_write_stage(OpStage.Late)],
     )
     _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
+    _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
 
     @staticmethod
     def __svadde_pre_ra_sim(op, state):
@@ -1078,12 +1110,21 @@ class OpKind(Enum):
             carry = (v >> GPR_SIZE_IN_BITS) != 0
         state.ssa_vals[op.outputs[0]] = tuple(RT)
         state.ssa_vals[op.outputs[1]] = carry,
+
+    @staticmethod
+    def __svadde_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        RT = state.vgpr(op.outputs[0])
+        RA = state.vgpr(op.input_vals[0])
+        RB = state.vgpr(op.input_vals[1])
+        state.writeln(f"sv.adde {RT}, {RA}, {RB}")
     SvAddE = GenericOpProperties(
         demo_asm="sv.adde *RT, *RA, *RB",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
         outputs=[OD_EXTRA3_VGPR, OD_CA],
     )
     _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
+    _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
 
     @staticmethod
     def __svsubfe_pre_ra_sim(op, state):
@@ -1099,12 +1140,21 @@ class OpKind(Enum):
             carry = (v >> GPR_SIZE_IN_BITS) != 0
         state.ssa_vals[op.outputs[0]] = tuple(RT)
         state.ssa_vals[op.outputs[1]] = carry,
+
+    @staticmethod
+    def __svsubfe_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        RT = state.vgpr(op.outputs[0])
+        RA = state.vgpr(op.input_vals[0])
+        RB = state.vgpr(op.input_vals[1])
+        state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
     SvSubFE = GenericOpProperties(
         demo_asm="sv.subfe *RT, *RA, *RB",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
         outputs=[OD_EXTRA3_VGPR, OD_CA],
     )
     _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
+    _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
 
     @staticmethod
     def __svmaddedu_pre_ra_sim(op, state):
@@ -1120,17 +1170,33 @@ class OpKind(Enum):
             carry = v >> GPR_SIZE_IN_BITS
         state.ssa_vals[op.outputs[0]] = tuple(RT)
         state.ssa_vals[op.outputs[1]] = carry,
+
+    @staticmethod
+    def __svmaddedu_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        RT = state.vgpr(op.outputs[0])
+        RA = state.vgpr(op.input_vals[0])
+        RB = state.sgpr(op.input_vals[1])
+        RC = state.sgpr(op.input_vals[2])
+        state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
     SvMAddEDU = GenericOpProperties(
         demo_asm="sv.maddedu *RT, *RA, RB, RC",
         inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
         outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
     )
     _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim
+    _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
 
     @staticmethod
     def __setvli_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = op.immediates[0],
+
+    @staticmethod
+    def __setvli_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        imm = op.immediates[0]
+        state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
     SetVLI = GenericOpProperties(
         demo_asm="setvl 0, 0, imm, 0, 1, 1",
         inputs=(),
@@ -1139,6 +1205,7 @@ class OpKind(Enum):
         is_load_immediate=True,
     )
     _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim
+    _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
 
     @staticmethod
     def __svli_pre_ra_sim(op, state):
@@ -1146,6 +1213,13 @@ class OpKind(Enum):
         VL, = state.ssa_vals[op.input_vals[0]]
         imm = op.immediates[0] & GPR_VALUE_MASK
         state.ssa_vals[op.outputs[0]] = (imm,) * VL
+
+    @staticmethod
+    def __svli_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        RT = state.vgpr(op.outputs[0])
+        imm = op.immediates[0]
+        state.writeln(f"sv.addi {RT}, 0, {imm}")
     SvLI = GenericOpProperties(
         demo_asm="sv.addi *RT, 0, imm",
         inputs=[OD_VL],
@@ -1154,12 +1228,20 @@ class OpKind(Enum):
         is_load_immediate=True,
     )
     _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim
+    _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
 
     @staticmethod
     def __li_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         imm = op.immediates[0] & GPR_VALUE_MASK
         state.ssa_vals[op.outputs[0]] = imm,
+
+    @staticmethod
+    def __li_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        RT = state.sgpr(op.outputs[0])
+        imm = op.immediates[0]
+        state.writeln(f"addi {RT}, 0, {imm}")
     LI = GenericOpProperties(
         demo_asm="addi RT, 0, imm",
         inputs=(),
@@ -1168,11 +1250,35 @@ class OpKind(Enum):
         is_load_immediate=True,
     )
     _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim
+    _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
 
     @staticmethod
     def __veccopytoreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+
+    @staticmethod
+    def __veccopytoreg_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        src_loc = state.loc(op.input_vals[0], (LocKind.GPR, LocKind.StackI64))
+        dest_loc = state.loc(op.outputs[0], LocKind.GPR)
+        RT = state.vgpr(dest_loc)
+        if src_loc == dest_loc:
+            return  # no-op
+        assert src_loc.kind in (LocKind.GPR, LocKind.StackI64), \
+            "checked by loc()"
+        if src_loc.kind is LocKind.StackI64:
+            src = state.stack(src_loc)
+            state.writeln(f"sv.ld {RT}, {src}")
+            return
+        elif src_loc.kind is not LocKind.GPR:
+            assert_never(src_loc.kind)
+        rev = ""
+        if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
+            rev = "/mrr"
+        src = state.vgpr(src_loc)
+        state.writeln(f"sv.or{rev} {RT}, {src}, {src}")
+
     VecCopyToReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1183,11 +1289,18 @@ class OpKind(Enum):
         is_copy=True,
     )
     _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
+    _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
 
     @staticmethod
     def __veccopyfromreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     VecCopyFromReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[OD_EXTRA3_VGPR, OD_VL],
@@ -1199,11 +1312,18 @@ class OpKind(Enum):
         is_copy=True,
     )
     _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim
+    _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
 
     @staticmethod
     def __copytoreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     CopyToReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1219,11 +1339,18 @@ class OpKind(Enum):
         is_copy=True,
     )
     _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim
+    _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
 
     @staticmethod
     def __copyfromreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     CopyFromReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1239,12 +1366,19 @@ class OpKind(Enum):
         is_copy=True,
     )
     _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim
+    _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
 
     @staticmethod
     def __concat_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = tuple(
             state.ssa_vals[i][0] for i in op.input_vals[:-1])
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     Concat = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1256,12 +1390,19 @@ class OpKind(Enum):
         is_copy=True,
     )
     _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
+    _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
 
     @staticmethod
     def __spread_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
             state.ssa_vals[op.outputs[idx]] = inp,
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     Spread = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[OD_EXTRA3_VGPR, OD_VL],
@@ -1274,6 +1415,7 @@ class OpKind(Enum):
         is_copy=True,
     )
     _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim
+    _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
 
     @staticmethod
     def __svld_pre_ra_sim(op, state):
@@ -1286,6 +1428,12 @@ class OpKind(Enum):
             v = state.load(addr + GPR_SIZE_IN_BYTES * i)
             RT.append(v & GPR_VALUE_MASK)
         state.ssa_vals[op.outputs[0]] = tuple(RT)
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     SvLd = GenericOpProperties(
         demo_asm="sv.ld *RT, imm(RA)",
         inputs=[OD_EXTRA3_SGPR, OD_VL],
@@ -1293,6 +1441,7 @@ class OpKind(Enum):
         immediates=[IMM_S16],
     )
     _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim
+    _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
 
     @staticmethod
     def __ld_pre_ra_sim(op, state):
@@ -1301,6 +1450,12 @@ class OpKind(Enum):
         addr = RA + op.immediates[0]
         v = state.load(addr)
         state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     Ld = GenericOpProperties(
         demo_asm="ld RT, imm(RA)",
         inputs=[OD_BASE_SGPR],
@@ -1308,6 +1463,7 @@ class OpKind(Enum):
         immediates=[IMM_S16],
     )
     _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
+    _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
 
     @staticmethod
     def __svstd_pre_ra_sim(op, state):
@@ -1318,6 +1474,12 @@ class OpKind(Enum):
         addr = RA + op.immediates[0]
         for i in range(VL):
             state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     SvStd = GenericOpProperties(
         demo_asm="sv.std *RS, imm(RA)",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
@@ -1326,6 +1488,7 @@ class OpKind(Enum):
         has_side_effects=True,
     )
     _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim
+    _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
 
     @staticmethod
     def __std_pre_ra_sim(op, state):
@@ -1334,6 +1497,12 @@ class OpKind(Enum):
         RA, = state.ssa_vals[op.input_vals[1]]
         addr = RA + op.immediates[0]
         state.store(addr, value=RS)
+
+    # FIXME: change to correct __*_gen_asm function
+    @staticmethod
+    def __clearca_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        state.writeln("addic 0, 0, 0")
     Std = GenericOpProperties(
         demo_asm="std RT, imm(RA)",
         inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
@@ -1342,11 +1511,17 @@ class OpKind(Enum):
         has_side_effects=True,
     )
     _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim
+    _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
 
     @staticmethod
     def __funcargr3_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         pass  # return value set before simulation
+
+    @staticmethod
+    def __funcargr3_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        pass  # no instructions needed
     FuncArgR3 = GenericOpProperties(
         demo_asm="",
         inputs=[],
@@ -1354,6 +1529,7 @@ class OpKind(Enum):
             Loc(kind=LocKind.GPR, start=3, reg_len=1))],
     )
     _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
+    _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
 
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
@@ -1812,3 +1988,62 @@ class PreRASimState:
                 field_vals.append(f"{name}={value!r}")
         field_vals_str = ", ".join(field_vals)
         return f"PreRASimState({field_vals_str})"
+
+
+@plain_data(frozen=True)
+class GenAsmState:
+    __slots__ = "allocated_locs", "output"
+
+    def __init__(self, allocated_locs, output):
+        # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
+        super().__init__()
+        self.allocated_locs = FMap(allocated_locs)
+        for ssa_val, loc in self.allocated_locs.items():
+            if ssa_val.ty != loc.ty:
+                raise ValueError(
+                    f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
+        if output is None:
+            output = []
+        self.output = output
+
+    def loc(self, ssa_val_or_loc, expected_kinds):
+        # type: (SSAVal | Loc, LocKind | tuple[LocKind, ...]) -> Loc
+        if isinstance(ssa_val_or_loc, SSAVal):
+            retval = self.allocated_locs[ssa_val_or_loc]
+        else:
+            retval = ssa_val_or_loc
+        if isinstance(expected_kinds, LocKind):
+            expected_kinds = expected_kinds,
+        if retval.kind not in expected_kinds:
+            if len(expected_kinds) == 1:
+                expected_kinds = expected_kinds[0]
+            raise ValueError(f"LocKind mismatch: {ssa_val_or_loc}: found "
+                             f"{retval.kind} expected {expected_kinds}")
+        return retval
+
+    def gpr(self, ssa_val_or_loc, is_vec):
+        # type: (SSAVal | Loc, bool) -> str
+        loc = self.loc(ssa_val_or_loc, LocKind.GPR)
+        vec_str = "*" if is_vec else ""
+        return vec_str + str(loc.start)
+
+    def sgpr(self, ssa_val_or_loc):
+        # type: (SSAVal | Loc) -> str
+        return self.gpr(ssa_val_or_loc, is_vec=False)
+
+    def vgpr(self, ssa_val_or_loc):
+        # type: (SSAVal | Loc) -> str
+        return self.gpr(ssa_val_or_loc, is_vec=True)
+
+    def stack(self, ssa_val_or_loc):
+        # type: (SSAVal | Loc) -> str
+        loc = self.loc(ssa_val_or_loc, LocKind.StackI64)
+        return f"{loc.start}(1)"
+
+    def writeln(self, *line_segments):
+        # type: (*str) -> None
+        line = " ".join(line_segments)
+        if isinstance(self.output, list):
+            self.output.append(line)
+        else:
+            self.output.write(line + "\n")
index 20ca534b108db69184deaacdf4e96f44fe43d163..d3ca3983c9dfb114fb0a7454484e9a6c3688a200 100644 (file)
@@ -6,13 +6,13 @@ this uses an algorithm based on:
 """
 
 from itertools import combinations
-from typing import Any, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Dict
+from typing import Iterable, Iterator, Mapping
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.compiler_ir2 import (BaseTy, FnAnalysis, Loc,
-                                                   LocSet, Op, ProgramRange,
+from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc,
+                                                   LocSet, ProgramRange,
                                                    SSAVal, Ty)
 from bigint_presentation_code.type_util import final
 from bigint_presentation_code.util import FMap, OFSet, OSet
@@ -269,7 +269,7 @@ class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
 
 
 @final
-class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]):
+class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
     def __init__(
         self, *,
         _private_merged_ssa_val_map,  # type: dict[SSAVal, MergedSSAVal]
@@ -305,7 +305,7 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]):
                         f"{self.__merged_ssa_val_map[ssa_val]}")
                 self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
                 added += 1
-            retval = IGNode(merged_ssa_val)
+            retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
             self.__map[merged_ssa_val] = retval
             added = None
             return retval
@@ -319,21 +319,40 @@ class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]):
 
     def merge_into_one_node(self, final_merged_ssa_val):
         # type: (MergedSSAVal) -> IGNode
-        source_nodes = {}  # type: dict[MergedSSAVal, IGNode]
+        source_nodes = OSet()  # type: OSet[IGNode]
+        edges = OSet()  # type: OSet[IGNode]
+        loc = None  # type: Loc | None
         for ssa_val in final_merged_ssa_val.ssa_vals:
             merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
-            source_nodes[merged_ssa_val] = self.__map[merged_ssa_val]
+            source_node = self.__map[merged_ssa_val]
+            source_nodes.add(source_node)
             for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
                 raise ValueError(
                     f"SSAVal {i} appears in source IGNode's merged_ssa_val "
                     f"but not in merged IGNode's merged_ssa_val: "
-                    f"source_node={self.__map[merged_ssa_val]} "
+                    f"source_node={source_node} "
                     f"final_merged_ssa_val={final_merged_ssa_val}")
-        # FIXME: work on function from here
-        raise NotImplementedError
-        self.__values_set.discard(value)
-        for ssa_val in value.ssa_val_offsets.keys():
-            del self.__merge_map[ssa_val]
+            if loc is None:
+                loc = source_node.loc
+            elif source_node.loc is not None and loc != source_node.loc:
+                raise ValueError(f"can't merge IGNodes with mismatched `loc` "
+                                 f"values: {loc} != {source_node.loc}")
+            edges |= source_node.edges
+        if len(source_nodes) == 1:
+            return source_nodes.pop()  # merging a single node is a no-op
+        # we're finished checking validity, now we can modify stuff
+        edges -= source_nodes
+        retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
+                        loc=loc)
+        for node in edges:
+            node.edges -= source_nodes
+            node.edges.add(retval)
+        for node in source_nodes:
+            del self.__map[node.merged_ssa_val]
+        self.__map[final_merged_ssa_val] = retval
+        for ssa_val in final_merged_ssa_val.ssa_vals:
+            self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
+        return retval
 
     def __repr__(self):
         # type: () -> str
@@ -387,7 +406,7 @@ class IGNode:
     """ interference graph node """
     __slots__ = "merged_ssa_val", "edges", "loc"
 
-    def __init__(self, merged_ssa_val, edges=(), loc=None):
+    def __init__(self, merged_ssa_val, edges, loc):
         # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
         self.merged_ssa_val = merged_ssa_val
         self.edges = OSet(edges)
@@ -405,6 +424,7 @@ class IGNode:
         return NotImplemented
 
     def __hash__(self):
+        # type: () -> int
         return hash(self.merged_ssa_val)
 
     def __repr__(self, nodes=None):
@@ -433,42 +453,36 @@ class IGNode:
         return False
 
 
-@plain_data()
-class AllocationFailed:
-    __slots__ = "node", "merged_ssa_vals", "interference_graph"
-
-    def __init__(self, node, merged_ssa_vals, interference_graph):
-        # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None
-        super().__init__()
+class AllocationFailedError(Exception):
+    def __init__(self, msg, node, interference_graph):
+        # type: (str, IGNode, InterferenceGraph) -> None
+        super().__init__(msg, node, interference_graph)
         self.node = node
-        self.merged_ssa_vals = merged_ssa_vals
         self.interference_graph = interference_graph
 
 
-class AllocationFailedError(Exception):
-    def __init__(self, msg, allocation_failed):
-        # type: (str, AllocationFailed) -> None
-        super().__init__(msg, allocation_failed)
-        self.allocation_failed = allocation_failed
+def allocate_registers(fn):
+    # type: (Fn) -> dict[SSAVal, Loc]
 
+    # inserts enough copies that no manual spilling is necessary, all
+    # spilling is done by the register allocator naturally allocating SSAVals
+    # to stack slots
+    fn.pre_ra_insert_copies()
 
-def try_allocate_registers_without_spilling(merged_ssa_vals):
-    # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed
+    fn_analysis = FnAnalysis(fn)
+    interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
 
-    interference_graph = {
-        i: IGNode(i) for i in merged_ssa_vals.merged_ssa_vals}
-    fn_analysis = merged_ssa_vals.fn_analysis
     for ssa_vals in fn_analysis.live_at.values():
         live_merged_ssa_vals = OSet()  # type: OSet[MergedSSAVal]
         for ssa_val in ssa_vals:
-            live_merged_ssa_vals.add(merged_ssa_vals.merge_map[ssa_val])
+            live_merged_ssa_vals.add(
+                interference_graph.merged_ssa_val_map[ssa_val])
         for i, j in combinations(live_merged_ssa_vals, 2):
             if i.loc_set.max_conflicts_with(j.loc_set) != 0:
-                interference_graph[i].add_edge(interference_graph[j])
-
-    nodes_remaining = OSet(interference_graph.values())
+                interference_graph.nodes[i].add_edge(
+                    interference_graph.nodes[j])
 
-# FIXME: work on code from here
+    nodes_remaining = OSet(interference_graph.nodes.values())
 
     def local_colorability_score(node):
         # type: (IGNode) -> int
@@ -481,9 +495,11 @@ def try_allocate_registers_without_spilling(merged_ssa_vals):
         retval = len(node.loc_set)
         for neighbor in node.edges:
             if neighbor in nodes_remaining:
-                retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
+                retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
         return retval
 
+    # TODO: implement copy-merging
+
     node_stack = []  # type: list[IGNode]
     while True:
         best_node = None  # type: None | IGNode
@@ -502,39 +518,29 @@ def try_allocate_registers_without_spilling(merged_ssa_vals):
         node_stack.append(best_node)
         nodes_remaining.remove(best_node)
 
-    retval = {}  # type: dict[SSAVal, RegLoc]
+    retval = {}  # type: dict[SSAVal, Loc]
 
     while len(node_stack) > 0:
         node = node_stack.pop()
-        if node.reg is not None:
-            if node.reg_conflicts_with_neighbors(node.reg):
-                return AllocationFailed(node=node,
-                                        live_intervals=live_intervals,
-                                        interference_graph=interference_graph)
+        if node.loc is not None:
+            if node.loc_conflicts_with_neighbors(node.loc):
+                raise AllocationFailedError(
+                    "IGNode is pre-allocated to a conflicting Loc",
+                    node=node, interference_graph=interference_graph)
         else:
             # pick the first non-conflicting register in node.reg_class, since
             # register classes are ordered from most preferred to least
             # preferred register.
-            for reg in node.reg_class:
-                if not node.reg_conflicts_with_neighbors(reg):
-                    node.reg = reg
+            for loc in node.loc_set:
+                if not node.loc_conflicts_with_neighbors(loc):
+                    node.loc = loc
                     break
-            if node.reg is None:
-                return AllocationFailed(node=node,
-                                        live_intervals=live_intervals,
-                                        interference_graph=interference_graph)
-
-        for ssa_val, offset in node.merged_reg_set.items():
-            retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
-
-    return retval
+            if node.loc is None:
+                raise AllocationFailedError(
+                    "failed to allocate Loc for IGNode",
+                    node=node, interference_graph=interference_graph)
 
+        for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
+            retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
 
-def allocate_registers(ops):
-    # type: (list[Op]) -> dict[SSAVal, RegLoc]
-    retval = try_allocate_registers_without_spilling(ops)
-    if isinstance(retval, AllocationFailed):
-        # TODO: implement spilling
-        raise AllocationFailedError(
-            "spilling required but not yet implemented", retval)
     return retval