""" Final stage of add reduce
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points,
+    def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
                        partition_step=1):
+        self.lidx = lidx
         self.partition_step = partition_step
         self.output_width = output_width
         self.n_inputs = n_inputs
         return FinalReduceData(self.partition_points,
                                  self.output_width, self.n_parts)
 
+    def setup(self, m, i):
+        m.submodules.finaladd = self
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        return self.o
+
     def elaborate(self, platform):
         """Elaborate this module."""
         m = Module()
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, partition_points,
+    def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
                        partition_step=1):
         """Create an ``AddReduce``.
 
         :param output_width: bit-width of ``output``.
         :param partition_points: the input partition points.
         """
+        self.lidx = lidx
         self.partition_step = partition_step
         self.n_inputs = n_inputs
         self.n_parts = n_parts
         return AddReduceData(self.partition_points, self.n_terms,
                              self.output_width, self.n_parts)
 
+    def setup(self, m, i):
+        setattr(m.submodules, "addreduce_%d" % self.lidx, self)
+        m.d.comb += self.i.eq(i)
+
+    def process(self, i):
+        return self.o
+
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
         retval = len(groups)*2
             groups = AddReduceSingle.full_adder_groups(len(inputs))
             if len(groups) == 0:
                 break
-            next_level = AddReduceSingle(ilen, self.output_width, n_parts,
+            lidx = len(mods)
+            next_level = AddReduceSingle(lidx, ilen, self.output_width, n_parts,
                                          partition_points,
                                          self.partition_step)
             mods.append(next_level)
             ilen = len(inputs)
             part_ops = next_level.i.part_ops
 
-        next_level = FinalAdd(ilen, self.output_width, n_parts,
+        lidx = len(mods)
+        next_level = FinalAdd(lidx, ilen, self.output_width, n_parts,
                               partition_points, self.partition_step)
         mods.append(next_level)
 
         i = at.i
         for idx in range(len(at.levels)):
             mcur = at.levels[idx]
-            setattr(m.submodules, "addreduce_%d" % idx, mcur)
+            mcur.setup(m, i)
+            o = mcur.ospec()
             if idx in self.register_levels:
-                m.d.sync += mcur.i.eq(i)
+                m.d.sync += o.eq(mcur.process(i))
             else:
-                m.d.comb += mcur.i.eq(i)
-            i = mcur.o # for next loop
+                m.d.comb += o.eq(mcur.process(i))
+            i = o # for next loop
 
         interm = Intermediates(128, 8, part_pts)
         interm.setup(m, i)