rename AllTermsData to InputData, use as input to base class Mul8_16_32_64
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 10:25:58 +0000 (11:25 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 10:25:58 +0000 (11:25 +0100)
src/ieee754/part_mul_add/multiply.py

index bebad9bc149c2470175a045dc57c3467351f3ff8..e0531b420e2cc885b2e10c42acec581c8e852c0b 100644 (file)
@@ -1065,12 +1065,14 @@ class IntermediateData:
                             rhs.intermediate_output, rhs.part_ops)
 
 
-class AllTermsData:
+class InputData:
 
-    def __init__(self, partition_points):
+    def __init__(self):
         self.a = Signal(64)
         self.b = Signal(64)
-        self.part_pts = partition_points.like()
+        self.part_pts = PartitionPoints()
+        for i in range(8, 64, 8):
+            self.part_pts[i] = Signal(name=f"part_pts_{i}")
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
 
     def eq_from(self, part_pts, inputs, part_ops):
@@ -1087,8 +1089,7 @@ class AllTerms(Elaboratable):
     """Set of terms to be added together
     """
 
-    def __init__(self, n_inputs, output_width, n_parts, register_levels,
-                       partition_points):
+    def __init__(self, n_inputs, output_width, n_parts, register_levels):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -1097,7 +1098,7 @@ class AllTerms(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        self.i = AllTermsData(partition_points)
+        self.i = InputData()
         self.register_levels = register_levels
         self.n_inputs = n_inputs
         self.n_parts = n_parts
@@ -1271,12 +1272,11 @@ class Mul8_16_32_64(Elaboratable):
         self.register_levels = list(register_levels)
 
         # inputs
-        self.part_pts = PartitionPoints()
-        for i in range(8, 64, 8):
-            self.part_pts[i] = Signal(name=f"part_pts_{i}")
-        self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
-        self.a = Signal(64)
-        self.b = Signal(64)
+        self.i = InputData()
+        self.part_pts = self.i.part_pts
+        self.part_ops = self.i.part_ops
+        self.a = self.i.a
+        self.b = self.i.b
 
         # intermediates (needed for unit tests)
         self.intermediate_output = Signal(128)
@@ -1291,7 +1291,7 @@ class Mul8_16_32_64(Elaboratable):
 
         n_inputs = 64 + 4
         n_parts = 8 #len(self.part_pts)
-        t = AllTerms(n_inputs, 128, n_parts, self.register_levels, part_pts)
+        t = AllTerms(n_inputs, 128, n_parts, self.register_levels)
         m.submodules.allterms = t
         m.d.comb += t.i.a.eq(self.a)
         m.d.comb += t.i.b.eq(self.b)