split out "Parts" to separate module
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 12:33:22 +0000 (13:33 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 12:33:22 +0000 (13:33 +0100)
src/ieee754/part_mul_add/multiply.py
src/ieee754/part_mul_add/test/test_multiply.py

index 33232e773827a94c9ef442f32db32e7ae8c97963..189656db1b338440425797464ba8cbe8d3a4d856 100644 (file)
@@ -6,7 +6,8 @@ from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
 from nmigen.hdl.ast import Assign
 from abc import ABCMeta, abstractmethod
 from nmigen.cli import main
-
+from functools import reduce
+from operator import or_
 
 class PartitionPoints(dict):
     """Partition points and corresponding ``Value``s.
@@ -424,6 +425,90 @@ class ProductTerm(Elaboratable):
 
         return m
 
+class Part(Elaboratable):
+    def __init__(self, width, n_parts, n_levels, pbwid):
+
+        # inputs
+        self.a = Signal(64)
+        self.b = Signal(64)
+        self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
+        self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+        self.pbs = Signal(pbwid, reset_less=True)
+
+        # outputs
+        self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
+        self.delayed_parts = [
+            [Signal(name=f"delayed_part_8_{delay}_{i}")
+             for i in range(n_parts)]
+                for delay in range(n_levels)]
+
+        self.not_a_term = Signal(width)
+        self.neg_lsb_a_term = Signal(width)
+        self.not_b_term = Signal(width)
+        self.neg_lsb_b_term = Signal(width)
+
+    def elaborate(self, platform):
+        m = Module()
+
+        pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
+        byte_count = 8 // len(parts)
+        for i in range(len(parts)):
+            pbl = []
+            pbl.append(~pbs[i * byte_count - 1])
+            for j in range(i * byte_count, (i + 1) * byte_count - 1):
+                pbl.append(pbs[j])
+            pbl.append(~pbs[(i + 1) * byte_count - 1])
+            value = Signal(len(pbl), reset_less=True)
+            m.d.comb += value.eq(Cat(*pbl))
+            m.d.comb += parts[i].eq(~(value).bool())
+            m.d.comb += delayed_parts[0][i].eq(parts[i])
+            m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
+                         for j in range(len(delayed_parts)-1)]
+
+        not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
+                self.not_a_term, self.neg_lsb_a_term, \
+                self.not_b_term, self.neg_lsb_b_term
+
+        byte_width = 8 // len(parts)
+        bit_width = 8 * byte_width
+        nat, nbt, nla, nlb = [], [], [], []
+        for i in range(len(parts)):
+            be = parts[i] & self.a[(i + 1) * bit_width - 1] \
+                & self._a_signed[i * byte_width]
+            ae = parts[i] & self.b[(i + 1) * bit_width - 1] \
+                & self._b_signed[i * byte_width]
+            a_enabled = Signal(name="a_en_%d" % i, reset_less=True)
+            b_enabled = Signal(name="b_en_%d" % i, reset_less=True)
+            m.d.comb += a_enabled.eq(ae)
+            m.d.comb += b_enabled.eq(be)
+
+            # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
+            # negation operation is split into a bitwise not and a +1.
+            # likewise for 16, 32, and 64-bit values.
+            nat.append(Mux(a_enabled,
+                    Cat(Repl(0, bit_width),
+                        ~self.a.bit_select(bit_width * i, bit_width)),
+                    0))
+
+            nla.append(Cat(Repl(0, bit_width), a_enabled,
+                           Repl(0, bit_width-1)))
+
+            nbt.append(Mux(b_enabled,
+                    Cat(Repl(0, bit_width),
+                        ~self.b.bit_select(bit_width * i, bit_width)),
+                    0))
+
+            nlb.append(Cat(Repl(0, bit_width), b_enabled,
+                           Repl(0, bit_width-1)))
+
+        m.d.comb += [not_a_term.eq(Cat(*nat)),
+                     not_b_term.eq(Cat(*nbt)),
+                     neg_lsb_a_term.eq(Cat(*nla)),
+                     neg_lsb_b_term.eq(Cat(*nlb)),
+                    ]
+
+        return m
+
 
 class Mul8_16_32_64(Elaboratable):
     """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
@@ -467,47 +552,12 @@ class Mul8_16_32_64(Elaboratable):
             [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
              for i in range(8)]
             for delay in range(1 + len(self.register_levels))]
