working on code
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 3 Nov 2022 07:38:58 +0000 (00:38 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 3 Nov 2022 07:41:19 +0000 (00:41 -0700)
src/bigint_presentation_code/_tests/test_compiler_ir.py
src/bigint_presentation_code/_tests/test_compiler_ir2.py
src/bigint_presentation_code/_tests/test_matrix.py
src/bigint_presentation_code/_tests/test_util.py
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/register_allocator2.py
src/bigint_presentation_code/util.py

index 68df120de6abc77056a32e91420a0d49fe8f94ce..820c305a70d995ecd2912883d6798e364630bbcf 100644 (file)
@@ -9,7 +9,6 @@ from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
                                                   RegLoc, SSAVal, XERBit,
                                                   generate_assembly,
                                                   op_set_to_list)
-import bigint_presentation_code.compiler_ir2
 
 
 class TestCompilerIR(unittest.TestCase):
index 40f02fa218216af36cf6ddc65290537814f0484b..116326a7c1a588ed2014333792aaa7f726051ec1 100644 (file)
@@ -1,13 +1,27 @@
 import unittest
 
 from bigint_presentation_code.compiler_ir2 import (GPR_SIZE_IN_BYTES, Fn,
-                                                   OpKind, PreRASimState,
+                                                   FnAnalysis, OpKind, OpStage,
+                                                   PreRASimState, ProgramPoint,
                                                    SSAVal)
 
 
 class TestCompilerIR(unittest.TestCase):
     maxDiff = None
 
+    def test_program_point(self):
+        # type: () -> None
+        expected = []  # type: list[ProgramPoint]
+        for op_index in range(5):
+            for stage in OpStage:
+                expected.append(ProgramPoint(op_index=op_index, stage=stage))
+
+        for idx, pp in enumerate(expected):
+            if idx + 1 < len(expected):
+                self.assertEqual(pp.next(), expected[idx + 1])
+
+        self.assertEqual(sorted(expected), expected)
+
     def make_add_fn(self):
         # type: () -> tuple[Fn, SSAVal]
         fn = Fn()
@@ -28,10 +42,128 @@ class TestCompilerIR(unittest.TestCase):
         op5 = fn.append_new_op(
             OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
         s = op5.outputs[0]
-        fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
-                         immediates=[0], maxvl=MAXVL, name="st")
+        _ = fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+                             immediates=[0], maxvl=MAXVL, name="st")
         return fn, arg
 
+    def test_fn_analysis(self):
+        fn, _arg = self.make_add_fn()
+        fn_analysis = FnAnalysis(fn)
+        print(repr(fn_analysis))
+        self.assertEqual(
+            repr(fn_analysis),
+            "FnAnalysis(fn=<Fn>, uses=FMap({"
+            "<arg.outputs[0]: <I64>>: OFSet(["
+            "<ld.input_uses[0]: <I64>>, <st.input_uses[1]: <I64>>]), "
+            "<vl.outputs[0]: <VL_MAXVL>>: OFSet(["
+            "<ld.input_uses[1]: <VL_MAXVL>>, <li.input_uses[0]: <VL_MAXVL>>, "
+            "<add.input_uses[3]: <VL_MAXVL>>, "
+            "<st.input_uses[2]: <VL_MAXVL>>]), "
+            "<ld.outputs[0]: <I64*32>>: OFSet(["
+            "<add.input_uses[0]: <I64*32>>]), "
+            "<li.outputs[0]: <I64*32>>: OFSet(["
+            "<add.input_uses[1]: <I64*32>>]), "
+            "<ca.outputs[0]: <CA>>: OFSet([<add.input_uses[2]: <CA>>]), "
+            "<add.outputs[0]: <I64*32>>: OFSet(["
+            "<st.input_uses[0]: <I64*32>>]), "
+            "<add.outputs[1]: <CA>>: OFSet()}), "
+            "op_indexes=FMap({"
+            "Op(kind=OpKind.FuncArgR3, input_vals=[], input_uses=(), "
+            "immediates=[], outputs=(<arg.outputs[0]: <I64>>,), "
+            "name='arg'): 0, "
+            "Op(kind=OpKind.SetVLI, input_vals=[], input_uses=(), "
+            "immediates=[32], outputs=(<vl.outputs[0]: <VL_MAXVL>>,), "
+            "name='vl'): 1, "
+            "Op(kind=OpKind.SvLd, input_vals=["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<ld.input_uses[0]: <I64>>, "
+            "<ld.input_uses[1]: <VL_MAXVL>>), immediates=[0], "
+            "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld'): 2, "
+            "Op(kind=OpKind.SvLI, input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), immediates=[0], "
+            "outputs=(<li.outputs[0]: <I64*32>>,), name='li'): 3, "
+            "Op(kind=OpKind.SetCA, input_vals=[], input_uses=(), "
+            "immediates=[], outputs=(<ca.outputs[0]: <CA>>,), name='ca'): 4, "
+            "Op(kind=OpKind.SvAddE, input_vals=["
+            "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+            "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add.input_uses[0]: <I64*32>>, "
+            "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+            "<add.input_uses[3]: <VL_MAXVL>>), immediates=[], outputs=("
+            "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
+            "name='add'): 5, "
+            "Op(kind=OpKind.SvStd, input_vals=["
+            "<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<st.input_uses[0]: <I64*32>>, "
+            "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
+            "immediates=[0], outputs=(), name='st'): 6}), "
+            "live_ranges=FMap({"
+            "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[6]:Late>, "
+            "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[6]:Late>, "
+            "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[5]:Late>, "
+            "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[5]:Late>, "
+            "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Late>, "
+            "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Late>, "
+            "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
+            "live_at=FMap({"
+            "<ops[0]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
+            "<ops[0]:Late>: OFSet([<arg.outputs[0]: <I64>>]), "
+            "<ops[1]:Early>: OFSet([<arg.outputs[0]: <I64>>]), "
+            "<ops[1]:Late>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>]), "
+            "<ops[2]:Early>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<ld.outputs[0]: <I64*32>>]), "
+            "<ops[2]:Late>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<ld.outputs[0]: <I64*32>>]), "
+            "<ops[3]:Early>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+            "<ops[3]:Late>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+            "<ops[4]:Early>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>]), "
+            "<ops[4]:Late>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+            "<ca.outputs[0]: <CA>>]), "
+            "<ops[5]:Early>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
+            "<ca.outputs[0]: <CA>>, <add.outputs[0]: <I64*32>>, "
+            "<add.outputs[1]: <CA>>]), "
+            "<ops[5]:Late>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>]), "
+            "<ops[6]:Early>: OFSet(["
+            "<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>, "
+            "<add.outputs[0]: <I64*32>>]), "
+            "<ops[6]:Late>: OFSet()}), "
+            "def_program_ranges=FMap({"
+            "<arg.outputs[0]: <I64>>: <range:ops[0]:Early..ops[1]:Early>, "
+            "<vl.outputs[0]: <VL_MAXVL>>: <range:ops[1]:Late..ops[2]:Early>, "
+            "<ld.outputs[0]: <I64*32>>: <range:ops[2]:Early..ops[3]:Early>, "
+            "<li.outputs[0]: <I64*32>>: <range:ops[3]:Early..ops[4]:Early>, "
+            "<ca.outputs[0]: <CA>>: <range:ops[4]:Late..ops[5]:Early>, "
+            "<add.outputs[0]: <I64*32>>: <range:ops[5]:Early..ops[6]:Early>, "
+            "<add.outputs[1]: <CA>>: <range:ops[5]:Early..ops[6]:Early>}), "
+            "use_program_points=FMap({"
+            "<ld.input_uses[0]: <I64>>: <ops[2]:Early>, "
+            "<ld.input_uses[1]: <VL_MAXVL>>: <ops[2]:Early>, "
+            "<li.input_uses[0]: <VL_MAXVL>>: <ops[3]:Early>, "
+            "<add.input_uses[0]: <I64*32>>: <ops[5]:Early>, "
+            "<add.input_uses[1]: <I64*32>>: <ops[5]:Early>, "
+            "<add.input_uses[2]: <CA>>: <ops[5]:Early>, "
+            "<add.input_uses[3]: <VL_MAXVL>>: <ops[5]:Early>, "
+            "<st.input_uses[0]: <I64*32>>: <ops[6]:Early>, "
+            "<st.input_uses[1]: <I64>>: <ops[6]:Early>, "
+            "<st.input_uses[2]: <VL_MAXVL>>: <ops[6]:Early>}), "
+            "all_program_points=<range:ops[0]:Early..ops[7]:Early>)")
+
     def test_repr(self):
         fn, _arg = self.make_add_fn()
         self.assertEqual([repr(i) for i in fn.ops], [
@@ -86,74 +218,91 @@ class TestCompilerIR(unittest.TestCase):
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), maxvl=1)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvLd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), maxvl=32)",
             "OpProperties(kind=OpKind.SvLI, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), maxvl=32)",
             "OpProperties(kind=OpKind.SetCA, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvAddE, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None)), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), maxvl=32)",
             "OpProperties(kind=OpKind.SvStd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=(), maxvl=32)",
         ])
 
