ALL_ELWIDTHS = (*FpElWid, *IntElWid)
__slots__ = ("__map",)
+ @staticmethod
+ def extract_value_algo(values, default=None, *, simd_map_get, mapping_get):
+ step = 0
+ while values is not None:
+ # specifically use base class to catch all SimdMap instances
+ if isinstance(values, SimdMap):
+ values = simd_map_get(values)
+ elif isinstance(values, Mapping):
+ values = mapping_get(values)
+ else:
+ return values
+ step += 1
+ # use object.__repr__ since repr() would probably recurse forever
+ assert step < 10000, (f"can't resolve infinitely recursive "
+ f"value {object.__repr__(values)}")
+ return default
+
@classmethod
def extract_value(cls, elwid, values, default=None):
"""get the value for elwid.
}) == 5
"""
assert elwid in cls.ALL_ELWIDTHS
- step = 0
- while values is not None:
- # specifically use base class to catch all SimdMap instances
- if isinstance(values, SimdMap):
- values = values.__map.get(elwid)
- elif isinstance(values, Mapping):
- values = values.get(elwid)
- else:
- return values
- step += 1
- # use object.__repr__ since repr() would probably recurse forever
- assert step < 10000, (f"can't resolve infinitely recursive "
- f"value {object.__repr__(values)}")
- return default
+ return SimdMap.extract_value_algo(
+ values, default,
+ simd_map_get=lambda v: v.__map.get(elwid),
+ mapping_get=lambda v: v.get(elwid))
def __init__(self, values=None):
"""construct a SimdMap"""
def __iter__(self):
"""return an iterator of (elwid, value) pairs"""
- return self.__map.items()
+ return iter(self.__map.items())
def __add__(self, other):
return self.map(operator.add, self, other)
yield elwid
+class SimdWHintMap(SimdMap):
+ """SimdMap with a width hint."""
+
+ __slots__ = ("__width_hint",)
+
+ @classmethod
+ def extract_width_hint(cls, values, default=None):
+ """get the value for width hint."""
+ def simd_map_get(v):
+ return v.width_hint if isinstance(v, SimdWHintMap) else None
+ return SimdMap.extract_value_algo(
+ values, default,
+ simd_map_get=simd_map_get,
+ mapping_get=lambda _: None)
+
+ def __init__(self, values=None, *, width_hint=None):
+ """construct a SimdWHintMap"""
+ super().__init__(values)
+ if width_hint is None:
+ width_hint = values
+
+ self.__width_hint = self.extract_width_hint(width_hint)
+
+ @property
+ def width_hint(self):
+ return self.__width_hint
+
+ def __eq__(self, other):
+ if isinstance(other, SimdMap):
+ return self.mapping == other.mapping \
+ and self.width_hint == self.extract_width_hint(other)
+ return NotImplemented
+
+ def __hash__(self):
+ if self.width_hint is None:
+ return super().__hash__()
+ return hash((tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS),
+ self.width_hint))
+
+ def __repr__(self):
+ wh = ""
+ if self.width_hint is not None:
+ wh = f", width_hint={self.width_hint!r}"
+ return f"{self.__class__.__name__}({dict(self.mapping)}{wh})"
+
+ @classmethod
+ def map(cls, f, *args):
+ """get the SimdWHintMap of the results of calling
+ `f(value1, value2, value3, ...)` where
+ `value1`, `value2`, `value3`, ... are the results of calling
+ `cls.extract_value` on each `args`.
+ """
+ retval = {}
+ for elwid in cls.ALL_ELWIDTHS:
+ extracted_args = [cls.extract_value(elwid, arg) for arg in args]
+ if None not in extracted_args:
+ retval[elwid] = f(*extracted_args)
+ width_hint = None
+ try:
+ extracted_args = [cls.extract_width_hint(arg) for arg in args]
+ if None not in extracted_args:
+ width_hint = f(*extracted_args)
+ except (ArithmeticError, LookupError, ValueError):
+ # ignore some errors and just clear width_hint
+ pass
+ return cls(retval, width_hint=width_hint)
+
+
def _check_for_missing_elwidths(name, all_elwidths=None):
missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths))
assert missing == [], f"{name} is missing entries for {missing}"
-XLEN = SimdMap({
+XLEN = SimdWHintMap({
IntElWid.I64: 64,
IntElWid.I32: 32,
IntElWid.I16: 16,
FpElWid.F32: 32,
FpElWid.F16: 16,
FpElWid.BF16: 16,
-})
+}, width_hint=64)
DEFAULT_FP_VEC_EL_COUNTS = SimdMap({
FpElWid.F64: 1,
from enum import Enum
from typing import (Any, Callable, ClassVar, Generic, ItemsView, Iterable,
- KeysView, Literal, Mapping, Optional, Tuple, TypeVar,
- Union, ValuesView, overload)
+ Iterator, KeysView, Literal, Mapping, Optional, Tuple,
+ TypeVar, Union, ValuesView, overload)
class ElWid(Enum):
__map: Mapping[_ElWid, _T]
+ @overload
+ @staticmethod
+ def extract_value_algo(values: None,
+ default: _T2 = None, *,
+ simd_map_get: Callable[["SimdMap[_T]"], _T],
+ mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+ ) -> _T2: ...
+
+ @overload
+ @staticmethod
+ def extract_value_algo(values: SimdMap[_T],
+ default: _T2 = None, *,
+ simd_map_get: Callable[["SimdMap[_T]"], _T],
+ mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+ ) -> Union[_T, _T2]: ...
+
+ @overload
+ @staticmethod
+ def extract_value_algo(values: Mapping[_ElWid, _T],
+ default: _T2 = None, *,
+ simd_map_get: Callable[["SimdMap[_T]"], _T],
+ mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+ ) -> Union[_T, _T2]: ...
+
+ @overload
+ @staticmethod
+ def extract_value_algo(values: _T,
+ default: _T2 = None, *,
+ simd_map_get: Callable[["SimdMap[_T]"], _T],
+ mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+ ) -> Union[_T, _T2]: ...
+
@overload
@classmethod
def extract_value(cls,
def get(self, elwid: _ElWid, default: _T2 = None, *,
raise_key_error: bool = False) -> Union[_T, _T2]: ...
- def __iter__(self) -> Iterable[Tuple[_ElWid, _T]]: ...
+ def __iter__(self) -> Iterator[Tuple[_ElWid, _T]]: ...
@overload
def __add__(self, other: SimdMap[_T]) -> SimdMap[_T]: ...
) -> Iterable[_ElWid]: ...
-XLEN: SimdMap[int] = ...
+class SimdWHintMap(SimdMap[_T]):
+ @overload
+ @classmethod
+ def extract_width_hint(cls,
+ values: None,
+ default: _T2 = None) -> _T2: ...
+
+ @overload
+ @classmethod
+ def extract_width_hint(cls,
+ values: SimdMap[_T],
+ default: _T2 = None) -> Union[_T, _T2]: ...
+
+ @overload
+ @classmethod
+ def extract_width_hint(cls,
+ values: Mapping[_ElWid, _T],
+ default: _T2 = None) -> Union[_T, _T2]: ...
+
+ @overload
+ @classmethod
+ def extract_width_hint(cls,
+ values: _T,
+ default: _T2 = None) -> Union[_T, _T2]: ...
+
+ @overload
+ def __init__(self, values: Optional[SimdMap[_T]] = None, *,
+ width_hint: Optional[SimdMap[_T]] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *,
+ width_hint: Optional[SimdMap[_T]] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[_T] = None, *,
+ width_hint: Optional[SimdMap[_T]] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[SimdMap[_T]] = None, *,
+ width_hint: Optional[Mapping[_ElWid, _T]] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *,
+ width_hint: Optional[Mapping[_ElWid, _T]] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[_T] = None, *,
+ width_hint: Optional[Mapping[_ElWid, _T]] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[SimdMap[_T]] = None, *,
+ width_hint: Optional[_T] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *,
+ width_hint: Optional[_T] = None): ...
+
+ @overload
+ def __init__(self, values: Optional[_T] = None, *,
+ width_hint: Optional[_T] = None): ...
+
+ @property
+ def width_hint(self) -> _T: ...
+
+
+XLEN: SimdWHintMap[int] = ...
DEFAULT_FP_VEC_EL_COUNTS: SimdMap[int] = ...