From: Jacob Lifshay Date: Tue, 8 Nov 2022 07:00:11 +0000 (-0800) Subject: rename compiler_ir2.py/register_allocator2.py to compiler_ir.py/register_allocator.py X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=65531388555a7924e9ebde840635348bc5e53b0b;p=bigint-presentation-code.git rename compiler_ir2.py/register_allocator2.py to compiler_ir.py/register_allocator.py --- diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py new file mode 100644 index 0000000..ba29ee0 --- /dev/null +++ b/src/bigint_presentation_code/_tests/test_compiler_ir.py @@ -0,0 +1,1069 @@ +import unittest + +from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, BaseTy, + Fn, FnAnalysis, GenAsmState, + Loc, LocKind, OpKind, + OpStage, PreRASimState, + ProgramPoint, SSAVal, Ty) + + +class TestCompilerIR(unittest.TestCase): + maxDiff = None + + def test_program_point(self): + # type: () -> None + expected = [] # type: list[ProgramPoint] + for op_index in range(5): + for stage in OpStage: + expected.append(ProgramPoint(op_index=op_index, stage=stage)) + + for idx, pp in enumerate(expected): + if idx + 1 < len(expected): + self.assertEqual(pp.next(), expected[idx + 1]) + + self.assertEqual(sorted(expected), expected) + + def make_add_fn(self): + # type: () -> tuple[Fn, SSAVal] + fn = Fn() + op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg") + arg = op0.outputs[0] + MAXVL = 32 + op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl") + vl = op1.outputs[0] + op2 = fn.append_new_op( + OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL, + name="ld") + a = op2.outputs[0] + op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], + maxvl=MAXVL, name="li") + b = op3.outputs[0] + op4 = fn.append_new_op(OpKind.SetCA, name="ca") + ca = op4.outputs[0] + op5 = fn.append_new_op( + OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add") + s = op5.outputs[0] + _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl], + immediates=[0], maxvl=MAXVL, name="st") + return fn, arg + + def test_fn_analysis(self): + fn, _arg = self.make_add_fn() + fn_analysis = FnAnalysis(fn) + self.assertEqual( + repr(fn_analysis.uses), + "FMap({" + ">: OFSet([" + ">, >]), " + ">: OFSet([" + ">, >, " + ">, " + ">]), " + ">: OFSet([" + ">]), " + ">: OFSet([" + ">]), " + ">: OFSet([>]), " + ">: OFSet([" + ">]), " + ">: OFSet()})" + ) + self.assertEqual( + repr(fn_analysis.op_indexes), + "FMap({" + "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), " + "immediates=[], outputs=(>,), " + "name='arg'): 0, " + "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), " + "immediates=[32], outputs=(>,), " + "name='vl'): 1, " + "Op(kind=OpKind.SvLd, input_vals=[" + ">, >], " + "input_uses=(>, " + ">), immediates=[0], " + "outputs=(>,), name='ld'): 2, " + "Op(kind=OpKind.SvLI, input_vals=[>], " + "input_uses=(>,), immediates=[0], " + "outputs=(>,), name='li'): 3, " + "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), " + "immediates=[], outputs=(>,), name='ca'): 4, " + "Op(kind=OpKind.SvAddE, input_vals=[" + ">, >, " + ">, >], " + "input_uses=(>, " + ">, >, " + ">), immediates=[], outputs=(" + ">, >), " + "name='add'): 5, " + "Op(kind=OpKind.SvStd, input_vals=[" + ">, >, " + ">], " + "input_uses=(>, " + ">, >), " + "immediates=[0], outputs=(), name='st'): 6})" + ) + self.assertEqual( + repr(fn_analysis.live_ranges), + "FMap({" + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: })" + ) + self.assertEqual( + repr(fn_analysis.live_at), + "FMap({" + ": OFSet([>]), " + ": OFSet([>]), " + ": OFSet([>]), " + ": OFSet([" + ">, >]), " + ": OFSet([" + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">, >, " + ">, >, " + ">]), " + ": OFSet([" + ">, >, " + ">, >]), " + ": OFSet([" + ">, >, " + ">]), " + ": OFSet()})" + ) + self.assertEqual( + repr(fn_analysis.def_program_ranges), + "FMap({" + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: })" + ) + self.assertEqual( + repr(fn_analysis.use_program_points), + "FMap({" + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: , " + ">: })" + ) + self.assertEqual( + repr(fn_analysis.all_program_points), + "") + + def test_repr(self): + fn, _arg = self.make_add_fn() + self.assertEqual([repr(i) for i in fn.ops], [ + "Op(kind=OpKind.FuncArgR3, " + "input_vals=[], " + "input_uses=(), " + "immediates=[], " + "outputs=(>,), name='arg')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), name='vl')", + "Op(kind=OpKind.SvLd, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[0], " + "outputs=(>,), name='ld')", + "Op(kind=OpKind.SvLI, " + "input_vals=[>], " + "input_uses=(>,), " + "immediates=[0], " + "outputs=(>,), name='li')", + "Op(kind=OpKind.SetCA, " + "input_vals=[], " + "input_uses=(), " + "immediates=[], " + "outputs=(>,), name='ca')", + "Op(kind=OpKind.SvAddE, " + "input_vals=[>, " + ">, >, " + ">], " + "input_uses=(>, " + ">, >, " + ">), " + "immediates=[], " + "outputs=(>, >), " + "name='add')", + "Op(kind=OpKind.SvStd, " + "input_vals=[>, >, " + ">], " + "input_uses=(>, " + ">, >), " + "immediates=[0], " + "outputs=(), name='st')", + ]) + self.assertEqual([repr(op.properties) for op in fn.ops], [ + "OpProperties(kind=OpKind.FuncArgR3, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([3])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=1)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SvLd, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " + "ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", + "OpProperties(kind=OpKind.SvLI, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", + "OpProperties(kind=OpKind.SetCA, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.CA: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SvAddE, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.CA: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.CA: FBitSet([0])}), ty=), " + "tied_input_index=2, spread_index=None, " + "write_stage=OpStage.Early)), maxvl=32)", + "OpProperties(kind=OpKind.SvStd, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " + "ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(), maxvl=32)", + ]) + + def test_pre_ra_insert_copies(self): + fn, _arg = self.make_add_fn() + fn.pre_ra_insert_copies() + self.assertEqual([repr(i) for i in fn.ops], [ + "Op(kind=OpKind.FuncArgR3, " + "input_vals=[], " + "input_uses=(), " + "immediates=[], " + "outputs=(>,), name='arg')", + "Op(kind=OpKind.CopyFromReg, " + "input_vals=[>], " + "input_uses=(>,), " + "immediates=[], " + "outputs=(>,), " + "name='arg.out0.copy')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), name='vl')", + "Op(kind=OpKind.CopyToReg, " + "input_vals=[>], " + "input_uses=(>,), " + "immediates=[], " + "outputs=(>,), name='ld.inp0.copy')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='ld.inp1.setvl')", + "Op(kind=OpKind.SvLd, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[0], " + "outputs=(>,), name='ld')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='ld.out0.setvl')", + "Op(kind=OpKind.VecCopyFromReg, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[], " + "outputs=(>,), " + "name='ld.out0.copy')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='li.inp0.setvl')", + "Op(kind=OpKind.SvLI, " + "input_vals=[>], " + "input_uses=(>,), " + "immediates=[0], " + "outputs=(>,), name='li')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='li.out0.setvl')", + "Op(kind=OpKind.VecCopyFromReg, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[], " + "outputs=(>,), " + "name='li.out0.copy')", + "Op(kind=OpKind.SetCA, " + "input_vals=[], " + "input_uses=(), " + "immediates=[], " + "outputs=(>,), name='ca')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='add.inp0.setvl')", + "Op(kind=OpKind.VecCopyToReg, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[], " + "outputs=(>,), " + "name='add.inp0.copy')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='add.inp1.setvl')", + "Op(kind=OpKind.VecCopyToReg, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[], " + "outputs=(>,), " + "name='add.inp1.copy')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='add.inp3.setvl')", + "Op(kind=OpKind.SvAddE, " + "input_vals=[>, " + ">, >, " + ">], " + "input_uses=(>, " + ">, >, " + ">), " + "immediates=[], " + "outputs=(>, >), " + "name='add')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='add.out0.setvl')", + "Op(kind=OpKind.VecCopyFromReg, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[], " + "outputs=(>,), " + "name='add.out0.copy')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='st.inp0.setvl')", + "Op(kind=OpKind.VecCopyToReg, " + "input_vals=[>, " + ">], " + "input_uses=(>, " + ">), " + "immediates=[], " + "outputs=(>,), " + "name='st.inp0.copy')", + "Op(kind=OpKind.CopyToReg, " + "input_vals=[>], " + "input_uses=(>,), " + "immediates=[], " + "outputs=(>,), " + "name='st.inp1.copy')", + "Op(kind=OpKind.SetVLI, " + "input_vals=[], " + "input_uses=(), " + "immediates=[32], " + "outputs=(>,), " + "name='st.inp2.setvl')", + "Op(kind=OpKind.SvStd, " + "input_vals=[>, " + ">, " + ">], " + "input_uses=(>, " + ">, >), " + "immediates=[0], " + "outputs=(), name='st')", + ]) + self.assertEqual([repr(op.properties) for op in fn.ops], [ + "OpProperties(kind=OpKind.FuncArgR3, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([3])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=1)", + "OpProperties(kind=OpKind.CopyFromReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " + "ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " + "LocKind.StackI64: FBitSet(range(0, 512))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.CopyToReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " + "LocKind.StackI64: FBitSet(range(0, 512))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " + "ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SvLd, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " + "ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.VecCopyFromReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97)), " + "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SvLI, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), maxvl=32)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.VecCopyFromReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97)), " + "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", + "OpProperties(kind=OpKind.SetCA, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.CA: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.VecCopyToReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97)), " + "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.VecCopyToReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97)), " + "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SvAddE, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.CA: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.CA: FBitSet([0])}), ty=), " + "tied_input_index=2, spread_index=None, " + "write_stage=OpStage.Early)), maxvl=32)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.VecCopyFromReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97)), " + "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.VecCopyToReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97)), " + "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=32)", + "OpProperties(kind=OpKind.CopyToReg, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " + "LocKind.StackI64: FBitSet(range(0, 512))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early),), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " + "ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SetVLI, " + "inputs=(), " + "outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late),), maxvl=1)", + "OpProperties(kind=OpKind.SvStd, " + "inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " + "ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " + "tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)), " + "outputs=(), maxvl=32)", + ]) + + def test_sim(self): + fn, arg = self.make_add_fn() + addr = 0x100 + state = PreRASimState(ssa_vals={arg: (addr,)}, memory={}) + state.store(addr=addr, value=0xffffffff_ffffffff, + size_in_bytes=GPR_SIZE_IN_BYTES) + state.store(addr=addr + GPR_SIZE_IN_BYTES, value=0xabcdef01_23456789, + size_in_bytes=GPR_SIZE_IN_BYTES) + self.assertEqual( + repr(state), + "PreRASimState(memory={\n" + "0x00100: <0xffffffffffffffff>,\n" + "0x00108: <0xabcdef0123456789>}, " + "ssa_vals={>: (0x100,)})") + fn.sim(state) + self.assertEqual( + repr(state), + "PreRASimState(memory={\n" + "0x00100: <0x0000000000000000>,\n" + "0x00108: <0xabcdef012345678a>,\n" + "0x00110: <0x0000000000000000>,\n" + "0x00118: <0x0000000000000000>,\n" + "0x00120: <0x0000000000000000>,\n" + "0x00128: <0x0000000000000000>,\n" + "0x00130: <0x0000000000000000>,\n" + "0x00138: <0x0000000000000000>,\n" + "0x00140: <0x0000000000000000>,\n" + "0x00148: <0x0000000000000000>,\n" + "0x00150: <0x0000000000000000>,\n" + "0x00158: <0x0000000000000000>,\n" + "0x00160: <0x0000000000000000>,\n" + "0x00168: <0x0000000000000000>,\n" + "0x00170: <0x0000000000000000>,\n" + "0x00178: <0x0000000000000000>,\n" + "0x00180: <0x0000000000000000>,\n" + "0x00188: <0x0000000000000000>,\n" + "0x00190: <0x0000000000000000>,\n" + "0x00198: <0x0000000000000000>,\n" + "0x001a0: <0x0000000000000000>,\n" + "0x001a8: <0x0000000000000000>,\n" + "0x001b0: <0x0000000000000000>,\n" + "0x001b8: <0x0000000000000000>,\n" + "0x001c0: <0x0000000000000000>,\n" + "0x001c8: <0x0000000000000000>,\n" + "0x001d0: <0x0000000000000000>,\n" + "0x001d8: <0x0000000000000000>,\n" + "0x001e0: <0x0000000000000000>,\n" + "0x001e8: <0x0000000000000000>,\n" + "0x001f0: <0x0000000000000000>,\n" + "0x001f8: <0x0000000000000000>}, ssa_vals={\n" + ">: (0x100,),\n" + ">: (0x20,),\n" + ">: (\n" + " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0),\n" + ">: (\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0),\n" + ">: (0x1,),\n" + ">: (\n" + " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0,\n" + " 0x0, 0x0, 0x0, 0x0),\n" + ">: (0x0,),\n" + "})") + + def test_gen_asm(self): + fn, _arg = self.make_add_fn() + fn.pre_ra_insert_copies() + VL_LOC = Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1) + CA_LOC = Loc(kind=LocKind.CA, start=0, reg_len=1) + state = GenAsmState(allocated_locs={ + fn.ops[0].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), + fn.ops[1].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), + fn.ops[2].outputs[0]: VL_LOC, + fn.ops[3].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), + fn.ops[4].outputs[0]: VL_LOC, + fn.ops[5].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), + fn.ops[6].outputs[0]: VL_LOC, + fn.ops[7].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), + fn.ops[8].outputs[0]: VL_LOC, + fn.ops[9].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32), + fn.ops[10].outputs[0]: VL_LOC, + fn.ops[11].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32), + fn.ops[12].outputs[0]: CA_LOC, + fn.ops[13].outputs[0]: VL_LOC, + fn.ops[14].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), + fn.ops[15].outputs[0]: VL_LOC, + fn.ops[16].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32), + fn.ops[17].outputs[0]: VL_LOC, + fn.ops[18].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), + fn.ops[18].outputs[1]: CA_LOC, + fn.ops[19].outputs[0]: VL_LOC, + fn.ops[20].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), + fn.ops[21].outputs[0]: VL_LOC, + fn.ops[22].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), + fn.ops[23].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), + fn.ops[24].outputs[0]: VL_LOC, + }) + fn.gen_asm(state) + self.assertEqual(state.output, [ + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.ld *32, 0(3)', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.addi *64, 0, 0', + 'setvl 0, 0, 32, 0, 1, 1', + 'subfc 0, 0, 0', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.adde *32, *32, *64', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.std *32, 0(3)', + ]) + + def test_spread(self): + fn = Fn() + maxvl = 4 + vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl], + name="vl", maxvl=maxvl).outputs[0] + li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], + name="li", maxvl=maxvl).outputs[0] + spread_op = fn.append_new_op(OpKind.Spread, input_vals=[li, vl], + name="spread", maxvl=maxvl) + self.assertEqual(spread_op.outputs[0].ty_before_spread, + Ty(base_ty=BaseTy.I64, reg_len=maxvl)) + _concat = fn.append_new_op( + OpKind.Concat, input_vals=[*spread_op.outputs[::-1], vl], + name="concat", maxvl=maxvl) + self.assertEqual([repr(op.properties) for op in fn.ops], [ + "OpProperties(kind=OpKind.SetVLI, inputs=(" + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late)," + "), maxvl=4)", + "OpProperties(kind=OpKind.SvLI, inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)," + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)," + "), maxvl=4)", + "OpProperties(kind=OpKind.Spread, inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)" + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=0, " + "write_stage=OpStage.Late), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=1, " + "write_stage=OpStage.Late), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=2, " + "write_stage=OpStage.Late), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=3, " + "write_stage=OpStage.Late)" + "), maxvl=4)", + "OpProperties(kind=OpKind.Concat, inputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=0, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=1, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=2, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=3, " + "write_stage=OpStage.Early), " + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.VL_MAXVL: FBitSet([0])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Early)" + "), outputs=(" + "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" + "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " + "ty=), tied_input_index=None, spread_index=None, " + "write_stage=OpStage.Late)," + "), maxvl=4)", + ]) + self.assertEqual([repr(op) for op in fn.ops], [ + "Op(kind=OpKind.SetVLI, input_vals=[" + "], input_uses=(" + "), immediates=[4], outputs=(" + ">," + "), name='vl')", + "Op(kind=OpKind.SvLI, input_vals=[" + ">" + "], input_uses=(" + ">," + "), immediates=[0], outputs=(" + ">," + "), name='li')", + "Op(kind=OpKind.Spread, input_vals=[" + ">, " + ">" + "], input_uses=(" + ">, " + ">" + "), immediates=[], outputs=(" + ">, " + ">, " + ">, " + ">" + "), name='spread')", + "Op(kind=OpKind.Concat, input_vals=[" + ">, " + ">, " + ">, " + ">, " + ">" + "], input_uses=(" + ">, " + ">, " + ">, " + ">, " + ">" + "), immediates=[], outputs=(" + ">," + "), name='concat')", + ]) + + +if __name__ == "__main__": + _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir2.py b/src/bigint_presentation_code/_tests/test_compiler_ir2.py deleted file mode 100644 index 833dbc9..0000000 --- a/src/bigint_presentation_code/_tests/test_compiler_ir2.py +++ /dev/null @@ -1,1068 +0,0 @@ -import unittest - -from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, BaseTy, Fn, - FnAnalysis, GenAsmState, Loc, LocKind, OpKind, OpStage, - PreRASimState, ProgramPoint, - SSAVal, Ty) - - -class TestCompilerIR(unittest.TestCase): - maxDiff = None - - def test_program_point(self): - # type: () -> None - expected = [] # type: list[ProgramPoint] - for op_index in range(5): - for stage in OpStage: - expected.append(ProgramPoint(op_index=op_index, stage=stage)) - - for idx, pp in enumerate(expected): - if idx + 1 < len(expected): - self.assertEqual(pp.next(), expected[idx + 1]) - - self.assertEqual(sorted(expected), expected) - - def make_add_fn(self): - # type: () -> tuple[Fn, SSAVal] - fn = Fn() - op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg") - arg = op0.outputs[0] - MAXVL = 32 - op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl") - vl = op1.outputs[0] - op2 = fn.append_new_op( - OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL, - name="ld") - a = op2.outputs[0] - op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], - maxvl=MAXVL, name="li") - b = op3.outputs[0] - op4 = fn.append_new_op(OpKind.SetCA, name="ca") - ca = op4.outputs[0] - op5 = fn.append_new_op( - OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add") - s = op5.outputs[0] - _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl], - immediates=[0], maxvl=MAXVL, name="st") - return fn, arg - - def test_fn_analysis(self): - fn, _arg = self.make_add_fn() - fn_analysis = FnAnalysis(fn) - self.assertEqual( - repr(fn_analysis.uses), - "FMap({" - ">: OFSet([" - ">, >]), " - ">: OFSet([" - ">, >, " - ">, " - ">]), " - ">: OFSet([" - ">]), " - ">: OFSet([" - ">]), " - ">: OFSet([>]), " - ">: OFSet([" - ">]), " - ">: OFSet()})" - ) - self.assertEqual( - repr(fn_analysis.op_indexes), - "FMap({" - "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), " - "immediates=[], outputs=(>,), " - "name='arg'): 0, " - "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), " - "immediates=[32], outputs=(>,), " - "name='vl'): 1, " - "Op(kind=OpKind.SvLd, input_vals=[" - ">, >], " - "input_uses=(>, " - ">), immediates=[0], " - "outputs=(>,), name='ld'): 2, " - "Op(kind=OpKind.SvLI, input_vals=[>], " - "input_uses=(>,), immediates=[0], " - "outputs=(>,), name='li'): 3, " - "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), " - "immediates=[], outputs=(>,), name='ca'): 4, " - "Op(kind=OpKind.SvAddE, input_vals=[" - ">, >, " - ">, >], " - "input_uses=(>, " - ">, >, " - ">), immediates=[], outputs=(" - ">, >), " - "name='add'): 5, " - "Op(kind=OpKind.SvStd, input_vals=[" - ">, >, " - ">], " - "input_uses=(>, " - ">, >), " - "immediates=[0], outputs=(), name='st'): 6})" - ) - self.assertEqual( - repr(fn_analysis.live_ranges), - "FMap({" - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: })" - ) - self.assertEqual( - repr(fn_analysis.live_at), - "FMap({" - ": OFSet([>]), " - ": OFSet([>]), " - ": OFSet([>]), " - ": OFSet([" - ">, >]), " - ": OFSet([" - ">, >, " - ">]), " - ": OFSet([" - ">, >, " - ">]), " - ": OFSet([" - ">, >, " - ">, >]), " - ": OFSet([" - ">, >, " - ">, >]), " - ": OFSet([" - ">, >, " - ">, >]), " - ": OFSet([" - ">, >, " - ">, >, " - ">]), " - ": OFSet([" - ">, >, " - ">, >, " - ">, >, " - ">]), " - ": OFSet([" - ">, >, " - ">, >]), " - ": OFSet([" - ">, >, " - ">]), " - ": OFSet()})" - ) - self.assertEqual( - repr(fn_analysis.def_program_ranges), - "FMap({" - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: })" - ) - self.assertEqual( - repr(fn_analysis.use_program_points), - "FMap({" - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: , " - ">: })" - ) - self.assertEqual( - repr(fn_analysis.all_program_points), - "") - - def test_repr(self): - fn, _arg = self.make_add_fn() - self.assertEqual([repr(i) for i in fn.ops], [ - "Op(kind=OpKind.FuncArgR3, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='arg')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), name='vl')", - "Op(kind=OpKind.SvLd, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[0], " - "outputs=(>,), name='ld')", - "Op(kind=OpKind.SvLI, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[0], " - "outputs=(>,), name='li')", - "Op(kind=OpKind.SetCA, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='ca')", - "Op(kind=OpKind.SvAddE, " - "input_vals=[>, " - ">, >, " - ">], " - "input_uses=(>, " - ">, >, " - ">), " - "immediates=[], " - "outputs=(>, >), " - "name='add')", - "Op(kind=OpKind.SvStd, " - "input_vals=[>, >, " - ">], " - "input_uses=(>, " - ">, >), " - "immediates=[0], " - "outputs=(), name='st')", - ]) - self.assertEqual([repr(op.properties) for op in fn.ops], [ - "OpProperties(kind=OpKind.FuncArgR3, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([3])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=1)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SvLd, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " - "ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", - "OpProperties(kind=OpKind.SvLI, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", - "OpProperties(kind=OpKind.SetCA, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SvAddE, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=2, spread_index=None, " - "write_stage=OpStage.Early)), maxvl=32)", - "OpProperties(kind=OpKind.SvStd, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " - "ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(), maxvl=32)", - ]) - - def test_pre_ra_insert_copies(self): - fn, _arg = self.make_add_fn() - fn.pre_ra_insert_copies() - self.assertEqual([repr(i) for i in fn.ops], [ - "Op(kind=OpKind.FuncArgR3, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='arg')", - "Op(kind=OpKind.CopyFromReg, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[], " - "outputs=(>,), " - "name='arg.out0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), name='vl')", - "Op(kind=OpKind.CopyToReg, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[], " - "outputs=(>,), name='ld.inp0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='ld.inp1.setvl')", - "Op(kind=OpKind.SvLd, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[0], " - "outputs=(>,), name='ld')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='ld.out0.setvl')", - "Op(kind=OpKind.VecCopyFromReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='ld.out0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='li.inp0.setvl')", - "Op(kind=OpKind.SvLI, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[0], " - "outputs=(>,), name='li')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='li.out0.setvl')", - "Op(kind=OpKind.VecCopyFromReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='li.out0.copy')", - "Op(kind=OpKind.SetCA, " - "input_vals=[], " - "input_uses=(), " - "immediates=[], " - "outputs=(>,), name='ca')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.inp0.setvl')", - "Op(kind=OpKind.VecCopyToReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='add.inp0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.inp1.setvl')", - "Op(kind=OpKind.VecCopyToReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='add.inp1.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.inp3.setvl')", - "Op(kind=OpKind.SvAddE, " - "input_vals=[>, " - ">, >, " - ">], " - "input_uses=(>, " - ">, >, " - ">), " - "immediates=[], " - "outputs=(>, >), " - "name='add')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='add.out0.setvl')", - "Op(kind=OpKind.VecCopyFromReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='add.out0.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='st.inp0.setvl')", - "Op(kind=OpKind.VecCopyToReg, " - "input_vals=[>, " - ">], " - "input_uses=(>, " - ">), " - "immediates=[], " - "outputs=(>,), " - "name='st.inp0.copy')", - "Op(kind=OpKind.CopyToReg, " - "input_vals=[>], " - "input_uses=(>,), " - "immediates=[], " - "outputs=(>,), " - "name='st.inp1.copy')", - "Op(kind=OpKind.SetVLI, " - "input_vals=[], " - "input_uses=(), " - "immediates=[32], " - "outputs=(>,), " - "name='st.inp2.setvl')", - "Op(kind=OpKind.SvStd, " - "input_vals=[>, " - ">, " - ">], " - "input_uses=(>, " - ">, >), " - "immediates=[0], " - "outputs=(), name='st')", - ]) - self.assertEqual([repr(op.properties) for op in fn.ops], [ - "OpProperties(kind=OpKind.FuncArgR3, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([3])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=1)", - "OpProperties(kind=OpKind.CopyFromReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " - "ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " - "LocKind.StackI64: FBitSet(range(0, 512))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.CopyToReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " - "LocKind.StackI64: FBitSet(range(0, 512))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " - "ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SvLd, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " - "ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.VecCopyFromReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97)), " - "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SvLI, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), maxvl=32)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.VecCopyFromReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97)), " - "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", - "OpProperties(kind=OpKind.SetCA, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.VecCopyToReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97)), " - "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.VecCopyToReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97)), " - "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SvAddE, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.CA: FBitSet([0])}), ty=), " - "tied_input_index=2, spread_index=None, " - "write_stage=OpStage.Early)), maxvl=32)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.VecCopyFromReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97)), " - "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.VecCopyToReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97)), " - "LocKind.StackI64: FBitSet(range(0, 481))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=32)", - "OpProperties(kind=OpKind.CopyToReg, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), " - "LocKind.StackI64: FBitSet(range(0, 512))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early),), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " - "ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SetVLI, " - "inputs=(), " - "outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late),), maxvl=1)", - "OpProperties(kind=OpKind.SvStd, " - "inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet(range(14, 97))}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), " - "ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), ty=), " - "tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)), " - "outputs=(), maxvl=32)", - ]) - - def test_sim(self): - fn, arg = self.make_add_fn() - addr = 0x100 - state = PreRASimState(ssa_vals={arg: (addr,)}, memory={}) - state.store(addr=addr, value=0xffffffff_ffffffff, - size_in_bytes=GPR_SIZE_IN_BYTES) - state.store(addr=addr + GPR_SIZE_IN_BYTES, value=0xabcdef01_23456789, - size_in_bytes=GPR_SIZE_IN_BYTES) - self.assertEqual( - repr(state), - "PreRASimState(memory={\n" - "0x00100: <0xffffffffffffffff>,\n" - "0x00108: <0xabcdef0123456789>}, " - "ssa_vals={>: (0x100,)})") - fn.sim(state) - self.assertEqual( - repr(state), - "PreRASimState(memory={\n" - "0x00100: <0x0000000000000000>,\n" - "0x00108: <0xabcdef012345678a>,\n" - "0x00110: <0x0000000000000000>,\n" - "0x00118: <0x0000000000000000>,\n" - "0x00120: <0x0000000000000000>,\n" - "0x00128: <0x0000000000000000>,\n" - "0x00130: <0x0000000000000000>,\n" - "0x00138: <0x0000000000000000>,\n" - "0x00140: <0x0000000000000000>,\n" - "0x00148: <0x0000000000000000>,\n" - "0x00150: <0x0000000000000000>,\n" - "0x00158: <0x0000000000000000>,\n" - "0x00160: <0x0000000000000000>,\n" - "0x00168: <0x0000000000000000>,\n" - "0x00170: <0x0000000000000000>,\n" - "0x00178: <0x0000000000000000>,\n" - "0x00180: <0x0000000000000000>,\n" - "0x00188: <0x0000000000000000>,\n" - "0x00190: <0x0000000000000000>,\n" - "0x00198: <0x0000000000000000>,\n" - "0x001a0: <0x0000000000000000>,\n" - "0x001a8: <0x0000000000000000>,\n" - "0x001b0: <0x0000000000000000>,\n" - "0x001b8: <0x0000000000000000>,\n" - "0x001c0: <0x0000000000000000>,\n" - "0x001c8: <0x0000000000000000>,\n" - "0x001d0: <0x0000000000000000>,\n" - "0x001d8: <0x0000000000000000>,\n" - "0x001e0: <0x0000000000000000>,\n" - "0x001e8: <0x0000000000000000>,\n" - "0x001f0: <0x0000000000000000>,\n" - "0x001f8: <0x0000000000000000>}, ssa_vals={\n" - ">: (0x100,),\n" - ">: (0x20,),\n" - ">: (\n" - " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0),\n" - ">: (\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0),\n" - ">: (0x1,),\n" - ">: (\n" - " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0,\n" - " 0x0, 0x0, 0x0, 0x0),\n" - ">: (0x0,),\n" - "})") - - def test_gen_asm(self): - fn, _arg = self.make_add_fn() - fn.pre_ra_insert_copies() - VL_LOC = Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1) - CA_LOC = Loc(kind=LocKind.CA, start=0, reg_len=1) - state = GenAsmState(allocated_locs={ - fn.ops[0].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), - fn.ops[1].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), - fn.ops[2].outputs[0]: VL_LOC, - fn.ops[3].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), - fn.ops[4].outputs[0]: VL_LOC, - fn.ops[5].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), - fn.ops[6].outputs[0]: VL_LOC, - fn.ops[7].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), - fn.ops[8].outputs[0]: VL_LOC, - fn.ops[9].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32), - fn.ops[10].outputs[0]: VL_LOC, - fn.ops[11].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32), - fn.ops[12].outputs[0]: CA_LOC, - fn.ops[13].outputs[0]: VL_LOC, - fn.ops[14].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), - fn.ops[15].outputs[0]: VL_LOC, - fn.ops[16].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32), - fn.ops[17].outputs[0]: VL_LOC, - fn.ops[18].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), - fn.ops[18].outputs[1]: CA_LOC, - fn.ops[19].outputs[0]: VL_LOC, - fn.ops[20].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), - fn.ops[21].outputs[0]: VL_LOC, - fn.ops[22].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32), - fn.ops[23].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1), - fn.ops[24].outputs[0]: VL_LOC, - }) - fn.gen_asm(state) - self.assertEqual(state.output, [ - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.ld *32, 0(3)', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.addi *64, 0, 0', - 'setvl 0, 0, 32, 0, 1, 1', - 'subfc 0, 0, 0', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.adde *32, *32, *64', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.std *32, 0(3)', - ]) - - def test_spread(self): - fn = Fn() - maxvl = 4 - vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl], - name="vl", maxvl=maxvl).outputs[0] - li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], - name="li", maxvl=maxvl).outputs[0] - spread_op = fn.append_new_op(OpKind.Spread, input_vals=[li, vl], - name="spread", maxvl=maxvl) - self.assertEqual(spread_op.outputs[0].ty_before_spread, - Ty(base_ty=BaseTy.I64, reg_len=maxvl)) - _concat = fn.append_new_op( - OpKind.Concat, input_vals=[*spread_op.outputs[::-1], vl], - name="concat", maxvl=maxvl) - self.assertEqual([repr(op.properties) for op in fn.ops], [ - "OpProperties(kind=OpKind.SetVLI, inputs=(" - "), outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), " - "ty=), tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late)," - "), maxvl=4)", - "OpProperties(kind=OpKind.SvLI, inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), " - "ty=), tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)," - "), outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)," - "), maxvl=4)", - "OpProperties(kind=OpKind.Spread, inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), " - "ty=), tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)" - "), outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=0, " - "write_stage=OpStage.Late), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=1, " - "write_stage=OpStage.Late), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=2, " - "write_stage=OpStage.Late), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=3, " - "write_stage=OpStage.Late)" - "), maxvl=4)", - "OpProperties(kind=OpKind.Concat, inputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=0, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=1, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=2, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=3, " - "write_stage=OpStage.Early), " - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.VL_MAXVL: FBitSet([0])}), " - "ty=), tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Early)" - "), outputs=(" - "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({" - "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), " - "ty=), tied_input_index=None, spread_index=None, " - "write_stage=OpStage.Late)," - "), maxvl=4)", - ]) - self.assertEqual([repr(op) for op in fn.ops], [ - "Op(kind=OpKind.SetVLI, input_vals=[" - "], input_uses=(" - "), immediates=[4], outputs=(" - ">," - "), name='vl')", - "Op(kind=OpKind.SvLI, input_vals=[" - ">" - "], input_uses=(" - ">," - "), immediates=[0], outputs=(" - ">," - "), name='li')", - "Op(kind=OpKind.Spread, input_vals=[" - ">, " - ">" - "], input_uses=(" - ">, " - ">" - "), immediates=[], outputs=(" - ">, " - ">, " - ">, " - ">" - "), name='spread')", - "Op(kind=OpKind.Concat, input_vals=[" - ">, " - ">, " - ">, " - ">, " - ">" - "], input_uses=(" - ">, " - ">, " - ">, " - ">, " - ">" - "), immediates=[], outputs=(" - ">," - "), name='concat')", - ]) - - -if __name__ == "__main__": - _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py new file mode 100644 index 0000000..d30ea12 --- /dev/null +++ b/src/bigint_presentation_code/_tests/test_register_allocator.py @@ -0,0 +1,493 @@ +import sys +import unittest + +from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind, + SSAVal) +from bigint_presentation_code.register_allocator import allocate_registers + + +class TestRegisterAllocator(unittest.TestCase): + maxDiff = None + + def make_add_fn(self): + # type: () -> tuple[Fn, SSAVal] + fn = Fn() + op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg") + arg = op0.outputs[0] + MAXVL = 32 + op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl") + vl = op1.outputs[0] + op2 = fn.append_new_op( + OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL, + name="ld") + a = op2.outputs[0] + op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], + maxvl=MAXVL, name="li") + b = op3.outputs[0] + op4 = fn.append_new_op(OpKind.SetCA, name="ca") + ca = op4.outputs[0] + op5 = fn.append_new_op( + OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add") + s = op5.outputs[0] + _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl], + immediates=[0], maxvl=MAXVL, name="st") + return fn, arg + + def test_register_allocate(self): + fn, _arg = self.make_add_fn() + reg_assignments = allocate_registers(fn) + + self.assertEqual( + repr(reg_assignments), + "{>: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=78, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1)}" + ) + + def test_gen_asm(self): + fn, _arg = self.make_add_fn() + reg_assignments = allocate_registers(fn) + + self.assertEqual( + repr(reg_assignments), + "{>: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=78, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.CA, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1)}" + ) + state = GenAsmState(reg_assignments) + fn.gen_asm(state) + self.assertEqual(state.output, [ + 'or 4, 3, 3', + 'setvl 0, 0, 32, 0, 1, 1', + 'or 3, 4, 4', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.ld *14, 0(3)', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.or *46, *14, *14', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.addi *14, 0, 0', + 'setvl 0, 0, 32, 0, 1, 1', + 'subfc 0, 0, 0', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.or *78, *46, *46', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.or *46, *14, *14', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.adde *14, *78, *46', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'or 3, 4, 4', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.std *14, 0(3)', + ]) + + def test_register_allocate_spread(self): + fn = Fn() + maxvl = 32 + vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl], + name="vl", maxvl=maxvl).outputs[0] + li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], + name="li", maxvl=maxvl).outputs[0] + spread = fn.append_new_op(OpKind.Spread, input_vals=[li, vl], + name="spread", maxvl=maxvl).outputs + _concat = fn.append_new_op( + OpKind.Concat, input_vals=[*spread[::-1], vl], + name="concat", maxvl=maxvl) + reg_assignments = allocate_registers(fn, debug_out=sys.stdout) + + self.assertEqual( + repr(reg_assignments), + "{>: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=16, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=17, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=18, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=19, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=20, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=21, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=22, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=23, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=24, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=25, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=26, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=27, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=28, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=29, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=30, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=31, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=32, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=33, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=34, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=35, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=36, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=37, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=38, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=39, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=40, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=41, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=42, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=43, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=44, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=45, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=3, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=4, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=5, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=15, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=16, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=17, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=18, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=19, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=20, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=21, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=22, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=23, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=24, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=25, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=26, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=27, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=28, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=29, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=30, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=31, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=32, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=33, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=34, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=35, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=36, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=37, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=38, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=39, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=40, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=41, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=42, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=43, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=44, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=45, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=6, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=7, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=8, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=9, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=10, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=11, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=12, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=46, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=47, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=48, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=49, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=50, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=51, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=52, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=53, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=54, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=55, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=56, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=57, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=58, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=59, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=60, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=61, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=62, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=63, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=64, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=65, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=66, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=67, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.GPR, start=14, reg_len=32), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " + ">: " + "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)}" + ) + state = GenAsmState(reg_assignments) + fn.gen_asm(state) + self.assertEqual(state.output, [ + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'sv.addi *14, 0, 0', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1', + 'or 67, 14, 14', + 'or 66, 15, 15', + 'or 65, 16, 16', + 'or 64, 17, 17', + 'or 63, 18, 18', + 'or 62, 19, 19', + 'or 61, 20, 20', + 'or 60, 21, 21', + 'or 59, 22, 22', + 'or 58, 23, 23', + 'or 57, 24, 24', + 'or 56, 25, 25', + 'or 55, 26, 26', + 'or 54, 27, 27', + 'or 53, 28, 28', + 'or 52, 29, 29', + 'or 51, 30, 30', + 'or 50, 31, 31', + 'or 49, 32, 32', + 'or 48, 33, 33', + 'or 47, 34, 34', + 'or 46, 35, 35', + 'or 12, 36, 36', + 'or 11, 37, 37', + 'or 10, 38, 38', + 'or 9, 39, 39', + 'or 8, 40, 40', + 'or 7, 41, 41', + 'or 6, 42, 42', + 'or 5, 43, 43', + 'or 4, 44, 44', + 'or 3, 45, 45', + 'or 14, 3, 3', + 'or 15, 4, 4', + 'or 16, 5, 5', + 'or 17, 6, 6', + 'or 18, 7, 7', + 'or 19, 8, 8', + 'or 20, 9, 9', + 'or 21, 10, 10', + 'or 22, 11, 11', + 'or 23, 12, 12', + 'or 24, 46, 46', + 'or 25, 47, 47', + 'or 26, 48, 48', + 'or 27, 49, 49', + 'or 28, 50, 50', + 'or 29, 51, 51', + 'or 30, 52, 52', + 'or 31, 53, 53', + 'or 32, 54, 54', + 'or 33, 55, 55', + 'or 34, 56, 56', + 'or 35, 57, 57', + 'or 36, 58, 58', + 'or 37, 59, 59', + 'or 38, 60, 60', + 'or 39, 61, 61', + 'or 40, 62, 62', + 'or 41, 63, 63', + 'or 42, 64, 64', + 'or 43, 65, 65', + 'or 44, 66, 66', + 'or 45, 67, 67', + 'setvl 0, 0, 32, 0, 1, 1', + 'setvl 0, 0, 32, 0, 1, 1']) + + +if __name__ == "__main__": + _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_register_allocator2.py b/src/bigint_presentation_code/_tests/test_register_allocator2.py deleted file mode 100644 index f34ed98..0000000 --- a/src/bigint_presentation_code/_tests/test_register_allocator2.py +++ /dev/null @@ -1,493 +0,0 @@ -import sys -import unittest - -from bigint_presentation_code.compiler_ir2 import (Fn, GenAsmState, OpKind, - SSAVal) -from bigint_presentation_code.register_allocator2 import allocate_registers - - -class TestCompilerIR(unittest.TestCase): - maxDiff = None - - def make_add_fn(self): - # type: () -> tuple[Fn, SSAVal] - fn = Fn() - op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg") - arg = op0.outputs[0] - MAXVL = 32 - op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl") - vl = op1.outputs[0] - op2 = fn.append_new_op( - OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL, - name="ld") - a = op2.outputs[0] - op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], - maxvl=MAXVL, name="li") - b = op3.outputs[0] - op4 = fn.append_new_op(OpKind.SetCA, name="ca") - ca = op4.outputs[0] - op5 = fn.append_new_op( - OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add") - s = op5.outputs[0] - _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl], - immediates=[0], maxvl=MAXVL, name="st") - return fn, arg - - def test_register_allocate(self): - fn, _arg = self.make_add_fn() - reg_assignments = allocate_registers(fn) - - self.assertEqual( - repr(reg_assignments), - "{>: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.GPR, start=46, reg_len=32), " - ">: " - "Loc(kind=LocKind.GPR, start=78, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=46, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1)}" - ) - - def test_gen_asm(self): - fn, _arg = self.make_add_fn() - reg_assignments = allocate_registers(fn) - - self.assertEqual( - repr(reg_assignments), - "{>: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.GPR, start=46, reg_len=32), " - ">: " - "Loc(kind=LocKind.GPR, start=78, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.CA, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=46, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1)}" - ) - state = GenAsmState(reg_assignments) - fn.gen_asm(state) - self.assertEqual(state.output, [ - 'or 4, 3, 3', - 'setvl 0, 0, 32, 0, 1, 1', - 'or 3, 4, 4', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.ld *14, 0(3)', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.or *46, *14, *14', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.addi *14, 0, 0', - 'setvl 0, 0, 32, 0, 1, 1', - 'subfc 0, 0, 0', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.or *78, *46, *46', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.or *46, *14, *14', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.adde *14, *78, *46', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'or 3, 4, 4', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.std *14, 0(3)', - ]) - - def test_register_allocate_spread(self): - fn = Fn() - maxvl = 32 - vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl], - name="vl", maxvl=maxvl).outputs[0] - li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0], - name="li", maxvl=maxvl).outputs[0] - spread = fn.append_new_op(OpKind.Spread, input_vals=[li, vl], - name="spread", maxvl=maxvl).outputs - _concat = fn.append_new_op( - OpKind.Concat, input_vals=[*spread[::-1], vl], - name="concat", maxvl=maxvl) - reg_assignments = allocate_registers(fn, debug_out=sys.stdout) - - self.assertEqual( - repr(reg_assignments), - "{>: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=15, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=16, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=17, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=18, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=19, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=20, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=21, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=22, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=23, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=24, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=25, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=26, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=27, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=28, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=29, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=30, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=31, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=32, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=33, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=34, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=35, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=36, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=37, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=38, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=39, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=40, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=41, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=42, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=43, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=44, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=45, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=3, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=4, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=5, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=15, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=16, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=17, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=18, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=19, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=20, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=21, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=22, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=23, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=24, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=25, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=26, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=27, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=28, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=29, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=30, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=31, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=32, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=33, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=34, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=35, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=36, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=37, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=38, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=39, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=40, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=41, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=42, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=43, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=44, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=45, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=6, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=7, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=8, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=9, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=10, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=11, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=12, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=46, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=47, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=48, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=49, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=50, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=51, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=52, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=53, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=54, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=55, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=56, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=57, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=58, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=59, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=60, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=61, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=62, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=63, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=64, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=65, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=66, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=67, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.GPR, start=14, reg_len=32), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), " - ">: " - "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)}" - ) - state = GenAsmState(reg_assignments) - fn.gen_asm(state) - self.assertEqual(state.output, [ - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'sv.addi *14, 0, 0', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1', - 'or 67, 14, 14', - 'or 66, 15, 15', - 'or 65, 16, 16', - 'or 64, 17, 17', - 'or 63, 18, 18', - 'or 62, 19, 19', - 'or 61, 20, 20', - 'or 60, 21, 21', - 'or 59, 22, 22', - 'or 58, 23, 23', - 'or 57, 24, 24', - 'or 56, 25, 25', - 'or 55, 26, 26', - 'or 54, 27, 27', - 'or 53, 28, 28', - 'or 52, 29, 29', - 'or 51, 30, 30', - 'or 50, 31, 31', - 'or 49, 32, 32', - 'or 48, 33, 33', - 'or 47, 34, 34', - 'or 46, 35, 35', - 'or 12, 36, 36', - 'or 11, 37, 37', - 'or 10, 38, 38', - 'or 9, 39, 39', - 'or 8, 40, 40', - 'or 7, 41, 41', - 'or 6, 42, 42', - 'or 5, 43, 43', - 'or 4, 44, 44', - 'or 3, 45, 45', - 'or 14, 3, 3', - 'or 15, 4, 4', - 'or 16, 5, 5', - 'or 17, 6, 6', - 'or 18, 7, 7', - 'or 19, 8, 8', - 'or 20, 9, 9', - 'or 21, 10, 10', - 'or 22, 11, 11', - 'or 23, 12, 12', - 'or 24, 46, 46', - 'or 25, 47, 47', - 'or 26, 48, 48', - 'or 27, 49, 49', - 'or 28, 50, 50', - 'or 29, 51, 51', - 'or 30, 52, 52', - 'or 31, 53, 53', - 'or 32, 54, 54', - 'or 33, 55, 55', - 'or 34, 56, 56', - 'or 35, 57, 57', - 'or 36, 58, 58', - 'or 37, 59, 59', - 'or 38, 60, 60', - 'or 39, 61, 61', - 'or 40, 62, 62', - 'or 41, 63, 63', - 'or 42, 64, 64', - 'or 43, 65, 65', - 'or 44, 66, 66', - 'or 45, 67, 67', - 'setvl 0, 0, 32, 0, 1, 1', - 'setvl 0, 0, 32, 0, 1, 1']) - - -if __name__ == "__main__": - _ = unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_toom_cook.py b/src/bigint_presentation_code/_tests/test_toom_cook.py index 3188430..96dc318 100644 --- a/src/bigint_presentation_code/_tests/test_toom_cook.py +++ b/src/bigint_presentation_code/_tests/test_toom_cook.py @@ -1,12 +1,12 @@ import unittest from typing import Callable -from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, - BaseSimState, Fn, - GenAsmState, OpKind, - PostRASimState, - PreRASimState) -from bigint_presentation_code.register_allocator2 import allocate_registers +from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, + BaseSimState, Fn, + GenAsmState, OpKind, + PostRASimState, + PreRASimState) +from bigint_presentation_code.register_allocator import allocate_registers from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py new file mode 100644 index 0000000..d3b52e8 --- /dev/null +++ b/src/bigint_presentation_code/compiler_ir.py @@ -0,0 +1,2294 @@ +import enum +from abc import ABCMeta, abstractmethod +from enum import Enum, unique +from functools import lru_cache, total_ordering +from io import StringIO +from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator, + Mapping, Sequence, TypeVar, Union, overload) +from weakref import WeakValueDictionary as _WeakVDict + +from cached_property import cached_property +from nmutil.plain_data import fields, plain_data + +from bigint_presentation_code.type_util import (Literal, Self, assert_never, + final) +from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta, + OFSet, OSet) + + +@final +class Fn: + def __init__(self): + self.ops = [] # type: list[Op] + self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op] + self.__next_name_suffix = 2 + + def _add_op_with_unused_name(self, op, name=""): + # type: (Op, str) -> str + if op.fn is not self: + raise ValueError("can't add Op to wrong Fn") + if hasattr(op, "name"): + raise ValueError("Op already named") + orig_name = name + while True: + if name != "" and name not in self.__op_names: + self.__op_names[name] = op + return name + name = orig_name + str(self.__next_name_suffix) + self.__next_name_suffix += 1 + + def __repr__(self): + # type: () -> str + return "" + + def append_op(self, op): + # type: (Op) -> None + if op.fn is not self: + raise ValueError("can't add Op to wrong Fn") + self.ops.append(op) + + def append_new_op(self, kind, input_vals=(), immediates=(), name="", + maxvl=1): + # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op + retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl), + input_vals=input_vals, immediates=immediates, name=name) + self.append_op(retval) + return retval + + def sim(self, state): + # type: (BaseSimState) -> None + for op in self.ops: + op.sim(state) + + def gen_asm(self, state): + # type: (GenAsmState) -> None + for op in self.ops: + op.gen_asm(state) + + def pre_ra_insert_copies(self): + # type: () -> None + orig_ops = list(self.ops) + copied_outputs = {} # type: dict[SSAVal, SSAVal] + setvli_outputs = {} # type: dict[SSAVal, Op] + self.ops.clear() + for op in orig_ops: + for i in range(len(op.input_vals)): + inp = copied_outputs[op.input_vals[i]] + if inp.ty.base_ty is BaseTy.I64: + maxvl = inp.ty.reg_len + if inp.ty.reg_len != 1: + setvl = self.append_new_op( + OpKind.SetVLI, immediates=[maxvl], + name=f"{op.name}.inp{i}.setvl") + vl = setvl.outputs[0] + mv = self.append_new_op( + OpKind.VecCopyToReg, input_vals=[inp, vl], + maxvl=maxvl, name=f"{op.name}.inp{i}.copy") + else: + mv = self.append_new_op( + OpKind.CopyToReg, input_vals=[inp], + name=f"{op.name}.inp{i}.copy") + op.input_vals[i] = mv.outputs[0] + elif inp.ty.base_ty is BaseTy.CA \ + or inp.ty.base_ty is BaseTy.VL_MAXVL: + # all copies would be no-ops, so we don't need to copy, + # though we do need to rematerialize SetVLI ops right + # before the ops VL + if inp in setvli_outputs: + setvl = self.append_new_op( + OpKind.SetVLI, + immediates=setvli_outputs[inp].immediates, + name=f"{op.name}.inp{i}.setvl") + inp = setvl.outputs[0] + op.input_vals[i] = inp + else: + assert_never(inp.ty.base_ty) + self.ops.append(op) + for i, out in enumerate(op.outputs): + if op.kind is OpKind.SetVLI: + setvli_outputs[out] = op + if out.ty.base_ty is BaseTy.I64: + maxvl = out.ty.reg_len + if out.ty.reg_len != 1: + setvl = self.append_new_op( + OpKind.SetVLI, immediates=[maxvl], + name=f"{op.name}.out{i}.setvl") + vl = setvl.outputs[0] + mv = self.append_new_op( + OpKind.VecCopyFromReg, input_vals=[out, vl], + maxvl=maxvl, name=f"{op.name}.out{i}.copy") + else: + mv = self.append_new_op( + OpKind.CopyFromReg, input_vals=[out], + name=f"{op.name}.out{i}.copy") + copied_outputs[out] = mv.outputs[0] + elif out.ty.base_ty is BaseTy.CA \ + or out.ty.base_ty is BaseTy.VL_MAXVL: + # all copies would be no-ops, so we don't need to copy + copied_outputs[out] = out + else: + assert_never(out.ty.base_ty) + + +@final +@unique +@total_ordering +class OpStage(Enum): + value: Literal[0, 1] # type: ignore + + def __new__(cls, value): + # type: (int) -> OpStage + value = int(value) + if value not in (0, 1): + raise ValueError("invalid value") + retval = object.__new__(cls) + retval._value_ = value + return retval + + Early = 0 + """ early stage of Op execution, where all input reads occur. + all output writes with `write_stage == Early` occur here too, and therefore + conflict with input reads, telling the compiler that it that can't share + that output's register with any inputs that the output isn't tied to. + + All outputs, even unused outputs, can't share registers with any other + outputs, independent of `write_stage` settings. + """ + Late = 1 + """ late stage of Op execution, where all output writes with + `write_stage == Late` occur, and therefore don't conflict with input reads, + telling the compiler that any inputs can safely use the same register as + those outputs. + + All outputs, even unused outputs, can't share registers with any other + outputs, independent of `write_stage` settings. + """ + + def __repr__(self): + # type: () -> str + return f"OpStage.{self._name_}" + + def __lt__(self, other): + # type: (OpStage | object) -> bool + if isinstance(other, OpStage): + return self.value < other.value + return NotImplemented + + +assert OpStage.Early < OpStage.Late, "early must be less than late" + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +@total_ordering +class ProgramPoint(metaclass=InternedMeta): + __slots__ = "op_index", "stage" + + def __init__(self, op_index, stage): + # type: (int, OpStage) -> None + self.op_index = op_index + self.stage = stage + + @property + def int_value(self): + # type: () -> int + """ an integer representation of `self` such that it keeps ordering and + successor/predecessor relations. + """ + return self.op_index * 2 + self.stage.value + + @staticmethod + def from_int_value(int_value): + # type: (int) -> ProgramPoint + op_index, stage = divmod(int_value, 2) + return ProgramPoint(op_index=op_index, stage=OpStage(stage)) + + def next(self, steps=1): + # type: (int) -> ProgramPoint + return ProgramPoint.from_int_value(self.int_value + steps) + + def prev(self, steps=1): + # type: (int) -> ProgramPoint + return self.next(steps=-steps) + + def __lt__(self, other): + # type: (ProgramPoint | Any) -> bool + if not isinstance(other, ProgramPoint): + return NotImplemented + if self.op_index != other.op_index: + return self.op_index < other.op_index + return self.stage < other.stage + + def __repr__(self): + # type: () -> str + return f"" + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta): + __slots__ = "start", "stop" + + def __init__(self, start, stop): + # type: (ProgramPoint, ProgramPoint) -> None + self.start = start + self.stop = stop + + @cached_property + def int_value_range(self): + # type: () -> range + return range(self.start.int_value, self.stop.int_value) + + @staticmethod + def from_int_value_range(int_value_range): + # type: (range) -> ProgramRange + if int_value_range.step != 1: + raise ValueError("int_value_range must have step == 1") + return ProgramRange( + start=ProgramPoint.from_int_value(int_value_range.start), + stop=ProgramPoint.from_int_value(int_value_range.stop)) + + @overload + def __getitem__(self, __idx): + # type: (int) -> ProgramPoint + ... + + @overload + def __getitem__(self, __idx): + # type: (slice) -> ProgramRange + ... + + def __getitem__(self, __idx): + # type: (int | slice) -> ProgramPoint | ProgramRange + v = range(self.start.int_value, self.stop.int_value)[__idx] + if isinstance(v, int): + return ProgramPoint.from_int_value(v) + return ProgramRange.from_int_value_range(v) + + def __len__(self): + # type: () -> int + return len(self.int_value_range) + + def __iter__(self): + # type: () -> Iterator[ProgramPoint] + return map(ProgramPoint.from_int_value, self.int_value_range) + + def __repr__(self): + # type: () -> str + start = repr(self.start).lstrip("<").rstrip(">") + stop = repr(self.stop).lstrip("<").rstrip(">") + return f"" + + +@plain_data(frozen=True, eq=False, repr=False) +@final +class FnAnalysis: + __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at", + "def_program_ranges", "use_program_points", + "all_program_points") + + def __init__(self, fn): + # type: (Fn) -> None + self.fn = fn + self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops)) + self.all_program_points = ProgramRange( + start=ProgramPoint(op_index=0, stage=OpStage.Early), + stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early)) + def_program_ranges = {} # type: dict[SSAVal, ProgramRange] + use_program_points = {} # type: dict[SSAUse, ProgramPoint] + uses = {} # type: dict[SSAVal, OSet[SSAUse]] + live_range_stops = {} # type: dict[SSAVal, ProgramPoint] + for op in fn.ops: + for use in op.input_uses: + uses[use.ssa_val].add(use) + use_program_point = self.__get_use_program_point(use) + use_program_points[use] = use_program_point + live_range_stops[use.ssa_val] = max( + live_range_stops[use.ssa_val], use_program_point.next()) + for out in op.outputs: + uses[out] = OSet() + def_program_range = self.__get_def_program_range(out) + def_program_ranges[out] = def_program_range + live_range_stops[out] = def_program_range.stop + self.uses = FMap((k, OFSet(v)) for k, v in uses.items()) + self.def_program_ranges = FMap(def_program_ranges) + self.use_program_points = FMap(use_program_points) + live_ranges = {} # type: dict[SSAVal, ProgramRange] + live_at = {i: OSet[SSAVal]() for i in self.all_program_points} + for ssa_val in uses.keys(): + live_ranges[ssa_val] = live_range = ProgramRange( + start=self.def_program_ranges[ssa_val].start, + stop=live_range_stops[ssa_val]) + for program_point in live_range: + live_at[program_point].add(ssa_val) + self.live_ranges = FMap(live_ranges) + self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items()) + + def __get_def_program_range(self, ssa_val): + # type: (SSAVal) -> ProgramRange + write_stage = ssa_val.defining_descriptor.write_stage + start = ProgramPoint( + op_index=self.op_indexes[ssa_val.op], stage=write_stage) + # always include late stage of ssa_val.op, to ensure outputs always + # overlap all other outputs. + # stop is exclusive, so we need the next program point. + stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next() + return ProgramRange(start=start, stop=stop) + + def __get_use_program_point(self, ssa_use): + # type: (SSAUse) -> ProgramPoint + assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \ + "assumed here, ensured by GenericOpProperties.__init__" + return ProgramPoint( + op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early) + + def __eq__(self, other): + # type: (FnAnalysis | Any) -> bool + if isinstance(other, FnAnalysis): + return self.fn == other.fn + return NotImplemented + + def __hash__(self): + # type: () -> int + return hash(self.fn) + + def __repr__(self): + # type: () -> str + return "" + + +@unique +@final +class BaseTy(Enum): + I64 = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() + + @cached_property + def only_scalar(self): + # type: () -> bool + if self is BaseTy.I64: + return False + elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: + return True + else: + assert_never(self) + + @cached_property + def max_reg_len(self): + # type: () -> int + if self is BaseTy.I64: + return 128 + elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: + return 1 + else: + assert_never(self) + + def __repr__(self): + return "BaseTy." + self._name_ + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +class Ty(metaclass=InternedMeta): + __slots__ = "base_ty", "reg_len" + + @staticmethod + def validate(base_ty, reg_len): + # type: (BaseTy, int) -> str | None + """ return a string with the error if the combination is invalid, + otherwise return None + """ + if base_ty.only_scalar and reg_len != 1: + return f"can't create a vector of an only-scalar type: {base_ty}" + if reg_len < 1 or reg_len > base_ty.max_reg_len: + return "reg_len out of range" + return None + + def __init__(self, base_ty, reg_len): + # type: (BaseTy, int) -> None + msg = self.validate(base_ty=base_ty, reg_len=reg_len) + if msg is not None: + raise ValueError(msg) + self.base_ty = base_ty + self.reg_len = reg_len + + def __repr__(self): + # type: () -> str + if self.reg_len != 1: + reg_len = f"*{self.reg_len}" + else: + reg_len = "" + return f"<{self.base_ty._name_}{reg_len}>" + + +@unique +@final +class LocKind(Enum): + GPR = enum.auto() + StackI64 = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() + + @cached_property + def base_ty(self): + # type: () -> BaseTy + if self is LocKind.GPR or self is LocKind.StackI64: + return BaseTy.I64 + if self is LocKind.CA: + return BaseTy.CA + if self is LocKind.VL_MAXVL: + return BaseTy.VL_MAXVL + else: + assert_never(self) + + @cached_property + def loc_count(self): + # type: () -> int + if self is LocKind.StackI64: + return 512 + if self is LocKind.GPR or self is LocKind.CA \ + or self is LocKind.VL_MAXVL: + return self.base_ty.max_reg_len + else: + assert_never(self) + + def __repr__(self): + return "LocKind." + self._name_ + + +@final +@unique +class LocSubKind(Enum): + BASE_GPR = enum.auto() + SV_EXTRA2_VGPR = enum.auto() + SV_EXTRA2_SGPR = enum.auto() + SV_EXTRA3_VGPR = enum.auto() + SV_EXTRA3_SGPR = enum.auto() + StackI64 = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() + + @cached_property + def kind(self): + # type: () -> LocKind + # pyright fails typechecking when using `in` here: + # reported: https://github.com/microsoft/pyright/issues/4102 + if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR, + LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR, + LocSubKind.SV_EXTRA3_SGPR): + return LocKind.GPR + if self is LocSubKind.StackI64: + return LocKind.StackI64 + if self is LocSubKind.CA: + return LocKind.CA + if self is LocSubKind.VL_MAXVL: + return LocKind.VL_MAXVL + assert_never(self) + + @property + def base_ty(self): + return self.kind.base_ty + + @lru_cache() + def allocatable_locs(self, ty): + # type: (Ty) -> LocSet + if ty.base_ty != self.base_ty: + raise ValueError("type mismatch") + if self is LocSubKind.BASE_GPR: + starts = range(32) + elif self is LocSubKind.SV_EXTRA2_VGPR: + starts = range(0, 128, 2) + elif self is LocSubKind.SV_EXTRA2_SGPR: + starts = range(64) + elif self is LocSubKind.SV_EXTRA3_VGPR \ + or self is LocSubKind.SV_EXTRA3_SGPR: + starts = range(128) + elif self is LocSubKind.StackI64: + starts = range(LocKind.StackI64.loc_count) + elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL: + return LocSet([Loc(kind=self.kind, start=0, reg_len=1)]) + else: + assert_never(self) + retval = [] # type: list[Loc] + for start in starts: + loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len) + if loc is None: + continue + conflicts = False + for special_loc in SPECIAL_GPRS: + if loc.conflicts(special_loc): + conflicts = True + break + if not conflicts: + retval.append(loc) + return LocSet(retval) + + def __repr__(self): + return "LocSubKind." + self._name_ + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class GenericTy(metaclass=InternedMeta): + __slots__ = "base_ty", "is_vec" + + def __init__(self, base_ty, is_vec): + # type: (BaseTy, bool) -> None + self.base_ty = base_ty + if base_ty.only_scalar and is_vec: + raise ValueError(f"base_ty={base_ty} requires is_vec=False") + self.is_vec = is_vec + + def instantiate(self, maxvl): + # type: (int) -> Ty + # here's where subvl and elwid would be accounted for + if self.is_vec: + return Ty(self.base_ty, maxvl) + return Ty(self.base_ty, 1) + + def can_instantiate_to(self, ty): + # type: (Ty) -> bool + if self.base_ty != ty.base_ty: + return False + if self.is_vec: + return True + return ty.reg_len == 1 + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class Loc(metaclass=InternedMeta): + __slots__ = "kind", "start", "reg_len" + + @staticmethod + def validate(kind, start, reg_len): + # type: (LocKind, int, int) -> str | None + msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len) + if msg is not None: + return msg + if reg_len > kind.loc_count: + return "invalid reg_len" + if start < 0 or start + reg_len > kind.loc_count: + return "start not in valid range" + return None + + @staticmethod + def try_make(kind, start, reg_len): + # type: (LocKind, int, int) -> Loc | None + msg = Loc.validate(kind=kind, start=start, reg_len=reg_len) + if msg is not None: + return None + return Loc(kind=kind, start=start, reg_len=reg_len) + + def __init__(self, kind, start, reg_len): + # type: (LocKind, int, int) -> None + msg = self.validate(kind=kind, start=start, reg_len=reg_len) + if msg is not None: + raise ValueError(msg) + self.kind = kind + self.reg_len = reg_len + self.start = start + + def conflicts(self, other): + # type: (Loc) -> bool + return (self.kind == other.kind + and self.start < other.stop and other.start < self.stop) + + @staticmethod + def make_ty(kind, reg_len): + # type: (LocKind, int) -> Ty + return Ty(base_ty=kind.base_ty, reg_len=reg_len) + + @cached_property + def ty(self): + # type: () -> Ty + return self.make_ty(kind=self.kind, reg_len=self.reg_len) + + @property + def stop(self): + # type: () -> int + return self.start + self.reg_len + + def try_concat(self, *others): + # type: (*Loc | None) -> Loc | None + reg_len = self.reg_len + stop = self.stop + for other in others: + if other is None or other.kind != self.kind: + return None + if stop != other.start: + return None + stop = other.stop + reg_len += other.reg_len + return Loc(kind=self.kind, start=self.start, reg_len=reg_len) + + def get_subloc_at_offset(self, subloc_ty, offset): + # type: (Ty, int) -> Loc + if subloc_ty.base_ty != self.kind.base_ty: + raise ValueError("BaseTy mismatch") + if offset < 0 or offset + subloc_ty.reg_len > self.reg_len: + raise ValueError("invalid sub-Loc: offset and/or " + "subloc_ty.reg_len out of range") + return Loc(kind=self.kind, + start=self.start + offset, reg_len=subloc_ty.reg_len) + + +SPECIAL_GPRS = ( + Loc(kind=LocKind.GPR, start=0, reg_len=1), + Loc(kind=LocKind.GPR, start=1, reg_len=1), + Loc(kind=LocKind.GPR, start=2, reg_len=1), + Loc(kind=LocKind.GPR, start=13, reg_len=1), +) + + +@final +class _LocSetHashHelper(AbstractSet[Loc]): + """helper to more quickly compute LocSet's hash""" + + def __init__(self, locs): + # type: (Iterable[Loc]) -> None + super().__init__() + self.locs = list(locs) + + def __hash__(self): + # type: () -> int + return super()._hash() + + def __contains__(self, x): + # type: (Loc | Any) -> bool + return x in self.locs + + def __iter__(self): + # type: () -> Iterator[Loc] + return iter(self.locs) + + def __len__(self): + return len(self.locs) + + +@plain_data(frozen=True, eq=False, repr=False) +@final +class LocSet(AbstractSet[Loc], metaclass=InternedMeta): + __slots__ = "starts", "ty", "_LocSet__hash" + + def __init__(self, __locs=()): + # type: (Iterable[Loc]) -> None + if isinstance(__locs, LocSet): + self.starts = __locs.starts # type: FMap[LocKind, FBitSet] + self.ty = __locs.ty # type: Ty | None + self._LocSet__hash = __locs._LocSet__hash # type: int + return + starts = {i: BitSet() for i in LocKind} + ty = None # type: None | Ty + + def locs(): + # type: () -> Iterable[Loc] + nonlocal ty + for loc in __locs: + if ty is None: + ty = loc.ty + if ty != loc.ty: + raise ValueError(f"conflicting types: {ty} != {loc.ty}") + starts[loc.kind].add(loc.start) + yield loc + self._LocSet__hash = _LocSetHashHelper(locs()).__hash__() + self.starts = FMap( + (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0) + self.ty = ty + + @cached_property + def stops(self): + # type: () -> FMap[LocKind, FBitSet] + if self.ty is None: + return FMap() + sh = self.ty.reg_len + return FMap( + (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items()) + + @property + def kinds(self): + # type: () -> AbstractSet[LocKind] + return self.starts.keys() + + @property + def reg_len(self): + # type: () -> int | None + if self.ty is None: + return None + return self.ty.reg_len + + @property + def base_ty(self): + # type: () -> BaseTy | None + if self.ty is None: + return None + return self.ty.base_ty + + def concat(self, *others): + # type: (*LocSet) -> LocSet + if self.ty is None: + return LocSet() + base_ty = self.ty.base_ty + reg_len = self.ty.reg_len + starts = {k: BitSet(v) for k, v in self.starts.items()} + for other in others: + if other.ty is None: + return LocSet() + if other.ty.base_ty != base_ty: + return LocSet() + for kind, other_starts in other.starts.items(): + if kind not in starts: + continue + starts[kind].bits &= other_starts.bits >> reg_len + if starts[kind] == 0: + del starts[kind] + if len(starts) == 0: + return LocSet() + reg_len += other.ty.reg_len + + def locs(): + # type: () -> Iterable[Loc] + for kind, v in starts.items(): + for start in v: + loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len) + if loc is not None: + yield loc + return LocSet(locs()) + + def __contains__(self, loc): + # type: (Loc | Any) -> bool + if not isinstance(loc, Loc) or loc.ty != self.ty: + return False + if loc.kind not in self.starts: + return False + return loc.start in self.starts[loc.kind] + + def __iter__(self): + # type: () -> Iterator[Loc] + if self.ty is None: + return + for kind, starts in self.starts.items(): + for start in starts: + yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len) + + @cached_property + def __len(self): + return sum((len(v) for v in self.starts.values()), 0) + + def __len__(self): + return self.__len + + def __hash__(self): + return self._LocSet__hash + + def __eq__(self, __other): + # type: (LocSet | Any) -> bool + if isinstance(__other, LocSet): + return self.ty == __other.ty and self.starts == __other.starts + return super().__eq__(__other) + + @lru_cache(maxsize=None, typed=True) + def max_conflicts_with(self, other): + # type: (LocSet | Loc) -> int + """the largest number of Locs in `self` that a single Loc + from `other` can conflict with + """ + if isinstance(other, LocSet): + return max(self.max_conflicts_with(i) for i in other) + else: + return sum(other.conflicts(i) for i in self) + + def __repr__(self): + items = [] # type: list[str] + for name in fields(self): + if name.startswith("_"): + continue + items.append(f"{name}={getattr(self, name)!r}") + return f"LocSet({', '.join(items)})" + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class GenericOperandDesc(metaclass=InternedMeta): + """generic Op operand descriptor""" + __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread", + "write_stage") + + def __init__( + self, ty, # type: GenericTy + sub_kinds, # type: Iterable[LocSubKind] + *, + fixed_loc=None, # type: Loc | None + tied_input_index=None, # type: int | None + spread=False, # type: bool + write_stage=OpStage.Early, # type: OpStage + ): + # type: (...) -> None + self.ty = ty + self.sub_kinds = OFSet(sub_kinds) + if len(self.sub_kinds) == 0: + raise ValueError("sub_kinds can't be empty") + self.fixed_loc = fixed_loc + if fixed_loc is not None: + if tied_input_index is not None: + raise ValueError("operand can't be both tied and fixed") + if not ty.can_instantiate_to(fixed_loc.ty): + raise ValueError( + f"fixed_loc has incompatible type for given generic " + f"type: fixed_loc={fixed_loc} generic ty={ty}") + if len(self.sub_kinds) != 1: + raise ValueError( + "multiple sub_kinds not allowed for fixed operand") + for sub_kind in self.sub_kinds: + if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty): + raise ValueError( + f"fixed_loc not in given sub_kind: " + f"fixed_loc={fixed_loc} sub_kind={sub_kind}") + for sub_kind in self.sub_kinds: + if sub_kind.base_ty != ty.base_ty: + raise ValueError(f"sub_kind is incompatible with type: " + f"sub_kind={sub_kind} ty={ty}") + if tied_input_index is not None and tied_input_index < 0: + raise ValueError("invalid tied_input_index") + self.tied_input_index = tied_input_index + self.spread = spread + if spread: + if self.tied_input_index is not None: + raise ValueError("operand can't be both spread and tied") + if self.fixed_loc is not None: + raise ValueError("operand can't be both spread and fixed") + if self.ty.is_vec: + raise ValueError("operand can't be both spread and vector") + self.write_stage = write_stage + + @cached_property + def ty_before_spread(self): + # type: () -> GenericTy + if self.spread: + return GenericTy(base_ty=self.ty.base_ty, is_vec=True) + return self.ty + + def tied_to_input(self, tied_input_index): + # type: (int) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, + tied_input_index=tied_input_index, + write_stage=self.write_stage) + + def with_fixed_loc(self, fixed_loc): + # type: (Loc) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc, + write_stage=self.write_stage) + + def with_write_stage(self, write_stage): + # type: (OpStage) -> Self + return GenericOperandDesc(self.ty, self.sub_kinds, + fixed_loc=self.fixed_loc, + tied_input_index=self.tied_input_index, + spread=self.spread, + write_stage=write_stage) + + def instantiate(self, maxvl): + # type: (int) -> Iterable[OperandDesc] + # assumes all spread operands have ty.reg_len = 1 + rep_count = 1 + if self.spread: + rep_count = maxvl + ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl) + + def locs_before_spread(): + # type: () -> Iterable[Loc] + if self.fixed_loc is not None: + if ty_before_spread != self.fixed_loc.ty: + raise ValueError( + f"instantiation failed: type mismatch with fixed_loc: " + f"instantiated type: {ty_before_spread} " + f"fixed_loc: {self.fixed_loc}") + yield self.fixed_loc + return + for sub_kind in self.sub_kinds: + yield from sub_kind.allocatable_locs(ty_before_spread) + loc_set_before_spread = LocSet(locs_before_spread()) + for idx in range(rep_count): + if not self.spread: + idx = None + yield OperandDesc(loc_set_before_spread=loc_set_before_spread, + tied_input_index=self.tied_input_index, + spread_index=idx, write_stage=self.write_stage) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class OperandDesc(metaclass=InternedMeta): + """Op operand descriptor""" + __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index", + "write_stage") + + def __init__(self, loc_set_before_spread, tied_input_index, spread_index, + write_stage): + # type: (LocSet, int | None, int | None, OpStage) -> None + if len(loc_set_before_spread) == 0: + raise ValueError("loc_set_before_spread must not be empty") + self.loc_set_before_spread = loc_set_before_spread + self.tied_input_index = tied_input_index + if self.tied_input_index is not None and spread_index is not None: + raise ValueError("operand can't be both spread and tied") + self.spread_index = spread_index + self.write_stage = write_stage + + @cached_property + def ty_before_spread(self): + # type: () -> Ty + ty = self.loc_set_before_spread.ty + assert ty is not None, ( + "__init__ checked that the LocSet isn't empty, " + "non-empty LocSets should always have ty set") + return ty + + @cached_property + def ty(self): + """ Ty after any spread is applied """ + if self.spread_index is not None: + # assumes all spread operands have ty.reg_len = 1 + return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1) + return self.ty_before_spread + + @property + def reg_offset_in_unspread(self): + """ the number of reg-sized slots in the unspread Loc before self's Loc + + e.g. if the unspread Loc containing self is: + `Loc(kind=LocKind.GPR, start=8, reg_len=4)` + and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)` + then reg_offset_into_unspread == 2 == 10 - 8 + """ + if self.spread_index is None: + return 0 + return self.spread_index * self.ty.reg_len + + +OD_BASE_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.BASE_GPR]) +OD_EXTRA3_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_SGPR]) +OD_EXTRA3_VGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA3_VGPR]) +OD_EXTRA2_SGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA2_SGPR]) +OD_EXTRA2_VGPR = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA2_VGPR]) +OD_CA = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.CA, is_vec=False), + sub_kinds=[LocSubKind.CA]) +OD_VL = GenericOperandDesc( + ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False), + sub_kinds=[LocSubKind.VL_MAXVL]) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class GenericOpProperties(metaclass=InternedMeta): + __slots__ = ("demo_asm", "inputs", "outputs", "immediates", + "is_copy", "is_load_immediate", "has_side_effects") + + def __init__( + self, demo_asm, # type: str + inputs, # type: Iterable[GenericOperandDesc] + outputs, # type: Iterable[GenericOperandDesc] + immediates=(), # type: Iterable[range] + is_copy=False, # type: bool + is_load_immediate=False, # type: bool + has_side_effects=False, # type: bool + ): + # type: (...) -> None + self.demo_asm = demo_asm # type: str + self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...] + for inp in self.inputs: + if inp.tied_input_index is not None: + raise ValueError( + f"tied_input_index is not allowed on inputs: {inp}") + if inp.write_stage is not OpStage.Early: + raise ValueError( + f"write_stage is not allowed on inputs: {inp}") + self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...] + fixed_locs = [] # type: list[tuple[Loc, int]] + for idx, out in enumerate(self.outputs): + if out.tied_input_index is not None: + if out.tied_input_index >= len(self.inputs): + raise ValueError(f"tied_input_index out of range: {out}") + tied_inp = self.inputs[out.tied_input_index] + expected_out = tied_inp.tied_to_input(out.tied_input_index) \ + .with_write_stage(out.write_stage) + if expected_out != out: + raise ValueError(f"output can't be tied to non-equivalent " + f"input: {out} tied to {tied_inp}") + if out.fixed_loc is not None: + for other_fixed_loc, other_idx in fixed_locs: + if not other_fixed_loc.conflicts(out.fixed_loc): + continue + raise ValueError( + f"conflicting fixed_locs: outputs[{idx}] and " + f"outputs[{other_idx}]: {out.fixed_loc} conflicts " + f"with {other_fixed_loc}") + fixed_locs.append((out.fixed_loc, idx)) + self.immediates = tuple(immediates) # type: tuple[range, ...] + self.is_copy = is_copy # type: bool + self.is_load_immediate = is_load_immediate # type: bool + self.has_side_effects = has_side_effects # type: bool + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class OpProperties(metaclass=InternedMeta): + __slots__ = "kind", "inputs", "outputs", "maxvl" + + def __init__(self, kind, maxvl): + # type: (OpKind, int) -> None + self.kind = kind # type: OpKind + inputs = [] # type: list[OperandDesc] + for inp in self.generic.inputs: + inputs.extend(inp.instantiate(maxvl=maxvl)) + self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...] + outputs = [] # type: list[OperandDesc] + for out in self.generic.outputs: + outputs.extend(out.instantiate(maxvl=maxvl)) + self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...] + self.maxvl = maxvl # type: int + + @property + def generic(self): + # type: () -> GenericOpProperties + return self.kind.properties + + @property + def immediates(self): + # type: () -> tuple[range, ...] + return self.generic.immediates + + @property + def demo_asm(self): + # type: () -> str + return self.generic.demo_asm + + @property + def is_copy(self): + # type: () -> bool + return self.generic.is_copy + + @property + def is_load_immediate(self): + # type: () -> bool + return self.generic.is_load_immediate + + @property + def has_side_effects(self): + # type: () -> bool + return self.generic.has_side_effects + + +IMM_S16 = range(-1 << 15, 1 << 15) + +_SIM_FN = Callable[["Op", "BaseSimState"], None] +_SIM_FN2 = Callable[[], _SIM_FN] +_SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2] +_GEN_ASM_FN = Callable[["Op", "GenAsmState"], None] +_GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN] +_GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2] + + +@unique +@final +class OpKind(Enum): + def __init__(self, properties): + # type: (GenericOpProperties) -> None + super().__init__() + self.__properties = properties + + @property + def properties(self): + # type: () -> GenericOpProperties + return self.__properties + + def instantiate(self, maxvl): + # type: (int) -> OpProperties + return OpProperties(self, maxvl=maxvl) + + def __repr__(self): + # type: () -> str + return "OpKind." + self._name_ + + @cached_property + def sim(self): + # type: () -> _SIM_FN + return _SIM_FNS[self.properties]() + + @cached_property + def gen_asm(self): + # type: () -> _GEN_ASM_FN + return _GEN_ASMS[self.properties]() + + @staticmethod + def __clearca_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = False, + + @staticmethod + def __clearca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("addic 0, 0, 0") + ClearCA = GenericOpProperties( + demo_asm="addic 0, 0, 0", + inputs=[], + outputs=[OD_CA.with_write_stage(OpStage.Late)], + ) + _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim + _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm + + @staticmethod + def __setca_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = True, + + @staticmethod + def __setca_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + state.writeln("subfc 0, 0, 0") + SetCA = GenericOpProperties( + demo_asm="subfc 0, 0, 0", + inputs=[], + outputs=[OD_CA.with_write_stage(OpStage.Late)], + ) + _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim + _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm + + @staticmethod + def __svadde_sim(op, state): + # type: (Op, BaseSimState) -> None + RA = state[op.input_vals[0]] + RB = state[op.input_vals[1]] + carry, = state[op.input_vals[2]] + VL, = state[op.input_vals[3]] + RT = [] # type: list[int] + for i in range(VL): + v = RA[i] + RB[i] + carry + RT.append(v & GPR_VALUE_MASK) + carry = (v >> GPR_SIZE_IN_BITS) != 0 + state[op.outputs[0]] = tuple(RT) + state[op.outputs[1]] = carry, + + @staticmethod + def __svadde_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + RB = state.vgpr(op.input_vals[1]) + state.writeln(f"sv.adde {RT}, {RA}, {RB}") + SvAddE = GenericOpProperties( + demo_asm="sv.adde *RT, *RA, *RB", + inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], + outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], + ) + _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim + _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm + + @staticmethod + def __addze_sim(op, state): + # type: (Op, BaseSimState) -> None + RA, = state[op.input_vals[0]] + carry, = state[op.input_vals[1]] + v = RA + carry + RT = v & GPR_VALUE_MASK + carry = (v >> GPR_SIZE_IN_BITS) != 0 + state[op.outputs[0]] = RT, + state[op.outputs[1]] = carry, + + @staticmethod + def __addze_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + state.writeln(f"addze {RT}, {RA}") + AddZE = GenericOpProperties( + demo_asm="addze RT, RA", + inputs=[OD_BASE_SGPR, OD_CA], + outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)], + ) + _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim + _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm + + @staticmethod + def __svsubfe_sim(op, state): + # type: (Op, BaseSimState) -> None + RA = state[op.input_vals[0]] + RB = state[op.input_vals[1]] + carry, = state[op.input_vals[2]] + VL, = state[op.input_vals[3]] + RT = [] # type: list[int] + for i in range(VL): + v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry + RT.append(v & GPR_VALUE_MASK) + carry = (v >> GPR_SIZE_IN_BITS) != 0 + state[op.outputs[0]] = tuple(RT) + state[op.outputs[1]] = carry, + + @staticmethod + def __svsubfe_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + RB = state.vgpr(op.input_vals[1]) + state.writeln(f"sv.subfe {RT}, {RA}, {RB}") + SvSubFE = GenericOpProperties( + demo_asm="sv.subfe *RT, *RA, *RB", + inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], + outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], + ) + _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim + _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm + + @staticmethod + def __svmaddedu_sim(op, state): + # type: (Op, BaseSimState) -> None + RA = state[op.input_vals[0]] + RB, = state[op.input_vals[1]] + carry, = state[op.input_vals[2]] + VL, = state[op.input_vals[3]] + RT = [] # type: list[int] + for i in range(VL): + v = RA[i] * RB + carry + RT.append(v & GPR_VALUE_MASK) + carry = v >> GPR_SIZE_IN_BITS + state[op.outputs[0]] = tuple(RT) + state[op.outputs[1]] = carry, + + @staticmethod + def __svmaddedu_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + RA = state.vgpr(op.input_vals[0]) + RB = state.sgpr(op.input_vals[1]) + RC = state.sgpr(op.input_vals[2]) + state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}") + SvMAddEDU = GenericOpProperties( + demo_asm="sv.maddedu *RT, *RA, RB, RC", + inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL], + outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)], + ) + _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim + _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm + + @staticmethod + def __setvli_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = op.immediates[0], + + @staticmethod + def __setvli_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + imm = op.immediates[0] + state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1") + SetVLI = GenericOpProperties( + demo_asm="setvl 0, 0, imm, 0, 1, 1", + inputs=(), + outputs=[OD_VL.with_write_stage(OpStage.Late)], + immediates=[range(1, 65)], + is_load_immediate=True, + ) + _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim + _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm + + @staticmethod + def __svli_sim(op, state): + # type: (Op, BaseSimState) -> None + VL, = state[op.input_vals[0]] + imm = op.immediates[0] & GPR_VALUE_MASK + state[op.outputs[0]] = (imm,) * VL + + @staticmethod + def __svli_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.vgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"sv.addi {RT}, 0, {imm}") + SvLI = GenericOpProperties( + demo_asm="sv.addi *RT, 0, imm", + inputs=[OD_VL], + outputs=[OD_EXTRA3_VGPR], + immediates=[IMM_S16], + is_load_immediate=True, + ) + _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim + _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm + + @staticmethod + def __li_sim(op, state): + # type: (Op, BaseSimState) -> None + imm = op.immediates[0] & GPR_VALUE_MASK + state[op.outputs[0]] = imm, + + @staticmethod + def __li_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RT = state.sgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"addi {RT}, 0, {imm}") + LI = GenericOpProperties( + demo_asm="addi RT, 0, imm", + inputs=(), + outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)], + immediates=[IMM_S16], + is_load_immediate=True, + ) + _SIM_FNS[LI] = lambda: OpKind.__li_sim + _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm + + @staticmethod + def __veccopytoreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] + + @staticmethod + def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state): + # type: (Loc, Loc, bool, GenAsmState) -> None + sv = "sv." if is_vec else "" + rev = "" + if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start: + rev = "/mrr" + if src_loc == dest_loc: + return # no-op + if src_loc.kind not in (LocKind.GPR, LocKind.StackI64): + raise ValueError(f"invalid src_loc.kind: {src_loc.kind}") + if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64): + raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}") + if src_loc.kind is LocKind.StackI64: + if dest_loc.kind is LocKind.StackI64: + raise ValueError( + f"can't copy from stack to stack: {src_loc} {dest_loc}") + elif dest_loc.kind is not LocKind.GPR: + assert_never(dest_loc.kind) + src = state.stack(src_loc) + dest = state.gpr(dest_loc, is_vec=is_vec) + state.writeln(f"{sv}ld {dest}, {src}") + elif dest_loc.kind is LocKind.StackI64: + if src_loc.kind is not LocKind.GPR: + assert_never(src_loc.kind) + src = state.gpr(src_loc, is_vec=is_vec) + dest = state.stack(dest_loc) + state.writeln(f"{sv}std {src}, {dest}") + elif src_loc.kind is LocKind.GPR: + if dest_loc.kind is not LocKind.GPR: + assert_never(dest_loc.kind) + src = state.gpr(src_loc, is_vec=is_vec) + dest = state.gpr(dest_loc, is_vec=is_vec) + state.writeln(f"{sv}or{rev} {dest}, {src}, {src}") + else: + assert_never(src_loc.kind) + + @staticmethod + def __veccopytoreg_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc( + op.input_vals[0], (LocKind.GPR, LocKind.StackI64)), + dest_loc=state.loc(op.outputs[0], LocKind.GPR), + is_vec=True, state=state) + + VecCopyToReg = GenericOpProperties( + demo_asm="sv.mv dest, src", + inputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64], + ), OD_VL], + outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], + is_copy=True, + ) + _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim + _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm + + @staticmethod + def __veccopyfromreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] + + @staticmethod + def __veccopyfromreg_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0], LocKind.GPR), + dest_loc=state.loc( + op.outputs[0], (LocKind.GPR, LocKind.StackI64)), + is_vec=True, state=state) + VecCopyFromReg = GenericOpProperties( + demo_asm="sv.mv dest, src", + inputs=[OD_EXTRA3_VGPR, OD_VL], + outputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=True), + sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64], + write_stage=OpStage.Late, + )], + is_copy=True, + ) + _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim + _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm + + @staticmethod + def __copytoreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] + + @staticmethod + def __copytoreg_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc( + op.input_vals[0], (LocKind.GPR, LocKind.StackI64)), + dest_loc=state.loc(op.outputs[0], LocKind.GPR), + is_vec=False, state=state) + CopyToReg = GenericOpProperties( + demo_asm="mv dest, src", + inputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR, + LocSubKind.StackI64], + )], + outputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR], + write_stage=OpStage.Late, + )], + is_copy=True, + ) + _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim + _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm + + @staticmethod + def __copyfromreg_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = state[op.input_vals[0]] + + @staticmethod + def __copyfromreg_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0], LocKind.GPR), + dest_loc=state.loc( + op.outputs[0], (LocKind.GPR, LocKind.StackI64)), + is_vec=False, state=state) + CopyFromReg = GenericOpProperties( + demo_asm="mv dest, src", + inputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR], + )], + outputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR, + LocSubKind.StackI64], + write_stage=OpStage.Late, + )], + is_copy=True, + ) + _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim + _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm + + @staticmethod + def __concat_sim(op, state): + # type: (Op, BaseSimState) -> None + state[op.outputs[0]] = tuple( + state[i][0] for i in op.input_vals[:-1]) + + @staticmethod + def __concat_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR), + dest_loc=state.loc(op.outputs[0], LocKind.GPR), + is_vec=True, state=state) + Concat = GenericOpProperties( + demo_asm="sv.mv dest, src", + inputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_VGPR], + spread=True, + ), OD_VL], + outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], + is_copy=True, + ) + _SIM_FNS[Concat] = lambda: OpKind.__concat_sim + _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm + + @staticmethod + def __spread_sim(op, state): + # type: (Op, BaseSimState) -> None + for idx, inp in enumerate(state[op.input_vals[0]]): + state[op.outputs[idx]] = inp, + + @staticmethod + def __spread_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + OpKind.__copy_to_from_reg_gen_asm( + src_loc=state.loc(op.input_vals[0], LocKind.GPR), + dest_loc=state.loc(op.outputs, LocKind.GPR), + is_vec=True, state=state) + Spread = GenericOpProperties( + demo_asm="sv.mv dest, src", + inputs=[OD_EXTRA3_VGPR, OD_VL], + outputs=[GenericOperandDesc( + ty=GenericTy(BaseTy.I64, is_vec=False), + sub_kinds=[LocSubKind.SV_EXTRA3_VGPR], + spread=True, + write_stage=OpStage.Late, + )], + is_copy=True, + ) + _SIM_FNS[Spread] = lambda: OpKind.__spread_sim + _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm + + @staticmethod + def __svld_sim(op, state): + # type: (Op, BaseSimState) -> None + RA, = state[op.input_vals[0]] + VL, = state[op.input_vals[1]] + addr = RA + op.immediates[0] + RT = [] # type: list[int] + for i in range(VL): + v = state.load(addr + GPR_SIZE_IN_BYTES * i) + RT.append(v & GPR_VALUE_MASK) + state[op.outputs[0]] = tuple(RT) + + @staticmethod + def __svld_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RA = state.sgpr(op.input_vals[0]) + RT = state.vgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"sv.ld {RT}, {imm}({RA})") + SvLd = GenericOpProperties( + demo_asm="sv.ld *RT, imm(RA)", + inputs=[OD_EXTRA3_SGPR, OD_VL], + outputs=[OD_EXTRA3_VGPR], + immediates=[IMM_S16], + ) + _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim + _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm + + @staticmethod + def __ld_sim(op, state): + # type: (Op, BaseSimState) -> None + RA, = state[op.input_vals[0]] + addr = RA + op.immediates[0] + v = state.load(addr) + state[op.outputs[0]] = v & GPR_VALUE_MASK, + + @staticmethod + def __ld_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RA = state.sgpr(op.input_vals[0]) + RT = state.sgpr(op.outputs[0]) + imm = op.immediates[0] + state.writeln(f"ld {RT}, {imm}({RA})") + Ld = GenericOpProperties( + demo_asm="ld RT, imm(RA)", + inputs=[OD_BASE_SGPR], + outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)], + immediates=[IMM_S16], + ) + _SIM_FNS[Ld] = lambda: OpKind.__ld_sim + _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm + + @staticmethod + def __svstd_sim(op, state): + # type: (Op, BaseSimState) -> None + RS = state[op.input_vals[0]] + RA, = state[op.input_vals[1]] + VL, = state[op.input_vals[2]] + addr = RA + op.immediates[0] + for i in range(VL): + state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i]) + + @staticmethod + def __svstd_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RS = state.vgpr(op.input_vals[0]) + RA = state.sgpr(op.input_vals[1]) + imm = op.immediates[0] + state.writeln(f"sv.std {RS}, {imm}({RA})") + SvStd = GenericOpProperties( + demo_asm="sv.std *RS, imm(RA)", + inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL], + outputs=[], + immediates=[IMM_S16], + has_side_effects=True, + ) + _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim + _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm + + @staticmethod + def __std_sim(op, state): + # type: (Op, BaseSimState) -> None + RS, = state[op.input_vals[0]] + RA, = state[op.input_vals[1]] + addr = RA + op.immediates[0] + state.store(addr, value=RS) + + @staticmethod + def __std_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + RS = state.sgpr(op.input_vals[0]) + RA = state.sgpr(op.input_vals[1]) + imm = op.immediates[0] + state.writeln(f"std {RS}, {imm}({RA})") + Std = GenericOpProperties( + demo_asm="std RS, imm(RA)", + inputs=[OD_BASE_SGPR, OD_BASE_SGPR], + outputs=[], + immediates=[IMM_S16], + has_side_effects=True, + ) + _SIM_FNS[Std] = lambda: OpKind.__std_sim + _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm + + @staticmethod + def __funcargr3_sim(op, state): + # type: (Op, BaseSimState) -> None + pass # return value set before simulation + + @staticmethod + def __funcargr3_gen_asm(op, state): + # type: (Op, GenAsmState) -> None + pass # no instructions needed + FuncArgR3 = GenericOpProperties( + demo_asm="", + inputs=[], + outputs=[OD_BASE_SGPR.with_fixed_loc( + Loc(kind=LocKind.GPR, start=3, reg_len=1))], + ) + _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim + _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +class SSAValOrUse(metaclass=InternedMeta): + __slots__ = "op", "operand_idx" + + def __init__(self, op, operand_idx): + # type: (Op, int) -> None + super().__init__() + self.op = op + if operand_idx < 0 or operand_idx >= len(self.descriptor_array): + raise ValueError("invalid operand_idx") + self.operand_idx = operand_idx + + @abstractmethod + def __repr__(self): + # type: () -> str + ... + + @property + @abstractmethod + def descriptor_array(self): + # type: () -> tuple[OperandDesc, ...] + ... + + @cached_property + def defining_descriptor(self): + # type: () -> OperandDesc + return self.descriptor_array[self.operand_idx] + + @cached_property + def ty(self): + # type: () -> Ty + return self.defining_descriptor.ty + + @cached_property + def ty_before_spread(self): + # type: () -> Ty + return self.defining_descriptor.ty_before_spread + + @property + def base_ty(self): + # type: () -> BaseTy + return self.ty_before_spread.base_ty + + @property + def reg_offset_in_unspread(self): + """ the number of reg-sized slots in the unspread Loc before self's Loc + + e.g. if the unspread Loc containing self is: + `Loc(kind=LocKind.GPR, start=8, reg_len=4)` + and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)` + then reg_offset_into_unspread == 2 == 10 - 8 + """ + return self.defining_descriptor.reg_offset_in_unspread + + @property + def unspread_start_idx(self): + # type: () -> int + return self.operand_idx - (self.defining_descriptor.spread_index or 0) + + @property + def unspread_start(self): + # type: () -> Self + return self.__class__(op=self.op, operand_idx=self.unspread_start_idx) + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +class SSAVal(SSAValOrUse): + __slots__ = () + + def __repr__(self): + # type: () -> str + return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>" + + @cached_property + def def_loc_set_before_spread(self): + # type: () -> LocSet + return self.defining_descriptor.loc_set_before_spread + + @cached_property + def descriptor_array(self): + # type: () -> tuple[OperandDesc, ...] + return self.op.properties.outputs + + @cached_property + def tied_input(self): + # type: () -> None | SSAUse + if self.defining_descriptor.tied_input_index is None: + return None + return SSAUse(op=self.op, + operand_idx=self.defining_descriptor.tied_input_index) + + @property + def write_stage(self): + # type: () -> OpStage + return self.defining_descriptor.write_stage + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +class SSAUse(SSAValOrUse): + __slots__ = () + + @cached_property + def use_loc_set_before_spread(self): + # type: () -> LocSet + return self.defining_descriptor.loc_set_before_spread + + @cached_property + def descriptor_array(self): + # type: () -> tuple[OperandDesc, ...] + return self.op.properties.inputs + + def __repr__(self): + # type: () -> str + return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>" + + @property + def ssa_val(self): + # type: () -> SSAVal + return self.op.input_vals[self.operand_idx] + + @ssa_val.setter + def ssa_val(self, ssa_val): + # type: (SSAVal) -> None + self.op.input_vals[self.operand_idx] = ssa_val + + +_T = TypeVar("_T") +_Desc = TypeVar("_Desc") + + +class OpInputSeq(Sequence[_T], Generic[_T, _Desc]): + @abstractmethod + def _verify_write_with_desc(self, idx, item, desc): + # type: (int, _T | Any, _Desc) -> None + raise NotImplementedError + + @final + def _verify_write(self, idx, item): + # type: (int | Any, _T | Any) -> int + if not isinstance(idx, int): + if isinstance(idx, slice): + raise TypeError( + f"can't write to slice of {self.__class__.__name__}") + raise TypeError(f"can't write with index {idx!r}") + # normalize idx, raising IndexError if it is out of range + idx = range(len(self.descriptors))[idx] + desc = self.descriptors[idx] + self._verify_write_with_desc(idx, item, desc) + return idx + + def _on_set(self, idx, new_item, old_item): + # type: (int, _T, _T | None) -> None + pass + + @abstractmethod + def _get_descriptors(self): + # type: () -> tuple[_Desc, ...] + raise NotImplementedError + + @cached_property + @final + def descriptors(self): + # type: () -> tuple[_Desc, ...] + return self._get_descriptors() + + @property + @final + def op(self): + return self.__op + + def __init__(self, items, op): + # type: (Iterable[_T], Op) -> None + super().__init__() + self.__op = op + self.__items = [] # type: list[_T] + for idx, item in enumerate(items): + if idx >= len(self.descriptors): + raise ValueError("too many items") + _ = self._verify_write(idx, item) + self.__items.append(item) + if len(self.__items) < len(self.descriptors): + raise ValueError("not enough items") + + @final + def __iter__(self): + # type: () -> Iterator[_T] + yield from self.__items + + @overload + def __getitem__(self, idx): + # type: (int) -> _T + ... + + @overload + def __getitem__(self, idx): + # type: (slice) -> list[_T] + ... + + @final + def __getitem__(self, idx): + # type: (int | slice) -> _T | list[_T] + return self.__items[idx] + + @final + def __setitem__(self, idx, item): + # type: (int, _T) -> None + idx = self._verify_write(idx, item) + self.__items[idx] = item + + @final + def __len__(self): + # type: () -> int + return len(self.__items) + + def __repr__(self): + # type: () -> str + return f"{self.__class__.__name__}({self.__items}, op=...)" + + +@final +class OpInputVals(OpInputSeq[SSAVal, OperandDesc]): + def _get_descriptors(self): + # type: () -> tuple[OperandDesc, ...] + return self.op.properties.inputs + + def _verify_write_with_desc(self, idx, item, desc): + # type: (int, SSAVal | Any, OperandDesc) -> None + if not isinstance(item, SSAVal): + raise TypeError("expected value of type SSAVal") + if item.ty != desc.ty: + raise ValueError(f"assigned item's type {item.ty!r} doesn't match " + f"corresponding input's type {desc.ty!r}") + + def _on_set(self, idx, new_item, old_item): + # type: (int, SSAVal, SSAVal | None) -> None + SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore + + def __init__(self, items, op): + # type: (Iterable[SSAVal], Op) -> None + if hasattr(op, "inputs"): + raise ValueError("Op.inputs already set") + super().__init__(items, op) + + +@final +class OpImmediates(OpInputSeq[int, range]): + def _get_descriptors(self): + # type: () -> tuple[range, ...] + return self.op.properties.immediates + + def _verify_write_with_desc(self, idx, item, desc): + # type: (int, int | Any, range) -> None + if not isinstance(item, int): + raise TypeError("expected value of type int") + if item not in desc: + raise ValueError(f"immediate value {item!r} not in {desc!r}") + + def __init__(self, items, op): + # type: (Iterable[int], Op) -> None + if hasattr(op, "immediates"): + raise ValueError("Op.immediates already set") + super().__init__(items, op) + + +@plain_data(frozen=True, eq=False, repr=False) +@final +class Op: + __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates", + "outputs", "name") + + def __init__(self, fn, properties, input_vals, immediates, name=""): + # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None + self.fn = fn + self.properties = properties + self.input_vals = OpInputVals(input_vals, op=self) + inputs_len = len(self.properties.inputs) + self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len)) + self.immediates = OpImmediates(immediates, op=self) + outputs_len = len(self.properties.outputs) + self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len)) + self.name = fn._add_op_with_unused_name(self, name) # type: ignore + + @property + def kind(self): + # type: () -> OpKind + return self.properties.kind + + def __eq__(self, other): + # type: (Op | Any) -> bool + if isinstance(other, Op): + return self is other + return NotImplemented + + def __hash__(self): + # type: () -> int + return object.__hash__(self) + + def __repr__(self): + # type: () -> str + field_vals = [] # type: list[str] + for name in fields(self): + if name == "properties": + name = "kind" + elif name == "fn": + continue + try: + value = getattr(self, name) + except AttributeError: + field_vals.append(f"{name}=") + continue + if isinstance(value, OpInputSeq): + value = list(value) # type: ignore + field_vals.append(f"{name}={value!r}") + field_vals_str = ", ".join(field_vals) + return f"Op({field_vals_str})" + + def sim(self, state): + # type: (BaseSimState) -> None + for inp in self.input_vals: + try: + val = state[inp] + except KeyError: + raise ValueError(f"SSAVal {inp} not yet assigned when " + f"running {self}") + if len(val) != inp.ty.reg_len: + raise ValueError( + f"value of SSAVal {inp} has wrong number of elements: " + f"expected {inp.ty.reg_len} found " + f"{len(val)}: {val!r}") + if isinstance(state, PreRASimState): + for out in self.outputs: + if out in state.ssa_vals: + if self.kind is OpKind.FuncArgR3: + continue + raise ValueError(f"SSAVal {out} already assigned before " + f"running {self}") + self.kind.sim(self, state) + for out in self.outputs: + try: + val = state[out] + except KeyError: + raise ValueError(f"running {self} failed to assign to {out}") + if len(val) != out.ty.reg_len: + raise ValueError( + f"value of SSAVal {out} has wrong number of elements: " + f"expected {out.ty.reg_len} found " + f"{len(val)}: {val!r}") + + def gen_asm(self, state): + # type: (GenAsmState) -> None + all_loc_kinds = tuple(LocKind) + for inp in self.input_vals: + state.loc(inp, expected_kinds=all_loc_kinds) + for out in self.outputs: + state.loc(out, expected_kinds=all_loc_kinds) + self.kind.gen_asm(self, state) + + +GPR_SIZE_IN_BYTES = 8 +BITS_IN_BYTE = 8 +GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE +GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1 + + +@plain_data(frozen=True, repr=False) +class BaseSimState(metaclass=ABCMeta): + __slots__ = "memory", + + def __init__(self, memory): + # type: (dict[int, int]) -> None + super().__init__() + self.memory = memory # type: dict[int, int] + + def load_byte(self, addr): + # type: (int) -> int + addr &= GPR_VALUE_MASK + return self.memory.get(addr, 0) & 0xFF + + def store_byte(self, addr, value): + # type: (int, int) -> None + addr &= GPR_VALUE_MASK + value &= 0xFF + self.memory[addr] = value + + def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False): + # type: (int, int, bool) -> int + if addr % size_in_bytes != 0: + raise ValueError(f"address not aligned: {hex(addr)} " + f"required alignment: {size_in_bytes}") + retval = 0 + for i in range(size_in_bytes): + retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE + if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0: + retval -= 1 << size_in_bytes * BITS_IN_BYTE + return retval + + def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES): + # type: (int, int, int) -> None + if addr % size_in_bytes != 0: + raise ValueError(f"address not aligned: {hex(addr)} " + f"required alignment: {size_in_bytes}") + for i in range(size_in_bytes): + self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF) + + def _memory__repr(self): + # type: () -> str + if len(self.memory) == 0: + return "{}" + keys = sorted(self.memory.keys(), reverse=True) + CHUNK_SIZE = GPR_SIZE_IN_BYTES + items = [] # type: list[str] + while len(keys) != 0: + addr = keys[-1] + if (len(keys) >= CHUNK_SIZE + and addr % CHUNK_SIZE == 0 + and keys[-CHUNK_SIZE:] + == list(reversed(range(addr, addr + CHUNK_SIZE)))): + value = self.load(addr, size_in_bytes=CHUNK_SIZE) + items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>") + keys[-CHUNK_SIZE:] = () + else: + items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}") + if len(items) == 1: + return f"{{{items[0]}}}" + items_str = ",\n".join(items) + return f"{{\n{items_str}}}" + + def __repr__(self): + # type: () -> str + field_vals = [] # type: list[str] + for name in fields(self): + try: + value = getattr(self, name) + except AttributeError: + field_vals.append(f"{name}=") + continue + repr_fn = getattr(self, f"_{name}__repr", None) + if callable(repr_fn): + field_vals.append(f"{name}={repr_fn()}") + else: + field_vals.append(f"{name}={value!r}") + field_vals_str = ", ".join(field_vals) + return f"{self.__class__.__name__}({field_vals_str})" + + @abstractmethod + def __getitem__(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + ... + + @abstractmethod + def __setitem__(self, ssa_val, value): + # type: (SSAVal, tuple[int, ...]) -> None + ... + + +@plain_data(frozen=True, repr=False) +@final +class PreRASimState(BaseSimState): + __slots__ = "ssa_vals", + + def __init__(self, ssa_vals, memory): + # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None + super().__init__(memory) + self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]] + + def _ssa_vals__repr(self): + # type: () -> str + if len(self.ssa_vals) == 0: + return "{}" + items = [] # type: list[str] + CHUNK_SIZE = 4 + for k, v in self.ssa_vals.items(): + element_strs = [] # type: list[str] + for i, el in enumerate(v): + if i % CHUNK_SIZE != 0: + element_strs.append(" " + hex(el)) + else: + element_strs.append("\n " + hex(el)) + if len(element_strs) <= CHUNK_SIZE: + element_strs[0] = element_strs[0].lstrip() + if len(element_strs) == 1: + element_strs.append("") + v_str = ",".join(element_strs) + items.append(f"{k!r}: ({v_str})") + if len(items) == 1 and "\n" not in items[0]: + return f"{{{items[0]}}}" + items_str = ",\n".join(items) + return f"{{\n{items_str},\n}}" + + def __getitem__(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + return self.ssa_vals[ssa_val] + + def __setitem__(self, ssa_val, value): + # type: (SSAVal, tuple[int, ...]) -> None + if len(value) != ssa_val.ty.reg_len: + raise ValueError("value has wrong len") + self.ssa_vals[ssa_val] = value + + +@plain_data(frozen=True, repr=False) +@final +class PostRASimState(BaseSimState): + __slots__ = "ssa_val_to_loc_map", "loc_values" + + def __init__(self, ssa_val_to_loc_map, memory, loc_values): + # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None + super().__init__(memory) + self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map) + for ssa_val, loc in self.ssa_val_to_loc_map.items(): + if ssa_val.ty != loc.ty: + raise ValueError( + f"type mismatch for SSAVal and Loc: {ssa_val} {loc}") + self.loc_values = loc_values + for loc in self.loc_values.keys(): + if loc.reg_len != 1: + raise ValueError( + "loc_values must only contain Locs with reg_len=1, all " + "larger Locs will be split into reg_len=1 sub-Locs") + + def _loc_values__repr(self): + # type: () -> str + locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start)) + items = [] # type: list[str] + for loc in locs: + items.append(f"{loc}: 0x{self.loc_values[loc]:x}") + items_str = ",\n".join(items) + return f"{{\n{items_str},\n}}" + + def __getitem__(self, ssa_val): + # type: (SSAVal) -> tuple[int, ...] + loc = self.ssa_val_to_loc_map[ssa_val] + subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1) + retval = [] # type: list[int] + for i in range(loc.reg_len): + subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i) + retval.append(self.loc_values.get(subloc, 0)) + return tuple(retval) + + def __setitem__(self, ssa_val, value): + # type: (SSAVal, tuple[int, ...]) -> None + if len(value) != ssa_val.ty.reg_len: + raise ValueError("value has wrong len") + loc = self.ssa_val_to_loc_map[ssa_val] + subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1) + for i in range(loc.reg_len): + subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i) + self.loc_values[subloc] = value[i] + + +@plain_data(frozen=True) +class GenAsmState: + __slots__ = "allocated_locs", "output" + + def __init__(self, allocated_locs, output=None): + # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None + super().__init__() + self.allocated_locs = FMap(allocated_locs) + for ssa_val, loc in self.allocated_locs.items(): + if ssa_val.ty != loc.ty: + raise ValueError( + f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}") + if output is None: + output = [] + self.output = output + + __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]] + + def loc(self, ssa_val_or_locs, expected_kinds): + # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc + if isinstance(ssa_val_or_locs, (SSAVal, Loc)): + ssa_val_or_locs = [ssa_val_or_locs] + locs = [] # type: list[Loc] + for i in ssa_val_or_locs: + if isinstance(i, SSAVal): + locs.append(self.allocated_locs[i]) + else: + locs.append(i) + if len(locs) == 0: + raise ValueError("invalid Loc sequence: must not be empty") + retval = locs[0].try_concat(*locs[1:]) + if retval is None: + raise ValueError("invalid Loc sequence: try_concat failed") + if isinstance(expected_kinds, LocKind): + expected_kinds = expected_kinds, + if retval.kind not in expected_kinds: + if len(expected_kinds) == 1: + expected_kinds = expected_kinds[0] + raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found " + f"{retval.kind} expected {expected_kinds}") + return retval + + def gpr(self, ssa_val_or_locs, is_vec): + # type: (__SSA_VAL_OR_LOCS, bool) -> str + loc = self.loc(ssa_val_or_locs, LocKind.GPR) + vec_str = "*" if is_vec else "" + return vec_str + str(loc.start) + + def sgpr(self, ssa_val_or_locs): + # type: (__SSA_VAL_OR_LOCS) -> str + return self.gpr(ssa_val_or_locs, is_vec=False) + + def vgpr(self, ssa_val_or_locs): + # type: (__SSA_VAL_OR_LOCS) -> str + return self.gpr(ssa_val_or_locs, is_vec=True) + + def stack(self, ssa_val_or_locs): + # type: (__SSA_VAL_OR_LOCS) -> str + loc = self.loc(ssa_val_or_locs, LocKind.StackI64) + return f"{loc.start}(1)" + + def writeln(self, *line_segments): + # type: (*str) -> None + line = " ".join(line_segments) + if isinstance(self.output, list): + self.output.append(line) + else: + self.output.write(line + "\n") diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py deleted file mode 100644 index d3b52e8..0000000 --- a/src/bigint_presentation_code/compiler_ir2.py +++ /dev/null @@ -1,2294 +0,0 @@ -import enum -from abc import ABCMeta, abstractmethod -from enum import Enum, unique -from functools import lru_cache, total_ordering -from io import StringIO -from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator, - Mapping, Sequence, TypeVar, Union, overload) -from weakref import WeakValueDictionary as _WeakVDict - -from cached_property import cached_property -from nmutil.plain_data import fields, plain_data - -from bigint_presentation_code.type_util import (Literal, Self, assert_never, - final) -from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta, - OFSet, OSet) - - -@final -class Fn: - def __init__(self): - self.ops = [] # type: list[Op] - self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op] - self.__next_name_suffix = 2 - - def _add_op_with_unused_name(self, op, name=""): - # type: (Op, str) -> str - if op.fn is not self: - raise ValueError("can't add Op to wrong Fn") - if hasattr(op, "name"): - raise ValueError("Op already named") - orig_name = name - while True: - if name != "" and name not in self.__op_names: - self.__op_names[name] = op - return name - name = orig_name + str(self.__next_name_suffix) - self.__next_name_suffix += 1 - - def __repr__(self): - # type: () -> str - return "" - - def append_op(self, op): - # type: (Op) -> None - if op.fn is not self: - raise ValueError("can't add Op to wrong Fn") - self.ops.append(op) - - def append_new_op(self, kind, input_vals=(), immediates=(), name="", - maxvl=1): - # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op - retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl), - input_vals=input_vals, immediates=immediates, name=name) - self.append_op(retval) - return retval - - def sim(self, state): - # type: (BaseSimState) -> None - for op in self.ops: - op.sim(state) - - def gen_asm(self, state): - # type: (GenAsmState) -> None - for op in self.ops: - op.gen_asm(state) - - def pre_ra_insert_copies(self): - # type: () -> None - orig_ops = list(self.ops) - copied_outputs = {} # type: dict[SSAVal, SSAVal] - setvli_outputs = {} # type: dict[SSAVal, Op] - self.ops.clear() - for op in orig_ops: - for i in range(len(op.input_vals)): - inp = copied_outputs[op.input_vals[i]] - if inp.ty.base_ty is BaseTy.I64: - maxvl = inp.ty.reg_len - if inp.ty.reg_len != 1: - setvl = self.append_new_op( - OpKind.SetVLI, immediates=[maxvl], - name=f"{op.name}.inp{i}.setvl") - vl = setvl.outputs[0] - mv = self.append_new_op( - OpKind.VecCopyToReg, input_vals=[inp, vl], - maxvl=maxvl, name=f"{op.name}.inp{i}.copy") - else: - mv = self.append_new_op( - OpKind.CopyToReg, input_vals=[inp], - name=f"{op.name}.inp{i}.copy") - op.input_vals[i] = mv.outputs[0] - elif inp.ty.base_ty is BaseTy.CA \ - or inp.ty.base_ty is BaseTy.VL_MAXVL: - # all copies would be no-ops, so we don't need to copy, - # though we do need to rematerialize SetVLI ops right - # before the ops VL - if inp in setvli_outputs: - setvl = self.append_new_op( - OpKind.SetVLI, - immediates=setvli_outputs[inp].immediates, - name=f"{op.name}.inp{i}.setvl") - inp = setvl.outputs[0] - op.input_vals[i] = inp - else: - assert_never(inp.ty.base_ty) - self.ops.append(op) - for i, out in enumerate(op.outputs): - if op.kind is OpKind.SetVLI: - setvli_outputs[out] = op - if out.ty.base_ty is BaseTy.I64: - maxvl = out.ty.reg_len - if out.ty.reg_len != 1: - setvl = self.append_new_op( - OpKind.SetVLI, immediates=[maxvl], - name=f"{op.name}.out{i}.setvl") - vl = setvl.outputs[0] - mv = self.append_new_op( - OpKind.VecCopyFromReg, input_vals=[out, vl], - maxvl=maxvl, name=f"{op.name}.out{i}.copy") - else: - mv = self.append_new_op( - OpKind.CopyFromReg, input_vals=[out], - name=f"{op.name}.out{i}.copy") - copied_outputs[out] = mv.outputs[0] - elif out.ty.base_ty is BaseTy.CA \ - or out.ty.base_ty is BaseTy.VL_MAXVL: - # all copies would be no-ops, so we don't need to copy - copied_outputs[out] = out - else: - assert_never(out.ty.base_ty) - - -@final -@unique -@total_ordering -class OpStage(Enum): - value: Literal[0, 1] # type: ignore - - def __new__(cls, value): - # type: (int) -> OpStage - value = int(value) - if value not in (0, 1): - raise ValueError("invalid value") - retval = object.__new__(cls) - retval._value_ = value - return retval - - Early = 0 - """ early stage of Op execution, where all input reads occur. - all output writes with `write_stage == Early` occur here too, and therefore - conflict with input reads, telling the compiler that it that can't share - that output's register with any inputs that the output isn't tied to. - - All outputs, even unused outputs, can't share registers with any other - outputs, independent of `write_stage` settings. - """ - Late = 1 - """ late stage of Op execution, where all output writes with - `write_stage == Late` occur, and therefore don't conflict with input reads, - telling the compiler that any inputs can safely use the same register as - those outputs. - - All outputs, even unused outputs, can't share registers with any other - outputs, independent of `write_stage` settings. - """ - - def __repr__(self): - # type: () -> str - return f"OpStage.{self._name_}" - - def __lt__(self, other): - # type: (OpStage | object) -> bool - if isinstance(other, OpStage): - return self.value < other.value - return NotImplemented - - -assert OpStage.Early < OpStage.Late, "early must be less than late" - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -@final -@total_ordering -class ProgramPoint(metaclass=InternedMeta): - __slots__ = "op_index", "stage" - - def __init__(self, op_index, stage): - # type: (int, OpStage) -> None - self.op_index = op_index - self.stage = stage - - @property - def int_value(self): - # type: () -> int - """ an integer representation of `self` such that it keeps ordering and - successor/predecessor relations. - """ - return self.op_index * 2 + self.stage.value - - @staticmethod - def from_int_value(int_value): - # type: (int) -> ProgramPoint - op_index, stage = divmod(int_value, 2) - return ProgramPoint(op_index=op_index, stage=OpStage(stage)) - - def next(self, steps=1): - # type: (int) -> ProgramPoint - return ProgramPoint.from_int_value(self.int_value + steps) - - def prev(self, steps=1): - # type: (int) -> ProgramPoint - return self.next(steps=-steps) - - def __lt__(self, other): - # type: (ProgramPoint | Any) -> bool - if not isinstance(other, ProgramPoint): - return NotImplemented - if self.op_index != other.op_index: - return self.op_index < other.op_index - return self.stage < other.stage - - def __repr__(self): - # type: () -> str - return f"" - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -@final -class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta): - __slots__ = "start", "stop" - - def __init__(self, start, stop): - # type: (ProgramPoint, ProgramPoint) -> None - self.start = start - self.stop = stop - - @cached_property - def int_value_range(self): - # type: () -> range - return range(self.start.int_value, self.stop.int_value) - - @staticmethod - def from_int_value_range(int_value_range): - # type: (range) -> ProgramRange - if int_value_range.step != 1: - raise ValueError("int_value_range must have step == 1") - return ProgramRange( - start=ProgramPoint.from_int_value(int_value_range.start), - stop=ProgramPoint.from_int_value(int_value_range.stop)) - - @overload - def __getitem__(self, __idx): - # type: (int) -> ProgramPoint - ... - - @overload - def __getitem__(self, __idx): - # type: (slice) -> ProgramRange - ... - - def __getitem__(self, __idx): - # type: (int | slice) -> ProgramPoint | ProgramRange - v = range(self.start.int_value, self.stop.int_value)[__idx] - if isinstance(v, int): - return ProgramPoint.from_int_value(v) - return ProgramRange.from_int_value_range(v) - - def __len__(self): - # type: () -> int - return len(self.int_value_range) - - def __iter__(self): - # type: () -> Iterator[ProgramPoint] - return map(ProgramPoint.from_int_value, self.int_value_range) - - def __repr__(self): - # type: () -> str - start = repr(self.start).lstrip("<").rstrip(">") - stop = repr(self.stop).lstrip("<").rstrip(">") - return f"" - - -@plain_data(frozen=True, eq=False, repr=False) -@final -class FnAnalysis: - __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at", - "def_program_ranges", "use_program_points", - "all_program_points") - - def __init__(self, fn): - # type: (Fn) -> None - self.fn = fn - self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops)) - self.all_program_points = ProgramRange( - start=ProgramPoint(op_index=0, stage=OpStage.Early), - stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early)) - def_program_ranges = {} # type: dict[SSAVal, ProgramRange] - use_program_points = {} # type: dict[SSAUse, ProgramPoint] - uses = {} # type: dict[SSAVal, OSet[SSAUse]] - live_range_stops = {} # type: dict[SSAVal, ProgramPoint] - for op in fn.ops: - for use in op.input_uses: - uses[use.ssa_val].add(use) - use_program_point = self.__get_use_program_point(use) - use_program_points[use] = use_program_point - live_range_stops[use.ssa_val] = max( - live_range_stops[use.ssa_val], use_program_point.next()) - for out in op.outputs: - uses[out] = OSet() - def_program_range = self.__get_def_program_range(out) - def_program_ranges[out] = def_program_range - live_range_stops[out] = def_program_range.stop - self.uses = FMap((k, OFSet(v)) for k, v in uses.items()) - self.def_program_ranges = FMap(def_program_ranges) - self.use_program_points = FMap(use_program_points) - live_ranges = {} # type: dict[SSAVal, ProgramRange] - live_at = {i: OSet[SSAVal]() for i in self.all_program_points} - for ssa_val in uses.keys(): - live_ranges[ssa_val] = live_range = ProgramRange( - start=self.def_program_ranges[ssa_val].start, - stop=live_range_stops[ssa_val]) - for program_point in live_range: - live_at[program_point].add(ssa_val) - self.live_ranges = FMap(live_ranges) - self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items()) - - def __get_def_program_range(self, ssa_val): - # type: (SSAVal) -> ProgramRange - write_stage = ssa_val.defining_descriptor.write_stage - start = ProgramPoint( - op_index=self.op_indexes[ssa_val.op], stage=write_stage) - # always include late stage of ssa_val.op, to ensure outputs always - # overlap all other outputs. - # stop is exclusive, so we need the next program point. - stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next() - return ProgramRange(start=start, stop=stop) - - def __get_use_program_point(self, ssa_use): - # type: (SSAUse) -> ProgramPoint - assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \ - "assumed here, ensured by GenericOpProperties.__init__" - return ProgramPoint( - op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early) - - def __eq__(self, other): - # type: (FnAnalysis | Any) -> bool - if isinstance(other, FnAnalysis): - return self.fn == other.fn - return NotImplemented - - def __hash__(self): - # type: () -> int - return hash(self.fn) - - def __repr__(self): - # type: () -> str - return "" - - -@unique -@final -class BaseTy(Enum): - I64 = enum.auto() - CA = enum.auto() - VL_MAXVL = enum.auto() - - @cached_property - def only_scalar(self): - # type: () -> bool - if self is BaseTy.I64: - return False - elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: - return True - else: - assert_never(self) - - @cached_property - def max_reg_len(self): - # type: () -> int - if self is BaseTy.I64: - return 128 - elif self is BaseTy.CA or self is BaseTy.VL_MAXVL: - return 1 - else: - assert_never(self) - - def __repr__(self): - return "BaseTy." + self._name_ - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -@final -class Ty(metaclass=InternedMeta): - __slots__ = "base_ty", "reg_len" - - @staticmethod - def validate(base_ty, reg_len): - # type: (BaseTy, int) -> str | None - """ return a string with the error if the combination is invalid, - otherwise return None - """ - if base_ty.only_scalar and reg_len != 1: - return f"can't create a vector of an only-scalar type: {base_ty}" - if reg_len < 1 or reg_len > base_ty.max_reg_len: - return "reg_len out of range" - return None - - def __init__(self, base_ty, reg_len): - # type: (BaseTy, int) -> None - msg = self.validate(base_ty=base_ty, reg_len=reg_len) - if msg is not None: - raise ValueError(msg) - self.base_ty = base_ty - self.reg_len = reg_len - - def __repr__(self): - # type: () -> str - if self.reg_len != 1: - reg_len = f"*{self.reg_len}" - else: - reg_len = "" - return f"<{self.base_ty._name_}{reg_len}>" - - -@unique -@final -class LocKind(Enum): - GPR = enum.auto() - StackI64 = enum.auto() - CA = enum.auto() - VL_MAXVL = enum.auto() - - @cached_property - def base_ty(self): - # type: () -> BaseTy - if self is LocKind.GPR or self is LocKind.StackI64: - return BaseTy.I64 - if self is LocKind.CA: - return BaseTy.CA - if self is LocKind.VL_MAXVL: - return BaseTy.VL_MAXVL - else: - assert_never(self) - - @cached_property - def loc_count(self): - # type: () -> int - if self is LocKind.StackI64: - return 512 - if self is LocKind.GPR or self is LocKind.CA \ - or self is LocKind.VL_MAXVL: - return self.base_ty.max_reg_len - else: - assert_never(self) - - def __repr__(self): - return "LocKind." + self._name_ - - -@final -@unique -class LocSubKind(Enum): - BASE_GPR = enum.auto() - SV_EXTRA2_VGPR = enum.auto() - SV_EXTRA2_SGPR = enum.auto() - SV_EXTRA3_VGPR = enum.auto() - SV_EXTRA3_SGPR = enum.auto() - StackI64 = enum.auto() - CA = enum.auto() - VL_MAXVL = enum.auto() - - @cached_property - def kind(self): - # type: () -> LocKind - # pyright fails typechecking when using `in` here: - # reported: https://github.com/microsoft/pyright/issues/4102 - if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR, - LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR, - LocSubKind.SV_EXTRA3_SGPR): - return LocKind.GPR - if self is LocSubKind.StackI64: - return LocKind.StackI64 - if self is LocSubKind.CA: - return LocKind.CA - if self is LocSubKind.VL_MAXVL: - return LocKind.VL_MAXVL - assert_never(self) - - @property - def base_ty(self): - return self.kind.base_ty - - @lru_cache() - def allocatable_locs(self, ty): - # type: (Ty) -> LocSet - if ty.base_ty != self.base_ty: - raise ValueError("type mismatch") - if self is LocSubKind.BASE_GPR: - starts = range(32) - elif self is LocSubKind.SV_EXTRA2_VGPR: - starts = range(0, 128, 2) - elif self is LocSubKind.SV_EXTRA2_SGPR: - starts = range(64) - elif self is LocSubKind.SV_EXTRA3_VGPR \ - or self is LocSubKind.SV_EXTRA3_SGPR: - starts = range(128) - elif self is LocSubKind.StackI64: - starts = range(LocKind.StackI64.loc_count) - elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL: - return LocSet([Loc(kind=self.kind, start=0, reg_len=1)]) - else: - assert_never(self) - retval = [] # type: list[Loc] - for start in starts: - loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len) - if loc is None: - continue - conflicts = False - for special_loc in SPECIAL_GPRS: - if loc.conflicts(special_loc): - conflicts = True - break - if not conflicts: - retval.append(loc) - return LocSet(retval) - - def __repr__(self): - return "LocSubKind." + self._name_ - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class GenericTy(metaclass=InternedMeta): - __slots__ = "base_ty", "is_vec" - - def __init__(self, base_ty, is_vec): - # type: (BaseTy, bool) -> None - self.base_ty = base_ty - if base_ty.only_scalar and is_vec: - raise ValueError(f"base_ty={base_ty} requires is_vec=False") - self.is_vec = is_vec - - def instantiate(self, maxvl): - # type: (int) -> Ty - # here's where subvl and elwid would be accounted for - if self.is_vec: - return Ty(self.base_ty, maxvl) - return Ty(self.base_ty, 1) - - def can_instantiate_to(self, ty): - # type: (Ty) -> bool - if self.base_ty != ty.base_ty: - return False - if self.is_vec: - return True - return ty.reg_len == 1 - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class Loc(metaclass=InternedMeta): - __slots__ = "kind", "start", "reg_len" - - @staticmethod - def validate(kind, start, reg_len): - # type: (LocKind, int, int) -> str | None - msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len) - if msg is not None: - return msg - if reg_len > kind.loc_count: - return "invalid reg_len" - if start < 0 or start + reg_len > kind.loc_count: - return "start not in valid range" - return None - - @staticmethod - def try_make(kind, start, reg_len): - # type: (LocKind, int, int) -> Loc | None - msg = Loc.validate(kind=kind, start=start, reg_len=reg_len) - if msg is not None: - return None - return Loc(kind=kind, start=start, reg_len=reg_len) - - def __init__(self, kind, start, reg_len): - # type: (LocKind, int, int) -> None - msg = self.validate(kind=kind, start=start, reg_len=reg_len) - if msg is not None: - raise ValueError(msg) - self.kind = kind - self.reg_len = reg_len - self.start = start - - def conflicts(self, other): - # type: (Loc) -> bool - return (self.kind == other.kind - and self.start < other.stop and other.start < self.stop) - - @staticmethod - def make_ty(kind, reg_len): - # type: (LocKind, int) -> Ty - return Ty(base_ty=kind.base_ty, reg_len=reg_len) - - @cached_property - def ty(self): - # type: () -> Ty - return self.make_ty(kind=self.kind, reg_len=self.reg_len) - - @property - def stop(self): - # type: () -> int - return self.start + self.reg_len - - def try_concat(self, *others): - # type: (*Loc | None) -> Loc | None - reg_len = self.reg_len - stop = self.stop - for other in others: - if other is None or other.kind != self.kind: - return None - if stop != other.start: - return None - stop = other.stop - reg_len += other.reg_len - return Loc(kind=self.kind, start=self.start, reg_len=reg_len) - - def get_subloc_at_offset(self, subloc_ty, offset): - # type: (Ty, int) -> Loc - if subloc_ty.base_ty != self.kind.base_ty: - raise ValueError("BaseTy mismatch") - if offset < 0 or offset + subloc_ty.reg_len > self.reg_len: - raise ValueError("invalid sub-Loc: offset and/or " - "subloc_ty.reg_len out of range") - return Loc(kind=self.kind, - start=self.start + offset, reg_len=subloc_ty.reg_len) - - -SPECIAL_GPRS = ( - Loc(kind=LocKind.GPR, start=0, reg_len=1), - Loc(kind=LocKind.GPR, start=1, reg_len=1), - Loc(kind=LocKind.GPR, start=2, reg_len=1), - Loc(kind=LocKind.GPR, start=13, reg_len=1), -) - - -@final -class _LocSetHashHelper(AbstractSet[Loc]): - """helper to more quickly compute LocSet's hash""" - - def __init__(self, locs): - # type: (Iterable[Loc]) -> None - super().__init__() - self.locs = list(locs) - - def __hash__(self): - # type: () -> int - return super()._hash() - - def __contains__(self, x): - # type: (Loc | Any) -> bool - return x in self.locs - - def __iter__(self): - # type: () -> Iterator[Loc] - return iter(self.locs) - - def __len__(self): - return len(self.locs) - - -@plain_data(frozen=True, eq=False, repr=False) -@final -class LocSet(AbstractSet[Loc], metaclass=InternedMeta): - __slots__ = "starts", "ty", "_LocSet__hash" - - def __init__(self, __locs=()): - # type: (Iterable[Loc]) -> None - if isinstance(__locs, LocSet): - self.starts = __locs.starts # type: FMap[LocKind, FBitSet] - self.ty = __locs.ty # type: Ty | None - self._LocSet__hash = __locs._LocSet__hash # type: int - return - starts = {i: BitSet() for i in LocKind} - ty = None # type: None | Ty - - def locs(): - # type: () -> Iterable[Loc] - nonlocal ty - for loc in __locs: - if ty is None: - ty = loc.ty - if ty != loc.ty: - raise ValueError(f"conflicting types: {ty} != {loc.ty}") - starts[loc.kind].add(loc.start) - yield loc - self._LocSet__hash = _LocSetHashHelper(locs()).__hash__() - self.starts = FMap( - (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0) - self.ty = ty - - @cached_property - def stops(self): - # type: () -> FMap[LocKind, FBitSet] - if self.ty is None: - return FMap() - sh = self.ty.reg_len - return FMap( - (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items()) - - @property - def kinds(self): - # type: () -> AbstractSet[LocKind] - return self.starts.keys() - - @property - def reg_len(self): - # type: () -> int | None - if self.ty is None: - return None - return self.ty.reg_len - - @property - def base_ty(self): - # type: () -> BaseTy | None - if self.ty is None: - return None - return self.ty.base_ty - - def concat(self, *others): - # type: (*LocSet) -> LocSet - if self.ty is None: - return LocSet() - base_ty = self.ty.base_ty - reg_len = self.ty.reg_len - starts = {k: BitSet(v) for k, v in self.starts.items()} - for other in others: - if other.ty is None: - return LocSet() - if other.ty.base_ty != base_ty: - return LocSet() - for kind, other_starts in other.starts.items(): - if kind not in starts: - continue - starts[kind].bits &= other_starts.bits >> reg_len - if starts[kind] == 0: - del starts[kind] - if len(starts) == 0: - return LocSet() - reg_len += other.ty.reg_len - - def locs(): - # type: () -> Iterable[Loc] - for kind, v in starts.items(): - for start in v: - loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len) - if loc is not None: - yield loc - return LocSet(locs()) - - def __contains__(self, loc): - # type: (Loc | Any) -> bool - if not isinstance(loc, Loc) or loc.ty != self.ty: - return False - if loc.kind not in self.starts: - return False - return loc.start in self.starts[loc.kind] - - def __iter__(self): - # type: () -> Iterator[Loc] - if self.ty is None: - return - for kind, starts in self.starts.items(): - for start in starts: - yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len) - - @cached_property - def __len(self): - return sum((len(v) for v in self.starts.values()), 0) - - def __len__(self): - return self.__len - - def __hash__(self): - return self._LocSet__hash - - def __eq__(self, __other): - # type: (LocSet | Any) -> bool - if isinstance(__other, LocSet): - return self.ty == __other.ty and self.starts == __other.starts - return super().__eq__(__other) - - @lru_cache(maxsize=None, typed=True) - def max_conflicts_with(self, other): - # type: (LocSet | Loc) -> int - """the largest number of Locs in `self` that a single Loc - from `other` can conflict with - """ - if isinstance(other, LocSet): - return max(self.max_conflicts_with(i) for i in other) - else: - return sum(other.conflicts(i) for i in self) - - def __repr__(self): - items = [] # type: list[str] - for name in fields(self): - if name.startswith("_"): - continue - items.append(f"{name}={getattr(self, name)!r}") - return f"LocSet({', '.join(items)})" - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class GenericOperandDesc(metaclass=InternedMeta): - """generic Op operand descriptor""" - __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread", - "write_stage") - - def __init__( - self, ty, # type: GenericTy - sub_kinds, # type: Iterable[LocSubKind] - *, - fixed_loc=None, # type: Loc | None - tied_input_index=None, # type: int | None - spread=False, # type: bool - write_stage=OpStage.Early, # type: OpStage - ): - # type: (...) -> None - self.ty = ty - self.sub_kinds = OFSet(sub_kinds) - if len(self.sub_kinds) == 0: - raise ValueError("sub_kinds can't be empty") - self.fixed_loc = fixed_loc - if fixed_loc is not None: - if tied_input_index is not None: - raise ValueError("operand can't be both tied and fixed") - if not ty.can_instantiate_to(fixed_loc.ty): - raise ValueError( - f"fixed_loc has incompatible type for given generic " - f"type: fixed_loc={fixed_loc} generic ty={ty}") - if len(self.sub_kinds) != 1: - raise ValueError( - "multiple sub_kinds not allowed for fixed operand") - for sub_kind in self.sub_kinds: - if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty): - raise ValueError( - f"fixed_loc not in given sub_kind: " - f"fixed_loc={fixed_loc} sub_kind={sub_kind}") - for sub_kind in self.sub_kinds: - if sub_kind.base_ty != ty.base_ty: - raise ValueError(f"sub_kind is incompatible with type: " - f"sub_kind={sub_kind} ty={ty}") - if tied_input_index is not None and tied_input_index < 0: - raise ValueError("invalid tied_input_index") - self.tied_input_index = tied_input_index - self.spread = spread - if spread: - if self.tied_input_index is not None: - raise ValueError("operand can't be both spread and tied") - if self.fixed_loc is not None: - raise ValueError("operand can't be both spread and fixed") - if self.ty.is_vec: - raise ValueError("operand can't be both spread and vector") - self.write_stage = write_stage - - @cached_property - def ty_before_spread(self): - # type: () -> GenericTy - if self.spread: - return GenericTy(base_ty=self.ty.base_ty, is_vec=True) - return self.ty - - def tied_to_input(self, tied_input_index): - # type: (int) -> Self - return GenericOperandDesc(self.ty, self.sub_kinds, - tied_input_index=tied_input_index, - write_stage=self.write_stage) - - def with_fixed_loc(self, fixed_loc): - # type: (Loc) -> Self - return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc, - write_stage=self.write_stage) - - def with_write_stage(self, write_stage): - # type: (OpStage) -> Self - return GenericOperandDesc(self.ty, self.sub_kinds, - fixed_loc=self.fixed_loc, - tied_input_index=self.tied_input_index, - spread=self.spread, - write_stage=write_stage) - - def instantiate(self, maxvl): - # type: (int) -> Iterable[OperandDesc] - # assumes all spread operands have ty.reg_len = 1 - rep_count = 1 - if self.spread: - rep_count = maxvl - ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl) - - def locs_before_spread(): - # type: () -> Iterable[Loc] - if self.fixed_loc is not None: - if ty_before_spread != self.fixed_loc.ty: - raise ValueError( - f"instantiation failed: type mismatch with fixed_loc: " - f"instantiated type: {ty_before_spread} " - f"fixed_loc: {self.fixed_loc}") - yield self.fixed_loc - return - for sub_kind in self.sub_kinds: - yield from sub_kind.allocatable_locs(ty_before_spread) - loc_set_before_spread = LocSet(locs_before_spread()) - for idx in range(rep_count): - if not self.spread: - idx = None - yield OperandDesc(loc_set_before_spread=loc_set_before_spread, - tied_input_index=self.tied_input_index, - spread_index=idx, write_stage=self.write_stage) - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class OperandDesc(metaclass=InternedMeta): - """Op operand descriptor""" - __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index", - "write_stage") - - def __init__(self, loc_set_before_spread, tied_input_index, spread_index, - write_stage): - # type: (LocSet, int | None, int | None, OpStage) -> None - if len(loc_set_before_spread) == 0: - raise ValueError("loc_set_before_spread must not be empty") - self.loc_set_before_spread = loc_set_before_spread - self.tied_input_index = tied_input_index - if self.tied_input_index is not None and spread_index is not None: - raise ValueError("operand can't be both spread and tied") - self.spread_index = spread_index - self.write_stage = write_stage - - @cached_property - def ty_before_spread(self): - # type: () -> Ty - ty = self.loc_set_before_spread.ty - assert ty is not None, ( - "__init__ checked that the LocSet isn't empty, " - "non-empty LocSets should always have ty set") - return ty - - @cached_property - def ty(self): - """ Ty after any spread is applied """ - if self.spread_index is not None: - # assumes all spread operands have ty.reg_len = 1 - return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1) - return self.ty_before_spread - - @property - def reg_offset_in_unspread(self): - """ the number of reg-sized slots in the unspread Loc before self's Loc - - e.g. if the unspread Loc containing self is: - `Loc(kind=LocKind.GPR, start=8, reg_len=4)` - and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)` - then reg_offset_into_unspread == 2 == 10 - 8 - """ - if self.spread_index is None: - return 0 - return self.spread_index * self.ty.reg_len - - -OD_BASE_SGPR = GenericOperandDesc( - ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.BASE_GPR]) -OD_EXTRA3_SGPR = GenericOperandDesc( - ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA3_SGPR]) -OD_EXTRA3_VGPR = GenericOperandDesc( - ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), - sub_kinds=[LocSubKind.SV_EXTRA3_VGPR]) -OD_EXTRA2_SGPR = GenericOperandDesc( - ty=GenericTy(base_ty=BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA2_SGPR]) -OD_EXTRA2_VGPR = GenericOperandDesc( - ty=GenericTy(base_ty=BaseTy.I64, is_vec=True), - sub_kinds=[LocSubKind.SV_EXTRA2_VGPR]) -OD_CA = GenericOperandDesc( - ty=GenericTy(base_ty=BaseTy.CA, is_vec=False), - sub_kinds=[LocSubKind.CA]) -OD_VL = GenericOperandDesc( - ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False), - sub_kinds=[LocSubKind.VL_MAXVL]) - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class GenericOpProperties(metaclass=InternedMeta): - __slots__ = ("demo_asm", "inputs", "outputs", "immediates", - "is_copy", "is_load_immediate", "has_side_effects") - - def __init__( - self, demo_asm, # type: str - inputs, # type: Iterable[GenericOperandDesc] - outputs, # type: Iterable[GenericOperandDesc] - immediates=(), # type: Iterable[range] - is_copy=False, # type: bool - is_load_immediate=False, # type: bool - has_side_effects=False, # type: bool - ): - # type: (...) -> None - self.demo_asm = demo_asm # type: str - self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...] - for inp in self.inputs: - if inp.tied_input_index is not None: - raise ValueError( - f"tied_input_index is not allowed on inputs: {inp}") - if inp.write_stage is not OpStage.Early: - raise ValueError( - f"write_stage is not allowed on inputs: {inp}") - self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...] - fixed_locs = [] # type: list[tuple[Loc, int]] - for idx, out in enumerate(self.outputs): - if out.tied_input_index is not None: - if out.tied_input_index >= len(self.inputs): - raise ValueError(f"tied_input_index out of range: {out}") - tied_inp = self.inputs[out.tied_input_index] - expected_out = tied_inp.tied_to_input(out.tied_input_index) \ - .with_write_stage(out.write_stage) - if expected_out != out: - raise ValueError(f"output can't be tied to non-equivalent " - f"input: {out} tied to {tied_inp}") - if out.fixed_loc is not None: - for other_fixed_loc, other_idx in fixed_locs: - if not other_fixed_loc.conflicts(out.fixed_loc): - continue - raise ValueError( - f"conflicting fixed_locs: outputs[{idx}] and " - f"outputs[{other_idx}]: {out.fixed_loc} conflicts " - f"with {other_fixed_loc}") - fixed_locs.append((out.fixed_loc, idx)) - self.immediates = tuple(immediates) # type: tuple[range, ...] - self.is_copy = is_copy # type: bool - self.is_load_immediate = is_load_immediate # type: bool - self.has_side_effects = has_side_effects # type: bool - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class OpProperties(metaclass=InternedMeta): - __slots__ = "kind", "inputs", "outputs", "maxvl" - - def __init__(self, kind, maxvl): - # type: (OpKind, int) -> None - self.kind = kind # type: OpKind - inputs = [] # type: list[OperandDesc] - for inp in self.generic.inputs: - inputs.extend(inp.instantiate(maxvl=maxvl)) - self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...] - outputs = [] # type: list[OperandDesc] - for out in self.generic.outputs: - outputs.extend(out.instantiate(maxvl=maxvl)) - self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...] - self.maxvl = maxvl # type: int - - @property - def generic(self): - # type: () -> GenericOpProperties - return self.kind.properties - - @property - def immediates(self): - # type: () -> tuple[range, ...] - return self.generic.immediates - - @property - def demo_asm(self): - # type: () -> str - return self.generic.demo_asm - - @property - def is_copy(self): - # type: () -> bool - return self.generic.is_copy - - @property - def is_load_immediate(self): - # type: () -> bool - return self.generic.is_load_immediate - - @property - def has_side_effects(self): - # type: () -> bool - return self.generic.has_side_effects - - -IMM_S16 = range(-1 << 15, 1 << 15) - -_SIM_FN = Callable[["Op", "BaseSimState"], None] -_SIM_FN2 = Callable[[], _SIM_FN] -_SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2] -_GEN_ASM_FN = Callable[["Op", "GenAsmState"], None] -_GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN] -_GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2] - - -@unique -@final -class OpKind(Enum): - def __init__(self, properties): - # type: (GenericOpProperties) -> None - super().__init__() - self.__properties = properties - - @property - def properties(self): - # type: () -> GenericOpProperties - return self.__properties - - def instantiate(self, maxvl): - # type: (int) -> OpProperties - return OpProperties(self, maxvl=maxvl) - - def __repr__(self): - # type: () -> str - return "OpKind." + self._name_ - - @cached_property - def sim(self): - # type: () -> _SIM_FN - return _SIM_FNS[self.properties]() - - @cached_property - def gen_asm(self): - # type: () -> _GEN_ASM_FN - return _GEN_ASMS[self.properties]() - - @staticmethod - def __clearca_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = False, - - @staticmethod - def __clearca_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - state.writeln("addic 0, 0, 0") - ClearCA = GenericOpProperties( - demo_asm="addic 0, 0, 0", - inputs=[], - outputs=[OD_CA.with_write_stage(OpStage.Late)], - ) - _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim - _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm - - @staticmethod - def __setca_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = True, - - @staticmethod - def __setca_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - state.writeln("subfc 0, 0, 0") - SetCA = GenericOpProperties( - demo_asm="subfc 0, 0, 0", - inputs=[], - outputs=[OD_CA.with_write_stage(OpStage.Late)], - ) - _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim - _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm - - @staticmethod - def __svadde_sim(op, state): - # type: (Op, BaseSimState) -> None - RA = state[op.input_vals[0]] - RB = state[op.input_vals[1]] - carry, = state[op.input_vals[2]] - VL, = state[op.input_vals[3]] - RT = [] # type: list[int] - for i in range(VL): - v = RA[i] + RB[i] + carry - RT.append(v & GPR_VALUE_MASK) - carry = (v >> GPR_SIZE_IN_BITS) != 0 - state[op.outputs[0]] = tuple(RT) - state[op.outputs[1]] = carry, - - @staticmethod - def __svadde_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RT = state.vgpr(op.outputs[0]) - RA = state.vgpr(op.input_vals[0]) - RB = state.vgpr(op.input_vals[1]) - state.writeln(f"sv.adde {RT}, {RA}, {RB}") - SvAddE = GenericOpProperties( - demo_asm="sv.adde *RT, *RA, *RB", - inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], - outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], - ) - _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim - _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm - - @staticmethod - def __addze_sim(op, state): - # type: (Op, BaseSimState) -> None - RA, = state[op.input_vals[0]] - carry, = state[op.input_vals[1]] - v = RA + carry - RT = v & GPR_VALUE_MASK - carry = (v >> GPR_SIZE_IN_BITS) != 0 - state[op.outputs[0]] = RT, - state[op.outputs[1]] = carry, - - @staticmethod - def __addze_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RT = state.vgpr(op.outputs[0]) - RA = state.vgpr(op.input_vals[0]) - state.writeln(f"addze {RT}, {RA}") - AddZE = GenericOpProperties( - demo_asm="addze RT, RA", - inputs=[OD_BASE_SGPR, OD_CA], - outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)], - ) - _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim - _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm - - @staticmethod - def __svsubfe_sim(op, state): - # type: (Op, BaseSimState) -> None - RA = state[op.input_vals[0]] - RB = state[op.input_vals[1]] - carry, = state[op.input_vals[2]] - VL, = state[op.input_vals[3]] - RT = [] # type: list[int] - for i in range(VL): - v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry - RT.append(v & GPR_VALUE_MASK) - carry = (v >> GPR_SIZE_IN_BITS) != 0 - state[op.outputs[0]] = tuple(RT) - state[op.outputs[1]] = carry, - - @staticmethod - def __svsubfe_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RT = state.vgpr(op.outputs[0]) - RA = state.vgpr(op.input_vals[0]) - RB = state.vgpr(op.input_vals[1]) - state.writeln(f"sv.subfe {RT}, {RA}, {RB}") - SvSubFE = GenericOpProperties( - demo_asm="sv.subfe *RT, *RA, *RB", - inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL], - outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)], - ) - _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim - _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm - - @staticmethod - def __svmaddedu_sim(op, state): - # type: (Op, BaseSimState) -> None - RA = state[op.input_vals[0]] - RB, = state[op.input_vals[1]] - carry, = state[op.input_vals[2]] - VL, = state[op.input_vals[3]] - RT = [] # type: list[int] - for i in range(VL): - v = RA[i] * RB + carry - RT.append(v & GPR_VALUE_MASK) - carry = v >> GPR_SIZE_IN_BITS - state[op.outputs[0]] = tuple(RT) - state[op.outputs[1]] = carry, - - @staticmethod - def __svmaddedu_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RT = state.vgpr(op.outputs[0]) - RA = state.vgpr(op.input_vals[0]) - RB = state.sgpr(op.input_vals[1]) - RC = state.sgpr(op.input_vals[2]) - state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}") - SvMAddEDU = GenericOpProperties( - demo_asm="sv.maddedu *RT, *RA, RB, RC", - inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL], - outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)], - ) - _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim - _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm - - @staticmethod - def __setvli_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = op.immediates[0], - - @staticmethod - def __setvli_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - imm = op.immediates[0] - state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1") - SetVLI = GenericOpProperties( - demo_asm="setvl 0, 0, imm, 0, 1, 1", - inputs=(), - outputs=[OD_VL.with_write_stage(OpStage.Late)], - immediates=[range(1, 65)], - is_load_immediate=True, - ) - _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim - _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm - - @staticmethod - def __svli_sim(op, state): - # type: (Op, BaseSimState) -> None - VL, = state[op.input_vals[0]] - imm = op.immediates[0] & GPR_VALUE_MASK - state[op.outputs[0]] = (imm,) * VL - - @staticmethod - def __svli_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RT = state.vgpr(op.outputs[0]) - imm = op.immediates[0] - state.writeln(f"sv.addi {RT}, 0, {imm}") - SvLI = GenericOpProperties( - demo_asm="sv.addi *RT, 0, imm", - inputs=[OD_VL], - outputs=[OD_EXTRA3_VGPR], - immediates=[IMM_S16], - is_load_immediate=True, - ) - _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim - _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm - - @staticmethod - def __li_sim(op, state): - # type: (Op, BaseSimState) -> None - imm = op.immediates[0] & GPR_VALUE_MASK - state[op.outputs[0]] = imm, - - @staticmethod - def __li_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RT = state.sgpr(op.outputs[0]) - imm = op.immediates[0] - state.writeln(f"addi {RT}, 0, {imm}") - LI = GenericOpProperties( - demo_asm="addi RT, 0, imm", - inputs=(), - outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)], - immediates=[IMM_S16], - is_load_immediate=True, - ) - _SIM_FNS[LI] = lambda: OpKind.__li_sim - _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm - - @staticmethod - def __veccopytoreg_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = state[op.input_vals[0]] - - @staticmethod - def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state): - # type: (Loc, Loc, bool, GenAsmState) -> None - sv = "sv." if is_vec else "" - rev = "" - if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start: - rev = "/mrr" - if src_loc == dest_loc: - return # no-op - if src_loc.kind not in (LocKind.GPR, LocKind.StackI64): - raise ValueError(f"invalid src_loc.kind: {src_loc.kind}") - if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64): - raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}") - if src_loc.kind is LocKind.StackI64: - if dest_loc.kind is LocKind.StackI64: - raise ValueError( - f"can't copy from stack to stack: {src_loc} {dest_loc}") - elif dest_loc.kind is not LocKind.GPR: - assert_never(dest_loc.kind) - src = state.stack(src_loc) - dest = state.gpr(dest_loc, is_vec=is_vec) - state.writeln(f"{sv}ld {dest}, {src}") - elif dest_loc.kind is LocKind.StackI64: - if src_loc.kind is not LocKind.GPR: - assert_never(src_loc.kind) - src = state.gpr(src_loc, is_vec=is_vec) - dest = state.stack(dest_loc) - state.writeln(f"{sv}std {src}, {dest}") - elif src_loc.kind is LocKind.GPR: - if dest_loc.kind is not LocKind.GPR: - assert_never(dest_loc.kind) - src = state.gpr(src_loc, is_vec=is_vec) - dest = state.gpr(dest_loc, is_vec=is_vec) - state.writeln(f"{sv}or{rev} {dest}, {src}, {src}") - else: - assert_never(src_loc.kind) - - @staticmethod - def __veccopytoreg_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - OpKind.__copy_to_from_reg_gen_asm( - src_loc=state.loc( - op.input_vals[0], (LocKind.GPR, LocKind.StackI64)), - dest_loc=state.loc(op.outputs[0], LocKind.GPR), - is_vec=True, state=state) - - VecCopyToReg = GenericOpProperties( - demo_asm="sv.mv dest, src", - inputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=True), - sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64], - ), OD_VL], - outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], - is_copy=True, - ) - _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim - _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm - - @staticmethod - def __veccopyfromreg_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = state[op.input_vals[0]] - - @staticmethod - def __veccopyfromreg_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - OpKind.__copy_to_from_reg_gen_asm( - src_loc=state.loc(op.input_vals[0], LocKind.GPR), - dest_loc=state.loc( - op.outputs[0], (LocKind.GPR, LocKind.StackI64)), - is_vec=True, state=state) - VecCopyFromReg = GenericOpProperties( - demo_asm="sv.mv dest, src", - inputs=[OD_EXTRA3_VGPR, OD_VL], - outputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=True), - sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64], - write_stage=OpStage.Late, - )], - is_copy=True, - ) - _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim - _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm - - @staticmethod - def __copytoreg_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = state[op.input_vals[0]] - - @staticmethod - def __copytoreg_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - OpKind.__copy_to_from_reg_gen_asm( - src_loc=state.loc( - op.input_vals[0], (LocKind.GPR, LocKind.StackI64)), - dest_loc=state.loc(op.outputs[0], LocKind.GPR), - is_vec=False, state=state) - CopyToReg = GenericOpProperties( - demo_asm="mv dest, src", - inputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR, - LocSubKind.StackI64], - )], - outputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR], - write_stage=OpStage.Late, - )], - is_copy=True, - ) - _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim - _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm - - @staticmethod - def __copyfromreg_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = state[op.input_vals[0]] - - @staticmethod - def __copyfromreg_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - OpKind.__copy_to_from_reg_gen_asm( - src_loc=state.loc(op.input_vals[0], LocKind.GPR), - dest_loc=state.loc( - op.outputs[0], (LocKind.GPR, LocKind.StackI64)), - is_vec=False, state=state) - CopyFromReg = GenericOpProperties( - demo_asm="mv dest, src", - inputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR], - )], - outputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR, - LocSubKind.StackI64], - write_stage=OpStage.Late, - )], - is_copy=True, - ) - _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim - _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm - - @staticmethod - def __concat_sim(op, state): - # type: (Op, BaseSimState) -> None - state[op.outputs[0]] = tuple( - state[i][0] for i in op.input_vals[:-1]) - - @staticmethod - def __concat_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - OpKind.__copy_to_from_reg_gen_asm( - src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR), - dest_loc=state.loc(op.outputs[0], LocKind.GPR), - is_vec=True, state=state) - Concat = GenericOpProperties( - demo_asm="sv.mv dest, src", - inputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA3_VGPR], - spread=True, - ), OD_VL], - outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)], - is_copy=True, - ) - _SIM_FNS[Concat] = lambda: OpKind.__concat_sim - _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm - - @staticmethod - def __spread_sim(op, state): - # type: (Op, BaseSimState) -> None - for idx, inp in enumerate(state[op.input_vals[0]]): - state[op.outputs[idx]] = inp, - - @staticmethod - def __spread_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - OpKind.__copy_to_from_reg_gen_asm( - src_loc=state.loc(op.input_vals[0], LocKind.GPR), - dest_loc=state.loc(op.outputs, LocKind.GPR), - is_vec=True, state=state) - Spread = GenericOpProperties( - demo_asm="sv.mv dest, src", - inputs=[OD_EXTRA3_VGPR, OD_VL], - outputs=[GenericOperandDesc( - ty=GenericTy(BaseTy.I64, is_vec=False), - sub_kinds=[LocSubKind.SV_EXTRA3_VGPR], - spread=True, - write_stage=OpStage.Late, - )], - is_copy=True, - ) - _SIM_FNS[Spread] = lambda: OpKind.__spread_sim - _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm - - @staticmethod - def __svld_sim(op, state): - # type: (Op, BaseSimState) -> None - RA, = state[op.input_vals[0]] - VL, = state[op.input_vals[1]] - addr = RA + op.immediates[0] - RT = [] # type: list[int] - for i in range(VL): - v = state.load(addr + GPR_SIZE_IN_BYTES * i) - RT.append(v & GPR_VALUE_MASK) - state[op.outputs[0]] = tuple(RT) - - @staticmethod - def __svld_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RA = state.sgpr(op.input_vals[0]) - RT = state.vgpr(op.outputs[0]) - imm = op.immediates[0] - state.writeln(f"sv.ld {RT}, {imm}({RA})") - SvLd = GenericOpProperties( - demo_asm="sv.ld *RT, imm(RA)", - inputs=[OD_EXTRA3_SGPR, OD_VL], - outputs=[OD_EXTRA3_VGPR], - immediates=[IMM_S16], - ) - _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim - _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm - - @staticmethod - def __ld_sim(op, state): - # type: (Op, BaseSimState) -> None - RA, = state[op.input_vals[0]] - addr = RA + op.immediates[0] - v = state.load(addr) - state[op.outputs[0]] = v & GPR_VALUE_MASK, - - @staticmethod - def __ld_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RA = state.sgpr(op.input_vals[0]) - RT = state.sgpr(op.outputs[0]) - imm = op.immediates[0] - state.writeln(f"ld {RT}, {imm}({RA})") - Ld = GenericOpProperties( - demo_asm="ld RT, imm(RA)", - inputs=[OD_BASE_SGPR], - outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)], - immediates=[IMM_S16], - ) - _SIM_FNS[Ld] = lambda: OpKind.__ld_sim - _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm - - @staticmethod - def __svstd_sim(op, state): - # type: (Op, BaseSimState) -> None - RS = state[op.input_vals[0]] - RA, = state[op.input_vals[1]] - VL, = state[op.input_vals[2]] - addr = RA + op.immediates[0] - for i in range(VL): - state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i]) - - @staticmethod - def __svstd_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RS = state.vgpr(op.input_vals[0]) - RA = state.sgpr(op.input_vals[1]) - imm = op.immediates[0] - state.writeln(f"sv.std {RS}, {imm}({RA})") - SvStd = GenericOpProperties( - demo_asm="sv.std *RS, imm(RA)", - inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL], - outputs=[], - immediates=[IMM_S16], - has_side_effects=True, - ) - _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim - _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm - - @staticmethod - def __std_sim(op, state): - # type: (Op, BaseSimState) -> None - RS, = state[op.input_vals[0]] - RA, = state[op.input_vals[1]] - addr = RA + op.immediates[0] - state.store(addr, value=RS) - - @staticmethod - def __std_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - RS = state.sgpr(op.input_vals[0]) - RA = state.sgpr(op.input_vals[1]) - imm = op.immediates[0] - state.writeln(f"std {RS}, {imm}({RA})") - Std = GenericOpProperties( - demo_asm="std RS, imm(RA)", - inputs=[OD_BASE_SGPR, OD_BASE_SGPR], - outputs=[], - immediates=[IMM_S16], - has_side_effects=True, - ) - _SIM_FNS[Std] = lambda: OpKind.__std_sim - _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm - - @staticmethod - def __funcargr3_sim(op, state): - # type: (Op, BaseSimState) -> None - pass # return value set before simulation - - @staticmethod - def __funcargr3_gen_asm(op, state): - # type: (Op, GenAsmState) -> None - pass # no instructions needed - FuncArgR3 = GenericOpProperties( - demo_asm="", - inputs=[], - outputs=[OD_BASE_SGPR.with_fixed_loc( - Loc(kind=LocKind.GPR, start=3, reg_len=1))], - ) - _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim - _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -class SSAValOrUse(metaclass=InternedMeta): - __slots__ = "op", "operand_idx" - - def __init__(self, op, operand_idx): - # type: (Op, int) -> None - super().__init__() - self.op = op - if operand_idx < 0 or operand_idx >= len(self.descriptor_array): - raise ValueError("invalid operand_idx") - self.operand_idx = operand_idx - - @abstractmethod - def __repr__(self): - # type: () -> str - ... - - @property - @abstractmethod - def descriptor_array(self): - # type: () -> tuple[OperandDesc, ...] - ... - - @cached_property - def defining_descriptor(self): - # type: () -> OperandDesc - return self.descriptor_array[self.operand_idx] - - @cached_property - def ty(self): - # type: () -> Ty - return self.defining_descriptor.ty - - @cached_property - def ty_before_spread(self): - # type: () -> Ty - return self.defining_descriptor.ty_before_spread - - @property - def base_ty(self): - # type: () -> BaseTy - return self.ty_before_spread.base_ty - - @property - def reg_offset_in_unspread(self): - """ the number of reg-sized slots in the unspread Loc before self's Loc - - e.g. if the unspread Loc containing self is: - `Loc(kind=LocKind.GPR, start=8, reg_len=4)` - and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)` - then reg_offset_into_unspread == 2 == 10 - 8 - """ - return self.defining_descriptor.reg_offset_in_unspread - - @property - def unspread_start_idx(self): - # type: () -> int - return self.operand_idx - (self.defining_descriptor.spread_index or 0) - - @property - def unspread_start(self): - # type: () -> Self - return self.__class__(op=self.op, operand_idx=self.unspread_start_idx) - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -@final -class SSAVal(SSAValOrUse): - __slots__ = () - - def __repr__(self): - # type: () -> str - return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>" - - @cached_property - def def_loc_set_before_spread(self): - # type: () -> LocSet - return self.defining_descriptor.loc_set_before_spread - - @cached_property - def descriptor_array(self): - # type: () -> tuple[OperandDesc, ...] - return self.op.properties.outputs - - @cached_property - def tied_input(self): - # type: () -> None | SSAUse - if self.defining_descriptor.tied_input_index is None: - return None - return SSAUse(op=self.op, - operand_idx=self.defining_descriptor.tied_input_index) - - @property - def write_stage(self): - # type: () -> OpStage - return self.defining_descriptor.write_stage - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -@final -class SSAUse(SSAValOrUse): - __slots__ = () - - @cached_property - def use_loc_set_before_spread(self): - # type: () -> LocSet - return self.defining_descriptor.loc_set_before_spread - - @cached_property - def descriptor_array(self): - # type: () -> tuple[OperandDesc, ...] - return self.op.properties.inputs - - def __repr__(self): - # type: () -> str - return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>" - - @property - def ssa_val(self): - # type: () -> SSAVal - return self.op.input_vals[self.operand_idx] - - @ssa_val.setter - def ssa_val(self, ssa_val): - # type: (SSAVal) -> None - self.op.input_vals[self.operand_idx] = ssa_val - - -_T = TypeVar("_T") -_Desc = TypeVar("_Desc") - - -class OpInputSeq(Sequence[_T], Generic[_T, _Desc]): - @abstractmethod - def _verify_write_with_desc(self, idx, item, desc): - # type: (int, _T | Any, _Desc) -> None - raise NotImplementedError - - @final - def _verify_write(self, idx, item): - # type: (int | Any, _T | Any) -> int - if not isinstance(idx, int): - if isinstance(idx, slice): - raise TypeError( - f"can't write to slice of {self.__class__.__name__}") - raise TypeError(f"can't write with index {idx!r}") - # normalize idx, raising IndexError if it is out of range - idx = range(len(self.descriptors))[idx] - desc = self.descriptors[idx] - self._verify_write_with_desc(idx, item, desc) - return idx - - def _on_set(self, idx, new_item, old_item): - # type: (int, _T, _T | None) -> None - pass - - @abstractmethod - def _get_descriptors(self): - # type: () -> tuple[_Desc, ...] - raise NotImplementedError - - @cached_property - @final - def descriptors(self): - # type: () -> tuple[_Desc, ...] - return self._get_descriptors() - - @property - @final - def op(self): - return self.__op - - def __init__(self, items, op): - # type: (Iterable[_T], Op) -> None - super().__init__() - self.__op = op - self.__items = [] # type: list[_T] - for idx, item in enumerate(items): - if idx >= len(self.descriptors): - raise ValueError("too many items") - _ = self._verify_write(idx, item) - self.__items.append(item) - if len(self.__items) < len(self.descriptors): - raise ValueError("not enough items") - - @final - def __iter__(self): - # type: () -> Iterator[_T] - yield from self.__items - - @overload - def __getitem__(self, idx): - # type: (int) -> _T - ... - - @overload - def __getitem__(self, idx): - # type: (slice) -> list[_T] - ... - - @final - def __getitem__(self, idx): - # type: (int | slice) -> _T | list[_T] - return self.__items[idx] - - @final - def __setitem__(self, idx, item): - # type: (int, _T) -> None - idx = self._verify_write(idx, item) - self.__items[idx] = item - - @final - def __len__(self): - # type: () -> int - return len(self.__items) - - def __repr__(self): - # type: () -> str - return f"{self.__class__.__name__}({self.__items}, op=...)" - - -@final -class OpInputVals(OpInputSeq[SSAVal, OperandDesc]): - def _get_descriptors(self): - # type: () -> tuple[OperandDesc, ...] - return self.op.properties.inputs - - def _verify_write_with_desc(self, idx, item, desc): - # type: (int, SSAVal | Any, OperandDesc) -> None - if not isinstance(item, SSAVal): - raise TypeError("expected value of type SSAVal") - if item.ty != desc.ty: - raise ValueError(f"assigned item's type {item.ty!r} doesn't match " - f"corresponding input's type {desc.ty!r}") - - def _on_set(self, idx, new_item, old_item): - # type: (int, SSAVal, SSAVal | None) -> None - SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore - - def __init__(self, items, op): - # type: (Iterable[SSAVal], Op) -> None - if hasattr(op, "inputs"): - raise ValueError("Op.inputs already set") - super().__init__(items, op) - - -@final -class OpImmediates(OpInputSeq[int, range]): - def _get_descriptors(self): - # type: () -> tuple[range, ...] - return self.op.properties.immediates - - def _verify_write_with_desc(self, idx, item, desc): - # type: (int, int | Any, range) -> None - if not isinstance(item, int): - raise TypeError("expected value of type int") - if item not in desc: - raise ValueError(f"immediate value {item!r} not in {desc!r}") - - def __init__(self, items, op): - # type: (Iterable[int], Op) -> None - if hasattr(op, "immediates"): - raise ValueError("Op.immediates already set") - super().__init__(items, op) - - -@plain_data(frozen=True, eq=False, repr=False) -@final -class Op: - __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates", - "outputs", "name") - - def __init__(self, fn, properties, input_vals, immediates, name=""): - # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None - self.fn = fn - self.properties = properties - self.input_vals = OpInputVals(input_vals, op=self) - inputs_len = len(self.properties.inputs) - self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len)) - self.immediates = OpImmediates(immediates, op=self) - outputs_len = len(self.properties.outputs) - self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len)) - self.name = fn._add_op_with_unused_name(self, name) # type: ignore - - @property - def kind(self): - # type: () -> OpKind - return self.properties.kind - - def __eq__(self, other): - # type: (Op | Any) -> bool - if isinstance(other, Op): - return self is other - return NotImplemented - - def __hash__(self): - # type: () -> int - return object.__hash__(self) - - def __repr__(self): - # type: () -> str - field_vals = [] # type: list[str] - for name in fields(self): - if name == "properties": - name = "kind" - elif name == "fn": - continue - try: - value = getattr(self, name) - except AttributeError: - field_vals.append(f"{name}=") - continue - if isinstance(value, OpInputSeq): - value = list(value) # type: ignore - field_vals.append(f"{name}={value!r}") - field_vals_str = ", ".join(field_vals) - return f"Op({field_vals_str})" - - def sim(self, state): - # type: (BaseSimState) -> None - for inp in self.input_vals: - try: - val = state[inp] - except KeyError: - raise ValueError(f"SSAVal {inp} not yet assigned when " - f"running {self}") - if len(val) != inp.ty.reg_len: - raise ValueError( - f"value of SSAVal {inp} has wrong number of elements: " - f"expected {inp.ty.reg_len} found " - f"{len(val)}: {val!r}") - if isinstance(state, PreRASimState): - for out in self.outputs: - if out in state.ssa_vals: - if self.kind is OpKind.FuncArgR3: - continue - raise ValueError(f"SSAVal {out} already assigned before " - f"running {self}") - self.kind.sim(self, state) - for out in self.outputs: - try: - val = state[out] - except KeyError: - raise ValueError(f"running {self} failed to assign to {out}") - if len(val) != out.ty.reg_len: - raise ValueError( - f"value of SSAVal {out} has wrong number of elements: " - f"expected {out.ty.reg_len} found " - f"{len(val)}: {val!r}") - - def gen_asm(self, state): - # type: (GenAsmState) -> None - all_loc_kinds = tuple(LocKind) - for inp in self.input_vals: - state.loc(inp, expected_kinds=all_loc_kinds) - for out in self.outputs: - state.loc(out, expected_kinds=all_loc_kinds) - self.kind.gen_asm(self, state) - - -GPR_SIZE_IN_BYTES = 8 -BITS_IN_BYTE = 8 -GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE -GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1 - - -@plain_data(frozen=True, repr=False) -class BaseSimState(metaclass=ABCMeta): - __slots__ = "memory", - - def __init__(self, memory): - # type: (dict[int, int]) -> None - super().__init__() - self.memory = memory # type: dict[int, int] - - def load_byte(self, addr): - # type: (int) -> int - addr &= GPR_VALUE_MASK - return self.memory.get(addr, 0) & 0xFF - - def store_byte(self, addr, value): - # type: (int, int) -> None - addr &= GPR_VALUE_MASK - value &= 0xFF - self.memory[addr] = value - - def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False): - # type: (int, int, bool) -> int - if addr % size_in_bytes != 0: - raise ValueError(f"address not aligned: {hex(addr)} " - f"required alignment: {size_in_bytes}") - retval = 0 - for i in range(size_in_bytes): - retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE - if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0: - retval -= 1 << size_in_bytes * BITS_IN_BYTE - return retval - - def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES): - # type: (int, int, int) -> None - if addr % size_in_bytes != 0: - raise ValueError(f"address not aligned: {hex(addr)} " - f"required alignment: {size_in_bytes}") - for i in range(size_in_bytes): - self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF) - - def _memory__repr(self): - # type: () -> str - if len(self.memory) == 0: - return "{}" - keys = sorted(self.memory.keys(), reverse=True) - CHUNK_SIZE = GPR_SIZE_IN_BYTES - items = [] # type: list[str] - while len(keys) != 0: - addr = keys[-1] - if (len(keys) >= CHUNK_SIZE - and addr % CHUNK_SIZE == 0 - and keys[-CHUNK_SIZE:] - == list(reversed(range(addr, addr + CHUNK_SIZE)))): - value = self.load(addr, size_in_bytes=CHUNK_SIZE) - items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>") - keys[-CHUNK_SIZE:] = () - else: - items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}") - if len(items) == 1: - return f"{{{items[0]}}}" - items_str = ",\n".join(items) - return f"{{\n{items_str}}}" - - def __repr__(self): - # type: () -> str - field_vals = [] # type: list[str] - for name in fields(self): - try: - value = getattr(self, name) - except AttributeError: - field_vals.append(f"{name}=") - continue - repr_fn = getattr(self, f"_{name}__repr", None) - if callable(repr_fn): - field_vals.append(f"{name}={repr_fn()}") - else: - field_vals.append(f"{name}={value!r}") - field_vals_str = ", ".join(field_vals) - return f"{self.__class__.__name__}({field_vals_str})" - - @abstractmethod - def __getitem__(self, ssa_val): - # type: (SSAVal) -> tuple[int, ...] - ... - - @abstractmethod - def __setitem__(self, ssa_val, value): - # type: (SSAVal, tuple[int, ...]) -> None - ... - - -@plain_data(frozen=True, repr=False) -@final -class PreRASimState(BaseSimState): - __slots__ = "ssa_vals", - - def __init__(self, ssa_vals, memory): - # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None - super().__init__(memory) - self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]] - - def _ssa_vals__repr(self): - # type: () -> str - if len(self.ssa_vals) == 0: - return "{}" - items = [] # type: list[str] - CHUNK_SIZE = 4 - for k, v in self.ssa_vals.items(): - element_strs = [] # type: list[str] - for i, el in enumerate(v): - if i % CHUNK_SIZE != 0: - element_strs.append(" " + hex(el)) - else: - element_strs.append("\n " + hex(el)) - if len(element_strs) <= CHUNK_SIZE: - element_strs[0] = element_strs[0].lstrip() - if len(element_strs) == 1: - element_strs.append("") - v_str = ",".join(element_strs) - items.append(f"{k!r}: ({v_str})") - if len(items) == 1 and "\n" not in items[0]: - return f"{{{items[0]}}}" - items_str = ",\n".join(items) - return f"{{\n{items_str},\n}}" - - def __getitem__(self, ssa_val): - # type: (SSAVal) -> tuple[int, ...] - return self.ssa_vals[ssa_val] - - def __setitem__(self, ssa_val, value): - # type: (SSAVal, tuple[int, ...]) -> None - if len(value) != ssa_val.ty.reg_len: - raise ValueError("value has wrong len") - self.ssa_vals[ssa_val] = value - - -@plain_data(frozen=True, repr=False) -@final -class PostRASimState(BaseSimState): - __slots__ = "ssa_val_to_loc_map", "loc_values" - - def __init__(self, ssa_val_to_loc_map, memory, loc_values): - # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None - super().__init__(memory) - self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map) - for ssa_val, loc in self.ssa_val_to_loc_map.items(): - if ssa_val.ty != loc.ty: - raise ValueError( - f"type mismatch for SSAVal and Loc: {ssa_val} {loc}") - self.loc_values = loc_values - for loc in self.loc_values.keys(): - if loc.reg_len != 1: - raise ValueError( - "loc_values must only contain Locs with reg_len=1, all " - "larger Locs will be split into reg_len=1 sub-Locs") - - def _loc_values__repr(self): - # type: () -> str - locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start)) - items = [] # type: list[str] - for loc in locs: - items.append(f"{loc}: 0x{self.loc_values[loc]:x}") - items_str = ",\n".join(items) - return f"{{\n{items_str},\n}}" - - def __getitem__(self, ssa_val): - # type: (SSAVal) -> tuple[int, ...] - loc = self.ssa_val_to_loc_map[ssa_val] - subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1) - retval = [] # type: list[int] - for i in range(loc.reg_len): - subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i) - retval.append(self.loc_values.get(subloc, 0)) - return tuple(retval) - - def __setitem__(self, ssa_val, value): - # type: (SSAVal, tuple[int, ...]) -> None - if len(value) != ssa_val.ty.reg_len: - raise ValueError("value has wrong len") - loc = self.ssa_val_to_loc_map[ssa_val] - subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1) - for i in range(loc.reg_len): - subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i) - self.loc_values[subloc] = value[i] - - -@plain_data(frozen=True) -class GenAsmState: - __slots__ = "allocated_locs", "output" - - def __init__(self, allocated_locs, output=None): - # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None - super().__init__() - self.allocated_locs = FMap(allocated_locs) - for ssa_val, loc in self.allocated_locs.items(): - if ssa_val.ty != loc.ty: - raise ValueError( - f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}") - if output is None: - output = [] - self.output = output - - __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]] - - def loc(self, ssa_val_or_locs, expected_kinds): - # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc - if isinstance(ssa_val_or_locs, (SSAVal, Loc)): - ssa_val_or_locs = [ssa_val_or_locs] - locs = [] # type: list[Loc] - for i in ssa_val_or_locs: - if isinstance(i, SSAVal): - locs.append(self.allocated_locs[i]) - else: - locs.append(i) - if len(locs) == 0: - raise ValueError("invalid Loc sequence: must not be empty") - retval = locs[0].try_concat(*locs[1:]) - if retval is None: - raise ValueError("invalid Loc sequence: try_concat failed") - if isinstance(expected_kinds, LocKind): - expected_kinds = expected_kinds, - if retval.kind not in expected_kinds: - if len(expected_kinds) == 1: - expected_kinds = expected_kinds[0] - raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found " - f"{retval.kind} expected {expected_kinds}") - return retval - - def gpr(self, ssa_val_or_locs, is_vec): - # type: (__SSA_VAL_OR_LOCS, bool) -> str - loc = self.loc(ssa_val_or_locs, LocKind.GPR) - vec_str = "*" if is_vec else "" - return vec_str + str(loc.start) - - def sgpr(self, ssa_val_or_locs): - # type: (__SSA_VAL_OR_LOCS) -> str - return self.gpr(ssa_val_or_locs, is_vec=False) - - def vgpr(self, ssa_val_or_locs): - # type: (__SSA_VAL_OR_LOCS) -> str - return self.gpr(ssa_val_or_locs, is_vec=True) - - def stack(self, ssa_val_or_locs): - # type: (__SSA_VAL_OR_LOCS) -> str - loc = self.loc(ssa_val_or_locs, LocKind.StackI64) - return f"{loc.start}(1)" - - def writeln(self, *line_segments): - # type: (*str) -> None - line = " ".join(line_segments) - if isinstance(self.output, list): - self.output.append(line) - else: - self.output.write(line + "\n") diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py new file mode 100644 index 0000000..b3f682e --- /dev/null +++ b/src/bigint_presentation_code/register_allocator.py @@ -0,0 +1,574 @@ +""" +Register Allocator for Toom-Cook algorithm generator for SVP64 + +this uses an algorithm based on: +[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) +""" + +from itertools import combinations +from typing import Iterable, Iterator, Mapping, TextIO + +from cached_property import cached_property +from nmutil.plain_data import plain_data + +from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc, + LocSet, ProgramRange, SSAVal, + Ty) +from bigint_presentation_code.type_util import final +from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet + + +class BadMergedSSAVal(ValueError): + pass + + +@plain_data(frozen=True, repr=False) +@final +class MergedSSAVal(metaclass=InternedMeta): + """a set of `SSAVal`s along with their offsets, all register allocated as + a single unit. + + Definition of the term `offset` for this class: + + Let `locs[x]` be the `Loc` that `x` is assigned to after register + allocation and let `msv` be a `MergedSSAVal` instance, then the offset + for each `SSAVal` `ssa_val` in `msv` is defined as: + + ``` + msv.ssa_val_offsets[ssa_val] = (msv.offset + + locs[ssa_val].start - locs[msv].start) + ``` + + Example: + ``` + v1.ty == + v2.ty == + v3.ty == + msv = MergedSSAVal({v1: 0, v2: 4, v3: 1}) + msv.ty == + ``` + if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then + * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)` + * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)` + * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)` + """ + __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set" + + def __init__(self, fn_analysis, ssa_val_offsets): + # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None + self.fn_analysis = fn_analysis + if isinstance(ssa_val_offsets, SSAVal): + ssa_val_offsets = {ssa_val_offsets: 0} + self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int] + first_ssa_val = None + for ssa_val in self.ssa_vals: + first_ssa_val = ssa_val + break + if first_ssa_val is None: + raise BadMergedSSAVal("MergedSSAVal can't be empty") + self.first_ssa_val = first_ssa_val # type: SSAVal + # self.ty checks for mismatched base_ty + reg_len = self.ty.reg_len + loc_set = None # type: None | LocSet + for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items(): + def locs(): + # type: () -> Iterable[Loc] + for loc in ssa_val.def_loc_set_before_spread: + disallowed_by_use = False + for use in fn_analysis.uses[ssa_val]: + # calculate the start for the use's Loc before spread + # e.g. if the def's Loc before spread starts at r6 + # and the def's reg_offset_in_unspread is 5 + # and the use's reg_offset_in_unspread is 3 + # then the use's Loc before spread starts at r8 + # because 8 == 6 + 5 - 3 + start = (loc.start + ssa_val.reg_offset_in_unspread + - use.reg_offset_in_unspread) + use_loc = Loc.try_make( + loc.kind, start=start, + reg_len=use.ty_before_spread.reg_len) + if (use_loc is None or + use_loc not in use.use_loc_set_before_spread): + disallowed_by_use = True + break + if disallowed_by_use: + continue + # FIXME: add spread consistency check + start = loc.start - cur_offset + self.offset + loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len) + if loc is not None and (loc_set is None or loc in loc_set): + yield loc + loc_set = LocSet(locs()) + assert loc_set is not None, "already checked that self isn't empty" + if loc_set.ty is None: + raise BadMergedSSAVal("there are no valid Locs left") + assert loc_set.ty == self.ty, "logic error somewhere" + self.loc_set = loc_set # type: LocSet + + @cached_property + def __hash(self): + # type: () -> int + return hash((self.fn_analysis, self.ssa_val_offsets)) + + def __hash__(self): + # type: () -> int + return self.__hash + + @cached_property + def offset(self): + # type: () -> int + return min(self.ssa_val_offsets_before_spread.values()) + + @property + def base_ty(self): + # type: () -> BaseTy + return self.first_ssa_val.base_ty + + @cached_property + def ssa_vals(self): + # type: () -> OFSet[SSAVal] + return OFSet(self.ssa_val_offsets.keys()) + + @cached_property + def ty(self): + # type: () -> Ty + reg_len = 0 + for ssa_val, offset in self.ssa_val_offsets_before_spread.items(): + cur_ty = ssa_val.ty_before_spread + if self.base_ty != cur_ty.base_ty: + raise BadMergedSSAVal( + f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}") + reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset) + return Ty(base_ty=self.base_ty, reg_len=reg_len) + + @cached_property + def ssa_val_offsets_before_spread(self): + # type: () -> FMap[SSAVal, int] + retval = {} # type: dict[SSAVal, int] + for ssa_val, offset in self.ssa_val_offsets.items(): + retval[ssa_val] = ( + offset - ssa_val.defining_descriptor.reg_offset_in_unspread) + return FMap(retval) + + def offset_by(self, amount): + # type: (int) -> MergedSSAVal + v = {k: v + amount for k, v in self.ssa_val_offsets.items()} + return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v) + + def normalized(self): + # type: () -> MergedSSAVal + return self.offset_by(-self.offset) + + def with_offset_to_match(self, target, additional_offset=0): + # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal + if isinstance(target, MergedSSAVal): + ssa_val_offsets = target.ssa_val_offsets + else: + ssa_val_offsets = {target: 0} + for ssa_val, offset in self.ssa_val_offsets.items(): + if ssa_val in ssa_val_offsets: + return self.offset_by( + ssa_val_offsets[ssa_val] + additional_offset - offset) + raise ValueError("can't change offset to match unrelated MergedSSAVal") + + def merged(self, *others): + # type: (*MergedSSAVal) -> MergedSSAVal + retval = dict(self.ssa_val_offsets) + for other in others: + if other.fn_analysis != self.fn_analysis: + raise ValueError("fn_analysis mismatch") + for ssa_val, offset in other.ssa_val_offsets.items(): + if ssa_val in retval and retval[ssa_val] != offset: + raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: " + f"{retval[ssa_val]} != {offset}") + retval[ssa_val] = offset + return MergedSSAVal(fn_analysis=self.fn_analysis, + ssa_val_offsets=retval) + + @cached_property + def live_interval(self): + # type: () -> ProgramRange + live_range = self.fn_analysis.live_ranges[self.first_ssa_val] + start = live_range.start + stop = live_range.stop + for ssa_val in self.ssa_vals: + live_range = self.fn_analysis.live_ranges[ssa_val] + start = min(start, live_range.start) + stop = max(stop, live_range.stop) + return ProgramRange(start=start, stop=stop) + + def __repr__(self): + return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, " + f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, " + f"live_interval={self.live_interval})") + + +@final +class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]): + def __init__(self): + # type: (...) -> None + self.__map = {} # type: dict[SSAVal, MergedSSAVal] + self.__ig_node_map = MergedSSAValToIGNodeMap( + _private_merged_ssa_val_map=self.__map) + + def __getitem__(self, __key): + # type: (SSAVal) -> MergedSSAVal + return self.__map[__key] + + def __iter__(self): + # type: () -> Iterator[SSAVal] + return iter(self.__map) + + def __len__(self): + # type: () -> int + return len(self.__map) + + @property + def ig_node_map(self): + # type: () -> MergedSSAValToIGNodeMap + return self.__ig_node_map + + def __repr__(self): + # type: () -> str + s = ",\n".join(repr(v) for v in self.__ig_node_map) + return f"SSAValToMergedSSAValMap({{{s}}})" + + +@final +class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): + def __init__( + self, *, + _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal] + ): + # type: (...) -> None + self.__merged_ssa_val_map = _private_merged_ssa_val_map + self.__map = {} # type: dict[MergedSSAVal, IGNode] + + def __getitem__(self, __key): + # type: (MergedSSAVal) -> IGNode + return self.__map[__key] + + def __iter__(self): + # type: () -> Iterator[MergedSSAVal] + return iter(self.__map) + + def __len__(self): + # type: () -> int + return len(self.__map) + + def add_node(self, merged_ssa_val): + # type: (MergedSSAVal) -> IGNode + node = self.__map.get(merged_ssa_val, None) + if node is not None: + return node + added = 0 # type: int | None + try: + for ssa_val in merged_ssa_val.ssa_vals: + if ssa_val in self.__merged_ssa_val_map: + raise ValueError( + f"overlapping `MergedSSAVal`s: {ssa_val} is in both " + f"{merged_ssa_val} and " + f"{self.__merged_ssa_val_map[ssa_val]}") + self.__merged_ssa_val_map[ssa_val] = merged_ssa_val + added += 1 + retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None) + self.__map[merged_ssa_val] = retval + added = None + return retval + finally: + if added is not None: + # remove partially added stuff + for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals): + if idx >= added: + break + del self.__merged_ssa_val_map[ssa_val] + + def merge_into_one_node(self, final_merged_ssa_val): + # type: (MergedSSAVal) -> IGNode + source_nodes = OSet() # type: OSet[IGNode] + edges = OSet() # type: OSet[IGNode] + loc = None # type: Loc | None + for ssa_val in final_merged_ssa_val.ssa_vals: + merged_ssa_val = self.__merged_ssa_val_map[ssa_val] + source_node = self.__map[merged_ssa_val] + source_nodes.add(source_node) + for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals: + raise ValueError( + f"SSAVal {i} appears in source IGNode's merged_ssa_val " + f"but not in merged IGNode's merged_ssa_val: " + f"source_node={source_node} " + f"final_merged_ssa_val={final_merged_ssa_val}") + if loc is None: + loc = source_node.loc + elif source_node.loc is not None and loc != source_node.loc: + raise ValueError(f"can't merge IGNodes with mismatched `loc` " + f"values: {loc} != {source_node.loc}") + edges |= source_node.edges + if len(source_nodes) == 1: + return source_nodes.pop() # merging a single node is a no-op + # we're finished checking validity, now we can modify stuff + edges -= source_nodes + retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges, + loc=loc) + for node in edges: + node.edges -= source_nodes + node.edges.add(retval) + for node in source_nodes: + del self.__map[node.merged_ssa_val] + self.__map[final_merged_ssa_val] = retval + for ssa_val in final_merged_ssa_val.ssa_vals: + self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val + return retval + + def __repr__(self, repr_state=None): + # type: (None | IGNodeReprState) -> str + if repr_state is None: + repr_state = IGNodeReprState() + s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values()) + return f"MergedSSAValToIGNodeMap({{{s}}})" + + +@plain_data(frozen=True, repr=False) +@final +class InterferenceGraph: + __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes" + + def __init__(self, fn_analysis, merged_ssa_vals): + # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None + self.fn_analysis = fn_analysis + self.merged_ssa_val_map = SSAValToMergedSSAValMap() + self.nodes = self.merged_ssa_val_map.ig_node_map + for i in merged_ssa_vals: + self.nodes.add_node(i) + + def merge(self, ssa_val1, ssa_val2, additional_offset=0): + # type: (SSAVal, SSAVal, int) -> IGNode + merged1 = self.merged_ssa_val_map[ssa_val1] + merged2 = self.merged_ssa_val_map[ssa_val2] + merged = merged1.with_offset_to_match(ssa_val1) + merged = merged.merged(merged2.with_offset_to_match( + ssa_val2, additional_offset=additional_offset)) + return self.nodes.merge_into_one_node(merged) + + @staticmethod + def minimally_merged(fn_analysis): + # type: (FnAnalysis) -> InterferenceGraph + retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=()) + for op in fn_analysis.fn.ops: + for inp in op.input_uses: + if inp.unspread_start != inp: + retval.merge(inp.unspread_start.ssa_val, inp.ssa_val, + additional_offset=inp.reg_offset_in_unspread) + for out in op.outputs: + retval.nodes.add_node(MergedSSAVal(fn_analysis, out)) + if out.unspread_start != out: + retval.merge(out.unspread_start, out, + additional_offset=out.reg_offset_in_unspread) + if out.tied_input is not None: + retval.merge(out.tied_input.ssa_val, out) + return retval + + def __repr__(self, repr_state=None): + # type: (None | IGNodeReprState) -> str + if repr_state is None: + repr_state = IGNodeReprState() + s = self.nodes.__repr__(repr_state) + return f"InterferenceGraph(nodes={s}, <...>)" + + +@plain_data(repr=False) +class IGNodeReprState: + __slots__ = "node_ids", "did_full_repr" + + def __init__(self): + super().__init__() + self.node_ids = {} # type: dict[IGNode, int] + self.did_full_repr = OSet() # type: OSet[IGNode] + + +@final +class IGNode: + """ interference graph node """ + __slots__ = "merged_ssa_val", "edges", "loc" + + def __init__(self, merged_ssa_val, edges, loc): + # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None + self.merged_ssa_val = merged_ssa_val + self.edges = OSet(edges) + self.loc = loc + + def add_edge(self, other): + # type: (IGNode) -> None + self.edges.add(other) + other.edges.add(self) + + def __eq__(self, other): + # type: (object) -> bool + if isinstance(other, IGNode): + return self.merged_ssa_val == other.merged_ssa_val + return NotImplemented + + def __hash__(self): + # type: () -> int + return hash(self.merged_ssa_val) + + def __repr__(self, repr_state=None, short=False): + # type: (None | IGNodeReprState, bool) -> str + if repr_state is None: + repr_state = IGNodeReprState() + node_id = repr_state.node_ids.get(self, None) + if node_id is None: + repr_state.node_ids[self] = node_id = len(repr_state.node_ids) + if short or self in repr_state.did_full_repr: + return f"" + repr_state.did_full_repr.add(self) + edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges) + return (f"IGNode(#{node_id}, " + f"merged_ssa_val={self.merged_ssa_val}, " + f"edges={{{edges}}}, " + f"loc={self.loc})") + + @property + def loc_set(self): + # type: () -> LocSet + return self.merged_ssa_val.loc_set + + def loc_conflicts_with_neighbors(self, loc): + # type: (Loc) -> bool + for neighbor in self.edges: + if neighbor.loc is not None and neighbor.loc.conflicts(loc): + return True + return False + + +class AllocationFailedError(Exception): + def __init__(self, msg, node, interference_graph): + # type: (str, IGNode, InterferenceGraph) -> None + super().__init__(msg, node, interference_graph) + self.node = node + self.interference_graph = interference_graph + + def __repr__(self, repr_state=None): + # type: (None | IGNodeReprState) -> str + if repr_state is None: + repr_state = IGNodeReprState() + return (f"{__class__.__name__}({self.args[0]!r}, " + f"node={self.node.__repr__(repr_state, True)}, " + f"interference_graph=" + f"{self.interference_graph.__repr__(repr_state)})") + + def __str__(self): + # type: () -> str + return self.__repr__() + + +def allocate_registers(fn, debug_out=None): + # type: (Fn, TextIO | None) -> dict[SSAVal, Loc] + + # inserts enough copies that no manual spilling is necessary, all + # spilling is done by the register allocator naturally allocating SSAVals + # to stack slots + fn.pre_ra_insert_copies() + + if debug_out is not None: + print(f"After pre_ra_insert_copies():\n{fn.ops}", + file=debug_out, flush=True) + + fn_analysis = FnAnalysis(fn) + interference_graph = InterferenceGraph.minimally_merged(fn_analysis) + + if debug_out is not None: + print(f"After InterferenceGraph.minimally_merged():\n" + f"{interference_graph}", file=debug_out, flush=True) + + for pp, ssa_vals in fn_analysis.live_at.items(): + live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal] + for ssa_val in ssa_vals: + live_merged_ssa_vals.add( + interference_graph.merged_ssa_val_map[ssa_val]) + for i, j in combinations(live_merged_ssa_vals, 2): + if i.loc_set.max_conflicts_with(j.loc_set) != 0: + interference_graph.nodes[i].add_edge( + interference_graph.nodes[j]) + if debug_out is not None: + print(f"processed {pp} out of {fn_analysis.all_program_points}", + file=debug_out, flush=True) + + if debug_out is not None: + print(f"After adding interference graph edges:\n" + f"{interference_graph}", file=debug_out, flush=True) + + nodes_remaining = OSet(interference_graph.nodes.values()) + + def local_colorability_score(node): + # type: (IGNode) -> int + """ returns a positive integer if node is locally colorable, returns + zero or a negative integer if node isn't known to be locally + colorable, the more negative the value, the less colorable + """ + if node not in nodes_remaining: + raise ValueError() + retval = len(node.loc_set) + for neighbor in node.edges: + if neighbor in nodes_remaining: + retval -= node.loc_set.max_conflicts_with(neighbor.loc_set) + return retval + + # TODO: implement copy-merging + + node_stack = [] # type: list[IGNode] + while True: + best_node = None # type: None | IGNode + best_score = 0 + for node in nodes_remaining: + score = local_colorability_score(node) + if best_node is None or score > best_score: + best_node = node + best_score = score + if best_score > 0: + # it's locally colorable, no need to find a better one + break + + if best_node is None: + break + node_stack.append(best_node) + nodes_remaining.remove(best_node) + + if debug_out is not None: + print(f"After deciding node allocation order:\n" + f"{node_stack}", file=debug_out, flush=True) + + retval = {} # type: dict[SSAVal, Loc] + + while len(node_stack) > 0: + node = node_stack.pop() + if node.loc is not None: + if node.loc_conflicts_with_neighbors(node.loc): + raise AllocationFailedError( + "IGNode is pre-allocated to a conflicting Loc", + node=node, interference_graph=interference_graph) + else: + # pick the first non-conflicting register in node.reg_class, since + # register classes are ordered from most preferred to least + # preferred register. + for loc in node.loc_set: + if not node.loc_conflicts_with_neighbors(loc): + node.loc = loc + break + if node.loc is None: + raise AllocationFailedError( + "failed to allocate Loc for IGNode", + node=node, interference_graph=interference_graph) + + if debug_out is not None: + print(f"After allocating Loc for node:\n{node}", + file=debug_out, flush=True) + + for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items(): + retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset) + + if debug_out is not None: + print(f"final Locs for all SSAVals:\n{retval}", + file=debug_out, flush=True) + + return retval diff --git a/src/bigint_presentation_code/register_allocator2.py b/src/bigint_presentation_code/register_allocator2.py deleted file mode 100644 index 278693d..0000000 --- a/src/bigint_presentation_code/register_allocator2.py +++ /dev/null @@ -1,574 +0,0 @@ -""" -Register Allocator for Toom-Cook algorithm generator for SVP64 - -this uses an algorithm based on: -[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) -""" - -from itertools import combinations -from typing import Iterable, Iterator, Mapping, TextIO - -from cached_property import cached_property -from nmutil.plain_data import plain_data - -from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc, - LocSet, ProgramRange, - SSAVal, Ty) -from bigint_presentation_code.type_util import final -from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet - - -class BadMergedSSAVal(ValueError): - pass - - -@plain_data(frozen=True, repr=False) -@final -class MergedSSAVal(metaclass=InternedMeta): - """a set of `SSAVal`s along with their offsets, all register allocated as - a single unit. - - Definition of the term `offset` for this class: - - Let `locs[x]` be the `Loc` that `x` is assigned to after register - allocation and let `msv` be a `MergedSSAVal` instance, then the offset - for each `SSAVal` `ssa_val` in `msv` is defined as: - - ``` - msv.ssa_val_offsets[ssa_val] = (msv.offset - + locs[ssa_val].start - locs[msv].start) - ``` - - Example: - ``` - v1.ty == - v2.ty == - v3.ty == - msv = MergedSSAVal({v1: 0, v2: 4, v3: 1}) - msv.ty == - ``` - if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then - * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)` - * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)` - * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)` - """ - __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set" - - def __init__(self, fn_analysis, ssa_val_offsets): - # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None - self.fn_analysis = fn_analysis - if isinstance(ssa_val_offsets, SSAVal): - ssa_val_offsets = {ssa_val_offsets: 0} - self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int] - first_ssa_val = None - for ssa_val in self.ssa_vals: - first_ssa_val = ssa_val - break - if first_ssa_val is None: - raise BadMergedSSAVal("MergedSSAVal can't be empty") - self.first_ssa_val = first_ssa_val # type: SSAVal - # self.ty checks for mismatched base_ty - reg_len = self.ty.reg_len - loc_set = None # type: None | LocSet - for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items(): - def locs(): - # type: () -> Iterable[Loc] - for loc in ssa_val.def_loc_set_before_spread: - disallowed_by_use = False - for use in fn_analysis.uses[ssa_val]: - # calculate the start for the use's Loc before spread - # e.g. if the def's Loc before spread starts at r6 - # and the def's reg_offset_in_unspread is 5 - # and the use's reg_offset_in_unspread is 3 - # then the use's Loc before spread starts at r8 - # because 8 == 6 + 5 - 3 - start = (loc.start + ssa_val.reg_offset_in_unspread - - use.reg_offset_in_unspread) - use_loc = Loc.try_make( - loc.kind, start=start, - reg_len=use.ty_before_spread.reg_len) - if (use_loc is None or - use_loc not in use.use_loc_set_before_spread): - disallowed_by_use = True - break - if disallowed_by_use: - continue - # FIXME: add spread consistency check - start = loc.start - cur_offset + self.offset - loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len) - if loc is not None and (loc_set is None or loc in loc_set): - yield loc - loc_set = LocSet(locs()) - assert loc_set is not None, "already checked that self isn't empty" - if loc_set.ty is None: - raise BadMergedSSAVal("there are no valid Locs left") - assert loc_set.ty == self.ty, "logic error somewhere" - self.loc_set = loc_set # type: LocSet - - @cached_property - def __hash(self): - # type: () -> int - return hash((self.fn_analysis, self.ssa_val_offsets)) - - def __hash__(self): - # type: () -> int - return self.__hash - - @cached_property - def offset(self): - # type: () -> int - return min(self.ssa_val_offsets_before_spread.values()) - - @property - def base_ty(self): - # type: () -> BaseTy - return self.first_ssa_val.base_ty - - @cached_property - def ssa_vals(self): - # type: () -> OFSet[SSAVal] - return OFSet(self.ssa_val_offsets.keys()) - - @cached_property - def ty(self): - # type: () -> Ty - reg_len = 0 - for ssa_val, offset in self.ssa_val_offsets_before_spread.items(): - cur_ty = ssa_val.ty_before_spread - if self.base_ty != cur_ty.base_ty: - raise BadMergedSSAVal( - f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}") - reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset) - return Ty(base_ty=self.base_ty, reg_len=reg_len) - - @cached_property - def ssa_val_offsets_before_spread(self): - # type: () -> FMap[SSAVal, int] - retval = {} # type: dict[SSAVal, int] - for ssa_val, offset in self.ssa_val_offsets.items(): - retval[ssa_val] = ( - offset - ssa_val.defining_descriptor.reg_offset_in_unspread) - return FMap(retval) - - def offset_by(self, amount): - # type: (int) -> MergedSSAVal - v = {k: v + amount for k, v in self.ssa_val_offsets.items()} - return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v) - - def normalized(self): - # type: () -> MergedSSAVal - return self.offset_by(-self.offset) - - def with_offset_to_match(self, target, additional_offset=0): - # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal - if isinstance(target, MergedSSAVal): - ssa_val_offsets = target.ssa_val_offsets - else: - ssa_val_offsets = {target: 0} - for ssa_val, offset in self.ssa_val_offsets.items(): - if ssa_val in ssa_val_offsets: - return self.offset_by( - ssa_val_offsets[ssa_val] + additional_offset - offset) - raise ValueError("can't change offset to match unrelated MergedSSAVal") - - def merged(self, *others): - # type: (*MergedSSAVal) -> MergedSSAVal - retval = dict(self.ssa_val_offsets) - for other in others: - if other.fn_analysis != self.fn_analysis: - raise ValueError("fn_analysis mismatch") - for ssa_val, offset in other.ssa_val_offsets.items(): - if ssa_val in retval and retval[ssa_val] != offset: - raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: " - f"{retval[ssa_val]} != {offset}") - retval[ssa_val] = offset - return MergedSSAVal(fn_analysis=self.fn_analysis, - ssa_val_offsets=retval) - - @cached_property - def live_interval(self): - # type: () -> ProgramRange - live_range = self.fn_analysis.live_ranges[self.first_ssa_val] - start = live_range.start - stop = live_range.stop - for ssa_val in self.ssa_vals: - live_range = self.fn_analysis.live_ranges[ssa_val] - start = min(start, live_range.start) - stop = max(stop, live_range.stop) - return ProgramRange(start=start, stop=stop) - - def __repr__(self): - return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, " - f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, " - f"live_interval={self.live_interval})") - - -@final -class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]): - def __init__(self): - # type: (...) -> None - self.__map = {} # type: dict[SSAVal, MergedSSAVal] - self.__ig_node_map = MergedSSAValToIGNodeMap( - _private_merged_ssa_val_map=self.__map) - - def __getitem__(self, __key): - # type: (SSAVal) -> MergedSSAVal - return self.__map[__key] - - def __iter__(self): - # type: () -> Iterator[SSAVal] - return iter(self.__map) - - def __len__(self): - # type: () -> int - return len(self.__map) - - @property - def ig_node_map(self): - # type: () -> MergedSSAValToIGNodeMap - return self.__ig_node_map - - def __repr__(self): - # type: () -> str - s = ",\n".join(repr(v) for v in self.__ig_node_map) - return f"SSAValToMergedSSAValMap({{{s}}})" - - -@final -class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]): - def __init__( - self, *, - _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal] - ): - # type: (...) -> None - self.__merged_ssa_val_map = _private_merged_ssa_val_map - self.__map = {} # type: dict[MergedSSAVal, IGNode] - - def __getitem__(self, __key): - # type: (MergedSSAVal) -> IGNode - return self.__map[__key] - - def __iter__(self): - # type: () -> Iterator[MergedSSAVal] - return iter(self.__map) - - def __len__(self): - # type: () -> int - return len(self.__map) - - def add_node(self, merged_ssa_val): - # type: (MergedSSAVal) -> IGNode - node = self.__map.get(merged_ssa_val, None) - if node is not None: - return node - added = 0 # type: int | None - try: - for ssa_val in merged_ssa_val.ssa_vals: - if ssa_val in self.__merged_ssa_val_map: - raise ValueError( - f"overlapping `MergedSSAVal`s: {ssa_val} is in both " - f"{merged_ssa_val} and " - f"{self.__merged_ssa_val_map[ssa_val]}") - self.__merged_ssa_val_map[ssa_val] = merged_ssa_val - added += 1 - retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None) - self.__map[merged_ssa_val] = retval - added = None - return retval - finally: - if added is not None: - # remove partially added stuff - for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals): - if idx >= added: - break - del self.__merged_ssa_val_map[ssa_val] - - def merge_into_one_node(self, final_merged_ssa_val): - # type: (MergedSSAVal) -> IGNode - source_nodes = OSet() # type: OSet[IGNode] - edges = OSet() # type: OSet[IGNode] - loc = None # type: Loc | None - for ssa_val in final_merged_ssa_val.ssa_vals: - merged_ssa_val = self.__merged_ssa_val_map[ssa_val] - source_node = self.__map[merged_ssa_val] - source_nodes.add(source_node) - for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals: - raise ValueError( - f"SSAVal {i} appears in source IGNode's merged_ssa_val " - f"but not in merged IGNode's merged_ssa_val: " - f"source_node={source_node} " - f"final_merged_ssa_val={final_merged_ssa_val}") - if loc is None: - loc = source_node.loc - elif source_node.loc is not None and loc != source_node.loc: - raise ValueError(f"can't merge IGNodes with mismatched `loc` " - f"values: {loc} != {source_node.loc}") - edges |= source_node.edges - if len(source_nodes) == 1: - return source_nodes.pop() # merging a single node is a no-op - # we're finished checking validity, now we can modify stuff - edges -= source_nodes - retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges, - loc=loc) - for node in edges: - node.edges -= source_nodes - node.edges.add(retval) - for node in source_nodes: - del self.__map[node.merged_ssa_val] - self.__map[final_merged_ssa_val] = retval - for ssa_val in final_merged_ssa_val.ssa_vals: - self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val - return retval - - def __repr__(self, repr_state=None): - # type: (None | IGNodeReprState) -> str - if repr_state is None: - repr_state = IGNodeReprState() - s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values()) - return f"MergedSSAValToIGNodeMap({{{s}}})" - - -@plain_data(frozen=True, repr=False) -@final -class InterferenceGraph: - __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes" - - def __init__(self, fn_analysis, merged_ssa_vals): - # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None - self.fn_analysis = fn_analysis - self.merged_ssa_val_map = SSAValToMergedSSAValMap() - self.nodes = self.merged_ssa_val_map.ig_node_map - for i in merged_ssa_vals: - self.nodes.add_node(i) - - def merge(self, ssa_val1, ssa_val2, additional_offset=0): - # type: (SSAVal, SSAVal, int) -> IGNode - merged1 = self.merged_ssa_val_map[ssa_val1] - merged2 = self.merged_ssa_val_map[ssa_val2] - merged = merged1.with_offset_to_match(ssa_val1) - merged = merged.merged(merged2.with_offset_to_match( - ssa_val2, additional_offset=additional_offset)) - return self.nodes.merge_into_one_node(merged) - - @staticmethod - def minimally_merged(fn_analysis): - # type: (FnAnalysis) -> InterferenceGraph - retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=()) - for op in fn_analysis.fn.ops: - for inp in op.input_uses: - if inp.unspread_start != inp: - retval.merge(inp.unspread_start.ssa_val, inp.ssa_val, - additional_offset=inp.reg_offset_in_unspread) - for out in op.outputs: - retval.nodes.add_node(MergedSSAVal(fn_analysis, out)) - if out.unspread_start != out: - retval.merge(out.unspread_start, out, - additional_offset=out.reg_offset_in_unspread) - if out.tied_input is not None: - retval.merge(out.tied_input.ssa_val, out) - return retval - - def __repr__(self, repr_state=None): - # type: (None | IGNodeReprState) -> str - if repr_state is None: - repr_state = IGNodeReprState() - s = self.nodes.__repr__(repr_state) - return f"InterferenceGraph(nodes={s}, <...>)" - - -@plain_data(repr=False) -class IGNodeReprState: - __slots__ = "node_ids", "did_full_repr" - - def __init__(self): - super().__init__() - self.node_ids = {} # type: dict[IGNode, int] - self.did_full_repr = OSet() # type: OSet[IGNode] - - -@final -class IGNode: - """ interference graph node """ - __slots__ = "merged_ssa_val", "edges", "loc" - - def __init__(self, merged_ssa_val, edges, loc): - # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None - self.merged_ssa_val = merged_ssa_val - self.edges = OSet(edges) - self.loc = loc - - def add_edge(self, other): - # type: (IGNode) -> None - self.edges.add(other) - other.edges.add(self) - - def __eq__(self, other): - # type: (object) -> bool - if isinstance(other, IGNode): - return self.merged_ssa_val == other.merged_ssa_val - return NotImplemented - - def __hash__(self): - # type: () -> int - return hash(self.merged_ssa_val) - - def __repr__(self, repr_state=None, short=False): - # type: (None | IGNodeReprState, bool) -> str - if repr_state is None: - repr_state = IGNodeReprState() - node_id = repr_state.node_ids.get(self, None) - if node_id is None: - repr_state.node_ids[self] = node_id = len(repr_state.node_ids) - if short or self in repr_state.did_full_repr: - return f"" - repr_state.did_full_repr.add(self) - edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges) - return (f"IGNode(#{node_id}, " - f"merged_ssa_val={self.merged_ssa_val}, " - f"edges={{{edges}}}, " - f"loc={self.loc})") - - @property - def loc_set(self): - # type: () -> LocSet - return self.merged_ssa_val.loc_set - - def loc_conflicts_with_neighbors(self, loc): - # type: (Loc) -> bool - for neighbor in self.edges: - if neighbor.loc is not None and neighbor.loc.conflicts(loc): - return True - return False - - -class AllocationFailedError(Exception): - def __init__(self, msg, node, interference_graph): - # type: (str, IGNode, InterferenceGraph) -> None - super().__init__(msg, node, interference_graph) - self.node = node - self.interference_graph = interference_graph - - def __repr__(self, repr_state=None): - # type: (None | IGNodeReprState) -> str - if repr_state is None: - repr_state = IGNodeReprState() - return (f"{__class__.__name__}({self.args[0]!r}, " - f"node={self.node.__repr__(repr_state, True)}, " - f"interference_graph=" - f"{self.interference_graph.__repr__(repr_state)})") - - def __str__(self): - # type: () -> str - return self.__repr__() - - -def allocate_registers(fn, debug_out=None): - # type: (Fn, TextIO | None) -> dict[SSAVal, Loc] - - # inserts enough copies that no manual spilling is necessary, all - # spilling is done by the register allocator naturally allocating SSAVals - # to stack slots - fn.pre_ra_insert_copies() - - if debug_out is not None: - print(f"After pre_ra_insert_copies():\n{fn.ops}", - file=debug_out, flush=True) - - fn_analysis = FnAnalysis(fn) - interference_graph = InterferenceGraph.minimally_merged(fn_analysis) - - if debug_out is not None: - print(f"After InterferenceGraph.minimally_merged():\n" - f"{interference_graph}", file=debug_out, flush=True) - - for pp, ssa_vals in fn_analysis.live_at.items(): - live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal] - for ssa_val in ssa_vals: - live_merged_ssa_vals.add( - interference_graph.merged_ssa_val_map[ssa_val]) - for i, j in combinations(live_merged_ssa_vals, 2): - if i.loc_set.max_conflicts_with(j.loc_set) != 0: - interference_graph.nodes[i].add_edge( - interference_graph.nodes[j]) - if debug_out is not None: - print(f"processed {pp} out of {fn_analysis.all_program_points}", - file=debug_out, flush=True) - - if debug_out is not None: - print(f"After adding interference graph edges:\n" - f"{interference_graph}", file=debug_out, flush=True) - - nodes_remaining = OSet(interference_graph.nodes.values()) - - def local_colorability_score(node): - # type: (IGNode) -> int - """ returns a positive integer if node is locally colorable, returns - zero or a negative integer if node isn't known to be locally - colorable, the more negative the value, the less colorable - """ - if node not in nodes_remaining: - raise ValueError() - retval = len(node.loc_set) - for neighbor in node.edges: - if neighbor in nodes_remaining: - retval -= node.loc_set.max_conflicts_with(neighbor.loc_set) - return retval - - # TODO: implement copy-merging - - node_stack = [] # type: list[IGNode] - while True: - best_node = None # type: None | IGNode - best_score = 0 - for node in nodes_remaining: - score = local_colorability_score(node) - if best_node is None or score > best_score: - best_node = node - best_score = score - if best_score > 0: - # it's locally colorable, no need to find a better one - break - - if best_node is None: - break - node_stack.append(best_node) - nodes_remaining.remove(best_node) - - if debug_out is not None: - print(f"After deciding node allocation order:\n" - f"{node_stack}", file=debug_out, flush=True) - - retval = {} # type: dict[SSAVal, Loc] - - while len(node_stack) > 0: - node = node_stack.pop() - if node.loc is not None: - if node.loc_conflicts_with_neighbors(node.loc): - raise AllocationFailedError( - "IGNode is pre-allocated to a conflicting Loc", - node=node, interference_graph=interference_graph) - else: - # pick the first non-conflicting register in node.reg_class, since - # register classes are ordered from most preferred to least - # preferred register. - for loc in node.loc_set: - if not node.loc_conflicts_with_neighbors(loc): - node.loc = loc - break - if node.loc is None: - raise AllocationFailedError( - "failed to allocate Loc for IGNode", - node=node, interference_graph=interference_graph) - - if debug_out is not None: - print(f"After allocating Loc for node:\n{node}", - file=debug_out, flush=True) - - for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items(): - retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset) - - if debug_out is not None: - print(f"final Locs for all SSAVals:\n{retval}", - file=debug_out, flush=True) - - return retval diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index aa18967..7e9748c 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -8,7 +8,7 @@ from typing import Any, Generic, Iterable, Mapping, TypeVar, Union from nmutil.plain_data import plain_data -from bigint_presentation_code.compiler_ir2 import (Fn, OpKind, SSAVal) +from bigint_presentation_code.compiler_ir import Fn, OpKind, SSAVal from bigint_presentation_code.matrix import Matrix from bigint_presentation_code.type_util import Literal, final