From af59f922ae9c30b56fddbfb6af33ae03a2dc5ac7 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 18 Oct 2022 00:36:02 -0700 Subject: [PATCH] working on generating output assembly --- src/bigint_presentation_code/compiler_ir.py | 688 ++++++++++++++++-- src/bigint_presentation_code/matrix.py | 2 +- .../register_allocator.py | 16 +- .../test_compiler_ir.py | 65 +- .../test_register_allocator.py | 153 ++-- 5 files changed, 754 insertions(+), 170 deletions(-) diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index aa37fa4..f264b01 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -1,14 +1,16 @@ """ Compiler IR for Toom-Cook algorithm generator for SVP64 + +This assumes VL != 0 throughout. """ from abc import ABCMeta, abstractmethod from collections import defaultdict from enum import Enum, EnumMeta, unique from functools import lru_cache -from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast +from typing import (TYPE_CHECKING, Any, Generic, Iterable, Sequence, Type, + TypeVar, cast) -from cached_property import cached_property from nmutil.plain_data import fields, plain_data from bigint_presentation_code.ordered_set import OFSet, OSet @@ -126,7 +128,7 @@ SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13) @final @unique class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta): - CY = "CY" + CA = "CA" def conflicts(self, other): # type: (RegLoc) -> bool @@ -150,6 +152,19 @@ class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta): return False +@final +@unique +class VL(RegLoc, Enum, metaclass=ABCEnumMeta): + VL_MAXVL = "VL_MAXVL" + """VL and MAXVL""" + + def conflicts(self, other): + # type: (RegLoc) -> bool + if isinstance(other, VL): + return self == other + return False + + @final class RegClass(OFSet[RegLoc]): """ an ordered set of registers. @@ -180,13 +195,15 @@ class RegType(metaclass=ABCMeta): _RegType = TypeVar("_RegType", bound=RegType) +_RegLoc = TypeVar("_RegLoc", bound=RegLoc) @plain_data(frozen=True, eq=False) +@final class GPRRangeType(RegType): __slots__ = "length", - def __init__(self, length): + def __init__(self, length=1): # type: (int) -> None if length < 1 or length > GPR_COUNT: raise ValueError("invalid length") @@ -205,6 +222,7 @@ class GPRRangeType(RegType): return RegClass(regs) @property + @final def reg_class(self): # type: () -> RegClass return GPRRangeType.__get_reg_class(self.length) @@ -220,15 +238,8 @@ class GPRRangeType(RegType): return hash(self.length) -@plain_data(frozen=True, eq=False) -@final -class GPRType(GPRRangeType): - __slots__ = () - - def __init__(self, length=1): - if length != 1: - raise ValueError("length must be 1") - super().__init__(length=1) +GPRType = GPRRangeType +"""a length=1 GPRRangeType""" @plain_data(frozen=True, unsafe_hash=True) @@ -253,13 +264,13 @@ class FixedGPRRangeType(RegType): @plain_data(frozen=True, unsafe_hash=True) @final -class CYType(RegType): +class CAType(RegType): __slots__ = () @property def reg_class(self): # type: () -> RegClass - return RegClass([XERBit.CY]) + return RegClass([XERBit.CA]) @plain_data(frozen=True, unsafe_hash=True) @@ -273,6 +284,39 @@ class GlobalMemType(RegType): return RegClass([GlobalMem.GlobalMem]) +@plain_data(frozen=True, unsafe_hash=True) +@final +class KnownVLType(RegType): + __slots__ = "length", + + def __init__(self, length): + # type: (int) -> None + if not (0 < length <= 64): + raise ValueError("invalid VL value") + self.length = length + + @property + def reg_class(self): + # type: () -> RegClass + return RegClass([VL.VL_MAXVL]) + + +def assert_vl_is(vl, expected_vl): + # type: (SSAKnownVL | KnownVLType | int | None, int) -> None + if vl is None: + vl = 1 + elif isinstance(vl, SSAVal): + vl = vl.ty.length + elif isinstance(vl, KnownVLType): + vl = vl.length + if vl != expected_vl: + raise ValueError( + f"wrong VL: expected {expected_vl} got {vl}") + + +STACK_SLOT_SIZE = 8 + + @plain_data(frozen=True, unsafe_hash=True) @final class StackSlot(RegLoc): @@ -289,6 +333,10 @@ class StackSlot(RegLoc): def stop_slot(self): return self.start_slot + self.length_in_slots + @property + def start_byte(self): + return self.start_slot * STACK_SLOT_SIZE + def conflicts(self, other): # type: (RegLoc) -> bool if isinstance(other, StackSlot): @@ -386,6 +434,11 @@ class SSAVal(Generic[_RegType]): return f"SSAVal({fields_str})" +SSAGPRRange = SSAVal[GPRRangeType] +SSAGPR = SSAVal[GPRType] +SSAKnownVL = SSAVal[KnownVLType] + + @final @plain_data(unsafe_hash=True, frozen=True) class EqualityConstraint: @@ -424,6 +477,177 @@ class _NotSet: _NOT_SET = _NotSet() +@plain_data(frozen=True, unsafe_hash=True) +class AsmTemplateSegment(Generic[_RegType], metaclass=ABCMeta): + __slots__ = "ssa_val", + + def __init__(self, ssa_val): + # type: (SSAVal[_RegType]) -> None + self.ssa_val = ssa_val + + def render(self, regs): + # type: (dict[SSAVal, RegLoc]) -> str + return self._render(regs[self.ssa_val]) + + @abstractmethod + def _render(self, reg): + # type: (RegLoc) -> str + ... + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class ATSGPR(AsmTemplateSegment[GPRRangeType]): + __slots__ = "offset", + + def __init__(self, ssa_val, offset=0): + # type: (SSAGPRRange, int) -> None + super().__init__(ssa_val) + self.offset = offset + + def _render(self, reg): + # type: (RegLoc) -> str + if not isinstance(reg, GPRRange): + raise TypeError() + return str(reg.start + self.offset) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class ATSStackSlot(AsmTemplateSegment[StackSlotType]): + __slots__ = () + + def _render(self, reg): + # type: (RegLoc) -> str + if not isinstance(reg, StackSlot): + raise TypeError() + return f"{reg.start_slot}(1)" + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class ATSCopyGPRRange(AsmTemplateSegment["GPRRangeType | FixedGPRRangeType"]): + __slots__ = "src_ssa_val", + + def __init__(self, ssa_val, src_ssa_val): + # type: (SSAVal[GPRRangeType | FixedGPRRangeType], SSAVal[GPRRangeType | FixedGPRRangeType]) -> None + self.ssa_val = ssa_val + self.src_ssa_val = src_ssa_val + + def render(self, regs): + # type: (dict[SSAVal, RegLoc]) -> str + src = regs[self.src_ssa_val] + dest = regs[self.ssa_val] + if not isinstance(dest, GPRRange): + raise TypeError() + if not isinstance(src, GPRRange): + raise TypeError() + if src.length != dest.length: + raise ValueError() + if src == dest: + return "" + mrr = "" + sv_ = "sv." + if src.length == 1: + sv_ = "" + elif src.conflicts(dest) and src.start > dest.start: + mrr = "/mrr" + return f"{sv_}or{mrr} *{dest.start}, *{src.start}, *{src.start}\n" + + def _render(self, reg): + # type: (RegLoc) -> str + raise TypeError("must call self.render") + + +@final +class AsmTemplate(Sequence["str | AsmTemplateSegment"]): + @staticmethod + def __process_segments(segments): + # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> Iterable[str | AsmTemplateSegment] + for i in segments: + if isinstance(i, AsmTemplate): + yield from i + else: + yield i + + def __init__(self, segments=()): + # type: (Iterable[str | AsmTemplateSegment | AsmTemplate]) -> None + self.__segments = tuple(self.__process_segments(segments)) + + def __getitem__(self, index): + # type: (int) -> str | AsmTemplateSegment + return self.__segments[index] + + def __len__(self): + return len(self.__segments) + + def __iter__(self): + return iter(self.__segments) + + def __hash__(self): + return hash(self.__segments) + + def render(self, regs): + # type: (dict[SSAVal, RegLoc]) -> str + retval = [] # type: list[str] + for segment in self: + if isinstance(segment, AsmTemplateSegment): + retval.append(segment.render(regs)) + else: + retval.append(segment) + return "".join(retval) + + +@final +class AsmContext: + def __init__(self, assigned_registers): + # type: (dict[SSAVal, RegLoc]) -> None + self.__assigned_registers = assigned_registers + + def reg(self, ssa_val, expected_ty): + # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc + try: + reg = self.__assigned_registers[ssa_val] + except KeyError as e: + raise ValueError(f"SSAVal not assigned a register: {ssa_val}") + wrong_len = (isinstance(reg, GPRRange) + and reg.length != ssa_val.ty.length) + if not isinstance(reg, expected_ty) or wrong_len: + raise TypeError( + f"SSAVal is assigned a register of the wrong type: " + f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}") + return reg + + def gpr_range(self, ssa_val): + # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange + return self.reg(ssa_val, GPRRange) + + def stack_slot(self, ssa_val): + # type: (SSAVal[StackSlotType]) -> StackSlot + return self.reg(ssa_val, StackSlot) + + def gpr(self, ssa_val, vec, offset=0): + # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str + reg = self.gpr_range(ssa_val).start + offset + return "*" * vec + str(reg) + + def vgpr(self, ssa_val, offset=0): + # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str + return self.gpr(ssa_val=ssa_val, vec=True, offset=offset) + + def sgpr(self, ssa_val, offset=0): + # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str + return self.gpr(ssa_val=ssa_val, vec=False, offset=offset) + + def needs_sv(self, *regs): + # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool + for reg in regs: + reg = self.gpr_range(reg) + if reg.length != 1 or reg.start >= 32: + return True + return False + + @plain_data(unsafe_hash=True, frozen=True, repr=False) class Op(metaclass=ABCMeta): __slots__ = "id", "fn" @@ -476,45 +700,77 @@ class Op(metaclass=ABCMeta): fields_str = ', '.join(fields_list) return f"{self.__class__.__name__}({fields_str})" + @abstractmethod + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + """get the lines of assembly for this Op""" + ... + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpLoadFromStackSlot(Op): - __slots__ = "dest", "src" + __slots__ = "dest", "src", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"src": self.src} + retval = {"src": self.src} # type: dict[str, SSAVal[Any]] + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] return {"dest": self.dest} - def __init__(self, fn, src): - # type: (Fn, SSAVal[GPRRangeType]) -> None + def __init__(self, fn, src, vl=None): + # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None super().__init__(fn) - self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) + self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) self.src = src + self.vl = vl + assert_vl_is(vl, self.dest.ty.length) + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1) + src = ctx.stack_slot(self.src) + if ctx.needs_sv(self.dest): + return [f"sv.ld {dest}, {src.start_byte}(1)"] + return [f"ld {dest}, {src.start_byte}(1)"] @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpStoreToStackSlot(Op): - __slots__ = "dest", "src" + __slots__ = "dest", "src", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"src": self.src} + retval = {"src": self.src} # type: dict[str, SSAVal[Any]] + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] return {"dest": self.dest} - def __init__(self, fn, src): - # type: (Fn, SSAVal[StackSlotType]) -> None + def __init__(self, fn, src, vl=None): + # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None super().__init__(fn) - self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) + self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) self.src = src + self.vl = vl + assert_vl_is(vl, src.ty.length) + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + src = ctx.gpr(self.src, vec=self.src.ty.length != 1) + dest = ctx.stack_slot(self.dest) + if ctx.needs_sv(self.src): + return [f"sv.std {src}, {dest.start_byte}(1)"] + return [f"std {src}, {dest.start_byte}(1)"] _RegSrcType = TypeVar("_RegSrcType", bound=RegType) @@ -523,18 +779,21 @@ _RegSrcType = TypeVar("_RegSrcType", bound=RegType) @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpCopy(Op, Generic[_RegSrcType, _RegType]): - __slots__ = "dest", "src" + __slots__ = "dest", "src", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"src": self.src} + retval = {"src": self.src} # type: dict[str, SSAVal[Any]] + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] return {"dest": self.dest} - def __init__(self, fn, src, dest_ty=None): - # type: (Fn, SSAVal[_RegSrcType], _RegType | None) -> None + def __init__(self, fn, src, dest_ty=None, vl=None): + # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None super().__init__(fn) if dest_ty is None: dest_ty = cast(_RegType, src.ty) @@ -543,17 +802,44 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]): if src.ty.length != dest_ty.reg.length: raise ValueError(f"incompatible source and destination " f"types: {src.ty} and {dest_ty}") + length = src.ty.length elif isinstance(src.ty, FixedGPRRangeType) \ and isinstance(dest_ty, GPRRangeType): if src.ty.reg.length != dest_ty.length: raise ValueError(f"incompatible source and destination " f"types: {src.ty} and {dest_ty}") + length = src.ty.length elif src.ty != dest_ty: raise ValueError(f"incompatible source and destination " f"types: {src.ty} and {dest_ty}") + elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)): + length = src.ty.length + else: + length = 1 self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType] self.src = src + self.vl = vl + assert_vl_is(vl, length) + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc): + return [] + if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and + isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))): + vec = self.dest.ty.length != 1 + dest = ctx.gpr_range(self.dest) # type: ignore + src = ctx.gpr_range(self.src) # type: ignore + dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore + src_s = ctx.gpr(self.src, vec=vec) # type: ignore + mrr = "" + if src.conflicts(dest) and src.start > dest.start: + mrr = "/mrr" + if ctx.needs_sv(self.src, self.dest): # type: ignore + return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"] + return [f"or {dest_s}, {src_s}, {src_s}"] + raise NotImplementedError @plain_data(unsafe_hash=True, frozen=True, repr=False) @@ -570,7 +856,7 @@ class OpConcat(Op): return {"dest": self.dest} def __init__(self, fn, sources): - # type: (Fn, Iterable[SSAVal[GPRRangeType]]) -> None + # type: (Fn, Iterable[SSAGPRRange]) -> None super().__init__(fn) sources = tuple(sources) self.dest = SSAVal(self, "dest", GPRRangeType( @@ -581,6 +867,10 @@ class OpConcat(Op): # type: () -> Iterable[EqualityConstraint] yield EqualityConstraint([self.dest], [*self.sources]) + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + return [] + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -596,7 +886,7 @@ class OpSplit(Op): return {i.arg_name: i for i in self.results} def __init__(self, fn, src, split_indexes): - # type: (Fn, SSAVal[GPRRangeType], Iterable[int]) -> None + # type: (Fn, SSAGPRRange, Iterable[int]) -> None super().__init__(fn) ranges = [] # type: list[GPRRangeType] last = 0 @@ -615,54 +905,86 @@ class OpSplit(Op): # type: () -> Iterable[EqualityConstraint] yield EqualityConstraint([*self.results], [self.src]) + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + return [] + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final -class OpAddSubE(Op): - __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub" +class OpBigIntAddSub(Op): + __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"RA": self.RA, "RB": self.RB, "CY_in": self.CY_in} + retval = {} # type: dict[str, SSAVal[Any]] + retval["lhs"] = self.lhs + retval["rhs"] = self.rhs + retval["CA_in"] = self.CA_in + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] - return {"RT": self.RT, "CY_out": self.CY_out} + return {"out": self.out, "CA_out": self.CA_out} - def __init__(self, fn, RA, RB, CY_in, is_sub): - # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None + def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None): + # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None super().__init__(fn) - if RA.ty != RB.ty: + if lhs.ty != rhs.ty: raise TypeError(f"source types must match: " - f"{RA} doesn't match {RB}") - self.RT = SSAVal(self, "RT", RA.ty) - self.RA = RA - self.RB = RB - self.CY_in = CY_in - self.CY_out = SSAVal(self, "CY_out", CY_in.ty) + f"{lhs} doesn't match {rhs}") + self.out = SSAVal(self, "out", lhs.ty) + self.lhs = lhs + self.rhs = rhs + self.CA_in = CA_in + self.CA_out = SSAVal(self, "CA_out", CA_in.ty) self.is_sub = is_sub + self.vl = vl + assert_vl_is(vl, lhs.ty.length) def get_extra_interferences(self): # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.RT, self.RA - yield self.RT, self.RB + yield self.out, self.lhs + yield self.out, self.rhs + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + vec = self.out.ty.length != 1 + out = ctx.gpr(self.out, vec=vec) + RA = ctx.gpr(self.lhs, vec=vec) + RB = ctx.gpr(self.rhs, vec=vec) + mnemonic = "adde" + if self.is_sub: + mnemonic = "subfe" + RA, RB = RB, RA # reorder to match subfe + if ctx.needs_sv(self.out, self.lhs, self.rhs): + return [f"sv.{mnemonic} {out}, {RA}, {RB}"] + return [f"{mnemonic} {out}, {RA}, {RB}"] @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpBigIntMulDiv(Op): - __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div" + __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"RA": self.RA, "RB": self.RB, "RC": self.RC} + retval = {} # type: dict[str, SSAVal[Any]] + retval["RA"] = self.RA + retval["RB"] = self.RB + retval["RC"] = self.RC + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] return {"RT": self.RT, "RS": self.RS} - def __init__(self, fn, RA, RB, RC, is_div): - # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None + def __init__(self, fn, RA, RB, RC, is_div, vl): + # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None super().__init__(fn) self.RT = SSAVal(self, "RT", RA.ty) self.RA = RA @@ -670,6 +992,8 @@ class OpBigIntMulDiv(Op): self.RC = RC self.RS = SSAVal(self, "RS", RC.ty) self.is_div = is_div + self.vl = vl + assert_vl_is(vl, RA.ty.length) def get_equality_constraints(self): # type: () -> Iterable[EqualityConstraint] @@ -684,6 +1008,18 @@ class OpBigIntMulDiv(Op): yield self.RS, self.RA yield self.RS, self.RB + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + vec = self.RT.ty.length != 1 + RT = ctx.gpr(self.RT, vec=vec) + RA = ctx.gpr(self.RA, vec=vec) + RB = ctx.sgpr(self.RB) + RC = ctx.sgpr(self.RC) + mnemonic = "maddedu" + if self.is_div: + mnemonic = "divmod2du/mrr" + return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"] + @final @unique @@ -692,58 +1028,179 @@ class ShiftKind(Enum): Sr = "sr" Sra = "sra" + def make_big_int_carry_in(self, fn, inp): + # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]] + if self is ShiftKind.Sl or self is ShiftKind.Sr: + li = OpLI(fn, 0) + return li.out, [li] + else: + assert self is ShiftKind.Sra + split = OpSplit(fn, inp, [inp.ty.length - 1]) + shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra) + return shr.out, [split, shr] + + def make_big_int_shift(self, fn, inp, sh, vl): + # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]] + carry_in, ops = self.make_big_int_carry_in(fn, inp) + big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl) + ops.append(big_int_shift) + return big_int_shift.out, ops + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpBigIntShift(Op): - __slots__ = "RT", "inp", "sh", "kind" + __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"inp": self.inp, "sh": self.sh} + retval = {} # type: dict[str, SSAVal[Any]] + retval["inp"] = self.inp + retval["sh"] = self.sh + retval["carry_in"] = self.carry_in + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] - return {"RT": self.RT} + return {"out": self.out, "_out_padding": self._out_padding} - def __init__(self, fn, inp, sh, kind): - # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None + def __init__(self, fn, inp, sh, carry_in, kind, vl=None): + # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None super().__init__(fn) - self.RT = SSAVal(self, "RT", inp.ty) + self.out = SSAVal(self, "out", inp.ty) + self._out_padding = SSAVal(self, "_out_padding", GPRRangeType()) + self.carry_in = carry_in self.inp = inp self.sh = sh self.kind = kind + self.vl = vl + assert_vl_is(vl, inp.ty.length) def get_extra_interferences(self): # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.RT, self.inp - yield self.RT, self.sh + yield self.out, self.sh + + def get_equality_constraints(self): + # type: () -> Iterable[EqualityConstraint] + if self.kind is ShiftKind.Sl: + yield EqualityConstraint([self.carry_in, self.inp], + [self.out, self._out_padding]) + else: + assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra + yield EqualityConstraint([self.inp, self.carry_in], + [self._out_padding, self.out]) + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + vec = self.out.ty.length != 1 + if self.kind is ShiftKind.Sl: + RT = ctx.gpr(self.out, vec=vec) + RA = ctx.gpr(self.out, vec=vec, offset=-1) + RB = ctx.sgpr(self.sh) + mrr = "/mrr" if vec else "" + return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"] + else: + assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra + RT = ctx.gpr(self.out, vec=vec) + RA = ctx.gpr(self.out, vec=vec, offset=1) + RB = ctx.sgpr(self.sh) + return [f"sv.dsrd {RT}, {RA}, {RB}, 1"] + + +@plain_data(unsafe_hash=True, frozen=True, repr=False) +@final +class OpShiftImm(Op): + __slots__ = "out", "inp", "sh", "kind", "ca_out" + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {"inp": self.inp} + + def outputs(self): + # type: () -> dict[str, SSAVal] + if self.ca_out is not None: + return {"out": self.out, "ca_out": self.ca_out} + return {"out": self.out} + + def __init__(self, fn, inp, sh, kind): + # type: (Fn, SSAGPR, int, ShiftKind) -> None + super().__init__(fn) + self.out = SSAVal(self, "out", inp.ty) + self.inp = inp + if not (0 <= sh < 64): + raise ValueError("shift amount out of range") + self.sh = sh + self.kind = kind + if self.kind is ShiftKind.Sra: + self.ca_out = SSAVal(self, "ca_out", CAType()) + else: + self.ca_out = None + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + out = ctx.sgpr(self.out) + inp = ctx.sgpr(self.inp) + if self.kind is ShiftKind.Sl: + mnemonic = "rldicr" + args = f"{self.sh}, {63 - self.sh}" + elif self.kind is ShiftKind.Sr: + mnemonic = "rldicl" + v = (64 - self.sh) % 64 + args = f"{v}, {self.sh}" + else: + assert self.kind is ShiftKind.Sra + mnemonic = "sradi" + args = f"{self.sh}" + if ctx.needs_sv(self.out, self.inp): + return [f"sv.{mnemonic} {out}, {inp}, {args}"] + return [f"{mnemonic} {out}, {inp}, {args}"] @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpLI(Op): - __slots__ = "out", "value" + __slots__ = "out", "value", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {} + retval = {} # type: dict[str, SSAVal[Any]] + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self, fn, value, length=1): - # type: (Fn, int, int) -> None + def __init__(self, fn, value, vl=None): + # type: (Fn, int, SSAKnownVL | None) -> None super().__init__(fn) + if vl is None: + length = 1 + else: + length = vl.ty.length self.out = SSAVal(self, "out", GPRRangeType(length)) + if not (-1 << 15 <= value <= (1 << 15) - 1): + raise ValueError(f"value out of range: {value}") self.value = value + self.vl = vl + assert_vl_is(vl, length) + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + vec = self.out.ty.length != 1 + out = ctx.gpr(self.out, vec=vec) + if ctx.needs_sv(self.out): + return [f"sv.addi {out}, 0, {self.value}"] + return [f"addi {out}, 0, {self.value}"] @plain_data(unsafe_hash=True, frozen=True, repr=False) @final -class OpClearCY(Op): - __slots__ = "out", +class OpSetCA(Op): + __slots__ = "out", "value" def inputs(self): # type: () -> dict[str, SSAVal] @@ -753,60 +1210,110 @@ class OpClearCY(Op): # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self, fn): - # type: (Fn) -> None + def __init__(self, fn, value): + # type: (Fn, bool) -> None super().__init__(fn) - self.out = SSAVal(self, "out", CYType()) + self.out = SSAVal(self, "out", CAType()) + self.value = value + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + if self.value: + return ["subfic 0, 0, -1"] + return ["addic 0, 0, 0"] @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpLoad(Op): - __slots__ = "RT", "RA", "offset", "mem" + __slots__ = "RT", "RA", "offset", "mem", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"RA": self.RA, "mem": self.mem} + retval = {} # type: dict[str, SSAVal[Any]] + retval["RA"] = self.RA + retval["mem"] = self.mem + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] return {"RT": self.RT} - def __init__(self, fn, RA, offset, mem, length=1): - # type: (Fn, SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None + def __init__(self, fn, RA, offset, mem, vl=None): + # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None super().__init__(fn) + if vl is None: + length = 1 + else: + length = vl.ty.length self.RT = SSAVal(self, "RT", GPRRangeType(length)) self.RA = RA + if not (-1 << 15 <= offset <= (1 << 15) - 1): + raise ValueError(f"offset out of range: {offset}") + if offset % 4 != 0: + raise ValueError(f"offset not aligned: {offset}") self.offset = offset self.mem = mem + self.vl = vl + assert_vl_is(vl, length) def get_extra_interferences(self): # type: () -> Iterable[tuple[SSAVal, SSAVal]] if self.RT.ty.length > 1: yield self.RT, self.RA + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1) + RA = ctx.sgpr(self.RA) + if ctx.needs_sv(self.RT, self.RA): + return [f"sv.ld {RT}, {self.offset}({RA})"] + return [f"ld {RT}, {self.offset}({RA})"] + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpStore(Op): - __slots__ = "RS", "RA", "offset", "mem_in", "mem_out" + __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl" def inputs(self): # type: () -> dict[str, SSAVal] - return {"RS": self.RS, "RA": self.RA, "mem_in": self.mem_in} + retval = {} # type: dict[str, SSAVal[Any]] + retval["RS"] = self.RS + retval["RA"] = self.RA + retval["mem_in"] = self.mem_in + if self.vl is not None: + retval["vl"] = self.vl + return retval def outputs(self): # type: () -> dict[str, SSAVal] return {"mem_out": self.mem_out} - def __init__(self, fn, RS, RA, offset, mem_in): - # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None + def __init__(self, fn, RS, RA, offset, mem_in, vl=None): + # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None super().__init__(fn) self.RS = RS self.RA = RA + if not (-1 << 15 <= offset <= (1 << 15) - 1): + raise ValueError(f"offset out of range: {offset}") + if offset % 4 != 0: + raise ValueError(f"offset not aligned: {offset}") self.offset = offset self.mem_in = mem_in self.mem_out = SSAVal(self, "mem_out", mem_in.ty) + self.vl = vl + assert_vl_is(vl, RS.ty.length) + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1) + RA = ctx.sgpr(self.RA) + if ctx.needs_sv(self.RS, self.RA): + return [f"sv.std {RS}, {self.offset}({RA})"] + return [f"std {RS}, {self.offset}({RA})"] @plain_data(unsafe_hash=True, frozen=True, repr=False) @@ -827,6 +1334,10 @@ class OpFuncArg(Op): super().__init__(fn) self.out = SSAVal(self, "out", ty) + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + return [] + @plain_data(unsafe_hash=True, frozen=True, repr=False) @final @@ -846,6 +1357,33 @@ class OpInputMem(Op): super().__init__(fn) self.out = SSAVal(self, "out", GlobalMemType()) + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + return [] + + +@plain_data(unsafe_hash=True, frozen=True, repr=False) +@final +class OpSetVLImm(Op): + __slots__ = "out", + + def inputs(self): + # type: () -> dict[str, SSAVal] + return {} + + def outputs(self): + # type: () -> dict[str, SSAVal] + return {"out": self.out} + + def __init__(self, fn, length): + # type: (Fn, int) -> None + super().__init__(fn) + self.out = SSAVal(self, "out", KnownVLType(length)) + + def get_asm_lines(self, ctx): + # type: (AsmContext) -> list[str] + return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"] + def op_set_to_list(ops): # type: (Iterable[Op]) -> list[Op] diff --git a/src/bigint_presentation_code/matrix.py b/src/bigint_presentation_code/matrix.py index 2636be8..3e1e154 100644 --- a/src/bigint_presentation_code/matrix.py +++ b/src/bigint_presentation_code/matrix.py @@ -1,7 +1,7 @@ import operator -from typing import Callable, Iterable from fractions import Fraction from numbers import Rational +from typing import Callable, Iterable class Matrix: diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index b44c32d..a22299f 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -342,6 +342,13 @@ class AllocationFailed: self.interference_graph = interference_graph +class AllocationFailedError(Exception): + def __init__(self, msg, allocation_failed): + # type: (str, AllocationFailed) -> None + super().__init__(msg, allocation_failed) + self.allocation_failed = allocation_failed + + def try_allocate_registers_without_spilling(ops): # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed @@ -422,5 +429,10 @@ def try_allocate_registers_without_spilling(ops): def allocate_registers(ops): - # type: (list[Op]) -> None - raise NotImplementedError + # type: (list[Op]) -> dict[SSAVal, RegLoc] + retval = try_allocate_registers_without_spilling(ops) + if isinstance(retval, AllocationFailed): + # TODO: implement spilling + raise AllocationFailedError( + "spilling required but not yet implemented", retval) + return retval diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py index ff52641..0c8a3ee 100644 --- a/src/bigint_presentation_code/test_compiler_ir.py +++ b/src/bigint_presentation_code/test_compiler_ir.py @@ -1,7 +1,11 @@ import unittest -from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange, GPRType, - Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore, +from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, + GPRRange, GPRType, + OpBigIntAddSub, OpConcat, + OpCopy, OpFuncArg, + OpInputMem, OpLI, OpLoad, + OpSetCA, OpSetVLImm, OpStore, op_set_to_list) @@ -15,32 +19,41 @@ class TestCompilerIR(unittest.TestCase): arg = op1.dest op2 = OpInputMem(fn) mem = op2.out - op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32) - a = op3.RT - op4 = OpLI(fn, 1) - b_0 = op4.out - op5 = OpLI(fn, 0, length=31) - b_rest = op5.out - op6 = OpConcat(fn, [b_0, b_rest]) - b = op6.dest - op7 = OpClearCY(fn) - cy = op7.out - op8 = OpAddSubE(fn, a, b, cy, is_sub=False) - s = op8.RT - op9 = OpStore(fn, s, arg, offset=0, mem_in=mem) - mem = op9.mem_out + op3 = OpSetVLImm(fn, 32) + vl = op3.out + op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) + a = op4.RT + op5 = OpLI(fn, 1) + b_0 = op5.out + op6 = OpSetVLImm(fn, 31) + vl = op6.out + op7 = OpLI(fn, 0, vl=vl) + b_rest = op7.out + op8 = OpConcat(fn, [b_0, b_rest]) + b = op8.dest + op9 = OpSetVLImm(fn, 32) + vl = op9.out + op10 = OpSetCA(fn, False) + ca = op10.out + op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) + s = op11.out + op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) + mem = op12.mem_out expected_ops = [ - op7, # OpClearCY() - op5, # OpLI(0, length=31) - op4, # OpLI(1) - op2, # OpInputMem() - op0, # OpFuncArg(FixedGPRRangeType(GPRRange(3))) - op6, # OpConcat([b_0, b_rest]) - op1, # OpCopy(op0.out, GPRType()) - op3, # OpLoad(arg, offset=0, mem=mem, length=32) - op8, # OpAddSubE(a, b, cy, is_sub=False) - op9, # OpStore(s, arg, offset=0, mem_in=mem) + op10, # OpSetCA(fn, False) + op9, # OpSetVLImm(fn, 32) + op6, # OpSetVLImm(fn, 31) + op5, # OpLI(fn, 1) + op3, # OpSetVLImm(fn, 32) + op2, # OpInputMem(fn) + op0, # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) + op7, # OpLI(fn, 0, vl=vl) + op1, # OpCopy(fn, op0.out, GPRType()) + op8, # OpConcat(fn, [b_0, b_rest]) + op4, # OpLoad(fn, arg, offset=0, mem=mem, vl=vl) + op11, # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) + op12, # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) ] ops = op_set_to_list(fn.ops[::-1]) diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py index bdc1938..6558264 100644 --- a/src/bigint_presentation_code/test_register_allocator.py +++ b/src/bigint_presentation_code/test_register_allocator.py @@ -1,12 +1,14 @@ import unittest -from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange, - GPRType, GlobalMem, Op, OpAddSubE, - OpClearCY, OpConcat, OpCopy, - OpFuncArg, OpInputMem, OpLI, - OpLoad, OpStore, XERBit) +from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, + GlobalMem, GPRRange, GPRType, + OpBigIntAddSub, OpConcat, + OpCopy, OpFuncArg, + OpInputMem, OpLI, OpLoad, + OpSetCA, OpSetVLImm, OpStore, + XERBit) from bigint_presentation_code.register_allocator import ( - AllocationFailed, allocate_registers, MergedRegSet, + AllocationFailed, MergedRegSet, allocate_registers, try_allocate_registers_without_spilling) @@ -15,26 +17,26 @@ class TestMergedRegSet(unittest.TestCase): def test_from_equality_constraint(self): fn = Fn() - op0 = OpLI(fn, 0, length=1) - op1 = OpLI(fn, 0, length=2) - op2 = OpLI(fn, 0, length=3) + li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out) + li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out) + li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out) self.assertEqual(MergedRegSet.from_equality_constraint([ - op0.out, - op1.out, - op2.out, + li0x1.out, + li0x2.out, + li0x3.out, ]), MergedRegSet({ - op0.out: 0, - op1.out: 1, - op2.out: 3, + li0x1.out: 0, + li0x2.out: 1, + li0x3.out: 3, }.items())) self.assertEqual(MergedRegSet.from_equality_constraint([ - op1.out, - op0.out, - op2.out, + li0x2.out, + li0x1.out, + li0x3.out, ]), MergedRegSet({ - op1.out: 0, - op0.out: 2, - op2.out: 3, + li0x2.out: 0, + li0x1.out: 2, + li0x3.out: 3, }.items())) @@ -43,38 +45,53 @@ class TestRegisterAllocator(unittest.TestCase): def test_try_alloc_fail(self): fn = Fn() - op0 = OpLI(fn, 0, length=52) - op1 = OpLI(fn, 0, length=64) - op2 = OpConcat(fn, [op0.out, op1.out]) + op0 = OpSetVLImm(fn, 52) + op1 = OpLI(fn, 0, vl=op0.out) + op2 = OpSetVLImm(fn, 64) + op3 = OpLI(fn, 0, vl=op2.out) + op4 = OpConcat(fn, [op1.out, op3.out]) reg_assignments = try_allocate_registers_without_spilling(fn.ops) self.assertEqual( repr(reg_assignments), "AllocationFailed(" "node=IGNode(#0, merged_reg_set=MergedRegSet([" - "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " "edges={}, reg=None), " "live_intervals=LiveIntervals(" "live_intervals={" - "MergedRegSet([(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]): " - "LiveInterval(first_write=0, last_use=2)}, " + "MergedRegSet([(<#0.out>, 0)]): " + "LiveInterval(first_write=0, last_use=1), " + "MergedRegSet([(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]): " + "LiveInterval(first_write=1, last_use=4), " + "MergedRegSet([(<#2.out>, 0)]): " + "LiveInterval(first_write=2, last_use=3)}, " "merged_reg_sets=MergedRegSets(data={" - "<#0.out>: MergedRegSet([" - "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " + "<#0.out>: MergedRegSet([(<#0.out>, 0)]), " "<#1.out>: MergedRegSet([" - "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " - "<#2.dest>: MergedRegSet([" - "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])}), " + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " + "<#2.out>: MergedRegSet([(<#2.out>, 0)]), " + "<#3.out>: MergedRegSet([" + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " + "<#4.dest>: MergedRegSet([" + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])}), " "reg_sets_live_after={" - "0: OFSet([MergedRegSet([" - "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), " + "0: OFSet([MergedRegSet([(<#0.out>, 0)])]), " "1: OFSet([MergedRegSet([" - "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), " - "2: OFSet()}), " + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), " + "2: OFSet([MergedRegSet([" + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " + "MergedRegSet([(<#2.out>, 0)])]), " + "3: OFSet([MergedRegSet([" + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)])]), " + "4: OFSet()}), " "interference_graph=InterferenceGraph(nodes={" - "...: IGNode(#0, " - "merged_reg_set=MergedRegSet([" - "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " + "...: IGNode(#0, merged_reg_set=MergedRegSet([(<#0.out>, 0)]), " + "edges={}, reg=None), " + "...: IGNode(#1, merged_reg_set=MergedRegSet([" + "(<#4.dest>, 0), (<#1.out>, 0), (<#3.out>, 52)]), " + "edges={}, reg=None), " + "...: IGNode(#2, merged_reg_set=MergedRegSet([(<#2.out>, 0)]), " "edges={}, reg=None)}))" ) @@ -85,20 +102,18 @@ class TestRegisterAllocator(unittest.TestCase): arg = op1.dest op2 = OpInputMem(fn) mem = op2.out - op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32) - a = op3.RT - op4 = OpLI(fn, 1) - b_0 = op4.out - op5 = OpLI(fn, 0, length=31) - b_rest = op5.out - op6 = OpConcat(fn, [b_0, b_rest]) - b = op6.dest - op7 = OpClearCY(fn) - cy = op7.out - op8 = OpAddSubE(fn, a, b, cy, is_sub=False) - s = op8.RT - op9 = OpStore(fn, s, arg, offset=0, mem_in=mem) - mem = op9.mem_out + op3 = OpSetVLImm(fn, 32) + vl = op3.out + op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) + a = op4.RT + op5 = OpLI(fn, 0, vl=vl) + b = op5.out + op6 = OpSetCA(fn, True) + ca = op6.out + op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) + s = op7.out + op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) + mem = op8.mem_out reg_assignments = try_allocate_registers_without_spilling(fn.ops) @@ -106,14 +121,13 @@ class TestRegisterAllocator(unittest.TestCase): op0.out: GPRRange(start=3, length=1), op1.dest: GPRRange(start=3, length=1), op2.out: GlobalMem.GlobalMem, - op3.RT: GPRRange(start=78, length=32), - op4.out: GPRRange(start=46, length=1), - op5.out: GPRRange(start=47, length=31), - op6.dest: GPRRange(start=46, length=32), - op7.out: XERBit.CY, - op8.RT: GPRRange(start=14, length=32), - op8.CY_out: XERBit.CY, - op9.mem_out: GlobalMem.GlobalMem, + op3.out: VL.VL_MAXVL, + op4.RT: GPRRange(start=78, length=32), + op5.out: GPRRange(start=46, length=32), + op6.out: XERBit.CA, + op7.out: GPRRange(start=14, length=32), + op7.CA_out: XERBit.CA, + op8.mem_out: GlobalMem.GlobalMem, } self.assertEqual(reg_assignments, expected_reg_assignments) @@ -121,14 +135,21 @@ class TestRegisterAllocator(unittest.TestCase): def tst_try_alloc_concat(self, expected_regs, expected_dest_reg): # type: (list[GPRRange], GPRRange) -> None fn = Fn() - li_ops = [OpLI(fn, i, r.length) for i, r in enumerate(expected_regs)] - concat = OpConcat(fn, [i.out for i in li_ops]) + inputs = [] + expected_reg_assignments = {} + for i, r in enumerate(expected_regs): + vl = OpSetVLImm(fn, r.length).out + expected_reg_assignments[vl] = VL.VL_MAXVL + inp = OpLI(fn, i, vl=vl).out + inputs.append(inp) + expected_reg_assignments[inp] = r + concat = OpConcat(fn, inputs) + expected_reg_assignments[concat.dest] = expected_dest_reg reg_assignments = try_allocate_registers_without_spilling(fn.ops) - expected_reg_assignments = {concat.dest: expected_dest_reg} - for li_op, reg in zip(li_ops, expected_regs): - expected_reg_assignments[li_op.out] = reg + for inp, reg in zip(inputs, expected_regs): + expected_reg_assignments[inp] = reg self.assertEqual(reg_assignments, expected_reg_assignments) -- 2.30.2