From 9b685e27a9aeebf19c250b411c6e471cc8d4bcd8 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Tue, 20 Aug 2019 10:34:26 +0100 Subject: [PATCH] removing recursion from AddReduce --- src/ieee754/part_mul_add/multiply.py | 108 +++++++++++++++++++-------- 1 file changed, 75 insertions(+), 33 deletions(-) diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 2c2c65a9..f6dc3ba1 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -325,12 +325,17 @@ class AddReduceSingle(Elaboratable): raise ValueError("partition_points doesn't fit in output_width") self._reg_partition_points = self.partition_points.like() - max_level = AddReduce.get_max_level(len(self.inputs)) + max_level = AddReduceSingle.get_max_level(len(self.inputs)) for level in self.register_levels: if level > max_level: raise ValueError( "not enough adder levels for specified register levels") + self.groups = AddReduceSingle.full_adder_groups(len(self.inputs)) + self._intermediate_terms = [] + if len(self.groups) != 0: + self.create_next_terms() + @staticmethod def get_max_level(input_count): """Get the maximum level. @@ -340,12 +345,13 @@ class AddReduceSingle(Elaboratable): """ retval = 0 while True: - groups = AddReduce.full_adder_groups(input_count) + groups = AddReduceSingle.full_adder_groups(input_count) if len(groups) == 0: return retval input_count %= FULL_ADDER_INPUT_COUNT input_count += 2 * len(groups) retval += 1 + @staticmethod def full_adder_groups(input_count): """Get ``inputs`` indices for which a full adder should be built.""" @@ -353,7 +359,7 @@ class AddReduceSingle(Elaboratable): input_count - FULL_ADDER_INPUT_COUNT + 1, FULL_ADDER_INPUT_COUNT) - def _elaborate(self, platform): + def elaborate(self, platform): """Elaborate this module.""" m = Module() @@ -368,10 +374,12 @@ class AddReduceSingle(Elaboratable): m.d.comb += resized_input_assignments m.d.comb += self._reg_partition_points.eq(self.partition_points) - groups = AddReduceSingle.full_adder_groups(len(self.inputs)) + for (value, term) in self._intermediate_terms: + m.d.comb += term.eq(value) + # if there are no full adders to create, then we handle the base cases # and return, otherwise we go on to the recursive case - if len(groups) == 0: + if len(self.groups) == 0: if len(self.inputs) == 0: # use 0 as the default output value m.d.comb += self.output.eq(0) @@ -379,8 +387,7 @@ class AddReduceSingle(Elaboratable): # handle single input m.d.comb += self.output.eq(self._resized_inputs[0]) else: - # base case for adding 2 or more inputs, which get recursively - # reduced to 2 inputs + # base case for adding 2 inputs assert len(self.inputs) == 2 adder = PartitionedAdder(len(self.output), self._reg_partition_points) @@ -388,32 +395,44 @@ class AddReduceSingle(Elaboratable): m.d.comb += adder.a.eq(self._resized_inputs[0]) m.d.comb += adder.b.eq(self._resized_inputs[1]) m.d.comb += self.output.eq(adder.output) - return None, m + return m + + mask = self._reg_partition_points.as_mask(len(self.output)) + m.d.comb += self.part_mask.eq(mask) + + # add and link the intermediate term modules + for i, (iidx, adder_i) in enumerate(self.adders): + setattr(m.submodules, f"adder_{i}", adder_i) + + m.d.comb += adder_i.in0.eq(self._resized_inputs[iidx]) + m.d.comb += adder_i.in1.eq(self._resized_inputs[iidx + 1]) + m.d.comb += adder_i.in2.eq(self._resized_inputs[iidx + 2]) + m.d.comb += adder_i.mask.eq(self.part_mask) + + return m + + def create_next_terms(self): # go on to prepare recursive case intermediate_terms = [] + _intermediate_terms = [] def add_intermediate_term(value): intermediate_term = Signal( len(self.output), name=f"intermediate_terms[{len(intermediate_terms)}]") + _intermediate_terms.append((value, intermediate_term)) intermediate_terms.append(intermediate_term) - m.d.comb += intermediate_term.eq(value) # store mask in intermediary (simplifies graph) - part_mask = Signal(len(self.output), reset_less=True) - mask = self._reg_partition_points.as_mask(len(self.output)) - m.d.comb += part_mask.eq(mask) + self.part_mask = Signal(len(self.output), reset_less=True) # create full adders for this recursive level. # this shrinks N terms to 2 * (N // 3) plus the remainder - for i in groups: + self.adders = [] + for i in self.groups: adder_i = MaskedFullAdder(len(self.output)) - setattr(m.submodules, f"adder_{i}", adder_i) - m.d.comb += adder_i.in0.eq(self._resized_inputs[i]) - m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1]) - m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2]) - m.d.comb += adder_i.mask.eq(part_mask) + self.adders.append((i, adder_i)) # add both the sum and the masked-carry to the next level. # 3 inputs have now been reduced to 2... add_intermediate_term(adder_i.sum) @@ -430,10 +449,11 @@ class AddReduceSingle(Elaboratable): else: assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0 - return intermediate_terms, m + self.intermediate_terms = intermediate_terms + self._intermediate_terms = _intermediate_terms -class AddReduce(AddReduceSingle): +class AddReduce(Elaboratable): """Recursively Add list of numbers together. :attribute inputs: input ``Signal``s to be summed. Modification not @@ -454,28 +474,50 @@ class AddReduce(AddReduceSingle): pipeline registers. :param partition_points: the input partition points. """ - AddReduceSingle.__init__(self, inputs, output_width, register_levels, - partition_points) + self.inputs = inputs + self.output = Signal(output_width) + self.output_width = output_width + self.register_levels = register_levels + self.partition_points = partition_points - def next_register_levels(self): + self.create_levels() + + @staticmethod + def next_register_levels(register_levels): """``Iterable`` of ``register_levels`` for next recursive level.""" - for level in self.register_levels: + for level in register_levels: if level > 0: yield level - 1 + def create_levels(self): + """creates reduction levels""" + + mods = [] + next_levels = self.register_levels + partition_points = self.partition_points + inputs = self.inputs + while True: + next_level = AddReduceSingle(inputs, self.output_width, next_levels, + partition_points) + mods.append(next_level) + if len(next_level.groups) == 0: + break + next_levels = list(AddReduce.next_register_levels(next_levels)) + partition_points = next_level._reg_partition_points + inputs = next_level.intermediate_terms + + self.levels = mods + def elaborate(self, platform): """Elaborate this module.""" - intermediate_terms, m = AddReduceSingle._elaborate(self, platform) - if intermediate_terms is None: - return m + m = Module() + + for i, next_level in enumerate(self.levels): + setattr(m.submodules, "next_level%d" % i, next_level) - # recursive invocation of ``AddReduce`` - next_level = AddReduce(intermediate_terms, - len(self.output), - self.next_register_levels(), - self._reg_partition_points) - m.submodules.next_level = next_level + # output comes from last module m.d.comb += self.output.eq(next_level.output) + return m -- 2.30.2