split "actionable" part of AddReduce out from "recursive" part
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 20 Aug 2019 06:13:47 +0000 (07:13 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 20 Aug 2019 06:13:47 +0000 (07:13 +0100)
src/ieee754/part_mul_add/multiply.py

index 4c3a3cf177af557c2d0c1fa8520847ae6fa33cfb..e86f655b53b6ffe5d06086a36ca62f2523a7850f 100644 (file)
@@ -269,7 +269,7 @@ class PartitionedAdder(Elaboratable):
 FULL_ADDER_INPUT_COUNT = 3
 
 
-class AddReduce(Elaboratable):
+class AddReduceSingle(Elaboratable):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -300,6 +300,7 @@ class AddReduce(Elaboratable):
         if not self.partition_points.fits_in_width(output_width):
             raise ValueError("partition_points doesn't fit in output_width")
         self._reg_partition_points = self.partition_points.like()
+
         max_level = AddReduce.get_max_level(len(self.inputs))
         for level in self.register_levels:
             if level > max_level:
@@ -321,13 +322,6 @@ class AddReduce(Elaboratable):
             input_count %= FULL_ADDER_INPUT_COUNT
             input_count += 2 * len(groups)
             retval += 1
-
-    def next_register_levels(self):
-        """``Iterable`` of ``register_levels`` for next recursive level."""
-        for level in self.register_levels:
-            if level > 0:
-                yield level - 1
-
     @staticmethod
     def full_adder_groups(input_count):
         """Get ``inputs`` indices for which a full adder should be built."""
@@ -335,7 +329,7 @@ class AddReduce(Elaboratable):
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
-    def elaborate(self, platform):
+    def _elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
@@ -350,7 +344,7 @@ class AddReduce(Elaboratable):
             m.d.comb += resized_input_assignments
             m.d.comb += self._reg_partition_points.eq(self.partition_points)
 
-        groups = AddReduce.full_adder_groups(len(self.inputs))
+        groups = AddReduceSingle.full_adder_groups(len(self.inputs))
         # if there are no full adders to create, then we handle the base cases
         # and return, otherwise we go on to the recursive case
         if len(groups) == 0:
@@ -370,8 +364,9 @@ class AddReduce(Elaboratable):
                 m.d.comb += adder.a.eq(self._resized_inputs[0])
                 m.d.comb += adder.b.eq(self._resized_inputs[1])
                 m.d.comb += self.output.eq(adder.output)
-            return m
-        # go on to handle recursive case
+            return None, m
+
+        # go on to prepare recursive case
         intermediate_terms = []
 
         def add_intermediate_term(value):
@@ -410,6 +405,46 @@ class AddReduce(Elaboratable):
             add_intermediate_term(self._resized_inputs[-1])
         else:
             assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
+
+        return intermediate_terms, m
+
+
+class AddReduce(AddReduceSingle):
+    """Recursively Add list of numbers together.
+
+    :attribute inputs: input ``Signal``s to be summed. Modification not
+        supported, except for by ``Signal.eq``.
+    :attribute register_levels: List of nesting levels that should have
+        pipeline registers.
+    :attribute output: output sum.
+    :attribute partition_points: the input partition points. Modification not
+        supported, except for by ``Signal.eq``.
+    """
+
+    def __init__(self, inputs, output_width, register_levels, partition_points):
+        """Create an ``AddReduce``.
+
+        :param inputs: input ``Signal``s to be summed.
+        :param output_width: bit-width of ``output``.
+        :param register_levels: List of nesting levels that should have
+            pipeline registers.
+        :param partition_points: the input partition points.
+        """
+        AddReduceSingle.__init__(self, inputs, output_width, register_levels,
+                                 partition_points)
+
+    def next_register_levels(self):
+        """``Iterable`` of ``register_levels`` for next recursive level."""
+        for level in self.register_levels:
+            if level > 0:
+                yield level - 1
+
+    def elaborate(self, platform):
+        """Elaborate this module."""
+        intermediate_terms, m = AddReduceSingle._elaborate(self, platform)
+        if intermediate_terms is None:
+            return m
+
         # recursive invocation of ``AddReduce``
         next_level = AddReduce(intermediate_terms,
                                len(self.output),