add SimdWHintMap to support tracking width_hint for XLEN
authorJacob Lifshay <programmerjake@gmail.com>
Wed, 27 Oct 2021 09:10:24 +0000 (02:10 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Wed, 27 Oct 2021 09:10:24 +0000 (02:10 -0700)
src/ieee754/part/util.py
src/ieee754/part/util.pyi

index 6a9bcb13d4c92e17f995d964e357b655a63df813..bdb33279c1157dca6db7cf6ef05ea1a4402df094 100644 (file)
@@ -34,6 +34,23 @@ class SimdMap:
     ALL_ELWIDTHS = (*FpElWid, *IntElWid)
     __slots__ = ("__map",)
 
+    @staticmethod
+    def extract_value_algo(values, default=None, *, simd_map_get, mapping_get):
+        step = 0
+        while values is not None:
+            # specifically use base class to catch all SimdMap instances
+            if isinstance(values, SimdMap):
+                values = simd_map_get(values)
+            elif isinstance(values, Mapping):
+                values = mapping_get(values)
+            else:
+                return values
+            step += 1
+            # use object.__repr__ since repr() would probably recurse forever
+            assert step < 10000, (f"can't resolve infinitely recursive "
+                                  f"value {object.__repr__(values)}")
+        return default
+
     @classmethod
     def extract_value(cls, elwid, values, default=None):
         """get the value for elwid.
@@ -56,20 +73,10 @@ class SimdMap:
         }) == 5
         """
         assert elwid in cls.ALL_ELWIDTHS
-        step = 0
-        while values is not None:
-            # specifically use base class to catch all SimdMap instances
-            if isinstance(values, SimdMap):
-                values = values.__map.get(elwid)
-            elif isinstance(values, Mapping):
-                values = values.get(elwid)
-            else:
-                return values
-            step += 1
-            # use object.__repr__ since repr() would probably recurse forever
-            assert step < 10000, (f"can't resolve infinitely recursive "
-                                  f"value {object.__repr__(values)}")
-        return default
+        return SimdMap.extract_value_algo(
+            values, default,
+            simd_map_get=lambda v: v.__map.get(elwid),
+            mapping_get=lambda v: v.get(elwid))
 
     def __init__(self, values=None):
         """construct a SimdMap"""
@@ -151,7 +158,7 @@ class SimdMap:
 
     def __iter__(self):
         """return an iterator of (elwid, value) pairs"""
-        return self.__map.items()
+        return iter(self.__map.items())
 
     def __add__(self, other):
         return self.map(operator.add, self, other)
@@ -264,12 +271,80 @@ class SimdMap:
                 yield elwid
 
 
+class SimdWHintMap(SimdMap):
+    """SimdMap with a width hint."""
+
+    __slots__ = ("__width_hint",)
+
+    @classmethod
+    def extract_width_hint(cls, values, default=None):
+        """get the value for width hint."""
+        def simd_map_get(v):
+            return v.width_hint if isinstance(v, SimdWHintMap) else None
+        return SimdMap.extract_value_algo(
+            values, default,
+            simd_map_get=simd_map_get,
+            mapping_get=lambda _: None)
+
+    def __init__(self, values=None, *, width_hint=None):
+        """construct a SimdWHintMap"""
+        super().__init__(values)
+        if width_hint is None:
+            width_hint = values
+
+        self.__width_hint = self.extract_width_hint(width_hint)
+
+    @property
+    def width_hint(self):
+        return self.__width_hint
+
+    def __eq__(self, other):
+        if isinstance(other, SimdMap):
+            return self.mapping == other.mapping \
+                and self.width_hint == self.extract_width_hint(other)
+        return NotImplemented
+
+    def __hash__(self):
+        if self.width_hint is None:
+            return super().__hash__()
+        return hash((tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS),
+                     self.width_hint))
+
+    def __repr__(self):
+        wh = ""
+        if self.width_hint is not None:
+            wh = f", width_hint={self.width_hint!r}"
+        return f"{self.__class__.__name__}({dict(self.mapping)}{wh})"
+
+    @classmethod
+    def map(cls, f, *args):
+        """get the SimdWHintMap of the results of calling
+        `f(value1, value2, value3, ...)` where
+        `value1`, `value2`, `value3`, ... are the results of calling
+        `cls.extract_value` on each `args`.
+        """
+        retval = {}
+        for elwid in cls.ALL_ELWIDTHS:
+            extracted_args = [cls.extract_value(elwid, arg) for arg in args]
+            if None not in extracted_args:
+                retval[elwid] = f(*extracted_args)
+        width_hint = None
+        try:
+            extracted_args = [cls.extract_width_hint(arg) for arg in args]
+            if None not in extracted_args:
+                width_hint = f(*extracted_args)
+        except (ArithmeticError, LookupError, ValueError):
+            # ignore some errors and just clear width_hint
+            pass
+        return cls(retval, width_hint=width_hint)
+
+
 def _check_for_missing_elwidths(name, all_elwidths=None):
     missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths))
     assert missing == [], f"{name} is missing entries for {missing}"
 
 
