switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 6f770842f8aa048c6afc2b8e54b5013f2e91e985..b132de56f0ecc90316b13fd1ee1f8da5b7ea4075 100644 (file)
@@ -8,350 +8,80 @@ from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
 from nmigen.cli import main
 from functools import reduce
 from operator import or_
+from ieee754.pipeline import PipelineSpec
+from nmutil.pipemodbase import PipeModBase
 
 
-
-class PartitionPoints(dict):
-    """Partition points and corresponding ``Value``s.
-
-    The points at where an ALU is partitioned along with ``Value``s that
-    specify if the corresponding partition points are enabled.
-
-    For example: ``{1: True, 5: True, 10: True}`` with
-    ``width == 16`` specifies that the ALU is split into 4 sections:
-    * bits 0 <= ``i`` < 1
-    * bits 1 <= ``i`` < 5
-    * bits 5 <= ``i`` < 10
-    * bits 10 <= ``i`` < 16
-
-    If the partition_points were instead ``{1: True, 5: a, 10: True}``
-    where ``a`` is a 1-bit ``Signal``:
-    * If ``a`` is asserted:
-        * bits 0 <= ``i`` < 1
-        * bits 1 <= ``i`` < 5
-        * bits 5 <= ``i`` < 10
-        * bits 10 <= ``i`` < 16
-    * Otherwise
-        * bits 0 <= ``i`` < 1
-        * bits 1 <= ``i`` < 10
-        * bits 10 <= ``i`` < 16
-    """
-
-    def __init__(self, partition_points=None):
-        """Create a new ``PartitionPoints``.
-
-        :param partition_points: the input partition points to values mapping.
-        """
-        super().__init__()
-        if partition_points is not None:
-            for point, enabled in partition_points.items():
-                if not isinstance(point, int):
-                    raise TypeError("point must be a non-negative integer")
-                if point < 0:
-                    raise ValueError("point must be a non-negative integer")
-                self[point] = Value.wrap(enabled)
-
-    def like(self, name=None, src_loc_at=0, mul=1):
-        """Create a new ``PartitionPoints`` with ``Signal``s for all values.
-
-        :param name: the base name for the new ``Signal``s.
-        :param mul: a multiplication factor on the indices
-        """
-        if name is None:
-            name = Signal(src_loc_at=1+src_loc_at).name  # get variable name
-        retval = PartitionPoints()
-        for point, enabled in self.items():
-            point *= mul
-            retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
-        return retval
-
-    def eq(self, rhs):
-        """Assign ``PartitionPoints`` using ``Signal.eq``."""
-        if set(self.keys()) != set(rhs.keys()):
-            raise ValueError("incompatible point set")
-        for point, enabled in self.items():
-            yield enabled.eq(rhs[point])
-
-    def as_mask(self, width):
-        """Create a bit-mask from `self`.
-
-        Each bit in the returned mask is clear only if the partition point at
-        the same bit-index is enabled.
-
-        :param width: the bit width of the resulting mask
-        """
-        bits = []
-        for i in range(width):
-            if i in self:
-                bits.append(~self[i])
-            else:
-                bits.append(True)
-        return Cat(*bits)
-
-    def get_max_partition_count(self, width):
-        """Get the maximum number of partitions.
-
-        Gets the number of partitions when all partition points are enabled.
-        """
-        retval = 1
-        for point in self.keys():
-            if point < width:
-                retval += 1
-        return retval
-
-    def fits_in_width(self, width):
-        """Check if all partition points are smaller than `width`."""
-        for point in self.keys():
-            if point >= width:
-                return False
-        return True
-
-    def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
-        if index == -1 or index == 7:
-            return C(True, 1)
-        assert index >= 0 and index < 8
-        return self[(index * 8 + 8)*mfactor]
-
-
-class FullAdder(Elaboratable):
-    """Full Adder.
-
-    :attribute in0: the first input
-    :attribute in1: the second input
-    :attribute in2: the third input
-    :attribute sum: the sum output
-    :attribute carry: the carry output
-
-    Rather than do individual full adders (and have an array of them,
-    which would be very slow to simulate), this module can specify the
-    bit width of the inputs and outputs: in effect it performs multiple
-    Full 3-2 Add operations "in parallel".
-    """
-
-    def __init__(self, width):
-        """Create a ``FullAdder``.
-
-        :param width: the bit width of the input and output
-        """
-        self.in0 = Signal(width, reset_less=True)
-        self.in1 = Signal(width, reset_less=True)
-        self.in2 = Signal(width, reset_less=True)
-        self.sum = Signal(width, reset_less=True)
-        self.carry = Signal(width, reset_less=True)
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
-        m.d.comb += self.carry.eq((self.in0 & self.in1)
-                                  | (self.in1 & self.in2)
-                                  | (self.in2 & self.in0))
-        return m
-
-
-class MaskedFullAdder(Elaboratable):
-    """Masked Full Adder.
-
-    :attribute mask: the carry partition mask
-    :attribute in0: the first input
-    :attribute in1: the second input
-    :attribute in2: the third input
-    :attribute sum: the sum output
-    :attribute mcarry: the masked carry output
-
-    FullAdders are always used with a "mask" on the output.  To keep
-    the graphviz "clean", this class performs the masking here rather
-    than inside a large for-loop.
-
-    See the following discussion as to why this is no longer derived
-    from FullAdder.  Each carry is shifted here *before* being ANDed
-    with the mask, so that an AOI cell may be used (which is more
-    gate-efficient)
-    https://en.wikipedia.org/wiki/AND-OR-Invert
-    https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
-    """
-
-    def __init__(self, width):
-        """Create a ``MaskedFullAdder``.
-
-        :param width: the bit width of the input and output
-        """
-        self.width = width
-        self.mask = Signal(width, reset_less=True)
-        self.mcarry = Signal(width, reset_less=True)
-        self.in0 = Signal(width, reset_less=True)
-        self.in1 = Signal(width, reset_less=True)
-        self.in2 = Signal(width, reset_less=True)
-        self.sum = Signal(width, reset_less=True)
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        s1 = Signal(self.width, reset_less=True)
-        s2 = Signal(self.width, reset_less=True)
-        s3 = Signal(self.width, reset_less=True)
-        c1 = Signal(self.width, reset_less=True)
-        c2 = Signal(self.width, reset_less=True)
-        c3 = Signal(self.width, reset_less=True)
-        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
-        m.d.comb += s1.eq(Cat(0, self.in0))
-        m.d.comb += s2.eq(Cat(0, self.in1))
-        m.d.comb += s3.eq(Cat(0, self.in2))
-        m.d.comb += c1.eq(s1 & s2 & self.mask)
-        m.d.comb += c2.eq(s2 & s3 & self.mask)
-        m.d.comb += c3.eq(s3 & s1 & self.mask)
-        m.d.comb += self.mcarry.eq(c1 | c2 | c3)
-        return m
-
-
-class PartitionedAdder(Elaboratable):
-    """Partitioned Adder.
-
-    Performs the final add.  The partition points are included in the
-    actual add (in one of the operands only), which causes a carry over
-    to the next bit.  Then the final output *removes* the extra bits from
-    the result.
-
-    partition: .... P... P... P... P... (32 bits)
-    a        : .... .... .... .... .... (32 bits)
-    b        : .... .... .... .... .... (32 bits)
-    exp-a    : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
-    exp-b    : ....0....0....0....0.... (32 bits plus 4 zeros)
-    exp-o    : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
-    o        : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
-
-    :attribute width: the bit width of the input and output. Read-only.
-    :attribute a: the first input to the adder
-    :attribute b: the second input to the adder
-    :attribute output: the sum output
-    :attribute partition_points: the input partition points. Modification not
-        supported, except for by ``Signal.eq``.
-    """
-
-    def __init__(self, width, partition_points):
-        """Create a ``PartitionedAdder``.
-
-        :param width: the bit width of the input and output
-        :param partition_points: the input partition points
-        """
-        self.width = width
-        self.a = Signal(width, reset_less=True)
-        self.b = Signal(width, reset_less=True)
-        self.output = Signal(width, reset_less=True)
-        self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(width):
-            raise ValueError("partition_points doesn't fit in width")
-        expanded_width = 0
-        for i in range(self.width):
-            if i in self.partition_points:
-                expanded_width += 1
-            expanded_width += 1
-        self._expanded_width = expanded_width
-
-    def elaborate(self, platform):
-        """Elaborate this module."""
-        m = Module()
-        expanded_a = Signal(self._expanded_width, reset_less=True)
-        expanded_b = Signal(self._expanded_width, reset_less=True)
-        expanded_o = Signal(self._expanded_width, reset_less=True)
-
-        expanded_index = 0
-        # store bits in a list, use Cat later.  graphviz is much cleaner
-        al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
-
-        # partition points are "breaks" (extra zeros or 1s) in what would
-        # otherwise be a massive long add.  when the "break" points are 0,
-        # whatever is in it (in the output) is discarded.  however when
-        # there is a "1", it causes a roll-over carry to the *next* bit.
-        # we still ignore the "break" bit in the [intermediate] output,
-        # however by that time we've got the effect that we wanted: the
-        # carry has been carried *over* the break point.
-
-        for i in range(self.width):
-            if i in self.partition_points:
-                # add extra bit set to 0 + 0 for enabled partition points
-                # and 1 + 0 for disabled partition points
-                ea.append(expanded_a[expanded_index])
-                al.append(~self.partition_points[i]) # add extra bit in a
-                eb.append(expanded_b[expanded_index])
-                bl.append(C(0)) # yes, add a zero
-                expanded_index += 1 # skip the extra point.  NOT in the output
-            ea.append(expanded_a[expanded_index])
-            eb.append(expanded_b[expanded_index])
-            eo.append(expanded_o[expanded_index])
-            al.append(self.a[i])
-            bl.append(self.b[i])
-            ol.append(self.output[i])
-            expanded_index += 1
-
-        # combine above using Cat
-        m.d.comb += Cat(*ea).eq(Cat(*al))
-        m.d.comb += Cat(*eb).eq(Cat(*bl))
-        m.d.comb += Cat(*ol).eq(Cat(*eo))
-
-        # use only one addition to take advantage of look-ahead carry and
-        # special hardware on FPGAs
-        m.d.comb += expanded_o.eq(expanded_a + expanded_b)
-        return m
+from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder
 
 
 FULL_ADDER_INPUT_COUNT = 3
 
 
 
 FULL_ADDER_INPUT_COUNT = 3
 
