import unittest
from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
- OpKind, PreRASimState,
+ FnAnalysis, OpKind, OpStage,
+ PreRASimState, ProgramPoint,
SSAVal)
class TestCompilerIR(unittest.TestCase):
maxDiff = None
+ def test_program_point(self):
+ # type: () -> None
+ expected = [] # type: list[ProgramPoint]
+ for op_index in range(5):
+ for stage in OpStage:
+ expected.append(ProgramPoint(op_index=op_index, stage=stage))
+
+ for idx, pp in enumerate(expected):
+ if idx + 1 < len(expected):
+ self.assertEqual(pp.next(), expected[idx + 1])
+
+ self.assertEqual(sorted(expected), expected)
+
def make_add_fn(self):
# type: () -> tuple[Fn, SSAVal]
fn = Fn()
op5 = fn.append_new_op(
OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
s = op5.outputs[0]
- fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
- immediates=[0], maxvl=MAXVL, name="st")
+ _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+ immediates=[0], maxvl=MAXVL, name="st")
return fn, arg
+ def test_fn_analysis(self):
+ fn, _arg = self.make_add_fn()
+ fn_analysis = FnAnalysis(fn)
+ print(repr(fn_analysis))
+ self.assertEqual(
+ repr(fn_analysis),
+ "FnAnalysis(fn=<Fn>, uses=FMap({"
+ "<arg.outputs[0]: <I64>>: OFSet(["
+ "<ld.input_uses[0]: <I64>>, <st.input_uses[1]: <I64>>]), "
+ "<vl.outputs[0]: <VL_MAXVL>>: OFSet(["
+ "<ld.input_uses[1]: <VL_MAXVL>>, <li.input_uses[0]: <VL_MAXVL>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>, "
+ "<st.input_uses[2]: <VL_MAXVL>>]), "
+ "<ld.outputs[0]: <I64*32>>: OFSet(["
+ "<add.input_uses[0]: <I64*32>>]), "
+ "<li.outputs[0]: <I64*32>>: OFSet(["
+ "<add.input_uses[1]: <I64*32>>]), "
+ "<ca.outputs[0]: <CA>>: OFSet([<add.input_uses[2]: <CA>>]), "
+ "<add.outputs[0]: <I64*32>>: OFSet(["
+ "<st.input_uses[0]: <I64*32>>]), "
+ "<add.outputs[1]: <CA>>: OFSet()}), "
+ "op_indexes=FMap({"
+ "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), "
+ "immediates=[], outputs=(<arg.outputs[0]: <I64>>,), "
+ "name='arg'): 0, "
+ "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), "
+ "immediates=[32], outputs=(<vl.outputs[0]: <VL_MAXVL>>,), "
+ "name='vl'): 1, "
+ "Op(kind=OpKind.SvLd, input_vals=["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<ld.input_uses[0]: <I64>>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>), immediates=[0], "
+ "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld'): 2, "
+ "Op(kind=OpKind.SvLI, input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), immediates=[0], "
+ "outputs=(<li.outputs[0]: <I64*32>>,), name='li'): 3, "
+ "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), "
+ "immediates=[], outputs=(<ca.outputs[0]: <CA>>,), name='ca'): 4, "
+ "Op(kind=OpKind.SvAddE, input_vals=["
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<add.input_uses[0]: <I64*32>>, "
+ "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+ "<add.input_uses[3]: <VL_MAXVL>>), immediates=[], outputs=("
+ "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+ "name='add'): 5, "
+ "Op(kind=OpKind.SvStd, input_vals=["
+ "<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_uses=(<st.input_uses[0]: <I64*32>>, "
+ "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
+ "immediates=[0], outputs=(), name='st'): 6}), "
+ "live_ranges=FMap({"
+ "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[6]:Late>, "
+ "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[6]:Late>, "
+ "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[5]:Late>, "
+ "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[5]:Late>, "
+ "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Late>, "
+ "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Late>, "
+ "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
+ "live_at=FMap({"
+ "<ops[0]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
+ "<ops[0]:Late>: OFSet([<arg.outputs[0]: <I64>>]), "
+ "<ops[1]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
+ "<ops[1]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>]), "
+ "<ops[2]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>]), "
+ "<ops[2]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>]), "
+ "<ops[3]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+ "<ops[3]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+ "<ops[4]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+ "<ops[4]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>]), "
+ "<ops[5]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+ "<ca.outputs[0]: <CA>>, <add.outputs[0]: <I64*32>>, "
+ "<add.outputs[1]: <CA>>]), "
+ "<ops[5]:Late>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>]), "
+ "<ops[6]:Early>: OFSet(["
+ "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+ "<add.outputs[0]: <I64*32>>]), "
+ "<ops[6]:Late>: OFSet()}), "
+ "def_program_ranges=FMap({"
+ "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[1]:Early>, "
+ "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[2]:Early>, "
+ "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[3]:Early>, "
+ "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[4]:Early>, "
+ "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Early>, "
+ "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Early>, "
+ "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
+ "use_program_points=FMap({"
+ "<ld.input_uses[0]: <I64>>: <ops[2]:Early>, "
+ "<ld.input_uses[1]: <VL_MAXVL>>: <ops[2]:Early>, "
+ "<li.input_uses[0]: <VL_MAXVL>>: <ops[3]:Early>, "
+ "<add.input_uses[0]: <I64*32>>: <ops[5]:Early>, "
+ "<add.input_uses[1]: <I64*32>>: <ops[5]:Early>, "
+ "<add.input_uses[2]: <CA>>: <ops[5]:Early>, "
+ "<add.input_uses[3]: <VL_MAXVL>>: <ops[5]:Early>, "
+ "<st.input_uses[0]: <I64*32>>: <ops[6]:Early>, "
+ "<st.input_uses[1]: <I64>>: <ops[6]:Early>, "
+ "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>}), "
+ "all_program_points=<range:ops[0]:Early..ops[7]:Early>)")
+
def test_repr(self):
fn, _arg = self.make_add_fn()
self.assertEqual([repr(i) for i in fn.ops], [
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([3])}), ty=<I64>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=1)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvLd, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
"ty=<I64>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
"OpProperties(kind=OpKind.SvLI, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
"OpProperties(kind=OpKind.SetCA, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvAddE, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None)), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), maxvl=32)",
"OpProperties(kind=OpKind.SvStd, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
"ty=<I64>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=(), maxvl=32)",
])
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([3])}), ty=<I64>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=1)",
"OpProperties(kind=OpKind.CopyFromReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
"ty=<I64>), "
- "tied_input_index=None, spread_index=None),), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
"LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.CopyToReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
"LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
- "tied_input_index=None, spread_index=None),), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
"ty=<I64>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvLd, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
"ty=<I64>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.VecCopyFromReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
"LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.SvLI, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), maxvl=32)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.VecCopyFromReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
"LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.SetCA, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.VecCopyToReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
"LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.VecCopyToReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
"LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.SvAddE, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None)), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), maxvl=32)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.VecCopyFromReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
"LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.VecCopyToReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
"LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None),), maxvl=32)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.CopyToReg, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
"LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
- "tied_input_index=None, spread_index=None),), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
"ty=<I64>), "
- "tied_input_index=None, spread_index=None),), maxvl=1)",
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvStd, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
"ty=<I64>), "
- "tied_input_index=None, spread_index=None), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
- "tied_input_index=None, spread_index=None)), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)), "
"outputs=(), maxvl=32)",
])
if __name__ == "__main__":
- unittest.main()
+ _ = unittest.main()
+from collections import defaultdict
import enum
from abc import ABCMeta, abstractmethod
from enum import Enum, unique
-from functools import lru_cache
+from functools import lru_cache, total_ordering
from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
Sequence, TypeVar, overload)
from weakref import WeakValueDictionary as _WeakVDict
from cached_property import cached_property
from nmutil.plain_data import fields, plain_data
-from bigint_presentation_code.type_util import Self, assert_never, final
+from bigint_presentation_code.type_util import Self, assert_never, final, Literal
from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
assert_never(out.ty.base_ty)
+@final
+@unique
+@total_ordering
+class OpStage(Enum):
+ value: Literal[0, 1] # type: ignore
+
+ def __new__(cls, value):
+ # type: (int) -> OpStage
+ value = int(value)
+ if value not in (0, 1):
+ raise ValueError("invalid value")
+ retval = object.__new__(cls)
+ retval._value_ = value
+ return retval
+
+ Early = 0
+ """ early stage of Op execution, where all input reads occur.
+ all output writes with `write_stage == Early` occur here too, and therefore
+ conflict with input reads, telling the compiler that it that can't share
+ that output's register with any inputs that the output isn't tied to.
+
+ All outputs, even unused outputs, can't share registers with any other
+ outputs, independent of `write_stage` settings.
+ """
+ Late = 1
+ """ late stage of Op execution, where all output writes with
+ `write_stage == Late` occur, and therefore don't conflict with input reads,
+ telling the compiler that any inputs can safely use the same register as
+ those outputs.
+
+ All outputs, even unused outputs, can't share registers with any other
+ outputs, independent of `write_stage` settings.
+ """
+
+ def __repr__(self):
+ # type: () -> str
+ return f"OpStage.{self._name_}"
+
+ def __lt__(self, other):
+ # type: (OpStage | object) -> bool
+ if isinstance(other, OpStage):
+ return self.value < other.value
+ return NotImplemented
+
+
+assert OpStage.Early < OpStage.Late, "early must be less than late"
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+@total_ordering
+class ProgramPoint:
+ __slots__ = "op_index", "stage"
+
+ def __init__(self, op_index, stage):
+ # type: (int, OpStage) -> None
+ self.op_index = op_index
+ self.stage = stage
+
+ @property
+ def int_value(self):
+ # type: () -> int
+ """ an integer representation of `self` such that it keeps ordering and
+ successor/predecessor relations.
+ """
+ return self.op_index * 2 + self.stage.value
+
+ @staticmethod
+ def from_int_value(int_value):
+ # type: (int) -> ProgramPoint
+ op_index, stage = divmod(int_value, 2)
+ return ProgramPoint(op_index=op_index, stage=OpStage(stage))
+
+ def next(self, steps=1):
+ # type: (int) -> ProgramPoint
+ return ProgramPoint.from_int_value(self.int_value + steps)
+
+ def prev(self, steps=1):
+ # type: (int) -> ProgramPoint
+ return self.next(steps=-steps)
+
+ def __lt__(self, other):
+ # type: (ProgramPoint | Any) -> bool
+ if not isinstance(other, ProgramPoint):
+ return NotImplemented
+ if self.op_index != other.op_index:
+ return self.op_index < other.op_index
+ return self.stage < other.stage
+
+ def __repr__(self):
+ # type: () -> str
+ return f"<ops[{self.op_index}]:{self.stage._name_}>"
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class ProgramRange(Sequence[ProgramPoint]):
+ __slots__ = "start", "stop"
+
+ def __init__(self, start, stop):
+ # type: (ProgramPoint, ProgramPoint) -> None
+ self.start = start
+ self.stop = stop
+
+ @cached_property
+ def int_value_range(self):
+ # type: () -> range
+ return range(self.start.int_value, self.stop.int_value)
+
+ @staticmethod
+ def from_int_value_range(int_value_range):
+ # type: (range) -> ProgramRange
+ if int_value_range.step != 1:
+ raise ValueError("int_value_range must have step == 1")
+ return ProgramRange(
+ start=ProgramPoint.from_int_value(int_value_range.start),
+ stop=ProgramPoint.from_int_value(int_value_range.stop))
+
+ @overload
+ def __getitem__(self, __idx):
+ # type: (int) -> ProgramPoint
+ ...
+
+ @overload
+ def __getitem__(self, __idx):
+ # type: (slice) -> ProgramRange
+ ...
+
+ def __getitem__(self, __idx):
+ # type: (int | slice) -> ProgramPoint | ProgramRange
+ v = range(self.start.int_value, self.stop.int_value)[__idx]
+ if isinstance(v, int):
+ return ProgramPoint.from_int_value(v)
+ return ProgramRange.from_int_value_range(v)
+
+ def __len__(self):
+ # type: () -> int
+ return len(self.int_value_range)
+
+ def __iter__(self):
+ # type: () -> Iterator[ProgramPoint]
+ return map(ProgramPoint.from_int_value, self.int_value_range)
+
+ def __repr__(self):
+ # type: () -> str
+ start = repr(self.start).lstrip("<").rstrip(">")
+ stop = repr(self.stop).lstrip("<").rstrip(">")
+ return f"<range:{start}..{stop}>"
+
+
@plain_data(frozen=True, eq=False)
@final
-class FnWithUses:
- __slots__ = "fn", "uses"
+class FnAnalysis:
+ __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
+ "def_program_ranges", "use_program_points",
+ "all_program_points")
def __init__(self, fn):
# type: (Fn) -> None
self.fn = fn
- retval = {} # type: dict[SSAVal, OSet[SSAUse]]
+ self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
+ self.all_program_points = ProgramRange(
+ start=ProgramPoint(op_index=0, stage=OpStage.Early),
+ stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
+ def_program_ranges = {} # type: dict[SSAVal, ProgramRange]
+ use_program_points = {} # type: dict[SSAUse, ProgramPoint]
+ uses = {} # type: dict[SSAVal, OSet[SSAUse]]
+ live_range_stops = {} # type: dict[SSAVal, ProgramPoint]
for op in fn.ops:
- for idx, inp in enumerate(op.input_vals):
- retval[inp].add(SSAUse(op, idx))
+ for use in op.input_uses:
+ uses[use.ssa_val].add(use)
+ use_program_point = self.__get_use_program_point(use)
+ use_program_points[use] = use_program_point
+ live_range_stops[use.ssa_val] = max(
+ live_range_stops[use.ssa_val], use_program_point.next())
for out in op.outputs:
- retval[out] = OSet()
- self.uses = FMap((k, OFSet(v)) for k, v in retval.items())
+ uses[out] = OSet()
+ def_program_range = self.__get_def_program_range(out)
+ def_program_ranges[out] = def_program_range
+ live_range_stops[out] = def_program_range.stop
+ self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
+ self.def_program_ranges = FMap(def_program_ranges)
+ self.use_program_points = FMap(use_program_points)
+ live_ranges = {} # type: dict[SSAVal, ProgramRange]
+ live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
+ for ssa_val in uses.keys():
+ live_ranges[ssa_val] = live_range = ProgramRange(
+ start=self.def_program_ranges[ssa_val].start,
+ stop=live_range_stops[ssa_val])
+ for program_point in live_range:
+ live_at[program_point].add(ssa_val)
+ self.live_ranges = FMap(live_ranges)
+ self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
+
+ def __get_def_program_range(self, ssa_val):
+ # type: (SSAVal) -> ProgramRange
+ write_stage = ssa_val.defining_descriptor.write_stage
+ start = ProgramPoint(
+ op_index=self.op_indexes[ssa_val.op], stage=write_stage)
+ # always include late stage of ssa_val.op, to ensure outputs always
+ # overlap all other outputs.
+ # stop is exclusive, so we need the next program point.
+ stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
+ return ProgramRange(start=start, stop=stop)
+
+ def __get_use_program_point(self, ssa_use):
+ # type: (SSAUse) -> ProgramPoint
+ assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
+ "assumed here, ensured by GenericOpProperties.__init__"
+ return ProgramPoint(
+ op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
def __eq__(self, other):
- # type: (FnWithUses | Any) -> bool
- if isinstance(other, FnWithUses):
+ # type: (FnAnalysis | Any) -> bool
+ if isinstance(other, FnAnalysis):
return self.fn == other.fn
return NotImplemented
# type: () -> LocKind
# pyright fails typechecking when using `in` here:
# reported: https://github.com/microsoft/pyright/issues/4102
- if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
- or self is LocSubKind.SV_EXTRA2_SGPR \
- or self is LocSubKind.SV_EXTRA3_VGPR \
- or self is LocSubKind.SV_EXTRA3_SGPR:
+ if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
+ LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
+ LocSubKind.SV_EXTRA3_SGPR):
return LocKind.GPR
if self is LocSubKind.StackI64:
return LocKind.StackI64
def __hash__(self):
return self.__hash
+ @lru_cache(maxsize=None, typed=True)
+ def max_conflicts_with(self, other):
+ # type: (LocSet | Loc) -> int
+ """the largest number of Locs in `self` that a single Loc
+ from `other` can conflict with
+ """
+ if isinstance(other, LocSet):
+ return max(self.max_conflicts_with(i) for i in other)
+ else:
+ return sum(other.conflicts(i) for i in self)
+
@plain_data(frozen=True, unsafe_hash=True)
@final
class GenericOperandDesc:
"""generic Op operand descriptor"""
- __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
+ __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
+ "write_stage")
def __init__(
self, ty, # type: GenericTy
fixed_loc=None, # type: Loc | None
tied_input_index=None, # type: int | None
spread=False, # type: bool
+ write_stage=OpStage.Early, # type: OpStage
):
# type: (...) -> None
self.ty = ty
raise ValueError("operand can't be both spread and fixed")
if self.ty.is_vec:
raise ValueError("operand can't be both spread and vector")
+ self.write_stage = write_stage
def tied_to_input(self, tied_input_index):
# type: (int) -> Self
return GenericOperandDesc(self.ty, self.sub_kinds,
- tied_input_index=tied_input_index)
+ tied_input_index=tied_input_index,
+ write_stage=self.write_stage)
def with_fixed_loc(self, fixed_loc):
# type: (Loc) -> Self
- return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
+ return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
+ write_stage=self.write_stage)
+
+ def with_write_stage(self, write_stage):
+ # type: (OpStage) -> Self
+ return GenericOperandDesc(self.ty, self.sub_kinds,
+ fixed_loc=self.fixed_loc,
+ tied_input_index=self.tied_input_index,
+ spread=self.spread,
+ write_stage=write_stage)
def instantiate(self, maxvl):
# type: (int) -> Iterable[OperandDesc]
idx = None
yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
tied_input_index=self.tied_input_index,
- spread_index=idx)
+ spread_index=idx, write_stage=self.write_stage)
@plain_data(frozen=True, unsafe_hash=True)
@final
class OperandDesc:
"""Op operand descriptor"""
- __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
+ __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
+ "write_stage")
- def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
- # type: (LocSet, int | None, int | None) -> None
+ def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
+ write_stage):
+ # type: (LocSet, int | None, int | None, OpStage) -> None
if len(loc_set_before_spread) == 0:
raise ValueError("loc_set_before_spread must not be empty")
self.loc_set_before_spread = loc_set_before_spread
if self.tied_input_index is not None and self.spread_index is not None:
raise ValueError("operand can't be both spread and tied")
self.spread_index = spread_index
+ self.write_stage = write_stage
@cached_property
def ty_before_spread(self):
has_side_effects=False, # type: bool
):
# type: (...) -> None
- self.demo_asm = demo_asm
- self.inputs = tuple(inputs)
+ self.demo_asm = demo_asm # type: str
+ self.inputs = tuple(inputs) # type: tuple[GenericOperandDesc, ...]
for inp in self.inputs:
if inp.tied_input_index is not None:
raise ValueError(
f"tied_input_index is not allowed on inputs: {inp}")
- self.outputs = tuple(outputs)
+ if inp.write_stage is not OpStage.Early:
+ raise ValueError(
+ f"write_stage is not allowed on inputs: {inp}")
+ self.outputs = tuple(outputs) # type: tuple[GenericOperandDesc, ...]
fixed_locs = [] # type: list[tuple[Loc, int]]
for idx, out in enumerate(self.outputs):
if out.tied_input_index is not None:
f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
f"with {other_fixed_loc}")
fixed_locs.append((out.fixed_loc, idx))
- self.immediates = tuple(immediates)
- self.is_copy = is_copy
- self.is_load_immediate = is_load_immediate
- self.has_side_effects = has_side_effects
+ self.immediates = tuple(immediates) # type: tuple[range, ...]
+ self.is_copy = is_copy # type: bool
+ self.is_load_immediate = is_load_immediate # type: bool
+ self.has_side_effects = has_side_effects # type: bool
@plain_data(frozen=True, unsafe_hash=True)
def __init__(self, kind, maxvl):
# type: (OpKind, int) -> None
- self.kind = kind
+ self.kind = kind # type: OpKind
inputs = [] # type: list[OperandDesc]
for inp in self.generic.inputs:
inputs.extend(inp.instantiate(maxvl=maxvl))
- self.inputs = tuple(inputs)
+ self.inputs = tuple(inputs) # type: tuple[OperandDesc, ...]
outputs = [] # type: list[OperandDesc]
for out in self.generic.outputs:
outputs.extend(out.instantiate(maxvl=maxvl))
- self.outputs = tuple(outputs)
- self.maxvl = maxvl
+ self.outputs = tuple(outputs) # type: tuple[OperandDesc, ...]
+ self.maxvl = maxvl # type: int
@property
def generic(self):
return OpProperties(self, maxvl=maxvl)
def __repr__(self):
+ # type: () -> str
return "OpKind." + self._name_
@cached_property
ClearCA = GenericOpProperties(
demo_asm="addic 0, 0, 0",
inputs=[],
- outputs=[OD_CA],
+ outputs=[OD_CA.with_write_stage(OpStage.Late)],
)
_PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
SetCA = GenericOpProperties(
demo_asm="subfc 0, 0, 0",
inputs=[],
- outputs=[OD_CA],
+ outputs=[OD_CA.with_write_stage(OpStage.Late)],
)
_PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
SetVLI = GenericOpProperties(
demo_asm="setvl 0, 0, imm, 0, 1, 1",
inputs=(),
- outputs=[OD_VL],
+ outputs=[OD_VL.with_write_stage(OpStage.Late)],
immediates=[range(1, 65)],
is_load_immediate=True,
)
LI = GenericOpProperties(
demo_asm="addi RT, 0, imm",
inputs=(),
- outputs=[OD_BASE_SGPR],
+ outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
immediates=[IMM_S16],
is_load_immediate=True,
)
ty=GenericTy(BaseTy.I64, is_vec=True),
sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
), OD_VL],
- outputs=[OD_EXTRA3_VGPR],
+ outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
is_copy=True,
)
_PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
outputs=[GenericOperandDesc(
ty=GenericTy(BaseTy.I64, is_vec=True),
sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
+ write_stage=OpStage.Late,
)],
is_copy=True,
)
outputs=[GenericOperandDesc(
ty=GenericTy(BaseTy.I64, is_vec=False),
sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
+ write_stage=OpStage.Late,
)],
is_copy=True,
)
ty=GenericTy(BaseTy.I64, is_vec=False),
sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
LocSubKind.StackI64],
+ write_stage=OpStage.Late,
)],
is_copy=True,
)
sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
spread=True,
), OD_VL],
- outputs=[OD_EXTRA3_VGPR],
+ outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
is_copy=True,
)
_PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
ty=GenericTy(BaseTy.I64, is_vec=False),
sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
spread=True,
+ write_stage=OpStage.Late,
)],
is_copy=True,
)
Ld = GenericOpProperties(
demo_asm="ld RT, imm(RA)",
inputs=[OD_BASE_SGPR],
- outputs=[OD_BASE_SGPR],
+ outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
immediates=[IMM_S16],
)
_PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
def __init__(self, op, operand_idx):
# type: (Op, int) -> None
+ super().__init__()
self.op = op
if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
raise ValueError("invalid operand_idx")
# type: () -> tuple[OperandDesc, ...]
...
- @property
+ @cached_property
def defining_descriptor(self):
# type: () -> OperandDesc
return self.descriptor_array[self.operand_idx]
return SSAUse(op=self.op,
operand_idx=self.defining_descriptor.tied_input_index)
+ @property
+ def write_stage(self):
+ # type: () -> OpStage
+ return self.defining_descriptor.write_stage
+
@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
def __init__(self, items, op):
# type: (Iterable[_T], Op) -> None
+ super().__init__()
self.__op = op
self.__items = [] # type: list[_T]
for idx, item in enumerate(items):
if idx >= len(self.descriptors):
raise ValueError("too many items")
- self._verify_write(idx, item)
+ _ = self._verify_write(idx, item)
self.__items.append(item)
if len(self.__items) < len(self.descriptors):
raise ValueError("not enough items")
return len(self.__items)
def __repr__(self):
+ # type: () -> str
return f"{self.__class__.__name__}({self.__items}, op=...)"
@property
def kind(self):
+ # type: () -> OpKind
return self.properties.kind
def __eq__(self, other):
return NotImplemented
def __hash__(self):
+ # type: () -> int
return object.__hash__(self)
def __repr__(self):
def __init__(self, ssa_vals, memory):
# type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
- self.ssa_vals = ssa_vals
- self.memory = memory
+ self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
+ self.memory = memory # type: dict[int, int]
def load_byte(self, addr):
# type: (int) -> int
"""
from itertools import combinations
-from typing import Any, Generic, Iterable, Iterator, Mapping, MutableSet
+from typing import Any, Iterable, Iterator, Mapping, MutableSet
from cached_property import cached_property
from nmutil.plain_data import plain_data
-from bigint_presentation_code.compiler_ir2 import (BaseTy, FnWithUses, Loc,
- LocSet, Op, SSAVal, Ty)
+from bigint_presentation_code.compiler_ir2 import (BaseTy, FnAnalysis, Loc,
+ LocSet, Op, ProgramRange,
+ SSAVal, Ty)
from bigint_presentation_code.type_util import final
from bigint_presentation_code.util import FMap, OFSet, OSet
def __init__(self, first_write, last_use=None):
# type: (int, int | None) -> None
+ super().__init__()
if last_use is None:
last_use = first_write
if last_use < first_write:
* `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
* `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
"""
- __slots__ = "fn_with_uses", "ssa_val_offsets", "first_ssa_val", "loc_set"
+ __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
- def __init__(self, fn_with_uses, ssa_val_offsets):
- # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
- self.fn_with_uses = fn_with_uses
+ def __init__(self, fn_analysis, ssa_val_offsets):
+ # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
+ self.fn_analysis = fn_analysis
if isinstance(ssa_val_offsets, SSAVal):
ssa_val_offsets = {ssa_val_offsets: 0}
self.ssa_val_offsets = FMap(ssa_val_offsets) # type: FMap[SSAVal, int]
# type: () -> Iterable[Loc]
for loc in ssa_val.def_loc_set_before_spread:
disallowed_by_use = False
- for use in fn_with_uses.uses[ssa_val]:
+ for use in fn_analysis.uses[ssa_val]:
use_spread_idx = \
use.defining_descriptor.spread_index or 0
# calculate the start for the use's Loc before spread
@cached_property
def __hash(self):
# type: () -> int
- return hash((self.fn_with_uses, self.ssa_val_offsets))
+ return hash((self.fn_analysis, self.ssa_val_offsets))
def __hash__(self):
# type: () -> int
def offset_by(self, amount):
# type: (int) -> MergedSSAVal
v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
- return MergedSSAVal(fn_with_uses=self.fn_with_uses, ssa_val_offsets=v)
+ return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
def normalized(self):
# type: () -> MergedSSAVal
# type: (*MergedSSAVal) -> MergedSSAVal
retval = dict(self.ssa_val_offsets)
for other in others:
- if other.fn_with_uses != self.fn_with_uses:
- raise ValueError("fn_with_uses mismatch")
+ if other.fn_analysis != self.fn_analysis:
+ raise ValueError("fn_analysis mismatch")
for ssa_val, offset in other.ssa_val_offsets.items():
if ssa_val in retval and retval[ssa_val] != offset:
raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
f"{retval[ssa_val]} != {offset}")
retval[ssa_val] = offset
- return MergedSSAVal(fn_with_uses=self.fn_with_uses,
+ return MergedSSAVal(fn_analysis=self.fn_analysis,
ssa_val_offsets=retval)
+ @cached_property
+ def live_interval(self):
+ # type: () -> ProgramRange
+ live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
+ start = live_range.start
+ stop = live_range.stop
+ for ssa_val in self.ssa_vals:
+ live_range = self.fn_analysis.live_ranges[ssa_val]
+ start = min(start, live_range.start)
+ stop = max(stop, live_range.stop)
+ return ProgramRange(start=start, stop=stop)
+
@final
class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]):
@plain_data(frozen=True)
@final
class MergedSSAVals:
- __slots__ = "fn_with_uses", "merge_map", "merged_ssa_vals"
+ __slots__ = "fn_analysis", "merge_map", "merged_ssa_vals"
- def __init__(self, fn_with_uses, merged_ssa_vals):
- # type: (FnWithUses, Iterable[MergedSSAVal]) -> None
- self.fn_with_uses = fn_with_uses
+ def __init__(self, fn_analysis, merged_ssa_vals):
+ # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
+ self.fn_analysis = fn_analysis
self.merge_map = MergedSSAValsMap()
self.merged_ssa_vals = self.merge_map.values_set
for i in merged_ssa_vals:
return merged
@staticmethod
- def minimally_merged(fn_with_uses):
- # type: (FnWithUses) -> MergedSSAVals
- retval = MergedSSAVals(fn_with_uses=fn_with_uses, merged_ssa_vals=())
- for op in fn_with_uses.fn.ops:
+ def minimally_merged(fn_analysis):
+ # type: (FnAnalysis) -> MergedSSAVals
+ retval = MergedSSAVals(fn_analysis=fn_analysis, merged_ssa_vals=())
+ for op in fn_analysis.fn.ops:
for inp in op.input_uses:
if inp.unspread_start != inp:
retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
return retval
-# FIXME: work on code from here
-
-
-@final
-class LiveIntervals(Mapping[MergedSSAVal, LiveInterval]):
- def __init__(self, merged_ssa_vals):
- # type: (list[Op]) -> None
- self.__merged_reg_sets = MergedRegSets(ops)
- live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
- for op_idx, op in enumerate(ops):
- for val in op.inputs().values():
- live_intervals[self.__merged_reg_sets[val]] += op_idx
- for val in op.outputs().values():
- reg_set = self.__merged_reg_sets[val]
- if reg_set not in live_intervals:
- live_intervals[reg_set] = LiveInterval(op_idx)
- else:
- live_intervals[reg_set] += op_idx
- self.__live_intervals = live_intervals
- live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
- live_after += (OSet() for _ in ops)
- for reg_set, live_interval in self.__live_intervals.items():
- for i in live_interval.live_after_op_range:
- live_after[i].add(reg_set)
- self.__live_after = [OFSet(i) for i in live_after]
-
- @property
- def merged_reg_sets(self):
- return self.__merged_reg_sets
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegType]) -> LiveInterval
- return self.__live_intervals[key]
-
- def __iter__(self):
- return iter(self.__live_intervals)
-
- def __len__(self):
- return len(self.__live_intervals)
-
- def reg_sets_live_after(self, op_index):
- # type: (int) -> OFSet[MergedRegSet[_RegType]]
- return self.__live_after[op_index]
-
- def __repr__(self):
- reg_sets_live_after = dict(enumerate(self.__live_after))
- return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
- f"merged_reg_sets={self.merged_reg_sets}, "
- f"reg_sets_live_after={reg_sets_live_after})")
-
-
@final
-class IGNode(Generic[_RegType]):
+class IGNode:
""" interference graph node """
- __slots__ = "merged_reg_set", "edges", "reg"
+ __slots__ = "merged_ssa_val", "edges", "loc"
- def __init__(self, merged_reg_set, edges=(), reg=None):
- # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
- self.merged_reg_set = merged_reg_set
+ def __init__(self, merged_ssa_val, edges=(), loc=None):
+ # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
+ self.merged_ssa_val = merged_ssa_val
self.edges = OSet(edges)
- self.reg = reg
+ self.loc = loc
def add_edge(self, other):
# type: (IGNode) -> None
def __eq__(self, other):
# type: (object) -> bool
if isinstance(other, IGNode):
- return self.merged_reg_set == other.merged_reg_set
+ return self.merged_ssa_val == other.merged_ssa_val
return NotImplemented
def __hash__(self):
- return hash(self.merged_reg_set)
+ return hash(self.merged_ssa_val)
def __repr__(self, nodes=None):
# type: (None | dict[IGNode, int]) -> str
nodes[self] = len(nodes)
edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
return (f"IGNode(#{nodes[self]}, "
- f"merged_reg_set={self.merged_reg_set}, "
+ f"merged_ssa_val={self.merged_ssa_val}, "
f"edges={edges}, "
- f"reg={self.reg})")
+ f"loc={self.loc})")
@property
- def reg_class(self):
- # type: () -> RegClass
- return self.merged_reg_set.ty.reg_class
+ def loc_set(self):
+ # type: () -> LocSet
+ return self.merged_ssa_val.loc_set
- def reg_conflicts_with_neighbors(self, reg):
- # type: (RegLoc) -> bool
+ def loc_conflicts_with_neighbors(self, loc):
+ # type: (Loc) -> bool
for neighbor in self.edges:
- if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+ if neighbor.loc is not None and neighbor.loc.conflicts(loc):
return True
return False
-@final
-class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
- def __init__(self, merged_reg_sets):
- # type: (Iterable[MergedRegSet[_RegType]]) -> None
- self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegType]) -> IGNode
- return self.__nodes[key]
-
- def __iter__(self):
- return iter(self.__nodes)
-
- def __len__(self):
- return len(self.__nodes)
-
- def __repr__(self):
- nodes = {}
- nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
- nodes_text = ", ".join(nodes_text)
- return f"InterferenceGraph(nodes={{{nodes_text}}})"
-
-
@plain_data()
class AllocationFailed:
- __slots__ = "node", "live_intervals", "interference_graph"
+ __slots__ = "node", "merged_ssa_vals", "interference_graph"
- def __init__(self, node, live_intervals, interference_graph):
- # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+ def __init__(self, node, merged_ssa_vals, interference_graph):
+ # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None
+ super().__init__()
self.node = node
- self.live_intervals = live_intervals
+ self.merged_ssa_vals = merged_ssa_vals
self.interference_graph = interference_graph
self.allocation_failed = allocation_failed
-def try_allocate_registers_without_spilling(ops):
- # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
+def try_allocate_registers_without_spilling(merged_ssa_vals):
+ # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed
- live_intervals = LiveIntervals(ops)
- merged_reg_sets = live_intervals.merged_reg_sets
- interference_graph = InterferenceGraph(merged_reg_sets.values())
- for op_idx, op in enumerate(ops):
- reg_sets = live_intervals.reg_sets_live_after(op_idx)
- for i, j in combinations(reg_sets, 2):
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
- interference_graph[i].add_edge(interference_graph[j])
- for i, j in op.get_extra_interferences():
- i = merged_reg_sets[i]
- j = merged_reg_sets[j]
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+ interference_graph = {
+ i: IGNode(i) for i in merged_ssa_vals.merged_ssa_vals}
+ fn_analysis = merged_ssa_vals.fn_analysis
+ for ssa_vals in fn_analysis.live_at.values():
+ live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
+ for ssa_val in ssa_vals:
+ live_merged_ssa_vals.add(merged_ssa_vals.merge_map[ssa_val])
+ for i, j in combinations(live_merged_ssa_vals, 2):
+ if i.loc_set.max_conflicts_with(j.loc_set) != 0:
interference_graph[i].add_edge(interference_graph[j])
nodes_remaining = OSet(interference_graph.values())
+# FIXME: work on code from here
+
def local_colorability_score(node):
# type: (IGNode) -> int
""" returns a positive integer if node is locally colorable, returns
"""
if node not in nodes_remaining:
raise ValueError()
- retval = len(node.reg_class)
+ retval = len(node.loc_set)
for neighbor in node.edges:
if neighbor in nodes_remaining:
retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)