From 60d6a3c1661958e66f5f0a5425fd59085fa83e12 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 27 Oct 2022 01:02:11 -0700 Subject: [PATCH] WIP rewriting compiler IR so regalloc works correctly --- src/bigint_presentation_code/compiler_ir2.py | 528 +++++++++++++++++++ 1 file changed, 528 insertions(+) create mode 100644 src/bigint_presentation_code/compiler_ir2.py diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py new file mode 100644 index 0000000..eacceb4 --- /dev/null +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -0,0 +1,528 @@ +import enum +from enum import Enum, unique +from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload + +from cached_property import cached_property +from nmutil.plain_data import plain_data + +from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final +from weakref import WeakValueDictionary + + +@final +class Fn: + def __init__(self): + self.ops = [] # type: list[Op] + op_names = WeakValueDictionary() + self.__op_names = op_names # type: WeakValueDictionary[str, Op] + self.__next_name_suffix = 2 + + def _add_op_with_unused_name(self, op, name=""): + # type: (Op, str) -> str + if op.fn is not self: + raise ValueError("can't add Op to wrong Fn") + if hasattr(op, "name"): + raise ValueError("Op already named") + orig_name = name + while True: + if name not in self.__op_names: + self.__op_names[name] = op + return name + name = orig_name + str(self.__next_name_suffix) + self.__next_name_suffix += 1 + + def __repr__(self): + return "" + + +@unique +@final +class RegKind(Enum): + GPR = enum.auto() + CA = enum.auto() + VL_MAXVL = enum.auto() + + @cached_property + def only_scalar(self): + if self is RegKind.GPR: + return False + elif self is RegKind.CA or self is RegKind.VL_MAXVL: + return True + else: + assert_never(self) + + @cached_property + def reg_count(self): + if self is RegKind.GPR: + return 128 + elif self is RegKind.CA or self is RegKind.VL_MAXVL: + return 1 + else: + assert_never(self) + + def __repr__(self): + return "RegKind." + self._name_ + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class OperandType: + __slots__ = "kind", "vec" + + def __init__(self, kind, vec): + # type: (RegKind, bool) -> None + self.kind = kind + if kind.only_scalar and vec: + raise ValueError(f"kind={kind} must have vec=False") + self.vec = vec + + def get_length(self, maxvl): + # type: (int) -> int + # here's where subvl and elwid would be accounted for + if self.vec: + return maxvl + return 1 + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class RegShape: + __slots__ = "kind", "length" + + def __init__(self, kind, length=1): + # type: (RegKind, int) -> None + self.kind = kind + if length < 1 or length > kind.reg_count: + raise ValueError("invalid length") + self.length = length + + def try_concat(self, *others): + # type: (*RegShape | Reg | RegClass | None) -> RegShape | None + kind = self.kind + length = self.length + for other in others: + if isinstance(other, (Reg, RegClass)): + other = other.shape + if other is None: + return None + if other.kind != self.kind: + return None + length += other.length + if length > kind.reg_count: + return None + return RegShape(kind=kind, length=length) + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class Reg: + __slots__ = "shape", "start" + + def __init__(self, shape, start): + # type: (RegShape, int) -> None + self.shape = shape + if start < 0 or start + shape.length > shape.kind.reg_count: + raise ValueError("start not in valid range") + self.start = start + + @property + def kind(self): + return self.shape.kind + + @property + def length(self): + return self.shape.length + + def conflicts(self, other): + # type: (Reg) -> bool + return (self.kind == other.kind + and self.start < other.stop and other.start < self.stop) + + @property + def stop(self): + return self.start + self.length + + def try_concat(self, *others): + # type: (*Reg | None) -> Reg | None + shape = self.shape.try_concat(*others) + if shape is None: + return None + stop = self.stop + for other in others: + assert other is not None, "already caught by RegShape.try_concat" + if stop != other.start: + return None + stop = other.stop + return Reg(shape, self.start) + + +@final +class RegClass(AbstractSet[Reg]): + def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0): + # type: (Iterable[Reg | int], RegShape | None, int) -> None + for reg_or_start in regs_or_starts: + if isinstance(reg_or_start, Reg): + if shape is None: + shape = reg_or_start.shape + elif shape != reg_or_start.shape: + raise ValueError(f"conflicting RegShapes: {shape} and " + f"{reg_or_start.shape}") + start = reg_or_start.start + else: + start = reg_or_start + if start < 0: + raise ValueError("a Reg's start is out of range") + starts_bitset |= 1 << start + if starts_bitset == 0: + shape = None + self.__shape = shape + self.__starts_bitset = starts_bitset + if shape is None: + if starts_bitset != 0: + raise ValueError("non-empty RegClass must have non-None shape") + return + if self.stops_bitset >= 1 << shape.kind.reg_count: + raise ValueError("a Reg's start is out of range") + + @property + def shape(self): + # type: () -> RegShape | None + return self.__shape + + @property + def starts_bitset(self): + # type: () -> int + return self.__starts_bitset + + @property + def stops_bitset(self): + # type: () -> int + if self.__shape is None: + return 0 + return self.__starts_bitset << self.__shape.length + + @cached_property + def starts(self): + # type: () -> OFSet[int] + if self.length is None: + return OFSet() + # TODO: fixme + # return OFSet(for i in range(self.length)) + + @cached_property + def stops(self): + # type: () -> OFSet[int] + if self.__shape is None: + return OFSet() + return OFSet(i + self.__shape.length for i in self.__starts) + + @property + def kind(self): + if self.__shape is None: + return None + return self.__shape.kind + + @property + def length(self): + """length of registers in this RegClass, not to be confused with the number of `Reg`s in self""" + if self.__shape is None: + return None + return self.__shape.length + + def concat(self, *others): + # type: (*RegClass) -> RegClass + shape = self.__shape + if shape is None: + return RegClass() + shape = shape.try_concat(*others) + if shape is None: + return RegClass() + starts = OSet(self.starts) + offset = shape.length + for other in others: + assert other.__shape is not None, \ + "already caught by RegShape.try_concat" + starts &= OSet(i - offset for i in other.starts) + offset += other.__shape.length + return RegClass(starts, shape=shape) + + def __contains__(self, reg): + # type: (Reg) -> bool + return reg.shape == self.shape and reg.start in self.starts + + def __iter__(self): + # type: () -> Iterator[Reg] + if self.shape is None: + return + for start in self.starts: + yield Reg(shape=self.shape, start=start) + + def __len__(self): + return len(self.starts) + + def __hash__(self): + return super()._hash() + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class Operand: + __slots__ = "ty", "regs" + + def __init__(self, ty, regs=None): + # type: (OperandType, OFSet[int] | None) -> None + pass + + +OT_VGPR = OperandType(RegKind.GPR, vec=True) +OT_SGPR = OperandType(RegKind.GPR, vec=False) +OT_CA = OperandType(RegKind.CA, vec=False) +OT_VL = OperandType(RegKind.VL_MAXVL, vec=False) + + +@plain_data(frozen=True, unsafe_hash=True) +class TiedOutput: + __slots__ = "input_index", "output_index" + + def __init__(self, input_index, output_index): + # type: (int, int) -> None + self.input_index = input_index + self.output_index = output_index + + +Constraint = Union[TiedOutput, NoReturn] + + +@plain_data(frozen=True, unsafe_hash=True) +@final +class OpProperties: + __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints", + "is_copy", "is_load_immediate", "has_side_effects") + + def __init__(self, demo_asm, # type: str + inputs, # type: Iterable[OperandType] + outputs, # type: Iterable[OperandType] + immediates, # type: Iterable[range] + constraints, # type: Iterable[Constraint] + is_copy=False, # type: bool + is_load_immediate=False, # type: bool + has_side_effects=False, # type: bool + ): + # type: (...) -> None + self.demo_asm = demo_asm + self.inputs = tuple(inputs) + self.outputs = tuple(outputs) + self.immediates = tuple(immediates) + self.constraints = tuple(constraints) + self.is_copy = is_copy + self.is_load_immediate = is_load_immediate + self.has_side_effects = has_side_effects + + +@unique +@final +class OpKind(Enum): + def __init__(self, properties): + # type: (OpProperties) -> None + super().__init__() + self.properties = properties + + SvAddE = OpProperties( + demo_asm="sv.adde *RT, *RA, *RB", + inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL), + outputs=(OT_VGPR, OT_CA), + immediates=(), + constraints=(), + ) + SvSubFE = OpProperties( + demo_asm="sv.subfe *RT, *RA, *RB", + inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL), + outputs=(OT_VGPR, OT_CA), + immediates=(), + constraints=(), + ) + SvMAddEDU = OpProperties( + demo_asm="sv.maddedu *RT, *RA, RB, RC", + inputs=(OT_VGPR, OT_SGPR, OT_SGPR, OT_VL), + outputs=(OT_VGPR, OT_SGPR), + immediates=(), + constraints=(), + ) + SetVLI = OpProperties( + demo_asm="setvl 0, 0, imm, 0, 1, 1", + inputs=(), + outputs=(OT_VL,), + immediates=(range(1, 65),), + constraints=(), + is_load_immediate=True, + ) + SvLI = OpProperties( + demo_asm="sv.addi *RT, 0, imm", + inputs=(OT_VL,), + outputs=(OT_VGPR,), + immediates=(range(-2 ** 15, 2 ** 15),), + constraints=(), + is_load_immediate=True, + ) + LI = OpProperties( + demo_asm="addi RT, 0, imm", + inputs=(), + outputs=(OT_SGPR,), + immediates=(range(-2 ** 15, 2 ** 15),), + constraints=(), + is_load_immediate=True, + ) + SvMv = OpProperties( + demo_asm="sv.or *RT, *src, *src", + inputs=(OT_VGPR, OT_VL), + outputs=(OT_VGPR,), + immediates=(), + constraints=(), + is_copy=True, + ) + Mv = OpProperties( + demo_asm="mv RT, src", + inputs=(OT_SGPR,), + outputs=(OT_SGPR,), + immediates=(), + constraints=(), + is_copy=True, + ) + + +@plain_data(frozen=True, unsafe_hash=True, repr=False) +@final +class SSAVal: + __slots__ = "sliced_op_outputs", + + _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]", + "tuple[Op, int]", "SSAVal"] + + @staticmethod + def __process_sliced_op_outputs(inp): + # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]] + for v in inp: + if isinstance(v, SSAVal): + yield from v.sliced_op_outputs + continue + op = v[0] + output_index = v[1] + if output_index < 0 or output_index >= len(op.properties.outputs): + raise ValueError("invalid output_index") + cur_len = op.properties.outputs[output_index].get_length(op.maxvl) + slice_ = slice(None) if len(v) == 2 else v[2] + if isinstance(slice_, range): + slice_ = slice(slice_.start, slice_.stop, slice_.step) + if isinstance(slice_, int): + # raise exception for out-of-range values + idx = range(cur_len)[slice_] + range_ = range(idx, idx + 1) + else: + # raise exception for out-of-range values + range_ = range(cur_len)[slice_] + if range_.step != 1: + raise ValueError("slice step must be 1") + if len(range_) == 0: + continue + yield op, output_index, range_ + + def __init__(self, sliced_op_outputs): + # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None + # we have length arg so plain_data.replace works + if isinstance(sliced_op_outputs, SSAVal): + inp = sliced_op_outputs.sliced_op_outputs + else: + inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs) + processed = [] # type: list[tuple[Op, int, range]] + length = 0 + for op, output_index, range_ in inp: + length += len(range_) + if len(processed) == 0: + processed.append((op, output_index, range_)) + continue + last_op, last_output_index, last_range_ = processed[-1] + if last_op == op and last_output_index == output_index \ + and last_range_.stop == range_.start: + # merge slices + range_ = range(last_range_.start, range_.stop) + processed[-1] = op, output_index, range_ + else: + processed.append((op, output_index, range_)) + self.sliced_op_outputs = tuple(processed) + + def __add__(self, other): + # type: (SSAVal) -> SSAVal + if not isinstance(other, SSAVal): + return NotImplemented + return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs) + + def __radd__(self, other): + # type: (SSAVal) -> SSAVal + if isinstance(other, SSAVal): + return other.__add__(self) + return NotImplemented + + @cached_property + def expanded_sliced_op_outputs(self): + # type: () -> tuple[tuple[Op, int, int], ...] + retval = [] + for op, output_index, range_ in self.sliced_op_outputs: + for i in range_: + retval.append((op, output_index, i)) + # must be tuple to not be modifiable since it's cached + return tuple(retval) + + def __getitem__(self, idx): + # type: (int | slice) -> SSAVal + if isinstance(idx, int): + return SSAVal([self.expanded_sliced_op_outputs[idx]]) + return SSAVal(self.expanded_sliced_op_outputs[idx]) + + def __len__(self): + return len(self.expanded_sliced_op_outputs) + + def __iter__(self): + # type: () -> Iterator[SSAVal] + for v in self.expanded_sliced_op_outputs: + yield SSAVal([v]) + + def __repr__(self): + # type: () -> str + if len(self.sliced_op_outputs) == 0: + return "SSAVal([])" + parts = [] + for op, output_index, range_ in self.sliced_op_outputs: + out_len = op.properties.outputs[output_index].get_length(op.maxvl) + parts.append(f"<{op.name}#{output_index}>") + if range_ != range(out_len): + parts[-1] += f"[{range_.start}:{range_.stop}]" + return " + ".join(parts) + + +@plain_data(frozen=True, eq=False) +@final +class Op: + __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name" + + def __init__(self, fn, kind, inputs, immediates, maxvl, name=""): + # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None + self.fn = fn + self.kind = kind + self.inputs = list(inputs) + self.immediates = list(immediates) + self.maxvl = maxvl + outputs_len = len(self.properties.outputs) + self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len)) + self.name = fn._add_op_with_unused_name(self, name) + + @property + def properties(self): + return self.kind.properties + + def __eq__(self, other): + if isinstance(other, Op): + return self is other + return NotImplemented + + def __hash__(self): + return object.__hash__(self) -- 2.30.2