-from collections import defaultdict
import enum
+from abc import abstractmethod
from enum import Enum, unique
-from typing import AbstractSet, Any, Iterable, Iterator, NoReturn, Tuple, Union, Mapping, overload
+from functools import lru_cache
+from typing import (AbstractSet, Any, Generic, Iterable, Iterator, Sequence,
+ TypeVar, overload)
from weakref import WeakValueDictionary as _WeakVDict
from cached_property import cached_property
from nmutil.plain_data import plain_data
from bigint_presentation_code.type_util import Self, assert_never, final
-from bigint_presentation_code.util import (BaseBitSet, BitSet, FBitSet, OFSet,
- OSet, FMap)
-from functools import lru_cache
+from bigint_presentation_code.util import BitSet, FBitSet, FMap, OFSet
@final
return
for sub_kind in self.sub_kinds:
yield from sub_kind.allocatable_locs(ty)
- loc_set = LocSet(locs())
+ loc_set_before_spread = LocSet(locs())
for idx in range(rep_count):
if not self.spread:
idx = None
- yield OperandDesc(loc_set=loc_set,
+ yield OperandDesc(loc_set_before_spread=loc_set_before_spread,
tied_input_index=self.tied_input_index,
spread_index=idx)
@final
class OperandDesc:
"""Op operand descriptor"""
- __slots__ = "loc_set", "tied_input_index", "spread_index"
+ __slots__ = "loc_set_before_spread", "tied_input_index", "spread_index"
- def __init__(self, loc_set, tied_input_index, spread_index):
+ def __init__(self, loc_set_before_spread, tied_input_index, spread_index):
# type: (LocSet, int | None, int | None) -> None
- if len(loc_set) == 0:
- raise ValueError("loc_set must not be empty")
- self.loc_set = loc_set
+ if len(loc_set_before_spread) == 0:
+ raise ValueError("loc_set_before_spread must not be empty")
+ self.loc_set_before_spread = loc_set_before_spread
self.tied_input_index = tied_input_index
if self.tied_input_index is not None and self.spread_index is not None:
raise ValueError("operand can't be both spread and tied")
+ self.spread_index = spread_index
+
+ @cached_property
+ def ty_before_spread(self):
+ # type: () -> Ty
+ ty = self.loc_set_before_spread.ty
+ assert ty is not None, (
+ "__init__ checked that the LocSet isn't empty, "
+ "non-empty LocSets should always have ty set")
+ return ty
+
+ @cached_property
+ def ty(self):
+ """ Ty after any spread is applied """
+ if self.spread_index is not None:
+ return Ty(base_ty=self.ty_before_spread.base_ty, reg_len=1)
+ return self.ty_before_spread
OD_BASE_SGPR = GenericOperandDesc(
@plain_data(frozen=True, unsafe_hash=True)
@final
class OpProperties:
- __slots__ = "kind", "inputs", "outputs"
+ __slots__ = "kind", "inputs", "outputs", "maxvl"
def __init__(self, kind, maxvl):
# type: (OpKind, int) -> None
for out in self.generic.outputs:
outputs.extend(out.instantiate(maxvl=maxvl))
self.outputs = tuple(outputs)
+ self.maxvl = maxvl
@property
def generic(self):
)
-# FIXME: rewrite from here
-
-
@plain_data(frozen=True, unsafe_hash=True, repr=False)
@final
class SSAVal:
- __slots__ = "sliced_op_outputs",
+ __slots__ = "op", "output_idx"
- _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]",
- "tuple[Op, int]", "SSAVal"]
+ def __init__(self, op, output_idx):
+ # type: (Op, int) -> None
+ self.op = op
+ if output_idx < 0 or output_idx >= len(op.properties.outputs):
+ raise ValueError("invalid output_idx")
+ self.output_idx = output_idx
- @staticmethod
- def __process_sliced_op_outputs(inp):
- # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]]
- for v in inp:
- if isinstance(v, SSAVal):
- yield from v.sliced_op_outputs
- continue
- op = v[0]
- output_index = v[1]
- if output_index < 0 or output_index >= len(op.properties.outputs):
- raise ValueError("invalid output_index")
- cur_len = op.properties.outputs[output_index].get_length(op.maxvl)
- slice_ = slice(None) if len(v) == 2 else v[2]
- if isinstance(slice_, range):
- slice_ = slice(slice_.start, slice_.stop, slice_.step)
- if isinstance(slice_, int):
- # raise exception for out-of-range values
- idx = range(cur_len)[slice_]
- range_ = range(idx, idx + 1)
- else:
- # raise exception for out-of-range values
- range_ = range(cur_len)[slice_]
- if range_.step != 1:
- raise ValueError("slice step must be 1")
- if len(range_) == 0:
- continue
- yield op, output_index, range_
+ def __repr__(self):
+ # type: () -> str
+ return f"<{self.op.name}#{self.output_idx}>"
- def __init__(self, sliced_op_outputs):
- # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None
- # we have length arg so plain_data.replace works
- if isinstance(sliced_op_outputs, SSAVal):
- inp = sliced_op_outputs.sliced_op_outputs
- else:
- inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs)
- processed = [] # type: list[tuple[Op, int, range]]
- length = 0
- for op, output_index, range_ in inp:
- length += len(range_)
- if len(processed) == 0:
- processed.append((op, output_index, range_))
- continue
- last_op, last_output_index, last_range_ = processed[-1]
- if last_op == op and last_output_index == output_index \
- and last_range_.stop == range_.start:
- # merge slices
- range_ = range(last_range_.start, range_.stop)
- processed[-1] = op, output_index, range_
- else:
- processed.append((op, output_index, range_))
- self.sliced_op_outputs = tuple(processed)
-
- def __add__(self, other):
- # type: (SSAVal | Any) -> SSAVal
- if not isinstance(other, SSAVal):
- return NotImplemented
- return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
-
- def __radd__(self, other):
- # type: (SSAVal | Any) -> SSAVal
- if isinstance(other, SSAVal):
- return other.__add__(self)
- return NotImplemented
+ @cached_property
+ def defining_descriptor(self):
+ # type: () -> OperandDesc
+ return self.op.properties.outputs[self.output_idx]
@cached_property
- def expanded_sliced_op_outputs(self):
- # type: () -> tuple[tuple[Op, int, int], ...]
- retval = [] # type: list[tuple[Op, int, int]]
- for op, output_index, range_ in self.sliced_op_outputs:
- for i in range_:
- retval.append((op, output_index, i))
- # must be tuple to not be modifiable since it's cached
- return tuple(retval)
+ def loc_set_before_spread(self):
+ # type: () -> LocSet
+ return self.defining_descriptor.loc_set_before_spread
+ @cached_property
+ def ty(self):
+ # type: () -> Ty
+ return self.defining_descriptor.ty
+
+ @cached_property
+ def ty_before_spread(self):
+ # type: () -> Ty
+ return self.defining_descriptor.ty_before_spread
+
+
+_T = TypeVar("_T")
+_Desc = TypeVar("_Desc")
+
+
+class OpInputSeq(Sequence[_T], Generic[_T, _Desc]):
+ @abstractmethod
+ def _verify_write_with_desc(self, idx, item, desc):
+ # type: (int, _T | Any, _Desc) -> None
+ raise NotImplementedError
+
+ @final
+ def _verify_write(self, idx, item):
+ # type: (int | Any, _T | Any) -> int
+ if not isinstance(idx, int):
+ if isinstance(idx, slice):
+ raise TypeError(
+ f"can't write to slice of {self.__class__.__name__}")
+ raise TypeError(f"can't write with index {idx!r}")
+ # normalize idx, raising IndexError if it is out of range
+ idx = range(len(self.descriptors))[idx]
+ desc = self.descriptors[idx]
+ self._verify_write_with_desc(idx, item, desc)
+ return idx
+
+ @abstractmethod
+ def _get_descriptors(self):
+ # type: () -> tuple[_Desc, ...]
+ raise NotImplementedError
+
+ @cached_property
+ @final
+ def descriptors(self):
+ # type: () -> tuple[_Desc, ...]
+ return self._get_descriptors()
+
+ @property
+ @final
+ def op(self):
+ return self.__op
+
+ def __init__(self, items, op):
+ # type: (Iterable[_T], Op) -> None
+ self.__op = op
+ self.__items = [] # type: list[_T]
+ for idx, item in enumerate(items):
+ if idx >= len(self.descriptors):
+ raise ValueError("too many items")
+ self._verify_write(idx, item)
+ self.__items.append(item)
+ if len(self.__items) < len(self.descriptors):
+ raise ValueError("not enough items")
+
+ @final
+ def __iter__(self):
+ # type: () -> Iterator[_T]
+ yield from self.__items
+
+ @overload
def __getitem__(self, idx):
- # type: (int | slice) -> SSAVal
- if isinstance(idx, int):
- return SSAVal([self.expanded_sliced_op_outputs[idx]])
- return SSAVal(self.expanded_sliced_op_outputs[idx])
+ # type: (int) -> _T
+ ...
+ @overload
+ def __getitem__(self, idx):
+ # type: (slice) -> list[_T]
+ ...
+
+ @final
+ def __getitem__(self, idx):
+ # type: (int | slice) -> _T | list[_T]
+ return self.__items[idx]
+
+ @final
+ def __setitem__(self, idx, item):
+ # type: (int, _T) -> None
+ idx = self._verify_write(idx, item)
+ self.__items[idx] = item
+
+ @final
def __len__(self):
- return len(self.expanded_sliced_op_outputs)
+ # type: () -> int
+ return len(self.__items)
- def __iter__(self):
- # type: () -> Iterator[SSAVal]
- for v in self.expanded_sliced_op_outputs:
- yield SSAVal([v])
- def __repr__(self):
- # type: () -> str
- if len(self.sliced_op_outputs) == 0:
- return "SSAVal([])"
- parts = [] # type: list[str]
- for op, output_index, range_ in self.sliced_op_outputs:
- out_len = op.properties.outputs[output_index].get_length(op.maxvl)
- parts.append(f"<{op.name}#{output_index}>")
- if range_ != range(out_len):
- parts[-1] += f"[{range_.start}:{range_.stop}]"
- return " + ".join(parts)
+@final
+class OpInputs(OpInputSeq[SSAVal, OperandDesc]):
+ def _get_descriptors(self):
+ # type: () -> tuple[OperandDesc, ...]
+ return self.op.properties.inputs
+
+ def _verify_write_with_desc(self, idx, item, desc):
+ # type: (int, SSAVal | Any, OperandDesc) -> None
+ if not isinstance(item, SSAVal):
+ raise TypeError("expected value of type SSAVal")
+ if item.ty != desc.ty:
+ raise ValueError(f"assigned item's type {item.ty!r} doesn't match "
+ f"corresponding input's type {desc.ty!r}")
+
+ def __init__(self, items, op):
+ # type: (Iterable[SSAVal], Op) -> None
+ if hasattr(op, "inputs"):
+ raise ValueError("Op.inputs already set")
+ super().__init__(items, op)
+
+
+@final
+class OpImmediates(OpInputSeq[int, range]):
+ def _get_descriptors(self):
+ # type: () -> tuple[range, ...]
+ return self.op.properties.immediates
+
+ def _verify_write_with_desc(self, idx, item, desc):
+ # type: (int, int | Any, range) -> None
+ if not isinstance(item, int):
+ raise TypeError("expected value of type int")
+ if item not in desc:
+ raise ValueError(f"immediate value {item!r} not in {desc!r}")
+
+ def __init__(self, items, op):
+ # type: (Iterable[int], Op) -> None
+ if hasattr(op, "immediates"):
+ raise ValueError("Op.immediates already set")
+ super().__init__(items, op)
@plain_data(frozen=True, eq=False)
@final
class Op:
- __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name"
+ __slots__ = "fn", "properties", "inputs", "immediates", "outputs", "name"
- def __init__(self, fn, kind, inputs, immediates, maxvl, name=""):
- # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None
+ def __init__(self, fn, properties, inputs, immediates, name=""):
+ # type: (Fn, OpProperties, Iterable[SSAVal], Iterable[int], str) -> None
self.fn = fn
- self.kind = kind
- self.inputs = list(inputs)
- self.immediates = list(immediates)
- self.maxvl = maxvl
+ self.properties = properties
+ self.inputs = OpInputs(inputs, op=self)
+ self.immediates = OpImmediates(immediates, op=self)
outputs_len = len(self.properties.outputs)
- self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
+ self.outputs = tuple(SSAVal(self, i) for i in range(outputs_len))
self.name = fn._add_op_with_unused_name(self, name) # type: ignore
@property
- def properties(self):
- return self.kind.properties
+ def kind(self):
+ return self.properties.kind
def __eq__(self, other):
# type: (Op | Any) -> bool