+
 class AddReduceData:
 
 class AddReduceData:
 
-    def __init__(self, ppoints, n_inputs, output_width, n_parts):
+    def __init__(self, part_pts, n_inputs, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
-                          for i in range(n_parts)]
-        self.inputs = [Signal(output_width, name=f"inputs_{i}",
-                              reset_less=True)
-                        for i in range(n_inputs)]
-        self.reg_partition_points = ppoints.like()
-
-    def eq_from(self, reg_partition_points, inputs, part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
-               [self.inputs[i].eq(inputs[i])
-                                     for i in range(len(self.inputs))] + \
+                         for i in range(n_parts)]
+        self.terms = [Signal(output_width, name=f"terms_{i}",
+                             reset_less=True)
+                      for i in range(n_inputs)]
+        self.part_pts = part_pts.like()
+
+    def eq_from(self, part_pts, inputs, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.terms[i].eq(inputs[i])
+                for i in range(len(self.terms))] + \
                [self.part_ops[i].eq(part_ops[i])
                [self.part_ops[i].eq(part_ops[i])
-                                     for i in range(len(self.part_ops))]
+                for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.inputs, rhs.part_ops)
+        return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
 
 
 class FinalReduceData:
 
 
 
 class FinalReduceData:
 
-    def __init__(self, ppoints, output_width, n_parts):
+    def __init__(self, part_pts, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
-                          for i in range(n_parts)]
+                         for i in range(n_parts)]
         self.output = Signal(output_width, reset_less=True)
         self.output = Signal(output_width, reset_less=True)