@@ -313,221 +462,268 @@ class TestCompilerIR(unittest.TestCase):
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([3])}), ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), maxvl=1)",
             "OpProperties(kind=OpKind.CopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
             "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.CopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
             "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvLd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), maxvl=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.VecCopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.SvLI, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), maxvl=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.VecCopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.SetCA, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.VecCopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.VecCopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.SvAddE, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.CA: FBitSet([0])}), ty=<CA>), "
-            "tied_input_index=None, spread_index=None)), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), maxvl=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.VecCopyFromReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.SetVLI, "
             "inputs=(), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.VecCopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97)), "
             "LocKind.StackI64: FBitSet(range(0, 993))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None),), maxvl=32)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=32)",
             "OpProperties(kind=OpKind.CopyToReg, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)]), "
             "LocKind.StackI64: FBitSet(range(0, 1024))}), ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early),), "
             "outputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
-            "tied_input_index=None, spread_index=None),), maxvl=1)",
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Late),), maxvl=1)",
             "OpProperties(kind=OpKind.SvStd, "
             "inputs=("
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet(range(14, 97))}), ty=<I64*32>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.GPR: FBitSet([*range(3, 13), *range(14, 128)])}), "
             "ty=<I64>), "
-            "tied_input_index=None, spread_index=None), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early), "
             "OperandDesc(loc_set_before_spread=LocSet(starts=FMap({"
             "LocKind.VL_MAXVL: FBitSet([0])}), ty=<VL_MAXVL>), "
-            "tied_input_index=None, spread_index=None)), "
+            "tied_input_index=None, spread_index=None, "
+            "write_stage=OpStage.Early)), "
             "outputs=(), maxvl=32)",
         ])
 
@@ -616,4 +812,4 @@ class TestCompilerIR(unittest.TestCase):
 
 
 if __name__ == "__main__":
-    unittest.main()
+    _ = unittest.main()
index 1a56df005b31718711305d3cbd5f128fb0b4d354..78bd990c29dd692005d2e8c34fcb1661aefd2b20 100644 (file)
@@ -100,14 +100,14 @@ class TestMatrix(unittest.TestCase):
                                        -_1_2, _1_6, _1_2, -_1_6, 2,
                                        0, 0, 0, 0, 1]))
         with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(1, 1, [0]).inverse()
+            _ = Matrix(1, 1, [0]).inverse()
         with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(2, 2, [0, 0, 1, 1]).inverse()
+            _ = Matrix(2, 2, [0, 0, 1, 1]).inverse()
         with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(2, 2, [1, 0, 1, 0]).inverse()
+            _ = Matrix(2, 2, [1, 0, 1, 0]).inverse()
         with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
-            Matrix(2, 2, [1, 1, 1, 1]).inverse()
+            _ = Matrix(2, 2, [1, 1, 1, 1]).inverse()
 
 
 if __name__ == "__main__":
-    unittest.main()
+    _ = unittest.main()
index 0bfe365b441772a8591acd3a5f677b6fa8a51361..d409df2d7e87cade7dd315b838fc68276ad00140 100644 (file)
@@ -27,4 +27,4 @@ class TestBitSet(unittest.TestCase):
 
 
 if __name__ == "__main__":
-    unittest.main()
+    _ = unittest.main()
index 03b623dafbf5f80ebdc96e487088552a42ffe040..05dc6f49d93e9b0e5896807251352bd960185e0e 100644 (file)
@@ -1,7 +1,8 @@
+from collections import defaultdict
 import enum
 from abc import ABCMeta, abstractmethod
 from enum import Enum, unique
