spelling mistake $i instead of %i
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index aa4fafb62617d74e36019f2e5e2d1f02355769af..db2ba408ac4981e18dc04e3c8ecfa9767b121523 100644 (file)
@@ -6,6 +6,8 @@ from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
 from nmigen.hdl.ast import Assign
 from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
 from nmigen.hdl.ast import Assign
 from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
+from functools import reduce
+from operator import or_
 
 
 class PartitionPoints(dict):
 
 
 class PartitionPoints(dict):
@@ -110,6 +112,11 @@ class FullAdder(Elaboratable):
     :attribute in2: the third input
     :attribute sum: the sum output
     :attribute carry: the carry output
     :attribute in2: the third input
     :attribute sum: the sum output
     :attribute carry: the carry output
+
+    Rather than do individual full adders (and have an array of them,
+    which would be very slow to simulate), this module can specify the
+    bit width of the inputs and outputs: in effect it performs multiple
+    Full 3-2 Add operations "in parallel".
     """
 
     def __init__(self, width):
     """
 
     def __init__(self, width):
@@ -163,6 +170,10 @@ class PartitionedAdder(Elaboratable):
                 expanded_width += 1
             expanded_width += 1
         self._expanded_width = expanded_width
                 expanded_width += 1
             expanded_width += 1
         self._expanded_width = expanded_width
+        # XXX these have to remain here due to some horrible nmigen
+        # simulation bugs involving sync.  it is *not* necessary to
+        # have them here, they should (under normal circumstances)
+        # be moved into elaborate, as they are entirely local
         self._expanded_a = Signal(expanded_width)
         self._expanded_b = Signal(expanded_width)
         self._expanded_output = Signal(expanded_width)
         self._expanded_a = Signal(expanded_width)
         self._expanded_b = Signal(expanded_width)
         self._expanded_output = Signal(expanded_width)
@@ -172,12 +183,8 @@ class PartitionedAdder(Elaboratable):
         m = Module()
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
         m = Module()
         expanded_index = 0
         # store bits in a list, use Cat later.  graphviz is much cleaner
-        al = []
-        bl = []
-        ol = []
-        ea = []
-        eb = []
-        eo = []
+        al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
+
         # partition points are "breaks" (extra zeros) in what would otherwise
         # be a massive long add.
         for i in range(self.width):
         # partition points are "breaks" (extra zeros) in what would otherwise
         # be a massive long add.
         for i in range(self.width):
@@ -199,7 +206,7 @@ class PartitionedAdder(Elaboratable):
         # combine above using Cat
         m.d.comb += Cat(*ea).eq(Cat(*al))
         m.d.comb += Cat(*eb).eq(Cat(*bl))
         # combine above using Cat
         m.d.comb += Cat(*ea).eq(Cat(*al))
         m.d.comb += Cat(*eb).eq(Cat(*bl))
