removing recursion from AddReduce
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 20 Aug 2019 09:34:26 +0000 (10:34 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 20 Aug 2019 09:34:26 +0000 (10:34 +0100)
src/ieee754/part_mul_add/multiply.py

index 2c2c65a9b59eed533f6bb30005e5876e0079bd9f..f6dc3ba177daa13854a1cfe663db13fb384a4920 100644 (file)
@@ -325,12 +325,17 @@ class AddReduceSingle(Elaboratable):
             raise ValueError("partition_points doesn't fit in output_width")
         self._reg_partition_points = self.partition_points.like()
 
-        max_level = AddReduce.get_max_level(len(self.inputs))
+        max_level = AddReduceSingle.get_max_level(len(self.inputs))
         for level in self.register_levels:
             if level > max_level:
                 raise ValueError(
                     "not enough adder levels for specified register levels")
 
+        self.groups = AddReduceSingle.full_adder_groups(len(self.inputs))
+        self._intermediate_terms = []
+        if len(self.groups) != 0:
+            self.create_next_terms()
+
     @staticmethod
     def get_max_level(input_count):
         """Get the maximum level.
@@ -340,12 +345,13 @@ class AddReduceSingle(Elaboratable):
         """
         retval = 0
         while True:
-            groups = AddReduce.full_adder_groups(input_count)
+            groups = AddReduceSingle.full_adder_groups(input_count)
             if len(groups) == 0:
                 return retval
             input_count %= FULL_ADDER_INPUT_COUNT
             input_count += 2 * len(groups)
             retval += 1
+
     @staticmethod
     def full_adder_groups(input_count):
         """Get ``inputs`` indices for which a full adder should be built."""
@@ -353,7 +359,7 @@ class AddReduceSingle(Elaboratable):
                      input_count - FULL_ADDER_INPUT_COUNT + 1,
                      FULL_ADDER_INPUT_COUNT)
 
-    def _elaborate(self, platform):
+    def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
 
@@ -368,10 +374,12 @@ class AddReduceSingle(Elaboratable):
             m.d.comb += resized_input_assignments
             m.d.comb += self._reg_partition_points.eq(self.partition_points)
 
-        groups = AddReduceSingle.full_adder_groups(len(self.inputs))
+        for (value, term) in self._intermediate_terms:
+            m.d.comb += term.eq(value)
+
         # if there are no full adders to create, then we handle the base cases
         # and return, otherwise we go on to the recursive case
-        if len(groups) == 0:
+        if len(self.groups) == 0:
             if len(self.inputs) == 0:
                 # use 0 as the default output value
                 m.d.comb += self.output.eq(0)
@@ -379,8 +387,7 @@ class AddReduceSingle(Elaboratable):
                 # handle single input
                 m.d.comb += self.output.eq(self._resized_inputs[0])
             else:
-                # base case for adding 2 or more inputs, which get recursively
-                # reduced to 2 inputs
+                # base case for adding 2 inputs
                 assert len(self.inputs) == 2
                 adder = PartitionedAdder(len(self.output),
                                          self._reg_partition_points)
@@ -388,32 +395,44 @@ class AddReduceSingle(Elaboratable):
                 m.d.comb += adder.a.eq(self._resized_inputs[0])
                 m.d.comb += adder.b.eq(self._resized_inputs[1])
                 m.d.comb += self.output.eq(adder.output)
-            return None, m
+            return m
+
+        mask = self._reg_partition_points.as_mask(len(self.output))
+        m.d.comb += self.part_mask.eq(mask)
+
+        # add and link the intermediate term modules
+        for i, (iidx, adder_i) in enumerate(self.adders):
+            setattr(m.submodules, f"adder_{i}", adder_i)
+
+            m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx])
+            m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1])
+            m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2])
+            m.d.comb += adder_i.mask.eq(self.part_mask)
+
+        return m
+
+    def create_next_terms(self):
 
         # go on to prepare recursive case
         intermediate_terms = []
+        _intermediate_terms = []
 
         def add_intermediate_term(value):
             intermediate_term = Signal(
                 len(self.output),
                 name=f"intermediate_terms[{len(intermediate_terms)}]")
+            _intermediate_terms.append((value, intermediate_term))
             intermediate_terms.append(intermediate_term)
-            m.d.comb += intermediate_term.eq(value)
 
         # store mask in intermediary (simplifies graph)
-        part_mask = Signal(len(self.output), reset_less=True)
-        mask = self._reg_partition_points.as_mask(len(self.output))
-        m.d.comb += part_mask.eq(mask)
+        self.part_mask = Signal(len(self.output), reset_less=True)
 
         # create full adders for this recursive level.
         # this shrinks N terms to 2 * (N // 3) plus the remainder
-        for i in groups:
+        self.adders = []
+        for i in self.groups:
             adder_i = MaskedFullAdder(len(self.output))
-            setattr(m.submodules, f"adder_{i}", adder_i)
-            m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
-            m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
-            m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
-            m.d.comb += adder_i.mask.eq(part_mask)
+            self.adders.append((i, adder_i))
             # add both the sum and the masked-carry to the next level.
             # 3 inputs have now been reduced to 2...
             add_intermediate_term(adder_i.sum)
@@ -430,10 +449,11 @@ class AddReduceSingle(Elaboratable):
         else:
             assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
 
-        return intermediate_terms, m
+        self.intermediate_terms = intermediate_terms
+        self._intermediate_terms = _intermediate_terms
 
 
-class AddReduce(AddReduceSingle):
+class AddReduce(Elaboratable):
     """Recursively Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -454,28 +474,50 @@ class AddReduce(AddReduceSingle):
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        AddReduceSingle.__init__(self, inputs, output_width, register_levels,
-                                 partition_points)
+        self.inputs = inputs
+        self.output = Signal(output_width)
+        self.output_width = output_width
+        self.register_levels = register_levels
+        self.partition_points = partition_points
 
-    def next_register_levels(self):
+        self.create_levels()
+
+    @staticmethod
+    def next_register_levels(register_levels):
         """``Iterable`` of ``register_levels`` for next recursive level."""
-        for level in self.register_levels:
+        for level in register_levels:
             if level > 0:
                 yield level - 1
 
+    def create_levels(self):
+        """creates reduction levels"""
+
+        mods = []
+        next_levels = self.register_levels
+        partition_points = self.partition_points
+        inputs = self.inputs
+        while True:
+            next_level = AddReduceSingle(inputs, self.output_width, next_levels,
+                                 partition_points)
+            mods.append(next_level)
+            if len(next_level.groups) == 0:
+                break
+            next_levels = list(AddReduce.next_register_levels(next_levels))
+            partition_points = next_level._reg_partition_points
+            inputs = next_level.intermediate_terms
+
+        self.levels = mods
+
     def elaborate(self, platform):
         """Elaborate this module."""
-        intermediate_terms, m = AddReduceSingle._elaborate(self, platform)
-        if intermediate_terms is None:
-            return m
+        m = Module()
+
+        for i, next_level in enumerate(self.levels):
+            setattr(m.submodules, "next_level%d" % i, next_level)
 
-        # recursive invocation of ``AddReduce``
-        next_level = AddReduce(intermediate_terms,
-                               len(self.output),
-                               self.next_register_levels(),
-                               self._reg_partition_points)
-        m.submodules.next_level = next_level
+        # output comes from last module
         m.d.comb += self.output.eq(next_level.output)
+
         return m