add to docstrings in PartitionedAdder
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11
12
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
15
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
18
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
21 * bits 0 <= ``i`` < 1
22 * bits 1 <= ``i`` < 5
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
25
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
29 * bits 0 <= ``i`` < 1
30 * bits 1 <= ``i`` < 5
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
33 * Otherwise
34 * bits 0 <= ``i`` < 1
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
37 """
38
39 def __init__(self, partition_points=None):
40 """Create a new ``PartitionPoints``.
41
42 :param partition_points: the input partition points to values mapping.
43 """
44 super().__init__()
45 if partition_points is not None:
46 for point, enabled in partition_points.items():
47 if not isinstance(point, int):
48 raise TypeError("point must be a non-negative integer")
49 if point < 0:
50 raise ValueError("point must be a non-negative integer")
51 self[point] = Value.wrap(enabled)
52
53 def like(self, name=None, src_loc_at=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
55
56 :param name: the base name for the new ``Signal``s.
57 """
58 if name is None:
59 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
60 retval = PartitionPoints()
61 for point, enabled in self.items():
62 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
63 return retval
64
65 def eq(self, rhs):
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self.keys()) != set(rhs.keys()):
68 raise ValueError("incompatible point set")
69 for point, enabled in self.items():
70 yield enabled.eq(rhs[point])
71
72 def as_mask(self, width):
73 """Create a bit-mask from `self`.
74
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
77
78 :param width: the bit width of the resulting mask
79 """
80 bits = []
81 for i in range(width):
82 if i in self:
83 bits.append(~self[i])
84 else:
85 bits.append(True)
86 return Cat(*bits)
87
88 def get_max_partition_count(self, width):
89 """Get the maximum number of partitions.
90
91 Gets the number of partitions when all partition points are enabled.
92 """
93 retval = 1
94 for point in self.keys():
95 if point < width:
96 retval += 1
97 return retval
98
99 def fits_in_width(self, width):
100 """Check if all partition points are smaller than `width`."""
101 for point in self.keys():
102 if point >= width:
103 return False
104 return True
105
106
107 class FullAdder(Elaboratable):
108 """Full Adder.
109
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
115
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
120 """
121
122 def __init__(self, width):
123 """Create a ``FullAdder``.
124
125 :param width: the bit width of the input and output
126 """
127 self.in0 = Signal(width)
128 self.in1 = Signal(width)
129 self.in2 = Signal(width)
130 self.sum = Signal(width)
131 self.carry = Signal(width)
132
133 def elaborate(self, platform):
134 """Elaborate this module."""
135 m = Module()
136 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
137 m.d.comb += self.carry.eq((self.in0 & self.in1)
138 | (self.in1 & self.in2)
139 | (self.in2 & self.in0))
140 return m
141
142
143 class MaskedFullAdder(FullAdder):
144 """Masked Full Adder.
145
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
152
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
156 """
157
158 def __init__(self, width):
159 """Create a ``MaskedFullAdder``.
160
161 :param width: the bit width of the input and output
162 """
163 FullAdder.__init__(self, width)
164 self.mask = Signal(width)
165 self.mcarry = Signal(width)
166
167 def elaborate(self, platform):
168 """Elaborate this module."""
169 m = FullAdder.elaborate(self, platform)
170 m.d.comb += self.mcarry.eq((self.carry << 1) & self.mask)
171 return m
172
173
174 class PartitionedAdder(Elaboratable):
175 """Partitioned Adder.
176
177 Performs the final add. The partition points are included in the
178 actual add (in one of the operands only), which causes a carry over
179 to the next bit. Then the final output *removes* the extra bits from
180 the result.
181
182 partition: .... P... P... P... P... (32 bits)
183 a : .... .... .... .... .... (32 bits)
184 b : .... .... .... .... .... (32 bits)
185 exp-a : ....P....P....P....P.... (32+4 bits)
186 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
187 exp-o : ....xN...xN...xN...xN... (32+4 bits)
188 o : .... N... N... N... N... (32 bits)
189
190 :attribute width: the bit width of the input and output. Read-only.
191 :attribute a: the first input to the adder
192 :attribute b: the second input to the adder
193 :attribute output: the sum output
194 :attribute partition_points: the input partition points. Modification not
195 supported, except for by ``Signal.eq``.
196 """
197
198 def __init__(self, width, partition_points):
199 """Create a ``PartitionedAdder``.
200
201 :param width: the bit width of the input and output
202 :param partition_points: the input partition points
203 """
204 self.width = width
205 self.a = Signal(width)
206 self.b = Signal(width)
207 self.output = Signal(width)
208 self.partition_points = PartitionPoints(partition_points)
209 if not self.partition_points.fits_in_width(width):
210 raise ValueError("partition_points doesn't fit in width")
211 expanded_width = 0
212 for i in range(self.width):
213 if i in self.partition_points:
214 expanded_width += 1
215 expanded_width += 1
216 self._expanded_width = expanded_width
217 # XXX these have to remain here due to some horrible nmigen
218 # simulation bugs involving sync. it is *not* necessary to
219 # have them here, they should (under normal circumstances)
220 # be moved into elaborate, as they are entirely local
221 self._expanded_a = Signal(expanded_width) # includes extra part-points
222 self._expanded_b = Signal(expanded_width) # likewise.
223 self._expanded_o = Signal(expanded_width) # likewise.
224
225 def elaborate(self, platform):
226 """Elaborate this module."""
227 m = Module()
228 expanded_index = 0
229 # store bits in a list, use Cat later. graphviz is much cleaner
230 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
231
232 # partition points are "breaks" (extra zeros or 1s) in what would
233 # otherwise be a massive long add. when the "break" points are 0,
234 # whatever is in it (in the output) is discarded. however when
235 # there is a "1", it causes a roll-over carry to the *next* bit.
236 # we still ignore the "break" bit in the [intermediate] output,
237 # however by that time we've got the effect that we wanted: the
238 # carry has been carried *over* the break point.
239
240 for i in range(self.width):
241 if i in self.partition_points:
242 # add extra bit set to 0 + 0 for enabled partition points
243 # and 1 + 0 for disabled partition points
244 ea.append(self._expanded_a[expanded_index])
245 al.append(~self.partition_points[i]) # add extra bit in a
246 eb.append(self._expanded_b[expanded_index])
247 bl.append(C(0)) # yes, add a zero
248 expanded_index += 1 # skip the extra point. NOT in the output
249 ea.append(self._expanded_a[expanded_index])
250 eb.append(self._expanded_b[expanded_index])
251 eo.append(self._expanded_o[expanded_index])
252 al.append(self.a[i])
253 bl.append(self.b[i])
254 ol.append(self.output[i])
255 expanded_index += 1
256
257 # combine above using Cat
258 m.d.comb += Cat(*ea).eq(Cat(*al))
259 m.d.comb += Cat(*eb).eq(Cat(*bl))
260 m.d.comb += Cat(*ol).eq(Cat(*eo))
261
262 # use only one addition to take advantage of look-ahead carry and
263 # special hardware on FPGAs
264 m.d.comb += self._expanded_o.eq(
265 self._expanded_a + self._expanded_b)
266 return m
267
268
269 FULL_ADDER_INPUT_COUNT = 3
270
271
272 class AddReduce(Elaboratable):
273 """Add list of numbers together.
274
275 :attribute inputs: input ``Signal``s to be summed. Modification not
276 supported, except for by ``Signal.eq``.
277 :attribute register_levels: List of nesting levels that should have
278 pipeline registers.
279 :attribute output: output sum.
280 :attribute partition_points: the input partition points. Modification not
281 supported, except for by ``Signal.eq``.
282 """
283
284 def __init__(self, inputs, output_width, register_levels, partition_points):
285 """Create an ``AddReduce``.
286
287 :param inputs: input ``Signal``s to be summed.
288 :param output_width: bit-width of ``output``.
289 :param register_levels: List of nesting levels that should have
290 pipeline registers.
291 :param partition_points: the input partition points.
292 """
293 self.inputs = list(inputs)
294 self._resized_inputs = [
295 Signal(output_width, name=f"resized_inputs[{i}]")
296 for i in range(len(self.inputs))]
297 self.register_levels = list(register_levels)
298 self.output = Signal(output_width)
299 self.partition_points = PartitionPoints(partition_points)
300 if not self.partition_points.fits_in_width(output_width):
301 raise ValueError("partition_points doesn't fit in output_width")
302 self._reg_partition_points = self.partition_points.like()
303 max_level = AddReduce.get_max_level(len(self.inputs))
304 for level in self.register_levels:
305 if level > max_level:
306 raise ValueError(
307 "not enough adder levels for specified register levels")
308
309 @staticmethod
310 def get_max_level(input_count):
311 """Get the maximum level.
312
313 All ``register_levels`` must be less than or equal to the maximum
314 level.
315 """
316 retval = 0
317 while True:
318 groups = AddReduce.full_adder_groups(input_count)
319 if len(groups) == 0:
320 return retval
321 input_count %= FULL_ADDER_INPUT_COUNT
322 input_count += 2 * len(groups)
323 retval += 1
324
325 def next_register_levels(self):
326 """``Iterable`` of ``register_levels`` for next recursive level."""
327 for level in self.register_levels:
328 if level > 0:
329 yield level - 1
330
331 @staticmethod
332 def full_adder_groups(input_count):
333 """Get ``inputs`` indices for which a full adder should be built."""
334 return range(0,
335 input_count - FULL_ADDER_INPUT_COUNT + 1,
336 FULL_ADDER_INPUT_COUNT)
337
338 def elaborate(self, platform):
339 """Elaborate this module."""
340 m = Module()
341
342 # resize inputs to correct bit-width and optionally add in
343 # pipeline registers
344 resized_input_assignments = [self._resized_inputs[i].eq(self.inputs[i])
345 for i in range(len(self.inputs))]
346 if 0 in self.register_levels:
347 m.d.sync += resized_input_assignments
348 m.d.sync += self._reg_partition_points.eq(self.partition_points)
349 else:
350 m.d.comb += resized_input_assignments
351 m.d.comb += self._reg_partition_points.eq(self.partition_points)
352
353 groups = AddReduce.full_adder_groups(len(self.inputs))
354 # if there are no full adders to create, then we handle the base cases
355 # and return, otherwise we go on to the recursive case
356 if len(groups) == 0:
357 if len(self.inputs) == 0:
358 # use 0 as the default output value
359 m.d.comb += self.output.eq(0)
360 elif len(self.inputs) == 1:
361 # handle single input
362 m.d.comb += self.output.eq(self._resized_inputs[0])
363 else:
364 # base case for adding 2 or more inputs, which get recursively
365 # reduced to 2 inputs
366 assert len(self.inputs) == 2
367 adder = PartitionedAdder(len(self.output),
368 self._reg_partition_points)
369 m.submodules.final_adder = adder
370 m.d.comb += adder.a.eq(self._resized_inputs[0])
371 m.d.comb += adder.b.eq(self._resized_inputs[1])
372 m.d.comb += self.output.eq(adder.output)
373 return m
374 # go on to handle recursive case
375 intermediate_terms = []
376
377 def add_intermediate_term(value):
378 intermediate_term = Signal(
379 len(self.output),
380 name=f"intermediate_terms[{len(intermediate_terms)}]")
381 intermediate_terms.append(intermediate_term)
382 m.d.comb += intermediate_term.eq(value)
383
384 # store mask in intermediary (simplifies graph)
385 part_mask = Signal(len(self.output), reset_less=True)
386 mask = self._reg_partition_points.as_mask(len(self.output))
387 m.d.comb += part_mask.eq(mask)
388
389 # create full adders for this recursive level.
390 # this shrinks N terms to 2 * (N // 3) plus the remainder
391 for i in groups:
392 adder_i = MaskedFullAdder(len(self.output))
393 setattr(m.submodules, f"adder_{i}", adder_i)
394 m.d.comb += adder_i.in0.eq(self._resized_inputs[i])
395 m.d.comb += adder_i.in1.eq(self._resized_inputs[i + 1])
396 m.d.comb += adder_i.in2.eq(self._resized_inputs[i + 2])
397 m.d.comb += adder_i.mask.eq(part_mask)
398 add_intermediate_term(adder_i.sum)
399 # mask out carry bits to prevent carries between partitions
400 add_intermediate_term(adder_i.mcarry)
401 # handle the remaining inputs.
402 if len(self.inputs) % FULL_ADDER_INPUT_COUNT == 1:
403 add_intermediate_term(self._resized_inputs[-1])
404 elif len(self.inputs) % FULL_ADDER_INPUT_COUNT == 2:
405 # Just pass the terms to the next layer, since we wouldn't gain
406 # anything by using a half adder since there would still be 2 terms
407 # and just passing the terms to the next layer saves gates.
408 add_intermediate_term(self._resized_inputs[-2])
409 add_intermediate_term(self._resized_inputs[-1])
410 else:
411 assert len(self.inputs) % FULL_ADDER_INPUT_COUNT == 0
412 # recursive invocation of ``AddReduce``
413 next_level = AddReduce(intermediate_terms,
414 len(self.output),
415 self.next_register_levels(),
416 self._reg_partition_points)
417 m.submodules.next_level = next_level
418 m.d.comb += self.output.eq(next_level.output)
419 return m
420
421
422 OP_MUL_LOW = 0
423 OP_MUL_SIGNED_HIGH = 1
424 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
425 OP_MUL_UNSIGNED_HIGH = 3
426
427
428 def get_term(value, shift=0, enabled=None):
429 if enabled is not None:
430 value = Mux(enabled, value, 0)
431 if shift > 0:
432 value = Cat(Repl(C(0, 1), shift), value)
433 else:
434 assert shift == 0
435 return value
436
437
438 class ProductTerm(Elaboratable):
439 """ this class creates a single product term (a[..]*b[..]).
440 it has a design flaw in that is the *output* that is selected,
441 where the multiplication(s) are combinatorially generated
442 all the time.
443 """
444
445 def __init__(self, width, twidth, pbwid, a_index, b_index):
446 self.a_index = a_index
447 self.b_index = b_index
448 shift = 8 * (self.a_index + self.b_index)
449 self.pwidth = width
450 self.twidth = twidth
451 self.width = width*2
452 self.shift = shift
453
454 self.ti = Signal(self.width, reset_less=True)
455 self.term = Signal(twidth, reset_less=True)
456 self.a = Signal(twidth//2, reset_less=True)
457 self.b = Signal(twidth//2, reset_less=True)
458 self.pb_en = Signal(pbwid, reset_less=True)
459
460 self.tl = tl = []
461 min_index = min(self.a_index, self.b_index)
462 max_index = max(self.a_index, self.b_index)
463 for i in range(min_index, max_index):
464 tl.append(self.pb_en[i])
465 name = "te_%d_%d" % (self.a_index, self.b_index)
466 if len(tl) > 0:
467 term_enabled = Signal(name=name, reset_less=True)
468 else:
469 term_enabled = None
470 self.enabled = term_enabled
471 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
472
473 def elaborate(self, platform):
474
475 m = Module()
476 if self.enabled is not None:
477 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
478
479 bsa = Signal(self.width, reset_less=True)
480 bsb = Signal(self.width, reset_less=True)
481 a_index, b_index = self.a_index, self.b_index
482 pwidth = self.pwidth
483 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
484 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
485 m.d.comb += self.ti.eq(bsa * bsb)
486 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
487 """
488 #TODO: sort out width issues, get inputs a/b switched on/off.
489 #data going into Muxes is 1/2 the required width
490
491 pwidth = self.pwidth
492 width = self.width
493 bsa = Signal(self.twidth//2, reset_less=True)
494 bsb = Signal(self.twidth//2, reset_less=True)
495 asel = Signal(width, reset_less=True)
496 bsel = Signal(width, reset_less=True)
497 a_index, b_index = self.a_index, self.b_index
498 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
499 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
500 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
501 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
502 m.d.comb += self.ti.eq(bsa * bsb)
503 m.d.comb += self.term.eq(self.ti)
504 """
505
506 return m
507
508
509 class ProductTerms(Elaboratable):
510 """ creates a bank of product terms. also performs the actual bit-selection
511 this class is to be wrapped with a for-loop on the "a" operand.
512 it creates a second-level for-loop on the "b" operand.
513 """
514 def __init__(self, width, twidth, pbwid, a_index, blen):
515 self.a_index = a_index
516 self.blen = blen
517 self.pwidth = width
518 self.twidth = twidth
519 self.pbwid = pbwid
520 self.a = Signal(twidth//2, reset_less=True)
521 self.b = Signal(twidth//2, reset_less=True)
522 self.pb_en = Signal(pbwid, reset_less=True)
523 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
524 for i in range(blen)]
525
526 def elaborate(self, platform):
527
528 m = Module()
529
530 for b_index in range(self.blen):
531 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
532 self.a_index, b_index)
533 setattr(m.submodules, "term_%d" % b_index, t)
534
535 m.d.comb += t.a.eq(self.a)
536 m.d.comb += t.b.eq(self.b)
537 m.d.comb += t.pb_en.eq(self.pb_en)
538
539 m.d.comb += self.terms[b_index].eq(t.term)
540
541 return m
542
543 class LSBNegTerm(Elaboratable):
544
545 def __init__(self, bit_width):
546 self.bit_width = bit_width
547 self.part = Signal(reset_less=True)
548 self.signed = Signal(reset_less=True)
549 self.op = Signal(bit_width, reset_less=True)
550 self.msb = Signal(reset_less=True)
551 self.nt = Signal(bit_width*2, reset_less=True)
552 self.nl = Signal(bit_width*2, reset_less=True)
553
554 def elaborate(self, platform):
555 m = Module()
556 comb = m.d.comb
557 bit_wid = self.bit_width
558 ext = Repl(0, bit_wid) # extend output to HI part
559
560 # determine sign of each incoming number *in this partition*
561 enabled = Signal(reset_less=True)
562 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
563
564 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
565 # negation operation is split into a bitwise not and a +1.
566 # likewise for 16, 32, and 64-bit values.
567
568 # width-extended 1s complement if a is signed, otherwise zero
569 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
570
571 # add 1 if signed, otherwise add zero
572 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
573
574 return m
575
576
577 class Part(Elaboratable):
578 """ a key class which, depending on the partitioning, will determine
579 what action to take when parts of the output are signed or unsigned.
580
581 this requires 2 pieces of data *per operand, per partition*:
582 whether the MSB is HI/LO (per partition!), and whether a signed
583 or unsigned operation has been *requested*.
584
585 once that is determined, signed is basically carried out
586 by splitting 2's complement into 1's complement plus one.
587 1's complement is just a bit-inversion.
588
589 the extra terms - as separate terms - are then thrown at the
590 AddReduce alongside the multiplication part-results.
591 """
592 def __init__(self, width, n_parts, n_levels, pbwid):
593
594 # inputs
595 self.a = Signal(64)
596 self.b = Signal(64)
597 self.a_signed = [Signal(name=f"a_signed_{i}") for i in range(8)]
598 self.b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
599 self.pbs = Signal(pbwid, reset_less=True)
600
601 # outputs
602 self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
603 self.delayed_parts = [
604 [Signal(name=f"delayed_part_{delay}_{i}")
605 for i in range(n_parts)]
606 for delay in range(n_levels)]
607 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
608 self.dplast = [Signal(name=f"dplast_{i}")
609 for i in range(n_parts)]
610
611 self.not_a_term = Signal(width)
612 self.neg_lsb_a_term = Signal(width)
613 self.not_b_term = Signal(width)
614 self.neg_lsb_b_term = Signal(width)
615
616 def elaborate(self, platform):
617 m = Module()
618
619 pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
620 # negated-temporary copy of partition bits
621 npbs = Signal.like(pbs, reset_less=True)
622 m.d.comb += npbs.eq(~pbs)
623 byte_count = 8 // len(parts)
624 for i in range(len(parts)):
625 pbl = []
626 pbl.append(npbs[i * byte_count - 1])
627 for j in range(i * byte_count, (i + 1) * byte_count - 1):
628 pbl.append(pbs[j])
629 pbl.append(npbs[(i + 1) * byte_count - 1])
630 value = Signal(len(pbl), name="value_%di" % i, reset_less=True)
631 m.d.comb += value.eq(Cat(*pbl))
632 m.d.comb += parts[i].eq(~(value).bool())
633 m.d.comb += delayed_parts[0][i].eq(parts[i])
634 m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
635 for j in range(len(delayed_parts)-1)]
636 m.d.comb += self.dplast[i].eq(delayed_parts[-1][i])
637
638 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
639 self.not_a_term, self.neg_lsb_a_term, \
640 self.not_b_term, self.neg_lsb_b_term
641
642 byte_width = 8 // len(parts) # byte width
643 bit_wid = 8 * byte_width # bit width
644 nat, nbt, nla, nlb = [], [], [], []
645 for i in range(len(parts)):
646 # work out bit-inverted and +1 term for a.
647 pa = LSBNegTerm(bit_wid)
648 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
649 m.d.comb += pa.part.eq(parts[i])
650 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
651 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
652 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
653 nat.append(pa.nt)
654 nla.append(pa.nl)
655
656 # work out bit-inverted and +1 term for b
657 pb = LSBNegTerm(bit_wid)
658 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
659 m.d.comb += pb.part.eq(parts[i])
660 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
661 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
662 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
663 nbt.append(pb.nt)
664 nlb.append(pb.nl)
665
666 # concatenate together and return all 4 results.
667 m.d.comb += [not_a_term.eq(Cat(*nat)),
668 not_b_term.eq(Cat(*nbt)),
669 neg_lsb_a_term.eq(Cat(*nla)),
670 neg_lsb_b_term.eq(Cat(*nlb)),
671 ]
672
673 return m
674
675
676 class IntermediateOut(Elaboratable):
677 """ selects the HI/LO part of the multiplication, for a given bit-width
678 the output is also reconstructed in its SIMD (partition) lanes.
679 """
680 def __init__(self, width, out_wid, n_parts):
681 self.width = width
682 self.n_parts = n_parts
683 self.delayed_part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
684 for i in range(8)]
685 self.intermed = Signal(out_wid, reset_less=True)
686 self.output = Signal(out_wid//2, reset_less=True)
687
688 def elaborate(self, platform):
689 m = Module()
690
691 ol = []
692 w = self.width
693 sel = w // 8
694 for i in range(self.n_parts):
695 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
696 m.d.comb += op.eq(
697 Mux(self.delayed_part_ops[sel * i] == OP_MUL_LOW,
698 self.intermed.bit_select(i * w*2, w),
699 self.intermed.bit_select(i * w*2 + w, w)))
700 ol.append(op)
701 m.d.comb += self.output.eq(Cat(*ol))
702
703 return m
704
705
706 class FinalOut(Elaboratable):
707 """ selects the final output based on the partitioning.
708
709 each byte is selectable independently, i.e. it is possible
710 that some partitions requested 8-bit computation whilst others
711 requested 16 or 32 bit.
712 """
713 def __init__(self, out_wid):
714 # inputs
715 self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
716 self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
717 self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
718
719 self.i8 = Signal(out_wid, reset_less=True)
720 self.i16 = Signal(out_wid, reset_less=True)
721 self.i32 = Signal(out_wid, reset_less=True)
722 self.i64 = Signal(out_wid, reset_less=True)
723
724 # output
725 self.out = Signal(out_wid, reset_less=True)
726
727 def elaborate(self, platform):
728 m = Module()
729 ol = []
730 for i in range(8):
731 # select one of the outputs: d8 selects i8, d16 selects i16
732 # d32 selects i32, and the default is i64.
733 # d8 and d16 are ORed together in the first Mux
734 # then the 2nd selects either i8 or i16.
735 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
736 op = Signal(8, reset_less=True, name="op_%d" % i)
737 m.d.comb += op.eq(
738 Mux(self.d8[i] | self.d16[i // 2],
739 Mux(self.d8[i], self.i8.bit_select(i * 8, 8),
740 self.i16.bit_select(i * 8, 8)),
741 Mux(self.d32[i // 4], self.i32.bit_select(i * 8, 8),
742 self.i64.bit_select(i * 8, 8))))
743 ol.append(op)
744 m.d.comb += self.out.eq(Cat(*ol))
745 return m
746
747
748 class OrMod(Elaboratable):
749 """ ORs four values together in a hierarchical tree
750 """
751 def __init__(self, wid):
752 self.wid = wid
753 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
754 for i in range(4)]
755 self.orout = Signal(wid, reset_less=True)
756
757 def elaborate(self, platform):
758 m = Module()
759 or1 = Signal(self.wid, reset_less=True)
760 or2 = Signal(self.wid, reset_less=True)
761 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
762 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
763 m.d.comb += self.orout.eq(or1 | or2)
764
765 return m
766
767
768 class Signs(Elaboratable):
769 """ determines whether a or b are signed numbers
770 based on the required operation type (OP_MUL_*)
771 """
772
773 def __init__(self):
774 self.part_ops = Signal(2, reset_less=True)
775 self.a_signed = Signal(reset_less=True)
776 self.b_signed = Signal(reset_less=True)
777
778 def elaborate(self, platform):
779
780 m = Module()
781
782 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
783 bsig = (self.part_ops == OP_MUL_LOW) \
784 | (self.part_ops == OP_MUL_SIGNED_HIGH)
785 m.d.comb += self.a_signed.eq(asig)
786 m.d.comb += self.b_signed.eq(bsig)
787
788 return m
789
790
791 class Mul8_16_32_64(Elaboratable):
792 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
793
794 Supports partitioning into any combination of 8, 16, 32, and 64-bit
795 partitions on naturally-aligned boundaries. Supports the operation being
796 set for each partition independently.
797
798 :attribute part_pts: the input partition points. Has a partition point at
799 multiples of 8 in 0 < i < 64. Each partition point's associated
800 ``Value`` is a ``Signal``. Modification not supported, except for by
801 ``Signal.eq``.
802 :attribute part_ops: the operation for each byte. The operation for a
803 particular partition is selected by assigning the selected operation
804 code to each byte in the partition. The allowed operation codes are:
805
806 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
807 RISC-V's `mul` instruction.
808 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
809 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
810 instruction.
811 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
812 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
813 `mulhsu` instruction.
814 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
815 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
816 instruction.
817 """
818
819 def __init__(self, register_levels=()):
820 """ register_levels: specifies the points in the cascade at which
821 flip-flops are to be inserted.
822 """
823
824 # parameter(s)
825 self.register_levels = list(register_levels)
826
827 # inputs
828 self.part_pts = PartitionPoints()
829 for i in range(8, 64, 8):
830 self.part_pts[i] = Signal(name=f"part_pts_{i}")
831 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
832 self.a = Signal(64)
833 self.b = Signal(64)
834
835 # intermediates (needed for unit tests)
836 self._intermediate_output = Signal(128)
837
838 # output
839 self.output = Signal(64)
840
841 def _part_byte(self, index):
842 if index == -1 or index == 7:
843 return C(True, 1)
844 assert index >= 0 and index < 8
845 return self.part_pts[index * 8 + 8]
846
847 def elaborate(self, platform):
848 m = Module()
849
850 # collect part-bytes
851 pbs = Signal(8, reset_less=True)
852 tl = []
853 for i in range(8):
854 pb = Signal(name="pb%d" % i, reset_less=True)
855 m.d.comb += pb.eq(self._part_byte(i))
856 tl.append(pb)
857 m.d.comb += pbs.eq(Cat(*tl))
858
859 # local variables
860 signs = []
861 for i in range(8):
862 s = Signs()
863 signs.append(s)
864 setattr(m.submodules, "signs%d" % i, s)
865 m.d.comb += s.part_ops.eq(self.part_ops[i])
866
867 delayed_part_ops = [
868 [Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
869 for i in range(8)]
870 for delay in range(1 + len(self.register_levels))]
871 for i in range(len(self.part_ops)):
872 m.d.comb += delayed_part_ops[0][i].eq(self.part_ops[i])
873 m.d.sync += [delayed_part_ops[j + 1][i].eq(delayed_part_ops[j][i])
874 for j in range(len(self.register_levels))]
875
876 n_levels = len(self.register_levels)+1
877 m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
878 m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
879 m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
880 m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
881 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
882 for mod in [part_8, part_16, part_32, part_64]:
883 m.d.comb += mod.a.eq(self.a)
884 m.d.comb += mod.b.eq(self.b)
885 for i in range(len(signs)):
886 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
887 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
888 m.d.comb += mod.pbs.eq(pbs)
889 nat_l.append(mod.not_a_term)
890 nbt_l.append(mod.not_b_term)
891 nla_l.append(mod.neg_lsb_a_term)
892 nlb_l.append(mod.neg_lsb_b_term)
893
894 terms = []
895
896 for a_index in range(8):
897 t = ProductTerms(8, 128, 8, a_index, 8)
898 setattr(m.submodules, "terms_%d" % a_index, t)
899
900 m.d.comb += t.a.eq(self.a)
901 m.d.comb += t.b.eq(self.b)
902 m.d.comb += t.pb_en.eq(pbs)
903
904 for term in t.terms:
905 terms.append(term)
906
907 # it's fine to bitwise-or data together since they are never enabled
908 # at the same time
909 m.submodules.nat_or = nat_or = OrMod(128)
910 m.submodules.nbt_or = nbt_or = OrMod(128)
911 m.submodules.nla_or = nla_or = OrMod(128)
912 m.submodules.nlb_or = nlb_or = OrMod(128)
913 for l, mod in [(nat_l, nat_or),
914 (nbt_l, nbt_or),
915 (nla_l, nla_or),
916 (nlb_l, nlb_or)]:
917 for i in range(len(l)):
918 m.d.comb += mod.orin[i].eq(l[i])
919 terms.append(mod.orout)
920
921 expanded_part_pts = PartitionPoints()
922 for i, v in self.part_pts.items():
923 signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
924 expanded_part_pts[i * 2] = signal
925 m.d.comb += signal.eq(v)
926
927 add_reduce = AddReduce(terms,
928 128,
929 self.register_levels,
930 expanded_part_pts)
931 m.submodules.add_reduce = add_reduce
932 m.d.comb += self._intermediate_output.eq(add_reduce.output)
933 # create _output_64
934 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
935 m.d.comb += io64.intermed.eq(self._intermediate_output)
936 for i in range(8):
937 m.d.comb += io64.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
938
939 # create _output_32
940 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
941 m.d.comb += io32.intermed.eq(self._intermediate_output)
942 for i in range(8):
943 m.d.comb += io32.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
944
945 # create _output_16
946 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
947 m.d.comb += io16.intermed.eq(self._intermediate_output)
948 for i in range(8):
949 m.d.comb += io16.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
950
951 # create _output_8
952 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
953 m.d.comb += io8.intermed.eq(self._intermediate_output)
954 for i in range(8):
955 m.d.comb += io8.delayed_part_ops[i].eq(delayed_part_ops[-1][i])
956
957 # final output
958 m.submodules.finalout = finalout = FinalOut(64)
959 for i in range(len(part_8.delayed_parts[-1])):
960 m.d.comb += finalout.d8[i].eq(part_8.dplast[i])
961 for i in range(len(part_16.delayed_parts[-1])):
962 m.d.comb += finalout.d16[i].eq(part_16.dplast[i])
963 for i in range(len(part_32.delayed_parts[-1])):
964 m.d.comb += finalout.d32[i].eq(part_32.dplast[i])
965 m.d.comb += finalout.i8.eq(io8.output)
966 m.d.comb += finalout.i16.eq(io16.output)
967 m.d.comb += finalout.i32.eq(io32.output)
968 m.d.comb += finalout.i64.eq(io64.output)
969 m.d.comb += self.output.eq(finalout.out)
970
971 return m
972
973
974 if __name__ == "__main__":
975 m = Mul8_16_32_64()
976 main(m, ports=[m.a,
977 m.b,
978 m._intermediate_output,
979 m.output,
980 *m.part_ops,
981 *m.part_pts.values()])