move part-bytes to AllTerms
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 23:39:13 +0000 (00:39 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 23:39:13 +0000 (00:39 +0100)
src/ieee754/part_mul_add/multiply.py

index 7b8782e220797f43db32f2fdae442c90bf464980..0d65e30fd79ed49ab53d2b077bff5b6156c5bd51 100644 (file)
@@ -1074,7 +1074,7 @@ class AllTerms(Elaboratable):
     """Set of terms to be added together
     """
 
-    def __init__(self, pbwid, n_inputs, output_width, n_parts, register_levels,
+    def __init__(self, n_inputs, output_width, n_parts, register_levels,
                        partition_points):
         """Create an ``AddReduce``.
 
@@ -1086,7 +1086,6 @@ class AllTerms(Elaboratable):
         """
         self.epps = partition_points.like()
         self.register_levels = register_levels
-        self.pbwid = pbwid
         self.n_inputs = n_inputs
         self.n_parts = n_parts
         self.output_width = output_width
@@ -1096,15 +1095,22 @@ class AllTerms(Elaboratable):
         self.a = Signal(64)
         self.b = Signal(64)
 
-        self.pbs = Signal(pbwid, reset_less=True)
         self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
 
     def elaborate(self, platform):
         m = Module()
 
-        pbs = self.pbs
         eps = self.epps
 
+        # collect part-bytes
+        pbs = Signal(8, reset_less=True)
+        tl = []
+        for i in range(8):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(eps.part_byte(i, mfactor=2))
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
+
         # local variables
         signs = []
         for i in range(8):
@@ -1273,15 +1279,6 @@ class Mul8_16_32_64(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        # collect part-bytes
-        pbs = Signal(8, reset_less=True)
-        tl = []
-        for i in range(8):
-            pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(self.part_pts.part_byte(i))
-            tl.append(pb)
-        m.d.comb += pbs.eq(Cat(*tl))
-
         # create (doubled) PartitionPoints (output is double input width)
         expanded_part_pts = eps = PartitionPoints()
         for i, v in self.part_pts.items():
@@ -1291,12 +1288,11 @@ class Mul8_16_32_64(Elaboratable):
 
         n_inputs = 64 + 4
         n_parts = 8 #len(self.part_pts)
-        t = AllTerms(8, n_inputs, 128, n_parts, self.register_levels,
+        t = AllTerms(n_inputs, 128, n_parts, self.register_levels,
                        eps)
         m.submodules.allterms = t
         m.d.comb += t.a.eq(self.a)
         m.d.comb += t.b.eq(self.b)
-        m.d.comb += t.pbs.eq(pbs)
         m.d.comb += t.epps.eq(eps)
         for i in range(8):
             m.d.comb += t.part_ops[i].eq(self.part_ops[i])