format code
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
index c9239fa6cccf970c34cdf2cf2c71339a4d25d8c0..b132de56f0ecc90316b13fd1ee1f8da5b7ea4075 100644 (file)
@@ -17,22 +17,23 @@ from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder
 
 FULL_ADDER_INPUT_COUNT = 3
 
+
 class AddReduceData:
 
     def __init__(self, part_pts, n_inputs, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
-                          for i in range(n_parts)]
+                         for i in range(n_parts)]
         self.terms = [Signal(output_width, name=f"terms_{i}",
-                              reset_less=True)
-                        for i in range(n_inputs)]
+                             reset_less=True)
+                      for i in range(n_inputs)]
         self.part_pts = part_pts.like()
 
     def eq_from(self, part_pts, inputs, part_ops):
         return [self.part_pts.eq(part_pts)] + \
                [self.terms[i].eq(inputs[i])
-                                     for i in range(len(self.terms))] + \
+                for i in range(len(self.terms))] + \
                [self.part_ops[i].eq(part_ops[i])
-                                     for i in range(len(self.part_ops))]
+                for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
         return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
@@ -42,7 +43,7 @@ class FinalReduceData:
 
     def __init__(self, part_pts, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
-                          for i in range(n_parts)]
+                         for i in range(n_parts)]
         self.output = Signal(output_width, reset_less=True)
         self.part_pts = part_pts.like()
 
@@ -50,7 +51,7 @@ class FinalReduceData:
         return [self.part_pts.eq(part_pts)] + \
                [self.output.eq(output)] + \
                [self.part_ops[i].eq(part_ops[i])
-                                     for i in range(len(self.part_ops))]
+                for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
         return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
@@ -61,7 +62,7 @@ class FinalAdd(PipeModBase):
     """
 
     def __init__(self, pspec, lidx, n_inputs, partition_points,
-                       partition_step=1):
+                 partition_step=1):
         self.lidx = lidx
         self.partition_step = partition_step
         self.output_width = pspec.width * 2
@@ -79,7 +80,7 @@ class FinalAdd(PipeModBase):
 
     def ospec(self):
         return FinalReduceData(self.partition_points,
-                                 self.output_width, self.n_parts)
+                               self.output_width, self.n_parts)
 
     def elaborate(self, platform):
         """Elaborate this module."""
@@ -123,7 +124,7 @@ class AddReduceSingle(PipeModBase):
     """
 
     def __init__(self, pspec, lidx, n_inputs, partition_points,
-                       partition_step=1):
+                 partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -228,7 +229,7 @@ class AddReduceSingle(PipeModBase):
         # copy reg part points and part ops to output
         m.d.comb += self.o.part_pts.eq(self.i.part_pts)
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
-                                     for i in range(len(self.i.part_ops))]
+                     for i in range(len(self.i.part_ops))]
 
         # set up the partition mask (for the adders)
         part_mask = Signal(self.output_width, reset_less=True)
@@ -316,7 +317,7 @@ class AddReduce(AddReduceInternal, Elaboratable):
     """
 
     def __init__(self, inputs, output_width, register_levels, part_pts,
-                       part_ops, partition_step=1):
+                 part_ops, partition_step=1):
         """Create an ``AddReduce``.
 
         :param inputs: input ``Signal``s to be summed.
@@ -330,7 +331,7 @@ class AddReduce(AddReduceInternal, Elaboratable):
         self._part_ops = part_ops
         n_parts = len(part_ops)
         self.i = AddReduceData(part_pts, len(inputs),
-                             output_width, n_parts)
+                               output_width, n_parts)
         AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
                                    partition_step)
         self.o = FinalReduceData(part_pts, output_width, n_parts)
@@ -351,7 +352,8 @@ class AddReduce(AddReduceInternal, Elaboratable):
         """Elaborate this module."""
         m = Module()
 
-        m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
+        m.d.comb += self.i.eq_from(self._part_pts,
+                                   self._inputs, self._part_ops)
 
         for i, next_level in enumerate(self.levels):
             setattr(m.submodules, "next_level%d" % i, next_level)
@@ -363,7 +365,7 @@ class AddReduce(AddReduceInternal, Elaboratable):
                 m.d.sync += mcur.i.eq(i)
             else:
                 m.d.comb += mcur.i.eq(i)
-            i = mcur.o # for next loop
+            i = mcur.o  # for next loop
 
         # output comes from last module
         m.d.comb += self.o.eq(i)
@@ -420,7 +422,7 @@ class ProductTerm(Elaboratable):
         else:
             term_enabled = None
         self.enabled = term_enabled
-        self.term.name = "term_%d_%d" % (a_index, b_index) # rename
+        self.term.name = "term_%d_%d" % (a_index, b_index)  # rename
 
     def elaborate(self, platform):
 
