simplify sign/term bits using Cat
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 10:32:15 +0000 (11:32 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 17 Aug 2019 10:32:15 +0000 (11:32 +0100)
src/ieee754/part_mul_add/multiply.py

index 3695f22ce926cb84cb65adf756e84395a6d42d1b..aef8b91816144d0b69c51d72a51dc4d6162b9f92 100644 (file)
@@ -608,6 +608,7 @@ class Mul8_16_32_64(Elaboratable):
                 ]:
             byte_width = 8 // len(parts)
             bit_width = 8 * byte_width
+            nat, nbt, nla, nlb = [], [], [], []
             for i in range(len(parts)):
                 be = parts[i] & self.a[(i + 1) * bit_width - 1] \
                     & self._a_signed[i * byte_width]
@@ -621,24 +622,27 @@ class Mul8_16_32_64(Elaboratable):
                 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
                 # negation operation is split into a bitwise not and a +1.
                 # likewise for 16, 32, and 64-bit values.
-                m.d.comb += [
-                    not_a_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(a_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.a.bit_select(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_a_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), a_enabled)),
-
-                    not_b_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(b_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.b.bit_select(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_b_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), b_enabled))]
+                nat.append(Mux(a_enabled,
+                        Cat(Repl(0, bit_width),
+                            ~self.a.bit_select(bit_width * i, bit_width)),
+                        0))
+
+                nla.append(Cat(Repl(0, bit_width), a_enabled,
+                               Repl(0, bit_width-1)))
+
+                nbt.append(Mux(b_enabled,
+                        Cat(Repl(0, bit_width),
+                            ~self.b.bit_select(bit_width * i, bit_width)),
+                        0))
+
+                nlb.append(Cat(Repl(0, bit_width), b_enabled,
+                               Repl(0, bit_width-1)))
+
+            m.d.comb += [not_a_term.eq(Cat(*nat)),
+                         not_b_term.eq(Cat(*nbt)),
+                         neg_lsb_a_term.eq(Cat(*nla)),
+                         neg_lsb_b_term.eq(Cat(*nlb)),
+                        ]
 
         expanded_part_pts = PartitionPoints()
         for i, v in self.part_pts.items():