work on switching algorithms
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 12 Oct 2022 06:49:59 +0000 (23:49 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 12 Oct 2022 06:49:59 +0000 (23:49 -0700)
src/bigint_presentation_code/toom_cook.py

index dd9061b321108cfa455b3642b54d2e2e58983d07..02ddb4b0af6d524beaecf798886601eb1c74d6d0 100644 (file)
+"""
+Toom-Cook algorithm generator for SVP64
+
+the register allocator uses an algorithm based on:
+[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
+"""
+
 from abc import ABCMeta, abstractmethod
 from collections import defaultdict
 from enum import Enum, unique
-from typing import AbstractSet, Iterable, Mapping, TYPE_CHECKING
+from functools import lru_cache
+from typing import (Sequence, AbstractSet, Iterable, Mapping,
+                    TYPE_CHECKING, Sequence, TypeVar)
 
 from nmutil.plain_data import plain_data
 
 if TYPE_CHECKING:
     from typing_extensions import final, Self
+    from typing import Generic
 else:
     def final(v):
         return v
 
+    # make plain_data work with Generics
+    class Generic:
+        def __class_getitem__(cls, item):
+            return object
+
 
 @plain_data(frozen=True, unsafe_hash=True)
-class PhysLoc:
-    pass
+class PhysLoc(metaclass=ABCMeta):
+    __slots__ = ()
 
 
 @plain_data(frozen=True, unsafe_hash=True)
-class GPROrStackLoc(PhysLoc):
-    pass
+class RegLoc(PhysLoc):
+    __slots__ = ()
+
+    @abstractmethod
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        ...
+
 
+@plain_data(frozen=True, unsafe_hash=True)
+class GPRRangeOrStackLoc(PhysLoc):
+    __slots__ = ()
 
+    @abstractmethod
+    def __len__(self):
+        # type: () -> int
+        ...
+
+
+GPR_COUNT = 128
+
+
+@plain_data(frozen=True, unsafe_hash=True)
 @final
-class GPR(GPROrStackLoc, Enum):
-    def __init__(self, reg_num):
-        # type: (int) -> None
-        self.reg_num = reg_num
-    # fmt: off
-    R0 = 0;     R1 = 1;     R2 = 2;     R3 = 3;     R4 = 4;     R5 = 5
-    R6 = 6;     R7 = 7;     R8 = 8;     R9 = 9;     R10 = 10;   R11 = 11
-    R12 = 12;   R13 = 13;   R14 = 14;   R15 = 15;   R16 = 16;   R17 = 17
-    R18 = 18;   R19 = 19;   R20 = 20;   R21 = 21;   R22 = 22;   R23 = 23
-    R24 = 24;   R25 = 25;   R26 = 26;   R27 = 27;   R28 = 28;   R29 = 29
-    R30 = 30;   R31 = 31;   R32 = 32;   R33 = 33;   R34 = 34;   R35 = 35
-    R36 = 36;   R37 = 37;   R38 = 38;   R39 = 39;   R40 = 40;   R41 = 41
-    R42 = 42;   R43 = 43;   R44 = 44;   R45 = 45;   R46 = 46;   R47 = 47
-    R48 = 48;   R49 = 49;   R50 = 50;   R51 = 51;   R52 = 52;   R53 = 53
-    R54 = 54;   R55 = 55;   R56 = 56;   R57 = 57;   R58 = 58;   R59 = 59
-    R60 = 60;   R61 = 61;   R62 = 62;   R63 = 63;   R64 = 64;   R65 = 65
-    R66 = 66;   R67 = 67;   R68 = 68;   R69 = 69;   R70 = 70;   R71 = 71
-    R72 = 72;   R73 = 73;   R74 = 74;   R75 = 75;   R76 = 76;   R77 = 77
-    R78 = 78;   R79 = 79;   R80 = 80;   R81 = 81;   R82 = 82;   R83 = 83
-    R84 = 84;   R85 = 85;   R86 = 86;   R87 = 87;   R88 = 88;   R89 = 89
-    R90 = 90;   R91 = 91;   R92 = 92;   R93 = 93;   R94 = 94;   R95 = 95
-    R96 = 96;   R97 = 97;   R98 = 98;   R99 = 99;   R100 = 100; R101 = 101
-    R102 = 102; R103 = 103; R104 = 104; R105 = 105; R106 = 106; R107 = 107
-    R108 = 108; R109 = 109; R110 = 110; R111 = 111; R112 = 112; R113 = 113
-    R114 = 114; R115 = 115; R116 = 116; R117 = 117; R118 = 118; R119 = 119
-    R120 = 120; R121 = 121; R122 = 122; R123 = 123; R124 = 124; R125 = 125
-    R126 = 126; R127 = 127
-    # fmt: on
-    SP = 1
-    TOC = 2
-
-
-SPECIAL_GPRS = GPR.R0, GPR.SP, GPR.TOC, GPR.R13
+class GPRRange(RegLoc, GPRRangeOrStackLoc, Sequence["GPRRange"]):
+    __slots__ = "start", "length"
+
+    def __init__(self, start, length=None):
+        # type: (int | range, int | None) -> None
+        if isinstance(start, range):
+            if length is not None:
+                raise TypeError("can't specify length when input is a range")
+            if start.step != 1:
+                raise ValueError("range must have a step of 1")
+            length = len(start)
+            start = start.start
+        elif length is None:
+            length = 1
+        if length <= 0 or start < 0 or start + length > GPR_COUNT:
+            raise ValueError("invalid GPRRange")
+        self.start = start
+        self.length = length
+
+    @property
+    def stop(self):
+        return self.start + self.length
+
+    @property
+    def step(self):
+        return 1
+
+    @property
+    def range(self):
+        return range(self.start, self.stop, self.step)
+
+    def __len__(self):
+        return self.length
+
+    def __getitem__(self, item):
+        # type: (int | slice) -> GPRRange
+        return GPRRange(self.range[item])
+
+    def __contains__(self, value):
+        # type: (GPRRange) -> bool
+        return value.start >= self.start and value.stop <= self.stop
+
+    def index(self, sub, start=None, end=None):
+        # type: (GPRRange, int | None, int | None) -> int
+        r = self.range[start:end]
+        if sub.start < r.start or sub.stop > r.stop:
+            raise ValueError("GPR range not found")
+        return sub.start - self.start
+
+    def count(self, sub, start=None, end=None):
+        # type: (GPRRange, int | None, int | None) -> int
+        r = self.range[start:end]
+        if len(r) == 0:
+            return 0
+        return int(sub in GPRRange(r))
+
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, GPRRange):
+            return self.stop > other.start and other.stop > self.start
+        return False
+
+
+SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
 
 
 @final
 @unique
