start adding ispec/ospec to multiply.py
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 11:02:07 +0000 (12:02 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 23 Aug 2019 11:02:07 +0000 (12:02 +0100)
src/ieee754/part_mul_add/multiply.py

index 3366a4f95187a491c724cb2a5cd7faa4f2a9ad4a..e81e42a75f1d45e2c7493f580e961b8a7ba31d32 100644 (file)
@@ -987,11 +987,18 @@ class FinalOut(Elaboratable):
     """
     def __init__(self, output_width, n_parts, part_pts):
         self.part_pts = part_pts
-        self.i = IntermediateData(part_pts, output_width, n_parts)
+        self.output_width = output_width
+        self.n_parts = n_parts
         self.out_wid = output_width//2
-        # output
-        self.out = Signal(self.out_wid, reset_less=True)
-        self.intermediate_output = Signal(output_width, reset_less=True)
+
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        return IntermediateData(self.part_pts, self.output_width, self.n_parts)
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
@@ -1045,8 +1052,11 @@ class FinalOut(Elaboratable):
                     Mux(d32[i // 4], i32.bit_select(i * 8, 8),
                                       i64.bit_select(i * 8, 8))))
             ol.append(op)
-        m.d.comb += self.out.eq(Cat(*ol))
-        m.d.comb += self.intermediate_output.eq(self.i.intermediate_output)
+
+        # create outputs
+        m.d.comb += self.o.output.eq(Cat(*ol))
+        m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
+
         return m
 
 
@@ -1138,6 +1148,17 @@ class InputData:
         return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
 
 
+class OutputData:
+
+    def __init__(self):
+        self.intermediate_output = Signal(128) # needed for unit tests
+        self.output = Signal(64)
+
+    def eq(self, rhs):
+        return [self.intermediate_output.eq(rhs.intermediate_output),
+                self.output.eq(rhs.output)]
+
+
 class AllTerms(Elaboratable):
     """Set of terms to be added together
     """
@@ -1151,13 +1172,20 @@ class AllTerms(Elaboratable):
             pipeline registers.
         :param partition_points: the input partition points.
         """
-        self.i = InputData()
         self.register_levels = register_levels
         self.n_inputs = n_inputs
         self.n_parts = n_parts
         self.output_width = output_width
-        self.o = AddReduceData(self.i.part_pts, n_inputs,
-                               output_width, n_parts)
+
+        self.i = self.ispec()
+        self.o = self.ospec()
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return AddReduceData(self.i.part_pts, self.n_inputs,
+                             self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         m = Module()
@@ -1324,18 +1352,24 @@ class Mul8_16_32_64(Elaboratable):
         # parameter(s)
         self.register_levels = list(register_levels)
 
+        self.i = self.ispec()
+        self.o = self.ospec()
+
         # inputs
-        self.i = InputData()
         self.part_pts = self.i.part_pts
         self.part_ops = self.i.part_ops
         self.a = self.i.a
         self.b = self.i.b
 
-        # intermediates (needed for unit tests)
-        self.intermediate_output = Signal(128)
-
         # output
-        self.output = Signal(64)
+        self.intermediate_output = self.o.intermediate_output
+        self.output = self.o.output
+
+    def ispec(self):
+        return InputData()
+
+    def ospec(self):
+        return OutputData()
 
     def elaborate(self, platform):
         m = Module()
@@ -1368,8 +1402,7 @@ class Mul8_16_32_64(Elaboratable):
         # final output
         m.submodules.finalout = finalout = FinalOut(128, 8, part_pts)
         m.d.comb += finalout.i.eq(interm.o)
-        m.d.comb += self.output.eq(finalout.out)
-        m.d.comb += self.intermediate_output.eq(finalout.intermediate_output)
+        m.d.comb += self.o.eq(finalout.o)
 
         return m