move part modules into FinalOut
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 07:04:06 +0000 (08:04 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Aug 2019 07:04:06 +0000 (08:04 +0100)
src/ieee754/part_mul_add/multiply.py

index 044b92c6883a322873654345ce0eee537d6c4267..09922caae8f4d820c441bad7635fd5b0dc88fdb4 100644 (file)
@@ -939,22 +939,51 @@ class FinalOut(Elaboratable):
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
-    def __init__(self, out_wid):
-        # inputs
-        self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
-        self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
-        self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
-
-        self.i8 = Signal(out_wid, reset_less=True)
-        self.i16 = Signal(out_wid, reset_less=True)
-        self.i32 = Signal(out_wid, reset_less=True)
-        self.i64 = Signal(out_wid, reset_less=True)
-
+    def __init__(self, output_width, n_parts, partition_points):
+        self.expanded_part_points = partition_points
+        self.i = IntermediateData(partition_points, output_width, n_parts)
+        self.out_wid = output_width//2
         # output
-        self.out = Signal(out_wid, reset_less=True)
+        self.out = Signal(self.out_wid, reset_less=True)
+        self.intermediate_output = Signal(output_width, reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
+
+        eps = self.expanded_part_points
+        m.submodules.p_8 = p_8 = Parts(8, eps, 8)
+        m.submodules.p_16 = p_16 = Parts(8, eps, 4)
+        m.submodules.p_32 = p_32 = Parts(8, eps, 2)
+        m.submodules.p_64 = p_64 = Parts(8, eps, 1)
+
+        out_part_pts = self.i.reg_partition_points
+
+        # temporaries
+        d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
+        d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
+        d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
+
+        i8 = Signal(self.out_wid, reset_less=True)
+        i16 = Signal(self.out_wid, reset_less=True)
+        i32 = Signal(self.out_wid, reset_less=True)
+        i64 = Signal(self.out_wid, reset_less=True)
+
+        m.d.comb += p_8.epps.eq(out_part_pts)
+        m.d.comb += p_16.epps.eq(out_part_pts)
+        m.d.comb += p_32.epps.eq(out_part_pts)
+        m.d.comb += p_64.epps.eq(out_part_pts)
+
+        for i in range(len(p_8.parts)):
+            m.d.comb += d8[i].eq(p_8.parts[i])
+        for i in range(len(p_16.parts)):
+            m.d.comb += d16[i].eq(p_16.parts[i])
+        for i in range(len(p_32.parts)):
+            m.d.comb += d32[i].eq(p_32.parts[i])
+        m.d.comb += i8.eq(self.i.outputs[0])
+        m.d.comb += i16.eq(self.i.outputs[1])
+        m.d.comb += i32.eq(self.i.outputs[2])
+        m.d.comb += i64.eq(self.i.outputs[3])
+
         ol = []
         for i in range(8):
             # select one of the outputs: d8 selects i8, d16 selects i16
@@ -964,13 +993,12 @@ class FinalOut(Elaboratable):
             # if neither d8 nor d16 are set, d32 selects either i32 or i64.
             op = Signal(8, reset_less=True, name="op_%d" % i)
             m.d.comb += op.eq(
-                Mux(self.d8[i] | self.d16[i // 2],
-                    Mux(self.d8[i], self.i8.part(i * 8, 8),
-                                     self.i16.part(i * 8, 8)),
-                    Mux(self.d32[i // 4], self.i32.part(i * 8, 8),
-                                          self.i64.part(i * 8, 8))))
+                Mux(d8[i] | d16[i // 2],
+                    Mux(d8[i], i8.part(i * 8, 8), i16.part(i * 8, 8)),
+                    Mux(d32[i // 4], i32.part(i * 8, 8), i64.part(i * 8, 8))))
             ol.append(op)
         m.d.comb += self.out.eq(Cat(*ol))
+        m.d.comb += self.intermediate_output.eq(self.i.intermediate_output)
         return m
 
 
@@ -1016,6 +1044,7 @@ class Signs(Elaboratable):
 
         return m
 
+
 class IntermediateData:
 
     def __init__(self, ppoints, output_width, n_parts):
@@ -1031,7 +1060,7 @@ class IntermediateData:
                       part_ops):
         return [self.reg_partition_points.eq(reg_partition_points)] + \
                [self.intermediate_output.eq(intermediate_output)] + \
-               [self.outputs.eq(outputs)
+               [self.outputs[i].eq(outputs[i])
                                      for i in range(4)] + \
                [self.part_ops[i].eq(part_ops[i])
                                      for i in range(len(self.part_ops))]
@@ -1045,8 +1074,7 @@ class Intermediates(Elaboratable):
     """ Intermediate output modules
     """
 
-    def __init__(self, output_width, n_parts, register_levels,
-                       partition_points):
+    def __init__(self, output_width, n_parts, partition_points):
         self.i = FinalReduceData(partition_points, output_width, n_parts)
         self.o = IntermediateData(partition_points, output_width, n_parts)
 
@@ -1226,35 +1254,13 @@ class Mul8_16_32_64(Elaboratable):
         m.submodules.add_reduce = add_reduce
         m.d.comb += self.intermediate_output.eq(add_reduce.o.output)
 
-        interm = Intermediates(128, 8, self.register_levels,
-                               expanded_part_pts)
+        interm = Intermediates(128, 8, expanded_part_pts)
         m.submodules.intermediates = interm
         m.d.comb += interm.i.eq(add_reduce.o)
 
-        m.submodules.p_8 = p_8 = Parts(8, eps, 8)
-        m.submodules.p_16 = p_16 = Parts(8, eps, 4)
-        m.submodules.p_32 = p_32 = Parts(8, eps, 2)
-        m.submodules.p_64 = p_64 = Parts(8, eps, 1)
-
-        out_part_pts = interm.o.reg_partition_points
-
-        m.d.comb += p_8.epps.eq(out_part_pts)
-        m.d.comb += p_16.epps.eq(out_part_pts)
-        m.d.comb += p_32.epps.eq(out_part_pts)
-        m.d.comb += p_64.epps.eq(out_part_pts)
-
         # final output
-        m.submodules.finalout = finalout = FinalOut(64)
-        for i in range(len(p_8.parts)):
-            m.d.comb += finalout.d8[i].eq(p_8.parts[i])
-        for i in range(len(p_16.parts)):
-            m.d.comb += finalout.d16[i].eq(p_16.parts[i])
-        for i in range(len(p_32.parts)):
-            m.d.comb += finalout.d32[i].eq(p_32.parts[i])
-        m.d.comb += finalout.i8.eq(interm.o.outputs[0])
-        m.d.comb += finalout.i16.eq(interm.o.outputs[1])
-        m.d.comb += finalout.i32.eq(interm.o.outputs[2])
-        m.d.comb += finalout.i64.eq(interm.o.outputs[3])
+        m.submodules.finalout = finalout = FinalOut(128, 8, expanded_part_pts)
+        m.d.comb += finalout.i.eq(interm.o)
         m.d.comb += self.output.eq(finalout.out)
 
         return m