@@ -463,6 +465,7 @@ class ProductTerms(Elaboratable):
         this class is to be wrapped with a for-loop on the "a" operand.
         it creates a second-level for-loop on the "b" operand.
     """
+
     def __init__(self, width, twidth, pbwid, a_index, blen):
         self.a_index = a_index
         self.blen = blen
@@ -472,8 +475,8 @@ class ProductTerms(Elaboratable):
         self.a = Signal(twidth//2, reset_less=True)
         self.b = Signal(twidth//2, reset_less=True)
         self.pb_en = Signal(pbwid, reset_less=True)
-        self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
-                            for i in range(blen)]
+        self.terms = [Signal(twidth, name="term%d" % i, reset_less=True)
+                      for i in range(blen)]
 
     def elaborate(self, platform):
 
@@ -508,7 +511,7 @@ class LSBNegTerm(Elaboratable):
         m = Module()
         comb = m.d.comb
         bit_wid = self.bit_width
-        ext = Repl(0, bit_wid) # extend output to HI part
+        ext = Repl(0, bit_wid)  # extend output to HI part
 
         # determine sign of each incoming number *in this partition*
         enabled = Signal(reset_less=True)
@@ -582,6 +585,7 @@ class Part(Elaboratable):
         the extra terms - as separate terms - are then thrown at the
         AddReduce alongside the multiplication part-results.
     """
+
     def __init__(self, part_pts, width, n_parts, pbwid):
 
         self.pbwid = pbwid
@@ -591,14 +595,14 @@ class Part(Elaboratable):
         self.a = Signal(64, reset_less=True)
         self.b = Signal(64, reset_less=True)
         self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
-                            for i in range(8)]
+                         for i in range(8)]
         self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
-                            for i in range(8)]
+                         for i in range(8)]
         self.pbs = Signal(pbwid, reset_less=True)
 
         # outputs
         self.parts = [Signal(name=f"part_{i}", reset_less=True)
-                            for i in range(n_parts)]
+                      for i in range(n_parts)]
 
         self.not_a_term = Signal(width, reset_less=True)
         self.neg_lsb_a_term = Signal(width, reset_less=True)
@@ -617,10 +621,10 @@ class Part(Elaboratable):
         byte_count = 8 // len(parts)
 
         not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
-                self.not_a_term, self.neg_lsb_a_term,
-                self.not_b_term, self.neg_lsb_b_term)
+            self.not_a_term, self.neg_lsb_a_term,
+            self.not_b_term, self.neg_lsb_b_term)
 
-        byte_width = 8 // len(parts) # byte width
+        byte_width = 8 // len(parts)  # byte width
         bit_wid = 8 * byte_width     # bit width
         nat, nbt, nla, nlb = [], [], [], []
         for i in range(len(parts)):
@@ -629,8 +633,8 @@ class Part(Elaboratable):
             setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
             m.d.comb += pa.part.eq(parts[i])
             m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
-            m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
-            m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
+            m.d.comb += pa.signed.eq(self.b_signed[i * byte_width])  # yes b
+            m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1])  # really, b
             nat.append(pa.nt)
             nla.append(pa.nl)
 
@@ -639,8 +643,8 @@ class Part(Elaboratable):
             setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
             m.d.comb += pb.part.eq(parts[i])
             m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
-            m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
-            m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
+            m.d.comb += pb.signed.eq(self.a_signed[i * byte_width])  # yes a
+            m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1])  # really, a
             nbt.append(pb.nt)
             nlb.append(pb.nl)
 
@@ -649,7 +653,7 @@ class Part(Elaboratable):
                      not_b_term.eq(Cat(*nbt)),
                      neg_lsb_a_term.eq(Cat(*nla)),
                      neg_lsb_b_term.eq(Cat(*nlb)),
-                    ]
+                     ]
 
         return m
 
@@ -658,11 +662,12 @@ class IntermediateOut(Elaboratable):
     """ selects the HI/LO part of the multiplication, for a given bit-width
         the output is also reconstructed in its SIMD (partition) lanes.
     """
+
     def __init__(self, width, out_wid, n_parts):
         self.width = width
         self.n_parts = n_parts
         self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
