register allocation and simulation works for simple mul 192x192!
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 8 Nov 2022 06:53:38 +0000 (22:53 -0800)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 8 Nov 2022 06:53:38 +0000 (22:53 -0800)
src/__init__.py [new file with mode: 0644]
src/bigint_presentation_code/_tests/test_compiler_ir2.py
src/bigint_presentation_code/_tests/test_register_allocator2.py
src/bigint_presentation_code/_tests/test_toom_cook.py
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/register_allocator2.py

diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
index e53f871a588fcb73ee7e9711e4da63e2cc27a252..833dbc958983a529f6ec953435bdf0b809ebfb1c 100644 (file)
@@ -1,9 +1,9 @@
 import unittest
 
-from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
+from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, BaseTy, Fn,
                                                    FnAnalysis, GenAsmState, Loc, LocKind, OpKind, OpStage,
                                                    PreRASimState, ProgramPoint,
-                                                   SSAVal)
+                                                   SSAVal, Ty)
 
 
 class TestCompilerIR(unittest.TestCase):
@@ -807,46 +807,14 @@ class TestCompilerIR(unittest.TestCase):
                     size_in_bytes=GPR_SIZE_IN_BYTES)
         self.assertEqual(
             repr(state),
-            "PreRASimState(ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)}, "
-            "memory={\n"
+            "PreRASimState(memory={\n"
             "0x00100: <0xffffffffffffffff>,\n"
-            "0x00108: <0xabcdef0123456789>})")
-        fn.pre_ra_sim(state)
+            "0x00108: <0xabcdef0123456789>}, "
+            "ssa_vals={<arg.outputs[0]: <I64>>: (0x100,)})")
+        fn.sim(state)
         self.assertEqual(
             repr(state),
-            "PreRASimState(ssa_vals={\n"
-            "<arg.outputs[0]: <I64>>: (0x100,),\n"
-            "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
-            "<ld.outputs[0]: <I64*32>>: (\n"
-            "    0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0),\n"
-            "<li.outputs[0]: <I64*32>>: (\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0),\n"
-            "<ca.outputs[0]: <CA>>: (0x1,),\n"
-            "<add.outputs[0]: <I64*32>>: (\n"
-            "    0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0,\n"
-            "    0x0, 0x0, 0x0, 0x0),\n"
-            "<add.outputs[1]: <CA>>: (0x0,),\n"
-            "}, memory={\n"
+            "PreRASimState(memory={\n"
             "0x00100: <0x0000000000000000>,\n"
             "0x00108: <0xabcdef012345678a>,\n"
             "0x00110: <0x0000000000000000>,\n"
@@ -878,7 +846,39 @@ class TestCompilerIR(unittest.TestCase):
             "0x001e0: <0x0000000000000000>,\n"
             "0x001e8: <0x0000000000000000>,\n"
             "0x001f0: <0x0000000000000000>,\n"
-            "0x001f8: <0x0000000000000000>})")
+            "0x001f8: <0x0000000000000000>}, ssa_vals={\n"
+            "<arg.outputs[0]: <I64>>: (0x100,),\n"
+            "<vl.outputs[0]: <VL_MAXVL>>: (0x20,),\n"
+            "<ld.outputs[0]: <I64*32>>: (\n"
+            "    0xffffffffffffffff, 0xabcdef0123456789, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0),\n"
+            "<li.outputs[0]: <I64*32>>: (\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0),\n"
+            "<ca.outputs[0]: <CA>>: (0x1,),\n"
+            "<add.outputs[0]: <I64*32>>: (\n"
+            "    0x0, 0xabcdef012345678a, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0,\n"
+            "    0x0, 0x0, 0x0, 0x0),\n"
+            "<add.outputs[1]: <CA>>: (0x0,),\n"
+            "})")
 
     def test_gen_asm(self):
         fn, _arg = self.make_add_fn()
@@ -933,6 +933,136 @@ class TestCompilerIR(unittest.TestCase):
             'sv.std *32, 0(3)',
         ])
 
+    def test_spread(self):
+        fn = Fn()
+        maxvl = 4
+        vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+                              name="vl", maxvl=maxvl).outputs[0]
+        li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+                              name="li", maxvl=maxvl).outputs[0]
+        spread_op = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
+                                     name="spread", maxvl=maxvl)
+        self.assertEqual(spread_op.outputs[0].ty_before_spread,
+                         Ty(base_ty=BaseTy.I64, reg_len=maxvl))
+        _concat = fn.append_new_op(
+            OpKind.Concat, input_vals=[*spread_op.outputs[::-1], vl],
+            name="concat", maxvl=maxvl)
+        self.assertEqual([repr(op.properties) for op in fn.ops], [
+            "OpProperties(kind=OpKind.SetVLI, inputs=("
+            "), outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), "
+            "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),"
+            "), maxvl=4)",
+            "OpProperties(kind=OpKind.SvLI, inputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), "
+            "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),"
+            "), outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),"
+            "), maxvl=4)",
+            "OpProperties(kind=OpKind.Spread, inputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), "
+            "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)"
+            "), outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=0, "
+            "write_stage=OpStage.Late), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=1, "
+            "write_stage=OpStage.Late), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=2, "
+            "write_stage=OpStage.Late), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=3, "
+            "write_stage=OpStage.Late)"
+            "), maxvl=4)",
+            "OpProperties(kind=OpKind.Concat, inputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=0, "
+            "write_stage=OpStage.Early), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=1, "
+            "write_stage=OpStage.Early), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=2, "
+            "write_stage=OpStage.Early), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=3, "
+            "write_stage=OpStage.Early), "
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.VL_MAXVL: FBitSet([0])}), "
+            "ty=<VL_MAXVL>), tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)"
+            "), outputs=("
+            "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
+            "LocKind.GPR: FBitSet([*range(3, 10), *range(14, 125)])}), "
+            "ty=<I64*4>), tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),"
+            "), maxvl=4)",
+        ])
+        self.assertEqual([repr(op) for op in fn.ops], [
+            "Op(kind=OpKind.SetVLI, input_vals=["
+            "], input_uses=("
+            "), immediates=[4], outputs=("
+            "<vl.outputs[0]: <VL_MAXVL>>,"
+            "), name='vl')",
+            "Op(kind=OpKind.SvLI, input_vals=["
+            "<vl.outputs[0]: <VL_MAXVL>>"
+            "], input_uses=("
+            "<li.input_uses[0]: <VL_MAXVL>>,"
+            "), immediates=[0], outputs=("
+            "<li.outputs[0]: <I64*4>>,"
+            "), name='li')",
+            "Op(kind=OpKind.Spread, input_vals=["
+            "<li.outputs[0]: <I64*4>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>"
+            "], input_uses=("
+            "<spread.input_uses[0]: <I64*4>>, "
+            "<spread.input_uses[1]: <VL_MAXVL>>"
+            "), immediates=[], outputs=("
+            "<spread.outputs[0]: <I64>>, "
+            "<spread.outputs[1]: <I64>>, "
+            "<spread.outputs[2]: <I64>>, "
+            "<spread.outputs[3]: <I64>>"
+            "), name='spread')",
+            "Op(kind=OpKind.Concat, input_vals=["
+            "<spread.outputs[3]: <I64>>, "
+            "<spread.outputs[2]: <I64>>, "
+            "<spread.outputs[1]: <I64>>, "
+            "<spread.outputs[0]: <I64>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>"
+            "], input_uses=("
+            "<concat.input_uses[0]: <I64>>, "
+            "<concat.input_uses[1]: <I64>>, "
+            "<concat.input_uses[2]: <I64>>, "
+            "<concat.input_uses[3]: <I64>>, "
+            "<concat.input_uses[4]: <VL_MAXVL>>"
+            "), immediates=[], outputs=("
+            "<concat.outputs[0]: <I64*4>>,"
+            "), name='concat')",
+        ])
+
 
 if __name__ == "__main__":
     _ = unittest.main()
index 697417f95b82b148beec1612770e428c81ad134b..f34ed982c2078e8318394f557b6cb52ab8168cff 100644 (file)
@@ -1,6 +1,8 @@
+import sys
 import unittest
 
