1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0, mul
=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
60 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
61 retval
= PartitionPoints()
62 for point
, enabled
in self
.items():
64 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self
.keys()) != set(rhs
.keys()):
70 raise ValueError("incompatible point set")
71 for point
, enabled
in self
.items():
72 yield enabled
.eq(rhs
[point
])
74 def as_mask(self
, width
):
75 """Create a bit-mask from `self`.
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
80 :param width: the bit width of the resulting mask
83 for i
in range(width
):
90 def get_max_partition_count(self
, width
):
91 """Get the maximum number of partitions.
93 Gets the number of partitions when all partition points are enabled.
96 for point
in self
.keys():
101 def fits_in_width(self
, width
):
102 """Check if all partition points are smaller than `width`."""
103 for point
in self
.keys():
108 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
109 if index
== -1 or index
== 7:
111 assert index
>= 0 and index
< 8
112 return self
[(index
* 8 + 8)*mfactor
]
115 class FullAdder(Elaboratable
):
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
130 def __init__(self
, width
):
131 """Create a ``FullAdder``.
133 :param width: the bit width of the input and output
135 self
.in0
= Signal(width
)
136 self
.in1
= Signal(width
)
137 self
.in2
= Signal(width
)
138 self
.sum = Signal(width
)
139 self
.carry
= Signal(width
)
141 def elaborate(self
, platform
):
142 """Elaborate this module."""
144 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
145 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
146 |
(self
.in1
& self
.in2
)
147 |
(self
.in2
& self
.in0
))
151 class MaskedFullAdder(Elaboratable
):
152 """Masked Full Adder.
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
173 def __init__(self
, width
):
174 """Create a ``MaskedFullAdder``.
176 :param width: the bit width of the input and output
179 self
.mask
= Signal(width
, reset_less
=True)
180 self
.mcarry
= Signal(width
, reset_less
=True)
181 self
.in0
= Signal(width
, reset_less
=True)
182 self
.in1
= Signal(width
, reset_less
=True)
183 self
.in2
= Signal(width
, reset_less
=True)
184 self
.sum = Signal(width
, reset_less
=True)
186 def elaborate(self
, platform
):
187 """Elaborate this module."""
189 s1
= Signal(self
.width
, reset_less
=True)
190 s2
= Signal(self
.width
, reset_less
=True)
191 s3
= Signal(self
.width
, reset_less
=True)
192 c1
= Signal(self
.width
, reset_less
=True)
193 c2
= Signal(self
.width
, reset_less
=True)
194 c3
= Signal(self
.width
, reset_less
=True)
195 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
196 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
197 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
198 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
199 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
200 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
201 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
202 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
206 class PartitionedAdder(Elaboratable
):
207 """Partitioned Adder.
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
230 def __init__(self
, width
, partition_points
):
231 """Create a ``PartitionedAdder``.
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
237 self
.a
= Signal(width
)
238 self
.b
= Signal(width
)
239 self
.output
= Signal(width
)
240 self
.partition_points
= PartitionPoints(partition_points
)
241 if not self
.partition_points
.fits_in_width(width
):
242 raise ValueError("partition_points doesn't fit in width")
244 for i
in range(self
.width
):
245 if i
in self
.partition_points
:
248 self
._expanded
_width
= expanded_width
249 # XXX these have to remain here due to some horrible nmigen
250 # simulation bugs involving sync. it is *not* necessary to
251 # have them here, they should (under normal circumstances)
252 # be moved into elaborate, as they are entirely local
253 self
._expanded
_a
= Signal(expanded_width
) # includes extra part-points
254 self
._expanded
_b
= Signal(expanded_width
) # likewise.
255 self
._expanded
_o
= Signal(expanded_width
) # likewise.
257 def elaborate(self
, platform
):
258 """Elaborate this module."""
261 # store bits in a list, use Cat later. graphviz is much cleaner
262 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
264 # partition points are "breaks" (extra zeros or 1s) in what would
265 # otherwise be a massive long add. when the "break" points are 0,
266 # whatever is in it (in the output) is discarded. however when
267 # there is a "1", it causes a roll-over carry to the *next* bit.
268 # we still ignore the "break" bit in the [intermediate] output,
269 # however by that time we've got the effect that we wanted: the
270 # carry has been carried *over* the break point.
272 for i
in range(self
.width
):
273 if i
in self
.partition_points
:
274 # add extra bit set to 0 + 0 for enabled partition points
275 # and 1 + 0 for disabled partition points
276 ea
.append(self
._expanded
_a
[expanded_index
])
277 al
.append(~self
.partition_points
[i
]) # add extra bit in a
278 eb
.append(self
._expanded
_b
[expanded_index
])
279 bl
.append(C(0)) # yes, add a zero
280 expanded_index
+= 1 # skip the extra point. NOT in the output
281 ea
.append(self
._expanded
_a
[expanded_index
])
282 eb
.append(self
._expanded
_b
[expanded_index
])
283 eo
.append(self
._expanded
_o
[expanded_index
])
286 ol
.append(self
.output
[i
])
289 # combine above using Cat
290 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
291 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
292 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
294 # use only one addition to take advantage of look-ahead carry and
295 # special hardware on FPGAs
296 m
.d
.comb
+= self
._expanded
_o
.eq(
297 self
._expanded
_a
+ self
._expanded
_b
)
301 FULL_ADDER_INPUT_COUNT
= 3
305 def __init__(self
, ppoints
, n_inputs
, output_width
, n_parts
):
306 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}")
307 for i
in range(n_parts
)]
308 self
.inputs
= [Signal(output_width
, name
=f
"inputs[{i}]")
309 for i
in range(n_inputs
)]
310 self
.reg_partition_points
= ppoints
.like()
313 return [self
.reg_partition_points
.eq(rhs
.reg_partition_points
)] + \
314 [self
.inputs
[i
].eq(rhs
.inputs
[i
])
315 for i
in range(len(self
.inputs
))] + \
316 [self
.part_ops
[i
].eq(rhs
.part_ops
[i
])
317 for i
in range(len(self
.part_ops
))]
320 class FinalAdd(Elaboratable
):
321 """ Final stage of add reduce
324 def __init__(self
, n_inputs
, output_width
, n_parts
, register_levels
,
326 self
.i
= AddReduceData(partition_points
, n_inputs
,
327 output_width
, n_parts
)
328 self
.n_inputs
= n_inputs
329 self
.n_parts
= n_parts
330 self
.out_part_ops
= self
.i
.part_ops
331 self
._resized
_inputs
= self
.i
.inputs
332 self
.register_levels
= list(register_levels
)
333 self
.output
= Signal(output_width
)
334 self
.partition_points
= PartitionPoints(partition_points
)
335 if not self
.partition_points
.fits_in_width(output_width
):
336 raise ValueError("partition_points doesn't fit in output_width")
337 self
._reg
_partition
_points
= self
.i
.reg_partition_points
338 self
.intermediate_terms
= []
340 def elaborate(self
, platform
):
341 """Elaborate this module."""
344 if self
.n_inputs
== 0:
345 # use 0 as the default output value
346 m
.d
.comb
+= self
.output
.eq(0)
347 elif self
.n_inputs
== 1:
348 # handle single input
349 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
351 # base case for adding 2 inputs
352 assert self
.n_inputs
== 2
353 adder
= PartitionedAdder(len(self
.output
),
354 self
._reg
_partition
_points
)
355 m
.submodules
.final_adder
= adder
356 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
357 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
358 m
.d
.comb
+= self
.output
.eq(adder
.output
)
362 class AddReduceSingle(Elaboratable
):
363 """Add list of numbers together.
365 :attribute inputs: input ``Signal``s to be summed. Modification not
366 supported, except for by ``Signal.eq``.
367 :attribute register_levels: List of nesting levels that should have
369 :attribute output: output sum.
370 :attribute partition_points: the input partition points. Modification not
371 supported, except for by ``Signal.eq``.
374 def __init__(self
, n_inputs
, output_width
, n_parts
, register_levels
,
376 """Create an ``AddReduce``.
378 :param inputs: input ``Signal``s to be summed.
379 :param output_width: bit-width of ``output``.
380 :param register_levels: List of nesting levels that should have
382 :param partition_points: the input partition points.
384 self
.n_inputs
= n_inputs
385 self
.n_parts
= n_parts
386 self
.output_width
= output_width
387 self
.i
= AddReduceData(partition_points
, n_inputs
,
388 output_width
, n_parts
)
389 self
.out_part_ops
= self
.i
.part_ops
390 self
._resized
_inputs
= self
.i
.inputs
391 self
.register_levels
= list(register_levels
)
392 self
.partition_points
= PartitionPoints(partition_points
)
393 if not self
.partition_points
.fits_in_width(output_width
):
394 raise ValueError("partition_points doesn't fit in output_width")
395 self
._reg
_partition
_points
= self
.i
.reg_partition_points
397 max_level
= AddReduceSingle
.get_max_level(n_inputs
)
398 for level
in self
.register_levels
:
399 if level
> max_level
:
401 "not enough adder levels for specified register levels")
403 # this is annoying. we have to create the modules (and terms)
404 # because we need to know what they are (in order to set up the
405 # interconnects back in AddReduce), but cannot do the m.d.comb +=
406 # etc because this is not in elaboratable.
407 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
408 self
._intermediate
_terms
= []
409 if len(self
.groups
) != 0:
410 self
.create_next_terms()
413 def get_max_level(input_count
):
414 """Get the maximum level.
416 All ``register_levels`` must be less than or equal to the maximum
421 groups
= AddReduceSingle
.full_adder_groups(input_count
)
424 input_count
%= FULL_ADDER_INPUT_COUNT
425 input_count
+= 2 * len(groups
)
429 def full_adder_groups(input_count
):
430 """Get ``inputs`` indices for which a full adder should be built."""
432 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
433 FULL_ADDER_INPUT_COUNT
)
435 def elaborate(self
, platform
):
436 """Elaborate this module."""
439 for (value
, term
) in self
._intermediate
_terms
:
440 m
.d
.comb
+= term
.eq(value
)
442 mask
= self
._reg
_partition
_points
.as_mask(self
.output_width
)
443 m
.d
.comb
+= self
.part_mask
.eq(mask
)
445 # add and link the intermediate term modules
446 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
447 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
449 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[iidx
])
450 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[iidx
+ 1])
451 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[iidx
+ 2])
452 m
.d
.comb
+= adder_i
.mask
.eq(self
.part_mask
)
456 def create_next_terms(self
):
458 # go on to prepare recursive case
459 intermediate_terms
= []
460 _intermediate_terms
= []
462 def add_intermediate_term(value
):
463 intermediate_term
= Signal(
465 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
466 _intermediate_terms
.append((value
, intermediate_term
))
467 intermediate_terms
.append(intermediate_term
)
469 # store mask in intermediary (simplifies graph)
470 self
.part_mask
= Signal(self
.output_width
, reset_less
=True)
472 # create full adders for this recursive level.
473 # this shrinks N terms to 2 * (N // 3) plus the remainder
475 for i
in self
.groups
:
476 adder_i
= MaskedFullAdder(self
.output_width
)
477 self
.adders
.append((i
, adder_i
))
478 # add both the sum and the masked-carry to the next level.
479 # 3 inputs have now been reduced to 2...
480 add_intermediate_term(adder_i
.sum)
481 add_intermediate_term(adder_i
.mcarry
)
482 # handle the remaining inputs.
483 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
484 add_intermediate_term(self
._resized
_inputs
[-1])
485 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
486 # Just pass the terms to the next layer, since we wouldn't gain
487 # anything by using a half adder since there would still be 2 terms
488 # and just passing the terms to the next layer saves gates.
489 add_intermediate_term(self
._resized
_inputs
[-2])
490 add_intermediate_term(self
._resized
_inputs
[-1])
492 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
494 self
.intermediate_terms
= intermediate_terms
495 self
._intermediate
_terms
= _intermediate_terms
498 class AddReduce(Elaboratable
):
499 """Recursively Add list of numbers together.
501 :attribute inputs: input ``Signal``s to be summed. Modification not
502 supported, except for by ``Signal.eq``.
503 :attribute register_levels: List of nesting levels that should have
505 :attribute output: output sum.
506 :attribute partition_points: the input partition points. Modification not
507 supported, except for by ``Signal.eq``.
510 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
512 """Create an ``AddReduce``.
514 :param inputs: input ``Signal``s to be summed.
515 :param output_width: bit-width of ``output``.
516 :param register_levels: List of nesting levels that should have
518 :param partition_points: the input partition points.
521 self
.part_ops
= part_ops
522 self
.out_part_ops
= [Signal(2, name
=f
"out_part_ops_{i}")
523 for i
in range(len(part_ops
))]
524 self
.output
= Signal(output_width
)
525 self
.output_width
= output_width
526 self
.register_levels
= register_levels
527 self
.partition_points
= partition_points
532 def get_max_level(input_count
):
533 return AddReduceSingle
.get_max_level(input_count
)
536 def next_register_levels(register_levels
):
537 """``Iterable`` of ``register_levels`` for next recursive level."""
538 for level
in register_levels
:
542 def create_levels(self
):
543 """creates reduction levels"""
546 next_levels
= self
.register_levels
547 partition_points
= self
.partition_points
549 part_ops
= self
.part_ops
550 n_parts
= len(part_ops
)
553 next_level
= AddReduceSingle(ilen
, self
.output_width
, n_parts
,
554 next_levels
, partition_points
)
555 mods
.append(next_level
)
556 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
557 partition_points
= next_level
._reg
_partition
_points
558 inputs
= next_level
.intermediate_terms
560 part_ops
= next_level
.out_part_ops
561 groups
= AddReduceSingle
.full_adder_groups(len(inputs
))
566 next_level
= FinalAdd(ilen
, self
.output_width
, n_parts
,
567 next_levels
, partition_points
)
568 mods
.append(next_level
)
572 def elaborate(self
, platform
):
573 """Elaborate this module."""
576 for i
, next_level
in enumerate(self
.levels
):
577 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
579 partition_points
= self
.partition_points
581 part_ops
= self
.part_ops
582 for i
in range(len(self
.levels
)):
583 mcur
= self
.levels
[i
]
584 inassign
= [mcur
._resized
_inputs
[i
].eq(inputs
[i
])
585 for i
in range(len(inputs
))]
586 copy_part_ops
= [mcur
.out_part_ops
[i
].eq(part_ops
[i
])
587 for i
in range(len(part_ops
))]
588 if 0 in mcur
.register_levels
:
589 m
.d
.sync
+= copy_part_ops
591 m
.d
.sync
+= mcur
._reg
_partition
_points
.eq(partition_points
)
593 m
.d
.comb
+= copy_part_ops
595 m
.d
.comb
+= mcur
._reg
_partition
_points
.eq(partition_points
)
596 partition_points
= mcur
._reg
_partition
_points
597 inputs
= mcur
.intermediate_terms
598 part_ops
= mcur
.out_part_ops
600 # output comes from last module
601 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
602 copy_part_ops
= [self
.out_part_ops
[i
].eq(next_level
.out_part_ops
[i
])
603 for i
in range(len(self
.part_ops
))]
604 m
.d
.comb
+= copy_part_ops
610 OP_MUL_SIGNED_HIGH
= 1
611 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
612 OP_MUL_UNSIGNED_HIGH
= 3
615 def get_term(value
, shift
=0, enabled
=None):
616 if enabled
is not None:
617 value
= Mux(enabled
, value
, 0)
619 value
= Cat(Repl(C(0, 1), shift
), value
)
625 class ProductTerm(Elaboratable
):
626 """ this class creates a single product term (a[..]*b[..]).
627 it has a design flaw in that is the *output* that is selected,
628 where the multiplication(s) are combinatorially generated
632 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
633 self
.a_index
= a_index
634 self
.b_index
= b_index
635 shift
= 8 * (self
.a_index
+ self
.b_index
)
641 self
.ti
= Signal(self
.width
, reset_less
=True)
642 self
.term
= Signal(twidth
, reset_less
=True)
643 self
.a
= Signal(twidth
//2, reset_less
=True)
644 self
.b
= Signal(twidth
//2, reset_less
=True)
645 self
.pb_en
= Signal(pbwid
, reset_less
=True)
648 min_index
= min(self
.a_index
, self
.b_index
)
649 max_index
= max(self
.a_index
, self
.b_index
)
650 for i
in range(min_index
, max_index
):
651 tl
.append(self
.pb_en
[i
])
652 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
654 term_enabled
= Signal(name
=name
, reset_less
=True)
657 self
.enabled
= term_enabled
658 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
660 def elaborate(self
, platform
):
663 if self
.enabled
is not None:
664 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
666 bsa
= Signal(self
.width
, reset_less
=True)
667 bsb
= Signal(self
.width
, reset_less
=True)
668 a_index
, b_index
= self
.a_index
, self
.b_index
670 m
.d
.comb
+= bsa
.eq(self
.a
.part(a_index
* pwidth
, pwidth
))
671 m
.d
.comb
+= bsb
.eq(self
.b
.part(b_index
* pwidth
, pwidth
))
672 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
673 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
675 #TODO: sort out width issues, get inputs a/b switched on/off.
676 #data going into Muxes is 1/2 the required width
680 bsa = Signal(self.twidth//2, reset_less=True)
681 bsb = Signal(self.twidth//2, reset_less=True)
682 asel = Signal(width, reset_less=True)
683 bsel = Signal(width, reset_less=True)
684 a_index, b_index = self.a_index, self.b_index
685 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
686 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
687 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
688 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
689 m.d.comb += self.ti.eq(bsa * bsb)
690 m.d.comb += self.term.eq(self.ti)
696 class ProductTerms(Elaboratable
):
697 """ creates a bank of product terms. also performs the actual bit-selection
698 this class is to be wrapped with a for-loop on the "a" operand.
699 it creates a second-level for-loop on the "b" operand.
701 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
702 self
.a_index
= a_index
707 self
.a
= Signal(twidth
//2, reset_less
=True)
708 self
.b
= Signal(twidth
//2, reset_less
=True)
709 self
.pb_en
= Signal(pbwid
, reset_less
=True)
710 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
711 for i
in range(blen
)]
713 def elaborate(self
, platform
):
717 for b_index
in range(self
.blen
):
718 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
719 self
.a_index
, b_index
)
720 setattr(m
.submodules
, "term_%d" % b_index
, t
)
722 m
.d
.comb
+= t
.a
.eq(self
.a
)
723 m
.d
.comb
+= t
.b
.eq(self
.b
)
724 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
726 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
731 class LSBNegTerm(Elaboratable
):
733 def __init__(self
, bit_width
):
734 self
.bit_width
= bit_width
735 self
.part
= Signal(reset_less
=True)
736 self
.signed
= Signal(reset_less
=True)
737 self
.op
= Signal(bit_width
, reset_less
=True)
738 self
.msb
= Signal(reset_less
=True)
739 self
.nt
= Signal(bit_width
*2, reset_less
=True)
740 self
.nl
= Signal(bit_width
*2, reset_less
=True)
742 def elaborate(self
, platform
):
745 bit_wid
= self
.bit_width
746 ext
= Repl(0, bit_wid
) # extend output to HI part
748 # determine sign of each incoming number *in this partition*
749 enabled
= Signal(reset_less
=True)
750 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
752 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
753 # negation operation is split into a bitwise not and a +1.
754 # likewise for 16, 32, and 64-bit values.
756 # width-extended 1s complement if a is signed, otherwise zero
757 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
759 # add 1 if signed, otherwise add zero
760 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
765 class Parts(Elaboratable
):
767 def __init__(self
, pbwid
, epps
, n_parts
):
770 self
.epps
= PartitionPoints
.like(epps
, name
="epps") # expanded points
772 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
774 def elaborate(self
, platform
):
777 epps
, parts
= self
.epps
, self
.parts
778 # collect part-bytes (double factor because the input is extended)
779 pbs
= Signal(self
.pbwid
, reset_less
=True)
781 for i
in range(self
.pbwid
):
782 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
783 m
.d
.comb
+= pb
.eq(epps
.part_byte(i
, mfactor
=2)) # double
785 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
787 # negated-temporary copy of partition bits
788 npbs
= Signal
.like(pbs
, reset_less
=True)
789 m
.d
.comb
+= npbs
.eq(~pbs
)
790 byte_count
= 8 // len(parts
)
791 for i
in range(len(parts
)):
793 pbl
.append(npbs
[i
* byte_count
- 1])
794 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
796 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
797 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
798 m
.d
.comb
+= value
.eq(Cat(*pbl
))
799 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
804 class Part(Elaboratable
):
805 """ a key class which, depending on the partitioning, will determine
806 what action to take when parts of the output are signed or unsigned.
808 this requires 2 pieces of data *per operand, per partition*:
809 whether the MSB is HI/LO (per partition!), and whether a signed
810 or unsigned operation has been *requested*.
812 once that is determined, signed is basically carried out
813 by splitting 2's complement into 1's complement plus one.
814 1's complement is just a bit-inversion.
816 the extra terms - as separate terms - are then thrown at the
817 AddReduce alongside the multiplication part-results.
819 def __init__(self
, epps
, width
, n_parts
, n_levels
, pbwid
):
827 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
828 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
829 self
.pbs
= Signal(pbwid
, reset_less
=True)
832 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
834 self
.not_a_term
= Signal(width
)
835 self
.neg_lsb_a_term
= Signal(width
)
836 self
.not_b_term
= Signal(width
)
837 self
.neg_lsb_b_term
= Signal(width
)
839 def elaborate(self
, platform
):
842 pbs
, parts
= self
.pbs
, self
.parts
844 m
.submodules
.p
= p
= Parts(self
.pbwid
, epps
, len(parts
))
845 m
.d
.comb
+= p
.epps
.eq(epps
)
848 byte_count
= 8 // len(parts
)
850 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
851 self
.not_a_term
, self
.neg_lsb_a_term
,
852 self
.not_b_term
, self
.neg_lsb_b_term
)
854 byte_width
= 8 // len(parts
) # byte width
855 bit_wid
= 8 * byte_width
# bit width
856 nat
, nbt
, nla
, nlb
= [], [], [], []
857 for i
in range(len(parts
)):
858 # work out bit-inverted and +1 term for a.
859 pa
= LSBNegTerm(bit_wid
)
860 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
861 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
862 m
.d
.comb
+= pa
.op
.eq(self
.a
.part(bit_wid
* i
, bit_wid
))
863 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
864 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
868 # work out bit-inverted and +1 term for b
869 pb
= LSBNegTerm(bit_wid
)
870 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
871 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
872 m
.d
.comb
+= pb
.op
.eq(self
.b
.part(bit_wid
* i
, bit_wid
))
873 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
874 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
878 # concatenate together and return all 4 results.
879 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
880 not_b_term
.eq(Cat(*nbt
)),
881 neg_lsb_a_term
.eq(Cat(*nla
)),
882 neg_lsb_b_term
.eq(Cat(*nlb
)),
888 class IntermediateOut(Elaboratable
):
889 """ selects the HI/LO part of the multiplication, for a given bit-width
890 the output is also reconstructed in its SIMD (partition) lanes.
892 def __init__(self
, width
, out_wid
, n_parts
):
894 self
.n_parts
= n_parts
895 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
897 self
.intermed
= Signal(out_wid
, reset_less
=True)
898 self
.output
= Signal(out_wid
//2, reset_less
=True)
900 def elaborate(self
, platform
):
906 for i
in range(self
.n_parts
):
907 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
909 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
910 self
.intermed
.part(i
* w
*2, w
),
911 self
.intermed
.part(i
* w
*2 + w
, w
)))
913 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
918 class FinalOut(Elaboratable
):
919 """ selects the final output based on the partitioning.
921 each byte is selectable independently, i.e. it is possible
922 that some partitions requested 8-bit computation whilst others
923 requested 16 or 32 bit.
925 def __init__(self
, out_wid
):
927 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
928 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
929 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
931 self
.i8
= Signal(out_wid
, reset_less
=True)
932 self
.i16
= Signal(out_wid
, reset_less
=True)
933 self
.i32
= Signal(out_wid
, reset_less
=True)
934 self
.i64
= Signal(out_wid
, reset_less
=True)
937 self
.out
= Signal(out_wid
, reset_less
=True)
939 def elaborate(self
, platform
):
943 # select one of the outputs: d8 selects i8, d16 selects i16
944 # d32 selects i32, and the default is i64.
945 # d8 and d16 are ORed together in the first Mux
946 # then the 2nd selects either i8 or i16.
947 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
948 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
950 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
951 Mux(self
.d8
[i
], self
.i8
.part(i
* 8, 8),
952 self
.i16
.part(i
* 8, 8)),
953 Mux(self
.d32
[i
// 4], self
.i32
.part(i
* 8, 8),
954 self
.i64
.part(i
* 8, 8))))
956 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
960 class OrMod(Elaboratable
):
961 """ ORs four values together in a hierarchical tree
963 def __init__(self
, wid
):
965 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
967 self
.orout
= Signal(wid
, reset_less
=True)
969 def elaborate(self
, platform
):
971 or1
= Signal(self
.wid
, reset_less
=True)
972 or2
= Signal(self
.wid
, reset_less
=True)
973 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
974 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
975 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
980 class Signs(Elaboratable
):
981 """ determines whether a or b are signed numbers
982 based on the required operation type (OP_MUL_*)
986 self
.part_ops
= Signal(2, reset_less
=True)
987 self
.a_signed
= Signal(reset_less
=True)
988 self
.b_signed
= Signal(reset_less
=True)
990 def elaborate(self
, platform
):
994 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
995 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
996 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
997 m
.d
.comb
+= self
.a_signed
.eq(asig
)
998 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
1003 class Mul8_16_32_64(Elaboratable
):
1004 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1006 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1007 partitions on naturally-aligned boundaries. Supports the operation being
1008 set for each partition independently.
1010 :attribute part_pts: the input partition points. Has a partition point at
1011 multiples of 8 in 0 < i < 64. Each partition point's associated
1012 ``Value`` is a ``Signal``. Modification not supported, except for by
1014 :attribute part_ops: the operation for each byte. The operation for a
1015 particular partition is selected by assigning the selected operation
1016 code to each byte in the partition. The allowed operation codes are:
1018 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1019 RISC-V's `mul` instruction.
1020 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1021 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1023 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1024 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1025 `mulhsu` instruction.
1026 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1027 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1031 def __init__(self
, register_levels
=()):
1032 """ register_levels: specifies the points in the cascade at which
1033 flip-flops are to be inserted.
1037 self
.register_levels
= list(register_levels
)
1040 self
.part_pts
= PartitionPoints()
1041 for i
in range(8, 64, 8):
1042 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1043 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1047 # intermediates (needed for unit tests)
1048 self
._intermediate
_output
= Signal(128)
1051 self
.output
= Signal(64)
1053 def elaborate(self
, platform
):
1056 # collect part-bytes
1057 pbs
= Signal(8, reset_less
=True)
1060 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1061 m
.d
.comb
+= pb
.eq(self
.part_pts
.part_byte(i
))
1063 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1065 # create (doubled) PartitionPoints (output is double input width)
1066 expanded_part_pts
= eps
= PartitionPoints()
1067 for i
, v
in self
.part_pts
.items():
1068 ep
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1069 expanded_part_pts
[i
* 2] = ep
1070 m
.d
.comb
+= ep
.eq(v
)
1077 setattr(m
.submodules
, "signs%d" % i
, s
)
1078 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
1080 n_levels
= len(self
.register_levels
)+1
1081 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, n_levels
, 8)
1082 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, n_levels
, 8)
1083 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, n_levels
, 8)
1084 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, n_levels
, 8)
1085 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1086 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1087 m
.d
.comb
+= mod
.a
.eq(self
.a
)
1088 m
.d
.comb
+= mod
.b
.eq(self
.b
)
1089 for i
in range(len(signs
)):
1090 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1091 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1092 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1093 nat_l
.append(mod
.not_a_term
)
1094 nbt_l
.append(mod
.not_b_term
)
1095 nla_l
.append(mod
.neg_lsb_a_term
)
1096 nlb_l
.append(mod
.neg_lsb_b_term
)
1100 for a_index
in range(8):
1101 t
= ProductTerms(8, 128, 8, a_index
, 8)
1102 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1104 m
.d
.comb
+= t
.a
.eq(self
.a
)
1105 m
.d
.comb
+= t
.b
.eq(self
.b
)
1106 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1108 for term
in t
.terms
:
1111 # it's fine to bitwise-or data together since they are never enabled
1113 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1114 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1115 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1116 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1117 for l
, mod
in [(nat_l
, nat_or
),
1121 for i
in range(len(l
)):
1122 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1123 terms
.append(mod
.orout
)
1125 add_reduce
= AddReduce(terms
,
1127 self
.register_levels
,
1131 out_part_ops
= add_reduce
.levels
[-1].out_part_ops
1132 out_part_pts
= add_reduce
.levels
[-1]._reg
_partition
_points
1134 m
.submodules
.add_reduce
= add_reduce
1135 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
1137 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1138 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
1140 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1143 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1144 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
1146 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1149 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1150 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
1152 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1155 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1156 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
1158 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1160 m
.submodules
.p_8
= p_8
= Parts(8, eps
, len(part_8
.parts
))
1161 m
.submodules
.p_16
= p_16
= Parts(8, eps
, len(part_16
.parts
))
1162 m
.submodules
.p_32
= p_32
= Parts(8, eps
, len(part_32
.parts
))
1163 m
.submodules
.p_64
= p_64
= Parts(8, eps
, len(part_64
.parts
))
1165 m
.d
.comb
+= p_8
.epps
.eq(out_part_pts
)
1166 m
.d
.comb
+= p_16
.epps
.eq(out_part_pts
)
1167 m
.d
.comb
+= p_32
.epps
.eq(out_part_pts
)
1168 m
.d
.comb
+= p_64
.epps
.eq(out_part_pts
)
1171 m
.submodules
.finalout
= finalout
= FinalOut(64)
1172 for i
in range(len(part_8
.parts
)):
1173 m
.d
.comb
+= finalout
.d8
[i
].eq(p_8
.parts
[i
])
1174 for i
in range(len(part_16
.parts
)):
1175 m
.d
.comb
+= finalout
.d16
[i
].eq(p_16
.parts
[i
])
1176 for i
in range(len(part_32
.parts
)):
1177 m
.d
.comb
+= finalout
.d32
[i
].eq(p_32
.parts
[i
])
1178 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1179 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1180 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1181 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1182 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1187 if __name__
== "__main__":
1191 m
._intermediate
_output
,
1194 *m
.part_pts
.values()])