working on toom_cook.py
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 8 Oct 2022 00:23:09 +0000 (17:23 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 8 Oct 2022 00:23:09 +0000 (17:23 -0700)
src/bigint_presentation_code/toom_cook.py

index bd8169b34a7bcdd1d5961b7825ee58b110872ce1..a045edbae2bc7a18449861e9a5737f5d911e71b2 100644 (file)
@@ -84,7 +84,7 @@ class StackSlot(GPROrStackLoc):
 
 
 @plain_data(eq=False)
-class SSAVal:
+class SSAVal(metaclass=ABCMeta):
     __slots__ = "id",
 
     def __init__(self, id=None):
@@ -101,6 +101,19 @@ class SSAVal:
     def __hash__(self):
         return hash(self.id)
 
+    def _get_phys_loc(self, phys_loc_in, value_assignments=None):
+        # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
+        if phys_loc_in is not None:
+            return phys_loc_in
+        if value_assignments is not None:
+            return value_assignments.get(self)
+        return None
+
+    @abstractmethod
+    def get_phys_loc(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None
+        ...
+
 
 @plain_data(eq=False)
 @final
@@ -115,17 +128,43 @@ class SSAGPRVal(SSAVal):
     def __len__(self):
         return 1
 
+    def get_phys_loc(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None
+        loc = self._get_phys_loc(self.phys_loc, value_assignments)
+        if isinstance(loc, GPROrStackLoc):
+            return loc
+        return None
+
     def get_reg_num(self, value_assignments=None):
         # type: (dict[SSAVal, PhysLoc] | None) -> int | None
-        phys_loc = None
-        if value_assignments is not None:
-            phys_loc = value_assignments.get(self)
-        if self.phys_loc is not None:
-            phys_loc = self.phys_loc
-        if isinstance(phys_loc, GPR):
-            return phys_loc.reg_num
+        reg = self.get_reg(value_assignments)
+        if reg is not None:
+            return reg.reg_num
+        return None
+
+    def get_reg(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None
+        loc = self.get_phys_loc(value_assignments)
+        if isinstance(loc, GPR):
+            return loc
+        return None
+
+    def get_stack_slot(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None
+        loc = self.get_phys_loc(value_assignments)
+        if isinstance(loc, StackSlot):
+            return loc
         return None
 
+    def possible_reg_assignments(self, value_assignments,
+                                 conflicting_regs=set()):
+        # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR]
+        if self.get_phys_loc(value_assignments) is not None:
+            raise ValueError("can't assign a already-assigned SSA value")
+        for reg in GPR:
+            if reg not in conflicting_regs:
+                yield reg
+
 
 @plain_data(eq=False)
 @final
@@ -136,6 +175,13 @@ class SSAXERBitVal(SSAVal):
         # type: (XERBit | None) -> None
         self.phys_loc = phys_loc
 
+    def get_phys_loc(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None
+        loc = self._get_phys_loc(self.phys_loc, value_assignments)
+        if isinstance(loc, XERBit):
+            return loc
+        return None
+
 
 @plain_data(eq=False)
 @final
@@ -146,6 +192,13 @@ class SSAMemory(SSAVal):
         # type: (GlobalMem) -> None
         self.phys_loc = phys_loc
 
+    def get_phys_loc(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem | None
+        loc = self._get_phys_loc(self.phys_loc, value_assignments)
+        if isinstance(loc, GlobalMem):
+            return loc
+        return None
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -159,46 +212,65 @@ class VecArg:
     def __len__(self):
         return len(self.regs)
 
-    def try_get_range(self, value_assignments=None, prefer_descending=False):
-        # type: (dict[SSAVal, PhysLoc] | None, int) -> range | None
+    def is_unassigned(self, value_assignments=None):
+        # type: (dict[SSAVal, PhysLoc] | None) -> bool
+        for val in self.regs:
+            if val.get_phys_loc(value_assignments) is not None:
+                return False
+        return True
+
+    def try_get_range(self, value_assignments=None, allow_unassigned=False,
+                      raise_if_invalid=False):
+        # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None
         if len(self.regs) == 0:
             return range(0)
 
-        start = self.regs[0].get_reg_num(value_assignments)
-        if start is None:
-            return None
-        if len(self.regs) == 1:
-            if prefer_descending:
-                return range(start, start - 1, -1)
-            return range(start, start + 1, 1)
-        reg = self.regs[1].get_reg_num(value_assignments)
-        if reg is None:
-            return None
-        step = reg - start
-        if abs(step) != 1:
-            return None
-        for i, reg in enumerate(self.regs):
-            reg = self.regs[1].get_reg_num(value_assignments)
+        retval = None  # type: range | None
+        for i, val in enumerate(self.regs):
+            if val.get_phys_loc(value_assignments) is None:
+                if not allow_unassigned:
+                    if raise_if_invalid:
+                        raise ValueError("not a valid register range: "
+                                         "unassigned SSA value encountered")
+                    return None
+                continue
+            reg = val.get_reg_num(value_assignments)
             if reg is None:
+                if raise_if_invalid:
+                    raise ValueError("not a valid register range: "
+                                     "non-register encountered")
                 return None
-            if start + i * step != reg:
+            expected_range = range(reg - i, reg - i + len(self.regs))
+            if retval is None:
+                retval = expected_range
+            elif retval != expected_range:
+                if raise_if_invalid:
+                    raise ValueError("not a valid register range: "
+                                     "register out of sequence")
                 return None
-        return range(start, start + len(self.regs) * step, step)
-
-    def try_get_ascending_range(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> range | None
-        r = self.try_get_range(value_assignments=value_assignments)
-        if r is not None and r.step == 1:
-            return r
-        return None
-
-    def try_get_descending_range(self, value_assignments=None):
-        # type: (dict[SSAVal, PhysLoc] | None) -> range | None
-        r = self.try_get_range(value_assignments=value_assignments,
-                               prefer_descending=True)
-        if r is not None and r.step == -1:
-            return r
-        return None
+        return retval
+
+    def possible_reg_assignments(
+        self,
+        val,  # type: SSAVal
+        value_assignments,  # type: dict[SSAVal, PhysLoc] | None
+        conflicting_regs=set(),  # type: set[GPR]
+    ):  # type: (...) -> Iterable[GPR]
+        index = self.regs.index(val)
+        alignment = 1
+        while alignment < len(self.regs):
+            alignment *= 2
+        r = self.try_get_range(value_assignments)
+        if r is not None and r.start % alignment != 0:
+            raise ValueError("must be a ascending aligned range of GPRs")
+        if r is None:
+            for i in range(0, len(GPR), alignment):
+                r = range(i, i + len(self.regs))
+                if any(GPR(reg) in conflicting_regs for reg in r):
+                    continue
+                yield GPR(r[index])
+        else:
+            yield GPR(r[index])
 
 
 @plain_data(unsafe_hash=True, frozen=True)
@@ -232,9 +304,9 @@ class Op(metaclass=ABCMeta):
         ...
 
     @abstractmethod
-    def meets_constraints(self, value_assignments):
-        # type: (dict[SSAVal, PhysLoc]) -> bool
-        return True
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        ...
 
     def __init__(self):
         pass
@@ -265,25 +337,49 @@ class OpCopy(Op):
         self.dest = dest
         self.src = src
 
-    def meets_constraints(self, value_assignments):
-        # type: (dict[SSAVal, PhysLoc]) -> bool
-        return True
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+        conflicting_regs = set()  # type: set[GPR]
+        if val in self.output_ssa_vals() and isinstance(self.dest, VecArg):
+            # OpCopy is the only op that can write to physical locations in
+            # any order, it handles figuring out the right instruction sequence
+            dest_locs = {}  # type: dict[GPROrStackLoc, SSAVal]
+            for val in self.dest.regs:
+                loc = val.get_phys_loc(value_assignments)
+                if loc is None:
+                    continue
+                if loc in dest_locs:
+                    raise ValueError(
+                        f"duplicate destination location not allowed: "
+                        f"{val} is assigned to {loc} which is also "
+                        f"written by {dest_locs[loc]}")
+                dest_locs[loc] = val
+                if isinstance(loc, GPR):
+                    conflicting_regs.add(loc)
+        if not isinstance(val, SSAGPRVal):
+            raise ValueError("invalid operand type")
+        return val.possible_reg_assignments(value_assignments,
+                                            conflicting_regs)
 
 
 def range_overlaps(range1, range2):
     # type: (range, range) -> bool
     if len(range1) == 0 or len(range2) == 0:
         return False
-    range1_last = next(reversed(range1))
-    range2_last = next(reversed(range2))
+    range1_last = range1[-1]
+    range2_last = range2[-1]
     return (range1.start in range2 or range1_last in range2 or
             range2.start in range1 or range2_last in range1)
 
-
 @plain_data(unsafe_hash=True, frozen=True)
 @final
-class OpAddE(Op):
-    __slots__ = "RT", "RA", "RB", "CY_in", "CY_out"
+class OpAddSubE(Op):
+    __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
 
     def inputs(self):
         # type: () -> dict[str, VecArg | SSAVal]
@@ -293,8 +389,8 @@ class OpAddE(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"RT": self.RT, "CY_out": self.CY_out}
 
-    def __init__(self, RT, RA, RB, CY_in, CY_out):
-        # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None
+    def __init__(self, RT, RA, RB, CY_in, CY_out, is_sub):
+        # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal, bool) -> None
         if len(RA.regs) != len(RT.regs):
             raise TypeError(f"source length must match dest "
                             f"length: {RA} doesn't match {RT}")
@@ -306,110 +402,43 @@ class OpAddE(Op):
         self.RB = RB
         self.CY_in = CY_in
         self.CY_out = CY_out
-
-    def meets_constraints(self, value_assignments):
-        # type: (dict[SSAVal, PhysLoc]) -> bool
-        RT_range = self.RT.try_get_ascending_range(value_assignments)
-        RA_range = self.RA.try_get_ascending_range(value_assignments)
-        RB_range = self.RB.try_get_ascending_range(value_assignments)
-        if RA_range is None or RB_range is None or RT_range is None:
-            return False
-        if RA_range != RT_range and range_overlaps(RA_range, RT_range):
-            return False
-        if RB_range != RT_range and range_overlaps(RB_range, RT_range):
-            return False
-        return True
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpSubE(Op):
-    __slots__ = "RT", "lhs", "rhs", "CY_in", "CY_out"
-
-    def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
-        return {"lhs": self.lhs, "rhs": self.rhs, "CY_in": self.CY_in}
-
-    def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
-        return {"RT": self.RT, "CY_out": self.CY_out}
-
-    def __init__(self, RT, lhs, rhs, CY_in, CY_out):
-        # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None
-        if len(lhs.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {lhs} doesn't match {RT}")
-        if len(rhs.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {rhs} doesn't match {RT}")
-        self.RT = RT
-        self.lhs = lhs
-        self.rhs = rhs
-        self.CY_in = CY_in
-        self.CY_out = CY_out
-
-    def meets_constraints(self, value_assignments):
-        # type: (dict[SSAVal, PhysLoc]) -> bool
-        RT_range = self.RT.try_get_ascending_range(value_assignments)
-        lhs_range = self.lhs.try_get_ascending_range(value_assignments)
-        rhs_range = self.rhs.try_get_ascending_range(value_assignments)
-        if lhs_range is None or rhs_range is None or RT_range is None:
-            return False
-        if lhs_range != RT_range and range_overlaps(lhs_range, RT_range):
-            return False
-        if rhs_range != RT_range and range_overlaps(rhs_range, RT_range):
-            return False
-        return True
-
-
-@plain_data(unsafe_hash=True, frozen=True)
-@final
-class OpMAddEU(Op):
-    __slots__ = "RT", "RA", "RB", "RC", "RS"
-
-    def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
-        return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
-
-    def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
-        return {"RT": self.RT, "RS": self.RS}
-
-    def __init__(self, RT, RA, RB, RC, RS):
-        # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None
-        if len(RA.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {RA} doesn't match {RT}")
-        self.RT = RT
-        self.RA = RA
-        self.RB = RB
-        self.RC = RC
-        self.RS = RS
-
-    def meets_constraints(self, value_assignments):
-        # type: (dict[SSAVal, PhysLoc]) -> bool
-        RT_range = self.RT.try_get_ascending_range(value_assignments)
-        RA_range = self.RA.try_get_ascending_range(value_assignments)
-        RB_reg = self.RB.get_reg_num(value_assignments)
-        RC_reg = self.RC.get_reg_num(value_assignments)
-        RS_reg = self.RS.get_reg_num(value_assignments)
-        if RT_range is None or RA_range is None or RB_reg is None \
-                or RC_reg is None or RS_reg is None:
-            return False
-        if RA_range != RT_range and range_overlaps(RA_range, RT_range):
-            return False
-        if RB_reg in RT_range or RC_reg in RT_range or RS_reg in RT_range \
-                or RS_reg in RA_range:
-            return False
-        if RS_reg != RC_reg:
-            return False
-        return True
+        self.is_sub = is_sub
+
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+        if self.CY_in == val or self.CY_out == val:
+            yield XERBit.CY
+        elif val in self.RT.regs:
+            # since possible_reg_assignments only returns aligned
+            # vectors, all possible assignments either are the same as an
+            # input or don't overlap with an input and we avoid the incorrect
+            # results caused by partial overlaps overwriting input elements
+            # before they're read
+            yield from self.RT.possible_reg_assignments(val, value_assignments)
+        elif val in self.RA.regs:
+            yield from self.RA.possible_reg_assignments(val, value_assignments)
+        else:
+            yield from self.RB.possible_reg_assignments(val, value_assignments)
+
+
+def to_reg_set(v):
+    # type: (None | GPR | range) -> set[GPR]
+    if v is None:
+        return set()
+    if isinstance(v, range):
+        return set(map(GPR, v))
+    return {v}
 
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
-class OpDivMod2DU(Op):
-    __slots__ = "RT", "RA", "RB", "RC", "RS"
+class OpBigIntMulDiv(Op):
+    __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
 
     def inputs(self):
         # type: () -> dict[str, VecArg | SSAVal]
@@ -419,8 +448,8 @@ class OpDivMod2DU(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"RT": self.RT, "RS": self.RS}
 
-    def __init__(self, RT, RA, RB, RC, RS):
-        # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None
+    def __init__(self, RT, RA, RB, RC, RS, is_div):
+        # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal, bool) -> None
         if len(RA.regs) != len(RT.regs):
             raise TypeError(f"source length must match dest "
                             f"length: {RA} doesn't match {RT}")
@@ -429,35 +458,60 @@ class OpDivMod2DU(Op):
         self.RB = RB
         self.RC = RC
         self.RS = RS
+        self.is_div = is_div
+
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        RT_range = self.RT.try_get_range(value_assignments,
+                                         allow_unassigned=True,
+                                         raise_if_invalid=True)
+        RA_range = self.RA.try_get_range(value_assignments,
+                                         allow_unassigned=True,
+                                         raise_if_invalid=True)
+        RC_RS_reg = self.RC.get_reg(value_assignments)
+        if RC_RS_reg is None:
+            RC_RS_reg = self.RS.get_reg(value_assignments)
+
+        if self.RC == val or self.RS == val:
+            if RC_RS_reg is not None:
+                yield RC_RS_reg
+            else:
+                conflicting_regs = to_reg_set(RT_range) | to_reg_set(RA_range)
+                yield from self.RC.possible_reg_assignments(value_assignments,
+                                                            conflicting_regs)
+        elif val in self.RT.regs:
+            # since possible_reg_assignments only returns aligned
+            # vectors, all possible assignments either are the same as
+            # RA or don't overlap with RA and we avoid the incorrect
+            # results caused by partial overlaps overwriting input elements
+            # before they're read
+            yield from self.RT.possible_reg_assignments(
+                val, value_assignments,
+                conflicting_regs=to_reg_set(RA_range) | to_reg_set(RC_RS_reg))
+        else:
+            yield from self.RA.possible_reg_assignments(
+                val, value_assignments,
+                conflicting_regs=to_reg_set(RT_range) | to_reg_set(RC_RS_reg))
 
 
-@plain_data(unsafe_hash=True, frozen=True)
 @final
-class OpBigIntShl(Op):
-    __slots__ = "RT", "inp", "sh"
-
-    def inputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
-        return {"inp": self.inp, "sh": self.sh}
-
-    def outputs(self):
-        # type: () -> dict[str, VecArg | SSAVal]
-        return {"RT": self.RT}
-
-    def __init__(self, RT, inp, sh):
-        # type: (VecArg, VecArg, SSAGPRVal) -> None
-        if len(inp.regs) != len(RT.regs):
-            raise TypeError(f"source length must match dest "
-                            f"length: {inp} doesn't match {RT}")
-        self.RT = RT
-        self.inp = inp
-        self.sh = sh
+@unique
+class ShiftKind(Enum):
+    Sl = "sl"
+    Sr = "sr"
+    Sra = "sra"
 
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
-class OpBigIntShr(Op):
-    __slots__ = "RT", "inp", "sh"
+class OpBigIntShift(Op):
+    __slots__ = "RT", "inp", "sh", "kind"
 
     def inputs(self):
         # type: () -> dict[str, VecArg | SSAVal]
@@ -467,14 +521,49 @@ class OpBigIntShr(Op):
         # type: () -> dict[str, VecArg | SSAVal]
         return {"RT": self.RT}
 
-    def __init__(self, RT, inp, sh):
-        # type: (VecArg, VecArg, SSAGPRVal) -> None
+    def __init__(self, RT, inp, sh, kind):
+        # type: (VecArg, VecArg, SSAGPRVal, ShiftKind) -> None
         if len(inp.regs) != len(RT.regs):
             raise TypeError(f"source length must match dest "
                             f"length: {inp} doesn't match {RT}")
         self.RT = RT
         self.inp = inp
         self.sh = sh
+        self.kind = kind
+
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        RT_range = self.RT.try_get_range(value_assignments,
+                                         allow_unassigned=True,
+                                         raise_if_invalid=True)
+        inp_range = self.inp.try_get_range(value_assignments,
+                                           allow_unassigned=True,
+                                           raise_if_invalid=True)
+        sh_reg = self.sh.get_reg(value_assignments)
+
+        if self.sh == val:
+            conflicting_regs = to_reg_set(RT_range)
+            yield from self.sh.possible_reg_assignments(value_assignments,
+                                                        conflicting_regs)
+        elif val in self.RT.regs:
+            # since possible_reg_assignments only returns aligned
+            # vectors, all possible assignments either are the same as
+            # RA or don't overlap with RA and we avoid the incorrect
+            # results caused by partial overlaps overwriting input elements
+            # before they're read
+            yield from self.RT.possible_reg_assignments(
+                val, value_assignments,
+                conflicting_regs=to_reg_set(inp_range) | to_reg_set(sh_reg))
+        else:
+            yield from self.inp.possible_reg_assignments(
+                val, value_assignments,
+                conflicting_regs=to_reg_set(RT_range))
 
 
 @plain_data(unsafe_hash=True, frozen=True)
@@ -495,6 +584,20 @@ class OpLI(Op):
         self.out = out
         self.value = value
 
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        if isinstance(self.out, VecArg):
+            yield from self.out.possible_reg_assignments(val,
+                                                         value_assignments)
+        else:
+            yield from self.out.possible_reg_assignments(value_assignments)
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -513,6 +616,16 @@ class OpClearCY(Op):
         # type: (SSAXERBitVal) -> None
         self.out = out
 
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        yield XERBit.CY
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -534,6 +647,34 @@ class OpLoad(Op):
         self.offset = offset
         self.mem = mem
 
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        RA_reg = self.RA.get_reg(value_assignments)
+
+        if self.mem == val:
+            yield GlobalMem.GlobalMem
+        elif self.RA == val:
+            if isinstance(self.RT, VecArg):
+                conflicting_regs = to_reg_set(self.RT.try_get_range(
+                    value_assignments, allow_unassigned=True,
+                    raise_if_invalid=True))
+            else:
+                conflicting_regs = set()
+            yield from self.RA.possible_reg_assignments(value_assignments,
+                                                        conflicting_regs)
+        elif isinstance(self.RT, VecArg):
+            yield from self.RT.possible_reg_assignments(
+                val, value_assignments,
+                conflicting_regs=to_reg_set(RA_reg))
+        else:
+            yield from self.RT.possible_reg_assignments(value_assignments)
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -556,6 +697,23 @@ class OpStore(Op):
         self.mem_in = mem_in
         self.mem_out = mem_out
 
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        if self.mem_in == val or self.mem_out == val:
+            yield GlobalMem.GlobalMem
+        elif self.RA == val:
+            yield from self.RA.possible_reg_assignments(value_assignments)
+        elif isinstance(self.RS, VecArg):
+            yield from self.RS.possible_reg_assignments(val, value_assignments)
+        else:
+            yield from self.RS.possible_reg_assignments(value_assignments)
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -574,6 +732,20 @@ class OpFuncArg(Op):
         # type: (VecArg | SSAGPRVal) -> None
         self.out = out
 
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        if isinstance(self.out, VecArg):
+            yield from self.out.possible_reg_assignments(val,
+                                                         value_assignments)
+        else:
+            yield from self.out.possible_reg_assignments(value_assignments)
+
 
 @plain_data(unsafe_hash=True, frozen=True)
 @final
@@ -592,6 +764,16 @@ class OpInputMem(Op):
         # type: (SSAMemory) -> None
         self.out = out
 
+    def possible_reg_assignments(self, val, value_assignments):
+        # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc]
+        if val not in self.input_ssa_vals() \
+                and val not in self.output_ssa_vals():
+            raise ValueError(f"{val} must be an operand of {self}")
+        if val.get_phys_loc(value_assignments) is not None:
+            raise ValueError(f"{val} already assigned a physical location")
+
+        yield GlobalMem.GlobalMem
+
 
 def op_set_to_list(ops):
     # type: (Iterable[Op]) -> list[Op]