add forgotten files from last commit
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 16 Oct 2021 01:19:11 +0000 (18:19 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 16 Oct 2021 01:19:11 +0000 (18:19 -0700)
src/ieee754/part_swizzle/__init__.py [new file with mode: 0644]
src/ieee754/part_swizzle/swizzle.py [new file with mode: 0644]

diff --git a/src/ieee754/part_swizzle/__init__.py b/src/ieee754/part_swizzle/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/ieee754/part_swizzle/swizzle.py b/src/ieee754/part_swizzle/swizzle.py
new file mode 100644 (file)
index 0000000..44c1922
--- /dev/null
@@ -0,0 +1,298 @@
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+from dataclasses import dataclass
+from functools import reduce
+from typing import Dict, FrozenSet, List, Set, Tuple
+from nmigen.hdl.ast import Cat, Const, Shape, Signal, SignalKey, Value, ValueKey
+from nmigen.hdl.dsl import Module
+from nmigen.hdl.ir import Elaboratable
+from ieee754.part.partsig import SimdSignal
+
+
+@dataclass(frozen=True, unsafe_hash=True)
+class Bit:
+    def get_value(self):
+        """get the value of this bit as a nmigen `Value`"""
+        raise NotImplementedError("called abstract method")
+
+
+@dataclass(frozen=True, unsafe_hash=True)
+class ValueBit(Bit):
+    src: ValueKey
+    bit_index: int
+
+    def __init__(self, src, bit_index):
+        if not isinstance(src, ValueKey):
+            src = ValueKey(src)
+        assert isinstance(bit_index, int)
+        assert bit_index in range(len(src.value))
+        object.__setattr__(self, "src", src)
+        object.__setattr__(self, "bit_index", bit_index)
+
+    def get_value(self):
+        """get the value of this bit as a nmigen `Value`"""
+        return self.src.value[self.bit_index]
+
+    def get_assign_target_sig(self):
+        """get the Signal that assigning to this bit would assign to"""
+        if isinstance(self.src.value, Signal):
+            return self.src.value
+        raise TypeError("not a valid assignment target")
+
+    def assign(self, value, signals_map):
+        sig = self.get_assign_target_sig()
+        return signals_map[SignalKey(sig)][self.bit_index].eq(value)
+
+
+@dataclass(frozen=True, unsafe_hash=True)
+class ConstBit(Bit):
+    bit: bool
+
+    def get_value(self):
+        return Const(self.bit, 1)
+
+
+@dataclass(frozen=True)
+class Swizzle:
+    bits: List[Bit]
+
+    def __init__(self, bits=()):
+        bits = list(bits)
+        for bit in bits:
+            assert isinstance(bit, Bit)
+        object.__setattr__(self, "bits", bits)
+
+    @staticmethod
+    def from_const(value, width):
+        return Swizzle(ConstBit((value & (1 << i)) != 0) for i in range(width))
+
+    @staticmethod
+    def from_value(value):
+        value = Value.cast(value)
+        if isinstance(value, Const):
+            return Swizzle.from_const(value.value, len(value))
+        return Swizzle(ValueBit(value, i) for i in range(len(value)))
+
+    def get_value(self):
+        return Cat(*(bit.get_value() for bit in self.bits))
+
+    def get_sign(self):
+        return self.bits[-1] if len(self.bits) != 0 else ConstBit(False)
+
+    def convert_u_to(self, shape):
+        shape = Shape.cast(shape)
+        additional = shape.width - len(self.bits)
+        self.bits[shape.width:] = [ConstBit(False)] * additional
+
+    def convert_s_to(self, shape):
+        shape = Shape.cast(shape)
+        additional = shape.width - len(self.bits)
+        self.bits[shape.width:] = [self.get_sign()] * additional
+
+    def __getitem__(self, key):
+        if isinstance(key, int):
+            return Swizzle([self.bits[key]])
+        assert isinstance(key, slice)
+        return Swizzle(self.bits[key])
+
+    def __add__(self, other):
+        if isinstance(other, Swizzle):
+            return Swizzle(self.bits + other.bits)
+        return NotImplemented
+
+    def __radd__(self, other):
+        if isinstance(other, Swizzle):
+            return Swizzle(other.bits + self.bits)
+        return NotImplemented
+
+    def __iadd__(self, other):
+        assert isinstance(other, Swizzle)
+        self.bits += other.bits
+        return self
+
+    def get_assign_target_sigs(self):
+        for b in self.bits:
+            assert isinstance(b, ValueBit)
+            yield b.get_assign_target_sig()
+
+
+@dataclass(frozen=True)
+class SwizzleKey:
+    """should be elwid or something similar.
+    importantly, all SimdSignals that are used together must have equal
+    SwizzleKeys."""
+    value: ValueKey
+    possible_values: FrozenSet[int]
+
+    @staticmethod
+    def from_simd_signal(simd_signal):
+        if isinstance(simd_signal, SwizzledSimdValue):
+            return simd_signal.swizzle_key
+
+        # can't just be PartitionPoints, since those vary between
+        # SimdSignals with different padding
+        raise NotImplementedError("TODO: implement extracting a SwizzleKey "
+                                  "from a SimdSignal")
+
+    def __init__(self, value, possible_values):
+        object.__setattr__(self, "value", ValueKey(value))
+        pvalues = []
+        shape = self.value.value.shape()
+        for value in possible_values:
+            if isinstance(value, int):
+                assert value == Const.normalize(value, shape)
+            else:
+                value = Value.cast(value)
+                assert isinstance(value, Const)
+                value = value.value
+            pvalues.append(value)
+        assert len(pvalues) != 0, "SwizzleKey can't have zero possible values"
+        object.__setattr__(self, "possible_values", frozenset(pvalues))
+
+
+class ResolveSwizzle(Elaboratable):
+    def __init__(self, swizzled_simd_value):
+        assert isinstance(swizzled_simd_value, SwizzledSimdValue)
+        self.swizzled_simd_value = swizzled_simd_value
+
+    def elaborate(self, platform):
+        m = Module()
+        swizzle_key = self.swizzled_simd_value.swizzle_key
+        swizzles = self.swizzled_simd_value.swizzles
+        output = self.swizzled_simd_value.sig
+        with m.Switch(swizzle_key.value):
+            for k in sorted(swizzle_key.possible_values):
+                swizzle = swizzles[k]
+                with m.Case(k):
+                    m.d.comb += output.eq(swizzle.get_value())
+        return m
+
+
+class AssignSwizzle(Elaboratable):
+    def __init__(self, swizzled_simd_value, src_sig):
+        assert isinstance(swizzled_simd_value, SwizzledSimdValue)
+        self.swizzled_simd_value = swizzled_simd_value
+        assert isinstance(src_sig, Signal)
+        self.src_sig = src_sig
+        self.converted_src_sig = Signal.like(swizzled_simd_value._sig_internal)
+        targets = swizzled_simd_value._get_assign_target_sigs()
+        targets = sorted({SignalKey(s) for s in targets})
+
+        def make_sig(i, s):
+            return Signal.like(s.signal, name=f"outputs_{i}")
+        self.outputs = {s: make_sig(i, s) for i, s in enumerate(targets)}
+
+    def elaborate(self, platform):
+        m = Module()
+        swizzle_key = self.swizzled_simd_value.swizzle_key
+        swizzles = self.swizzled_simd_value.swizzles
+        for k, v in self.outputs.items():
+            m.d.comb += v.eq(k.signal)
+        m.d.comb += self.converted_src_sig.eq(self.src_sig)
+        with m.Switch(swizzle_key.value):
+            for k in sorted(swizzle_key.possible_values):
+                swizzle = swizzles[k]
+                with m.Case(k):
+                    for index, bit in enumerate(swizzle.bits):
+                        rhs = self.converted_src_sig[index]
+                        assert isinstance(bit, ValueBit)
+                        m.d.comb += bit.assign(rhs, self.outputs)
+        return m
+
+
+class SwizzledSimdValue(SimdSignal):
+    """the result of any number of Cat and Slice operations on
+    Signals/SimdSignals. This is specifically intended to support assignment
+    to Cat and Slice, but is also useful for reducing the number of muxes
+    chained together down to a single layer of muxes."""
+    __next_id = 0
+
+    @staticmethod
+    def from_simd_signal(simd_signal):
+        if isinstance(simd_signal, SwizzledSimdValue):
+            return simd_signal
+        assert isinstance(simd_signal, SimdSignal)
+        swizzle_key = SwizzleKey.from_simd_signal(simd_signal)
+        swizzle = Swizzle.from_value(simd_signal.sig)
+        retval = SwizzledSimdValue(swizzle_key, swizzle)
+        retval.set_module(simd_signal.m)
+        return retval
+
+    @staticmethod
+    def __do_splat(swizzle_key, value):
+        """splat a non-simd value, returning a SimdSignal"""
+        raise NotImplementedError("TODO: need splat implementation")
+
+    def __do_convert_rhs_to_simd_signal_like_self(self, rhs):
+        """convert a value to be a SimdSignal of the same layout/shape as self,
+        returning a SimdSignal."""
+        raise NotImplementedError("TODO: need conversion implementation")
+
+    @staticmethod
+    def from_value(swizzle_key, value):
+        if not isinstance(value, SimdSignal):
+            value = SwizzledSimdValue.__do_splat(swizzle_key, value)
+        retval = SwizzledSimdValue.from_simd_signal(value)
+        assert swizzle_key == retval.swizzle_key
+        return retval
+
+    @classmethod
+    def __make_name(cls):
+        id_ = cls.__next_id
+        cls.__next_id = id_ + 1
+        return f"swizzle_{id_}"
+
+    def __init__(self, swizzle_key, swizzles):
+        assert isinstance(swizzle_key, SwizzleKey)
+        self.swizzle_key = swizzle_key
+        possible_keys = swizzle_key.possible_values
+        if isinstance(swizzles, Swizzle):
+            self.swizzles = {k: swizzles for k in possible_keys}
+        else:
+            self.swizzles = {}
+            for k in possible_keys:
+                swizzle = swizzles[k]
+                assert isinstance(swizzle, Swizzle)
+                self.swizzles[k] = swizzle
+        width = None
+        for swizzle in self.swizzles.values():
+            if width is None:
+                width = len(swizzle.bits)
+            assert width == len(swizzle.bits), \
+                "inconsistent swizzle widths"
+        assert width is not None
+        self.__sig_need_setup = False  # ignore accesses during __init__
+        super().__init__(swizzle_key.value, width, name="output")
+        self.__sig_need_setup = True
+
+    @property
+    def sig(self):
+        # override sig to handle lazily adding the ResolveSwizzle submodule
+        if self.__sig_need_setup:
+            self.__sig_need_setup = False
+            submodule = ResolveSwizzle(self)
+            setattr(self.m.submodules, self.__make_name(), submodule)
+        return self._sig_internal
+
+    @sig.setter
+    def sig(self, value):
+        assert isinstance(value, Signal)
+        self._sig_internal = value
+
+    def _get_assign_target_sigs(self):
+        for swizzle in self.swizzles.values():
+            yield from swizzle.get_assign_target_sigs()
+
+    def __Assign__(self, val, *, src_loc_at=0):
+        rhs = self.__do_convert_rhs_to_simd_signal_like_self(val)
+        assert isinstance(rhs, SimdSignal)
+        submodule = AssignSwizzle(self, rhs.sig)
+        setattr(self.m.submodules, self.__make_name(), submodule)
+        return [k.signal.eq(v) for k, v in submodule.outputs.items()]
+
+    def __Cat__(self, *args, src_loc_at=0):
+        raise NotImplementedError("TODO: implement")
+
+    def __Slice__(self, start, stop, *, src_loc_at=0):
+        raise NotImplementedError("TODO: implement")