-from functools import lru_cache
+from functools import lru_cache, total_ordering
 from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator,
                     Sequence, TypeVar, overload)
 from weakref import WeakValueDictionary as _WeakVDict
@@ -9,7 +10,7 @@ from weakref import WeakValueDictionary as _WeakVDict
 from cached_property import cached_property
 from nmutil.plain_data import fields, plain_data
 
-from bigint_presentation_code.type_util import Self, assert_never, final
+from bigint_presentation_code.type_util import Self, assert_never, final, Literal
 from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet, OSet
 
 
@@ -111,25 +112,221 @@ class Fn:
                     assert_never(out.ty.base_ty)
 
 
+@final
+@unique
+@total_ordering
+class OpStage(Enum):
+    value: Literal[0, 1]  # type: ignore
+
+    def __new__(cls, value):
+        # type: (int) -> OpStage
+        value = int(value)
+        if value not in (0, 1):
+            raise ValueError("invalid value")
+        retval = object.__new__(cls)
+        retval._value_ = value
+        return retval
+
+    Early = 0
+    """ early stage of Op execution, where all input reads occur.
+    all output writes with `write_stage == Early` occur here too, and therefore
+    conflict with input reads, telling the compiler that it that can't share
+    that output's register with any inputs that the output isn't tied to.
+
+    All outputs, even unused outputs, can't share registers with any other
+    outputs, independent of `write_stage` settings.
+    """
+    Late = 1
+    """ late stage of Op execution, where all output writes with
+    `write_stage == Late` occur, and therefore don't conflict with input reads,
+    telling the compiler that any inputs can safely use the same register as
+    those outputs.
+
+    All outputs, even unused outputs, can't share registers with any other
+    outputs, independent of `write_stage` settings.
+    """
+
+    def __repr__(self):
+        # type: () -> str
+        return f"OpStage.{self._name_}"
+
+    def __lt__(self, other):
+        # type: (OpStage | object) -> bool
+        if isinstance(other, OpStage):
+            return self.value < other.value
+        return NotImplemented
+
+
+assert OpStage.Early < OpStage.Late, "early must be less than late"
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+@total_ordering
+class ProgramPoint:
+    __slots__ = "op_index", "stage"
+
+    def __init__(self, op_index, stage):
+        # type: (int, OpStage) -> None
+        self.op_index = op_index
+        self.stage = stage
+
+    @property
+    def int_value(self):
+        # type: () -> int
+        """ an integer representation of `self` such that it keeps ordering and
+        successor/predecessor relations.
+        """
+        return self.op_index * 2 + self.stage.value
+
+    @staticmethod
+    def from_int_value(int_value):
+        # type: (int) -> ProgramPoint
+        op_index, stage = divmod(int_value, 2)
+        return ProgramPoint(op_index=op_index, stage=OpStage(stage))
+
+    def next(self, steps=1):
+        # type: (int) -> ProgramPoint
+        return ProgramPoint.from_int_value(self.int_value + steps)
+
+    def prev(self, steps=1):
+        # type: (int) -> ProgramPoint
+        return self.next(steps=-steps)
+
+    def __lt__(self, other):
+        # type: (ProgramPoint | Any) -> bool
+        if not isinstance(other, ProgramPoint):
+            return NotImplemented
+        if self.op_index != other.op_index:
+            return self.op_index < other.op_index
+        return self.stage < other.stage
+
+    def __repr__(self):
+        # type: () -> str
+        return f"<ops[{self.op_index}]:{self.stage._name_}>"
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class ProgramRange(Sequence[ProgramPoint]):
+    __slots__ = "start", "stop"
+
+    def __init__(self, start, stop):
+        # type: (ProgramPoint, ProgramPoint) -> None
+        self.start = start
+        self.stop = stop
+
+    @cached_property
+    def int_value_range(self):
+        # type: () -> range
+        return range(self.start.int_value, self.stop.int_value)
+
+    @staticmethod
+    def from_int_value_range(int_value_range):
+        # type: (range) -> ProgramRange
+        if int_value_range.step != 1:
+            raise ValueError("int_value_range must have step == 1")
+        return ProgramRange(
+            start=ProgramPoint.from_int_value(int_value_range.start),
+            stop=ProgramPoint.from_int_value(int_value_range.stop))
+
+    @overload
+    def __getitem__(self, __idx):
+        # type: (int) -> ProgramPoint
+        ...
+
+    @overload
+    def __getitem__(self, __idx):
+        # type: (slice) -> ProgramRange
+        ...
+
+    def __getitem__(self, __idx):
+        # type: (int | slice) -> ProgramPoint | ProgramRange
+        v = range(self.start.int_value, self.stop.int_value)[__idx]
+        if isinstance(v, int):
+            return ProgramPoint.from_int_value(v)
+        return ProgramRange.from_int_value_range(v)
+
+    def __len__(self):
+        # type: () -> int
+        return len(self.int_value_range)
+
+    def __iter__(self):
+        # type: () -> Iterator[ProgramPoint]
+        return map(ProgramPoint.from_int_value, self.int_value_range)
+
+    def __repr__(self):
+        # type: () -> str
+        start = repr(self.start).lstrip("<").rstrip(">")
+        stop = repr(self.stop).lstrip("<").rstrip(">")
+        return f"<range:{start}..{stop}>"
+
+
 @plain_data(frozen=True, eq=False)
 @final
-class FnWithUses:
-    __slots__ = "fn", "uses"
+class FnAnalysis:
+    __slots__ = ("fn", "uses", "op_indexes", "live_ranges", "live_at",
+                 "def_program_ranges", "use_program_points",
+                 "all_program_points")
 
     def __init__(self, fn):
         # type: (Fn) -> None
         self.fn = fn
-        retval = {}  # type: dict[SSAVal, OSet[SSAUse]]
+        self.op_indexes = FMap((op, idx) for idx, op in enumerate(fn.ops))
+        self.all_program_points = ProgramRange(
+            start=ProgramPoint(op_index=0, stage=OpStage.Early),
+            stop=ProgramPoint(op_index=len(fn.ops), stage=OpStage.Early))
+        def_program_ranges = {}  # type: dict[SSAVal, ProgramRange]
+        use_program_points = {}  # type: dict[SSAUse, ProgramPoint]
+        uses = {}  # type: dict[SSAVal, OSet[SSAUse]]
+        live_range_stops = {}  # type: dict[SSAVal, ProgramPoint]
         for op in fn.ops:
