import unittest
from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
- FnAnalysis, OpKind, OpStage,
+ FnAnalysis, GenAsmState, Loc, LocKind, OpKind, OpStage,
PreRASimState, ProgramPoint,
SSAVal)
def test_fn_analysis(self):
fn, _arg = self.make_add_fn()
fn_analysis = FnAnalysis(fn)
- print(repr(fn_analysis))
self.assertEqual(
- repr(fn_analysis),
- "FnAnalysis(fn=<Fn>, uses=FMap({"
+ repr(fn_analysis.uses),
+ "FMap({"
"<arg.outputs[0]: <I64>>: OFSet(["
"<ld.input_uses[0]: <I64>>, <st.input_uses[1]: <I64>>]), "
"<vl.outputs[0]: <VL_MAXVL>>: OFSet(["
"<ca.outputs[0]: <CA>>: OFSet([<add.input_uses[2]: <CA>>]), "
"<add.outputs[0]: <I64*32>>: OFSet(["
"<st.input_uses[0]: <I64*32>>]), "
- "<add.outputs[1]: <CA>>: OFSet()}), "
- "op_indexes=FMap({"
+ "<add.outputs[1]: <CA>>: OFSet()})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.op_indexes),
+ "FMap({"
"Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), "
"immediates=[], outputs=(<arg.outputs[0]: <I64>>,), "
"name='arg'): 0, "
"<vl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<st.input_uses[0]: <I64*32>>, "
"<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
- "immediates=[0], outputs=(), name='st'): 6}), "
- "live_ranges=FMap({"
+ "immediates=[0], outputs=(), name='st'): 6})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.live_ranges),
+ "FMap({"
"<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[6]:Late>, "
"<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[6]:Late>, "
"<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[5]:Late>, "
"<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[5]:Late>, "
"<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Late>, "
"<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Late>, "
- "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
- "live_at=FMap({"
+ "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.live_at),
+ "FMap({"
"<ops[0]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
"<ops[0]:Late>: OFSet([<arg.outputs[0]: <I64>>]), "
"<ops[1]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
"<ops[6]:Early>: OFSet(["
"<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
"<add.outputs[0]: <I64*32>>]), "
- "<ops[6]:Late>: OFSet()}), "
- "def_program_ranges=FMap({"
+ "<ops[6]:Late>: OFSet()})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.def_program_ranges),
+ "FMap({"
"<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[1]:Early>, "
"<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[2]:Early>, "
"<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[3]:Early>, "
"<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[4]:Early>, "
"<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Early>, "
"<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Early>, "
- "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
- "use_program_points=FMap({"
+ "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.use_program_points),
+ "FMap({"
"<ld.input_uses[0]: <I64>>: <ops[2]:Early>, "
"<ld.input_uses[1]: <VL_MAXVL>>: <ops[2]:Early>, "
"<li.input_uses[0]: <VL_MAXVL>>: <ops[3]:Early>, "
"<add.input_uses[3]: <VL_MAXVL>>: <ops[5]:Early>, "
"<st.input_uses[0]: <I64*32>>: <ops[6]:Early>, "
"<st.input_uses[1]: <I64>>: <ops[6]:Early>, "
- "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>}), "
- "all_program_points=<range:ops[0]:Early..ops[7]:Early>)")
+ "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>})"
+ )
+ self.assertEqual(
+ repr(fn_analysis.all_program_points),
+ "<range:ops[0]:Early..ops[7]:Early>")
def test_repr(self):
fn, _arg = self.make_add_fn()
"write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None, "
+ "tied_input_index=2, spread_index=None, "
"write_stage=OpStage.Early)), maxvl=32)",
"OpProperties(kind=OpKind.SvStd, "
"inputs=("
"input_uses=(<ld.inp0.copy.input_uses[0]: <I64>>,), "
"immediates=[], "
"outputs=(<ld.inp0.copy.outputs[0]: <I64>>,), name='ld.inp0.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='ld.inp1.setvl')",
"Op(kind=OpKind.SvLd, "
"input_vals=[<ld.inp0.copy.outputs[0]: <I64>>, "
- "<vl.outputs[0]: <VL_MAXVL>>], "
+ "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<ld.input_uses[0]: <I64>>, "
"<ld.input_uses[1]: <VL_MAXVL>>), "
"immediates=[0], "
"immediates=[], "
"outputs=(<ld.out0.copy.outputs[0]: <I64*32>>,), "
"name='ld.out0.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<li.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='li.inp0.setvl')",
"Op(kind=OpKind.SvLI, "
- "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+ "input_vals=[<li.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
"immediates=[0], "
"outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
"immediates=[], "
"outputs=(<add.inp1.copy.outputs[0]: <I64*32>>,), "
"name='add.inp1.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<add.inp3.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='add.inp3.setvl')",
"Op(kind=OpKind.SvAddE, "
"input_vals=[<add.inp0.copy.outputs[0]: <I64*32>>, "
"<add.inp1.copy.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
- "<vl.outputs[0]: <VL_MAXVL>>], "
+ "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<add.input_uses[0]: <I64*32>>, "
"<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
"<add.input_uses[3]: <VL_MAXVL>>), "
"immediates=[], "
"outputs=(<st.inp1.copy.outputs[0]: <I64>>,), "
"name='st.inp1.copy')",
+ "Op(kind=OpKind.SetVLI, "
+ "input_vals=[], "
+ "input_uses=(), "
+ "immediates=[32], "
+ "outputs=(<st.inp2.setvl.outputs[0]: <VL_MAXVL>>,), "
+ "name='st.inp2.setvl')",
"Op(kind=OpKind.SvStd, "
"input_vals=[<st.inp0.copy.outputs[0]: <I64*32>>, "
- "<st.inp1.copy.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+ "<st.inp1.copy.outputs[0]: <I64>>, "
+ "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>], "
"input_uses=(<st.input_uses[0]: <I64*32>>, "
"<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
"immediates=[0], "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
- "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
+ "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
- "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
+ "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Early),), "
"outputs=("
"ty=<I64>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvLd, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvLI, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.SetCA, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Late),), maxvl=32)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvAddE, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.CA: FBitSet([0])}), ty=<CA>), "
- "tied_input_index=None, spread_index=None, "
+ "tied_input_index=2, spread_index=None, "
"write_stage=OpStage.Early)), maxvl=32)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=(), "
"outputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Late),), maxvl=32)",
"OpProperties(kind=OpKind.SetVLI, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet(range(14, 97)), "
- "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
+ "LocKind.StackI64: FBitSet(range(0, 481))}), ty=<I64*32>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Early), "
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
- "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
+ "LocKind.StackI64: FBitSet(range(0, 512))}), ty=<I64>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Early),), "
"outputs=("
"ty=<I64>), "
"tied_input_index=None, spread_index=None, "
"write_stage=OpStage.Late),), maxvl=1)",
+ "OpProperties(kind=OpKind.SetVLI, "
+ "inputs=(), "
+ "outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
+ "tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),), maxvl=1)",
"OpProperties(kind=OpKind.SvStd, "
"inputs=("
"OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
"0x001f0: <0x0000000000000000>,\n"
"0x001f8: <0x0000000000000000>})")
+ def test_gen_asm(self):
+ fn, _arg = self.make_add_fn()
+ fn.pre_ra_insert_copies()
+ VL_LOC = Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)
+ CA_LOC = Loc(kind=LocKind.CA, start=0, reg_len=1)
+ state = GenAsmState(allocated_locs={
+ fn.ops[0].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[1].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[2].outputs[0]: VL_LOC,
+ fn.ops[3].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[4].outputs[0]: VL_LOC,
+ fn.ops[5].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[6].outputs[0]: VL_LOC,
+ fn.ops[7].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[8].outputs[0]: VL_LOC,
+ fn.ops[9].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+ fn.ops[10].outputs[0]: VL_LOC,
+ fn.ops[11].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+ fn.ops[12].outputs[0]: CA_LOC,
+ fn.ops[13].outputs[0]: VL_LOC,
+ fn.ops[14].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[15].outputs[0]: VL_LOC,
+ fn.ops[16].outputs[0]: Loc(kind=LocKind.GPR, start=64, reg_len=32),
+ fn.ops[17].outputs[0]: VL_LOC,
+ fn.ops[18].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[18].outputs[1]: CA_LOC,
+ fn.ops[19].outputs[0]: VL_LOC,
+ fn.ops[20].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[21].outputs[0]: VL_LOC,
+ fn.ops[22].outputs[0]: Loc(kind=LocKind.GPR, start=32, reg_len=32),
+ fn.ops[23].outputs[0]: Loc(kind=LocKind.GPR, start=3, reg_len=1),
+ fn.ops[24].outputs[0]: VL_LOC,
+ })
+ fn.gen_asm(state)
+ self.assertEqual(state.output, [
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.ld *32, 0(3)',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.addi *64, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'subfc 0, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.adde *32, *32, *64',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.std *32, 0(3)',
+ ])
+
if __name__ == "__main__":
_ = unittest.main()
--- /dev/null
+import unittest
+
+from bigint_presentation_code.compiler_ir2 import Fn, GenAsmState, OpKind, SSAVal
+from bigint_presentation_code.register_allocator2 import allocate_registers
+
+
+class TestCompilerIR(unittest.TestCase):
+ maxDiff = None
+
+ def make_add_fn(self):
+ # type: () -> tuple[Fn, SSAVal]
+ fn = Fn()
+ op0 = fn.append_new_op(OpKind.FuncArgR3, name="arg")
+ arg = op0.outputs[0]
+ MAXVL = 32
+ op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
+ vl = op1.outputs[0]
+ op2 = fn.append_new_op(
+ OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
+ name="ld")
+ a = op2.outputs[0]
+ op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ maxvl=MAXVL, name="li")
+ b = op3.outputs[0]
+ op4 = fn.append_new_op(OpKind.SetCA, name="ca")
+ ca = op4.outputs[0]
+ op5 = fn.append_new_op(
+ OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
+ s = op5.outputs[0]
+ _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+ immediates=[0], maxvl=MAXVL, name="st")
+ return fn, arg
+
+ def test_register_allocate(self):
+ fn, _arg = self.make_add_fn()
+ reg_assignments = allocate_registers(fn)
+
+ self.assertEqual(
+ repr(reg_assignments),
+ "{<add.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.inp1.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<add.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+ "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<st.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<st.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ca.outputs[0]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<vl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<arg.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+ )
+
+ def test_gen_asm(self):
+ fn, _arg = self.make_add_fn()
+ reg_assignments = allocate_registers(fn)
+
+ self.assertEqual(
+ repr(reg_assignments),
+ "{<add.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.inp1.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<add.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=78, reg_len=32), "
+ "<st.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<st.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<st.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<add.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ca.outputs[0]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=32), "
+ "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<ld.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ld.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<vl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<arg.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<arg.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1)}"
+ )
+ state = GenAsmState(reg_assignments)
+ fn.gen_asm(state)
+ self.assertEqual(state.output, [
+ 'or 4, 3, 3',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'or 3, 4, 4',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.ld *14, 0(3)',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.or *46, *14, *14',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.addi *14, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'subfc 0, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.or *78, *46, *46',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.or *46, *14, *14',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.adde *14, *78, *46',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'or 3, 4, 4',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.std *14, 0(3)',
+ ])
+
+
+if __name__ == "__main__":
+ _ = unittest.main()
from functools import lru_cache, total_ordering
from io import StringIO
from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
- Mapping, Sequence, TypeVar, overload)
+ Mapping, Sequence, TypeVar, Union, overload)
from weakref import WeakValueDictionary as _WeakVDict
from cached_property import cached_property
for op in self.ops:
op.pre_ra_sim(state)
+ def gen_asm(self, state):
+ # type: (GenAsmState) -> None
+ for op in self.ops:
+ op.gen_asm(state)
+
def pre_ra_insert_copies(self):
# type: () -> None
orig_ops = list(self.ops)
copied_outputs = {} # type: dict[SSAVal, SSAVal]
+ setvli_outputs = {} # type: dict[SSAVal, Op]
self.ops.clear()
for op in orig_ops:
for i in range(len(op.input_vals)):
op.input_vals[i] = mv.outputs[0]
elif inp.ty.base_ty is BaseTy.CA \
or inp.ty.base_ty is BaseTy.VL_MAXVL:
- # all copies would be no-ops, so we don't need to copy
+ # all copies would be no-ops, so we don't need to copy,
+ # though we do need to rematerialize SetVLI ops right
+ # before the ops VL
+ if inp in setvli_outputs:
+ setvl = self.append_new_op(
+ OpKind.SetVLI,
+ immediates=setvli_outputs[inp].immediates,
+ name=f"{op.name}.inp{i}.setvl")
+ inp = setvl.outputs[0]
op.input_vals[i] = inp
else:
assert_never(inp.ty.base_ty)
self.ops.append(op)
for i, out in enumerate(op.outputs):
+ if op.kind is OpKind.SetVLI:
+ setvli_outputs[out] = op
if out.ty.base_ty is BaseTy.I64:
maxvl = out.ty.reg_len
if out.ty.reg_len != 1:
return f"<range:{start}..{stop}>"
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
@final
class FnAnalysis:
__slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
# type: () -> int
return hash(self.fn)
+ def __repr__(self):
+ # type: () -> str
+ return "<FnAnalysis>"
+
@unique
@final
def loc_count(self):
# type: () -> int
if self is LocKind.StackI64:
- return 1024
+ return 512
if self is LocKind.GPR or self is LocKind.CA \
or self is LocKind.VL_MAXVL:
return self.base_ty.max_reg_len
# type: (Ty, int) -> Loc
if subloc_ty.base_ty != self.kind.base_ty:
raise ValueError("BaseTy mismatch")
- start = self.start + offset
- if offset < 0 or start + subloc_ty.reg_len > self.reg_len:
+ if offset < 0 or offset + subloc_ty.reg_len > self.reg_len:
raise ValueError("invalid sub-Loc: offset and/or "
"subloc_ty.reg_len out of range")
- return Loc(kind=self.kind, start=start, reg_len=subloc_ty.reg_len)
+ return Loc(kind=self.kind,
+ start=self.start + offset, reg_len=subloc_ty.reg_len)
SPECIAL_GPRS = (
def __len__(self):
return self.__len
+ __HASHES = {} # type: dict[tuple[Ty | None, FMap[LocKind, FBitSet]], int]
+
@cached_property
def __hash(self):
- return super()._hash()
+ # cache hashes to avoid slow LocSet iteration
+ key = self.ty, self.starts
+ retval = self.__HASHES.get(key, None)
+ if retval is None:
+ self.__HASHES[key] = retval = super(LocSet, self)._hash()
+ return retval
def __hash__(self):
return self.__hash
+ def __eq__(self, __other):
+ # type: (LocSet | Any) -> bool
+ if isinstance(__other, LocSet):
+ return self.ty == __other.ty and self.starts == __other.starts
+ return super().__eq__(__other)
+
@lru_cache(maxsize=None, typed=True)
def max_conflicts_with(self, other):
# type: (LocSet | Loc) -> int
raise ValueError("loc_set_before_spread must not be empty")
self.loc_set_before_spread = loc_set_before_spread
self.tied_input_index = tied_input_index
- if self.tied_input_index is not None and self.spread_index is not None:
+ if self.tied_input_index is not None and spread_index is not None:
raise ValueError("operand can't be both spread and tied")
self.spread_index = spread_index
self.write_stage = write_stage
SvAddE = GenericOpProperties(
demo_asm="sv.adde *RT, *RA, *RB",
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
- outputs=[OD_EXTRA3_VGPR, OD_CA],
+ outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
)
_PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
_GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
SvSubFE = GenericOpProperties(
demo_asm="sv.subfe *RT, *RA, *RB",
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
- outputs=[OD_EXTRA3_VGPR, OD_CA],
+ outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
)
_PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
_GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
@staticmethod
- def __veccopytoreg_gen_asm(op, state):
- # type: (Op, GenAsmState) -> None
- src_loc = state.loc(op.input_vals[0], (LocKind.GPR, LocKind.StackI64))
- dest_loc = state.loc(op.outputs[0], LocKind.GPR)
- RT = state.vgpr(dest_loc)
+ def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
+ # type: (Loc, Loc, bool, GenAsmState) -> None
+ sv = "sv." if is_vec else ""
+ rev = ""
+ if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
+ rev = "/mrr"
if src_loc == dest_loc:
return # no-op
- assert src_loc.kind in (LocKind.GPR, LocKind.StackI64), \
- "checked by loc()"
+ if src_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+ raise ValueError(f"invalid src_loc.kind: {src_loc.kind}")
+ if dest_loc.kind not in (LocKind.GPR, LocKind.StackI64):
+ raise ValueError(f"invalid dest_loc.kind: {dest_loc.kind}")
if src_loc.kind is LocKind.StackI64:
+ if dest_loc.kind is LocKind.StackI64:
+ raise ValueError(
+ f"can't copy from stack to stack: {src_loc} {dest_loc}")
+ elif dest_loc.kind is not LocKind.GPR:
+ assert_never(dest_loc.kind)
src = state.stack(src_loc)
- state.writeln(f"sv.ld {RT}, {src}")
- return
- elif src_loc.kind is not LocKind.GPR:
+ dest = state.gpr(dest_loc, is_vec=is_vec)
+ state.writeln(f"{sv}ld {dest}, {src}")
+ elif dest_loc.kind is LocKind.StackI64:
+ if src_loc.kind is not LocKind.GPR:
+ assert_never(src_loc.kind)
+ src = state.gpr(src_loc, is_vec=is_vec)
+ dest = state.stack(dest_loc)
+ state.writeln(f"{sv}std {src}, {dest}")
+ elif src_loc.kind is LocKind.GPR:
+ if dest_loc.kind is not LocKind.GPR:
+ assert_never(dest_loc.kind)
+ src = state.gpr(src_loc, is_vec=is_vec)
+ dest = state.gpr(dest_loc, is_vec=is_vec)
+ state.writeln(f"{sv}or{rev} {dest}, {src}, {src}")
+ else:
assert_never(src_loc.kind)
- rev = ""
- if src_loc.conflicts(dest_loc) and src_loc.start < dest_loc.start:
- rev = "/mrr"
- src = state.vgpr(src_loc)
- state.writeln(f"sv.or{rev} {RT}, {src}, {src}")
+
+ @staticmethod
+ def __veccopytoreg_gen_asm(op, state):
+ # type: (Op, GenAsmState) -> None
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(
+ op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=True, state=state)
VecCopyToReg = GenericOpProperties(
demo_asm="sv.mv dest, src",
# type: (Op, PreRASimState) -> None
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __veccopyfromreg_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(
+ op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+ is_vec=True, state=state)
VecCopyFromReg = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[OD_EXTRA3_VGPR, OD_VL],
# type: (Op, PreRASimState) -> None
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __copytoreg_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(
+ op.input_vals[0], (LocKind.GPR, LocKind.StackI64)),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=False, state=state)
CopyToReg = GenericOpProperties(
demo_asm="mv dest, src",
inputs=[GenericOperandDesc(
# type: (Op, PreRASimState) -> None
state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __copyfromreg_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(
+ op.outputs[0], (LocKind.GPR, LocKind.StackI64)),
+ is_vec=False, state=state)
CopyFromReg = GenericOpProperties(
demo_asm="mv dest, src",
inputs=[GenericOperandDesc(
state.ssa_vals[op.outputs[0]] = tuple(
state.ssa_vals[i][0] for i in op.input_vals[:-1])
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __concat_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0:-1], LocKind.GPR),
+ dest_loc=state.loc(op.outputs[0], LocKind.GPR),
+ is_vec=True, state=state)
Concat = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[GenericOperandDesc(
for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
state.ssa_vals[op.outputs[idx]] = inp,
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __spread_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ OpKind.__copy_to_from_reg_gen_asm(
+ src_loc=state.loc(op.input_vals[0], LocKind.GPR),
+ dest_loc=state.loc(op.outputs, LocKind.GPR),
+ is_vec=True, state=state)
Spread = GenericOpProperties(
demo_asm="sv.mv dest, src",
inputs=[OD_EXTRA3_VGPR, OD_VL],
RT.append(v & GPR_VALUE_MASK)
state.ssa_vals[op.outputs[0]] = tuple(RT)
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __svld_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RA = state.sgpr(op.input_vals[0])
+ RT = state.vgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"sv.ld {RT}, {imm}({RA})")
SvLd = GenericOpProperties(
demo_asm="sv.ld *RT, imm(RA)",
inputs=[OD_EXTRA3_SGPR, OD_VL],
v = state.load(addr)
state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __ld_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RA = state.sgpr(op.input_vals[0])
+ RT = state.sgpr(op.outputs[0])
+ imm = op.immediates[0]
+ state.writeln(f"ld {RT}, {imm}({RA})")
Ld = GenericOpProperties(
demo_asm="ld RT, imm(RA)",
inputs=[OD_BASE_SGPR],
for i in range(VL):
state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __svstd_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RS = state.vgpr(op.input_vals[0])
+ RA = state.sgpr(op.input_vals[1])
+ imm = op.immediates[0]
+ state.writeln(f"sv.std {RS}, {imm}({RA})")
SvStd = GenericOpProperties(
demo_asm="sv.std *RS, imm(RA)",
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_SGPR, OD_VL],
addr = RA + op.immediates[0]
state.store(addr, value=RS)
- # FIXME: change to correct __*_gen_asm function
@staticmethod
- def __clearca_gen_asm(op, state):
+ def __std_gen_asm(op, state):
# type: (Op, GenAsmState) -> None
- state.writeln("addic 0, 0, 0")
+ RS = state.sgpr(op.input_vals[0])
+ RA = state.sgpr(op.input_vals[1])
+ imm = op.immediates[0]
+ state.writeln(f"std {RS}, {imm}({RA})")
Std = GenericOpProperties(
- demo_asm="std RT, imm(RA)",
+ demo_asm="std RS, imm(RA)",
inputs=[OD_BASE_SGPR, OD_BASE_SGPR],
outputs=[],
immediates=[IMM_S16],
f"expected {out.ty.reg_len} found "
f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
+ def gen_asm(self, state):
+ # type: (GenAsmState) -> None
+ all_loc_kinds = tuple(LocKind)
+ for inp in self.input_vals:
+ state.loc(inp, expected_kinds=all_loc_kinds)
+ for out in self.outputs:
+ state.loc(out, expected_kinds=all_loc_kinds)
+ self.kind.gen_asm(self, state)
+
GPR_SIZE_IN_BYTES = 8
BITS_IN_BYTE = 8
class GenAsmState:
__slots__ = "allocated_locs", "output"
- def __init__(self, allocated_locs, output):
+ def __init__(self, allocated_locs, output=None):
# type: (Mapping[SSAVal, Loc], StringIO | list[str] | None) -> None
super().__init__()
self.allocated_locs = FMap(allocated_locs)
output = []
self.output = output
- def loc(self, ssa_val_or_loc, expected_kinds):
- # type: (SSAVal | Loc, LocKind | tuple[LocKind, ...]) -> Loc
- if isinstance(ssa_val_or_loc, SSAVal):
- retval = self.allocated_locs[ssa_val_or_loc]
- else:
- retval = ssa_val_or_loc
+ __SSA_VAL_OR_LOCS = Union[SSAVal, Loc, Sequence["SSAVal | Loc"]]
+
+ def loc(self, ssa_val_or_locs, expected_kinds):
+ # type: (__SSA_VAL_OR_LOCS, LocKind | tuple[LocKind, ...]) -> Loc
+ if isinstance(ssa_val_or_locs, (SSAVal, Loc)):
+ ssa_val_or_locs = [ssa_val_or_locs]
+ locs = [] # type: list[Loc]
+ for i in ssa_val_or_locs:
+ if isinstance(i, SSAVal):
+ locs.append(self.allocated_locs[i])
+ else:
+ locs.append(i)
+ if len(locs) == 0:
+ raise ValueError("invalid Loc sequence: must not be empty")
+ retval = locs[0].try_concat(*locs[1:])
+ if retval is None:
+ raise ValueError("invalid Loc sequence: try_concat failed")
if isinstance(expected_kinds, LocKind):
expected_kinds = expected_kinds,
if retval.kind not in expected_kinds:
if len(expected_kinds) == 1:
expected_kinds = expected_kinds[0]
- raise ValueError(f"LocKind mismatch: {ssa_val_or_loc}: found "
+ raise ValueError(f"LocKind mismatch: {ssa_val_or_locs}: found "
f"{retval.kind} expected {expected_kinds}")
return retval
- def gpr(self, ssa_val_or_loc, is_vec):
- # type: (SSAVal | Loc, bool) -> str
- loc = self.loc(ssa_val_or_loc, LocKind.GPR)
+ def gpr(self, ssa_val_or_locs, is_vec):
+ # type: (__SSA_VAL_OR_LOCS, bool) -> str
+ loc = self.loc(ssa_val_or_locs, LocKind.GPR)
vec_str = "*" if is_vec else ""
return vec_str + str(loc.start)
- def sgpr(self, ssa_val_or_loc):
- # type: (SSAVal | Loc) -> str
- return self.gpr(ssa_val_or_loc, is_vec=False)
+ def sgpr(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ return self.gpr(ssa_val_or_locs, is_vec=False)
- def vgpr(self, ssa_val_or_loc):
- # type: (SSAVal | Loc) -> str
- return self.gpr(ssa_val_or_loc, is_vec=True)
+ def vgpr(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ return self.gpr(ssa_val_or_locs, is_vec=True)
- def stack(self, ssa_val_or_loc):
- # type: (SSAVal | Loc) -> str
- loc = self.loc(ssa_val_or_loc, LocKind.StackI64)
+ def stack(self, ssa_val_or_locs):
+ # type: (__SSA_VAL_OR_LOCS) -> str
+ loc = self.loc(ssa_val_or_locs, LocKind.StackI64)
return f"{loc.start}(1)"
def writeln(self, *line_segments):
pass
-@plain_data(frozen=True)
+@plain_data(frozen=True, repr=False)
@final
class MergedSSAVal:
"""a set of `SSAVal`s along with their offsets, all register allocated as
stop = max(stop, live_range.stop)
return ProgramRange(start=start, stop=stop)
+ def __repr__(self):
+ return (f"MergedSSAVal({self.fn_analysis}, "
+ f"ssa_val_offsets={self.ssa_val_offsets})")
+
@final
class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
self.__merged_ssa_val_map[ssa_val] = final_merged_ssa_val
return retval
- def __repr__(self):
- # type: () -> str
- s = ",\n".join(repr(v) for v in self.__map.values())
+ def __repr__(self, repr_state=None):
+ # type: (None | IGNodeReprState) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ s = ",\n".join(v.__repr__(repr_state) for v in self.__map.values())
return f"MergedSSAValToIGNodeMap({{{s}}})"
-@plain_data(frozen=True)
+@plain_data(frozen=True, repr=False)
@final
class InterferenceGraph:
__slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
retval.merge(out.tied_input.ssa_val, out)
return retval
+ def __repr__(self, repr_state=None):
+ # type: (None | IGNodeReprState) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ s = self.nodes.__repr__(repr_state)
+ return f"InterferenceGraph(nodes={s}, <...>)"
+
+
+@plain_data(repr=False)
+class IGNodeReprState:
+ __slots__ = "node_ids", "did_full_repr"
+
+ def __init__(self):
+ super().__init__()
+ self.node_ids = {} # type: dict[IGNode, int]
+ self.did_full_repr = OSet() # type: OSet[IGNode]
+
@final
class IGNode:
# type: () -> int
return hash(self.merged_ssa_val)
- def __repr__(self, nodes=None):
- # type: (None | dict[IGNode, int]) -> str
- if nodes is None:
- nodes = {}
- if self in nodes:
- return f"<IGNode #{nodes[self]}>"
- nodes[self] = len(nodes)
- edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
- return (f"IGNode(#{nodes[self]}, "
+ def __repr__(self, repr_state=None, short=False):
+ # type: (None | IGNodeReprState, bool) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ node_id = repr_state.node_ids.get(self, None)
+ if node_id is None:
+ repr_state.node_ids[self] = node_id = len(repr_state.node_ids)
+ if short or self in repr_state.did_full_repr:
+ return f"<IGNode #{node_id}>"
+ repr_state.did_full_repr.add(self)
+ edges = ", ".join(i.__repr__(repr_state, True) for i in self.edges)
+ return (f"IGNode(#{node_id}, "
f"merged_ssa_val={self.merged_ssa_val}, "
- f"edges={edges}, "
+ f"edges={{{edges}}}, "
f"loc={self.loc})")
@property
self.node = node
self.interference_graph = interference_graph
+ def __repr__(self, repr_state=None):
+ # type: (None | IGNodeReprState) -> str
+ if repr_state is None:
+ repr_state = IGNodeReprState()
+ return (f"{__class__.__name__}({self.args[0]!r}, "
+ f"node={self.node.__repr__(repr_state, True)}, "
+ f"interference_graph="
+ f"{self.interference_graph.__repr__(repr_state)})")
+
+ def __str__(self):
+ # type: () -> str
+ return self.__repr__()
+
def allocate_registers(fn):
# type: (Fn) -> dict[SSAVal, Loc]