-        self.reg_partition_points = ppoints.like()
+        self.part_pts = part_pts.like()
 
 
-    def eq_from(self, reg_partition_points, output, part_ops):
-        return [self.reg_partition_points.eq(reg_partition_points)] + \
+    def eq_from(self, part_pts, output, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
                [self.output.eq(output)] + \
                [self.part_ops[i].eq(part_ops[i])
                [self.output.eq(output)] + \
                [self.part_ops[i].eq(part_ops[i])
-                                     for i in range(len(self.part_ops))]
+                for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
 
     def eq(self, rhs):
-        return self.eq_from(rhs.reg_partition_points, rhs.output, rhs.part_ops)
+        return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
 
 
 
 
-class FinalAdd(Elaboratable):
+class FinalAdd(PipeModBase):
     """ Final stage of add reduce
     """
 
     """ Final stage of add reduce
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
-        self.output_width = output_width
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                 partition_step=1):
+        self.lidx = lidx
+        self.partition_step = partition_step
+        self.output_width = pspec.width * 2
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.register_levels = list(register_levels)
+        self.n_parts = pspec.n_parts
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
             raise ValueError("partition_points doesn't fit in output_width")
 
+        super().__init__(pspec, "finaladd")
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return FinalReduceData(self.partition_points,
+                               self.output_width, self.n_parts)
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
@@ -363,24 +93,25 @@ class FinalAdd(Elaboratable):
             m.d.comb += output.eq(0)
         elif self.n_inputs == 1:
             # handle single input
             m.d.comb += output.eq(0)
         elif self.n_inputs == 1:
             # handle single input
-            m.d.comb += output.eq(self.i.inputs[0])
+            m.d.comb += output.eq(self.i.terms[0])
         else:
             # base case for adding 2 inputs
             assert self.n_inputs == 2
         else:
             # base case for adding 2 inputs
             assert self.n_inputs == 2
-            adder = PartitionedAdder(output_width, self.i.reg_partition_points)
+            adder = PartitionedAdder(output_width,
+                                     self.i.part_pts, self.partition_step)
             m.submodules.final_adder = adder
             m.submodules.final_adder = adder
-            m.d.comb += adder.a.eq(self.i.inputs[0])
-            m.d.comb += adder.b.eq(self.i.inputs[1])
+            m.d.comb += adder.a.eq(self.i.terms[0])
+            m.d.comb += adder.b.eq(self.i.terms[1])
             m.d.comb += output.eq(adder.output)
 
         # create output
             m.d.comb += output.eq(adder.output)
 
         # create output
-        m.d.comb += self.o.eq_from(self.i.reg_partition_points, output,
+        m.d.comb += self.o.eq_from(self.i.part_pts, output,
                                    self.i.part_ops)
 
         return m
 
 
                                    self.i.part_ops)
 
         return m
 
 
-class AddReduceSingle(Elaboratable):
+class AddReduceSingle(PipeModBase):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -392,44 +123,46 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
+                 partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         :param output_width: bit-width of ``output``.
-        :param register_levels: List of nesting levels that should have
-            pipeline registers.
         :param partition_points: the input partition points.
         """
         :param partition_points: the input partition points.
         """
+        self.lidx = lidx
+        self.partition_step = partition_step
         self.n_inputs = n_inputs
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
-        self.i = AddReduceData(partition_points, n_inputs,
-                               output_width, n_parts)
-        self.register_levels = list(register_levels)
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
         self.partition_points = PartitionPoints(partition_points)
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
             raise ValueError("partition_points doesn't fit in output_width")
 
-        max_level = AddReduceSingle.get_max_level(n_inputs)
-        for level in self.register_levels:
-            if level > max_level:
-                raise ValueError(
-                    "not enough adder levels for specified register levels")
-
-        # this is annoying.  we have to create the modules (and terms)
-        # because we need to know what they are (in order to set up the
-        # interconnects back in AddReduce), but cannot do the m.d.comb +=
-        # etc because this is not in elaboratable.
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
-        self._intermediate_terms = []
-        self.adders = []
-        if len(self.groups) != 0:
-            self.create_next_terms()
+        self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
 
 
-        self.o = AddReduceData(partition_points, len(self._intermediate_terms),
-                               output_width, n_parts)
+        super().__init__(pspec, "addreduce_%d" % lidx)
+
+    def ispec(self):
+        return AddReduceData(self.partition_points, self.n_inputs,
+                             self.output_width, self.n_parts)
+
+    def ospec(self):
+        return AddReduceData(self.partition_points, self.n_terms,
+                             self.output_width, self.n_parts)
+
+    @staticmethod
+    def calc_n_inputs(n_inputs, groups):
+        retval = len(groups)*2
+        if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
+            retval += 1
+        elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
+            retval += 2
+        else:
+            assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
+        return retval
 
     @staticmethod
     def get_max_level(input_count):
 
     @staticmethod
     def get_max_level(input_count):
@@ -454,68 +187,124 @@ class AddReduceSingle(Elaboratable):
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
+    def create_next_terms(self):
+        """ create next intermediate terms, for linking up in elaborate, below
+        """
+        terms = []
+        adders = []
+
+        # create full adders for this recursive level.
+        # this shrinks N terms to 2 * (N // 3) plus the remainder
+        for i in self.groups:
+            adder_i = MaskedFullAdder(self.output_width)
+            adders.append((i, adder_i))
+            # add both the sum and the masked-carry to the next level.
+            # 3 inputs have now been reduced to 2...
+            terms.append(adder_i.sum)
+            terms.append(adder_i.mcarry)
+        # handle the remaining inputs.
+        if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
+            terms.append(self.i.terms[-1])
+        elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
+            # Just pass the terms to the next layer, since we wouldn't gain
+            # anything by using a half adder since there would still be 2 terms
+            # and just passing the terms to the next layer saves gates.
+            terms.append(self.i.terms[-2])
+            terms.append(self.i.terms[-1])
+        else:
+            assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
+
+        return terms, adders
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
+        terms, adders = self.create_next_terms()
+
         # copy the intermediate terms to the output
         # copy the intermediate terms to the output
-        for i, value in enumerate(self._intermediate_terms):
-            m.d.comb += self.o.inputs[i].eq(value)
+        for i, value in enumerate(terms):
+            m.d.comb += self.o.terms[i].eq(value)
 
         # copy reg part points and part ops to output
 
         # copy reg part points and part ops to output
-        m.d.comb += self.o.reg_partition_points.eq(self.i.reg_partition_points)
+        m.d.comb += self.o.part_pts.eq(self.i.part_pts)
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
-                                     for i in range(len(self.i.part_ops))]
+                     for i in range(len(self.i.part_ops))]
 
         # set up the partition mask (for the adders)
         part_mask = Signal(self.output_width, reset_less=True)
 
 
         # set up the partition mask (for the adders)
         part_mask = Signal(self.output_width, reset_less=True)
 
-        mask = self.i.reg_partition_points.as_mask(self.output_width)
+        # get partition points as a mask
+        mask = self.i.part_pts.as_mask(self.output_width,
+                                       mul=self.partition_step)
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
         m.d.comb += part_mask.eq(mask)
 
         # add and link the intermediate term modules
-        for i, (iidx, adder_i) in enumerate(self.adders):
+        for i, (iidx, adder_i) in enumerate(adders):
             setattr(m.submodules, f"adder_{i}", adder_i)
 
             setattr(m.submodules, f"adder_{i}", adder_i)
 
-            m.d.comb += adder_i.in0.eq(self.i.inputs[iidx])
-            m.d.comb += adder_i.in1.eq(self.i.inputs[iidx + 1])
-            m.d.comb += adder_i.in2.eq(self.i.inputs[iidx + 2])
+            m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
+            m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
+            m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
             m.d.comb += adder_i.mask.eq(part_mask)
 
         return m
 
             m.d.comb += adder_i.mask.eq(part_mask)
 
         return m
 
-    def create_next_terms(self):
 
 
-        _intermediate_terms = []
+class AddReduceInternal:
+    """Iteratively Add list of numbers together.
 
 
-        def add_intermediate_term(value):
-            _intermediate_terms.append(value)
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
 
 
-        # create full adders for this recursive level.
-        # this shrinks N terms to 2 * (N // 3) plus the remainder
-        for i in self.groups:
-            adder_i = MaskedFullAdder(self.output_width)
-            self.adders.append((i, adder_i))
-            # add both the sum and the masked-carry to the next level.
-            # 3 inputs have now been reduced to 2...
-            add_intermediate_term(adder_i.sum)
-            add_intermediate_term(adder_i.mcarry)
-        # handle the remaining inputs.
-        if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
-            add_intermediate_term(self.i.inputs[-1])
-        elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
-            # Just pass the terms to the next layer, since we wouldn't gain
-            # anything by using a half adder since there would still be 2 terms
-            # and just passing the terms to the next layer saves gates.
-            add_intermediate_term(self.i.inputs[-2])
-            add_intermediate_term(self.i.inputs[-1])
-        else:
-            assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
+    def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param partition_points: the input partition points.
+        """
+        self.pspec = pspec
+        self.n_inputs = n_inputs
+        self.output_width = pspec.width * 2
+        self.partition_points = part_pts
+        self.partition_step = partition_step
+
+        self.create_levels()
+
+    def create_levels(self):
+        """creates reduction levels"""
+
+        mods = []
+        partition_points = self.partition_points
+        ilen = self.n_inputs
+        while True:
+            groups = AddReduceSingle.full_adder_groups(ilen)
+            if len(groups) == 0:
+                break
+            lidx = len(mods)
+            next_level = AddReduceSingle(self.pspec, lidx, ilen,
+                                         partition_points,
+                                         self.partition_step)
+            mods.append(next_level)
+            partition_points = next_level.i.part_pts
+            ilen = len(next_level.o.terms)
 
 
-        self._intermediate_terms = _intermediate_terms
+        lidx = len(mods)
+        next_level = FinalAdd(self.pspec, lidx, ilen,
+                              partition_points, self.partition_step)
+        mods.append(next_level)
+
+        self.levels = mods
 
 
 
 
-class AddReduce(Elaboratable):
+class AddReduce(AddReduceInternal, Elaboratable):
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -527,8 +316,8 @@ class AddReduce(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, inputs, output_width, register_levels, partition_points,
-                       part_ops):
+    def __init__(self, inputs, output_width, register_levels, part_pts,
+                 part_ops, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -537,15 +326,16 @@ class AddReduce(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        self.inputs = inputs
-        self.part_ops = part_ops
+        self._inputs = inputs
+        self._part_pts = part_pts
+        self._part_ops = part_ops
         n_parts = len(part_ops)
         n_parts = len(part_ops)
-        self.o = FinalReduceData(partition_points, output_width, n_parts)
-        self.output_width = output_width
+        self.i = AddReduceData(part_pts, len(inputs),
+                               output_width, n_parts)
+        AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
+                                   partition_step)
+        self.o = FinalReduceData(part_pts, output_width, n_parts)
         self.register_levels = register_levels
         self.register_levels = register_levels
-        self.partition_points = partition_points
-
-        self.create_levels()
 
     @staticmethod
     def get_max_level(input_count):
 
     @staticmethod
     def get_max_level(input_count):
@@ -558,57 +348,24 @@ class AddReduce(Elaboratable):
             if level > 0:
                 yield level - 1
 
             if level > 0:
                 yield level - 1
 
-    def create_levels(self):
-        """creates reduction levels"""
-
-        mods = []
-        next_levels = self.register_levels
-        partition_points = self.partition_points
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        inputs = self.inputs
-        ilen = len(inputs)
-        while True:
-            groups = AddReduceSingle.full_adder_groups(len(inputs))
-            if len(groups) == 0:
-                break
-            next_level = AddReduceSingle(ilen, self.output_width, n_parts,
-                                         next_levels, partition_points)
-            mods.append(next_level)
-            next_levels = list(AddReduce.next_register_levels(next_levels))
-            partition_points = next_level.i.reg_partition_points
-            inputs = next_level.o.inputs
-            ilen = len(inputs)
-            part_ops = next_level.i.part_ops
-
-        next_level = FinalAdd(ilen, self.output_width, n_parts,
-                              next_levels, partition_points)
-        mods.append(next_level)
-
-        self.levels = mods
-
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
+        m.d.comb += self.i.eq_from(self._part_pts,
+                                   self._inputs, self._part_ops)
+
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
 
-        partition_points = self.partition_points
-        inputs = self.inputs
-        part_ops = self.part_ops
-        n_parts = len(part_ops)
-        n_inputs = len(inputs)
-        output_width = self.output_width
-        i = AddReduceData(partition_points, n_inputs, output_width, n_parts)
-        m.d.comb += i.eq_from(partition_points, inputs, part_ops)
+        i = self.i
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
         for idx in range(len(self.levels)):
             mcur = self.levels[idx]
-            if 0 in mcur.register_levels:
+            if idx in self.register_levels:
                 m.d.sync += mcur.i.eq(i)
             else:
                 m.d.comb += mcur.i.eq(i)
                 m.d.sync += mcur.i.eq(i)
             else:
                 m.d.comb += mcur.i.eq(i)
-            i = mcur.o # for next loop
+            i = mcur.o  # for next loop
 
         # output comes from last module
         m.d.comb += self.o.eq(i)
 
         # output comes from last module
         m.d.comb += self.o.eq(i)
@@ -665,7 +422,7 @@ class ProductTerm(Elaboratable):
         else:
             term_enabled = None
         self.enabled = term_enabled
         else:
             term_enabled = None
         self.enabled = term_enabled
-        self.term.name = "term_%d_%d" % (a_index, b_index) # rename
+        self.term.name = "term_%d_%d" % (a_index, b_index)  # rename
 
     def elaborate(self, platform):
 
 
     def elaborate(self, platform):
 
@@ -677,8 +434,8 @@ class ProductTerm(Elaboratable):
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
         bsb = Signal(self.width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         pwidth = self.pwidth
-        m.d.comb += bsa.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsb.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
         """
@@ -692,8 +449,8 @@ class ProductTerm(Elaboratable):
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
         asel = Signal(width, reset_less=True)
         bsel = Signal(width, reset_less=True)
         a_index, b_index = self.a_index, self.b_index
-        m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
-        m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
+        m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
         m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
         m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
         m.d.comb += self.ti.eq(bsa * bsb)
@@ -708,6 +465,7 @@ class ProductTerms(Elaboratable):
         this class is to be wrapped with a for-loop on the "a" operand.
         it creates a second-level for-loop on the "b" operand.
     """
         this class is to be wrapped with a for-loop on the "a" operand.
         it creates a second-level for-loop on the "b" operand.
     """
+
     def __init__(self, width, twidth, pbwid, a_index, blen):
         self.a_index = a_index
         self.blen = blen
     def __init__(self, width, twidth, pbwid, a_index, blen):
         self.a_index = a_index
         self.blen = blen
@@ -717,8 +475,8 @@ class ProductTerms(Elaboratable):
         self.a = Signal(twidth//2, reset_less=True)
         self.b = Signal(twidth//2, reset_less=True)
         self.pb_en = Signal(pbwid, reset_less=True)
         self.a = Signal(twidth//2, reset_less=True)
         self.b = Signal(twidth//2, reset_less=True)
         self.pb_en = Signal(pbwid, reset_less=True)
-        self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
-                            for i in range(blen)]
+        self.terms = [Signal(twidth, name="term%d" % i, reset_less=True)
+                      for i in range(blen)]
 
     def elaborate(self, platform):
 
 
     def elaborate(self, platform):
 
@@ -753,7 +511,7 @@ class LSBNegTerm(Elaboratable):
         m = Module()
         comb = m.d.comb
         bit_wid = self.bit_width
         m = Module()
         comb = m.d.comb
         bit_wid = self.bit_width
-        ext = Repl(0, bit_wid) # extend output to HI part
+        ext = Repl(0, bit_wid)  # extend output to HI part
 
         # determine sign of each incoming number *in this partition*
         enabled = Signal(reset_less=True)
 
         # determine sign of each incoming number *in this partition*
         enabled = Signal(reset_less=True)
@@ -774,10 +532,10 @@ class LSBNegTerm(Elaboratable):
 
 class Parts(Elaboratable):
 
 
 class Parts(Elaboratable):
 
-    def __init__(self, pbwid, epps, n_parts):
+    def __init__(self, pbwid, part_pts, n_parts):
         self.pbwid = pbwid
         # inputs
         self.pbwid = pbwid
         # inputs
-        self.epps = PartitionPoints.like(epps, name="epps") # expanded points
+        self.part_pts = PartitionPoints.like(part_pts)
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
                       for i in range(n_parts)]
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
                       for i in range(n_parts)]
@@ -785,13 +543,13 @@ class Parts(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
 
-        epps, parts = self.epps, self.parts
+        part_pts, parts = self.part_pts, self.parts
         # collect part-bytes (double factor because the input is extended)
         pbs = Signal(self.pbwid, reset_less=True)
         tl = []
         for i in range(self.pbwid):
             pb = Signal(name="pb%d" % i, reset_less=True)
         # collect part-bytes (double factor because the input is extended)
         pbs = Signal(self.pbwid, reset_less=True)
         tl = []
         for i in range(self.pbwid):
             pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(epps.part_byte(i, mfactor=2)) # double
+            m.d.comb += pb.eq(part_pts.part_byte(i))
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
@@ -827,23 +585,24 @@ class Part(Elaboratable):
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
-    def __init__(self, epps, width, n_parts, n_levels, pbwid):
+
+    def __init__(self, part_pts, width, n_parts, pbwid):
 
         self.pbwid = pbwid
 
         self.pbwid = pbwid
-        self.epps = epps
+        self.part_pts = part_pts
 
         # inputs
         self.a = Signal(64, reset_less=True)
         self.b = Signal(64, reset_less=True)
         self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
 
         # inputs
         self.a = Signal(64, reset_less=True)
         self.b = Signal(64, reset_less=True)
         self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
-                            for i in range(8)]
+                         for i in range(8)]
         self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
         self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
-                            for i in range(8)]
+                         for i in range(8)]
         self.pbs = Signal(pbwid, reset_less=True)
 
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
         self.pbs = Signal(pbwid, reset_less=True)
 
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
-                            for i in range(n_parts)]
+                      for i in range(n_parts)]
 
         self.not_a_term = Signal(width, reset_less=True)
         self.neg_lsb_a_term = Signal(width, reset_less=True)
 
         self.not_a_term = Signal(width, reset_less=True)
         self.neg_lsb_a_term = Signal(width, reset_less=True)
