use PipelineSpec and PipeModBase in AddReduce
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 26 Aug 2019 08:48:11 +0000 (09:48 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 26 Aug 2019 08:48:11 +0000 (09:48 +0100)
src/ieee754/part_mul_add/multiply.py

index 7202c9cf653d99c30ded41442eab875dcfd12532..51f994d25c4359cec5a246fe81c0f719b8411a86 100644 (file)
@@ -406,7 +406,7 @@ class FinalAdd(Elaboratable):
         return m
 
 
-class AddReduceSingle(Elaboratable):
+class AddReduceSingle(PipeModBase):
     """Add list of numbers together.
 
     :attribute inputs: input ``Signal``s to be summed. Modification not
@@ -418,7 +418,7 @@ class AddReduceSingle(Elaboratable):
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
+    def __init__(self, pspec, lidx, n_inputs, partition_points,
                        partition_step=1):
         """Create an ``AddReduce``.
 
@@ -429,17 +429,16 @@ class AddReduceSingle(Elaboratable):
         self.lidx = lidx
         self.partition_step = partition_step
         self.n_inputs = n_inputs
-        self.n_parts = n_parts
-        self.output_width = output_width
+        self.n_parts = pspec.n_parts
+        self.output_width = pspec.width * 2
         self.partition_points = PartitionPoints(partition_points)
-        if not self.partition_points.fits_in_width(output_width):
+        if not self.partition_points.fits_in_width(self.output_width):
             raise ValueError("partition_points doesn't fit in output_width")
 
         self.groups = AddReduceSingle.full_adder_groups(n_inputs)
         self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
 
-        self.i = self.ispec()
-        self.o = self.ospec()
+        super().__init__(pspec, "addreduce_%d" % lidx)
 
     def ispec(self):
         return AddReduceData(self.partition_points, self.n_inputs,
@@ -449,13 +448,6 @@ class AddReduceSingle(Elaboratable):
         return AddReduceData(self.partition_points, self.n_terms,
                              self.output_width, self.n_parts)
 
-    def setup(self, m, i):
-        setattr(m.submodules, "addreduce_%d" % self.lidx, self)
-        m.d.comb += self.i.eq(i)
-
-    def process(self, i):
-        return self.o
-
     @staticmethod
     def calc_n_inputs(n_inputs, groups):
         retval = len(groups)*2
@@ -566,7 +558,7 @@ class AddReduceInternal:
         supported, except for by ``Signal.eq``.
     """
 
-    def __init__(self, i, output_width, partition_step=1):
+    def __init__(self, i, pspec, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -574,9 +566,10 @@ class AddReduceInternal:
         :param partition_points: the input partition points.
         """
         self.i = i
+        self.pspec = pspec
         self.inputs = i.terms
         self.part_ops = i.part_ops
-        self.output_width = output_width
+        self.output_width = pspec.width * 2
         self.partition_points = i.part_pts
         self.partition_step = partition_step
 
@@ -596,7 +589,7 @@ class AddReduceInternal:
             if len(groups) == 0:
                 break
             lidx = len(mods)
-            next_level = AddReduceSingle(lidx, ilen, self.output_width, n_parts,
+            next_level = AddReduceSingle(self.pspec, lidx, ilen,
                                          partition_points,
                                          self.partition_step)
             mods.append(next_level)
@@ -1185,12 +1178,12 @@ class AllTerms(PipeModBase):
     """Set of terms to be added together
     """
 
-    def __init__(self, pspec):
+    def __init__(self, pspec, n_inputs):
         """Create an ``AllTerms``.
         """
-        self.n_inputs = pspec.n_inputs
+        self.n_inputs = n_inputs
         self.n_parts = pspec.n_parts
-        self.output_width = pspec.width
+        self.output_width = pspec.width * 2
         super().__init__(pspec, "allterms")
 
     def ispec(self):
@@ -1380,8 +1373,7 @@ class Mul8_16_32_64(Elaboratable):
 
         self.id_wid = 0 # num_bits(num_rows)
         self.op_wid = 0
-        self.pspec = PipelineSpec(128, self.id_wid, self.op_wid, n_ops=3)
-        self.pspec.n_inputs = 64 + 4
+        self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
         self.pspec.n_parts = 8
 
         # parameter(s)
@@ -1412,14 +1404,14 @@ class Mul8_16_32_64(Elaboratable):
         part_pts = self.part_pts
 
         n_parts = self.pspec.n_parts
-        n_inputs = self.pspec.n_inputs
-        output_width = self.pspec.width
-        t = AllTerms(self.pspec)
+        n_inputs = 64 + 4
+        output_width = self.pspec.width * 2
+        t = AllTerms(self.pspec, n_inputs)
         t.setup(m, self.i)
 
         terms = t.o.terms
 
-        at = AddReduceInternal(t.process(self.i), 128, partition_step=2)
+        at = AddReduceInternal(t.process(self.i), self.pspec, partition_step=2)
 
         i = at.i
         for idx in range(len(at.levels)):