-XLEN = SimdMap({
+XLEN = SimdWHintMap({
     IntElWid.I64: 64,
     IntElWid.I32: 32,
     IntElWid.I16: 16,
@@ -278,7 +353,7 @@ XLEN = SimdMap({
     FpElWid.F32: 32,
     FpElWid.F16: 16,
     FpElWid.BF16: 16,
-})
+}, width_hint=64)
 
 DEFAULT_FP_VEC_EL_COUNTS = SimdMap({
     FpElWid.F64: 1,
index 97486600ee0931660a956a0fbe29f5d7d3cc18e8..5d2189bc237e41eab7a749ef36602633a08baa55 100644 (file)
@@ -3,8 +3,8 @@
 
 from enum import Enum
 from typing import (Any, Callable, ClassVar, Generic, ItemsView, Iterable,
-                    KeysView, Literal, Mapping, Optional, Tuple, TypeVar,
-                    Union, ValuesView, overload)
+                    Iterator, KeysView, Literal, Mapping, Optional, Tuple,
+                    TypeVar, Union, ValuesView, overload)
 
 
 class ElWid(Enum):
@@ -37,6 +37,38 @@ class SimdMap(Generic[_T]):
 
     __map: Mapping[_ElWid, _T]
 
+    @overload
+    @staticmethod
+    def extract_value_algo(values: None,
+                           default: _T2 = None, *,
+                           simd_map_get: Callable[["SimdMap[_T]"], _T],
+                           mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+                           ) -> _T2: ...
+
+    @overload
+    @staticmethod
+    def extract_value_algo(values: SimdMap[_T],
+                           default: _T2 = None, *,
+                           simd_map_get: Callable[["SimdMap[_T]"], _T],
+                           mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+                           ) -> Union[_T, _T2]: ...
+
+    @overload
+    @staticmethod
+    def extract_value_algo(values: Mapping[_ElWid, _T],
+                           default: _T2 = None, *,
+                           simd_map_get: Callable[["SimdMap[_T]"], _T],
+                           mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+                           ) -> Union[_T, _T2]: ...
+
+    @overload
+    @staticmethod
+    def extract_value_algo(values: _T,
+                           default: _T2 = None, *,
+                           simd_map_get: Callable[["SimdMap[_T]"], _T],
+                           mapping_get: Callable[[Mapping[_ElWid, _T]], _T],
+                           ) -> Union[_T, _T2]: ...
+
     @overload
     @classmethod
     def extract_value(cls,
@@ -803,7 +835,7 @@ class SimdMap(Generic[_T]):
     def get(self, elwid: _ElWid, default: _T2 = None, *,
             raise_key_error: bool = False) -> Union[_T, _T2]: ...
 
-    def __iter__(self) -> Iterable[Tuple[_ElWid, _T]]: ...
+    def __iter__(self) -> Iterator[Tuple[_ElWid, _T]]: ...
 
     @overload
     def __add__(self, other: SimdMap[_T]) -> SimdMap[_T]: ...
@@ -1109,7 +1141,72 @@ class SimdMap(Generic[_T]):
                          ) -> Iterable[_ElWid]: ...
 
 
-XLEN: SimdMap[int] = ...
+class SimdWHintMap(SimdMap[_T]):
+    @overload
+    @classmethod
+    def extract_width_hint(cls,
+                           values: None,
+                           default: _T2 = None) -> _T2: ...
+
+    @overload
+    @classmethod
+    def extract_width_hint(cls,
+                           values: SimdMap[_T],
+                           default: _T2 = None) -> Union[_T, _T2]: ...
+
+    @overload
+    @classmethod
+    def extract_width_hint(cls,
+                           values: Mapping[_ElWid, _T],
+                           default: _T2 = None) -> Union[_T, _T2]: ...
+
+    @overload
+    @classmethod
+    def extract_width_hint(cls,
+                           values: _T,
+                           default: _T2 = None) -> Union[_T, _T2]: ...
+
+    @overload
+    def __init__(self, values: Optional[SimdMap[_T]] = None, *,
+                 width_hint: Optional[SimdMap[_T]] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *,
+                 width_hint: Optional[SimdMap[_T]] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[_T] = None, *,
+                 width_hint: Optional[SimdMap[_T]] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[SimdMap[_T]] = None, *,
+                 width_hint: Optional[Mapping[_ElWid, _T]] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *,
+                 width_hint: Optional[Mapping[_ElWid, _T]] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[_T] = None, *,
+                 width_hint: Optional[Mapping[_ElWid, _T]] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[SimdMap[_T]] = None, *,
+                 width_hint: Optional[_T] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[Mapping[_ElWid, _T]] = None, *,
+                 width_hint: Optional[_T] = None): ...
+
+    @overload
+    def __init__(self, values: Optional[_T] = None, *,
+                 width_hint: Optional[_T] = None): ...
+
+    @property
+    def width_hint(self) -> _T: ...
+
+
+XLEN: SimdWHintMap[int] = ...
 
 DEFAULT_FP_VEC_EL_COUNTS: SimdMap[int] = ...