@@ -854,18 +613,18 @@ class Part(Elaboratable):
         m = Module()
 
         pbs, parts = self.pbs, self.parts
         m = Module()
 
         pbs, parts = self.pbs, self.parts
-        epps = self.epps
-        m.submodules.p = p = Parts(self.pbwid, epps, len(parts))
-        m.d.comb += p.epps.eq(epps)
+        part_pts = self.part_pts
+        m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
+        m.d.comb += p.part_pts.eq(part_pts)
         parts = p.parts
 
         byte_count = 8 // len(parts)
 
         not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
         parts = p.parts
 
         byte_count = 8 // len(parts)
 
         not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
-                self.not_a_term, self.neg_lsb_a_term,
-                self.not_b_term, self.neg_lsb_b_term)
+            self.not_a_term, self.neg_lsb_a_term,
+            self.not_b_term, self.neg_lsb_b_term)
 
 
-        byte_width = 8 // len(parts) # byte width
+        byte_width = 8 // len(parts)  # byte width
         bit_wid = 8 * byte_width     # bit width
         nat, nbt, nla, nlb = [], [], [], []
         for i in range(len(parts)):
         bit_wid = 8 * byte_width     # bit width
         nat, nbt, nla, nlb = [], [], [], []
         for i in range(len(parts)):