-from bigint_presentation_code.compiler_ir2 import Fn, GenAsmState, OpKind, SSAVal
+from bigint_presentation_code.compiler_ir2 import (Fn, GenAsmState, OpKind,
+                                                   SSAVal)
 from bigint_presentation_code.register_allocator2 import allocate_registers
 
 
@@ -177,6 +179,315 @@ class TestCompilerIR(unittest.TestCase):
             'sv.std *14, 0(3)',
         ])
 
+    def test_register_allocate_spread(self):
+        fn = Fn()
+        maxvl = 32
+        vl = fn.append_new_op(OpKind.SetVLI, immediates=[maxvl],
+                              name="vl", maxvl=maxvl).outputs[0]
+        li = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+                              name="li", maxvl=maxvl).outputs[0]
+        spread = fn.append_new_op(OpKind.Spread, input_vals=[li, vl],
+                                  name="spread", maxvl=maxvl).outputs
+        _concat = fn.append_new_op(
+            OpKind.Concat, input_vals=[*spread[::-1], vl],
+            name="concat", maxvl=maxvl)
+        reg_assignments = allocate_registers(fn, debug_out=sys.stdout)
+
+        self.assertEqual(
+            repr(reg_assignments),
+            "{<concat.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<concat.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<concat.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+            "<concat.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+            "<concat.inp2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+            "<concat.inp3.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+            "<concat.inp4.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+            "<concat.inp5.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+            "<concat.inp6.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
+            "<concat.inp7.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
+            "<concat.inp8.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
+            "<concat.inp9.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+            "<concat.inp10.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
+            "<concat.inp11.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
+            "<concat.inp12.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
+            "<concat.inp13.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
+            "<concat.inp14.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
+            "<concat.inp15.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
+            "<concat.inp16.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
+            "<concat.inp17.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
+            "<concat.inp18.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
+            "<concat.inp19.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
+            "<concat.inp20.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
+            "<concat.inp21.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
+            "<concat.inp22.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
+            "<concat.inp23.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
+            "<concat.inp24.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
+            "<concat.inp25.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
+            "<concat.inp26.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
+            "<concat.inp27.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
+            "<concat.inp28.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
+            "<concat.inp29.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
+            "<concat.inp30.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
+            "<concat.inp31.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
+            "<concat.inp32.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<spread.out31.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<spread.out30.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<spread.out29.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<spread.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+            "<spread.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+            "<spread.outputs[2]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+            "<spread.outputs[3]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+            "<spread.outputs[4]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+            "<spread.outputs[5]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+            "<spread.outputs[6]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=20, reg_len=1), "
+            "<spread.outputs[7]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=21, reg_len=1), "
+            "<spread.outputs[8]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=22, reg_len=1), "
+            "<spread.outputs[9]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+            "<spread.outputs[10]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=24, reg_len=1), "
+            "<spread.outputs[11]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=25, reg_len=1), "
+            "<spread.outputs[12]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=26, reg_len=1), "
+            "<spread.outputs[13]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=27, reg_len=1), "
+            "<spread.outputs[14]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=28, reg_len=1), "
+            "<spread.outputs[15]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=29, reg_len=1), "
+            "<spread.outputs[16]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=30, reg_len=1), "
+            "<spread.outputs[17]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=31, reg_len=1), "
+            "<spread.outputs[18]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=32, reg_len=1), "
+            "<spread.outputs[19]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=33, reg_len=1), "
+            "<spread.outputs[20]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=34, reg_len=1), "
+            "<spread.outputs[21]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=35, reg_len=1), "
+            "<spread.outputs[22]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=36, reg_len=1), "
+            "<spread.outputs[23]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=37, reg_len=1), "
+            "<spread.outputs[24]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=38, reg_len=1), "
+            "<spread.outputs[25]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=39, reg_len=1), "
+            "<spread.outputs[26]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=40, reg_len=1), "
+            "<spread.outputs[27]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=41, reg_len=1), "
+            "<spread.outputs[28]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=42, reg_len=1), "
+            "<spread.outputs[29]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=43, reg_len=1), "
+            "<spread.outputs[30]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=44, reg_len=1), "
+            "<spread.outputs[31]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=45, reg_len=1), "
+            "<spread.out28.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<spread.out27.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+            "<spread.out26.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
+            "<spread.out25.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
+            "<spread.out24.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
+            "<spread.out23.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+            "<spread.out22.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+            "<spread.out21.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=46, reg_len=1), "
+            "<spread.out20.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=47, reg_len=1), "
+            "<spread.out19.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=48, reg_len=1), "
+            "<spread.out18.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=49, reg_len=1), "
+            "<spread.out17.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=50, reg_len=1), "
+            "<spread.out16.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=51, reg_len=1), "
+            "<spread.out15.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=52, reg_len=1), "
+            "<spread.out14.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=53, reg_len=1), "
+            "<spread.out13.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=54, reg_len=1), "
+            "<spread.out12.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=55, reg_len=1), "
+            "<spread.out11.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=56, reg_len=1), "
+            "<spread.out10.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=57, reg_len=1), "
+            "<spread.out9.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=58, reg_len=1), "
+            "<spread.out8.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=59, reg_len=1), "
+            "<spread.out7.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=60, reg_len=1), "
+            "<spread.out6.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=61, reg_len=1), "
+            "<spread.out5.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=62, reg_len=1), "
+            "<spread.out4.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=63, reg_len=1), "
+            "<spread.out3.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=64, reg_len=1), "
+            "<spread.out2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=65, reg_len=1), "
+            "<spread.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=66, reg_len=1), "
+            "<spread.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=67, reg_len=1), "
+            "<spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<spread.inp0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.out0.copy.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<li.outputs[0]: <I64*32>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=32), "
+            "<li.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<vl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1)}"
+        )
+        state = GenAsmState(reg_assignments)
+        fn.gen_asm(state)
+        self.assertEqual(state.output, [
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'sv.addi *14, 0, 0',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'or 67, 14, 14',
+            'or 66, 15, 15',
+            'or 65, 16, 16',
+            'or 64, 17, 17',
+            'or 63, 18, 18',
+            'or 62, 19, 19',
+            'or 61, 20, 20',
+            'or 60, 21, 21',
+            'or 59, 22, 22',
+            'or 58, 23, 23',
+            'or 57, 24, 24',
+            'or 56, 25, 25',
+            'or 55, 26, 26',
+            'or 54, 27, 27',
+            'or 53, 28, 28',
+            'or 52, 29, 29',
+            'or 51, 30, 30',
+            'or 50, 31, 31',
+            'or 49, 32, 32',
+            'or 48, 33, 33',
+            'or 47, 34, 34',
+            'or 46, 35, 35',
+            'or 12, 36, 36',
+            'or 11, 37, 37',
+            'or 10, 38, 38',
+            'or 9, 39, 39',
+            'or 8, 40, 40',
+            'or 7, 41, 41',
+            'or 6, 42, 42',
+            'or 5, 43, 43',
+            'or 4, 44, 44',
+            'or 3, 45, 45',
+            'or 14, 3, 3',
+            'or 15, 4, 4',
+            'or 16, 5, 5',
+            'or 17, 6, 6',
+            'or 18, 7, 7',
+            'or 19, 8, 8',
+            'or 20, 9, 9',
+            'or 21, 10, 10',
+            'or 22, 11, 11',
+            'or 23, 12, 12',
+            'or 24, 46, 46',
+            'or 25, 47, 47',
+            'or 26, 48, 48',
+            'or 27, 49, 49',
+            'or 28, 50, 50',
+            'or 29, 51, 51',
+            'or 30, 52, 52',
+            'or 31, 53, 53',
+            'or 32, 54, 54',
+            'or 33, 55, 55',
+            'or 34, 56, 56',
+            'or 35, 57, 57',
+            'or 36, 58, 58',
+            'or 37, 59, 59',
+            'or 38, 60, 60',
+            'or 39, 61, 61',
+            'or 40, 62, 62',
+            'or 41, 63, 63',
+            'or 42, 64, 64',
+            'or 43, 65, 65',
+            'or 44, 66, 66',
+            'or 45, 67, 67',
+            'setvl 0, 0, 32, 0, 1, 1',
+            'setvl 0, 0, 32, 0, 1, 1'])
+
 
 if __name__ == "__main__":
     _ = unittest.main()
