working on code
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 2 Nov 2022 07:32:00 +0000 (00:32 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 2 Nov 2022 07:32:00 +0000 (00:32 -0700)
src/bigint_presentation_code/_tests/test_compiler_ir2.py
src/bigint_presentation_code/compiler_ir2.py
src/bigint_presentation_code/register_allocator2.py

index 4aa8e3eda5f31cdceaa1751d431ab6178f162838..40f02fa218216af36cf6ddc65290537814f0484b 100644 (file)
@@ -17,54 +17,66 @@ class TestCompilerIR(unittest.TestCase):
         op1 = fn.append_new_op(OpKind.SetVLI, immediates=[MAXVL], name="vl")
         vl = op1.outputs[0]
         op2 = fn.append_new_op(
-            OpKind.SvLd, inputs=[arg, vl], immediates=[0], maxvl=MAXVL,
+            OpKind.SvLd, input_vals=[arg, vl], immediates=[0], maxvl=MAXVL,
             name="ld")
         a = op2.outputs[0]
-        op3 = fn.append_new_op(
-            OpKind.SvLI, inputs=[vl], immediates=[0], maxvl=MAXVL, name="li")
+        op3 = fn.append_new_op(OpKind.SvLI, input_vals=[vl], immediates=[0],
+                               maxvl=MAXVL, name="li")
         b = op3.outputs[0]
         op4 = fn.append_new_op(OpKind.SetCA, name="ca")
         ca = op4.outputs[0]
         op5 = fn.append_new_op(
-            OpKind.SvAddE, inputs=[a, b, ca, vl], maxvl=MAXVL, name="add")
+            OpKind.SvAddE, input_vals=[a, b, ca, vl], maxvl=MAXVL, name="add")
         s = op5.outputs[0]
-        fn.append_new_op(
-            OpKind.SvStd, inputs=[s, arg, vl], immediates=[0], maxvl=MAXVL,
-            name="st")
+        fn.append_new_op(OpKind.SvStd, input_vals=[s, arg, vl],
+                         immediates=[0], maxvl=MAXVL, name="st")
         return fn, arg
 
     def test_repr(self):
         fn, _arg = self.make_add_fn()
         self.assertEqual([repr(i) for i in fn.ops], [
             "Op(kind=OpKind.FuncArgR3, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[], "
             "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
             "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
             "Op(kind=OpKind.SvLd, "
-            "inputs=[<arg.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<arg.outputs[0]: <I64>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<ld.input_uses[0]: <I64>>, "
+            "<ld.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[0], "
             "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
             "Op(kind=OpKind.SvLI, "
-            "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
             "immediates=[0], "
             "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
             "Op(kind=OpKind.SetCA, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[], "
             "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
             "Op(kind=OpKind.SvAddE, "
-            "inputs=[<ld.outputs[0]: <I64*32>>, <li.outputs[0]: <I64*32>>, "
-            "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<ld.outputs[0]: <I64*32>>, "
+            "<li.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add.input_uses[0]: <I64*32>>, "
+            "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+            "<add.input_uses[3]: <VL_MAXVL>>), "
             "immediates=[], "
             "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
             "name='add')",
             "Op(kind=OpKind.SvStd, "
-            "inputs=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
+            "input_vals=[<add.outputs[0]: <I64*32>>, <arg.outputs[0]: <I64>>, "
             "<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<st.input_uses[0]: <I64*32>>, "
+            "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
             "immediates=[0], "
             "outputs=(), name='st')",
         ])