-            for idx, inp in enumerate(op.input_vals):
-                retval[inp].add(SSAUse(op, idx))
+            for use in op.input_uses:
+                uses[use.ssa_val].add(use)
+                use_program_point = self.__get_use_program_point(use)
+                use_program_points[use] = use_program_point
+                live_range_stops[use.ssa_val] = max(
+                    live_range_stops[use.ssa_val], use_program_point.next())
             for out in op.outputs:
-                retval[out] = OSet()
-        self.uses = FMap((k, OFSet(v)) for k, v in retval.items())
+                uses[out] = OSet()
+                def_program_range = self.__get_def_program_range(out)
+                def_program_ranges[out] = def_program_range
+                live_range_stops[out] = def_program_range.stop
+        self.uses = FMap((k, OFSet(v)) for k, v in uses.items())
+        self.def_program_ranges = FMap(def_program_ranges)
+        self.use_program_points = FMap(use_program_points)
+        live_ranges = {}  # type: dict[SSAVal, ProgramRange]
+        live_at = {i: OSet[SSAVal]() for i in self.all_program_points}
+        for ssa_val in uses.keys():
+            live_ranges[ssa_val] = live_range = ProgramRange(
+                start=self.def_program_ranges[ssa_val].start,
+                stop=live_range_stops[ssa_val])
+            for program_point in live_range:
+                live_at[program_point].add(ssa_val)
+        self.live_ranges = FMap(live_ranges)
+        self.live_at = FMap((k, OFSet(v)) for k, v in live_at.items())
+
+    def __get_def_program_range(self, ssa_val):
+        # type: (SSAVal) -> ProgramRange
+        write_stage = ssa_val.defining_descriptor.write_stage
+        start = ProgramPoint(
+            op_index=self.op_indexes[ssa_val.op], stage=write_stage)
+        # always include late stage of ssa_val.op, to ensure outputs always
+        # overlap all other outputs.
+        # stop is exclusive, so we need the next program point.
+        stop = ProgramPoint(op_index=start.op_index, stage=OpStage.Late).next()
+        return ProgramRange(start=start, stop=stop)
+
+    def __get_use_program_point(self, ssa_use):
+        # type: (SSAUse) -> ProgramPoint
+        assert ssa_use.defining_descriptor.write_stage is OpStage.Early, \
+            "assumed here, ensured by GenericOpProperties.__init__"
+        return ProgramPoint(
+            op_index=self.op_indexes[ssa_use.op], stage=OpStage.Early)
 
     def __eq__(self, other):
-        # type: (FnWithUses | Any) -> bool
-        if isinstance(other, FnWithUses):
+        # type: (FnAnalysis | Any) -> bool
+        if isinstance(other, FnAnalysis):
             return self.fn == other.fn
         return NotImplemented
 
@@ -255,10 +452,9 @@ class LocSubKind(Enum):
         # type: () -> LocKind
         # pyright fails typechecking when using `in` here:
         # reported: https://github.com/microsoft/pyright/issues/4102
-        if self is LocSubKind.BASE_GPR or self is LocSubKind.SV_EXTRA2_VGPR \
-                or self is LocSubKind.SV_EXTRA2_SGPR \
-                or self is LocSubKind.SV_EXTRA3_VGPR \
-                or self is LocSubKind.SV_EXTRA3_SGPR:
+        if self in (LocSubKind.BASE_GPR, LocSubKind.SV_EXTRA2_VGPR,
+                    LocSubKind.SV_EXTRA2_SGPR, LocSubKind.SV_EXTRA3_VGPR,
+                    LocSubKind.SV_EXTRA3_SGPR):
             return LocKind.GPR
         if self is LocSubKind.StackI64:
             return LocKind.StackI64
@@ -526,12 +722,24 @@ class LocSet(AbstractSet[Loc]):
     def __hash__(self):
         return self.__hash
 
+    @lru_cache(maxsize=None, typed=True)
+    def max_conflicts_with(self, other):
+        # type: (LocSet | Loc) -> int
+        """the largest number of Locs in `self` that a single Loc
+        from `other` can conflict with
+        """
+        if isinstance(other, LocSet):
+            return max(self.max_conflicts_with(i) for i in other)
+        else:
+            return sum(other.conflicts(i) for i in self)
+
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
 class GenericOperandDesc:
     """generic Op operand descriptor"""
-    __slots__ = "ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread"
+    __slots__ = ("ty", "fixed_loc", "sub_kinds", "tied_input_index", "spread",
+                 "write_stage")
 
     def __init__(
         self, ty,  # type: GenericTy
@@ -540,6 +748,7 @@ class GenericOperandDesc:
         fixed_loc=None,  # type: Loc | None
         tied_input_index=None,  # type: int | None
         spread=False,  # type: bool
+        write_stage=OpStage.Early,  # type: OpStage
     ):
         # type: (...) -> None
         self.ty = ty
@@ -577,15 +786,26 @@ class GenericOperandDesc:
                 raise ValueError("operand can't be both spread and fixed")
             if self.ty.is_vec:
                 raise ValueError("operand can't be both spread and vector")
+        self.write_stage = write_stage
 
     def tied_to_input(self, tied_input_index):
         # type: (int) -> Self
         return GenericOperandDesc(self.ty, self.sub_kinds,
-                                  tied_input_index=tied_input_index)
+                                  tied_input_index=tied_input_index,
+                                  write_stage=self.write_stage)
 
     def with_fixed_loc(self, fixed_loc):
         # type: (Loc) -> Self
-        return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc)
+        return GenericOperandDesc(self.ty, self.sub_kinds, fixed_loc=fixed_loc,
+                                  write_stage=self.write_stage)
+
+    def with_write_stage(self, write_stage):
+        # type: (OpStage) -> Self
+        return GenericOperandDesc(self.ty, self.sub_kinds,
+                                  fixed_loc=self.fixed_loc,
+                                  tied_input_index=self.tied_input_index,
+                                  spread=self.spread,
+                                  write_stage=write_stage)
 
     def instantiate(self, maxvl):
         # type: (int) -> Iterable[OperandDesc]
@@ -613,17 +833,19 @@ class GenericOperandDesc:
                 idx = None
             yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
                               tied_input_index=self.tied_input_index,
-                              spread_index=idx)
+                              spread_index=idx, write_stage=self.write_stage)
 
 
 @plain_data(frozen=True, unsafe_hash=True)
 @final
 class OperandDesc:
     """Op operand descriptor"""
