From 61e2f7323558017348df62687c85e6c2746f785a Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 2 Nov 2022 00:32:00 -0700 Subject: [PATCH] working on code --- .../_tests/test_compiler_ir2.py | 175 ++++++++---- src/bigint_presentation_code/compiler_ir2.py | 217 +++++++++------ .../register_allocator2.py | 255 +++++++++++++----- 3 files changed, 447 insertions(+), 200 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir2.py b/src/bigint_presentation_code/_tests/test_compiler_ir2.py index 4aa8e3e..40f02fa 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir2.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir2.py @@ -17,54 +17,66 @@ class TestCompilerIR(unittest.TestCase): op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl") vl = op1.outputs[0] op2 = fn.append_new_op( - OpKind.SvLd, inputs=[arg, vl], immediates=[0], maxvl=MAXVL, + OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL, name="ld") a = op2.outputs[0] - op3 = fn.append_new_op( - OpKind.SvLI, inputs=[vl], immediates=[0], maxvl=MAXVL, name="li") + op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], + maxvl=MAXVL, name="li") b = op3.outputs[0] op4 = fn.append_new_op(OpKind.SetCA, name="ca") ca = op4.outputs[0] op5 = fn.append_new_op( - OpKind.SvAddE, inputs=[a, b, ca, vl], maxvl=MAXVL, name="add") + OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add") s = op5.outputs[0] - fn.append_new_op( - OpKind.SvStd, inputs=[s, arg, vl], immediates=[0], maxvl=MAXVL, - name="st") + fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl], + immediates=[0], maxvl=MAXVL, name="st") return fn, arg def test_repr(self): fn, _arg = self.make_add_fn() self.assertEqual([repr(i) for i in fn.ops], [ "Op(kind=OpKind.FuncArgR3, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[], " "outputs=(>,), name='arg')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " "outputs=(>,), name='vl')", "Op(kind=OpKind.SvLd, " - "inputs=[>, >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[0], " "outputs=(>,), name='ld')", "Op(kind=OpKind.SvLI, " - "inputs=[>], " + "input_vals=[>], " + "input_uses=(>,), " "immediates=[0], " "outputs=(>,), name='li')", "Op(kind=OpKind.SetCA, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[], " "outputs=(>,), name='ca')", "Op(kind=OpKind.SvAddE, " - "inputs=[>, >, " - ">, >], " + "input_vals=[>, " + ">, >, " + ">], " + "input_uses=(>, " + ">, >, " + ">), " "immediates=[], " "outputs=(>, >), " "name='add')", "Op(kind=OpKind.SvStd, " - "inputs=[>, >, " + "input_vals=[>, >, " ">], " + "input_uses=(>, " + ">, >), " "immediates=[0], " "outputs=(), name='st')", ]) @@ -150,95 +162,148 @@ class TestCompilerIR(unittest.TestCase): fn.pre_ra_insert_copies() self.assertEqual([repr(i) for i in fn.ops], [ "Op(kind=OpKind.FuncArgR3, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[], " "outputs=(>,), name='arg')", "Op(kind=OpKind.CopyFromReg, " - "inputs=[>], " + "input_vals=[>], " + "input_uses=(>,), " "immediates=[], " - "outputs=(<2.outputs[0]: >,), name='2')", + "outputs=(>,), " + "name='arg.out0.copy')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " "outputs=(>,), name='vl')", "Op(kind=OpKind.CopyToReg, " - "inputs=[<2.outputs[0]: >], " + "input_vals=[>], " + "input_uses=(>,), " "immediates=[], " - "outputs=(<3.outputs[0]: >,), name='3')", + "outputs=(>,), name='ld.inp0.copy')", "Op(kind=OpKind.SvLd, " - "inputs=[<3.outputs[0]: >, >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[0], " "outputs=(>,), name='ld')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " - "outputs=(<4.outputs[0]: >,), name='4')", + "outputs=(>,), " + "name='ld.out0.setvl')", "Op(kind=OpKind.VecCopyFromReg, " - "inputs=[>, <4.outputs[0]: >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[], " - "outputs=(<5.outputs[0]: >,), name='5')", + "outputs=(>,), " + "name='ld.out0.copy')", "Op(kind=OpKind.SvLI, " - "inputs=[>], " + "input_vals=[>], " + "input_uses=(>,), " "immediates=[0], " "outputs=(>,), name='li')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " - "outputs=(<6.outputs[0]: >,), name='6')", + "outputs=(>,), " + "name='li.out0.setvl')", "Op(kind=OpKind.VecCopyFromReg, " - "inputs=[>, <6.outputs[0]: >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[], " - "outputs=(<7.outputs[0]: >,), name='7')", + "outputs=(>,), " + "name='li.out0.copy')", "Op(kind=OpKind.SetCA, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[], " "outputs=(>,), name='ca')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " - "outputs=(<8.outputs[0]: >,), name='8')", + "outputs=(>,), " + "name='add.inp0.setvl')", "Op(kind=OpKind.VecCopyToReg, " - "inputs=[<5.outputs[0]: >, <8.outputs[0]: >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[], " - "outputs=(<9.outputs[0]: >,), name='9')", + "outputs=(>,), " + "name='add.inp0.copy')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " - "outputs=(<10.outputs[0]: >,), name='10')", + "outputs=(>,), " + "name='add.inp1.setvl')", "Op(kind=OpKind.VecCopyToReg, " - "inputs=[<7.outputs[0]: >, <10.outputs[0]: >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[], " - "outputs=(<11.outputs[0]: >,), name='11')", + "outputs=(>,), " + "name='add.inp1.copy')", "Op(kind=OpKind.SvAddE, " - "inputs=[<9.outputs[0]: >, <11.outputs[0]: >, " - ">, >], " + "input_vals=[>, " + ">, >, " + ">], " + "input_uses=(>, " + ">, >, " + ">), " "immediates=[], " "outputs=(>, >), " "name='add')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " - "outputs=(<12.outputs[0]: >,), name='12')", + "outputs=(>,), " + "name='add.out0.setvl')", "Op(kind=OpKind.VecCopyFromReg, " - "inputs=[>, " - "<12.outputs[0]: >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[], " - "outputs=(<13.outputs[0]: >,), name='13')", + "outputs=(>,), " + "name='add.out0.copy')", "Op(kind=OpKind.SetVLI, " - "inputs=[], " + "input_vals=[], " + "input_uses=(), " "immediates=[32], " - "outputs=(<14.outputs[0]: >,), name='14')", + "outputs=(>,), " + "name='st.inp0.setvl')", "Op(kind=OpKind.VecCopyToReg, " - "inputs=[<13.outputs[0]: >, <14.outputs[0]: >], " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " "immediates=[], " - "outputs=(<15.outputs[0]: >,), name='15')", + "outputs=(>,), " + "name='st.inp0.copy')", "Op(kind=OpKind.CopyToReg, " - "inputs=[<2.outputs[0]: >], " + "input_vals=[>], " + "input_uses=(>,), " "immediates=[], " - "outputs=(<16.outputs[0]: >,), name='16')", + "outputs=(>,), " + "name='st.inp1.copy')", "Op(kind=OpKind.SvStd, " - "inputs=[<15.outputs[0]: >, <16.outputs[0]: >, " - ">], " + "input_vals=[>, " + ">, >], " + "input_uses=(>, " + ">, >), " "immediates=[0], " "outputs=(), name='st')", ]) diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index 8e36509..03b623d 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -44,10 +44,11 @@ class Fn: raise ValueError("can't add Op to wrong Fn") self.ops.append(op) - def append_new_op(self, kind, inputs=(), immediates=(), name="", maxvl=1): + def append_new_op(self, kind, input_vals=(), immediates=(), name="", + maxvl=1): # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl), - inputs=inputs, immediates=immediates, name=name) + input_vals=input_vals, immediates=immediates, name=name) self.append_op(retval) return retval @@ -62,38 +63,45 @@ class Fn: copied_outputs = {} # type: dict[SSAVal, SSAVal] self.ops.clear() for op in orig_ops: - for i in range(len(op.inputs)): - inp = copied_outputs[op.inputs[i]] + for i in range(len(op.input_vals)): + inp = copied_outputs[op.input_vals[i]] if inp.ty.base_ty is BaseTy.I64: maxvl = inp.ty.reg_len if inp.ty.reg_len != 1: - setvl = self.append_new_op(OpKind.SetVLI, - immediates=[maxvl]) + setvl = self.append_new_op( + OpKind.SetVLI, immediates=[maxvl], + name=f"{op.name}.inp{i}.setvl") vl = setvl.outputs[0] - mv = self.append_new_op(OpKind.VecCopyToReg, - inputs=[inp, vl], maxvl=maxvl) + mv = self.append_new_op( + OpKind.VecCopyToReg, input_vals=[inp, vl], + maxvl=maxvl, name=f"{op.name}.inp{i}.copy") else: - mv = self.append_new_op(OpKind.CopyToReg, inputs=[inp]) - op.inputs[i] = mv.outputs[0] + mv = self.append_new_op( + OpKind.CopyToReg, input_vals=[inp], + name=f"{op.name}.inp{i}.copy") + op.input_vals[i] = mv.outputs[0] elif inp.ty.base_ty is BaseTy.CA \ or inp.ty.base_ty is BaseTy.VL_MAXVL: # all copies would be no-ops, so we don't need to copy - op.inputs[i] = inp + op.input_vals[i] = inp else: assert_never(inp.ty.base_ty) self.ops.append(op) - for out in op.outputs: + for i, out in enumerate(op.outputs): if out.ty.base_ty is BaseTy.I64: maxvl = out.ty.reg_len if out.ty.reg_len != 1: - setvl = self.append_new_op(OpKind.SetVLI, - immediates=[maxvl]) + setvl = self.append_new_op( + OpKind.SetVLI, immediates=[maxvl], + name=f"{op.name}.out{i}.setvl") vl = setvl.outputs[0] - mv = self.append_new_op(OpKind.VecCopyFromReg, - inputs=[out, vl], maxvl=maxvl) + mv = self.append_new_op( + OpKind.VecCopyFromReg, input_vals=[out, vl], + maxvl=maxvl, name=f"{op.name}.out{i}.copy") else: - mv = self.append_new_op(OpKind.CopyFromReg, - inputs=[out]) + mv = self.append_new_op( + OpKind.CopyFromReg, input_vals=[out], + name=f"{op.name}.out{i}.copy") copied_outputs[out] = mv.outputs[0] elif out.ty.base_ty is BaseTy.CA \ or out.ty.base_ty is BaseTy.VL_MAXVL: @@ -113,7 +121,7 @@ class FnWithUses: self.fn = fn retval = {} # type: dict[SSAVal, OSet[SSAUse]] for op in fn.ops: - for idx, inp in enumerate(op.inputs): + for idx, inp in enumerate(op.input_vals): retval[inp].add(SSAUse(op, idx)) for out in op.outputs: retval[out] = OSet() @@ -581,6 +589,7 @@ class GenericOperandDesc: def instantiate(self, maxvl): # type: (int) -> Iterable[OperandDesc] + # assumes all spread operands have ty.reg_len = 1 rep_count = 1 if self.spread: rep_count = maxvl @@ -636,9 +645,23 @@ class OperandDesc: def ty(self): """ Ty after any spread is applied """ if self.spread_index is not None: + # assumes all spread operands have ty.reg_len = 1 return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1) return self.ty_before_spread + @property + def reg_offset_in_unspread(self): + """ the number of reg-sized slots in the unspread Loc before self's Loc + + e.g. if the unspread Loc containing self is: + `Loc(kind=LocKind.GPR, start=8, reg_len=4)` + and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)` + then reg_offset_into_unspread == 2 == 10 - 8 + """ + if self.spread_index is None: + return 0 + return self.spread_index * self.ty.reg_len + OD_BASE_SGPR = GenericOperandDesc( ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), @@ -816,10 +839,10 @@ class OpKind(Enum): @staticmethod def __svadde_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - RA = state.ssa_vals[op.inputs[0]] - RB = state.ssa_vals[op.inputs[1]] - carry, = state.ssa_vals[op.inputs[2]] - VL, = state.ssa_vals[op.inputs[3]] + RA = state.ssa_vals[op.input_vals[0]] + RB = state.ssa_vals[op.input_vals[1]] + carry, = state.ssa_vals[op.input_vals[2]] + VL, = state.ssa_vals[op.input_vals[3]] RT = [] # type: list[int] for i in range(VL): v = RA[i] + RB[i] + carry @@ -837,10 +860,10 @@ class OpKind(Enum): @staticmethod def __svsubfe_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - RA = state.ssa_vals[op.inputs[0]] - RB = state.ssa_vals[op.inputs[1]] - carry, = state.ssa_vals[op.inputs[2]] - VL, = state.ssa_vals[op.inputs[3]] + RA = state.ssa_vals[op.input_vals[0]] + RB = state.ssa_vals[op.input_vals[1]] + carry, = state.ssa_vals[op.input_vals[2]] + VL, = state.ssa_vals[op.input_vals[3]] RT = [] # type: list[int] for i in range(VL): v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry @@ -858,10 +881,10 @@ class OpKind(Enum): @staticmethod def __svmaddedu_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - RA = state.ssa_vals[op.inputs[0]] - RB, = state.ssa_vals[op.inputs[1]] - carry, = state.ssa_vals[op.inputs[2]] - VL, = state.ssa_vals[op.inputs[3]] + RA = state.ssa_vals[op.input_vals[0]] + RB, = state.ssa_vals[op.input_vals[1]] + carry, = state.ssa_vals[op.input_vals[2]] + VL, = state.ssa_vals[op.input_vals[3]] RT = [] # type: list[int] for i in range(VL): v = RA[i] * RB + carry @@ -892,7 +915,7 @@ class OpKind(Enum): @staticmethod def __svli_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - VL, = state.ssa_vals[op.inputs[0]] + VL, = state.ssa_vals[op.input_vals[0]] imm = op.immediates[0] & GPR_VALUE_MASK state.ssa_vals[op.outputs[0]] = (imm,) * VL SvLI = GenericOpProperties( @@ -921,7 +944,7 @@ class OpKind(Enum): @staticmethod def __veccopytoreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]] + state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] VecCopyToReg = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[GenericOperandDesc( @@ -936,7 +959,7 @@ class OpKind(Enum): @staticmethod def __veccopyfromreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]] + state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] VecCopyFromReg = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[OD_EXTRA3_VGPR, OD_VL], @@ -951,7 +974,7 @@ class OpKind(Enum): @staticmethod def __copytoreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]] + state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] CopyToReg = GenericOpProperties( demo_asm="mv dest, src", inputs=[GenericOperandDesc( @@ -970,7 +993,7 @@ class OpKind(Enum): @staticmethod def __copyfromreg_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]] + state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]] CopyFromReg = GenericOpProperties( demo_asm="mv dest, src", inputs=[GenericOperandDesc( @@ -990,7 +1013,7 @@ class OpKind(Enum): def __concat_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None state.ssa_vals[op.outputs[0]] = tuple( - state.ssa_vals[i][0] for i in op.inputs[:-1]) + state.ssa_vals[i][0] for i in op.input_vals[:-1]) Concat = GenericOpProperties( demo_asm="sv.mv dest, src", inputs=[GenericOperandDesc( @@ -1006,7 +1029,7 @@ class OpKind(Enum): @staticmethod def __spread_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - for idx, inp in enumerate(state.ssa_vals[op.inputs[0]]): + for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]): state.ssa_vals[op.outputs[idx]] = inp, Spread = GenericOpProperties( demo_asm="sv.mv dest, src", @@ -1023,8 +1046,8 @@ class OpKind(Enum): @staticmethod def __svld_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - RA, = state.ssa_vals[op.inputs[0]] - VL, = state.ssa_vals[op.inputs[1]] + RA, = state.ssa_vals[op.input_vals[0]] + VL, = state.ssa_vals[op.input_vals[1]] addr = RA + op.immediates[0] RT = [] # type: list[int] for i in range(VL): @@ -1042,7 +1065,7 @@ class OpKind(Enum): @staticmethod def __ld_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - RA, = state.ssa_vals[op.inputs[0]] + RA, = state.ssa_vals[op.input_vals[0]] addr = RA + op.immediates[0] v = state.load(addr) state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK, @@ -1057,9 +1080,9 @@ class OpKind(Enum): @staticmethod def __svstd_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - RS = state.ssa_vals[op.inputs[0]] - RA, = state.ssa_vals[op.inputs[1]] - VL, = state.ssa_vals[op.inputs[2]] + RS = state.ssa_vals[op.input_vals[0]] + RA, = state.ssa_vals[op.input_vals[1]] + VL, = state.ssa_vals[op.input_vals[2]] addr = RA + op.immediates[0] for i in range(VL): state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i]) @@ -1075,8 +1098,8 @@ class OpKind(Enum): @staticmethod def __std_pre_ra_sim(op, state): # type: (Op, PreRASimState) -> None - RS, = state.ssa_vals[op.inputs[0]] - RA, = state.ssa_vals[op.inputs[1]] + RS, = state.ssa_vals[op.input_vals[0]] + RA, = state.ssa_vals[op.input_vals[1]] addr = RA + op.immediates[0] state.store(addr, value=RS) Std = GenericOpProperties( @@ -1103,11 +1126,14 @@ class OpKind(Enum): @plain_data(frozen=True, unsafe_hash=True, repr=False) class SSAValOrUse(metaclass=ABCMeta): - __slots__ = "op", + __slots__ = "op", "operand_idx" - def __init__(self, op): - # type: (Op) -> None + def __init__(self, op, operand_idx): + # type: (Op, int) -> None self.op = op + if operand_idx < 0 or operand_idx >= len(self.descriptor_array): + raise ValueError("invalid operand_idx") + self.operand_idx = operand_idx @abstractmethod def __repr__(self): @@ -1116,9 +1142,14 @@ class SSAValOrUse(metaclass=ABCMeta): @property @abstractmethod + def descriptor_array(self): + # type: () -> tuple[OperandDesc, ...] + ... + + @property def defining_descriptor(self): # type: () -> OperandDesc - ... + return self.descriptor_array[self.operand_idx] @cached_property def ty(self): @@ -1135,22 +1166,36 @@ class SSAValOrUse(metaclass=ABCMeta): # type: () -> BaseTy return self.ty_before_spread.base_ty + @property + def reg_offset_in_unspread(self): + """ the number of reg-sized slots in the unspread Loc before self's Loc + + e.g. if the unspread Loc containing self is: + `Loc(kind=LocKind.GPR, start=8, reg_len=4)` + and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)` + then reg_offset_into_unspread == 2 == 10 - 8 + """ + return self.defining_descriptor.reg_offset_in_unspread + + @property + def unspread_start_idx(self): + # type: () -> int + return self.operand_idx - (self.defining_descriptor.spread_index or 0) + + @property + def unspread_start(self): + # type: () -> Self + return self.__class__(op=self.op, operand_idx=self.unspread_start_idx) + @plain_data(frozen=True, unsafe_hash=True, repr=False) @final class SSAVal(SSAValOrUse): - __slots__ = "output_idx", - - def __init__(self, op, output_idx): - # type: (Op, int) -> None - super().__init__(op) - if output_idx < 0 or output_idx >= len(op.properties.outputs): - raise ValueError("invalid output_idx") - self.output_idx = output_idx + __slots__ = () def __repr__(self): # type: () -> str - return f"<{self.op.name}.outputs[{self.output_idx}]: {self.ty}>" + return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>" @cached_property def def_loc_set_before_spread(self): @@ -1158,22 +1203,23 @@ class SSAVal(SSAValOrUse): return self.defining_descriptor.loc_set_before_spread @cached_property - def defining_descriptor(self): - # type: () -> OperandDesc - return self.op.properties.outputs[self.output_idx] + def descriptor_array(self): + # type: () -> tuple[OperandDesc, ...] + return self.op.properties.outputs + + @cached_property + def tied_input(self): + # type: () -> None | SSAUse + if self.defining_descriptor.tied_input_index is None: + return None + return SSAUse(op=self.op, + operand_idx=self.defining_descriptor.tied_input_index) @plain_data(frozen=True, unsafe_hash=True, repr=False) @final class SSAUse(SSAValOrUse): - __slots__ = "input_idx", - - def __init__(self, op, input_idx): - # type: (Op, int) -> None - super().__init__(op) - self.input_idx = input_idx - if input_idx < 0 or input_idx >= len(op.inputs): - raise ValueError("input_idx out of range") + __slots__ = () @cached_property def use_loc_set_before_spread(self): @@ -1181,13 +1227,23 @@ class SSAUse(SSAValOrUse): return self.defining_descriptor.loc_set_before_spread @cached_property - def defining_descriptor(self): - # type: () -> OperandDesc - return self.op.properties.inputs[self.input_idx] + def descriptor_array(self): + # type: () -> tuple[OperandDesc, ...] + return self.op.properties.inputs def __repr__(self): # type: () -> str - return f"<{self.op.name}.inputs[{self.input_idx}]: {self.ty}>" + return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>" + + @property + def ssa_val(self): + # type: () -> SSAVal + return self.op.input_vals[self.operand_idx] + + @ssa_val.setter + def ssa_val(self, ssa_val): + # type: (SSAVal) -> None + self.op.input_vals[self.operand_idx] = ssa_val _T = TypeVar("_T") @@ -1282,7 +1338,7 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]): @final -class OpInputs(OpInputSeq[SSAVal, OperandDesc]): +class OpInputVals(OpInputSeq[SSAVal, OperandDesc]): def _get_descriptors(self): # type: () -> tuple[OperandDesc, ...] return self.op.properties.inputs @@ -1329,13 +1385,16 @@ class OpImmediates(OpInputSeq[int, range]): @plain_data(frozen=True, eq=False, repr=False) @final class Op: - __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name" + __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates", + "outputs", "name") - def __init__(self, fn, properties, inputs, immediates, name=""): + def __init__(self, fn, properties, input_vals, immediates, name=""): # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None self.fn = fn self.properties = properties - self.inputs = OpInputs(inputs, op=self) + self.input_vals = OpInputVals(input_vals, op=self) + inputs_len = len(self.properties.inputs) + self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len)) self.immediates = OpImmediates(immediates, op=self) outputs_len = len(self.properties.outputs) self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len)) @@ -1375,7 +1434,7 @@ class Op: def pre_ra_sim(self, state): # type: (PreRASimState) -> None - for inp in self.inputs: + for inp in self.input_vals: if inp not in state.ssa_vals: raise ValueError(f"SSAVal {inp} not yet assigned when " f"running {self}") diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py index 962a021..68443d9 100644 --- a/src/bigint_presentation_code/register_allocator2.py +++ b/src/bigint_presentation_code/register_allocator2.py @@ -6,17 +6,15 @@ this uses an algorithm based on: """ from itertools import combinations -from functools import reduce -from typing import Generic, Iterable, Mapping -from cached_property import cached_property -import operator +from typing import Any, Generic, Iterable, Iterator, Mapping, MutableSet +from cached_property import cached_property from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir2 import ( - Op, LocSet, Ty, SSAVal, BaseTy, Loc, FnWithUses) -from bigint_presentation_code.type_util import final, Self -from bigint_presentation_code.util import OFSet, OSet, FMap +from bigint_presentation_code.compiler_ir2 import (BaseTy, FnWithUses, Loc, + LocSet, Op, SSAVal, Ty) +from bigint_presentation_code.type_util import final +from bigint_presentation_code.util import FMap, OFSet, OSet @plain_data(unsafe_hash=True, order=True, frozen=True) @@ -58,7 +56,7 @@ class BadMergedSSAVal(ValueError): pass -@plain_data(frozen=True, unsafe_hash=True) +@plain_data(frozen=True) @final class MergedSSAVal: """a set of `SSAVal`s along with their offsets, all register allocated as @@ -88,7 +86,7 @@ class MergedSSAVal: * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)` * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)` """ - __slots__ = "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set" + __slots__ = "fn_with_uses", "ssa_val_offsets", "first_ssa_val", "loc_set" def __init__(self, fn_with_uses, ssa_val_offsets): # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None @@ -96,13 +94,13 @@ class MergedSSAVal: if isinstance(ssa_val_offsets, SSAVal): ssa_val_offsets = {ssa_val_offsets: 0} self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int] - base_ty = None - for ssa_val in self.ssa_val_offsets.keys(): - base_ty = ssa_val.base_ty + first_ssa_val = None + for ssa_val in self.ssa_vals: + first_ssa_val = ssa_val break - if base_ty is None: + if first_ssa_val is None: raise BadMergedSSAVal("MergedSSAVal can't be empty") - self.base_ty = base_ty # type: BaseTy + self.first_ssa_val = first_ssa_val # type: SSAVal # self.ty checks for mismatched base_ty reg_len = self.ty.reg_len loc_set = None # type: None | LocSet @@ -144,11 +142,30 @@ class MergedSSAVal: assert loc_set.ty == self.ty, "logic error somewhere" self.loc_set = loc_set # type: LocSet + @cached_property + def __hash(self): + # type: () -> int + return hash((self.fn_with_uses, self.ssa_val_offsets)) + + def __hash__(self): + # type: () -> int + return self.__hash + @cached_property def offset(self): # type: () -> int return min(self.ssa_val_offsets_before_spread.values()) + @property + def base_ty(self): + # type: () -> BaseTy + return self.first_ssa_val.base_ty + + @cached_property + def ssa_vals(self): + # type: () -> OFSet[SSAVal] + return OFSet(self.ssa_val_offsets.keys()) + @cached_property def ty(self): # type: () -> Ty @@ -166,15 +183,8 @@ class MergedSSAVal: # type: () -> FMap[SSAVal, int] retval = {} # type: dict[SSAVal, int] for ssa_val, offset in self.ssa_val_offsets.items(): - offset_before_spread = offset - spread_index = ssa_val.defining_descriptor.spread_index - if spread_index is not None: - assert ssa_val.ty.reg_len == 1, ( - "this function assumes spreading always converts a vector " - "to a contiguous sequence of scalars, if that's changed " - "in the future, then this function needs to be adjusted") - offset_before_spread -= spread_index - retval[ssa_val] = offset_before_spread + retval[ssa_val] = ( + offset - ssa_val.defining_descriptor.reg_offset_in_unspread) return FMap(retval) def offset_by(self, amount): @@ -186,65 +196,178 @@ class MergedSSAVal: # type: () -> MergedSSAVal return self.offset_by(-self.offset) - def with_offset_to_match(self, target): - # type: (MergedSSAVal) -> MergedSSAVal + def with_offset_to_match(self, target, additional_offset=0): + # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal + if isinstance(target, MergedSSAVal): + ssa_val_offsets = target.ssa_val_offsets + else: + ssa_val_offsets = {target: 0} for ssa_val, offset in self.ssa_val_offsets.items(): - if ssa_val in target.ssa_val_offsets: - return self.offset_by(target.ssa_val_offsets[ssa_val] - offset) + if ssa_val in ssa_val_offsets: + return self.offset_by( + ssa_val_offsets[ssa_val] + additional_offset - offset) raise ValueError("can't change offset to match unrelated MergedSSAVal") + def merged(self, *others): + # type: (*MergedSSAVal) -> MergedSSAVal + retval = dict(self.ssa_val_offsets) + for other in others: + if other.fn_with_uses != self.fn_with_uses: + raise ValueError("fn_with_uses mismatch") + for ssa_val, offset in other.ssa_val_offsets.items(): + if ssa_val in retval and retval[ssa_val] != offset: + raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: " + f"{retval[ssa_val]} != {offset}") + retval[ssa_val] = offset + return MergedSSAVal(fn_with_uses=self.fn_with_uses, + ssa_val_offsets=retval) + + +@final +class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]): + def __init__(self): + # type: (...) -> None + self.__merge_map = {} # type: dict[SSAVal, MergedSSAVal] + self.__values_set = MergedSSAValsSet( + _private_merge_map=self.__merge_map, + _private_values_set=OSet()) + + def __getitem__(self, __key): + # type: (SSAVal) -> MergedSSAVal + return self.__merge_map[__key] + + def __iter__(self): + # type: () -> Iterator[SSAVal] + return iter(self.__merge_map) + + def __len__(self): + # type: () -> int + return len(self.__merge_map) + + @property + def values_set(self): + # type: () -> MergedSSAValsSet + return self.__values_set + + def __repr__(self): + # type: () -> str + s = ",\n".join(repr(v) for v in self.__values_set) + return f"MergedSSAValsMap({{{s}}})" + @final -class MergedSSAVals(OFSet[MergedSSAVal]): - def __init__(self, merged_ssa_vals=()): - # type: (Iterable[MergedSSAVal]) -> None - super().__init__(merged_ssa_vals) - merge_map = {} # type: dict[SSAVal, MergedSSAVal] - for merged_ssa_val in self: - for ssa_val in merged_ssa_val.ssa_val_offsets.keys(): - if ssa_val in merge_map: +class MergedSSAValsSet(MutableSet[MergedSSAVal]): + def __init__(self, *, + _private_merge_map, # type: dict[SSAVal, MergedSSAVal] + _private_values_set, # type: OSet[MergedSSAVal] + ): + # type: (...) -> None + self.__merge_map = _private_merge_map + self.__values_set = _private_values_set + + @classmethod + def _from_iterable(cls, it): + # type: (Iterable[MergedSSAVal]) -> OSet[MergedSSAVal] + return OSet(it) + + def __contains__(self, value): + # type: (MergedSSAVal | Any) -> bool + return value in self.__values_set + + def __iter__(self): + # type: () -> Iterator[MergedSSAVal] + return iter(self.__values_set) + + def __len__(self): + # type: () -> int + return len(self.__values_set) + + def add(self, value): + # type: (MergedSSAVal) -> None + if value in self: + return + added = 0 # type: int | None + try: + for ssa_val in value.ssa_vals: + if ssa_val in self.__merge_map: raise ValueError( f"overlapping `MergedSSAVal`s: {ssa_val} is in both " - f"{merged_ssa_val} and {merge_map[ssa_val]}") - merge_map[ssa_val] = merged_ssa_val - self.__merge_map = FMap(merge_map) + f"{value} and {self.__merge_map[ssa_val]}") + self.__merge_map[ssa_val] = value + added += 1 + self.__values_set.add(value) + added = None + finally: + if added is not None: + # remove partially added stuff + for idx, ssa_val in enumerate(value.ssa_vals): + if idx >= added: + break + del self.__merge_map[ssa_val] + + def discard(self, value): + # type: (MergedSSAVal) -> None + if value not in self: + return + self.__values_set.discard(value) + for ssa_val in value.ssa_val_offsets.keys(): + del self.__merge_map[ssa_val] - @cached_property - def merge_map(self): - # type: () -> FMap[SSAVal, MergedSSAVal] - return self.__merge_map + def __repr__(self): + # type: () -> str + s = ",\n".join(repr(v) for v in self.__values_set) + return f"MergedSSAValsSet({{{s}}})" -# FIXME: work on code from here + +@plain_data(frozen=True) +@final +class MergedSSAVals: + __slots__ = "fn_with_uses", "merge_map", "merged_ssa_vals" + + def __init__(self, fn_with_uses, merged_ssa_vals): + # type: (FnWithUses, Iterable[MergedSSAVal]) -> None + self.fn_with_uses = fn_with_uses + self.merge_map = MergedSSAValsMap() + self.merged_ssa_vals = self.merge_map.values_set + for i in merged_ssa_vals: + self.merged_ssa_vals.add(i) + + def merge(self, ssa_val1, ssa_val2, additional_offset=0): + # type: (SSAVal, SSAVal, int) -> MergedSSAVal + merged1 = self.merge_map[ssa_val1] + merged2 = self.merge_map[ssa_val2] + merged = merged1.with_offset_to_match(ssa_val1) + merged = merged.merged(merged2.with_offset_to_match( + ssa_val2, additional_offset=additional_offset)) + self.merged_ssa_vals.remove(merged1) + self.merged_ssa_vals.remove(merged2) + self.merged_ssa_vals.add(merged) + return merged @staticmethod def minimally_merged(fn_with_uses): # type: (FnWithUses) -> MergedSSAVals - merge_map = {} # type: dict[SSAVal, MergedSSAVal] + retval = MergedSSAVals(fn_with_uses=fn_with_uses, merged_ssa_vals=()) for op in fn_with_uses.fn.ops: - for fn - for val in (*op.inputs().values(), *op.outputs().values()): - if val not in merged_sets: - merged_sets[val] = MergedRegSet(val) - for e in op.get_equality_constraints(): - lhs_set = MergedRegSet.from_equality_constraint(e.lhs) - rhs_set = MergedRegSet.from_equality_constraint(e.rhs) - items = [] # type: list[tuple[SSAVal, int]] - for i in e.lhs: - s = merged_sets[i].with_offset_to_match(lhs_set) - items.extend(s.items()) - for i in e.rhs: - s = merged_sets[i].with_offset_to_match(rhs_set) - items.extend(s.items()) - full_set = MergedRegSet(items) - for val in full_set.keys(): - merged_sets[val] = full_set - - self.__map = {k: v.normalized() for k, v in merged_sets.items()} + for inp in op.input_uses: + if inp.unspread_start != inp: + retval.merge(inp.unspread_start.ssa_val, inp.ssa_val, + additional_offset=inp.reg_offset_in_unspread) + for out in op.outputs: + if out.unspread_start != out: + retval.merge(out.unspread_start, out, + additional_offset=out.reg_offset_in_unspread) + if out.tied_input is not None: + retval.merge(out.tied_input.ssa_val, out) + return retval + + +# FIXME: work on code from here @final -class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]): - def __init__(self, ops): +class LiveIntervals(Mapping[MergedSSAVal, LiveInterval]): + def __init__(self, merged_ssa_vals): # type: (list[Op]) -> None self.__merged_reg_sets = MergedRegSets(ops) live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval] -- 2.30.2