move product terms to separate module (Term)
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 10:13:59 +0000 (11:13 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 10:13:59 +0000 (11:13 +0100)
src/ieee754/part_mul_add/multiply.py

index 1788e0a002079904c068092753f929a78656410b..3695f22ce926cb84cb65adf756e84395a6d42d1b 100644 (file)
@@ -365,6 +365,7 @@ OP_MUL_SIGNED_HIGH = 1
 OP_MUL_SIGNED_UNSIGNED_HIGH = 2  # a is signed, b is unsigned
 OP_MUL_UNSIGNED_HIGH = 3
 
+
 def get_term(value, shift=0, enabled=None):
     if enabled is not None:
         value = Mux(enabled, value, 0)
@@ -386,33 +387,30 @@ class Term(Elaboratable):
         self.a_index = a_index
         self.b_index = b_index
         self.width = width
-        w2 = width * 2
         self.a = Signal(width, reset_less=True)
         self.b = Signal(width, reset_less=True)
-        self.product = Signal(w2, reset_less=True)
         self.term = Signal(twidth, reset_less=True)
         self.pb_en = Signal(pbwid, reset_less=True)
 
     def elaborate(self, platform):
-        m.d.comb += self.product.eq(self.a * self.b)
-
-        terms = []
 
-        def add_term(value, shift=0, enabled=None):
-            g_add_term(m, terms, value, shift, enabled)
+        m = Module()
+        product = Signal(self.width*2, reset_less=True)
+        m.d.comb += product.eq(self.a * self.b)
 
         tl = []
-        min_index = min(a_index, b_index)
-        max_index = max(a_index, b_index)
+        min_index = min(self.a_index, self.b_index)
+        max_index = max(self.a_index, self.b_index)
         for i in range(min_index, max_index):
-            m.d.comb += pbs.eq(self._part_byte(i))
             tl.append(self.pb_en[i])
-        name = "te_%d_%d" % (a_index, b_index)
-        term_enabled = Signal(name=name, reset_less=True)
-        m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
-        shift = 8 * (a_index + b_index),
-        value = products[a_index][b_index],
-        m.d.comb += term.eq(get_term(value, shift, term_enabled))
+        name = "te_%d_%d" % (self.a_index, self.b_index)
+        if len(tl) > 0:
+            term_enabled = Signal(name=name, reset_less=True)
+            m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
+        else:
+            term_enabled = None
+        shift = 8 * (self.a_index + self.b_index)
+        m.d.comb += self.term.eq(get_term(product, shift, term_enabled))
 
         return m
 
@@ -510,6 +508,15 @@ class Mul8_16_32_64(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
+        # collect part-bytes
+        pbs = Signal(8, reset_less=True)
+        tl = []
+        for i in range(8):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(self._part_byte(i))
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
+
         for i in range(len(self.part_ops)):
             m.d.comb += self._delayed_part_ops[0][i].eq(self.part_ops[i])
             m.d.sync += [self._delayed_part_ops[j + 1][i]
@@ -527,65 +534,38 @@ class Mul8_16_32_64(Elaboratable):
                                      (self._part_8, self._delayed_part_8)]:
             byte_count = 8 // len(parts)
             for i in range(len(parts)):
-                pb = self._part_byte(i * byte_count - 1)
+                pb = pbs[i * byte_count - 1]
                 value = add_intermediate_value(pb)
                 for j in range(i * byte_count, (i + 1) * byte_count - 1):
-                    pb = add_intermediate_value(~self._part_byte(j))
+                    pb = add_intermediate_value(~pbs[j])
                     value = add_intermediate_value(value & pb)
-                pb = self._part_byte((i + 1) * byte_count - 1)
+                pb = pbs[(i + 1) * byte_count - 1]
                 value = add_intermediate_value(value & pb)
                 m.d.comb += parts[i].eq(value)
                 m.d.comb += delayed_parts[0][i].eq(parts[i])
                 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
                              for j in range(len(self.register_levels))]
 
-        products = [[
-                Signal(16, name=f"products_{i}_{j}", reset_less=True)
-                for j in range(8)]
-            for i in range(8)]
+        terms = []
 
         for a_index in range(8):
             for b_index in range(8):
-                a = self.a.bit_select(a_index * 8, 8)
-                b = self.b.bit_select(b_index * 8, 8)
-                m.d.comb += products[a_index][b_index].eq(a * b)
+                t = Term(8, 128, 8, a_index, b_index)
+                setattr(m.submodules, "term_%d_%d" % (a_index, b_index), t)
 
-        terms = []
+                m.d.comb += t.a.eq(self.a.bit_select(a_index * 8, 8))
+                m.d.comb += t.b.eq(self.b.bit_select(b_index * 8, 8))
+                m.d.comb += t.pb_en.eq(pbs)
+
+                terms.append(t.term)
 
         def add_term(value, shift=0, enabled=None):
             g_add_term(m, terms, value, shift, enabled)
 
-        # collect part-bytes
-        pbs = Signal(8, reset_less=True)
-        tl = []
-        for i in range(8):
-            pb = Signal(name="pb%d" % i, reset_less=True)
-            m.d.comb += pb.eq(self._part_byte(i))
-            tl.append(pb)
-        m.d.comb += pbs.eq(Cat(*tl))
-
-        # add terms (if enabled)
-        for a_index in range(8):
-            for b_index in range(8):
-                tl = []
-                min_index = min(a_index, b_index)
-                max_index = max(a_index, b_index)
-                for i in range(min_index, max_index):
-                    tl.append(pbs[i])
-                name = "te_%d_%d" % (a_index, b_index)
-                if len(tl) > 0:
-                    term_enabled = Signal(name=name, reset_less=True)
-                    m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
-                else:
-                    term_enabled = None
-                add_term(products[a_index][b_index],
-                         8 * (a_index + b_index),
-                         term_enabled)
-
         for i in range(8):
             a_signed = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH
             b_signed = (self.part_ops[i] == OP_MUL_LOW) \
-                | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
+                        | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
             m.d.comb += self._a_signed[i].eq(a_signed)
             m.d.comb += self._b_signed[i].eq(b_signed)