-                                     for i in range(8)]
+                         for i in range(8)]
         self.intermed = Signal(out_wid, reset_less=True)
         self.output = Signal(out_wid//2, reset_less=True)
 
@@ -691,6 +696,7 @@ class FinalOut(PipeModBase):
         that some partitions requested 8-bit computation whilst others
         requested 16 or 32 bit.
     """
+
     def __init__(self, pspec, part_pts):
 
         self.part_pts = part_pts
@@ -754,9 +760,9 @@ class FinalOut(PipeModBase):
             m.d.comb += op.eq(
                 Mux(d8[i] | d16[i // 2],
                     Mux(d8[i], i8.bit_select(i * 8, 8),
-                               i16.bit_select(i * 8, 8)),
+                        i16.bit_select(i * 8, 8)),
                     Mux(d32[i // 4], i32.bit_select(i * 8, 8),
-                                      i64.bit_select(i * 8, 8))))
+                        i64.bit_select(i * 8, 8))))
             ol.append(op)
 
         # create outputs
@@ -769,6 +775,7 @@ class FinalOut(PipeModBase):
 class OrMod(Elaboratable):
     """ ORs four values together in a hierarchical tree
     """
+
     def __init__(self, wid):
         self.wid = wid
         self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
@@ -802,7 +809,7 @@ class Signs(Elaboratable):
 
         asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
         bsig = (self.part_ops == OP_MUL_LOW) \
-                    | (self.part_ops == OP_MUL_SIGNED_HIGH)
+            | (self.part_ops == OP_MUL_SIGNED_HIGH)
         m.d.comb += self.a_signed.eq(asig)
         m.d.comb += self.b_signed.eq(bsig)
 
@@ -813,21 +820,21 @@ class IntermediateData:
 
     def __init__(self, part_pts, output_width, n_parts):
         self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
-                          for i in range(n_parts)]
+                         for i in range(n_parts)]
         self.part_pts = part_pts.like()
         self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
-                          for i in range(4)]
+                        for i in range(4)]
         # intermediates (needed for unit tests)
         self.intermediate_output = Signal(output_width)
 
     def eq_from(self, part_pts, outputs, intermediate_output,
-                      part_ops):
+                part_ops):
         return [self.part_pts.eq(part_pts)] + \
                [self.intermediate_output.eq(intermediate_output)] + \
                [self.outputs[i].eq(outputs[i])
-                                     for i in range(4)] + \
+                for i in range(4)] + \
                [self.part_ops[i].eq(part_ops[i])
-                                     for i in range(len(self.part_ops))]
+                for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
         return self.eq_from(rhs.part_pts, rhs.outputs,
@@ -848,7 +855,7 @@ class InputData:
         return [self.part_pts.eq(part_pts)] + \
                [self.a.eq(a), self.b.eq(b)] + \
                [self.part_ops[i].eq(part_ops[i])
-                                     for i in range(len(self.part_ops))]
+                for i in range(len(self.part_ops))]
 
     def eq(self, rhs):
         return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
@@ -857,7 +864,7 @@ class InputData:
 class OutputData:
 
     def __init__(self):
-        self.intermediate_output = Signal(128) # needed for unit tests
+        self.intermediate_output = Signal(128)  # needed for unit tests
         self.output = Signal(64)
 
     def eq(self, rhs):
@@ -943,9 +950,9 @@ class AllTerms(PipeModBase):
         m.submodules.nla_or = nla_or = OrMod(128)
         m.submodules.nlb_or = nlb_or = OrMod(128)
         for l, mod in [(nat_l, nat_or),
-                             (nbt_l, nbt_or),
-                             (nla_l, nla_or),
-                             (nlb_l, nlb_or)]:
+                       (nbt_l, nbt_or),
+                       (nla_l, nla_or),
+                       (nlb_l, nlb_or)]:
             for i in range(len(l)):
                 m.d.comb += mod.orin[i].eq(l[i])
             terms.append(mod.orout)
@@ -957,7 +964,7 @@ class AllTerms(PipeModBase):
         # copy reg part points and part ops to output
         m.d.comb += self.o.part_pts.eq(eps)
         m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
-                                     for i in range(len(self.i.part_ops))]
+                     for i in range(len(self.i.part_ops))]
 
         return m
 
@@ -1056,7 +1063,7 @@ class Mul8_16_32_64(Elaboratable):
             flip-flops are to be inserted.
         """
 
-        self.id_wid = 0 # num_bits(num_rows)
+        self.id_wid = 0  # num_bits(num_rows)
         self.op_wid = 0
         self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
         self.pspec.n_parts = 8
@@ -1094,7 +1101,8 @@ class Mul8_16_32_64(Elaboratable):
 
         terms = t.o.terms
 
-        at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
+        at = AddReduceInternal(self.pspec, n_inputs,
+                               part_pts, partition_step=2)
 
         i = t.o
         for idx in range(len(at.levels)):
@@ -1105,7 +1113,7 @@ class Mul8_16_32_64(Elaboratable):
                 m.d.sync += o.eq(mcur.process(i))
             else:
                 m.d.comb += o.eq(mcur.process(i))
-            i = o # for next loop
+            i = o  # for next loop
 
         interm = Intermediates(self.pspec, part_pts)
         interm.setup(m, i)