index 994c9510ea797c016c6a1acdbc0c0b333b4cb858..31884304630b285eb4483b4ac59ac060330620ad 100644 (file)
@@ -1,7 +1,10 @@
 import unittest
+from typing import Callable
 
-from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
+from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES,
+                                                   BaseSimState, Fn,
                                                    GenAsmState, OpKind,
+                                                   PostRASimState,
                                                    PreRASimState)
 from bigint_presentation_code.register_allocator2 import allocate_registers
 from bigint_presentation_code.toom_cook import ToomCookInstance, simple_mul
@@ -204,6 +207,21 @@ class TestToomCook(unittest.TestCase):
         )
 
     def test_simple_mul_192x192_pre_ra_sim(self):
+        def create_sim_state(code):
+            # type: (SimpleMul192x192) -> BaseSimState
+            return PreRASimState(ssa_vals={}, memory={})
+        self.tst_simple_mul_192x192_sim(create_sim_state)
+
+    def test_simple_mul_192x192_post_ra_sim(self):
+        def create_sim_state(code):
+            # type: (SimpleMul192x192) -> BaseSimState
+            ssa_val_to_loc_map = allocate_registers(code.fn)
+            return PostRASimState(ssa_val_to_loc_map=ssa_val_to_loc_map,
+                                  memory={}, loc_values={})
+        self.tst_simple_mul_192x192_sim(create_sim_state)
+
+    def tst_simple_mul_192x192_sim(self, create_sim_state):
+        # type: (Callable[[SimpleMul192x192], BaseSimState]) -> None
         # test multiplying:
         #   0x000191acb262e15b_4c6b5f2b19e1a53e_821a2342132c5b57
         # * 0x4a37c0567bcbab53_cf1f597598194ae6_208a49071aeec507
@@ -214,18 +232,19 @@ class TestToomCook(unittest.TestCase):
         # == int.from_bytes(b"arbitrary 192x192->384-bit multiplication test",
         #                   'little')
         code = SimpleMul192x192()
+        state = create_sim_state(code)
         ptr_in = 0x100
         dest_ptr = ptr_in + code.dest_offset
         lhs_ptr = ptr_in + code.lhs_offset
         rhs_ptr = ptr_in + code.rhs_offset
-        state = PreRASimState(ssa_vals={code.ptr_in: (ptr_in,)}, memory={})
+        state[code.ptr_in] = ptr_in,
         state.store(lhs_ptr, 0x821a2342132c5b57)
         state.store(lhs_ptr + 8, 0x4c6b5f2b19e1a53e)
         state.store(lhs_ptr + 16, 0x000191acb262e15b)
         state.store(rhs_ptr, 0x208a49071aeec507)
         state.store(rhs_ptr + 8, 0xcf1f597598194ae6)
         state.store(rhs_ptr + 16, 0x4a37c0567bcbab53)
-        code.fn.pre_ra_sim(state)
+        code.fn.sim(state)
         expected_bytes = b"arbitrary 192x192->384-bit multiplication test"
         OUT_BYTE_COUNT = 6 * GPR_SIZE_IN_BYTES
         expected_bytes = expected_bytes.ljust(OUT_BYTE_COUNT, b'\0')
@@ -458,28 +477,413 @@ class TestToomCook(unittest.TestCase):
             "name='store_dest')",
         ])
 
-    # FIXME: register allocator currently allocates wrong registers
-    @unittest.expectedFailure
     def test_simple_mul_192x192_reg_alloc(self):
         code = SimpleMul192x192()
         fn = code.fn
         assigned_registers = allocate_registers(fn)
-        self.assertEqual(assigned_registers, {
-        })
-        self.fail("register allocator currently allocates wrong registers")
+        self.assertEqual(
+            repr(assigned_registers), "{"
+            "<store_dest.inp2.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<store_dest.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<store_dest.inp0.copy.outputs[0]: <I64*6>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=6), "
+            "<store_dest.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<setvl6.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<concat_retval.out0.copy.outputs[0]: <I64*6>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
+            "<concat_retval.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<concat_retval.outputs[0]: <I64*6>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=6), "
+            "<concat_retval.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<concat_retval.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<concat_retval.inp2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<concat_retval.inp3.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<concat_retval.inp4.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+            "<concat_retval.inp5.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=8, reg_len=1), "
+            "<concat_retval.inp6.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<retval_setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add_hi2.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=9, reg_len=1), "
+            "<clear_ca2.outputs[0]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add2.outputs[1]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add_hi2.outputs[1]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add_hi2.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<add_hi2.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<add2_rt_spread.out2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=10, reg_len=1), "
+            "<add2_rt_spread.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+            "<add2_rt_spread.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+            "<add2_rt_spread.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<add2_rt_spread.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<add2_rt_spread.outputs[2]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<add2_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add2_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add2_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add2.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add2.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add2.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add2.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add2.inp1.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+            "<add2.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add2.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=9, reg_len=3), "
+            "<add2.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add2_rb_concat.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add2_rb_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add2_rb_concat.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add2_rb_concat.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<add2_rb_concat.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<add2_rb_concat.inp2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<add2_rb_concat.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul2.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+            "<mul2.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+            "<mul2.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul2.inp2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<mul2.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<mul2.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=3), "
+            "<mul2.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul2.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+            "<mul2.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
+            "<mul2.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add_hi1.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+            "<clear_ca1.outputs[0]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add1.outputs[1]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add_hi1.outputs[1]: <CA>>: "
+            "Loc(kind=LocKind.CA, start=0, reg_len=1), "
+            "<add_hi1.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<add_hi1.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<add1_rt_spread.out2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+            "<add1_rt_spread.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+            "<add1_rt_spread.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=16, reg_len=1), "
+            "<add1_rt_spread.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<add1_rt_spread.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<add1_rt_spread.outputs[2]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<add1_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add1_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add1_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add1.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add1.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add1.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add1.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add1.inp1.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+            "<add1.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add1.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=9, reg_len=3), "
+            "<add1.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add1_rb_concat.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add1_rb_concat.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<add1_rb_concat.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<add1_rb_concat.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<add1_rb_concat.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<add1_rb_concat.inp2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<add1_rb_concat.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul1.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+            "<mul1.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=3), "
+            "<mul1.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul1.inp2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<mul1.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<mul1.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=3), "
+            "<mul1.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul1.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+            "<mul1.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
+            "<mul1.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul0_rt_spread.out2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=11, reg_len=1), "
+            "<mul0_rt_spread.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=12, reg_len=1), "
+            "<mul0_rt_spread.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=17, reg_len=1), "
+            "<mul0_rt_spread.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<mul0_rt_spread.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<mul0_rt_spread.outputs[2]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<mul0_rt_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul0_rt_spread.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<mul0_rt_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul0.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=15, reg_len=1), "
+            "<mul0.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<mul0.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul0.inp2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<mul0.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<mul0.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<mul0.inp3.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<mul0.inp1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+            "<mul0.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=8, reg_len=3), "
+            "<mul0.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<zero.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=18, reg_len=1), "
+            "<zero.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1), "
+            "<lhs_setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<rhs_spread.out2.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=19, reg_len=1), "
+            "<rhs_spread.out1.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=14, reg_len=1), "
+            "<rhs_spread.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=4, reg_len=1), "
+            "<rhs_spread.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=5, reg_len=1), "
+            "<rhs_spread.outputs[1]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<rhs_spread.outputs[2]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=7, reg_len=1), "
+            "<rhs_spread.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<rhs_spread.inp0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<rhs_spread.inp0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<rhs_setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<load_rhs.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<load_rhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<load_rhs.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<load_rhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<load_rhs.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<load_lhs.out0.copy.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=20, reg_len=3), "
+            "<load_lhs.out0.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<load_lhs.outputs[0]: <I64*3>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=3), "
+            "<load_lhs.inp1.setvl.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<load_lhs.inp0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=6, reg_len=1), "
+            "<setvl3.outputs[0]: <VL_MAXVL>>: "
+            "Loc(kind=LocKind.VL_MAXVL, start=0, reg_len=1), "
+            "<ptr_in.out0.copy.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=23, reg_len=1), "
+            "<ptr_in.outputs[0]: <I64>>: "
+            "Loc(kind=LocKind.GPR, start=3, reg_len=1)"
+            "}")
 
