1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class MaskedFullAdder(Elaboratable
):
144 """Masked Full Adder.
146 :attribute mask: the carry partition mask
147 :attribute in0: the first input
148 :attribute in1: the second input
149 :attribute in2: the third input
150 :attribute sum: the sum output
151 :attribute mcarry: the masked carry output
153 FullAdders are always used with a "mask" on the output. To keep
154 the graphviz "clean", this class performs the masking here rather
155 than inside a large for-loop.
157 See the following discussion as to why this is no longer derived
158 from FullAdder. Each carry is shifted here *before* being ANDed
159 with the mask, so that an AOI cell may be used (which is more
161 https://en.wikipedia.org/wiki/AND-OR-Invert
162 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
165 def __init__(self
, width
):
166 """Create a ``MaskedFullAdder``.
168 :param width: the bit width of the input and output
171 self
.mask
= Signal(width
, reset_less
=True)
172 self
.mcarry
= Signal(width
, reset_less
=True)
173 self
.in0
= Signal(width
, reset_less
=True)
174 self
.in1
= Signal(width
, reset_less
=True)
175 self
.in2
= Signal(width
, reset_less
=True)
176 self
.sum = Signal(width
, reset_less
=True)
178 def elaborate(self
, platform
):
179 """Elaborate this module."""
181 s1
= Signal(self
.width
, reset_less
=True)
182 s2
= Signal(self
.width
, reset_less
=True)
183 s3
= Signal(self
.width
, reset_less
=True)
184 c1
= Signal(self
.width
, reset_less
=True)
185 c2
= Signal(self
.width
, reset_less
=True)
186 c3
= Signal(self
.width
, reset_less
=True)
187 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
188 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
189 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
190 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
191 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
192 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
193 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
194 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
198 class PartitionedAdder(Elaboratable
):
199 """Partitioned Adder.
201 Performs the final add. The partition points are included in the
202 actual add (in one of the operands only), which causes a carry over
203 to the next bit. Then the final output *removes* the extra bits from
206 partition: .... P... P... P... P... (32 bits)
207 a : .... .... .... .... .... (32 bits)
208 b : .... .... .... .... .... (32 bits)
209 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
210 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
211 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
212 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
214 :attribute width: the bit width of the input and output. Read-only.
215 :attribute a: the first input to the adder
216 :attribute b: the second input to the adder
217 :attribute output: the sum output
218 :attribute partition_points: the input partition points. Modification not
219 supported, except for by ``Signal.eq``.
222 def __init__(self
, width
, partition_points
):
223 """Create a ``PartitionedAdder``.
225 :param width: the bit width of the input and output
226 :param partition_points: the input partition points
229 self
.a
= Signal(width
)
230 self
.b
= Signal(width
)
231 self
.output
= Signal(width
)
232 self
.partition_points
= PartitionPoints(partition_points
)
233 if not self
.partition_points
.fits_in_width(width
):
234 raise ValueError("partition_points doesn't fit in width")
236 for i
in range(self
.width
):
237 if i
in self
.partition_points
:
240 self
._expanded
_width
= expanded_width
241 # XXX these have to remain here due to some horrible nmigen
242 # simulation bugs involving sync. it is *not* necessary to
243 # have them here, they should (under normal circumstances)
244 # be moved into elaborate, as they are entirely local
245 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
246 self
._expanded
_b
= Signal(expanded_width
) # likewise.
247 self
._expanded
_o
= Signal(expanded_width
) # likewise.
249 def elaborate(self
, platform
):
250 """Elaborate this module."""
253 # store bits in a list, use Cat later. graphviz is much cleaner
254 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
256 # partition points are "breaks" (extra zeros or 1s) in what would
257 # otherwise be a massive long add. when the "break" points are 0,
258 # whatever is in it (in the output) is discarded. however when
259 # there is a "1", it causes a roll-over carry to the *next* bit.
260 # we still ignore the "break" bit in the [intermediate] output,
261 # however by that time we've got the effect that we wanted: the
262 # carry has been carried *over* the break point.
264 for i
in range(self
.width
):
265 if i
in self
.partition_points
:
266 # add extra bit set to 0 + 0 for enabled partition points
267 # and 1 + 0 for disabled partition points
268 ea
.append(self
._expanded
_a
[expanded_index
])
269 al
.append(~self
.partition_points
[i
]) # add extra bit in a
270 eb
.append(self
._expanded
_b
[expanded_index
])
271 bl
.append(C(0)) # yes, add a zero
272 expanded_index
+= 1 # skip the extra point. NOT in the output
273 ea
.append(self
._expanded
_a
[expanded_index
])
274 eb
.append(self
._expanded
_b
[expanded_index
])
275 eo
.append(self
._expanded
_o
[expanded_index
])
278 ol
.append(self
.output
[i
])
281 # combine above using Cat
282 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
283 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
284 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
286 # use only one addition to take advantage of look-ahead carry and
287 # special hardware on FPGAs
288 m
.d
.comb
+= self
._expanded
_o
.eq(
289 self
._expanded
_a
+ self
._expanded
_b
)
293 FULL_ADDER_INPUT_COUNT
= 3
296 class AddReduceSingle(Elaboratable
):
297 """Add list of numbers together.
299 :attribute inputs: input ``Signal``s to be summed. Modification not
300 supported, except for by ``Signal.eq``.
301 :attribute register_levels: List of nesting levels that should have
303 :attribute output: output sum.
304 :attribute partition_points: the input partition points. Modification not
305 supported, except for by ``Signal.eq``.
308 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
309 """Create an ``AddReduce``.
311 :param inputs: input ``Signal``s to be summed.
312 :param output_width: bit-width of ``output``.
313 :param register_levels: List of nesting levels that should have
315 :param partition_points: the input partition points.
317 self
.inputs
= list(inputs
)
318 self
._resized
_inputs
= [
319 Signal(output_width
, name
=f
"resized_inputs[{i}]")
320 for i
in range(len(self
.inputs
))]
321 self
.register_levels
= list(register_levels
)
322 self
.output
= Signal(output_width
)
323 self
.partition_points
= PartitionPoints(partition_points
)
324 if not self
.partition_points
.fits_in_width(output_width
):
325 raise ValueError("partition_points doesn't fit in output_width")
326 self
._reg
_partition
_points
= self
.partition_points
.like()
328 max_level
= AddReduce
.get_max_level(len(self
.inputs
))
329 for level
in self
.register_levels
:
330 if level
> max_level
:
332 "not enough adder levels for specified register levels")
335 def get_max_level(input_count
):
336 """Get the maximum level.
338 All ``register_levels`` must be less than or equal to the maximum
343 groups
= AddReduce
.full_adder_groups(input_count
)
346 input_count
%= FULL_ADDER_INPUT_COUNT
347 input_count
+= 2 * len(groups
)
350 def full_adder_groups(input_count
):
351 """Get ``inputs`` indices for which a full adder should be built."""
353 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
354 FULL_ADDER_INPUT_COUNT
)
356 def _elaborate(self
, platform
):
357 """Elaborate this module."""
360 # resize inputs to correct bit-width and optionally add in
362 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
363 for i
in range(len(self
.inputs
))]
364 if 0 in self
.register_levels
:
365 m
.d
.sync
+= resized_input_assignments
366 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
368 m
.d
.comb
+= resized_input_assignments
369 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
371 groups
= AddReduceSingle
.full_adder_groups(len(self
.inputs
))
372 # if there are no full adders to create, then we handle the base cases
373 # and return, otherwise we go on to the recursive case
375 if len(self
.inputs
) == 0:
376 # use 0 as the default output value
377 m
.d
.comb
+= self
.output
.eq(0)
378 elif len(self
.inputs
) == 1:
379 # handle single input
380 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
382 # base case for adding 2 or more inputs, which get recursively
383 # reduced to 2 inputs
384 assert len(self
.inputs
) == 2
385 adder
= PartitionedAdder(len(self
.output
),
386 self
._reg
_partition
_points
)
387 m
.submodules
.final_adder
= adder
388 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
389 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
390 m
.d
.comb
+= self
.output
.eq(adder
.output
)
393 # go on to prepare recursive case
394 intermediate_terms
= []
396 def add_intermediate_term(value
):
397 intermediate_term
= Signal(
399 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
400 intermediate_terms
.append(intermediate_term
)
401 m
.d
.comb
+= intermediate_term
.eq(value
)
403 # store mask in intermediary (simplifies graph)
404 part_mask
= Signal(len(self
.output
), reset_less
=True)
405 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
406 m
.d
.comb
+= part_mask
.eq(mask
)
408 # create full adders for this recursive level.
409 # this shrinks N terms to 2 * (N // 3) plus the remainder
411 adder_i
= MaskedFullAdder(len(self
.output
))
412 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
413 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[i
])
414 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[i
+ 1])
415 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[i
+ 2])
416 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
417 # add both the sum and the masked-carry to the next level.
418 # 3 inputs have now been reduced to 2...
419 add_intermediate_term(adder_i
.sum)
420 add_intermediate_term(adder_i
.mcarry
)
421 # handle the remaining inputs.
422 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
423 add_intermediate_term(self
._resized
_inputs
[-1])
424 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
425 # Just pass the terms to the next layer, since we wouldn't gain
426 # anything by using a half adder since there would still be 2 terms
427 # and just passing the terms to the next layer saves gates.
428 add_intermediate_term(self
._resized
_inputs
[-2])
429 add_intermediate_term(self
._resized
_inputs
[-1])
431 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
433 return intermediate_terms
, m
436 class AddReduce(AddReduceSingle
):
437 """Recursively Add list of numbers together.
439 :attribute inputs: input ``Signal``s to be summed. Modification not
440 supported, except for by ``Signal.eq``.
441 :attribute register_levels: List of nesting levels that should have
443 :attribute output: output sum.
444 :attribute partition_points: the input partition points. Modification not
445 supported, except for by ``Signal.eq``.
448 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
449 """Create an ``AddReduce``.
451 :param inputs: input ``Signal``s to be summed.
452 :param output_width: bit-width of ``output``.
453 :param register_levels: List of nesting levels that should have
455 :param partition_points: the input partition points.
457 AddReduceSingle
.__init
__(self
, inputs
, output_width
, register_levels
,
460 def next_register_levels(self
):
461 """``Iterable`` of ``register_levels`` for next recursive level."""
462 for level
in self
.register_levels
:
466 def elaborate(self
, platform
):
467 """Elaborate this module."""
468 intermediate_terms
, m
= AddReduceSingle
._elaborate
(self
, platform
)
469 if intermediate_terms
is None:
472 # recursive invocation of ``AddReduce``
473 next_level
= AddReduce(intermediate_terms
,
475 self
.next_register_levels(),
476 self
._reg
_partition
_points
)
477 m
.submodules
.next_level
= next_level
478 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
483 OP_MUL_SIGNED_HIGH
= 1
484 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
485 OP_MUL_UNSIGNED_HIGH
= 3
488 def get_term(value
, shift
=0, enabled
=None):
489 if enabled
is not None:
490 value
= Mux(enabled
, value
, 0)
492 value
= Cat(Repl(C(0, 1), shift
), value
)
498 class ProductTerm(Elaboratable
):
499 """ this class creates a single product term (a[..]*b[..]).
500 it has a design flaw in that is the *output* that is selected,
501 where the multiplication(s) are combinatorially generated
505 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
506 self
.a_index
= a_index
507 self
.b_index
= b_index
508 shift
= 8 * (self
.a_index
+ self
.b_index
)
514 self
.ti
= Signal(self
.width
, reset_less
=True)
515 self
.term
= Signal(twidth
, reset_less
=True)
516 self
.a
= Signal(twidth
//2, reset_less
=True)
517 self
.b
= Signal(twidth
//2, reset_less
=True)
518 self
.pb_en
= Signal(pbwid
, reset_less
=True)
521 min_index
= min(self
.a_index
, self
.b_index
)
522 max_index
= max(self
.a_index
, self
.b_index
)
523 for i
in range(min_index
, max_index
):
524 tl
.append(self
.pb_en
[i
])
525 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
527 term_enabled
= Signal(name
=name
, reset_less
=True)
530 self
.enabled
= term_enabled
531 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
533 def elaborate(self
, platform
):
536 if self
.enabled
is not None:
537 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
539 bsa
= Signal(self
.width
, reset_less
=True)
540 bsb
= Signal(self
.width
, reset_less
=True)
541 a_index
, b_index
= self
.a_index
, self
.b_index
543 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
544 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
545 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
546 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
548 #TODO: sort out width issues, get inputs a/b switched on/off.
549 #data going into Muxes is 1/2 the required width
553 bsa = Signal(self.twidth//2, reset_less=True)
554 bsb = Signal(self.twidth//2, reset_less=True)
555 asel = Signal(width, reset_less=True)
556 bsel = Signal(width, reset_less=True)
557 a_index, b_index = self.a_index, self.b_index
558 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
559 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
560 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
561 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
562 m.d.comb += self.ti.eq(bsa * bsb)
563 m.d.comb += self.term.eq(self.ti)
569 class ProductTerms(Elaboratable
):
570 """ creates a bank of product terms. also performs the actual bit-selection
571 this class is to be wrapped with a for-loop on the "a" operand.
572 it creates a second-level for-loop on the "b" operand.
574 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
575 self
.a_index
= a_index
580 self
.a
= Signal(twidth
//2, reset_less
=True)
581 self
.b
= Signal(twidth
//2, reset_less
=True)
582 self
.pb_en
= Signal(pbwid
, reset_less
=True)
583 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
584 for i
in range(blen
)]
586 def elaborate(self
, platform
):
590 for b_index
in range(self
.blen
):
591 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
592 self
.a_index
, b_index
)
593 setattr(m
.submodules
, "term_%d" % b_index
, t
)
595 m
.d
.comb
+= t
.a
.eq(self
.a
)
596 m
.d
.comb
+= t
.b
.eq(self
.b
)
597 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
599 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
603 class LSBNegTerm(Elaboratable
):
605 def __init__(self
, bit_width
):
606 self
.bit_width
= bit_width
607 self
.part
= Signal(reset_less
=True)
608 self
.signed
= Signal(reset_less
=True)
609 self
.op
= Signal(bit_width
, reset_less
=True)
610 self
.msb
= Signal(reset_less
=True)
611 self
.nt
= Signal(bit_width
*2, reset_less
=True)
612 self
.nl
= Signal(bit_width
*2, reset_less
=True)
614 def elaborate(self
, platform
):
617 bit_wid
= self
.bit_width
618 ext
= Repl(0, bit_wid
) # extend output to HI part
620 # determine sign of each incoming number *in this partition*
621 enabled
= Signal(reset_less
=True)
622 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
624 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
625 # negation operation is split into a bitwise not and a +1.
626 # likewise for 16, 32, and 64-bit values.
628 # width-extended 1s complement if a is signed, otherwise zero
629 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
631 # add 1 if signed, otherwise add zero
632 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
637 class Part(Elaboratable
):
638 """ a key class which, depending on the partitioning, will determine
639 what action to take when parts of the output are signed or unsigned.
641 this requires 2 pieces of data *per operand, per partition*:
642 whether the MSB is HI/LO (per partition!), and whether a signed
643 or unsigned operation has been *requested*.
645 once that is determined, signed is basically carried out
646 by splitting 2's complement into 1's complement plus one.
647 1's complement is just a bit-inversion.
649 the extra terms - as separate terms - are then thrown at the
650 AddReduce alongside the multiplication part-results.
652 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
657 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
658 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
659 self
.pbs
= Signal(pbwid
, reset_less
=True)
662 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
663 self
.delayed_parts
= [
664 [Signal(name
=f
"delayed_part_{delay}_{i}")
665 for i
in range(n_parts
)]
666 for delay
in range(n_levels
)]
667 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
668 self
.dplast
= [Signal(name
=f
"dplast_{i}")
669 for i
in range(n_parts
)]
671 self
.not_a_term
= Signal(width
)
672 self
.neg_lsb_a_term
= Signal(width
)
673 self
.not_b_term
= Signal(width
)
674 self
.neg_lsb_b_term
= Signal(width
)
676 def elaborate(self
, platform
):
679 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
680 # negated-temporary copy of partition bits
681 npbs
= Signal
.like(pbs
, reset_less
=True)
682 m
.d
.comb
+= npbs
.eq(~pbs
)
683 byte_count
= 8 // len(parts
)
684 for i
in range(len(parts
)):
686 pbl
.append(npbs
[i
* byte_count
- 1])
687 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
689 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
690 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
691 m
.d
.comb
+= value
.eq(Cat(*pbl
))
692 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
693 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
694 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
695 for j
in range(len(delayed_parts
)-1)]
696 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
698 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
699 self
.not_a_term
, self
.neg_lsb_a_term
, \
700 self
.not_b_term
, self
.neg_lsb_b_term
702 byte_width
= 8 // len(parts
) # byte width
703 bit_wid
= 8 * byte_width
# bit width
704 nat
, nbt
, nla
, nlb
= [], [], [], []
705 for i
in range(len(parts
)):
706 # work out bit-inverted and +1 term for a.
707 pa
= LSBNegTerm(bit_wid
)
708 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
709 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
710 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
711 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
712 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
716 # work out bit-inverted and +1 term for b
717 pb
= LSBNegTerm(bit_wid
)
718 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
719 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
720 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
721 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
722 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
726 # concatenate together and return all 4 results.
727 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
728 not_b_term
.eq(Cat(*nbt
)),
729 neg_lsb_a_term
.eq(Cat(*nla
)),
730 neg_lsb_b_term
.eq(Cat(*nlb
)),
736 class IntermediateOut(Elaboratable
):
737 """ selects the HI/LO part of the multiplication, for a given bit-width
738 the output is also reconstructed in its SIMD (partition) lanes.
740 def __init__(self
, width
, out_wid
, n_parts
):
742 self
.n_parts
= n_parts
743 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
745 self
.intermed
= Signal(out_wid
, reset_less
=True)
746 self
.output
= Signal(out_wid
//2, reset_less
=True)
748 def elaborate(self
, platform
):
754 for i
in range(self
.n_parts
):
755 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
757 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
758 self
.intermed
.bit_select(i
* w
*2, w
),
759 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
761 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
766 class FinalOut(Elaboratable
):
767 """ selects the final output based on the partitioning.
769 each byte is selectable independently, i.e. it is possible
770 that some partitions requested 8-bit computation whilst others
771 requested 16 or 32 bit.
773 def __init__(self
, out_wid
):
775 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
776 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
777 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
779 self
.i8
= Signal(out_wid
, reset_less
=True)
780 self
.i16
= Signal(out_wid
, reset_less
=True)
781 self
.i32
= Signal(out_wid
, reset_less
=True)
782 self
.i64
= Signal(out_wid
, reset_less
=True)
785 self
.out
= Signal(out_wid
, reset_less
=True)
787 def elaborate(self
, platform
):
791 # select one of the outputs: d8 selects i8, d16 selects i16
792 # d32 selects i32, and the default is i64.
793 # d8 and d16 are ORed together in the first Mux
794 # then the 2nd selects either i8 or i16.
795 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
796 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
798 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
799 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
800 self
.i16
.bit_select(i
* 8, 8)),
801 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
802 self
.i64
.bit_select(i
* 8, 8))))
804 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
808 class OrMod(Elaboratable
):
809 """ ORs four values together in a hierarchical tree
811 def __init__(self
, wid
):
813 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
815 self
.orout
= Signal(wid
, reset_less
=True)
817 def elaborate(self
, platform
):
819 or1
= Signal(self
.wid
, reset_less
=True)
820 or2
= Signal(self
.wid
, reset_less
=True)
821 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
822 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
823 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
828 class Signs(Elaboratable
):
829 """ determines whether a or b are signed numbers
830 based on the required operation type (OP_MUL_*)
834 self
.part_ops
= Signal(2, reset_less
=True)
835 self
.a_signed
= Signal(reset_less
=True)
836 self
.b_signed
= Signal(reset_less
=True)
838 def elaborate(self
, platform
):
842 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
843 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
844 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
845 m
.d
.comb
+= self
.a_signed
.eq(asig
)
846 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
851 class Mul8_16_32_64(Elaboratable
):
852 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
854 Supports partitioning into any combination of 8, 16, 32, and 64-bit
855 partitions on naturally-aligned boundaries. Supports the operation being
856 set for each partition independently.
858 :attribute part_pts: the input partition points. Has a partition point at
859 multiples of 8 in 0 < i < 64. Each partition point's associated
860 ``Value`` is a ``Signal``. Modification not supported, except for by
862 :attribute part_ops: the operation for each byte. The operation for a
863 particular partition is selected by assigning the selected operation
864 code to each byte in the partition. The allowed operation codes are:
866 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
867 RISC-V's `mul` instruction.
868 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
869 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
871 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
872 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
873 `mulhsu` instruction.
874 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
875 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
879 def __init__(self
, register_levels
=()):
880 """ register_levels: specifies the points in the cascade at which
881 flip-flops are to be inserted.
885 self
.register_levels
= list(register_levels
)
888 self
.part_pts
= PartitionPoints()
889 for i
in range(8, 64, 8):
890 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
891 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
895 # intermediates (needed for unit tests)
896 self
._intermediate
_output
= Signal(128)
899 self
.output
= Signal(64)
901 def _part_byte(self
, index
):
902 if index
== -1 or index
== 7:
904 assert index
>= 0 and index
< 8
905 return self
.part_pts
[index
* 8 + 8]
907 def elaborate(self
, platform
):
911 pbs
= Signal(8, reset_less
=True)
914 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
915 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
917 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
924 setattr(m
.submodules
, "signs%d" % i
, s
)
925 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
928 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
930 for delay
in range(1 + len(self
.register_levels
))]
931 for i
in range(len(self
.part_ops
)):
932 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
933 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
934 for j
in range(len(self
.register_levels
))]
936 n_levels
= len(self
.register_levels
)+1
937 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
938 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
939 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
940 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
941 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
942 for mod
in [part_8
, part_16
, part_32
, part_64
]:
943 m
.d
.comb
+= mod
.a
.eq(self
.a
)
944 m
.d
.comb
+= mod
.b
.eq(self
.b
)
945 for i
in range(len(signs
)):
946 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
947 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
948 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
949 nat_l
.append(mod
.not_a_term
)
950 nbt_l
.append(mod
.not_b_term
)
951 nla_l
.append(mod
.neg_lsb_a_term
)
952 nlb_l
.append(mod
.neg_lsb_b_term
)
956 for a_index
in range(8):
957 t
= ProductTerms(8, 128, 8, a_index
, 8)
958 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
960 m
.d
.comb
+= t
.a
.eq(self
.a
)
961 m
.d
.comb
+= t
.b
.eq(self
.b
)
962 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
967 # it's fine to bitwise-or data together since they are never enabled
969 m
.submodules
.nat_or
= nat_or
= OrMod(128)
970 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
971 m
.submodules
.nla_or
= nla_or
= OrMod(128)
972 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
973 for l
, mod
in [(nat_l
, nat_or
),
977 for i
in range(len(l
)):
978 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
979 terms
.append(mod
.orout
)
981 expanded_part_pts
= PartitionPoints()
982 for i
, v
in self
.part_pts
.items():
983 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
984 expanded_part_pts
[i
* 2] = signal
985 m
.d
.comb
+= signal
.eq(v
)
987 add_reduce
= AddReduce(terms
,
989 self
.register_levels
,
991 m
.submodules
.add_reduce
= add_reduce
992 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
994 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
995 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
997 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1000 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1001 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1003 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1006 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1007 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1009 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1012 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1013 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1015 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
1018 m
.submodules
.finalout
= finalout
= FinalOut(64)
1019 for i
in range(len(part_8
.delayed_parts
[-1])):
1020 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
1021 for i
in range(len(part_16
.delayed_parts
[-1])):
1022 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
1023 for i
in range(len(part_32
.delayed_parts
[-1])):
1024 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
1025 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1026 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1027 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1028 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1029 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1034 if __name__
== "__main__":
1038 m
._intermediate
_output
,
1041 *m
.part_pts
.values()])