-        m.d.comb += Cat(*eo).eq(Cat(*ol))
+        m.d.comb += Cat(*ol).eq(Cat(*eo))
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
         m.d.comb += self._expanded_output.eq(
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
         m.d.comb += self._expanded_output.eq(
@@ -366,6 +373,369 @@ OP_MUL_SIGNED_UNSIGNED_HIGH = 2  # a is signed, b is unsigned
 OP_MUL_UNSIGNED_HIGH = 3
 
 
 OP_MUL_UNSIGNED_HIGH = 3
 
 
+def get_term(value, shift=0, enabled=None):
+    if enabled is not None:
+        value = Mux(enabled, value, 0)
+    if shift > 0:
+        value = Cat(Repl(C(0, 1), shift), value)
+    else:
+        assert shift == 0
+    return value
+
+
+class ProductTerm(Elaboratable):
+    """ this class creates a single product term (a[..]*b[..]).
+        it has a design flaw in that is the *output* that is selected,
+        where the multiplication(s) are combinatorially generated
+        all the time.
+    """
+
+    def __init__(self, width, twidth, pbwid, a_index, b_index):
+        self.a_index = a_index
+        self.b_index = b_index
+        shift = 8 * (self.a_index + self.b_index)
+        self.pwidth = width
+        self.twidth = twidth
+        self.width = width*2
+        self.shift = shift
+
+        self.ti = Signal(self.width, reset_less=True)
+        self.term = Signal(twidth, reset_less=True)
+        self.a = Signal(twidth//2, reset_less=True)
+        self.b = Signal(twidth//2, reset_less=True)
+        self.pb_en = Signal(pbwid, reset_less=True)
+
+        self.tl = tl = []
+        min_index = min(self.a_index, self.b_index)
+        max_index = max(self.a_index, self.b_index)
+        for i in range(min_index, max_index):
+            tl.append(self.pb_en[i])
+        name = "te_%d_%d" % (self.a_index, self.b_index)
+        if len(tl) > 0:
+            term_enabled = Signal(name=name, reset_less=True)
+        else:
+            term_enabled = None
+        self.enabled = term_enabled
+        self.term.name = "term_%d_%d" % (a_index, b_index) # rename
+
+    def elaborate(self, platform):
+
+        m = Module()
+        if self.enabled is not None:
+            m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
+
+        bsa = Signal(self.width, reset_less=True)
+        bsb = Signal(self.width, reset_less=True)
+        a_index, b_index = self.a_index, self.b_index
+        pwidth = self.pwidth
+        m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
+        m.d.comb += self.ti.eq(bsa * bsb)
+        m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
+        """
+        #TODO: sort out width issues, get inputs a/b switched on/off.
+        #data going into Muxes is 1/2 the required width
+
+        pwidth = self.pwidth
+        width = self.width
+        bsa = Signal(self.twidth//2, reset_less=True)
+        bsb = Signal(self.twidth//2, reset_less=True)
+        asel = Signal(width, reset_less=True)
+        bsel = Signal(width, reset_less=True)
+        a_index, b_index = self.a_index, self.b_index
+        m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
+        m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
+        m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
+        m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
+        m.d.comb += self.ti.eq(bsa * bsb)
+        m.d.comb += self.term.eq(self.ti)
+        """
+
+        return m
+
+
+class ProductTerms(Elaboratable):
+    """ creates a bank of product terms.  also performs the actual bit-selection
+        this class is to be wrapped with a for-loop on the "a" operand.
+        it creates a second-level for-loop on the "b" operand.
+    """
+    def __init__(self, width, twidth, pbwid, a_index, blen):
+        self.a_index = a_index
+        self.blen = blen
+        self.pwidth = width
+        self.twidth = twidth
+        self.pbwid = pbwid
+        self.a = Signal(twidth//2, reset_less=True)
+        self.b = Signal(twidth//2, reset_less=True)
+        self.pb_en = Signal(pbwid, reset_less=True)
+        self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
+                            for i in range(blen)]
+
+    def elaborate(self, platform):
+
+        m = Module()
+
+        for b_index in range(self.blen):
+            t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
+                            self.a_index, b_index)
+            setattr(m.submodules, "term_%d" % b_index, t)
+
+            m.d.comb += t.a.eq(self.a)
+            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.pb_en.eq(self.pb_en)
+
+            m.d.comb += self.terms[b_index].eq(t.term)
+
+        return m
+
+class LSBNegTerm(Elaboratable):
+
+    def __init__(self, bit_width):
+        self.bit_width = bit_width
+        self.part = Signal(reset_less=True)
+        self.signed = Signal(reset_less=True)
+        self.op = Signal(bit_width, reset_less=True)
+        self.msb = Signal(reset_less=True)
+        self.nt = Signal(bit_width*2, reset_less=True)
+        self.nl = Signal(bit_width*2, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        bit_wid = self.bit_width
+        ext = Repl(0, bit_wid) # extend output to HI part
+
+        # determine sign of each incoming number *in this partition*
+        enabled = Signal(reset_less=True)
+        m.d.comb += enabled.eq(self.part & self.msb & self.signed)
+
+        # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
+        # negation operation is split into a bitwise not and a +1.
+        # likewise for 16, 32, and 64-bit values.
+
+        # width-extended 1s complement if a is signed, otherwise zero
+        comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
+
+        # add 1 if signed, otherwise add zero
+        comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
+
+        return m
+
+
+class Part(Elaboratable):
+    """ a key class which, depending on the partitioning, will determine
+        what action to take when parts of the output are signed or unsigned.
+
+        this requires 2 pieces of data *per operand, per partition*:
+        whether the MSB is HI/LO (per partition!), and whether a signed
+        or unsigned operation has been *requested*.
+
+        once that is determined, signed is basically carried out
+        by splitting 2's complement into 1's complement plus one.
+        1's complement is just a bit-inversion.
+
+        the extra terms - as separate terms - are then thrown at the
+        AddReduce alongside the multiplication part-results.
+    """
+    def __init__(self, width, n_parts, n_levels, pbwid):
+
+        # inputs
+        self.a = Signal(64)
+        self.b = Signal(64)
+        self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
+        self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+        self.pbs = Signal(pbwid, reset_less=True)
+
+        # outputs
+        self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
+        self.delayed_parts = [
+            [Signal(name=f"delayed_part_{delay}_{i}")
+             for i in range(n_parts)]
+                for delay in range(n_levels)]
+        # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
+        self.dplast = [Signal(name=f"dplast_{i}")
+                         for i in range(n_parts)]
+
+        self.not_a_term = Signal(width)
+        self.neg_lsb_a_term = Signal(width)
+        self.not_b_term = Signal(width)
+        self.neg_lsb_b_term = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
+        # negated-temporary copy of partition bits
+        npbs = Signal.like(pbs, reset_less=True)
+        m.d.comb += npbs.eq(~pbs)
+        byte_count = 8 // len(parts)
+        for i in range(len(parts)):
+            pbl = []
+            pbl.append(npbs[i * byte_count - 1])
+            for j in range(i * byte_count, (i + 1) * byte_count - 1):
+                pbl.append(pbs[j])
+            pbl.append(npbs[(i + 1) * byte_count - 1])
+            value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
+            m.d.comb += value.eq(Cat(*pbl))
+            m.d.comb += parts[i].eq(~(value).bool())
+            m.d.comb += delayed_parts[0][i].eq(parts[i])
+            m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
+                         for j in range(len(delayed_parts)-1)]
+            m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
+
+        not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
+                self.not_a_term, self.neg_lsb_a_term, \
+                self.not_b_term, self.neg_lsb_b_term
+
+        byte_width = 8 // len(parts) # byte width
+        bit_wid = 8 * byte_width     # bit width
+        nat, nbt, nla, nlb = [], [], [], []
+        for i in range(len(parts)):
+            # work out bit-inverted and +1 term for a.
+            pa = LSBNegTerm(bit_wid)
+            setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
+            m.d.comb += pa.part.eq(parts[i])
+            m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
+            m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
+            m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
+            nat.append(pa.nt)
+            nla.append(pa.nl)
+
+            # work out bit-inverted and +1 term for b
+            pb = LSBNegTerm(bit_wid)
+            setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
+            m.d.comb += pb.part.eq(parts[i])
+            m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
+            m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
+            m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
+            nbt.append(pb.nt)
+            nlb.append(pb.nl)
+
+        # concatenate together and return all 4 results.
+        m.d.comb += [not_a_term.eq(Cat(*nat)),
+                     not_b_term.eq(Cat(*nbt)),
+                     neg_lsb_a_term.eq(Cat(*nla)),
+                     neg_lsb_b_term.eq(Cat(*nlb)),
+                    ]
+
+        return m
+
+
+class IntermediateOut(Elaboratable):
+    """ selects the HI/LO part of the multiplication, for a given bit-width
+        the output is also reconstructed in its SIMD (partition) lanes.
+    """
+    def __init__(self, width, out_wid, n_parts):
+        self.width = width
+        self.n_parts = n_parts
+        self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
+                                     for i in range(8)]
+        self.intermed = Signal(out_wid, reset_less=True)
+        self.output = Signal(out_wid//2, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        ol = []
+        w = self.width
+        sel = w // 8
+        for i in range(self.n_parts):
+            op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
+            m.d.comb += op.eq(
+                Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
+                    self.intermed.bit_select(i * w*2, w),
+                    self.intermed.bit_select(i * w*2 + w, w)))
+            ol.append(op)
+        m.d.comb += self.output.eq(Cat(*ol))
+
+        return m
+
+
+class FinalOut(Elaboratable):
+    """ selects the final output based on the partitioning.
+
+        each byte is selectable independently, i.e. it is possible
+        that some partitions requested 8-bit computation whilst others
+        requested 16 or 32 bit.
+    """
+    def __init__(self, out_wid):
+        # inputs
+        self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
+        self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
+        self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
+
+        self.i8 = Signal(out_wid, reset_less=True)
+        self.i16 = Signal(out_wid, reset_less=True)
+        self.i32 = Signal(out_wid, reset_less=True)
+        self.i64 = Signal(out_wid, reset_less=True)
+
+        # output
+        self.out = Signal(out_wid, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        ol = []
+        for i in range(8):
+            # select one of the outputs: d8 selects i8, d16 selects i16
+            # d32 selects i32, and the default is i64.
+            # d8 and d16 are ORed together in the first Mux
+            # then the 2nd selects either i8 or i16.
+            # if neither d8 nor d16 are set, d32 selects either i32 or i64.
+            op = Signal(8, reset_less=True, name="op_%d" % i)
+            m.d.comb += op.eq(
+                Mux(self.d8[i] | self.d16[i // 2],
+                    Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
+                                     self.i16.bit_select(i * 8, 8)),
+                    Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
+                                          self.i64.bit_select(i * 8, 8))))
+            ol.append(op)
+        m.d.comb += self.out.eq(Cat(*ol))
+        return m
+
+
+class OrMod(Elaboratable):
+    """ ORs four values together in a hierarchical tree
+    """
+    def __init__(self, wid):
+        self.wid = wid
+        self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
+                     for i in range(4)]
+        self.orout = Signal(wid, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        or1 = Signal(self.wid, reset_less=True)
+        or2 = Signal(self.wid, reset_less=True)
+        m.d.comb += or1.eq(self.orin[0] | self.orin[1])
+        m.d.comb += or2.eq(self.orin[2] | self.orin[3])
+        m.d.comb += self.orout.eq(or1 | or2)
+
+        return m
+
+
+class Signs(Elaboratable):
+    """ determines whether a or b are signed numbers
+        based on the required operation type (OP_MUL_*)
+    """
+
+    def __init__(self):
+        self.part_ops = Signal(2, reset_less=True)
+        self.a_signed = Signal(reset_less=True)
+        self.b_signed = Signal(reset_less=True)
+
+    def elaborate(self, platform):
+
+        m = Module()
+
+        asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
+        bsig = (self.part_ops == OP_MUL_LOW) \
+                    | (self.part_ops == OP_MUL_SIGNED_HIGH)
+        m.d.comb += self.a_signed.eq(asig)
+        m.d.comb += self.b_signed.eq(bsig)
+
+        return m
+
+
 class Mul8_16_32_64(Elaboratable):
     """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
 
 class Mul8_16_32_64(Elaboratable):
     """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
 
@@ -394,61 +764,27 @@ class Mul8_16_32_64(Elaboratable):
             instruction.
     """
 
             instruction.
     """
 
-    def __init__(self, register_levels= ()):
+    def __init__(self, register_levels=()):
+        """ register_levels: specifies the points in the cascade at which
+            flip-flops are to be inserted.
+        """
+
+        # parameter(s)
+        self.register_levels = list(register_levels)
+
+        # inputs
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
         self.a = Signal(64)
         self.b = Signal(64)
         self.part_pts = PartitionPoints()
         for i in range(8, 64, 8):
             self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
         self.a = Signal(64)
         self.b = Signal(64)
-        self.output = Signal(64)
-        self.register_levels = list(register_levels)
+
+        # intermediates (needed for unit tests)
         self._intermediate_output = Signal(128)
         self._intermediate_output = Signal(128)
-        self._delayed_part_ops = [
-            [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
-             for i in range(8)]
-            for delay in range(1 + len(self.register_levels))]
-        self._part_8 = [Signal(name=f"_part_8_{i}") for i in range(8)]
-        self._part_16 = [Signal(name=f"_part_16_{i}") for i in range(4)]
-        self._part_32 = [Signal(name=f"_part_32_{i}") for i in range(2)]
-        self._part_64 = [Signal(name=f"_part_64")]
-        self._delayed_part_8 = [
-            [Signal(name=f"_delayed_part_8_{delay}_{i}")
-             for i in range(8)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_16 = [
-            [Signal(name=f"_delayed_part_16_{delay}_{i}")
-             for i in range(4)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_32 = [
-            [Signal(name=f"_delayed_part_32_{delay}_{i}")
-             for i in range(2)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_64 = [
-            [Signal(name=f"_delayed_part_64_{delay}")]
-            for delay in range(1 + len(self.register_levels))]
-        self._output_64 = Signal(64)
-        self._output_32 = Signal(64)
-        self._output_16 = Signal(64)
-        self._output_8 = Signal(64)
-        self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
-        self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
-        self._not_a_term_8 = Signal(128)
-        self._neg_lsb_a_term_8 = Signal(128)
-        self._not_b_term_8 = Signal(128)
-        self._neg_lsb_b_term_8 = Signal(128)
-        self._not_a_term_16 = Signal(128)
-        self._neg_lsb_a_term_16 = Signal(128)
-        self._not_b_term_16 = Signal(128)
-        self._neg_lsb_b_term_16 = Signal(128)
-        self._not_a_term_32 = Signal(128)
-        self._neg_lsb_a_term_32 = Signal(128)
-        self._not_b_term_32 = Signal(128)
-        self._neg_lsb_b_term_32 = Signal(128)
-        self._not_a_term_64 = Signal(128)
-        self._neg_lsb_a_term_64 = Signal(128)
-        self._not_b_term_64 = Signal(128)
-        self._neg_lsb_b_term_64 = Signal(128)
+
+        # output
+        self.output = Signal(64)
 
     def _part_byte(self, index):
         if index == -1 or index == 7:
 
     def _part_byte(self, index):
         if index == -1 or index == 7:
@@ -459,147 +795,80 @@ class Mul8_16_32_64(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
     def elaborate(self, platform):
         m = Module()
 
+        # collect part-bytes
+        pbs = Signal(8, reset_less=True)
+        tl = []
+        for i in range(8):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(self._part_byte(i))
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
+
+        # local variables
+        signs = []
+        for i in range(8):
+            s = Signs()
+            signs.append(s)
+            setattr(m.submodules, "signs%d" % i, s)
+            m.d.comb += s.part_ops.eq(self.part_ops[i])
+
+        delayed_part_ops = [
+            [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
+             for i in range(8)]
+            for delay in range(1 + len(self.register_levels))]
         for i in range(len(self.part_ops)):
         for i in range(len(self.part_ops)):
-            m.d.comb += self._delayed_part_ops[0][i].eq(self.part_ops[i])
-            m.d.sync += [self._delayed_part_ops[j + 1][i]
-                         .eq(self._delayed_part_ops[j][i])
+            m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
+            m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
                          for j in range(len(self.register_levels))]
 
                          for j in range(len(self.register_levels))]
 
-        def add_intermediate_value(value):
-            intermediate_value = Signal(len(value), reset_less=True)
-            m.d.comb += intermediate_value.eq(value)
-            return intermediate_value
-
-        for parts, delayed_parts in [(self._part_64, self._delayed_part_64),
-                                     (self._part_32, self._delayed_part_32),
-                                     (self._part_16, self._delayed_part_16),
-                                     (self._part_8, self._delayed_part_8)]:
-            byte_count = 8 // len(parts)
-            for i in range(len(parts)):
-                pb = self._part_byte(i * byte_count - 1)
-                value = add_intermediate_value(pb)
-                for j in range(i * byte_count, (i + 1) * byte_count - 1):
-                    pb = add_intermediate_value(~self._part_byte(j))
-                    value = add_intermediate_value(value & pb)
-                pb = self._part_byte((i + 1) * byte_count - 1)
-                value = add_intermediate_value(value & pb)
-                m.d.comb += parts[i].eq(value)
-                m.d.comb += delayed_parts[0][i].eq(parts[i])
-                m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
-                             for j in range(len(self.register_levels))]
-
-        products = [[
-                Signal(16, name=f"products_{i}_{j}")
-                for j in range(8)]
-            for i in range(8)]
-
-        for a_index in range(8):
-            for b_index in range(8):
-                a = self.a.part(a_index * 8, 8)
-                b = self.b.part(b_index * 8, 8)
-                m.d.comb += products[a_index][b_index].eq(a * b)
+        n_levels = len(self.register_levels)+1
+        m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
+        m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
+        m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
+        m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
+        nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
+        for mod in [part_8, part_16, part_32, part_64]:
+            m.d.comb += mod.a.eq(self.a)
+            m.d.comb += mod.b.eq(self.b)
+            for i in range(len(signs)):
+                m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
+                m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
+            m.d.comb += mod.pbs.eq(pbs)
+            nat_l.append(mod.not_a_term)
+            nbt_l.append(mod.not_b_term)
+            nla_l.append(mod.neg_lsb_a_term)
+            nlb_l.append(mod.neg_lsb_b_term)
 
         terms = []
 
 
         terms = []
 
-        def add_term(value, shift=0, enabled=None):
-            term = Signal(128)
-            terms.append(term)
-            if enabled is not None:
-                value = Mux(enabled, value, 0)
-            if shift > 0:
-                value = Cat(Repl(C(0, 1), shift), value)
-            else:
-                assert shift == 0
-            m.d.comb += term.eq(value)
-
         for a_index in range(8):
         for a_index in range(8):
-            for b_index in range(8):
-                term_enabled: Value = C(True, 1)
-                min_index = min(a_index, b_index)
-                max_index = max(a_index, b_index)
-                for i in range(min_index, max_index):
-                    term_enabled &= ~self._part_byte(i)
-                add_term(products[a_index][b_index],
-                         8 * (a_index + b_index),
-                         term_enabled)
+            t = ProductTerms(8, 128, 8, a_index, 8)
+            setattr(m.submodules, "terms_%d" % a_index, t)
 
 
-        for i in range(8):
-            a_signed = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH
-            b_signed = (self.part_ops[i] == OP_MUL_LOW) \
-                | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
-            m.d.comb += self._a_signed[i].eq(a_signed)
-            m.d.comb += self._b_signed[i].eq(b_signed)
+            m.d.comb += t.a.eq(self.a)
+            m.d.comb += t.b.eq(self.b)
+            m.d.comb += t.pb_en.eq(pbs)
+
+            for term in t.terms:
+                terms.append(term)
 
 
-        # it's fine to bitwise-or these together since they are never enabled
+        # it's fine to bitwise-or data together since they are never enabled
         # at the same time
         # at the same time
-        add_term(self._not_a_term_8 | self._not_a_term_16
-                 | self._not_a_term_32 | self._not_a_term_64)
-        add_term(self._neg_lsb_a_term_8 | self._neg_lsb_a_term_16
-                 | self._neg_lsb_a_term_32 | self._neg_lsb_a_term_64)
-        add_term(self._not_b_term_8 | self._not_b_term_16
-                 | self._not_b_term_32 | self._not_b_term_64)
-        add_term(self._neg_lsb_b_term_8 | self._neg_lsb_b_term_16
-                 | self._neg_lsb_b_term_32 | self._neg_lsb_b_term_64)
-
-        for not_a_term, \
-            neg_lsb_a_term, \
-            not_b_term, \
-            neg_lsb_b_term, \
-            parts in [
-                (self._not_a_term_8,
-                 self._neg_lsb_a_term_8,
-                 self._not_b_term_8,
-                 self._neg_lsb_b_term_8,
-                 self._part_8),
-                (self._not_a_term_16,
-                 self._neg_lsb_a_term_16,
-                 self._not_b_term_16,
-                 self._neg_lsb_b_term_16,
-                 self._part_16),
-                (self._not_a_term_32,
-                 self._neg_lsb_a_term_32,
-                 self._not_b_term_32,
-                 self._neg_lsb_b_term_32,
-                 self._part_32),
-                (self._not_a_term_64,
-                 self._neg_lsb_a_term_64,
-                 self._not_b_term_64,
-                 self._neg_lsb_b_term_64,
-                 self._part_64),
-                ]:
-            byte_width = 8 // len(parts)
-            bit_width = 8 * byte_width
-            for i in range(len(parts)):
-                b_enabled = parts[i] & self.a[(i + 1) * bit_width - 1] \
-                    & self._a_signed[i * byte_width]
-                a_enabled = parts[i] & self.b[(i + 1) * bit_width - 1] \
-                    & self._b_signed[i * byte_width]
-
-                # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
-                # negation operation is split into a bitwise not and a +1.
-                # likewise for 16, 32, and 64-bit values.
-                m.d.comb += [
-                    not_a_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(a_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.a.part(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_a_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), a_enabled)),
-
-                    not_b_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(b_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.b.part(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_b_term.part(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), b_enabled))]
+        m.submodules.nat_or = nat_or = OrMod(128)
+        m.submodules.nbt_or = nbt_or = OrMod(128)
+        m.submodules.nla_or = nla_or = OrMod(128)
+        m.submodules.nlb_or = nlb_or = OrMod(128)
+        for l, mod in [(nat_l, nat_or),
+                             (nbt_l, nbt_or),
+                             (nla_l, nla_or),
+                             (nlb_l, nlb_or)]:
+            for i in range(len(l)):
+                m.d.comb += mod.orin[i].eq(l[i])
+            terms.append(mod.orout)
 
         expanded_part_pts = PartitionPoints()
         for i, v in self.part_pts.items():
 
         expanded_part_pts = PartitionPoints()
         for i, v in self.part_pts.items():
-            signal = Signal(name=f"expanded_part_pts_{i*2}")
+            signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
             expanded_part_pts[i * 2] = signal
             m.d.comb += signal.eq(v)
 
             expanded_part_pts[i * 2] = signal
             m.d.comb += signal.eq(v)
 
@@ -609,35 +878,44 @@ class Mul8_16_32_64(Elaboratable):
                                expanded_part_pts)
         m.submodules.add_reduce = add_reduce
         m.d.comb += self._intermediate_output.eq(add_reduce.output)
                                expanded_part_pts)
         m.submodules.add_reduce = add_reduce
         m.d.comb += self._intermediate_output.eq(add_reduce.output)
-        m.d.comb += self._output_64.eq(
-            Mux(self._delayed_part_ops[-1][0] == OP_MUL_LOW,
-                self._intermediate_output.part(0, 64),
-                self._intermediate_output.part(64, 64)))
-        for i in range(2):
-            m.d.comb += self._output_32.part(i * 32, 32).eq(
-                Mux(self._delayed_part_ops[-1][4 * i] == OP_MUL_LOW,
-                    self._intermediate_output.part(i * 64, 32),
-                    self._intermediate_output.part(i * 64 + 32, 32)))
-        for i in range(4):
-            m.d.comb += self._output_16.part(i * 16, 16).eq(
-                Mux(self._delayed_part_ops[-1][2 * i] == OP_MUL_LOW,
-                    self._intermediate_output.part(i * 32, 16),
-                    self._intermediate_output.part(i * 32 + 16, 16)))
+        # create _output_64
+        m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
+        m.d.comb += io64.intermed.eq(self._intermediate_output)
+        for i in range(8):
+            m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
+
+        # create _output_32
+        m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
+        m.d.comb += io32.intermed.eq(self._intermediate_output)
         for i in range(8):
         for i in range(8):
-            m.d.comb += self._output_8.part(i * 8, 8).eq(
-                Mux(self._delayed_part_ops[-1][i] == OP_MUL_LOW,
-                    self._intermediate_output.part(i * 16, 8),
-                    self._intermediate_output.part(i * 16 + 8, 8)))
+            m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
+
+        # create _output_16
+        m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
+        m.d.comb += io16.intermed.eq(self._intermediate_output)
+        for i in range(8):
+            m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
+
+        # create _output_8
+        m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
+        m.d.comb += io8.intermed.eq(self._intermediate_output)
         for i in range(8):
         for i in range(8):
-            m.d.comb += self.output.part(i * 8, 8).eq(
-                Mux(self._delayed_part_8[-1][i]
-                    | self._delayed_part_16[-1][i // 2],
-                    Mux(self._delayed_part_8[-1][i],
-                        self._output_8.part(i * 8, 8),
-                        self._output_16.part(i * 8, 8)),
-                    Mux(self._delayed_part_32[-1][i // 4],
-                        self._output_32.part(i * 8, 8),
-                        self._output_64.part(i * 8, 8))))
+            m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
+
+        # final output
+        m.submodules.finalout = finalout = FinalOut(64)
+        for i in range(len(part_8.delayed_parts[-1])):
+            m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
+        for i in range(len(part_16.delayed_parts[-1])):
+            m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
+        for i in range(len(part_32.delayed_parts[-1])):
+            m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
+        m.d.comb += finalout.i8.eq(io8.output)
+        m.d.comb += finalout.i16.eq(io16.output)
+        m.d.comb += finalout.i32.eq(io32.output)
+        m.d.comb += finalout.i64.eq(io64.output)
+        m.d.comb += self.output.eq(finalout.out)
+
         return m
 
 
         return m