1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class PartitionedAdder(Elaboratable
):
144 """Partitioned Adder.
146 :attribute width: the bit width of the input and output. Read-only.
147 :attribute a: the first input to the adder
148 :attribute b: the second input to the adder
149 :attribute output: the sum output
150 :attribute partition_points: the input partition points. Modification not
151 supported, except for by ``Signal.eq``.
154 def __init__(self
, width
, partition_points
):
155 """Create a ``PartitionedAdder``.
157 :param width: the bit width of the input and output
158 :param partition_points: the input partition points
161 self
.a
= Signal(width
)
162 self
.b
= Signal(width
)
163 self
.output
= Signal(width
)
164 self
.partition_points
= PartitionPoints(partition_points
)
165 if not self
.partition_points
.fits_in_width(width
):
166 raise ValueError("partition_points doesn't fit in width")
168 for i
in range(self
.width
):
169 if i
in self
.partition_points
:
172 self
._expanded
_width
= expanded_width
173 # XXX these have to remain here due to some horrible nmigen
174 # simulation bugs involving sync. it is *not* necessary to
175 # have them here, they should (under normal circumstances)
176 # be moved into elaborate, as they are entirely local
177 self
._expanded
_a
= Signal(expanded_width
)
178 self
._expanded
_b
= Signal(expanded_width
)
179 self
._expanded
_output
= Signal(expanded_width
)
181 def elaborate(self
, platform
):
182 """Elaborate this module."""
185 # store bits in a list, use Cat later. graphviz is much cleaner
186 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
188 # partition points are "breaks" (extra zeros) in what would otherwise
189 # be a massive long add.
190 for i
in range(self
.width
):
191 if i
in self
.partition_points
:
192 # add extra bit set to 0 + 0 for enabled partition points
193 # and 1 + 0 for disabled partition points
194 ea
.append(self
._expanded
_a
[expanded_index
])
195 al
.append(~self
.partition_points
[i
])
196 eb
.append(self
._expanded
_b
[expanded_index
])
199 ea
.append(self
._expanded
_a
[expanded_index
])
201 eb
.append(self
._expanded
_b
[expanded_index
])
203 eo
.append(self
._expanded
_output
[expanded_index
])
204 ol
.append(self
.output
[i
])
206 # combine above using Cat
207 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
208 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
209 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
210 # use only one addition to take advantage of look-ahead carry and
211 # special hardware on FPGAs
212 m
.d
.comb
+= self
._expanded
_output
.eq(
213 self
._expanded
_a
+ self
._expanded
_b
)
217 FULL_ADDER_INPUT_COUNT
= 3
220 class AddReduce(Elaboratable
):
221 """Add list of numbers together.
223 :attribute inputs: input ``Signal``s to be summed. Modification not
224 supported, except for by ``Signal.eq``.
225 :attribute register_levels: List of nesting levels that should have
227 :attribute output: output sum.
228 :attribute partition_points: the input partition points. Modification not
229 supported, except for by ``Signal.eq``.
232 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
233 """Create an ``AddReduce``.
235 :param inputs: input ``Signal``s to be summed.
236 :param output_width: bit-width of ``output``.
237 :param register_levels: List of nesting levels that should have
239 :param partition_points: the input partition points.
241 self
.inputs
= list(inputs
)
242 self
._resized
_inputs
= [
243 Signal(output_width
, name
=f
"resized_inputs[{i}]")
244 for i
in range(len(self
.inputs
))]
245 self
.register_levels
= list(register_levels
)
246 self
.output
= Signal(output_width
)
247 self
.partition_points
= PartitionPoints(partition_points
)
248 if not self
.partition_points
.fits_in_width(output_width
):
249 raise ValueError("partition_points doesn't fit in output_width")
250 self
._reg
_partition
_points
= self
.partition_points
.like()
251 max_level
= AddReduce
.get_max_level(len(self
.inputs
))
252 for level
in self
.register_levels
:
253 if level
> max_level
:
255 "not enough adder levels for specified register levels")
258 def get_max_level(input_count
):
259 """Get the maximum level.
261 All ``register_levels`` must be less than or equal to the maximum
266 groups
= AddReduce
.full_adder_groups(input_count
)
269 input_count
%= FULL_ADDER_INPUT_COUNT
270 input_count
+= 2 * len(groups
)
273 def next_register_levels(self
):
274 """``Iterable`` of ``register_levels`` for next recursive level."""
275 for level
in self
.register_levels
:
280 def full_adder_groups(input_count
):
281 """Get ``inputs`` indices for which a full adder should be built."""
283 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
284 FULL_ADDER_INPUT_COUNT
)
286 def elaborate(self
, platform
):
287 """Elaborate this module."""
290 # resize inputs to correct bit-width and optionally add in
292 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
293 for i
in range(len(self
.inputs
))]
294 if 0 in self
.register_levels
:
295 m
.d
.sync
+= resized_input_assignments
296 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
298 m
.d
.comb
+= resized_input_assignments
299 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
301 groups
= AddReduce
.full_adder_groups(len(self
.inputs
))
302 # if there are no full adders to create, then we handle the base cases
303 # and return, otherwise we go on to the recursive case
305 if len(self
.inputs
) == 0:
306 # use 0 as the default output value
307 m
.d
.comb
+= self
.output
.eq(0)
308 elif len(self
.inputs
) == 1:
309 # handle single input
310 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
312 # base case for adding 2 or more inputs, which get recursively
313 # reduced to 2 inputs
314 assert len(self
.inputs
) == 2
315 adder
= PartitionedAdder(len(self
.output
),
316 self
._reg
_partition
_points
)
317 m
.submodules
.final_adder
= adder
318 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
319 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
320 m
.d
.comb
+= self
.output
.eq(adder
.output
)
322 # go on to handle recursive case
323 intermediate_terms
= []
325 def add_intermediate_term(value
):
326 intermediate_term
= Signal(
328 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
329 intermediate_terms
.append(intermediate_term
)
330 m
.d
.comb
+= intermediate_term
.eq(value
)
332 # store mask in intermediary (simplifies graph)
333 part_mask
= Signal(len(self
.output
), reset_less
=True)
334 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
335 m
.d
.comb
+= part_mask
.eq(mask
)
337 # create full adders for this recursive level.
338 # this shrinks N terms to 2 * (N // 3) plus the remainder
340 adder_i
= FullAdder(len(self
.output
))
341 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
342 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[i
])
343 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[i
+ 1])
344 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[i
+ 2])
345 add_intermediate_term(adder_i
.sum)
346 shifted_carry
= adder_i
.carry
<< 1
347 # mask out carry bits to prevent carries between partitions
348 add_intermediate_term((adder_i
.carry
<< 1) & part_mask
)
349 # handle the remaining inputs.
350 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
351 add_intermediate_term(self
._resized
_inputs
[-1])
352 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
353 # Just pass the terms to the next layer, since we wouldn't gain
354 # anything by using a half adder since there would still be 2 terms
355 # and just passing the terms to the next layer saves gates.
356 add_intermediate_term(self
._resized
_inputs
[-2])
357 add_intermediate_term(self
._resized
_inputs
[-1])
359 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
360 # recursive invocation of ``AddReduce``
361 next_level
= AddReduce(intermediate_terms
,
363 self
.next_register_levels(),
364 self
._reg
_partition
_points
)
365 m
.submodules
.next_level
= next_level
366 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
371 OP_MUL_SIGNED_HIGH
= 1
372 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
373 OP_MUL_UNSIGNED_HIGH
= 3
376 def get_term(value
, shift
=0, enabled
=None):
377 if enabled
is not None:
378 value
= Mux(enabled
, value
, 0)
380 value
= Cat(Repl(C(0, 1), shift
), value
)
386 class ProductTerm(Elaboratable
):
387 """ this class creates a single product term (a[..]*b[..]).
388 it has a design flaw in that is the *output* that is selected,
389 where the multiplication(s) are combinatorially generated
393 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
394 self
.a_index
= a_index
395 self
.b_index
= b_index
396 shift
= 8 * (self
.a_index
+ self
.b_index
)
402 self
.ti
= Signal(self
.width
, reset_less
=True)
403 self
.term
= Signal(twidth
, reset_less
=True)
404 self
.a
= Signal(twidth
//2, reset_less
=True)
405 self
.b
= Signal(twidth
//2, reset_less
=True)
406 self
.pb_en
= Signal(pbwid
, reset_less
=True)
409 min_index
= min(self
.a_index
, self
.b_index
)
410 max_index
= max(self
.a_index
, self
.b_index
)
411 for i
in range(min_index
, max_index
):
412 tl
.append(self
.pb_en
[i
])
413 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
415 term_enabled
= Signal(name
=name
, reset_less
=True)
418 self
.enabled
= term_enabled
419 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
421 def elaborate(self
, platform
):
424 if self
.enabled
is not None:
425 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
427 bsa
= Signal(self
.width
, reset_less
=True)
428 bsb
= Signal(self
.width
, reset_less
=True)
429 a_index
, b_index
= self
.a_index
, self
.b_index
431 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
432 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
433 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
434 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
436 #TODO: sort out width issues, get inputs a/b switched on/off.
437 #data going into Muxes is 1/2 the required width
441 bsa = Signal(self.twidth//2, reset_less=True)
442 bsb = Signal(self.twidth//2, reset_less=True)
443 asel = Signal(width, reset_less=True)
444 bsel = Signal(width, reset_less=True)
445 a_index, b_index = self.a_index, self.b_index
446 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
447 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
448 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
449 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
450 m.d.comb += self.ti.eq(bsa * bsb)
451 m.d.comb += self.term.eq(self.ti)
457 class ProductTerms(Elaboratable
):
458 """ creates a bank of product terms. also performs the actual bit-selection
459 this class is to be wrapped with a for-loop on the "a" operand.
460 it creates a second-level for-loop on the "b" operand.
462 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
463 self
.a_index
= a_index
468 self
.a
= Signal(twidth
//2, reset_less
=True)
469 self
.b
= Signal(twidth
//2, reset_less
=True)
470 self
.pb_en
= Signal(pbwid
, reset_less
=True)
471 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
472 for i
in range(blen
)]
474 def elaborate(self
, platform
):
478 for b_index
in range(self
.blen
):
479 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
480 self
.a_index
, b_index
)
481 setattr(m
.submodules
, "term_%d" % b_index
, t
)
483 m
.d
.comb
+= t
.a
.eq(self
.a
)
484 m
.d
.comb
+= t
.b
.eq(self
.b
)
485 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
487 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
491 class LSBNegTerm(Elaboratable
):
493 def __init__(self
, bit_width
):
494 self
.bit_width
= bit_width
495 self
.part
= Signal(reset_less
=True)
496 self
.signed
= Signal(reset_less
=True)
497 self
.op
= Signal(bit_width
, reset_less
=True)
498 self
.msb
= Signal(reset_less
=True)
499 self
.nt
= Signal(bit_width
*2, reset_less
=True)
500 self
.nl
= Signal(bit_width
*2, reset_less
=True)
502 def elaborate(self
, platform
):
505 bit_wid
= self
.bit_width
506 ext
= Repl(0, bit_wid
) # extend output to HI part
508 # determine sign of each incoming number *in this partition*
509 enabled
= Signal(reset_less
=True)
510 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
512 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
513 # negation operation is split into a bitwise not and a +1.
514 # likewise for 16, 32, and 64-bit values.
516 # width-extended 1s complement if a is signed, otherwise zero
517 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
519 # add 1 if signed, otherwise add zero
520 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
525 class Part(Elaboratable
):
526 """ a key class which, depending on the partitioning, will determine
527 what action to take when parts of the output are signed or unsigned.
529 this requires 2 pieces of data *per operand, per partition*:
530 whether the MSB is HI/LO (per partition!), and whether a signed
531 or unsigned operation has been *requested*.
533 once that is determined, signed is basically carried out
534 by splitting 2's complement into 1's complement plus one.
535 1's complement is just a bit-inversion.
537 the extra terms - as separate terms - are then thrown at the
538 AddReduce alongside the multiplication part-results.
540 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
545 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
546 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
547 self
.pbs
= Signal(pbwid
, reset_less
=True)
550 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
551 self
.delayed_parts
= [
552 [Signal(name
=f
"delayed_part_{delay}_{i}")
553 for i
in range(n_parts
)]
554 for delay
in range(n_levels
)]
555 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
556 self
.dplast
= [Signal(name
=f
"dplast_{i}")
557 for i
in range(n_parts
)]
559 self
.not_a_term
= Signal(width
)
560 self
.neg_lsb_a_term
= Signal(width
)
561 self
.not_b_term
= Signal(width
)
562 self
.neg_lsb_b_term
= Signal(width
)
564 def elaborate(self
, platform
):
567 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
568 # negated-temporary copy of partition bits
569 npbs
= Signal
.like(pbs
, reset_less
=True)
570 m
.d
.comb
+= npbs
.eq(~pbs
)
571 byte_count
= 8 // len(parts
)
572 for i
in range(len(parts
)):
574 pbl
.append(npbs
[i
* byte_count
- 1])
575 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
577 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
578 value
= Signal(len(pbl
), name
="value_$i" % i
, reset_less
=True)
579 m
.d
.comb
+= value
.eq(Cat(*pbl
))
580 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
581 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
582 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
583 for j
in range(len(delayed_parts
)-1)]
584 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
586 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
587 self
.not_a_term
, self
.neg_lsb_a_term
, \
588 self
.not_b_term
, self
.neg_lsb_b_term
590 byte_width
= 8 // len(parts
) # byte width
591 bit_wid
= 8 * byte_width
# bit width
592 nat
, nbt
, nla
, nlb
= [], [], [], []
593 for i
in range(len(parts
)):
594 # work out bit-inverted and +1 term for a.
595 pa
= LSBNegTerm(bit_wid
)
596 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
597 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
598 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
599 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
600 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
604 # work out bit-inverted and +1 term for b
605 pb
= LSBNegTerm(bit_wid
)
606 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
607 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
608 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
609 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
610 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
614 # concatenate together and return all 4 results.
615 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
616 not_b_term
.eq(Cat(*nbt
)),
617 neg_lsb_a_term
.eq(Cat(*nla
)),
618 neg_lsb_b_term
.eq(Cat(*nlb
)),
624 class IntermediateOut(Elaboratable
):
625 """ selects the HI/LO part of the multiplication, for a given bit-width
626 the output is also reconstructed in its SIMD (partition) lanes.
628 def __init__(self
, width
, out_wid
, n_parts
):
630 self
.n_parts
= n_parts
631 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
633 self
.intermed
= Signal(out_wid
, reset_less
=True)
634 self
.output
= Signal(out_wid
//2, reset_less
=True)
636 def elaborate(self
, platform
):
642 for i
in range(self
.n_parts
):
643 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
645 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
646 self
.intermed
.bit_select(i
* w
*2, w
),
647 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
649 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
654 class FinalOut(Elaboratable
):
655 """ selects the final output based on the partitioning.
657 each byte is selectable independently, i.e. it is possible
658 that some partitions requested 8-bit computation whilst others
659 requested 16 or 32 bit.
661 def __init__(self
, out_wid
):
663 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
664 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
665 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
667 self
.i8
= Signal(out_wid
, reset_less
=True)
668 self
.i16
= Signal(out_wid
, reset_less
=True)
669 self
.i32
= Signal(out_wid
, reset_less
=True)
670 self
.i64
= Signal(out_wid
, reset_less
=True)
673 self
.out
= Signal(out_wid
, reset_less
=True)
675 def elaborate(self
, platform
):
679 # select one of the outputs: d8 selects i8, d16 selects i16
680 # d32 selects i32, and the default is i64.
681 # d8 and d16 are ORed together in the first Mux
682 # then the 2nd selects either i8 or i16.
683 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
684 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
686 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
687 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
688 self
.i16
.bit_select(i
* 8, 8)),
689 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
690 self
.i64
.bit_select(i
* 8, 8))))
692 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
696 class OrMod(Elaboratable
):
697 """ ORs four values together in a hierarchical tree
699 def __init__(self
, wid
):
701 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
703 self
.orout
= Signal(wid
, reset_less
=True)
705 def elaborate(self
, platform
):
707 or1
= Signal(self
.wid
, reset_less
=True)
708 or2
= Signal(self
.wid
, reset_less
=True)
709 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
710 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
711 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
716 class Signs(Elaboratable
):
717 """ determines whether a or b are signed numbers
718 based on the required operation type (OP_MUL_*)
722 self
.part_ops
= Signal(2, reset_less
=True)
723 self
.a_signed
= Signal(reset_less
=True)
724 self
.b_signed
= Signal(reset_less
=True)
726 def elaborate(self
, platform
):
730 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
731 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
732 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
733 m
.d
.comb
+= self
.a_signed
.eq(asig
)
734 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
739 class Mul8_16_32_64(Elaboratable
):
740 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
742 Supports partitioning into any combination of 8, 16, 32, and 64-bit
743 partitions on naturally-aligned boundaries. Supports the operation being
744 set for each partition independently.
746 :attribute part_pts: the input partition points. Has a partition point at
747 multiples of 8 in 0 < i < 64. Each partition point's associated
748 ``Value`` is a ``Signal``. Modification not supported, except for by
750 :attribute part_ops: the operation for each byte. The operation for a
751 particular partition is selected by assigning the selected operation
752 code to each byte in the partition. The allowed operation codes are:
754 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
755 RISC-V's `mul` instruction.
756 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
757 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
759 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
760 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
761 `mulhsu` instruction.
762 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
763 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
767 def __init__(self
, register_levels
=()):
768 """ register_levels: specifies the points in the cascade at which
769 flip-flops are to be inserted.
773 self
.register_levels
= list(register_levels
)
776 self
.part_pts
= PartitionPoints()
777 for i
in range(8, 64, 8):
778 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
779 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
783 # intermediates (needed for unit tests)
784 self
._intermediate
_output
= Signal(128)
787 self
.output
= Signal(64)
789 def _part_byte(self
, index
):
790 if index
== -1 or index
== 7:
792 assert index
>= 0 and index
< 8
793 return self
.part_pts
[index
* 8 + 8]
795 def elaborate(self
, platform
):
799 pbs
= Signal(8, reset_less
=True)
802 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
803 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
805 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
812 setattr(m
.submodules
, "signs%d" % i
, s
)
813 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
816 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
818 for delay
in range(1 + len(self
.register_levels
))]
819 for i
in range(len(self
.part_ops
)):
820 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
821 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
822 for j
in range(len(self
.register_levels
))]
824 n_levels
= len(self
.register_levels
)+1
825 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
826 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
827 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
828 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
829 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
830 for mod
in [part_8
, part_16
, part_32
, part_64
]:
831 m
.d
.comb
+= mod
.a
.eq(self
.a
)
832 m
.d
.comb
+= mod
.b
.eq(self
.b
)
833 for i
in range(len(signs
)):
834 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
835 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
836 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
837 nat_l
.append(mod
.not_a_term
)
838 nbt_l
.append(mod
.not_b_term
)
839 nla_l
.append(mod
.neg_lsb_a_term
)
840 nlb_l
.append(mod
.neg_lsb_b_term
)
844 for a_index
in range(8):
845 t
= ProductTerms(8, 128, 8, a_index
, 8)
846 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
848 m
.d
.comb
+= t
.a
.eq(self
.a
)
849 m
.d
.comb
+= t
.b
.eq(self
.b
)
850 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
855 # it's fine to bitwise-or data together since they are never enabled
857 m
.submodules
.nat_or
= nat_or
= OrMod(128)
858 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
859 m
.submodules
.nla_or
= nla_or
= OrMod(128)
860 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
861 for l
, mod
in [(nat_l
, nat_or
),
865 for i
in range(len(l
)):
866 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
867 terms
.append(mod
.orout
)
869 expanded_part_pts
= PartitionPoints()
870 for i
, v
in self
.part_pts
.items():
871 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
872 expanded_part_pts
[i
* 2] = signal
873 m
.d
.comb
+= signal
.eq(v
)
875 add_reduce
= AddReduce(terms
,
877 self
.register_levels
,
879 m
.submodules
.add_reduce
= add_reduce
880 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
882 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
883 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
885 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
888 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
889 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
891 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
894 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
895 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
897 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
900 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
901 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
903 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
906 m
.submodules
.finalout
= finalout
= FinalOut(64)
907 for i
in range(len(part_8
.delayed_parts
[-1])):
908 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
909 for i
in range(len(part_16
.delayed_parts
[-1])):
910 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
911 for i
in range(len(part_32
.delayed_parts
[-1])):
912 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
913 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
914 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
915 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
916 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
917 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
922 if __name__
== "__main__":
926 m
._intermediate
_output
,
929 *m
.part_pts
.values()])