@@ -873,9 +632,9 @@ class Part(Elaboratable):
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
             pa = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
-            m.d.comb += pa.op.eq(self.a.part(bit_wid * i, bit_wid))
-            m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
-            m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
+            m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
+            m.d.comb += pa.signed.eq(self.b_signed[i * byte_width])  # yes b
+            m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1])  # really, b
             nat.append(pa.nt)
             nla.append(pa.nl)
 
             nat.append(pa.nt)
             nla.append(pa.nl)
 
@@ -883,9 +642,9 @@ class Part(Elaboratable):
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
             pb = LSBNegTerm(bit_wid)
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
-            m.d.comb += pb.op.eq(self.b.part(bit_wid * i, bit_wid))
-            m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
-            m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
+            m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
+            m.d.comb += pb.signed.eq(self.a_signed[i * byte_width])  # yes a
+            m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1])  # really, a
             nbt.append(pb.nt)
             nlb.append(pb.nl)
 
             nbt.append(pb.nt)
             nlb.append(pb.nl)
 
@@ -894,7 +653,7 @@ class Part(Elaboratable):
                      not_b_term.eq(Cat(*nbt)),
                      neg_lsb_a_term.eq(Cat(*nla)),
                      neg_lsb_b_term.eq(Cat(*nlb)),
                      not_b_term.eq(Cat(*nbt)),
                      neg_lsb_a_term.eq(Cat(*nla)),
                      neg_lsb_b_term.eq(Cat(*nlb)),