-class XERBit(Enum, PhysLoc):
+class XERBit(Enum, RegLoc):
     CY = "CY"
 
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, XERBit):
+            return self == other
+        return False
+
 
 @final
 @unique
-class GlobalMem(Enum, PhysLoc):
-    """singleton representing all non-StackSlot memory"""
+class GlobalMem(Enum, RegLoc):
+    """singleton representing all non-StackSlot memory -- treated as a single
+    physical register for register allocation purposes.
+    """
     GlobalMem = "GlobalMem"
 
-
-ALLOCATABLE_REGS = frozenset((set(GPR) - set(SPECIAL_GPRS))
-                             | set(XERBit) | set(GlobalMem))
+    def conflicts(self, other):
+        # type: (RegLoc) -> bool
+        if isinstance(other, GlobalMem):
+            return self == other
+        return False
 
 
-@plain_data()
 @final
-class StackSlot(GPROrStackLoc):
-    """a stack slot. Use OpCopy to load from/store into this stack slot."""
-    __slots__ = "offset",
+class RegClass(AbstractSet[RegLoc]):
+    def __init__(self, regs):
+        # type: (Iterable[RegLoc]) -> None
+        self.__regs = frozenset(regs)
 
-    def __init__(self, offset=None):
-        # type: (int | None) -> None
-        self.offset = offset
+    def __len__(self):
+        return len(self.__regs)
 
+    def __iter__(self):
+        return iter(self.__regs)
 
-class SSAVal(metaclass=ABCMeta):
-    __slots__ = "op", "arg_name", "element_index"
+    def __contains__(self, v):
+        # type: (RegLoc) -> bool
+        return v in self.__regs
 
-    def __init__(self, op, arg_name, element_index):
-        # type: (Op, str, int) -> None
-        self.op = op
-        """the Op that writes this SSAVal"""
+    def __hash__(self):
+        return super()._hash()
 
-        self.arg_name = arg_name
-        self.element_index = element_index
 
-    @final
-    def __eq__(self, rhs):
-        if isinstance(rhs, SSAVal):
-            return (self.op is rhs.op
-                    and self.arg_name == rhs.arg_name
-                    and self.element_index == rhs.element_index)
-        return False
+@plain_data(frozen=True, unsafe_hash=True)
+class RegType(metaclass=ABCMeta):
+    __slots__ = ()
 
-    @final
-    def __hash__(self):
-        return hash((id(self.op), self.arg_name, self.element_index))
+    @property
+    @abstractmethod
+    def reg_class(self):
+        # type: () -> RegClass
+        return ...
 
-    def _get_phys_loc(self, phys_loc_in, value_assignments=None):
-        # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
-        if phys_loc_in is not None:
-            return phys_loc_in
-        if value_assignments is not None:
-            return value_assignments.get(self)
-        return None
 
