]:
             byte_width = 8 // len(parts)
             bit_width = 8 * byte_width
+            nat, nbt, nla, nlb = [], [], [], []
             for i in range(len(parts)):
                 be = parts[i] & self.a[(i + 1) * bit_width - 1] \
                     & self._a_signed[i * byte_width]
                 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
                 # negation operation is split into a bitwise not and a +1.
                 # likewise for 16, 32, and 64-bit values.
-                m.d.comb += [
-                    not_a_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(a_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.a.bit_select(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_a_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), a_enabled)),
-
-                    not_b_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Mux(b_enabled,
-                            Cat(Repl(0, bit_width),
-                                ~self.b.bit_select(bit_width * i, bit_width)),
-                            0)),
-
-                    neg_lsb_b_term.bit_select(bit_width * 2 * i, bit_width * 2)
-                    .eq(Cat(Repl(0, bit_width), b_enabled))]
+                nat.append(Mux(a_enabled,
+                        Cat(Repl(0, bit_width),
+                            ~self.a.bit_select(bit_width * i, bit_width)),
+                        0))
+
+                nla.append(Cat(Repl(0, bit_width), a_enabled,
+                               Repl(0, bit_width-1)))
+
+                nbt.append(Mux(b_enabled,
+                        Cat(Repl(0, bit_width),
+                            ~self.b.bit_select(bit_width * i, bit_width)),
+                        0))
+
+                nlb.append(Cat(Repl(0, bit_width), b_enabled,
+                               Repl(0, bit_width-1)))
+
+            m.d.comb += [not_a_term.eq(Cat(*nat)),
+                         not_b_term.eq(Cat(*nbt)),
+                         neg_lsb_a_term.eq(Cat(*nla)),
+                         neg_lsb_b_term.eq(Cat(*nlb)),
+                        ]
 
         expanded_part_pts = PartitionPoints()
         for i, v in self.part_pts.items():