-    __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
+    __slots__ = ("loc_set_before_spread", "tied_input_index", "spread_index",
+                 "write_stage")
 
-    def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
-        # type: (LocSet, int | None, int | None) -> None
+    def __init__(self, loc_set_before_spread, tied_input_index, spread_index,
+                 write_stage):
+        # type: (LocSet, int | None, int | None, OpStage) -> None
         if len(loc_set_before_spread) == 0:
             raise ValueError("loc_set_before_spread must not be empty")
         self.loc_set_before_spread = loc_set_before_spread
@@ -631,6 +853,7 @@ class OperandDesc:
         if self.tied_input_index is not None and self.spread_index is not None:
             raise ValueError("operand can't be both spread and tied")
         self.spread_index = spread_index
+        self.write_stage = write_stage
 
     @cached_property
     def ty_before_spread(self):
@@ -702,13 +925,16 @@ class GenericOpProperties:
         has_side_effects=False,  # type: bool
     ):
         # type: (...) -> None
-        self.demo_asm = demo_asm
-        self.inputs = tuple(inputs)
+        self.demo_asm = demo_asm  # type: str
+        self.inputs = tuple(inputs)  # type: tuple[GenericOperandDesc, ...]
         for inp in self.inputs:
             if inp.tied_input_index is not None:
                 raise ValueError(
                     f"tied_input_index is not allowed on inputs: {inp}")
-        self.outputs = tuple(outputs)
+            if inp.write_stage is not OpStage.Early:
+                raise ValueError(
+                    f"write_stage is not allowed on inputs: {inp}")
+        self.outputs = tuple(outputs)  # type: tuple[GenericOperandDesc, ...]
         fixed_locs = []  # type: list[tuple[Loc, int]]
         for idx, out in enumerate(self.outputs):
             if out.tied_input_index is not None:
@@ -727,10 +953,10 @@ class GenericOpProperties:
                         f"outputs[{other_idx}]: {out.fixed_loc} conflicts "
                         f"with {other_fixed_loc}")
                 fixed_locs.append((out.fixed_loc, idx))
-        self.immediates = tuple(immediates)
-        self.is_copy = is_copy
-        self.is_load_immediate = is_load_immediate
-        self.has_side_effects = has_side_effects
+        self.immediates = tuple(immediates)  # type: tuple[range, ...]
+        self.is_copy = is_copy  # type: bool
+        self.is_load_immediate = is_load_immediate  # type: bool
+        self.has_side_effects = has_side_effects  # type: bool
 
 
 @plain_data(frozen=True, unsafe_hash=True)
@@ -740,16 +966,16 @@ class OpProperties:
 
     def __init__(self, kind, maxvl):
         # type: (OpKind, int) -> None
-        self.kind = kind
+        self.kind = kind  # type: OpKind
         inputs = []  # type: list[OperandDesc]
         for inp in self.generic.inputs:
             inputs.extend(inp.instantiate(maxvl=maxvl))
-        self.inputs = tuple(inputs)
+        self.inputs = tuple(inputs)  # type: tuple[OperandDesc, ...]
         outputs = []  # type: list[OperandDesc]
         for out in self.generic.outputs:
             outputs.extend(out.instantiate(maxvl=maxvl))
-        self.outputs = tuple(outputs)
-        self.maxvl = maxvl
+        self.outputs = tuple(outputs)  # type: tuple[OperandDesc, ...]
+        self.maxvl = maxvl  # type: int
 
     @property
     def generic(self):
@@ -807,6 +1033,7 @@ class OpKind(Enum):
         return OpProperties(self, maxvl=maxvl)
 
     def __repr__(self):
+        # type: () -> str
         return "OpKind." + self._name_
 
     @cached_property
@@ -821,7 +1048,7 @@ class OpKind(Enum):
     ClearCA = GenericOpProperties(
         demo_asm="addic 0, 0, 0",
         inputs=[],
-        outputs=[OD_CA],
+        outputs=[OD_CA.with_write_stage(OpStage.Late)],
     )
     _PRE_RA_SIMS[ClearCA] = lambda: OpKind.__clearca_pre_ra_sim
 
@@ -832,7 +1059,7 @@ class OpKind(Enum):
     SetCA = GenericOpProperties(
         demo_asm="subfc 0, 0, 0",
         inputs=[],
-        outputs=[OD_CA],
+        outputs=[OD_CA.with_write_stage(OpStage.Late)],
     )
     _PRE_RA_SIMS[SetCA] = lambda: OpKind.__setca_pre_ra_sim
 
@@ -906,7 +1133,7 @@ class OpKind(Enum):
     SetVLI = GenericOpProperties(
         demo_asm="setvl 0, 0, imm, 0, 1, 1",
         inputs=(),
-        outputs=[OD_VL],
+        outputs=[OD_VL.with_write_stage(OpStage.Late)],
         immediates=[range(1, 65)],
         is_load_immediate=True,
     )
@@ -935,7 +1162,7 @@ class OpKind(Enum):
     LI = GenericOpProperties(
         demo_asm="addi RT, 0, imm",
         inputs=(),
-        outputs=[OD_BASE_SGPR],
+        outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
         immediates=[IMM_S16],
         is_load_immediate=True,
     )
@@ -951,7 +1178,7 @@ class OpKind(Enum):
             ty=GenericTy(BaseTy.I64, is_vec=True),
             sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
         ), OD_VL],
-        outputs=[OD_EXTRA3_VGPR],
+        outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
         is_copy=True,
     )
     _PRE_RA_SIMS[VecCopyToReg] = lambda: OpKind.__veccopytoreg_pre_ra_sim
@@ -966,6 +1193,7 @@ class OpKind(Enum):
         outputs=[GenericOperandDesc(
             ty=GenericTy(BaseTy.I64, is_vec=True),
             sub_kinds=[LocSubKind.SV_EXTRA3_VGPR, LocSubKind.StackI64],
+            write_stage=OpStage.Late,
         )],
         is_copy=True,
     )
@@ -985,6 +1213,7 @@ class OpKind(Enum):
         outputs=[GenericOperandDesc(
             ty=GenericTy(BaseTy.I64, is_vec=False),
             sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR],
+            write_stage=OpStage.Late,
         )],
         is_copy=True,
     )
@@ -1004,6 +1233,7 @@ class OpKind(Enum):
             ty=GenericTy(BaseTy.I64, is_vec=False),
             sub_kinds=[LocSubKind.SV_EXTRA3_SGPR, LocSubKind.BASE_GPR,
                        LocSubKind.StackI64],