-    # FIXME: register allocator currently allocates wrong registers
-    @unittest.expectedFailure
     def test_simple_mul_192x192_asm(self):
-        self.skipTest("WIP")
         code = SimpleMul192x192()
         fn = code.fn
         assigned_registers = allocate_registers(fn)
         gen_asm_state = GenAsmState(assigned_registers)
         fn.gen_asm(gen_asm_state)
         self.assertEqual(gen_asm_state.output, [
+            'or 23, 3, 3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'or 6, 23, 23',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.ld *3, 48(6)',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *20, *3, *3',
+            'or 6, 23, 23',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.ld *3, 72(6)',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or/mrr *5, *3, *3',
+            'or 4, 5, 5',
+            'or 14, 6, 6',
+            'or 19, 7, 7',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'addi 3, 0, 0',
+            'or 18, 3, 3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *8, *20, *20',
+            'or 7, 4, 4',
+            'or 6, 18, 18',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.maddedu *3, *8, 7, 6',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'or 15, 6, 6',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'or 17, 3, 3',
+            'or 12, 4, 4',
+            'or 11, 5, 5',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *8, *20, *20',
+            'or 7, 14, 14',
+            'or 3, 18, 18',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.maddedu *4, *8, 7, 3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or/mrr *6, *4, *4',
+            'or 14, 3, 3',
+            'or 3, 12, 12',
+            'or 4, 11, 11',
+            'or 5, 15, 15',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'addic 0, 0, 0',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *9, *6, *6',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *6, *3, *3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.adde *3, *9, *6',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'or 16, 3, 3',
+            'or 15, 4, 4',
+            'or 12, 5, 5',
+            'or 4, 14, 14',
+            'addze *3, *4',
+            'or 11, 3, 3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *8, *20, *20',
+            'or 7, 19, 19',
+            'or 3, 18, 18',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.maddedu *4, *8, 7, 3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or/mrr *6, *4, *4',
+            'or 14, 3, 3',
+            'or 3, 15, 15',
+            'or 4, 12, 12',
+            'or 5, 11, 11',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'addic 0, 0, 0',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *9, *6, *6',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.or *6, *3, *3',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'sv.adde *3, *9, *6',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'setvl 0, 0, 3, 0, 1, 1',
+            'or 12, 3, 3',
+            'or 11, 4, 4',
+            'or 10, 5, 5',
+            'or 4, 14, 14',
+            'addze *3, *4',
+            'or 9, 3, 3',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'or 3, 17, 17',
+            'or 4, 16, 16',
+            'or 5, 12, 12',
+            'or 6, 11, 11',
+            'or 7, 10, 10',
+            'or 8, 9, 9',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'sv.or/mrr *4, *3, *3',
+            'or 3, 23, 23',
+            'setvl 0, 0, 6, 0, 1, 1',
+            'sv.std *4, 0(3)'
         ])
-        self.fail("register allocator currently allocates wrong registers")
 
 
 if __name__ == "__main__":
index bd3d38c77f1a7d648387a7d8e16abafe93b5009f..d3b52e85c4a5cbb654bf7bedd39c25f42924da2f 100644 (file)
@@ -12,7 +12,8 @@ from nmutil.plain_data import fields, plain_data
 
 from bigint_presentation_code.type_util import (Literal, Self, assert_never,
                                                 final)
-from bigint_presentation_code.util import BitSet, FBitSet, FMap, InternedMeta, OFSet, OSet
+from bigint_presentation_code.util import (BitSet, FBitSet, FMap, InternedMeta,
+                                           OFSet, OSet)
 
 
 @final
@@ -54,10 +55,10 @@ class Fn:
         self.append_op(retval)
         return retval
 
-    def pre_ra_sim(self, state):
-        # type: (PreRASimState) -> None
+    def sim(self, state):
+        # type: (BaseSimState) -> None
         for op in self.ops:
-            op.pre_ra_sim(state)
+            op.sim(state)
 
     def gen_asm(self, state):
         # type: (GenAsmState) -> None
@@ -641,25 +642,57 @@ SPECIAL_GPRS = (
 )
 
 
-@plain_data(frozen=True, eq=False)
+@final
+class _LocSetHashHelper(AbstractSet[Loc]):
+    """helper to more quickly compute LocSet's hash"""
+
+    def __init__(self, locs):
+        # type: (Iterable[Loc]) -> None
+        super().__init__()
+        self.locs = list(locs)
+
+    def __hash__(self):
+        # type: () -> int
+        return super()._hash()
+
+    def __contains__(self, x):
+        # type: (Loc | Any) -> bool
+        return x in self.locs
+
+    def __iter__(self):
+        # type: () -> Iterator[Loc]
+        return iter(self.locs)
+
+    def __len__(self):
+        return len(self.locs)
+
+
+@plain_data(frozen=True, eq=False, repr=False)
 @final
 class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
-    __slots__ = "starts", "ty"
+    __slots__ = "starts", "ty", "_LocSet__hash"
 
     def __init__(self, __locs=()):
         # type: (Iterable[Loc]) -> None
         if isinstance(__locs, LocSet):
             self.starts = __locs.starts  # type: FMap[LocKind, FBitSet]
             self.ty = __locs.ty  # type: Ty | None
+            self._LocSet__hash = __locs._LocSet__hash  # type: int
             return
         starts = {i: BitSet() for i in LocKind}
