add new Terms class, get part_pts into intermediary
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 09:57:23 +0000 (10:57 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 09:57:23 +0000 (10:57 +0100)
src/ieee754/part_mul_add/multiply.py

index 0e3af6b1a3759819422cf7a4efd5afda745e9436..1788e0a002079904c068092753f929a78656410b 100644 (file)
@@ -365,6 +365,57 @@ OP_MUL_SIGNED_HIGH = 1
 OP_MUL_SIGNED_UNSIGNED_HIGH = 2  # a is signed, b is unsigned
 OP_MUL_UNSIGNED_HIGH = 3
 
+def get_term(value, shift=0, enabled=None):
+    if enabled is not None:
+        value = Mux(enabled, value, 0)
+    if shift > 0:
+        value = Cat(Repl(C(0, 1), shift), value)
+    else:
+        assert shift == 0
+    return value
+
+
+def g_add_term(m, terms, value, shift=0, enabled=None):
+    term = Signal(128, reset_less=True)
+    terms.append(term)
+    m.d.comb += term.eq(get_term(value, shift, enabled))
+
+
+class Term(Elaboratable):
+    def __init__(self, width, twidth, pbwid, a_index, b_index):
+        self.a_index = a_index
+        self.b_index = b_index
+        self.width = width
+        w2 = width * 2
+        self.a = Signal(width, reset_less=True)
+        self.b = Signal(width, reset_less=True)
+        self.product = Signal(w2, reset_less=True)
+        self.term = Signal(twidth, reset_less=True)
+        self.pb_en = Signal(pbwid, reset_less=True)
+
+    def elaborate(self, platform):
+        m.d.comb += self.product.eq(self.a * self.b)
+
+        terms = []
+
+        def add_term(value, shift=0, enabled=None):
+            g_add_term(m, terms, value, shift, enabled)
+
+        tl = []
+        min_index = min(a_index, b_index)
+        max_index = max(a_index, b_index)
+        for i in range(min_index, max_index):
+            m.d.comb += pbs.eq(self._part_byte(i))
+            tl.append(self.pb_en[i])
+        name = "te_%d_%d" % (a_index, b_index)
+        term_enabled = Signal(name=name, reset_less=True)
+        m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
+        shift = 8 * (a_index + b_index),
+        value = products[a_index][b_index],
+        m.d.comb += term.eq(get_term(value, shift, term_enabled))
+
+        return m
+
 
 class Mul8_16_32_64(Elaboratable):
     """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
@@ -502,28 +553,31 @@ class Mul8_16_32_64(Elaboratable):
         terms = []
 
         def add_term(value, shift=0, enabled=None):
-            term = Signal(128, reset_less=True)
-            terms.append(term)
-            if enabled is not None:
-                value = Mux(enabled, value, 0)
-            if shift > 0:
-                value = Cat(Repl(C(0, 1), shift), value)
-            else:
-                assert shift == 0
-            m.d.comb += term.eq(value)
+            g_add_term(m, terms, value, shift, enabled)
+
+        # collect part-bytes
+        pbs = Signal(8, reset_less=True)
+        tl = []
+        for i in range(8):
+            pb = Signal(name="pb%d" % i, reset_less=True)
+            m.d.comb += pb.eq(self._part_byte(i))
+            tl.append(pb)
+        m.d.comb += pbs.eq(Cat(*tl))
 
+        # add terms (if enabled)
         for a_index in range(8):
             for b_index in range(8):
                 tl = []
                 min_index = min(a_index, b_index)
                 max_index = max(a_index, b_index)
                 for i in range(min_index, max_index):
-                    pbs = Signal(reset_less=True)
-                    m.d.comb += pbs.eq(self._part_byte(i))
-                    tl.append(pbs)
+                    tl.append(pbs[i])
                 name = "te_%d_%d" % (a_index, b_index)
-                term_enabled = Signal(name=name, reset_less=True)
-                m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
+                if len(tl) > 0:
+                    term_enabled = Signal(name=name, reset_less=True)
+                    m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
+                else:
+                    term_enabled = None
                 add_term(products[a_index][b_index],
                          8 * (a_index + b_index),
                          term_enabled)