register_allocator2.py works!
[bigint-presentation-code.git] / src / bigint_presentation_code / compiler_ir2.py
index 9f911969c7f96235d472f7fb7b15cac83c9491a9..b256a07b7a58868bdcce9d203cdf96389fcd6dc6 100644 (file)
@@ -4,7 +4,7 @@ from enum import Enum, unique
 from functools import lru_cache, total_ordering
 from io import StringIO
 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
-                    Mapping, Sequence, TypeVar, overload)
+                    Mapping, Sequence, TypeVar, Union, overload)
 from weakref import WeakValueDictionary as _WeakVDict
 
 from cached_property import cached_property
@@ -59,10 +59,16 @@ class Fn:
         for op in self.ops:
             op.pre_ra_sim(state)
 
+    def gen_asm(self, state):
+        # type: (GenAsmState) -> None
+        for op in self.ops:
+            op.gen_asm(state)
+
     def pre_ra_insert_copies(self):
         # type: () -> None
         orig_ops = list(self.ops)
         copied_outputs = {}  # type: dict[SSAVal, SSAVal]
+        setvli_outputs = {}  # type: dict[SSAVal, Op]
         self.ops.clear()
         for op in orig_ops:
             for i in range(len(op.input_vals)):
@@ -84,12 +90,22 @@ class Fn:
                     op.input_vals[i] = mv.outputs[0]
                 elif inp.ty.base_ty is BaseTy.CA \
                         or inp.ty.base_ty is BaseTy.VL_MAXVL:
-                    # all copies would be no-ops, so we don't need to copy
+                    # all copies would be no-ops, so we don't need to copy,
+                    # though we do need to rematerialize SetVLI ops right
+                    # before the ops VL
+                    if inp in setvli_outputs:
+                        setvl = self.append_new_op(
+                            OpKind.SetVLI,
+                            immediates=setvli_outputs[inp].immediates,
+                            name=f"{op.name}.inp{i}.setvl")
+                        inp = setvl.outputs[0]
                     op.input_vals[i] = inp
                 else:
                     assert_never(inp.ty.base_ty)
             self.ops.append(op)
             for i, out in enumerate(op.outputs):
+                if op.kind is OpKind.SetVLI:
+                    setvli_outputs[out] = op
                 if out.ty.base_ty is BaseTy.I64:
                     maxvl = out.ty.reg_len
                     if out.ty.reg_len != 1:
@@ -263,7 +279,7 @@ class ProgramRange(Sequence[ProgramPoint]):
         return f"<range:{start}..{stop}>"
 
 
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
 @final
 class FnAnalysis:
     __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
@@ -335,6 +351,10 @@ class FnAnalysis:
         # type: () -> int
         return hash(self.fn)
 
+    def __repr__(self):
+        # type: () -> str
+        return "<FnAnalysis>"
+
 
 @unique
 @final
@@ -425,7 +445,7 @@ class LocKind(Enum):
     def loc_count(self):
         # type: () -> int
         if self is LocKind.StackI64:
-            return 1024
+            return 512
         if self is LocKind.GPR or self is LocKind.CA \
                 or self is LocKind.VL_MAXVL:
             return self.base_ty.max_reg_len
@@ -606,11 +626,11 @@ class Loc:
         # type: (Ty, int) -> Loc
         if subloc_ty.base_ty != self.kind.base_ty:
             raise ValueError("BaseTy mismatch")
-        start = self.start + offset
-        if offset < 0 or start + subloc_ty.reg_len > self.reg_len:
+        if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
             raise ValueError("invalid sub-Loc: offset and/or "
                              "subloc_ty.reg_len out of range")