-        ty = None
-        for loc in __locs:
-            if ty is None:
-                ty = loc.ty
-            if ty != loc.ty:
-                raise ValueError(f"conflicting types: {ty} != {loc.ty}")
-            starts[loc.kind].add(loc.start)
+        ty = None  # type: None | Ty
+
+        def locs():
+            # type: () -> Iterable[Loc]
+            nonlocal ty
+            for loc in __locs:
+                if ty is None:
+                    ty = loc.ty
+                if ty != loc.ty:
+                    raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+                starts[loc.kind].add(loc.start)
+                yield loc
+        self._LocSet__hash = _LocSetHashHelper(locs()).__hash__()
         self.starts = FMap(
             (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
         self.ty = ty
@@ -747,7 +780,7 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
         return self.__len
 
     def __hash__(self):
-        return super()._hash()
+        return self._LocSet__hash
 
     def __eq__(self, __other):
         # type: (LocSet | Any) -> bool
@@ -766,6 +799,14 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
         else:
             return sum(other.conflicts(i) for i in self)
 
+    def __repr__(self):
+        items = []  # type: list[str]
+        for name in fields(self):
+            if name.startswith("_"):
+                continue
+            items.append(f"{name}={getattr(self, name)!r}")
+        return f"LocSet({', '.join(items)})"
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
@@ -821,6 +862,13 @@ class GenericOperandDesc(metaclass=InternedMeta):
                 raise ValueError("operand can't be both spread and vector")
         self.write_stage = write_stage
 
+    @cached_property
+    def ty_before_spread(self):
+        # type: () -> GenericTy
+        if self.spread:
+            return GenericTy(base_ty=self.ty.base_ty, is_vec=True)
+        return self.ty
+
     def tied_to_input(self, tied_input_index):
         # type: (int) -> Self
         return GenericOperandDesc(self.ty, self.sub_kinds,
@@ -846,21 +894,21 @@ class GenericOperandDesc(metaclass=InternedMeta):
         rep_count = 1
         if self.spread:
             rep_count = maxvl
-            maxvl = 1
-        ty = self.ty.instantiate(maxvl=maxvl)
+        ty_before_spread = self.ty_before_spread.instantiate(maxvl=maxvl)
 
-        def locs():
+        def locs_before_spread():
             # type: () -> Iterable[Loc]
             if self.fixed_loc is not None:
-                if ty != self.fixed_loc.ty:
+                if ty_before_spread != self.fixed_loc.ty:
                     raise ValueError(
                         f"instantiation failed: type mismatch with fixed_loc: "
-                        f"instantiated type: {ty} fixed_loc: {self.fixed_loc}")
+                        f"instantiated type: {ty_before_spread} "
+                        f"fixed_loc: {self.fixed_loc}")
                 yield self.fixed_loc
                 return
             for sub_kind in self.sub_kinds:
-                yield from sub_kind.allocatable_locs(ty)
-        loc_set_before_spread = LocSet(locs())
+                yield from sub_kind.allocatable_locs(ty_before_spread)
+        loc_set_before_spread = LocSet(locs_before_spread())
         for idx in range(rep_count):
             if not self.spread:
                 idx = None
@@ -1045,9 +1093,9 @@ class OpProperties(metaclass=InternedMeta):
 
 IMM_S16 = range(-1 << 15, 1 << 15)
 
-_PRE_RA_SIM_FN = Callable[["Op", "PreRASimState"], None]
-_PRE_RA_SIM_FN2 = Callable[[], _PRE_RA_SIM_FN]
-_PRE_RA_SIMS = {}  # type: dict[GenericOpProperties | Any, _PRE_RA_SIM_FN2]
+_SIM_FN = Callable[["Op", "BaseSimState"], None]
+_SIM_FN2 = Callable[[], _SIM_FN]
+_SIM_FNS = {}  # type: dict[GenericOpProperties | Any, _SIM_FN2]
 _GEN_ASM_FN = Callable[["Op", "GenAsmState"], None]
 _GEN_ASM_FN2 = Callable[[], _GEN_ASM_FN]
 _GEN_ASMS = {}  # type: dict[GenericOpProperties | Any, _GEN_ASM_FN2]
@@ -1075,9 +1123,9 @@ class OpKind(Enum):
         return "OpKind." + self._name_
 
     @cached_property
-    def pre_ra_sim(self):
-        # type: () -> _PRE_RA_SIM_FN
-        return _PRE_RA_SIMS[self.properties]()
+    def sim(self):
+        # type: () -> _SIM_FN
+        return _SIM_FNS[self.properties]()
 
     @cached_property
     def gen_asm(self):
@@ -1085,9 +1133,9 @@ class OpKind(Enum):
         return _GEN_ASMS[self.properties]()
 
     @staticmethod
-    def __clearca_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = False,
+    def __clearca_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = False,
 
     @staticmethod
     def __clearca_gen_asm(op, state):
@@ -1098,13 +1146,13 @@ class OpKind(Enum):
         inputs=[],
         outputs=[OD_CA.with_write_stage(OpStage.Late)],
     )
-    _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
+    _SIM_FNS[ClearCA] = lambda: OpKind.__clearca_sim
     _GEN_ASMS[ClearCA] = lambda: OpKind.__clearca_gen_asm
 
     @staticmethod
-    def __setca_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = True,
+    def __setca_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = True,
 
     @staticmethod
     def __setca_gen_asm(op, state):
@@ -1115,23 +1163,23 @@ class OpKind(Enum):
         inputs=[],
         outputs=[OD_CA.with_write_stage(OpStage.Late)],
     )
-    _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
+    _SIM_FNS[SetCA] = lambda: OpKind.__setca_sim
     _GEN_ASMS[SetCA] = lambda: OpKind.__setca_gen_asm
 
     @staticmethod
-    def __svadde_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RA = state.ssa_vals[op.input_vals[0]]
-        RB = state.ssa_vals[op.input_vals[1]]
-        carry, = state.ssa_vals[op.input_vals[2]]
-        VL, = state.ssa_vals[op.input_vals[3]]
+    def __svadde_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RA = state[op.input_vals[0]]
+        RB = state[op.input_vals[1]]
+        carry, = state[op.input_vals[2]]
+        VL, = state[op.input_vals[3]]
         RT = []  # type: list[int]
         for i in range(VL):
             v = RA[i] + RB[i] + carry
             RT.append(v & GPR_VALUE_MASK)
             carry = (v >> GPR_SIZE_IN_BITS) != 0
-        state.ssa_vals[op.outputs[0]] = tuple(RT)
-        state.ssa_vals[op.outputs[1]] = carry,
+        state[op.outputs[0]] = tuple(RT)
+        state[op.outputs[1]] = carry,
 
     @staticmethod
     def __svadde_gen_asm(op, state):
@@ -1145,19 +1193,19 @@ class OpKind(Enum):
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
         outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
     )
-    _PRE_RA_SIMS[SvAddE] = lambda: OpKind.__svadde_pre_ra_sim
+    _SIM_FNS[SvAddE] = lambda: OpKind.__svadde_sim
     _GEN_ASMS[SvAddE] = lambda: OpKind.__svadde_gen_asm
 
     @staticmethod
-    def __addze_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RA, = state.ssa_vals[op.input_vals[0]]
-        carry, = state.ssa_vals[op.input_vals[1]]
+    def __addze_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RA, = state[op.input_vals[0]]
+        carry, = state[op.input_vals[1]]
         v = RA + carry
         RT = v & GPR_VALUE_MASK
         carry = (v >> GPR_SIZE_IN_BITS) != 0
-        state.ssa_vals[op.outputs[0]] = RT,
-        state.ssa_vals[op.outputs[1]] = carry,
+        state[op.outputs[0]] = RT,
+        state[op.outputs[1]] = carry,
 
     @staticmethod
     def __addze_gen_asm(op, state):
@@ -1170,23 +1218,23 @@ class OpKind(Enum):
         inputs=[OD_BASE_SGPR, OD_CA],
         outputs=[OD_BASE_SGPR, OD_CA.tied_to_input(1)],
     )
-    _PRE_RA_SIMS[AddZE] = lambda: OpKind.__addze_pre_ra_sim
+    _SIM_FNS[AddZE] = lambda: OpKind.__addze_sim
     _GEN_ASMS[AddZE] = lambda: OpKind.__addze_gen_asm
 
     @staticmethod
-    def __svsubfe_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RA = state.ssa_vals[op.input_vals[0]]
-        RB = state.ssa_vals[op.input_vals[1]]
-        carry, = state.ssa_vals[op.input_vals[2]]
-        VL, = state.ssa_vals[op.input_vals[3]]
+    def __svsubfe_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RA = state[op.input_vals[0]]
+        RB = state[op.input_vals[1]]
+        carry, = state[op.input_vals[2]]
+        VL, = state[op.input_vals[3]]
         RT = []  # type: list[int]
         for i in range(VL):
             v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
             RT.append(v & GPR_VALUE_MASK)
             carry = (v >> GPR_SIZE_IN_BITS) != 0
-        state.ssa_vals[op.outputs[0]] = tuple(RT)
-        state.ssa_vals[op.outputs[1]] = carry,
+        state[op.outputs[0]] = tuple(RT)
+        state[op.outputs[1]] = carry,
 
     @staticmethod
     def __svsubfe_gen_asm(op, state):
@@ -1200,23 +1248,23 @@ class OpKind(Enum):
         inputs=[OD_EXTRA3_VGPR, OD_EXTRA3_VGPR, OD_CA, OD_VL],
         outputs=[OD_EXTRA3_VGPR, OD_CA.tied_to_input(2)],
     )
-    _PRE_RA_SIMS[SvSubFE] = lambda: OpKind.__svsubfe_pre_ra_sim
+    _SIM_FNS[SvSubFE] = lambda: OpKind.__svsubfe_sim
     _GEN_ASMS[SvSubFE] = lambda: OpKind.__svsubfe_gen_asm
 
     @staticmethod
-    def __svmaddedu_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RA = state.ssa_vals[op.input_vals[0]]
-        RB, = state.ssa_vals[op.input_vals[1]]
-        carry, = state.ssa_vals[op.input_vals[2]]
-        VL, = state.ssa_vals[op.input_vals[3]]
+    def __svmaddedu_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RA = state[op.input_vals[0]]
+        RB, = state[op.input_vals[1]]
+        carry, = state[op.input_vals[2]]
+        VL, = state[op.input_vals[3]]
         RT = []  # type: list[int]
         for i in range(VL):
             v = RA[i] * RB + carry
             RT.append(v & GPR_VALUE_MASK)
             carry = v >> GPR_SIZE_IN_BITS
