1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0, mul
=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
60 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
61 retval
= PartitionPoints()
62 for point
, enabled
in self
.items():
64 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self
.keys()) != set(rhs
.keys()):
70 raise ValueError("incompatible point set")
71 for point
, enabled
in self
.items():
72 yield enabled
.eq(rhs
[point
])
74 def as_mask(self
, width
):
75 """Create a bit-mask from `self`.
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
80 :param width: the bit width of the resulting mask
83 for i
in range(width
):
90 def get_max_partition_count(self
, width
):
91 """Get the maximum number of partitions.
93 Gets the number of partitions when all partition points are enabled.
96 for point
in self
.keys():
101 def fits_in_width(self
, width
):
102 """Check if all partition points are smaller than `width`."""
103 for point
in self
.keys():
108 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
109 if index
== -1 or index
== 7:
111 assert index
>= 0 and index
< 8
112 return self
[(index
* 8 + 8)*mfactor
]
115 class FullAdder(Elaboratable
):
118 :attribute in0: the first input
119 :attribute in1: the second input
120 :attribute in2: the third input
121 :attribute sum: the sum output
122 :attribute carry: the carry output
124 Rather than do individual full adders (and have an array of them,
125 which would be very slow to simulate), this module can specify the
126 bit width of the inputs and outputs: in effect it performs multiple
127 Full 3-2 Add operations "in parallel".
130 def __init__(self
, width
):
131 """Create a ``FullAdder``.
133 :param width: the bit width of the input and output
135 self
.in0
= Signal(width
, reset_less
=True)
136 self
.in1
= Signal(width
, reset_less
=True)
137 self
.in2
= Signal(width
, reset_less
=True)
138 self
.sum = Signal(width
, reset_less
=True)
139 self
.carry
= Signal(width
, reset_less
=True)
141 def elaborate(self
, platform
):
142 """Elaborate this module."""
144 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
145 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
146 |
(self
.in1
& self
.in2
)
147 |
(self
.in2
& self
.in0
))
151 class MaskedFullAdder(Elaboratable
):
152 """Masked Full Adder.
154 :attribute mask: the carry partition mask
155 :attribute in0: the first input
156 :attribute in1: the second input
157 :attribute in2: the third input
158 :attribute sum: the sum output
159 :attribute mcarry: the masked carry output
161 FullAdders are always used with a "mask" on the output. To keep
162 the graphviz "clean", this class performs the masking here rather
163 than inside a large for-loop.
165 See the following discussion as to why this is no longer derived
166 from FullAdder. Each carry is shifted here *before* being ANDed
167 with the mask, so that an AOI cell may be used (which is more
169 https://en.wikipedia.org/wiki/AND-OR-Invert
170 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
173 def __init__(self
, width
):
174 """Create a ``MaskedFullAdder``.
176 :param width: the bit width of the input and output
179 self
.mask
= Signal(width
, reset_less
=True)
180 self
.mcarry
= Signal(width
, reset_less
=True)
181 self
.in0
= Signal(width
, reset_less
=True)
182 self
.in1
= Signal(width
, reset_less
=True)
183 self
.in2
= Signal(width
, reset_less
=True)
184 self
.sum = Signal(width
, reset_less
=True)
186 def elaborate(self
, platform
):
187 """Elaborate this module."""
189 s1
= Signal(self
.width
, reset_less
=True)
190 s2
= Signal(self
.width
, reset_less
=True)
191 s3
= Signal(self
.width
, reset_less
=True)
192 c1
= Signal(self
.width
, reset_less
=True)
193 c2
= Signal(self
.width
, reset_less
=True)
194 c3
= Signal(self
.width
, reset_less
=True)
195 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
196 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
197 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
198 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
199 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
200 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
201 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
202 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
206 class PartitionedAdder(Elaboratable
):
207 """Partitioned Adder.
209 Performs the final add. The partition points are included in the
210 actual add (in one of the operands only), which causes a carry over
211 to the next bit. Then the final output *removes* the extra bits from
214 partition: .... P... P... P... P... (32 bits)
215 a : .... .... .... .... .... (32 bits)
216 b : .... .... .... .... .... (32 bits)
217 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
218 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
219 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
220 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
222 :attribute width: the bit width of the input and output. Read-only.
223 :attribute a: the first input to the adder
224 :attribute b: the second input to the adder
225 :attribute output: the sum output
226 :attribute partition_points: the input partition points. Modification not
227 supported, except for by ``Signal.eq``.
230 def __init__(self
, width
, partition_points
):
231 """Create a ``PartitionedAdder``.
233 :param width: the bit width of the input and output
234 :param partition_points: the input partition points
237 self
.a
= Signal(width
, reset_less
=True)
238 self
.b
= Signal(width
, reset_less
=True)
239 self
.output
= Signal(width
, reset_less
=True)
240 self
.partition_points
= PartitionPoints(partition_points
)
241 if not self
.partition_points
.fits_in_width(width
):
242 raise ValueError("partition_points doesn't fit in width")
244 for i
in range(self
.width
):
245 if i
in self
.partition_points
:
248 self
._expanded
_width
= expanded_width
250 def elaborate(self
, platform
):
251 """Elaborate this module."""
253 expanded_a
= Signal(self
._expanded
_width
, reset_less
=True)
254 expanded_b
= Signal(self
._expanded
_width
, reset_less
=True)
255 expanded_o
= Signal(self
._expanded
_width
, reset_less
=True)
258 # store bits in a list, use Cat later. graphviz is much cleaner
259 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
261 # partition points are "breaks" (extra zeros or 1s) in what would
262 # otherwise be a massive long add. when the "break" points are 0,
263 # whatever is in it (in the output) is discarded. however when
264 # there is a "1", it causes a roll-over carry to the *next* bit.
265 # we still ignore the "break" bit in the [intermediate] output,
266 # however by that time we've got the effect that we wanted: the
267 # carry has been carried *over* the break point.
269 for i
in range(self
.width
):
270 if i
in self
.partition_points
:
271 # add extra bit set to 0 + 0 for enabled partition points
272 # and 1 + 0 for disabled partition points
273 ea
.append(expanded_a
[expanded_index
])
274 al
.append(~self
.partition_points
[i
]) # add extra bit in a
275 eb
.append(expanded_b
[expanded_index
])
276 bl
.append(C(0)) # yes, add a zero
277 expanded_index
+= 1 # skip the extra point. NOT in the output
278 ea
.append(expanded_a
[expanded_index
])
279 eb
.append(expanded_b
[expanded_index
])
280 eo
.append(expanded_o
[expanded_index
])
283 ol
.append(self
.output
[i
])
286 # combine above using Cat
287 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
288 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
289 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
291 # use only one addition to take advantage of look-ahead carry and
292 # special hardware on FPGAs
293 m
.d
.comb
+= expanded_o
.eq(expanded_a
+ expanded_b
)
297 FULL_ADDER_INPUT_COUNT
= 3
301 def __init__(self
, ppoints
, n_inputs
, output_width
, n_parts
):
302 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
303 for i
in range(n_parts
)]
304 self
.inputs
= [Signal(output_width
, name
=f
"inputs_{i}",
306 for i
in range(n_inputs
)]
307 self
.reg_partition_points
= ppoints
.like()
309 def eq_from(self
, reg_partition_points
, inputs
, part_ops
):
310 return [self
.reg_partition_points
.eq(reg_partition_points
)] + \
311 [self
.inputs
[i
].eq(inputs
[i
])
312 for i
in range(len(self
.inputs
))] + \
313 [self
.part_ops
[i
].eq(part_ops
[i
])
314 for i
in range(len(self
.part_ops
))]
317 return self
.eq_from(rhs
.reg_partition_points
, rhs
.inputs
, rhs
.part_ops
)
320 class FinalReduceData
:
322 def __init__(self
, ppoints
, output_width
, n_parts
):
323 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
324 for i
in range(n_parts
)]
325 self
.output
= Signal(output_width
, reset_less
=True)
326 self
.reg_partition_points
= ppoints
.like()
328 def eq_from(self
, reg_partition_points
, output
, part_ops
):
329 return [self
.reg_partition_points
.eq(reg_partition_points
)] + \
330 [self
.output
.eq(output
)] + \
331 [self
.part_ops
[i
].eq(part_ops
[i
])
332 for i
in range(len(self
.part_ops
))]
335 return self
.eq_from(rhs
.reg_partition_points
, rhs
.output
, rhs
.part_ops
)
338 class FinalAdd(Elaboratable
):
339 """ Final stage of add reduce
342 def __init__(self
, n_inputs
, output_width
, n_parts
, register_levels
,
344 self
.i
= AddReduceData(partition_points
, n_inputs
,
345 output_width
, n_parts
)
346 self
.o
= FinalReduceData(partition_points
, output_width
, n_parts
)
347 self
.output_width
= output_width
348 self
.n_inputs
= n_inputs
349 self
.n_parts
= n_parts
350 self
.register_levels
= list(register_levels
)
351 self
.partition_points
= PartitionPoints(partition_points
)
352 if not self
.partition_points
.fits_in_width(output_width
):
353 raise ValueError("partition_points doesn't fit in output_width")
355 def elaborate(self
, platform
):
356 """Elaborate this module."""
359 output_width
= self
.output_width
360 output
= Signal(output_width
, reset_less
=True)
361 if self
.n_inputs
== 0:
362 # use 0 as the default output value
363 m
.d
.comb
+= output
.eq(0)
364 elif self
.n_inputs
== 1:
365 # handle single input
366 m
.d
.comb
+= output
.eq(self
.i
.inputs
[0])
368 # base case for adding 2 inputs
369 assert self
.n_inputs
== 2
370 adder
= PartitionedAdder(output_width
, self
.i
.reg_partition_points
)
371 m
.submodules
.final_adder
= adder
372 m
.d
.comb
+= adder
.a
.eq(self
.i
.inputs
[0])
373 m
.d
.comb
+= adder
.b
.eq(self
.i
.inputs
[1])
374 m
.d
.comb
+= output
.eq(adder
.output
)
377 m
.d
.comb
+= self
.o
.eq_from(self
.i
.reg_partition_points
, output
,
383 class AddReduceSingle(Elaboratable
):
384 """Add list of numbers together.
386 :attribute inputs: input ``Signal``s to be summed. Modification not
387 supported, except for by ``Signal.eq``.
388 :attribute register_levels: List of nesting levels that should have
390 :attribute output: output sum.
391 :attribute partition_points: the input partition points. Modification not
392 supported, except for by ``Signal.eq``.
395 def __init__(self
, n_inputs
, output_width
, n_parts
, register_levels
,
397 """Create an ``AddReduce``.
399 :param inputs: input ``Signal``s to be summed.
400 :param output_width: bit-width of ``output``.
401 :param register_levels: List of nesting levels that should have
403 :param partition_points: the input partition points.
405 self
.n_inputs
= n_inputs
406 self
.n_parts
= n_parts
407 self
.output_width
= output_width
408 self
.i
= AddReduceData(partition_points
, n_inputs
,
409 output_width
, n_parts
)
410 self
.register_levels
= list(register_levels
)
411 self
.partition_points
= PartitionPoints(partition_points
)
412 if not self
.partition_points
.fits_in_width(output_width
):
413 raise ValueError("partition_points doesn't fit in output_width")
415 max_level
= AddReduceSingle
.get_max_level(n_inputs
)
416 for level
in self
.register_levels
:
417 if level
> max_level
:
419 "not enough adder levels for specified register levels")
421 # this is annoying. we have to create the modules (and terms)
422 # because we need to know what they are (in order to set up the
423 # interconnects back in AddReduce), but cannot do the m.d.comb +=
424 # etc because this is not in elaboratable.
425 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
426 self
._intermediate
_terms
= []
428 if len(self
.groups
) != 0:
429 self
.create_next_terms()
431 self
.o
= AddReduceData(partition_points
, len(self
._intermediate
_terms
),
432 output_width
, n_parts
)
435 def get_max_level(input_count
):
436 """Get the maximum level.
438 All ``register_levels`` must be less than or equal to the maximum
443 groups
= AddReduceSingle
.full_adder_groups(input_count
)
446 input_count
%= FULL_ADDER_INPUT_COUNT
447 input_count
+= 2 * len(groups
)
451 def full_adder_groups(input_count
):
452 """Get ``inputs`` indices for which a full adder should be built."""
454 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
455 FULL_ADDER_INPUT_COUNT
)
457 def elaborate(self
, platform
):
458 """Elaborate this module."""
461 # copy the intermediate terms to the output
462 for i
, value
in enumerate(self
._intermediate
_terms
):
463 m
.d
.comb
+= self
.o
.inputs
[i
].eq(value
)
465 # copy reg part points and part ops to output
466 m
.d
.comb
+= self
.o
.reg_partition_points
.eq(self
.i
.reg_partition_points
)
467 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
468 for i
in range(len(self
.i
.part_ops
))]
470 # set up the partition mask (for the adders)
471 part_mask
= Signal(self
.output_width
, reset_less
=True)
473 mask
= self
.i
.reg_partition_points
.as_mask(self
.output_width
)
474 m
.d
.comb
+= part_mask
.eq(mask
)
476 # add and link the intermediate term modules
477 for i
, (iidx
, adder_i
) in enumerate(self
.adders
):
478 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
480 m
.d
.comb
+= adder_i
.in0
.eq(self
.i
.inputs
[iidx
])
481 m
.d
.comb
+= adder_i
.in1
.eq(self
.i
.inputs
[iidx
+ 1])
482 m
.d
.comb
+= adder_i
.in2
.eq(self
.i
.inputs
[iidx
+ 2])
483 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
487 def create_next_terms(self
):
489 _intermediate_terms
= []
491 def add_intermediate_term(value
):
492 _intermediate_terms
.append(value
)
494 # create full adders for this recursive level.
495 # this shrinks N terms to 2 * (N // 3) plus the remainder
496 for i
in self
.groups
:
497 adder_i
= MaskedFullAdder(self
.output_width
)
498 self
.adders
.append((i
, adder_i
))
499 # add both the sum and the masked-carry to the next level.
500 # 3 inputs have now been reduced to 2...
501 add_intermediate_term(adder_i
.sum)
502 add_intermediate_term(adder_i
.mcarry
)
503 # handle the remaining inputs.
504 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
505 add_intermediate_term(self
.i
.inputs
[-1])
506 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
507 # Just pass the terms to the next layer, since we wouldn't gain
508 # anything by using a half adder since there would still be 2 terms
509 # and just passing the terms to the next layer saves gates.
510 add_intermediate_term(self
.i
.inputs
[-2])
511 add_intermediate_term(self
.i
.inputs
[-1])
513 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
515 self
._intermediate
_terms
= _intermediate_terms
518 class AddReduce(Elaboratable
):
519 """Recursively Add list of numbers together.
521 :attribute inputs: input ``Signal``s to be summed. Modification not
522 supported, except for by ``Signal.eq``.
523 :attribute register_levels: List of nesting levels that should have
525 :attribute output: output sum.
526 :attribute partition_points: the input partition points. Modification not
527 supported, except for by ``Signal.eq``.
530 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
,
532 """Create an ``AddReduce``.
534 :param inputs: input ``Signal``s to be summed.
535 :param output_width: bit-width of ``output``.
536 :param register_levels: List of nesting levels that should have
538 :param partition_points: the input partition points.
541 self
.part_ops
= part_ops
542 n_parts
= len(part_ops
)
543 self
.o
= FinalReduceData(partition_points
, output_width
, n_parts
)
544 self
.output_width
= output_width
545 self
.register_levels
= register_levels
546 self
.partition_points
= partition_points
551 def get_max_level(input_count
):
552 return AddReduceSingle
.get_max_level(input_count
)
555 def next_register_levels(register_levels
):
556 """``Iterable`` of ``register_levels`` for next recursive level."""
557 for level
in register_levels
:
561 def create_levels(self
):
562 """creates reduction levels"""
565 next_levels
= self
.register_levels
566 partition_points
= self
.partition_points
567 part_ops
= self
.part_ops
568 n_parts
= len(part_ops
)
572 groups
= AddReduceSingle
.full_adder_groups(len(inputs
))
575 next_level
= AddReduceSingle(ilen
, self
.output_width
, n_parts
,
576 next_levels
, partition_points
)
577 mods
.append(next_level
)
578 next_levels
= list(AddReduce
.next_register_levels(next_levels
))
579 partition_points
= next_level
.i
.reg_partition_points
580 inputs
= next_level
.o
.inputs
582 part_ops
= next_level
.i
.part_ops
584 next_level
= FinalAdd(ilen
, self
.output_width
, n_parts
,
585 next_levels
, partition_points
)
586 mods
.append(next_level
)
590 def elaborate(self
, platform
):
591 """Elaborate this module."""
594 for i
, next_level
in enumerate(self
.levels
):
595 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
597 partition_points
= self
.partition_points
599 part_ops
= self
.part_ops
600 n_parts
= len(part_ops
)
601 n_inputs
= len(inputs
)
602 output_width
= self
.output_width
603 i
= AddReduceData(partition_points
, n_inputs
, output_width
, n_parts
)
604 m
.d
.comb
+= i
.eq_from(partition_points
, inputs
, part_ops
)
605 for idx
in range(len(self
.levels
)):
606 mcur
= self
.levels
[idx
]
607 if 0 in mcur
.register_levels
:
608 m
.d
.sync
+= mcur
.i
.eq(i
)
610 m
.d
.comb
+= mcur
.i
.eq(i
)
611 i
= mcur
.o
# for next loop
613 # output comes from last module
614 m
.d
.comb
+= self
.o
.eq(i
)
620 OP_MUL_SIGNED_HIGH
= 1
621 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
622 OP_MUL_UNSIGNED_HIGH
= 3
625 def get_term(value
, shift
=0, enabled
=None):
626 if enabled
is not None:
627 value
= Mux(enabled
, value
, 0)
629 value
= Cat(Repl(C(0, 1), shift
), value
)
635 class ProductTerm(Elaboratable
):
636 """ this class creates a single product term (a[..]*b[..]).
637 it has a design flaw in that is the *output* that is selected,
638 where the multiplication(s) are combinatorially generated
642 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
643 self
.a_index
= a_index
644 self
.b_index
= b_index
645 shift
= 8 * (self
.a_index
+ self
.b_index
)
651 self
.ti
= Signal(self
.width
, reset_less
=True)
652 self
.term
= Signal(twidth
, reset_less
=True)
653 self
.a
= Signal(twidth
//2, reset_less
=True)
654 self
.b
= Signal(twidth
//2, reset_less
=True)
655 self
.pb_en
= Signal(pbwid
, reset_less
=True)
658 min_index
= min(self
.a_index
, self
.b_index
)
659 max_index
= max(self
.a_index
, self
.b_index
)
660 for i
in range(min_index
, max_index
):
661 tl
.append(self
.pb_en
[i
])
662 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
664 term_enabled
= Signal(name
=name
, reset_less
=True)
667 self
.enabled
= term_enabled
668 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
670 def elaborate(self
, platform
):
673 if self
.enabled
is not None:
674 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
676 bsa
= Signal(self
.width
, reset_less
=True)
677 bsb
= Signal(self
.width
, reset_less
=True)
678 a_index
, b_index
= self
.a_index
, self
.b_index
680 m
.d
.comb
+= bsa
.eq(self
.a
.part(a_index
* pwidth
, pwidth
))
681 m
.d
.comb
+= bsb
.eq(self
.b
.part(b_index
* pwidth
, pwidth
))
682 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
683 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
685 #TODO: sort out width issues, get inputs a/b switched on/off.
686 #data going into Muxes is 1/2 the required width
690 bsa = Signal(self.twidth//2, reset_less=True)
691 bsb = Signal(self.twidth//2, reset_less=True)
692 asel = Signal(width, reset_less=True)
693 bsel = Signal(width, reset_less=True)
694 a_index, b_index = self.a_index, self.b_index
695 m.d.comb += asel.eq(self.a.part(a_index * pwidth, pwidth))
696 m.d.comb += bsel.eq(self.b.part(b_index * pwidth, pwidth))
697 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
698 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
699 m.d.comb += self.ti.eq(bsa * bsb)
700 m.d.comb += self.term.eq(self.ti)
706 class ProductTerms(Elaboratable
):
707 """ creates a bank of product terms. also performs the actual bit-selection
708 this class is to be wrapped with a for-loop on the "a" operand.
709 it creates a second-level for-loop on the "b" operand.
711 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
712 self
.a_index
= a_index
717 self
.a
= Signal(twidth
//2, reset_less
=True)
718 self
.b
= Signal(twidth
//2, reset_less
=True)
719 self
.pb_en
= Signal(pbwid
, reset_less
=True)
720 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
721 for i
in range(blen
)]
723 def elaborate(self
, platform
):
727 for b_index
in range(self
.blen
):
728 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
729 self
.a_index
, b_index
)
730 setattr(m
.submodules
, "term_%d" % b_index
, t
)
732 m
.d
.comb
+= t
.a
.eq(self
.a
)
733 m
.d
.comb
+= t
.b
.eq(self
.b
)
734 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
736 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
741 class LSBNegTerm(Elaboratable
):
743 def __init__(self
, bit_width
):
744 self
.bit_width
= bit_width
745 self
.part
= Signal(reset_less
=True)
746 self
.signed
= Signal(reset_less
=True)
747 self
.op
= Signal(bit_width
, reset_less
=True)
748 self
.msb
= Signal(reset_less
=True)
749 self
.nt
= Signal(bit_width
*2, reset_less
=True)
750 self
.nl
= Signal(bit_width
*2, reset_less
=True)
752 def elaborate(self
, platform
):
755 bit_wid
= self
.bit_width
756 ext
= Repl(0, bit_wid
) # extend output to HI part
758 # determine sign of each incoming number *in this partition*
759 enabled
= Signal(reset_less
=True)
760 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
762 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
763 # negation operation is split into a bitwise not and a +1.
764 # likewise for 16, 32, and 64-bit values.
766 # width-extended 1s complement if a is signed, otherwise zero
767 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
769 # add 1 if signed, otherwise add zero
770 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
775 class Parts(Elaboratable
):
777 def __init__(self
, pbwid
, epps
, n_parts
):
780 self
.epps
= PartitionPoints
.like(epps
, name
="epps") # expanded points
782 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
783 for i
in range(n_parts
)]
785 def elaborate(self
, platform
):
788 epps
, parts
= self
.epps
, self
.parts
789 # collect part-bytes (double factor because the input is extended)
790 pbs
= Signal(self
.pbwid
, reset_less
=True)
792 for i
in range(self
.pbwid
):
793 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
794 m
.d
.comb
+= pb
.eq(epps
.part_byte(i
, mfactor
=2)) # double
796 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
798 # negated-temporary copy of partition bits
799 npbs
= Signal
.like(pbs
, reset_less
=True)
800 m
.d
.comb
+= npbs
.eq(~pbs
)
801 byte_count
= 8 // len(parts
)
802 for i
in range(len(parts
)):
804 pbl
.append(npbs
[i
* byte_count
- 1])
805 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
807 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
808 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
809 m
.d
.comb
+= value
.eq(Cat(*pbl
))
810 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
815 class Part(Elaboratable
):
816 """ a key class which, depending on the partitioning, will determine
817 what action to take when parts of the output are signed or unsigned.
819 this requires 2 pieces of data *per operand, per partition*:
820 whether the MSB is HI/LO (per partition!), and whether a signed
821 or unsigned operation has been *requested*.
823 once that is determined, signed is basically carried out
824 by splitting 2's complement into 1's complement plus one.
825 1's complement is just a bit-inversion.
827 the extra terms - as separate terms - are then thrown at the
828 AddReduce alongside the multiplication part-results.
830 def __init__(self
, epps
, width
, n_parts
, n_levels
, pbwid
):
836 self
.a
= Signal(64, reset_less
=True)
837 self
.b
= Signal(64, reset_less
=True)
838 self
.a_signed
= [Signal(name
=f
"a_signed_{i}", reset_less
=True)
840 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}", reset_less
=True)
842 self
.pbs
= Signal(pbwid
, reset_less
=True)
845 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
846 for i
in range(n_parts
)]
848 self
.not_a_term
= Signal(width
, reset_less
=True)
849 self
.neg_lsb_a_term
= Signal(width
, reset_less
=True)
850 self
.not_b_term
= Signal(width
, reset_less
=True)
851 self
.neg_lsb_b_term
= Signal(width
, reset_less
=True)
853 def elaborate(self
, platform
):
856 pbs
, parts
= self
.pbs
, self
.parts
858 m
.submodules
.p
= p
= Parts(self
.pbwid
, epps
, len(parts
))
859 m
.d
.comb
+= p
.epps
.eq(epps
)
862 byte_count
= 8 // len(parts
)
864 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
865 self
.not_a_term
, self
.neg_lsb_a_term
,
866 self
.not_b_term
, self
.neg_lsb_b_term
)
868 byte_width
= 8 // len(parts
) # byte width
869 bit_wid
= 8 * byte_width
# bit width
870 nat
, nbt
, nla
, nlb
= [], [], [], []
871 for i
in range(len(parts
)):
872 # work out bit-inverted and +1 term for a.
873 pa
= LSBNegTerm(bit_wid
)
874 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
875 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
876 m
.d
.comb
+= pa
.op
.eq(self
.a
.part(bit_wid
* i
, bit_wid
))
877 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
878 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
882 # work out bit-inverted and +1 term for b
883 pb
= LSBNegTerm(bit_wid
)
884 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
885 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
886 m
.d
.comb
+= pb
.op
.eq(self
.b
.part(bit_wid
* i
, bit_wid
))
887 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
888 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
892 # concatenate together and return all 4 results.
893 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
894 not_b_term
.eq(Cat(*nbt
)),
895 neg_lsb_a_term
.eq(Cat(*nla
)),
896 neg_lsb_b_term
.eq(Cat(*nlb
)),
902 class IntermediateOut(Elaboratable
):
903 """ selects the HI/LO part of the multiplication, for a given bit-width
904 the output is also reconstructed in its SIMD (partition) lanes.
906 def __init__(self
, width
, out_wid
, n_parts
):
908 self
.n_parts
= n_parts
909 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
911 self
.intermed
= Signal(out_wid
, reset_less
=True)
912 self
.output
= Signal(out_wid
//2, reset_less
=True)
914 def elaborate(self
, platform
):
920 for i
in range(self
.n_parts
):
921 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
923 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
924 self
.intermed
.part(i
* w
*2, w
),
925 self
.intermed
.part(i
* w
*2 + w
, w
)))
927 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
932 class FinalOut(Elaboratable
):
933 """ selects the final output based on the partitioning.
935 each byte is selectable independently, i.e. it is possible
936 that some partitions requested 8-bit computation whilst others
937 requested 16 or 32 bit.
939 def __init__(self
, out_wid
):
941 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
942 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
943 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
945 self
.i8
= Signal(out_wid
, reset_less
=True)
946 self
.i16
= Signal(out_wid
, reset_less
=True)
947 self
.i32
= Signal(out_wid
, reset_less
=True)
948 self
.i64
= Signal(out_wid
, reset_less
=True)
951 self
.out
= Signal(out_wid
, reset_less
=True)
953 def elaborate(self
, platform
):
957 # select one of the outputs: d8 selects i8, d16 selects i16
958 # d32 selects i32, and the default is i64.
959 # d8 and d16 are ORed together in the first Mux
960 # then the 2nd selects either i8 or i16.
961 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
962 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
964 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
965 Mux(self
.d8
[i
], self
.i8
.part(i
* 8, 8),
966 self
.i16
.part(i
* 8, 8)),
967 Mux(self
.d32
[i
// 4], self
.i32
.part(i
* 8, 8),
968 self
.i64
.part(i
* 8, 8))))
970 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
974 class OrMod(Elaboratable
):
975 """ ORs four values together in a hierarchical tree
977 def __init__(self
, wid
):
979 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
981 self
.orout
= Signal(wid
, reset_less
=True)
983 def elaborate(self
, platform
):
985 or1
= Signal(self
.wid
, reset_less
=True)
986 or2
= Signal(self
.wid
, reset_less
=True)
987 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
988 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
989 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
994 class Signs(Elaboratable
):
995 """ determines whether a or b are signed numbers
996 based on the required operation type (OP_MUL_*)
1000 self
.part_ops
= Signal(2, reset_less
=True)
1001 self
.a_signed
= Signal(reset_less
=True)
1002 self
.b_signed
= Signal(reset_less
=True)
1004 def elaborate(self
, platform
):
1008 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
1009 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
1010 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
1011 m
.d
.comb
+= self
.a_signed
.eq(asig
)
1012 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
1017 class Mul8_16_32_64(Elaboratable
):
1018 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1020 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1021 partitions on naturally-aligned boundaries. Supports the operation being
1022 set for each partition independently.
1024 :attribute part_pts: the input partition points. Has a partition point at
1025 multiples of 8 in 0 < i < 64. Each partition point's associated
1026 ``Value`` is a ``Signal``. Modification not supported, except for by
1028 :attribute part_ops: the operation for each byte. The operation for a
1029 particular partition is selected by assigning the selected operation
1030 code to each byte in the partition. The allowed operation codes are:
1032 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1033 RISC-V's `mul` instruction.
1034 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1035 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1037 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1038 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1039 `mulhsu` instruction.
1040 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1041 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1045 def __init__(self
, register_levels
=()):
1046 """ register_levels: specifies the points in the cascade at which
1047 flip-flops are to be inserted.
1051 self
.register_levels
= list(register_levels
)
1054 self
.part_pts
= PartitionPoints()
1055 for i
in range(8, 64, 8):
1056 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1057 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1061 # intermediates (needed for unit tests)
1062 self
.intermediate_output
= Signal(128)
1065 self
.output
= Signal(64)
1067 def elaborate(self
, platform
):
1070 # collect part-bytes
1071 pbs
= Signal(8, reset_less
=True)
1074 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1075 m
.d
.comb
+= pb
.eq(self
.part_pts
.part_byte(i
))
1077 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1079 # create (doubled) PartitionPoints (output is double input width)
1080 expanded_part_pts
= eps
= PartitionPoints()
1081 for i
, v
in self
.part_pts
.items():
1082 ep
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
1083 expanded_part_pts
[i
* 2] = ep
1084 m
.d
.comb
+= ep
.eq(v
)
1091 setattr(m
.submodules
, "signs%d" % i
, s
)
1092 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
1094 n_levels
= len(self
.register_levels
)+1
1095 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, n_levels
, 8)
1096 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, n_levels
, 8)
1097 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, n_levels
, 8)
1098 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, n_levels
, 8)
1099 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1100 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1101 m
.d
.comb
+= mod
.a
.eq(self
.a
)
1102 m
.d
.comb
+= mod
.b
.eq(self
.b
)
1103 for i
in range(len(signs
)):
1104 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1105 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1106 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1107 nat_l
.append(mod
.not_a_term
)
1108 nbt_l
.append(mod
.not_b_term
)
1109 nla_l
.append(mod
.neg_lsb_a_term
)
1110 nlb_l
.append(mod
.neg_lsb_b_term
)
1114 for a_index
in range(8):
1115 t
= ProductTerms(8, 128, 8, a_index
, 8)
1116 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1118 m
.d
.comb
+= t
.a
.eq(self
.a
)
1119 m
.d
.comb
+= t
.b
.eq(self
.b
)
1120 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1122 for term
in t
.terms
:
1125 # it's fine to bitwise-or data together since they are never enabled
1127 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1128 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1129 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1130 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1131 for l
, mod
in [(nat_l
, nat_or
),
1135 for i
in range(len(l
)):
1136 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1137 terms
.append(mod
.orout
)
1139 add_reduce
= AddReduce(terms
,
1141 self
.register_levels
,
1145 out_part_ops
= add_reduce
.o
.part_ops
1146 out_part_pts
= add_reduce
.o
.reg_partition_points
1148 m
.submodules
.add_reduce
= add_reduce
1149 m
.d
.comb
+= self
.intermediate_output
.eq(add_reduce
.o
.output
)
1151 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1152 m
.d
.comb
+= io64
.intermed
.eq(self
.intermediate_output
)
1154 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1157 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1158 m
.d
.comb
+= io32
.intermed
.eq(self
.intermediate_output
)
1160 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1163 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1164 m
.d
.comb
+= io16
.intermed
.eq(self
.intermediate_output
)
1166 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1169 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1170 m
.d
.comb
+= io8
.intermed
.eq(self
.intermediate_output
)
1172 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1174 m
.submodules
.p_8
= p_8
= Parts(8, eps
, len(part_8
.parts
))
1175 m
.submodules
.p_16
= p_16
= Parts(8, eps
, len(part_16
.parts
))
1176 m
.submodules
.p_32
= p_32
= Parts(8, eps
, len(part_32
.parts
))
1177 m
.submodules
.p_64
= p_64
= Parts(8, eps
, len(part_64
.parts
))
1179 m
.d
.comb
+= p_8
.epps
.eq(out_part_pts
)
1180 m
.d
.comb
+= p_16
.epps
.eq(out_part_pts
)
1181 m
.d
.comb
+= p_32
.epps
.eq(out_part_pts
)
1182 m
.d
.comb
+= p_64
.epps
.eq(out_part_pts
)
1185 m
.submodules
.finalout
= finalout
= FinalOut(64)
1186 for i
in range(len(part_8
.parts
)):
1187 m
.d
.comb
+= finalout
.d8
[i
].eq(p_8
.parts
[i
])
1188 for i
in range(len(part_16
.parts
)):
1189 m
.d
.comb
+= finalout
.d16
[i
].eq(p_16
.parts
[i
])
1190 for i
in range(len(part_32
.parts
)):
1191 m
.d
.comb
+= finalout
.d32
[i
].eq(p_32
.parts
[i
])
1192 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
1193 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
1194 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
1195 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
1196 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
1201 if __name__
== "__main__":
1205 m
.intermediate_output
,
1208 *m
.part_pts
.values()])