From 45b3d82e222589403ca5e4f128e8a8737d85d30e Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 13 Oct 2021 01:47:38 -0700 Subject: [PATCH] add SimdMap and SimdScope and XLEN --- src/ieee754/part/util.py | 397 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 397 insertions(+) create mode 100644 src/ieee754/part/util.py diff --git a/src/ieee754/part/util.py b/src/ieee754/part/util.py new file mode 100644 index 00000000..8cadf8c1 --- /dev/null +++ b/src/ieee754/part/util.py @@ -0,0 +1,397 @@ +# SPDX-License-Identifier: LGPL-3-or-later +# See Notices.txt for copyright information + +from enum import Enum +from typing import Mapping +import operator +import math +from types import MappingProxyType +from contextlib import contextmanager + +from nmigen.hdl.ast import Signal + + +class ElWid(Enum): + def __repr__(self): + return super().__str__() + + +class FpElWid(ElWid): + F64 = 0 + F32 = 1 + F16 = 2 + BF16 = 3 + + +class IntElWid(ElWid): + I64 = 0 + I32 = 1 + I16 = 2 + I8 = 3 + + +class SimdMap: + """A map from ElWid values to Python values. + SimdMap instances are immutable.""" + + ALL_ELWIDTHS = (*FpElWid, *IntElWid) + __slots__ = ("__map",) + + @classmethod + def extract_value(cls, elwid, values, default=None): + """get the value for elwid. + if `values` is a `SimdMap` or a `Mapping`, then return the + corresponding value for `elwid`, recursing until finding a non-map. + if `values` ever ends up not existing (in the case of a map) or being + `None`, return `default`. + + Examples: + SimdMap.extract_value(IntElWid.I8, 5) == 5 + SimdMap.extract_value(IntElWid.I8, None) == None + SimdMap.extract_value(IntElWid.I8, None, 3) == 3 + SimdMap.extract_value(IntElWid.I8, {}) == None + SimdMap.extract_value(IntElWid.I8, {IntElWid.I8: 5}) == 5 + SimdMap.extract_value(IntElWid.I8, { + IntElWid.I8: {IntElWid.I8: 5}, + }) == 5 + SimdMap.extract_value(IntElWid.I8, { + IntElWid.I8: SimdMap({IntElWid.I8: 5}), + }) == 5 + """ + assert elwid in cls.ALL_ELWIDTHS + step = 0 + while values is not None: + # specifically use base class to catch all SimdMap instances + if isinstance(values, SimdMap): + values = values.__map.get(elwid) + elif isinstance(values, Mapping): + values = values.get(elwid) + else: + return values + step += 1 + # use object.__repr__ since repr() would probably recurse forever + assert step < 10000, (f"can't resolve infinitely recursive " + f"value {object.__repr__(values)}") + return default + + def __init__(self, values=None): + """construct a SimdMap""" + mapping = {} + for elwid in self.ALL_ELWIDTHS: + v = self.extract_value(elwid, values) + if v is not None: + mapping[elwid] = v + self.__map = MappingProxyType(mapping) + + @property + def mapping(self): + """the values as a read-only Mapping[ElWid, Any]""" + return self.__map + + def values(self): + return self.__map.values() + + def keys(self): + return self.__map.keys() + + def items(self): + return self.__map.items() + + @classmethod + def map_with_elwid(cls, f, *args): + """get the SimdMap of the results of calling + `f(elwid, value1, value2, value3, ...)` where + `value1`, `value2`, `value3`, ... are the results of calling + `cls.extract_value` on each `args`. + + This is similar to Python's built-in `map` function. + + Examples: + SimdMap.map_with_elwid(lambda elwid, a: a + 1, {IntElWid.I32: 5}) == + SimdMap({IntElWid.I32: 6}) + SimdMap.map_with_elwid(lambda elwid, a: a + 1, 3) == + SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...}) + SimdMap.map_with_elwid(lambda elwid, a, b: a + b, + 3, {IntElWid.I8: 4}, + ) == SimdMap({IntElWid.I8: 7}) + SimdMap.map_with_elwid(lambda elwid: elwid.name) == + SimdMap({IntElWid.I8: "I8", IntElWid.I16: "I16"}) + """ + retval = {} + for elwid in cls.ALL_ELWIDTHS: + extracted_args = [cls.extract_value(elwid, arg) for arg in args] + if None not in extracted_args: + retval[elwid] = f(elwid, *extracted_args) + return cls(retval) + + @classmethod + def map(cls, f, *args): + """get the SimdMap of the results of calling + `f(value1, value2, value3, ...)` where + `value1`, `value2`, `value3`, ... are the results of calling + `cls.extract_value` on each `args`. + + This is similar to Python's built-in `map` function. + + Examples: + SimdMap.map(lambda a: a + 1, {IntElWid.I32: 5}) == + SimdMap({IntElWid.I32: 6}) + SimdMap.map(lambda a: a + 1, 3) == + SimdMap({IntElWid.I8: 4, IntElWid.I16: 4, ...}) + SimdMap.map(lambda a, b: a + b, + 3, {IntElWid.I8: 4}, + ) == SimdMap({IntElWid.I8: 7}) + """ + return cls.map_with_elwid(lambda elwid, *args2: f(*args2), *args) + + def get(self, elwid, default=None, *, raise_key_error=False): + if raise_key_error: + retval = self.extract_value(elwid, self) + if retval is None: + raise KeyError() + return retval + return self.extract_value(elwid, self, default) + + def __iter__(self): + """return an iterator of (elwid, value) pairs""" + return self.__map.items() + + def __add__(self, other): + return self.map(operator.add, self, other) + + def __radd__(self, other): + return self.map(operator.add, other, self) + + def __sub__(self, other): + return self.map(operator.sub, self, other) + + def __rsub__(self, other): + return self.map(operator.sub, other, self) + + def __mul__(self, other): + return self.map(operator.mul, self, other) + + def __rmul__(self, other): + return self.map(operator.mul, other, self) + + def __floordiv__(self, other): + return self.map(operator.floordiv, self, other) + + def __rfloordiv__(self, other): + return self.map(operator.floordiv, other, self) + + def __truediv__(self, other): + return self.map(operator.truediv, self, other) + + def __rtruediv__(self, other): + return self.map(operator.truediv, other, self) + + def __mod__(self, other): + return self.map(operator.mod, self, other) + + def __rmod__(self, other): + return self.map(operator.mod, other, self) + + def __abs__(self): + return self.map(abs, self) + + def __and__(self, other): + return self.map(operator.and_, self, other) + + def __rand__(self, other): + return self.map(operator.and_, other, self) + + def __divmod__(self, other): + return self.map(divmod, self, other) + + def __ceil__(self): + return self.map(math.ceil, self) + + def __float__(self): + return self.map(float, self) + + def __floor__(self): + return self.map(math.floor, self) + + def __eq__(self, other): + if isinstance(other, SimdMap): + return self.mapping == other.mapping + return NotImplemented + + def __hash__(self): + return hash(tuple(self.mapping.get(i) for i in self.ALL_ELWIDTHS)) + + def __repr__(self): + return f"{self.__class__.__name__}({dict(self.mapping)})" + + def __invert__(self): + return self.map(operator.invert, self) + + def __lshift__(self, other): + return self.map(operator.lshift, self, other) + + def __rlshift__(self, other): + return self.map(operator.lshift, other, self) + + def __rshift__(self, other): + return self.map(operator.rshift, self, other) + + def __rrshift__(self, other): + return self.map(operator.rshift, other, self) + + def __neg__(self): + return self.map(operator.neg, self) + + def __pos__(self): + return self.map(operator.pos, self) + + def __or__(self, other): + return self.map(operator.or_, self, other) + + def __ror__(self, other): + return self.map(operator.or_, other, self) + + def __xor__(self, other): + return self.map(operator.xor, self, other) + + def __rxor__(self, other): + return self.map(operator.xor, other, self) + + def missing_elwidths(self, *, all_elwidths=None): + """an iterator of the elwidths where self doesn't have a corresponding + value""" + if all_elwidths is None: + all_elwidths = self.ALL_ELWIDTHS + for elwid in all_elwidths: + if elwid not in self.keys(): + yield elwid + + +def _check_for_missing_elwidths(name, all_elwidths=None): + missing = list(globals()[name].missing_elwidths(all_elwidths=all_elwidths)) + assert missing == [], f"{name} is missing entries for {missing}" + + +XLEN = SimdMap({ + IntElWid.I64: 64, + IntElWid.I32: 32, + IntElWid.I16: 16, + IntElWid.I8: 8, + FpElWid.F64: 64, + FpElWid.F32: 32, + FpElWid.F16: 16, + FpElWid.BF16: 16, +}) + +DEFAULT_FP_PART_COUNTS = SimdMap({ + FpElWid.F64: 4, + FpElWid.F32: 2, + FpElWid.F16: 1, + FpElWid.BF16: 1, +}) + +DEFAULT_INT_PART_COUNTS = SimdMap({ + IntElWid.I64: 8, + IntElWid.I32: 4, + IntElWid.I16: 2, + IntElWid.I8: 1, +}) + +_check_for_missing_elwidths("XLEN") +_check_for_missing_elwidths("DEFAULT_FP_PART_COUNTS", FpElWid) +_check_for_missing_elwidths("DEFAULT_INT_PART_COUNTS", IntElWid) + + +class SimdScope: + """The global scope object for SimdSignal and friends + + Members: + * part_counts: SimdMap + a map from `ElWid` values `k` to the number of parts in an element + when `self.elwid == k`. Values should be minimized, since higher values + often create bigger circuits. + + Example: + # here, an I8 element is 1 part wide + part_counts = {ElWid.I8: 1, ElWid.I16: 2, ElWid.I32: 4, ElWid.I64: 8} + + Another Example: + # here, an F16 element is 1 part wide + part_counts = {ElWid.F16: 1, ElWid.BF16: 1, ElWid.F32: 2, ElWid.F64: 4} + * simd_full_width_hint: int + the default value for SimdLayout's full_width argument, the full number + of bits in a SIMD value. + * elwid: ElWid or nmigen Value with a shape of some ElWid class + the current elwid (simd element type) + """ + + __SCOPE_STACK = [] + + @classmethod + def get(cls): + """get the current SimdScope. + + Example: + SimdScope.get(None) is None + SimdScope.get() raises ValueError + with SimdScope(...) as s: + SimdScope.get() is s + """ + if len(cls.__SCOPE_STACK) > 0: + retval = cls.__SCOPE_STACK[-1] + assert isinstance(retval, SimdScope), "inconsistent scope stack" + return retval + raise ValueError("not in a `with SimdScope()` statement") + + def __enter__(self): + self.__SCOPE_STACK.append(self) + return self + + def __exit__(self, exc_type, exc_value, traceback): + assert self.__SCOPE_STACK.pop() is self, "inconsistent scope stack" + return False + + def __init__(self, *, simd_full_width_hint=64, elwid=None, + part_counts=None, elwid_type=IntElWid, scalar=False): + # TODO: add more arguments/members and processing for integration with + self.simd_full_width_hint = simd_full_width_hint + if isinstance(elwid, (IntElWid, FpElWid)): + elwid_type = type(elwid) + if part_counts is None: + part_counts = SimdMap({elwid: 1}) + assert issubclass(elwid_type, (IntElWid, FpElWid)) + self.elwid_type = elwid_type + scalar_elwid = elwid_type(0) + if part_counts is None: + if scalar: + part_counts = SimdMap({scalar_elwid: 1}) + elif issubclass(elwid_type, FpElWid): + part_counts = DEFAULT_FP_PART_COUNTS + else: + part_counts = DEFAULT_INT_PART_COUNTS + + def check(elwid, part_count): + assert type(elwid) == elwid_type, "inconsistent ElWid types" + part_count = int(part_count) + assert part_count != 0 and (part_count & (part_count - 1)) == 0,\ + "part_counts values must all be powers of two" + return part_count + + self.part_counts = SimdMap.map_with_elwid(check, part_counts) + self.full_part_count = max(part_counts.values()) + assert self.simd_full_width_hint % self.full_part_count == 0,\ + "simd_full_width_hint must be a multiple of full_part_count" + if elwid is not None: + self.elwid = elwid + elif scalar: + self.elwid = scalar_elwid + else: + self.elwid = Signal(elwid_type) + + def __repr__(self): + return (f"SimdScope(\n" + f" simd_full_width_hint={self.simd_full_width_hint},\n" + f" elwid={self.elwid},\n" + f" elwid_type={self.elwid_type},\n" + f" part_counts={self.part_counts},\n" + f" full_part_count={self.full_part_count})") -- 2.30.2