-        state.ssa_vals[op.outputs[0]] = tuple(RT)
-        state.ssa_vals[op.outputs[1]] = carry,
+        state[op.outputs[0]] = tuple(RT)
+        state[op.outputs[1]] = carry,
 
     @staticmethod
     def __svmaddedu_gen_asm(op, state):
@@ -1231,13 +1279,13 @@ class OpKind(Enum):
         inputs=[OD_EXTRA2_VGPR, OD_EXTRA2_SGPR, OD_EXTRA2_SGPR, OD_VL],
         outputs=[OD_EXTRA3_VGPR, OD_EXTRA2_SGPR.tied_to_input(2)],
     )
-    _PRE_RA_SIMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_pre_ra_sim
+    _SIM_FNS[SvMAddEDU] = lambda: OpKind.__svmaddedu_sim
     _GEN_ASMS[SvMAddEDU] = lambda: OpKind.__svmaddedu_gen_asm
 
     @staticmethod
-    def __setvli_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = op.immediates[0],
+    def __setvli_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = op.immediates[0],
 
     @staticmethod
     def __setvli_gen_asm(op, state):
@@ -1251,15 +1299,15 @@ class OpKind(Enum):
         immediates=[range(1, 65)],
         is_load_immediate=True,
     )
-    _PRE_RA_SIMS[SetVLI] = lambda: OpKind.__setvli_pre_ra_sim
+    _SIM_FNS[SetVLI] = lambda: OpKind.__setvli_sim
     _GEN_ASMS[SetVLI] = lambda: OpKind.__setvli_gen_asm
 
     @staticmethod
-    def __svli_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        VL, = state.ssa_vals[op.input_vals[0]]
+    def __svli_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        VL, = state[op.input_vals[0]]
         imm = op.immediates[0] & GPR_VALUE_MASK
-        state.ssa_vals[op.outputs[0]] = (imm,) * VL
+        state[op.outputs[0]] = (imm,) * VL
 
     @staticmethod
     def __svli_gen_asm(op, state):
@@ -1274,14 +1322,14 @@ class OpKind(Enum):
         immediates=[IMM_S16],
         is_load_immediate=True,
     )
-    _PRE_RA_SIMS[SvLI] = lambda: OpKind.__svli_pre_ra_sim
+    _SIM_FNS[SvLI] = lambda: OpKind.__svli_sim
     _GEN_ASMS[SvLI] = lambda: OpKind.__svli_gen_asm
 
     @staticmethod
-    def __li_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
+    def __li_sim(op, state):
+        # type: (Op, BaseSimState) -> None
         imm = op.immediates[0] & GPR_VALUE_MASK
-        state.ssa_vals[op.outputs[0]] = imm,
+        state[op.outputs[0]] = imm,
 
     @staticmethod
     def __li_gen_asm(op, state):
@@ -1296,13 +1344,13 @@ class OpKind(Enum):
         immediates=[IMM_S16],
         is_load_immediate=True,
     )
-    _PRE_RA_SIMS[LI] = lambda: OpKind.__li_pre_ra_sim
+    _SIM_FNS[LI] = lambda: OpKind.__li_sim
     _GEN_ASMS[LI] = lambda: OpKind.__li_gen_asm
 
     @staticmethod
-    def __veccopytoreg_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+    def __veccopytoreg_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = state[op.input_vals[0]]
 
     @staticmethod
     def __copy_to_from_reg_gen_asm(src_loc, dest_loc, is_vec, state):
@@ -1359,13 +1407,13 @@ class OpKind(Enum):
         outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
         is_copy=True,
     )
-    _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
+    _SIM_FNS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_sim
     _GEN_ASMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_gen_asm
 
     @staticmethod
-    def __veccopyfromreg_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+    def __veccopyfromreg_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = state[op.input_vals[0]]
 
     @staticmethod
     def __veccopyfromreg_gen_asm(op, state):
@@ -1385,13 +1433,13 @@ class OpKind(Enum):
         )],
         is_copy=True,
     )
-    _PRE_RA_SIMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_pre_ra_sim
+    _SIM_FNS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_sim
     _GEN_ASMS[VecCopyFromReg] = lambda: OpKind.__veccopyfromreg_gen_asm
 
     @staticmethod
-    def __copytoreg_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+    def __copytoreg_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = state[op.input_vals[0]]
 
     @staticmethod
     def __copytoreg_gen_asm(op, state):
@@ -1415,13 +1463,13 @@ class OpKind(Enum):
         )],
         is_copy=True,
     )
-    _PRE_RA_SIMS[CopyToReg] = lambda: OpKind.__copytoreg_pre_ra_sim
+    _SIM_FNS[CopyToReg] = lambda: OpKind.__copytoreg_sim
     _GEN_ASMS[CopyToReg] = lambda: OpKind.__copytoreg_gen_asm
 
     @staticmethod
-    def __copyfromreg_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
+    def __copyfromreg_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = state[op.input_vals[0]]
 
     @staticmethod
     def __copyfromreg_gen_asm(op, state):
@@ -1445,14 +1493,14 @@ class OpKind(Enum):
         )],
         is_copy=True,
     )
-    _PRE_RA_SIMS[CopyFromReg] = lambda: OpKind.__copyfromreg_pre_ra_sim
+    _SIM_FNS[CopyFromReg] = lambda: OpKind.__copyfromreg_sim
     _GEN_ASMS[CopyFromReg] = lambda: OpKind.__copyfromreg_gen_asm
 
     @staticmethod
-    def __concat_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = tuple(
-            state.ssa_vals[i][0] for i in op.input_vals[:-1])
+    def __concat_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        state[op.outputs[0]] = tuple(
+            state[i][0] for i in op.input_vals[:-1])
 
     @staticmethod
     def __concat_gen_asm(op, state):
@@ -1471,14 +1519,14 @@ class OpKind(Enum):
         outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
         is_copy=True,
     )
-    _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
+    _SIM_FNS[Concat] = lambda: OpKind.__concat_sim
     _GEN_ASMS[Concat] = lambda: OpKind.__concat_gen_asm
 
     @staticmethod
-    def __spread_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
-            state.ssa_vals[op.outputs[idx]] = inp,
+    def __spread_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        for idx, inp in enumerate(state[op.input_vals[0]]):
+            state[op.outputs[idx]] = inp,
 
     @staticmethod
     def __spread_gen_asm(op, state):
@@ -1498,20 +1546,20 @@ class OpKind(Enum):
         )],
         is_copy=True,
     )
-    _PRE_RA_SIMS[Spread] = lambda: OpKind.__spread_pre_ra_sim
+    _SIM_FNS[Spread] = lambda: OpKind.__spread_sim
     _GEN_ASMS[Spread] = lambda: OpKind.__spread_gen_asm
 
     @staticmethod
-    def __svld_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RA, = state.ssa_vals[op.input_vals[0]]
-        VL, = state.ssa_vals[op.input_vals[1]]
+    def __svld_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RA, = state[op.input_vals[0]]
+        VL, = state[op.input_vals[1]]
         addr = RA + op.immediates[0]
         RT = []  # type: list[int]
         for i in range(VL):
             v = state.load(addr + GPR_SIZE_IN_BYTES * i)
             RT.append(v & GPR_VALUE_MASK)
-        state.ssa_vals[op.outputs[0]] = tuple(RT)
+        state[op.outputs[0]] = tuple(RT)
 
     @staticmethod
     def __svld_gen_asm(op, state):
@@ -1526,16 +1574,16 @@ class OpKind(Enum):
         outputs=[OD_EXTRA3_VGPR],
         immediates=[IMM_S16],
     )
-    _PRE_RA_SIMS[SvLd] = lambda: OpKind.__svld_pre_ra_sim
+    _SIM_FNS[SvLd] = lambda: OpKind.__svld_sim
     _GEN_ASMS[SvLd] = lambda: OpKind.__svld_gen_asm
 
     @staticmethod
