From 18e284fbde4034cf5111cf5bdcff752efa79c086 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Sun, 30 Oct 2022 02:20:48 -0700 Subject: [PATCH] implement more of new compiler ir --- src/bigint_presentation_code/compiler_ir2.py | 305 +++++++++++-------- 1 file changed, 186 insertions(+), 119 deletions(-) diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py index 3e2ad0c..b04848b 100644 --- a/src/bigint_presentation_code/compiler_ir2.py +++ b/src/bigint_presentation_code/compiler_ir2.py @@ -1,16 +1,16 @@ -from collections import defaultdict import enum +from abc import abstractmethod from enum import Enum, unique -from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload +from functools import lru_cache +from typing import (AbstractSet, Any, Generic, Iterable, Iterator, Sequence, + TypeVar, overload) from weakref import WeakValueDictionary as _WeakVDict from cached_property import cached_property from nmutil.plain_data import plain_data from bigint_presentation_code.type_util import Self, assert_never, final -from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet, - OSet, FMap) -from functools import lru_cache +from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet @final @@ -461,11 +461,11 @@ class GenericOperandDesc: return for sub_kind in self.sub_kinds: yield from sub_kind.allocatable_locs(ty) - loc_set = LocSet(locs()) + loc_set_before_spread = LocSet(locs()) for idx in range(rep_count): if not self.spread: idx = None - yield OperandDesc(loc_set=loc_set, + yield OperandDesc(loc_set_before_spread=loc_set_before_spread, tied_input_index=self.tied_input_index, spread_index=idx) @@ -474,16 +474,33 @@ class GenericOperandDesc: @final class OperandDesc: """Op operand descriptor""" - __slots__ = "loc_set", "tied_input_index", "spread_index" + __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index" - def __init__(self, loc_set, tied_input_index, spread_index): + def __init__(self, loc_set_before_spread, tied_input_index, spread_index): # type: (LocSet, int | None, int | None) -> None - if len(loc_set) == 0: - raise ValueError("loc_set must not be empty") - self.loc_set = loc_set + if len(loc_set_before_spread) == 0: + raise ValueError("loc_set_before_spread must not be empty") + self.loc_set_before_spread = loc_set_before_spread self.tied_input_index = tied_input_index if self.tied_input_index is not None and self.spread_index is not None: raise ValueError("operand can't be both spread and tied") + self.spread_index = spread_index + + @cached_property + def ty_before_spread(self): + # type: () -> Ty + ty = self.loc_set_before_spread.ty + assert ty is not None, ( + "__init__ checked that the LocSet isn't empty, " + "non-empty LocSets should always have ty set") + return ty + + @cached_property + def ty(self): + """ Ty after any spread is applied """ + if self.spread_index is not None: + return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1) + return self.ty_before_spread OD_BASE_SGPR = GenericOperandDesc( @@ -554,7 +571,7 @@ class GenericOpProperties: @plain_data(frozen=True, unsafe_hash=True) @final class OpProperties: - __slots__ = "kind", "inputs", "outputs" + __slots__ = "kind", "inputs", "outputs", "maxvl" def __init__(self, kind, maxvl): # type: (OpKind, int) -> None @@ -567,6 +584,7 @@ class OpProperties: for out in self.generic.outputs: outputs.extend(out.instantiate(maxvl=maxvl)) self.outputs = tuple(outputs) + self.maxvl = maxvl @property def generic(self): @@ -715,137 +733,186 @@ class OpKind(Enum): ) -# FIXME: rewrite from here - - @plain_data(frozen=True, unsafe_hash=True, repr=False) @final class SSAVal: - __slots__ = "sliced_op_outputs", + __slots__ = "op", "output_idx" - _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]", - "tuple[Op, int]", "SSAVal"] + def __init__(self, op, output_idx): + # type: (Op, int) -> None + self.op = op + if output_idx < 0 or output_idx >= len(op.properties.outputs): + raise ValueError("invalid output_idx") + self.output_idx = output_idx - @staticmethod - def __process_sliced_op_outputs(inp): - # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]] - for v in inp: - if isinstance(v, SSAVal): - yield from v.sliced_op_outputs - continue - op = v[0] - output_index = v[1] - if output_index < 0 or output_index >= len(op.properties.outputs): - raise ValueError("invalid output_index") - cur_len = op.properties.outputs[output_index].get_length(op.maxvl) - slice_ = slice(None) if len(v) == 2 else v[2] - if isinstance(slice_, range): - slice_ = slice(slice_.start, slice_.stop, slice_.step) - if isinstance(slice_, int): - # raise exception for out-of-range values - idx = range(cur_len)[slice_] - range_ = range(idx, idx + 1) - else: - # raise exception for out-of-range values - range_ = range(cur_len)[slice_] - if range_.step != 1: - raise ValueError("slice step must be 1") - if len(range_) == 0: - continue - yield op, output_index, range_ + def __repr__(self): + # type: () -> str + return f"<{self.op.name}#{self.output_idx}>" - def __init__(self, sliced_op_outputs): - # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None - # we have length arg so plain_data.replace works - if isinstance(sliced_op_outputs, SSAVal): - inp = sliced_op_outputs.sliced_op_outputs - else: - inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs) - processed = [] # type: list[tuple[Op, int, range]] - length = 0 - for op, output_index, range_ in inp: - length += len(range_) - if len(processed) == 0: - processed.append((op, output_index, range_)) - continue - last_op, last_output_index, last_range_ = processed[-1] - if last_op == op and last_output_index == output_index \ - and last_range_.stop == range_.start: - # merge slices - range_ = range(last_range_.start, range_.stop) - processed[-1] = op, output_index, range_ - else: - processed.append((op, output_index, range_)) - self.sliced_op_outputs = tuple(processed) - - def __add__(self, other): - # type: (SSAVal | Any) -> SSAVal - if not isinstance(other, SSAVal): - return NotImplemented - return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs) - - def __radd__(self, other): - # type: (SSAVal | Any) -> SSAVal - if isinstance(other, SSAVal): - return other.__add__(self) - return NotImplemented + @cached_property + def defining_descriptor(self): + # type: () -> OperandDesc + return self.op.properties.outputs[self.output_idx] @cached_property - def expanded_sliced_op_outputs(self): - # type: () -> tuple[tuple[Op, int, int], ...] - retval = [] # type: list[tuple[Op, int, int]] - for op, output_index, range_ in self.sliced_op_outputs: - for i in range_: - retval.append((op, output_index, i)) - # must be tuple to not be modifiable since it's cached - return tuple(retval) + def loc_set_before_spread(self): + # type: () -> LocSet + return self.defining_descriptor.loc_set_before_spread + @cached_property + def ty(self): + # type: () -> Ty + return self.defining_descriptor.ty + + @cached_property + def ty_before_spread(self): + # type: () -> Ty + return self.defining_descriptor.ty_before_spread + + +_T = TypeVar("_T") +_Desc = TypeVar("_Desc") + + +class OpInputSeq(Sequence[_T], Generic[_T, _Desc]): + @abstractmethod + def _verify_write_with_desc(self, idx, item, desc): + # type: (int, _T | Any, _Desc) -> None + raise NotImplementedError + + @final + def _verify_write(self, idx, item): + # type: (int | Any, _T | Any) -> int + if not isinstance(idx, int): + if isinstance(idx, slice): + raise TypeError( + f"can't write to slice of {self.__class__.__name__}") + raise TypeError(f"can't write with index {idx!r}") + # normalize idx, raising IndexError if it is out of range + idx = range(len(self.descriptors))[idx] + desc = self.descriptors[idx] + self._verify_write_with_desc(idx, item, desc) + return idx + + @abstractmethod + def _get_descriptors(self): + # type: () -> tuple[_Desc, ...] + raise NotImplementedError + + @cached_property + @final + def descriptors(self): + # type: () -> tuple[_Desc, ...] + return self._get_descriptors() + + @property + @final + def op(self): + return self.__op + + def __init__(self, items, op): + # type: (Iterable[_T], Op) -> None + self.__op = op + self.__items = [] # type: list[_T] + for idx, item in enumerate(items): + if idx >= len(self.descriptors): + raise ValueError("too many items") + self._verify_write(idx, item) + self.__items.append(item) + if len(self.__items) < len(self.descriptors): + raise ValueError("not enough items") + + @final + def __iter__(self): + # type: () -> Iterator[_T] + yield from self.__items + + @overload def __getitem__(self, idx): - # type: (int | slice) -> SSAVal - if isinstance(idx, int): - return SSAVal([self.expanded_sliced_op_outputs[idx]]) - return SSAVal(self.expanded_sliced_op_outputs[idx]) + # type: (int) -> _T + ... + @overload + def __getitem__(self, idx): + # type: (slice) -> list[_T] + ... + + @final + def __getitem__(self, idx): + # type: (int | slice) -> _T | list[_T] + return self.__items[idx] + + @final + def __setitem__(self, idx, item): + # type: (int, _T) -> None + idx = self._verify_write(idx, item) + self.__items[idx] = item + + @final def __len__(self): - return len(self.expanded_sliced_op_outputs) + # type: () -> int + return len(self.__items) - def __iter__(self): - # type: () -> Iterator[SSAVal] - for v in self.expanded_sliced_op_outputs: - yield SSAVal([v]) - def __repr__(self): - # type: () -> str - if len(self.sliced_op_outputs) == 0: - return "SSAVal([])" - parts = [] # type: list[str] - for op, output_index, range_ in self.sliced_op_outputs: - out_len = op.properties.outputs[output_index].get_length(op.maxvl) - parts.append(f"<{op.name}#{output_index}>") - if range_ != range(out_len): - parts[-1] += f"[{range_.start}:{range_.stop}]" - return " + ".join(parts) +@final +class OpInputs(OpInputSeq[SSAVal, OperandDesc]): + def _get_descriptors(self): + # type: () -> tuple[OperandDesc, ...] + return self.op.properties.inputs + + def _verify_write_with_desc(self, idx, item, desc): + # type: (int, SSAVal | Any, OperandDesc) -> None + if not isinstance(item, SSAVal): + raise TypeError("expected value of type SSAVal") + if item.ty != desc.ty: + raise ValueError(f"assigned item's type {item.ty!r} doesn't match " + f"corresponding input's type {desc.ty!r}") + + def __init__(self, items, op): + # type: (Iterable[SSAVal], Op) -> None + if hasattr(op, "inputs"): + raise ValueError("Op.inputs already set") + super().__init__(items, op) + + +@final +class OpImmediates(OpInputSeq[int, range]): + def _get_descriptors(self): + # type: () -> tuple[range, ...] + return self.op.properties.immediates + + def _verify_write_with_desc(self, idx, item, desc): + # type: (int, int | Any, range) -> None + if not isinstance(item, int): + raise TypeError("expected value of type int") + if item not in desc: + raise ValueError(f"immediate value {item!r} not in {desc!r}") + + def __init__(self, items, op): + # type: (Iterable[int], Op) -> None + if hasattr(op, "immediates"): + raise ValueError("Op.immediates already set") + super().__init__(items, op) @plain_data(frozen=True, eq=False) @final class Op: - __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name" + __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name" - def __init__(self, fn, kind, inputs, immediates, maxvl, name=""): - # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None + def __init__(self, fn, properties, inputs, immediates, name=""): + # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None self.fn = fn - self.kind = kind - self.inputs = list(inputs) - self.immediates = list(immediates) - self.maxvl = maxvl + self.properties = properties + self.inputs = OpInputs(inputs, op=self) + self.immediates = OpImmediates(immediates, op=self) outputs_len = len(self.properties.outputs) - self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len)) + self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len)) self.name = fn._add_op_with_unused_name(self, name) # type: ignore @property - def properties(self): - return self.kind.properties + def kind(self): + return self.properties.kind def __eq__(self, other): # type: (Op | Any) -> bool -- 2.30.2