@@ -150,95 +162,148 @@ class TestCompilerIR(unittest.TestCase):
         fn.pre_ra_insert_copies()
         self.assertEqual([repr(i) for i in fn.ops], [
             "Op(kind=OpKind.FuncArgR3, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[], "
             "outputs=(<arg.outputs[0]: <I64>>,), name='arg')",
             "Op(kind=OpKind.CopyFromReg, "
-            "inputs=[<arg.outputs[0]: <I64>>], "
+            "input_vals=[<arg.outputs[0]: <I64>>], "
+            "input_uses=(<arg.out0.copy.input_uses[0]: <I64>>,), "
             "immediates=[], "
-            "outputs=(<2.outputs[0]: <I64>>,), name='2')",
+            "outputs=(<arg.out0.copy.outputs[0]: <I64>>,), "
+            "name='arg.out0.copy')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
             "outputs=(<vl.outputs[0]: <VL_MAXVL>>,), name='vl')",
             "Op(kind=OpKind.CopyToReg, "
-            "inputs=[<2.outputs[0]: <I64>>], "
+            "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
+            "input_uses=(<ld.inp0.copy.input_uses[0]: <I64>>,), "
             "immediates=[], "
-            "outputs=(<3.outputs[0]: <I64>>,), name='3')",
+            "outputs=(<ld.inp0.copy.outputs[0]: <I64>>,), name='ld.inp0.copy')",
             "Op(kind=OpKind.SvLd, "
-            "inputs=[<3.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<ld.inp0.copy.outputs[0]: <I64>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<ld.input_uses[0]: <I64>>, "
+            "<ld.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[0], "
             "outputs=(<ld.outputs[0]: <I64*32>>,), name='ld')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
-            "outputs=(<4.outputs[0]: <VL_MAXVL>>,), name='4')",
+            "outputs=(<ld.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='ld.out0.setvl')",
             "Op(kind=OpKind.VecCopyFromReg, "
-            "inputs=[<ld.outputs[0]: <I64*32>>, <4.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<ld.outputs[0]: <I64*32>>, "
+            "<ld.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<ld.out0.copy.input_uses[0]: <I64*32>>, "
+            "<ld.out0.copy.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[], "
-            "outputs=(<5.outputs[0]: <I64*32>>,), name='5')",
+            "outputs=(<ld.out0.copy.outputs[0]: <I64*32>>,), "
+            "name='ld.out0.copy')",
             "Op(kind=OpKind.SvLI, "
-            "inputs=[<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<li.input_uses[0]: <VL_MAXVL>>,), "
             "immediates=[0], "
             "outputs=(<li.outputs[0]: <I64*32>>,), name='li')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
-            "outputs=(<6.outputs[0]: <VL_MAXVL>>,), name='6')",
+            "outputs=(<li.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='li.out0.setvl')",
             "Op(kind=OpKind.VecCopyFromReg, "
-            "inputs=[<li.outputs[0]: <I64*32>>, <6.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<li.outputs[0]: <I64*32>>, "
+            "<li.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<li.out0.copy.input_uses[0]: <I64*32>>, "
+            "<li.out0.copy.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[], "
-            "outputs=(<7.outputs[0]: <I64*32>>,), name='7')",
+            "outputs=(<li.out0.copy.outputs[0]: <I64*32>>,), "
+            "name='li.out0.copy')",
             "Op(kind=OpKind.SetCA, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[], "
             "outputs=(<ca.outputs[0]: <CA>>,), name='ca')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
-            "outputs=(<8.outputs[0]: <VL_MAXVL>>,), name='8')",
+            "outputs=(<add.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='add.inp0.setvl')",
             "Op(kind=OpKind.VecCopyToReg, "
-            "inputs=[<5.outputs[0]: <I64*32>>, <8.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<ld.out0.copy.outputs[0]: <I64*32>>, "
+            "<add.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add.inp0.copy.input_uses[0]: <I64*32>>, "
+            "<add.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[], "
-            "outputs=(<9.outputs[0]: <I64*32>>,), name='9')",
+            "outputs=(<add.inp0.copy.outputs[0]: <I64*32>>,), "
+            "name='add.inp0.copy')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
-            "outputs=(<10.outputs[0]: <VL_MAXVL>>,), name='10')",
+            "outputs=(<add.inp1.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='add.inp1.setvl')",
             "Op(kind=OpKind.VecCopyToReg, "
-            "inputs=[<7.outputs[0]: <I64*32>>, <10.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<li.out0.copy.outputs[0]: <I64*32>>, "
+            "<add.inp1.setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add.inp1.copy.input_uses[0]: <I64*32>>, "
+            "<add.inp1.copy.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[], "
-            "outputs=(<11.outputs[0]: <I64*32>>,), name='11')",
+            "outputs=(<add.inp1.copy.outputs[0]: <I64*32>>,), "
+            "name='add.inp1.copy')",
             "Op(kind=OpKind.SvAddE, "
-            "inputs=[<9.outputs[0]: <I64*32>>, <11.outputs[0]: <I64*32>>, "
-            "<ca.outputs[0]: <CA>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<add.inp0.copy.outputs[0]: <I64*32>>, "
+            "<add.inp1.copy.outputs[0]: <I64*32>>, <ca.outputs[0]: <CA>>, "
+            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add.input_uses[0]: <I64*32>>, "
+            "<add.input_uses[1]: <I64*32>>, <add.input_uses[2]: <CA>>, "
+            "<add.input_uses[3]: <VL_MAXVL>>), "
             "immediates=[], "
             "outputs=(<add.outputs[0]: <I64*32>>, <add.outputs[1]: <CA>>), "
             "name='add')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
-            "outputs=(<12.outputs[0]: <VL_MAXVL>>,), name='12')",
+            "outputs=(<add.out0.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='add.out0.setvl')",
             "Op(kind=OpKind.VecCopyFromReg, "
-            "inputs=[<add.outputs[0]: <I64*32>>, "
-            "<12.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<add.outputs[0]: <I64*32>>, "
+            "<add.out0.setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<add.out0.copy.input_uses[0]: <I64*32>>, "
+            "<add.out0.copy.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[], "
-            "outputs=(<13.outputs[0]: <I64*32>>,), name='13')",
+            "outputs=(<add.out0.copy.outputs[0]: <I64*32>>,), "
+            "name='add.out0.copy')",
             "Op(kind=OpKind.SetVLI, "
-            "inputs=[], "
+            "input_vals=[], "
+            "input_uses=(), "
             "immediates=[32], "
-            "outputs=(<14.outputs[0]: <VL_MAXVL>>,), name='14')",
+            "outputs=(<st.inp0.setvl.outputs[0]: <VL_MAXVL>>,), "
+            "name='st.inp0.setvl')",
             "Op(kind=OpKind.VecCopyToReg, "
-            "inputs=[<13.outputs[0]: <I64*32>>, <14.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<add.out0.copy.outputs[0]: <I64*32>>, "
+            "<st.inp0.setvl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<st.inp0.copy.input_uses[0]: <I64*32>>, "
+            "<st.inp0.copy.input_uses[1]: <VL_MAXVL>>), "
             "immediates=[], "
-            "outputs=(<15.outputs[0]: <I64*32>>,), name='15')",
+            "outputs=(<st.inp0.copy.outputs[0]: <I64*32>>,), "
+            "name='st.inp0.copy')",
             "Op(kind=OpKind.CopyToReg, "
-            "inputs=[<2.outputs[0]: <I64>>], "
+            "input_vals=[<arg.out0.copy.outputs[0]: <I64>>], "
+            "input_uses=(<st.inp1.copy.input_uses[0]: <I64>>,), "
             "immediates=[], "
-            "outputs=(<16.outputs[0]: <I64>>,), name='16')",
+            "outputs=(<st.inp1.copy.outputs[0]: <I64>>,), "
+            "name='st.inp1.copy')",
             "Op(kind=OpKind.SvStd, "
-            "inputs=[<15.outputs[0]: <I64*32>>, <16.outputs[0]: <I64>>, "
-            "<vl.outputs[0]: <VL_MAXVL>>], "
+            "input_vals=[<st.inp0.copy.outputs[0]: <I64*32>>, "
+            "<st.inp1.copy.outputs[0]: <I64>>, <vl.outputs[0]: <VL_MAXVL>>], "
+            "input_uses=(<st.input_uses[0]: <I64*32>>, "
+            "<st.input_uses[1]: <I64>>, <st.input_uses[2]: <VL_MAXVL>>), "
             "immediates=[0], "
             "outputs=(), name='st')",
         ])