-        return Loc(kind=self.kind, start=start, reg_len=subloc_ty.reg_len)
+        return Loc(kind=self.kind,
+                   start=self.start + offset, reg_len=subloc_ty.reg_len)
 
 
 SPECIAL_GPRS = (
@@ -726,13 +746,26 @@ class LocSet(AbstractSet[Loc]):
     def __len__(self):
         return self.__len
 
+    __HASHES = {}  # type: dict[tuple[Ty | None, FMap[LocKind, FBitSet]], int]
+
     @cached_property
     def __hash(self):
-        return super()._hash()
+        # cache hashes to avoid slow LocSet iteration
+        key = self.ty, self.starts
+        retval = self.__HASHES.get(key, None)
+        if retval is None:
+            self.__HASHES[key] = retval = super(LocSet, self)._hash()
+        return retval
 
     def __hash__(self):
         return self.__hash
 
+    def __eq__(self, __other):
+        # type: (LocSet | Any) -> bool
+        if isinstance(__other, LocSet):
+            return self.ty == __other.ty and self.starts == __other.starts
+        return super().__eq__(__other)
+
     @lru_cache(maxsize=None, typed=True)
     def max_conflicts_with(self, other):
         # type: (LocSet | Loc) -> int
@@ -861,7 +894,7 @@ class OperandDesc:
             raise ValueError("loc_set_before_spread must not be empty")
         self.loc_set_before_spread = loc_set_before_spread
         self.tied_input_index = tied_input_index
-        if self.tied_input_index is not None and self.spread_index is not None:
+        if self.tied_input_index is not None and spread_index is not None:
             raise ValueError("operand can't be both spread and tied")
         self.spread_index = spread_index
         self.write_stage = write_stage
@@ -1121,7 +1154,7 @@ class OpKind(Enum):
     SvAddE = GenericOpProperties(
         demo_asm="sv.adde *RT, *RA, *RB",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
-        outputs=[OD_EXTRA3_VGPR, OD_CA],
+        outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
     )
     _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
     _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
@@ -1151,7 +1184,7 @@ class OpKind(Enum):
     SvSubFE = GenericOpProperties(
         demo_asm="sv.subfe *RT, *RA, *RB",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
-        outputs=[OD_EXTRA3_VGPR, OD_CA],
+        outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
     )
     _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
     _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
@@ -1258,26 +1291,50 @@ class OpKind(Enum):
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
     @staticmethod
-    def __veccopytoreg_gen_asm(op, state):
-        # type: (Op, GenAsmState) -> None
-        src_loc = state.loc(op.input_vals[0], (LocKind.GPR, LocKind.StackI64))
-        dest_loc = state.loc(op.outputs[0], LocKind.GPR)
-        RT = state.vgpr(dest_loc)
+    def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
+        # type: (Loc, Loc, bool, GenAsmState) -> None
+        sv = "sv." if is_vec else ""
+        rev = ""
+        if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
+            rev = "/mrr"
         if src_loc == dest_loc:
             return  # no-op
-        assert src_loc.kind in (LocKind.GPR, LocKind.StackI64), \
-            "checked by loc()"
+        if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+            raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
+        if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+            raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
         if src_loc.kind is LocKind.StackI64:
+            if dest_loc.kind is LocKind.StackI64:
+                raise ValueError(
+                    f"can't copy from stack to stack: {src_loc} {dest_loc}")
+            elif dest_loc.kind is not LocKind.GPR:
+                assert_never(dest_loc.kind)
             src = state.stack(src_loc)
-            state.writeln(f"sv.ld {RT}, {src}")
-            return
-        elif src_loc.kind is not LocKind.GPR:
+            dest = state.gpr(dest_loc, is_vec=is_vec)
+            state.writeln(f"{sv}ld {dest}, {src}")
+        elif dest_loc.kind is LocKind.StackI64:
+            if src_loc.kind is not LocKind.GPR:
+                assert_never(src_loc.kind)
+            src = state.gpr(src_loc, is_vec=is_vec)
+            dest = state.stack(dest_loc)
+            state.writeln(f"{sv}std {src}, {dest}")
+        elif src_loc.kind is LocKind.GPR:
+            if dest_loc.kind is not LocKind.GPR:
+                assert_never(dest_loc.kind)
+            src = state.gpr(src_loc, is_vec=is_vec)
+            dest = state.gpr(dest_loc, is_vec=is_vec)
+            state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
+        else:
             assert_never(src_loc.kind)
-        rev = ""
-        if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
-            rev = "/mrr"
-        src = state.vgpr(src_loc)
-        state.writeln(f"sv.or{rev} {RT}, {src}, {src}")
+
+    @staticmethod
+    def __veccopytoreg_gen_asm(op, state):
+        # type: (Op, GenAsmState) -> None
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(
+                op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+            dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+            is_vec=True, state=state)
 
     VecCopyToReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
@@ -1296,11 +1353,14 @@ class OpKind(Enum):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __veccopyfromreg_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+            dest_loc=state.loc(
+                op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+            is_vec=True, state=state)
     VecCopyFromReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[OD_EXTRA3_VGPR, OD_VL],
@@ -1319,11 +1379,14 @@ class OpKind(Enum):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __copytoreg_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(
+                op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+            dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+            is_vec=False, state=state)
     CopyToReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1346,11 +1409,14 @@ class OpKind(Enum):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __copyfromreg_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+            dest_loc=state.loc(
+                op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+            is_vec=False, state=state)
     CopyFromReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1374,11 +1440,13 @@ class OpKind(Enum):
         state.ssa_vals[op.outputs[0]] = tuple(
             state.ssa_vals[i][0] for i in op.input_vals[:-1])
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __concat_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
+            dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+            is_vec=True, state=state)
     Concat = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1398,11 +1466,13 @@ class OpKind(Enum):
         for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
             state.ssa_vals[op.outputs[idx]] = inp,
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __spread_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        OpKind.__copy_to_from_reg_gen_asm(
+            src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+            dest_loc=state.loc(op.outputs, LocKind.GPR),
+            is_vec=True, state=state)
     Spread = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[OD_EXTRA3_VGPR, OD_VL],
