From 387de2702ad348b62c6067135da8a1e8b43b175e Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Sat, 17 Aug 2019 07:45:18 +0100 Subject: [PATCH] add partitioned multiplier/adder --- src/ieee754/part_mul_add/__init__.py | 0 src/ieee754/part_mul_add/multiply.py | 641 +++++++++++++++ src/ieee754/part_mul_add/test/__init__.py | 0 .../part_mul_add/test/test_multiply.py | 770 ++++++++++++++++++ 4 files changed, 1411 insertions(+) create mode 100644 src/ieee754/part_mul_add/__init__.py create mode 100644 src/ieee754/part_mul_add/multiply.py create mode 100644 src/ieee754/part_mul_add/test/__init__.py create mode 100644 src/ieee754/part_mul_add/test/test_multiply.py diff --git a/src/ieee754/part_mul_add/__init__.py b/src/ieee754/part_mul_add/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py new file mode 100644 index 00000000..5902967c --- /dev/null +++ b/src/ieee754/part_mul_add/multiply.py @@ -0,0 +1,641 @@ +# SPDX-License-Identifier: LGPL-2.1-or-later +# See Notices.txt for copyright information +"""Integer Multiplication.""" + +from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl +from nmigen.hdl.ast import Assign +from abc import ABCMeta, abstractmethod +from typing import Any, NewType, Union, List, Dict, Iterable, Mapping, Optional +from typing_extensions import final +from nmigen.cli import main + +PartitionPointsIn = Mapping[int, Union[Value, bool, int]] + + +class PartitionPoints(Dict[int, Value]): + """Partition points and corresponding ``Value``s. + + The points at where an ALU is partitioned along with ``Value``s that + specify if the corresponding partition points are enabled. + + For example: ``{1: True, 5: True, 10: True}`` with + ``width == 16`` specifies that the ALU is split into 4 sections: + * bits 0 <= ``i`` < 1 + * bits 1 <= ``i`` < 5 + * bits 5 <= ``i`` < 10 + * bits 10 <= ``i`` < 16 + + If the partition_points were instead ``{1: True, 5: a, 10: True}`` + where ``a`` is a 1-bit ``Signal``: + * If ``a`` is asserted: + * bits 0 <= ``i`` < 1 + * bits 1 <= ``i`` < 5 + * bits 5 <= ``i`` < 10 + * bits 10 <= ``i`` < 16 + * Otherwise + * bits 0 <= ``i`` < 1 + * bits 1 <= ``i`` < 10 + * bits 10 <= ``i`` < 16 + """ + + def __init__(self, partition_points: Optional[PartitionPointsIn] = None): + """Create a new ``PartitionPoints``. + + :param partition_points: the input partition points to values mapping. + """ + super().__init__() + if partition_points is not None: + for point, enabled in partition_points.items(): + if not isinstance(point, int): + raise TypeError("point must be a non-negative integer") + if point < 0: + raise ValueError("point must be a non-negative integer") + self[point] = Value.wrap(enabled) + + def like(self, + name: Optional[str] = None, + src_loc_at: int = 0) -> 'PartitionPoints': + """Create a new ``PartitionPoints`` with ``Signal``s for all values. + + :param name: the base name for the new ``Signal``s. + """ + if name is None: + name = Signal(src_loc_at=1+src_loc_at).name # get variable name + retval = PartitionPoints() + for point, enabled in self.items(): + retval[point] = Signal(enabled.shape(), name=f"{name}_{point}") + return retval + + def eq(self, rhs: 'PartitionPoints') -> Iterable[Assign]: + """Assign ``PartitionPoints`` using ``Signal.eq``.""" + if set(self.keys()) != set(rhs.keys()): + raise ValueError("incompatible point set") + for point, enabled in self.items(): + yield enabled.eq(rhs[point]) + + def as_mask(self, width: int) -> Value: + """Create a bit-mask from `self`. + + Each bit in the returned mask is clear only if the partition point at + the same bit-index is enabled. + + :param width: the bit width of the resulting mask + """ + bits: List[Union[Value, bool]] + bits = [] + for i in range(width): + if i in self: + bits.append(~self[i]) + else: + bits.append(True) + return Cat(*bits) + + def get_max_partition_count(self, width: int) -> int: + """Get the maximum number of partitions. + + Gets the number of partitions when all partition points are enabled. + """ + retval = 1 + for point in self.keys(): + if point < width: + retval += 1 + return retval + + def fits_in_width(self, width: int) -> bool: + """Check if all partition points are smaller than `width`.""" + for point in self.keys(): + if point >= width: + return False + return True + + +@final +class FullAdder(Elaboratable): + """Full Adder. + + :attribute in0: the first input + :attribute in1: the second input + :attribute in2: the third input + :attribute sum: the sum output + :attribute carry: the carry output + """ + + def __init__(self, width: int): + """Create a ``FullAdder``. + + :param width: the bit width of the input and output + """ + self.in0 = Signal(width) + self.in1 = Signal(width) + self.in2 = Signal(width) + self.sum = Signal(width) + self.carry = Signal(width) + + def elaborate(self, platform: Any) -> Module: + """Elaborate this module.""" + m = Module() + m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2) + m.d.comb += self.carry.eq((self.in0 & self.in1) + | (self.in1 & self.in2) + | (self.in2 & self.in0)) + return m + + +@final +class PartitionedAdder(Elaboratable): + """Partitioned Adder. + + :attribute width: the bit width of the input and output. Read-only. + :attribute a: the first input to the adder + :attribute b: the second input to the adder + :attribute output: the sum output + :attribute partition_points: the input partition points. Modification not + supported, except for by ``Signal.eq``. + """ + + def __init__(self, width: int, partition_points: PartitionPointsIn): + """Create a ``PartitionedAdder``. + + :param width: the bit width of the input and output + :param partition_points: the input partition points + """ + self.width = width + self.a = Signal(width) + self.b = Signal(width) + self.output = Signal(width) + self.partition_points = PartitionPoints(partition_points) + if not self.partition_points.fits_in_width(width): + raise ValueError("partition_points doesn't fit in width") + expanded_width = 0 + for i in range(self.width): + if i in self.partition_points: + expanded_width += 1 + expanded_width += 1 + self._expanded_width = expanded_width + self._expanded_a = Signal(expanded_width) + self._expanded_b = Signal(expanded_width) + self._expanded_output = Signal(expanded_width) + + def elaborate(self, platform: Any) -> Module: + """Elaborate this module.""" + m = Module() + expanded_index = 0 + for i in range(self.width): + if i in self.partition_points: + # add extra bit set to 0 + 0 for enabled partition points + # and 1 + 0 for disabled partition points + m.d.comb += self._expanded_a[expanded_index].eq( + ~self.partition_points[i]) + m.d.comb += self._expanded_b[expanded_index].eq(0) + expanded_index += 1 + m.d.comb += self._expanded_a[expanded_index].eq(self.a[i]) + m.d.comb += self._expanded_b[expanded_index].eq(self.b[i]) + m.d.comb += self.output[i].eq( + self._expanded_output[expanded_index]) + expanded_index += 1 + # use only one addition to take advantage of look-ahead carry and + # special hardware on FPGAs + m.d.comb += self._expanded_output.eq( + self._expanded_a + self._expanded_b) + return m + + +FULL_ADDER_INPUT_COUNT = 3 + + +@final +class AddReduce(Elaboratable): + """Add list of numbers together. + + :attribute inputs: input ``Signal``s to be summed. Modification not + supported, except for by ``Signal.eq``. + :attribute register_levels: List of nesting levels that should have + pipeline registers. + :attribute output: output sum. + :attribute partition_points: the input partition points. Modification not + supported, except for by ``Signal.eq``. + """ + + def __init__(self, + inputs: Iterable[Signal], + output_width: int, + register_levels: Iterable[int], + partition_points: PartitionPointsIn): + """Create an ``AddReduce``. + + :param inputs: input ``Signal``s to be summed. + :param output_width: bit-width of ``output``. + :param register_levels: List of nesting levels that should have + pipeline registers. + :param partition_points: the input partition points. + """ + self.inputs = list(inputs) + self._resized_inputs = [ + Signal(output_width, name=f"resized_inputs[{i}]") + for i in range(len(self.inputs))] + self.register_levels = list(register_levels) + self.output = Signal(output_width) + self.partition_points = PartitionPoints(partition_points) + if not self.partition_points.fits_in_width(output_width): + raise ValueError("partition_points doesn't fit in output_width") + self._reg_partition_points = self.partition_points.like() + max_level = AddReduce.get_max_level(len(self.inputs)) + for level in self.register_levels: + if level > max_level: + raise ValueError( + "not enough adder levels for specified register levels") + + @staticmethod + def get_max_level(input_count: int) -> int: + """Get the maximum level. + + All ``register_levels`` must be less than or equal to the maximum + level. + """ + retval = 0 + while True: + groups = AddReduce.full_adder_groups(input_count) + if len(groups) == 0: + return retval + input_count %= FULL_ADDER_INPUT_COUNT + input_count += 2 * len(groups) + retval += 1 + + def next_register_levels(self) -> Iterable[int]: + """``Iterable`` of ``register_levels`` for next recursive level.""" + for level in self.register_levels: + if level > 0: + yield level - 1 + + @staticmethod + def full_adder_groups(input_count: int) -> range: + """Get ``inputs`` indices for which a full adder should be built.""" + return range(0, + input_count - FULL_ADDER_INPUT_COUNT + 1, + FULL_ADDER_INPUT_COUNT) + + def elaborate(self, platform: Any) -> Module: + """Elaborate this module.""" + m = Module() + + # resize inputs to correct bit-width and optionally add in + # pipeline registers + resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i]) + for i in range(len(self.inputs))] + if 0 in self.register_levels: + m.d.sync += resized_input_assignments + m.d.sync += self._reg_partition_points.eq(self.partition_points) + else: + m.d.comb += resized_input_assignments + m.d.comb += self._reg_partition_points.eq(self.partition_points) + + groups = AddReduce.full_adder_groups(len(self.inputs)) + # if there are no full adders to create, then we handle the base cases + # and return, otherwise we go on to the recursive case + if len(groups) == 0: + if len(self.inputs) == 0: + # use 0 as the default output value + m.d.comb += self.output.eq(0) + elif len(self.inputs) == 1: + # handle single input + m.d.comb += self.output.eq(self._resized_inputs[0]) + else: + # base case for adding 2 or more inputs, which get recursively + # reduced to 2 inputs + assert len(self.inputs) == 2 + adder = PartitionedAdder(len(self.output), + self._reg_partition_points) + m.submodules.final_adder = adder + m.d.comb += adder.a.eq(self._resized_inputs[0]) + m.d.comb += adder.b.eq(self._resized_inputs[1]) + m.d.comb += self.output.eq(adder.output) + return m + # go on to handle recursive case + intermediate_terms: List[Signal] + intermediate_terms = [] + + def add_intermediate_term(value: Value) -> None: + intermediate_term = Signal( + len(self.output), + name=f"intermediate_terms[{len(intermediate_terms)}]") + intermediate_terms.append(intermediate_term) + m.d.comb += intermediate_term.eq(value) + + part_mask = self._reg_partition_points.as_mask(len(self.output)) + + # create full adders for this recursive level. + # this shrinks N terms to 2 * (N // 3) plus the remainder + for i in groups: + adder_i = FullAdder(len(self.output)) + setattr(m.submodules, f"adder_{i}", adder_i) + m.d.comb += adder_i.in0.eq(self._resized_inputs[i]) + m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1]) + m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2]) + add_intermediate_term(adder_i.sum) + shifted_carry = adder_i.carry << 1 + # mask out carry bits to prevent carries between partitions + add_intermediate_term((adder_i.carry << 1) & part_mask) + # handle the remaining inputs. + if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1: + add_intermediate_term(self._resized_inputs[-1]) + elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2: + # Just pass the terms to the next layer, since we wouldn't gain + # anything by using a half adder since there would still be 2 terms + # and just passing the terms to the next layer saves gates. + add_intermediate_term(self._resized_inputs[-2]) + add_intermediate_term(self._resized_inputs[-1]) + else: + assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0 + # recursive invocation of ``AddReduce`` + next_level = AddReduce(intermediate_terms, + len(self.output), + self.next_register_levels(), + self._reg_partition_points) + m.submodules.next_level = next_level + m.d.comb += self.output.eq(next_level.output) + return m + + +OP_MUL_LOW = 0 +OP_MUL_SIGNED_HIGH = 1 +OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned +OP_MUL_UNSIGNED_HIGH = 3 + + +class Mul8_16_32_64(Elaboratable): + """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier. + + Supports partitioning into any combination of 8, 16, 32, and 64-bit + partitions on naturally-aligned boundaries. Supports the operation being + set for each partition independently. + + :attribute part_pts: the input partition points. Has a partition point at + multiples of 8 in 0 < i < 64. Each partition point's associated + ``Value`` is a ``Signal``. Modification not supported, except for by + ``Signal.eq``. + :attribute part_ops: the operation for each byte. The operation for a + particular partition is selected by assigning the selected operation + code to each byte in the partition. The allowed operation codes are: + + :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to + RISC-V's `mul` instruction. + :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both + ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh` + instruction. + :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product + where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's + `mulhsu` instruction. + :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both + ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu` + instruction. + """ + + def __init__(self, register_levels: Iterable[int] = ()): + self.part_pts = PartitionPoints() + for i in range(8, 64, 8): + self.part_pts[i] = Signal(name=f"part_pts_{i}") + self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)] + self.a = Signal(64) + self.b = Signal(64) + self.output = Signal(64) + self.register_levels = list(register_levels) + self._intermediate_output = Signal(128) + self._delayed_part_ops = [ + [Signal(2, name=f"_delayed_part_ops_{delay}_{i}") + for i in range(8)] + for delay in range(1 + len(self.register_levels))] + self._part_8 = [Signal(name=f"_part_8_{i}") for i in range(8)] + self._part_16 = [Signal(name=f"_part_16_{i}") for i in range(4)] + self._part_32 = [Signal(name=f"_part_32_{i}") for i in range(2)] + self._part_64 = [Signal(name=f"_part_64")] + self._delayed_part_8 = [ + [Signal(name=f"_delayed_part_8_{delay}_{i}") + for i in range(8)] + for delay in range(1 + len(self.register_levels))] + self._delayed_part_16 = [ + [Signal(name=f"_delayed_part_16_{delay}_{i}") + for i in range(4)] + for delay in range(1 + len(self.register_levels))] + self._delayed_part_32 = [ + [Signal(name=f"_delayed_part_32_{delay}_{i}") + for i in range(2)] + for delay in range(1 + len(self.register_levels))] + self._delayed_part_64 = [ + [Signal(name=f"_delayed_part_64_{delay}")] + for delay in range(1 + len(self.register_levels))] + self._output_64 = Signal(64) + self._output_32 = Signal(64) + self._output_16 = Signal(64) + self._output_8 = Signal(64) + self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)] + self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)] + self._not_a_term_8 = Signal(128) + self._neg_lsb_a_term_8 = Signal(128) + self._not_b_term_8 = Signal(128) + self._neg_lsb_b_term_8 = Signal(128) + self._not_a_term_16 = Signal(128) + self._neg_lsb_a_term_16 = Signal(128) + self._not_b_term_16 = Signal(128) + self._neg_lsb_b_term_16 = Signal(128) + self._not_a_term_32 = Signal(128) + self._neg_lsb_a_term_32 = Signal(128) + self._not_b_term_32 = Signal(128) + self._neg_lsb_b_term_32 = Signal(128) + self._not_a_term_64 = Signal(128) + self._neg_lsb_a_term_64 = Signal(128) + self._not_b_term_64 = Signal(128) + self._neg_lsb_b_term_64 = Signal(128) + + def _part_byte(self, index: int) -> Value: + if index == -1 or index == 7: + return C(True, 1) + assert index >= 0 and index < 8 + return self.part_pts[index * 8 + 8] + + def elaborate(self, platform: Any) -> Module: + m = Module() + + for i in range(len(self.part_ops)): + m.d.comb += self._delayed_part_ops[0][i].eq(self.part_ops[i]) + m.d.sync += [self._delayed_part_ops[j + 1][i] + .eq(self._delayed_part_ops[j][i]) + for j in range(len(self.register_levels))] + + for parts, delayed_parts in [(self._part_64, self._delayed_part_64), + (self._part_32, self._delayed_part_32), + (self._part_16, self._delayed_part_16), + (self._part_8, self._delayed_part_8)]: + byte_count = 8 // len(parts) + for i in range(len(parts)): + value = self._part_byte(i * byte_count - 1) + for j in range(i * byte_count, (i + 1) * byte_count - 1): + value &= ~self._part_byte(j) + value &= self._part_byte((i + 1) * byte_count - 1) + m.d.comb += parts[i].eq(value) + m.d.comb += delayed_parts[0][i].eq(parts[i]) + m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i]) + for j in range(len(self.register_levels))] + + products = [[ + Signal(16, name=f"products_{i}_{j}") + for j in range(8)] + for i in range(8)] + + for a_index in range(8): + for b_index in range(8): + a = self.a.part(a_index * 8, 8) + b = self.b.part(b_index * 8, 8) + m.d.comb += products[a_index][b_index].eq(a * b) + + terms = [] + + def add_term(value: Value, + shift: int = 0, + enabled: Optional[Value] = None) -> None: + term = Signal(128) + terms.append(term) + if enabled is not None: + value = Mux(enabled, value, 0) + if shift > 0: + value = Cat(Repl(C(0, 1), shift), value) + else: + assert shift == 0 + m.d.comb += term.eq(value) + + for a_index in range(8): + for b_index in range(8): + term_enabled: Value = C(True, 1) + min_index = min(a_index, b_index) + max_index = max(a_index, b_index) + for i in range(min_index, max_index): + term_enabled &= ~self._part_byte(i) + add_term(products[a_index][b_index], + 8 * (a_index + b_index), + term_enabled) + + for i in range(8): + a_signed = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH + b_signed = (self.part_ops[i] == OP_MUL_LOW) \ + | (self.part_ops[i] == OP_MUL_SIGNED_HIGH) + m.d.comb += self._a_signed[i].eq(a_signed) + m.d.comb += self._b_signed[i].eq(b_signed) + + # it's fine to bitwise-or these together since they are never enabled + # at the same time + add_term(self._not_a_term_8 | self._not_a_term_16 + | self._not_a_term_32 | self._not_a_term_64) + add_term(self._neg_lsb_a_term_8 | self._neg_lsb_a_term_16 + | self._neg_lsb_a_term_32 | self._neg_lsb_a_term_64) + add_term(self._not_b_term_8 | self._not_b_term_16 + | self._not_b_term_32 | self._not_b_term_64) + add_term(self._neg_lsb_b_term_8 | self._neg_lsb_b_term_16 + | self._neg_lsb_b_term_32 | self._neg_lsb_b_term_64) + + for not_a_term, \ + neg_lsb_a_term, \ + not_b_term, \ + neg_lsb_b_term, \ + parts in [ + (self._not_a_term_8, + self._neg_lsb_a_term_8, + self._not_b_term_8, + self._neg_lsb_b_term_8, + self._part_8), + (self._not_a_term_16, + self._neg_lsb_a_term_16, + self._not_b_term_16, + self._neg_lsb_b_term_16, + self._part_16), + (self._not_a_term_32, + self._neg_lsb_a_term_32, + self._not_b_term_32, + self._neg_lsb_b_term_32, + self._part_32), + (self._not_a_term_64, + self._neg_lsb_a_term_64, + self._not_b_term_64, + self._neg_lsb_b_term_64, + self._part_64), + ]: + byte_width = 8 // len(parts) + bit_width = 8 * byte_width + for i in range(len(parts)): + b_enabled = parts[i] & self.a[(i + 1) * bit_width - 1] \ + & self._a_signed[i * byte_width] + a_enabled = parts[i] & self.b[(i + 1) * bit_width - 1] \ + & self._b_signed[i * byte_width] + + # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the + # negation operation is split into a bitwise not and a +1. + # likewise for 16, 32, and 64-bit values. + m.d.comb += [ + not_a_term.part(bit_width * 2 * i, bit_width * 2) + .eq(Mux(a_enabled, + Cat(Repl(0, bit_width), + ~self.a.part(bit_width * i, bit_width)), + 0)), + + neg_lsb_a_term.part(bit_width * 2 * i, bit_width * 2) + .eq(Cat(Repl(0, bit_width), a_enabled)), + + not_b_term.part(bit_width * 2 * i, bit_width * 2) + .eq(Mux(b_enabled, + Cat(Repl(0, bit_width), + ~self.b.part(bit_width * i, bit_width)), + 0)), + + neg_lsb_b_term.part(bit_width * 2 * i, bit_width * 2) + .eq(Cat(Repl(0, bit_width), b_enabled))] + + expanded_part_pts = PartitionPoints() + for i, v in self.part_pts.items(): + signal = Signal(name=f"expanded_part_pts_{i*2}") + expanded_part_pts[i * 2] = signal + m.d.comb += signal.eq(v) + + add_reduce = AddReduce(terms, + 128, + self.register_levels, + expanded_part_pts) + m.submodules.add_reduce = add_reduce + m.d.comb += self._intermediate_output.eq(add_reduce.output) + m.d.comb += self._output_64.eq( + Mux(self._delayed_part_ops[-1][0] == OP_MUL_LOW, + self._intermediate_output.part(0, 64), + self._intermediate_output.part(64, 64))) + for i in range(2): + m.d.comb += self._output_32.part(i * 32, 32).eq( + Mux(self._delayed_part_ops[-1][4 * i] == OP_MUL_LOW, + self._intermediate_output.part(i * 64, 32), + self._intermediate_output.part(i * 64 + 32, 32))) + for i in range(4): + m.d.comb += self._output_16.part(i * 16, 16).eq( + Mux(self._delayed_part_ops[-1][2 * i] == OP_MUL_LOW, + self._intermediate_output.part(i * 32, 16), + self._intermediate_output.part(i * 32 + 16, 16))) + for i in range(8): + m.d.comb += self._output_8.part(i * 8, 8).eq( + Mux(self._delayed_part_ops[-1][i] == OP_MUL_LOW, + self._intermediate_output.part(i * 16, 8), + self._intermediate_output.part(i * 16 + 8, 8))) + for i in range(8): + m.d.comb += self.output.part(i * 8, 8).eq( + Mux(self._delayed_part_8[-1][i] + | self._delayed_part_16[-1][i // 2], + Mux(self._delayed_part_8[-1][i], + self._output_8.part(i * 8, 8), + self._output_16.part(i * 8, 8)), + Mux(self._delayed_part_32[-1][i // 4], + self._output_32.part(i * 8, 8), + self._output_64.part(i * 8, 8)))) + return m + + +if __name__ == "__main__": + m = Mul8_16_32_64() + main(m, ports=[m.a, + m.b, + m._intermediate_output, + m.output, + *m.part_ops, + *m.part_pts.values()]) diff --git a/src/ieee754/part_mul_add/test/__init__.py b/src/ieee754/part_mul_add/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ieee754/part_mul_add/test/test_multiply.py b/src/ieee754/part_mul_add/test/test_multiply.py new file mode 100644 index 00000000..ec833073 --- /dev/null +++ b/src/ieee754/part_mul_add/test/test_multiply.py @@ -0,0 +1,770 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-2.1-or-later +# See Notices.txt for copyright information + +from src.multiply import PartitionPoints, PartitionedAdder, AddReduce, \ + Mul8_16_32_64, OP_MUL_LOW, OP_MUL_SIGNED_HIGH, \ + OP_MUL_SIGNED_UNSIGNED_HIGH, OP_MUL_UNSIGNED_HIGH +from nmigen import Signal, Module +from nmigen.back.pysim import Simulator, Delay, Tick, Passive +from nmigen.hdl.ast import Assign, Value +from typing import Any, Generator, List, Union, Optional, Tuple, Iterable +import unittest +from hashlib import sha256 +import enum +import pdb + + +def create_simulator(module: Any, + traces: List[Signal], + test_name: str) -> Simulator: + return Simulator(module, + vcd_file=open(test_name + ".vcd", "w"), + gtkw_file=open(test_name + ".gtkw", "w"), + traces=traces) + + +AsyncProcessCommand = Union[Delay, Tick, Passive, Assign, Value] +ProcessCommand = Optional[AsyncProcessCommand] +AsyncProcessGenerator = Generator[AsyncProcessCommand, Union[int, None], None] +ProcessGenerator = Generator[ProcessCommand, Union[int, None], None] + + +class TestPartitionPoints(unittest.TestCase): + def test(self) -> None: + module = Module() + width = 16 + mask = Signal(width) + partition_point_10 = Signal() + partition_points = PartitionPoints({1: True, + 5: False, + 10: partition_point_10}) + module.d.comb += mask.eq(partition_points.as_mask(width)) + with create_simulator(module, + [mask, partition_point_10], + "partition_points") as sim: + def async_process() -> AsyncProcessGenerator: + self.assertEqual((yield partition_points[1]), True) + self.assertEqual((yield partition_points[5]), False) + yield partition_point_10.eq(0) + yield Delay(1e-6) + self.assertEqual((yield mask), 0xFFFD) + yield partition_point_10.eq(1) + yield Delay(1e-6) + self.assertEqual((yield mask), 0xFBFD) + + sim.add_process(async_process) + sim.run() + + +class TestPartitionedAdder(unittest.TestCase): + def test(self) -> None: + width = 16 + partition_nibbles = Signal() + partition_bytes = Signal() + module = PartitionedAdder(width, + {0x4: partition_nibbles, + 0x8: partition_bytes | partition_nibbles, + 0xC: partition_nibbles}) + with create_simulator(module, + [partition_nibbles, + partition_bytes, + module.a, + module.b, + module.output], + "partitioned_adder") as sim: + def async_process() -> AsyncProcessGenerator: + def test_add(msg_prefix: str, + *mask_list: Tuple[int, ...]) -> Any: + for a, b in [(0x0000, 0x0000), + (0x1234, 0x1234), + (0xABCD, 0xABCD), + (0xFFFF, 0x0000), + (0x0000, 0x0000), + (0xFFFF, 0xFFFF), + (0x0000, 0xFFFF)]: + yield module.a.eq(a) + yield module.b.eq(b) + yield Delay(1e-6) + y = 0 + for mask in mask_list: + y |= mask & ((a & mask) + (b & mask)) + output = (yield module.output) + msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \ + f" => 0x{y:X} != 0x{output:X}" + self.assertEqual(y, output, msg) + yield partition_nibbles.eq(0) + yield partition_bytes.eq(0) + yield from test_add("16-bit", 0xFFFF) + yield partition_nibbles.eq(0) + yield partition_bytes.eq(1) + yield from test_add("8-bit", 0xFF00, 0x00FF) + yield partition_nibbles.eq(1) + yield partition_bytes.eq(0) + yield from test_add("4-bit", 0xF000, 0x0F00, 0x00F0, 0x000F) + + sim.add_process(async_process) + sim.run() + + +class GenOrCheck(enum.Enum): + Generate = enum.auto() + Check = enum.auto() + + +class TestAddReduce(unittest.TestCase): + def calculate_input_values(self, + input_count: int, + key: int, + extra_keys: List[int] = [] + ) -> (List[int], List[str]): + input_values = [] + input_values_str = [] + for i in range(input_count): + if key == 0: + value = 0 + elif key == 1: + value = 0xFFFF + elif key == 2: + value = 0x0111 + else: + hash_input = f"{input_count} {i} {key} {extra_keys}" + hash = sha256(hash_input.encode()).digest() + value = int.from_bytes(hash, byteorder="little") + value &= 0xFFFF + input_values.append(value) + input_values_str.append(f"0x{value:04X}") + return input_values, input_values_str + + def subtest_value(self, + inputs: List[Signal], + module: AddReduce, + mask_list: List[int], + gen_or_check: GenOrCheck, + values: List[int]) -> AsyncProcessGenerator: + if gen_or_check == GenOrCheck.Generate: + for i, v in zip(inputs, values): + yield i.eq(v) + yield Delay(1e-6) + y = 0 + for mask in mask_list: + v = 0 + for value in values: + v += value & mask + y |= mask & v + output = (yield module.output) + if gen_or_check == GenOrCheck.Check: + self.assertEqual(y, output, f"0x{y:X} != 0x{output:X}") + yield Tick() + + def subtest_key(self, + input_count: int, + inputs: List[Signal], + module: AddReduce, + key: int, + mask_list: List[int], + gen_or_check: GenOrCheck) -> AsyncProcessGenerator: + values, values_str = self.calculate_input_values(input_count, key) + if gen_or_check == GenOrCheck.Check: + with self.subTest(inputs=values_str): + yield from self.subtest_value(inputs, + module, + mask_list, + gen_or_check, + values) + else: + yield from self.subtest_value(inputs, + module, + mask_list, + gen_or_check, + values) + + def subtest_run_sim(self, + input_count: int, + sim: Simulator, + partition_4: Signal, + partition_8: Signal, + inputs: List[Signal], + module: AddReduce, + delay_cycles: int) -> None: + def generic_process(gen_or_check: GenOrCheck) -> AsyncProcessGenerator: + for partition_4_value, partition_8_value, mask_list in [ + (0, 0, [0xFFFF]), + (0, 1, [0xFF00, 0x00FF]), + (1, 0, [0xFFF0, 0x000F]), + (1, 1, [0xFF00, 0x00F0, 0x000F])]: + key_count = 8 + if gen_or_check == GenOrCheck.Check: + with self.subTest(partition_4=partition_4_value, + partition_8=partition_8_value): + for key in range(key_count): + with self.subTest(key=key): + yield from self.subtest_key(input_count, + inputs, + module, + key, + mask_list, + gen_or_check) + else: + if gen_or_check == GenOrCheck.Generate: + yield partition_4.eq(partition_4_value) + yield partition_8.eq(partition_8_value) + for key in range(key_count): + yield from self.subtest_key(input_count, + inputs, + module, + key, + mask_list, + gen_or_check) + + def generate_process() -> AsyncProcessGenerator: + yield from generic_process(GenOrCheck.Generate) + + def check_process() -> AsyncProcessGenerator: + if delay_cycles != 0: + for _ in range(delay_cycles): + yield Tick() + yield from generic_process(GenOrCheck.Check) + + sim.add_clock(2e-6) + sim.add_process(generate_process) + sim.add_process(check_process) + sim.run() + + def subtest_file(self, + input_count: int, + register_levels: List[int]) -> None: + max_level = AddReduce.get_max_level(input_count) + for level in register_levels: + if level > max_level: + return + partition_4 = Signal() + partition_8 = Signal() + partition_points = PartitionPoints() + partition_points[4] = partition_4 + partition_points[8] = partition_8 + width = 16 + inputs = [Signal(width, name=f"input_{i}") + for i in range(input_count)] + module = AddReduce(inputs, + width, + register_levels, + partition_points) + file_name = "add_reduce" + if len(register_levels) != 0: + file_name += f"-{'_'.join(map(repr, register_levels))}" + file_name += f"-{input_count:02d}" + with create_simulator(module, + [partition_4, + partition_8, + *inputs, + module.output], + file_name) as sim: + self.subtest_run_sim(input_count, + sim, + partition_4, + partition_8, + inputs, + module, + len(register_levels)) + + def subtest_register_levels(self, register_levels: List[int]) -> None: + for input_count in range(0, 16): + with self.subTest(input_count=input_count, + register_levels=repr(register_levels)): + self.subtest_file(input_count, register_levels) + + def test_empty(self) -> None: + self.subtest_register_levels([]) + + def test_0(self) -> None: + self.subtest_register_levels([0]) + + def test_1(self) -> None: + self.subtest_register_levels([1]) + + def test_2(self) -> None: + self.subtest_register_levels([2]) + + def test_3(self) -> None: + self.subtest_register_levels([3]) + + def test_4(self) -> None: + self.subtest_register_levels([4]) + + def test_5(self) -> None: + self.subtest_register_levels([5]) + + def test_0(self) -> None: + self.subtest_register_levels([0]) + + def test_0_1(self) -> None: + self.subtest_register_levels([0, 1]) + + def test_0_1_2(self) -> None: + self.subtest_register_levels([0, 1, 2]) + + def test_0_1_2_3(self) -> None: + self.subtest_register_levels([0, 1, 2, 3]) + + def test_0_1_2_3_4(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4]) + + def test_0_1_2_3_4_5(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4, 5]) + + def test_0_2(self) -> None: + self.subtest_register_levels([0, 2]) + + def test_0_3(self) -> None: + self.subtest_register_levels([0, 3]) + + def test_0_4(self) -> None: + self.subtest_register_levels([0, 4]) + + def test_0_5(self) -> None: + self.subtest_register_levels([0, 5]) + + +class SIMDMulLane: + def __init__(self, + a_signed: bool, + b_signed: bool, + bit_width: int, + high_half: bool): + self.a_signed = a_signed + self.b_signed = b_signed + self.bit_width = bit_width + self.high_half = high_half + + def __repr__(self): + return f"SIMDMulLane({self.a_signed}, {self.b_signed}, " +\ + f"{self.bit_width}, {self.high_half})" + + +class TestMul8_16_32_64(unittest.TestCase): + @staticmethod + def simd_mul(a: int, b: int, lanes: List[SIMDMulLane]) -> Tuple[int, int]: + output = 0 + intermediate_output = 0 + shift = 0 + for lane in lanes: + a_signed = lane.a_signed or not lane.high_half + b_signed = lane.b_signed or not lane.high_half + mask = (1 << lane.bit_width) - 1 + sign_bit = 1 << (lane.bit_width - 1) + a_part = (a >> shift) & mask + if a_signed and (a_part & sign_bit) != 0: + a_part -= 1 << lane.bit_width + b_part = (b >> shift) & mask + if b_signed and (b_part & sign_bit) != 0: + b_part -= 1 << lane.bit_width + value = a_part * b_part + value &= (1 << (lane.bit_width * 2)) - 1 + intermediate_output |= value << (shift * 2) + if lane.high_half: + value >>= lane.bit_width + value &= mask + output |= value << shift + shift += lane.bit_width + return output, intermediate_output + + @staticmethod + def get_test_cases(lanes: List[SIMDMulLane], + keys: Iterable[int]) -> Iterable[Tuple[int, int]]: + mask = (1 << 64) - 1 + for i in range(8): + hash_input = f"{i} {lanes} {list(keys)}" + hash = sha256(hash_input.encode()).digest() + value = int.from_bytes(hash, byteorder="little") + yield (value & mask, value >> 64) + a = 0 + b = 0 + shift = 0 + for lane in lanes: + a |= 1 << (shift + lane.bit_width - 1) + b |= 1 << (shift + lane.bit_width - 1) + shift += lane.bit_width + yield a, b + + def test_simd_mul_lane(self): + self.assertEqual(f"{SIMDMulLane(True, True, 8, False)}", + "SIMDMulLane(True, True, 8, False)") + + def test_simd_mul(self): + lanes = [SIMDMulLane(True, + True, + 8, + True), + SIMDMulLane(False, + False, + 8, + True), + SIMDMulLane(True, + True, + 16, + False), + SIMDMulLane(True, + False, + 32, + True)] + a = 0x0123456789ABCDEF + b = 0xFEDCBA9876543210 + output = 0x0121FA00FE1C28FE + intermediate_output = 0x0121FA0023E20B28C94DFE1C280AFEF0 + self.assertEqual(self.simd_mul(a, b, lanes), + (output, intermediate_output)) + a = 0x8123456789ABCDEF + b = 0xFEDCBA9876543210 + output = 0x81B39CB4FE1C28FE + intermediate_output = 0x81B39CB423E20B28C94DFE1C280AFEF0 + self.assertEqual(self.simd_mul(a, b, lanes), + (output, intermediate_output)) + + def test_signed_mul_from_unsigned(self): + for i in range(0, 0x10): + for j in range(0, 0x10): + si = i if i & 8 else i - 0x10 # signed i + sj = j if j & 8 else j - 0x10 # signed j + mulu = i * j + mulsu = si * j + mul = si * sj + with self.subTest(i=i, j=j, si=si, sj=sj, + mulu=mulu, mulsu=mulsu, mul=mul): + mulsu2 = mulu + if si < 0: + mulsu2 += ~j << 4 + mulsu2 += 1 << 4 + self.assertEqual(mulsu & 0xFF, mulsu2 & 0xFF) + mul2 = mulsu2 + if sj < 0: + mul2 += ~i << 4 + mul2 += 1 << 4 + self.assertEqual(mul & 0xFF, mul2 & 0xFF) + + def subtest_value(self, + a: int, + b: int, + module: Mul8_16_32_64, + lanes: List[SIMDMulLane], + gen_or_check: GenOrCheck) -> AsyncProcessGenerator: + if gen_or_check == GenOrCheck.Generate: + yield module.a.eq(a) + yield module.b.eq(b) + output2, intermediate_output2 = self.simd_mul(a, b, lanes) + yield Delay(1e-6) + if gen_or_check == GenOrCheck.Check: + intermediate_output = (yield module._intermediate_output) + self.assertEqual(intermediate_output, + intermediate_output2, + f"0x{intermediate_output:X} " + + f"!= 0x{intermediate_output2:X}") + output = (yield module.output) + self.assertEqual(output, output2, f"0x{output:X} != 0x{output2:X}") + yield Tick() + + def subtest_lanes_2(self, + lanes: List[SIMDMulLane], + module: Mul8_16_32_64, + gen_or_check: GenOrCheck) -> AsyncProcessGenerator: + bit_index = 8 + part_index = 0 + for lane in lanes: + if lane.high_half: + if lane.a_signed: + if lane.b_signed: + op = OP_MUL_SIGNED_HIGH + else: + op = OP_MUL_SIGNED_UNSIGNED_HIGH + else: + self.assertFalse(lane.b_signed, + "unsigned * signed not supported") + op = OP_MUL_UNSIGNED_HIGH + else: + op = OP_MUL_LOW + self.assertEqual(lane.bit_width % 8, 0) + for i in range(lane.bit_width // 8): + if gen_or_check == GenOrCheck.Generate: + yield module.part_ops[part_index].eq(op) + part_index += 1 + for i in range(lane.bit_width // 8 - 1): + if gen_or_check == GenOrCheck.Generate: + yield module.part_pts[bit_index].eq(0) + bit_index += 8 + if bit_index < 64 and gen_or_check == GenOrCheck.Generate: + yield module.part_pts[bit_index].eq(1) + bit_index += 8 + self.assertEqual(part_index, 8) + for a, b in self.get_test_cases(lanes, ()): + if gen_or_check == GenOrCheck.Check: + with self.subTest(a=f"{a:X}", b=f"{b:X}"): + yield from self.subtest_value(a, b, module, lanes, gen_or_check) + else: + yield from self.subtest_value(a, b, module, lanes, gen_or_check) + + def subtest_lanes(self, + lanes: List[SIMDMulLane], + module: Mul8_16_32_64, + gen_or_check: GenOrCheck) -> AsyncProcessGenerator: + if gen_or_check == GenOrCheck.Check: + with self.subTest(lanes=repr(lanes)): + yield from self.subtest_lanes_2(lanes, module, gen_or_check) + else: + yield from self.subtest_lanes_2(lanes, module, gen_or_check) + + def subtest_file(self, + register_levels: List[int]) -> None: + module = Mul8_16_32_64(register_levels) + file_name = "mul8_16_32_64" + if len(register_levels) != 0: + file_name += f"-{'_'.join(map(repr, register_levels))}" + ports = [module.a, + module.b, + module._intermediate_output, + module.output] + ports.extend(module.part_ops) + ports.extend(module.part_pts.values()) + for signals in module._delayed_part_ops: + ports.extend(signals) + ports.extend(module._part_8) + ports.extend(module._part_16) + ports.extend(module._part_32) + ports.extend(module._part_64) + for signals in module._delayed_part_8: + ports.extend(signals) + for signals in module._delayed_part_16: + ports.extend(signals) + for signals in module._delayed_part_32: + ports.extend(signals) + for signals in module._delayed_part_64: + ports.extend(signals) + ports += [module._output_64, + module._output_32, + module._output_16, + module._output_8] + ports.extend(module._a_signed) + ports.extend(module._b_signed) + ports += [module._not_a_term_8, + module._neg_lsb_a_term_8, + module._not_b_term_8, + module._neg_lsb_b_term_8, + module._not_a_term_16, + module._neg_lsb_a_term_16, + module._not_b_term_16, + module._neg_lsb_b_term_16, + module._not_a_term_32, + module._neg_lsb_a_term_32, + module._not_b_term_32, + module._neg_lsb_b_term_32, + module._not_a_term_64, + module._neg_lsb_a_term_64, + module._not_b_term_64, + module._neg_lsb_b_term_64] + with create_simulator(module, ports, file_name) as sim: + def process(gen_or_check: GenOrCheck) -> AsyncProcessGenerator: + for a_signed in False, True: + for b_signed in False, True: + if not a_signed and b_signed: + continue + for high_half in False, True: + if not high_half and not (a_signed and b_signed): + continue + yield from self.subtest_lanes( + [SIMDMulLane(a_signed, + b_signed, + 64, + high_half)], + module, + gen_or_check) + yield from self.subtest_lanes( + [SIMDMulLane(a_signed, + b_signed, + 32, + high_half)] * 2, + module, + gen_or_check) + yield from self.subtest_lanes( + [SIMDMulLane(a_signed, + b_signed, + 16, + high_half)] * 4, + module, + gen_or_check) + yield from self.subtest_lanes( + [SIMDMulLane(a_signed, + b_signed, + 8, + high_half)] * 8, + module, + gen_or_check) + yield from self.subtest_lanes([SIMDMulLane(False, + False, + 32, + True), + SIMDMulLane(False, + False, + 16, + True), + SIMDMulLane(False, + False, + 8, + True), + SIMDMulLane(False, + False, + 8, + True)], + module, + gen_or_check) + yield from self.subtest_lanes([SIMDMulLane(True, + False, + 32, + True), + SIMDMulLane(True, + True, + 16, + False), + SIMDMulLane(True, + True, + 8, + True), + SIMDMulLane(False, + False, + 8, + True)], + module, + gen_or_check) + yield from self.subtest_lanes([SIMDMulLane(True, + True, + 8, + True), + SIMDMulLane(False, + False, + 8, + True), + SIMDMulLane(True, + True, + 16, + False), + SIMDMulLane(True, + False, + 32, + True)], + module, + gen_or_check) + + def generate_process() -> AsyncProcessGenerator: + yield from process(GenOrCheck.Generate) + + def check_process() -> AsyncProcessGenerator: + if len(register_levels) != 0: + for _ in register_levels: + yield Tick() + yield from process(GenOrCheck.Check) + + sim.add_clock(2e-6) + sim.add_process(generate_process) + sim.add_process(check_process) + sim.run() + + def subtest_register_levels(self, register_levels: List[int]) -> None: + with self.subTest(register_levels=repr(register_levels)): + self.subtest_file(register_levels) + + def test_empty(self) -> None: + self.subtest_register_levels([]) + + def test_0(self) -> None: + self.subtest_register_levels([0]) + + def test_1(self) -> None: + self.subtest_register_levels([1]) + + def test_2(self) -> None: + self.subtest_register_levels([2]) + + def test_3(self) -> None: + self.subtest_register_levels([3]) + + def test_4(self) -> None: + self.subtest_register_levels([4]) + + def test_5(self) -> None: + self.subtest_register_levels([5]) + + def test_6(self) -> None: + self.subtest_register_levels([6]) + + def test_7(self) -> None: + self.subtest_register_levels([7]) + + def test_8(self) -> None: + self.subtest_register_levels([8]) + + def test_9(self) -> None: + self.subtest_register_levels([9]) + + def test_10(self) -> None: + self.subtest_register_levels([10]) + + def test_0(self) -> None: + self.subtest_register_levels([0]) + + def test_0_1(self) -> None: + self.subtest_register_levels([0, 1]) + + def test_0_1_2(self) -> None: + self.subtest_register_levels([0, 1, 2]) + + def test_0_1_2_3(self) -> None: + self.subtest_register_levels([0, 1, 2, 3]) + + def test_0_1_2_3_4(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4]) + + def test_0_1_2_3_4_5(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4, 5]) + + def test_0_1_2_3_4_5_6(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6]) + + def test_0_1_2_3_4_5_6_7(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7]) + + def test_0_1_2_3_4_5_6_7_8(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7, 8]) + + def test_0_1_2_3_4_5_6_7_8_9(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + + def test_0_1_2_3_4_5_6_7_8_9_10(self) -> None: + self.subtest_register_levels([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + + def test_0_2(self) -> None: + self.subtest_register_levels([0, 2]) + + def test_0_3(self) -> None: + self.subtest_register_levels([0, 3]) + + def test_0_4(self) -> None: + self.subtest_register_levels([0, 4]) + + def test_0_5(self) -> None: + self.subtest_register_levels([0, 5]) + + def test_0_6(self) -> None: + self.subtest_register_levels([0, 6]) + + def test_0_7(self) -> None: + self.subtest_register_levels([0, 7]) + + def test_0_8(self) -> None: + self.subtest_register_levels([0, 8]) + + def test_0_9(self) -> None: + self.subtest_register_levels([0, 9]) + + def test_0_10(self) -> None: + self.subtest_register_levels([0, 10]) + +if __name__ == '__main__': + unittest.main() -- 2.30.2