index 8e365096dee70ed15da4f32303cd611d6c1ad275..03b623dafbf5f80ebdc96e487088552a42ffe040 100644 (file)
@@ -44,10 +44,11 @@ class Fn:
             raise ValueError("can't add Op to wrong Fn")
         self.ops.append(op)
 
-    def append_new_op(self, kind, inputs=(), immediates=(), name="", maxvl=1):
+    def append_new_op(self, kind, input_vals=(), immediates=(), name="",
+                      maxvl=1):
         # type: (OpKind, Iterable[SSAVal], Iterable[int], str, int) -> Op
         retval = Op(fn=self, properties=kind.instantiate(maxvl=maxvl),
-                    inputs=inputs, immediates=immediates, name=name)
+                    input_vals=input_vals, immediates=immediates, name=name)
         self.append_op(retval)
         return retval
 
@@ -62,38 +63,45 @@ class Fn:
         copied_outputs = {}  # type: dict[SSAVal, SSAVal]
         self.ops.clear()
         for op in orig_ops:
-            for i in range(len(op.inputs)):
-                inp = copied_outputs[op.inputs[i]]
+            for i in range(len(op.input_vals)):
+                inp = copied_outputs[op.input_vals[i]]
                 if inp.ty.base_ty is BaseTy.I64:
                     maxvl = inp.ty.reg_len
                     if inp.ty.reg_len != 1:
-                        setvl = self.append_new_op(OpKind.SetVLI,
-                                                   immediates=[maxvl])
+                        setvl = self.append_new_op(
+                            OpKind.SetVLI, immediates=[maxvl],
+                            name=f"{op.name}.inp{i}.setvl")
                         vl = setvl.outputs[0]
-                        mv = self.append_new_op(OpKind.VecCopyToReg,
-                                                inputs=[inp, vl], maxvl=maxvl)
+                        mv = self.append_new_op(
+                            OpKind.VecCopyToReg, input_vals=[inp, vl],
+                            maxvl=maxvl, name=f"{op.name}.inp{i}.copy")
                     else:
-                        mv = self.append_new_op(OpKind.CopyToReg, inputs=[inp])
-                    op.inputs[i] = mv.outputs[0]
+                        mv = self.append_new_op(
+                            OpKind.CopyToReg, input_vals=[inp],
+                            name=f"{op.name}.inp{i}.copy")
+                    op.input_vals[i] = mv.outputs[0]
                 elif inp.ty.base_ty is BaseTy.CA \
                         or inp.ty.base_ty is BaseTy.VL_MAXVL:
                     # all copies would be no-ops, so we don't need to copy
-                    op.inputs[i] = inp
+                    op.input_vals[i] = inp
                 else:
                     assert_never(inp.ty.base_ty)
             self.ops.append(op)
-            for out in op.outputs:
+            for i, out in enumerate(op.outputs):
                 if out.ty.base_ty is BaseTy.I64:
                     maxvl = out.ty.reg_len
                     if out.ty.reg_len != 1:
-                        setvl = self.append_new_op(OpKind.SetVLI,
-                                                   immediates=[maxvl])
+                        setvl = self.append_new_op(
+                            OpKind.SetVLI, immediates=[maxvl],
+                            name=f"{op.name}.out{i}.setvl")
                         vl = setvl.outputs[0]
-                        mv = self.append_new_op(OpKind.VecCopyFromReg,
-                                                inputs=[out, vl], maxvl=maxvl)
+                        mv = self.append_new_op(
+                            OpKind.VecCopyFromReg, input_vals=[out, vl],
+                            maxvl=maxvl, name=f"{op.name}.out{i}.copy")
                     else:
-                        mv = self.append_new_op(OpKind.CopyFromReg,
-                                                inputs=[out])
+                        mv = self.append_new_op(
+                            OpKind.CopyFromReg, input_vals=[out],
+                            name=f"{op.name}.out{i}.copy")
                     copied_outputs[out] = mv.outputs[0]
                 elif out.ty.base_ty is BaseTy.CA \
                         or out.ty.base_ty is BaseTy.VL_MAXVL:
@@ -113,7 +121,7 @@ class FnWithUses:
         self.fn = fn
         retval = {}  # type: dict[SSAVal, OSet[SSAUse]]
         for op in fn.ops:
-            for idx, inp in enumerate(op.inputs):
+            for idx, inp in enumerate(op.input_vals):
                 retval[inp].add(SSAUse(op, idx))
             for out in op.outputs:
                 retval[out] = OSet()
