derive new class Term and ProductTerm
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 11:05:04 +0000 (12:05 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 11:05:04 +0000 (12:05 +0100)
src/ieee754/part_mul_add/multiply.py

index b22a69c15c9ebe27b127c71ff4a0113a67086f7d..74d8ca4633b3fdc81225fc9e451777ec3b735566 100644 (file)
@@ -383,22 +383,32 @@ def g_add_term(m, terms, value, shift=0, enabled=None):
 
 
 class Term(Elaboratable):
+    def __init__(self, width, twidth, shift=0, enabled=None):
+        self.width = width
+        self.shift = shift
+        self.enabled = enabled
+        self.t_in = Signal(width, reset_less=True)
+        self.term = Signal(twidth, reset_less=True)
+
+    def elaborate(self, platform):
+
+        m = Module()
+        m.d.comb += self.term.eq(get_term(self.t_in, self.shift, self.enabled))
+
+        return m
+
+
+class ProductTerm(Elaboratable):
     def __init__(self, width, twidth, pbwid, a_index, b_index):
         self.a_index = a_index
         self.b_index = b_index
+        shift = 8 * (self.a_index + self.b_index)
         self.width = width
         self.a = Signal(width, reset_less=True)
         self.b = Signal(width, reset_less=True)
-        self.term = Signal(twidth, reset_less=True)
         self.pb_en = Signal(pbwid, reset_less=True)
 
-    def elaborate(self, platform):
-
-        m = Module()
-        product = Signal(self.width*2, reset_less=True)
-        m.d.comb += product.eq(self.a * self.b)
-
-        tl = []
+        self.tl = tl = []
         min_index = min(self.a_index, self.b_index)
         max_index = max(self.a_index, self.b_index)
         for i in range(min_index, max_index):
@@ -406,11 +416,17 @@ class Term(Elaboratable):
         name = "te_%d_%d" % (self.a_index, self.b_index)
         if len(tl) > 0:
             term_enabled = Signal(name=name, reset_less=True)
-            m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
         else:
             term_enabled = None
-        shift = 8 * (self.a_index + self.b_index)
-        m.d.comb += self.term.eq(get_term(product, shift, term_enabled))
+
+        Term.__init__(self, width*2, twidth, shift, term_enabled)
+
+    def elaborate(self, platform):
+
+        m = Term.elaborate(self, platform)
+        if self.enabled is not None:
+            m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
+        m.d.comb += self.t_in.eq(self.a * self.b)
 
         return m
 
@@ -545,7 +561,7 @@ class Mul8_16_32_64(Elaboratable):
 
         for a_index in range(8):
             for b_index in range(8):
-                t = Term(8, 128, 8, a_index, b_index)
+                t = ProductTerm(8, 128, 8, a_index, b_index)
                 setattr(m.submodules, "term_%d_%d" % (a_index, b_index), t)
 
                 m.d.comb += t.a.eq(self.a.bit_select(a_index * 8, 8))