-                    ]
+                     ]
 
         return m
 
 
         return m
 
@@ -903,11 +662,12 @@ class IntermediateOut(Elaboratable):
     """ selects the HI/LO part of the multiplication, for a given bit-width
         the output is also reconstructed in its SIMD (partition) lanes.
     """
     """ selects the HI/LO part of the multiplication, for a given bit-width
         the output is also reconstructed in its SIMD (partition) lanes.
     """
+
     def __init__(self, width, out_wid, n_parts):
         self.width = width
         self.n_parts = n_parts
         self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
     def __init__(self, width, out_wid, n_parts):
         self.width = width
         self.n_parts = n_parts
         self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
-                                     for i in range(8)]
+                         for i in range(8)]
         self.intermed = Signal(out_wid, reset_less=True)
         self.output = Signal(out_wid//2, reset_less=True)
 
         self.intermed = Signal(out_wid, reset_less=True)
         self.output = Signal(out_wid//2, reset_less=True)
 
@@ -921,37 +681,74 @@ class IntermediateOut(Elaboratable):
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
             op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
             m.d.comb += op.eq(
                 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
-                    self.intermed.part(i * w*2, w),
-                    self.intermed.part(i * w*2 + w, w)))
+                    self.intermed.bit_select(i * w*2, w),
+                    self.intermed.bit_select(i * w*2 + w, w)))
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
             ol.append(op)
         m.d.comb += self.output.eq(Cat(*ol))
 
         return m
 
 
