whoops, a-enabled and b-enabled swapped
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index f0bb95b5dc5cabae2e3def234f21a7391861cb85..dba87880e6012a3d942443cbf6b260712d5c7c75 100644 (file)
@@ -199,7 +199,7 @@ class PartitionedAdder(Elaboratable):
         # combine above using Cat
         m.d.comb += Cat(*ea).eq(Cat(*al))
         m.d.comb += Cat(*eb).eq(Cat(*bl))
-        m.d.comb += Cat(*eo).eq(Cat(*ol))
+        m.d.comb += Cat(*ol).eq(Cat(*eo))
         # use only one addition to take advantage of look-ahead carry and
         # special hardware on FPGAs
         m.d.comb += self._expanded_output.eq(
@@ -313,7 +313,6 @@ class AddReduce(Elaboratable):
                 m.d.comb += self.output.eq(adder.output)
             return m
         # go on to handle recursive case
-        intermediate_terms: List[Signal]
         intermediate_terms = []
 
         def add_intermediate_term(value):
@@ -323,7 +322,10 @@ class AddReduce(Elaboratable):
             intermediate_terms.append(intermediate_term)
             m.d.comb += intermediate_term.eq(value)
 
-        part_mask = self._reg_partition_points.as_mask(len(self.output))
+        # store mask in intermediary (simplifies graph)
+        part_mask = Signal(len(self.output), reset_less=True)
+        mask = self._reg_partition_points.as_mask(len(self.output))
+        m.d.comb += part_mask.eq(mask)
 
         # create full adders for this recursive level.
         # this shrinks N terms to 2 * (N // 3) plus the remainder
@@ -463,23 +465,31 @@ class Mul8_16_32_64(Elaboratable):
                          .eq(self._delayed_part_ops[j][i])
                          for j in range(len(self.register_levels))]
 
+        def add_intermediate_value(value):
+            intermediate_value = Signal(len(value), reset_less=True)
+            m.d.comb += intermediate_value.eq(value)
+            return intermediate_value
+
         for parts, delayed_parts in [(self._part_64, self._delayed_part_64),
                                      (self._part_32, self._delayed_part_32),
                                      (self._part_16, self._delayed_part_16),
                                      (self._part_8, self._delayed_part_8)]:
             byte_count = 8 // len(parts)
             for i in range(len(parts)):
-                value = self._part_byte(i * byte_count - 1)
+                pb = self._part_byte(i * byte_count - 1)
+                value = add_intermediate_value(pb)
                 for j in range(i * byte_count, (i + 1) * byte_count - 1):
-                    value &= ~self._part_byte(j)
-                value &= self._part_byte((i + 1) * byte_count - 1)
+                    pb = add_intermediate_value(~self._part_byte(j))
+                    value = add_intermediate_value(value & pb)
+                pb = self._part_byte((i + 1) * byte_count - 1)
+                value = add_intermediate_value(value & pb)
                 m.d.comb += parts[i].eq(value)
                 m.d.comb += delayed_parts[0][i].eq(parts[i])
                 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
                              for j in range(len(self.register_levels))]
 
         products = [[
-                Signal(16, name=f"products_{i}_{j}")
+                Signal(16, name=f"products_{i}_{j}", reset_less=True)
                 for j in range(8)]
             for i in range(8)]
 
@@ -492,7 +502,7 @@ class Mul8_16_32_64(Elaboratable):
         terms = []
 
         def add_term(value, shift=0, enabled=None):
-            term = Signal(128)
+            term = Signal(128, reset_less=True)
             terms.append(term)
             if enabled is not None:
                 value = Mux(enabled, value, 0)
@@ -504,11 +514,16 @@ class Mul8_16_32_64(Elaboratable):
 
         for a_index in range(8):
             for b_index in range(8):
-                term_enabled: Value = C(True, 1)
+                tl = []
                 min_index = min(a_index, b_index)
                 max_index = max(a_index, b_index)
                 for i in range(min_index, max_index):
-                    term_enabled &= ~self._part_byte(i)
+                    pbs = Signal(reset_less=True)
+                    m.d.comb += pbs.eq(self._part_byte(i))
+                    tl.append(pbs)
+                name = "te_%d_%d" % (a_index, b_index)
+                term_enabled = Signal(name=name, reset_less=True)
+                m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
                 add_term(products[a_index][b_index],
                          8 * (a_index + b_index),
                          term_enabled)