-        self._part_8 = [Signal(name=f"_part_8_{i}") for i in range(8)]
-        self._part_16 = [Signal(name=f"_part_16_{i}") for i in range(4)]
-        self._part_32 = [Signal(name=f"_part_32_{i}") for i in range(2)]
-        self._part_64 = [Signal(name=f"_part_64")]
-        self._delayed_part_8 = [
-            [Signal(name=f"_delayed_part_8_{delay}_{i}")
-             for i in range(8)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_16 = [
-            [Signal(name=f"_delayed_part_16_{delay}_{i}")
-             for i in range(4)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_32 = [
-            [Signal(name=f"_delayed_part_32_{delay}_{i}")
-             for i in range(2)]
-            for delay in range(1 + len(self.register_levels))]
-        self._delayed_part_64 = [
-            [Signal(name=f"_delayed_part_64_{delay}")]
-            for delay in range(1 + len(self.register_levels))]
         self._output_64 = Signal(64)
         self._output_32 = Signal(64)
         self._output_16 = Signal(64)
         self._output_8 = Signal(64)
         self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
         self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
-        self._not_a_term_8 = Signal(128)
-        self._neg_lsb_a_term_8 = Signal(128)
-        self._not_b_term_8 = Signal(128)
-        self._neg_lsb_b_term_8 = Signal(128)
-        self._not_a_term_16 = Signal(128)
-        self._neg_lsb_a_term_16 = Signal(128)
-        self._not_b_term_16 = Signal(128)
-        self._neg_lsb_b_term_16 = Signal(128)
-        self._not_a_term_32 = Signal(128)
-        self._neg_lsb_a_term_32 = Signal(128)
-        self._not_b_term_32 = Signal(128)
-        self._neg_lsb_b_term_32 = Signal(128)
-        self._not_a_term_64 = Signal(128)
-        self._neg_lsb_a_term_64 = Signal(128)
-        self._not_b_term_64 = Signal(128)
-        self._neg_lsb_b_term_64 = Signal(128)
 
     def _part_byte(self, index):
         if index == -1 or index == 7:
@@ -533,23 +583,24 @@ class Mul8_16_32_64(Elaboratable):
                          .eq(self._delayed_part_ops[j][i])
                          for j in range(len(self.register_levels))]
 
-        for parts, delayed_parts in [(self._part_64, self._delayed_part_64),
-                                     (self._part_32, self._delayed_part_32),
-                                     (self._part_16, self._delayed_part_16),
-                                     (self._part_8, self._delayed_part_8)]:
-            byte_count = 8 // len(parts)
-            for i in range(len(parts)):
-                pbl = []
-                pbl.append(~pbs[i * byte_count - 1])
-                for j in range(i * byte_count, (i + 1) * byte_count - 1):
-                    pbl.append(pbs[j])
-                pbl.append(~pbs[(i + 1) * byte_count - 1])
-                value = Signal(len(pbl), reset_less=True)
-                m.d.comb += value.eq(Cat(*pbl))
-                m.d.comb += parts[i].eq(~(value).bool())
-                m.d.comb += delayed_parts[0][i].eq(parts[i])
-                m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
-                             for j in range(len(self.register_levels))]
+        n_levels = len(self.register_levels)+1
+        m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
+        m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
+        m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
+        m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
+        nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
+        for mod in [part_8, part_16, part_32, part_64]:
+            m.d.comb += mod.a.eq(self.a)
+            m.d.comb += mod.b.eq(self.b)
+            for i in range(len(self._a_signed)):
+                m.d.comb += mod._a_signed[i].eq(self._a_signed[i])
+            for i in range(len(self._b_signed)):
+                m.d.comb += mod._b_signed[i].eq(self._b_signed[i])
+            m.d.comb += mod.pbs.eq(pbs)
+            nat_l.append(mod.not_a_term)
+            nbt_l.append(mod.not_b_term)
+            nla_l.append(mod.neg_lsb_a_term)
+            nlb_l.append(mod.neg_lsb_b_term)
 
         terms = []
 
@@ -573,87 +624,23 @@ class Mul8_16_32_64(Elaboratable):
 
         # it's fine to bitwise-or these together since they are never enabled
         # at the same time
+        nat_l = reduce(or_, nat_l)
+        nbt_l = reduce(or_, nbt_l)
+        nla_l = reduce(or_, nla_l)
+        nlb_l = reduce(or_, nlb_l)
         m.submodules.nat = nat = Term(128, 128)
         m.submodules.nla = nla = Term(128, 128)
         m.submodules.nbt = nbt = Term(128, 128)
         m.submodules.nlb = nlb = Term(128, 128)
-        m.d.comb += nat.ti.eq(self._not_a_term_8 | self._not_a_term_16
-                           | self._not_a_term_32 | self._not_a_term_64)
-        m.d.comb += nbt.ti.eq(self._not_b_term_8 | self._not_b_term_16
-                           | self._not_b_term_32 | self._not_b_term_64)
-        m.d.comb += nla.ti.eq(self._neg_lsb_a_term_8 | self._neg_lsb_a_term_16
-                           | self._neg_lsb_a_term_32 | self._neg_lsb_a_term_64)
-        m.d.comb += nlb.ti.eq(self._neg_lsb_b_term_8 | self._neg_lsb_b_term_16
-                           | self._neg_lsb_b_term_32 | self._neg_lsb_b_term_64)
+        m.d.comb += nat.ti.eq(nat_l)
+        m.d.comb += nbt.ti.eq(nbt_l)
+        m.d.comb += nla.ti.eq(nla_l)
+        m.d.comb += nlb.ti.eq(nlb_l)
         terms.append(nat.term)
         terms.append(nla.term)
         terms.append(nbt.term)
         terms.append(nlb.term)
 
-        for not_a_term, \
-            neg_lsb_a_term, \
-            not_b_term, \
-            neg_lsb_b_term, \
-            parts in [
-                (self._not_a_term_8,
-                 self._neg_lsb_a_term_8,
-                 self._not_b_term_8,
-                 self._neg_lsb_b_term_8,
-                 self._part_8),
-                (self._not_a_term_16,
-                 self._neg_lsb_a_term_16,
-                 self._not_b_term_16,
-                 self._neg_lsb_b_term_16,
-                 self._part_16),
-                (self._not_a_term_32,
-                 self._neg_lsb_a_term_32,
-                 self._not_b_term_32,
-                 self._neg_lsb_b_term_32,
-                 self._part_32),
-                (self._not_a_term_64,
-                 self._neg_lsb_a_term_64,
-                 self._not_b_term_64,
-                 self._neg_lsb_b_term_64,
-                 self._part_64),
-                ]:
-            byte_width = 8 // len(parts)
-            bit_width = 8 * byte_width
-            nat, nbt, nla, nlb = [], [], [], []
-            for i in range(len(parts)):
-                be = parts[i] & self.a[(i + 1) * bit_width - 1] \
-                    & self._a_signed[i * byte_width]
-                ae = parts[i] & self.b[(i + 1) * bit_width - 1] \
-                    & self._b_signed[i * byte_width]
-                a_enabled = Signal(name="a_en_%d" % i, reset_less=True)
-                b_enabled = Signal(name="b_en_%d" % i, reset_less=True)
-                m.d.comb += a_enabled.eq(ae)
-                m.d.comb += b_enabled.eq(be)
-
-                # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
-                # negation operation is split into a bitwise not and a +1.
-                # likewise for 16, 32, and 64-bit values.
-                nat.append(Mux(a_enabled,
-                        Cat(Repl(0, bit_width),
-                            ~self.a.bit_select(bit_width * i, bit_width)),
-                        0))
-
-                nla.append(Cat(Repl(0, bit_width), a_enabled,
-                               Repl(0, bit_width-1)))
-
-                nbt.append(Mux(b_enabled,
-                        Cat(Repl(0, bit_width),
-                            ~self.b.bit_select(bit_width * i, bit_width)),
-                        0))
-
-                nlb.append(Cat(Repl(0, bit_width), b_enabled,
-                               Repl(0, bit_width-1)))
-
-            m.d.comb += [not_a_term.eq(Cat(*nat)),
-                         not_b_term.eq(Cat(*nbt)),
-                         neg_lsb_a_term.eq(Cat(*nla)),
-                         neg_lsb_b_term.eq(Cat(*nlb)),
-                        ]
-
         expanded_part_pts = PartitionPoints()
         for i, v in self.part_pts.items():
             signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
@@ -709,12 +696,12 @@ class Mul8_16_32_64(Elaboratable):
         for i in range(8):
             op = Signal(8, reset_less=True, name="op%d" % i)
             m.d.comb += op.eq(
-                Mux(self._delayed_part_8[-1][i]
-                    | self._delayed_part_16[-1][i // 2],
-                    Mux(self._delayed_part_8[-1][i],
+                Mux(part_8.delayed_parts[-1][i]
+                    | part_16.delayed_parts[-1][i // 2],
+                    Mux(part_8.delayed_parts[-1][i],
                         self._output_8.bit_select(i * 8, 8),
                         self._output_16.bit_select(i * 8, 8)),
-                    Mux(self._delayed_part_32[-1][i // 4],
+                    Mux(part_32.delayed_parts[-1][i // 4],
                         self._output_32.bit_select(i * 8, 8),
                         self._output_64.bit_select(i * 8, 8))))
             ol.append(op)
index ef7f5cd722cd71ab57b3a59099e929896e40ba7f..d96d45c138d97e8f88edc39bb4e49d202f7bb8dd 100644 (file)
@@ -527,40 +527,12 @@ class TestMul8_16_32_64(unittest.TestCase):
         ports.extend(module.part_pts.values())
         for signals in module._delayed_part_ops:
             ports.extend(signals)
-        ports.extend(module._part_8)
-        ports.extend(module._part_16)
-        ports.extend(module._part_32)
-        ports.extend(module._part_64)
-        for signals in module._delayed_part_8:
-            ports.extend(signals)
-        for signals in module._delayed_part_16:
-            ports.extend(signals)
-        for signals in module._delayed_part_32:
-            ports.extend(signals)
-        for signals in module._delayed_part_64:
-            ports.extend(signals)
         ports += [module._output_64,
                   module._output_32,
                   module._output_16,
                   module._output_8]
         ports.extend(module._a_signed)
         ports.extend(module._b_signed)
-        ports += [module._not_a_term_8,
-                  module._neg_lsb_a_term_8,
-                  module._not_b_term_8,
-                  module._neg_lsb_b_term_8,
-                  module._not_a_term_16,
-                  module._neg_lsb_a_term_16,
-                  module._not_b_term_16,
-                  module._neg_lsb_b_term_16,
-                  module._not_a_term_32,
-                  module._neg_lsb_a_term_32,
-                  module._not_b_term_32,
-                  module._neg_lsb_b_term_32,
-                  module._not_a_term_64,
-                  module._neg_lsb_a_term_64,
-                  module._not_b_term_64,
-                  module._neg_lsb_b_term_64]
         with create_simulator(module, ports, file_name) as sim:
             def process(gen_or_check: GenOrCheck) -> AsyncProcessGenerator:
                 for a_signed in False, True: