whitespace
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index 81c5a8ac0f3f3d535bcfe98a8a7431ba1c8df024..517e6cf9e2835b712b95c43e7832f04283a5bd3d 100644 (file)
@@ -140,7 +140,7 @@ class FullAdder(Elaboratable):
         return m
 
 
-class MaskedFullAdder(FullAdder):
+class MaskedFullAdder(Elaboratable):
     """Masked Full Adder.
 
     :attribute mask: the carry partition mask
@@ -153,6 +153,13 @@ class MaskedFullAdder(FullAdder):
     FullAdders are always used with a "mask" on the output.  To keep
     the graphviz "clean", this class performs the masking here rather
     than inside a large for-loop.
+
+    See the following discussion as to why this is no longer derived
+    from FullAdder.  Each carry is shifted here *before* being ANDed
+    with the mask, so that an AOI cell may be used (which is more
+    gate-efficient)
+    https://en.wikipedia.org/wiki/AND-OR-Invert
+    https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
     """
 
     def __init__(self, width):
@@ -160,14 +167,31 @@ class MaskedFullAdder(FullAdder):
 
         :param width: the bit width of the input and output
         """
-        FullAdder.__init__(self, width)
-        self.mask = Signal(width)
-        self.mcarry = Signal(width)
+        self.width = width
+        self.mask = Signal(width, reset_less=True)
+        self.mcarry = Signal(width, reset_less=True)
+        self.in0 = Signal(width, reset_less=True)
+        self.in1 = Signal(width, reset_less=True)
+        self.in2 = Signal(width, reset_less=True)
+        self.sum = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
         """Elaborate this module."""
-        m = FullAdder.elaborate(self, platform)
-        m.d.comb += self.mcarry.eq((self.carry << 1) & self.mask)
+        m = Module()
+        s1 = Signal(self.width, reset_less=True)
+        s2 = Signal(self.width, reset_less=True)
+        s3 = Signal(self.width, reset_less=True)
+        c1 = Signal(self.width, reset_less=True)
+        c2 = Signal(self.width, reset_less=True)
+        c3 = Signal(self.width, reset_less=True)
+        m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
+        m.d.comb += s1.eq(Cat(0, self.in0))
+        m.d.comb += s2.eq(Cat(0, self.in1))
+        m.d.comb += s3.eq(Cat(0, self.in2))
+        m.d.comb += c1.eq(s1 & s2 & self.mask)
+        m.d.comb += c2.eq(s2 & s3 & self.mask)
+        m.d.comb += c3.eq(s3 & s1 & self.mask)
+        m.d.comb += self.mcarry.eq(c1 | c2 | c3)
         return m
 
 
@@ -179,6 +203,14 @@ class PartitionedAdder(Elaboratable):
     to the next bit.  Then the final output *removes* the extra bits from
     the result.
 
+    partition: .... P... P... P... P... (32 bits)
+    a        : .... .... .... .... .... (32 bits)
+    b        : .... .... .... .... .... (32 bits)
+    exp-a    : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
+    exp-b    : ....0....0....0....0.... (32 bits plus 4 zeros)
+    exp-o    : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
+    o        : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
+
     :attribute width: the bit width of the input and output. Read-only.
     :attribute a: the first input to the adder
     :attribute b: the second input to the adder
@@ -210,9 +242,9 @@ class PartitionedAdder(Elaboratable):
         # simulation bugs involving sync.  it is *not* necessary to
         # have them here, they should (under normal circumstances)
         # be moved into elaborate, as they are entirely local
-        self._expanded_a = Signal(expanded_width)
-        self._expanded_b = Signal(expanded_width)
-        self._expanded_output = Signal(expanded_width)
+        self._expanded_a = Signal(expanded_width) # includes extra part-points
+        self._expanded_b = Signal(expanded_width) # likewise.
+        self._expanded_o = Signal(expanded_width) # likewise.
 
     def elaborate(self, platform):
         """Elaborate this module."""
@@ -236,13 +268,13 @@ class PartitionedAdder(Elaboratable):
                 ea.append(self._expanded_a[expanded_index])
                 al.append(~self.partition_points[i]) # add extra bit in a
                 eb.append(self._expanded_b[expanded_index])
-                bl.append(C(0)) # do *not* add extra bit into b.
-                expanded_index += 1
+                bl.append(C(0)) # yes, add a zero
+                expanded_index += 1 # skip the extra point.  NOT in the output
             ea.append(self._expanded_a[expanded_index])
-            al.append(self.a[i])
             eb.append(self._expanded_b[expanded_index])
+            eo.append(self._expanded_o[expanded_index])
+            al.append(self.a[i])
             bl.append(self.b[i])
-            eo.append(self._expanded_output[expanded_index])
             ol.append(self.output[i])
             expanded_index += 1
 
@@ -253,7 +285,7 @@ class PartitionedAdder(Elaboratable):
 
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
-        m.d.comb += self._expanded_output.eq(
+        m.d.comb += self._expanded_o.eq(
             self._expanded_a + self._expanded_b)
         return m
 
@@ -261,7 +293,7 @@ class PartitionedAdder(Elaboratable):
 FULL_ADDER_INPUT_COUNT = 3
 
 
-class AddReduce(Elaboratable):
+class AddReduceSingle(Elaboratable):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -292,12 +324,18 @@ class AddReduce(Elaboratable):
         if not self.partition_points.fits_in_width(output_width):
             raise ValueError("partition_points doesn't fit in output_width")
         self._reg_partition_points = self.partition_points.like()
-        max_level = AddReduce.get_max_level(len(self.inputs))
+
+        max_level = AddReduceSingle.get_max_level(len(self.inputs))
         for level in self.register_levels:
             if level > max_level:
                 raise ValueError(
                     "not enough adder levels for specified register levels")
 
+        self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
+        self._intermediate_terms = []
+        if len(self.groups) != 0:
+            self.create_next_terms()
+
     @staticmethod
     def get_max_level(input_count):
         """Get the maximum level.
@@ -307,19 +345,13 @@ class AddReduce(Elaboratable):
         """
         retval = 0
         while True:
-            groups = AddReduce.full_adder_groups(input_count)
+            groups = AddReduceSingle.full_adder_groups(input_count)
             if len(groups) == 0:
                 return retval
             input_count %= FULL_ADDER_INPUT_COUNT
             input_count += 2 * len(groups)
             retval += 1
 
-    def next_register_levels(self):
-        """``Iterable`` of ``register_levels`` for next recursive level."""
-        for level in self.register_levels:
-            if level > 0:
-                yield level - 1
-
     @staticmethod
     def full_adder_groups(input_count):
         """Get ``inputs`` indices for which a full adder should be built."""
@@ -342,10 +374,12 @@ class AddReduce(Elaboratable):
             m.d.comb += resized_input_assignments
             m.d.comb += self._reg_partition_points.eq(self.partition_points)
 
-        groups = AddReduce.full_adder_groups(len(self.inputs))
+        for (value, term) in self._intermediate_terms:
+            m.d.comb += term.eq(value)
+
         # if there are no full adders to create, then we handle the base cases
         # and return, otherwise we go on to the recursive case
-        if len(groups) == 0:
+        if len(self.groups) == 0:
             if len(self.inputs) == 0:
                 # use 0 as the default output value
                 m.d.comb += self.output.eq(0)
@@ -353,8 +387,7 @@ class AddReduce(Elaboratable):
                 # handle single input
                 m.d.comb += self.output.eq(self._resized_inputs[0])
             else:
-                # base case for adding 2 or more inputs, which get recursively
-                # reduced to 2 inputs
+                # base case for adding 2 inputs
                 assert len(self.inputs) == 2
                 adder = PartitionedAdder(len(self.output),
                                          self._reg_partition_points)
@@ -363,32 +396,46 @@ class AddReduce(Elaboratable):
                 m.d.comb += adder.b.eq(self._resized_inputs[1])
                 m.d.comb += self.output.eq(adder.output)
             return m
-        # go on to handle recursive case
+
+        mask = self._reg_partition_points.as_mask(len(self.output))
+        m.d.comb += self.part_mask.eq(mask)
+
+        # add and link the intermediate term modules
+        for i, (iidx, adder_i) in enumerate(self.adders):
+            setattr(m.submodules, f"adder_{i}", adder_i)
+
+            m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
+            m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
+            m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
+            m.d.comb += adder_i.mask.eq(self.part_mask)
+
+        return m
+
+    def create_next_terms(self):
+
+        # go on to prepare recursive case
         intermediate_terms = []
+        _intermediate_terms = []
 
         def add_intermediate_term(value):
             intermediate_term = Signal(
                 len(self.output),
                 name=f"intermediate_terms[{len(intermediate_terms)}]")
+            _intermediate_terms.append((value, intermediate_term))
             intermediate_terms.append(intermediate_term)
-            m.d.comb += intermediate_term.eq(value)
 
         # store mask in intermediary (simplifies graph)
-        part_mask = Signal(len(self.output), reset_less=True)
-        mask = self._reg_partition_points.as_mask(len(self.output))
-        m.d.comb += part_mask.eq(mask)
+        self.part_mask = Signal(len(self.output), reset_less=True)
 
         # create full adders for this recursive level.
         # this shrinks N terms to 2 * (N // 3) plus the remainder
-        for i in groups:
+        self.adders = []
+        for i in self.groups:
             adder_i = MaskedFullAdder(len(self.output))
-            setattr(m.submodules, f"adder_{i}", adder_i)
-            m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
-            m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
-            m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
-            m.d.comb += adder_i.mask.eq(part_mask)
+            self.adders.append((i, adder_i))
+            # add both the sum and the masked-carry to the next level.
+            # 3 inputs have now been reduced to 2...
             add_intermediate_term(adder_i.sum)
-            # mask out carry bits to prevent carries between partitions
             add_intermediate_term(adder_i.mcarry)
         # handle the remaining inputs.
         if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
@@ -401,13 +448,76 @@ class AddReduce(Elaboratable):
             add_intermediate_term(self._resized_inputs[-1])
         else:
             assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
-        # recursive invocation of ``AddReduce``
-        next_level = AddReduce(intermediate_terms,
-                               len(self.output),
-                               self.next_register_levels(),
-                               self._reg_partition_points)
-        m.submodules.next_level = next_level
+
+        self.intermediate_terms = intermediate_terms
+        self._intermediate_terms = _intermediate_terms
+
+
+class AddReduce(Elaboratable):
+    """Recursively Add list of numbers together.
+
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
+
+    def __init__(self, inputs, output_width, register_levels, partition_points):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param register_levels: List of nesting levels that should have
+            pipeline registers.
+        :param partition_points: the input partition points.
+        """
+        self.inputs = inputs
+        self.output = Signal(output_width)
+        self.output_width = output_width
+        self.register_levels = register_levels
+        self.partition_points = partition_points
+
+        self.create_levels()
+
+    @staticmethod
+    def next_register_levels(register_levels):
+        """``Iterable`` of ``register_levels`` for next recursive level."""
+        for level in register_levels:
+            if level > 0:
+                yield level - 1
+
+    def create_levels(self):
+        """creates reduction levels"""
+
+        mods = []
+        next_levels = self.register_levels
+        partition_points = self.partition_points
+        inputs = self.inputs
+        while True:
+            next_level = AddReduceSingle(inputs, self.output_width, next_levels,
+                                 partition_points)
+            mods.append(next_level)
+            if len(next_level.groups) == 0:
+                break
+            next_levels = list(AddReduce.next_register_levels(next_levels))
+            partition_points = next_level._reg_partition_points
+            inputs = next_level.intermediate_terms
+
+        self.levels = mods
+
+    def elaborate(self, platform):
+        """Elaborate this module."""
+        m = Module()
+
+        for i, next_level in enumerate(self.levels):
+            setattr(m.submodules, "next_level%d" % i, next_level)
+
+        # output comes from last module
         m.d.comb += self.output.eq(next_level.output)
+
         return m
 
 
@@ -532,6 +642,7 @@ class ProductTerms(Elaboratable):
 
         return m
 
+
 class LSBNegTerm(Elaboratable):
 
     def __init__(self, bit_width):