+++ /dev/null
-# type: ignore
-"""
-Compiler IR for Toom-Cook algorithm generator for SVP64
-
-This assumes VL != 0 throughout.
-"""
-
-from abc import ABCMeta, abstractmethod
-from collections import defaultdict
-from enum import Enum, EnumMeta, unique
-from functools import lru_cache
-from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast
-
-from nmutil.plain_data import fields, plain_data
-
-from bigint_presentation_code.type_util import final
-from bigint_presentation_code.util import FMap, OFSet, OSet
-
-
-class ABCEnumMeta(EnumMeta, ABCMeta):
- pass
-
-
-class RegLoc(metaclass=ABCMeta):
- __slots__ = ()
-
- @abstractmethod
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- ...
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> RegLoc
- if self not in subreg_type.reg_class:
- raise ValueError(f"register not a member of subreg_type: "
- f"reg={self} subreg_type={subreg_type}")
- if offset != 0:
- raise ValueError(f"non-zero sub-register offset not supported "
- f"for register: {self}")
- return self
-
-
-GPR_COUNT = 128
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-@final
-class GPRRange(RegLoc, Sequence["GPRRange"]):
- __slots__ = "start", "length"
-
- def __init__(self, start, length=None):
- # type: (int | range, int | None) -> None
- if isinstance(start, range):
- if length is not None:
- raise TypeError("can't specify length when input is a range")
- if start.step != 1:
- raise ValueError("range must have a step of 1")
- length = len(start)
- start = start.start
- elif length is None:
- length = 1
- if length <= 0 or start < 0 or start + length > GPR_COUNT:
- raise ValueError("invalid GPRRange")
- self.start = start
- self.length = length
-
- @property
- def stop(self):
- return self.start + self.length
-
- @property
- def step(self):
- return 1
-
- @property
- def range(self):
- return range(self.start, self.stop, self.step)
-
- def __len__(self):
- return self.length
-
- def __getitem__(self, item):
- # type: (int | slice) -> GPRRange
- return GPRRange(self.range[item])
-
- def __contains__(self, value):
- # type: (GPRRange) -> bool
- return value.start >= self.start and value.stop <= self.stop
-
- def index(self, sub, start=None, end=None):
- # type: (GPRRange, int | None, int | None) -> int
- r = self.range[start:end]
- if sub.start < r.start or sub.stop > r.stop:
- raise ValueError("GPR range not found")
- return sub.start - self.start
-
- def count(self, sub, start=None, end=None):
- # type: (GPRRange, int | None, int | None) -> int
- r = self.range[start:end]
- if len(r) == 0:
- return 0
- return int(sub in GPRRange(r))
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, GPRRange):
- return self.stop > other.start and other.stop > self.start
- return False
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> GPRRange
- if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
- raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
- f"GPRRangeType: {subreg_type}")
- if offset < 0 or offset + subreg_type.length > self.stop:
- raise ValueError(f"sub-register offset is out of range: {offset}")
- return GPRRange(self.start + offset, subreg_type.length)
-
- def __repr__(self):
- if self.length == 1:
- return f"<r{self.start}>"
- return f"<r{self.start}..len={self.length}>"
-
-
-SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13)
-
-
-@final
-@unique
-class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
- CA = "CA"
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, XERBit):
- return self == other
- return False
-
-
-@final
-@unique
-class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta):
- """singleton representing all non-StackSlot memory -- treated as a single
- physical register for register allocation purposes.
- """
- GlobalMem = "GlobalMem"
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, GlobalMem):
- return self == other
- return False
-
-
-@final
-@unique
-class VL(RegLoc, Enum, metaclass=ABCEnumMeta):
- VL_MAXVL = "VL_MAXVL"
- """VL and MAXVL"""
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, VL):
- return self == other
- return False
-
-
-@final
-class RegClass(OFSet[RegLoc]):
- """ an ordered set of registers.
- earlier registers are preferred by the register allocator.
- """
-
- @lru_cache(maxsize=None, typed=True)
- def max_conflicts_with(self, other):
- # type: (RegClass | RegLoc) -> int
- """the largest number of registers in `self` that a single register
- from `other` can conflict with
- """
- if isinstance(other, RegClass):
- return max(self.max_conflicts_with(i) for i in other)
- else:
- return sum(other.conflicts(i) for i in self)
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-class RegType(metaclass=ABCMeta):
- __slots__ = ()
-
- @property
- @abstractmethod
- def reg_class(self):
- # type: () -> RegClass
- return ...
-
-
-_RegType = TypeVar("_RegType", bound=RegType)
-_RegLoc = TypeVar("_RegLoc", bound=RegLoc)
-
-
-@plain_data(frozen=True, eq=False, repr=False)
-@final
-class GPRRangeType(RegType):
- __slots__ = "length",
-
- def __init__(self, length=1):
- # type: (int) -> None
- if length < 1 or length > GPR_COUNT:
- raise ValueError("invalid length")
- self.length = length
-
- @staticmethod
- @lru_cache(maxsize=None)
- def __get_reg_class(length):
- # type: (int) -> RegClass
- regs = []
- for start in range(GPR_COUNT - length):
- reg = GPRRange(start, length)
- if any(i in reg for i in SPECIAL_GPRS):
- continue
- regs.append(reg)
- return RegClass(regs)
-
- @property
- @final
- def reg_class(self):
- # type: () -> RegClass
- return GPRRangeType.__get_reg_class(self.length)
-
- @final
- def __eq__(self, other):
- if isinstance(other, GPRRangeType):
- return self.length == other.length
- return False
-
- @final
- def __hash__(self):
- return hash(self.length)
-
- def __repr__(self):
- return f"<gpr_ty[{self.length}]>"
-
-
-GPRType = GPRRangeType
-"""a length=1 GPRRangeType"""
-
-
-@plain_data(frozen=True, unsafe_hash=True, repr=False)
-@final
-class FixedGPRRangeType(RegType):
- __slots__ = "reg",
-
- def __init__(self, reg):
- # type: (GPRRange) -> None
- self.reg = reg
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([self.reg])
-
- @property
- def length(self):
- # type: () -> int
- return self.reg.length
-
- def __repr__(self):
- return f"<fixed({self.reg})>"
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class CAType(RegType):
- __slots__ = ()
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([XERBit.CA])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class GlobalMemType(RegType):
- __slots__ = ()
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([GlobalMem.GlobalMem])
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class KnownVLType(RegType):
- __slots__ = "length",
-
- def __init__(self, length):
- # type: (int) -> None
- if not (0 < length <= 64):
- raise ValueError("invalid VL value")
- self.length = length
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return RegClass([VL.VL_MAXVL])
-
-
-def assert_vl_is(vl, expected_vl):
- # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
- if vl is None:
- vl = 1
- elif isinstance(vl, SSAVal):
- vl = vl.ty.length
- elif isinstance(vl, KnownVLType):
- vl = vl.length
- if vl != expected_vl:
- raise ValueError(
- f"wrong VL: expected {expected_vl} got {vl}")
-
-
-STACK_SLOT_SIZE = 8
-
-
-@plain_data(frozen=True, unsafe_hash=True)
-@final
-class StackSlot(RegLoc):
- __slots__ = "start_slot", "length_in_slots",
-
- def __init__(self, start_slot, length_in_slots):
- # type: (int, int) -> None
- self.start_slot = start_slot
- if length_in_slots < 1:
- raise ValueError("invalid length_in_slots")
- self.length_in_slots = length_in_slots
-
- @property
- def stop_slot(self):
- return self.start_slot + self.length_in_slots
-
- @property
- def start_byte(self):
- return self.start_slot * STACK_SLOT_SIZE
-
- def conflicts(self, other):
- # type: (RegLoc) -> bool
- if isinstance(other, StackSlot):
- return (self.stop_slot > other.start_slot
- and other.stop_slot > self.start_slot)
- return False
-
- def get_subreg_at_offset(self, subreg_type, offset):
- # type: (RegType, int) -> StackSlot
- if not isinstance(subreg_type, StackSlotType):
- raise ValueError(f"subreg_type is not a "
- f"StackSlotType: {subreg_type}")
- if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot:
- raise ValueError(f"sub-register offset is out of range: {offset}")
- return StackSlot(self.start_slot + offset, subreg_type.length_in_slots)
-
-
-STACK_SLOT_COUNT = 128
-
-
-@plain_data(frozen=True, eq=False)
-@final
-class StackSlotType(RegType):
- __slots__ = "length_in_slots",
-
- def __init__(self, length_in_slots=1):
- # type: (int) -> None
- if length_in_slots < 1:
- raise ValueError("invalid length_in_slots")
- self.length_in_slots = length_in_slots
-
- @staticmethod
- @lru_cache(maxsize=None)
- def __get_reg_class(length_in_slots):
- # type: (int) -> RegClass
- regs = []
- for start in range(STACK_SLOT_COUNT - length_in_slots):
- reg = StackSlot(start, length_in_slots)
- regs.append(reg)
- return RegClass(regs)
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return StackSlotType.__get_reg_class(self.length_in_slots)
-
- @final
- def __eq__(self, other):
- if isinstance(other, StackSlotType):
- return self.length_in_slots == other.length_in_slots
- return False
-
- @final
- def __hash__(self):
- return hash(self.length_in_slots)
-
-
-@plain_data(frozen=True, eq=False, repr=False)
-@final
-class SSAVal(Generic[_RegType]):
- __slots__ = "op", "arg_name", "ty",
-
- def __init__(self, op, arg_name, ty):
- # type: (Op, str, _RegType) -> None
- self.op = op
- """the Op that writes this SSAVal"""
-
- self.arg_name = arg_name
- """the name of the argument of self.op that writes this SSAVal"""
-
- self.ty = ty
-
- def __eq__(self, rhs):
- if isinstance(rhs, SSAVal):
- return (self.op is rhs.op
- and self.arg_name == rhs.arg_name)
- return False
-
- def __hash__(self):
- return hash((id(self.op), self.arg_name))
-
- def __repr__(self):
- return f"<#{self.op.id}.{self.arg_name}: {self.ty}>"
-
-
-SSAGPRRange = SSAVal[GPRRangeType]
-SSAGPR = SSAVal[GPRType]
-SSAKnownVL = SSAVal[KnownVLType]
-
-
-@final
-@plain_data(unsafe_hash=True, frozen=True)
-class EqualityConstraint:
- __slots__ = "lhs", "rhs"
-
- def __init__(self, lhs, rhs):
- # type: (list[SSAVal], list[SSAVal]) -> None
- self.lhs = lhs
- self.rhs = rhs
- if len(lhs) == 0 or len(rhs) == 0:
- raise ValueError("can't constrain an empty list to be equal")
-
-
-@final
-class Fn:
- __slots__ = "ops",
-
- def __init__(self):
- # type: () -> None
- self.ops = [] # type: list[Op]
-
- def __repr__(self, short=False):
- if short:
- return "<Fn>"
- ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
- return f"<Fn([{ops}])>"
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- for op in self.ops:
- op.pre_ra_sim(state)
-
-
-class _NotSet:
- """ helper for __repr__ for when fields aren't set """
-
- def __repr__(self):
- return "<not set>"
-
-
-_NOT_SET = _NotSet()
-
-
-@final
-class AsmContext:
- def __init__(self, assigned_registers):
- # type: (dict[SSAVal, RegLoc]) -> None
- self.__assigned_registers = assigned_registers
-
- def reg(self, ssa_val, expected_ty):
- # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
- try:
- reg = self.__assigned_registers[ssa_val]
- except KeyError as e:
- raise ValueError(f"SSAVal not assigned a register: {ssa_val}")
- wrong_len = (isinstance(reg, GPRRange)
- and reg.length != ssa_val.ty.length)
- if not isinstance(reg, expected_ty) or wrong_len:
- raise TypeError(
- f"SSAVal is assigned a register of the wrong type: "
- f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
- return reg
-
- def gpr_range(self, ssa_val):
- # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
- return self.reg(ssa_val, GPRRange)
-
- def stack_slot(self, ssa_val):
- # type: (SSAVal[StackSlotType]) -> StackSlot
- return self.reg(ssa_val, StackSlot)
-
- def gpr(self, ssa_val, vec, offset=0):
- # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
- reg = self.gpr_range(ssa_val).start + offset
- return "*" * vec + str(reg)
-
- def vgpr(self, ssa_val, offset=0):
- # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
- return self.gpr(ssa_val=ssa_val, vec=True, offset=offset)
-
- def sgpr(self, ssa_val, offset=0):
- # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
- return self.gpr(ssa_val=ssa_val, vec=False, offset=offset)
-
- def needs_sv(self, *regs):
- # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
- for reg in regs:
- reg = self.gpr_range(reg)
- if reg.length != 1 or reg.start >= 32:
- return True
- return False
-
-
-GPR_SIZE_IN_BYTES = 8
-GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8
-GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1
-
-
-@plain_data(frozen=True)
-@final
-class PreRASimState:
- __slots__ = ("gprs", "VLs", "CAs",
- "global_mems", "stack_slots",
- "fixed_gprs")
-
- def __init__(
- self,
- gprs, # type: dict[SSAGPRRange, tuple[int, ...]]
- VLs, # type: dict[SSAKnownVL, int]
- CAs, # type: dict[SSAVal[CAType], bool]
- global_mems, # type: dict[SSAVal[GlobalMemType], FMap[int, int]]
- stack_slots, # type: dict[SSAVal[StackSlotType], tuple[int, ...]]
- fixed_gprs, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]]
- ):
- # type: (...) -> None
- self.gprs = gprs
- self.VLs = VLs
- self.CAs = CAs
- self.global_mems = global_mems
- self.stack_slots = stack_slots
- self.fixed_gprs = fixed_gprs
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-class Op(metaclass=ABCMeta):
- __slots__ = "id", "fn"
-
- @abstractmethod
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- ...
-
- @abstractmethod
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- ...
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- if False:
- yield ...
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- if False:
- yield ...
-
- def __init__(self, fn):
- # type: (Fn) -> None
- self.id = len(fn.ops)
- fn.ops.append(self)
- self.fn = fn
-
- @final
- def __repr__(self, just_id=False):
- fields_list = [f"#{self.id}"]
- outputs = None
- try:
- outputs = self.outputs()
- except AttributeError:
- pass
- if not just_id:
- for name in fields(self):
- if name in ("id", "fn"):
- continue
- v = getattr(self, name, _NOT_SET)
- if (outputs is not None and name in outputs
- and outputs[name] is v):
- fields_list.append(repr(v))
- else:
- fields_list.append(f"{name}={v!r}")
- fields_str = ', '.join(fields_list)
- return f"{self.__class__.__name__}({fields_str})"
-
- @abstractmethod
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- """get the lines of assembly for this Op"""
- ...
-
- @abstractmethod
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- """simulate op before register allocation"""
- ...
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpLoadFromStackSlot(Op):
- __slots__ = "dest", "src", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, fn, src, vl=None):
- # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
- super().__init__(fn)
- self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
- self.src = src
- self.vl = vl
- assert_vl_is(vl, self.dest.ty.length)
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1)
- src = ctx.stack_slot(self.src)
- if ctx.needs_sv(self.dest):
- return [f"sv.ld {dest}, {src.start_byte}(1)"]
- return [f"ld {dest}, {src.start_byte}(1)"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- """simulate op before register allocation"""
- state.gprs[self.dest] = state.stack_slots[self.src]
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpStoreToStackSlot(Op):
- __slots__ = "dest", "src", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, fn, src, vl=None):
- # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
- super().__init__(fn)
- self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
- self.src = src
- self.vl = vl
- assert_vl_is(vl, src.ty.length)
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- src = ctx.gpr(self.src, vec=self.src.ty.length != 1)
- dest = ctx.stack_slot(self.dest)
- if ctx.needs_sv(self.src):
- return [f"sv.std {src}, {dest.start_byte}(1)"]
- return [f"std {src}, {dest.start_byte}(1)"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- """simulate op before register allocation"""
- state.stack_slots[self.dest] = state.gprs[self.src]
-
-
-_RegSrcType = TypeVar("_RegSrcType", bound=RegType)
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpCopy(Op, Generic[_RegSrcType, _RegType]):
- __slots__ = "dest", "src", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, fn, src, dest_ty=None, vl=None):
- # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
- super().__init__(fn)
- if dest_ty is None:
- dest_ty = cast(_RegType, src.ty)
- if isinstance(src.ty, GPRRangeType) \
- and isinstance(dest_ty, FixedGPRRangeType):
- if src.ty.length != dest_ty.reg.length:
- raise ValueError(f"incompatible source and destination "
- f"types: {src.ty} and {dest_ty}")
- length = src.ty.length
- elif isinstance(src.ty, FixedGPRRangeType) \
- and isinstance(dest_ty, GPRRangeType):
- if src.ty.reg.length != dest_ty.length:
- raise ValueError(f"incompatible source and destination "
- f"types: {src.ty} and {dest_ty}")
- length = src.ty.length
- elif src.ty != dest_ty:
- raise ValueError(f"incompatible source and destination "
- f"types: {src.ty} and {dest_ty}")
- elif isinstance(src.ty, StackSlotType):
- raise ValueError("can't use OpCopy on stack slots")
- elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
- length = src.ty.length
- else:
- length = 1
-
- self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
- self.src = src
- self.vl = vl
- assert_vl_is(vl, length)
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc):
- return []
- if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
- isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
- vec = self.dest.ty.length != 1
- dest = ctx.gpr_range(self.dest) # type: ignore
- src = ctx.gpr_range(self.src) # type: ignore
- dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore
- src_s = ctx.gpr(self.src, vec=vec) # type: ignore
- mrr = ""
- if src.conflicts(dest) and src.start > dest.start:
- mrr = "/mrr"
- if ctx.needs_sv(self.src, self.dest): # type: ignore
- return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
- return [f"or {dest_s}, {src_s}, {src_s}"]
- raise NotImplementedError
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
- isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
- if isinstance(self.src.ty, GPRRangeType):
- v = state.gprs[self.src] # type: ignore
- else:
- v = state.fixed_gprs[self.src] # type: ignore
- if isinstance(self.dest.ty, GPRRangeType):
- state.gprs[self.dest] = v # type: ignore
- else:
- state.fixed_gprs[self.dest] = v # type: ignore
- elif (isinstance(self.src.ty, FixedGPRRangeType) and
- isinstance(self.dest.ty, GPRRangeType)):
- state.gprs[self.dest] = state.fixed_gprs[self.src] # type: ignore
- elif (isinstance(self.src.ty, GPRRangeType) and
- isinstance(self.dest.ty, FixedGPRRangeType)):
- state.fixed_gprs[self.dest] = state.gprs[self.src] # type: ignore
- elif (isinstance(self.src.ty, CAType) and
- self.src.ty == self.dest.ty):
- state.CAs[self.dest] = state.CAs[self.src] # type: ignore
- elif (isinstance(self.src.ty, KnownVLType) and
- self.src.ty == self.dest.ty):
- state.VLs[self.dest] = state.VLs[self.src] # type: ignore
- elif (isinstance(self.src.ty, GlobalMemType) and
- self.src.ty == self.dest.ty):
- v = state.global_mems[self.src] # type: ignore
- state.global_mems[self.dest] = v # type: ignore
- else:
- raise NotImplementedError
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpConcat(Op):
- __slots__ = "dest", "sources"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {f"sources[{i}]": v for i, v in enumerate(self.sources)}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"dest": self.dest}
-
- def __init__(self, fn, sources):
- # type: (Fn, Iterable[SSAGPRRange]) -> None
- super().__init__(fn)
- sources = tuple(sources)
- self.dest = SSAVal(self, "dest", GPRRangeType(
- sum(i.ty.length for i in sources)))
- self.sources = sources
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([self.dest], [*self.sources])
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- return []
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- v = []
- for src in self.sources:
- v.extend(state.gprs[src])
- state.gprs[self.dest] = tuple(v)
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpSplit(Op):
- __slots__ = "results", "src"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"src": self.src}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {i.arg_name: i for i in self.results}
-
- def __init__(self, fn, src, split_indexes):
- # type: (Fn, SSAGPRRange, Iterable[int]) -> None
- super().__init__(fn)
- ranges = [] # type: list[GPRRangeType]
- last = 0
- for i in split_indexes:
- if not (0 < i < src.ty.length):
- raise ValueError(f"invalid split index: {i}, must be in "
- f"0 < i < {src.ty.length}")
- ranges.append(GPRRangeType(i - last))
- last = i
- ranges.append(GPRRangeType(src.ty.length - last))
- self.src = src
- self.results = tuple(
- SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges))
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([*self.results], [self.src])
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- return []
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- rest = state.gprs[self.src]
- for dest in reversed(self.results):
- state.gprs[dest] = rest[-dest.ty.length:]
- rest = rest[:-dest.ty.length]
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpBigIntAddSub(Op):
- __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {} # type: dict[str, SSAVal[Any]]
- retval["lhs"] = self.lhs
- retval["rhs"] = self.rhs
- retval["CA_in"] = self.CA_in
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out, "CA_out": self.CA_out}
-
- def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None):
- # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
- super().__init__(fn)
- if lhs.ty != rhs.ty:
- raise TypeError(f"source types must match: "
- f"{lhs} doesn't match {rhs}")
- self.out = SSAVal(self, "out", lhs.ty)
- self.lhs = lhs
- self.rhs = rhs
- self.CA_in = CA_in
- self.CA_out = SSAVal(self, "CA_out", CA_in.ty)
- self.is_sub = is_sub
- self.vl = vl
- assert_vl_is(vl, lhs.ty.length)
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.out, self.lhs
- yield self.out, self.rhs
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- vec = self.out.ty.length != 1
- out = ctx.gpr(self.out, vec=vec)
- RA = ctx.gpr(self.lhs, vec=vec)
- RB = ctx.gpr(self.rhs, vec=vec)
- mnemonic = "adde"
- if self.is_sub:
- mnemonic = "subfe"
- RA, RB = RB, RA # reorder to match subfe
- if ctx.needs_sv(self.out, self.lhs, self.rhs):
- return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
- return [f"{mnemonic} {out}, {RA}, {RB}"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- carry = state.CAs[self.CA_in]
- out = [] # type: list[int]
- for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]):
- if self.is_sub:
- r = r ^ GPR_VALUE_MASK
- s = l + r + carry
- carry = s != (s & GPR_VALUE_MASK)
- out.append(s & GPR_VALUE_MASK)
- state.CAs[self.CA_out] = carry
- state.gprs[self.out] = tuple(out)
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpBigIntMulDiv(Op):
- __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {} # type: dict[str, SSAVal[Any]]
- retval["RA"] = self.RA
- retval["RB"] = self.RB
- retval["RC"] = self.RC
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT, "RS": self.RS}
-
- def __init__(self, fn, RA, RB, RC, is_div, vl):
- # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
- super().__init__(fn)
- self.RT = SSAVal(self, "RT", RA.ty)
- self.RA = RA
- self.RB = RB
- self.RC = RC
- self.RS = SSAVal(self, "RS", RC.ty)
- self.is_div = is_div
- self.vl = vl
- assert_vl_is(vl, RA.ty.length)
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- yield EqualityConstraint([self.RC], [self.RS])
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.RA
- yield self.RT, self.RB
- yield self.RT, self.RC
- yield self.RT, self.RS
- yield self.RS, self.RA
- yield self.RS, self.RB
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- vec = self.RT.ty.length != 1
- RT = ctx.gpr(self.RT, vec=vec)
- RA = ctx.gpr(self.RA, vec=vec)
- RB = ctx.sgpr(self.RB)
- RC = ctx.sgpr(self.RC)
- mnemonic = "maddedu"
- if self.is_div:
- mnemonic = "divmod2du/mrr"
- return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- carry = state.gprs[self.RC][0]
- RA = state.gprs[self.RA]
- RB = state.gprs[self.RB][0]
- RT = [0] * self.RT.ty.length
- if self.is_div:
- for i in reversed(range(self.RT.ty.length)):
- if carry < RB and RB != 0:
- div, mod = divmod((carry << 64) | RA[i], RB)
- RT[i] = div & GPR_VALUE_MASK
- carry = mod & GPR_VALUE_MASK
- else:
- RT[i] = GPR_VALUE_MASK
- carry = 0
- else:
- for i in range(self.RT.ty.length):
- v = RA[i] * RB + carry
- carry = v >> 64
- RT[i] = v & GPR_VALUE_MASK
- state.gprs[self.RS] = carry,
- state.gprs[self.RT] = tuple(RT)
-
-
-@final
-@unique
-class ShiftKind(Enum):
- Sl = "sl"
- Sr = "sr"
- Sra = "sra"
-
- def make_big_int_carry_in(self, fn, inp):
- # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
- if self is ShiftKind.Sl or self is ShiftKind.Sr:
- li = OpLI(fn, 0)
- return li.out, [li]
- else:
- assert self is ShiftKind.Sra
- split = OpSplit(fn, inp, [inp.ty.length - 1])
- shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra)
- return shr.out, [split, shr]
-
- def make_big_int_shift(self, fn, inp, sh, vl):
- # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
- carry_in, ops = self.make_big_int_carry_in(fn, inp)
- big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl)
- ops.append(big_int_shift)
- return big_int_shift.out, ops
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpBigIntShift(Op):
- __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {} # type: dict[str, SSAVal[Any]]
- retval["inp"] = self.inp
- retval["sh"] = self.sh
- retval["carry_in"] = self.carry_in
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out, "_out_padding": self._out_padding}
-
- def __init__(self, fn, inp, sh, carry_in, kind, vl=None):
- # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
- super().__init__(fn)
- self.out = SSAVal(self, "out", inp.ty)
- self._out_padding = SSAVal(self, "_out_padding", GPRRangeType())
- self.carry_in = carry_in
- self.inp = inp
- self.sh = sh
- self.kind = kind
- self.vl = vl
- assert_vl_is(vl, inp.ty.length)
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.out, self.sh
-
- def get_equality_constraints(self):
- # type: () -> Iterable[EqualityConstraint]
- if self.kind is ShiftKind.Sl:
- yield EqualityConstraint([self.carry_in, self.inp],
- [self.out, self._out_padding])
- else:
- assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
- yield EqualityConstraint([self.inp, self.carry_in],
- [self._out_padding, self.out])
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- vec = self.out.ty.length != 1
- if self.kind is ShiftKind.Sl:
- RT = ctx.gpr(self.out, vec=vec)
- RA = ctx.gpr(self.out, vec=vec, offset=-1)
- RB = ctx.sgpr(self.sh)
- mrr = "/mrr" if vec else ""
- return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
- else:
- assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
- RT = ctx.gpr(self.out, vec=vec)
- RA = ctx.gpr(self.out, vec=vec, offset=1)
- RB = ctx.sgpr(self.sh)
- return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- out = [0] * self.out.ty.length
- carry = state.gprs[self.carry_in][0]
- sh = state.gprs[self.sh][0] % 64
- if self.kind is ShiftKind.Sl:
- inp = carry, *state.gprs[self.inp]
- for i in reversed(range(self.out.ty.length)):
- v = inp[i] | (inp[i + 1] << 64)
- v <<= sh
- out[i] = (v >> 64) & GPR_VALUE_MASK
- else:
- assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
- inp = *state.gprs[self.inp], carry
- for i in range(self.out.ty.length):
- v = inp[i] | (inp[i + 1] << 64)
- v >>= sh
- out[i] = v & GPR_VALUE_MASK
- # state.gprs[self._out_padding] is intentionally not written
- state.gprs[self.out] = tuple(out)
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpShiftImm(Op):
- __slots__ = "out", "inp", "sh", "kind", "ca_out"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {"inp": self.inp}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- if self.ca_out is not None:
- return {"out": self.out, "ca_out": self.ca_out}
- return {"out": self.out}
-
- def __init__(self, fn, inp, sh, kind):
- # type: (Fn, SSAGPR, int, ShiftKind) -> None
- super().__init__(fn)
- self.out = SSAVal(self, "out", inp.ty)
- self.inp = inp
- if not (0 <= sh < 64):
- raise ValueError("shift amount out of range")
- self.sh = sh
- self.kind = kind
- if self.kind is ShiftKind.Sra:
- self.ca_out = SSAVal(self, "ca_out", CAType())
- else:
- self.ca_out = None
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- out = ctx.sgpr(self.out)
- inp = ctx.sgpr(self.inp)
- if self.kind is ShiftKind.Sl:
- mnemonic = "rldicr"
- args = f"{self.sh}, {63 - self.sh}"
- elif self.kind is ShiftKind.Sr:
- mnemonic = "rldicl"
- v = (64 - self.sh) % 64
- args = f"{v}, {self.sh}"
- else:
- assert self.kind is ShiftKind.Sra
- mnemonic = "sradi"
- args = f"{self.sh}"
- if ctx.needs_sv(self.out, self.inp):
- return [f"sv.{mnemonic} {out}, {inp}, {args}"]
- return [f"{mnemonic} {out}, {inp}, {args}"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- inp = state.gprs[self.inp][0]
- if self.kind is ShiftKind.Sl:
- assert self.ca_out is None
- out = inp << self.sh
- elif self.kind is ShiftKind.Sr:
- assert self.ca_out is None
- out = inp >> self.sh
- else:
- assert self.kind is ShiftKind.Sra
- assert self.ca_out is not None
- if inp & (1 << 63): # sign extend
- inp -= 1 << 64
- out = inp >> self.sh
- ca = inp < 0 and (out << self.sh) != inp
- state.CAs[self.ca_out] = ca
- state.gprs[self.out] = out,
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpLI(Op):
- __slots__ = "out", "value", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {} # type: dict[str, SSAVal[Any]]
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, fn, value, vl=None):
- # type: (Fn, int, SSAKnownVL | None) -> None
- super().__init__(fn)
- if vl is None:
- length = 1
- else:
- length = vl.ty.length
- self.out = SSAVal(self, "out", GPRRangeType(length))
- if not (-1 << 15 <= value <= (1 << 15) - 1):
- raise ValueError(f"value out of range: {value}")
- self.value = value
- self.vl = vl
- assert_vl_is(vl, length)
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- vec = self.out.ty.length != 1
- out = ctx.gpr(self.out, vec=vec)
- if ctx.needs_sv(self.out):
- return [f"sv.addi {out}, 0, {self.value}"]
- return [f"addi {out}, 0, {self.value}"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- value = self.value & GPR_VALUE_MASK
- state.gprs[self.out] = (value,) * self.out.ty.length
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpSetCA(Op):
- __slots__ = "out", "value"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, fn, value):
- # type: (Fn, bool) -> None
- super().__init__(fn)
- self.out = SSAVal(self, "out", CAType())
- self.value = value
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- if self.value:
- return ["subfic 0, 0, -1"]
- return ["addic 0, 0, 0"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- state.CAs[self.out] = self.value
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpLoad(Op):
- __slots__ = "RT", "RA", "offset", "mem", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {} # type: dict[str, SSAVal[Any]]
- retval["RA"] = self.RA
- retval["mem"] = self.mem
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"RT": self.RT}
-
- def __init__(self, fn, RA, offset, mem, vl=None):
- # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
- super().__init__(fn)
- if vl is None:
- length = 1
- else:
- length = vl.ty.length
- self.RT = SSAVal(self, "RT", GPRRangeType(length))
- self.RA = RA
- if not (-1 << 15 <= offset <= (1 << 15) - 1):
- raise ValueError(f"offset out of range: {offset}")
- if offset % 4 != 0:
- raise ValueError(f"offset not aligned: {offset}")
- self.offset = offset
- self.mem = mem
- self.vl = vl
- assert_vl_is(vl, length)
-
- def get_extra_interferences(self):
- # type: () -> Iterable[tuple[SSAVal, SSAVal]]
- if self.RT.ty.length > 1:
- yield self.RT, self.RA
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1)
- RA = ctx.sgpr(self.RA)
- if ctx.needs_sv(self.RT, self.RA):
- return [f"sv.ld {RT}, {self.offset}({RA})"]
- return [f"ld {RT}, {self.offset}({RA})"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- addr = state.gprs[self.RA][0]
- addr += self.offset
- RT = [0] * self.RT.ty.length
- mem = state.global_mems[self.mem]
- for i in range(self.RT.ty.length):
- cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
- if cur_addr % GPR_SIZE_IN_BYTES != 0:
- raise ValueError(f"can't load from unaligned address: "
- f"{cur_addr:#x}")
- for j in range(GPR_SIZE_IN_BYTES):
- byte_val = mem.get(cur_addr + j, 0) & 0xFF
- RT[i] |= byte_val << (j * 8)
- state.gprs[self.RT] = tuple(RT)
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpStore(Op):
- __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl"
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- retval = {} # type: dict[str, SSAVal[Any]]
- retval["RS"] = self.RS
- retval["RA"] = self.RA
- retval["mem_in"] = self.mem_in
- if self.vl is not None:
- retval["vl"] = self.vl
- return retval
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"mem_out": self.mem_out}
-
- def __init__(self, fn, RS, RA, offset, mem_in, vl=None):
- # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
- super().__init__(fn)
- self.RS = RS
- self.RA = RA
- if not (-1 << 15 <= offset <= (1 << 15) - 1):
- raise ValueError(f"offset out of range: {offset}")
- if offset % 4 != 0:
- raise ValueError(f"offset not aligned: {offset}")
- self.offset = offset
- self.mem_in = mem_in
- self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
- self.vl = vl
- assert_vl_is(vl, RS.ty.length)
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1)
- RA = ctx.sgpr(self.RA)
- if ctx.needs_sv(self.RS, self.RA):
- return [f"sv.std {RS}, {self.offset}({RA})"]
- return [f"std {RS}, {self.offset}({RA})"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- mem = dict(state.global_mems[self.mem_in])
- addr = state.gprs[self.RA][0]
- addr += self.offset
- RS = state.gprs[self.RS]
- for i in range(self.RS.ty.length):
- cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK
- if cur_addr % GPR_SIZE_IN_BYTES != 0:
- raise ValueError(f"can't store to unaligned address: "
- f"{cur_addr:#x}")
- for j in range(GPR_SIZE_IN_BYTES):
- mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF
- state.global_mems[self.mem_out] = FMap(mem)
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpFuncArg(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, fn, ty):
- # type: (Fn, FixedGPRRangeType) -> None
- super().__init__(fn)
- self.out = SSAVal(self, "out", ty)
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- return []
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- if self.out not in state.fixed_gprs:
- state.fixed_gprs[self.out] = (0,) * self.out.ty.length
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpInputMem(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, fn):
- # type: (Fn) -> None
- super().__init__(fn)
- self.out = SSAVal(self, "out", GlobalMemType())
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- return []
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- if self.out not in state.global_mems:
- state.global_mems[self.out] = FMap()
-
-
-@plain_data(unsafe_hash=True, frozen=True, repr=False)
-@final
-class OpSetVLImm(Op):
- __slots__ = "out",
-
- def inputs(self):
- # type: () -> dict[str, SSAVal]
- return {}
-
- def outputs(self):
- # type: () -> dict[str, SSAVal]
- return {"out": self.out}
-
- def __init__(self, fn, length):
- # type: (Fn, int) -> None
- super().__init__(fn)
- self.out = SSAVal(self, "out", KnownVLType(length))
-
- def get_asm_lines(self, ctx):
- # type: (AsmContext) -> list[str]
- return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
-
- def pre_ra_sim(self, state):
- # type: (PreRASimState) -> None
- state.VLs[self.out] = self.out.ty.length
-
-
-def op_set_to_list(ops):
- # type: (Iterable[Op]) -> list[Op]
- worklists = [{}] # type: list[dict[Op, None]]
- inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
- ops_to_pending_input_count_map = {} # type: dict[Op, int]
- for op in ops:
- input_count = 0
- for val in op.inputs().values():
- input_count += 1
- inps_to_ops_map[val][op] = None
- while len(worklists) <= input_count:
- worklists.append({})
- ops_to_pending_input_count_map[op] = input_count
- worklists[input_count][op] = None
- retval = [] # type: list[Op]
- ready_vals = OSet() # type: OSet[SSAVal]
- while len(worklists[0]) != 0:
- writing_op = next(iter(worklists[0]))
- del worklists[0][writing_op]
- retval.append(writing_op)
- for val in writing_op.outputs().values():
- if val in ready_vals:
- raise ValueError(f"multiple instructions must not write "
- f"to the same SSA value: {val}")
- ready_vals.add(val)
- for reading_op in inps_to_ops_map[val]:
- pending = ops_to_pending_input_count_map[reading_op]
- del worklists[pending][reading_op]
- pending -= 1
- worklists[pending][reading_op] = None
- ops_to_pending_input_count_map[reading_op] = pending
- for worklist in worklists:
- for op in worklist:
- raise ValueError(f"instruction is part of a dependency loop or "
- f"its inputs are never written: {op}")
- return retval
-
-
-def generate_assembly(ops, assigned_registers=None):
- # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str]
- if assigned_registers is None:
- from bigint_presentation_code.register_allocator import \
- allocate_registers
- assigned_registers = allocate_registers(ops)
- ctx = AsmContext(assigned_registers)
- retval = [] # list[str]
- for op in ops:
- retval.extend(op.get_asm_lines(ctx))
- retval.append("bclr 20, 0, 0")
- return retval
+++ /dev/null
-"""
-Register Allocator for Toom-Cook algorithm generator for SVP64
-
-this uses an algorithm based on:
-[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf)
-"""
-
-from itertools import combinations
-from typing import Generic, Iterable, Mapping, TypeVar
-
-from nmutil.plain_data import plain_data
-
-from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
- RegLoc, RegType, SSAVal)
-from bigint_presentation_code.type_util import final
-from bigint_presentation_code.util import OFSet, OSet
-
-_RegType = TypeVar("_RegType", bound=RegType)
-
-
-@plain_data(unsafe_hash=True, order=True, frozen=True)
-class LiveInterval:
- __slots__ = "first_write", "last_use"
-
- def __init__(self, first_write, last_use=None):
- # type: (int, int | None) -> None
- if last_use is None:
- last_use = first_write
- if last_use < first_write:
- raise ValueError("uses must be after first_write")
- if first_write < 0 or last_use < 0:
- raise ValueError("indexes must be nonnegative")
- self.first_write = first_write
- self.last_use = last_use
-
- def overlaps(self, other):
- # type: (LiveInterval) -> bool
- if self.first_write == other.first_write:
- return True
- return self.last_use > other.first_write \
- and other.last_use > self.first_write
-
- def __add__(self, use):
- # type: (int) -> LiveInterval
- last_use = max(self.last_use, use)
- return LiveInterval(first_write=self.first_write, last_use=last_use)
-
- @property
- def live_after_op_range(self):
- """the range of op indexes where self is live immediately after the
- Op at each index
- """
- return range(self.first_write, self.last_use)
-
-
-@final
-class MergedRegSet(Mapping[SSAVal[_RegType], int]):
- def __init__(self, reg_set):
- # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
- self.__items = {} # type: dict[SSAVal[_RegType], int]
- if isinstance(reg_set, SSAVal):
- reg_set = [(reg_set, 0)]
- for ssa_val, offset in reg_set:
- if ssa_val in self.__items:
- other = self.__items[ssa_val]
- if offset != other:
- raise ValueError(
- f"can't merge register sets: conflicting offsets: "
- f"for {ssa_val}: {offset} != {other}")
- else:
- self.__items[ssa_val] = offset
- first_item = None
- for i in self.__items.items():
- first_item = i
- break
- if first_item is None:
- raise ValueError("can't have empty MergedRegs")
- first_ssa_val, start = first_item
- ty = first_ssa_val.ty
- if isinstance(ty, GPRRangeType):
- stop = start + ty.length
- for ssa_val, offset in self.__items.items():
- if not isinstance(ssa_val.ty, GPRRangeType):
- raise ValueError(f"can't merge incompatible types: "
- f"{ssa_val.ty} and {ty}")
- stop = max(stop, offset + ssa_val.ty.length)
- start = min(start, offset)
- ty = GPRRangeType(stop - start)
- else:
- stop = 1
- for ssa_val, offset in self.__items.items():
- if offset != 0:
- raise ValueError(f"can't have non-zero offset "
- f"for {ssa_val.ty}")
- if ty != ssa_val.ty:
- raise ValueError(f"can't merge incompatible types: "
- f"{ssa_val.ty} and {ty}")
- self.__start = start # type: int
- self.__stop = stop # type: int
- self.__ty = ty # type: RegType
- self.__hash = hash(OFSet(self.items()))
-
- @staticmethod
- def from_equality_constraint(constraint_sequence):
- # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
- if len(constraint_sequence) == 1:
- # any type allowed with len = 1
- return MergedRegSet(constraint_sequence[0])
- offset = 0
- retval = []
- for val in constraint_sequence:
- if not isinstance(val.ty, GPRRangeType):
- raise ValueError("equality constraint sequences must only "
- "have SSAVal type GPRRangeType")
- retval.append((val, offset))
- offset += val.ty.length
- return MergedRegSet(retval)
-
- @property
- def ty(self):
- return self.__ty
-
- @property
- def stop(self):
- return self.__stop
-
- @property
- def start(self):
- return self.__start
-
- @property
- def range(self):
- return range(self.__start, self.__stop)
-
- def offset_by(self, amount):
- # type: (int) -> MergedRegSet[_RegType]
- return MergedRegSet((k, v + amount) for k, v in self.items())
-
- def normalized(self):
- # type: () -> MergedRegSet[_RegType]
- return self.offset_by(-self.start)
-
- def with_offset_to_match(self, target):
- # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
- for ssa_val, offset in self.items():
- if ssa_val in target:
- return self.offset_by(target[ssa_val] - offset)
- raise ValueError("can't change offset to match unrelated MergedRegSet")
-
- def __getitem__(self, item):
- # type: (SSAVal[_RegType]) -> int
- return self.__items[item]
-
- def __iter__(self):
- return iter(self.__items)
-
- def __len__(self):
- return len(self.__items)
-
- def __hash__(self):
- return self.__hash
-
- def __repr__(self):
- return f"MergedRegSet({list(self.__items.items())})"
-
-
-@final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
- def __init__(self, ops):
- # type: (Iterable[Op]) -> None
- merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]]
- for op in ops:
- for val in (*op.inputs().values(), *op.outputs().values()):
- if val not in merged_sets:
- merged_sets[val] = MergedRegSet(val)
- for e in op.get_equality_constraints():
- lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
- rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
- items = [] # type: list[tuple[SSAVal, int]]
- for i in e.lhs:
- s = merged_sets[i].with_offset_to_match(lhs_set)
- items.extend(s.items())
- for i in e.rhs:
- s = merged_sets[i].with_offset_to_match(rhs_set)
- items.extend(s.items())
- full_set = MergedRegSet(items)
- for val in full_set.keys():
- merged_sets[val] = full_set
-
- self.__map = {k: v.normalized() for k, v in merged_sets.items()}
-
- def __getitem__(self, key):
- # type: (SSAVal) -> MergedRegSet
- return self.__map[key]
-
- def __iter__(self):
- return iter(self.__map)
-
- def __len__(self):
- return len(self.__map)
-
- def __repr__(self):
- return f"MergedRegSets(data={self.__map})"
-
-
-@final
-class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
- def __init__(self, ops):
- # type: (list[Op]) -> None
- self.__merged_reg_sets = MergedRegSets(ops)
- live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
- for op_idx, op in enumerate(ops):
- for val in op.inputs().values():
- live_intervals[self.__merged_reg_sets[val]] += op_idx
- for val in op.outputs().values():
- reg_set = self.__merged_reg_sets[val]
- if reg_set not in live_intervals:
- live_intervals[reg_set] = LiveInterval(op_idx)
- else:
- live_intervals[reg_set] += op_idx
- self.__live_intervals = live_intervals
- live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
- live_after += (OSet() for _ in ops)
- for reg_set, live_interval in self.__live_intervals.items():
- for i in live_interval.live_after_op_range:
- live_after[i].add(reg_set)
- self.__live_after = [OFSet(i) for i in live_after]
-
- @property
- def merged_reg_sets(self):
- return self.__merged_reg_sets
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegType]) -> LiveInterval
- return self.__live_intervals[key]
-
- def __iter__(self):
- return iter(self.__live_intervals)
-
- def __len__(self):
- return len(self.__live_intervals)
-
- def reg_sets_live_after(self, op_index):
- # type: (int) -> OFSet[MergedRegSet[_RegType]]
- return self.__live_after[op_index]
-
- def __repr__(self):
- reg_sets_live_after = dict(enumerate(self.__live_after))
- return (f"LiveIntervals(live_intervals={self.__live_intervals}, "
- f"merged_reg_sets={self.merged_reg_sets}, "
- f"reg_sets_live_after={reg_sets_live_after})")
-
-
-@final
-class IGNode(Generic[_RegType]):
- """ interference graph node """
- __slots__ = "merged_reg_set", "edges", "reg"
-
- def __init__(self, merged_reg_set, edges=(), reg=None):
- # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
- self.merged_reg_set = merged_reg_set
- self.edges = OSet(edges)
- self.reg = reg
-
- def add_edge(self, other):
- # type: (IGNode) -> None
- self.edges.add(other)
- other.edges.add(self)
-
- def __eq__(self, other):
- # type: (object) -> bool
- if isinstance(other, IGNode):
- return self.merged_reg_set == other.merged_reg_set
- return NotImplemented
-
- def __hash__(self):
- return hash(self.merged_reg_set)
-
- def __repr__(self, nodes=None):
- # type: (None | dict[IGNode, int]) -> str
- if nodes is None:
- nodes = {}
- if self in nodes:
- return f"<IGNode #{nodes[self]}>"
- nodes[self] = len(nodes)
- edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}"
- return (f"IGNode(#{nodes[self]}, "
- f"merged_reg_set={self.merged_reg_set}, "
- f"edges={edges}, "
- f"reg={self.reg})")
-
- @property
- def reg_class(self):
- # type: () -> RegClass
- return self.merged_reg_set.ty.reg_class
-
- def reg_conflicts_with_neighbors(self, reg):
- # type: (RegLoc) -> bool
- for neighbor in self.edges:
- if neighbor.reg is not None and neighbor.reg.conflicts(reg):
- return True
- return False
-
-
-@final
-class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
- def __init__(self, merged_reg_sets):
- # type: (Iterable[MergedRegSet[_RegType]]) -> None
- self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
-
- def __getitem__(self, key):
- # type: (MergedRegSet[_RegType]) -> IGNode
- return self.__nodes[key]
-
- def __iter__(self):
- return iter(self.__nodes)
-
- def __len__(self):
- return len(self.__nodes)
-
- def __repr__(self):
- nodes = {}
- nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
- nodes_text = ", ".join(nodes_text)
- return f"InterferenceGraph(nodes={{{nodes_text}}})"
-
-
-@plain_data()
-class AllocationFailed:
- __slots__ = "node", "live_intervals", "interference_graph"
-
- def __init__(self, node, live_intervals, interference_graph):
- # type: (IGNode, LiveIntervals, InterferenceGraph) -> None
- self.node = node
- self.live_intervals = live_intervals
- self.interference_graph = interference_graph
-
-
-class AllocationFailedError(Exception):
- def __init__(self, msg, allocation_failed):
- # type: (str, AllocationFailed) -> None
- super().__init__(msg, allocation_failed)
- self.allocation_failed = allocation_failed
-
-
-def try_allocate_registers_without_spilling(ops):
- # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed
-
- live_intervals = LiveIntervals(ops)
- merged_reg_sets = live_intervals.merged_reg_sets
- interference_graph = InterferenceGraph(merged_reg_sets.values())
- for op_idx, op in enumerate(ops):
- reg_sets = live_intervals.reg_sets_live_after(op_idx)
- for i, j in combinations(reg_sets, 2):
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
- interference_graph[i].add_edge(interference_graph[j])
- for i, j in op.get_extra_interferences():
- i = merged_reg_sets[i]
- j = merged_reg_sets[j]
- if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
- interference_graph[i].add_edge(interference_graph[j])
-
- nodes_remaining = OSet(interference_graph.values())
-
- def local_colorability_score(node):
- # type: (IGNode) -> int
- """ returns a positive integer if node is locally colorable, returns
- zero or a negative integer if node isn't known to be locally
- colorable, the more negative the value, the less colorable
- """
- if node not in nodes_remaining:
- raise ValueError()
- retval = len(node.reg_class)
- for neighbor in node.edges:
- if neighbor in nodes_remaining:
- retval -= node.reg_class.max_conflicts_with(neighbor.reg_class)
- return retval
-
- node_stack = [] # type: list[IGNode]
- while True:
- best_node = None # type: None | IGNode
- best_score = 0
- for node in nodes_remaining:
- score = local_colorability_score(node)
- if best_node is None or score > best_score:
- best_node = node
- best_score = score
- if best_score > 0:
- # it's locally colorable, no need to find a better one
- break
-
- if best_node is None:
- break
- node_stack.append(best_node)
- nodes_remaining.remove(best_node)
-
- retval = {} # type: dict[SSAVal, RegLoc]
-
- while len(node_stack) > 0:
- node = node_stack.pop()
- if node.reg is not None:
- if node.reg_conflicts_with_neighbors(node.reg):
- return AllocationFailed(node=node,
- live_intervals=live_intervals,
- interference_graph=interference_graph)
- else:
- # pick the first non-conflicting register in node.reg_class, since
- # register classes are ordered from most preferred to least
- # preferred register.
- for reg in node.reg_class:
- if not node.reg_conflicts_with_neighbors(reg):
- node.reg = reg
- break
- if node.reg is None:
- return AllocationFailed(node=node,
- live_intervals=live_intervals,
- interference_graph=interference_graph)
-
- for ssa_val, offset in node.merged_reg_set.items():
- retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset)
-
- return retval
-
-
-def allocate_registers(ops):
- # type: (list[Op]) -> dict[SSAVal, RegLoc]
- retval = try_allocate_registers_without_spilling(ops)
- if isinstance(retval, AllocationFailed):
- # TODO: implement spilling
- raise AllocationFailedError(
- "spilling required but not yet implemented", retval)
- return retval