-    def __ld_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RA, = state.ssa_vals[op.input_vals[0]]
+    def __ld_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RA, = state[op.input_vals[0]]
         addr = RA + op.immediates[0]
         v = state.load(addr)
-        state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
+        state[op.outputs[0]] = v & GPR_VALUE_MASK,
 
     @staticmethod
     def __ld_gen_asm(op, state):
@@ -1550,15 +1598,15 @@ class OpKind(Enum):
         outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
         immediates=[IMM_S16],
     )
-    _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
+    _SIM_FNS[Ld] = lambda: OpKind.__ld_sim
     _GEN_ASMS[Ld] = lambda: OpKind.__ld_gen_asm
 
     @staticmethod
-    def __svstd_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RS = state.ssa_vals[op.input_vals[0]]
-        RA, = state.ssa_vals[op.input_vals[1]]
-        VL, = state.ssa_vals[op.input_vals[2]]
+    def __svstd_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RS = state[op.input_vals[0]]
+        RA, = state[op.input_vals[1]]
+        VL, = state[op.input_vals[2]]
         addr = RA + op.immediates[0]
         for i in range(VL):
             state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
@@ -1577,14 +1625,14 @@ class OpKind(Enum):
         immediates=[IMM_S16],
         has_side_effects=True,
     )
-    _PRE_RA_SIMS[SvStd] = lambda: OpKind.__svstd_pre_ra_sim
+    _SIM_FNS[SvStd] = lambda: OpKind.__svstd_sim
     _GEN_ASMS[SvStd] = lambda: OpKind.__svstd_gen_asm
 
     @staticmethod
-    def __std_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
-        RS, = state.ssa_vals[op.input_vals[0]]
-        RA, = state.ssa_vals[op.input_vals[1]]
+    def __std_sim(op, state):
+        # type: (Op, BaseSimState) -> None
+        RS, = state[op.input_vals[0]]
+        RA, = state[op.input_vals[1]]
         addr = RA + op.immediates[0]
         state.store(addr, value=RS)
 
@@ -1602,12 +1650,12 @@ class OpKind(Enum):
         immediates=[IMM_S16],
         has_side_effects=True,
     )
-    _PRE_RA_SIMS[Std] = lambda: OpKind.__std_pre_ra_sim
+    _SIM_FNS[Std] = lambda: OpKind.__std_sim
     _GEN_ASMS[Std] = lambda: OpKind.__std_gen_asm
 
     @staticmethod
-    def __funcargr3_pre_ra_sim(op, state):
-        # type: (Op, PreRASimState) -> None
+    def __funcargr3_sim(op, state):
+        # type: (Op, BaseSimState) -> None
         pass  # return value set before simulation
 
     @staticmethod
@@ -1620,7 +1668,7 @@ class OpKind(Enum):
         outputs=[OD_BASE_SGPR.with_fixed_loc(
             Loc(kind=LocKind.GPR, start=3, reg_len=1))],
     )
-    _PRE_RA_SIMS[FuncArgR3] = lambda: OpKind.__funcargr3_pre_ra_sim
+    _SIM_FNS[FuncArgR3] = lambda: OpKind.__funcargr3_sim
     _GEN_ASMS[FuncArgR3] = lambda: OpKind.__funcargr3_gen_asm
 
 
@@ -1942,32 +1990,37 @@ class Op:
         field_vals_str = ", ".join(field_vals)
         return f"Op({field_vals_str})"
 
-    def pre_ra_sim(self, state):
-        # type: (PreRASimState) -> None
+    def sim(self, state):
+        # type: (BaseSimState) -> None
         for inp in self.input_vals:
-            if inp not in state.ssa_vals:
+            try:
+                val = state[inp]
+            except KeyError:
                 raise ValueError(f"SSAVal {inp} not yet assigned when "
                                  f"running {self}")
-            if len(state.ssa_vals[inp]) != inp.ty.reg_len:
+            if len(val) != inp.ty.reg_len:
                 raise ValueError(
                     f"value of SSAVal {inp} has wrong number of elements: "
                     f"expected {inp.ty.reg_len} found "
-                    f"{len(state.ssa_vals[inp])}: {state.ssa_vals[inp]!r}")
-        for out in self.outputs:
-            if out in state.ssa_vals:
-                if self.kind is OpKind.FuncArgR3:
-                    continue
-                raise ValueError(f"SSAVal {out} already assigned before "
-                                 f"running {self}")
-        self.kind.pre_ra_sim(self, state)
+                    f"{len(val)}: {val!r}")
+        if isinstance(state, PreRASimState):
+            for out in self.outputs:
+                if out in state.ssa_vals:
+                    if self.kind is OpKind.FuncArgR3:
+                        continue
+                    raise ValueError(f"SSAVal {out} already assigned before "
+                                     f"running {self}")
+        self.kind.sim(self, state)
         for out in self.outputs:
-            if out not in state.ssa_vals:
+            try:
+                val = state[out]
+            except KeyError:
                 raise ValueError(f"running {self} failed to assign to {out}")
-            if len(state.ssa_vals[out]) != out.ty.reg_len:
+            if len(val) != out.ty.reg_len:
                 raise ValueError(
                     f"value of SSAVal {out} has wrong number of elements: "
                     f"expected {out.ty.reg_len} found "
-                    f"{len(state.ssa_vals[out])}: {state.ssa_vals[out]!r}")
+                    f"{len(val)}: {val!r}")
 
     def gen_asm(self, state):
         # type: (GenAsmState) -> None
@@ -1986,13 +2039,12 @@ GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
 
 
 @plain_data(frozen=True, repr=False)
-@final
-class PreRASimState:
-    __slots__ = "ssa_vals", "memory"
+class BaseSimState(metaclass=ABCMeta):
+    __slots__ = "memory",
 
-    def __init__(self, ssa_vals, memory):
-        # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
-        self.ssa_vals = ssa_vals  # type: dict[SSAVal, tuple[int, ...]]
+    def __init__(self, memory):
+        # type: (dict[int, int]) -> None
+        super().__init__()
         self.memory = memory  # type: dict[int, int]
 
     def load_byte(self, addr):
@@ -2049,6 +2101,44 @@ class PreRASimState:
         items_str = ",\n".join(items)
         return f"{{\n{items_str}}}"
 
+    def __repr__(self):
+        # type: () -> str
+        field_vals = []  # type: list[str]
+        for name in fields(self):
+            try:
+                value = getattr(self, name)
+            except AttributeError:
+                field_vals.append(f"{name}=<not set>")
+                continue
+            repr_fn = getattr(self, f"_{name}__repr", None)
+            if callable(repr_fn):
+                field_vals.append(f"{name}={repr_fn()}")
+            else:
+                field_vals.append(f"{name}={value!r}")
+        field_vals_str = ", ".join(field_vals)
+        return f"{self.__class__.__name__}({field_vals_str})"
+
+    @abstractmethod
+    def __getitem__(self, ssa_val):
+        # type: (SSAVal) -> tuple[int, ...]
+        ...
+
+    @abstractmethod
+    def __setitem__(self, ssa_val, value):
+        # type: (SSAVal, tuple[int, ...]) -> None
+        ...
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class PreRASimState(BaseSimState):
+    __slots__ = "ssa_vals",
+
+    def __init__(self, ssa_vals, memory):
+        # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
+        super().__init__(memory)
+        self.ssa_vals = ssa_vals  # type: dict[SSAVal, tuple[int, ...]]
+
     def _ssa_vals__repr(self):
         # type: () -> str
         if len(self.ssa_vals) == 0:
@@ -2073,22 +2163,65 @@ class PreRASimState:
         items_str = ",\n".join(items)
         return f"{{\n{items_str},\n}}"
 
