import unittest
-from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
+from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, BaseTy, Fn,
FnAnalysis, GenAsmState, Loc, LocKind, OpKind, OpStage,
PreRASimState, ProgramPoint,
- SSAVal)
+ SSAVal, Ty)
class TestCompilerIR(unittest.TestCase):
size_in_bytes=GPR_SIZE_IN_BYTES)
self.assertEqual(
repr(state),
- "PreRASimState(ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)}, "
- "memory={\n"
+ "PreRASimState(memory={\n"
"0x00100: <0xffffffffffffffff>,\n"
- "0x00108: <0xabcdef0123456789>})")
- fn.pre_ra_sim(state)
+ "0x00108: <0xabcdef0123456789>}, "
+ "ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)})")
+ fn.sim(state)
self.assertEqual(
repr(state),
- "PreRASimState(ssa_vals={\n"
- "<arg.outputs[0]: <I64>>: (0x100,),\n"
- "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
- "<ld.outputs[0]: <I64*32>>: (\n"
- " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0),\n"
- "<li.outputs[0]: <I64*32>>: (\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0),\n"
- "<ca.outputs[0]: <CA>>: (0x1,),\n"
- "<add.outputs[0]: <I64*32>>: (\n"
- " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0,\n"
- " 0x0, 0x0, 0x0, 0x0),\n"
- "<add.outputs[1]: <CA>>: (0x0,),\n"
- "}, memory={\n"
+ "PreRASimState(memory={\n"
"0x00100: <0x0000000000000000>,\n"
"0x00108: <0xabcdef012345678a>,\n"
"0x00110: <0x0000000000000000>,\n"
"0x001e0: <0x0000000000000000>,\n"
"0x001e8: <0x0000000000000000>,\n"
"0x001f0: <0x0000000000000000>,\n"
- "0x001f8: <0x0000000000000000>})")
+ "0x001f8: <0x0000000000000000>}, ssa_vals={\n"
+ "<arg.outputs[0]: <I64>>: (0x100,),\n"
+ "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
+ "<ld.outputs[0]: <I64*32>>: (\n"
+ " 0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0),\n"
+ "<li.outputs[0]: <I64*32>>: (\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0),\n"
+ "<ca.outputs[0]: <CA>>: (0x1,),\n"
+ "<add.outputs[0]: <I64*32>>: (\n"
+ " 0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0,\n"
+ " 0x0, 0x0, 0x0, 0x0),\n"
+ "<add.outputs[1]: <CA>>: (0x0,),\n"
+ "})")
def test_gen_asm(self):
fn, _arg = self.make_add_fn()
'sv.std *32, 0(3)',
])
+ def test_spread(self):
+ fn = Fn()
+ maxvl = 4
+ vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+ name="vl", maxvl=maxvl).outputs[0]
+ li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ name="li", maxvl=maxvl).outputs[0]
+ spread_op = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
+ name="spread", maxvl=maxvl)
+ self.assertEqual(spread_op.outputs[0].ty_before_spread,
+ Ty(base_ty=BaseTy.I64, reg_len=maxvl))
+ _concat = fn.append_new_op(
+ OpKind.Concat, input_vals=[*spread_op.outputs[::-1], vl],
+ name="concat", maxvl=maxvl)
+ self.assertEqual([repr(op.properties) for op in fn.ops], [
+ "OpProperties(kind=OpKind.SetVLI, inputs=("
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),"
+ "), maxvl=4)",
+ "OpProperties(kind=OpKind.SvLI, inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),"
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early),"
+ "), maxvl=4)",
+ "OpProperties(kind=OpKind.Spread, inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)"
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=0, "
+ "write_stage=OpStage.Late), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=1, "
+ "write_stage=OpStage.Late), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=2, "
+ "write_stage=OpStage.Late), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=3, "
+ "write_stage=OpStage.Late)"
+ "), maxvl=4)",
+ "OpProperties(kind=OpKind.Concat, inputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=0, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=1, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=2, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=3, "
+ "write_stage=OpStage.Early), "
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.VL_MAXVL: FBitSet([0])}), "
+ "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Early)"
+ "), outputs=("
+ "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+ "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+ "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+ "write_stage=OpStage.Late),"
+ "), maxvl=4)",
+ ])
+ self.assertEqual([repr(op) for op in fn.ops], [
+ "Op(kind=OpKind.SetVLI, input_vals=["
+ "], input_uses=("
+ "), immediates=[4], outputs=("
+ "<vl.outputs[0]: <VL_MAXVL>>,"
+ "), name='vl')",
+ "Op(kind=OpKind.SvLI, input_vals=["
+ "<vl.outputs[0]: <VL_MAXVL>>"
+ "], input_uses=("
+ "<li.input_uses[0]: <VL_MAXVL>>,"
+ "), immediates=[0], outputs=("
+ "<li.outputs[0]: <I64*4>>,"
+ "), name='li')",
+ "Op(kind=OpKind.Spread, input_vals=["
+ "<li.outputs[0]: <I64*4>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>"
+ "], input_uses=("
+ "<spread.input_uses[0]: <I64*4>>, "
+ "<spread.input_uses[1]: <VL_MAXVL>>"
+ "), immediates=[], outputs=("
+ "<spread.outputs[0]: <I64>>, "
+ "<spread.outputs[1]: <I64>>, "
+ "<spread.outputs[2]: <I64>>, "
+ "<spread.outputs[3]: <I64>>"
+ "), name='spread')",
+ "Op(kind=OpKind.Concat, input_vals=["
+ "<spread.outputs[3]: <I64>>, "
+ "<spread.outputs[2]: <I64>>, "
+ "<spread.outputs[1]: <I64>>, "
+ "<spread.outputs[0]: <I64>>, "
+ "<vl.outputs[0]: <VL_MAXVL>>"
+ "], input_uses=("
+ "<concat.input_uses[0]: <I64>>, "
+ "<concat.input_uses[1]: <I64>>, "
+ "<concat.input_uses[2]: <I64>>, "
+ "<concat.input_uses[3]: <I64>>, "
+ "<concat.input_uses[4]: <VL_MAXVL>>"
+ "), immediates=[], outputs=("
+ "<concat.outputs[0]: <I64*4>>,"
+ "), name='concat')",
+ ])
+
if __name__ == "__main__":
_ = unittest.main()
+import sys
import unittest
-from bigint_presentation_code.compiler_ir2 import Fn, GenAsmState, OpKind, SSAVal
+from bigint_presentation_code.compiler_ir2 import (Fn, GenAsmState, OpKind,
+ SSAVal)
from bigint_presentation_code.register_allocator2 import allocate_registers
'sv.std *14, 0(3)',
])
+ def test_register_allocate_spread(self):
+ fn = Fn()
+ maxvl = 32
+ vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+ name="vl", maxvl=maxvl).outputs[0]
+ li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+ name="li", maxvl=maxvl).outputs[0]
+ spread = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
+ name="spread", maxvl=maxvl).outputs
+ _concat = fn.append_new_op(
+ OpKind.Concat, input_vals=[*spread[::-1], vl],
+ name="concat", maxvl=maxvl)
+ reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+
+ self.assertEqual(
+ repr(reg_assignments),
+ "{<concat.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<concat.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<concat.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+ "<concat.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+ "<concat.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+ "<concat.inp3.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+ "<concat.inp4.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+ "<concat.inp5.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+ "<concat.inp6.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
+ "<concat.inp7.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
+ "<concat.inp8.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
+ "<concat.inp9.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+ "<concat.inp10.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
+ "<concat.inp11.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
+ "<concat.inp12.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
+ "<concat.inp13.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
+ "<concat.inp14.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
+ "<concat.inp15.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
+ "<concat.inp16.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
+ "<concat.inp17.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
+ "<concat.inp18.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
+ "<concat.inp19.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
+ "<concat.inp20.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
+ "<concat.inp21.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
+ "<concat.inp22.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
+ "<concat.inp23.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
+ "<concat.inp24.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
+ "<concat.inp25.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
+ "<concat.inp26.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
+ "<concat.inp27.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
+ "<concat.inp28.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
+ "<concat.inp29.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
+ "<concat.inp30.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
+ "<concat.inp31.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
+ "<concat.inp32.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<spread.out31.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<spread.out30.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<spread.out29.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<spread.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+ "<spread.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+ "<spread.outputs[2]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+ "<spread.outputs[3]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+ "<spread.outputs[4]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+ "<spread.outputs[5]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+ "<spread.outputs[6]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
+ "<spread.outputs[7]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
+ "<spread.outputs[8]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
+ "<spread.outputs[9]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+ "<spread.outputs[10]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
+ "<spread.outputs[11]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
+ "<spread.outputs[12]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
+ "<spread.outputs[13]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
+ "<spread.outputs[14]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
+ "<spread.outputs[15]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
+ "<spread.outputs[16]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
+ "<spread.outputs[17]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
+ "<spread.outputs[18]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
+ "<spread.outputs[19]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
+ "<spread.outputs[20]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
+ "<spread.outputs[21]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
+ "<spread.outputs[22]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
+ "<spread.outputs[23]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
+ "<spread.outputs[24]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
+ "<spread.outputs[25]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
+ "<spread.outputs[26]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
+ "<spread.outputs[27]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
+ "<spread.outputs[28]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
+ "<spread.outputs[29]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
+ "<spread.outputs[30]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
+ "<spread.outputs[31]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
+ "<spread.out28.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<spread.out27.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+ "<spread.out26.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
+ "<spread.out25.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
+ "<spread.out24.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
+ "<spread.out23.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+ "<spread.out22.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+ "<spread.out21.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=46, reg_len=1), "
+ "<spread.out20.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=47, reg_len=1), "
+ "<spread.out19.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=48, reg_len=1), "
+ "<spread.out18.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=49, reg_len=1), "
+ "<spread.out17.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=50, reg_len=1), "
+ "<spread.out16.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=51, reg_len=1), "
+ "<spread.out15.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=52, reg_len=1), "
+ "<spread.out14.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=53, reg_len=1), "
+ "<spread.out13.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=54, reg_len=1), "
+ "<spread.out12.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=55, reg_len=1), "
+ "<spread.out11.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=56, reg_len=1), "
+ "<spread.out10.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=57, reg_len=1), "
+ "<spread.out9.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=58, reg_len=1), "
+ "<spread.out8.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=59, reg_len=1), "
+ "<spread.out7.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=60, reg_len=1), "
+ "<spread.out6.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=61, reg_len=1), "
+ "<spread.out5.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=62, reg_len=1), "
+ "<spread.out4.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=63, reg_len=1), "
+ "<spread.out3.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=64, reg_len=1), "
+ "<spread.out2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=65, reg_len=1), "
+ "<spread.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=66, reg_len=1), "
+ "<spread.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=67, reg_len=1), "
+ "<spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<spread.inp0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.out0.copy.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<li.outputs[0]: <I64*32>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+ "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<vl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)}"
+ )
+ state = GenAsmState(reg_assignments)
+ fn.gen_asm(state)
+ self.assertEqual(state.output, [
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'sv.addi *14, 0, 0',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'or 67, 14, 14',
+ 'or 66, 15, 15',
+ 'or 65, 16, 16',
+ 'or 64, 17, 17',
+ 'or 63, 18, 18',
+ 'or 62, 19, 19',
+ 'or 61, 20, 20',
+ 'or 60, 21, 21',
+ 'or 59, 22, 22',
+ 'or 58, 23, 23',
+ 'or 57, 24, 24',
+ 'or 56, 25, 25',
+ 'or 55, 26, 26',
+ 'or 54, 27, 27',
+ 'or 53, 28, 28',
+ 'or 52, 29, 29',
+ 'or 51, 30, 30',
+ 'or 50, 31, 31',
+ 'or 49, 32, 32',
+ 'or 48, 33, 33',
+ 'or 47, 34, 34',
+ 'or 46, 35, 35',
+ 'or 12, 36, 36',
+ 'or 11, 37, 37',
+ 'or 10, 38, 38',
+ 'or 9, 39, 39',
+ 'or 8, 40, 40',
+ 'or 7, 41, 41',
+ 'or 6, 42, 42',
+ 'or 5, 43, 43',
+ 'or 4, 44, 44',
+ 'or 3, 45, 45',
+ 'or 14, 3, 3',
+ 'or 15, 4, 4',
+ 'or 16, 5, 5',
+ 'or 17, 6, 6',
+ 'or 18, 7, 7',
+ 'or 19, 8, 8',
+ 'or 20, 9, 9',
+ 'or 21, 10, 10',
+ 'or 22, 11, 11',
+ 'or 23, 12, 12',
+ 'or 24, 46, 46',
+ 'or 25, 47, 47',
+ 'or 26, 48, 48',
+ 'or 27, 49, 49',
+ 'or 28, 50, 50',
+ 'or 29, 51, 51',
+ 'or 30, 52, 52',
+ 'or 31, 53, 53',
+ 'or 32, 54, 54',
+ 'or 33, 55, 55',
+ 'or 34, 56, 56',
+ 'or 35, 57, 57',
+ 'or 36, 58, 58',
+ 'or 37, 59, 59',
+ 'or 38, 60, 60',
+ 'or 39, 61, 61',
+ 'or 40, 62, 62',
+ 'or 41, 63, 63',
+ 'or 42, 64, 64',
+ 'or 43, 65, 65',
+ 'or 44, 66, 66',
+ 'or 45, 67, 67',
+ 'setvl 0, 0, 32, 0, 1, 1',
+ 'setvl 0, 0, 32, 0, 1, 1'])
+
if __name__ == "__main__":
_ = unittest.main()
import unittest
+from typing import Callable
-from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
+from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES,
+ BaseSimState, Fn,
GenAsmState, OpKind,
+ PostRASimState,
PreRASimState)
from bigint_presentation_code.register_allocator2 import allocate_registers
from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
)
def test_simple_mul_192x192_pre_ra_sim(self):
+ def create_sim_state(code):
+ # type: (SimpleMul192x192) -> BaseSimState
+ return PreRASimState(ssa_vals={}, memory={})
+ self.tst_simple_mul_192x192_sim(create_sim_state)
+
+ def test_simple_mul_192x192_post_ra_sim(self):
+ def create_sim_state(code):
+ # type: (SimpleMul192x192) -> BaseSimState
+ ssa_val_to_loc_map = allocate_registers(code.fn)
+ return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map,
+ memory={}, loc_values={})
+ self.tst_simple_mul_192x192_sim(create_sim_state)
+
+ def tst_simple_mul_192x192_sim(self, create_sim_state):
+ # type: (Callable[[SimpleMul192x192], BaseSimState]) -> None
# test multiplying:
# 0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
# * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
# == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
# 'little')
code = SimpleMul192x192()
+ state = create_sim_state(code)
ptr_in = 0x100
dest_ptr = ptr_in + code.dest_offset
lhs_ptr = ptr_in + code.lhs_offset
rhs_ptr = ptr_in + code.rhs_offset
- state = PreRASimState(ssa_vals={code.ptr_in: (ptr_in,)}, memory={})
+ state[code.ptr_in] = ptr_in,
state.store(lhs_ptr, 0x821a2342132c5b57)
state.store(lhs_ptr + 8, 0x4c6b5f2b19e1a53e)
state.store(lhs_ptr + 16, 0x000191acb262e15b)
state.store(rhs_ptr, 0x208a49071aeec507)
state.store(rhs_ptr + 8, 0xcf1f597598194ae6)
state.store(rhs_ptr + 16, 0x4a37c0567bcbab53)
- code.fn.pre_ra_sim(state)
+ code.fn.sim(state)
expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
"name='store_dest')",
])
- # FIXME: register allocator currently allocates wrong registers
- @unittest.expectedFailure
def test_simple_mul_192x192_reg_alloc(self):
code = SimpleMul192x192()
fn = code.fn
assigned_registers = allocate_registers(fn)
- self.assertEqual(assigned_registers, {
- })
- self.fail("register allocator currently allocates wrong registers")
+ self.assertEqual(
+ repr(assigned_registers), "{"
+ "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<store_dest.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<store_dest.inp0.copy.outputs[0]: <I64*6>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=6), "
+ "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<setvl6.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<concat_retval.out0.copy.outputs[0]: <I64*6>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
+ "<concat_retval.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<concat_retval.outputs[0]: <I64*6>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
+ "<concat_retval.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<concat_retval.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<concat_retval.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<concat_retval.inp3.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<concat_retval.inp4.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+ "<concat_retval.inp5.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
+ "<concat_retval.inp6.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<retval_setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add_hi2.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
+ "<clear_ca2.outputs[0]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add2.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add_hi2.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add_hi2.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<add_hi2.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<add2_rt_spread.out2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
+ "<add2_rt_spread.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+ "<add2_rt_spread.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+ "<add2_rt_spread.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<add2_rt_spread.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<add2_rt_spread.outputs[2]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<add2_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add2_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add2_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add2.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add2.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add2.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add2.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add2.inp1.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+ "<add2.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add2.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=9, reg_len=3), "
+ "<add2.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add2_rb_concat.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add2_rb_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add2_rb_concat.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add2_rb_concat.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<add2_rb_concat.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<add2_rb_concat.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<add2_rb_concat.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul2.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+ "<mul2.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+ "<mul2.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul2.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<mul2.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<mul2.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=3), "
+ "<mul2.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul2.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+ "<mul2.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
+ "<mul2.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add_hi1.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+ "<clear_ca1.outputs[0]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add1.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add_hi1.outputs[1]: <CA>>: "
+ "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+ "<add_hi1.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<add_hi1.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<add1_rt_spread.out2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+ "<add1_rt_spread.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+ "<add1_rt_spread.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+ "<add1_rt_spread.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<add1_rt_spread.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<add1_rt_spread.outputs[2]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<add1_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add1_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add1_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add1.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add1.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add1.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add1.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add1.inp1.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+ "<add1.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add1.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=9, reg_len=3), "
+ "<add1.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add1_rb_concat.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add1_rb_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<add1_rb_concat.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<add1_rb_concat.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<add1_rb_concat.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<add1_rb_concat.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<add1_rb_concat.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul1.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+ "<mul1.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+ "<mul1.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul1.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<mul1.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<mul1.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=3), "
+ "<mul1.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul1.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+ "<mul1.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
+ "<mul1.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul0_rt_spread.out2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+ "<mul0_rt_spread.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+ "<mul0_rt_spread.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+ "<mul0_rt_spread.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<mul0_rt_spread.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<mul0_rt_spread.outputs[2]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<mul0_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul0_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<mul0_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul0.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+ "<mul0.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<mul0.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul0.inp2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<mul0.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<mul0.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<mul0.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<mul0.inp1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+ "<mul0.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
+ "<mul0.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<zero.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+ "<zero.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+ "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<rhs_spread.out2.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+ "<rhs_spread.out1.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+ "<rhs_spread.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+ "<rhs_spread.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+ "<rhs_spread.outputs[1]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<rhs_spread.outputs[2]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+ "<rhs_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<rhs_spread.inp0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<load_rhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<load_rhs.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<load_rhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
+ "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<load_lhs.outputs[0]: <I64*3>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+ "<load_lhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+ "<setvl3.outputs[0]: <VL_MAXVL>>: "
+ "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+ "<ptr_in.out0.copy.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+ "<ptr_in.outputs[0]: <I64>>: "
+ "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
+ "}")
- # FIXME: register allocator currently allocates wrong registers
- @unittest.expectedFailure
def test_simple_mul_192x192_asm(self):
- self.skipTest("WIP")
code = SimpleMul192x192()
fn = code.fn
assigned_registers = allocate_registers(fn)
gen_asm_state = GenAsmState(assigned_registers)
fn.gen_asm(gen_asm_state)
self.assertEqual(gen_asm_state.output, [
+ 'or 23, 3, 3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'or 6, 23, 23',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.ld *3, 48(6)',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *20, *3, *3',
+ 'or 6, 23, 23',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.ld *3, 72(6)',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or/mrr *5, *3, *3',
+ 'or 4, 5, 5',
+ 'or 14, 6, 6',
+ 'or 19, 7, 7',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'addi 3, 0, 0',
+ 'or 18, 3, 3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *8, *20, *20',
+ 'or 7, 4, 4',
+ 'or 6, 18, 18',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.maddedu *3, *8, 7, 6',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'or 15, 6, 6',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'or 17, 3, 3',
+ 'or 12, 4, 4',
+ 'or 11, 5, 5',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *8, *20, *20',
+ 'or 7, 14, 14',
+ 'or 3, 18, 18',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.maddedu *4, *8, 7, 3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or/mrr *6, *4, *4',
+ 'or 14, 3, 3',
+ 'or 3, 12, 12',
+ 'or 4, 11, 11',
+ 'or 5, 15, 15',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'addic 0, 0, 0',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *9, *6, *6',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *6, *3, *3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.adde *3, *9, *6',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'or 16, 3, 3',
+ 'or 15, 4, 4',
+ 'or 12, 5, 5',
+ 'or 4, 14, 14',
+ 'addze *3, *4',
+ 'or 11, 3, 3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *8, *20, *20',
+ 'or 7, 19, 19',
+ 'or 3, 18, 18',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.maddedu *4, *8, 7, 3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or/mrr *6, *4, *4',
+ 'or 14, 3, 3',
+ 'or 3, 15, 15',
+ 'or 4, 12, 12',
+ 'or 5, 11, 11',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'addic 0, 0, 0',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *9, *6, *6',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.or *6, *3, *3',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'sv.adde *3, *9, *6',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'setvl 0, 0, 3, 0, 1, 1',
+ 'or 12, 3, 3',
+ 'or 11, 4, 4',
+ 'or 10, 5, 5',
+ 'or 4, 14, 14',
+ 'addze *3, *4',
+ 'or 9, 3, 3',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'or 3, 17, 17',
+ 'or 4, 16, 16',
+ 'or 5, 12, 12',
+ 'or 6, 11, 11',
+ 'or 7, 10, 10',
+ 'or 8, 9, 9',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'sv.or/mrr *4, *3, *3',
+ 'or 3, 23, 23',
+ 'setvl 0, 0, 6, 0, 1, 1',
+ 'sv.std *4, 0(3)'
])
- self.fail("register allocator currently allocates wrong registers")
if __name__ == "__main__":
from bigint_presentation_code.type_util import (Literal, Self, assert_never,
final)
-from bigint_presentation_code.util import BitSet, FBitSet, FMap, InternedMeta, OFSet, OSet
+from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
+ OFSet, OSet)
@final
self.append_op(retval)
return retval
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
+ def sim(self, state):
+ # type: (BaseSimState) -> None
for op in self.ops:
- op.pre_ra_sim(state)
+ op.sim(state)
def gen_asm(self, state):
# type: (GenAsmState) -> None
)
-@plain_data(frozen=True, eq=False)
+@final
+class _LocSetHashHelper(AbstractSet[Loc]):
+ """helper to more quickly compute LocSet's hash"""
+
+ def __init__(self, locs):
+ # type: (Iterable[Loc]) -> None
+ super().__init__()
+ self.locs = list(locs)
+
+ def __hash__(self):
+ # type: () -> int
+ return super()._hash()
+
+ def __contains__(self, x):
+ # type: (Loc | Any) -> bool
+ return x in self.locs
+
+ def __iter__(self):
+ # type: () -> Iterator[Loc]
+ return iter(self.locs)
+
+ def __len__(self):
+ return len(self.locs)
+
+
+@plain_data(frozen=True, eq=False, repr=False)
@final
class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
- __slots__ = "starts", "ty"
+ __slots__ = "starts", "ty", "_LocSet__hash"
def __init__(self, __locs=()):
# type: (Iterable[Loc]) -> None
if isinstance(__locs, LocSet):
self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
self.ty = __locs.ty # type: Ty | None
+ self._LocSet__hash = __locs._LocSet__hash # type: int
return
starts = {i: BitSet() for i in LocKind}
- ty = None
- for loc in __locs:
- if ty is None:
- ty = loc.ty
- if ty != loc.ty:
- raise ValueError(f"conflicting types: {ty} != {loc.ty}")
- starts[loc.kind].add(loc.start)
+ ty = None # type: None | Ty
+
+ def locs():
+ # type: () -> Iterable[Loc]
+ nonlocal ty
+ for loc in __locs:
+ if ty is None:
+ ty = loc.ty
+ if ty != loc.ty:
+ raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+ starts[loc.kind].add(loc.start)
+ yield loc
+ self._LocSet__hash = _LocSetHashHelper(locs()).__hash__()
self.starts = FMap(
(k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
self.ty = ty
return self.__len
def __hash__(self):
- return super()._hash()
+ return self._LocSet__hash
def __eq__(self, __other):
# type: (LocSet | Any) -> bool
else:
return sum(other.conflicts(i) for i in self)
+ def __repr__(self):
+ items = [] # type: list[str]
+ for name in fields(self):
+ if name.startswith("_"):
+ continue
+ items.append(f"{name}={getattr(self, name)!r}")
+ return f"LocSet({', '.join(items)})"
+
@plain_data(frozen=True, unsafe_hash=True)
@final
raise ValueError("operand can't be both spread and vector")
self.write_stage = write_stage
+ @cached_property
+ def ty_before_spread(self):
+ # type: () -> GenericTy
+ if self.spread:
+ return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
+ return self.ty
+
def tied_to_input(self, tied_input_index):
# type: (int) -> Self
return GenericOperandDesc(self.ty, self.sub_kinds,
rep_count = 1
if self.spread:
rep_count = maxvl
- maxvl = 1
- ty = self.ty.instantiate(maxvl=maxvl)
+ ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
- def locs():
+ def locs_before_spread():
# type: () -> Iterable[Loc]
if self.fixed_loc is not None:
- if ty != self.fixed_loc.ty:
+ if ty_before_spread != self.fixed_loc.ty:
raise ValueError(
f"instantiation failed: type mismatch with fixed_loc: "
- f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
+ f"instantiated type: {ty_before_spread} "
+ f"fixed_loc: {self.fixed_loc}")
yield self.fixed_loc
return
for sub_kind in self.sub_kinds:
- yield from sub_kind.allocatable_locs(ty)
- loc_set_before_spread = LocSet(locs())
+ yield from sub_kind.allocatable_locs(ty_before_spread)
+ loc_set_before_spread = LocSet(locs_before_spread())
for idx in range(rep_count):
if not self.spread:
idx = None
IMM_S16 = range(-1 << 15, 1 << 15)
-_PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None]
-_PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN]
-_PRE_RA_SIMS = {} # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
+_SIM_FN = Callable[["Op", "BaseSimState"], None]
+_SIM_FN2 = Callable[[], _SIM_FN]
+_SIM_FNS = {} # type: dict[GenericOpProperties | Any, _SIM_FN2]
_GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
_GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
_GEN_ASMS = {} # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
return "OpKind." + self._name_
@cached_property
- def pre_ra_sim(self):
- # type: () -> _PRE_RA_SIM_FN
- return _PRE_RA_SIMS[self.properties]()
+ def sim(self):
+ # type: () -> _SIM_FN
+ return _SIM_FNS[self.properties]()
@cached_property
def gen_asm(self):
return _GEN_ASMS[self.properties]()
@staticmethod
- def __clearca_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = False,
+ def __clearca_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = False,
@staticmethod
def __clearca_gen_asm(op, state):
inputs=[],
outputs=[OD_CA.with_write_stage(OpStage.Late)],
)
- _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
+ _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
_GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
@staticmethod
- def __setca_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = True,
+ def __setca_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = True,
@staticmethod
def __setca_gen_asm(op, state):
inputs=[],
outputs=[OD_CA.with_write_stage(OpStage.Late)],
)
- _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
+ _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
_GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
@staticmethod
- def __svadde_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RA = state.ssa_vals[op.input_vals[0]]
- RB = state.ssa_vals[op.input_vals[1]]
- carry, = state.ssa_vals[op.input_vals[2]]
- VL, = state.ssa_vals[op.input_vals[3]]
+ def __svadde_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA = state[op.input_vals[0]]
+ RB = state[op.input_vals[1]]
+ carry, = state[op.input_vals[2]]
+ VL, = state[op.input_vals[3]]
RT = [] # type: list[int]
for i in range(VL):
v = RA[i] + RB[i] + carry
RT.append(v & GPR_VALUE_MASK)
carry = (v >> GPR_SIZE_IN_BITS) != 0
- state.ssa_vals[op.outputs[0]] = tuple(RT)
- state.ssa_vals[op.outputs[1]] = carry,
+ state[op.outputs[0]] = tuple(RT)
+ state[op.outputs[1]] = carry,
@staticmethod
def __svadde_gen_asm(op, state):
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
)
- _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
+ _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
_GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
@staticmethod
- def __addze_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RA, = state.ssa_vals[op.input_vals[0]]
- carry, = state.ssa_vals[op.input_vals[1]]
+ def __addze_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA, = state[op.input_vals[0]]
+ carry, = state[op.input_vals[1]]
v = RA + carry
RT = v & GPR_VALUE_MASK
carry = (v >> GPR_SIZE_IN_BITS) != 0
- state.ssa_vals[op.outputs[0]] = RT,
- state.ssa_vals[op.outputs[1]] = carry,
+ state[op.outputs[0]] = RT,
+ state[op.outputs[1]] = carry,
@staticmethod
def __addze_gen_asm(op, state):
inputs=[OD_BASE_SGPR, OD_CA],
outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
)
- _PRE_RA_SIMS[AddZE] = lambda: OpKind.__addze_pre_ra_sim
+ _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
_GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
@staticmethod
- def __svsubfe_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RA = state.ssa_vals[op.input_vals[0]]
- RB = state.ssa_vals[op.input_vals[1]]
- carry, = state.ssa_vals[op.input_vals[2]]
- VL, = state.ssa_vals[op.input_vals[3]]
+ def __svsubfe_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA = state[op.input_vals[0]]
+ RB = state[op.input_vals[1]]
+ carry, = state[op.input_vals[2]]
+ VL, = state[op.input_vals[3]]
RT = [] # type: list[int]
for i in range(VL):
v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
RT.append(v & GPR_VALUE_MASK)
carry = (v >> GPR_SIZE_IN_BITS) != 0
- state.ssa_vals[op.outputs[0]] = tuple(RT)
- state.ssa_vals[op.outputs[1]] = carry,
+ state[op.outputs[0]] = tuple(RT)
+ state[op.outputs[1]] = carry,
@staticmethod
def __svsubfe_gen_asm(op, state):
inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
)
- _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
+ _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
_GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
@staticmethod
- def __svmaddedu_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RA = state.ssa_vals[op.input_vals[0]]
- RB, = state.ssa_vals[op.input_vals[1]]
- carry, = state.ssa_vals[op.input_vals[2]]
- VL, = state.ssa_vals[op.input_vals[3]]
+ def __svmaddedu_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA = state[op.input_vals[0]]
+ RB, = state[op.input_vals[1]]
+ carry, = state[op.input_vals[2]]
+ VL, = state[op.input_vals[3]]
RT = [] # type: list[int]
for i in range(VL):
v = RA[i] * RB + carry
RT.append(v & GPR_VALUE_MASK)
carry = v >> GPR_SIZE_IN_BITS
- state.ssa_vals[op.outputs[0]] = tuple(RT)
- state.ssa_vals[op.outputs[1]] = carry,
+ state[op.outputs[0]] = tuple(RT)
+ state[op.outputs[1]] = carry,
@staticmethod
def __svmaddedu_gen_asm(op, state):
inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
)
- _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim
+ _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
_GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
@staticmethod
- def __setvli_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = op.immediates[0],
+ def __setvli_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = op.immediates[0],
@staticmethod
def __setvli_gen_asm(op, state):
immediates=[range(1, 65)],
is_load_immediate=True,
)
- _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim
+ _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
_GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
@staticmethod
- def __svli_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- VL, = state.ssa_vals[op.input_vals[0]]
+ def __svli_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ VL, = state[op.input_vals[0]]
imm = op.immediates[0] & GPR_VALUE_MASK
- state.ssa_vals[op.outputs[0]] = (imm,) * VL
+ state[op.outputs[0]] = (imm,) * VL
@staticmethod
def __svli_gen_asm(op, state):
immediates=[IMM_S16],
is_load_immediate=True,
)
- _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim
+ _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
_GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
@staticmethod
- def __li_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
+ def __li_sim(op, state):
+ # type: (Op, BaseSimState) -> None
imm = op.immediates[0] & GPR_VALUE_MASK
- state.ssa_vals[op.outputs[0]] = imm,
+ state[op.outputs[0]] = imm,
@staticmethod
def __li_gen_asm(op, state):
immediates=[IMM_S16],
is_load_immediate=True,
)
- _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim
+ _SIM_FNS[LI] = lambda: OpKind.__li_sim
_GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
@staticmethod
- def __veccopytoreg_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+ def __veccopytoreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
@staticmethod
def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
is_copy=True,
)
- _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
+ _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
_GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
@staticmethod
- def __veccopyfromreg_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+ def __veccopyfromreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
@staticmethod
def __veccopyfromreg_gen_asm(op, state):
)],
is_copy=True,
)
- _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim
+ _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
_GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
@staticmethod
- def __copytoreg_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+ def __copytoreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
@staticmethod
def __copytoreg_gen_asm(op, state):
)],
is_copy=True,
)
- _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim
+ _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
_GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
@staticmethod
- def __copyfromreg_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+ def __copyfromreg_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = state[op.input_vals[0]]
@staticmethod
def __copyfromreg_gen_asm(op, state):
)],
is_copy=True,
)
- _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim
+ _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
_GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
@staticmethod
- def __concat_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- state.ssa_vals[op.outputs[0]] = tuple(
- state.ssa_vals[i][0] for i in op.input_vals[:-1])
+ def __concat_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ state[op.outputs[0]] = tuple(
+ state[i][0] for i in op.input_vals[:-1])
@staticmethod
def __concat_gen_asm(op, state):
outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
is_copy=True,
)
- _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
+ _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
_GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
@staticmethod
- def __spread_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
- state.ssa_vals[op.outputs[idx]] = inp,
+ def __spread_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ for idx, inp in enumerate(state[op.input_vals[0]]):
+ state[op.outputs[idx]] = inp,
@staticmethod
def __spread_gen_asm(op, state):
)],
is_copy=True,
)
- _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim
+ _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
_GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
@staticmethod
- def __svld_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RA, = state.ssa_vals[op.input_vals[0]]
- VL, = state.ssa_vals[op.input_vals[1]]
+ def __svld_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA, = state[op.input_vals[0]]
+ VL, = state[op.input_vals[1]]
addr = RA + op.immediates[0]
RT = [] # type: list[int]
for i in range(VL):
v = state.load(addr + GPR_SIZE_IN_BYTES * i)
RT.append(v & GPR_VALUE_MASK)
- state.ssa_vals[op.outputs[0]] = tuple(RT)
+ state[op.outputs[0]] = tuple(RT)
@staticmethod
def __svld_gen_asm(op, state):
outputs=[OD_EXTRA3_VGPR],
immediates=[IMM_S16],
)
- _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim
+ _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
_GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
@staticmethod
- def __ld_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RA, = state.ssa_vals[op.input_vals[0]]
+ def __ld_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RA, = state[op.input_vals[0]]
addr = RA + op.immediates[0]
v = state.load(addr)
- state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
+ state[op.outputs[0]] = v & GPR_VALUE_MASK,
@staticmethod
def __ld_gen_asm(op, state):
outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
immediates=[IMM_S16],
)
- _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
+ _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
_GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
@staticmethod
- def __svstd_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RS = state.ssa_vals[op.input_vals[0]]
- RA, = state.ssa_vals[op.input_vals[1]]
- VL, = state.ssa_vals[op.input_vals[2]]
+ def __svstd_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RS = state[op.input_vals[0]]
+ RA, = state[op.input_vals[1]]
+ VL, = state[op.input_vals[2]]
addr = RA + op.immediates[0]
for i in range(VL):
state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
immediates=[IMM_S16],
has_side_effects=True,
)
- _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim
+ _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
_GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
@staticmethod
- def __std_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
- RS, = state.ssa_vals[op.input_vals[0]]
- RA, = state.ssa_vals[op.input_vals[1]]
+ def __std_sim(op, state):
+ # type: (Op, BaseSimState) -> None
+ RS, = state[op.input_vals[0]]
+ RA, = state[op.input_vals[1]]
addr = RA + op.immediates[0]
state.store(addr, value=RS)
immediates=[IMM_S16],
has_side_effects=True,
)
- _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim
+ _SIM_FNS[Std] = lambda: OpKind.__std_sim
_GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
@staticmethod
- def __funcargr3_pre_ra_sim(op, state):
- # type: (Op, PreRASimState) -> None
+ def __funcargr3_sim(op, state):
+ # type: (Op, BaseSimState) -> None
pass # return value set before simulation
@staticmethod
outputs=[OD_BASE_SGPR.with_fixed_loc(
Loc(kind=LocKind.GPR, start=3, reg_len=1))],
)
- _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
+ _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
_GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
field_vals_str = ", ".join(field_vals)
return f"Op({field_vals_str})"
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
+ def sim(self, state):
+ # type: (BaseSimState) -> None
for inp in self.input_vals:
- if inp not in state.ssa_vals:
+ try:
+ val = state[inp]
+ except KeyError:
raise ValueError(f"SSAVal {inp} not yet assigned when "
f"running {self}")
- if len(state.ssa_vals[inp]) != inp.ty.reg_len:
+ if len(val) != inp.ty.reg_len:
raise ValueError(
f"value of SSAVal {inp} has wrong number of elements: "
f"expected {inp.ty.reg_len} found "
- f"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
- for out in self.outputs:
- if out in state.ssa_vals:
- if self.kind is OpKind.FuncArgR3:
- continue
- raise ValueError(f"SSAVal {out} already assigned before "
- f"running {self}")
- self.kind.pre_ra_sim(self, state)
+ f"{len(val)}: {val!r}")
+ if isinstance(state, PreRASimState):
+ for out in self.outputs:
+ if out in state.ssa_vals:
+ if self.kind is OpKind.FuncArgR3:
+ continue
+ raise ValueError(f"SSAVal {out} already assigned before "
+ f"running {self}")
+ self.kind.sim(self, state)
for out in self.outputs:
- if out not in state.ssa_vals:
+ try:
+ val = state[out]
+ except KeyError:
raise ValueError(f"running {self} failed to assign to {out}")
- if len(state.ssa_vals[out]) != out.ty.reg_len:
+ if len(val) != out.ty.reg_len:
raise ValueError(
f"value of SSAVal {out} has wrong number of elements: "
f"expected {out.ty.reg_len} found "
- f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
+ f"{len(val)}: {val!r}")
def gen_asm(self, state):
# type: (GenAsmState) -> None
@plain_data(frozen=True, repr=False)
-@final
-class PreRASimState:
- __slots__ = "ssa_vals", "memory"
+class BaseSimState(metaclass=ABCMeta):
+ __slots__ = "memory",
- def __init__(self, ssa_vals, memory):
- # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
- self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
+ def __init__(self, memory):
+ # type: (dict[int, int]) -> None
+ super().__init__()
self.memory = memory # type: dict[int, int]
def load_byte(self, addr):
items_str = ",\n".join(items)
return f"{{\n{items_str}}}"
+ def __repr__(self):
+ # type: () -> str
+ field_vals = [] # type: list[str]
+ for name in fields(self):
+ try:
+ value = getattr(self, name)
+ except AttributeError:
+ field_vals.append(f"{name}=<not set>")
+ continue
+ repr_fn = getattr(self, f"_{name}__repr", None)
+ if callable(repr_fn):
+ field_vals.append(f"{name}={repr_fn()}")
+ else:
+ field_vals.append(f"{name}={value!r}")
+ field_vals_str = ", ".join(field_vals)
+ return f"{self.__class__.__name__}({field_vals_str})"
+
+ @abstractmethod
+ def __getitem__(self, ssa_val):
+ # type: (SSAVal) -> tuple[int, ...]
+ ...
+
+ @abstractmethod
+ def __setitem__(self, ssa_val, value):
+ # type: (SSAVal, tuple[int, ...]) -> None
+ ...
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class PreRASimState(BaseSimState):
+ __slots__ = "ssa_vals",
+
+ def __init__(self, ssa_vals, memory):
+ # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
+ super().__init__(memory)
+ self.ssa_vals = ssa_vals # type: dict[SSAVal, tuple[int, ...]]
+
def _ssa_vals__repr(self):
# type: () -> str
if len(self.ssa_vals) == 0:
items_str = ",\n".join(items)
return f"{{\n{items_str},\n}}"
- def __repr__(self):
+ def __getitem__(self, ssa_val):
+ # type: (SSAVal) -> tuple[int, ...]
+ return self.ssa_vals[ssa_val]
+
+ def __setitem__(self, ssa_val, value):
+ # type: (SSAVal, tuple[int, ...]) -> None
+ if len(value) != ssa_val.ty.reg_len:
+ raise ValueError("value has wrong len")
+ self.ssa_vals[ssa_val] = value
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class PostRASimState(BaseSimState):
+ __slots__ = "ssa_val_to_loc_map", "loc_values"
+
+ def __init__(self, ssa_val_to_loc_map, memory, loc_values):
+ # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
+ super().__init__(memory)
+ self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
+ for ssa_val, loc in self.ssa_val_to_loc_map.items():
+ if ssa_val.ty != loc.ty:
+ raise ValueError(
+ f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
+ self.loc_values = loc_values
+ for loc in self.loc_values.keys():
+ if loc.reg_len != 1:
+ raise ValueError(
+ "loc_values must only contain Locs with reg_len=1, all "
+ "larger Locs will be split into reg_len=1 sub-Locs")
+
+ def _loc_values__repr(self):
# type: () -> str
- field_vals = [] # type: list[str]
- for name in fields(self):
- try:
- value = getattr(self, name)
- except AttributeError:
- field_vals.append(f"{name}=<not set>")
- continue
- repr_fn = getattr(self, f"_{name}__repr", None)
- if callable(repr_fn):
- field_vals.append(f"{name}={repr_fn()}")
- else:
- field_vals.append(f"{name}={value!r}")
- field_vals_str = ", ".join(field_vals)
- return f"PreRASimState({field_vals_str})"
+ locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start))
+ items = [] # type: list[str]
+ for loc in locs:
+ items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
+ items_str = ",\n".join(items)
+ return f"{{\n{items_str},\n}}"
+
+ def __getitem__(self, ssa_val):
+ # type: (SSAVal) -> tuple[int, ...]
+ loc = self.ssa_val_to_loc_map[ssa_val]
+ subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
+ retval = [] # type: list[int]
+ for i in range(loc.reg_len):
+ subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
+ retval.append(self.loc_values.get(subloc, 0))
+ return tuple(retval)
+
+ def __setitem__(self, ssa_val, value):
+ # type: (SSAVal, tuple[int, ...]) -> None
+ if len(value) != ssa_val.ty.reg_len:
+ raise ValueError("value has wrong len")
+ loc = self.ssa_val_to_loc_map[ssa_val]
+ subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
+ for i in range(loc.reg_len):
+ subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
+ self.loc_values[subloc] = value[i]
@plain_data(frozen=True)
"""
from itertools import combinations
-from typing import Iterable, Iterator, Mapping
+from typing import Iterable, Iterator, Mapping, TextIO
from cached_property import cached_property
from nmutil.plain_data import plain_data
reg_len = self.ty.reg_len
loc_set = None # type: None | LocSet
for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
- def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
-
def locs():
# type: () -> Iterable[Loc]
for loc in ssa_val.def_loc_set_before_spread:
disallowed_by_use = False
for use in fn_analysis.uses[ssa_val]:
- use_spread_idx = \
- use.defining_descriptor.spread_index or 0
# calculate the start for the use's Loc before spread
# e.g. if the def's Loc before spread starts at r6
- # and the def's spread_index is 5
- # and the use's spread_index is 3
+ # and the def's reg_offset_in_unspread is 5
+ # and the use's reg_offset_in_unspread is 3
# then the use's Loc before spread starts at r8
# because 8 == 6 + 5 - 3
- start = loc.start + def_spread_idx - use_spread_idx
+ start = (loc.start + ssa_val.reg_offset_in_unspread
+ - use.reg_offset_in_unspread)
use_loc = Loc.try_make(
loc.kind, start=start,
reg_len=use.ty_before_spread.reg_len)
return ProgramRange(start=start, stop=stop)
def __repr__(self):
- return (f"MergedSSAVal({self.fn_analysis}, "
- f"ssa_val_offsets={self.ssa_val_offsets})")
+ return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
+ f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
+ f"live_interval={self.live_interval})")
@final
return self.__repr__()
-def allocate_registers(fn):
- # type: (Fn) -> dict[SSAVal, Loc]
+def allocate_registers(fn, debug_out=None):
+ # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
# inserts enough copies that no manual spilling is necessary, all
# spilling is done by the register allocator naturally allocating SSAVals
# to stack slots
fn.pre_ra_insert_copies()
+ if debug_out is not None:
+ print(f"After pre_ra_insert_copies():\n{fn.ops}",
+ file=debug_out, flush=True)
+
fn_analysis = FnAnalysis(fn)
interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
- for ssa_vals in fn_analysis.live_at.values():
+ if debug_out is not None:
+ print(f"After InterferenceGraph.minimally_merged():\n"
+ f"{interference_graph}", file=debug_out, flush=True)
+
+ for pp, ssa_vals in fn_analysis.live_at.items():
live_merged_ssa_vals = OSet() # type: OSet[MergedSSAVal]
for ssa_val in ssa_vals:
live_merged_ssa_vals.add(
if i.loc_set.max_conflicts_with(j.loc_set) != 0:
interference_graph.nodes[i].add_edge(
interference_graph.nodes[j])
+ if debug_out is not None:
+ print(f"processed {pp} out of {fn_analysis.all_program_points}",
+ file=debug_out, flush=True)
+
+ if debug_out is not None:
+ print(f"After adding interference graph edges:\n"
+ f"{interference_graph}", file=debug_out, flush=True)
nodes_remaining = OSet(interference_graph.nodes.values())
node_stack.append(best_node)
nodes_remaining.remove(best_node)
+ if debug_out is not None:
+ print(f"After deciding node allocation order:\n"
+ f"{node_stack}", file=debug_out, flush=True)
+
retval = {} # type: dict[SSAVal, Loc]
while len(node_stack) > 0:
"failed to allocate Loc for IGNode",
node=node, interference_graph=interference_graph)
+ if debug_out is not None:
+ print(f"After allocating Loc for node:\n{node}",
+ file=debug_out, flush=True)
+
for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
+ if debug_out is not None:
+ print(f"final Locs for all SSAVals:\n{retval}",
+ file=debug_out, flush=True)
+
return retval