add partitioned multiplier/adder
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 06:45:18 +0000 (07:45 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 06:45:18 +0000 (07:45 +0100)
src/ieee754/part_mul_add/__init__.py [new file with mode: 0644]
src/ieee754/part_mul_add/multiply.py [new file with mode: 0644]
src/ieee754/part_mul_add/test/__init__.py [new file with mode: 0644]
src/ieee754/part_mul_add/test/test_multiply.py [new file with mode: 0644]

diff --git a/src/ieee754/part_mul_add/__init__.py b/src/ieee754/part_mul_add/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py
new file mode 100644 (file)
index 0000000..5902967
--- /dev/null
@@ -0,0 +1,641 @@
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+"""Integer Multiplication."""
+
+from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
+from nmigen.hdl.ast import Assign
+from abc import ABCMeta, abstractmethod
+from typing import Any, NewType, Union, List, Dict, Iterable, Mapping, Optional
+from typing_extensions import final
+from nmigen.cli import main
+
+PartitionPointsIn = Mapping[int, Union[Value, bool, int]]
+
+
+class PartitionPoints(Dict[int, Value]):
+    """Partition points and corresponding ``Value``s.
+
+    The points at where an ALU is partitioned along with ``Value``s that
+    specify if the corresponding partition points are enabled.
+
+    For example: ``{1: True, 5: True, 10: True}`` with
+    ``width == 16`` specifies that the ALU is split into 4 sections:
+    * bits 0 <= ``i`` < 1
+    * bits 1 <= ``i`` < 5
+    * bits 5 <= ``i`` < 10
+    * bits 10 <= ``i`` < 16
+
+    If the partition_points were instead ``{1: True, 5: a, 10: True}``
+    where ``a`` is a 1-bit ``Signal``:
+    * If ``a`` is asserted:
+        * bits 0 <= ``i`` < 1
+        * bits 1 <= ``i`` < 5
+        * bits 5 <= ``i`` < 10
+        * bits 10 <= ``i`` < 16
+    * Otherwise
+        * bits 0 <= ``i`` < 1
+        * bits 1 <= ``i`` < 10
+        * bits 10 <= ``i`` < 16
+    """
+
+    def __init__(self, partition_points: Optional[PartitionPointsIn] = None):
+        """Create a new ``PartitionPoints``.
+
+        :param partition_points: the input partition points to values mapping.
+        """
+        super().__init__()
+        if partition_points is not None:
+            for point, enabled in partition_points.items():
+                if not isinstance(point, int):
+                    raise TypeError("point must be a non-negative integer")
+                if point < 0:
+                    raise ValueError("point must be a non-negative integer")
+                self[point] = Value.wrap(enabled)
+
+    def like(self,
+             name: Optional[str] = None,
+             src_loc_at: int = 0) -> 'PartitionPoints':
+        """Create a new ``PartitionPoints`` with ``Signal``s for all values.
+
+        :param name: the base name for the new ``Signal``s.
+        """
+        if name is None:
+            name = Signal(src_loc_at=1+src_loc_at).name  # get variable name
+        retval = PartitionPoints()
+        for point, enabled in self.items():
+            retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
+        return retval
+
+    def eq(self, rhs: 'PartitionPoints') -> Iterable[Assign]:
+        """Assign ``PartitionPoints`` using ``Signal.eq``."""
+        if set(self.keys()) != set(rhs.keys()):
+            raise ValueError("incompatible point set")
+        for point, enabled in self.items():
+            yield enabled.eq(rhs[point])
+
+    def as_mask(self, width: int) -> Value:
+        """Create a bit-mask from `self`.
+
+        Each bit in the returned mask is clear only if the partition point at
+        the same bit-index is enabled.
+
+        :param width: the bit width of the resulting mask
+        """
+        bits: List[Union[Value, bool]]
+        bits = []
+        for i in range(width):
+            if i in self:
+                bits.append(~self[i])
+            else:
+                bits.append(True)
+        return Cat(*bits)
+
+    def get_max_partition_count(self, width: int) -> int:
+        """Get the maximum number of partitions.
+
+        Gets the number of partitions when all partition points are enabled.
+        """
+        retval = 1
+        for point in self.keys():
+            if point < width:
+                retval += 1
+        return retval
+
+    def fits_in_width(self, width: int) -> bool:
+        """Check if all partition points are smaller than `width`."""
+        for point in self.keys():
+            if point >= width:
+                return False
+        return True
+
+
+@final
+class FullAdder(Elaboratable):
+    """Full Adder.
+
+    :attribute in0: the first input
+    :attribute in1: the second input
+    :attribute in2: the third input
+    :attribute sum: the sum output
+    :attribute carry: the carry output
+    """
+
+    def __init__(self, width: int):
+        """Create a ``FullAdder``.
+
+        :param width: the bit width of the input and output
+        """
+        self.in0 = Signal(width)
+        self.in1 = Signal(width)
+        self.in2 = Signal(width)
+        self.sum = Signal(width)
+        self.carry = Signal(width)
+
+    def elaborate(self, platform: Any) -> Module:
+        """Elaborate this module."""
+        m = Module()
+        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
+        m.d.comb += self.carry.eq((self.in0 & self.in1)
+                                  | (self.in1 & self.in2)
+                                  | (self.in2 & self.in0))
+        return m
+
+
+@final
+class PartitionedAdder(Elaboratable):
+    """Partitioned Adder.
+
+    :attribute width: the bit width of the input and output. Read-only.
+    :attribute a: the first input to the adder
+    :attribute b: the second input to the adder
+    :attribute output: the sum output
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
+
+    def __init__(self, width: int, partition_points: PartitionPointsIn):
+        """Create a ``PartitionedAdder``.
+
+        :param width: the bit width of the input and output
+        :param partition_points: the input partition points
+        """
+        self.width = width
+        self.a = Signal(width)
+        self.b = Signal(width)
+        self.output = Signal(width)
+        self.partition_points = PartitionPoints(partition_points)
+        if not self.partition_points.fits_in_width(width):
+            raise ValueError("partition_points doesn't fit in width")
+        expanded_width = 0
+        for i in range(self.width):
+            if i in self.partition_points:
+                expanded_width += 1
+            expanded_width += 1
+        self._expanded_width = expanded_width
+        self._expanded_a = Signal(expanded_width)
+        self._expanded_b = Signal(expanded_width)
+        self._expanded_output = Signal(expanded_width)
+
+    def elaborate(self, platform: Any) -> Module:
+        """Elaborate this module."""
+        m = Module()
+        expanded_index = 0
+        for i in range(self.width):
+            if i in self.partition_points:
+                # add extra bit set to 0 + 0 for enabled partition points
+                # and 1 + 0 for disabled partition points
+                m.d.comb += self._expanded_a[expanded_index].eq(
+                    ~self.partition_points[i])
+                m.d.comb += self._expanded_b[expanded_index].eq(0)
+                expanded_index += 1
+            m.d.comb += self._expanded_a[expanded_index].eq(self.a[i])
+            m.d.comb += self._expanded_b[expanded_index].eq(self.b[i])
+            m.d.comb += self.output[i].eq(
+                self._expanded_output[expanded_index])
+            expanded_index += 1
+        # use only one addition to take advantage of look-ahead carry and
+        # special hardware on FPGAs
+        m.d.comb += self._expanded_output.eq(
+            self._expanded_a + self._expanded_b)
+        return m
+
+
+FULL_ADDER_INPUT_COUNT = 3
+
+
+@final
+class AddReduce(Elaboratable):
+    """Add list of numbers together.
+
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
+
+    def __init__(self,
+                 inputs: Iterable[Signal],
+                 output_width: int,
+                 register_levels: Iterable[int],
+                 partition_points: PartitionPointsIn):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param register_levels: List of nesting levels that should have
+            pipeline registers.
+        :param partition_points: the input partition points.
+        """
+        self.inputs = list(inputs)
+        self._resized_inputs = [
+            Signal(output_width, name=f"resized_inputs[{i}]")
+            for i in range(len(self.inputs))]
+        self.register_levels = list(register_levels)
+        self.output = Signal(output_width)
+        self.partition_points = PartitionPoints(partition_points)
+        if not self.partition_points.fits_in_width(output_width):
+            raise ValueError("partition_points doesn't fit in output_width")
+        self._reg_partition_points = self.partition_points.like()
+        max_level = AddReduce.get_max_level(len(self.inputs))
+        for level in self.register_levels:
+            if level > max_level:
+                raise ValueError(
+                    "not enough adder levels for specified register levels")
+
+    @staticmethod
+    def get_max_level(input_count: int) -> int:
+        """Get the maximum level.
+
+        All ``register_levels`` must be less than or equal to the maximum
+        level.
+        """
+        retval = 0
+        while True:
+            groups = AddReduce.full_adder_groups(input_count)
+            if len(groups) == 0:
+                return retval
+            input_count %= FULL_ADDER_INPUT_COUNT
+            input_count += 2 * len(groups)
+            retval += 1
+
+    def next_register_levels(self) -> Iterable[int]:
+        """``Iterable`` of ``register_levels`` for next recursive level."""
+        for level in self.register_levels:
+            if level > 0:
+                yield level - 1
+
+    @staticmethod
+    def full_adder_groups(input_count: int) -> range:
+        """Get ``inputs`` indices for which a full adder should be built."""
+        return range(0,
+                     input_count - FULL_ADDER_INPUT_COUNT + 1,
+                     FULL_ADDER_INPUT_COUNT)
+
+    def elaborate(self, platform: Any) -> Module:
+        """Elaborate this module."""
+        m = Module()
+
+        # resize inputs to correct bit-width and optionally add in
+        # pipeline registers
+        resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
+                                     for i in range(len(self.inputs))]
+        if 0 in self.register_levels:
+            m.d.sync += resized_input_assignments
+            m.d.sync += self._reg_partition_points.eq(self.partition_points)
+        else:
+            m.d.comb += resized_input_assignments
+            m.d.comb += self._reg_partition_points.eq(self.partition_points)
+
+        groups = AddReduce.full_adder_groups(len(self.inputs))
+        # if there are no full adders to create, then we handle the base cases
+        # and return, otherwise we go on to the recursive case
+        if len(groups) == 0:
+            if len(self.inputs) == 0:
+                # use 0 as the default output value
+                m.d.comb += self.output.eq(0)
+            elif len(self.inputs) == 1:
+                # handle single input
+                m.d.comb += self.output.eq(self._resized_inputs[0])
+            else:
+                # base case for adding 2 or more inputs, which get recursively
+                # reduced to 2 inputs
+                assert len(self.inputs) == 2
+                adder = PartitionedAdder(len(self.output),
+                                         self._reg_partition_points)
+                m.submodules.final_adder = adder
+                m.d.comb += adder.a.eq(self._resized_inputs[0])
+                m.d.comb += adder.b.eq(self._resized_inputs[1])
+                m.d.comb += self.output.eq(adder.output)
+            return m
+        # go on to handle recursive case
+        intermediate_terms: List[Signal]
+        intermediate_terms = []
+
+        def add_intermediate_term(value: Value) -> None:
+            intermediate_term = Signal(
+                len(self.output),
+                name=f"intermediate_terms[{len(intermediate_terms)}]")
+            intermediate_terms.append(intermediate_term)
+            m.d.comb += intermediate_term.eq(value)
+
+        part_mask = self._reg_partition_points.as_mask(len(self.output))
+
+        # create full adders for this recursive level.
+        # this shrinks N terms to 2 * (N // 3) plus the remainder
+        for i in groups:
+            adder_i = FullAdder(len(self.output))
+            setattr(m.submodules, f"adder_{i}", adder_i)
+            m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
+            m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
+            m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
+            add_intermediate_term(adder_i.sum)
+            shifted_carry = adder_i.carry << 1
+            # mask out carry bits to prevent carries between partitions
+            add_intermediate_term((adder_i.carry << 1) & part_mask)
+        # handle the remaining inputs.
+        if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
+            add_intermediate_term(self._resized_inputs[-1])
+        elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
+            # Just pass the terms to the next layer, since we wouldn't gain
+            # anything by using a half adder since there would still be 2 terms
+            # and just passing the terms to the next layer saves gates.
+            add_intermediate_term(self._resized_inputs[-2])
+            add_intermediate_term(self._resized_inputs[-1])
+        else:
+            assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
+        # recursive invocation of ``AddReduce``
+        next_level = AddReduce(intermediate_terms,
+                               len(self.output),
+                               self.next_register_levels(),
+                               self._reg_partition_points)
+        m.submodules.next_level = next_level
+        m.d.comb += self.output.eq(next_level.output)
+        return m
+
+
+OP_MUL_LOW = 0
+OP_MUL_SIGNED_HIGH = 1
+OP_MUL_SIGNED_UNSIGNED_HIGH = 2  # a is signed, b is unsigned
+OP_MUL_UNSIGNED_HIGH = 3
+
+
+class Mul8_16_32_64(Elaboratable):
+    """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
+
+    Supports partitioning into any combination of 8, 16, 32, and 64-bit
+    partitions on naturally-aligned boundaries. Supports the operation being
+    set for each partition independently.
+
+    :attribute part_pts: the input partition points. Has a partition point at
+        multiples of 8 in 0 < i < 64. Each partition point's associated
+        ``Value`` is a ``Signal``. Modification not supported, except for by
+        ``Signal.eq``.
+    :attribute part_ops: the operation for each byte. The operation for a
+        particular partition is selected by assigning the selected operation
+        code to each byte in the partition. The allowed operation codes are:
+
+        :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
+            RISC-V's `mul` instruction.
+        :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
+            ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
+            instruction.
+        :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
+            where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
+            `mulhsu` instruction.
+        :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
+            ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
+            instruction.
+    """
+
+    def __init__(self, register_levels: Iterable[int] = ()):
+        self.part_pts = PartitionPoints()
+        for i in range(8, 64, 8):
+            self.part_pts[i] = Signal(name=f"part_pts_{i}")
+        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
+        self.a = Signal(64)
+        self.b = Signal(64)
+        self.output = Signal(64)
+        self.register_levels = list(register_levels)
+        self._intermediate_output = Signal(128)
+        self._delayed_part_ops = [
+            [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
+             for i in range(8)]
+            for delay in range(1 + len(self.register_levels))]
+        self._part_8 = [Signal(name=f"_part_8_{i}") for i in range(8)]
+        self._part_16 = [Signal(name=f"_part_16_{i}") for i in range(4)]
+        self._part_32 = [Signal(name=f"_part_32_{i}") for i in range(2)]
+        self._part_64 = [Signal(name=f"_part_64")]
+        self._delayed_part_8 = [
+            [Signal(name=f"_delayed_part_8_{delay}_{i}")
+             for i in range(8)]
+            for delay in range(1 + len(self.register_levels))]
+        self._delayed_part_16 = [
+            [Signal(name=f"_delayed_part_16_{delay}_{i}")
+             for i in range(4)]
+            for delay in range(1 + len(self.register_levels))]
+        self._delayed_part_32 = [
+            [Signal(name=f"_delayed_part_32_{delay}_{i}")
+             for i in range(2)]
+            for delay in range(1 + len(self.register_levels))]
+        self._delayed_part_64 = [
+            [Signal(name=f"_delayed_part_64_{delay}")]
+            for delay in range(1 + len(self.register_levels))]
+        self._output_64 = Signal(64)
+        self._output_32 = Signal(64)
+        self._output_16 = Signal(64)
+        self._output_8 = Signal(64)
+        self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
+        self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+        self._not_a_term_8 = Signal(128)
+        self._neg_lsb_a_term_8 = Signal(128)
+        self._not_b_term_8 = Signal(128)
+        self._neg_lsb_b_term_8 = Signal(128)
+        self._not_a_term_16 = Signal(128)
+        self._neg_lsb_a_term_16 = Signal(128)
+        self._not_b_term_16 = Signal(128)
+        self._neg_lsb_b_term_16 = Signal(128)
+        self._not_a_term_32 = Signal(128)
+        self._neg_lsb_a_term_32 = Signal(128)
+        self._not_b_term_32 = Signal(128)
+        self._neg_lsb_b_term_32 = Signal(128)
+        self._not_a_term_64 = Signal(128)
+        self._neg_lsb_a_term_64 = Signal(128)
+        self._not_b_term_64 = Signal(128)
+        self._neg_lsb_b_term_64 = Signal(128)
+
+    def _part_byte(self, index: int) -> Value:
+        if index == -1 or index == 7:
+            return C(True, 1)
+        assert index >= 0 and index < 8
+        return self.part_pts[index * 8 + 8]
+
+    def elaborate(self, platform: Any) -> Module:
+        m = Module()
+
+        for i in range(len(self.part_ops)):
+            m.d.comb += self._delayed_part_ops[0][i].eq(self.part_ops[i])
+            m.d.sync += [self._delayed_part_ops[j + 1][i]
+                         .eq(self._delayed_part_ops[j][i])
+                         for j in range(len(self.register_levels))]
+
+        for parts, delayed_parts in [(self._part_64, self._delayed_part_64),
+                                     (self._part_32, self._delayed_part_32),
+                                     (self._part_16, self._delayed_part_16),
+                                     (self._part_8, self._delayed_part_8)]:
+            byte_count = 8 // len(parts)
+            for i in range(len(parts)):
+                value = self._part_byte(i * byte_count - 1)
+                for j in range(i * byte_count, (i + 1) * byte_count - 1):
+                    value &= ~self._part_byte(j)
+                value &= self._part_byte((i + 1) * byte_count - 1)
+                m.d.comb += parts[i].eq(value)
+                m.d.comb += delayed_parts[0][i].eq(parts[i])
+                m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
+                             for j in range(len(self.register_levels))]
+
+        products = [[
+                Signal(16, name=f"products_{i}_{j}")
+                for j in range(8)]
+            for i in range(8)]
+
+        for a_index in range(8):
+            for b_index in range(8):
+                a = self.a.part(a_index * 8, 8)
+                b = self.b.part(b_index * 8, 8)
+                m.d.comb += products[a_index][b_index].eq(a * b)
+
+        terms = []
+
+        def add_term(value: Value,
+                     shift: int = 0,
+                     enabled: Optional[Value] = None) -> None:
+            term = Signal(128)
+            terms.append(term)
+            if enabled is not None:
+                value = Mux(enabled, value, 0)
+            if shift > 0:
+                value = Cat(Repl(C(0, 1), shift), value)
+            else:
+                assert shift == 0
+            m.d.comb += term.eq(value)
+
+        for a_index in range(8):
+            for b_index in range(8):
+                term_enabled: Value = C(True, 1)
+                min_index = min(a_index, b_index)
+                max_index = max(a_index, b_index)
+                for i in range(min_index, max_index):
+                    term_enabled &= ~self._part_byte(i)
+                add_term(products[a_index][b_index],
+                         8 * (a_index + b_index),
+                         term_enabled)
+
+        for i in range(8):
+            a_signed = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH
+            b_signed = (self.part_ops[i] == OP_MUL_LOW) \
+                | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
+            m.d.comb += self._a_signed[i].eq(a_signed)
+            m.d.comb += self._b_signed[i].eq(b_signed)
+
+        # it's fine to bitwise-or these together since they are never enabled
+        # at the same time
+        add_term(self._not_a_term_8 | self._not_a_term_16
+                 | self._not_a_term_32 | self._not_a_term_64)
+        add_term(self._neg_lsb_a_term_8 | self._neg_lsb_a_term_16
+                 | self._neg_lsb_a_term_32 | self._neg_lsb_a_term_64)
+        add_term(self._not_b_term_8 | self._not_b_term_16
+                 | self._not_b_term_32 | self._not_b_term_64)
+        add_term(self._neg_lsb_b_term_8 | self._neg_lsb_b_term_16
+                 | self._neg_lsb_b_term_32 | self._neg_lsb_b_term_64)
+
+        for not_a_term, \
+            neg_lsb_a_term, \
+            not_b_term, \
+            neg_lsb_b_term, \
+            parts in [
+                (self._not_a_term_8,
+                 self._neg_lsb_a_term_8,
+                 self._not_b_term_8,
+                 self._neg_lsb_b_term_8,
+                 self._part_8),
+                (self._not_a_term_16,
+                 self._neg_lsb_a_term_16,
+                 self._not_b_term_16,
+                 self._neg_lsb_b_term_16,
+                 self._part_16),
+                (self._not_a_term_32,
+                 self._neg_lsb_a_term_32,
+                 self._not_b_term_32,
+                 self._neg_lsb_b_term_32,
+                 self._part_32),
+                (self._not_a_term_64,
+                 self._neg_lsb_a_term_64,
+                 self._not_b_term_64,
+                 self._neg_lsb_b_term_64,
+                 self._part_64),
+                ]:
+            byte_width = 8 // len(parts)
+            bit_width = 8 * byte_width
+            for i in range(len(parts)):
+                b_enabled = parts[i] & self.a[(i + 1) * bit_width - 1] \
+                    & self._a_signed[i * byte_width]
+                a_enabled = parts[i] & self.b[(i + 1) * bit_width - 1] \
+                    & self._b_signed[i * byte_width]
+
+                # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
+                # negation operation is split into a bitwise not and a +1.
+                # likewise for 16, 32, and 64-bit values.
+                m.d.comb += [
+                    not_a_term.part(bit_width * 2 * i, bit_width * 2)
+                    .eq(Mux(a_enabled,
+                            Cat(Repl(0, bit_width),
+                                ~self.a.part(bit_width * i, bit_width)),
+                            0)),
+
+                    neg_lsb_a_term.part(bit_width * 2 * i, bit_width * 2)
+                    .eq(Cat(Repl(0, bit_width), a_enabled)),
+
+                    not_b_term.part(bit_width * 2 * i, bit_width * 2)
+                    .eq(Mux(b_enabled,
+                            Cat(Repl(0, bit_width),
+                                ~self.b.part(bit_width * i, bit_width)),
+                            0)),
+
+                    neg_lsb_b_term.part(bit_width * 2 * i, bit_width * 2)
+                    .eq(Cat(Repl(0, bit_width), b_enabled))]
+
+        expanded_part_pts = PartitionPoints()
+        for i, v in self.part_pts.items():
+            signal = Signal(name=f"expanded_part_pts_{i*2}")
+            expanded_part_pts[i * 2] = signal
+            m.d.comb += signal.eq(v)
+
+        add_reduce = AddReduce(terms,
+                               128,
+                               self.register_levels,
+                               expanded_part_pts)
+        m.submodules.add_reduce = add_reduce
+        m.d.comb += self._intermediate_output.eq(add_reduce.output)
+        m.d.comb += self._output_64.eq(
+            Mux(self._delayed_part_ops[-1][0] == OP_MUL_LOW,
+                self._intermediate_output.part(0, 64),
+                self._intermediate_output.part(64, 64)))
+        for i in range(2):
+            m.d.comb += self._output_32.part(i * 32, 32).eq(
+                Mux(self._delayed_part_ops[-1][4 * i] == OP_MUL_LOW,
+                    self._intermediate_output.part(i * 64, 32),
+                    self._intermediate_output.part(i * 64 + 32, 32)))
+        for i in range(4):
+            m.d.comb += self._output_16.part(i * 16, 16).eq(
+                Mux(self._delayed_part_ops[-1][2 * i] == OP_MUL_LOW,
+                    self._intermediate_output.part(i * 32, 16),
+                    self._intermediate_output.part(i * 32 + 16, 16)))
+        for i in range(8):
+            m.d.comb += self._output_8.part(i * 8, 8).eq(
+                Mux(self._delayed_part_ops[-1][i] == OP_MUL_LOW,
+                    self._intermediate_output.part(i * 16, 8),
+                    self._intermediate_output.part(i * 16 + 8, 8)))
+        for i in range(8):
+            m.d.comb += self.output.part(i * 8, 8).eq(
+                Mux(self._delayed_part_8[-1][i]
+                    | self._delayed_part_16[-1][i // 2],
+                    Mux(self._delayed_part_8[-1][i],
+                        self._output_8.part(i * 8, 8),
+                        self._output_16.part(i * 8, 8)),
+                    Mux(self._delayed_part_32[-1][i // 4],
+                        self._output_32.part(i * 8, 8),
+                        self._output_64.part(i * 8, 8))))
+        return m
+
+
+if __name__ == "__main__":
+    m = Mul8_16_32_64()
+    main(m, ports=[m.a,
+                   m.b,
+                   m._intermediate_output,
+                   m.output,
+                   *m.part_ops,
+                   *m.part_pts.values()])
diff --git a/src/ieee754/part_mul_add/test/__init__.py b/src/ieee754/part_mul_add/test/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/src/ieee754/part_mul_add/test/test_multiply.py b/src/ieee754/part_mul_add/test/test_multiply.py
new file mode 100644 (file)
index 0000000..ec83307
--- /dev/null
@@ -0,0 +1,770 @@
+#!/usr/bin/env python3
+# SPDX-License-Identifier: LGPL-2.1-or-later
+# See Notices.txt for copyright information
+
+from src.multiply import PartitionPoints, PartitionedAdder, AddReduce, \
+    Mul8_16_32_64, OP_MUL_LOW, OP_MUL_SIGNED_HIGH, \
+    OP_MUL_SIGNED_UNSIGNED_HIGH, OP_MUL_UNSIGNED_HIGH
+from nmigen import Signal, Module
+from nmigen.back.pysim import Simulator, Delay, Tick, Passive
+from nmigen.hdl.ast import Assign, Value
+from typing import Any, Generator, List, Union, Optional, Tuple, Iterable
+import unittest
+from hashlib import sha256
+import enum
+import pdb
+
+
+def create_simulator(module: Any,
+                     traces: List[Signal],
+                     test_name: str) -> Simulator:
+    return Simulator(module,
+                     vcd_file=open(test_name + ".vcd", "w"),
+                     gtkw_file=open(test_name + ".gtkw", "w"),
+                     traces=traces)
+
+
+AsyncProcessCommand = Union[Delay, Tick, Passive, Assign, Value]
+ProcessCommand = Optional[AsyncProcessCommand]
+AsyncProcessGenerator = Generator[AsyncProcessCommand, Union[int, None], None]
+ProcessGenerator = Generator[ProcessCommand, Union[int, None], None]
+
+
+class TestPartitionPoints(unittest.TestCase):
+    def test(self) -> None:
+        module = Module()
+        width = 16
+        mask = Signal(width)
+        partition_point_10 = Signal()
+        partition_points = PartitionPoints({1: True,
+                                            5: False,
+                                            10: partition_point_10})
+        module.d.comb += mask.eq(partition_points.as_mask(width))
+        with create_simulator(module,
+                              [mask, partition_point_10],
+                              "partition_points") as sim:
+            def async_process() -> AsyncProcessGenerator:
+                self.assertEqual((yield partition_points[1]), True)
+                self.assertEqual((yield partition_points[5]), False)
+                yield partition_point_10.eq(0)
+                yield Delay(1e-6)
+                self.assertEqual((yield mask), 0xFFFD)
+                yield partition_point_10.eq(1)
+                yield Delay(1e-6)
+                self.assertEqual((yield mask), 0xFBFD)
+
+            sim.add_process(async_process)
+            sim.run()
+
+
+class TestPartitionedAdder(unittest.TestCase):
+    def test(self) -> None:
+        width = 16
+        partition_nibbles = Signal()
+        partition_bytes = Signal()
+        module = PartitionedAdder(width,
+                                  {0x4: partition_nibbles,
+                                   0x8: partition_bytes | partition_nibbles,
+                                   0xC: partition_nibbles})
+        with create_simulator(module,
+                              [partition_nibbles,
+                               partition_bytes,
+                               module.a,
+                               module.b,
+                               module.output],
+                              "partitioned_adder") as sim:
+            def async_process() -> AsyncProcessGenerator:
+                def test_add(msg_prefix: str,
+                             *mask_list: Tuple[int, ...]) -> Any:
+                    for a, b in [(0x0000, 0x0000),
+                                 (0x1234, 0x1234),
+                                 (0xABCD, 0xABCD),
+                                 (0xFFFF, 0x0000),
+                                 (0x0000, 0x0000),
+                                 (0xFFFF, 0xFFFF),
+                                 (0x0000, 0xFFFF)]:
+                        yield module.a.eq(a)
+                        yield module.b.eq(b)
+                        yield Delay(1e-6)
+                        y = 0
+                        for mask in mask_list:
+                            y |= mask & ((a & mask) + (b & mask))
+                        output = (yield module.output)
+                        msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
+                            f" => 0x{y:X} != 0x{output:X}"
+                        self.assertEqual(y, output, msg)
+                yield partition_nibbles.eq(0)
+                yield partition_bytes.eq(0)
+                yield from test_add("16-bit", 0xFFFF)
+                yield partition_nibbles.eq(0)
+                yield partition_bytes.eq(1)
+                yield from test_add("8-bit", 0xFF00, 0x00FF)
+                yield partition_nibbles.eq(1)
+                yield partition_bytes.eq(0)
+                yield from test_add("4-bit", 0xF000, 0x0F00, 0x00F0, 0x000F)
+
+            sim.add_process(async_process)
+            sim.run()
+
+
+class GenOrCheck(enum.Enum):
+    Generate = enum.auto()
+    Check = enum.auto()
+
+
+class TestAddReduce(unittest.TestCase):
+    def calculate_input_values(self,
+                               input_count: int,
+                               key: int,
+                               extra_keys: List[int] = []
+                               ) -> (List[int], List[str]):
+        input_values = []
+        input_values_str = []
+        for i in range(input_count):
+            if key == 0:
+                value = 0
+            elif key == 1:
+                value = 0xFFFF
+            elif key == 2:
+                value = 0x0111
+            else:
+                hash_input = f"{input_count} {i} {key} {extra_keys}"
+                hash = sha256(hash_input.encode()).digest()
+                value = int.from_bytes(hash, byteorder="little")
+                value &= 0xFFFF
+            input_values.append(value)
+            input_values_str.append(f"0x{value:04X}")
+        return input_values, input_values_str
+
+    def subtest_value(self,
+                      inputs: List[Signal],
+                      module: AddReduce,
+                      mask_list: List[int],
+                      gen_or_check: GenOrCheck,
+                      values: List[int]) -> AsyncProcessGenerator:
+        if gen_or_check == GenOrCheck.Generate:
+            for i, v in zip(inputs, values):
+                yield i.eq(v)
+        yield Delay(1e-6)
+        y = 0
+        for mask in mask_list:
+            v = 0
+            for value in values:
+                v += value & mask
+            y |= mask & v
+        output = (yield module.output)
+        if gen_or_check == GenOrCheck.Check:
+            self.assertEqual(y, output, f"0x{y:X} != 0x{output:X}")
+        yield Tick()
+
+    def subtest_key(self,
+                    input_count: int,
+                    inputs: List[Signal],
+                    module: AddReduce,
+                    key: int,
+                    mask_list: List[int],
+                    gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
+        values, values_str = self.calculate_input_values(input_count, key)
+        if gen_or_check == GenOrCheck.Check:
+            with self.subTest(inputs=values_str):
+                yield from self.subtest_value(inputs,
+                                              module,
+                                              mask_list,
+                                              gen_or_check,
+                                              values)
+        else:
+            yield from self.subtest_value(inputs,
+                                          module,
+                                          mask_list,
+                                          gen_or_check,
+                                          values)
+
+    def subtest_run_sim(self,
+                        input_count: int,
+                        sim: Simulator,
+                        partition_4: Signal,
+                        partition_8: Signal,
+                        inputs: List[Signal],
+                        module: AddReduce,
+                        delay_cycles: int) -> None:
+        def generic_process(gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
+            for partition_4_value, partition_8_value, mask_list in [
+                    (0, 0, [0xFFFF]),
+                    (0, 1, [0xFF00, 0x00FF]),
+                    (1, 0, [0xFFF0, 0x000F]),
+                    (1, 1, [0xFF00, 0x00F0, 0x000F])]:
+                key_count = 8
+                if gen_or_check == GenOrCheck.Check:
+                    with self.subTest(partition_4=partition_4_value,
+                                      partition_8=partition_8_value):
+                        for key in range(key_count):
+                            with self.subTest(key=key):
+                                yield from self.subtest_key(input_count,
+                                                            inputs,
+                                                            module,
+                                                            key,
+                                                            mask_list,
+                                                            gen_or_check)
+                else:
+                    if gen_or_check == GenOrCheck.Generate:
+                        yield partition_4.eq(partition_4_value)
+                        yield partition_8.eq(partition_8_value)
+                    for key in range(key_count):
+                        yield from self.subtest_key(input_count,
+                                                    inputs,
+                                                    module,
+                                                    key,
+                                                    mask_list,
+                                                    gen_or_check)
+
+        def generate_process() -> AsyncProcessGenerator:
+            yield from generic_process(GenOrCheck.Generate)
+
+        def check_process() -> AsyncProcessGenerator:
+            if delay_cycles != 0:
+                for _ in range(delay_cycles):
+                    yield Tick()
+            yield from generic_process(GenOrCheck.Check)
+
+        sim.add_clock(2e-6)
+        sim.add_process(generate_process)
+        sim.add_process(check_process)
+        sim.run()
+
+    def subtest_file(self,
+                     input_count: int,
+                     register_levels: List[int]) -> None:
+        max_level = AddReduce.get_max_level(input_count)
+        for level in register_levels:
+            if level > max_level:
+                return
+        partition_4 = Signal()
+        partition_8 = Signal()
+        partition_points = PartitionPoints()
+        partition_points[4] = partition_4
+        partition_points[8] = partition_8
+        width = 16
+        inputs = [Signal(width, name=f"input_{i}")
+                  for i in range(input_count)]
+        module = AddReduce(inputs,
+                           width,
+                           register_levels,
+                           partition_points)
+        file_name = "add_reduce"
+        if len(register_levels) != 0:
+            file_name += f"-{'_'.join(map(repr, register_levels))}"
+        file_name += f"-{input_count:02d}"
+        with create_simulator(module,
+                              [partition_4,
+                               partition_8,
+                               *inputs,
+                               module.output],
+                              file_name) as sim:
+            self.subtest_run_sim(input_count,
+                                 sim,
+                                 partition_4,
+                                 partition_8,
+                                 inputs,
+                                 module,
+                                 len(register_levels))
+
+    def subtest_register_levels(self, register_levels: List[int]) -> None:
+        for input_count in range(0, 16):
+            with self.subTest(input_count=input_count,
+                              register_levels=repr(register_levels)):
+                self.subtest_file(input_count, register_levels)
+
+    def test_empty(self) -> None:
+        self.subtest_register_levels([])
+
+    def test_0(self) -> None:
+        self.subtest_register_levels([0])
+
+    def test_1(self) -> None:
+        self.subtest_register_levels([1])
+
+    def test_2(self) -> None:
+        self.subtest_register_levels([2])
+
+    def test_3(self) -> None:
+        self.subtest_register_levels([3])
+
+    def test_4(self) -> None:
+        self.subtest_register_levels([4])
+
+    def test_5(self) -> None:
+        self.subtest_register_levels([5])
+
+    def test_0(self) -> None:
+        self.subtest_register_levels([0])
+
+    def test_0_1(self) -> None:
+        self.subtest_register_levels([0, 1])
+
+    def test_0_1_2(self) -> None:
+        self.subtest_register_levels([0, 1, 2])
+
+    def test_0_1_2_3(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3])
+
+    def test_0_1_2_3_4(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4])
+
+    def test_0_1_2_3_4_5(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4, 5])
+
+    def test_0_2(self) -> None:
+        self.subtest_register_levels([0, 2])
+
+    def test_0_3(self) -> None:
+        self.subtest_register_levels([0, 3])
+
+    def test_0_4(self) -> None:
+        self.subtest_register_levels([0, 4])
+
+    def test_0_5(self) -> None:
+        self.subtest_register_levels([0, 5])
+
+
+class SIMDMulLane:
+    def __init__(self,
+                 a_signed: bool,
+                 b_signed: bool,
+                 bit_width: int,
+                 high_half: bool):
+        self.a_signed = a_signed
+        self.b_signed = b_signed
+        self.bit_width = bit_width
+        self.high_half = high_half
+
+    def __repr__(self):
+        return f"SIMDMulLane({self.a_signed}, {self.b_signed}, " +\
+            f"{self.bit_width}, {self.high_half})"
+
+
+class TestMul8_16_32_64(unittest.TestCase):
+    @staticmethod
+    def simd_mul(a: int, b: int, lanes: List[SIMDMulLane]) -> Tuple[int, int]:
+        output = 0
+        intermediate_output = 0
+        shift = 0
+        for lane in lanes:
+            a_signed = lane.a_signed or not lane.high_half
+            b_signed = lane.b_signed or not lane.high_half
+            mask = (1 << lane.bit_width) - 1
+            sign_bit = 1 << (lane.bit_width - 1)
+            a_part = (a >> shift) & mask
+            if a_signed and (a_part & sign_bit) != 0:
+                a_part -= 1 << lane.bit_width
+            b_part = (b >> shift) & mask
+            if b_signed and (b_part & sign_bit) != 0:
+                b_part -= 1 << lane.bit_width
+            value = a_part * b_part
+            value &= (1 << (lane.bit_width * 2)) - 1
+            intermediate_output |= value << (shift * 2)
+            if lane.high_half:
+                value >>= lane.bit_width
+            value &= mask
+            output |= value << shift
+            shift += lane.bit_width
+        return output, intermediate_output
+
+    @staticmethod
+    def get_test_cases(lanes: List[SIMDMulLane],
+                       keys: Iterable[int]) -> Iterable[Tuple[int, int]]:
+        mask = (1 << 64) - 1
+        for i in range(8):
+            hash_input = f"{i} {lanes} {list(keys)}"
+            hash = sha256(hash_input.encode()).digest()
+            value = int.from_bytes(hash, byteorder="little")
+            yield (value & mask, value >> 64)
+        a = 0
+        b = 0
+        shift = 0
+        for lane in lanes:
+            a |= 1 << (shift + lane.bit_width - 1)
+            b |= 1 << (shift + lane.bit_width - 1)
+            shift += lane.bit_width
+        yield a, b
+
+    def test_simd_mul_lane(self):
+        self.assertEqual(f"{SIMDMulLane(True, True, 8, False)}",
+                         "SIMDMulLane(True, True, 8, False)")
+
+    def test_simd_mul(self):
+        lanes = [SIMDMulLane(True,
+                             True,
+                             8,
+                             True),
+                 SIMDMulLane(False,
+                             False,
+                             8,
+                             True),
+                 SIMDMulLane(True,
+                             True,
+                             16,
+                             False),
+                 SIMDMulLane(True,
+                             False,
+                             32,
+                             True)]
+        a = 0x0123456789ABCDEF
+        b = 0xFEDCBA9876543210
+        output = 0x0121FA00FE1C28FE
+        intermediate_output = 0x0121FA0023E20B28C94DFE1C280AFEF0
+        self.assertEqual(self.simd_mul(a, b, lanes),
+                         (output, intermediate_output))
+        a = 0x8123456789ABCDEF
+        b = 0xFEDCBA9876543210
+        output = 0x81B39CB4FE1C28FE
+        intermediate_output = 0x81B39CB423E20B28C94DFE1C280AFEF0
+        self.assertEqual(self.simd_mul(a, b, lanes),
+                         (output, intermediate_output))
+
+    def test_signed_mul_from_unsigned(self):
+        for i in range(0, 0x10):
+            for j in range(0, 0x10):
+                si = i if i & 8 else i - 0x10  # signed i
+                sj = j if j & 8 else j - 0x10  # signed j
+                mulu = i * j
+                mulsu = si * j
+                mul = si * sj
+                with self.subTest(i=i, j=j, si=si, sj=sj,
+                                  mulu=mulu, mulsu=mulsu, mul=mul):
+                    mulsu2 = mulu
+                    if si < 0:
+                        mulsu2 += ~j << 4
+                        mulsu2 += 1 << 4
+                    self.assertEqual(mulsu & 0xFF, mulsu2 & 0xFF)
+                    mul2 = mulsu2
+                    if sj < 0:
+                        mul2 += ~i << 4
+                        mul2 += 1 << 4
+                    self.assertEqual(mul & 0xFF, mul2 & 0xFF)
+
+    def subtest_value(self,
+                      a: int,
+                      b: int,
+                      module: Mul8_16_32_64,
+                      lanes: List[SIMDMulLane],
+                      gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
+        if gen_or_check == GenOrCheck.Generate:
+            yield module.a.eq(a)
+            yield module.b.eq(b)
+        output2, intermediate_output2 = self.simd_mul(a, b, lanes)
+        yield Delay(1e-6)
+        if gen_or_check == GenOrCheck.Check:
+            intermediate_output = (yield module._intermediate_output)
+            self.assertEqual(intermediate_output,
+                             intermediate_output2,
+                             f"0x{intermediate_output:X} "
+                             + f"!= 0x{intermediate_output2:X}")
+            output = (yield module.output)
+            self.assertEqual(output, output2, f"0x{output:X} != 0x{output2:X}")
+        yield Tick()
+
+    def subtest_lanes_2(self,
+                        lanes: List[SIMDMulLane],
+                        module: Mul8_16_32_64,
+                        gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
+        bit_index = 8
+        part_index = 0
+        for lane in lanes:
+            if lane.high_half:
+                if lane.a_signed:
+                    if lane.b_signed:
+                        op = OP_MUL_SIGNED_HIGH
+                    else:
+                        op = OP_MUL_SIGNED_UNSIGNED_HIGH
+                else:
+                    self.assertFalse(lane.b_signed,
+                                     "unsigned * signed not supported")
+                    op = OP_MUL_UNSIGNED_HIGH
+            else:
+                op = OP_MUL_LOW
+            self.assertEqual(lane.bit_width % 8, 0)
+            for i in range(lane.bit_width // 8):
+                if gen_or_check == GenOrCheck.Generate:
+                    yield module.part_ops[part_index].eq(op)
+                part_index += 1
+            for i in range(lane.bit_width // 8 - 1):
+                if gen_or_check == GenOrCheck.Generate:
+                    yield module.part_pts[bit_index].eq(0)
+                bit_index += 8
+            if bit_index < 64 and gen_or_check == GenOrCheck.Generate:
+                yield module.part_pts[bit_index].eq(1)
+            bit_index += 8
+        self.assertEqual(part_index, 8)
+        for a, b in self.get_test_cases(lanes, ()):
+            if gen_or_check == GenOrCheck.Check:
+                with self.subTest(a=f"{a:X}", b=f"{b:X}"):
+                    yield from self.subtest_value(a, b, module, lanes, gen_or_check)
+            else:
+                yield from self.subtest_value(a, b, module, lanes, gen_or_check)
+
+    def subtest_lanes(self,
+                      lanes: List[SIMDMulLane],
+                      module: Mul8_16_32_64,
+                      gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
+        if gen_or_check == GenOrCheck.Check:
+            with self.subTest(lanes=repr(lanes)):
+                yield from self.subtest_lanes_2(lanes, module, gen_or_check)
+        else:
+            yield from self.subtest_lanes_2(lanes, module, gen_or_check)
+
+    def subtest_file(self,
+                     register_levels: List[int]) -> None:
+        module = Mul8_16_32_64(register_levels)
+        file_name = "mul8_16_32_64"
+        if len(register_levels) != 0:
+            file_name += f"-{'_'.join(map(repr, register_levels))}"
+        ports = [module.a,
+                 module.b,
+                 module._intermediate_output,
+                 module.output]
+        ports.extend(module.part_ops)
+        ports.extend(module.part_pts.values())
+        for signals in module._delayed_part_ops:
+            ports.extend(signals)
+        ports.extend(module._part_8)
+        ports.extend(module._part_16)
+        ports.extend(module._part_32)
+        ports.extend(module._part_64)
+        for signals in module._delayed_part_8:
+            ports.extend(signals)
+        for signals in module._delayed_part_16:
+            ports.extend(signals)
+        for signals in module._delayed_part_32:
+            ports.extend(signals)
+        for signals in module._delayed_part_64:
+            ports.extend(signals)
+        ports += [module._output_64,
+                  module._output_32,
+                  module._output_16,
+                  module._output_8]
+        ports.extend(module._a_signed)
+        ports.extend(module._b_signed)
+        ports += [module._not_a_term_8,
+                  module._neg_lsb_a_term_8,
+                  module._not_b_term_8,
+                  module._neg_lsb_b_term_8,
+                  module._not_a_term_16,
+                  module._neg_lsb_a_term_16,
+                  module._not_b_term_16,
+                  module._neg_lsb_b_term_16,
+                  module._not_a_term_32,
+                  module._neg_lsb_a_term_32,
+                  module._not_b_term_32,
+                  module._neg_lsb_b_term_32,
+                  module._not_a_term_64,
+                  module._neg_lsb_a_term_64,
+                  module._not_b_term_64,
+                  module._neg_lsb_b_term_64]
+        with create_simulator(module, ports, file_name) as sim:
+            def process(gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
+                for a_signed in False, True:
+                    for b_signed in False, True:
+                        if not a_signed and b_signed:
+                            continue
+                        for high_half in False, True:
+                            if not high_half and not (a_signed and b_signed):
+                                continue
+                            yield from self.subtest_lanes(
+                                [SIMDMulLane(a_signed,
+                                             b_signed,
+                                             64,
+                                             high_half)],
+                                module,
+                                gen_or_check)
+                            yield from self.subtest_lanes(
+                                [SIMDMulLane(a_signed,
+                                             b_signed,
+                                             32,
+                                             high_half)] * 2,
+                                module,
+                                gen_or_check)
+                            yield from self.subtest_lanes(
+                                [SIMDMulLane(a_signed,
+                                             b_signed,
+                                             16,
+                                             high_half)] * 4,
+                                module,
+                                gen_or_check)
+                            yield from self.subtest_lanes(
+                                [SIMDMulLane(a_signed,
+                                             b_signed,
+                                             8,
+                                             high_half)] * 8,
+                                module,
+                                gen_or_check)
+                yield from self.subtest_lanes([SIMDMulLane(False,
+                                                           False,
+                                                           32,
+                                                           True),
+                                               SIMDMulLane(False,
+                                                           False,
+                                                           16,
+                                                           True),
+                                               SIMDMulLane(False,
+                                                           False,
+                                                           8,
+                                                           True),
+                                               SIMDMulLane(False,
+                                                           False,
+                                                           8,
+                                                           True)],
+                                              module,
+                                              gen_or_check)
+                yield from self.subtest_lanes([SIMDMulLane(True,
+                                                           False,
+                                                           32,
+                                                           True),
+                                               SIMDMulLane(True,
+                                                           True,
+                                                           16,
+                                                           False),
+                                               SIMDMulLane(True,
+                                                           True,
+                                                           8,
+                                                           True),
+                                               SIMDMulLane(False,
+                                                           False,
+                                                           8,
+                                                           True)],
+                                              module,
+                                              gen_or_check)
+                yield from self.subtest_lanes([SIMDMulLane(True,
+                                                           True,
+                                                           8,
+                                                           True),
+                                               SIMDMulLane(False,
+                                                           False,
+                                                           8,
+                                                           True),
+                                               SIMDMulLane(True,
+                                                           True,
+                                                           16,
+                                                           False),
+                                               SIMDMulLane(True,
+                                                           False,
+                                                           32,
+                                                           True)],
+                                              module,
+                                              gen_or_check)
+
+            def generate_process() -> AsyncProcessGenerator:
+                yield from process(GenOrCheck.Generate)
+
+            def check_process() -> AsyncProcessGenerator:
+                if len(register_levels) != 0:
+                    for _ in register_levels:
+                        yield Tick()
+                yield from process(GenOrCheck.Check)
+
+            sim.add_clock(2e-6)
+            sim.add_process(generate_process)
+            sim.add_process(check_process)
+            sim.run()
+
+    def subtest_register_levels(self, register_levels: List[int]) -> None:
+        with self.subTest(register_levels=repr(register_levels)):
+            self.subtest_file(register_levels)
+
+    def test_empty(self) -> None:
+        self.subtest_register_levels([])
+
+    def test_0(self) -> None:
+        self.subtest_register_levels([0])
+
+    def test_1(self) -> None:
+        self.subtest_register_levels([1])
+
+    def test_2(self) -> None:
+        self.subtest_register_levels([2])
+
+    def test_3(self) -> None:
+        self.subtest_register_levels([3])
+
+    def test_4(self) -> None:
+        self.subtest_register_levels([4])
+
+    def test_5(self) -> None:
+        self.subtest_register_levels([5])
+
+    def test_6(self) -> None:
+        self.subtest_register_levels([6])
+
+    def test_7(self) -> None:
+        self.subtest_register_levels([7])
+
+    def test_8(self) -> None:
+        self.subtest_register_levels([8])
+
+    def test_9(self) -> None:
+        self.subtest_register_levels([9])
+
+    def test_10(self) -> None:
+        self.subtest_register_levels([10])
+
+    def test_0(self) -> None:
+        self.subtest_register_levels([0])
+
+    def test_0_1(self) -> None:
+        self.subtest_register_levels([0, 1])
+
+    def test_0_1_2(self) -> None:
+        self.subtest_register_levels([0, 1, 2])
+
+    def test_0_1_2_3(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3])
+
+    def test_0_1_2_3_4(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4])
+
+    def test_0_1_2_3_4_5(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4, 5])
+
+    def test_0_1_2_3_4_5_6(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6])
+
+    def test_0_1_2_3_4_5_6_7(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7])
+
+    def test_0_1_2_3_4_5_6_7_8(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7, 8])
+
+    def test_0_1_2_3_4_5_6_7_8_9(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
+
+    def test_0_1_2_3_4_5_6_7_8_9_10(self) -> None:
+        self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
+
+    def test_0_2(self) -> None:
+        self.subtest_register_levels([0, 2])
+
+    def test_0_3(self) -> None:
+        self.subtest_register_levels([0, 3])
+
+    def test_0_4(self) -> None:
+        self.subtest_register_levels([0, 4])
+
+    def test_0_5(self) -> None:
+        self.subtest_register_levels([0, 5])
+
+    def test_0_6(self) -> None:
+        self.subtest_register_levels([0, 6])
+
+    def test_0_7(self) -> None:
+        self.subtest_register_levels([0, 7])
+
+    def test_0_8(self) -> None:
+        self.subtest_register_levels([0, 8])
+
+    def test_0_9(self) -> None:
+        self.subtest_register_levels([0, 9])
+
+    def test_0_10(self) -> None:
+        self.subtest_register_levels([0, 10])
+
+if __name__ == '__main__':
+    unittest.main()