+            write_stage=OpStage.Late,
         )],
         is_copy=True,
     )
@@ -1021,7 +1251,7 @@ class OpKind(Enum):
             sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
             spread=True,
         ), OD_VL],
-        outputs=[OD_EXTRA3_VGPR],
+        outputs=[OD_EXTRA3_VGPR.with_write_stage(OpStage.Late)],
         is_copy=True,
     )
     _PRE_RA_SIMS[Concat] = lambda: OpKind.__concat_pre_ra_sim
@@ -1038,6 +1268,7 @@ class OpKind(Enum):
             ty=GenericTy(BaseTy.I64, is_vec=False),
             sub_kinds=[LocSubKind.SV_EXTRA3_VGPR],
             spread=True,
+            write_stage=OpStage.Late,
         )],
         is_copy=True,
     )
@@ -1072,7 +1303,7 @@ class OpKind(Enum):
     Ld = GenericOpProperties(
         demo_asm="ld RT, imm(RA)",
         inputs=[OD_BASE_SGPR],
-        outputs=[OD_BASE_SGPR],
+        outputs=[OD_BASE_SGPR.with_write_stage(OpStage.Late)],
         immediates=[IMM_S16],
     )
     _PRE_RA_SIMS[Ld] = lambda: OpKind.__ld_pre_ra_sim
@@ -1130,6 +1361,7 @@ class SSAValOrUse(metaclass=ABCMeta):
 
     def __init__(self, op, operand_idx):
         # type: (Op, int) -> None
+        super().__init__()
         self.op = op
         if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
             raise ValueError("invalid operand_idx")
@@ -1146,7 +1378,7 @@ class SSAValOrUse(metaclass=ABCMeta):
         # type: () -> tuple[OperandDesc, ...]
         ...
 
-    @property
+    @cached_property
     def defining_descriptor(self):
         # type: () -> OperandDesc
         return self.descriptor_array[self.operand_idx]
@@ -1215,6 +1447,11 @@ class SSAVal(SSAValOrUse):
         return SSAUse(op=self.op,
                       operand_idx=self.defining_descriptor.tied_input_index)
 
+    @property
+    def write_stage(self):
+        # type: () -> OpStage
+        return self.defining_descriptor.write_stage
+
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
@@ -1292,12 +1529,13 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
 
     def __init__(self, items, op):
         # type: (Iterable[_T], Op) -> None
+        super().__init__()
         self.__op = op
         self.__items = []  # type: list[_T]
         for idx, item in enumerate(items):
             if idx >= len(self.descriptors):
                 raise ValueError("too many items")
-            self._verify_write(idx, item)
+            _ = self._verify_write(idx, item)
             self.__items.append(item)
         if len(self.__items) < len(self.descriptors):
             raise ValueError("not enough items")
@@ -1334,6 +1572,7 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
         return len(self.__items)
 
     def __repr__(self):
+        # type: () -> str
         return f"{self.__class__.__name__}({self.__items}, op=...)"
 
 
@@ -1402,6 +1641,7 @@ class Op:
 
     @property
     def kind(self):
+        # type: () -> OpKind
         return self.properties.kind
 
     def __eq__(self, other):
@@ -1411,6 +1651,7 @@ class Op:
         return NotImplemented
 
     def __hash__(self):
+        # type: () -> int
         return object.__hash__(self)
 
     def __repr__(self):
@@ -1473,8 +1714,8 @@ class PreRASimState:
 
     def __init__(self, ssa_vals, memory):
         # type: (dict[SSAVal, tuple[int, ...]], dict[int, int]) -> None
-        self.ssa_vals = ssa_vals
-        self.memory = memory
+        self.ssa_vals = ssa_vals  # type: dict[SSAVal, tuple[int, ...]]
+        self.memory = memory  # type: dict[int, int]
 
     def load_byte(self, addr):
         # type: (int) -> int
index 68443d9a6e65054760934ede353bb4a67435bcde..c7d9a88a79b873ed705441be57200867da86924d 100644 (file)
@@ -6,13 +6,14 @@ this uses an algorithm based on:
 """
 
 from itertools import combinations
-from typing import Any, Generic, Iterable, Iterator, Mapping, MutableSet
+from typing import Any, Iterable, Iterator, Mapping, MutableSet
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.compiler_ir2 import (BaseTy, FnWithUses, Loc,
-                                                   LocSet, Op, SSAVal, Ty)
+from bigint_presentation_code.compiler_ir2 import (BaseTy, FnAnalysis, Loc,
+                                                   LocSet, Op, ProgramRange,
+                                                   SSAVal, Ty)
 from bigint_presentation_code.type_util import final
 from bigint_presentation_code.util import FMap, OFSet, OSet
 
@@ -23,6 +24,7 @@ class LiveInterval:
 
     def __init__(self, first_write, last_use=None):
         # type: (int, int | None) -> None
+        super().__init__()
         if last_use is None:
             last_use = first_write
         if last_use < first_write:
@@ -86,11 +88,11 @@ class MergedSSAVal:
     * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
     * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
     """
-    __slots__ = "fn_with_uses", "ssa_val_offsets", "first_ssa_val", "loc_set"
+    __slots__ = "fn_analysis", "ssa_val_offsets", "first_ssa_val", "loc_set"
 
-    def __init__(self, fn_with_uses, ssa_val_offsets):
-        # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
-        self.fn_with_uses = fn_with_uses
+    def __init__(self, fn_analysis, ssa_val_offsets):
+        # type: (FnAnalysis, Mapping[SSAVal, int] | SSAVal) -> None
+        self.fn_analysis = fn_analysis
         if isinstance(ssa_val_offsets, SSAVal):
             ssa_val_offsets = {ssa_val_offsets: 0}
         self.ssa_val_offsets = FMap(ssa_val_offsets)  # type: FMap[SSAVal, int]
@@ -111,7 +113,7 @@ class MergedSSAVal:
                 # type: () -> Iterable[Loc]
                 for loc in ssa_val.def_loc_set_before_spread:
                     disallowed_by_use = False
-                    for use in fn_with_uses.uses[ssa_val]:
+                    for use in fn_analysis.uses[ssa_val]:
                         use_spread_idx = \
                             use.defining_descriptor.spread_index or 0
                         # calculate the start for the use's Loc before spread