@@ -581,6 +589,7 @@ class GenericOperandDesc:
 
     def instantiate(self, maxvl):
         # type: (int) -> Iterable[OperandDesc]
+        # assumes all spread operands have ty.reg_len = 1
         rep_count = 1
         if self.spread:
             rep_count = maxvl
@@ -636,9 +645,23 @@ class OperandDesc:
     def ty(self):
         """ Ty after any spread is applied """
         if self.spread_index is not None:
+            # assumes all spread operands have ty.reg_len = 1
             return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
         return self.ty_before_spread
 
+    @property
+    def reg_offset_in_unspread(self):
+        """ the number of reg-sized slots in the unspread Loc before self's Loc
+
+        e.g. if the unspread Loc containing self is:
+        `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
+        and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
+        then reg_offset_into_unspread == 2 == 10 - 8
+        """
+        if self.spread_index is None:
+            return 0
+        return self.spread_index * self.ty.reg_len
+
 
 OD_BASE_SGPR = GenericOperandDesc(
     ty=GenericTy(base_ty=BaseTy.I64, is_vec=False),
@@ -816,10 +839,10 @@ class OpKind(Enum):
     @staticmethod
     def __svadde_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        RA = state.ssa_vals[op.inputs[0]]
-        RB = state.ssa_vals[op.inputs[1]]
-        carry, = state.ssa_vals[op.inputs[2]]
-        VL, = state.ssa_vals[op.inputs[3]]
+        RA = state.ssa_vals[op.input_vals[0]]
+        RB = state.ssa_vals[op.input_vals[1]]
+        carry, = state.ssa_vals[op.input_vals[2]]
+        VL, = state.ssa_vals[op.input_vals[3]]
         RT = []  # type: list[int]
         for i in range(VL):
             v = RA[i] + RB[i] + carry
@@ -837,10 +860,10 @@ class OpKind(Enum):
     @staticmethod
     def __svsubfe_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        RA = state.ssa_vals[op.inputs[0]]
-        RB = state.ssa_vals[op.inputs[1]]
-        carry, = state.ssa_vals[op.inputs[2]]
-        VL, = state.ssa_vals[op.inputs[3]]
+        RA = state.ssa_vals[op.input_vals[0]]
+        RB = state.ssa_vals[op.input_vals[1]]
+        carry, = state.ssa_vals[op.input_vals[2]]
+        VL, = state.ssa_vals[op.input_vals[3]]
         RT = []  # type: list[int]
         for i in range(VL):
             v = (~RA[i] & GPR_VALUE_MASK) + RB[i] + carry
@@ -858,10 +881,10 @@ class OpKind(Enum):
     @staticmethod
     def __svmaddedu_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        RA = state.ssa_vals[op.inputs[0]]
-        RB, = state.ssa_vals[op.inputs[1]]
-        carry, = state.ssa_vals[op.inputs[2]]
-        VL, = state.ssa_vals[op.inputs[3]]
+        RA = state.ssa_vals[op.input_vals[0]]
+        RB, = state.ssa_vals[op.input_vals[1]]
+        carry, = state.ssa_vals[op.input_vals[2]]
+        VL, = state.ssa_vals[op.input_vals[3]]
         RT = []  # type: list[int]
         for i in range(VL):
             v = RA[i] * RB + carry
@@ -892,7 +915,7 @@ class OpKind(Enum):
     @staticmethod
     def __svli_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        VL, = state.ssa_vals[op.inputs[0]]
+        VL, = state.ssa_vals[op.input_vals[0]]
         imm = op.immediates[0] & GPR_VALUE_MASK
         state.ssa_vals[op.outputs[0]] = (imm,) * VL
     SvLI = GenericOpProperties(
@@ -921,7 +944,7 @@ class OpKind(Enum):
     @staticmethod
     def __veccopytoreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
     VecCopyToReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[GenericOperandDesc(
@@ -936,7 +959,7 @@ class OpKind(Enum):
     @staticmethod
     def __veccopyfromreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
     VecCopyFromReg = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[OD_EXTRA3_VGPR, OD_VL],
@@ -951,7 +974,7 @@ class OpKind(Enum):
     @staticmethod
     def __copytoreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
     CopyToReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -970,7 +993,7 @@ class OpKind(Enum):
     @staticmethod
     def __copyfromreg_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.inputs[0]]
+        state.ssa_vals[op.outputs[0]] = state.ssa_vals[op.input_vals[0]]
     CopyFromReg = GenericOpProperties(
         demo_asm="mv dest, src",
         inputs=[GenericOperandDesc(
@@ -990,7 +1013,7 @@ class OpKind(Enum):
     def __concat_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
         state.ssa_vals[op.outputs[0]] = tuple(
-            state.ssa_vals[i][0] for i in op.inputs[:-1])
+            state.ssa_vals[i][0] for i in op.input_vals[:-1])
     Concat = GenericOpProperties(
         demo_asm="sv.mv dest, src",
         inputs=[GenericOperandDesc(
@@ -1006,7 +1029,7 @@ class OpKind(Enum):
     @staticmethod
     def __spread_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        for idx, inp in enumerate(state.ssa_vals[op.inputs[0]]):
+        for idx, inp in enumerate(state.ssa_vals[op.input_vals[0]]):
             state.ssa_vals[op.outputs[idx]] = inp,
     Spread = GenericOpProperties(
         demo_asm="sv.mv dest, src",
@@ -1023,8 +1046,8 @@ class OpKind(Enum):
     @staticmethod
     def __svld_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        RA, = state.ssa_vals[op.inputs[0]]
-        VL, = state.ssa_vals[op.inputs[1]]
+        RA, = state.ssa_vals[op.input_vals[0]]
+        VL, = state.ssa_vals[op.input_vals[1]]
         addr = RA + op.immediates[0]
         RT = []  # type: list[int]
         for i in range(VL):
@@ -1042,7 +1065,7 @@ class OpKind(Enum):
     @staticmethod
     def __ld_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        RA, = state.ssa_vals[op.inputs[0]]
+        RA, = state.ssa_vals[op.input_vals[0]]
         addr = RA + op.immediates[0]
         v = state.load(addr)
         state.ssa_vals[op.outputs[0]] = v & GPR_VALUE_MASK,
@@ -1057,9 +1080,9 @@ class OpKind(Enum):
     @staticmethod
     def __svstd_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        RS = state.ssa_vals[op.inputs[0]]
-        RA, = state.ssa_vals[op.inputs[1]]
-        VL, = state.ssa_vals[op.inputs[2]]
+        RS = state.ssa_vals[op.input_vals[0]]
+        RA, = state.ssa_vals[op.input_vals[1]]
+        VL, = state.ssa_vals[op.input_vals[2]]
         addr = RA + op.immediates[0]
         for i in range(VL):
             state.store(addr + GPR_SIZE_IN_BYTES * i, value=RS[i])
@@ -1075,8 +1098,8 @@ class OpKind(Enum):
     @staticmethod
     def __std_pre_ra_sim(op, state):
         # type: (Op, PreRASimState) -> None
-        RS, = state.ssa_vals[op.inputs[0]]
-        RA, = state.ssa_vals[op.inputs[1]]
+        RS, = state.ssa_vals[op.input_vals[0]]
+        RA, = state.ssa_vals[op.input_vals[1]]
         addr = RA + op.immediates[0]
         state.store(addr, value=RS)
     Std = GenericOpProperties(
@@ -1103,11 +1126,14 @@ class OpKind(Enum):
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 class SSAValOrUse(metaclass=ABCMeta):
-    __slots__ = "op",
+    __slots__ = "op", "operand_idx"
 
-    def __init__(self, op):
-        # type: (Op) -> None
+    def __init__(self, op, operand_idx):
+        # type: (Op, int) -> None
         self.op = op
+        if operand_idx < 0 or operand_idx >= len(self.descriptor_array):
+            raise ValueError("invalid operand_idx")
+        self.operand_idx = operand_idx
 
     @abstractmethod
     def __repr__(self):
@@ -1116,9 +1142,14 @@ class SSAValOrUse(metaclass=ABCMeta):
 
     @property
     @abstractmethod
+    def descriptor_array(self):
+        # type: () -> tuple[OperandDesc, ...]
+        ...
+
+    @property
     def defining_descriptor(self):
         # type: () -> OperandDesc
-        ...
+        return self.descriptor_array[self.operand_idx]
 
     @cached_property
     def ty(self):
@@ -1135,22 +1166,36 @@ class SSAValOrUse(metaclass=ABCMeta):
         # type: () -> BaseTy
         return self.ty_before_spread.base_ty
 
+    @property
+    def reg_offset_in_unspread(self):
+        """ the number of reg-sized slots in the unspread Loc before self's Loc
+
+        e.g. if the unspread Loc containing self is:
+        `Loc(kind=LocKind.GPR, start=8, reg_len=4)`
+        and self's Loc is `Loc(kind=LocKind.GPR, start=10, reg_len=1)`
+        then reg_offset_into_unspread == 2 == 10 - 8
+        """
+        return self.defining_descriptor.reg_offset_in_unspread
+
+    @property
+    def unspread_start_idx(self):
+        # type: () -> int
+        return self.operand_idx - (self.defining_descriptor.spread_index or 0)
+
+    @property
+    def unspread_start(self):
+        # type: () -> Self
+        return self.__class__(op=self.op, operand_idx=self.unspread_start_idx)
+
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 class SSAVal(SSAValOrUse):
-    __slots__ = "output_idx",
-
-    def __init__(self, op, output_idx):
-        # type: (Op, int) -> None
-        super().__init__(op)
-        if output_idx < 0 or output_idx >= len(op.properties.outputs):
-            raise ValueError("invalid output_idx")
-        self.output_idx = output_idx
+    __slots__ = ()
 
     def __repr__(self):
         # type: () -> str
-        return f"<{self.op.name}.outputs[{self.output_idx}]: {self.ty}>"
+        return f"<{self.op.name}.outputs[{self.operand_idx}]: {self.ty}>"
 
     @cached_property
     def def_loc_set_before_spread(self):
@@ -1158,22 +1203,23 @@ class SSAVal(SSAValOrUse):
         return self.defining_descriptor.loc_set_before_spread
 
     @cached_property
-    def defining_descriptor(self):
-        # type: () -> OperandDesc
-        return self.op.properties.outputs[self.output_idx]
+    def descriptor_array(self):
+        # type: () -> tuple[OperandDesc, ...]
+        return self.op.properties.outputs
+
+    @cached_property
+    def tied_input(self):
+        # type: () -> None | SSAUse
+        if self.defining_descriptor.tied_input_index is None:
+            return None
+        return SSAUse(op=self.op,
+                      operand_idx=self.defining_descriptor.tied_input_index)
 
 
 @plain_data(frozen=True, unsafe_hash=True, repr=False)
 @final
 class SSAUse(SSAValOrUse):
-    __slots__ = "input_idx",
-
-    def __init__(self, op, input_idx):
-        # type: (Op, int) -> None
-        super().__init__(op)
-        self.input_idx = input_idx
-        if input_idx < 0 or input_idx >= len(op.inputs):
-            raise ValueError("input_idx out of range")
+    __slots__ = ()
 
     @cached_property
     def use_loc_set_before_spread(self):
@@ -1181,13 +1227,23 @@ class SSAUse(SSAValOrUse):
         return self.defining_descriptor.loc_set_before_spread
 
     @cached_property
-    def defining_descriptor(self):
-        # type: () -> OperandDesc
-        return self.op.properties.inputs[self.input_idx]
+    def descriptor_array(self):
+        # type: () -> tuple[OperandDesc, ...]
+        return self.op.properties.inputs
 
     def __repr__(self):
         # type: () -> str
-        return f"<{self.op.name}.inputs[{self.input_idx}]: {self.ty}>"
+        return f"<{self.op.name}.input_uses[{self.operand_idx}]: {self.ty}>"
+
+    @property
+    def ssa_val(self):
+        # type: () -> SSAVal
+        return self.op.input_vals[self.operand_idx]
+
+    @ssa_val.setter
+    def ssa_val(self, ssa_val):
+        # type: (SSAVal) -> None
+        self.op.input_vals[self.operand_idx] = ssa_val
 
 
 _T = TypeVar("_T")
@@ -1282,7 +1338,7 @@ class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
 
 
 @final
-class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
+class OpInputVals(OpInputSeq[SSAVal, OperandDesc]):
     def _get_descriptors(self):
         # type: () -> tuple[OperandDesc, ...]
         return self.op.properties.inputs
@@ -1329,13 +1385,16 @@ class OpImmediates(OpInputSeq[int, range]):
 @plain_data(frozen=True, eq=False, repr=False)
 @final
 class Op:
-    __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name"
+    __slots__ = ("fn", "properties", "input_vals", "input_uses", "immediates",
+                 "outputs", "name")
 
-    def __init__(self, fn, properties, inputs, immediates, name=""):
+    def __init__(self, fn, properties, input_vals, immediates, name=""):
         # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
         self.fn = fn
         self.properties = properties
-        self.inputs = OpInputs(inputs, op=self)
+        self.input_vals = OpInputVals(input_vals, op=self)
+        inputs_len = len(self.properties.inputs)
+        self.input_uses = tuple(SSAUse(self, i) for i in range(inputs_len))
         self.immediates = OpImmediates(immediates, op=self)
         outputs_len = len(self.properties.outputs)
         self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
@@ -1375,7 +1434,7 @@ class Op:
 
     def pre_ra_sim(self, state):
         # type: (PreRASimState) -> None
-        for inp in self.inputs:
+        for inp in self.input_vals:
             if inp not in state.ssa_vals:
                 raise ValueError(f"SSAVal {inp} not yet assigned when "
                                  f"running {self}")
index 962a0217d325e32a2a447207cab6e6b555da2c27..68443d9a6e65054760934ede353bb4a67435bcde 100644 (file)
@@ -6,17 +6,15 @@ this uses an algorithm based on:
 """
 
 from itertools import combinations
-from functools import reduce
-from typing import Generic, Iterable, Mapping
-from cached_property import cached_property
-import operator
+from typing import Any, Generic, Iterable, Iterator, Mapping, MutableSet
 
+from cached_property import cached_property
 from nmutil.plain_data import plain_data
 
-from bigint_presentation_code.compiler_ir2 import (
-    Op, LocSet, Ty, SSAVal, BaseTy, Loc, FnWithUses)
-from bigint_presentation_code.type_util import final, Self
-from bigint_presentation_code.util import OFSet, OSet, FMap
+from bigint_presentation_code.compiler_ir2 import (BaseTy, FnWithUses, Loc,
+                                                   LocSet, Op, SSAVal, Ty)
+from bigint_presentation_code.type_util import final
+from bigint_presentation_code.util import FMap, OFSet, OSet
 
 
 @plain_data(unsafe_hash=True, order=True, frozen=True)
@@ -58,7 +56,7 @@ class BadMergedSSAVal(ValueError):
     pass
 
 
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True)
 @final
 class MergedSSAVal:
     """a set of `SSAVal`s along with their offsets, all register allocated as
@@ -88,7 +86,7 @@ class MergedSSAVal:
     * `v2` is allocated to `Loc(kind=LocKind.GPR, start=24, reg_len=2)`
     * `v3` is allocated to `Loc(kind=LocKind.GPR, start=21, reg_len=1)`
     """
-    __slots__ = "fn_with_uses", "ssa_val_offsets", "base_ty", "loc_set"
+    __slots__ = "fn_with_uses", "ssa_val_offsets", "first_ssa_val", "loc_set"
 
     def __init__(self, fn_with_uses, ssa_val_offsets):
         # type: (FnWithUses, Mapping[SSAVal, int] | SSAVal) -> None
@@ -96,13 +94,13 @@ class MergedSSAVal:
         if isinstance(ssa_val_offsets, SSAVal):
             ssa_val_offsets = {ssa_val_offsets: 0}
         self.ssa_val_offsets = FMap(ssa_val_offsets)  # type: FMap[SSAVal, int]
-        base_ty = None
-        for ssa_val in self.ssa_val_offsets.keys():
-            base_ty = ssa_val.base_ty
+        first_ssa_val = None
+        for ssa_val in self.ssa_vals:
+            first_ssa_val = ssa_val
             break
-        if base_ty is None:
+        if first_ssa_val is None:
             raise BadMergedSSAVal("MergedSSAVal can't be empty")
-        self.base_ty = base_ty  # type: BaseTy
+        self.first_ssa_val = first_ssa_val  # type: SSAVal
         # self.ty checks for mismatched base_ty
         reg_len = self.ty.reg_len
         loc_set = None  # type: None | LocSet
@@ -144,11 +142,30 @@ class MergedSSAVal:
         assert loc_set.ty == self.ty, "logic error somewhere"
         self.loc_set = loc_set  # type: LocSet
 
+    @cached_property
+    def __hash(self):
+        # type: () -> int
+        return hash((self.fn_with_uses, self.ssa_val_offsets))
+
+    def __hash__(self):
+        # type: () -> int
+        return self.__hash
+
     @cached_property
     def offset(self):
         # type: () -> int
         return min(self.ssa_val_offsets_before_spread.values())
 
+    @property
+    def base_ty(self):
+        # type: () -> BaseTy
+        return self.first_ssa_val.base_ty
+
+    @cached_property
+    def ssa_vals(self):
+        # type: () -> OFSet[SSAVal]
+        return OFSet(self.ssa_val_offsets.keys())
+
     @cached_property
     def ty(self):
         # type: () -> Ty
@@ -166,15 +183,8 @@ class MergedSSAVal:
         # type: () -> FMap[SSAVal, int]
         retval = {}  # type: dict[SSAVal, int]
         for ssa_val, offset in self.ssa_val_offsets.items():
-            offset_before_spread = offset
-            spread_index = ssa_val.defining_descriptor.spread_index
-            if spread_index is not None:
-                assert ssa_val.ty.reg_len == 1, (
-                    "this function assumes spreading always converts a vector "
-                    "to a contiguous sequence of scalars, if that's changed "
-                    "in the future, then this function needs to be adjusted")
-                offset_before_spread -= spread_index
-            retval[ssa_val] = offset_before_spread
+            retval[ssa_val] = (
+                offset - ssa_val.defining_descriptor.reg_offset_in_unspread)
         return FMap(retval)
 
     def offset_by(self, amount):
@@ -186,65 +196,178 @@ class MergedSSAVal:
         # type: () -> MergedSSAVal
         return self.offset_by(-self.offset)
 
-    def with_offset_to_match(self, target):
-        # type: (MergedSSAVal) -> MergedSSAVal
+    def with_offset_to_match(self, target, additional_offset=0):
+        # type: (MergedSSAVal | SSAVal, int) -> MergedSSAVal
+        if isinstance(target, MergedSSAVal):
+            ssa_val_offsets = target.ssa_val_offsets
+        else:
+            ssa_val_offsets = {target: 0}
         for ssa_val, offset in self.ssa_val_offsets.items():
-            if ssa_val in target.ssa_val_offsets:
-                return self.offset_by(target.ssa_val_offsets[ssa_val] - offset)
+            if ssa_val in ssa_val_offsets:
+                return self.offset_by(
+                    ssa_val_offsets[ssa_val] + additional_offset - offset)
         raise ValueError("can't change offset to match unrelated MergedSSAVal")
 
+    def merged(self, *others):
+        # type: (*MergedSSAVal) -> MergedSSAVal
+        retval = dict(self.ssa_val_offsets)
+        for other in others:
+            if other.fn_with_uses != self.fn_with_uses:
+                raise ValueError("fn_with_uses mismatch")
+            for ssa_val, offset in other.ssa_val_offsets.items():
+                if ssa_val in retval and retval[ssa_val] != offset:
+                    raise BadMergedSSAVal(f"offset mismatch for {ssa_val}: "
+                                          f"{retval[ssa_val]} != {offset}")
+                retval[ssa_val] = offset
+        return MergedSSAVal(fn_with_uses=self.fn_with_uses,
+                            ssa_val_offsets=retval)
+
+
+@final
+class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]):
+    def __init__(self):
+        # type: (...) -> None
+        self.__merge_map = {}  # type: dict[SSAVal, MergedSSAVal]
+        self.__values_set = MergedSSAValsSet(
+            _private_merge_map=self.__merge_map,
+            _private_values_set=OSet())
+
+    def __getitem__(self, __key):
+        # type: (SSAVal) -> MergedSSAVal
+        return self.__merge_map[__key]
+
+    def __iter__(self):
+        # type: () -> Iterator[SSAVal]
+        return iter(self.__merge_map)
+
+    def __len__(self):
+        # type: () -> int
+        return len(self.__merge_map)
+
+    @property
+    def values_set(self):
+        # type: () -> MergedSSAValsSet
+        return self.__values_set
+
+    def __repr__(self):
+        # type: () -> str
+        s = ",\n".join(repr(v) for v in self.__values_set)
+        return f"MergedSSAValsMap({{{s}}})"
+
 
 @final
-class MergedSSAVals(OFSet[MergedSSAVal]):
-    def __init__(self, merged_ssa_vals=()):
-        # type: (Iterable[MergedSSAVal]) -> None
-        super().__init__(merged_ssa_vals)
-        merge_map = {}  # type: dict[SSAVal, MergedSSAVal]
-        for merged_ssa_val in self:
-            for ssa_val in merged_ssa_val.ssa_val_offsets.keys():
-                if ssa_val in merge_map:
+class MergedSSAValsSet(MutableSet[MergedSSAVal]):
+    def __init__(self, *,
+                 _private_merge_map,  # type: dict[SSAVal, MergedSSAVal]
+                 _private_values_set,  # type: OSet[MergedSSAVal]
+                 ):
+        # type: (...) -> None
+        self.__merge_map = _private_merge_map
+        self.__values_set = _private_values_set
+
+    @classmethod
+    def _from_iterable(cls, it):
+        # type: (Iterable[MergedSSAVal]) -> OSet[MergedSSAVal]
+        return OSet(it)
+
+    def __contains__(self, value):
+        # type: (MergedSSAVal | Any) -> bool
+        return value in self.__values_set
+
+    def __iter__(self):
+        # type: () -> Iterator[MergedSSAVal]
+        return iter(self.__values_set)
+
+    def __len__(self):
+        # type: () -> int
+        return len(self.__values_set)
+
+    def add(self, value):
+        # type: (MergedSSAVal) -> None
+        if value in self:
+            return
+        added = 0  # type: int | None
+        try:
+            for ssa_val in value.ssa_vals:
+                if ssa_val in self.__merge_map:
                     raise ValueError(
                         f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
-                        f"{merged_ssa_val} and {merge_map[ssa_val]}")
-                merge_map[ssa_val] = merged_ssa_val
-        self.__merge_map = FMap(merge_map)
+                        f"{value} and {self.__merge_map[ssa_val]}")
+                self.__merge_map[ssa_val] = value
+                added += 1
+            self.__values_set.add(value)
+            added = None
+        finally:
+            if added is not None:
+                # remove partially added stuff
+                for idx, ssa_val in enumerate(value.ssa_vals):
+                    if idx >= added:
+                        break
+                    del self.__merge_map[ssa_val]
+
+    def discard(self, value):
+        # type: (MergedSSAVal) -> None
+        if value not in self:
+            return
+        self.__values_set.discard(value)
+        for ssa_val in value.ssa_val_offsets.keys():
+            del self.__merge_map[ssa_val]
 
-    @cached_property
-    def merge_map(self):
-        # type: () -> FMap[SSAVal, MergedSSAVal]
-        return self.__merge_map
+    def __repr__(self):
+        # type: () -> str
+        s = ",\n".join(repr(v) for v in self.__values_set)
+        return f"MergedSSAValsSet({{{s}}})"
 
-# FIXME: work on code from here
+
+@plain_data(frozen=True)
+@final
+class MergedSSAVals:
+    __slots__ = "fn_with_uses", "merge_map", "merged_ssa_vals"
+
+    def __init__(self, fn_with_uses, merged_ssa_vals):
+        # type: (FnWithUses, Iterable[MergedSSAVal]) -> None
+        self.fn_with_uses = fn_with_uses
+        self.merge_map = MergedSSAValsMap()
+        self.merged_ssa_vals = self.merge_map.values_set
+        for i in merged_ssa_vals:
+            self.merged_ssa_vals.add(i)
+
+    def merge(self, ssa_val1, ssa_val2, additional_offset=0):
+        # type: (SSAVal, SSAVal, int) -> MergedSSAVal
+        merged1 = self.merge_map[ssa_val1]
+        merged2 = self.merge_map[ssa_val2]
+        merged = merged1.with_offset_to_match(ssa_val1)
+        merged = merged.merged(merged2.with_offset_to_match(
+            ssa_val2, additional_offset=additional_offset))
+        self.merged_ssa_vals.remove(merged1)
+        self.merged_ssa_vals.remove(merged2)
+        self.merged_ssa_vals.add(merged)
+        return merged
 
     @staticmethod
     def minimally_merged(fn_with_uses):
         # type: (FnWithUses) -> MergedSSAVals
-        merge_map = {}  # type: dict[SSAVal, MergedSSAVal]
+        retval = MergedSSAVals(fn_with_uses=fn_with_uses, merged_ssa_vals=())
         for op in fn_with_uses.fn.ops:
-            for fn
-            for val in (*op.inputs().values(), *op.outputs().values()):
-                if val not in merged_sets:
-                    merged_sets[val] = MergedRegSet(val)
-            for e in op.get_equality_constraints():
-                lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
-                rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
-                items = []  # type: list[tuple[SSAVal, int]]
-                for i in e.lhs:
-                    s = merged_sets[i].with_offset_to_match(lhs_set)
-                    items.extend(s.items())
-                for i in e.rhs:
-                    s = merged_sets[i].with_offset_to_match(rhs_set)
-                    items.extend(s.items())
-                full_set = MergedRegSet(items)
-                for val in full_set.keys():
-                    merged_sets[val] = full_set
-
-        self.__map = {k: v.normalized() for k, v in merged_sets.items()}
+            for inp in op.input_uses:
+                if inp.unspread_start != inp:
+                    retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
+                                 additional_offset=inp.reg_offset_in_unspread)
+            for out in op.outputs:
+                if out.unspread_start != out:
+                    retval.merge(out.unspread_start, out,
+                                 additional_offset=out.reg_offset_in_unspread)
+                if out.tied_input is not None:
+                    retval.merge(out.tied_input.ssa_val, out)
+        return retval
+
+
+# FIXME: work on code from here
 
 
 @final
-class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
-    def __init__(self, ops):
+class LiveIntervals(Mapping[MergedSSAVal, LiveInterval]):
+    def __init__(self, merged_ssa_vals):
         # type: (list[Op]) -> None
         self.__merged_reg_sets = MergedRegSets(ops)
         live_intervals = {}  # type: dict[MergedRegSet[_RegType], LiveInterval]