1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class MaskedFullAdder(FullAdder
):
144 """Masked Full Adder.
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
158 def __init__(self
, width
):
159 """Create a ``MaskedFullAdder``.
161 :param width: the bit width of the input and output
163 FullAdder
.__init
__(self
, width
)
164 self
.mask
= Signal(width
)
165 self
.mcarry
= Signal(width
)
167 def elaborate(self
, platform
):
168 """Elaborate this module."""
169 m
= FullAdder
.elaborate(self
, platform
)
170 m
.d
.comb
+= self
.mcarry
.eq((self
.carry
<< 1) & self
.mask
)
174 class PartitionedAdder(Elaboratable
):
175 """Partitioned Adder.
177 Performs the final add. The partition points are included in the
178 actual add (in one of the operands only), which causes a carry over
179 to the next bit. Then the final output *removes* the extra bits from
182 partition: .... P... P... P... P... (32 bits)
183 a : .... .... .... .... .... (32 bits)
184 b : .... .... .... .... .... (32 bits)
185 exp-a : ....P....P....P....P.... (32+4 bits)
186 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
187 exp-o : ....xN...xN...xN...xN... (32+4 bits)
188 o : .... N... N... N... N... (32 bits)
190 :attribute width: the bit width of the input and output. Read-only.
191 :attribute a: the first input to the adder
192 :attribute b: the second input to the adder
193 :attribute output: the sum output
194 :attribute partition_points: the input partition points. Modification not
195 supported, except for by ``Signal.eq``.
198 def __init__(self
, width
, partition_points
):
199 """Create a ``PartitionedAdder``.
201 :param width: the bit width of the input and output
202 :param partition_points: the input partition points
205 self
.a
= Signal(width
)
206 self
.b
= Signal(width
)
207 self
.output
= Signal(width
)
208 self
.partition_points
= PartitionPoints(partition_points
)
209 if not self
.partition_points
.fits_in_width(width
):
210 raise ValueError("partition_points doesn't fit in width")
212 for i
in range(self
.width
):
213 if i
in self
.partition_points
:
216 self
._expanded
_width
= expanded_width
217 # XXX these have to remain here due to some horrible nmigen
218 # simulation bugs involving sync. it is *not* necessary to
219 # have them here, they should (under normal circumstances)
220 # be moved into elaborate, as they are entirely local
221 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
222 self
._expanded
_b
= Signal(expanded_width
) # likewise.
223 self
._expanded
_o
= Signal(expanded_width
) # likewise.
225 def elaborate(self
, platform
):
226 """Elaborate this module."""
229 # store bits in a list, use Cat later. graphviz is much cleaner
230 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
232 # partition points are "breaks" (extra zeros or 1s) in what would
233 # otherwise be a massive long add. when the "break" points are 0,
234 # whatever is in it (in the output) is discarded. however when
235 # there is a "1", it causes a roll-over carry to the *next* bit.
236 # we still ignore the "break" bit in the [intermediate] output,
237 # however by that time we've got the effect that we wanted: the
238 # carry has been carried *over* the break point.
240 for i
in range(self
.width
):
241 if i
in self
.partition_points
:
242 # add extra bit set to 0 + 0 for enabled partition points
243 # and 1 + 0 for disabled partition points
244 ea
.append(self
._expanded
_a
[expanded_index
])
245 al
.append(~self
.partition_points
[i
]) # add extra bit in a
246 eb
.append(self
._expanded
_b
[expanded_index
])
247 bl
.append(C(0)) # yes, add a zero
248 expanded_index
+= 1 # skip the extra point. NOT in the output
249 ea
.append(self
._expanded
_a
[expanded_index
])
250 eb
.append(self
._expanded
_b
[expanded_index
])
251 eo
.append(self
._expanded
_o
[expanded_index
])
254 ol
.append(self
.output
[i
])
257 # combine above using Cat
258 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
259 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
260 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
262 # use only one addition to take advantage of look-ahead carry and
263 # special hardware on FPGAs
264 m
.d
.comb
+= self
._expanded
_o
.eq(
265 self
._expanded
_a
+ self
._expanded
_b
)
269 FULL_ADDER_INPUT_COUNT
= 3
272 class AddReduce(Elaboratable
):
273 """Add list of numbers together.
275 :attribute inputs: input ``Signal``s to be summed. Modification not
276 supported, except for by ``Signal.eq``.
277 :attribute register_levels: List of nesting levels that should have
279 :attribute output: output sum.
280 :attribute partition_points: the input partition points. Modification not
281 supported, except for by ``Signal.eq``.
284 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
285 """Create an ``AddReduce``.
287 :param inputs: input ``Signal``s to be summed.
288 :param output_width: bit-width of ``output``.
289 :param register_levels: List of nesting levels that should have
291 :param partition_points: the input partition points.
293 self
.inputs
= list(inputs
)
294 self
._resized
_inputs
= [
295 Signal(output_width
, name
=f
"resized_inputs[{i}]")
296 for i
in range(len(self
.inputs
))]
297 self
.register_levels
= list(register_levels
)
298 self
.output
= Signal(output_width
)
299 self
.partition_points
= PartitionPoints(partition_points
)
300 if not self
.partition_points
.fits_in_width(output_width
):
301 raise ValueError("partition_points doesn't fit in output_width")
302 self
._reg
_partition
_points
= self
.partition_points
.like()
303 max_level
= AddReduce
.get_max_level(len(self
.inputs
))
304 for level
in self
.register_levels
:
305 if level
> max_level
:
307 "not enough adder levels for specified register levels")
310 def get_max_level(input_count
):
311 """Get the maximum level.
313 All ``register_levels`` must be less than or equal to the maximum
318 groups
= AddReduce
.full_adder_groups(input_count
)
321 input_count
%= FULL_ADDER_INPUT_COUNT
322 input_count
+= 2 * len(groups
)
325 def next_register_levels(self
):
326 """``Iterable`` of ``register_levels`` for next recursive level."""
327 for level
in self
.register_levels
:
332 def full_adder_groups(input_count
):
333 """Get ``inputs`` indices for which a full adder should be built."""
335 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
336 FULL_ADDER_INPUT_COUNT
)
338 def elaborate(self
, platform
):
339 """Elaborate this module."""
342 # resize inputs to correct bit-width and optionally add in
344 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
345 for i
in range(len(self
.inputs
))]
346 if 0 in self
.register_levels
:
347 m
.d
.sync
+= resized_input_assignments
348 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
350 m
.d
.comb
+= resized_input_assignments
351 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
353 groups
= AddReduce
.full_adder_groups(len(self
.inputs
))
354 # if there are no full adders to create, then we handle the base cases
355 # and return, otherwise we go on to the recursive case
357 if len(self
.inputs
) == 0:
358 # use 0 as the default output value
359 m
.d
.comb
+= self
.output
.eq(0)
360 elif len(self
.inputs
) == 1:
361 # handle single input
362 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
364 # base case for adding 2 or more inputs, which get recursively
365 # reduced to 2 inputs
366 assert len(self
.inputs
) == 2
367 adder
= PartitionedAdder(len(self
.output
),
368 self
._reg
_partition
_points
)
369 m
.submodules
.final_adder
= adder
370 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
371 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
372 m
.d
.comb
+= self
.output
.eq(adder
.output
)
374 # go on to handle recursive case
375 intermediate_terms
= []
377 def add_intermediate_term(value
):
378 intermediate_term
= Signal(
380 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
381 intermediate_terms
.append(intermediate_term
)
382 m
.d
.comb
+= intermediate_term
.eq(value
)
384 # store mask in intermediary (simplifies graph)
385 part_mask
= Signal(len(self
.output
), reset_less
=True)
386 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
387 m
.d
.comb
+= part_mask
.eq(mask
)
389 # create full adders for this recursive level.
390 # this shrinks N terms to 2 * (N // 3) plus the remainder
392 adder_i
= MaskedFullAdder(len(self
.output
))
393 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
394 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[i
])
395 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[i
+ 1])
396 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[i
+ 2])
397 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
398 add_intermediate_term(adder_i
.sum)
399 # mask out carry bits to prevent carries between partitions
400 add_intermediate_term(adder_i
.mcarry
)
401 # handle the remaining inputs.
402 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
403 add_intermediate_term(self
._resized
_inputs
[-1])
404 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
405 # Just pass the terms to the next layer, since we wouldn't gain
406 # anything by using a half adder since there would still be 2 terms
407 # and just passing the terms to the next layer saves gates.
408 add_intermediate_term(self
._resized
_inputs
[-2])
409 add_intermediate_term(self
._resized
_inputs
[-1])
411 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
412 # recursive invocation of ``AddReduce``
413 next_level
= AddReduce(intermediate_terms
,
415 self
.next_register_levels(),
416 self
._reg
_partition
_points
)
417 m
.submodules
.next_level
= next_level
418 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
423 OP_MUL_SIGNED_HIGH
= 1
424 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
425 OP_MUL_UNSIGNED_HIGH
= 3
428 def get_term(value
, shift
=0, enabled
=None):
429 if enabled
is not None:
430 value
= Mux(enabled
, value
, 0)
432 value
= Cat(Repl(C(0, 1), shift
), value
)
438 class ProductTerm(Elaboratable
):
439 """ this class creates a single product term (a[..]*b[..]).
440 it has a design flaw in that is the *output* that is selected,
441 where the multiplication(s) are combinatorially generated
445 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
446 self
.a_index
= a_index
447 self
.b_index
= b_index
448 shift
= 8 * (self
.a_index
+ self
.b_index
)
454 self
.ti
= Signal(self
.width
, reset_less
=True)
455 self
.term
= Signal(twidth
, reset_less
=True)
456 self
.a
= Signal(twidth
//2, reset_less
=True)
457 self
.b
= Signal(twidth
//2, reset_less
=True)
458 self
.pb_en
= Signal(pbwid
, reset_less
=True)
461 min_index
= min(self
.a_index
, self
.b_index
)
462 max_index
= max(self
.a_index
, self
.b_index
)
463 for i
in range(min_index
, max_index
):
464 tl
.append(self
.pb_en
[i
])
465 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
467 term_enabled
= Signal(name
=name
, reset_less
=True)
470 self
.enabled
= term_enabled
471 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
473 def elaborate(self
, platform
):
476 if self
.enabled
is not None:
477 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
479 bsa
= Signal(self
.width
, reset_less
=True)
480 bsb
= Signal(self
.width
, reset_less
=True)
481 a_index
, b_index
= self
.a_index
, self
.b_index
483 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
484 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
485 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
486 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
488 #TODO: sort out width issues, get inputs a/b switched on/off.
489 #data going into Muxes is 1/2 the required width
493 bsa = Signal(self.twidth//2, reset_less=True)
494 bsb = Signal(self.twidth//2, reset_less=True)
495 asel = Signal(width, reset_less=True)
496 bsel = Signal(width, reset_less=True)
497 a_index, b_index = self.a_index, self.b_index
498 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
499 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
500 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
501 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
502 m.d.comb += self.ti.eq(bsa * bsb)
503 m.d.comb += self.term.eq(self.ti)
509 class ProductTerms(Elaboratable
):
510 """ creates a bank of product terms. also performs the actual bit-selection
511 this class is to be wrapped with a for-loop on the "a" operand.
512 it creates a second-level for-loop on the "b" operand.
514 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
515 self
.a_index
= a_index
520 self
.a
= Signal(twidth
//2, reset_less
=True)
521 self
.b
= Signal(twidth
//2, reset_less
=True)
522 self
.pb_en
= Signal(pbwid
, reset_less
=True)
523 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
524 for i
in range(blen
)]
526 def elaborate(self
, platform
):
530 for b_index
in range(self
.blen
):
531 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
532 self
.a_index
, b_index
)
533 setattr(m
.submodules
, "term_%d" % b_index
, t
)
535 m
.d
.comb
+= t
.a
.eq(self
.a
)
536 m
.d
.comb
+= t
.b
.eq(self
.b
)
537 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
539 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
543 class LSBNegTerm(Elaboratable
):
545 def __init__(self
, bit_width
):
546 self
.bit_width
= bit_width
547 self
.part
= Signal(reset_less
=True)
548 self
.signed
= Signal(reset_less
=True)
549 self
.op
= Signal(bit_width
, reset_less
=True)
550 self
.msb
= Signal(reset_less
=True)
551 self
.nt
= Signal(bit_width
*2, reset_less
=True)
552 self
.nl
= Signal(bit_width
*2, reset_less
=True)
554 def elaborate(self
, platform
):
557 bit_wid
= self
.bit_width
558 ext
= Repl(0, bit_wid
) # extend output to HI part
560 # determine sign of each incoming number *in this partition*
561 enabled
= Signal(reset_less
=True)
562 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
564 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
565 # negation operation is split into a bitwise not and a +1.
566 # likewise for 16, 32, and 64-bit values.
568 # width-extended 1s complement if a is signed, otherwise zero
569 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
571 # add 1 if signed, otherwise add zero
572 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
577 class Part(Elaboratable
):
578 """ a key class which, depending on the partitioning, will determine
579 what action to take when parts of the output are signed or unsigned.
581 this requires 2 pieces of data *per operand, per partition*:
582 whether the MSB is HI/LO (per partition!), and whether a signed
583 or unsigned operation has been *requested*.
585 once that is determined, signed is basically carried out
586 by splitting 2's complement into 1's complement plus one.
587 1's complement is just a bit-inversion.
589 the extra terms - as separate terms - are then thrown at the
590 AddReduce alongside the multiplication part-results.
592 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
597 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
598 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
599 self
.pbs
= Signal(pbwid
, reset_less
=True)
602 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
603 self
.delayed_parts
= [
604 [Signal(name
=f
"delayed_part_{delay}_{i}")
605 for i
in range(n_parts
)]
606 for delay
in range(n_levels
)]
607 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
608 self
.dplast
= [Signal(name
=f
"dplast_{i}")
609 for i
in range(n_parts
)]
611 self
.not_a_term
= Signal(width
)
612 self
.neg_lsb_a_term
= Signal(width
)
613 self
.not_b_term
= Signal(width
)
614 self
.neg_lsb_b_term
= Signal(width
)
616 def elaborate(self
, platform
):
619 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
620 # negated-temporary copy of partition bits
621 npbs
= Signal
.like(pbs
, reset_less
=True)
622 m
.d
.comb
+= npbs
.eq(~pbs
)
623 byte_count
= 8 // len(parts
)
624 for i
in range(len(parts
)):
626 pbl
.append(npbs
[i
* byte_count
- 1])
627 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
629 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
630 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
631 m
.d
.comb
+= value
.eq(Cat(*pbl
))
632 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
633 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
634 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
635 for j
in range(len(delayed_parts
)-1)]
636 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
638 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
639 self
.not_a_term
, self
.neg_lsb_a_term
, \
640 self
.not_b_term
, self
.neg_lsb_b_term
642 byte_width
= 8 // len(parts
) # byte width
643 bit_wid
= 8 * byte_width
# bit width
644 nat
, nbt
, nla
, nlb
= [], [], [], []
645 for i
in range(len(parts
)):
646 # work out bit-inverted and +1 term for a.
647 pa
= LSBNegTerm(bit_wid
)
648 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
649 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
650 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
651 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
652 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
656 # work out bit-inverted and +1 term for b
657 pb
= LSBNegTerm(bit_wid
)
658 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
659 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
660 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
661 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
662 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
666 # concatenate together and return all 4 results.
667 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
668 not_b_term
.eq(Cat(*nbt
)),
669 neg_lsb_a_term
.eq(Cat(*nla
)),
670 neg_lsb_b_term
.eq(Cat(*nlb
)),
676 class IntermediateOut(Elaboratable
):
677 """ selects the HI/LO part of the multiplication, for a given bit-width
678 the output is also reconstructed in its SIMD (partition) lanes.
680 def __init__(self
, width
, out_wid
, n_parts
):
682 self
.n_parts
= n_parts
683 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
685 self
.intermed
= Signal(out_wid
, reset_less
=True)
686 self
.output
= Signal(out_wid
//2, reset_less
=True)
688 def elaborate(self
, platform
):
694 for i
in range(self
.n_parts
):
695 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
697 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
698 self
.intermed
.bit_select(i
* w
*2, w
),
699 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
701 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
706 class FinalOut(Elaboratable
):
707 """ selects the final output based on the partitioning.
709 each byte is selectable independently, i.e. it is possible
710 that some partitions requested 8-bit computation whilst others
711 requested 16 or 32 bit.
713 def __init__(self
, out_wid
):
715 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
716 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
717 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
719 self
.i8
= Signal(out_wid
, reset_less
=True)
720 self
.i16
= Signal(out_wid
, reset_less
=True)
721 self
.i32
= Signal(out_wid
, reset_less
=True)
722 self
.i64
= Signal(out_wid
, reset_less
=True)
725 self
.out
= Signal(out_wid
, reset_less
=True)
727 def elaborate(self
, platform
):
731 # select one of the outputs: d8 selects i8, d16 selects i16
732 # d32 selects i32, and the default is i64.
733 # d8 and d16 are ORed together in the first Mux
734 # then the 2nd selects either i8 or i16.
735 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
736 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
738 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
739 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
740 self
.i16
.bit_select(i
* 8, 8)),
741 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
742 self
.i64
.bit_select(i
* 8, 8))))
744 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
748 class OrMod(Elaboratable
):
749 """ ORs four values together in a hierarchical tree
751 def __init__(self
, wid
):
753 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
755 self
.orout
= Signal(wid
, reset_less
=True)
757 def elaborate(self
, platform
):
759 or1
= Signal(self
.wid
, reset_less
=True)
760 or2
= Signal(self
.wid
, reset_less
=True)
761 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
762 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
763 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
768 class Signs(Elaboratable
):
769 """ determines whether a or b are signed numbers
770 based on the required operation type (OP_MUL_*)
774 self
.part_ops
= Signal(2, reset_less
=True)
775 self
.a_signed
= Signal(reset_less
=True)
776 self
.b_signed
= Signal(reset_less
=True)
778 def elaborate(self
, platform
):
782 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
783 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
784 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
785 m
.d
.comb
+= self
.a_signed
.eq(asig
)
786 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
791 class Mul8_16_32_64(Elaboratable
):
792 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
794 Supports partitioning into any combination of 8, 16, 32, and 64-bit
795 partitions on naturally-aligned boundaries. Supports the operation being
796 set for each partition independently.
798 :attribute part_pts: the input partition points. Has a partition point at
799 multiples of 8 in 0 < i < 64. Each partition point's associated
800 ``Value`` is a ``Signal``. Modification not supported, except for by
802 :attribute part_ops: the operation for each byte. The operation for a
803 particular partition is selected by assigning the selected operation
804 code to each byte in the partition. The allowed operation codes are:
806 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
807 RISC-V's `mul` instruction.
808 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
809 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
811 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
812 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
813 `mulhsu` instruction.
814 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
815 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
819 def __init__(self
, register_levels
=()):
820 """ register_levels: specifies the points in the cascade at which
821 flip-flops are to be inserted.
825 self
.register_levels
= list(register_levels
)
828 self
.part_pts
= PartitionPoints()
829 for i
in range(8, 64, 8):
830 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
831 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
835 # intermediates (needed for unit tests)
836 self
._intermediate
_output
= Signal(128)
839 self
.output
= Signal(64)
841 def _part_byte(self
, index
):
842 if index
== -1 or index
== 7:
844 assert index
>= 0 and index
< 8
845 return self
.part_pts
[index
* 8 + 8]
847 def elaborate(self
, platform
):
851 pbs
= Signal(8, reset_less
=True)
854 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
855 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
857 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
864 setattr(m
.submodules
, "signs%d" % i
, s
)
865 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
868 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
870 for delay
in range(1 + len(self
.register_levels
))]
871 for i
in range(len(self
.part_ops
)):
872 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
873 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
874 for j
in range(len(self
.register_levels
))]
876 n_levels
= len(self
.register_levels
)+1
877 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
878 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
879 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
880 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
881 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
882 for mod
in [part_8
, part_16
, part_32
, part_64
]:
883 m
.d
.comb
+= mod
.a
.eq(self
.a
)
884 m
.d
.comb
+= mod
.b
.eq(self
.b
)
885 for i
in range(len(signs
)):
886 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
887 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
888 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
889 nat_l
.append(mod
.not_a_term
)
890 nbt_l
.append(mod
.not_b_term
)
891 nla_l
.append(mod
.neg_lsb_a_term
)
892 nlb_l
.append(mod
.neg_lsb_b_term
)
896 for a_index
in range(8):
897 t
= ProductTerms(8, 128, 8, a_index
, 8)
898 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
900 m
.d
.comb
+= t
.a
.eq(self
.a
)
901 m
.d
.comb
+= t
.b
.eq(self
.b
)
902 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
907 # it's fine to bitwise-or data together since they are never enabled
909 m
.submodules
.nat_or
= nat_or
= OrMod(128)
910 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
911 m
.submodules
.nla_or
= nla_or
= OrMod(128)
912 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
913 for l
, mod
in [(nat_l
, nat_or
),
917 for i
in range(len(l
)):
918 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
919 terms
.append(mod
.orout
)
921 expanded_part_pts
= PartitionPoints()
922 for i
, v
in self
.part_pts
.items():
923 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
924 expanded_part_pts
[i
* 2] = signal
925 m
.d
.comb
+= signal
.eq(v
)
927 add_reduce
= AddReduce(terms
,
929 self
.register_levels
,
931 m
.submodules
.add_reduce
= add_reduce
932 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
934 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
935 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
937 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
940 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
941 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
943 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
946 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
947 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
949 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
952 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
953 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
955 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
958 m
.submodules
.finalout
= finalout
= FinalOut(64)
959 for i
in range(len(part_8
.delayed_parts
[-1])):
960 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
961 for i
in range(len(part_16
.delayed_parts
[-1])):
962 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
963 for i
in range(len(part_32
.delayed_parts
[-1])):
964 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
965 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
966 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
967 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
968 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
969 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
974 if __name__
== "__main__":
978 m
._intermediate
_output
,
981 *m
.part_pts
.values()])