@@ -145,7 +147,7 @@ class MergedSSAVal:
     @cached_property
     def __hash(self):
         # type: () -> int
-        return hash((self.fn_with_uses, self.ssa_val_offsets))
+        return hash((self.fn_analysis, self.ssa_val_offsets))
 
     def __hash__(self):
         # type: () -> int
@@ -190,7 +192,7 @@ class MergedSSAVal:
     def offset_by(self, amount):
         # type: (int) -> MergedSSAVal
         v = {k: v + amount for k, v in self.ssa_val_offsets.items()}
-        return MergedSSAVal(fn_with_uses=self.fn_with_uses, ssa_val_offsets=v)
+        return MergedSSAVal(fn_analysis=self.fn_analysis, ssa_val_offsets=v)
 
     def normalized(self):
         # type: () -> MergedSSAVal
@@ -212,16 +214,28 @@ class MergedSSAVal:
         # type: (*MergedSSAVal) -> MergedSSAVal
         retval = dict(self.ssa_val_offsets)
         for other in others:
-            if other.fn_with_uses != self.fn_with_uses:
-                raise ValueError("fn_with_uses mismatch")
+            if other.fn_analysis != self.fn_analysis:
+                raise ValueError("fn_analysis mismatch")
             for ssa_val, offset in other.ssa_val_offsets.items():
                 if ssa_val in retval and retval[ssa_val] != offset:
                     raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
                                           f"{retval[ssa_val]} != {offset}")
                 retval[ssa_val] = offset
-        return MergedSSAVal(fn_with_uses=self.fn_with_uses,
+        return MergedSSAVal(fn_analysis=self.fn_analysis,
                             ssa_val_offsets=retval)
 
+    @cached_property
+    def live_interval(self):
+        # type: () -> ProgramRange
+        live_range = self.fn_analysis.live_ranges[self.first_ssa_val]
+        start = live_range.start
+        stop = live_range.stop
+        for ssa_val in self.ssa_vals:
+            live_range = self.fn_analysis.live_ranges[ssa_val]
+            start = min(start, live_range.start)
+            stop = max(stop, live_range.stop)
+        return ProgramRange(start=start, stop=stop)
+
 
 @final
 class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]):
@@ -322,11 +336,11 @@ class MergedSSAValsSet(MutableSet[MergedSSAVal]):
 @plain_data(frozen=True)
 @final
 class MergedSSAVals:
-    __slots__ = "fn_with_uses", "merge_map", "merged_ssa_vals"
+    __slots__ = "fn_analysis", "merge_map", "merged_ssa_vals"
 
-    def __init__(self, fn_with_uses, merged_ssa_vals):
-        # type: (FnWithUses, Iterable[MergedSSAVal]) -> None
-        self.fn_with_uses = fn_with_uses
+    def __init__(self, fn_analysis, merged_ssa_vals):
+        # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
+        self.fn_analysis = fn_analysis
         self.merge_map = MergedSSAValsMap()
         self.merged_ssa_vals = self.merge_map.values_set
         for i in merged_ssa_vals:
@@ -345,10 +359,10 @@ class MergedSSAVals:
         return merged
 
     @staticmethod
-    def minimally_merged(fn_with_uses):
-        # type: (FnWithUses) -> MergedSSAVals
-        retval = MergedSSAVals(fn_with_uses=fn_with_uses, merged_ssa_vals=())
-        for op in fn_with_uses.fn.ops:
+    def minimally_merged(fn_analysis):
+        # type: (FnAnalysis) -> MergedSSAVals
+        retval = MergedSSAVals(fn_analysis=fn_analysis, merged_ssa_vals=())
+        for op in fn_analysis.fn.ops:
             for inp in op.input_uses:
                 if inp.unspread_start != inp:
                     retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
@@ -362,67 +376,16 @@ class MergedSSAVals:
         return retval
 
 
-# FIXME: work on code from here
-
-
-@final
-class LiveIntervals(Mapping[MergedSSAVal, LiveInterval]):
-    def __init__(self, merged_ssa_vals):
-        # type: (list[Op]) -> None
-        self.__merged_reg_sets = MergedRegSets(ops)
-        live_intervals = {}  # type: dict[MergedRegSet[_RegType], LiveInterval]
-        for op_idx, op in enumerate(ops):
-            for val in op.inputs().values():
-                live_intervals[self.__merged_reg_sets[val]] += op_idx
-            for val in op.outputs().values():
-                reg_set = self.__merged_reg_sets[val]
-                if reg_set not in live_intervals:
-                    live_intervals[reg_set] = LiveInterval(op_idx)
-                else:
-                    live_intervals[reg_set] += op_idx
-        self.__live_intervals = live_intervals
-        live_after = []  # type: list[OSet[MergedRegSet[_RegType]]]
-        live_after += (OSet() for _ in ops)
-        for reg_set, live_interval in self.__live_intervals.items():
-            for i in live_interval.live_after_op_range:
-                live_after[i].add(reg_set)
-        self.__live_after = [OFSet(i) for i in live_after]
-
-    @property
-    def merged_reg_sets(self):
-        return self.__merged_reg_sets
-
-    def __getitem__(self, key):
-        # type: (MergedRegSet[_RegType]) -> LiveInterval
-        return self.__live_intervals[key]
-
-    def __iter__(self):
-        return iter(self.__live_intervals)
-
-    def __len__(self):
-        return len(self.__live_intervals)
-
-    def reg_sets_live_after(self, op_index):
-        # type: (int) -> OFSet[MergedRegSet[_RegType]]
-        return self.__live_after[op_index]
-
-    def __repr__(self):
-        reg_sets_live_after = dict(enumerate(self.__live_after))
-        return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
-                f"merged_reg_sets={self.merged_reg_sets}, "
-                f"reg_sets_live_after={reg_sets_live_after})")
-
-
 @final
-class IGNode(Generic[_RegType]):
+class IGNode:
     """ interference graph node """
-    __slots__ = "merged_reg_set", "edges", "reg"
+    __slots__ = "merged_ssa_val", "edges", "loc"
 
-    def __init__(self, merged_reg_set, edges=(), reg=None):
-        # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
-        self.merged_reg_set = merged_reg_set
+    def __init__(self, merged_ssa_val, edges=(), loc=None):
+        # type: (MergedSSAVal, Iterable[IGNode], Loc | None) -> None
+        self.merged_ssa_val = merged_ssa_val
         self.edges = OSet(edges)