@@ -560,10 +575,14 @@ class Mul8_16_32_64(Elaboratable):
             byte_width = 8 // len(parts)
             bit_width = 8 * byte_width
             for i in range(len(parts)):
-                b_enabled = parts[i] & self.a[(i + 1) * bit_width - 1] \
+                be = parts[i] & self.a[(i + 1) * bit_width - 1] \
                     & self._a_signed[i * byte_width]
-                a_enabled = parts[i] & self.b[(i + 1) * bit_width - 1] \
+                ae = parts[i] & self.b[(i + 1) * bit_width - 1] \
                     & self._b_signed[i * byte_width]
+                a_enabled = Signal(name="a_en_%d" % i, reset_less=True)
+                b_enabled = Signal(name="b_en_%d" % i, reset_less=True)
+                m.d.comb += a_enabled.eq(ae)
+                m.d.comb += b_enabled.eq(be)
 
                 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
                 # negation operation is split into a bitwise not and a +1.
@@ -589,7 +608,7 @@ class Mul8_16_32_64(Elaboratable):
 
         expanded_part_pts = PartitionPoints()
         for i, v in self.part_pts.items():
-            signal = Signal(name=f"expanded_part_pts_{i*2}")
+            signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
             expanded_part_pts[i * 2] = signal
             m.d.comb += signal.eq(v)
 
@@ -603,23 +622,45 @@ class Mul8_16_32_64(Elaboratable):
             Mux(self._delayed_part_ops[-1][0] == OP_MUL_LOW,
                 self._intermediate_output.part(0, 64),
                 self._intermediate_output.part(64, 64)))
+
+        # create _output_32
+        ol = []
         for i in range(2):
-            m.d.comb += self._output_32.part(i * 32, 32).eq(
+            op = Signal(32, reset_less=True, name="op32_%d" % i)
+            m.d.comb += op.eq(
                 Mux(self._delayed_part_ops[-1][4 * i] == OP_MUL_LOW,
                     self._intermediate_output.part(i * 64, 32),
                     self._intermediate_output.part(i * 64 + 32, 32)))
+            ol.append(op)
+        m.d.comb += self._output_32.eq(Cat(*ol))
+
+        # create _output_16
+        ol = []
         for i in range(4):
-            m.d.comb += self._output_16.part(i * 16, 16).eq(
+            op = Signal(16, reset_less=True, name="op16_%d" % i)
+            m.d.comb += op.eq(
                 Mux(self._delayed_part_ops[-1][2 * i] == OP_MUL_LOW,
                     self._intermediate_output.part(i * 32, 16),
                     self._intermediate_output.part(i * 32 + 16, 16)))
+            ol.append(op)
+        m.d.comb += self._output_16.eq(Cat(*ol))
+
+        # create _output_8
+        ol = []
         for i in range(8):
-            m.d.comb += self._output_8.part(i * 8, 8).eq(
+            op = Signal(8, reset_less=True, name="op8_%d" % i)
+            m.d.comb += op.eq(
                 Mux(self._delayed_part_ops[-1][i] == OP_MUL_LOW,
                     self._intermediate_output.part(i * 16, 8),
                     self._intermediate_output.part(i * 16 + 8, 8)))
+            ol.append(op)
+        m.d.comb += self._output_8.eq(Cat(*ol))
+
+        # final output
+        ol = []
         for i in range(8):
-            m.d.comb += self.output.part(i * 8, 8).eq(
+            op = Signal(8, reset_less=True, name="op%d" % i)
+            m.d.comb += op.eq(
                 Mux(self._delayed_part_8[-1][i]
                     | self._delayed_part_16[-1][i // 2],
                     Mux(self._delayed_part_8[-1][i],
@@ -628,6 +669,8 @@ class Mul8_16_32_64(Elaboratable):
                     Mux(self._delayed_part_32[-1][i // 4],
                         self._output_32.part(i * 8, 8),
                         self._output_64.part(i * 8, 8))))
+            ol.append(op)
+        m.d.comb += self.output.eq(Cat(*ol))
         return m