From d880f1f2dee6448a3603459d0d2897546f19b355 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 7 Oct 2022 17:23:09 -0700 Subject: [PATCH] working on toom_cook.py --- src/bigint_presentation_code/toom_cook.py | 544 +++++++++++++++------- 1 file changed, 363 insertions(+), 181 deletions(-) diff --git a/src/bigint_presentation_code/toom_cook.py b/src/bigint_presentation_code/toom_cook.py index bd8169b..a045edb 100644 --- a/src/bigint_presentation_code/toom_cook.py +++ b/src/bigint_presentation_code/toom_cook.py @@ -84,7 +84,7 @@ class StackSlot(GPROrStackLoc): @plain_data(eq=False) -class SSAVal: +class SSAVal(metaclass=ABCMeta): __slots__ = "id", def __init__(self, id=None): @@ -101,6 +101,19 @@ class SSAVal: def __hash__(self): return hash(self.id) + def _get_phys_loc(self, phys_loc_in, value_assignments=None): + # type: (PhysLoc | None, dict[SSAVal, PhysLoc] | None) -> PhysLoc | None + if phys_loc_in is not None: + return phys_loc_in + if value_assignments is not None: + return value_assignments.get(self) + return None + + @abstractmethod + def get_phys_loc(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> PhysLoc | None + ... + @plain_data(eq=False) @final @@ -115,17 +128,43 @@ class SSAGPRVal(SSAVal): def __len__(self): return 1 + def get_phys_loc(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> GPROrStackLoc | None + loc = self._get_phys_loc(self.phys_loc, value_assignments) + if isinstance(loc, GPROrStackLoc): + return loc + return None + def get_reg_num(self, value_assignments=None): # type: (dict[SSAVal, PhysLoc] | None) -> int | None - phys_loc = None - if value_assignments is not None: - phys_loc = value_assignments.get(self) - if self.phys_loc is not None: - phys_loc = self.phys_loc - if isinstance(phys_loc, GPR): - return phys_loc.reg_num + reg = self.get_reg(value_assignments) + if reg is not None: + return reg.reg_num + return None + + def get_reg(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> GPR | None + loc = self.get_phys_loc(value_assignments) + if isinstance(loc, GPR): + return loc + return None + + def get_stack_slot(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> StackSlot | None + loc = self.get_phys_loc(value_assignments) + if isinstance(loc, StackSlot): + return loc return None + def possible_reg_assignments(self, value_assignments, + conflicting_regs=set()): + # type: (dict[SSAVal, PhysLoc] | None, set[GPR]) -> Iterable[GPR] + if self.get_phys_loc(value_assignments) is not None: + raise ValueError("can't assign a already-assigned SSA value") + for reg in GPR: + if reg not in conflicting_regs: + yield reg + @plain_data(eq=False) @final @@ -136,6 +175,13 @@ class SSAXERBitVal(SSAVal): # type: (XERBit | None) -> None self.phys_loc = phys_loc + def get_phys_loc(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> XERBit | None + loc = self._get_phys_loc(self.phys_loc, value_assignments) + if isinstance(loc, XERBit): + return loc + return None + @plain_data(eq=False) @final @@ -146,6 +192,13 @@ class SSAMemory(SSAVal): # type: (GlobalMem) -> None self.phys_loc = phys_loc + def get_phys_loc(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> GlobalMem | None + loc = self._get_phys_loc(self.phys_loc, value_assignments) + if isinstance(loc, GlobalMem): + return loc + return None + @plain_data(unsafe_hash=True, frozen=True) @final @@ -159,46 +212,65 @@ class VecArg: def __len__(self): return len(self.regs) - def try_get_range(self, value_assignments=None, prefer_descending=False): - # type: (dict[SSAVal, PhysLoc] | None, int) -> range | None + def is_unassigned(self, value_assignments=None): + # type: (dict[SSAVal, PhysLoc] | None) -> bool + for val in self.regs: + if val.get_phys_loc(value_assignments) is not None: + return False + return True + + def try_get_range(self, value_assignments=None, allow_unassigned=False, + raise_if_invalid=False): + # type: (dict[SSAVal, PhysLoc] | None, bool, bool) -> range | None if len(self.regs) == 0: return range(0) - start = self.regs[0].get_reg_num(value_assignments) - if start is None: - return None - if len(self.regs) == 1: - if prefer_descending: - return range(start, start - 1, -1) - return range(start, start + 1, 1) - reg = self.regs[1].get_reg_num(value_assignments) - if reg is None: - return None - step = reg - start - if abs(step) != 1: - return None - for i, reg in enumerate(self.regs): - reg = self.regs[1].get_reg_num(value_assignments) + retval = None # type: range | None + for i, val in enumerate(self.regs): + if val.get_phys_loc(value_assignments) is None: + if not allow_unassigned: + if raise_if_invalid: + raise ValueError("not a valid register range: " + "unassigned SSA value encountered") + return None + continue + reg = val.get_reg_num(value_assignments) if reg is None: + if raise_if_invalid: + raise ValueError("not a valid register range: " + "non-register encountered") return None - if start + i * step != reg: + expected_range = range(reg - i, reg - i + len(self.regs)) + if retval is None: + retval = expected_range + elif retval != expected_range: + if raise_if_invalid: + raise ValueError("not a valid register range: " + "register out of sequence") return None - return range(start, start + len(self.regs) * step, step) - - def try_get_ascending_range(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> range | None - r = self.try_get_range(value_assignments=value_assignments) - if r is not None and r.step == 1: - return r - return None - - def try_get_descending_range(self, value_assignments=None): - # type: (dict[SSAVal, PhysLoc] | None) -> range | None - r = self.try_get_range(value_assignments=value_assignments, - prefer_descending=True) - if r is not None and r.step == -1: - return r - return None + return retval + + def possible_reg_assignments( + self, + val, # type: SSAVal + value_assignments, # type: dict[SSAVal, PhysLoc] | None + conflicting_regs=set(), # type: set[GPR] + ): # type: (...) -> Iterable[GPR] + index = self.regs.index(val) + alignment = 1 + while alignment < len(self.regs): + alignment *= 2 + r = self.try_get_range(value_assignments) + if r is not None and r.start % alignment != 0: + raise ValueError("must be a ascending aligned range of GPRs") + if r is None: + for i in range(0, len(GPR), alignment): + r = range(i, i + len(self.regs)) + if any(GPR(reg) in conflicting_regs for reg in r): + continue + yield GPR(r[index]) + else: + yield GPR(r[index]) @plain_data(unsafe_hash=True, frozen=True) @@ -232,9 +304,9 @@ class Op(metaclass=ABCMeta): ... @abstractmethod - def meets_constraints(self, value_assignments): - # type: (dict[SSAVal, PhysLoc]) -> bool - return True + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + ... def __init__(self): pass @@ -265,25 +337,49 @@ class OpCopy(Op): self.dest = dest self.src = src - def meets_constraints(self, value_assignments): - # type: (dict[SSAVal, PhysLoc]) -> bool - return True + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + conflicting_regs = set() # type: set[GPR] + if val in self.output_ssa_vals() and isinstance(self.dest, VecArg): + # OpCopy is the only op that can write to physical locations in + # any order, it handles figuring out the right instruction sequence + dest_locs = {} # type: dict[GPROrStackLoc, SSAVal] + for val in self.dest.regs: + loc = val.get_phys_loc(value_assignments) + if loc is None: + continue + if loc in dest_locs: + raise ValueError( + f"duplicate destination location not allowed: " + f"{val} is assigned to {loc} which is also " + f"written by {dest_locs[loc]}") + dest_locs[loc] = val + if isinstance(loc, GPR): + conflicting_regs.add(loc) + if not isinstance(val, SSAGPRVal): + raise ValueError("invalid operand type") + return val.possible_reg_assignments(value_assignments, + conflicting_regs) def range_overlaps(range1, range2): # type: (range, range) -> bool if len(range1) == 0 or len(range2) == 0: return False - range1_last = next(reversed(range1)) - range2_last = next(reversed(range2)) + range1_last = range1[-1] + range2_last = range2[-1] return (range1.start in range2 or range1_last in range2 or range2.start in range1 or range2_last in range1) - @plain_data(unsafe_hash=True, frozen=True) @final -class OpAddE(Op): - __slots__ = "RT", "RA", "RB", "CY_in", "CY_out" +class OpAddSubE(Op): + __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub" def inputs(self): # type: () -> dict[str, VecArg | SSAVal] @@ -293,8 +389,8 @@ class OpAddE(Op): # type: () -> dict[str, VecArg | SSAVal] return {"RT": self.RT, "CY_out": self.CY_out} - def __init__(self, RT, RA, RB, CY_in, CY_out): - # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None + def __init__(self, RT, RA, RB, CY_in, CY_out, is_sub): + # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal, bool) -> None if len(RA.regs) != len(RT.regs): raise TypeError(f"source length must match dest " f"length: {RA} doesn't match {RT}") @@ -306,110 +402,43 @@ class OpAddE(Op): self.RB = RB self.CY_in = CY_in self.CY_out = CY_out - - def meets_constraints(self, value_assignments): - # type: (dict[SSAVal, PhysLoc]) -> bool - RT_range = self.RT.try_get_ascending_range(value_assignments) - RA_range = self.RA.try_get_ascending_range(value_assignments) - RB_range = self.RB.try_get_ascending_range(value_assignments) - if RA_range is None or RB_range is None or RT_range is None: - return False - if RA_range != RT_range and range_overlaps(RA_range, RT_range): - return False - if RB_range != RT_range and range_overlaps(RB_range, RT_range): - return False - return True - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpSubE(Op): - __slots__ = "RT", "lhs", "rhs", "CY_in", "CY_out" - - def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] - return {"lhs": self.lhs, "rhs": self.rhs, "CY_in": self.CY_in} - - def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] - return {"RT": self.RT, "CY_out": self.CY_out} - - def __init__(self, RT, lhs, rhs, CY_in, CY_out): - # type: (VecArg, VecArg, VecArg, SSAXERBitVal, SSAXERBitVal) -> None - if len(lhs.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {lhs} doesn't match {RT}") - if len(rhs.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {rhs} doesn't match {RT}") - self.RT = RT - self.lhs = lhs - self.rhs = rhs - self.CY_in = CY_in - self.CY_out = CY_out - - def meets_constraints(self, value_assignments): - # type: (dict[SSAVal, PhysLoc]) -> bool - RT_range = self.RT.try_get_ascending_range(value_assignments) - lhs_range = self.lhs.try_get_ascending_range(value_assignments) - rhs_range = self.rhs.try_get_ascending_range(value_assignments) - if lhs_range is None or rhs_range is None or RT_range is None: - return False - if lhs_range != RT_range and range_overlaps(lhs_range, RT_range): - return False - if rhs_range != RT_range and range_overlaps(rhs_range, RT_range): - return False - return True - - -@plain_data(unsafe_hash=True, frozen=True) -@final -class OpMAddEU(Op): - __slots__ = "RT", "RA", "RB", "RC", "RS" - - def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] - return {"RA": self.RA, "RB": self.RB, "RC": self.RC} - - def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] - return {"RT": self.RT, "RS": self.RS} - - def __init__(self, RT, RA, RB, RC, RS): - # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None - if len(RA.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {RA} doesn't match {RT}") - self.RT = RT - self.RA = RA - self.RB = RB - self.RC = RC - self.RS = RS - - def meets_constraints(self, value_assignments): - # type: (dict[SSAVal, PhysLoc]) -> bool - RT_range = self.RT.try_get_ascending_range(value_assignments) - RA_range = self.RA.try_get_ascending_range(value_assignments) - RB_reg = self.RB.get_reg_num(value_assignments) - RC_reg = self.RC.get_reg_num(value_assignments) - RS_reg = self.RS.get_reg_num(value_assignments) - if RT_range is None or RA_range is None or RB_reg is None \ - or RC_reg is None or RS_reg is None: - return False - if RA_range != RT_range and range_overlaps(RA_range, RT_range): - return False - if RB_reg in RT_range or RC_reg in RT_range or RS_reg in RT_range \ - or RS_reg in RA_range: - return False - if RS_reg != RC_reg: - return False - return True + self.is_sub = is_sub + + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + if self.CY_in == val or self.CY_out == val: + yield XERBit.CY + elif val in self.RT.regs: + # since possible_reg_assignments only returns aligned + # vectors, all possible assignments either are the same as an + # input or don't overlap with an input and we avoid the incorrect + # results caused by partial overlaps overwriting input elements + # before they're read + yield from self.RT.possible_reg_assignments(val, value_assignments) + elif val in self.RA.regs: + yield from self.RA.possible_reg_assignments(val, value_assignments) + else: + yield from self.RB.possible_reg_assignments(val, value_assignments) + + +def to_reg_set(v): + # type: (None | GPR | range) -> set[GPR] + if v is None: + return set() + if isinstance(v, range): + return set(map(GPR, v)) + return {v} @plain_data(unsafe_hash=True, frozen=True) @final -class OpDivMod2DU(Op): - __slots__ = "RT", "RA", "RB", "RC", "RS" +class OpBigIntMulDiv(Op): + __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div" def inputs(self): # type: () -> dict[str, VecArg | SSAVal] @@ -419,8 +448,8 @@ class OpDivMod2DU(Op): # type: () -> dict[str, VecArg | SSAVal] return {"RT": self.RT, "RS": self.RS} - def __init__(self, RT, RA, RB, RC, RS): - # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal) -> None + def __init__(self, RT, RA, RB, RC, RS, is_div): + # type: (VecArg, VecArg, SSAGPRVal, SSAGPRVal, SSAGPRVal, bool) -> None if len(RA.regs) != len(RT.regs): raise TypeError(f"source length must match dest " f"length: {RA} doesn't match {RT}") @@ -429,35 +458,60 @@ class OpDivMod2DU(Op): self.RB = RB self.RC = RC self.RS = RS + self.is_div = is_div + + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + RT_range = self.RT.try_get_range(value_assignments, + allow_unassigned=True, + raise_if_invalid=True) + RA_range = self.RA.try_get_range(value_assignments, + allow_unassigned=True, + raise_if_invalid=True) + RC_RS_reg = self.RC.get_reg(value_assignments) + if RC_RS_reg is None: + RC_RS_reg = self.RS.get_reg(value_assignments) + + if self.RC == val or self.RS == val: + if RC_RS_reg is not None: + yield RC_RS_reg + else: + conflicting_regs = to_reg_set(RT_range) | to_reg_set(RA_range) + yield from self.RC.possible_reg_assignments(value_assignments, + conflicting_regs) + elif val in self.RT.regs: + # since possible_reg_assignments only returns aligned + # vectors, all possible assignments either are the same as + # RA or don't overlap with RA and we avoid the incorrect + # results caused by partial overlaps overwriting input elements + # before they're read + yield from self.RT.possible_reg_assignments( + val, value_assignments, + conflicting_regs=to_reg_set(RA_range) | to_reg_set(RC_RS_reg)) + else: + yield from self.RA.possible_reg_assignments( + val, value_assignments, + conflicting_regs=to_reg_set(RT_range) | to_reg_set(RC_RS_reg)) -@plain_data(unsafe_hash=True, frozen=True) @final -class OpBigIntShl(Op): - __slots__ = "RT", "inp", "sh" - - def inputs(self): - # type: () -> dict[str, VecArg | SSAVal] - return {"inp": self.inp, "sh": self.sh} - - def outputs(self): - # type: () -> dict[str, VecArg | SSAVal] - return {"RT": self.RT} - - def __init__(self, RT, inp, sh): - # type: (VecArg, VecArg, SSAGPRVal) -> None - if len(inp.regs) != len(RT.regs): - raise TypeError(f"source length must match dest " - f"length: {inp} doesn't match {RT}") - self.RT = RT - self.inp = inp - self.sh = sh +@unique +class ShiftKind(Enum): + Sl = "sl" + Sr = "sr" + Sra = "sra" @plain_data(unsafe_hash=True, frozen=True) @final -class OpBigIntShr(Op): - __slots__ = "RT", "inp", "sh" +class OpBigIntShift(Op): + __slots__ = "RT", "inp", "sh", "kind" def inputs(self): # type: () -> dict[str, VecArg | SSAVal] @@ -467,14 +521,49 @@ class OpBigIntShr(Op): # type: () -> dict[str, VecArg | SSAVal] return {"RT": self.RT} - def __init__(self, RT, inp, sh): - # type: (VecArg, VecArg, SSAGPRVal) -> None + def __init__(self, RT, inp, sh, kind): + # type: (VecArg, VecArg, SSAGPRVal, ShiftKind) -> None if len(inp.regs) != len(RT.regs): raise TypeError(f"source length must match dest " f"length: {inp} doesn't match {RT}") self.RT = RT self.inp = inp self.sh = sh + self.kind = kind + + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + RT_range = self.RT.try_get_range(value_assignments, + allow_unassigned=True, + raise_if_invalid=True) + inp_range = self.inp.try_get_range(value_assignments, + allow_unassigned=True, + raise_if_invalid=True) + sh_reg = self.sh.get_reg(value_assignments) + + if self.sh == val: + conflicting_regs = to_reg_set(RT_range) + yield from self.sh.possible_reg_assignments(value_assignments, + conflicting_regs) + elif val in self.RT.regs: + # since possible_reg_assignments only returns aligned + # vectors, all possible assignments either are the same as + # RA or don't overlap with RA and we avoid the incorrect + # results caused by partial overlaps overwriting input elements + # before they're read + yield from self.RT.possible_reg_assignments( + val, value_assignments, + conflicting_regs=to_reg_set(inp_range) | to_reg_set(sh_reg)) + else: + yield from self.inp.possible_reg_assignments( + val, value_assignments, + conflicting_regs=to_reg_set(RT_range)) @plain_data(unsafe_hash=True, frozen=True) @@ -495,6 +584,20 @@ class OpLI(Op): self.out = out self.value = value + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + if isinstance(self.out, VecArg): + yield from self.out.possible_reg_assignments(val, + value_assignments) + else: + yield from self.out.possible_reg_assignments(value_assignments) + @plain_data(unsafe_hash=True, frozen=True) @final @@ -513,6 +616,16 @@ class OpClearCY(Op): # type: (SSAXERBitVal) -> None self.out = out + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + yield XERBit.CY + @plain_data(unsafe_hash=True, frozen=True) @final @@ -534,6 +647,34 @@ class OpLoad(Op): self.offset = offset self.mem = mem + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + RA_reg = self.RA.get_reg(value_assignments) + + if self.mem == val: + yield GlobalMem.GlobalMem + elif self.RA == val: + if isinstance(self.RT, VecArg): + conflicting_regs = to_reg_set(self.RT.try_get_range( + value_assignments, allow_unassigned=True, + raise_if_invalid=True)) + else: + conflicting_regs = set() + yield from self.RA.possible_reg_assignments(value_assignments, + conflicting_regs) + elif isinstance(self.RT, VecArg): + yield from self.RT.possible_reg_assignments( + val, value_assignments, + conflicting_regs=to_reg_set(RA_reg)) + else: + yield from self.RT.possible_reg_assignments(value_assignments) + @plain_data(unsafe_hash=True, frozen=True) @final @@ -556,6 +697,23 @@ class OpStore(Op): self.mem_in = mem_in self.mem_out = mem_out + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + if self.mem_in == val or self.mem_out == val: + yield GlobalMem.GlobalMem + elif self.RA == val: + yield from self.RA.possible_reg_assignments(value_assignments) + elif isinstance(self.RS, VecArg): + yield from self.RS.possible_reg_assignments(val, value_assignments) + else: + yield from self.RS.possible_reg_assignments(value_assignments) + @plain_data(unsafe_hash=True, frozen=True) @final @@ -574,6 +732,20 @@ class OpFuncArg(Op): # type: (VecArg | SSAGPRVal) -> None self.out = out + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + if isinstance(self.out, VecArg): + yield from self.out.possible_reg_assignments(val, + value_assignments) + else: + yield from self.out.possible_reg_assignments(value_assignments) + @plain_data(unsafe_hash=True, frozen=True) @final @@ -592,6 +764,16 @@ class OpInputMem(Op): # type: (SSAMemory) -> None self.out = out + def possible_reg_assignments(self, val, value_assignments): + # type: (SSAVal, dict[SSAVal, PhysLoc]) -> Iterable[PhysLoc] + if val not in self.input_ssa_vals() \ + and val not in self.output_ssa_vals(): + raise ValueError(f"{val} must be an operand of {self}") + if val.get_phys_loc(value_assignments) is not None: + raise ValueError(f"{val} already assigned a physical location") + + yield GlobalMem.GlobalMem + def op_set_to_list(ops): # type: (Iterable[Op]) -> list[Op] -- 2.30.2