-        self.reg = reg
+        self.loc = loc
 
     def add_edge(self, other):
         # type: (IGNode) -> None
@@ -432,11 +395,11 @@ class IGNode(Generic[_RegType]):
     def __eq__(self, other):
         # type: (object) -> bool
         if isinstance(other, IGNode):
-            return self.merged_reg_set == other.merged_reg_set
+            return self.merged_ssa_val == other.merged_ssa_val
         return NotImplemented
 
     def __hash__(self):
-        return hash(self.merged_reg_set)
+        return hash(self.merged_ssa_val)
 
     def __repr__(self, nodes=None):
         # type: (None | dict[IGNode, int]) -> str
@@ -447,54 +410,32 @@ class IGNode(Generic[_RegType]):
         nodes[self] = len(nodes)
         edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
         return (f"IGNode(#{nodes[self]}, "
-                f"merged_reg_set={self.merged_reg_set}, "
+                f"merged_ssa_val={self.merged_ssa_val}, "
                 f"edges={edges}, "
-                f"reg={self.reg})")
+                f"loc={self.loc})")
 
     @property
-    def reg_class(self):
-        # type: () -> RegClass
-        return self.merged_reg_set.ty.reg_class
+    def loc_set(self):
+        # type: () -> LocSet
+        return self.merged_ssa_val.loc_set
 
-    def reg_conflicts_with_neighbors(self, reg):
-        # type: (RegLoc) -> bool
+    def loc_conflicts_with_neighbors(self, loc):
+        # type: (Loc) -> bool
         for neighbor in self.edges:
-            if neighbor.reg is not None and neighbor.reg.conflicts(reg):
+            if neighbor.loc is not None and neighbor.loc.conflicts(loc):
                 return True
         return False
 
 
-@final
-class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
-    def __init__(self, merged_reg_sets):
-        # type: (Iterable[MergedRegSet[_RegType]]) -> None
-        self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
-
-    def __getitem__(self, key):
-        # type: (MergedRegSet[_RegType]) -> IGNode
-        return self.__nodes[key]
-
-    def __iter__(self):
-        return iter(self.__nodes)
-
-    def __len__(self):
-        return len(self.__nodes)
-
-    def __repr__(self):
-        nodes = {}
-        nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
-        nodes_text = ", ".join(nodes_text)
-        return f"InterferenceGraph(nodes={{{nodes_text}}})"
-
-
 @plain_data()
 class AllocationFailed:
-    __slots__ = "node", "live_intervals", "interference_graph"
+    __slots__ = "node", "merged_ssa_vals", "interference_graph"
 
-    def __init__(self, node, live_intervals, interference_graph):
-        # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
+    def __init__(self, node, merged_ssa_vals, interference_graph):
+        # type: (IGNode, MergedSSAVals, dict[MergedSSAVal, IGNode]) -> None
+        super().__init__()
         self.node = node
-        self.live_intervals = live_intervals
+        self.merged_ssa_vals = merged_ssa_vals
         self.interference_graph = interference_graph
 
 
@@ -505,25 +446,24 @@ class AllocationFailedError(Exception):
         self.allocation_failed = allocation_failed
 
 
-def try_allocate_registers_without_spilling(ops):
-    # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
+def try_allocate_registers_without_spilling(merged_ssa_vals):
+    # type: (MergedSSAVals) -> dict[SSAVal, Loc] | AllocationFailed
 
-    live_intervals = LiveIntervals(ops)
-    merged_reg_sets = live_intervals.merged_reg_sets
-    interference_graph = InterferenceGraph(merged_reg_sets.values())
-    for op_idx, op in enumerate(ops):
-        reg_sets = live_intervals.reg_sets_live_after(op_idx)
-        for i, j in combinations(reg_sets, 2):
-            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
-                interference_graph[i].add_edge(interference_graph[j])
-        for i, j in op.get_extra_interferences():
-            i = merged_reg_sets[i]
-            j = merged_reg_sets[j]
-            if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
+    interference_graph = {
+        i: IGNode(i) for i in merged_ssa_vals.merged_ssa_vals}
+    fn_analysis = merged_ssa_vals.fn_analysis
+    for ssa_vals in fn_analysis.live_at.values():
+        live_merged_ssa_vals = OSet()  # type: OSet[MergedSSAVal]
+        for ssa_val in ssa_vals:
+            live_merged_ssa_vals.add(merged_ssa_vals.merge_map[ssa_val])
+        for i, j in combinations(live_merged_ssa_vals, 2):
+            if i.loc_set.max_conflicts_with(j.loc_set) != 0:
                 interference_graph[i].add_edge(interference_graph[j])
 
     nodes_remaining = OSet(interference_graph.values())
 
+# FIXME: work on code from here
+
     def local_colorability_score(node):
         # type: (IGNode) -> int
         """ returns a positive integer if node is locally colorable, returns
@@ -532,7 +472,7 @@ def try_allocate_registers_without_spilling(ops):
         """
         if node not in nodes_remaining:
             raise ValueError()
-        retval = len(node.reg_class)
+        retval = len(node.loc_set)
         for neighbor in node.edges:
             if neighbor in nodes_remaining:
                 retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
index b85b3ac0e959d82f66df263c5d7e0eebe6fca41a..757f267ecdfabcc923980e3d863666b5c55d0305 100644 (file)
@@ -26,6 +26,7 @@ class OFSet(AbstractSet[_T_co]):
 
     def __init__(self, items=()):
         # type: (Iterable[_T_co]) -> None
+        super().__init__()
         self.__items = {v: None for v in items}
 
     def __contains__(self, x):
@@ -57,6 +58,7 @@ class OSet(MutableSet[_T]):
 
     def __init__(self, items=()):
         # type: (Iterable[_T]) -> None
+        super().__init__()
         self.__items = {v: None for v in items}
 
     def __contains__(self, x):
@@ -107,6 +109,7 @@ class FMap(Mapping[_T, _T_co]):
 
     def __init__(self, items=()):
         # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
+        super().__init__()
         self.__items = dict(items)  # type: dict[_T, _T_co]
         self.__hash = None  # type: None | int
 
@@ -179,6 +182,7 @@ class BaseBitSet(AbstractSet[int]):
 
     def __init__(self, items=(), bits=0):
         # type: (Iterable[int], int) -> None
+        super().__init__()
         if isinstance(items, BaseBitSet):
             bits |= items.bits
         else: