--- /dev/null
+import unittest
+
+from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, BaseTy,
+ Fn, FnAnalysis, GenAsmState,
+ Loc, LocKind, OpKind,
+ OpStage, PreRASimState,
+ ProgramPoint, SSAVal, Ty)
+
+
+class TestCompilerIR(unittest.TestCase):
+ maxDiff = None
+
+ def test_program_point(self):
+ # type: () -> None
+ expected = [] # type: list[ProgramPoint]
+ for op_index in range(5):
+ for stage in OpStage:
+ expected.append(ProgramPoint(op_index=op_index, stage=stage))
+
+ for idx, pp in enumerate(expected):
+ if idx + 1 < len(expected):
+ self.assertEqual(pp.next(), expected[idx + 1])
+
+ self.assertEqual(sorted(expected), expected)
+
+ def make_add_fn(self):
+ # type: () -> tuple[Fn, SSAVal]
+ fn = Fn()
+ op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg")
+ arg = op0.outputs[0]
+ MAXVL = 32
+ op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
+ vl = op1.outputs[0]
+ op2 = fn.append_new_op(
+ OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
+ name="ld")
+ a = op2.outputs[0]
+ op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ maxvl=MAXVL, name="li")
+ b = op3.outputs[0]
+ op4 = fn.append_new_op(OpKind.SetCA, name="ca")
+ ca = op4.outputs[0]
+ op5 = fn.append_new_op(
+ OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
+ s = op5.outputs[0]
+ _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+ immediates=[0], maxvl=MAXVL, name="st")
+ return fn, arg
+
+ def test_fn_analysis(self):
+ fn, _arg = self.make_add_fn()
+ fn_analysis = FnAnalysis(fn)
+ self.assertEqual(
+ repr(fn_analysis.uses),
+ "FMap({"
+ "<arg.outputs[0]: <I64>>: OFSet(["
+ "<ld.input_uses[0]: <I64>>, <st.input_uses[1]: <I64>>]), "
+ "<vl.outputs[0]: <VL_MAXVL>>: OFSet(["
+ "<ld.input_uses[1]: <VL_MAXVL>>, <li.input_uses[0]: <VL_MAXVL>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>, "
+ "<st.input_uses[2]: <VL_MAXVL>>]), "
+ "<ld.outputs[0]: <I64*32>>: OFSet(["
+ "<add.input_uses[0]: <I64*32>>]), "
+ "<li.outputs[0]: <I64*32>>: OFSet(["
+ "<add.input_uses[1]: <I64*32>>]), "
+ "<ca.outputs[0]: <CA>>: OFSet([<add.input_uses[2]: <CA>>]), "
+ "<add.outputs[0]: <I64*32>>: OFSet(["
+ "<st.input_uses[0]: <I64*32>>]), "
+ "<add.outputs[1]: <CA>>: OFSet()})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.op_indexes),
+ "FMap({"
+ "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), "
+ "immediates=[], outputs=(<arg.outputs[0]: <I64>>,), "
+ "name='arg'): 0, "
+ "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), "
+ "immediates=[32], outputs=(<vl.outputs[0]: <VL_MAXVL>>,), "
+ "name='vl'): 1, "
+ "Op(kind=OpKind.SvLd, input_vals=["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.input_uses[0]: <I64>>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>), immediates=[0], "
+ "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld'): 2, "
+ "Op(kind=OpKind.SvLI, input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), immediates=[0], "
+ "outputs=(<li.outputs[0]: <I64*32>>,), name='li'): 3, "
+ "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), "
+ "immediates=[], outputs=(<ca.outputs[0]: <CA>>,), name='ca'): 4, "
+ "Op(kind=OpKind.SvAddE, input_vals=["
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.input_uses[0]: <I64*32>>, "
+ "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>), immediates=[], outputs=("
+ "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+ "name='add'): 5, "
+ "Op(kind=OpKind.SvStd, input_vals=["
+ "<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.input_uses[0]: <I64*32>>, "
+ "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
+ "immediates=[0], outputs=(), name='st'): 6})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.live_ranges),
+ "FMap({"
+ "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[6]:Late>, "
+ "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[6]:Late>, "
+ "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[5]:Late>, "
+ "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[5]:Late>, "
+ "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Late>, "
+ "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Late>, "
+ "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.live_at),
+ "FMap({"
+ "<ops[0]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
+ "<ops[0]:Late>: OFSet([<arg.outputs[0]: <I64>>]), "
+ "<ops[1]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
+ "<ops[1]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>]), "
+ "<ops[2]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>]), "
+ "<ops[2]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>]), "
+ "<ops[3]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+ "<ops[3]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+ "<ops[4]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+ "<ops[4]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>]), "
+ "<ops[5]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>, <add.outputs[0]: <I64*32>>, "
+ "<add.outputs[1]: <CA>>]), "
+ "<ops[5]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>]), "
+ "<ops[6]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<add.outputs[0]: <I64*32>>]), "
+ "<ops[6]:Late>: OFSet()})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.def_program_ranges),
+ "FMap({"
+ "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[1]:Early>, "
+ "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[2]:Early>, "
+ "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[3]:Early>, "
+ "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[4]:Early>, "
+ "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Early>, "
+ "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Early>, "
+ "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.use_program_points),
+ "FMap({"
+ "<ld.input_uses[0]: <I64>>: <ops[2]:Early>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>: <ops[2]:Early>, "
+ "<li.input_uses[0]: <VL_MAXVL>>: <ops[3]:Early>, "
+ "<add.input_uses[0]: <I64*32>>: <ops[5]:Early>, "
+ "<add.input_uses[1]: <I64*32>>: <ops[5]:Early>, "
+ "<add.input_uses[2]: <CA>>: <ops[5]:Early>, "
+ "<add.input_uses[3]: <VL_MAXVL>>: <ops[5]:Early>, "
+ "<st.input_uses[0]: <I64*32>>: <ops[6]:Early>, "
+ "<st.input_uses[1]: <I64>>: <ops[6]:Early>, "
+ "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.all_program_points),
+ "<range:ops[0]:Early..ops[7]:Early>")
+
+ def test_repr(self):
+ fn, _arg = self.make_add_fn()
+ self.assertEqual([repr(i) for i in fn.ops], [
+ "Op(kind=OpKind.FuncArgR3, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[], "
+ "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
+ "Op(kind=OpKind.SvLd, "
+ "input_vals=[<arg.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.input_uses[0]: <I64>>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[0], "
+ "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
+ "Op(kind=OpKind.SvLI, "
+ "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
+ "immediates=[0], "
+ "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
+ "Op(kind=OpKind.SetCA, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[], "
+ "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
+ "Op(kind=OpKind.SvAddE, "
+ "input_vals=[<ld.outputs[0]: <I64*32>>, "
+ "<li.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.input_uses[0]: <I64*32>>, "
+ "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+ "name='add')",
+ "Op(kind=OpKind.SvStd, "
+ "input_vals=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.input_uses[0]: <I64*32>>, "
+ "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
+ "immediates=[0], "
+ "outputs=(), name='st')",
+ ])
+ self.assertEqual([repr(op.properties) for op in fn.ops], [
+ "OpProperties(kind=OpKind.FuncArgR3, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=1)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SvLd, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
+ "ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
+ "OpProperties(kind=OpKind.SvLI, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetCA, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.CA: FBitSet([0])}), ty=<CA>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SvAddE, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.CA: FBitSet([0])}), ty=<CA>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.CA: FBitSet([0])}), ty=<CA>), "
+ "tied_input_index=2, spread_index=None, "
+ "write_stage=OpStage.Early)), maxvl=32)",
+ "OpProperties(kind=OpKind.SvStd, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
+ "ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=(), maxvl=32)",
+ ])
+
+ def test_pre_ra_insert_copies(self):
+ fn, _arg = self.make_add_fn()
+ fn.pre_ra_insert_copies()
+ self.assertEqual([repr(i) for i in fn.ops], [
+ "Op(kind=OpKind.FuncArgR3, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[], "
+ "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
+ "Op(kind=OpKind.CopyFromReg, "
+ "input_vals=[<arg.outputs[0]: <I64>>], "
+ "input_uses=(<arg.out0.copy.input_uses[0]: <I64>>,), "
+ "immediates=[], "
+ "outputs=(<arg.out0.copy.outputs[0]: <I64>>,), "
+ "name='arg.out0.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
+ "Op(kind=OpKind.CopyToReg, "
+ "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
+ "input_uses=(<ld.inp0.copy.input_uses[0]: <I64>>,), "
+ "immediates=[], "
+ "outputs=(<ld.inp0.copy.outputs[0]: <I64>>,), name='ld.inp0.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='ld.inp1.setvl')",
+ "Op(kind=OpKind.SvLd, "
+ "input_vals=[<ld.inp0.copy.outputs[0]: <I64>>, "
+ "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.input_uses[0]: <I64>>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[0], "
+ "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<ld.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='ld.out0.setvl')",
+ "Op(kind=OpKind.VecCopyFromReg, "
+ "input_vals=[<ld.outputs[0]: <I64*32>>, "
+ "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.out0.copy.input_uses[0]: <I64*32>>, "
+ "<ld.out0.copy.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<ld.out0.copy.outputs[0]: <I64*32>>,), "
+ "name='ld.out0.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<li.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='li.inp0.setvl')",
+ "Op(kind=OpKind.SvLI, "
+ "input_vals=[<li.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
+ "immediates=[0], "
+ "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<li.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='li.out0.setvl')",
+ "Op(kind=OpKind.VecCopyFromReg, "
+ "input_vals=[<li.outputs[0]: <I64*32>>, "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.out0.copy.input_uses[0]: <I64*32>>, "
+ "<li.out0.copy.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<li.out0.copy.outputs[0]: <I64*32>>,), "
+ "name='li.out0.copy')",
+ "Op(kind=OpKind.SetCA, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[], "
+ "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<add.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.inp0.setvl')",
+ "Op(kind=OpKind.VecCopyToReg, "
+ "input_vals=[<ld.out0.copy.outputs[0]: <I64*32>>, "
+ "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.inp0.copy.input_uses[0]: <I64*32>>, "
+ "<add.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<add.inp0.copy.outputs[0]: <I64*32>>,), "
+ "name='add.inp0.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<add.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.inp1.setvl')",
+ "Op(kind=OpKind.VecCopyToReg, "
+ "input_vals=[<li.out0.copy.outputs[0]: <I64*32>>, "
+ "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.inp1.copy.input_uses[0]: <I64*32>>, "
+ "<add.inp1.copy.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<add.inp1.copy.outputs[0]: <I64*32>>,), "
+ "name='add.inp1.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<add.inp3.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.inp3.setvl')",
+ "Op(kind=OpKind.SvAddE, "
+ "input_vals=[<add.inp0.copy.outputs[0]: <I64*32>>, "
+ "<add.inp1.copy.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
+ "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.input_uses[0]: <I64*32>>, "
+ "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+ "name='add')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<add.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.out0.setvl')",
+ "Op(kind=OpKind.VecCopyFromReg, "
+ "input_vals=[<add.outputs[0]: <I64*32>>, "
+ "<add.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.out0.copy.input_uses[0]: <I64*32>>, "
+ "<add.out0.copy.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<add.out0.copy.outputs[0]: <I64*32>>,), "
+ "name='add.out0.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<st.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='st.inp0.setvl')",
+ "Op(kind=OpKind.VecCopyToReg, "
+ "input_vals=[<add.out0.copy.outputs[0]: <I64*32>>, "
+ "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.inp0.copy.input_uses[0]: <I64*32>>, "
+ "<st.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
+ "immediates=[], "
+ "outputs=(<st.inp0.copy.outputs[0]: <I64*32>>,), "
+ "name='st.inp0.copy')",
+ "Op(kind=OpKind.CopyToReg, "
+ "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
+ "input_uses=(<st.inp1.copy.input_uses[0]: <I64>>,), "
+ "immediates=[], "
+ "outputs=(<st.inp1.copy.outputs[0]: <I64>>,), "
+ "name='st.inp1.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<st.inp2.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='st.inp2.setvl')",
+ "Op(kind=OpKind.SvStd, "
+ "input_vals=[<st.inp0.copy.outputs[0]: <I64*32>>, "
+ "<st.inp1.copy.outputs[0]: <I64>>, "
+ "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.input_uses[0]: <I64*32>>, "
+ "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
+ "immediates=[0], "
+ "outputs=(), name='st')",
+ ])
+ self.assertEqual([repr(op.properties) for op in fn.ops], [
+ "OpProperties(kind=OpKind.FuncArgR3, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=1)",
+ "OpProperties(kind=OpKind.CopyFromReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
+ "ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
+ "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.CopyToReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
+ "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
+ "ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SvLd, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
+ "ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.VecCopyFromReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97)), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SvLI, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.VecCopyFromReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97)), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetCA, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.CA: FBitSet([0])}), ty=<CA>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.VecCopyToReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97)), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.VecCopyToReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97)), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SvAddE, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.CA: FBitSet([0])}), ty=<CA>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.CA: FBitSet([0])}), ty=<CA>), "
+ "tied_input_index=2, spread_index=None, "
+ "write_stage=OpStage.Early)), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.VecCopyFromReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97)), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.VecCopyToReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97)), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.CopyToReg, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
+ "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
+ "ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SvStd, "
+ "inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
+ "ty=<I64>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
+ "outputs=(), maxvl=32)",
+ ])
+
+ def test_sim(self):
+ fn, arg = self.make_add_fn()
+ addr = 0x100
+ state = PreRASimState(ssa_vals={arg: (addr,)}, memory={})
+ state.store(addr=addr, value=0xffffffff_ffffffff,
+ size_in_bytes=GPR_SIZE_IN_BYTES)
+ state.store(addr=addr + GPR_SIZE_IN_BYTES, value=0xabcdef01_23456789,
+ size_in_bytes=GPR_SIZE_IN_BYTES)
+ self.assertEqual(
+ repr(state),
+ "PreRASimState(memory={\n"
+ "0x00100: <0xffffffffffffffff>,\n"
+ "0x00108: <0xabcdef0123456789>}, "
+ "ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)})")
+ fn.sim(state)
+ self.assertEqual(
+ repr(state),
+ "PreRASimState(memory={\n"
+ "0x00100: <0x0000000000000000>,\n"
+ "0x00108: <0xabcdef012345678a>,\n"
+ "0x00110: <0x0000000000000000>,\n"
+ "0x00118: <0x0000000000000000>,\n"
+ "0x00120: <0x0000000000000000>,\n"
+ "0x00128: <0x0000000000000000>,\n"
+ "0x00130: <0x0000000000000000>,\n"
+ "0x00138: <0x0000000000000000>,\n"
+ "0x00140: <0x0000000000000000>,\n"
+ "0x00148: <0x0000000000000000>,\n"
+ "0x00150: <0x0000000000000000>,\n"
+ "0x00158: <0x0000000000000000>,\n"
+ "0x00160: <0x0000000000000000>,\n"
+ "0x00168: <0x0000000000000000>,\n"
+ "0x00170: <0x0000000000000000>,\n"
+ "0x00178: <0x0000000000000000>,\n"
+ "0x00180: <0x0000000000000000>,\n"
+ "0x00188: <0x0000000000000000>,\n"
+ "0x00190: <0x0000000000000000>,\n"
+ "0x00198: <0x0000000000000000>,\n"
+ "0x001a0: <0x0000000000000000>,\n"
+ "0x001a8: <0x0000000000000000>,\n"
+ "0x001b0: <0x0000000000000000>,\n"
+ "0x001b8: <0x0000000000000000>,\n"
+ "0x001c0: <0x0000000000000000>,\n"
+ "0x001c8: <0x0000000000000000>,\n"
+ "0x001d0: <0x0000000000000000>,\n"
+ "0x001d8: <0x0000000000000000>,\n"
+ "0x001e0: <0x0000000000000000>,\n"
+ "0x001e8: <0x0000000000000000>,\n"
+ "0x001f0: <0x0000000000000000>,\n"
+ "0x001f8: <0x0000000000000000>}, ssa_vals={\n"
+ "<arg.outputs[0]: <I64>>: (0x100,),\n"
+ "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
+ "<ld.outputs[0]: <I64*32>>: (\n"
+ " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0),\n"
+ "<li.outputs[0]: <I64*32>>: (\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0),\n"
+ "<ca.outputs[0]: <CA>>: (0x1,),\n"
+ "<add.outputs[0]: <I64*32>>: (\n"
+ " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0),\n"
+ "<add.outputs[1]: <CA>>: (0x0,),\n"
+ "})")
+
+ def test_gen_asm(self):
+ fn, _arg = self.make_add_fn()
+ fn.pre_ra_insert_copies()
+ VL_LOC = Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)
+ CA_LOC = Loc(kind=LocKind.CA, start=0, reg_len=1)
+ state = GenAsmState(allocated_locs={
+ fn.ops[0].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[1].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[2].outputs[0]: VL_LOC,
+ fn.ops[3].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[4].outputs[0]: VL_LOC,
+ fn.ops[5].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[6].outputs[0]: VL_LOC,
+ fn.ops[7].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[8].outputs[0]: VL_LOC,
+ fn.ops[9].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+ fn.ops[10].outputs[0]: VL_LOC,
+ fn.ops[11].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+ fn.ops[12].outputs[0]: CA_LOC,
+ fn.ops[13].outputs[0]: VL_LOC,
+ fn.ops[14].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[15].outputs[0]: VL_LOC,
+ fn.ops[16].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+ fn.ops[17].outputs[0]: VL_LOC,
+ fn.ops[18].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[18].outputs[1]: CA_LOC,
+ fn.ops[19].outputs[0]: VL_LOC,
+ fn.ops[20].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[21].outputs[0]: VL_LOC,
+ fn.ops[22].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[23].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[24].outputs[0]: VL_LOC,
+ })
+ fn.gen_asm(state)
+ self.assertEqual(state.output, [
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.ld *32, 0(3)',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.addi *64, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'subfc 0, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.adde *32, *32, *64',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.std *32, 0(3)',
+ ])
+
+ def test_spread(self):
+ fn = Fn()
+ maxvl = 4
+ vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+ name="vl", maxvl=maxvl).outputs[0]
+ li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ name="li", maxvl=maxvl).outputs[0]
+ spread_op = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
+ name="spread", maxvl=maxvl)
+ self.assertEqual(spread_op.outputs[0].ty_before_spread,
+ Ty(base_ty=BaseTy.I64, reg_len=maxvl))
+ _concat = fn.append_new_op(
+ OpKind.Concat, input_vals=[*spread_op.outputs[::-1], vl],
+ name="concat", maxvl=maxvl)
+ self.assertEqual([repr(op.properties) for op in fn.ops], [
+ "OpProperties(kind=OpKind.SetVLI, inputs=("
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),"
+ "), maxvl=4)",
+ "OpProperties(kind=OpKind.SvLI, inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),"
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),"
+ "), maxvl=4)",
+ "OpProperties(kind=OpKind.Spread, inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)"
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=0, "
+ "write_stage=OpStage.Late), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=1, "
+ "write_stage=OpStage.Late), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=2, "
+ "write_stage=OpStage.Late), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=3, "
+ "write_stage=OpStage.Late)"
+ "), maxvl=4)",
+ "OpProperties(kind=OpKind.Concat, inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=0, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=1, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=2, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=3, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)"
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),"
+ "), maxvl=4)",
+ ])
+ self.assertEqual([repr(op) for op in fn.ops], [
+ "Op(kind=OpKind.SetVLI, input_vals=["
+ "], input_uses=("
+ "), immediates=[4], outputs=("
+ "<vl.outputs[0]: <VL_MAXVL>>,"
+ "), name='vl')",
+ "Op(kind=OpKind.SvLI, input_vals=["
+ "<vl.outputs[0]: <VL_MAXVL>>"
+ "], input_uses=("
+ "<li.input_uses[0]: <VL_MAXVL>>,"
+ "), immediates=[0], outputs=("
+ "<li.outputs[0]: <I64*4>>,"
+ "), name='li')",
+ "Op(kind=OpKind.Spread, input_vals=["
+ "<li.outputs[0]: <I64*4>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>"
+ "], input_uses=("
+ "<spread.input_uses[0]: <I64*4>>, "
+ "<spread.input_uses[1]: <VL_MAXVL>>"
+ "), immediates=[], outputs=("
+ "<spread.outputs[0]: <I64>>, "
+ "<spread.outputs[1]: <I64>>, "
+ "<spread.outputs[2]: <I64>>, "
+ "<spread.outputs[3]: <I64>>"
+ "), name='spread')",
+ "Op(kind=OpKind.Concat, input_vals=["
+ "<spread.outputs[3]: <I64>>, "
+ "<spread.outputs[2]: <I64>>, "
+ "<spread.outputs[1]: <I64>>, "
+ "<spread.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>"
+ "], input_uses=("
+ "<concat.input_uses[0]: <I64>>, "
+ "<concat.input_uses[1]: <I64>>, "
+ "<concat.input_uses[2]: <I64>>, "
+ "<concat.input_uses[3]: <I64>>, "
+ "<concat.input_uses[4]: <VL_MAXVL>>"
+ "), immediates=[], outputs=("
+ "<concat.outputs[0]: <I64*4>>,"
+ "), name='concat')",
+ ])
+
+
+if __name__ == "__main__":
+ _ = unittest.main()
+++ /dev/null
-import unittest
-
-from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, BaseTy, Fn,
- FnAnalysis, GenAsmState, Loc, LocKind, OpKind, OpStage,
- PreRASimState, ProgramPoint,
- SSAVal, Ty)
-
-
-class TestCompilerIR(unittest.TestCase):
- maxDiff = None
-
- def test_program_point(self):
- # type: () -> None
- expected = [] # type: list[ProgramPoint]
- for op_index in range(5):
- for stage in OpStage:
- expected.append(ProgramPoint(op_index=op_index, stage=stage))
-
- for idx, pp in enumerate(expected):
- if idx + 1 < len(expected):
- self.assertEqual(pp.next(), expected[idx + 1])
-
- self.assertEqual(sorted(expected), expected)
-
- def make_add_fn(self):
- # type: () -> tuple[Fn, SSAVal]
- fn = Fn()
- op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg")
- arg = op0.outputs[0]
- MAXVL = 32
- op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
- vl = op1.outputs[0]
- op2 = fn.append_new_op(
- OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
- name="ld")
- a = op2.outputs[0]
- op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
- maxvl=MAXVL, name="li")
- b = op3.outputs[0]
- op4 = fn.append_new_op(OpKind.SetCA, name="ca")
- ca = op4.outputs[0]
- op5 = fn.append_new_op(
- OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
- s = op5.outputs[0]
- _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
- immediates=[0], maxvl=MAXVL, name="st")
- return fn, arg
-
- def test_fn_analysis(self):
- fn, _arg = self.make_add_fn()
- fn_analysis = FnAnalysis(fn)
- self.assertEqual(
- repr(fn_analysis.uses),
- "FMap({"
- "<arg.outputs[0]: <I64>>: OFSet(["
- "<ld.input_uses[0]: <I64>>, <st.input_uses[1]: <I64>>]), "
- "<vl.outputs[0]: <VL_MAXVL>>: OFSet(["
- "<ld.input_uses[1]: <VL_MAXVL>>, <li.input_uses[0]: <VL_MAXVL>>, "
- "<add.input_uses[3]: <VL_MAXVL>>, "
- "<st.input_uses[2]: <VL_MAXVL>>]), "
- "<ld.outputs[0]: <I64*32>>: OFSet(["
- "<add.input_uses[0]: <I64*32>>]), "
- "<li.outputs[0]: <I64*32>>: OFSet(["
- "<add.input_uses[1]: <I64*32>>]), "
- "<ca.outputs[0]: <CA>>: OFSet([<add.input_uses[2]: <CA>>]), "
- "<add.outputs[0]: <I64*32>>: OFSet(["
- "<st.input_uses[0]: <I64*32>>]), "
- "<add.outputs[1]: <CA>>: OFSet()})"
- )
- self.assertEqual(
- repr(fn_analysis.op_indexes),
- "FMap({"
- "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), "
- "immediates=[], outputs=(<arg.outputs[0]: <I64>>,), "
- "name='arg'): 0, "
- "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), "
- "immediates=[32], outputs=(<vl.outputs[0]: <VL_MAXVL>>,), "
- "name='vl'): 1, "
- "Op(kind=OpKind.SvLd, input_vals=["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<ld.input_uses[0]: <I64>>, "
- "<ld.input_uses[1]: <VL_MAXVL>>), immediates=[0], "
- "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld'): 2, "
- "Op(kind=OpKind.SvLI, input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), immediates=[0], "
- "outputs=(<li.outputs[0]: <I64*32>>,), name='li'): 3, "
- "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), "
- "immediates=[], outputs=(<ca.outputs[0]: <CA>>,), name='ca'): 4, "
- "Op(kind=OpKind.SvAddE, input_vals=["
- "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
- "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<add.input_uses[0]: <I64*32>>, "
- "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
- "<add.input_uses[3]: <VL_MAXVL>>), immediates=[], outputs=("
- "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
- "name='add'): 5, "
- "Op(kind=OpKind.SvStd, input_vals=["
- "<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
- "<vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<st.input_uses[0]: <I64*32>>, "
- "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
- "immediates=[0], outputs=(), name='st'): 6})"
- )
- self.assertEqual(
- repr(fn_analysis.live_ranges),
- "FMap({"
- "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[6]:Late>, "
- "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[6]:Late>, "
- "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[5]:Late>, "
- "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[5]:Late>, "
- "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Late>, "
- "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Late>, "
- "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
- )
- self.assertEqual(
- repr(fn_analysis.live_at),
- "FMap({"
- "<ops[0]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
- "<ops[0]:Late>: OFSet([<arg.outputs[0]: <I64>>]), "
- "<ops[1]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
- "<ops[1]:Late>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>]), "
- "<ops[2]:Early>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<ld.outputs[0]: <I64*32>>]), "
- "<ops[2]:Late>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<ld.outputs[0]: <I64*32>>]), "
- "<ops[3]:Early>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
- "<ops[3]:Late>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
- "<ops[4]:Early>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
- "<ops[4]:Late>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
- "<ca.outputs[0]: <CA>>]), "
- "<ops[5]:Early>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
- "<ca.outputs[0]: <CA>>, <add.outputs[0]: <I64*32>>, "
- "<add.outputs[1]: <CA>>]), "
- "<ops[5]:Late>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>]), "
- "<ops[6]:Early>: OFSet(["
- "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
- "<add.outputs[0]: <I64*32>>]), "
- "<ops[6]:Late>: OFSet()})"
- )
- self.assertEqual(
- repr(fn_analysis.def_program_ranges),
- "FMap({"
- "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[1]:Early>, "
- "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[2]:Early>, "
- "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[3]:Early>, "
- "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[4]:Early>, "
- "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Early>, "
- "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Early>, "
- "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
- )
- self.assertEqual(
- repr(fn_analysis.use_program_points),
- "FMap({"
- "<ld.input_uses[0]: <I64>>: <ops[2]:Early>, "
- "<ld.input_uses[1]: <VL_MAXVL>>: <ops[2]:Early>, "
- "<li.input_uses[0]: <VL_MAXVL>>: <ops[3]:Early>, "
- "<add.input_uses[0]: <I64*32>>: <ops[5]:Early>, "
- "<add.input_uses[1]: <I64*32>>: <ops[5]:Early>, "
- "<add.input_uses[2]: <CA>>: <ops[5]:Early>, "
- "<add.input_uses[3]: <VL_MAXVL>>: <ops[5]:Early>, "
- "<st.input_uses[0]: <I64*32>>: <ops[6]:Early>, "
- "<st.input_uses[1]: <I64>>: <ops[6]:Early>, "
- "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>})"
- )
- self.assertEqual(
- repr(fn_analysis.all_program_points),
- "<range:ops[0]:Early..ops[7]:Early>")
-
- def test_repr(self):
- fn, _arg = self.make_add_fn()
- self.assertEqual([repr(i) for i in fn.ops], [
- "Op(kind=OpKind.FuncArgR3, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[], "
- "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
- "Op(kind=OpKind.SvLd, "
- "input_vals=[<arg.outputs[0]: <I64>>, "
- "<vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<ld.input_uses[0]: <I64>>, "
- "<ld.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[0], "
- "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
- "Op(kind=OpKind.SvLI, "
- "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
- "immediates=[0], "
- "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
- "Op(kind=OpKind.SetCA, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[], "
- "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
- "Op(kind=OpKind.SvAddE, "
- "input_vals=[<ld.outputs[0]: <I64*32>>, "
- "<li.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
- "<vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<add.input_uses[0]: <I64*32>>, "
- "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
- "<add.input_uses[3]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
- "name='add')",
- "Op(kind=OpKind.SvStd, "
- "input_vals=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
- "<vl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<st.input_uses[0]: <I64*32>>, "
- "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
- "immediates=[0], "
- "outputs=(), name='st')",
- ])
- self.assertEqual([repr(op.properties) for op in fn.ops], [
- "OpProperties(kind=OpKind.FuncArgR3, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), maxvl=1)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SvLd, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
- "ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), maxvl=32)",
- "OpProperties(kind=OpKind.SvLI, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), maxvl=32)",
- "OpProperties(kind=OpKind.SetCA, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SvAddE, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=2, spread_index=None, "
- "write_stage=OpStage.Early)), maxvl=32)",
- "OpProperties(kind=OpKind.SvStd, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
- "ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=(), maxvl=32)",
- ])
-
- def test_pre_ra_insert_copies(self):
- fn, _arg = self.make_add_fn()
- fn.pre_ra_insert_copies()
- self.assertEqual([repr(i) for i in fn.ops], [
- "Op(kind=OpKind.FuncArgR3, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[], "
- "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
- "Op(kind=OpKind.CopyFromReg, "
- "input_vals=[<arg.outputs[0]: <I64>>], "
- "input_uses=(<arg.out0.copy.input_uses[0]: <I64>>,), "
- "immediates=[], "
- "outputs=(<arg.out0.copy.outputs[0]: <I64>>,), "
- "name='arg.out0.copy')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
- "Op(kind=OpKind.CopyToReg, "
- "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
- "input_uses=(<ld.inp0.copy.input_uses[0]: <I64>>,), "
- "immediates=[], "
- "outputs=(<ld.inp0.copy.outputs[0]: <I64>>,), name='ld.inp0.copy')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='ld.inp1.setvl')",
- "Op(kind=OpKind.SvLd, "
- "input_vals=[<ld.inp0.copy.outputs[0]: <I64>>, "
- "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<ld.input_uses[0]: <I64>>, "
- "<ld.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[0], "
- "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<ld.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='ld.out0.setvl')",
- "Op(kind=OpKind.VecCopyFromReg, "
- "input_vals=[<ld.outputs[0]: <I64*32>>, "
- "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<ld.out0.copy.input_uses[0]: <I64*32>>, "
- "<ld.out0.copy.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<ld.out0.copy.outputs[0]: <I64*32>>,), "
- "name='ld.out0.copy')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<li.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='li.inp0.setvl')",
- "Op(kind=OpKind.SvLI, "
- "input_vals=[<li.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
- "immediates=[0], "
- "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<li.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='li.out0.setvl')",
- "Op(kind=OpKind.VecCopyFromReg, "
- "input_vals=[<li.outputs[0]: <I64*32>>, "
- "<li.out0.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<li.out0.copy.input_uses[0]: <I64*32>>, "
- "<li.out0.copy.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<li.out0.copy.outputs[0]: <I64*32>>,), "
- "name='li.out0.copy')",
- "Op(kind=OpKind.SetCA, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[], "
- "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<add.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='add.inp0.setvl')",
- "Op(kind=OpKind.VecCopyToReg, "
- "input_vals=[<ld.out0.copy.outputs[0]: <I64*32>>, "
- "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<add.inp0.copy.input_uses[0]: <I64*32>>, "
- "<add.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<add.inp0.copy.outputs[0]: <I64*32>>,), "
- "name='add.inp0.copy')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<add.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='add.inp1.setvl')",
- "Op(kind=OpKind.VecCopyToReg, "
- "input_vals=[<li.out0.copy.outputs[0]: <I64*32>>, "
- "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<add.inp1.copy.input_uses[0]: <I64*32>>, "
- "<add.inp1.copy.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<add.inp1.copy.outputs[0]: <I64*32>>,), "
- "name='add.inp1.copy')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<add.inp3.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='add.inp3.setvl')",
- "Op(kind=OpKind.SvAddE, "
- "input_vals=[<add.inp0.copy.outputs[0]: <I64*32>>, "
- "<add.inp1.copy.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
- "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<add.input_uses[0]: <I64*32>>, "
- "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
- "<add.input_uses[3]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
- "name='add')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<add.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='add.out0.setvl')",
- "Op(kind=OpKind.VecCopyFromReg, "
- "input_vals=[<add.outputs[0]: <I64*32>>, "
- "<add.out0.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<add.out0.copy.input_uses[0]: <I64*32>>, "
- "<add.out0.copy.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<add.out0.copy.outputs[0]: <I64*32>>,), "
- "name='add.out0.copy')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<st.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='st.inp0.setvl')",
- "Op(kind=OpKind.VecCopyToReg, "
- "input_vals=[<add.out0.copy.outputs[0]: <I64*32>>, "
- "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<st.inp0.copy.input_uses[0]: <I64*32>>, "
- "<st.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
- "immediates=[], "
- "outputs=(<st.inp0.copy.outputs[0]: <I64*32>>,), "
- "name='st.inp0.copy')",
- "Op(kind=OpKind.CopyToReg, "
- "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
- "input_uses=(<st.inp1.copy.input_uses[0]: <I64>>,), "
- "immediates=[], "
- "outputs=(<st.inp1.copy.outputs[0]: <I64>>,), "
- "name='st.inp1.copy')",
- "Op(kind=OpKind.SetVLI, "
- "input_vals=[], "
- "input_uses=(), "
- "immediates=[32], "
- "outputs=(<st.inp2.setvl.outputs[0]: <VL_MAXVL>>,), "
- "name='st.inp2.setvl')",
- "Op(kind=OpKind.SvStd, "
- "input_vals=[<st.inp0.copy.outputs[0]: <I64*32>>, "
- "<st.inp1.copy.outputs[0]: <I64>>, "
- "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>], "
- "input_uses=(<st.input_uses[0]: <I64*32>>, "
- "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
- "immediates=[0], "
- "outputs=(), name='st')",
- ])
- self.assertEqual([repr(op.properties) for op in fn.ops], [
- "OpProperties(kind=OpKind.FuncArgR3, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), maxvl=1)",
- "OpProperties(kind=OpKind.CopyFromReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
- "ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
- "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.CopyToReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
- "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
- "ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SvLd, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
- "ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), maxvl=32)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.VecCopyFromReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=32)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SvLI, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), maxvl=32)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.VecCopyFromReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=32)",
- "OpProperties(kind=OpKind.SetCA, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.VecCopyToReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=32)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.VecCopyToReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=32)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SvAddE, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=2, spread_index=None, "
- "write_stage=OpStage.Early)), maxvl=32)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.VecCopyFromReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=32)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.VecCopyToReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=32)",
- "OpProperties(kind=OpKind.CopyToReg, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
- "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
- "ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SetVLI, "
- "inputs=(), "
- "outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),), maxvl=1)",
- "OpProperties(kind=OpKind.SvStd, "
- "inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
- "ty=<I64>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)), "
- "outputs=(), maxvl=32)",
- ])
-
- def test_sim(self):
- fn, arg = self.make_add_fn()
- addr = 0x100
- state = PreRASimState(ssa_vals={arg: (addr,)}, memory={})
- state.store(addr=addr, value=0xffffffff_ffffffff,
- size_in_bytes=GPR_SIZE_IN_BYTES)
- state.store(addr=addr + GPR_SIZE_IN_BYTES, value=0xabcdef01_23456789,
- size_in_bytes=GPR_SIZE_IN_BYTES)
- self.assertEqual(
- repr(state),
- "PreRASimState(memory={\n"
- "0x00100: <0xffffffffffffffff>,\n"
- "0x00108: <0xabcdef0123456789>}, "
- "ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)})")
- fn.sim(state)
- self.assertEqual(
- repr(state),
- "PreRASimState(memory={\n"
- "0x00100: <0x0000000000000000>,\n"
- "0x00108: <0xabcdef012345678a>,\n"
- "0x00110: <0x0000000000000000>,\n"
- "0x00118: <0x0000000000000000>,\n"
- "0x00120: <0x0000000000000000>,\n"
- "0x00128: <0x0000000000000000>,\n"
- "0x00130: <0x0000000000000000>,\n"
- "0x00138: <0x0000000000000000>,\n"
- "0x00140: <0x0000000000000000>,\n"
- "0x00148: <0x0000000000000000>,\n"
- "0x00150: <0x0000000000000000>,\n"
- "0x00158: <0x0000000000000000>,\n"
- "0x00160: <0x0000000000000000>,\n"
- "0x00168: <0x0000000000000000>,\n"
- "0x00170: <0x0000000000000000>,\n"
- "0x00178: <0x0000000000000000>,\n"
- "0x00180: <0x0000000000000000>,\n"
- "0x00188: <0x0000000000000000>,\n"
- "0x00190: <0x0000000000000000>,\n"
- "0x00198: <0x0000000000000000>,\n"
- "0x001a0: <0x0000000000000000>,\n"
- "0x001a8: <0x0000000000000000>,\n"
- "0x001b0: <0x0000000000000000>,\n"
- "0x001b8: <0x0000000000000000>,\n"
- "0x001c0: <0x0000000000000000>,\n"
- "0x001c8: <0x0000000000000000>,\n"
- "0x001d0: <0x0000000000000000>,\n"
- "0x001d8: <0x0000000000000000>,\n"
- "0x001e0: <0x0000000000000000>,\n"
- "0x001e8: <0x0000000000000000>,\n"
- "0x001f0: <0x0000000000000000>,\n"
- "0x001f8: <0x0000000000000000>}, ssa_vals={\n"
- "<arg.outputs[0]: <I64>>: (0x100,),\n"
- "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
- "<ld.outputs[0]: <I64*32>>: (\n"
- " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0),\n"
- "<li.outputs[0]: <I64*32>>: (\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0),\n"
- "<ca.outputs[0]: <CA>>: (0x1,),\n"
- "<add.outputs[0]: <I64*32>>: (\n"
- " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0),\n"
- "<add.outputs[1]: <CA>>: (0x0,),\n"
- "})")
-
- def test_gen_asm(self):
- fn, _arg = self.make_add_fn()
- fn.pre_ra_insert_copies()
- VL_LOC = Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)
- CA_LOC = Loc(kind=LocKind.CA, start=0, reg_len=1)
- state = GenAsmState(allocated_locs={
- fn.ops[0].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
- fn.ops[1].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
- fn.ops[2].outputs[0]: VL_LOC,
- fn.ops[3].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
- fn.ops[4].outputs[0]: VL_LOC,
- fn.ops[5].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
- fn.ops[6].outputs[0]: VL_LOC,
- fn.ops[7].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
- fn.ops[8].outputs[0]: VL_LOC,
- fn.ops[9].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
- fn.ops[10].outputs[0]: VL_LOC,
- fn.ops[11].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
- fn.ops[12].outputs[0]: CA_LOC,
- fn.ops[13].outputs[0]: VL_LOC,
- fn.ops[14].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
- fn.ops[15].outputs[0]: VL_LOC,
- fn.ops[16].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
- fn.ops[17].outputs[0]: VL_LOC,
- fn.ops[18].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
- fn.ops[18].outputs[1]: CA_LOC,
- fn.ops[19].outputs[0]: VL_LOC,
- fn.ops[20].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
- fn.ops[21].outputs[0]: VL_LOC,
- fn.ops[22].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
- fn.ops[23].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
- fn.ops[24].outputs[0]: VL_LOC,
- })
- fn.gen_asm(state)
- self.assertEqual(state.output, [
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.ld *32, 0(3)',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.addi *64, 0, 0',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'subfc 0, 0, 0',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.adde *32, *32, *64',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.std *32, 0(3)',
- ])
-
- def test_spread(self):
- fn = Fn()
- maxvl = 4
- vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
- name="vl", maxvl=maxvl).outputs[0]
- li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
- name="li", maxvl=maxvl).outputs[0]
- spread_op = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
- name="spread", maxvl=maxvl)
- self.assertEqual(spread_op.outputs[0].ty_before_spread,
- Ty(base_ty=BaseTy.I64, reg_len=maxvl))
- _concat = fn.append_new_op(
- OpKind.Concat, input_vals=[*spread_op.outputs[::-1], vl],
- name="concat", maxvl=maxvl)
- self.assertEqual([repr(op.properties) for op in fn.ops], [
- "OpProperties(kind=OpKind.SetVLI, inputs=("
- "), outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), "
- "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),"
- "), maxvl=4)",
- "OpProperties(kind=OpKind.SvLI, inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), "
- "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),"
- "), outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early),"
- "), maxvl=4)",
- "OpProperties(kind=OpKind.Spread, inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), "
- "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)"
- "), outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=0, "
- "write_stage=OpStage.Late), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=1, "
- "write_stage=OpStage.Late), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=2, "
- "write_stage=OpStage.Late), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=3, "
- "write_stage=OpStage.Late)"
- "), maxvl=4)",
- "OpProperties(kind=OpKind.Concat, inputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=0, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=1, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=2, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=3, "
- "write_stage=OpStage.Early), "
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.VL_MAXVL: FBitSet([0])}), "
- "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Early)"
- "), outputs=("
- "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
- "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
- "ty=<I64*4>), tied_input_index=None, spread_index=None, "
- "write_stage=OpStage.Late),"
- "), maxvl=4)",
- ])
- self.assertEqual([repr(op) for op in fn.ops], [
- "Op(kind=OpKind.SetVLI, input_vals=["
- "], input_uses=("
- "), immediates=[4], outputs=("
- "<vl.outputs[0]: <VL_MAXVL>>,"
- "), name='vl')",
- "Op(kind=OpKind.SvLI, input_vals=["
- "<vl.outputs[0]: <VL_MAXVL>>"
- "], input_uses=("
- "<li.input_uses[0]: <VL_MAXVL>>,"
- "), immediates=[0], outputs=("
- "<li.outputs[0]: <I64*4>>,"
- "), name='li')",
- "Op(kind=OpKind.Spread, input_vals=["
- "<li.outputs[0]: <I64*4>>, "
- "<vl.outputs[0]: <VL_MAXVL>>"
- "], input_uses=("
- "<spread.input_uses[0]: <I64*4>>, "
- "<spread.input_uses[1]: <VL_MAXVL>>"
- "), immediates=[], outputs=("
- "<spread.outputs[0]: <I64>>, "
- "<spread.outputs[1]: <I64>>, "
- "<spread.outputs[2]: <I64>>, "
- "<spread.outputs[3]: <I64>>"
- "), name='spread')",
- "Op(kind=OpKind.Concat, input_vals=["
- "<spread.outputs[3]: <I64>>, "
- "<spread.outputs[2]: <I64>>, "
- "<spread.outputs[1]: <I64>>, "
- "<spread.outputs[0]: <I64>>, "
- "<vl.outputs[0]: <VL_MAXVL>>"
- "], input_uses=("
- "<concat.input_uses[0]: <I64>>, "
- "<concat.input_uses[1]: <I64>>, "
- "<concat.input_uses[2]: <I64>>, "
- "<concat.input_uses[3]: <I64>>, "
- "<concat.input_uses[4]: <VL_MAXVL>>"
- "), immediates=[], outputs=("
- "<concat.outputs[0]: <I64*4>>,"
- "), name='concat')",
- ])
-
-
-if __name__ == "__main__":
- _ = unittest.main()
--- /dev/null
+import sys
+import unittest
+
+from bigint_presentation_code.compiler_ir import (Fn, GenAsmState, OpKind,
+ SSAVal)
+from bigint_presentation_code.register_allocator import allocate_registers
+
+
+class TestRegisterAllocator(unittest.TestCase):
+ maxDiff = None
+
+ def make_add_fn(self):
+ # type: () -> tuple[Fn, SSAVal]
+ fn = Fn()
+ op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg")
+ arg = op0.outputs[0]
+ MAXVL = 32
+ op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
+ vl = op1.outputs[0]
+ op2 = fn.append_new_op(
+ OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
+ name="ld")
+ a = op2.outputs[0]
+ op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ maxvl=MAXVL, name="li")
+ b = op3.outputs[0]
+ op4 = fn.append_new_op(OpKind.SetCA, name="ca")
+ ca = op4.outputs[0]
+ op5 = fn.append_new_op(
+ OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
+ s = op5.outputs[0]
+ _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+ immediates=[0], maxvl=MAXVL, name="st")
+ return fn, arg
+
+ def test_register_allocate(self):
+ fn, _arg = self.make_add_fn()
+ reg_assignments = allocate_registers(fn)
+
+ self.assertEqual(
+ repr(reg_assignments),
+ "{<add.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.inp1.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<add.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+ "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<st.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<st.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ca.outputs[0]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<vl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<arg.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+ )
+
+ def test_gen_asm(self):
+ fn, _arg = self.make_add_fn()
+ reg_assignments = allocate_registers(fn)
+
+ self.assertEqual(
+ repr(reg_assignments),
+ "{<add.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.inp1.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<add.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+ "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<st.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<st.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ca.outputs[0]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<vl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<arg.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+ )
+ state = GenAsmState(reg_assignments)
+ fn.gen_asm(state)
+ self.assertEqual(state.output, [
+ 'or 4, 3, 3',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'or 3, 4, 4',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.ld *14, 0(3)',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.or *46, *14, *14',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.addi *14, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'subfc 0, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.or *78, *46, *46',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.or *46, *14, *14',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.adde *14, *78, *46',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'or 3, 4, 4',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.std *14, 0(3)',
+ ])
+
+ def test_register_allocate_spread(self):
+ fn = Fn()
+ maxvl = 32
+ vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+ name="vl", maxvl=maxvl).outputs[0]
+ li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ name="li", maxvl=maxvl).outputs[0]
+ spread = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
+ name="spread", maxvl=maxvl).outputs
+ _concat = fn.append_new_op(
+ OpKind.Concat, input_vals=[*spread[::-1], vl],
+ name="concat", maxvl=maxvl)
+ reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+
+ self.assertEqual(
+ repr(reg_assignments),
+ "{<concat.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<concat.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<concat.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+ "<concat.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+ "<concat.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+ "<concat.inp3.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+ "<concat.inp4.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+ "<concat.inp5.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+ "<concat.inp6.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
+ "<concat.inp7.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
+ "<concat.inp8.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
+ "<concat.inp9.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+ "<concat.inp10.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
+ "<concat.inp11.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
+ "<concat.inp12.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
+ "<concat.inp13.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
+ "<concat.inp14.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
+ "<concat.inp15.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
+ "<concat.inp16.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
+ "<concat.inp17.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
+ "<concat.inp18.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
+ "<concat.inp19.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
+ "<concat.inp20.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
+ "<concat.inp21.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
+ "<concat.inp22.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
+ "<concat.inp23.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
+ "<concat.inp24.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
+ "<concat.inp25.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
+ "<concat.inp26.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
+ "<concat.inp27.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
+ "<concat.inp28.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
+ "<concat.inp29.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
+ "<concat.inp30.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
+ "<concat.inp31.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
+ "<concat.inp32.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<spread.out31.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<spread.out30.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<spread.out29.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<spread.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+ "<spread.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+ "<spread.outputs[2]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+ "<spread.outputs[3]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+ "<spread.outputs[4]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+ "<spread.outputs[5]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+ "<spread.outputs[6]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
+ "<spread.outputs[7]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
+ "<spread.outputs[8]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
+ "<spread.outputs[9]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+ "<spread.outputs[10]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
+ "<spread.outputs[11]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
+ "<spread.outputs[12]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
+ "<spread.outputs[13]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
+ "<spread.outputs[14]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
+ "<spread.outputs[15]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
+ "<spread.outputs[16]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
+ "<spread.outputs[17]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
+ "<spread.outputs[18]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
+ "<spread.outputs[19]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
+ "<spread.outputs[20]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
+ "<spread.outputs[21]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
+ "<spread.outputs[22]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
+ "<spread.outputs[23]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
+ "<spread.outputs[24]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
+ "<spread.outputs[25]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
+ "<spread.outputs[26]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
+ "<spread.outputs[27]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
+ "<spread.outputs[28]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
+ "<spread.outputs[29]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
+ "<spread.outputs[30]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
+ "<spread.outputs[31]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
+ "<spread.out28.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<spread.out27.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+ "<spread.out26.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
+ "<spread.out25.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
+ "<spread.out24.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
+ "<spread.out23.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+ "<spread.out22.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+ "<spread.out21.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=1), "
+ "<spread.out20.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=47, reg_len=1), "
+ "<spread.out19.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=48, reg_len=1), "
+ "<spread.out18.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=49, reg_len=1), "
+ "<spread.out17.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=50, reg_len=1), "
+ "<spread.out16.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=51, reg_len=1), "
+ "<spread.out15.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=52, reg_len=1), "
+ "<spread.out14.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=53, reg_len=1), "
+ "<spread.out13.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=54, reg_len=1), "
+ "<spread.out12.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=55, reg_len=1), "
+ "<spread.out11.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=56, reg_len=1), "
+ "<spread.out10.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=57, reg_len=1), "
+ "<spread.out9.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=58, reg_len=1), "
+ "<spread.out8.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=59, reg_len=1), "
+ "<spread.out7.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=60, reg_len=1), "
+ "<spread.out6.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=61, reg_len=1), "
+ "<spread.out5.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=62, reg_len=1), "
+ "<spread.out4.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=63, reg_len=1), "
+ "<spread.out3.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=64, reg_len=1), "
+ "<spread.out2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=65, reg_len=1), "
+ "<spread.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=66, reg_len=1), "
+ "<spread.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=67, reg_len=1), "
+ "<spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<spread.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<vl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)}"
+ )
+ state = GenAsmState(reg_assignments)
+ fn.gen_asm(state)
+ self.assertEqual(state.output, [
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.addi *14, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'or 67, 14, 14',
+ 'or 66, 15, 15',
+ 'or 65, 16, 16',
+ 'or 64, 17, 17',
+ 'or 63, 18, 18',
+ 'or 62, 19, 19',
+ 'or 61, 20, 20',
+ 'or 60, 21, 21',
+ 'or 59, 22, 22',
+ 'or 58, 23, 23',
+ 'or 57, 24, 24',
+ 'or 56, 25, 25',
+ 'or 55, 26, 26',
+ 'or 54, 27, 27',
+ 'or 53, 28, 28',
+ 'or 52, 29, 29',
+ 'or 51, 30, 30',
+ 'or 50, 31, 31',
+ 'or 49, 32, 32',
+ 'or 48, 33, 33',
+ 'or 47, 34, 34',
+ 'or 46, 35, 35',
+ 'or 12, 36, 36',
+ 'or 11, 37, 37',
+ 'or 10, 38, 38',
+ 'or 9, 39, 39',
+ 'or 8, 40, 40',
+ 'or 7, 41, 41',
+ 'or 6, 42, 42',
+ 'or 5, 43, 43',
+ 'or 4, 44, 44',
+ 'or 3, 45, 45',
+ 'or 14, 3, 3',
+ 'or 15, 4, 4',
+ 'or 16, 5, 5',
+ 'or 17, 6, 6',
+ 'or 18, 7, 7',
+ 'or 19, 8, 8',
+ 'or 20, 9, 9',
+ 'or 21, 10, 10',
+ 'or 22, 11, 11',
+ 'or 23, 12, 12',
+ 'or 24, 46, 46',
+ 'or 25, 47, 47',
+ 'or 26, 48, 48',
+ 'or 27, 49, 49',
+ 'or 28, 50, 50',
+ 'or 29, 51, 51',
+ 'or 30, 52, 52',
+ 'or 31, 53, 53',
+ 'or 32, 54, 54',
+ 'or 33, 55, 55',
+ 'or 34, 56, 56',
+ 'or 35, 57, 57',
+ 'or 36, 58, 58',
+ 'or 37, 59, 59',
+ 'or 38, 60, 60',
+ 'or 39, 61, 61',
+ 'or 40, 62, 62',
+ 'or 41, 63, 63',
+ 'or 42, 64, 64',
+ 'or 43, 65, 65',
+ 'or 44, 66, 66',
+ 'or 45, 67, 67',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1'])
+
+
+if __name__ == "__main__":
+ _ = unittest.main()
+++ /dev/null
-import sys
-import unittest
-
-from bigint_presentation_code.compiler_ir2 import (Fn, GenAsmState, OpKind,
- SSAVal)
-from bigint_presentation_code.register_allocator2 import allocate_registers
-
-
-class TestCompilerIR(unittest.TestCase):
- maxDiff = None
-
- def make_add_fn(self):
- # type: () -> tuple[Fn, SSAVal]
- fn = Fn()
- op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg")
- arg = op0.outputs[0]
- MAXVL = 32
- op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
- vl = op1.outputs[0]
- op2 = fn.append_new_op(
- OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
- name="ld")
- a = op2.outputs[0]
- op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
- maxvl=MAXVL, name="li")
- b = op3.outputs[0]
- op4 = fn.append_new_op(OpKind.SetCA, name="ca")
- ca = op4.outputs[0]
- op5 = fn.append_new_op(
- OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
- s = op5.outputs[0]
- _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
- immediates=[0], maxvl=MAXVL, name="st")
- return fn, arg
-
- def test_register_allocate(self):
- fn, _arg = self.make_add_fn()
- reg_assignments = allocate_registers(fn)
-
- self.assertEqual(
- repr(reg_assignments),
- "{<add.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<add.inp1.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
- "<add.inp0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
- "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<st.inp1.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<st.inp0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<add.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ca.outputs[0]: <CA>>: "
- "Loc(kind=LocKind.CA, start=0, reg_len=1), "
- "<add.outputs[1]: <CA>>: "
- "Loc(kind=LocKind.CA, start=0, reg_len=1), "
- "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<li.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<li.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ld.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
- "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ld.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ld.inp0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<vl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<arg.out0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
- "<arg.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
- )
-
- def test_gen_asm(self):
- fn, _arg = self.make_add_fn()
- reg_assignments = allocate_registers(fn)
-
- self.assertEqual(
- repr(reg_assignments),
- "{<add.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<add.inp1.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
- "<add.inp0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
- "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<st.inp1.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<st.inp0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<add.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ca.outputs[0]: <CA>>: "
- "Loc(kind=LocKind.CA, start=0, reg_len=1), "
- "<add.outputs[1]: <CA>>: "
- "Loc(kind=LocKind.CA, start=0, reg_len=1), "
- "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<li.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<li.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ld.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
- "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ld.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<ld.inp0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<vl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<arg.out0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
- "<arg.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
- )
- state = GenAsmState(reg_assignments)
- fn.gen_asm(state)
- self.assertEqual(state.output, [
- 'or 4, 3, 3',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'or 3, 4, 4',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.ld *14, 0(3)',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.or *46, *14, *14',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.addi *14, 0, 0',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'subfc 0, 0, 0',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.or *78, *46, *46',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.or *46, *14, *14',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.adde *14, *78, *46',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'or 3, 4, 4',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.std *14, 0(3)',
- ])
-
- def test_register_allocate_spread(self):
- fn = Fn()
- maxvl = 32
- vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
- name="vl", maxvl=maxvl).outputs[0]
- li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
- name="li", maxvl=maxvl).outputs[0]
- spread = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
- name="spread", maxvl=maxvl).outputs
- _concat = fn.append_new_op(
- OpKind.Concat, input_vals=[*spread[::-1], vl],
- name="concat", maxvl=maxvl)
- reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
-
- self.assertEqual(
- repr(reg_assignments),
- "{<concat.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<concat.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<concat.inp0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
- "<concat.inp1.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
- "<concat.inp2.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
- "<concat.inp3.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
- "<concat.inp4.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
- "<concat.inp5.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
- "<concat.inp6.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
- "<concat.inp7.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
- "<concat.inp8.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
- "<concat.inp9.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
- "<concat.inp10.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
- "<concat.inp11.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
- "<concat.inp12.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
- "<concat.inp13.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
- "<concat.inp14.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
- "<concat.inp15.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
- "<concat.inp16.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
- "<concat.inp17.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
- "<concat.inp18.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
- "<concat.inp19.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
- "<concat.inp20.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
- "<concat.inp21.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
- "<concat.inp22.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
- "<concat.inp23.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
- "<concat.inp24.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
- "<concat.inp25.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
- "<concat.inp26.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
- "<concat.inp27.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
- "<concat.inp28.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
- "<concat.inp29.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
- "<concat.inp30.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
- "<concat.inp31.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
- "<concat.inp32.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<spread.out31.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
- "<spread.out30.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
- "<spread.out29.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
- "<spread.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
- "<spread.outputs[1]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
- "<spread.outputs[2]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
- "<spread.outputs[3]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
- "<spread.outputs[4]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
- "<spread.outputs[5]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
- "<spread.outputs[6]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
- "<spread.outputs[7]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
- "<spread.outputs[8]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
- "<spread.outputs[9]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
- "<spread.outputs[10]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
- "<spread.outputs[11]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
- "<spread.outputs[12]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
- "<spread.outputs[13]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
- "<spread.outputs[14]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
- "<spread.outputs[15]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
- "<spread.outputs[16]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
- "<spread.outputs[17]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
- "<spread.outputs[18]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
- "<spread.outputs[19]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
- "<spread.outputs[20]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
- "<spread.outputs[21]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
- "<spread.outputs[22]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
- "<spread.outputs[23]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
- "<spread.outputs[24]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
- "<spread.outputs[25]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
- "<spread.outputs[26]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
- "<spread.outputs[27]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
- "<spread.outputs[28]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
- "<spread.outputs[29]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
- "<spread.outputs[30]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
- "<spread.outputs[31]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
- "<spread.out28.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
- "<spread.out27.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
- "<spread.out26.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
- "<spread.out25.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
- "<spread.out24.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
- "<spread.out23.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
- "<spread.out22.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
- "<spread.out21.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=46, reg_len=1), "
- "<spread.out20.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=47, reg_len=1), "
- "<spread.out19.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=48, reg_len=1), "
- "<spread.out18.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=49, reg_len=1), "
- "<spread.out17.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=50, reg_len=1), "
- "<spread.out16.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=51, reg_len=1), "
- "<spread.out15.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=52, reg_len=1), "
- "<spread.out14.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=53, reg_len=1), "
- "<spread.out13.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=54, reg_len=1), "
- "<spread.out12.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=55, reg_len=1), "
- "<spread.out11.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=56, reg_len=1), "
- "<spread.out10.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=57, reg_len=1), "
- "<spread.out9.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=58, reg_len=1), "
- "<spread.out8.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=59, reg_len=1), "
- "<spread.out7.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=60, reg_len=1), "
- "<spread.out6.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=61, reg_len=1), "
- "<spread.out5.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=62, reg_len=1), "
- "<spread.out4.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=63, reg_len=1), "
- "<spread.out3.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=64, reg_len=1), "
- "<spread.out2.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=65, reg_len=1), "
- "<spread.out1.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=66, reg_len=1), "
- "<spread.out0.copy.outputs[0]: <I64>>: "
- "Loc(kind=LocKind.GPR, start=67, reg_len=1), "
- "<spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<spread.inp0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<li.out0.copy.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<li.outputs[0]: <I64*32>>: "
- "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
- "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
- "<vl.outputs[0]: <VL_MAXVL>>: "
- "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)}"
- )
- state = GenAsmState(reg_assignments)
- fn.gen_asm(state)
- self.assertEqual(state.output, [
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'sv.addi *14, 0, 0',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'or 67, 14, 14',
- 'or 66, 15, 15',
- 'or 65, 16, 16',
- 'or 64, 17, 17',
- 'or 63, 18, 18',
- 'or 62, 19, 19',
- 'or 61, 20, 20',
- 'or 60, 21, 21',
- 'or 59, 22, 22',
- 'or 58, 23, 23',
- 'or 57, 24, 24',
- 'or 56, 25, 25',
- 'or 55, 26, 26',
- 'or 54, 27, 27',
- 'or 53, 28, 28',
- 'or 52, 29, 29',
- 'or 51, 30, 30',
- 'or 50, 31, 31',
- 'or 49, 32, 32',
- 'or 48, 33, 33',
- 'or 47, 34, 34',
- 'or 46, 35, 35',
- 'or 12, 36, 36',
- 'or 11, 37, 37',
- 'or 10, 38, 38',
- 'or 9, 39, 39',
- 'or 8, 40, 40',
- 'or 7, 41, 41',
- 'or 6, 42, 42',
- 'or 5, 43, 43',
- 'or 4, 44, 44',
- 'or 3, 45, 45',
- 'or 14, 3, 3',
- 'or 15, 4, 4',
- 'or 16, 5, 5',
- 'or 17, 6, 6',
- 'or 18, 7, 7',
- 'or 19, 8, 8',
- 'or 20, 9, 9',
- 'or 21, 10, 10',
- 'or 22, 11, 11',
- 'or 23, 12, 12',
- 'or 24, 46, 46',
- 'or 25, 47, 47',
- 'or 26, 48, 48',
- 'or 27, 49, 49',
- 'or 28, 50, 50',
- 'or 29, 51, 51',
- 'or 30, 52, 52',
- 'or 31, 53, 53',
- 'or 32, 54, 54',
- 'or 33, 55, 55',
- 'or 34, 56, 56',
- 'or 35, 57, 57',
- 'or 36, 58, 58',
- 'or 37, 59, 59',
- 'or 38, 60, 60',
- 'or 39, 61, 61',
- 'or 40, 62, 62',
- 'or 41, 63, 63',
- 'or 42, 64, 64',
- 'or 43, 65, 65',
- 'or 44, 66, 66',
- 'or 45, 67, 67',
- 'setvl 0, 0, 32, 0, 1, 1',
- 'setvl 0, 0, 32, 0, 1, 1'])
-
-
-if __name__ == "__main__":
- _ = unittest.main()
import unittest
from typing import Callable
-from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES,
- BaseSimState, Fn,
- GenAsmState, OpKind,
- PostRASimState,
- PreRASimState)
-from bigint_presentation_code.register_allocator2 import allocate_registers
+from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES,
+ BaseSimState, Fn,
+ GenAsmState, OpKind,
+ PostRASimState,
+ PreRASimState)
+from bigint_presentation_code.register_allocator import allocate_registers
from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
--- /dev/null
+import enum
+from abc import ABCMeta, abstractmethod
+from enum import Enum, unique
+from functools import lru_cache, total_ordering
+from io import StringIO
+from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
+ Mapping, Sequence, TypeVar, Union, overload)
+from weakref import WeakValueDictionary as _WeakVDict
+
+from cached_property import cached_property
+from nmutil.plain_data import fields, plain_data
+
+from bigint_presentation_code.type_util import (Literal, Self, assert_never,
+ final)
+from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
+ OFSet, OSet)
+
+
+@final
+class Fn:
+ def __init__(self):
+ self.ops = [] # type: list[Op]
+ self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
+ self.__next_name_suffix = 2
+
+ def _add_op_with_unused_name(self, op, name=""):
+ # type: (Op, str) -> str
+ if op.fn is not self:
+ raise ValueError("can't add Op to wrong Fn")
+ if hasattr(op, "name"):
+ raise ValueError("Op already named")
+ orig_name = name
+ while True:
+ if name != "" and name not in self.__op_names:
+ self.__op_names[name] = op
+ return name
+ name = orig_name + str(self.__next_name_suffix)
+ self.__next_name_suffix += 1
+
+ def __repr__(self):
+ # type: () -> str
+ return "<Fn>"
+
+ def append_op(self, op):
+ # type: (Op) -> None
+ if op.fn is not self:
+ raise ValueError("can't add Op to wrong Fn")
+ self.ops.append(op)
+
+ def append_new_op(self, kind, input_vals=(), immediates=(), name="",
+ maxvl=1):
+ # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
+ retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
+ input_vals=input_vals, immediates=immediates, name=name)
+ self.append_op(retval)
+ return retval
+
+ def sim(self, state):
+ # type: (BaseSimState) -> None
+ for op in self.ops:
+ op.sim(state)
+
+ def gen_asm(self, state):
+ # type: (GenAsmState) -> None
+ for op in self.ops:
+ op.gen_asm(state)
+
+ def pre_ra_insert_copies(self):
+ # type: () -> None
+ orig_ops = list(self.ops)
+ copied_outputs = {} # type: dict[SSAVal, SSAVal]
+ setvli_outputs = {} # type: dict[SSAVal, Op]
+ self.ops.clear()
+ for op in orig_ops:
+ for i in range(len(op.input_vals)):
+ inp = copied_outputs[op.input_vals[i]]
+ if inp.ty.base_ty is BaseTy.I64:
+ maxvl = inp.ty.reg_len
+ if inp.ty.reg_len != 1:
+ setvl = self.append_new_op(
+ OpKind.SetVLI, immediates=[maxvl],
+ name=f"{op.name}.inp{i}.setvl")
+ vl = setvl.outputs[0]
+ mv = self.append_new_op(
+ OpKind.VecCopyToReg, input_vals=[inp, vl],
+ maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
+ else:
+ mv = self.append_new_op(
+ OpKind.CopyToReg, input_vals=[inp],
+ name=f"{op.name}.inp{i}.copy")
+ op.input_vals[i] = mv.outputs[0]
+ elif inp.ty.base_ty is BaseTy.CA \
+ or inp.ty.base_ty is BaseTy.VL_MAXVL:
+ # all copies would be no-ops, so we don't need to copy,
+ # though we do need to rematerialize SetVLI ops right
+ # before the ops VL
+ if inp in setvli_outputs:
+ setvl = self.append_new_op(
+ OpKind.SetVLI,
+ immediates=setvli_outputs[inp].immediates,
+ name=f"{op.name}.inp{i}.setvl")
+ inp = setvl.outputs[0]
+ op.input_vals[i] = inp
+ else:
+ assert_never(inp.ty.base_ty)
+ self.ops.append(op)
+ for i, out in enumerate(op.outputs):
+ if op.kind is OpKind.SetVLI:
+ setvli_outputs[out] = op
+ if out.ty.base_ty is BaseTy.I64:
+ maxvl = out.ty.reg_len
+ if out.ty.reg_len != 1:
+ setvl = self.append_new_op(
+ OpKind.SetVLI, immediates=[maxvl],
+ name=f"{op.name}.out{i}.setvl")
+ vl = setvl.outputs[0]
+ mv = self.append_new_op(
+ OpKind.VecCopyFromReg, input_vals=[out, vl],
+ maxvl=maxvl, name=f"{op.name}.out{i}.copy")
+ else:
+ mv = self.append_new_op(
+ OpKind.CopyFromReg, input_vals=[out],
+ name=f"{op.name}.out{i}.copy")
+ copied_outputs[out] = mv.outputs[0]
+ elif out.ty.base_ty is BaseTy.CA \
+ or out.ty.base_ty is BaseTy.VL_MAXVL:
+ # all copies would be no-ops, so we don't need to copy
+ copied_outputs[out] = out
+ else:
+ assert_never(out.ty.base_ty)
+
+
+@final
+@unique
+@total_ordering
+class OpStage(Enum):
+ value: Literal[0, 1] # type: ignore
+
+ def __new__(cls, value):
+ # type: (int) -> OpStage
+ value = int(value)
+ if value not in (0, 1):
+ raise ValueError("invalid value")
+ retval = object.__new__(cls)
+ retval._value_ = value
+ return retval
+
+ Early = 0
+ """ early stage of Op execution, where all input reads occur.
+ all output writes with `write_stage == Early` occur here too, and therefore
+ conflict with input reads, telling the compiler that it that can't share
+ that output's register with any inputs that the output isn't tied to.
+
+ All outputs, even unused outputs, can't share registers with any other
+ outputs, independent of `write_stage` settings.
+ """
+ Late = 1
+ """ late stage of Op execution, where all output writes with
+ `write_stage == Late` occur, and therefore don't conflict with input reads,
+ telling the compiler that any inputs can safely use the same register as
+ those outputs.
+
+ All outputs, even unused outputs, can't share registers with any other
+ outputs, independent of `write_stage` settings.
+ """
+
+ def __repr__(self):
+ # type: () -> str
+ return f"OpStage.{self._name_}"
+
+ def __lt__(self, other):
+ # type: (OpStage | object) -> bool
+ if isinstance(other, OpStage):
+ return self.value < other.value
+ return NotImplemented
+
+
+assert OpStage.Early < OpStage.Late, "early must be less than late"
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+@total_ordering
+class ProgramPoint(metaclass=InternedMeta):
+ __slots__ = "op_index", "stage"
+
+ def __init__(self, op_index, stage):
+ # type: (int, OpStage) -> None
+ self.op_index = op_index
+ self.stage = stage
+
+ @property
+ def int_value(self):
+ # type: () -> int
+ """ an integer representation of `self` such that it keeps ordering and
+ successor/predecessor relations.
+ """
+ return self.op_index * 2 + self.stage.value
+
+ @staticmethod
+ def from_int_value(int_value):
+ # type: (int) -> ProgramPoint
+ op_index, stage = divmod(int_value, 2)
+ return ProgramPoint(op_index=op_index, stage=OpStage(stage))
+
+ def next(self, steps=1):
+ # type: (int) -> ProgramPoint
+ return ProgramPoint.from_int_value(self.int_value + steps)
+
+ def prev(self, steps=1):
+ # type: (int) -> ProgramPoint
+ return self.next(steps=-steps)
+
+ def __lt__(self, other):
+ # type: (ProgramPoint | Any) -> bool
+ if not isinstance(other, ProgramPoint):
+ return NotImplemented
+ if self.op_index != other.op_index:
+ return self.op_index < other.op_index
+ return self.stage < other.stage
+
+ def __repr__(self):
+ # type: () -> str
+ return f"<ops[{self.op_index}]:{self.stage._name_}>"
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
+ __slots__ = "start", "stop"
+
+ def __init__(self, start, stop):
+ # type: (ProgramPoint, ProgramPoint) -> None
+ self.start = start
+ self.stop = stop
+
+ @cached_property
+ def int_value_range(self):
+ # type: () -> range
+ return range(self.start.int_value, self.stop.int_value)
+
+ @staticmethod
+ def from_int_value_range(int_value_range):
+ # type: (range) -> ProgramRange
+ if int_value_range.step != 1:
+ raise ValueError("int_value_range must have step == 1")
+ return ProgramRange(
+ start=ProgramPoint.from_int_value(int_value_range.start),
+ stop=ProgramPoint.from_int_value(int_value_range.stop))
+
+ @overload
+ def __getitem__(self, __idx):
+ # type: (int) -> ProgramPoint
+ ...
+
+ @overload
+ def __getitem__(self, __idx):
+ # type: (slice) -> ProgramRange
+ ...
+
+ def __getitem__(self, __idx):
+ # type: (int | slice) -> ProgramPoint | ProgramRange
+ v = range(self.start.int_value, self.stop.int_value)[__idx]
+ if isinstance(v, int):
+ return ProgramPoint.from_int_value(v)
+ return ProgramRange.from_int_value_range(v)
+
+ def __len__(self):
+ # type: () -> int
+ return len(self.int_value_range)
+
+ def __iter__(self):
+ # type: () -> Iterator[ProgramPoint]
+ return map(ProgramPoint.from_int_value, self.int_value_range)
+
+ def __repr__(self):
+ # type: () -> str
+ start = repr(self.start).lstrip("<").rstrip(">")
+ stop = repr(self.stop).lstrip("<").rstrip(">")
+ return f"<range:{start}..{stop}>"
+
+
+@plain_data(frozen=True, eq=False, repr=False)
+@final
+class FnAnalysis:
+ __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
+ "def_program_ranges", "use_program_points",
+ "all_program_points")
+
+ def __init__(self, fn):
+ # type: (Fn) -> None
+ self.fn = fn
+ self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
+ self.all_program_points = ProgramRange(
+ start=ProgramPoint(op_index=0, stage=OpStage.Early),
+ stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
+ def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
+ use_program_points = {} # type: dict[SSAUse, ProgramPoint]
+ uses = {} # type: dict[SSAVal, OSet[SSAUse]]
+ live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
+ for op in fn.ops:
+ for use in op.input_uses:
+ uses[use.ssa_val].add(use)
+ use_program_point = self.__get_use_program_point(use)
+ use_program_points[use] = use_program_point
+ live_range_stops[use.ssa_val] = max(
+ live_range_stops[use.ssa_val], use_program_point.next())
+ for out in op.outputs:
+ uses[out] = OSet()
+ def_program_range = self.__get_def_program_range(out)
+ def_program_ranges[out] = def_program_range
+ live_range_stops[out] = def_program_range.stop
+ self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
+ self.def_program_ranges = FMap(def_program_ranges)
+ self.use_program_points = FMap(use_program_points)
+ live_ranges = {} # type: dict[SSAVal, ProgramRange]
+ live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
+ for ssa_val in uses.keys():
+ live_ranges[ssa_val] = live_range = ProgramRange(
+ start=self.def_program_ranges[ssa_val].start,
+ stop=live_range_stops[ssa_val])
+ for program_point in live_range:
+ live_at[program_point].add(ssa_val)
+ self.live_ranges = FMap(live_ranges)
+ self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
+
+ def __get_def_program_range(self, ssa_val):
+ # type: (SSAVal) -> ProgramRange
+ write_stage = ssa_val.defining_descriptor.write_stage
+ start = ProgramPoint(
+ op_index=self.op_indexes[ssa_val.op], stage=write_stage)
+ # always include late stage of ssa_val.op, to ensure outputs always
+ # overlap all other outputs.
+ # stop is exclusive, so we need the next program point.
+ stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
+ return ProgramRange(start=start, stop=stop)
+
+ def __get_use_program_point(self, ssa_use):
+ # type: (SSAUse) -> ProgramPoint
+ assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
+ "assumed here, ensured by GenericOpProperties.__init__"
+ return ProgramPoint(
+ op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
+
+ def __eq__(self, other):
+ # type: (FnAnalysis | Any) -> bool
+ if isinstance(other, FnAnalysis):
+ return self.fn == other.fn
+ return NotImplemented
+
+ def __hash__(self):
+ # type: () -> int
+ return hash(self.fn)
+
+ def __repr__(self):
+ # type: () -> str
+ return "<FnAnalysis>"
+
+
+@unique
+@final
+class BaseTy(Enum):
+ I64 = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
+
+ @cached_property
+ def only_scalar(self):
+ # type: () -> bool
+ if self is BaseTy.I64:
+ return False
+ elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
+ return True
+ else:
+ assert_never(self)
+
+ @cached_property
+ def max_reg_len(self):
+ # type: () -> int
+ if self is BaseTy.I64:
+ return 128
+ elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
+ return 1
+ else:
+ assert_never(self)
+
+ def __repr__(self):
+ return "BaseTy." + self._name_
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class Ty(metaclass=InternedMeta):
+ __slots__ = "base_ty", "reg_len"
+
+ @staticmethod
+ def validate(base_ty, reg_len):
+ # type: (BaseTy, int) -> str | None
+ """ return a string with the error if the combination is invalid,
+ otherwise return None
+ """
+ if base_ty.only_scalar and reg_len != 1:
+ return f"can't create a vector of an only-scalar type: {base_ty}"
+ if reg_len < 1 or reg_len > base_ty.max_reg_len:
+ return "reg_len out of range"
+ return None
+
+ def __init__(self, base_ty, reg_len):
+ # type: (BaseTy, int) -> None
+ msg = self.validate(base_ty=base_ty, reg_len=reg_len)
+ if msg is not None:
+ raise ValueError(msg)
+ self.base_ty = base_ty
+ self.reg_len = reg_len
+
+ def __repr__(self):
+ # type: () -> str
+ if self.reg_len != 1:
+ reg_len = f"*{self.reg_len}"
+ else:
+ reg_len = ""
+ return f"<{self.base_ty._name_}{reg_len}>"
+
+
+@unique
+@final
+class LocKind(Enum):
+ GPR = enum.auto()
+ StackI64 = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
+
+ @cached_property
+ def base_ty(self):
+ # type: () -> BaseTy
+ if self is LocKind.GPR or self is LocKind.StackI64:
+ return BaseTy.I64
+ if self is LocKind.CA:
+ return BaseTy.CA
+ if self is LocKind.VL_MAXVL:
+ return BaseTy.VL_MAXVL
+ else:
+ assert_never(self)
+
+ @cached_property
+ def loc_count(self):
+ # type: () -> int
+ if self is LocKind.StackI64:
+ return 512
+ if self is LocKind.GPR or self is LocKind.CA \
+ or self is LocKind.VL_MAXVL:
+ return self.base_ty.max_reg_len
+ else:
+ assert_never(self)
+
+ def __repr__(self):
+ return "LocKind." + self._name_
+
+
+@final
+@unique
+class LocSubKind(Enum):
+ BASE_GPR = enum.auto()
+ SV_EXTRA2_VGPR = enum.auto()
+ SV_EXTRA2_SGPR = enum.auto()
+ SV_EXTRA3_VGPR = enum.auto()
+ SV_EXTRA3_SGPR = enum.auto()
+ StackI64 = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
+
+ @cached_property
+ def kind(self):
+ # type: () -> LocKind
+ # pyright fails typechecking when using `in` here:
+ # reported: https://github.com/microsoft/pyright/issues/4102
+ if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
+ LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
+ LocSubKind.SV_EXTRA3_SGPR):
+ return LocKind.GPR
+ if self is LocSubKind.StackI64:
+ return LocKind.StackI64
+ if self is LocSubKind.CA:
+ return LocKind.CA
+ if self is LocSubKind.VL_MAXVL:
+ return LocKind.VL_MAXVL
+ assert_never(self)
+
+ @property
+ def base_ty(self):
+ return self.kind.base_ty
+
+ @lru_cache()
+ def allocatable_locs(self, ty):
+ # type: (Ty) -> LocSet
+ if ty.base_ty != self.base_ty:
+ raise ValueError("type mismatch")
+ if self is LocSubKind.BASE_GPR:
+ starts = range(32)
+ elif self is LocSubKind.SV_EXTRA2_VGPR:
+ starts = range(0, 128, 2)
+ elif self is LocSubKind.SV_EXTRA2_SGPR:
+ starts = range(64)
+ elif self is LocSubKind.SV_EXTRA3_VGPR \
+ or self is LocSubKind.SV_EXTRA3_SGPR:
+ starts = range(128)
+ elif self is LocSubKind.StackI64:
+ starts = range(LocKind.StackI64.loc_count)
+ elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
+ return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
+ else:
+ assert_never(self)
+ retval = [] # type: list[Loc]
+ for start in starts:
+ loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
+ if loc is None:
+ continue
+ conflicts = False
+ for special_loc in SPECIAL_GPRS:
+ if loc.conflicts(special_loc):
+ conflicts = True
+ break
+ if not conflicts:
+ retval.append(loc)
+ return LocSet(retval)
+
+ def __repr__(self):
+ return "LocSubKind." + self._name_
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GenericTy(metaclass=InternedMeta):
+ __slots__ = "base_ty", "is_vec"
+
+ def __init__(self, base_ty, is_vec):
+ # type: (BaseTy, bool) -> None
+ self.base_ty = base_ty
+ if base_ty.only_scalar and is_vec:
+ raise ValueError(f"base_ty={base_ty} requires is_vec=False")
+ self.is_vec = is_vec
+
+ def instantiate(self, maxvl):
+ # type: (int) -> Ty
+ # here's where subvl and elwid would be accounted for
+ if self.is_vec:
+ return Ty(self.base_ty, maxvl)
+ return Ty(self.base_ty, 1)
+
+ def can_instantiate_to(self, ty):
+ # type: (Ty) -> bool
+ if self.base_ty != ty.base_ty:
+ return False
+ if self.is_vec:
+ return True
+ return ty.reg_len == 1
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Loc(metaclass=InternedMeta):
+ __slots__ = "kind", "start", "reg_len"
+
+ @staticmethod
+ def validate(kind, start, reg_len):
+ # type: (LocKind, int, int) -> str | None
+ msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
+ if msg is not None:
+ return msg
+ if reg_len > kind.loc_count:
+ return "invalid reg_len"
+ if start < 0 or start + reg_len > kind.loc_count:
+ return "start not in valid range"
+ return None
+
+ @staticmethod
+ def try_make(kind, start, reg_len):
+ # type: (LocKind, int, int) -> Loc | None
+ msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
+ if msg is not None:
+ return None
+ return Loc(kind=kind, start=start, reg_len=reg_len)
+
+ def __init__(self, kind, start, reg_len):
+ # type: (LocKind, int, int) -> None
+ msg = self.validate(kind=kind, start=start, reg_len=reg_len)
+ if msg is not None:
+ raise ValueError(msg)
+ self.kind = kind
+ self.reg_len = reg_len
+ self.start = start
+
+ def conflicts(self, other):
+ # type: (Loc) -> bool
+ return (self.kind == other.kind
+ and self.start < other.stop and other.start < self.stop)
+
+ @staticmethod
+ def make_ty(kind, reg_len):
+ # type: (LocKind, int) -> Ty
+ return Ty(base_ty=kind.base_ty, reg_len=reg_len)
+
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ return self.make_ty(kind=self.kind, reg_len=self.reg_len)
+
+ @property
+ def stop(self):
+ # type: () -> int
+ return self.start + self.reg_len
+
+ def try_concat(self, *others):
+ # type: (*Loc | None) -> Loc | None
+ reg_len = self.reg_len
+ stop = self.stop
+ for other in others:
+ if other is None or other.kind != self.kind:
+ return None
+ if stop != other.start:
+ return None
+ stop = other.stop
+ reg_len += other.reg_len
+ return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
+
+ def get_subloc_at_offset(self, subloc_ty, offset):
+ # type: (Ty, int) -> Loc
+ if subloc_ty.base_ty != self.kind.base_ty:
+ raise ValueError("BaseTy mismatch")
+ if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
+ raise ValueError("invalid sub-Loc: offset and/or "
+ "subloc_ty.reg_len out of range")
+ return Loc(kind=self.kind,
+ start=self.start + offset, reg_len=subloc_ty.reg_len)
+
+
+SPECIAL_GPRS = (
+ Loc(kind=LocKind.GPR, start=0, reg_len=1),
+ Loc(kind=LocKind.GPR, start=1, reg_len=1),
+ Loc(kind=LocKind.GPR, start=2, reg_len=1),
+ Loc(kind=LocKind.GPR, start=13, reg_len=1),
+)
+
+
+@final
+class _LocSetHashHelper(AbstractSet[Loc]):
+ """helper to more quickly compute LocSet's hash"""
+
+ def __init__(self, locs):
+ # type: (Iterable[Loc]) -> None
+ super().__init__()
+ self.locs = list(locs)
+
+ def __hash__(self):
+ # type: () -> int
+ return super()._hash()
+
+ def __contains__(self, x):
+ # type: (Loc | Any) -> bool
+ return x in self.locs
+
+ def __iter__(self):
+ # type: () -> Iterator[Loc]
+ return iter(self.locs)
+
+ def __len__(self):
+ return len(self.locs)
+
+
+@plain_data(frozen=True, eq=False, repr=False)
+@final
+class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
+ __slots__ = "starts", "ty", "_LocSet__hash"
+
+ def __init__(self, __locs=()):
+ # type: (Iterable[Loc]) -> None
+ if isinstance(__locs, LocSet):
+ self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
+ self.ty = __locs.ty # type: Ty | None
+ self._LocSet__hash = __locs._LocSet__hash # type: int
+ return
+ starts = {i: BitSet() for i in LocKind}
+ ty = None # type: None | Ty
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ nonlocal ty
+ for loc in __locs:
+ if ty is None:
+ ty = loc.ty
+ if ty != loc.ty:
+ raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+ starts[loc.kind].add(loc.start)
+ yield loc
+ self._LocSet__hash = _LocSetHashHelper(locs()).__hash__()
+ self.starts = FMap(
+ (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
+ self.ty = ty
+
+ @cached_property
+ def stops(self):
+ # type: () -> FMap[LocKind, FBitSet]
+ if self.ty is None:
+ return FMap()
+ sh = self.ty.reg_len
+ return FMap(
+ (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
+
+ @property
+ def kinds(self):
+ # type: () -> AbstractSet[LocKind]
+ return self.starts.keys()
+
+ @property
+ def reg_len(self):
+ # type: () -> int | None
+ if self.ty is None:
+ return None
+ return self.ty.reg_len
+
+ @property
+ def base_ty(self):
+ # type: () -> BaseTy | None
+ if self.ty is None:
+ return None
+ return self.ty.base_ty
+
+ def concat(self, *others):
+ # type: (*LocSet) -> LocSet
+ if self.ty is None:
+ return LocSet()
+ base_ty = self.ty.base_ty
+ reg_len = self.ty.reg_len
+ starts = {k: BitSet(v) for k, v in self.starts.items()}
+ for other in others:
+ if other.ty is None:
+ return LocSet()
+ if other.ty.base_ty != base_ty:
+ return LocSet()
+ for kind, other_starts in other.starts.items():
+ if kind not in starts:
+ continue
+ starts[kind].bits &= other_starts.bits >> reg_len
+ if starts[kind] == 0:
+ del starts[kind]
+ if len(starts) == 0:
+ return LocSet()
+ reg_len += other.ty.reg_len
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ for kind, v in starts.items():
+ for start in v:
+ loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
+ if loc is not None:
+ yield loc
+ return LocSet(locs())
+
+ def __contains__(self, loc):
+ # type: (Loc | Any) -> bool
+ if not isinstance(loc, Loc) or loc.ty != self.ty:
+ return False
+ if loc.kind not in self.starts:
+ return False
+ return loc.start in self.starts[loc.kind]
+
+ def __iter__(self):
+ # type: () -> Iterator[Loc]
+ if self.ty is None:
+ return
+ for kind, starts in self.starts.items():
+ for start in starts:
+ yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
+
+ @cached_property
+ def __len(self):
+ return sum((len(v) for v in self.starts.values()), 0)
+
+ def __len__(self):
+ return self.__len
+
+ def __hash__(self):
+ return self._LocSet__hash
+
+ def __eq__(self, __other):
+ # type: (LocSet | Any) -> bool
+ if isinstance(__other, LocSet):
+ return self.ty == __other.ty and self.starts == __other.starts
+ return super().__eq__(__other)
+
+ @lru_cache(maxsize=None, typed=True)
+ def max_conflicts_with(self, other):
+ # type: (LocSet | Loc) -> int
+ """the largest number of Locs in `self` that a single Loc
+ from `other` can conflict with
+ """
+ if isinstance(other, LocSet):
+ return max(self.max_conflicts_with(i) for i in other)
+ else:
+ return sum(other.conflicts(i) for i in self)
+
+ def __repr__(self):
+ items = [] # type: list[str]
+ for name in fields(self):
+ if name.startswith("_"):
+ continue
+ items.append(f"{name}={getattr(self, name)!r}")
+ return f"LocSet({', '.join(items)})"
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GenericOperandDesc(metaclass=InternedMeta):
+ """generic Op operand descriptor"""
+ __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
+ "write_stage")
+
+ def __init__(
+ self, ty, # type: GenericTy
+ sub_kinds, # type: Iterable[LocSubKind]
+ *,
+ fixed_loc=None, # type: Loc | None
+ tied_input_index=None, # type: int | None
+ spread=False, # type: bool
+ write_stage=OpStage.Early, # type: OpStage
+ ):
+ # type: (...) -> None
+ self.ty = ty
+ self.sub_kinds = OFSet(sub_kinds)
+ if len(self.sub_kinds) == 0:
+ raise ValueError("sub_kinds can't be empty")
+ self.fixed_loc = fixed_loc
+ if fixed_loc is not None:
+ if tied_input_index is not None:
+ raise ValueError("operand can't be both tied and fixed")
+ if not ty.can_instantiate_to(fixed_loc.ty):
+ raise ValueError(
+ f"fixed_loc has incompatible type for given generic "
+ f"type: fixed_loc={fixed_loc} generic ty={ty}")
+ if len(self.sub_kinds) != 1:
+ raise ValueError(
+ "multiple sub_kinds not allowed for fixed operand")
+ for sub_kind in self.sub_kinds:
+ if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
+ raise ValueError(
+ f"fixed_loc not in given sub_kind: "
+ f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
+ for sub_kind in self.sub_kinds:
+ if sub_kind.base_ty != ty.base_ty:
+ raise ValueError(f"sub_kind is incompatible with type: "
+ f"sub_kind={sub_kind} ty={ty}")
+ if tied_input_index is not None and tied_input_index < 0:
+ raise ValueError("invalid tied_input_index")
+ self.tied_input_index = tied_input_index
+ self.spread = spread
+ if spread:
+ if self.tied_input_index is not None:
+ raise ValueError("operand can't be both spread and tied")
+ if self.fixed_loc is not None:
+ raise ValueError("operand can't be both spread and fixed")
+ if self.ty.is_vec:
+ raise ValueError("operand can't be both spread and vector")
+ self.write_stage = write_stage
+
+ @cached_property
+ def ty_before_spread(self):
+ # type: () -> GenericTy
+ if self.spread:
+ return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
+ return self.ty
+
+ def tied_to_input(self, tied_input_index):
+ # type: (int) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds,
+ tied_input_index=tied_input_index,
+ write_stage=self.write_stage)
+
+ def with_fixed_loc(self, fixed_loc):
+ # type: (Loc) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
+ write_stage=self.write_stage)
+
+ def with_write_stage(self, write_stage):
+ # type: (OpStage) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds,
+ fixed_loc=self.fixed_loc,
+ tied_input_index=self.tied_input_index,
+ spread=self.spread,
+ write_stage=write_stage)
+
+ def instantiate(self, maxvl):
+ # type: (int) -> Iterable[OperandDesc]
+ # assumes all spread operands have ty.reg_len = 1
+ rep_count = 1
+ if self.spread:
+ rep_count = maxvl
+ ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
+
+ def locs_before_spread():
+ # type: () -> Iterable[Loc]
+ if self.fixed_loc is not None:
+ if ty_before_spread != self.fixed_loc.ty:
+ raise ValueError(
+ f"instantiation failed: type mismatch with fixed_loc: "
+ f"instantiated type: {ty_before_spread} "
+ f"fixed_loc: {self.fixed_loc}")
+ yield self.fixed_loc
+ return
+ for sub_kind in self.sub_kinds:
+ yield from sub_kind.allocatable_locs(ty_before_spread)
+ loc_set_before_spread = LocSet(locs_before_spread())
+ for idx in range(rep_count):
+ if not self.spread:
+ idx = None
+ yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
+ tied_input_index=self.tied_input_index,
+ spread_index=idx, write_stage=self.write_stage)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class OperandDesc(metaclass=InternedMeta):
+ """Op operand descriptor"""
+ __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
+ "write_stage")
+
+ def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
+ write_stage):
+ # type: (LocSet, int | None, int | None, OpStage) -> None
+ if len(loc_set_before_spread) == 0:
+ raise ValueError("loc_set_before_spread must not be empty")
+ self.loc_set_before_spread = loc_set_before_spread
+ self.tied_input_index = tied_input_index
+ if self.tied_input_index is not None and spread_index is not None:
+ raise ValueError("operand can't be both spread and tied")
+ self.spread_index = spread_index
+ self.write_stage = write_stage
+
+ @cached_property
+ def ty_before_spread(self):
+ # type: () -> Ty
+ ty = self.loc_set_before_spread.ty
+ assert ty is not None, (
+ "__init__ checked that the LocSet isn't empty, "
+ "non-empty LocSets should always have ty set")
+ return ty
+
+ @cached_property
+ def ty(self):
+ """ Ty after any spread is applied """
+ if self.spread_index is not None:
+ # assumes all spread operands have ty.reg_len = 1
+ return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
+ return self.ty_before_spread
+
+ @property
+ def reg_offset_in_unspread(self):
+ """ the number of reg-sized slots in the unspread Loc before self's Loc
+
+ e.g. if the unspread Loc containing self is:
+ `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
+ and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
+ then reg_offset_into_unspread == 2 == 10 - 8
+ """
+ if self.spread_index is None:
+ return 0
+ return self.spread_index * self.ty.reg_len
+
+
+OD_BASE_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.BASE_GPR])
+OD_EXTRA3_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
+OD_EXTRA3_VGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
+OD_EXTRA2_SGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
+OD_EXTRA2_VGPR = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
+OD_CA = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
+ sub_kinds=[LocSubKind.CA])
+OD_VL = GenericOperandDesc(
+ ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
+ sub_kinds=[LocSubKind.VL_MAXVL])
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GenericOpProperties(metaclass=InternedMeta):
+ __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
+ "is_copy", "is_load_immediate", "has_side_effects")
+
+ def __init__(
+ self, demo_asm, # type: str
+ inputs, # type: Iterable[GenericOperandDesc]
+ outputs, # type: Iterable[GenericOperandDesc]
+ immediates=(), # type: Iterable[range]
+ is_copy=False, # type: bool
+ is_load_immediate=False, # type: bool
+ has_side_effects=False, # type: bool
+ ):
+ # type: (...) -> None
+ self.demo_asm = demo_asm # type: str
+ self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
+ for inp in self.inputs:
+ if inp.tied_input_index is not None:
+ raise ValueError(
+ f"tied_input_index is not allowed on inputs: {inp}")
+ if inp.write_stage is not OpStage.Early:
+ raise ValueError(
+ f"write_stage is not allowed on inputs: {inp}")
+ self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
+ fixed_locs = [] # type: list[tuple[Loc, int]]
+ for idx, out in enumerate(self.outputs):
+ if out.tied_input_index is not None:
+ if out.tied_input_index >= len(self.inputs):
+ raise ValueError(f"tied_input_index out of range: {out}")
+ tied_inp = self.inputs[out.tied_input_index]
+ expected_out = tied_inp.tied_to_input(out.tied_input_index) \
+ .with_write_stage(out.write_stage)
+ if expected_out != out:
+ raise ValueError(f"output can't be tied to non-equivalent "
+ f"input: {out} tied to {tied_inp}")
+ if out.fixed_loc is not None:
+ for other_fixed_loc, other_idx in fixed_locs:
+ if not other_fixed_loc.conflicts(out.fixed_loc):
+ continue
+ raise ValueError(
+ f"conflicting fixed_locs: outputs[{idx}] and "
+ f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
+ f"with {other_fixed_loc}")
+ fixed_locs.append((out.fixed_loc, idx))
+ self.immediates = tuple(immediates) # type: tuple[range, ...]
+ self.is_copy = is_copy # type: bool
+ self.is_load_immediate = is_load_immediate # type: bool
+ self.has_side_effects = has_side_effects # type: bool
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class OpProperties(metaclass=InternedMeta):
+ __slots__ = "kind", "inputs", "outputs", "maxvl"
+
+ def __init__(self, kind, maxvl):
+ # type: (OpKind, int) -> None
+ self.kind = kind # type: OpKind
+ inputs = [] # type: list[OperandDesc]
+ for inp in self.generic.inputs:
+ inputs.extend(inp.instantiate(maxvl=maxvl))
+ self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
+ outputs = [] # type: list[OperandDesc]
+ for out in self.generic.outputs:
+ outputs.extend(out.instantiate(maxvl=maxvl))
+ self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
+ self.maxvl = maxvl # type: int
+
+ @property
+ def generic(self):
+ # type: () -> GenericOpProperties
+ return self.kind.properties
+
+ @property
+ def immediates(self):
+ # type: () -> tuple[range, ...]
+ return self.generic.immediates
+
+ @property
+ def demo_asm(self):
+ # type: () -> str
+ return self.generic.demo_asm
+
+ @property
+ def is_copy(self):
+ # type: () -> bool
+ return self.generic.is_copy
+
+ @property
+ def is_load_immediate(self):
+ # type: () -> bool
+ return self.generic.is_load_immediate
+
+ @property
+ def has_side_effects(self):
+ # type: () -> bool
+ return self.generic.has_side_effects
+
+
+IMM_S16 = range(-1 << 15, 1 << 15)
+
+_SIM_FN = Callable[["Op", "BaseSimState"], None]
+_SIM_FN2 = Callable[[], _SIM_FN]
+_SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
+_GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
+_GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
+_GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
+
+
+@unique
+@final
+class OpKind(Enum):
+ def __init__(self, properties):
+ # type: (GenericOpProperties) -> None
+ super().__init__()
+ self.__properties = properties
+
+ @property
+ def properties(self):
+ # type: () -> GenericOpProperties
+ return self.__properties
+
+ def instantiate(self, maxvl):
+ # type: (int) -> OpProperties
+ return OpProperties(self, maxvl=maxvl)
+
+ def __repr__(self):
+ # type: () -> str
+ return "OpKind." + self._name_
+
+ @cached_property
+ def sim(self):
+ # type: () -> _SIM_FN
+ return _SIM_FNS[self.properties]()
+
+ @cached_property
+ def gen_asm(self):
+ # type: () -> _GEN_ASM_FN
+ return _GEN_ASMS[self.properties]()
+
+ @staticmethod
+ def __clearca_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = False,
+
+ @staticmethod
+ def __clearca_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ state.writeln("addic 0, 0, 0")
+ ClearCA = GenericOpProperties(
+ demo_asm="addic 0, 0, 0",
+ inputs=[],
+ outputs=[OD_CA.with_write_stage(OpStage.Late)],
+ )
+ _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
+ _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
+
+ @staticmethod
+ def __setca_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = True,
+
+ @staticmethod
+ def __setca_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ state.writeln("subfc 0, 0, 0")
+ SetCA = GenericOpProperties(
+ demo_asm="subfc 0, 0, 0",
+ inputs=[],
+ outputs=[OD_CA.with_write_stage(OpStage.Late)],
+ )
+ _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
+ _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
+
+ @staticmethod
+ def __svadde_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA = state[op.input_vals[0]]
+ RB = state[op.input_vals[1]]
+ carry, = state[op.input_vals[2]]
+ VL, = state[op.input_vals[3]]
+ RT = [] # type: list[int]
+ for i in range(VL):
+ v = RA[i] + RB[i] + carry
+ RT.append(v & GPR_VALUE_MASK)
+ carry = (v >> GPR_SIZE_IN_BITS) != 0
+ state[op.outputs[0]] = tuple(RT)
+ state[op.outputs[1]] = carry,
+
+ @staticmethod
+ def __svadde_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RT = state.vgpr(op.outputs[0])
+ RA = state.vgpr(op.input_vals[0])
+ RB = state.vgpr(op.input_vals[1])
+ state.writeln(f"sv.adde {RT}, {RA}, {RB}")
+ SvAddE = GenericOpProperties(
+ demo_asm="sv.adde *RT, *RA, *RB",
+ inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
+ outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
+ )
+ _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
+ _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
+
+ @staticmethod
+ def __addze_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA, = state[op.input_vals[0]]
+ carry, = state[op.input_vals[1]]
+ v = RA + carry
+ RT = v & GPR_VALUE_MASK
+ carry = (v >> GPR_SIZE_IN_BITS) != 0
+ state[op.outputs[0]] = RT,
+ state[op.outputs[1]] = carry,
+
+ @staticmethod
+ def __addze_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RT = state.vgpr(op.outputs[0])
+ RA = state.vgpr(op.input_vals[0])
+ state.writeln(f"addze {RT}, {RA}")
+ AddZE = GenericOpProperties(
+ demo_asm="addze RT, RA",
+ inputs=[OD_BASE_SGPR, OD_CA],
+ outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
+ )
+ _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
+ _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
+
+ @staticmethod
+ def __svsubfe_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA = state[op.input_vals[0]]
+ RB = state[op.input_vals[1]]
+ carry, = state[op.input_vals[2]]
+ VL, = state[op.input_vals[3]]
+ RT = [] # type: list[int]
+ for i in range(VL):
+ v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
+ RT.append(v & GPR_VALUE_MASK)
+ carry = (v >> GPR_SIZE_IN_BITS) != 0
+ state[op.outputs[0]] = tuple(RT)
+ state[op.outputs[1]] = carry,
+
+ @staticmethod
+ def __svsubfe_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RT = state.vgpr(op.outputs[0])
+ RA = state.vgpr(op.input_vals[0])
+ RB = state.vgpr(op.input_vals[1])
+ state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
+ SvSubFE = GenericOpProperties(
+ demo_asm="sv.subfe *RT, *RA, *RB",
+ inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
+ outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
+ )
+ _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
+ _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
+
+ @staticmethod
+ def __svmaddedu_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA = state[op.input_vals[0]]
+ RB, = state[op.input_vals[1]]
+ carry, = state[op.input_vals[2]]
+ VL, = state[op.input_vals[3]]
+ RT = [] # type: list[int]
+ for i in range(VL):
+ v = RA[i] * RB + carry
+ RT.append(v & GPR_VALUE_MASK)
+ carry = v >> GPR_SIZE_IN_BITS
+ state[op.outputs[0]] = tuple(RT)
+ state[op.outputs[1]] = carry,
+
+ @staticmethod
+ def __svmaddedu_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RT = state.vgpr(op.outputs[0])
+ RA = state.vgpr(op.input_vals[0])
+ RB = state.sgpr(op.input_vals[1])
+ RC = state.sgpr(op.input_vals[2])
+ state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
+ SvMAddEDU = GenericOpProperties(
+ demo_asm="sv.maddedu *RT, *RA, RB, RC",
+ inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
+ outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
+ )
+ _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
+ _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
+
+ @staticmethod
+ def __setvli_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = op.immediates[0],
+
+ @staticmethod
+ def __setvli_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ imm = op.immediates[0]
+ state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
+ SetVLI = GenericOpProperties(
+ demo_asm="setvl 0, 0, imm, 0, 1, 1",
+ inputs=(),
+ outputs=[OD_VL.with_write_stage(OpStage.Late)],
+ immediates=[range(1, 65)],
+ is_load_immediate=True,
+ )
+ _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
+ _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
+
+ @staticmethod
+ def __svli_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ VL, = state[op.input_vals[0]]
+ imm = op.immediates[0] & GPR_VALUE_MASK
+ state[op.outputs[0]] = (imm,) * VL
+
+ @staticmethod
+ def __svli_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RT = state.vgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"sv.addi {RT}, 0, {imm}")
+ SvLI = GenericOpProperties(
+ demo_asm="sv.addi *RT, 0, imm",
+ inputs=[OD_VL],
+ outputs=[OD_EXTRA3_VGPR],
+ immediates=[IMM_S16],
+ is_load_immediate=True,
+ )
+ _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
+ _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
+
+ @staticmethod
+ def __li_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ imm = op.immediates[0] & GPR_VALUE_MASK
+ state[op.outputs[0]] = imm,
+
+ @staticmethod
+ def __li_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RT = state.sgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"addi {RT}, 0, {imm}")
+ LI = GenericOpProperties(
+ demo_asm="addi RT, 0, imm",
+ inputs=(),
+ outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
+ immediates=[IMM_S16],
+ is_load_immediate=True,
+ )
+ _SIM_FNS[LI] = lambda: OpKind.__li_sim
+ _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
+
+ @staticmethod
+ def __veccopytoreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
+
+ @staticmethod
+ def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
+ # type: (Loc, Loc, bool, GenAsmState) -> None
+ sv = "sv." if is_vec else ""
+ rev = ""
+ if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
+ rev = "/mrr"
+ if src_loc == dest_loc:
+ return # no-op
+ if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+ raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
+ if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+ raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
+ if src_loc.kind is LocKind.StackI64:
+ if dest_loc.kind is LocKind.StackI64:
+ raise ValueError(
+ f"can't copy from stack to stack: {src_loc} {dest_loc}")
+ elif dest_loc.kind is not LocKind.GPR:
+ assert_never(dest_loc.kind)
+ src = state.stack(src_loc)
+ dest = state.gpr(dest_loc, is_vec=is_vec)
+ state.writeln(f"{sv}ld {dest}, {src}")
+ elif dest_loc.kind is LocKind.StackI64:
+ if src_loc.kind is not LocKind.GPR:
+ assert_never(src_loc.kind)
+ src = state.gpr(src_loc, is_vec=is_vec)
+ dest = state.stack(dest_loc)
+ state.writeln(f"{sv}std {src}, {dest}")
+ elif src_loc.kind is LocKind.GPR:
+ if dest_loc.kind is not LocKind.GPR:
+ assert_never(dest_loc.kind)
+ src = state.gpr(src_loc, is_vec=is_vec)
+ dest = state.gpr(dest_loc, is_vec=is_vec)
+ state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
+ else:
+ assert_never(src_loc.kind)
+
+ @staticmethod
+ def __veccopytoreg_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(
+ op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=True, state=state)
+
+ VecCopyToReg = GenericOpProperties(
+ demo_asm="sv.mv dest, src",
+ inputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
+ ), OD_VL],
+ outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
+ is_copy=True,
+ )
+ _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
+ _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
+
+ @staticmethod
+ def __veccopyfromreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
+
+ @staticmethod
+ def __veccopyfromreg_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(
+ op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+ is_vec=True, state=state)
+ VecCopyFromReg = GenericOpProperties(
+ demo_asm="sv.mv dest, src",
+ inputs=[OD_EXTRA3_VGPR, OD_VL],
+ outputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=True),
+ sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
+ write_stage=OpStage.Late,
+ )],
+ is_copy=True,
+ )
+ _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
+ _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
+
+ @staticmethod
+ def __copytoreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
+
+ @staticmethod
+ def __copytoreg_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(
+ op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=False, state=state)
+ CopyToReg = GenericOpProperties(
+ demo_asm="mv dest, src",
+ inputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
+ LocSubKind.StackI64],
+ )],
+ outputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
+ write_stage=OpStage.Late,
+ )],
+ is_copy=True,
+ )
+ _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
+ _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
+
+ @staticmethod
+ def __copyfromreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
+
+ @staticmethod
+ def __copyfromreg_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(
+ op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+ is_vec=False, state=state)
+ CopyFromReg = GenericOpProperties(
+ demo_asm="mv dest, src",
+ inputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
+ )],
+ outputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
+ LocSubKind.StackI64],
+ write_stage=OpStage.Late,
+ )],
+ is_copy=True,
+ )
+ _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
+ _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
+
+ @staticmethod
+ def __concat_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = tuple(
+ state[i][0] for i in op.input_vals[:-1])
+
+ @staticmethod
+ def __concat_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=True, state=state)
+ Concat = GenericOpProperties(
+ demo_asm="sv.mv dest, src",
+ inputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
+ spread=True,
+ ), OD_VL],
+ outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
+ is_copy=True,
+ )
+ _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
+ _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
+
+ @staticmethod
+ def __spread_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ for idx, inp in enumerate(state[op.input_vals[0]]):
+ state[op.outputs[idx]] = inp,
+
+ @staticmethod
+ def __spread_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(op.outputs, LocKind.GPR),
+ is_vec=True, state=state)
+ Spread = GenericOpProperties(
+ demo_asm="sv.mv dest, src",
+ inputs=[OD_EXTRA3_VGPR, OD_VL],
+ outputs=[GenericOperandDesc(
+ ty=GenericTy(BaseTy.I64, is_vec=False),
+ sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
+ spread=True,
+ write_stage=OpStage.Late,
+ )],
+ is_copy=True,
+ )
+ _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
+ _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
+
+ @staticmethod
+ def __svld_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA, = state[op.input_vals[0]]
+ VL, = state[op.input_vals[1]]
+ addr = RA + op.immediates[0]
+ RT = [] # type: list[int]
+ for i in range(VL):
+ v = state.load(addr + GPR_SIZE_IN_BYTES * i)
+ RT.append(v & GPR_VALUE_MASK)
+ state[op.outputs[0]] = tuple(RT)
+
+ @staticmethod
+ def __svld_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RA = state.sgpr(op.input_vals[0])
+ RT = state.vgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"sv.ld {RT}, {imm}({RA})")
+ SvLd = GenericOpProperties(
+ demo_asm="sv.ld *RT, imm(RA)",
+ inputs=[OD_EXTRA3_SGPR, OD_VL],
+ outputs=[OD_EXTRA3_VGPR],
+ immediates=[IMM_S16],
+ )
+ _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
+ _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
+
+ @staticmethod
+ def __ld_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA, = state[op.input_vals[0]]
+ addr = RA + op.immediates[0]
+ v = state.load(addr)
+ state[op.outputs[0]] = v & GPR_VALUE_MASK,
+
+ @staticmethod
+ def __ld_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RA = state.sgpr(op.input_vals[0])
+ RT = state.sgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"ld {RT}, {imm}({RA})")
+ Ld = GenericOpProperties(
+ demo_asm="ld RT, imm(RA)",
+ inputs=[OD_BASE_SGPR],
+ outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
+ immediates=[IMM_S16],
+ )
+ _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
+ _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
+
+ @staticmethod
+ def __svstd_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RS = state[op.input_vals[0]]
+ RA, = state[op.input_vals[1]]
+ VL, = state[op.input_vals[2]]
+ addr = RA + op.immediates[0]
+ for i in range(VL):
+ state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
+
+ @staticmethod
+ def __svstd_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RS = state.vgpr(op.input_vals[0])
+ RA = state.sgpr(op.input_vals[1])
+ imm = op.immediates[0]
+ state.writeln(f"sv.std {RS}, {imm}({RA})")
+ SvStd = GenericOpProperties(
+ demo_asm="sv.std *RS, imm(RA)",
+ inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
+ outputs=[],
+ immediates=[IMM_S16],
+ has_side_effects=True,
+ )
+ _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
+ _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
+
+ @staticmethod
+ def __std_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RS, = state[op.input_vals[0]]
+ RA, = state[op.input_vals[1]]
+ addr = RA + op.immediates[0]
+ state.store(addr, value=RS)
+
+ @staticmethod
+ def __std_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ RS = state.sgpr(op.input_vals[0])
+ RA = state.sgpr(op.input_vals[1])
+ imm = op.immediates[0]
+ state.writeln(f"std {RS}, {imm}({RA})")
+ Std = GenericOpProperties(
+ demo_asm="std RS, imm(RA)",
+ inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
+ outputs=[],
+ immediates=[IMM_S16],
+ has_side_effects=True,
+ )
+ _SIM_FNS[Std] = lambda: OpKind.__std_sim
+ _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
+
+ @staticmethod
+ def __funcargr3_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ pass # return value set before simulation
+
+ @staticmethod
+ def __funcargr3_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ pass # no instructions needed
+ FuncArgR3 = GenericOpProperties(
+ demo_asm="",
+ inputs=[],
+ outputs=[OD_BASE_SGPR.with_fixed_loc(
+ Loc(kind=LocKind.GPR, start=3, reg_len=1))],
+ )
+ _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
+ _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+class SSAValOrUse(metaclass=InternedMeta):
+ __slots__ = "op", "operand_idx"
+
+ def __init__(self, op, operand_idx):
+ # type: (Op, int) -> None
+ super().__init__()
+ self.op = op
+ if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
+ raise ValueError("invalid operand_idx")
+ self.operand_idx = operand_idx
+
+ @abstractmethod
+ def __repr__(self):
+ # type: () -> str
+ ...
+
+ @property
+ @abstractmethod
+ def descriptor_array(self):
+ # type: () -> tuple[OperandDesc, ...]
+ ...
+
+ @cached_property
+ def defining_descriptor(self):
+ # type: () -> OperandDesc
+ return self.descriptor_array[self.operand_idx]
+
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ return self.defining_descriptor.ty
+
+ @cached_property
+ def ty_before_spread(self):
+ # type: () -> Ty
+ return self.defining_descriptor.ty_before_spread
+
+ @property
+ def base_ty(self):
+ # type: () -> BaseTy
+ return self.ty_before_spread.base_ty
+
+ @property
+ def reg_offset_in_unspread(self):
+ """ the number of reg-sized slots in the unspread Loc before self's Loc
+
+ e.g. if the unspread Loc containing self is:
+ `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
+ and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
+ then reg_offset_into_unspread == 2 == 10 - 8
+ """
+ return self.defining_descriptor.reg_offset_in_unspread
+
+ @property
+ def unspread_start_idx(self):
+ # type: () -> int
+ return self.operand_idx - (self.defining_descriptor.spread_index or 0)
+
+ @property
+ def unspread_start(self):
+ # type: () -> Self
+ return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class SSAVal(SSAValOrUse):
+ __slots__ = ()
+
+ def __repr__(self):
+ # type: () -> str
+ return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
+
+ @cached_property
+ def def_loc_set_before_spread(self):
+ # type: () -> LocSet
+ return self.defining_descriptor.loc_set_before_spread
+
+ @cached_property
+ def descriptor_array(self):
+ # type: () -> tuple[OperandDesc, ...]
+ return self.op.properties.outputs
+
+ @cached_property
+ def tied_input(self):
+ # type: () -> None | SSAUse
+ if self.defining_descriptor.tied_input_index is None:
+ return None
+ return SSAUse(op=self.op,
+ operand_idx=self.defining_descriptor.tied_input_index)
+
+ @property
+ def write_stage(self):
+ # type: () -> OpStage
+ return self.defining_descriptor.write_stage
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class SSAUse(SSAValOrUse):
+ __slots__ = ()
+
+ @cached_property
+ def use_loc_set_before_spread(self):
+ # type: () -> LocSet
+ return self.defining_descriptor.loc_set_before_spread
+
+ @cached_property
+ def descriptor_array(self):
+ # type: () -> tuple[OperandDesc, ...]
+ return self.op.properties.inputs
+
+ def __repr__(self):
+ # type: () -> str
+ return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
+
+ @property
+ def ssa_val(self):
+ # type: () -> SSAVal
+ return self.op.input_vals[self.operand_idx]
+
+ @ssa_val.setter
+ def ssa_val(self, ssa_val):
+ # type: (SSAVal) -> None
+ self.op.input_vals[self.operand_idx] = ssa_val
+
+
+_T = TypeVar("_T")
+_Desc = TypeVar("_Desc")
+
+
+class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
+ @abstractmethod
+ def _verify_write_with_desc(self, idx, item, desc):
+ # type: (int, _T | Any, _Desc) -> None
+ raise NotImplementedError
+
+ @final
+ def _verify_write(self, idx, item):
+ # type: (int | Any, _T | Any) -> int
+ if not isinstance(idx, int):
+ if isinstance(idx, slice):
+ raise TypeError(
+ f"can't write to slice of {self.__class__.__name__}")
+ raise TypeError(f"can't write with index {idx!r}")
+ # normalize idx, raising IndexError if it is out of range
+ idx = range(len(self.descriptors))[idx]
+ desc = self.descriptors[idx]
+ self._verify_write_with_desc(idx, item, desc)
+ return idx
+
+ def _on_set(self, idx, new_item, old_item):
+ # type: (int, _T, _T | None) -> None
+ pass
+
+ @abstractmethod
+ def _get_descriptors(self):
+ # type: () -> tuple[_Desc, ...]
+ raise NotImplementedError
+
+ @cached_property
+ @final
+ def descriptors(self):
+ # type: () -> tuple[_Desc, ...]
+ return self._get_descriptors()
+
+ @property
+ @final
+ def op(self):
+ return self.__op
+
+ def __init__(self, items, op):
+ # type: (Iterable[_T], Op) -> None
+ super().__init__()
+ self.__op = op
+ self.__items = [] # type: list[_T]
+ for idx, item in enumerate(items):
+ if idx >= len(self.descriptors):
+ raise ValueError("too many items")
+ _ = self._verify_write(idx, item)
+ self.__items.append(item)
+ if len(self.__items) < len(self.descriptors):
+ raise ValueError("not enough items")
+
+ @final
+ def __iter__(self):
+ # type: () -> Iterator[_T]
+ yield from self.__items
+
+ @overload
+ def __getitem__(self, idx):
+ # type: (int) -> _T
+ ...
+
+ @overload
+ def __getitem__(self, idx):
+ # type: (slice) -> list[_T]
+ ...
+
+ @final
+ def __getitem__(self, idx):
+ # type: (int | slice) -> _T | list[_T]
+ return self.__items[idx]
+
+ @final
+ def __setitem__(self, idx, item):
+ # type: (int, _T) -> None
+ idx = self._verify_write(idx, item)
+ self.__items[idx] = item
+
+ @final
+ def __len__(self):
+ # type: () -> int
+ return len(self.__items)
+
+ def __repr__(self):
+ # type: () -> str
+ return f"{self.__class__.__name__}({self.__items}, op=...)"
+
+
+@final
+class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
+ def _get_descriptors(self):
+ # type: () -> tuple[OperandDesc, ...]
+ return self.op.properties.inputs
+
+ def _verify_write_with_desc(self, idx, item, desc):
+ # type: (int, SSAVal | Any, OperandDesc) -> None
+ if not isinstance(item, SSAVal):
+ raise TypeError("expected value of type SSAVal")
+ if item.ty != desc.ty:
+ raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
+ f"corresponding input's type {desc.ty!r}")
+
+ def _on_set(self, idx, new_item, old_item):
+ # type: (int, SSAVal, SSAVal | None) -> None
+ SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
+
+ def __init__(self, items, op):
+ # type: (Iterable[SSAVal], Op) -> None
+ if hasattr(op, "inputs"):
+ raise ValueError("Op.inputs already set")
+ super().__init__(items, op)
+
+
+@final
+class OpImmediates(OpInputSeq[int, range]):
+ def _get_descriptors(self):
+ # type: () -> tuple[range, ...]
+ return self.op.properties.immediates
+
+ def _verify_write_with_desc(self, idx, item, desc):
+ # type: (int, int | Any, range) -> None
+ if not isinstance(item, int):
+ raise TypeError("expected value of type int")
+ if item not in desc:
+ raise ValueError(f"immediate value {item!r} not in {desc!r}")
+
+ def __init__(self, items, op):
+ # type: (Iterable[int], Op) -> None
+ if hasattr(op, "immediates"):
+ raise ValueError("Op.immediates already set")
+ super().__init__(items, op)
+
+
+@plain_data(frozen=True, eq=False, repr=False)
+@final
+class Op:
+ __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
+ "outputs", "name")
+
+ def __init__(self, fn, properties, input_vals, immediates, name=""):
+ # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
+ self.fn = fn
+ self.properties = properties
+ self.input_vals = OpInputVals(input_vals, op=self)
+ inputs_len = len(self.properties.inputs)
+ self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
+ self.immediates = OpImmediates(immediates, op=self)
+ outputs_len = len(self.properties.outputs)
+ self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
+ self.name = fn._add_op_with_unused_name(self, name) # type: ignore
+
+ @property
+ def kind(self):
+ # type: () -> OpKind
+ return self.properties.kind
+
+ def __eq__(self, other):
+ # type: (Op | Any) -> bool
+ if isinstance(other, Op):
+ return self is other
+ return NotImplemented
+
+ def __hash__(self):
+ # type: () -> int
+ return object.__hash__(self)
+
+ def __repr__(self):
+ # type: () -> str
+ field_vals = [] # type: list[str]
+ for name in fields(self):
+ if name == "properties":
+ name = "kind"
+ elif name == "fn":
+ continue
+ try:
+ value = getattr(self, name)
+ except AttributeError:
+ field_vals.append(f"{name}=<not set>")
+ continue
+ if isinstance(value, OpInputSeq):
+ value = list(value) # type: ignore
+ field_vals.append(f"{name}={value!r}")
+ field_vals_str = ", ".join(field_vals)
+ return f"Op({field_vals_str})"
+
+ def sim(self, state):
+ # type: (BaseSimState) -> None
+ for inp in self.input_vals:
+ try:
+ val = state[inp]
+ except KeyError:
+ raise ValueError(f"SSAVal {inp} not yet assigned when "
+ f"running {self}")
+ if len(val) != inp.ty.reg_len:
+ raise ValueError(
+ f"value of SSAVal {inp} has wrong number of elements: "
+ f"expected {inp.ty.reg_len} found "
+ f"{len(val)}: {val!r}")
+ if isinstance(state, PreRASimState):
+ for out in self.outputs:
+ if out in state.ssa_vals:
+ if self.kind is OpKind.FuncArgR3:
+ continue
+ raise ValueError(f"SSAVal {out} already assigned before "
+ f"running {self}")
+ self.kind.sim(self, state)
+ for out in self.outputs:
+ try:
+ val = state[out]
+ except KeyError:
+ raise ValueError(f"running {self} failed to assign to {out}")
+ if len(val) != out.ty.reg_len:
+ raise ValueError(
+ f"value of SSAVal {out} has wrong number of elements: "
+ f"expected {out.ty.reg_len} found "
+ f"{len(val)}: {val!r}")
+
+ def gen_asm(self, state):
+ # type: (GenAsmState) -> None
+ all_loc_kinds = tuple(LocKind)
+ for inp in self.input_vals:
+ state.loc(inp, expected_kinds=all_loc_kinds)
+ for out in self.outputs:
+ state.loc(out, expected_kinds=all_loc_kinds)
+ self.kind.gen_asm(self, state)
+
+
+GPR_SIZE_IN_BYTES = 8
+BITS_IN_BYTE = 8
+GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
+GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
+
+
+@plain_data(frozen=True, repr=False)
+class BaseSimState(metaclass=ABCMeta):
+ __slots__ = "memory",
+
+ def __init__(self, memory):
+ # type: (dict[int, int]) -> None
+ super().__init__()
+ self.memory = memory # type: dict[int, int]
+
+ def load_byte(self, addr):
+ # type: (int) -> int
+ addr &= GPR_VALUE_MASK
+ return self.memory.get(addr, 0) & 0xFF
+
+ def store_byte(self, addr, value):
+ # type: (int, int) -> None
+ addr &= GPR_VALUE_MASK
+ value &= 0xFF
+ self.memory[addr] = value
+
+ def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
+ # type: (int, int, bool) -> int
+ if addr % size_in_bytes != 0:
+ raise ValueError(f"address not aligned: {hex(addr)} "
+ f"required alignment: {size_in_bytes}")
+ retval = 0
+ for i in range(size_in_bytes):
+ retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
+ if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
+ retval -= 1 << size_in_bytes * BITS_IN_BYTE
+ return retval
+
+ def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
+ # type: (int, int, int) -> None
+ if addr % size_in_bytes != 0:
+ raise ValueError(f"address not aligned: {hex(addr)} "
+ f"required alignment: {size_in_bytes}")
+ for i in range(size_in_bytes):
+ self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
+
+ def _memory__repr(self):
+ # type: () -> str
+ if len(self.memory) == 0:
+ return "{}"
+ keys = sorted(self.memory.keys(), reverse=True)
+ CHUNK_SIZE = GPR_SIZE_IN_BYTES
+ items = [] # type: list[str]
+ while len(keys) != 0:
+ addr = keys[-1]
+ if (len(keys) >= CHUNK_SIZE
+ and addr % CHUNK_SIZE == 0
+ and keys[-CHUNK_SIZE:]
+ == list(reversed(range(addr, addr + CHUNK_SIZE)))):
+ value = self.load(addr, size_in_bytes=CHUNK_SIZE)
+ items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
+ keys[-CHUNK_SIZE:] = ()
+ else:
+ items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
+ if len(items) == 1:
+ return f"{{{items[0]}}}"
+ items_str = ",\n".join(items)
+ return f"{{\n{items_str}}}"
+
+ def __repr__(self):
+ # type: () -> str
+ field_vals = [] # type: list[str]
+ for name in fields(self):
+ try:
+ value = getattr(self, name)
+ except AttributeError:
+ field_vals.append(f"{name}=<not set>")
+ continue
+ repr_fn = getattr(self, f"_{name}__repr", None)
+ if callable(repr_fn):
+ field_vals.append(f"{name}={repr_fn()}")
+ else:
+ field_vals.append(f"{name}={value!r}")
+ field_vals_str = ", ".join(field_vals)
+ return f"{self.__class__.__name__}({field_vals_str})"
+
+ @abstractmethod
+ def __getitem__(self, ssa_val):
+ # type: (SSAVal) -> tuple[int, ...]
+ ...
+
+ @abstractmethod
+ def __setitem__(self, ssa_val, value):
+ # type: (SSAVal, tuple[int, ...]) -> None
+ ...
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class PreRASimState(BaseSimState):
+ __slots__ = "ssa_vals",
+
+ def __init__(self, ssa_vals, memory):
+ # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
+ super().__init__(memory)
+ self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
+
+ def _ssa_vals__repr(self):
+ # type: () -> str
+ if len(self.ssa_vals) == 0:
+ return "{}"
+ items = [] # type: list[str]
+ CHUNK_SIZE = 4
+ for k, v in self.ssa_vals.items():
+ element_strs = [] # type: list[str]
+ for i, el in enumerate(v):
+ if i % CHUNK_SIZE != 0:
+ element_strs.append(" " + hex(el))
+ else:
+ element_strs.append("\n " + hex(el))
+ if len(element_strs) <= CHUNK_SIZE:
+ element_strs[0] = element_strs[0].lstrip()
+ if len(element_strs) == 1:
+ element_strs.append("")
+ v_str = ",".join(element_strs)
+ items.append(f"{k!r}: ({v_str})")
+ if len(items) == 1 and "\n" not in items[0]:
+ return f"{{{items[0]}}}"
+ items_str = ",\n".join(items)
+ return f"{{\n{items_str},\n}}"
+
+ def __getitem__(self, ssa_val):
+ # type: (SSAVal) -> tuple[int, ...]
+ return self.ssa_vals[ssa_val]
+
+ def __setitem__(self, ssa_val, value):
+ # type: (SSAVal, tuple[int, ...]) -> None
+ if len(value) != ssa_val.ty.reg_len:
+ raise ValueError("value has wrong len")
+ self.ssa_vals[ssa_val] = value
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class PostRASimState(BaseSimState):
+ __slots__ = "ssa_val_to_loc_map", "loc_values"
+
+ def __init__(self, ssa_val_to_loc_map, memory, loc_values):
+ # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
+ super().__init__(memory)
+ self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
+ for ssa_val, loc in self.ssa_val_to_loc_map.items():
+ if ssa_val.ty != loc.ty:
+ raise ValueError(
+ f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
+ self.loc_values = loc_values
+ for loc in self.loc_values.keys():
+ if loc.reg_len != 1:
+ raise ValueError(
+ "loc_values must only contain Locs with reg_len=1, all "
+ "larger Locs will be split into reg_len=1 sub-Locs")
+
+ def _loc_values__repr(self):
+ # type: () -> str
+ locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start))
+ items = [] # type: list[str]
+ for loc in locs:
+ items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
+ items_str = ",\n".join(items)
+ return f"{{\n{items_str},\n}}"
+
+ def __getitem__(self, ssa_val):
+ # type: (SSAVal) -> tuple[int, ...]
+ loc = self.ssa_val_to_loc_map[ssa_val]
+ subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
+ retval = [] # type: list[int]
+ for i in range(loc.reg_len):
+ subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
+ retval.append(self.loc_values.get(subloc, 0))
+ return tuple(retval)
+
+ def __setitem__(self, ssa_val, value):
+ # type: (SSAVal, tuple[int, ...]) -> None
+ if len(value) != ssa_val.ty.reg_len:
+ raise ValueError("value has wrong len")
+ loc = self.ssa_val_to_loc_map[ssa_val]
+ subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
+ for i in range(loc.reg_len):
+ subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
+ self.loc_values[subloc] = value[i]
+
+
+@plain_data(frozen=True)
+class GenAsmState:
+ __slots__ = "allocated_locs", "output"
+
+ def __init__(self, allocated_locs, output=None):
+ # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
+ super().__init__()
+ self.allocated_locs = FMap(allocated_locs)
+ for ssa_val, loc in self.allocated_locs.items():
+ if ssa_val.ty != loc.ty:
+ raise ValueError(
+ f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
+ if output is None:
+ output = []
+ self.output = output
+
+ __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
+
+ def loc(self, ssa_val_or_locs, expected_kinds):
+ # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
+ if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
+ ssa_val_or_locs = [ssa_val_or_locs]
+ locs = [] # type: list[Loc]
+ for i in ssa_val_or_locs:
+ if isinstance(i, SSAVal):
+ locs.append(self.allocated_locs[i])
+ else:
+ locs.append(i)
+ if len(locs) == 0:
+ raise ValueError("invalid Loc sequence: must not be empty")
+ retval = locs[0].try_concat(*locs[1:])
+ if retval is None:
+ raise ValueError("invalid Loc sequence: try_concat failed")
+ if isinstance(expected_kinds, LocKind):
+ expected_kinds = expected_kinds,
+ if retval.kind not in expected_kinds:
+ if len(expected_kinds) == 1:
+ expected_kinds = expected_kinds[0]
+ raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
+ f"{retval.kind} expected {expected_kinds}")
+ return retval
+
+ def gpr(self, ssa_val_or_locs, is_vec):
+ # type: (__SSA_VAL_OR_LOCS, bool) -> str
+ loc = self.loc(ssa_val_or_locs, LocKind.GPR)
+ vec_str = "*" if is_vec else ""
+ return vec_str + str(loc.start)
+
+ def sgpr(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ return self.gpr(ssa_val_or_locs, is_vec=False)
+
+ def vgpr(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ return self.gpr(ssa_val_or_locs, is_vec=True)
+
+ def stack(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
+ return f"{loc.start}(1)"
+
+ def writeln(self, *line_segments):
+ # type: (*str) -> None
+ line = " ".join(line_segments)
+ if isinstance(self.output, list):
+ self.output.append(line)
+ else:
+ self.output.write(line + "\n")
+++ /dev/null
-import enum
-from abc import ABCMeta, abstractmethod
-from enum import Enum, unique
-from functools import lru_cache, total_ordering
-from io import StringIO
-from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
- Mapping, Sequence, TypeVar, Union, overload)
-from weakref import WeakValueDictionary as _WeakVDict
-
-from cached_property import cached_property
-from nmutil.plain_data import fields, plain_data
-
-from bigint_presentation_code.type_util import (Literal, Self, assert_never,
- final)
-from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
- OFSet, OSet)
-
-
-@final
-class Fn:
- def __init__(self):
- self.ops = [] # type: list[Op]
- self.__op_names = _WeakVDict() # type: _WeakVDict[str, Op]
- self.__next_name_suffix = 2
-
- def _add_op_with_unused_name(self, op, name=""):
- # type: (Op, str) -> str
- if op.fn is not self:
- raise ValueError("can't add Op to wrong Fn")
- if hasattr(op, "name"):
- raise ValueError("Op already named")
- orig_name = name
- while True:
- if name != "" and name not in self.__op_names:
- self.__op_names[name] = op
- return name
- name = orig_name + str(self.__next_name_suffix)
- self.__next_name_suffix += 1
-
- def __repr__(self):
- # type: () -> str
- return "<Fn>"
-
- def append_op(self, op):
- # type: (Op) -> None
- if op.fn is not self:
- raise ValueError("can't add Op to wrong Fn")
- self.ops.append(op)
-
- def append_new_op(self, kind, input_vals=(), immediates=(), name="",
- maxvl=1):
- # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
- retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
- input_vals=input_vals, immediates=immediates, name=name)
- self.append_op(retval)
- return retval
-
- def sim(self, state):
- # type: (BaseSimState) -> None
- for op in self.ops:
- op.sim(state)
-
- def gen_asm(self, state):
- # type: (GenAsmState) -> None
- for op in self.ops:
- op.gen_asm(state)
-
- def pre_ra_insert_copies(self):
- # type: () -> None
- orig_ops = list(self.ops)
- copied_outputs = {} # type: dict[SSAVal, SSAVal]
- setvli_outputs = {} # type: dict[SSAVal, Op]
- self.ops.clear()
- for op in orig_ops:
- for i in range(len(op.input_vals)):
- inp = copied_outputs[op.input_vals[i]]
- if inp.ty.base_ty is BaseTy.I64:
- maxvl = inp.ty.reg_len
- if inp.ty.reg_len != 1:
- setvl = self.append_new_op(
- OpKind.SetVLI, immediates=[maxvl],
- name=f"{op.name}.inp{i}.setvl")
- vl = setvl.outputs[0]
- mv = self.append_new_op(
- OpKind.VecCopyToReg, input_vals=[inp, vl],
- maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
- else:
- mv = self.append_new_op(
- OpKind.CopyToReg, input_vals=[inp],
- name=f"{op.name}.inp{i}.copy")
- op.input_vals[i] = mv.outputs[0]
- elif inp.ty.base_ty is BaseTy.CA \
- or inp.ty.base_ty is BaseTy.VL_MAXVL:
- # all copies would be no-ops, so we don't need to copy,
- # though we do need to rematerialize SetVLI ops right
- # before the ops VL
- if inp in setvli_outputs:
- setvl = self.append_new_op(
- OpKind.SetVLI,
- immediates=setvli_outputs[inp].immediates,
- name=f"{op.name}.inp{i}.setvl")
- inp = setvl.outputs[0]
- op.input_vals[i] = inp
- else:
- assert_never(inp.ty.base_ty)
- self.ops.append(op)
- for i, out in enumerate(op.outputs):
- if op.kind is OpKind.SetVLI:
- setvli_outputs[out] = op
- if out.ty.base_ty is BaseTy.I64:
- maxvl = out.ty.reg_len
- if out.ty.reg_len != 1:
- setvl = self.append_new_op(
- OpKind.SetVLI, immediates=[maxvl],
- name=f"{op.name}.out{i}.setvl")
- vl = setvl.outputs[0]
- mv = self.append_new_op(
- OpKind.VecCopyFromReg, input_vals=[out, vl],
- maxvl=maxvl, name=f"{op.name}.out{i}.copy")
- else:
- mv = self.append_new_op(
- OpKind.CopyFromReg, input_vals=[out],
- name=f"{op.name}.out{i}.copy")
- copied_outputs[out] = mv.outputs[0]
- elif out.ty.base_ty is BaseTy.CA \
- or out.ty.base_ty is BaseTy.VL_MAXVL:
- # all copies would be no-ops, so we don't need to copy
- copied_outputs[out] = out
- else:
- assert_never(out.ty.base_ty)
-
-
-@final
-@unique
-@total_ordering
-class OpStage(Enum):
- value: Literal[0, 1] # type: ignore
-
- def __new__(cls, value):
- # type: (int) -> OpStage
- value = int(value)
- if value not in (0, 1):
- raise ValueError("invalid value")
- retval = object.__new__(cls)
- retval._value_ = value
- return retval
-
- Early = 0
- """ early stage of Op execution, where all input reads occur.
- all output writes with `write_stage == Early` occur here too, and therefore
- conflict with input reads, telling the compiler that it that can't share
- that output's register with any inputs that the output isn't tied to.
-
- All outputs, even unused outputs, can't share registers with any other
- outputs, independent of `write_stage` settings.
- """
- Late = 1
- """ late stage of Op execution, where all output writes with
- `write_stage == Late` occur, and therefore don't conflict with input reads,
- telling the compiler that any inputs can safely use the same register as
- those outputs.
-
- All outputs, even unused outputs, can't share registers with any other
- outputs, independent of `write_stage` settings.
- """
-
- def __repr__(self):
- # type: () -> str
- return f"OpStage.{self._name_}"
-
- def __lt__(self, other):
- # type: (OpStage | object) -> bool
- if isinstance(other, OpStage):
- return self.value < other.value
- return NotImplemented
-
-
-assert OpStage.Early < OpStage.Late, "early must be less than late"
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-@final
-@total_ordering
-class ProgramPoint(metaclass=InternedMeta):
- __slots__ = "op_index", "stage"
-
- def __init__(self, op_index, stage):
- # type: (int, OpStage) -> None
- self.op_index = op_index
- self.stage = stage
-
- @property
- def int_value(self):
- # type: () -> int
- """ an integer representation of `self` such that it keeps ordering and
- successor/predecessor relations.
- """
- return self.op_index * 2 + self.stage.value
-
- @staticmethod
- def from_int_value(int_value):
- # type: (int) -> ProgramPoint
- op_index, stage = divmod(int_value, 2)
- return ProgramPoint(op_index=op_index, stage=OpStage(stage))
-
- def next(self, steps=1):
- # type: (int) -> ProgramPoint
- return ProgramPoint.from_int_value(self.int_value + steps)
-
- def prev(self, steps=1):
- # type: (int) -> ProgramPoint
- return self.next(steps=-steps)
-
- def __lt__(self, other):
- # type: (ProgramPoint | Any) -> bool
- if not isinstance(other, ProgramPoint):
- return NotImplemented
- if self.op_index != other.op_index:
- return self.op_index < other.op_index
- return self.stage < other.stage
-
- def __repr__(self):
- # type: () -> str
- return f"<ops[{self.op_index}]:{self.stage._name_}>"
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-@final
-class ProgramRange(Sequence[ProgramPoint], metaclass=InternedMeta):
- __slots__ = "start", "stop"
-
- def __init__(self, start, stop):
- # type: (ProgramPoint, ProgramPoint) -> None
- self.start = start
- self.stop = stop
-
- @cached_property
- def int_value_range(self):
- # type: () -> range
- return range(self.start.int_value, self.stop.int_value)
-
- @staticmethod
- def from_int_value_range(int_value_range):
- # type: (range) -> ProgramRange
- if int_value_range.step != 1:
- raise ValueError("int_value_range must have step == 1")
- return ProgramRange(
- start=ProgramPoint.from_int_value(int_value_range.start),
- stop=ProgramPoint.from_int_value(int_value_range.stop))
-
- @overload
- def __getitem__(self, __idx):
- # type: (int) -> ProgramPoint
- ...
-
- @overload
- def __getitem__(self, __idx):
- # type: (slice) -> ProgramRange
- ...
-
- def __getitem__(self, __idx):
- # type: (int | slice) -> ProgramPoint | ProgramRange
- v = range(self.start.int_value, self.stop.int_value)[__idx]
- if isinstance(v, int):
- return ProgramPoint.from_int_value(v)
- return ProgramRange.from_int_value_range(v)
-
- def __len__(self):
- # type: () -> int
- return len(self.int_value_range)
-
- def __iter__(self):
- # type: () -> Iterator[ProgramPoint]
- return map(ProgramPoint.from_int_value, self.int_value_range)
-
- def __repr__(self):
- # type: () -> str
- start = repr(self.start).lstrip("<").rstrip(">")
- stop = repr(self.stop).lstrip("<").rstrip(">")
- return f"<range:{start}..{stop}>"
-
-
-@plain_data(frozen=True, eq=False, repr=False)
-@final
-class FnAnalysis:
- __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
- "def_program_ranges", "use_program_points",
- "all_program_points")
-
- def __init__(self, fn):
- # type: (Fn) -> None
- self.fn = fn
- self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
- self.all_program_points = ProgramRange(
- start=ProgramPoint(op_index=0, stage=OpStage.Early),
- stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
- def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
- use_program_points = {} # type: dict[SSAUse, ProgramPoint]
- uses = {} # type: dict[SSAVal, OSet[SSAUse]]
- live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
- for op in fn.ops:
- for use in op.input_uses:
- uses[use.ssa_val].add(use)
- use_program_point = self.__get_use_program_point(use)
- use_program_points[use] = use_program_point
- live_range_stops[use.ssa_val] = max(
- live_range_stops[use.ssa_val], use_program_point.next())
- for out in op.outputs:
- uses[out] = OSet()
- def_program_range = self.__get_def_program_range(out)
- def_program_ranges[out] = def_program_range
- live_range_stops[out] = def_program_range.stop
- self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
- self.def_program_ranges = FMap(def_program_ranges)
- self.use_program_points = FMap(use_program_points)
- live_ranges = {} # type: dict[SSAVal, ProgramRange]
- live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
- for ssa_val in uses.keys():
- live_ranges[ssa_val] = live_range = ProgramRange(
- start=self.def_program_ranges[ssa_val].start,
- stop=live_range_stops[ssa_val])
- for program_point in live_range:
- live_at[program_point].add(ssa_val)
- self.live_ranges = FMap(live_ranges)
- self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
-
- def __get_def_program_range(self, ssa_val):
- # type: (SSAVal) -> ProgramRange
- write_stage = ssa_val.defining_descriptor.write_stage
- start = ProgramPoint(
- op_index=self.op_indexes[ssa_val.op], stage=write_stage)
- # always include late stage of ssa_val.op, to ensure outputs always
- # overlap all other outputs.
- # stop is exclusive, so we need the next program point.
- stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
- return ProgramRange(start=start, stop=stop)
-
- def __get_use_program_point(self, ssa_use):
- # type: (SSAUse) -> ProgramPoint
- assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
- "assumed here, ensured by GenericOpProperties.__init__"
- return ProgramPoint(
- op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
-
- def __eq__(self, other):
- # type: (FnAnalysis | Any) -> bool
- if isinstance(other, FnAnalysis):
- return self.fn == other.fn
- return NotImplemented
-
- def __hash__(self):
- # type: () -> int
- return hash(self.fn)
-
- def __repr__(self):
- # type: () -> str
- return "<FnAnalysis>"
-
-
-@unique
-@final
-class BaseTy(Enum):
- I64 = enum.auto()
- CA = enum.auto()
- VL_MAXVL = enum.auto()
-
- @cached_property
- def only_scalar(self):
- # type: () -> bool
- if self is BaseTy.I64:
- return False
- elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
- return True
- else:
- assert_never(self)
-
- @cached_property
- def max_reg_len(self):
- # type: () -> int
- if self is BaseTy.I64:
- return 128
- elif self is BaseTy.CA or self is BaseTy.VL_MAXVL:
- return 1
- else:
- assert_never(self)
-
- def __repr__(self):
- return "BaseTy." + self._name_
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-@final
-class Ty(metaclass=InternedMeta):
- __slots__ = "base_ty", "reg_len"
-
- @staticmethod
- def validate(base_ty, reg_len):
- # type: (BaseTy, int) -> str | None
- """ return a string with the error if the combination is invalid,
- otherwise return None
- """
- if base_ty.only_scalar and reg_len != 1:
- return f"can't create a vector of an only-scalar type: {base_ty}"
- if reg_len < 1 or reg_len > base_ty.max_reg_len:
- return "reg_len out of range"
- return None
-
- def __init__(self, base_ty, reg_len):
- # type: (BaseTy, int) -> None
- msg = self.validate(base_ty=base_ty, reg_len=reg_len)
- if msg is not None:
- raise ValueError(msg)
- self.base_ty = base_ty
- self.reg_len = reg_len
-
- def __repr__(self):
- # type: () -> str
- if self.reg_len != 1:
- reg_len = f"*{self.reg_len}"
- else:
- reg_len = ""
- return f"<{self.base_ty._name_}{reg_len}>"
-
-
-@unique
-@final
-class LocKind(Enum):
- GPR = enum.auto()
- StackI64 = enum.auto()
- CA = enum.auto()
- VL_MAXVL = enum.auto()
-
- @cached_property
- def base_ty(self):
- # type: () -> BaseTy
- if self is LocKind.GPR or self is LocKind.StackI64:
- return BaseTy.I64
- if self is LocKind.CA:
- return BaseTy.CA
- if self is LocKind.VL_MAXVL:
- return BaseTy.VL_MAXVL
- else:
- assert_never(self)
-
- @cached_property
- def loc_count(self):
- # type: () -> int
- if self is LocKind.StackI64:
- return 512
- if self is LocKind.GPR or self is LocKind.CA \
- or self is LocKind.VL_MAXVL:
- return self.base_ty.max_reg_len
- else:
- assert_never(self)
-
- def __repr__(self):
- return "LocKind." + self._name_
-
-
-@final
-@unique
-class LocSubKind(Enum):
- BASE_GPR = enum.auto()
- SV_EXTRA2_VGPR = enum.auto()
- SV_EXTRA2_SGPR = enum.auto()
- SV_EXTRA3_VGPR = enum.auto()
- SV_EXTRA3_SGPR = enum.auto()
- StackI64 = enum.auto()
- CA = enum.auto()
- VL_MAXVL = enum.auto()
-
- @cached_property
- def kind(self):
- # type: () -> LocKind
- # pyright fails typechecking when using `in` here:
- # reported: https://github.com/microsoft/pyright/issues/4102
- if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
- LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
- LocSubKind.SV_EXTRA3_SGPR):
- return LocKind.GPR
- if self is LocSubKind.StackI64:
- return LocKind.StackI64
- if self is LocSubKind.CA:
- return LocKind.CA
- if self is LocSubKind.VL_MAXVL:
- return LocKind.VL_MAXVL
- assert_never(self)
-
- @property
- def base_ty(self):
- return self.kind.base_ty
-
- @lru_cache()
- def allocatable_locs(self, ty):
- # type: (Ty) -> LocSet
- if ty.base_ty != self.base_ty:
- raise ValueError("type mismatch")
- if self is LocSubKind.BASE_GPR:
- starts = range(32)
- elif self is LocSubKind.SV_EXTRA2_VGPR:
- starts = range(0, 128, 2)
- elif self is LocSubKind.SV_EXTRA2_SGPR:
- starts = range(64)
- elif self is LocSubKind.SV_EXTRA3_VGPR \
- or self is LocSubKind.SV_EXTRA3_SGPR:
- starts = range(128)
- elif self is LocSubKind.StackI64:
- starts = range(LocKind.StackI64.loc_count)
- elif self is LocSubKind.CA or self is LocSubKind.VL_MAXVL:
- return LocSet([Loc(kind=self.kind, start=0, reg_len=1)])
- else:
- assert_never(self)
- retval = [] # type: list[Loc]
- for start in starts:
- loc = Loc.try_make(kind=self.kind, start=start, reg_len=ty.reg_len)
- if loc is None:
- continue
- conflicts = False
- for special_loc in SPECIAL_GPRS:
- if loc.conflicts(special_loc):
- conflicts = True
- break
- if not conflicts:
- retval.append(loc)
- return LocSet(retval)
-
- def __repr__(self):
- return "LocSubKind." + self._name_
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GenericTy(metaclass=InternedMeta):
- __slots__ = "base_ty", "is_vec"
-
- def __init__(self, base_ty, is_vec):
- # type: (BaseTy, bool) -> None
- self.base_ty = base_ty
- if base_ty.only_scalar and is_vec:
- raise ValueError(f"base_ty={base_ty} requires is_vec=False")
- self.is_vec = is_vec
-
- def instantiate(self, maxvl):
- # type: (int) -> Ty
- # here's where subvl and elwid would be accounted for
- if self.is_vec:
- return Ty(self.base_ty, maxvl)
- return Ty(self.base_ty, 1)
-
- def can_instantiate_to(self, ty):
- # type: (Ty) -> bool
- if self.base_ty != ty.base_ty:
- return False
- if self.is_vec:
- return True
- return ty.reg_len == 1
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class Loc(metaclass=InternedMeta):
- __slots__ = "kind", "start", "reg_len"
-
- @staticmethod
- def validate(kind, start, reg_len):
- # type: (LocKind, int, int) -> str | None
- msg = Ty.validate(base_ty=kind.base_ty, reg_len=reg_len)
- if msg is not None:
- return msg
- if reg_len > kind.loc_count:
- return "invalid reg_len"
- if start < 0 or start + reg_len > kind.loc_count:
- return "start not in valid range"
- return None
-
- @staticmethod
- def try_make(kind, start, reg_len):
- # type: (LocKind, int, int) -> Loc | None
- msg = Loc.validate(kind=kind, start=start, reg_len=reg_len)
- if msg is not None:
- return None
- return Loc(kind=kind, start=start, reg_len=reg_len)
-
- def __init__(self, kind, start, reg_len):
- # type: (LocKind, int, int) -> None
- msg = self.validate(kind=kind, start=start, reg_len=reg_len)
- if msg is not None:
- raise ValueError(msg)
- self.kind = kind
- self.reg_len = reg_len
- self.start = start
-
- def conflicts(self, other):
- # type: (Loc) -> bool
- return (self.kind == other.kind
- and self.start < other.stop and other.start < self.stop)
-
- @staticmethod
- def make_ty(kind, reg_len):
- # type: (LocKind, int) -> Ty
- return Ty(base_ty=kind.base_ty, reg_len=reg_len)
-
- @cached_property
- def ty(self):
- # type: () -> Ty
- return self.make_ty(kind=self.kind, reg_len=self.reg_len)
-
- @property
- def stop(self):
- # type: () -> int
- return self.start + self.reg_len
-
- def try_concat(self, *others):
- # type: (*Loc | None) -> Loc | None
- reg_len = self.reg_len
- stop = self.stop
- for other in others:
- if other is None or other.kind != self.kind:
- return None
- if stop != other.start:
- return None
- stop = other.stop
- reg_len += other.reg_len
- return Loc(kind=self.kind, start=self.start, reg_len=reg_len)
-
- def get_subloc_at_offset(self, subloc_ty, offset):
- # type: (Ty, int) -> Loc
- if subloc_ty.base_ty != self.kind.base_ty:
- raise ValueError("BaseTy mismatch")
- if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
- raise ValueError("invalid sub-Loc: offset and/or "
- "subloc_ty.reg_len out of range")
- return Loc(kind=self.kind,
- start=self.start + offset, reg_len=subloc_ty.reg_len)
-
-
-SPECIAL_GPRS = (
- Loc(kind=LocKind.GPR, start=0, reg_len=1),
- Loc(kind=LocKind.GPR, start=1, reg_len=1),
- Loc(kind=LocKind.GPR, start=2, reg_len=1),
- Loc(kind=LocKind.GPR, start=13, reg_len=1),
-)
-
-
-@final
-class _LocSetHashHelper(AbstractSet[Loc]):
- """helper to more quickly compute LocSet's hash"""
-
- def __init__(self, locs):
- # type: (Iterable[Loc]) -> None
- super().__init__()
- self.locs = list(locs)
-
- def __hash__(self):
- # type: () -> int
- return super()._hash()
-
- def __contains__(self, x):
- # type: (Loc | Any) -> bool
- return x in self.locs
-
- def __iter__(self):
- # type: () -> Iterator[Loc]
- return iter(self.locs)
-
- def __len__(self):
- return len(self.locs)
-
-
-@plain_data(frozen=True, eq=False, repr=False)
-@final
-class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
- __slots__ = "starts", "ty", "_LocSet__hash"
-
- def __init__(self, __locs=()):
- # type: (Iterable[Loc]) -> None
- if isinstance(__locs, LocSet):
- self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
- self.ty = __locs.ty # type: Ty | None
- self._LocSet__hash = __locs._LocSet__hash # type: int
- return
- starts = {i: BitSet() for i in LocKind}
- ty = None # type: None | Ty
-
- def locs():
- # type: () -> Iterable[Loc]
- nonlocal ty
- for loc in __locs:
- if ty is None:
- ty = loc.ty
- if ty != loc.ty:
- raise ValueError(f"conflicting types: {ty} != {loc.ty}")
- starts[loc.kind].add(loc.start)
- yield loc
- self._LocSet__hash = _LocSetHashHelper(locs()).__hash__()
- self.starts = FMap(
- (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
- self.ty = ty
-
- @cached_property
- def stops(self):
- # type: () -> FMap[LocKind, FBitSet]
- if self.ty is None:
- return FMap()
- sh = self.ty.reg_len
- return FMap(
- (k, FBitSet(bits=v.bits << sh)) for k, v in self.starts.items())
-
- @property
- def kinds(self):
- # type: () -> AbstractSet[LocKind]
- return self.starts.keys()
-
- @property
- def reg_len(self):
- # type: () -> int | None
- if self.ty is None:
- return None
- return self.ty.reg_len
-
- @property
- def base_ty(self):
- # type: () -> BaseTy | None
- if self.ty is None:
- return None
- return self.ty.base_ty
-
- def concat(self, *others):
- # type: (*LocSet) -> LocSet
- if self.ty is None:
- return LocSet()
- base_ty = self.ty.base_ty
- reg_len = self.ty.reg_len
- starts = {k: BitSet(v) for k, v in self.starts.items()}
- for other in others:
- if other.ty is None:
- return LocSet()
- if other.ty.base_ty != base_ty:
- return LocSet()
- for kind, other_starts in other.starts.items():
- if kind not in starts:
- continue
- starts[kind].bits &= other_starts.bits >> reg_len
- if starts[kind] == 0:
- del starts[kind]
- if len(starts) == 0:
- return LocSet()
- reg_len += other.ty.reg_len
-
- def locs():
- # type: () -> Iterable[Loc]
- for kind, v in starts.items():
- for start in v:
- loc = Loc.try_make(kind=kind, start=start, reg_len=reg_len)
- if loc is not None:
- yield loc
- return LocSet(locs())
-
- def __contains__(self, loc):
- # type: (Loc | Any) -> bool
- if not isinstance(loc, Loc) or loc.ty != self.ty:
- return False
- if loc.kind not in self.starts:
- return False
- return loc.start in self.starts[loc.kind]
-
- def __iter__(self):
- # type: () -> Iterator[Loc]
- if self.ty is None:
- return
- for kind, starts in self.starts.items():
- for start in starts:
- yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
-
- @cached_property
- def __len(self):
- return sum((len(v) for v in self.starts.values()), 0)
-
- def __len__(self):
- return self.__len
-
- def __hash__(self):
- return self._LocSet__hash
-
- def __eq__(self, __other):
- # type: (LocSet | Any) -> bool
- if isinstance(__other, LocSet):
- return self.ty == __other.ty and self.starts == __other.starts
- return super().__eq__(__other)
-
- @lru_cache(maxsize=None, typed=True)
- def max_conflicts_with(self, other):
- # type: (LocSet | Loc) -> int
- """the largest number of Locs in `self` that a single Loc
- from `other` can conflict with
- """
- if isinstance(other, LocSet):
- return max(self.max_conflicts_with(i) for i in other)
- else:
- return sum(other.conflicts(i) for i in self)
-
- def __repr__(self):
- items = [] # type: list[str]
- for name in fields(self):
- if name.startswith("_"):
- continue
- items.append(f"{name}={getattr(self, name)!r}")
- return f"LocSet({', '.join(items)})"
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GenericOperandDesc(metaclass=InternedMeta):
- """generic Op operand descriptor"""
- __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
- "write_stage")
-
- def __init__(
- self, ty, # type: GenericTy
- sub_kinds, # type: Iterable[LocSubKind]
- *,
- fixed_loc=None, # type: Loc | None
- tied_input_index=None, # type: int | None
- spread=False, # type: bool
- write_stage=OpStage.Early, # type: OpStage
- ):
- # type: (...) -> None
- self.ty = ty
- self.sub_kinds = OFSet(sub_kinds)
- if len(self.sub_kinds) == 0:
- raise ValueError("sub_kinds can't be empty")
- self.fixed_loc = fixed_loc
- if fixed_loc is not None:
- if tied_input_index is not None:
- raise ValueError("operand can't be both tied and fixed")
- if not ty.can_instantiate_to(fixed_loc.ty):
- raise ValueError(
- f"fixed_loc has incompatible type for given generic "
- f"type: fixed_loc={fixed_loc} generic ty={ty}")
- if len(self.sub_kinds) != 1:
- raise ValueError(
- "multiple sub_kinds not allowed for fixed operand")
- for sub_kind in self.sub_kinds:
- if fixed_loc not in sub_kind.allocatable_locs(fixed_loc.ty):
- raise ValueError(
- f"fixed_loc not in given sub_kind: "
- f"fixed_loc={fixed_loc} sub_kind={sub_kind}")
- for sub_kind in self.sub_kinds:
- if sub_kind.base_ty != ty.base_ty:
- raise ValueError(f"sub_kind is incompatible with type: "
- f"sub_kind={sub_kind} ty={ty}")
- if tied_input_index is not None and tied_input_index < 0:
- raise ValueError("invalid tied_input_index")
- self.tied_input_index = tied_input_index
- self.spread = spread
- if spread:
- if self.tied_input_index is not None:
- raise ValueError("operand can't be both spread and tied")
- if self.fixed_loc is not None:
- raise ValueError("operand can't be both spread and fixed")
- if self.ty.is_vec:
- raise ValueError("operand can't be both spread and vector")
- self.write_stage = write_stage
-
- @cached_property
- def ty_before_spread(self):
- # type: () -> GenericTy
- if self.spread:
- return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
- return self.ty
-
- def tied_to_input(self, tied_input_index):
- # type: (int) -> Self
- return GenericOperandDesc(self.ty, self.sub_kinds,
- tied_input_index=tied_input_index,
- write_stage=self.write_stage)
-
- def with_fixed_loc(self, fixed_loc):
- # type: (Loc) -> Self
- return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
- write_stage=self.write_stage)
-
- def with_write_stage(self, write_stage):
- # type: (OpStage) -> Self
- return GenericOperandDesc(self.ty, self.sub_kinds,
- fixed_loc=self.fixed_loc,
- tied_input_index=self.tied_input_index,
- spread=self.spread,
- write_stage=write_stage)
-
- def instantiate(self, maxvl):
- # type: (int) -> Iterable[OperandDesc]
- # assumes all spread operands have ty.reg_len = 1
- rep_count = 1
- if self.spread:
- rep_count = maxvl
- ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
-
- def locs_before_spread():
- # type: () -> Iterable[Loc]
- if self.fixed_loc is not None:
- if ty_before_spread != self.fixed_loc.ty:
- raise ValueError(
- f"instantiation failed: type mismatch with fixed_loc: "
- f"instantiated type: {ty_before_spread} "
- f"fixed_loc: {self.fixed_loc}")
- yield self.fixed_loc
- return
- for sub_kind in self.sub_kinds:
- yield from sub_kind.allocatable_locs(ty_before_spread)
- loc_set_before_spread = LocSet(locs_before_spread())
- for idx in range(rep_count):
- if not self.spread:
- idx = None
- yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
- tied_input_index=self.tied_input_index,
- spread_index=idx, write_stage=self.write_stage)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class OperandDesc(metaclass=InternedMeta):
- """Op operand descriptor"""
- __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
- "write_stage")
-
- def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
- write_stage):
- # type: (LocSet, int | None, int | None, OpStage) -> None
- if len(loc_set_before_spread) == 0:
- raise ValueError("loc_set_before_spread must not be empty")
- self.loc_set_before_spread = loc_set_before_spread
- self.tied_input_index = tied_input_index
- if self.tied_input_index is not None and spread_index is not None:
- raise ValueError("operand can't be both spread and tied")
- self.spread_index = spread_index
- self.write_stage = write_stage
-
- @cached_property
- def ty_before_spread(self):
- # type: () -> Ty
- ty = self.loc_set_before_spread.ty
- assert ty is not None, (
- "__init__ checked that the LocSet isn't empty, "
- "non-empty LocSets should always have ty set")
- return ty
-
- @cached_property
- def ty(self):
- """ Ty after any spread is applied """
- if self.spread_index is not None:
- # assumes all spread operands have ty.reg_len = 1
- return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
- return self.ty_before_spread
-
- @property
- def reg_offset_in_unspread(self):
- """ the number of reg-sized slots in the unspread Loc before self's Loc
-
- e.g. if the unspread Loc containing self is:
- `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
- and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
- then reg_offset_into_unspread == 2 == 10 - 8
- """
- if self.spread_index is None:
- return 0
- return self.spread_index * self.ty.reg_len
-
-
-OD_BASE_SGPR = GenericOperandDesc(
- ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.BASE_GPR])
-OD_EXTRA3_SGPR = GenericOperandDesc(
- ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA3_SGPR])
-OD_EXTRA3_VGPR = GenericOperandDesc(
- ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
- sub_kinds=[LocSubKind.SV_EXTRA3_VGPR])
-OD_EXTRA2_SGPR = GenericOperandDesc(
- ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA2_SGPR])
-OD_EXTRA2_VGPR = GenericOperandDesc(
- ty=GenericTy(base_ty=BaseTy.I64, is_vec=True),
- sub_kinds=[LocSubKind.SV_EXTRA2_VGPR])
-OD_CA = GenericOperandDesc(
- ty=GenericTy(base_ty=BaseTy.CA, is_vec=False),
- sub_kinds=[LocSubKind.CA])
-OD_VL = GenericOperandDesc(
- ty=GenericTy(base_ty=BaseTy.VL_MAXVL, is_vec=False),
- sub_kinds=[LocSubKind.VL_MAXVL])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GenericOpProperties(metaclass=InternedMeta):
- __slots__ = ("demo_asm", "inputs", "outputs", "immediates",
- "is_copy", "is_load_immediate", "has_side_effects")
-
- def __init__(
- self, demo_asm, # type: str
- inputs, # type: Iterable[GenericOperandDesc]
- outputs, # type: Iterable[GenericOperandDesc]
- immediates=(), # type: Iterable[range]
- is_copy=False, # type: bool
- is_load_immediate=False, # type: bool
- has_side_effects=False, # type: bool
- ):
- # type: (...) -> None
- self.demo_asm = demo_asm # type: str
- self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
- for inp in self.inputs:
- if inp.tied_input_index is not None:
- raise ValueError(
- f"tied_input_index is not allowed on inputs: {inp}")
- if inp.write_stage is not OpStage.Early:
- raise ValueError(
- f"write_stage is not allowed on inputs: {inp}")
- self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
- fixed_locs = [] # type: list[tuple[Loc, int]]
- for idx, out in enumerate(self.outputs):
- if out.tied_input_index is not None:
- if out.tied_input_index >= len(self.inputs):
- raise ValueError(f"tied_input_index out of range: {out}")
- tied_inp = self.inputs[out.tied_input_index]
- expected_out = tied_inp.tied_to_input(out.tied_input_index) \
- .with_write_stage(out.write_stage)
- if expected_out != out:
- raise ValueError(f"output can't be tied to non-equivalent "
- f"input: {out} tied to {tied_inp}")
- if out.fixed_loc is not None:
- for other_fixed_loc, other_idx in fixed_locs:
- if not other_fixed_loc.conflicts(out.fixed_loc):
- continue
- raise ValueError(
- f"conflicting fixed_locs: outputs[{idx}] and "
- f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
- f"with {other_fixed_loc}")
- fixed_locs.append((out.fixed_loc, idx))
- self.immediates = tuple(immediates) # type: tuple[range, ...]
- self.is_copy = is_copy # type: bool
- self.is_load_immediate = is_load_immediate # type: bool
- self.has_side_effects = has_side_effects # type: bool
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class OpProperties(metaclass=InternedMeta):
- __slots__ = "kind", "inputs", "outputs", "maxvl"
-
- def __init__(self, kind, maxvl):
- # type: (OpKind, int) -> None
- self.kind = kind # type: OpKind
- inputs = [] # type: list[OperandDesc]
- for inp in self.generic.inputs:
- inputs.extend(inp.instantiate(maxvl=maxvl))
- self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
- outputs = [] # type: list[OperandDesc]
- for out in self.generic.outputs:
- outputs.extend(out.instantiate(maxvl=maxvl))
- self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
- self.maxvl = maxvl # type: int
-
- @property
- def generic(self):
- # type: () -> GenericOpProperties
- return self.kind.properties
-
- @property
- def immediates(self):
- # type: () -> tuple[range, ...]
- return self.generic.immediates
-
- @property
- def demo_asm(self):
- # type: () -> str
- return self.generic.demo_asm
-
- @property
- def is_copy(self):
- # type: () -> bool
- return self.generic.is_copy
-
- @property
- def is_load_immediate(self):
- # type: () -> bool
- return self.generic.is_load_immediate
-
- @property
- def has_side_effects(self):
- # type: () -> bool
- return self.generic.has_side_effects
-
-
-IMM_S16 = range(-1 << 15, 1 << 15)
-
-_SIM_FN = Callable[["Op", "BaseSimState"], None]
-_SIM_FN2 = Callable[[], _SIM_FN]
-_SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
-_GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
-_GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
-_GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
-
-
-@unique
-@final
-class OpKind(Enum):
- def __init__(self, properties):
- # type: (GenericOpProperties) -> None
- super().__init__()
- self.__properties = properties
-
- @property
- def properties(self):
- # type: () -> GenericOpProperties
- return self.__properties
-
- def instantiate(self, maxvl):
- # type: (int) -> OpProperties
- return OpProperties(self, maxvl=maxvl)
-
- def __repr__(self):
- # type: () -> str
- return "OpKind." + self._name_
-
- @cached_property
- def sim(self):
- # type: () -> _SIM_FN
- return _SIM_FNS[self.properties]()
-
- @cached_property
- def gen_asm(self):
- # type: () -> _GEN_ASM_FN
- return _GEN_ASMS[self.properties]()
-
- @staticmethod
- def __clearca_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = False,
-
- @staticmethod
- def __clearca_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
- ClearCA = GenericOpProperties(
- demo_asm="addic 0, 0, 0",
- inputs=[],
- outputs=[OD_CA.with_write_stage(OpStage.Late)],
- )
- _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
- _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
-
- @staticmethod
- def __setca_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = True,
-
- @staticmethod
- def __setca_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- state.writeln("subfc 0, 0, 0")
- SetCA = GenericOpProperties(
- demo_asm="subfc 0, 0, 0",
- inputs=[],
- outputs=[OD_CA.with_write_stage(OpStage.Late)],
- )
- _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
- _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
-
- @staticmethod
- def __svadde_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RA = state[op.input_vals[0]]
- RB = state[op.input_vals[1]]
- carry, = state[op.input_vals[2]]
- VL, = state[op.input_vals[3]]
- RT = [] # type: list[int]
- for i in range(VL):
- v = RA[i] + RB[i] + carry
- RT.append(v & GPR_VALUE_MASK)
- carry = (v >> GPR_SIZE_IN_BITS) != 0
- state[op.outputs[0]] = tuple(RT)
- state[op.outputs[1]] = carry,
-
- @staticmethod
- def __svadde_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RT = state.vgpr(op.outputs[0])
- RA = state.vgpr(op.input_vals[0])
- RB = state.vgpr(op.input_vals[1])
- state.writeln(f"sv.adde {RT}, {RA}, {RB}")
- SvAddE = GenericOpProperties(
- demo_asm="sv.adde *RT, *RA, *RB",
- inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
- outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
- )
- _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
- _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
-
- @staticmethod
- def __addze_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RA, = state[op.input_vals[0]]
- carry, = state[op.input_vals[1]]
- v = RA + carry
- RT = v & GPR_VALUE_MASK
- carry = (v >> GPR_SIZE_IN_BITS) != 0
- state[op.outputs[0]] = RT,
- state[op.outputs[1]] = carry,
-
- @staticmethod
- def __addze_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RT = state.vgpr(op.outputs[0])
- RA = state.vgpr(op.input_vals[0])
- state.writeln(f"addze {RT}, {RA}")
- AddZE = GenericOpProperties(
- demo_asm="addze RT, RA",
- inputs=[OD_BASE_SGPR, OD_CA],
- outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
- )
- _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
- _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
-
- @staticmethod
- def __svsubfe_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RA = state[op.input_vals[0]]
- RB = state[op.input_vals[1]]
- carry, = state[op.input_vals[2]]
- VL, = state[op.input_vals[3]]
- RT = [] # type: list[int]
- for i in range(VL):
- v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
- RT.append(v & GPR_VALUE_MASK)
- carry = (v >> GPR_SIZE_IN_BITS) != 0
- state[op.outputs[0]] = tuple(RT)
- state[op.outputs[1]] = carry,
-
- @staticmethod
- def __svsubfe_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RT = state.vgpr(op.outputs[0])
- RA = state.vgpr(op.input_vals[0])
- RB = state.vgpr(op.input_vals[1])
- state.writeln(f"sv.subfe {RT}, {RA}, {RB}")
- SvSubFE = GenericOpProperties(
- demo_asm="sv.subfe *RT, *RA, *RB",
- inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
- outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
- )
- _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
- _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
-
- @staticmethod
- def __svmaddedu_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RA = state[op.input_vals[0]]
- RB, = state[op.input_vals[1]]
- carry, = state[op.input_vals[2]]
- VL, = state[op.input_vals[3]]
- RT = [] # type: list[int]
- for i in range(VL):
- v = RA[i] * RB + carry
- RT.append(v & GPR_VALUE_MASK)
- carry = v >> GPR_SIZE_IN_BITS
- state[op.outputs[0]] = tuple(RT)
- state[op.outputs[1]] = carry,
-
- @staticmethod
- def __svmaddedu_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RT = state.vgpr(op.outputs[0])
- RA = state.vgpr(op.input_vals[0])
- RB = state.sgpr(op.input_vals[1])
- RC = state.sgpr(op.input_vals[2])
- state.writeln(f"sv.maddedu {RT}, {RA}, {RB}, {RC}")
- SvMAddEDU = GenericOpProperties(
- demo_asm="sv.maddedu *RT, *RA, RB, RC",
- inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
- outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
- )
- _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
- _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
-
- @staticmethod
- def __setvli_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = op.immediates[0],
-
- @staticmethod
- def __setvli_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- imm = op.immediates[0]
- state.writeln(f"setvl 0, 0, {imm}, 0, 1, 1")
- SetVLI = GenericOpProperties(
- demo_asm="setvl 0, 0, imm, 0, 1, 1",
- inputs=(),
- outputs=[OD_VL.with_write_stage(OpStage.Late)],
- immediates=[range(1, 65)],
- is_load_immediate=True,
- )
- _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
- _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
-
- @staticmethod
- def __svli_sim(op, state):
- # type: (Op, BaseSimState) -> None
- VL, = state[op.input_vals[0]]
- imm = op.immediates[0] & GPR_VALUE_MASK
- state[op.outputs[0]] = (imm,) * VL
-
- @staticmethod
- def __svli_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RT = state.vgpr(op.outputs[0])
- imm = op.immediates[0]
- state.writeln(f"sv.addi {RT}, 0, {imm}")
- SvLI = GenericOpProperties(
- demo_asm="sv.addi *RT, 0, imm",
- inputs=[OD_VL],
- outputs=[OD_EXTRA3_VGPR],
- immediates=[IMM_S16],
- is_load_immediate=True,
- )
- _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
- _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
-
- @staticmethod
- def __li_sim(op, state):
- # type: (Op, BaseSimState) -> None
- imm = op.immediates[0] & GPR_VALUE_MASK
- state[op.outputs[0]] = imm,
-
- @staticmethod
- def __li_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RT = state.sgpr(op.outputs[0])
- imm = op.immediates[0]
- state.writeln(f"addi {RT}, 0, {imm}")
- LI = GenericOpProperties(
- demo_asm="addi RT, 0, imm",
- inputs=(),
- outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
- immediates=[IMM_S16],
- is_load_immediate=True,
- )
- _SIM_FNS[LI] = lambda: OpKind.__li_sim
- _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
-
- @staticmethod
- def __veccopytoreg_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = state[op.input_vals[0]]
-
- @staticmethod
- def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
- # type: (Loc, Loc, bool, GenAsmState) -> None
- sv = "sv." if is_vec else ""
- rev = ""
- if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
- rev = "/mrr"
- if src_loc == dest_loc:
- return # no-op
- if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
- raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
- if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
- raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
- if src_loc.kind is LocKind.StackI64:
- if dest_loc.kind is LocKind.StackI64:
- raise ValueError(
- f"can't copy from stack to stack: {src_loc} {dest_loc}")
- elif dest_loc.kind is not LocKind.GPR:
- assert_never(dest_loc.kind)
- src = state.stack(src_loc)
- dest = state.gpr(dest_loc, is_vec=is_vec)
- state.writeln(f"{sv}ld {dest}, {src}")
- elif dest_loc.kind is LocKind.StackI64:
- if src_loc.kind is not LocKind.GPR:
- assert_never(src_loc.kind)
- src = state.gpr(src_loc, is_vec=is_vec)
- dest = state.stack(dest_loc)
- state.writeln(f"{sv}std {src}, {dest}")
- elif src_loc.kind is LocKind.GPR:
- if dest_loc.kind is not LocKind.GPR:
- assert_never(dest_loc.kind)
- src = state.gpr(src_loc, is_vec=is_vec)
- dest = state.gpr(dest_loc, is_vec=is_vec)
- state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
- else:
- assert_never(src_loc.kind)
-
- @staticmethod
- def __veccopytoreg_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- OpKind.__copy_to_from_reg_gen_asm(
- src_loc=state.loc(
- op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
- dest_loc=state.loc(op.outputs[0], LocKind.GPR),
- is_vec=True, state=state)
-
- VecCopyToReg = GenericOpProperties(
- demo_asm="sv.mv dest, src",
- inputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=True),
- sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
- ), OD_VL],
- outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
- is_copy=True,
- )
- _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
- _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
-
- @staticmethod
- def __veccopyfromreg_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = state[op.input_vals[0]]
-
- @staticmethod
- def __veccopyfromreg_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- OpKind.__copy_to_from_reg_gen_asm(
- src_loc=state.loc(op.input_vals[0], LocKind.GPR),
- dest_loc=state.loc(
- op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
- is_vec=True, state=state)
- VecCopyFromReg = GenericOpProperties(
- demo_asm="sv.mv dest, src",
- inputs=[OD_EXTRA3_VGPR, OD_VL],
- outputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=True),
- sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
- write_stage=OpStage.Late,
- )],
- is_copy=True,
- )
- _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
- _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
-
- @staticmethod
- def __copytoreg_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = state[op.input_vals[0]]
-
- @staticmethod
- def __copytoreg_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- OpKind.__copy_to_from_reg_gen_asm(
- src_loc=state.loc(
- op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
- dest_loc=state.loc(op.outputs[0], LocKind.GPR),
- is_vec=False, state=state)
- CopyToReg = GenericOpProperties(
- demo_asm="mv dest, src",
- inputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
- LocSubKind.StackI64],
- )],
- outputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
- write_stage=OpStage.Late,
- )],
- is_copy=True,
- )
- _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
- _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
-
- @staticmethod
- def __copyfromreg_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = state[op.input_vals[0]]
-
- @staticmethod
- def __copyfromreg_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- OpKind.__copy_to_from_reg_gen_asm(
- src_loc=state.loc(op.input_vals[0], LocKind.GPR),
- dest_loc=state.loc(
- op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
- is_vec=False, state=state)
- CopyFromReg = GenericOpProperties(
- demo_asm="mv dest, src",
- inputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
- )],
- outputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
- LocSubKind.StackI64],
- write_stage=OpStage.Late,
- )],
- is_copy=True,
- )
- _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
- _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
-
- @staticmethod
- def __concat_sim(op, state):
- # type: (Op, BaseSimState) -> None
- state[op.outputs[0]] = tuple(
- state[i][0] for i in op.input_vals[:-1])
-
- @staticmethod
- def __concat_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- OpKind.__copy_to_from_reg_gen_asm(
- src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
- dest_loc=state.loc(op.outputs[0], LocKind.GPR),
- is_vec=True, state=state)
- Concat = GenericOpProperties(
- demo_asm="sv.mv dest, src",
- inputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
- spread=True,
- ), OD_VL],
- outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
- is_copy=True,
- )
- _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
- _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
-
- @staticmethod
- def __spread_sim(op, state):
- # type: (Op, BaseSimState) -> None
- for idx, inp in enumerate(state[op.input_vals[0]]):
- state[op.outputs[idx]] = inp,
-
- @staticmethod
- def __spread_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- OpKind.__copy_to_from_reg_gen_asm(
- src_loc=state.loc(op.input_vals[0], LocKind.GPR),
- dest_loc=state.loc(op.outputs, LocKind.GPR),
- is_vec=True, state=state)
- Spread = GenericOpProperties(
- demo_asm="sv.mv dest, src",
- inputs=[OD_EXTRA3_VGPR, OD_VL],
- outputs=[GenericOperandDesc(
- ty=GenericTy(BaseTy.I64, is_vec=False),
- sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
- spread=True,
- write_stage=OpStage.Late,
- )],
- is_copy=True,
- )
- _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
- _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
-
- @staticmethod
- def __svld_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RA, = state[op.input_vals[0]]
- VL, = state[op.input_vals[1]]
- addr = RA + op.immediates[0]
- RT = [] # type: list[int]
- for i in range(VL):
- v = state.load(addr + GPR_SIZE_IN_BYTES * i)
- RT.append(v & GPR_VALUE_MASK)
- state[op.outputs[0]] = tuple(RT)
-
- @staticmethod
- def __svld_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RA = state.sgpr(op.input_vals[0])
- RT = state.vgpr(op.outputs[0])
- imm = op.immediates[0]
- state.writeln(f"sv.ld {RT}, {imm}({RA})")
- SvLd = GenericOpProperties(
- demo_asm="sv.ld *RT, imm(RA)",
- inputs=[OD_EXTRA3_SGPR, OD_VL],
- outputs=[OD_EXTRA3_VGPR],
- immediates=[IMM_S16],
- )
- _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
- _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
-
- @staticmethod
- def __ld_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RA, = state[op.input_vals[0]]
- addr = RA + op.immediates[0]
- v = state.load(addr)
- state[op.outputs[0]] = v & GPR_VALUE_MASK,
-
- @staticmethod
- def __ld_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RA = state.sgpr(op.input_vals[0])
- RT = state.sgpr(op.outputs[0])
- imm = op.immediates[0]
- state.writeln(f"ld {RT}, {imm}({RA})")
- Ld = GenericOpProperties(
- demo_asm="ld RT, imm(RA)",
- inputs=[OD_BASE_SGPR],
- outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
- immediates=[IMM_S16],
- )
- _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
- _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
-
- @staticmethod
- def __svstd_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RS = state[op.input_vals[0]]
- RA, = state[op.input_vals[1]]
- VL, = state[op.input_vals[2]]
- addr = RA + op.immediates[0]
- for i in range(VL):
- state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
-
- @staticmethod
- def __svstd_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RS = state.vgpr(op.input_vals[0])
- RA = state.sgpr(op.input_vals[1])
- imm = op.immediates[0]
- state.writeln(f"sv.std {RS}, {imm}({RA})")
- SvStd = GenericOpProperties(
- demo_asm="sv.std *RS, imm(RA)",
- inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
- outputs=[],
- immediates=[IMM_S16],
- has_side_effects=True,
- )
- _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
- _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
-
- @staticmethod
- def __std_sim(op, state):
- # type: (Op, BaseSimState) -> None
- RS, = state[op.input_vals[0]]
- RA, = state[op.input_vals[1]]
- addr = RA + op.immediates[0]
- state.store(addr, value=RS)
-
- @staticmethod
- def __std_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- RS = state.sgpr(op.input_vals[0])
- RA = state.sgpr(op.input_vals[1])
- imm = op.immediates[0]
- state.writeln(f"std {RS}, {imm}({RA})")
- Std = GenericOpProperties(
- demo_asm="std RS, imm(RA)",
- inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
- outputs=[],
- immediates=[IMM_S16],
- has_side_effects=True,
- )
- _SIM_FNS[Std] = lambda: OpKind.__std_sim
- _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
-
- @staticmethod
- def __funcargr3_sim(op, state):
- # type: (Op, BaseSimState) -> None
- pass # return value set before simulation
-
- @staticmethod
- def __funcargr3_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- pass # no instructions needed
- FuncArgR3 = GenericOpProperties(
- demo_asm="",
- inputs=[],
- outputs=[OD_BASE_SGPR.with_fixed_loc(
- Loc(kind=LocKind.GPR, start=3, reg_len=1))],
- )
- _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
- _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-class SSAValOrUse(metaclass=InternedMeta):
- __slots__ = "op", "operand_idx"
-
- def __init__(self, op, operand_idx):
- # type: (Op, int) -> None
- super().__init__()
- self.op = op
- if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
- raise ValueError("invalid operand_idx")
- self.operand_idx = operand_idx
-
- @abstractmethod
- def __repr__(self):
- # type: () -> str
- ...
-
- @property
- @abstractmethod
- def descriptor_array(self):
- # type: () -> tuple[OperandDesc, ...]
- ...
-
- @cached_property
- def defining_descriptor(self):
- # type: () -> OperandDesc
- return self.descriptor_array[self.operand_idx]
-
- @cached_property
- def ty(self):
- # type: () -> Ty
- return self.defining_descriptor.ty
-
- @cached_property
- def ty_before_spread(self):
- # type: () -> Ty
- return self.defining_descriptor.ty_before_spread
-
- @property
- def base_ty(self):
- # type: () -> BaseTy
- return self.ty_before_spread.base_ty
-
- @property
- def reg_offset_in_unspread(self):
- """ the number of reg-sized slots in the unspread Loc before self's Loc
-
- e.g. if the unspread Loc containing self is:
- `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
- and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
- then reg_offset_into_unspread == 2 == 10 - 8
- """
- return self.defining_descriptor.reg_offset_in_unspread
-
- @property
- def unspread_start_idx(self):
- # type: () -> int
- return self.operand_idx - (self.defining_descriptor.spread_index or 0)
-
- @property
- def unspread_start(self):
- # type: () -> Self
- return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-@final
-class SSAVal(SSAValOrUse):
- __slots__ = ()
-
- def __repr__(self):
- # type: () -> str
- return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
-
- @cached_property
- def def_loc_set_before_spread(self):
- # type: () -> LocSet
- return self.defining_descriptor.loc_set_before_spread
-
- @cached_property
- def descriptor_array(self):
- # type: () -> tuple[OperandDesc, ...]
- return self.op.properties.outputs
-
- @cached_property
- def tied_input(self):
- # type: () -> None | SSAUse
- if self.defining_descriptor.tied_input_index is None:
- return None
- return SSAUse(op=self.op,
- operand_idx=self.defining_descriptor.tied_input_index)
-
- @property
- def write_stage(self):
- # type: () -> OpStage
- return self.defining_descriptor.write_stage
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-@final
-class SSAUse(SSAValOrUse):
- __slots__ = ()
-
- @cached_property
- def use_loc_set_before_spread(self):
- # type: () -> LocSet
- return self.defining_descriptor.loc_set_before_spread
-
- @cached_property
- def descriptor_array(self):
- # type: () -> tuple[OperandDesc, ...]
- return self.op.properties.inputs
-
- def __repr__(self):
- # type: () -> str
- return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
-
- @property
- def ssa_val(self):
- # type: () -> SSAVal
- return self.op.input_vals[self.operand_idx]
-
- @ssa_val.setter
- def ssa_val(self, ssa_val):
- # type: (SSAVal) -> None
- self.op.input_vals[self.operand_idx] = ssa_val
-
-
-_T = TypeVar("_T")
-_Desc = TypeVar("_Desc")
-
-
-class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
- @abstractmethod
- def _verify_write_with_desc(self, idx, item, desc):
- # type: (int, _T | Any, _Desc) -> None
- raise NotImplementedError
-
- @final
- def _verify_write(self, idx, item):
- # type: (int | Any, _T | Any) -> int
- if not isinstance(idx, int):
- if isinstance(idx, slice):
- raise TypeError(
- f"can't write to slice of {self.__class__.__name__}")
- raise TypeError(f"can't write with index {idx!r}")
- # normalize idx, raising IndexError if it is out of range
- idx = range(len(self.descriptors))[idx]
- desc = self.descriptors[idx]
- self._verify_write_with_desc(idx, item, desc)
- return idx
-
- def _on_set(self, idx, new_item, old_item):
- # type: (int, _T, _T | None) -> None
- pass
-
- @abstractmethod
- def _get_descriptors(self):
- # type: () -> tuple[_Desc, ...]
- raise NotImplementedError
-
- @cached_property
- @final
- def descriptors(self):
- # type: () -> tuple[_Desc, ...]
- return self._get_descriptors()
-
- @property
- @final
- def op(self):
- return self.__op
-
- def __init__(self, items, op):
- # type: (Iterable[_T], Op) -> None
- super().__init__()
- self.__op = op
- self.__items = [] # type: list[_T]
- for idx, item in enumerate(items):
- if idx >= len(self.descriptors):
- raise ValueError("too many items")
- _ = self._verify_write(idx, item)
- self.__items.append(item)
- if len(self.__items) < len(self.descriptors):
- raise ValueError("not enough items")
-
- @final
- def __iter__(self):
- # type: () -> Iterator[_T]
- yield from self.__items
-
- @overload
- def __getitem__(self, idx):
- # type: (int) -> _T
- ...
-
- @overload
- def __getitem__(self, idx):
- # type: (slice) -> list[_T]
- ...
-
- @final
- def __getitem__(self, idx):
- # type: (int | slice) -> _T | list[_T]
- return self.__items[idx]
-
- @final
- def __setitem__(self, idx, item):
- # type: (int, _T) -> None
- idx = self._verify_write(idx, item)
- self.__items[idx] = item
-
- @final
- def __len__(self):
- # type: () -> int
- return len(self.__items)
-
- def __repr__(self):
- # type: () -> str
- return f"{self.__class__.__name__}({self.__items}, op=...)"
-
-
-@final
-class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
- def _get_descriptors(self):
- # type: () -> tuple[OperandDesc, ...]
- return self.op.properties.inputs
-
- def _verify_write_with_desc(self, idx, item, desc):
- # type: (int, SSAVal | Any, OperandDesc) -> None
- if not isinstance(item, SSAVal):
- raise TypeError("expected value of type SSAVal")
- if item.ty != desc.ty:
- raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
- f"corresponding input's type {desc.ty!r}")
-
- def _on_set(self, idx, new_item, old_item):
- # type: (int, SSAVal, SSAVal | None) -> None
- SSAUses._on_op_input_set(self, idx, new_item, old_item) # type: ignore
-
- def __init__(self, items, op):
- # type: (Iterable[SSAVal], Op) -> None
- if hasattr(op, "inputs"):
- raise ValueError("Op.inputs already set")
- super().__init__(items, op)
-
-
-@final
-class OpImmediates(OpInputSeq[int, range]):
- def _get_descriptors(self):
- # type: () -> tuple[range, ...]
- return self.op.properties.immediates
-
- def _verify_write_with_desc(self, idx, item, desc):
- # type: (int, int | Any, range) -> None
- if not isinstance(item, int):
- raise TypeError("expected value of type int")
- if item not in desc:
- raise ValueError(f"immediate value {item!r} not in {desc!r}")
-
- def __init__(self, items, op):
- # type: (Iterable[int], Op) -> None
- if hasattr(op, "immediates"):
- raise ValueError("Op.immediates already set")
- super().__init__(items, op)
-
-
-@plain_data(frozen=True, eq=False, repr=False)
-@final
-class Op:
- __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
- "outputs", "name")
-
- def __init__(self, fn, properties, input_vals, immediates, name=""):
- # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
- self.fn = fn
- self.properties = properties
- self.input_vals = OpInputVals(input_vals, op=self)
- inputs_len = len(self.properties.inputs)
- self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
- self.immediates = OpImmediates(immediates, op=self)
- outputs_len = len(self.properties.outputs)
- self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
- self.name = fn._add_op_with_unused_name(self, name) # type: ignore
-
- @property
- def kind(self):
- # type: () -> OpKind
- return self.properties.kind
-
- def __eq__(self, other):
- # type: (Op | Any) -> bool
- if isinstance(other, Op):
- return self is other
- return NotImplemented
-
- def __hash__(self):
- # type: () -> int
- return object.__hash__(self)
-
- def __repr__(self):
- # type: () -> str
- field_vals = [] # type: list[str]
- for name in fields(self):
- if name == "properties":
- name = "kind"
- elif name == "fn":
- continue
- try:
- value = getattr(self, name)
- except AttributeError:
- field_vals.append(f"{name}=<not set>")
- continue
- if isinstance(value, OpInputSeq):
- value = list(value) # type: ignore
- field_vals.append(f"{name}={value!r}")
- field_vals_str = ", ".join(field_vals)
- return f"Op({field_vals_str})"
-
- def sim(self, state):
- # type: (BaseSimState) -> None
- for inp in self.input_vals:
- try:
- val = state[inp]
- except KeyError:
- raise ValueError(f"SSAVal {inp} not yet assigned when "
- f"running {self}")
- if len(val) != inp.ty.reg_len:
- raise ValueError(
- f"value of SSAVal {inp} has wrong number of elements: "
- f"expected {inp.ty.reg_len} found "
- f"{len(val)}: {val!r}")
- if isinstance(state, PreRASimState):
- for out in self.outputs:
- if out in state.ssa_vals:
- if self.kind is OpKind.FuncArgR3:
- continue
- raise ValueError(f"SSAVal {out} already assigned before "
- f"running {self}")
- self.kind.sim(self, state)
- for out in self.outputs:
- try:
- val = state[out]
- except KeyError:
- raise ValueError(f"running {self} failed to assign to {out}")
- if len(val) != out.ty.reg_len:
- raise ValueError(
- f"value of SSAVal {out} has wrong number of elements: "
- f"expected {out.ty.reg_len} found "
- f"{len(val)}: {val!r}")
-
- def gen_asm(self, state):
- # type: (GenAsmState) -> None
- all_loc_kinds = tuple(LocKind)
- for inp in self.input_vals:
- state.loc(inp, expected_kinds=all_loc_kinds)
- for out in self.outputs:
- state.loc(out, expected_kinds=all_loc_kinds)
- self.kind.gen_asm(self, state)
-
-
-GPR_SIZE_IN_BYTES = 8
-BITS_IN_BYTE = 8
-GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * BITS_IN_BYTE
-GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
-
-
-@plain_data(frozen=True, repr=False)
-class BaseSimState(metaclass=ABCMeta):
- __slots__ = "memory",
-
- def __init__(self, memory):
- # type: (dict[int, int]) -> None
- super().__init__()
- self.memory = memory # type: dict[int, int]
-
- def load_byte(self, addr):
- # type: (int) -> int
- addr &= GPR_VALUE_MASK
- return self.memory.get(addr, 0) & 0xFF
-
- def store_byte(self, addr, value):
- # type: (int, int) -> None
- addr &= GPR_VALUE_MASK
- value &= 0xFF
- self.memory[addr] = value
-
- def load(self, addr, size_in_bytes=GPR_SIZE_IN_BYTES, signed=False):
- # type: (int, int, bool) -> int
- if addr % size_in_bytes != 0:
- raise ValueError(f"address not aligned: {hex(addr)} "
- f"required alignment: {size_in_bytes}")
- retval = 0
- for i in range(size_in_bytes):
- retval |= self.load_byte(addr + i) << i * BITS_IN_BYTE
- if signed and retval >> (size_in_bytes * BITS_IN_BYTE - 1) != 0:
- retval -= 1 << size_in_bytes * BITS_IN_BYTE
- return retval
-
- def store(self, addr, value, size_in_bytes=GPR_SIZE_IN_BYTES):
- # type: (int, int, int) -> None
- if addr % size_in_bytes != 0:
- raise ValueError(f"address not aligned: {hex(addr)} "
- f"required alignment: {size_in_bytes}")
- for i in range(size_in_bytes):
- self.store_byte(addr + i, (value >> i * BITS_IN_BYTE) & 0xFF)
-
- def _memory__repr(self):
- # type: () -> str
- if len(self.memory) == 0:
- return "{}"
- keys = sorted(self.memory.keys(), reverse=True)
- CHUNK_SIZE = GPR_SIZE_IN_BYTES
- items = [] # type: list[str]
- while len(keys) != 0:
- addr = keys[-1]
- if (len(keys) >= CHUNK_SIZE
- and addr % CHUNK_SIZE == 0
- and keys[-CHUNK_SIZE:]
- == list(reversed(range(addr, addr + CHUNK_SIZE)))):
- value = self.load(addr, size_in_bytes=CHUNK_SIZE)
- items.append(f"0x{addr:05x}: <0x{value:0{CHUNK_SIZE * 2}x}>")
- keys[-CHUNK_SIZE:] = ()
- else:
- items.append(f"0x{addr:05x}: 0x{self.memory[keys.pop()]:02x}")
- if len(items) == 1:
- return f"{{{items[0]}}}"
- items_str = ",\n".join(items)
- return f"{{\n{items_str}}}"
-
- def __repr__(self):
- # type: () -> str
- field_vals = [] # type: list[str]
- for name in fields(self):
- try:
- value = getattr(self, name)
- except AttributeError:
- field_vals.append(f"{name}=<not set>")
- continue
- repr_fn = getattr(self, f"_{name}__repr", None)
- if callable(repr_fn):
- field_vals.append(f"{name}={repr_fn()}")
- else:
- field_vals.append(f"{name}={value!r}")
- field_vals_str = ", ".join(field_vals)
- return f"{self.__class__.__name__}({field_vals_str})"
-
- @abstractmethod
- def __getitem__(self, ssa_val):
- # type: (SSAVal) -> tuple[int, ...]
- ...
-
- @abstractmethod
- def __setitem__(self, ssa_val, value):
- # type: (SSAVal, tuple[int, ...]) -> None
- ...
-
-
-@plain_data(frozen=True, repr=False)
-@final
-class PreRASimState(BaseSimState):
- __slots__ = "ssa_vals",
-
- def __init__(self, ssa_vals, memory):
- # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
- super().__init__(memory)
- self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
-
- def _ssa_vals__repr(self):
- # type: () -> str
- if len(self.ssa_vals) == 0:
- return "{}"
- items = [] # type: list[str]
- CHUNK_SIZE = 4
- for k, v in self.ssa_vals.items():
- element_strs = [] # type: list[str]
- for i, el in enumerate(v):
- if i % CHUNK_SIZE != 0:
- element_strs.append(" " + hex(el))
- else:
- element_strs.append("\n " + hex(el))
- if len(element_strs) <= CHUNK_SIZE:
- element_strs[0] = element_strs[0].lstrip()
- if len(element_strs) == 1:
- element_strs.append("")
- v_str = ",".join(element_strs)
- items.append(f"{k!r}: ({v_str})")
- if len(items) == 1 and "\n" not in items[0]:
- return f"{{{items[0]}}}"
- items_str = ",\n".join(items)
- return f"{{\n{items_str},\n}}"
-
- def __getitem__(self, ssa_val):
- # type: (SSAVal) -> tuple[int, ...]
- return self.ssa_vals[ssa_val]
-
- def __setitem__(self, ssa_val, value):
- # type: (SSAVal, tuple[int, ...]) -> None
- if len(value) != ssa_val.ty.reg_len:
- raise ValueError("value has wrong len")
- self.ssa_vals[ssa_val] = value
-
-
-@plain_data(frozen=True, repr=False)
-@final
-class PostRASimState(BaseSimState):
- __slots__ = "ssa_val_to_loc_map", "loc_values"
-
- def __init__(self, ssa_val_to_loc_map, memory, loc_values):
- # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
- super().__init__(memory)
- self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
- for ssa_val, loc in self.ssa_val_to_loc_map.items():
- if ssa_val.ty != loc.ty:
- raise ValueError(
- f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
- self.loc_values = loc_values
- for loc in self.loc_values.keys():
- if loc.reg_len != 1:
- raise ValueError(
- "loc_values must only contain Locs with reg_len=1, all "
- "larger Locs will be split into reg_len=1 sub-Locs")
-
- def _loc_values__repr(self):
- # type: () -> str
- locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start))
- items = [] # type: list[str]
- for loc in locs:
- items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
- items_str = ",\n".join(items)
- return f"{{\n{items_str},\n}}"
-
- def __getitem__(self, ssa_val):
- # type: (SSAVal) -> tuple[int, ...]
- loc = self.ssa_val_to_loc_map[ssa_val]
- subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
- retval = [] # type: list[int]
- for i in range(loc.reg_len):
- subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
- retval.append(self.loc_values.get(subloc, 0))
- return tuple(retval)
-
- def __setitem__(self, ssa_val, value):
- # type: (SSAVal, tuple[int, ...]) -> None
- if len(value) != ssa_val.ty.reg_len:
- raise ValueError("value has wrong len")
- loc = self.ssa_val_to_loc_map[ssa_val]
- subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
- for i in range(loc.reg_len):
- subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
- self.loc_values[subloc] = value[i]
-
-
-@plain_data(frozen=True)
-class GenAsmState:
- __slots__ = "allocated_locs", "output"
-
- def __init__(self, allocated_locs, output=None):
- # type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
- super().__init__()
- self.allocated_locs = FMap(allocated_locs)
- for ssa_val, loc in self.allocated_locs.items():
- if ssa_val.ty != loc.ty:
- raise ValueError(
- f"Ty mismatch: ssa_val.ty:{ssa_val.ty} != loc.ty:{loc.ty}")
- if output is None:
- output = []
- self.output = output
-
- __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
-
- def loc(self, ssa_val_or_locs, expected_kinds):
- # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
- if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
- ssa_val_or_locs = [ssa_val_or_locs]
- locs = [] # type: list[Loc]
- for i in ssa_val_or_locs:
- if isinstance(i, SSAVal):
- locs.append(self.allocated_locs[i])
- else:
- locs.append(i)
- if len(locs) == 0:
- raise ValueError("invalid Loc sequence: must not be empty")
- retval = locs[0].try_concat(*locs[1:])
- if retval is None:
- raise ValueError("invalid Loc sequence: try_concat failed")
- if isinstance(expected_kinds, LocKind):
- expected_kinds = expected_kinds,
- if retval.kind not in expected_kinds:
- if len(expected_kinds) == 1:
- expected_kinds = expected_kinds[0]
- raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
- f"{retval.kind} expected {expected_kinds}")
- return retval
-
- def gpr(self, ssa_val_or_locs, is_vec):
- # type: (__SSA_VAL_OR_LOCS, bool) -> str
- loc = self.loc(ssa_val_or_locs, LocKind.GPR)
- vec_str = "*" if is_vec else ""
- return vec_str + str(loc.start)
-
- def sgpr(self, ssa_val_or_locs):
- # type: (__SSA_VAL_OR_LOCS) -> str
- return self.gpr(ssa_val_or_locs, is_vec=False)
-
- def vgpr(self, ssa_val_or_locs):
- # type: (__SSA_VAL_OR_LOCS) -> str
- return self.gpr(ssa_val_or_locs, is_vec=True)
-
- def stack(self, ssa_val_or_locs):
- # type: (__SSA_VAL_OR_LOCS) -> str
- loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
- return f"{loc.start}(1)"
-
- def writeln(self, *line_segments):
- # type: (*str) -> None
- line = " ".join(line_segments)
- if isinstance(self.output, list):
- self.output.append(line)
- else:
- self.output.write(line + "\n")
--- /dev/null
+"""
+Register Allocator for Toom-Cook algorithm generator for SVP64
+
+this uses an algorithm based on:
+[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+"""
+
+from itertools import combinations
+from typing import Iterable, Iterator, Mapping, TextIO
+
+from cached_property import cached_property
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.compiler_ir import (BaseTy, Fn, FnAnalysis, Loc,
+ LocSet, ProgramRange, SSAVal,
+ Ty)
+from bigint_presentation_code.type_util import final
+from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
+
+
+class BadMergedSSAVal(ValueError):
+ pass
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class MergedSSAVal(metaclass=InternedMeta):
+ """a set of `SSAVal`s along with their offsets, all register allocated as
+ a single unit.
+
+ Definition of the term `offset` for this class:
+
+ Let `locs[x]` be the `Loc` that `x` is assigned to after register
+ allocation and let `msv` be a `MergedSSAVal` instance, then the offset
+ for each `SSAVal` `ssa_val` in `msv` is defined as:
+
+ ```
+ msv.ssa_val_offsets[ssa_val] = (msv.offset
+ + locs[ssa_val].start - locs[msv].start)
+ ```
+
+ Example:
+ ```
+ v1.ty == <I64*4>
+ v2.ty == <I64*2>
+ v3.ty == <I64>
+ msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
+ msv.ty == <I64*6>
+ ```
+ if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
+ * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
+ * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
+ * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
+ """
+ __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
+
+ def __init__(self, fn_analysis, ssa_val_offsets):
+ # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
+ self.fn_analysis = fn_analysis
+ if isinstance(ssa_val_offsets, SSAVal):
+ ssa_val_offsets = {ssa_val_offsets: 0}
+ self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
+ first_ssa_val = None
+ for ssa_val in self.ssa_vals:
+ first_ssa_val = ssa_val
+ break
+ if first_ssa_val is None:
+ raise BadMergedSSAVal("MergedSSAVal can't be empty")
+ self.first_ssa_val = first_ssa_val # type: SSAVal
+ # self.ty checks for mismatched base_ty
+ reg_len = self.ty.reg_len
+ loc_set = None # type: None | LocSet
+ for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
+ def locs():
+ # type: () -> Iterable[Loc]
+ for loc in ssa_val.def_loc_set_before_spread:
+ disallowed_by_use = False
+ for use in fn_analysis.uses[ssa_val]:
+ # calculate the start for the use's Loc before spread
+ # e.g. if the def's Loc before spread starts at r6
+ # and the def's reg_offset_in_unspread is 5
+ # and the use's reg_offset_in_unspread is 3
+ # then the use's Loc before spread starts at r8
+ # because 8 == 6 + 5 - 3
+ start = (loc.start + ssa_val.reg_offset_in_unspread
+ - use.reg_offset_in_unspread)
+ use_loc = Loc.try_make(
+ loc.kind, start=start,
+ reg_len=use.ty_before_spread.reg_len)
+ if (use_loc is None or
+ use_loc not in use.use_loc_set_before_spread):
+ disallowed_by_use = True
+ break
+ if disallowed_by_use:
+ continue
+ # FIXME: add spread consistency check
+ start = loc.start - cur_offset + self.offset
+ loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
+ if loc is not None and (loc_set is None or loc in loc_set):
+ yield loc
+ loc_set = LocSet(locs())
+ assert loc_set is not None, "already checked that self isn't empty"
+ if loc_set.ty is None:
+ raise BadMergedSSAVal("there are no valid Locs left")
+ assert loc_set.ty == self.ty, "logic error somewhere"
+ self.loc_set = loc_set # type: LocSet
+
+ @cached_property
+ def __hash(self):
+ # type: () -> int
+ return hash((self.fn_analysis, self.ssa_val_offsets))
+
+ def __hash__(self):
+ # type: () -> int
+ return self.__hash
+
+ @cached_property
+ def offset(self):
+ # type: () -> int
+ return min(self.ssa_val_offsets_before_spread.values())
+
+ @property
+ def base_ty(self):
+ # type: () -> BaseTy
+ return self.first_ssa_val.base_ty
+
+ @cached_property
+ def ssa_vals(self):
+ # type: () -> OFSet[SSAVal]
+ return OFSet(self.ssa_val_offsets.keys())
+
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ reg_len = 0
+ for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
+ cur_ty = ssa_val.ty_before_spread
+ if self.base_ty != cur_ty.base_ty:
+ raise BadMergedSSAVal(
+ f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
+ reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
+ return Ty(base_ty=self.base_ty, reg_len=reg_len)
+
+ @cached_property
+ def ssa_val_offsets_before_spread(self):
+ # type: () -> FMap[SSAVal, int]
+ retval = {} # type: dict[SSAVal, int]
+ for ssa_val, offset in self.ssa_val_offsets.items():
+ retval[ssa_val] = (
+ offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
+ return FMap(retval)
+
+ def offset_by(self, amount):
+ # type: (int) -> MergedSSAVal
+ v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
+ return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
+
+ def normalized(self):
+ # type: () -> MergedSSAVal
+ return self.offset_by(-self.offset)
+
+ def with_offset_to_match(self, target, additional_offset=0):
+ # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
+ if isinstance(target, MergedSSAVal):
+ ssa_val_offsets = target.ssa_val_offsets
+ else:
+ ssa_val_offsets = {target: 0}
+ for ssa_val, offset in self.ssa_val_offsets.items():
+ if ssa_val in ssa_val_offsets:
+ return self.offset_by(
+ ssa_val_offsets[ssa_val] + additional_offset - offset)
+ raise ValueError("can't change offset to match unrelated MergedSSAVal")
+
+ def merged(self, *others):
+ # type: (*MergedSSAVal) -> MergedSSAVal
+ retval = dict(self.ssa_val_offsets)
+ for other in others:
+ if other.fn_analysis != self.fn_analysis:
+ raise ValueError("fn_analysis mismatch")
+ for ssa_val, offset in other.ssa_val_offsets.items():
+ if ssa_val in retval and retval[ssa_val] != offset:
+ raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
+ f"{retval[ssa_val]} != {offset}")
+ retval[ssa_val] = offset
+ return MergedSSAVal(fn_analysis=self.fn_analysis,
+ ssa_val_offsets=retval)
+
+ @cached_property
+ def live_interval(self):
+ # type: () -> ProgramRange
+ live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
+ start = live_range.start
+ stop = live_range.stop
+ for ssa_val in self.ssa_vals:
+ live_range = self.fn_analysis.live_ranges[ssa_val]
+ start = min(start, live_range.start)
+ stop = max(stop, live_range.stop)
+ return ProgramRange(start=start, stop=stop)
+
+ def __repr__(self):
+ return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
+ f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
+ f"live_interval={self.live_interval})")
+
+
+@final
+class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
+ def __init__(self):
+ # type: (...) -> None
+ self.__map = {} # type: dict[SSAVal, MergedSSAVal]
+ self.__ig_node_map = MergedSSAValToIGNodeMap(
+ _private_merged_ssa_val_map=self.__map)
+
+ def __getitem__(self, __key):
+ # type: (SSAVal) -> MergedSSAVal
+ return self.__map[__key]
+
+ def __iter__(self):
+ # type: () -> Iterator[SSAVal]
+ return iter(self.__map)
+
+ def __len__(self):
+ # type: () -> int
+ return len(self.__map)
+
+ @property
+ def ig_node_map(self):
+ # type: () -> MergedSSAValToIGNodeMap
+ return self.__ig_node_map
+
+ def __repr__(self):
+ # type: () -> str
+ s = ",\n".join(repr(v) for v in self.__ig_node_map)
+ return f"SSAValToMergedSSAValMap({{{s}}})"
+
+
+@final
+class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
+ def __init__(
+ self, *,
+ _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
+ ):
+ # type: (...) -> None
+ self.__merged_ssa_val_map = _private_merged_ssa_val_map
+ self.__map = {} # type: dict[MergedSSAVal, IGNode]
+
+ def __getitem__(self, __key):
+ # type: (MergedSSAVal) -> IGNode
+ return self.__map[__key]
+
+ def __iter__(self):
+ # type: () -> Iterator[MergedSSAVal]
+ return iter(self.__map)
+
+ def __len__(self):
+ # type: () -> int
+ return len(self.__map)
+
+ def add_node(self, merged_ssa_val):
+ # type: (MergedSSAVal) -> IGNode
+ node = self.__map.get(merged_ssa_val, None)
+ if node is not None:
+ return node
+ added = 0 # type: int | None
+ try:
+ for ssa_val in merged_ssa_val.ssa_vals:
+ if ssa_val in self.__merged_ssa_val_map:
+ raise ValueError(
+ f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
+ f"{merged_ssa_val} and "
+ f"{self.__merged_ssa_val_map[ssa_val]}")
+ self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
+ added += 1
+ retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
+ self.__map[merged_ssa_val] = retval
+ added = None
+ return retval
+ finally:
+ if added is not None:
+ # remove partially added stuff
+ for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
+ if idx >= added:
+ break
+ del self.__merged_ssa_val_map[ssa_val]
+
+ def merge_into_one_node(self, final_merged_ssa_val):
+ # type: (MergedSSAVal) -> IGNode
+ source_nodes = OSet() # type: OSet[IGNode]
+ edges = OSet() # type: OSet[IGNode]
+ loc = None # type: Loc | None
+ for ssa_val in final_merged_ssa_val.ssa_vals:
+ merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
+ source_node = self.__map[merged_ssa_val]
+ source_nodes.add(source_node)
+ for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
+ raise ValueError(
+ f"SSAVal {i} appears in source IGNode's merged_ssa_val "
+ f"but not in merged IGNode's merged_ssa_val: "
+ f"source_node={source_node} "
+ f"final_merged_ssa_val={final_merged_ssa_val}")
+ if loc is None:
+ loc = source_node.loc
+ elif source_node.loc is not None and loc != source_node.loc:
+ raise ValueError(f"can't merge IGNodes with mismatched `loc` "
+ f"values: {loc} != {source_node.loc}")
+ edges |= source_node.edges
+ if len(source_nodes) == 1:
+ return source_nodes.pop() # merging a single node is a no-op
+ # we're finished checking validity, now we can modify stuff
+ edges -= source_nodes
+ retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
+ loc=loc)
+ for node in edges:
+ node.edges -= source_nodes
+ node.edges.add(retval)
+ for node in source_nodes:
+ del self.__map[node.merged_ssa_val]
+ self.__map[final_merged_ssa_val] = retval
+ for ssa_val in final_merged_ssa_val.ssa_vals:
+ self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
+ return retval
+
+ def __repr__(self, repr_state=None):
+ # type: (None | IGNodeReprState) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
+ return f"MergedSSAValToIGNodeMap({{{s}}})"
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class InterferenceGraph:
+ __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
+
+ def __init__(self, fn_analysis, merged_ssa_vals):
+ # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
+ self.fn_analysis = fn_analysis
+ self.merged_ssa_val_map = SSAValToMergedSSAValMap()
+ self.nodes = self.merged_ssa_val_map.ig_node_map
+ for i in merged_ssa_vals:
+ self.nodes.add_node(i)
+
+ def merge(self, ssa_val1, ssa_val2, additional_offset=0):
+ # type: (SSAVal, SSAVal, int) -> IGNode
+ merged1 = self.merged_ssa_val_map[ssa_val1]
+ merged2 = self.merged_ssa_val_map[ssa_val2]
+ merged = merged1.with_offset_to_match(ssa_val1)
+ merged = merged.merged(merged2.with_offset_to_match(
+ ssa_val2, additional_offset=additional_offset))
+ return self.nodes.merge_into_one_node(merged)
+
+ @staticmethod
+ def minimally_merged(fn_analysis):
+ # type: (FnAnalysis) -> InterferenceGraph
+ retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
+ for op in fn_analysis.fn.ops:
+ for inp in op.input_uses:
+ if inp.unspread_start != inp:
+ retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
+ additional_offset=inp.reg_offset_in_unspread)
+ for out in op.outputs:
+ retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
+ if out.unspread_start != out:
+ retval.merge(out.unspread_start, out,
+ additional_offset=out.reg_offset_in_unspread)
+ if out.tied_input is not None:
+ retval.merge(out.tied_input.ssa_val, out)
+ return retval
+
+ def __repr__(self, repr_state=None):
+ # type: (None | IGNodeReprState) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ s = self.nodes.__repr__(repr_state)
+ return f"InterferenceGraph(nodes={s}, <...>)"
+
+
+@plain_data(repr=False)
+class IGNodeReprState:
+ __slots__ = "node_ids", "did_full_repr"
+
+ def __init__(self):
+ super().__init__()
+ self.node_ids = {} # type: dict[IGNode, int]
+ self.did_full_repr = OSet() # type: OSet[IGNode]
+
+
+@final
+class IGNode:
+ """ interference graph node """
+ __slots__ = "merged_ssa_val", "edges", "loc"
+
+ def __init__(self, merged_ssa_val, edges, loc):
+ # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
+ self.merged_ssa_val = merged_ssa_val
+ self.edges = OSet(edges)
+ self.loc = loc
+
+ def add_edge(self, other):
+ # type: (IGNode) -> None
+ self.edges.add(other)
+ other.edges.add(self)
+
+ def __eq__(self, other):
+ # type: (object) -> bool
+ if isinstance(other, IGNode):
+ return self.merged_ssa_val == other.merged_ssa_val
+ return NotImplemented
+
+ def __hash__(self):
+ # type: () -> int
+ return hash(self.merged_ssa_val)
+
+ def __repr__(self, repr_state=None, short=False):
+ # type: (None | IGNodeReprState, bool) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ node_id = repr_state.node_ids.get(self, None)
+ if node_id is None:
+ repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
+ if short or self in repr_state.did_full_repr:
+ return f"<IGNode #{node_id}>"
+ repr_state.did_full_repr.add(self)
+ edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
+ return (f"IGNode(#{node_id}, "
+ f"merged_ssa_val={self.merged_ssa_val}, "
+ f"edges={{{edges}}}, "
+ f"loc={self.loc})")
+
+ @property
+ def loc_set(self):
+ # type: () -> LocSet
+ return self.merged_ssa_val.loc_set
+
+ def loc_conflicts_with_neighbors(self, loc):
+ # type: (Loc) -> bool
+ for neighbor in self.edges:
+ if neighbor.loc is not None and neighbor.loc.conflicts(loc):
+ return True
+ return False
+
+
+class AllocationFailedError(Exception):
+ def __init__(self, msg, node, interference_graph):
+ # type: (str, IGNode, InterferenceGraph) -> None
+ super().__init__(msg, node, interference_graph)
+ self.node = node
+ self.interference_graph = interference_graph
+
+ def __repr__(self, repr_state=None):
+ # type: (None | IGNodeReprState) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ return (f"{__class__.__name__}({self.args[0]!r}, "
+ f"node={self.node.__repr__(repr_state, True)}, "
+ f"interference_graph="
+ f"{self.interference_graph.__repr__(repr_state)})")
+
+ def __str__(self):
+ # type: () -> str
+ return self.__repr__()
+
+
+def allocate_registers(fn, debug_out=None):
+ # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
+
+ # inserts enough copies that no manual spilling is necessary, all
+ # spilling is done by the register allocator naturally allocating SSAVals
+ # to stack slots
+ fn.pre_ra_insert_copies()
+
+ if debug_out is not None:
+ print(f"After pre_ra_insert_copies():\n{fn.ops}",
+ file=debug_out, flush=True)
+
+ fn_analysis = FnAnalysis(fn)
+ interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
+
+ if debug_out is not None:
+ print(f"After InterferenceGraph.minimally_merged():\n"
+ f"{interference_graph}", file=debug_out, flush=True)
+
+ for pp, ssa_vals in fn_analysis.live_at.items():
+ live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
+ for ssa_val in ssa_vals:
+ live_merged_ssa_vals.add(
+ interference_graph.merged_ssa_val_map[ssa_val])
+ for i, j in combinations(live_merged_ssa_vals, 2):
+ if i.loc_set.max_conflicts_with(j.loc_set) != 0:
+ interference_graph.nodes[i].add_edge(
+ interference_graph.nodes[j])
+ if debug_out is not None:
+ print(f"processed {pp} out of {fn_analysis.all_program_points}",
+ file=debug_out, flush=True)
+
+ if debug_out is not None:
+ print(f"After adding interference graph edges:\n"
+ f"{interference_graph}", file=debug_out, flush=True)
+
+ nodes_remaining = OSet(interference_graph.nodes.values())
+
+ def local_colorability_score(node):
+ # type: (IGNode) -> int
+ """ returns a positive integer if node is locally colorable, returns
+ zero or a negative integer if node isn't known to be locally
+ colorable, the more negative the value, the less colorable
+ """
+ if node not in nodes_remaining:
+ raise ValueError()
+ retval = len(node.loc_set)
+ for neighbor in node.edges:
+ if neighbor in nodes_remaining:
+ retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
+ return retval
+
+ # TODO: implement copy-merging
+
+ node_stack = [] # type: list[IGNode]
+ while True:
+ best_node = None # type: None | IGNode
+ best_score = 0
+ for node in nodes_remaining:
+ score = local_colorability_score(node)
+ if best_node is None or score > best_score:
+ best_node = node
+ best_score = score
+ if best_score > 0:
+ # it's locally colorable, no need to find a better one
+ break
+
+ if best_node is None:
+ break
+ node_stack.append(best_node)
+ nodes_remaining.remove(best_node)
+
+ if debug_out is not None:
+ print(f"After deciding node allocation order:\n"
+ f"{node_stack}", file=debug_out, flush=True)
+
+ retval = {} # type: dict[SSAVal, Loc]
+
+ while len(node_stack) > 0:
+ node = node_stack.pop()
+ if node.loc is not None:
+ if node.loc_conflicts_with_neighbors(node.loc):
+ raise AllocationFailedError(
+ "IGNode is pre-allocated to a conflicting Loc",
+ node=node, interference_graph=interference_graph)
+ else:
+ # pick the first non-conflicting register in node.reg_class, since
+ # register classes are ordered from most preferred to least
+ # preferred register.
+ for loc in node.loc_set:
+ if not node.loc_conflicts_with_neighbors(loc):
+ node.loc = loc
+ break
+ if node.loc is None:
+ raise AllocationFailedError(
+ "failed to allocate Loc for IGNode",
+ node=node, interference_graph=interference_graph)
+
+ if debug_out is not None:
+ print(f"After allocating Loc for node:\n{node}",
+ file=debug_out, flush=True)
+
+ for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
+ retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
+
+ if debug_out is not None:
+ print(f"final Locs for all SSAVals:\n{retval}",
+ file=debug_out, flush=True)
+
+ return retval
+++ /dev/null
-"""
-Register Allocator for Toom-Cook algorithm generator for SVP64
-
-this uses an algorithm based on:
-[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
-"""
-
-from itertools import combinations
-from typing import Iterable, Iterator, Mapping, TextIO
-
-from cached_property import cached_property
-from nmutil.plain_data import plain_data
-
-from bigint_presentation_code.compiler_ir2 import (BaseTy, Fn, FnAnalysis, Loc,
- LocSet, ProgramRange,
- SSAVal, Ty)
-from bigint_presentation_code.type_util import final
-from bigint_presentation_code.util import FMap, InternedMeta, OFSet, OSet
-
-
-class BadMergedSSAVal(ValueError):
- pass
-
-
-@plain_data(frozen=True, repr=False)
-@final
-class MergedSSAVal(metaclass=InternedMeta):
- """a set of `SSAVal`s along with their offsets, all register allocated as
- a single unit.
-
- Definition of the term `offset` for this class:
-
- Let `locs[x]` be the `Loc` that `x` is assigned to after register
- allocation and let `msv` be a `MergedSSAVal` instance, then the offset
- for each `SSAVal` `ssa_val` in `msv` is defined as:
-
- ```
- msv.ssa_val_offsets[ssa_val] = (msv.offset
- + locs[ssa_val].start - locs[msv].start)
- ```
-
- Example:
- ```
- v1.ty == <I64*4>
- v2.ty == <I64*2>
- v3.ty == <I64>
- msv = MergedSSAVal({v1: 0, v2: 4, v3: 1})
- msv.ty == <I64*6>
- ```
- if `msv` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=6)`, then
- * `v1` is allocated to `Loc(kind=LocKind.GPR, start=20, reg_len=4)`
- * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
- * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
- """
- __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
-
- def __init__(self, fn_analysis, ssa_val_offsets):
- # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
- self.fn_analysis = fn_analysis
- if isinstance(ssa_val_offsets, SSAVal):
- ssa_val_offsets = {ssa_val_offsets: 0}
- self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
- first_ssa_val = None
- for ssa_val in self.ssa_vals:
- first_ssa_val = ssa_val
- break
- if first_ssa_val is None:
- raise BadMergedSSAVal("MergedSSAVal can't be empty")
- self.first_ssa_val = first_ssa_val # type: SSAVal
- # self.ty checks for mismatched base_ty
- reg_len = self.ty.reg_len
- loc_set = None # type: None | LocSet
- for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
- def locs():
- # type: () -> Iterable[Loc]
- for loc in ssa_val.def_loc_set_before_spread:
- disallowed_by_use = False
- for use in fn_analysis.uses[ssa_val]:
- # calculate the start for the use's Loc before spread
- # e.g. if the def's Loc before spread starts at r6
- # and the def's reg_offset_in_unspread is 5
- # and the use's reg_offset_in_unspread is 3
- # then the use's Loc before spread starts at r8
- # because 8 == 6 + 5 - 3
- start = (loc.start + ssa_val.reg_offset_in_unspread
- - use.reg_offset_in_unspread)
- use_loc = Loc.try_make(
- loc.kind, start=start,
- reg_len=use.ty_before_spread.reg_len)
- if (use_loc is None or
- use_loc not in use.use_loc_set_before_spread):
- disallowed_by_use = True
- break
- if disallowed_by_use:
- continue
- # FIXME: add spread consistency check
- start = loc.start - cur_offset + self.offset
- loc = Loc.try_make(loc.kind, start=start, reg_len=reg_len)
- if loc is not None and (loc_set is None or loc in loc_set):
- yield loc
- loc_set = LocSet(locs())
- assert loc_set is not None, "already checked that self isn't empty"
- if loc_set.ty is None:
- raise BadMergedSSAVal("there are no valid Locs left")
- assert loc_set.ty == self.ty, "logic error somewhere"
- self.loc_set = loc_set # type: LocSet
-
- @cached_property
- def __hash(self):
- # type: () -> int
- return hash((self.fn_analysis, self.ssa_val_offsets))
-
- def __hash__(self):
- # type: () -> int
- return self.__hash
-
- @cached_property
- def offset(self):
- # type: () -> int
- return min(self.ssa_val_offsets_before_spread.values())
-
- @property
- def base_ty(self):
- # type: () -> BaseTy
- return self.first_ssa_val.base_ty
-
- @cached_property
- def ssa_vals(self):
- # type: () -> OFSet[SSAVal]
- return OFSet(self.ssa_val_offsets.keys())
-
- @cached_property
- def ty(self):
- # type: () -> Ty
- reg_len = 0
- for ssa_val, offset in self.ssa_val_offsets_before_spread.items():
- cur_ty = ssa_val.ty_before_spread
- if self.base_ty != cur_ty.base_ty:
- raise BadMergedSSAVal(
- f"BaseTy mismatch: {self.base_ty} != {cur_ty.base_ty}")
- reg_len = max(reg_len, cur_ty.reg_len + offset - self.offset)
- return Ty(base_ty=self.base_ty, reg_len=reg_len)
-
- @cached_property
- def ssa_val_offsets_before_spread(self):
- # type: () -> FMap[SSAVal, int]
- retval = {} # type: dict[SSAVal, int]
- for ssa_val, offset in self.ssa_val_offsets.items():
- retval[ssa_val] = (
- offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
- return FMap(retval)
-
- def offset_by(self, amount):
- # type: (int) -> MergedSSAVal
- v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
- return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
-
- def normalized(self):
- # type: () -> MergedSSAVal
- return self.offset_by(-self.offset)
-
- def with_offset_to_match(self, target, additional_offset=0):
- # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
- if isinstance(target, MergedSSAVal):
- ssa_val_offsets = target.ssa_val_offsets
- else:
- ssa_val_offsets = {target: 0}
- for ssa_val, offset in self.ssa_val_offsets.items():
- if ssa_val in ssa_val_offsets:
- return self.offset_by(
- ssa_val_offsets[ssa_val] + additional_offset - offset)
- raise ValueError("can't change offset to match unrelated MergedSSAVal")
-
- def merged(self, *others):
- # type: (*MergedSSAVal) -> MergedSSAVal
- retval = dict(self.ssa_val_offsets)
- for other in others:
- if other.fn_analysis != self.fn_analysis:
- raise ValueError("fn_analysis mismatch")
- for ssa_val, offset in other.ssa_val_offsets.items():
- if ssa_val in retval and retval[ssa_val] != offset:
- raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
- f"{retval[ssa_val]} != {offset}")
- retval[ssa_val] = offset
- return MergedSSAVal(fn_analysis=self.fn_analysis,
- ssa_val_offsets=retval)
-
- @cached_property
- def live_interval(self):
- # type: () -> ProgramRange
- live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
- start = live_range.start
- stop = live_range.stop
- for ssa_val in self.ssa_vals:
- live_range = self.fn_analysis.live_ranges[ssa_val]
- start = min(start, live_range.start)
- stop = max(stop, live_range.stop)
- return ProgramRange(start=start, stop=stop)
-
- def __repr__(self):
- return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
- f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
- f"live_interval={self.live_interval})")
-
-
-@final
-class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
- def __init__(self):
- # type: (...) -> None
- self.__map = {} # type: dict[SSAVal, MergedSSAVal]
- self.__ig_node_map = MergedSSAValToIGNodeMap(
- _private_merged_ssa_val_map=self.__map)
-
- def __getitem__(self, __key):
- # type: (SSAVal) -> MergedSSAVal
- return self.__map[__key]
-
- def __iter__(self):
- # type: () -> Iterator[SSAVal]
- return iter(self.__map)
-
- def __len__(self):
- # type: () -> int
- return len(self.__map)
-
- @property
- def ig_node_map(self):
- # type: () -> MergedSSAValToIGNodeMap
- return self.__ig_node_map
-
- def __repr__(self):
- # type: () -> str
- s = ",\n".join(repr(v) for v in self.__ig_node_map)
- return f"SSAValToMergedSSAValMap({{{s}}})"
-
-
-@final
-class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, "IGNode"]):
- def __init__(
- self, *,
- _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
- ):
- # type: (...) -> None
- self.__merged_ssa_val_map = _private_merged_ssa_val_map
- self.__map = {} # type: dict[MergedSSAVal, IGNode]
-
- def __getitem__(self, __key):
- # type: (MergedSSAVal) -> IGNode
- return self.__map[__key]
-
- def __iter__(self):
- # type: () -> Iterator[MergedSSAVal]
- return iter(self.__map)
-
- def __len__(self):
- # type: () -> int
- return len(self.__map)
-
- def add_node(self, merged_ssa_val):
- # type: (MergedSSAVal) -> IGNode
- node = self.__map.get(merged_ssa_val, None)
- if node is not None:
- return node
- added = 0 # type: int | None
- try:
- for ssa_val in merged_ssa_val.ssa_vals:
- if ssa_val in self.__merged_ssa_val_map:
- raise ValueError(
- f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
- f"{merged_ssa_val} and "
- f"{self.__merged_ssa_val_map[ssa_val]}")
- self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
- added += 1
- retval = IGNode(merged_ssa_val=merged_ssa_val, edges=(), loc=None)
- self.__map[merged_ssa_val] = retval
- added = None
- return retval
- finally:
- if added is not None:
- # remove partially added stuff
- for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
- if idx >= added:
- break
- del self.__merged_ssa_val_map[ssa_val]
-
- def merge_into_one_node(self, final_merged_ssa_val):
- # type: (MergedSSAVal) -> IGNode
- source_nodes = OSet() # type: OSet[IGNode]
- edges = OSet() # type: OSet[IGNode]
- loc = None # type: Loc | None
- for ssa_val in final_merged_ssa_val.ssa_vals:
- merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
- source_node = self.__map[merged_ssa_val]
- source_nodes.add(source_node)
- for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
- raise ValueError(
- f"SSAVal {i} appears in source IGNode's merged_ssa_val "
- f"but not in merged IGNode's merged_ssa_val: "
- f"source_node={source_node} "
- f"final_merged_ssa_val={final_merged_ssa_val}")
- if loc is None:
- loc = source_node.loc
- elif source_node.loc is not None and loc != source_node.loc:
- raise ValueError(f"can't merge IGNodes with mismatched `loc` "
- f"values: {loc} != {source_node.loc}")
- edges |= source_node.edges
- if len(source_nodes) == 1:
- return source_nodes.pop() # merging a single node is a no-op
- # we're finished checking validity, now we can modify stuff
- edges -= source_nodes
- retval = IGNode(merged_ssa_val=final_merged_ssa_val, edges=edges,
- loc=loc)
- for node in edges:
- node.edges -= source_nodes
- node.edges.add(retval)
- for node in source_nodes:
- del self.__map[node.merged_ssa_val]
- self.__map[final_merged_ssa_val] = retval
- for ssa_val in final_merged_ssa_val.ssa_vals:
- self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
- return retval
-
- def __repr__(self, repr_state=None):
- # type: (None | IGNodeReprState) -> str
- if repr_state is None:
- repr_state = IGNodeReprState()
- s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
- return f"MergedSSAValToIGNodeMap({{{s}}})"
-
-
-@plain_data(frozen=True, repr=False)
-@final
-class InterferenceGraph:
- __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
-
- def __init__(self, fn_analysis, merged_ssa_vals):
- # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
- self.fn_analysis = fn_analysis
- self.merged_ssa_val_map = SSAValToMergedSSAValMap()
- self.nodes = self.merged_ssa_val_map.ig_node_map
- for i in merged_ssa_vals:
- self.nodes.add_node(i)
-
- def merge(self, ssa_val1, ssa_val2, additional_offset=0):
- # type: (SSAVal, SSAVal, int) -> IGNode
- merged1 = self.merged_ssa_val_map[ssa_val1]
- merged2 = self.merged_ssa_val_map[ssa_val2]
- merged = merged1.with_offset_to_match(ssa_val1)
- merged = merged.merged(merged2.with_offset_to_match(
- ssa_val2, additional_offset=additional_offset))
- return self.nodes.merge_into_one_node(merged)
-
- @staticmethod
- def minimally_merged(fn_analysis):
- # type: (FnAnalysis) -> InterferenceGraph
- retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
- for op in fn_analysis.fn.ops:
- for inp in op.input_uses:
- if inp.unspread_start != inp:
- retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
- additional_offset=inp.reg_offset_in_unspread)
- for out in op.outputs:
- retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
- if out.unspread_start != out:
- retval.merge(out.unspread_start, out,
- additional_offset=out.reg_offset_in_unspread)
- if out.tied_input is not None:
- retval.merge(out.tied_input.ssa_val, out)
- return retval
-
- def __repr__(self, repr_state=None):
- # type: (None | IGNodeReprState) -> str
- if repr_state is None:
- repr_state = IGNodeReprState()
- s = self.nodes.__repr__(repr_state)
- return f"InterferenceGraph(nodes={s}, <...>)"
-
-
-@plain_data(repr=False)
-class IGNodeReprState:
- __slots__ = "node_ids", "did_full_repr"
-
- def __init__(self):
- super().__init__()
- self.node_ids = {} # type: dict[IGNode, int]
- self.did_full_repr = OSet() # type: OSet[IGNode]
-
-
-@final
-class IGNode:
- """ interference graph node """
- __slots__ = "merged_ssa_val", "edges", "loc"
-
- def __init__(self, merged_ssa_val, edges, loc):
- # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
- self.merged_ssa_val = merged_ssa_val
- self.edges = OSet(edges)
- self.loc = loc
-
- def add_edge(self, other):
- # type: (IGNode) -> None
- self.edges.add(other)
- other.edges.add(self)
-
- def __eq__(self, other):
- # type: (object) -> bool
- if isinstance(other, IGNode):
- return self.merged_ssa_val == other.merged_ssa_val
- return NotImplemented
-
- def __hash__(self):
- # type: () -> int
- return hash(self.merged_ssa_val)
-
- def __repr__(self, repr_state=None, short=False):
- # type: (None | IGNodeReprState, bool) -> str
- if repr_state is None:
- repr_state = IGNodeReprState()
- node_id = repr_state.node_ids.get(self, None)
- if node_id is None:
- repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
- if short or self in repr_state.did_full_repr:
- return f"<IGNode #{node_id}>"
- repr_state.did_full_repr.add(self)
- edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
- return (f"IGNode(#{node_id}, "
- f"merged_ssa_val={self.merged_ssa_val}, "
- f"edges={{{edges}}}, "
- f"loc={self.loc})")
-
- @property
- def loc_set(self):
- # type: () -> LocSet
- return self.merged_ssa_val.loc_set
-
- def loc_conflicts_with_neighbors(self, loc):
- # type: (Loc) -> bool
- for neighbor in self.edges:
- if neighbor.loc is not None and neighbor.loc.conflicts(loc):
- return True
- return False
-
-
-class AllocationFailedError(Exception):
- def __init__(self, msg, node, interference_graph):
- # type: (str, IGNode, InterferenceGraph) -> None
- super().__init__(msg, node, interference_graph)
- self.node = node
- self.interference_graph = interference_graph
-
- def __repr__(self, repr_state=None):
- # type: (None | IGNodeReprState) -> str
- if repr_state is None:
- repr_state = IGNodeReprState()
- return (f"{__class__.__name__}({self.args[0]!r}, "
- f"node={self.node.__repr__(repr_state, True)}, "
- f"interference_graph="
- f"{self.interference_graph.__repr__(repr_state)})")
-
- def __str__(self):
- # type: () -> str
- return self.__repr__()
-
-
-def allocate_registers(fn, debug_out=None):
- # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
-
- # inserts enough copies that no manual spilling is necessary, all
- # spilling is done by the register allocator naturally allocating SSAVals
- # to stack slots
- fn.pre_ra_insert_copies()
-
- if debug_out is not None:
- print(f"After pre_ra_insert_copies():\n{fn.ops}",
- file=debug_out, flush=True)
-
- fn_analysis = FnAnalysis(fn)
- interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
-
- if debug_out is not None:
- print(f"After InterferenceGraph.minimally_merged():\n"
- f"{interference_graph}", file=debug_out, flush=True)
-
- for pp, ssa_vals in fn_analysis.live_at.items():
- live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
- for ssa_val in ssa_vals:
- live_merged_ssa_vals.add(
- interference_graph.merged_ssa_val_map[ssa_val])
- for i, j in combinations(live_merged_ssa_vals, 2):
- if i.loc_set.max_conflicts_with(j.loc_set) != 0:
- interference_graph.nodes[i].add_edge(
- interference_graph.nodes[j])
- if debug_out is not None:
- print(f"processed {pp} out of {fn_analysis.all_program_points}",
- file=debug_out, flush=True)
-
- if debug_out is not None:
- print(f"After adding interference graph edges:\n"
- f"{interference_graph}", file=debug_out, flush=True)
-
- nodes_remaining = OSet(interference_graph.nodes.values())
-
- def local_colorability_score(node):
- # type: (IGNode) -> int
- """ returns a positive integer if node is locally colorable, returns
- zero or a negative integer if node isn't known to be locally
- colorable, the more negative the value, the less colorable
- """
- if node not in nodes_remaining:
- raise ValueError()
- retval = len(node.loc_set)
- for neighbor in node.edges:
- if neighbor in nodes_remaining:
- retval -= node.loc_set.max_conflicts_with(neighbor.loc_set)
- return retval
-
- # TODO: implement copy-merging
-
- node_stack = [] # type: list[IGNode]
- while True:
- best_node = None # type: None | IGNode
- best_score = 0
- for node in nodes_remaining:
- score = local_colorability_score(node)
- if best_node is None or score > best_score:
- best_node = node
- best_score = score
- if best_score > 0:
- # it's locally colorable, no need to find a better one
- break
-
- if best_node is None:
- break
- node_stack.append(best_node)
- nodes_remaining.remove(best_node)
-
- if debug_out is not None:
- print(f"After deciding node allocation order:\n"
- f"{node_stack}", file=debug_out, flush=True)
-
- retval = {} # type: dict[SSAVal, Loc]
-
- while len(node_stack) > 0:
- node = node_stack.pop()
- if node.loc is not None:
- if node.loc_conflicts_with_neighbors(node.loc):
- raise AllocationFailedError(
- "IGNode is pre-allocated to a conflicting Loc",
- node=node, interference_graph=interference_graph)
- else:
- # pick the first non-conflicting register in node.reg_class, since
- # register classes are ordered from most preferred to least
- # preferred register.
- for loc in node.loc_set:
- if not node.loc_conflicts_with_neighbors(loc):
- node.loc = loc
- break
- if node.loc is None:
- raise AllocationFailedError(
- "failed to allocate Loc for IGNode",
- node=node, interference_graph=interference_graph)
-
- if debug_out is not None:
- print(f"After allocating Loc for node:\n{node}",
- file=debug_out, flush=True)
-
- for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
- retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
-
- if debug_out is not None:
- print(f"final Locs for all SSAVals:\n{retval}",
- file=debug_out, flush=True)
-
- return retval
from nmutil.plain_data import plain_data
-from bigint_presentation_code.compiler_ir2 import (Fn, OpKind, SSAVal)
+from bigint_presentation_code.compiler_ir import Fn, OpKind, SSAVal
from bigint_presentation_code.matrix import Matrix
from bigint_presentation_code.type_util import Literal, final