-class FinalOut(Elaboratable):
+class FinalOut(PipeModBase):
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
     """ selects the final output based on the partitioning.
 
         each byte is selectable independently, i.e. it is possible
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
-    def __init__(self, out_wid):
-        # inputs
-        self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
-        self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
-        self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
 
 
-        self.i8 = Signal(out_wid, reset_less=True)
-        self.i16 = Signal(out_wid, reset_less=True)
-        self.i32 = Signal(out_wid, reset_less=True)
-        self.i64 = Signal(out_wid, reset_less=True)
+    def __init__(self, pspec, part_pts):
 
 
-        # output
-        self.out = Signal(out_wid, reset_less=True)
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+        self.out_wid = pspec.width
+
+        super().__init__(pspec, "finalout")
+
+    def ispec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
+
+        part_pts = self.part_pts
+        m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
+        m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
+        m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
+        m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
+
+        out_part_pts = self.i.part_pts
+
+        # temporaries
+        d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
+        d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
+        d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
+
+        i8 = Signal(self.out_wid, reset_less=True)
+        i16 = Signal(self.out_wid, reset_less=True)
+        i32 = Signal(self.out_wid, reset_less=True)
+        i64 = Signal(self.out_wid, reset_less=True)
+
+        m.d.comb += p_8.part_pts.eq(out_part_pts)
+        m.d.comb += p_16.part_pts.eq(out_part_pts)
+        m.d.comb += p_32.part_pts.eq(out_part_pts)
+        m.d.comb += p_64.part_pts.eq(out_part_pts)
+
+        for i in range(len(p_8.parts)):
+            m.d.comb += d8[i].eq(p_8.parts[i])
+        for i in range(len(p_16.parts)):
+            m.d.comb += d16[i].eq(p_16.parts[i])
+        for i in range(len(p_32.parts)):
+            m.d.comb += d32[i].eq(p_32.parts[i])
+        m.d.comb += i8.eq(self.i.outputs[0])
+        m.d.comb += i16.eq(self.i.outputs[1])
+        m.d.comb += i32.eq(self.i.outputs[2])
+        m.d.comb += i64.eq(self.i.outputs[3])
+
         ol = []
         for i in range(8):
             # select one of the outputs: d8 selects i8, d16 selects i16
         ol = []
         for i in range(8):
             # select one of the outputs: d8 selects i8, d16 selects i16
@@ -961,19 +758,24 @@ class FinalOut(Elaboratable):
             # if neither d8 nor d16 are set, d32 selects either i32 or i64.
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
             # if neither d8 nor d16 are set, d32 selects either i32 or i64.
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
-                Mux(self.d8[i] | self.d16[i // 2],
-                    Mux(self.d8[i], self.i8.part(i * 8, 8),
-                                     self.i16.part(i * 8, 8)),
-                    Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
-                                          self.i64.part(i * 8, 8))))
+                Mux(d8[i] | d16[i // 2],
+                    Mux(d8[i], i8.bit_select(i * 8, 8),
+                        i16.bit_select(i * 8, 8)),
+                    Mux(d32[i // 4], i32.bit_select(i * 8, 8),
+                        i64.bit_select(i * 8, 8))))
             ol.append(op)
             ol.append(op)
-        m.d.comb += self.out.eq(Cat(*ol))
+
+        # create outputs
+        m.d.comb += self.o.output.eq(Cat(*ol))
+        m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
+
         return m
 
 
 class OrMod(Elaboratable):
     """ ORs four values together in a hierarchical tree
     """
         return m
 
 
 class OrMod(Elaboratable):
     """ ORs four values together in a hierarchical tree
     """
+
     def __init__(self, wid):
         self.wid = wid
         self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
     def __init__(self, wid):
         self.wid = wid
         self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
@@ -1007,99 +809,118 @@ class Signs(Elaboratable):
 
         asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
         bsig = (self.part_ops == OP_MUL_LOW) \
 
         asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
         bsig = (self.part_ops == OP_MUL_LOW) \
-                    | (self.part_ops == OP_MUL_SIGNED_HIGH)
+            | (self.part_ops == OP_MUL_SIGNED_HIGH)
         m.d.comb += self.a_signed.eq(asig)
         m.d.comb += self.b_signed.eq(bsig)
 
         return m
 
 
         m.d.comb += self.a_signed.eq(asig)
         m.d.comb += self.b_signed.eq(bsig)
 
         return m
 
 
-class Mul8_16_32_64(Elaboratable):
-    """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
-
-    Supports partitioning into any combination of 8, 16, 32, and 64-bit
-    partitions on naturally-aligned boundaries. Supports the operation being
-    set for each partition independently.
+class IntermediateData:
 
 
-    :attribute part_pts: the input partition points. Has a partition point at
-        multiples of 8 in 0 < i < 64. Each partition point's associated
-        ``Value`` is a ``Signal``. Modification not supported, except for by
-        ``Signal.eq``.
-    :attribute part_ops: the operation for each byte. The operation for a
-        particular partition is selected by assigning the selected operation
-        code to each byte in the partition. The allowed operation codes are:
+    def __init__(self, part_pts, output_width, n_parts):
+        self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
+                         for i in range(n_parts)]
+        self.part_pts = part_pts.like()
+        self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
+                        for i in range(4)]
+        # intermediates (needed for unit tests)
+        self.intermediate_output = Signal(output_width)
+
+    def eq_from(self, part_pts, outputs, intermediate_output,
+                part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.intermediate_output.eq(intermediate_output)] + \
+               [self.outputs[i].eq(outputs[i])
+                for i in range(4)] + \
+               [self.part_ops[i].eq(part_ops[i])
+                for i in range(len(self.part_ops))]
 
 
-        :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
-            RISC-V's `mul` instruction.
-        :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
-            ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
-            instruction.
-        :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
-            where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
-            `mulhsu` instruction.
-        :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
-            ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
-            instruction.
-    """
+    def eq(self, rhs):
+        return self.eq_from(rhs.part_pts, rhs.outputs,
+                            rhs.intermediate_output, rhs.part_ops)
 
 
-    def __init__(self, register_levels=()):
-        """ register_levels: specifies the points in the cascade at which
-            flip-flops are to be inserted.
-        """
 
 
-        # parameter(s)
-        self.register_levels = list(register_levels)
+class InputData:
 
 
-        # inputs
+    def __init__(self):
+        self.a = Signal(64)
+        self.b = Signal(64)
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
-        self.a = Signal(64)
-        self.b = Signal(64)
 
 
-        # intermediates (needed for unit tests)
-        self.intermediate_output = Signal(128)
+    def eq_from(self, part_pts, a, b, part_ops):
+        return [self.part_pts.eq(part_pts)] + \
+               [self.a.eq(a), self.b.eq(b)] + \
+               [self.part_ops[i].eq(part_ops[i])
+                for i in range(len(self.part_ops))]
 
 
-        # output
+    def eq(self, rhs):
+        return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
+
+
+class OutputData:
+
+    def __init__(self):
+        self.intermediate_output = Signal(128)  # needed for unit tests
         self.output = Signal(64)
 
         self.output = Signal(64)
 
+    def eq(self, rhs):
+        return [self.intermediate_output.eq(rhs.intermediate_output),
+                self.output.eq(rhs.output)]
+
+
+class AllTerms(PipeModBase):
+    """Set of terms to be added together
+    """
+
+    def __init__(self, pspec, n_inputs):
+        """Create an ``AllTerms``.
+        """
+        self.n_inputs = n_inputs
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
+        super().__init__(pspec, "allterms")
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return AddReduceData(self.i.part_pts, self.n_inputs,
+                             self.output_width, self.n_parts)
+
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
 
+        eps = self.i.part_pts
+
         # collect part-bytes
         pbs = Signal(8, reset_less=True)
         tl = []
         for i in range(8):
             pb = Signal(name="pb%d" % i, reset_less=True)
         # collect part-bytes
         pbs = Signal(8, reset_less=True)
         tl = []
         for i in range(8):
             pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(self.part_pts.part_byte(i))
+            m.d.comb += pb.eq(eps.part_byte(i))
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
             tl.append(pb)
         m.d.comb += pbs.eq(Cat(*tl))
 
-        # create (doubled) PartitionPoints (output is double input width)
-        expanded_part_pts = eps = PartitionPoints()
-        for i, v in self.part_pts.items():
-            ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
-            expanded_part_pts[i * 2] = ep
-            m.d.comb += ep.eq(v)
-
         # local variables
         signs = []
         for i in range(8):
             s = Signs()
             signs.append(s)
             setattr(m.submodules, "signs%d" % i, s)
         # local variables
         signs = []
         for i in range(8):
             s = Signs()
             signs.append(s)
             setattr(m.submodules, "signs%d" % i, s)
-            m.d.comb += s.part_ops.eq(self.part_ops[i])
+            m.d.comb += s.part_ops.eq(self.i.part_ops[i])
 
 
-        n_levels = len(self.register_levels)+1
-        m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8)
-        m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8)
-        m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8)
-        m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8)
+        m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
+        m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
+        m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
+        m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
         nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
         for mod in [part_8, part_16, part_32, part_64]:
-            m.d.comb += mod.a.eq(self.a)
-            m.d.comb += mod.b.eq(self.b)
+            m.d.comb += mod.a.eq(self.i.a)
+            m.d.comb += mod.b.eq(self.i.b)
             for i in range(len(signs)):
                 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
                 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
             for i in range(len(signs)):
                 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
                 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
@@ -1115,8 +936,8 @@ class Mul8_16_32_64(Elaboratable):
             t = ProductTerms(8, 128, 8, a_index, 8)
             setattr(m.submodules, "terms_%d" % a_index, t)
 
             t = ProductTerms(8, 128, 8, a_index, 8)
             setattr(m.submodules, "terms_%d" % a_index, t)
 
-            m.d.comb += t.a.eq(self.a)
-            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.a.eq(self.i.a)
+            m.d.comb += t.b.eq(self.i.b)
             m.d.comb += t.pb_en.eq(pbs)
 
             for term in t.terms:
             m.d.comb += t.pb_en.eq(pbs)
 
             for term in t.terms:
@@ -1129,71 +950,179 @@ class Mul8_16_32_64(Elaboratable):
         m.submodules.nla_or = nla_or = OrMod(128)
         m.submodules.nlb_or = nlb_or = OrMod(128)
         for l, mod in [(nat_l, nat_or),
         m.submodules.nla_or = nla_or = OrMod(128)
         m.submodules.nlb_or = nlb_or = OrMod(128)
         for l, mod in [(nat_l, nat_or),
-                             (nbt_l, nbt_or),
-                             (nla_l, nla_or),
-                             (nlb_l, nlb_or)]:
+                       (nbt_l, nbt_or),
+                       (nla_l, nla_or),
+                       (nlb_l, nlb_or)]:
             for i in range(len(l)):
                 m.d.comb += mod.orin[i].eq(l[i])
             terms.append(mod.orout)
 
             for i in range(len(l)):
                 m.d.comb += mod.orin[i].eq(l[i])
             terms.append(mod.orout)
 
-        add_reduce = AddReduce(terms,
-                               128,
-                               self.register_levels,
-                               expanded_part_pts,
-                               self.part_ops)
+        # copy the intermediate terms to the output
+        for i, value in enumerate(terms):
+            m.d.comb += self.o.terms[i].eq(value)
+
+        # copy reg part points and part ops to output
+        m.d.comb += self.o.part_pts.eq(eps)
+        m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
+                     for i in range(len(self.i.part_ops))]
+
+        return m
+
+
+class Intermediates(PipeModBase):
+    """ Intermediate output modules
+    """
+
+    def __init__(self, pspec, part_pts):
+        self.part_pts = part_pts
+        self.output_width = pspec.width * 2
+        self.n_parts = pspec.n_parts
+
+        super().__init__(pspec, "intermediates")
+
+    def ispec(self):
+        return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def elaborate(self, platform):
+        m = Module()
 
 
-        out_part_ops = add_reduce.o.part_ops
-        out_part_pts = add_reduce.o.reg_partition_points
+        out_part_ops = self.i.part_ops
+        out_part_pts = self.i.part_pts
 
 
-        m.submodules.add_reduce = add_reduce
-        m.d.comb += self.intermediate_output.eq(add_reduce.o.output)
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
         # create _output_64
         m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
-        m.d.comb += io64.intermed.eq(self.intermediate_output)
+        m.d.comb += io64.intermed.eq(self.i.output)
         for i in range(8):
             m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
         for i in range(8):
             m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[3].eq(io64.output)
 
         # create _output_32
         m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
 
         # create _output_32
         m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
-        m.d.comb += io32.intermed.eq(self.intermediate_output)
+        m.d.comb += io32.intermed.eq(self.i.output)
         for i in range(8):
             m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
         for i in range(8):
             m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[2].eq(io32.output)
 
         # create _output_16
         m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
 
         # create _output_16
         m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
-        m.d.comb += io16.intermed.eq(self.intermediate_output)
+        m.d.comb += io16.intermed.eq(self.i.output)
         for i in range(8):
             m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
         for i in range(8):
             m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[1].eq(io16.output)
 
         # create _output_8
         m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
 
         # create _output_8
         m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
-        m.d.comb += io8.intermed.eq(self.intermediate_output)
+        m.d.comb += io8.intermed.eq(self.i.output)
         for i in range(8):
             m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
         for i in range(8):
             m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.outputs[0].eq(io8.output)
+
+        for i in range(8):
+            m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
+        m.d.comb += self.o.part_pts.eq(out_part_pts)
+        m.d.comb += self.o.intermediate_output.eq(self.i.output)
+
+        return m
+
+
+class Mul8_16_32_64(Elaboratable):
+    """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
+
+    XXX NOTE: this class is intended for unit test purposes ONLY.
+
+    Supports partitioning into any combination of 8, 16, 32, and 64-bit
+    partitions on naturally-aligned boundaries. Supports the operation being
+    set for each partition independently.
+
+    :attribute part_pts: the input partition points. Has a partition point at
+        multiples of 8 in 0 < i < 64. Each partition point's associated
+        ``Value`` is a ``Signal``. Modification not supported, except for by
+        ``Signal.eq``.
+    :attribute part_ops: the operation for each byte. The operation for a
+        particular partition is selected by assigning the selected operation
+        code to each byte in the partition. The allowed operation codes are:
+
+        :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
+            RISC-V's `mul` instruction.
+        :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
+            ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
+            instruction.
+        :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
+            where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
+            `mulhsu` instruction.
+        :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
+            ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
+            instruction.
+    """
+
+    def __init__(self, register_levels=()):
+        """ register_levels: specifies the points in the cascade at which
+            flip-flops are to be inserted.
+        """
+
+        self.id_wid = 0  # num_bits(num_rows)
+        self.op_wid = 0
+        self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
+        self.pspec.n_parts = 8
+
+        # parameter(s)
+        self.register_levels = list(register_levels)
+
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+        # inputs
+        self.part_pts = self.i.part_pts
+        self.part_ops = self.i.part_ops
+        self.a = self.i.a
+        self.b = self.i.b
+
+        # output
+        self.intermediate_output = self.o.intermediate_output
+        self.output = self.o.output
 
 
-        m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts))
-        m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts))
-        m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts))
-        m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts))
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return OutputData()
+
+    def elaborate(self, platform):
+        m = Module()
+
+        part_pts = self.part_pts
+
+        n_inputs = 64 + 4
+        t = AllTerms(self.pspec, n_inputs)
+        t.setup(m, self.i)
+
+        terms = t.o.terms
+
+        at = AddReduceInternal(self.pspec, n_inputs,
+                               part_pts, partition_step=2)
+
+        i = t.o
+        for idx in range(len(at.levels)):
+            mcur = at.levels[idx]
+            mcur.setup(m, i)
+            o = mcur.ospec()
+            if idx in self.register_levels:
+                m.d.sync += o.eq(mcur.process(i))
+            else:
+                m.d.comb += o.eq(mcur.process(i))
+            i = o  # for next loop
 
 
-        m.d.comb += p_8.epps.eq(out_part_pts)
-        m.d.comb += p_16.epps.eq(out_part_pts)
-        m.d.comb += p_32.epps.eq(out_part_pts)
-        m.d.comb += p_64.epps.eq(out_part_pts)
+        interm = Intermediates(self.pspec, part_pts)
+        interm.setup(m, i)
+        o = interm.process(interm.i)
 
         # final output
 
         # final output
-        m.submodules.finalout = finalout = FinalOut(64)
-        for i in range(len(part_8.parts)):
-            m.d.comb += finalout.d8[i].eq(p_8.parts[i])
-        for i in range(len(part_16.parts)):
-            m.d.comb += finalout.d16[i].eq(p_16.parts[i])
-        for i in range(len(part_32.parts)):
-            m.d.comb += finalout.d32[i].eq(p_32.parts[i])
-        m.d.comb += finalout.i8.eq(io8.output)
-        m.d.comb += finalout.i16.eq(io16.output)
-        m.d.comb += finalout.i32.eq(io32.output)
-        m.d.comb += finalout.i64.eq(io64.output)
-        m.d.comb += self.output.eq(finalout.out)
+        finalout = FinalOut(self.pspec, part_pts)
+        finalout.setup(m, o)
+        m.d.comb += self.o.eq(finalout.process(o))
 
         return m
 
 
         return m