-    @abstractmethod
-    def get_phys_loc(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
-        ...
+@plain_data(frozen=True, eq=False)
+class GPRRangeType(RegType):
+    __slots__ = "length",
+
+    def __init__(self, length):
+        # type: (int) -> None
+        if length < 1 or length > GPR_COUNT:
+            raise ValueError("invalid length")
+        self.length = length
+
+    @staticmethod
+    @lru_cache()
+    def __get_reg_class(length):
+        # type: (int) -> RegClass
+        regs = []
+        for start in range(GPR_COUNT - length):
+            reg = GPRRange(start, length)
+            if any(i in reg for i in SPECIAL_GPRS):
+                continue
+            regs.append(reg)
+        return RegClass(regs)
+
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return GPRRangeType.__get_reg_class(self.length)
 
     @final
-    def __repr__(self):
-        name = self.__class__.__name__
-        op = object.__repr__(self.op)
-        phys_loc = self.get_phys_loc()
-        return (f"{name}(op={op}, arg_name={self.arg_name}, "
-                f"element_index={self.element_index}, phys_loc={phys_loc})")
+    def __eq__(self, other):
+        if isinstance(other, GPRRangeType):
+            return self.length == other.length
+        return False
 
     @final
-    def like(self, op, arg_name):
-        # type: (Op, str) -> Self
-        """create a new SSAVal based off of self's type.
-        has same signature as VecArg.like.
-        """
-        return self.__class__(op=op, arg_name=arg_name,
-                              element_index=0)
+    def __hash__(self):
+        return hash(self.length)
 
 
+@plain_data(frozen=True, eq=False)
 @final
-class SSAGPRVal(SSAVal):
-    __slots__ = "phys_loc",
+class GPRType(GPRRangeType):
+    __slots__ = ()
 
-    def __init__(self, op, arg_name, element_index, phys_loc=None):
-        # type: (Op, str, int, GPROrStackLoc | None) -> None
-        super().__init__(op, arg_name, element_index)
-        self.phys_loc = phys_loc
+    def __init__(self, length=1):
+        if length != 1:
+            raise ValueError("length must be 1")
+        super().__init__(length=1)
 
-    def __len__(self):
-        return 1
 
-    def get_phys_loc(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None
-        loc = self._get_phys_loc(self.phys_loc, value_assignments)
-        if isinstance(loc, GPROrStackLoc):
-            return loc
-        return None
-
-    def get_reg_num(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> int | None
-        reg = self.get_reg(value_assignments)
-        if reg is not None:
-            return reg.reg_num
-        return None
-
-    def get_reg(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None
-        loc = self.get_phys_loc(value_assignments)
-        if isinstance(loc, GPR):
-            return loc
-        return None
-
-    def get_stack_slot(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None
-        loc = self.get_phys_loc(value_assignments)
-        if isinstance(loc, StackSlot):
-            return loc
-        return None
-
-    def possible_reg_assignments(self, value_assignments,
-                                 conflicting_regs=set()):
-        # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR]
-        if self.get_phys_loc(value_assignments) is not None:
-            raise ValueError("can't assign a already-assigned SSA value")
-        for reg in GPR:
-            if reg not in conflicting_regs:
-                yield reg
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class CYType(RegType):
+    __slots__ = ()
 
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return RegClass([XERBit.CY])
 
-@final
-class SSAXERBitVal(SSAVal):
-    __slots__ = "phys_loc",
 
-    def __init__(self, op, arg_name, element_index, phys_loc=None):
-        # type: (Op, str, int, XERBit | None) -> None
-        super().__init__(op, arg_name, element_index)
-        self.phys_loc = phys_loc
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class GlobalMemType(RegType):
+    __slots__ = ()
 
-    def get_phys_loc(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None
-        loc = self._get_phys_loc(self.phys_loc, value_assignments)
-        if isinstance(loc, XERBit):
-            return loc
-        return None
+    @property
+    def reg_class(self):
+        # type: () -> RegClass
+        return RegClass([GlobalMem.GlobalMem])
 
 
+@plain_data()
 @final
-class SSAMemory(SSAVal):
-    __slots__ = "phys_loc",
+class StackSlot(GPRRangeOrStackLoc):
+    """a stack slot. Use OpCopy to load from/store into this stack slot."""
+    __slots__ = "offset", "length"
 
-    def __init__(self, op, arg_name, element_index,
-                 phys_loc=GlobalMem.GlobalMem):
-        # type: (Op, str, int, GlobalMem) -> None
-        super().__init__(op, arg_name, element_index)
-        self.phys_loc = phys_loc
+    def __init__(self, offset=None, length=1):
+        # type: (int | None, int) -> None
+        self.offset = offset
+        if length < 1:
+            raise ValueError("invalid length")
+        self.length = length
 
-    def get_phys_loc(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem
-        loc = self._get_phys_loc(self.phys_loc, value_assignments)
-        if isinstance(loc, GlobalMem):
-            return loc
-        return self.phys_loc
+    def __len__(self):
+        return self.length
 
 
-@plain_data(unsafe_hash=True, frozen=True)
+_RegType = TypeVar("_RegType", bound=RegType)
+
+
+@plain_data(frozen=True, eq=False)
 @final
-class VecArg:
-    __slots__ = "regs",
+class SSAVal(Generic[_RegType]):
+    __slots__ = "op", "arg_name", "ty", "arg_index"
 
-    def __init__(self, regs):
-        # type: (Iterable[SSAGPRVal]) -> None
-        self.regs = tuple(regs)
+    def __init__(self, op, arg_name, ty):
+        # type: (Op, str, _RegType) -> None
+        self.op = op
+        """the Op that writes this SSAVal"""
 
-    def __len__(self):
-        return len(self.regs)
-
-    def is_unassigned(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> bool
-        for val in self.regs:
-            if val.get_phys_loc(value_assignments) is not None:
-                return False
-        return True
-
-    def try_get_range(self, value_assignments=None, allow_unassigned=False,
-                      raise_if_invalid=False):
-        # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None
-        if len(self.regs) == 0:
-            return range(0)
-
-        retval = None  # type: range | None
-        for i, val in enumerate(self.regs):
-            if val.get_phys_loc(value_assignments) is None:
-                if not allow_unassigned:
-                    if raise_if_invalid:
-                        raise ValueError("not a valid register range: "
-                                         "unassigned SSA value encountered")
-                    return None
-                continue
-            reg = val.get_reg_num(value_assignments)
-            if reg is None:
-                if raise_if_invalid:
-                    raise ValueError("not a valid register range: "
-                                     "non-register encountered")
-                return None
-            expected_range = range(reg - i, reg - i + len(self.regs))
-            if retval is None:
-                retval = expected_range
-            elif retval != expected_range:
-                if raise_if_invalid:
-                    raise ValueError("not a valid register range: "
-                                     "register out of sequence")
-                return None
-        return retval
-
-    def possible_reg_assignments(
-        self,
-        val,  # type: SSAVal
-        value_assignments,  # type: dict[SSAVal, PhysLoc] | None
-        conflicting_regs=set(),  # type: set[GPR]
-    ):  # type: (...) -> Iterable[GPR]
-        index = self.regs.index(val)
-        alignment = 1
-        while alignment < len(self.regs):
-            alignment *= 2
-        r = self.try_get_range(value_assignments)
-        if r is not None and r.start % alignment != 0:
-            raise ValueError("must be a ascending aligned range of GPRs")
-        if r is None:
-            for i in range(0, len(GPR), alignment):
-                r = range(i, i + len(self.regs))
-                if any(GPR(reg) in conflicting_regs for reg in r):
-                    continue
-                yield GPR(r[index])
-        else:
-            yield GPR(r[index])
+        self.arg_name = arg_name
+        """the name of the argument of self.op that writes this SSAVal"""
 
-    def like(self, op, arg_name):
-        # type: (Op, str) -> VecArg
-        """create a new VecArg based off of self's type.
-        has same signature as SSAVal.like.
-        """
-        return VecArg(
-            SSAGPRVal(op, arg_name, i) for i in range(len(self.regs)))
+        self.ty = ty
 
+    def __eq__(self, rhs):
+        if isinstance(rhs, SSAVal):
+            return (self.op is rhs.op
+                    and self.arg_name == rhs.arg_name)
+        return False
 
-def vec_or_scalar_arg(element_count, op, arg_name):
-    # type: (int | None, Op, str) -> VecArg | SSAGPRVal
-    if element_count is None:
-        return SSAGPRVal(op, arg_name, 0)
-    else:
-        return VecArg(SSAGPRVal(op, arg_name, i) for i in range(element_count))
+    def __hash__(self):
+        return hash((id(self.op), self.arg_name))
 
 
 @final
@@ -320,44 +303,25 @@ class EqualityConstraint:
     __slots__ = "lhs", "rhs"
 
     def __init__(self, lhs, rhs):
-        # type: (SSAVal, SSAVal) -> None
+        # type: (list[SSAVal], list[SSAVal]) -> None
         self.lhs = lhs
         self.rhs = rhs
+        if len(lhs) == 0 or len(rhs) == 0:
+            raise ValueError("can't constrain an empty list to be equal")
 
 
 @plain_data(unsafe_hash=True, frozen=True)
 class Op(metaclass=ABCMeta):
     __slots__ = ()
 
-    def input_ssa_vals(self):
-        # type: () -> Iterable[SSAVal]
-        for arg in self.inputs().values():
-            if isinstance(arg, VecArg):
-                yield from arg.regs
-            else:
-                yield arg
-
-    def output_ssa_vals(self):
-        # type: () -> Iterable[SSAVal]
-        for arg in self.outputs().values():
-            if isinstance(arg, VecArg):
-                yield from arg.regs
-            else:
-                yield arg
-
     @abstractmethod
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         ...
 
     @abstractmethod
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
-        ...
-
-    @abstractmethod
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        # type: () -> dict[str, SSAVal]
         ...
 
     def get_equality_constraints(self):
@@ -371,42 +335,79 @@ class Op(metaclass=ABCMeta):
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
-class OpCopy(Op):
+class OpCopy(Op, Generic[_RegType]):
     __slots__ = "dest", "src"
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"src": self.src}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"dest": self.dest}
 
     def __init__(self, src):
-        # type: (SSAGPRVal) -> None
-        self.dest = src.like(op=self, arg_name="dest")
+        # type: (SSAVal[_RegType]) -> None
+        self.dest = SSAVal(self, "dest", src.ty)
         self.src = src
 
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-        if not isinstance(val, SSAGPRVal):
-            raise ValueError("invalid operand type")
-        return val.possible_reg_assignments(value_assignments)
-
-
-def range_overlaps(range1, range2):
-    # type: (range, range) -> bool
-    if len(range1) == 0 or len(range2) == 0:
-        return False
-    range1_last = range1[-1]
-    range2_last = range2[-1]
-    return (range1.start in range2 or range1_last in range2 or
-            range2.start in range1 or range2_last in range1)
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpConcat(Op):
+    __slots__ = "dest", "sources"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"dest": self.dest}
+
+    def __init__(self, sources):
+        # type: (Iterable[SSAVal[GPRRangeType]]) -> None
+        sources = tuple(sources)
+        self.dest = SSAVal(self, "dest", GPRRangeType(
+            sum(i.ty.length for i in sources)))
+        self.sources = sources
+
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint([self.dest], [*self.sources])
+
+
+@plain_data(unsafe_hash=True, frozen=True)
+@final
+class OpSplit(Op):
+    __slots__ = "results", "src"
+
+    def inputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {"src": self.src}
+
+    def outputs(self):
+        # type: () -> dict[str, SSAVal]
+        return {i.arg_name: i for i in self.results}
+
+    def __init__(self, src, split_indexes):
+        # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
+        ranges = []  # type: list[GPRRangeType]
+        last = 0
+        for i in split_indexes:
+            if not (0 < i < src.ty.length):
+                raise ValueError(f"invalid split index: {i}, must be in "
+                                 f"0 < i < {src.ty.length}")
+            ranges.append(GPRRangeType(i - last))
+            last = i
+        ranges.append(GPRRangeType(src.ty.length - last))
+        self.src = src
+        self.results = tuple(
+            SSAVal(self, f"results{i}", r) for i, r in enumerate(ranges))
+
+    def get_equality_constraints(self):
+        # type: () -> Iterable[EqualityConstraint]
+        yield EqualityConstraint([*self.results], [self.src])
 
 
 @plain_data(unsafe_hash=True, frozen=True)
@@ -415,59 +416,25 @@ class OpAddSubE(Op):
     __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RT": self.RT, "CY_out": self.CY_out}
 
     def __init__(self, RA, RB, CY_in, is_sub):
-        # type: (VecArg, VecArg, SSAXERBitVal, bool) -> None
-        if len(RA.regs) != len(RB.regs):
-            raise TypeError(f"source lengths must match: "
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+        if RA.ty != RB.ty:
+            raise TypeError(f"source types must match: "
                             f"{RA} doesn't match {RB}")
-        self.RT = RA.like(op=self, arg_name="RT")
+        self.RT = SSAVal(self, "RT", RA.ty)
         self.RA = RA
         self.RB = RB
         self.CY_in = CY_in
-        self.CY_out = CY_in.like(op=self, arg_name="CY_out")
+        self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
         self.is_sub = is_sub
 
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-        if self.CY_in == val or self.CY_out == val:
-            yield XERBit.CY
-        elif val in self.RT.regs:
-            # since possible_reg_assignments only returns aligned
-            # vectors, all possible assignments either are the same as an
-            # input or don't overlap with an input and we avoid the incorrect
-            # results caused by partial overlaps overwriting input elements
-            # before they're read
-            yield from self.RT.possible_reg_assignments(val, value_assignments)
-        elif val in self.RA.regs:
-            yield from self.RA.possible_reg_assignments(val, value_assignments)
-        else:
-            yield from self.RB.possible_reg_assignments(val, value_assignments)
-
-    def get_equality_constraints(self):
-        # type: () -> Iterable[EqualityConstraint]
-        yield EqualityConstraint(self.CY_in, self.CY_out)
-
-
-def to_reg_set(v):
-    # type: (None | GPR | range) -> set[GPR]
-    if v is None:
-        return set()
-    if isinstance(v, range):
-        return set(map(GPR, v))
-    return {v}
-
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -475,64 +442,25 @@ class OpBigIntMulDiv(Op):
     __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RT": self.RT, "RS": self.RS}
 
     def __init__(self, RA, RB, RC, is_div):
-        # type: (VecArg, SSAGPRVal, SSAGPRVal, bool) -> None
-        self.RT = RA.like(op=self, arg_name="RT")
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+        self.RT = SSAVal(self, "RT", RA.ty)
         self.RA = RA
         self.RB = RB
         self.RC = RC
-        self.RS = RC.like(op=self, arg_name="RS")
+        self.RS = SSAVal(self, "RS", RC.ty)
         self.is_div = is_div
 
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        RT_range = self.RT.try_get_range(value_assignments,
-                                         allow_unassigned=True,
-                                         raise_if_invalid=True)
-        RA_range = self.RA.try_get_range(value_assignments,
-                                         allow_unassigned=True,
-                                         raise_if_invalid=True)
-        RC_RS_reg = self.RC.get_reg(value_assignments)
-        if RC_RS_reg is None:
-            RC_RS_reg = self.RS.get_reg(value_assignments)
-
-        if self.RC == val or self.RS == val:
-            if RC_RS_reg is not None:
-                yield RC_RS_reg
-            else:
-                conflicting_regs = to_reg_set(RT_range) | to_reg_set(RA_range)
-                yield from self.RC.possible_reg_assignments(value_assignments,
-                                                            conflicting_regs)
-        elif val in self.RT.regs:
-            # since possible_reg_assignments only returns aligned
-            # vectors, all possible assignments either are the same as
-            # RA or don't overlap with RA and we avoid the incorrect
-            # results caused by partial overlaps overwriting input elements
-            # before they're read
-            yield from self.RT.possible_reg_assignments(
-                val, value_assignments,
-                conflicting_regs=to_reg_set(RA_range) | to_reg_set(RC_RS_reg))
-        else:
-            yield from self.RA.possible_reg_assignments(
-                val, value_assignments,
-                conflicting_regs=to_reg_set(RT_range) | to_reg_set(RC_RS_reg))
-
     def get_equality_constraints(self):
         # type: () -> Iterable[EqualityConstraint]
-        yield EqualityConstraint(self.RC, self.RS)
+        yield EqualityConstraint([self.RC], [self.RS])
 
 
 @final
@@ -549,54 +477,20 @@ class OpBigIntShift(Op):
     __slots__ = "RT", "inp", "sh", "kind"
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"inp": self.inp, "sh": self.sh}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RT": self.RT}
 
     def __init__(self, inp, sh, kind):
-        # type: (VecArg, SSAGPRVal, ShiftKind) -> None
-        self.RT = inp.like(op=self, arg_name="RT")
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+        self.RT = SSAVal(self, "RT", inp.ty)
         self.inp = inp
         self.sh = sh
         self.kind = kind
 
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        RT_range = self.RT.try_get_range(value_assignments,
-                                         allow_unassigned=True,
-                                         raise_if_invalid=True)
-        inp_range = self.inp.try_get_range(value_assignments,
-                                           allow_unassigned=True,
-                                           raise_if_invalid=True)
-        sh_reg = self.sh.get_reg(value_assignments)
-
-        if self.sh == val:
-            conflicting_regs = to_reg_set(RT_range)
-            yield from self.sh.possible_reg_assignments(value_assignments,
-                                                        conflicting_regs)
-        elif val in self.RT.regs:
-            # since possible_reg_assignments only returns aligned
-            # vectors, all possible assignments either are the same as
-            # RA or don't overlap with RA and we avoid the incorrect
-            # results caused by partial overlaps overwriting input elements
-            # before they're read
-            yield from self.RT.possible_reg_assignments(
-                val, value_assignments,
-                conflicting_regs=to_reg_set(inp_range) | to_reg_set(sh_reg))
-        else:
-            yield from self.inp.possible_reg_assignments(
-                val, value_assignments,
-                conflicting_regs=to_reg_set(RT_range))
-
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -604,32 +498,18 @@ class OpLI(Op):
     __slots__ = "out", "value"
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self, value, element_count=None):
-        # type: (int, int | None) -> None
-        self.out = vec_or_scalar_arg(element_count, op=self, arg_name="out")
+    def __init__(self, value, length=1):
+        # type: (int, int) -> None
+        self.out = SSAVal(self, "out", GPRRangeType(length))
         self.value = value
 
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        if isinstance(self.out, VecArg):
-            yield from self.out.possible_reg_assignments(val,
-                                                         value_assignments)
-        else:
-            yield from self.out.possible_reg_assignments(value_assignments)
-
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -637,27 +517,16 @@ class OpClearCY(Op):
     __slots__ = "out",
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
     def __init__(self):
         # type: () -> None
-        self.out = SSAXERBitVal(op=self, arg_name="out", element_index=0,
-                                phys_loc=XERBit.CY)
-
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        yield XERBit.CY
+        self.out = SSAVal(self, "out", CYType())
 
 
 @plain_data(unsafe_hash=True, frozen=True)
@@ -666,48 +535,20 @@ class OpLoad(Op):
     __slots__ = "RT", "RA", "offset", "mem"
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RA": self.RA, "mem": self.mem}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RT": self.RT}
 
-    def __init__(self, RA, offset, mem, element_count=None):
-        # type: (SSAGPRVal, int, SSAMemory, int | None) -> None
-        self.RT = vec_or_scalar_arg(element_count, op=self, arg_name="RT")
+    def __init__(self, RA, offset, mem, length=1):
+        # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+        self.RT = SSAVal(self, "RT", GPRRangeType(length))
         self.RA = RA
         self.offset = offset
         self.mem = mem
 
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        RA_reg = self.RA.get_reg(value_assignments)
-
-        if self.mem == val:
-            yield GlobalMem.GlobalMem
-        elif self.RA == val:
-            if isinstance(self.RT, VecArg):
-                conflicting_regs = to_reg_set(self.RT.try_get_range(
-                    value_assignments, allow_unassigned=True,
-                    raise_if_invalid=True))
-            else:
-                conflicting_regs = set()
-            yield from self.RA.possible_reg_assignments(value_assignments,
-                                                        conflicting_regs)
-        elif isinstance(self.RT, VecArg):
-            yield from self.RT.possible_reg_assignments(
-                val, value_assignments,
-                conflicting_regs=to_reg_set(RA_reg))
-        else:
-            yield from self.RT.possible_reg_assignments(value_assignments)
-
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -715,41 +556,20 @@ class OpStore(Op):
     __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"mem_out": self.mem_out}
 
     def __init__(self, RS, RA, offset, mem_in):
-        # type: (VecArg | SSAGPRVal, SSAGPRVal, int, SSAMemory) -> None
+        # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
         self.RS = RS
         self.RA = RA
         self.offset = offset
         self.mem_in = mem_in
-        self.mem_out = mem_in.like(op=self, arg_name="mem_out")
-
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        if self.mem_in == val or self.mem_out == val:
-            yield GlobalMem.GlobalMem
-        elif self.RA == val:
-            yield from self.RA.possible_reg_assignments(value_assignments)
-        elif isinstance(self.RS, VecArg):
-            yield from self.RS.possible_reg_assignments(val, value_assignments)
-        else:
-            yield from self.RS.possible_reg_assignments(value_assignments)
-
-    def get_equality_constraints(self):
-        # type: () -> Iterable[EqualityConstraint]
-        yield EqualityConstraint(self.mem_in, self.mem_out)
+        self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
 
 
 @plain_data(unsafe_hash=True, frozen=True)
@@ -758,34 +578,16 @@ class OpFuncArg(Op):
     __slots__ = "out",
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
-    def __init__(self, phys_loc):
-        # type: (GPROrStackLoc | Iterable[GPROrStackLoc]) -> None
-        if isinstance(phys_loc, GPROrStackLoc):
-            self.out = SSAGPRVal(self, "out", 0, phys_loc)
-        else:
-            self.out = VecArg(
-                SSAGPRVal(self, "out", i, v) for i, v in enumerate(phys_loc))
-
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        if isinstance(self.out, VecArg):
-            yield from self.out.possible_reg_assignments(val,
-                                                         value_assignments)
-        else:
-            yield from self.out.possible_reg_assignments(value_assignments)
+    def __init__(self, ty):
+        # type: (RegType) -> None
+        self.out = SSAVal(self, "out", ty)
 
 
 @plain_data(unsafe_hash=True, frozen=True)
@@ -794,26 +596,16 @@ class OpInputMem(Op):
     __slots__ = "out",
 
     def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {}
 
     def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
+        # type: () -> dict[str, SSAVal]
         return {"out": self.out}
 
     def __init__(self):
         # type: () -> None
-        self.out = SSAMemory(op=self, arg_name="out", element_index=0)
-
-    def possible_reg_assignments(self, val, value_assignments):
-        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
-        if val not in self.input_ssa_vals() \
-                and val not in self.output_ssa_vals():
-            raise ValueError(f"{val} must be an operand of {self}")
-        if val.get_phys_loc(value_assignments) is not None:
-            raise ValueError(f"{val} already assigned a physical location")
-
-        yield GlobalMem.GlobalMem
+        self.out = SSAVal(self, "out", GlobalMemType())
 
 
 def op_set_to_list(ops):
@@ -823,7 +615,7 @@ def op_set_to_list(ops):
     ops_to_pending_input_count_map = {}  # type: dict[Op, int]
     for op in ops:
         input_count = 0
-        for val in op.input_ssa_vals():
+        for val in op.inputs().values():
             input_count += 1
             input_vals_to_ops_map[val].add(op)
         while len(worklists) <= input_count:
@@ -835,7 +627,7 @@ def op_set_to_list(ops):
     while len(worklists[0]) != 0:
         writing_op = worklists[0].pop()
         retval.append(writing_op)
-        for val in writing_op.output_ssa_vals():
+        for val in writing_op.outputs().values():
             if val in ready_vals:
                 raise ValueError(f"multiple instructions must not write "
                                  f"to the same SSA value: {val}")
@@ -882,14 +674,101 @@ class LiveInterval:
 
 
 @final
-class EqualitySet(AbstractSet[SSAVal]):
-    def __init__(self, items):
-        # type: (Iterable[SSAVal]) -> None
-        self.__items = frozenset(items)
+class MergedRegSet(Mapping[SSAVal[_RegType], int]):
+    def __init__(self, reg_set):
+        # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
+        self.__items = {}  # type: dict[SSAVal[_RegType], int]
+        if isinstance(reg_set, SSAVal):
+            reg_set = [(reg_set, 0)]
+        for ssa_val, offset in reg_set:
+            if ssa_val in self.__items:
+                other = self.__items[ssa_val]
+                if offset != other:
+                    raise ValueError(
+                        f"can't merge register sets: conflicting offsets: "
+                        f"for {ssa_val}: {offset} != {other}")
+            else:
+                self.__items[ssa_val] = offset
+        first_item = None
+        for i in self.__items.items():
+            first_item = i
+            break
+        if first_item is None:
+            raise ValueError("can't have empty MergedRegs")
+        first_ssa_val, start = first_item
+        ty = first_ssa_val.ty
+        if isinstance(ty, GPRRangeType):
+            stop = start + ty.length
+            for ssa_val, offset in self.__items.items():
+                if not isinstance(ssa_val.ty, GPRRangeType):
+                    raise ValueError(f"can't merge incompatible types: "
+                                     f"{ssa_val.ty} and {ty}")
+                stop = max(stop, offset + ssa_val.ty.length)
+                start = min(start, offset)
+            ty = GPRRangeType(stop - start)
+        else:
+            stop = 1
+            for ssa_val, offset in self.__items.items():
+                if offset != 0:
+                    raise ValueError(f"can't have non-zero offset "
+                                     f"for {ssa_val.ty}")
+                if ty != ssa_val.ty:
+                    raise ValueError(f"can't merge incompatible types: "
+                                     f"{ssa_val.ty} and {ty}")
+        self.__start = start  # type: int
+        self.__stop = stop  # type: int
+        self.__ty = ty  # type: RegType
+
+    @staticmethod
+    def from_equality_constraint(constraint_sequence):
+        # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
+        if len(constraint_sequence) == 1:
+            # any type allowed with len = 1
+            return MergedRegSet(constraint_sequence[0])
+        offset = 0
+        retval = []
+        for val in constraint_sequence:
+            if not isinstance(val.ty, GPRRangeType):
+                raise ValueError("equality constraint sequences must only "
+                                 "have SSAVal type GPRRangeType")
+            retval.append((val, offset))
+            offset += val.ty.length
+        return MergedRegSet(retval)
 
-    def __contains__(self, x):
-        # type: (object) -> bool
-        return x in self.__items
+    @property
+    def ty(self):
+        return self.__ty
+
+    @property
+    def stop(self):
+        return self.__stop
+
+    @property
+    def start(self):
+        return self.__start
+
+    @property
+    def range(self):
+        return range(self.__start, self.__stop)
+
+    def offset_by(self, amount):
+        # type: (int) -> MergedRegSet[_RegType]
+        return MergedRegSet((k, v + amount) for k, v in self.items())
+
+    def normalized(self):
+        # type: () -> MergedRegSet[_RegType]
+        return self.offset_by(-self.start)
+
+    def with_offset_to_match(self, target):
+        # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
+        for ssa_val, offset in self.items():
+            if ssa_val in target:
+                return self.offset_by(target[ssa_val] - offset)
+        raise ValueError("can't change offset to match unrelated MergedRegSet")
+
+    def __getitem__(self, item):
+        # type: (SSAVal[_RegType]) -> int
+        return self.__items[item]
 
     def __iter__(self):
         return iter(self.__items)
@@ -898,60 +777,66 @@ class EqualitySet(AbstractSet[SSAVal]):
         return len(self.__items)
 
     def __hash__(self):
-        return super()._hash()
+        return hash(frozenset(self.items()))
+
+    def __repr__(self):
+        return f"MergedRegSet({list(self.__items.items())})"
 
 
 @final
-class EqualitySets(Mapping[SSAVal, EqualitySet]):
+class MergedRegSets(Mapping[SSAVal, MergedRegSet]):
     def __init__(self, ops):
         # type: (Iterable[Op]) -> None
-        indexes = {}  # type: dict[SSAVal, int]
-        sets = []  # type: list[set[SSAVal]]
+        merged_sets = {}  # type: dict[SSAVal, MergedRegSet]
         for op in ops:
-            for val in (*op.input_ssa_vals(), *op.output_ssa_vals()):
-                if val not in indexes:
-                    indexes[val] = len(sets)
-                    sets.append({val})
+            for val in (*op.inputs().values(), *op.outputs().values()):
+                if val not in merged_sets:
+                    merged_sets[val] = MergedRegSet(val)
             for e in op.get_equality_constraints():
-                lhs_index = indexes[e.lhs]
-                rhs_index = indexes[e.rhs]
-                sets[lhs_index] |= sets[rhs_index]
-                for val in sets[rhs_index]:
-                    indexes[val] = lhs_index
+                lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
+                rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
+                lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
+                rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
+                full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
+                for val in full_set.keys():
+                    merged_sets[val] = full_set
 
-        equality_sets = [EqualitySet(i) for i in sets]
-        self.__map = {k: equality_sets[v] for k, v in indexes.items()}
+        self.__map = {k: v.normalized() for k, v in merged_sets.items()}
 
     def __getitem__(self, key):
-        # type: (SSAVal) -> EqualitySet
+        # type: (SSAVal) -> MergedRegSet
         return self.__map[key]
 
     def __iter__(self):
         return iter(self.__map)
 
+    def __len__(self):
+        return len(self.__map)
+
 
 @final
-class LiveIntervals(Mapping[EqualitySet, LiveInterval]):
+class LiveIntervals(Mapping[MergedRegSet, LiveInterval]):
     def __init__(self, ops):
         # type: (list[Op]) -> None
-        self.__equality_sets = eqsets = EqualitySets(ops)
-        live_intervals = {}  # type: dict[EqualitySet, LiveInterval]
+        self.__merges_reg_sets = MergedRegSets(ops)
+        live_intervals = {}  # type: dict[MergedRegSet, LiveInterval]
         for op_idx, op in enumerate(ops):
-            for val in op.input_ssa_vals():
-                live_intervals[eqsets[val]] += op_idx
-            for val in op.output_ssa_vals():
-                if eqsets[val] not in live_intervals:
-                    live_intervals[eqsets[val]] = LiveInterval(op_idx)
+            for val in op.inputs().values():
+                live_intervals[self.__merges_reg_sets[val]] += op_idx
+            for val in op.outputs().values():
+                reg_set = self.__merges_reg_sets[val]
+                if reg_set not in live_intervals:
+                    live_intervals[reg_set] = LiveInterval(op_idx)
                 else:
-                    live_intervals[eqsets[val]] += op_idx
+                    live_intervals[reg_set] += op_idx
         self.__live_intervals = live_intervals
 
     @property
-    def equality_sets(self):
-        return self.__equality_sets
+    def merges_reg_sets(self):
+        return self.__merges_reg_sets
 
     def __getitem__(self, key):
-        # type: (EqualitySet) -> LiveInterval
+        # type: (MergedRegSet) -> LiveInterval
         return self.__live_intervals[key]
 
     def __iter__(self):
@@ -961,11 +846,11 @@ class LiveIntervals(Mapping[EqualitySet, LiveInterval]):
 @final
 class IGNode:
     """ interference graph node """
-    __slots__ = "equality_set", "edges"
+    __slots__ = "merged_reg_set", "edges"
 
-    def __init__(self, equality_set, edges=()):
-        # type: (EqualitySet, Iterable[IGNode]) -> None
-        self.equality_set = equality_set
+    def __init__(self, merged_reg_set, edges=()):
+        # type: (MergedRegSet, Iterable[IGNode]) -> None
+        self.merged_reg_set = merged_reg_set
         self.edges = set(edges)
 
     def add_edge(self, other):
@@ -976,11 +861,11 @@ class IGNode:
     def __eq__(self, other):
         # type: (object) -> bool
         if isinstance(other, IGNode):
-            return self.equality_set == other.equality_set
+            return self.merged_reg_set == other.merged_reg_set
         return NotImplemented
 
     def __hash__(self):
-        return self.equality_set.__hash__()
+        return hash(self.merged_reg_set)
 
     def __repr__(self, nodes=None):
         # type: (None | dict[IGNode, int]) -> str
@@ -991,18 +876,18 @@ class IGNode:
         nodes[self] = len(nodes)
         edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
         return (f"IGNode(#{nodes[self]}, "
-                f"equality_set={self.equality_set}, "
+                f"merged_reg_set={self.merged_reg_set}, "
                 f"edges={edges})")
 
 
 @final
-class InterferenceGraph(Mapping[EqualitySet, IGNode]):
-    def __init__(self, equality_sets):
-        # type: (Iterable[EqualitySet]) -> None
-        self.__nodes = {i: IGNode(i) for i in equality_sets}
+class InterferenceGraph(Mapping[MergedRegSet, IGNode]):
+    def __init__(self, merged_reg_sets):
+        # type: (Iterable[MergedRegSet]) -> None
+        self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
 
     def __getitem__(self, key):
-        # type: (EqualitySet) -> IGNode
+        # type: (MergedRegSet) -> IGNode
         return self.__nodes[key]
 
     def __iter__(self):
@@ -1014,7 +899,7 @@ class AllocationFailed:
     __slots__ = "op_idx", "arg", "live_intervals"
 
     def __init__(self, op_idx, arg, live_intervals):
-        # type: (int, SSAVal | VecArg, LiveIntervals) -> None
+        # type: (int, SSAVal, LiveIntervals) -> None
         self.op_idx = op_idx
         self.arg = arg
         self.live_intervals = live_intervals
@@ -1022,11 +907,8 @@ class AllocationFailed:
 
 def try_allocate_registers_without_spilling(ops):
     # type: (list[Op]) -> dict[SSAVal, PhysLoc] | AllocationFailed
-    live_intervals = LiveIntervals(ops)
 
-    def is_constrained(node):
-        # type: (EqualitySet) -> bool
-        raise NotImplementedError
+    live_intervals = LiveIntervals(ops)
 
     raise NotImplementedError