add to docstrings in PartitionedAdder
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index db1f267bdadc2e79ed2c188dd735e6b6c7af10eb..1d3ceea2c5b4062d0ea663f786ee7b9866fabf0e 100644 (file)
@@ -140,9 +140,53 @@ class FullAdder(Elaboratable):
         return m
 
 
+class MaskedFullAdder(FullAdder):
+    """Masked Full Adder.
+
+    :attribute mask: the carry partition mask
+    :attribute in0: the first input
+    :attribute in1: the second input
+    :attribute in2: the third input
+    :attribute sum: the sum output
+    :attribute mcarry: the masked carry output
+
+    FullAdders are always used with a "mask" on the output.  To keep
+    the graphviz "clean", this class performs the masking here rather
+    than inside a large for-loop.
+    """
+
+    def __init__(self, width):
+        """Create a ``MaskedFullAdder``.
+
+        :param width: the bit width of the input and output
+        """
+        FullAdder.__init__(self, width)
+        self.mask = Signal(width)
+        self.mcarry = Signal(width)
+
+    def elaborate(self, platform):
+        """Elaborate this module."""
+        m = FullAdder.elaborate(self, platform)
+        m.d.comb += self.mcarry.eq((self.carry << 1) & self.mask)
+        return m
+
+
 class PartitionedAdder(Elaboratable):
     """Partitioned Adder.
 
+    Performs the final add.  The partition points are included in the
+    actual add (in one of the operands only), which causes a carry over
+    to the next bit.  Then the final output *removes* the extra bits from
+    the result.
+
+    partition: .... P... P... P... P... (32 bits)
+    a        : .... .... .... .... .... (32 bits)
+    b        : .... .... .... .... .... (32 bits)
+    exp-a    : ....P....P....P....P.... (32+4 bits)
+    exp-b    : ....0....0....0....0.... (32 bits plus 4 zeros)
+    exp-o    : ....xN...xN...xN...xN... (32+4 bits)
+    o        : .... N... N... N... N... (32 bits)
+
     :attribute width: the bit width of the input and output. Read-only.
     :attribute a: the first input to the adder
     :attribute b: the second input to the adder
@@ -174,9 +218,9 @@ class PartitionedAdder(Elaboratable):
         # simulation bugs involving sync.  it is *not* necessary to
         # have them here, they should (under normal circumstances)
         # be moved into elaborate, as they are entirely local
-        self._expanded_a = Signal(expanded_width)
-        self._expanded_b = Signal(expanded_width)
-        self._expanded_output = Signal(expanded_width)
+        self._expanded_a = Signal(expanded_width) # includes extra part-points
+        self._expanded_b = Signal(expanded_width) # likewise.
+        self._expanded_o = Signal(expanded_width) # likewise.
 
     def elaborate(self, platform):
         """Elaborate this module."""
@@ -185,31 +229,39 @@ class PartitionedAdder(Elaboratable):
         # store bits in a list, use Cat later.  graphviz is much cleaner
         al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
 
-        # partition points are "breaks" (extra zeros) in what would otherwise
-        # be a massive long add.
+        # partition points are "breaks" (extra zeros or 1s) in what would
+        # otherwise be a massive long add.  when the "break" points are 0,
+        # whatever is in it (in the output) is discarded.  however when
+        # there is a "1", it causes a roll-over carry to the *next* bit.
+        # we still ignore the "break" bit in the [intermediate] output,
+        # however by that time we've got the effect that we wanted: the
+        # carry has been carried *over* the break point.
+
         for i in range(self.width):
             if i in self.partition_points:
                 # add extra bit set to 0 + 0 for enabled partition points
                 # and 1 + 0 for disabled partition points
                 ea.append(self._expanded_a[expanded_index])
-                al.append(~self.partition_points[i])
+                al.append(~self.partition_points[i]) # add extra bit in a
                 eb.append(self._expanded_b[expanded_index])
-                bl.append(C(0))
-                expanded_index += 1
+                bl.append(C(0)) # do *not* add extra bit into b.
+                expanded_index += 1 # skip the extra point.  NOT in the output
             ea.append(self._expanded_a[expanded_index])
-            al.append(self.a[i])
             eb.append(self._expanded_b[expanded_index])
+            eo.append(self._expanded_o[expanded_index])
+            al.append(self.a[i])
             bl.append(self.b[i])
-            eo.append(self._expanded_output[expanded_index])
             ol.append(self.output[i])
             expanded_index += 1
+
         # combine above using Cat
         m.d.comb += Cat(*ea).eq(Cat(*al))
         m.d.comb += Cat(*eb).eq(Cat(*bl))
         m.d.comb += Cat(*ol).eq(Cat(*eo))
+
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
-        m.d.comb += self._expanded_output.eq(
+        m.d.comb += self._expanded_o.eq(
             self._expanded_a + self._expanded_b)
         return m
 
@@ -337,15 +389,15 @@ class AddReduce(Elaboratable):
         # create full adders for this recursive level.
         # this shrinks N terms to 2 * (N // 3) plus the remainder
         for i in groups:
-            adder_i = FullAdder(len(self.output))
+            adder_i = MaskedFullAdder(len(self.output))
             setattr(m.submodules, f"adder_{i}", adder_i)
             m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
             m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
             m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
+            m.d.comb += adder_i.mask.eq(part_mask)
             add_intermediate_term(adder_i.sum)
-            shifted_carry = adder_i.carry << 1
             # mask out carry bits to prevent carries between partitions
-            add_intermediate_term((adder_i.carry << 1) & part_mask)
+            add_intermediate_term(adder_i.mcarry)
         # handle the remaining inputs.
         if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
             add_intermediate_term(self._resized_inputs[-1])
@@ -565,14 +617,17 @@ class Part(Elaboratable):
         m = Module()
 
         pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
+        # negated-temporary copy of partition bits
+        npbs = Signal.like(pbs, reset_less=True)
+        m.d.comb += npbs.eq(~pbs)
         byte_count = 8 // len(parts)
         for i in range(len(parts)):
             pbl = []
-            pbl.append(~pbs[i * byte_count - 1])
+            pbl.append(npbs[i * byte_count - 1])
             for j in range(i * byte_count, (i + 1) * byte_count - 1):
                 pbl.append(pbs[j])
-            pbl.append(~pbs[(i + 1) * byte_count - 1])
-            value = Signal(len(pbl), reset_less=True)
+            pbl.append(npbs[(i + 1) * byte_count - 1])
+            value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
             m.d.comb += value.eq(Cat(*pbl))
             m.d.comb += parts[i].eq(~(value).bool())
             m.d.comb += delayed_parts[0][i].eq(parts[i])
@@ -590,7 +645,7 @@ class Part(Elaboratable):
         for i in range(len(parts)):
             # work out bit-inverted and +1 term for a.
             pa = LSBNegTerm(bit_wid)
-            setattr(m.submodules, "lnt_a_%d" % i, pa)
+            setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
             m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
@@ -600,7 +655,7 @@ class Part(Elaboratable):
 
             # work out bit-inverted and +1 term for b
             pb = LSBNegTerm(bit_wid)
-            setattr(m.submodules, "lnt_b_%d" % i, pb)
+            setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
             m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
             m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a