FinalOutput module
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 16:26:01 +0000 (17:26 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 16:26:01 +0000 (17:26 +0100)
src/ieee754/part_mul_add/multiply.py

index 69c84c4816187ff2c851bba4dd1bf2723e5d6d90..ffa28a6dec94341d238efcd9665800c28c709a2a 100644 (file)
@@ -577,6 +577,37 @@ class IntermediateOut(Elaboratable):
 
         return m
 
+
+class FinalOut(Elaboratable):
+    def __init__(self, out_wid):
+        # inputs
+        self.d8 = [Signal(name=f"d8{i}") for i in range(8)]
+        self.d16 = [Signal(name=f"d16{i}") for i in range(4)]
+        self.d32 = [Signal(name=f"d32{i}") for i in range(2)]
+        self.out8 = Signal(out_wid, reset_less=True)
+        self.out16 = Signal(out_wid, reset_less=True)
+        self.out32 = Signal(out_wid, reset_less=True)
+        self.out64 = Signal(out_wid, reset_less=True)
+
+        # output
+        self.output = Signal(out_wid, reset_less=True)
+
+    def elaborate(self, platform):
+        m = Module()
+        ol = []
+        for i in range(8):
+            op = Signal(8, reset_less=True, name="op%d" % i)
+            m.d.comb += op.eq(
+                Mux(self.d8[i] | self.d16[i // 2],
+                    Mux(self.d8[i], self.out8.bit_select(i * 8, 8),
+                                     self.out16.bit_select(i * 8, 8)),
+                    Mux(self.d32[i // 4], self.out32.bit_select(i * 8, 8),
+                                          self.out64.bit_select(i * 8, 8))))
+            ol.append(op)
+        m.d.comb += self.output.eq(Cat(*ol))
+        return m
+
+
 class OrMod(Elaboratable):
     def __init__(self, wid):
         self.wid = wid
@@ -681,10 +712,6 @@ class Mul8_16_32_64(Elaboratable):
         m.d.comb += pbs.eq(Cat(*tl))
 
         # local variables
-        output_64 = Signal(64)
-        output_32 = Signal(64)
-        output_16 = Signal(64)
-        output_8 = Signal(64)
         signs = []
         for i in range(8):
             s = Signs()
@@ -763,44 +790,39 @@ class Mul8_16_32_64(Elaboratable):
         m.d.comb += io64.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += output_64.eq(io64.output)
 
         # create _output_32
         m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
         m.d.comb += io32.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += output_32.eq(io32.output)
 
         # create _output_16
         m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
         m.d.comb += io16.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += output_16.eq(io16.output)
 
         # create _output_8
         m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
         m.d.comb += io8.intermed.eq(self._intermediate_output)
         for i in range(8):
             m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
-        m.d.comb += output_8.eq(io8.output)
 
         # final output
-        ol = []
+        m.submodules.finalout = out = FinalOut(64)
         for i in range(8):
-            op = Signal(8, reset_less=True, name="op%d" % i)
-            m.d.comb += op.eq(
-                Mux(part_8.delayed_parts[-1][i]
-                    | part_16.delayed_parts[-1][i // 2],
-                    Mux(part_8.delayed_parts[-1][i],
-                        output_8.bit_select(i * 8, 8),
-                        output_16.bit_select(i * 8, 8)),
-                    Mux(part_32.delayed_parts[-1][i // 4],
-                        output_32.bit_select(i * 8, 8),
-                        output_64.bit_select(i * 8, 8))))
-            ol.append(op)
-        m.d.comb += self.output.eq(Cat(*ol))
+            m.d.comb += out.d8[i].eq(part_8.delayed_parts[-1][i])
+        for i in range(4):
+            m.d.comb += out.d16[i].eq(part_16.delayed_parts[-1][i])
+        for i in range(2):
+            m.d.comb += out.d32[i].eq(part_32.delayed_parts[-1][i])
+        m.d.comb += out.out8.eq(io8.output)
+        m.d.comb += out.out16.eq(io16.output)
+        m.d.comb += out.out32.eq(io32.output)
+        m.d.comb += out.out64.eq(io64.output)
+        m.d.comb += self.output.eq(out.output)
+
         return m