"""
Compiler IR for Toom-Cook algorithm generator for SVP64
+
+This assumes VL != 0 throughout.
"""
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from enum import Enum, EnumMeta, unique
from functools import lru_cache
-from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast
+from typing import (TYPE_CHECKING, Any, Generic, Iterable, Sequence, Type,
+ TypeVar, cast)
-from cached_property import cached_property
from nmutil.plain_data import fields, plain_data
from bigint_presentation_code.ordered_set import OFSet, OSet
@final
@unique
class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta):
- CY = "CY"
+ CA = "CA"
def conflicts(self, other):
# type: (RegLoc) -> bool
return False
+@final
+@unique
+class VL(RegLoc, Enum, metaclass=ABCEnumMeta):
+ VL_MAXVL = "VL_MAXVL"
+ """VL and MAXVL"""
+
+ def conflicts(self, other):
+ # type: (RegLoc) -> bool
+ if isinstance(other, VL):
+ return self == other
+ return False
+
+
@final
class RegClass(OFSet[RegLoc]):
""" an ordered set of registers.
_RegType = TypeVar("_RegType", bound=RegType)
+_RegLoc = TypeVar("_RegLoc", bound=RegLoc)
@plain_data(frozen=True, eq=False)
+@final
class GPRRangeType(RegType):
__slots__ = "length",
- def __init__(self, length):
+ def __init__(self, length=1):
# type: (int) -> None
if length < 1 or length > GPR_COUNT:
raise ValueError("invalid length")
return RegClass(regs)
@property
+ @final
def reg_class(self):
# type: () -> RegClass
return GPRRangeType.__get_reg_class(self.length)
return hash(self.length)
-@plain_data(frozen=True, eq=False)
-@final
-class GPRType(GPRRangeType):
- __slots__ = ()
-
- def __init__(self, length=1):
- if length != 1:
- raise ValueError("length must be 1")
- super().__init__(length=1)
+GPRType = GPRRangeType
+"""a length=1 GPRRangeType"""
@plain_data(frozen=True, unsafe_hash=True)
@plain_data(frozen=True, unsafe_hash=True)
@final
-class CYType(RegType):
+class CAType(RegType):
__slots__ = ()
@property
def reg_class(self):
# type: () -> RegClass
- return RegClass([XERBit.CY])
+ return RegClass([XERBit.CA])
@plain_data(frozen=True, unsafe_hash=True)
return RegClass([GlobalMem.GlobalMem])
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class KnownVLType(RegType):
+ __slots__ = "length",
+
+ def __init__(self, length):
+ # type: (int) -> None
+ if not (0 < length <= 64):
+ raise ValueError("invalid VL value")
+ self.length = length
+
+ @property
+ def reg_class(self):
+ # type: () -> RegClass
+ return RegClass([VL.VL_MAXVL])
+
+
+def assert_vl_is(vl, expected_vl):
+ # type: (SSAKnownVL | KnownVLType | int | None, int) -> None
+ if vl is None:
+ vl = 1
+ elif isinstance(vl, SSAVal):
+ vl = vl.ty.length
+ elif isinstance(vl, KnownVLType):
+ vl = vl.length
+ if vl != expected_vl:
+ raise ValueError(
+ f"wrong VL: expected {expected_vl} got {vl}")
+
+
+STACK_SLOT_SIZE = 8
+
+
@plain_data(frozen=True, unsafe_hash=True)
@final
class StackSlot(RegLoc):
def stop_slot(self):
return self.start_slot + self.length_in_slots
+ @property
+ def start_byte(self):
+ return self.start_slot * STACK_SLOT_SIZE
+
def conflicts(self, other):
# type: (RegLoc) -> bool
if isinstance(other, StackSlot):
return f"SSAVal({fields_str})"
+SSAGPRRange = SSAVal[GPRRangeType]
+SSAGPR = SSAVal[GPRType]
+SSAKnownVL = SSAVal[KnownVLType]
+
+
@final
@plain_data(unsafe_hash=True, frozen=True)
class EqualityConstraint:
_NOT_SET = _NotSet()
+@plain_data(frozen=True, unsafe_hash=True)
+class AsmTemplateSegment(Generic[_RegType], metaclass=ABCMeta):
+ __slots__ = "ssa_val",
+
+ def __init__(self, ssa_val):
+ # type: (SSAVal[_RegType]) -> None
+ self.ssa_val = ssa_val
+
+ def render(self, regs):
+ # type: (dict[SSAVal, RegLoc]) -> str
+ return self._render(regs[self.ssa_val])
+
+ @abstractmethod
+ def _render(self, reg):
+ # type: (RegLoc) -> str
+ ...
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ATSGPR(AsmTemplateSegment[GPRRangeType]):
+ __slots__ = "offset",
+
+ def __init__(self, ssa_val, offset=0):
+ # type: (SSAGPRRange, int) -> None
+ super().__init__(ssa_val)
+ self.offset = offset
+
+ def _render(self, reg):
+ # type: (RegLoc) -> str
+ if not isinstance(reg, GPRRange):
+ raise TypeError()
+ return str(reg.start + self.offset)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ATSStackSlot(AsmTemplateSegment[StackSlotType]):
+ __slots__ = ()
+
+ def _render(self, reg):
+ # type: (RegLoc) -> str
+ if not isinstance(reg, StackSlot):
+ raise TypeError()
+ return f"{reg.start_slot}(1)"
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class ATSCopyGPRRange(AsmTemplateSegment["GPRRangeType | FixedGPRRangeType"]):
+ __slots__ = "src_ssa_val",
+
+ def __init__(self, ssa_val, src_ssa_val):
+ # type: (SSAVal[GPRRangeType | FixedGPRRangeType], SSAVal[GPRRangeType | FixedGPRRangeType]) -> None
+ self.ssa_val = ssa_val
+ self.src_ssa_val = src_ssa_val
+
+ def render(self, regs):
+ # type: (dict[SSAVal, RegLoc]) -> str
+ src = regs[self.src_ssa_val]
+ dest = regs[self.ssa_val]
+ if not isinstance(dest, GPRRange):
+ raise TypeError()
+ if not isinstance(src, GPRRange):
+ raise TypeError()
+ if src.length != dest.length:
+ raise ValueError()
+ if src == dest:
+ return ""
+ mrr = ""
+ sv_ = "sv."
+ if src.length == 1:
+ sv_ = ""
+ elif src.conflicts(dest) and src.start > dest.start:
+ mrr = "/mrr"
+ return f"{sv_}or{mrr} *{dest.start}, *{src.start}, *{src.start}\n"
+
+ def _render(self, reg):
+ # type: (RegLoc) -> str
+ raise TypeError("must call self.render")
+
+
+@final
+class AsmTemplate(Sequence["str | AsmTemplateSegment"]):
+ @staticmethod
+ def __process_segments(segments):
+ # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> Iterable[str | AsmTemplateSegment]
+ for i in segments:
+ if isinstance(i, AsmTemplate):
+ yield from i
+ else:
+ yield i
+
+ def __init__(self, segments=()):
+ # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> None
+ self.__segments = tuple(self.__process_segments(segments))
+
+ def __getitem__(self, index):
+ # type: (int) -> str | AsmTemplateSegment
+ return self.__segments[index]
+
+ def __len__(self):
+ return len(self.__segments)
+
+ def __iter__(self):
+ return iter(self.__segments)
+
+ def __hash__(self):
+ return hash(self.__segments)
+
+ def render(self, regs):
+ # type: (dict[SSAVal, RegLoc]) -> str
+ retval = [] # type: list[str]
+ for segment in self:
+ if isinstance(segment, AsmTemplateSegment):
+ retval.append(segment.render(regs))
+ else:
+ retval.append(segment)
+ return "".join(retval)
+
+
+@final
+class AsmContext:
+ def __init__(self, assigned_registers):
+ # type: (dict[SSAVal, RegLoc]) -> None
+ self.__assigned_registers = assigned_registers
+
+ def reg(self, ssa_val, expected_ty):
+ # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc
+ try:
+ reg = self.__assigned_registers[ssa_val]
+ except KeyError as e:
+ raise ValueError(f"SSAVal not assigned a register: {ssa_val}")
+ wrong_len = (isinstance(reg, GPRRange)
+ and reg.length != ssa_val.ty.length)
+ if not isinstance(reg, expected_ty) or wrong_len:
+ raise TypeError(
+ f"SSAVal is assigned a register of the wrong type: "
+ f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}")
+ return reg
+
+ def gpr_range(self, ssa_val):
+ # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange
+ return self.reg(ssa_val, GPRRange)
+
+ def stack_slot(self, ssa_val):
+ # type: (SSAVal[StackSlotType]) -> StackSlot
+ return self.reg(ssa_val, StackSlot)
+
+ def gpr(self, ssa_val, vec, offset=0):
+ # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str
+ reg = self.gpr_range(ssa_val).start + offset
+ return "*" * vec + str(reg)
+
+ def vgpr(self, ssa_val, offset=0):
+ # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str
+ return self.gpr(ssa_val=ssa_val, vec=True, offset=offset)
+
+ def sgpr(self, ssa_val, offset=0):
+ # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str
+ return self.gpr(ssa_val=ssa_val, vec=False, offset=offset)
+
+ def needs_sv(self, *regs):
+ # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool
+ for reg in regs:
+ reg = self.gpr_range(reg)
+ if reg.length != 1 or reg.start >= 32:
+ return True
+ return False
+
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
class Op(metaclass=ABCMeta):
__slots__ = "id", "fn"
fields_str = ', '.join(fields_list)
return f"{self.__class__.__name__}({fields_str})"
+ @abstractmethod
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ """get the lines of assembly for this Op"""
+ ...
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpLoadFromStackSlot(Op):
- __slots__ = "dest", "src"
+ __slots__ = "dest", "src", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"src": self.src}
+ retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
return {"dest": self.dest}
- def __init__(self, fn, src):
- # type: (Fn, SSAVal[GPRRangeType]) -> None
+ def __init__(self, fn, src, vl=None):
+ # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None
super().__init__(fn)
- self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
+ self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
self.src = src
+ self.vl = vl
+ assert_vl_is(vl, self.dest.ty.length)
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1)
+ src = ctx.stack_slot(self.src)
+ if ctx.needs_sv(self.dest):
+ return [f"sv.ld {dest}, {src.start_byte}(1)"]
+ return [f"ld {dest}, {src.start_byte}(1)"]
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpStoreToStackSlot(Op):
- __slots__ = "dest", "src"
+ __slots__ = "dest", "src", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"src": self.src}
+ retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
return {"dest": self.dest}
- def __init__(self, fn, src):
- # type: (Fn, SSAVal[StackSlotType]) -> None
+ def __init__(self, fn, src, vl=None):
+ # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None
super().__init__(fn)
- self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
+ self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
self.src = src
+ self.vl = vl
+ assert_vl_is(vl, src.ty.length)
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ src = ctx.gpr(self.src, vec=self.src.ty.length != 1)
+ dest = ctx.stack_slot(self.dest)
+ if ctx.needs_sv(self.src):
+ return [f"sv.std {src}, {dest.start_byte}(1)"]
+ return [f"std {src}, {dest.start_byte}(1)"]
_RegSrcType = TypeVar("_RegSrcType", bound=RegType)
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpCopy(Op, Generic[_RegSrcType, _RegType]):
- __slots__ = "dest", "src"
+ __slots__ = "dest", "src", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"src": self.src}
+ retval = {"src": self.src} # type: dict[str, SSAVal[Any]]
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
return {"dest": self.dest}
- def __init__(self, fn, src, dest_ty=None):
- # type: (Fn, SSAVal[_RegSrcType], _RegType | None) -> None
+ def __init__(self, fn, src, dest_ty=None, vl=None):
+ # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None
super().__init__(fn)
if dest_ty is None:
dest_ty = cast(_RegType, src.ty)
if src.ty.length != dest_ty.reg.length:
raise ValueError(f"incompatible source and destination "
f"types: {src.ty} and {dest_ty}")
+ length = src.ty.length
elif isinstance(src.ty, FixedGPRRangeType) \
and isinstance(dest_ty, GPRRangeType):
if src.ty.reg.length != dest_ty.length:
raise ValueError(f"incompatible source and destination "
f"types: {src.ty} and {dest_ty}")
+ length = src.ty.length
elif src.ty != dest_ty:
raise ValueError(f"incompatible source and destination "
f"types: {src.ty} and {dest_ty}")
+ elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)):
+ length = src.ty.length
+ else:
+ length = 1
self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
self.src = src
+ self.vl = vl
+ assert_vl_is(vl, length)
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc):
+ return []
+ if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and
+ isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))):
+ vec = self.dest.ty.length != 1
+ dest = ctx.gpr_range(self.dest) # type: ignore
+ src = ctx.gpr_range(self.src) # type: ignore
+ dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore
+ src_s = ctx.gpr(self.src, vec=vec) # type: ignore
+ mrr = ""
+ if src.conflicts(dest) and src.start > dest.start:
+ mrr = "/mrr"
+ if ctx.needs_sv(self.src, self.dest): # type: ignore
+ return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"]
+ return [f"or {dest_s}, {src_s}, {src_s}"]
+ raise NotImplementedError
@plain_data(unsafe_hash=True, frozen=True, repr=False)
return {"dest": self.dest}
def __init__(self, fn, sources):
- # type: (Fn, Iterable[SSAVal[GPRRangeType]]) -> None
+ # type: (Fn, Iterable[SSAGPRRange]) -> None
super().__init__(fn)
sources = tuple(sources)
self.dest = SSAVal(self, "dest", GPRRangeType(
# type: () -> Iterable[EqualityConstraint]
yield EqualityConstraint([self.dest], [*self.sources])
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ return []
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
return {i.arg_name: i for i in self.results}
def __init__(self, fn, src, split_indexes):
- # type: (Fn, SSAVal[GPRRangeType], Iterable[int]) -> None
+ # type: (Fn, SSAGPRRange, Iterable[int]) -> None
super().__init__(fn)
ranges = [] # type: list[GPRRangeType]
last = 0
# type: () -> Iterable[EqualityConstraint]
yield EqualityConstraint([*self.results], [self.src])
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ return []
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
-class OpAddSubE(Op):
- __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
+class OpBigIntAddSub(Op):
+ __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in}
+ retval = {} # type: dict[str, SSAVal[Any]]
+ retval["lhs"] = self.lhs
+ retval["rhs"] = self.rhs
+ retval["CA_in"] = self.CA_in
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
- return {"RT": self.RT, "CY_out": self.CY_out}
+ return {"out": self.out, "CA_out": self.CA_out}
- def __init__(self, fn, RA, RB, CY_in, is_sub):
- # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+ def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None):
+ # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None
super().__init__(fn)
- if RA.ty != RB.ty:
+ if lhs.ty != rhs.ty:
raise TypeError(f"source types must match: "
- f"{RA} doesn't match {RB}")
- self.RT = SSAVal(self, "RT", RA.ty)
- self.RA = RA
- self.RB = RB
- self.CY_in = CY_in
- self.CY_out = SSAVal(self, "CY_out", CY_in.ty)
+ f"{lhs} doesn't match {rhs}")
+ self.out = SSAVal(self, "out", lhs.ty)
+ self.lhs = lhs
+ self.rhs = rhs
+ self.CA_in = CA_in
+ self.CA_out = SSAVal(self, "CA_out", CA_in.ty)
self.is_sub = is_sub
+ self.vl = vl
+ assert_vl_is(vl, lhs.ty.length)
def get_extra_interferences(self):
# type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.RA
- yield self.RT, self.RB
+ yield self.out, self.lhs
+ yield self.out, self.rhs
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ vec = self.out.ty.length != 1
+ out = ctx.gpr(self.out, vec=vec)
+ RA = ctx.gpr(self.lhs, vec=vec)
+ RB = ctx.gpr(self.rhs, vec=vec)
+ mnemonic = "adde"
+ if self.is_sub:
+ mnemonic = "subfe"
+ RA, RB = RB, RA # reorder to match subfe
+ if ctx.needs_sv(self.out, self.lhs, self.rhs):
+ return [f"sv.{mnemonic} {out}, {RA}, {RB}"]
+ return [f"{mnemonic} {out}, {RA}, {RB}"]
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpBigIntMulDiv(Op):
- __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
+ __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "RB": self.RB, "RC": self.RC}
+ retval = {} # type: dict[str, SSAVal[Any]]
+ retval["RA"] = self.RA
+ retval["RB"] = self.RB
+ retval["RC"] = self.RC
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
return {"RT": self.RT, "RS": self.RS}
- def __init__(self, fn, RA, RB, RC, is_div):
- # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+ def __init__(self, fn, RA, RB, RC, is_div, vl):
+ # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None
super().__init__(fn)
self.RT = SSAVal(self, "RT", RA.ty)
self.RA = RA
self.RC = RC
self.RS = SSAVal(self, "RS", RC.ty)
self.is_div = is_div
+ self.vl = vl
+ assert_vl_is(vl, RA.ty.length)
def get_equality_constraints(self):
# type: () -> Iterable[EqualityConstraint]
yield self.RS, self.RA
yield self.RS, self.RB
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ vec = self.RT.ty.length != 1
+ RT = ctx.gpr(self.RT, vec=vec)
+ RA = ctx.gpr(self.RA, vec=vec)
+ RB = ctx.sgpr(self.RB)
+ RC = ctx.sgpr(self.RC)
+ mnemonic = "maddedu"
+ if self.is_div:
+ mnemonic = "divmod2du/mrr"
+ return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"]
+
@final
@unique
Sr = "sr"
Sra = "sra"
+ def make_big_int_carry_in(self, fn, inp):
+ # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]]
+ if self is ShiftKind.Sl or self is ShiftKind.Sr:
+ li = OpLI(fn, 0)
+ return li.out, [li]
+ else:
+ assert self is ShiftKind.Sra
+ split = OpSplit(fn, inp, [inp.ty.length - 1])
+ shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra)
+ return shr.out, [split, shr]
+
+ def make_big_int_shift(self, fn, inp, sh, vl):
+ # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]]
+ carry_in, ops = self.make_big_int_carry_in(fn, inp)
+ big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl)
+ ops.append(big_int_shift)
+ return big_int_shift.out, ops
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpBigIntShift(Op):
- __slots__ = "RT", "inp", "sh", "kind"
+ __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"inp": self.inp, "sh": self.sh}
+ retval = {} # type: dict[str, SSAVal[Any]]
+ retval["inp"] = self.inp
+ retval["sh"] = self.sh
+ retval["carry_in"] = self.carry_in
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
- return {"RT": self.RT}
+ return {"out": self.out, "_out_padding": self._out_padding}
- def __init__(self, fn, inp, sh, kind):
- # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+ def __init__(self, fn, inp, sh, carry_in, kind, vl=None):
+ # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None
super().__init__(fn)
- self.RT = SSAVal(self, "RT", inp.ty)
+ self.out = SSAVal(self, "out", inp.ty)
+ self._out_padding = SSAVal(self, "_out_padding", GPRRangeType())
+ self.carry_in = carry_in
self.inp = inp
self.sh = sh
self.kind = kind
+ self.vl = vl
+ assert_vl_is(vl, inp.ty.length)
def get_extra_interferences(self):
# type: () -> Iterable[tuple[SSAVal, SSAVal]]
- yield self.RT, self.inp
- yield self.RT, self.sh
+ yield self.out, self.sh
+
+ def get_equality_constraints(self):
+ # type: () -> Iterable[EqualityConstraint]
+ if self.kind is ShiftKind.Sl:
+ yield EqualityConstraint([self.carry_in, self.inp],
+ [self.out, self._out_padding])
+ else:
+ assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
+ yield EqualityConstraint([self.inp, self.carry_in],
+ [self._out_padding, self.out])
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ vec = self.out.ty.length != 1
+ if self.kind is ShiftKind.Sl:
+ RT = ctx.gpr(self.out, vec=vec)
+ RA = ctx.gpr(self.out, vec=vec, offset=-1)
+ RB = ctx.sgpr(self.sh)
+ mrr = "/mrr" if vec else ""
+ return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"]
+ else:
+ assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra
+ RT = ctx.gpr(self.out, vec=vec)
+ RA = ctx.gpr(self.out, vec=vec, offset=1)
+ RB = ctx.sgpr(self.sh)
+ return [f"sv.dsrd {RT}, {RA}, {RB}, 1"]
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
+@final
+class OpShiftImm(Op):
+ __slots__ = "out", "inp", "sh", "kind", "ca_out"
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"inp": self.inp}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ if self.ca_out is not None:
+ return {"out": self.out, "ca_out": self.ca_out}
+ return {"out": self.out}
+
+ def __init__(self, fn, inp, sh, kind):
+ # type: (Fn, SSAGPR, int, ShiftKind) -> None
+ super().__init__(fn)
+ self.out = SSAVal(self, "out", inp.ty)
+ self.inp = inp
+ if not (0 <= sh < 64):
+ raise ValueError("shift amount out of range")
+ self.sh = sh
+ self.kind = kind
+ if self.kind is ShiftKind.Sra:
+ self.ca_out = SSAVal(self, "ca_out", CAType())
+ else:
+ self.ca_out = None
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ out = ctx.sgpr(self.out)
+ inp = ctx.sgpr(self.inp)
+ if self.kind is ShiftKind.Sl:
+ mnemonic = "rldicr"
+ args = f"{self.sh}, {63 - self.sh}"
+ elif self.kind is ShiftKind.Sr:
+ mnemonic = "rldicl"
+ v = (64 - self.sh) % 64
+ args = f"{v}, {self.sh}"
+ else:
+ assert self.kind is ShiftKind.Sra
+ mnemonic = "sradi"
+ args = f"{self.sh}"
+ if ctx.needs_sv(self.out, self.inp):
+ return [f"sv.{mnemonic} {out}, {inp}, {args}"]
+ return [f"{mnemonic} {out}, {inp}, {args}"]
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpLI(Op):
- __slots__ = "out", "value"
+ __slots__ = "out", "value", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {}
+ retval = {} # type: dict[str, SSAVal[Any]]
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
return {"out": self.out}
- def __init__(self, fn, value, length=1):
- # type: (Fn, int, int) -> None
+ def __init__(self, fn, value, vl=None):
+ # type: (Fn, int, SSAKnownVL | None) -> None
super().__init__(fn)
+ if vl is None:
+ length = 1
+ else:
+ length = vl.ty.length
self.out = SSAVal(self, "out", GPRRangeType(length))
+ if not (-1 << 15 <= value <= (1 << 15) - 1):
+ raise ValueError(f"value out of range: {value}")
self.value = value
+ self.vl = vl
+ assert_vl_is(vl, length)
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ vec = self.out.ty.length != 1
+ out = ctx.gpr(self.out, vec=vec)
+ if ctx.needs_sv(self.out):
+ return [f"sv.addi {out}, 0, {self.value}"]
+ return [f"addi {out}, 0, {self.value}"]
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
-class OpClearCY(Op):
- __slots__ = "out",
+class OpSetCA(Op):
+ __slots__ = "out", "value"
def inputs(self):
# type: () -> dict[str, SSAVal]
# type: () -> dict[str, SSAVal]
return {"out": self.out}
- def __init__(self, fn):
- # type: (Fn) -> None
+ def __init__(self, fn, value):
+ # type: (Fn, bool) -> None
super().__init__(fn)
- self.out = SSAVal(self, "out", CYType())
+ self.out = SSAVal(self, "out", CAType())
+ self.value = value
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ if self.value:
+ return ["subfic 0, 0, -1"]
+ return ["addic 0, 0, 0"]
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpLoad(Op):
- __slots__ = "RT", "RA", "offset", "mem"
+ __slots__ = "RT", "RA", "offset", "mem", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"RA": self.RA, "mem": self.mem}
+ retval = {} # type: dict[str, SSAVal[Any]]
+ retval["RA"] = self.RA
+ retval["mem"] = self.mem
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
return {"RT": self.RT}
- def __init__(self, fn, RA, offset, mem, length=1):
- # type: (Fn, SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+ def __init__(self, fn, RA, offset, mem, vl=None):
+ # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
super().__init__(fn)
+ if vl is None:
+ length = 1
+ else:
+ length = vl.ty.length
self.RT = SSAVal(self, "RT", GPRRangeType(length))
self.RA = RA
+ if not (-1 << 15 <= offset <= (1 << 15) - 1):
+ raise ValueError(f"offset out of range: {offset}")
+ if offset % 4 != 0:
+ raise ValueError(f"offset not aligned: {offset}")
self.offset = offset
self.mem = mem
+ self.vl = vl
+ assert_vl_is(vl, length)
def get_extra_interferences(self):
# type: () -> Iterable[tuple[SSAVal, SSAVal]]
if self.RT.ty.length > 1:
yield self.RT, self.RA
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1)
+ RA = ctx.sgpr(self.RA)
+ if ctx.needs_sv(self.RT, self.RA):
+ return [f"sv.ld {RT}, {self.offset}({RA})"]
+ return [f"ld {RT}, {self.offset}({RA})"]
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpStore(Op):
- __slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
+ __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl"
def inputs(self):
# type: () -> dict[str, SSAVal]
- return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in}
+ retval = {} # type: dict[str, SSAVal[Any]]
+ retval["RS"] = self.RS
+ retval["RA"] = self.RA
+ retval["mem_in"] = self.mem_in
+ if self.vl is not None:
+ retval["vl"] = self.vl
+ return retval
def outputs(self):
# type: () -> dict[str, SSAVal]
return {"mem_out": self.mem_out}
- def __init__(self, fn, RS, RA, offset, mem_in):
- # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+ def __init__(self, fn, RS, RA, offset, mem_in, vl=None):
+ # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None
super().__init__(fn)
self.RS = RS
self.RA = RA
+ if not (-1 << 15 <= offset <= (1 << 15) - 1):
+ raise ValueError(f"offset out of range: {offset}")
+ if offset % 4 != 0:
+ raise ValueError(f"offset not aligned: {offset}")
self.offset = offset
self.mem_in = mem_in
self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
+ self.vl = vl
+ assert_vl_is(vl, RS.ty.length)
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1)
+ RA = ctx.sgpr(self.RA)
+ if ctx.needs_sv(self.RS, self.RA):
+ return [f"sv.std {RS}, {self.offset}({RA})"]
+ return [f"std {RS}, {self.offset}({RA})"]
@plain_data(unsafe_hash=True, frozen=True, repr=False)
super().__init__(fn)
self.out = SSAVal(self, "out", ty)
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ return []
+
@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
super().__init__(fn)
self.out = SSAVal(self, "out", GlobalMemType())
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ return []
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
+@final
+class OpSetVLImm(Op):
+ __slots__ = "out",
+
+ def inputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {}
+
+ def outputs(self):
+ # type: () -> dict[str, SSAVal]
+ return {"out": self.out}
+
+ def __init__(self, fn, length):
+ # type: (Fn, int) -> None
+ super().__init__(fn)
+ self.out = SSAVal(self, "out", KnownVLType(length))
+
+ def get_asm_lines(self, ctx):
+ # type: (AsmContext) -> list[str]
+ return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"]
+
def op_set_to_list(ops):
# type: (Iterable[Op]) -> list[Op]
import unittest
-from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange,
- GPRType, GlobalMem, Op, OpAddSubE,
- OpClearCY, OpConcat, OpCopy,
- OpFuncArg, OpInputMem, OpLI,
- OpLoad, OpStore, XERBit)
+from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn,
+ GlobalMem, GPRRange, GPRType,
+ OpBigIntAddSub, OpConcat,
+ OpCopy, OpFuncArg,
+ OpInputMem, OpLI, OpLoad,
+ OpSetCA, OpSetVLImm, OpStore,
+ XERBit)
from bigint_presentation_code.register_allocator import (
- AllocationFailed, allocate_registers, MergedRegSet,
+ AllocationFailed, MergedRegSet, allocate_registers,
try_allocate_registers_without_spilling)
def test_from_equality_constraint(self):
fn = Fn()
- op0 = OpLI(fn, 0, length=1)
- op1 = OpLI(fn, 0, length=2)
- op2 = OpLI(fn, 0, length=3)
+ li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out)
+ li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out)
+ li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out)
self.assertEqual(MergedRegSet.from_equality_constraint([
- op0.out,
- op1.out,
- op2.out,
+ li0x1.out,
+ li0x2.out,
+ li0x3.out,
]), MergedRegSet({
- op0.out: 0,
- op1.out: 1,
- op2.out: 3,
+ li0x1.out: 0,
+ li0x2.out: 1,
+ li0x3.out: 3,
}.items()))
self.assertEqual(MergedRegSet.from_equality_constraint([
- op1.out,
- op0.out,
- op2.out,
+ li0x2.out,
+ li0x1.out,
+ li0x3.out,
]), MergedRegSet({
- op1.out: 0,
- op0.out: 2,
- op2.out: 3,
+ li0x2.out: 0,
+ li0x1.out: 2,
+ li0x3.out: 3,
}.items()))
def test_try_alloc_fail(self):
fn = Fn()
- op0 = OpLI(fn, 0, length=52)
- op1 = OpLI(fn, 0, length=64)
- op2 = OpConcat(fn, [op0.out, op1.out])
+ op0 = OpSetVLImm(fn, 52)
+ op1 = OpLI(fn, 0, vl=op0.out)
+ op2 = OpSetVLImm(fn, 64)
+ op3 = OpLI(fn, 0, vl=op2.out)
+ op4 = OpConcat(fn, [op1.out, op3.out])
reg_assignments = try_allocate_registers_without_spilling(fn.ops)
self.assertEqual(
repr(reg_assignments),
"AllocationFailed("
"node=IGNode(#0, merged_reg_set=MergedRegSet(["
- "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
"edges={}, reg=None), "
"live_intervals=LiveIntervals("
"live_intervals={"
- "MergedRegSet([(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]): "
- "LiveInterval(first_write=0, last_use=2)}, "
+ "MergedRegSet([(<#0.out>, 0)]): "
+ "LiveInterval(first_write=0, last_use=1), "
+ "MergedRegSet([(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]): "
+ "LiveInterval(first_write=1, last_use=4), "
+ "MergedRegSet([(<#2.out>, 0)]): "
+ "LiveInterval(first_write=2, last_use=3)}, "
"merged_reg_sets=MergedRegSets(data={"
- "<#0.out>: MergedRegSet(["
- "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+ "<#0.out>: MergedRegSet([(<#0.out>, 0)]), "
"<#1.out>: MergedRegSet(["
- "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
- "<#2.dest>: MergedRegSet(["
- "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])}), "
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+ "<#2.out>: MergedRegSet([(<#2.out>, 0)]), "
+ "<#3.out>: MergedRegSet(["
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+ "<#4.dest>: MergedRegSet(["
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])}), "
"reg_sets_live_after={"
- "0: OFSet([MergedRegSet(["
- "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+ "0: OFSet([MergedRegSet([(<#0.out>, 0)])]), "
"1: OFSet([MergedRegSet(["
- "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
- "2: OFSet()}), "
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+ "2: OFSet([MergedRegSet(["
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+ "MergedRegSet([(<#2.out>, 0)])]), "
+ "3: OFSet([MergedRegSet(["
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), "
+ "4: OFSet()}), "
"interference_graph=InterferenceGraph(nodes={"
- "...: IGNode(#0, "
- "merged_reg_set=MergedRegSet(["
- "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+ "...: IGNode(#0, merged_reg_set=MergedRegSet([(<#0.out>, 0)]), "
+ "edges={}, reg=None), "
+ "...: IGNode(#1, merged_reg_set=MergedRegSet(["
+ "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), "
+ "edges={}, reg=None), "
+ "...: IGNode(#2, merged_reg_set=MergedRegSet([(<#2.out>, 0)]), "
"edges={}, reg=None)}))"
)
arg = op1.dest
op2 = OpInputMem(fn)
mem = op2.out
- op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32)
- a = op3.RT
- op4 = OpLI(fn, 1)
- b_0 = op4.out
- op5 = OpLI(fn, 0, length=31)
- b_rest = op5.out
- op6 = OpConcat(fn, [b_0, b_rest])
- b = op6.dest
- op7 = OpClearCY(fn)
- cy = op7.out
- op8 = OpAddSubE(fn, a, b, cy, is_sub=False)
- s = op8.RT
- op9 = OpStore(fn, s, arg, offset=0, mem_in=mem)
- mem = op9.mem_out
+ op3 = OpSetVLImm(fn, 32)
+ vl = op3.out
+ op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl)
+ a = op4.RT
+ op5 = OpLI(fn, 0, vl=vl)
+ b = op5.out
+ op6 = OpSetCA(fn, True)
+ ca = op6.out
+ op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl)
+ s = op7.out
+ op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl)
+ mem = op8.mem_out
reg_assignments = try_allocate_registers_without_spilling(fn.ops)
op0.out: GPRRange(start=3, length=1),
op1.dest: GPRRange(start=3, length=1),
op2.out: GlobalMem.GlobalMem,
- op3.RT: GPRRange(start=78, length=32),
- op4.out: GPRRange(start=46, length=1),
- op5.out: GPRRange(start=47, length=31),
- op6.dest: GPRRange(start=46, length=32),
- op7.out: XERBit.CY,
- op8.RT: GPRRange(start=14, length=32),
- op8.CY_out: XERBit.CY,
- op9.mem_out: GlobalMem.GlobalMem,
+ op3.out: VL.VL_MAXVL,
+ op4.RT: GPRRange(start=78, length=32),
+ op5.out: GPRRange(start=46, length=32),
+ op6.out: XERBit.CA,
+ op7.out: GPRRange(start=14, length=32),
+ op7.CA_out: XERBit.CA,
+ op8.mem_out: GlobalMem.GlobalMem,
}
self.assertEqual(reg_assignments, expected_reg_assignments)
def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
# type: (list[GPRRange], GPRRange) -> None
fn = Fn()
- li_ops = [OpLI(fn, i, r.length) for i, r in enumerate(expected_regs)]
- concat = OpConcat(fn, [i.out for i in li_ops])
+ inputs = []
+ expected_reg_assignments = {}
+ for i, r in enumerate(expected_regs):
+ vl = OpSetVLImm(fn, r.length).out
+ expected_reg_assignments[vl] = VL.VL_MAXVL
+ inp = OpLI(fn, i, vl=vl).out
+ inputs.append(inp)
+ expected_reg_assignments[inp] = r
+ concat = OpConcat(fn, inputs)
+ expected_reg_assignments[concat.dest] = expected_dest_reg
reg_assignments = try_allocate_registers_without_spilling(fn.ops)
- expected_reg_assignments = {concat.dest: expected_dest_reg}
- for li_op, reg in zip(li_ops, expected_regs):
- expected_reg_assignments[li_op.out] = reg
+ for inp, reg in zip(inputs, expected_regs):
+ expected_reg_assignments[inp] = reg
self.assertEqual(reg_assignments, expected_reg_assignments)