@@ -1429,11 +1499,13 @@ class OpKind(Enum):
             RT.append(v & GPR_VALUE_MASK)
         state.ssa_vals[op.outputs[0]] = tuple(RT)
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __svld_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RA = state.sgpr(op.input_vals[0])
+        RT = state.vgpr(op.outputs[0])
+        imm = op.immediates[0]
+        state.writeln(f"sv.ld {RT}, {imm}({RA})")
     SvLd = GenericOpProperties(
         demo_asm="sv.ld *RT, imm(RA)",
         inputs=[OD_EXTRA3_SGPR, OD_VL],
@@ -1451,11 +1523,13 @@ class OpKind(Enum):
         v = state.load(addr)
         state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __ld_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RA = state.sgpr(op.input_vals[0])
+        RT = state.sgpr(op.outputs[0])
+        imm = op.immediates[0]
+        state.writeln(f"ld {RT}, {imm}({RA})")
     Ld = GenericOpProperties(
         demo_asm="ld RT, imm(RA)",
         inputs=[OD_BASE_SGPR],
@@ -1475,11 +1549,13 @@ class OpKind(Enum):
         for i in range(VL):
             state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __svstd_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RS = state.vgpr(op.input_vals[0])
+        RA = state.sgpr(op.input_vals[1])
+        imm = op.immediates[0]
+        state.writeln(f"sv.std {RS}, {imm}({RA})")
     SvStd = GenericOpProperties(
         demo_asm="sv.std *RS, imm(RA)",
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
@@ -1498,13 +1574,15 @@ class OpKind(Enum):
         addr = RA + op.immediates[0]
         state.store(addr, value=RS)
 
-    # FIXME: change to correct __*_gen_asm function
     @staticmethod
-    def __clearca_gen_asm(op, state):
+    def __std_gen_asm(op, state):
         # type: (Op, GenAsmState) -> None
-        state.writeln("addic 0, 0, 0")
+        RS = state.sgpr(op.input_vals[0])
+        RA = state.sgpr(op.input_vals[1])
+        imm = op.immediates[0]
+        state.writeln(f"std {RS}, {imm}({RA})")
     Std = GenericOpProperties(
-        demo_asm="std RT, imm(RA)",
+        demo_asm="std RS, imm(RA)",
         inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
         outputs=[],
         immediates=[IMM_S16],
@@ -1877,6 +1955,15 @@ class Op:
                     f"expected {out.ty.reg_len} found "
                     f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
 
+    def gen_asm(self, state):
+        # type: (GenAsmState) -> None
+        all_loc_kinds = tuple(LocKind)
+        for inp in self.input_vals:
+            state.loc(inp, expected_kinds=all_loc_kinds)
+        for out in self.outputs:
+            state.loc(out, expected_kinds=all_loc_kinds)
+        self.kind.gen_asm(self, state)
+
 
 GPR_SIZE_IN_BYTES = 8
 BITS_IN_BYTE = 8
@@ -1994,7 +2081,7 @@ class PreRASimState:
 class GenAsmState:
     __slots__ = "allocated_locs", "output"
 
-    def __init__(self, allocated_locs, output):
+    def __init__(self, allocated_locs, output=None):
         # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
         super().__init__()
         self.allocated_locs = FMap(allocated_locs)
@@ -2006,38 +2093,49 @@ class GenAsmState:
             output = []
         self.output = output
 
-    def loc(self, ssa_val_or_loc, expected_kinds):
-        # type: (SSAVal | Loc, LocKind | tuple[LocKind, ...]) -> Loc
-        if isinstance(ssa_val_or_loc, SSAVal):
-            retval = self.allocated_locs[ssa_val_or_loc]
-        else:
-            retval = ssa_val_or_loc
+    __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
+
+    def loc(self, ssa_val_or_locs, expected_kinds):
+        # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
+        if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
+            ssa_val_or_locs = [ssa_val_or_locs]
+        locs = []  # type: list[Loc]
+        for i in ssa_val_or_locs:
+            if isinstance(i, SSAVal):
+                locs.append(self.allocated_locs[i])
+            else:
+                locs.append(i)
+        if len(locs) == 0:
+            raise ValueError("invalid Loc sequence: must not be empty")
+        retval = locs[0].try_concat(*locs[1:])
+        if retval is None:
+            raise ValueError("invalid Loc sequence: try_concat failed")
         if isinstance(expected_kinds, LocKind):
             expected_kinds = expected_kinds,
         if retval.kind not in expected_kinds:
             if len(expected_kinds) == 1:
                 expected_kinds = expected_kinds[0]
-            raise ValueError(f"LocKind mismatch: {ssa_val_or_loc}: found "
+            raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
                              f"{retval.kind} expected {expected_kinds}")
         return retval
 
-    def gpr(self, ssa_val_or_loc, is_vec):
-        # type: (SSAVal | Loc, bool) -> str
-        loc = self.loc(ssa_val_or_loc, LocKind.GPR)
+    def gpr(self, ssa_val_or_locs, is_vec):
+        # type: (__SSA_VAL_OR_LOCS, bool) -> str
+        loc = self.loc(ssa_val_or_locs, LocKind.GPR)
         vec_str = "*" if is_vec else ""
         return vec_str + str(loc.start)
 
-    def sgpr(self, ssa_val_or_loc):
-        # type: (SSAVal | Loc) -> str
-        return self.gpr(ssa_val_or_loc, is_vec=False)
+    def sgpr(self, ssa_val_or_locs):
+        # type: (__SSA_VAL_OR_LOCS) -> str
+        return self.gpr(ssa_val_or_locs, is_vec=False)
 
-    def vgpr(self, ssa_val_or_loc):
-        # type: (SSAVal | Loc) -> str
-        return self.gpr(ssa_val_or_loc, is_vec=True)
+    def vgpr(self, ssa_val_or_locs):
+        # type: (__SSA_VAL_OR_LOCS) -> str
+        return self.gpr(ssa_val_or_locs, is_vec=True)
 
-    def stack(self, ssa_val_or_loc):
-        # type: (SSAVal | Loc) -> str
-        loc = self.loc(ssa_val_or_loc, LocKind.StackI64)
+    def stack(self, ssa_val_or_locs):
+        # type: (__SSA_VAL_OR_LOCS) -> str
+        loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
         return f"{loc.start}(1)"
 
     def writeln(self, *line_segments):