-    def __repr__(self):
+    def __getitem__(self, ssa_val):
+        # type: (SSAVal) -> tuple[int, ...]
+        return self.ssa_vals[ssa_val]
+
+    def __setitem__(self, ssa_val, value):
+        # type: (SSAVal, tuple[int, ...]) -> None
+        if len(value) != ssa_val.ty.reg_len:
+            raise ValueError("value has wrong len")
+        self.ssa_vals[ssa_val] = value
+
+
+@plain_data(frozen=True, repr=False)
+@final
+class PostRASimState(BaseSimState):
+    __slots__ = "ssa_val_to_loc_map", "loc_values"
+
+    def __init__(self, ssa_val_to_loc_map, memory, loc_values):
+        # type: (dict[SSAVal, Loc], dict[int, int], dict[Loc, int]) -> None
+        super().__init__(memory)
+        self.ssa_val_to_loc_map = FMap(ssa_val_to_loc_map)
+        for ssa_val, loc in self.ssa_val_to_loc_map.items():
+            if ssa_val.ty != loc.ty:
+                raise ValueError(
+                    f"type mismatch for SSAVal and Loc: {ssa_val} {loc}")
+        self.loc_values = loc_values
+        for loc in self.loc_values.keys():
+            if loc.reg_len != 1:
+                raise ValueError(
+                    "loc_values must only contain Locs with reg_len=1, all "
+                    "larger Locs will be split into reg_len=1 sub-Locs")
+
+    def _loc_values__repr(self):
         # type: () -> str
-        field_vals = []  # type: list[str]
-        for name in fields(self):
-            try:
-                value = getattr(self, name)
-            except AttributeError:
-                field_vals.append(f"{name}=<not set>")
-                continue
-            repr_fn = getattr(self, f"_{name}__repr", None)
-            if callable(repr_fn):
-                field_vals.append(f"{name}={repr_fn()}")
-            else:
-                field_vals.append(f"{name}={value!r}")
-        field_vals_str = ", ".join(field_vals)
-        return f"PreRASimState({field_vals_str})"
+        locs = sorted(self.loc_values.keys(), key=lambda v: (v.kind, v.start))
+        items = []  # type: list[str]
+        for loc in locs:
+            items.append(f"{loc}: 0x{self.loc_values[loc]:x}")
+        items_str = ",\n".join(items)
+        return f"{{\n{items_str},\n}}"
+
+    def __getitem__(self, ssa_val):
+        # type: (SSAVal) -> tuple[int, ...]
+        loc = self.ssa_val_to_loc_map[ssa_val]
+        subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
+        retval = []  # type: list[int]
+        for i in range(loc.reg_len):
+            subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
+            retval.append(self.loc_values.get(subloc, 0))
+        return tuple(retval)
+
+    def __setitem__(self, ssa_val, value):
+        # type: (SSAVal, tuple[int, ...]) -> None
+        if len(value) != ssa_val.ty.reg_len:
+            raise ValueError("value has wrong len")
+        loc = self.ssa_val_to_loc_map[ssa_val]
+        subloc_ty = Ty(base_ty=loc.ty.base_ty, reg_len=1)
+        for i in range(loc.reg_len):
+            subloc = loc.get_subloc_at_offset(subloc_ty=subloc_ty, offset=i)
+            self.loc_values[subloc] = value[i]
 
 
 @plain_data(frozen=True)
index 198eebcc9593f13050232e216b58e6699ad2e9c5..278693d165368850d318d40eadaf4bb43321f6d8 100644 (file)
@@ -6,7 +6,7 @@ this uses an algorithm based on:
 """
 
 from itertools import combinations
-from typing import Iterable, Iterator, Mapping
+from typing import Iterable, Iterator, Mapping, TextIO
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
@@ -71,22 +71,19 @@ class MergedSSAVal(metaclass=InternedMeta):
         reg_len = self.ty.reg_len
         loc_set = None  # type: None | LocSet
         for ssa_val, cur_offset in self.ssa_val_offsets_before_spread.items():
-            def_spread_idx = ssa_val.defining_descriptor.spread_index or 0
-
             def locs():
                 # type: () -> Iterable[Loc]
                 for loc in ssa_val.def_loc_set_before_spread:
                     disallowed_by_use = False
                     for use in fn_analysis.uses[ssa_val]:
-                        use_spread_idx = \
-                            use.defining_descriptor.spread_index or 0
                         # calculate the start for the use's Loc before spread
                         # e.g. if the def's Loc before spread starts at r6
-                        # and the def's spread_index is 5
-                        # and the use's spread_index is 3
+                        # and the def's reg_offset_in_unspread is 5
+                        # and the use's reg_offset_in_unspread is 3
                         # then the use's Loc before spread starts at r8
                         # because 8 == 6 + 5 - 3
-                        start = loc.start + def_spread_idx - use_spread_idx
+                        start = (loc.start + ssa_val.reg_offset_in_unspread
+                                 - use.reg_offset_in_unspread)
                         use_loc = Loc.try_make(
                             loc.kind, start=start,
                             reg_len=use.ty_before_spread.reg_len)
@@ -201,8 +198,9 @@ class MergedSSAVal(metaclass=InternedMeta):
         return ProgramRange(start=start, stop=stop)
 
     def __repr__(self):
-        return (f"MergedSSAVal({self.fn_analysis}, "
-                f"ssa_val_offsets={self.ssa_val_offsets})")
+        return (f"MergedSSAVal(ssa_val_offsets={self.ssa_val_offsets}, "
+                f"offset={self.offset}, ty={self.ty}, loc_set={self.loc_set}, "
+                f"live_interval={self.live_interval})")
 
 
 @final
@@ -464,18 +462,26 @@ class AllocationFailedError(Exception):
         return self.__repr__()
 
 
-def allocate_registers(fn):
-    # type: (Fn) -> dict[SSAVal, Loc]
+def allocate_registers(fn, debug_out=None):
+    # type: (Fn, TextIO | None) -> dict[SSAVal, Loc]
 
     # inserts enough copies that no manual spilling is necessary, all
     # spilling is done by the register allocator naturally allocating SSAVals
     # to stack slots
     fn.pre_ra_insert_copies()
 
+    if debug_out is not None:
+        print(f"After pre_ra_insert_copies():\n{fn.ops}",
+              file=debug_out, flush=True)
+
     fn_analysis = FnAnalysis(fn)
     interference_graph = InterferenceGraph.minimally_merged(fn_analysis)
 
-    for ssa_vals in fn_analysis.live_at.values():
+    if debug_out is not None:
+        print(f"After InterferenceGraph.minimally_merged():\n"
+              f"{interference_graph}", file=debug_out, flush=True)
+
+    for pp, ssa_vals in fn_analysis.live_at.items():
         live_merged_ssa_vals = OSet()  # type: OSet[MergedSSAVal]
         for ssa_val in ssa_vals:
             live_merged_ssa_vals.add(
@@ -484,6 +490,13 @@ def allocate_registers(fn):
             if i.loc_set.max_conflicts_with(j.loc_set) != 0:
                 interference_graph.nodes[i].add_edge(
                     interference_graph.nodes[j])
+        if debug_out is not None:
+            print(f"processed {pp} out of {fn_analysis.all_program_points}",
+                  file=debug_out, flush=True)
+
+    if debug_out is not None:
+        print(f"After adding interference graph edges:\n"
+              f"{interference_graph}", file=debug_out, flush=True)
 
     nodes_remaining = OSet(interference_graph.nodes.values())
 
@@ -521,6 +534,10 @@ def allocate_registers(fn):
         node_stack.append(best_node)
         nodes_remaining.remove(best_node)
 
+    if debug_out is not None:
+        print(f"After deciding node allocation order:\n"
+              f"{node_stack}", file=debug_out, flush=True)
+
     retval = {}  # type: dict[SSAVal, Loc]
 
     while len(node_stack) > 0:
@@ -543,7 +560,15 @@ def allocate_registers(fn):
                     "failed to allocate Loc for IGNode",
                     node=node, interference_graph=interference_graph)
 
+        if debug_out is not None:
+            print(f"After allocating Loc for node:\n{node}",
+                  file=debug_out, flush=True)
+
         for ssa_val, offset in node.merged_ssa_val.ssa_val_offsets.items():
             retval[ssa_val] = node.loc.get_subloc_at_offset(ssa_val.ty, offset)
 
+    if debug_out is not None:
+        print(f"final Locs for all SSAVals:\n{retval}",
+              file=debug_out, flush=True)
+
     return retval