1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
59 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
60 retval
= PartitionPoints()
61 for point
, enabled
in self
.items():
62 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
66 """Assign ``PartitionPoints`` using ``Signal.eq``."""
67 if set(self
.keys()) != set(rhs
.keys()):
68 raise ValueError("incompatible point set")
69 for point
, enabled
in self
.items():
70 yield enabled
.eq(rhs
[point
])
72 def as_mask(self
, width
):
73 """Create a bit-mask from `self`.
75 Each bit in the returned mask is clear only if the partition point at
76 the same bit-index is enabled.
78 :param width: the bit width of the resulting mask
81 for i
in range(width
):
88 def get_max_partition_count(self
, width
):
89 """Get the maximum number of partitions.
91 Gets the number of partitions when all partition points are enabled.
94 for point
in self
.keys():
99 def fits_in_width(self
, width
):
100 """Check if all partition points are smaller than `width`."""
101 for point
in self
.keys():
107 class FullAdder(Elaboratable
):
110 :attribute in0: the first input
111 :attribute in1: the second input
112 :attribute in2: the third input
113 :attribute sum: the sum output
114 :attribute carry: the carry output
116 Rather than do individual full adders (and have an array of them,
117 which would be very slow to simulate), this module can specify the
118 bit width of the inputs and outputs: in effect it performs multiple
119 Full 3-2 Add operations "in parallel".
122 def __init__(self
, width
):
123 """Create a ``FullAdder``.
125 :param width: the bit width of the input and output
127 self
.in0
= Signal(width
)
128 self
.in1
= Signal(width
)
129 self
.in2
= Signal(width
)
130 self
.sum = Signal(width
)
131 self
.carry
= Signal(width
)
133 def elaborate(self
, platform
):
134 """Elaborate this module."""
136 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
137 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
138 |
(self
.in1
& self
.in2
)
139 |
(self
.in2
& self
.in0
))
143 class PartitionedAdder(Elaboratable
):
144 """Partitioned Adder.
146 Performs the final add. The partition points are included in the
147 actual add (in one of the operands only), which causes a carry over
148 to the next bit. Then the final output *removes* the extra bits from
151 :attribute width: the bit width of the input and output. Read-only.
152 :attribute a: the first input to the adder
153 :attribute b: the second input to the adder
154 :attribute output: the sum output
155 :attribute partition_points: the input partition points. Modification not
156 supported, except for by ``Signal.eq``.
159 def __init__(self
, width
, partition_points
):
160 """Create a ``PartitionedAdder``.
162 :param width: the bit width of the input and output
163 :param partition_points: the input partition points
166 self
.a
= Signal(width
)
167 self
.b
= Signal(width
)
168 self
.output
= Signal(width
)
169 self
.partition_points
= PartitionPoints(partition_points
)
170 if not self
.partition_points
.fits_in_width(width
):
171 raise ValueError("partition_points doesn't fit in width")
173 for i
in range(self
.width
):
174 if i
in self
.partition_points
:
177 self
._expanded
_width
= expanded_width
178 # XXX these have to remain here due to some horrible nmigen
179 # simulation bugs involving sync. it is *not* necessary to
180 # have them here, they should (under normal circumstances)
181 # be moved into elaborate, as they are entirely local
182 self
._expanded
_a
= Signal(expanded_width
)
183 self
._expanded
_b
= Signal(expanded_width
)
184 self
._expanded
_output
= Signal(expanded_width
)
186 def elaborate(self
, platform
):
187 """Elaborate this module."""
190 # store bits in a list, use Cat later. graphviz is much cleaner
191 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
193 # partition points are "breaks" (extra zeros or 1s) in what would
194 # otherwise be a massive long add. when the "break" points are 0,
195 # whatever is in it (in the output) is discarded. however when
196 # there is a "1", it causes a roll-over carry to the *next* bit.
197 # we still ignore the "break" bit in the [intermediate] output,
198 # however by that time we've got the effect that we wanted: the
199 # carry has been carried *over* the break point.
201 for i
in range(self
.width
):
202 if i
in self
.partition_points
:
203 # add extra bit set to 0 + 0 for enabled partition points
204 # and 1 + 0 for disabled partition points
205 ea
.append(self
._expanded
_a
[expanded_index
])
206 al
.append(~self
.partition_points
[i
]) # add extra bit in a
207 eb
.append(self
._expanded
_b
[expanded_index
])
208 bl
.append(C(0)) # do *not* add extra bit into b.
210 ea
.append(self
._expanded
_a
[expanded_index
])
212 eb
.append(self
._expanded
_b
[expanded_index
])
214 eo
.append(self
._expanded
_output
[expanded_index
])
215 ol
.append(self
.output
[i
])
218 # combine above using Cat
219 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
220 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
221 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
223 # use only one addition to take advantage of look-ahead carry and
224 # special hardware on FPGAs
225 m
.d
.comb
+= self
._expanded
_output
.eq(
226 self
._expanded
_a
+ self
._expanded
_b
)
230 FULL_ADDER_INPUT_COUNT
= 3
233 class AddReduce(Elaboratable
):
234 """Add list of numbers together.
236 :attribute inputs: input ``Signal``s to be summed. Modification not
237 supported, except for by ``Signal.eq``.
238 :attribute register_levels: List of nesting levels that should have
240 :attribute output: output sum.
241 :attribute partition_points: the input partition points. Modification not
242 supported, except for by ``Signal.eq``.
245 def __init__(self
, inputs
, output_width
, register_levels
, partition_points
):
246 """Create an ``AddReduce``.
248 :param inputs: input ``Signal``s to be summed.
249 :param output_width: bit-width of ``output``.
250 :param register_levels: List of nesting levels that should have
252 :param partition_points: the input partition points.
254 self
.inputs
= list(inputs
)
255 self
._resized
_inputs
= [
256 Signal(output_width
, name
=f
"resized_inputs[{i}]")
257 for i
in range(len(self
.inputs
))]
258 self
.register_levels
= list(register_levels
)
259 self
.output
= Signal(output_width
)
260 self
.partition_points
= PartitionPoints(partition_points
)
261 if not self
.partition_points
.fits_in_width(output_width
):
262 raise ValueError("partition_points doesn't fit in output_width")
263 self
._reg
_partition
_points
= self
.partition_points
.like()
264 max_level
= AddReduce
.get_max_level(len(self
.inputs
))
265 for level
in self
.register_levels
:
266 if level
> max_level
:
268 "not enough adder levels for specified register levels")
271 def get_max_level(input_count
):
272 """Get the maximum level.
274 All ``register_levels`` must be less than or equal to the maximum
279 groups
= AddReduce
.full_adder_groups(input_count
)
282 input_count
%= FULL_ADDER_INPUT_COUNT
283 input_count
+= 2 * len(groups
)
286 def next_register_levels(self
):
287 """``Iterable`` of ``register_levels`` for next recursive level."""
288 for level
in self
.register_levels
:
293 def full_adder_groups(input_count
):
294 """Get ``inputs`` indices for which a full adder should be built."""
296 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
297 FULL_ADDER_INPUT_COUNT
)
299 def elaborate(self
, platform
):
300 """Elaborate this module."""
303 # resize inputs to correct bit-width and optionally add in
305 resized_input_assignments
= [self
._resized
_inputs
[i
].eq(self
.inputs
[i
])
306 for i
in range(len(self
.inputs
))]
307 if 0 in self
.register_levels
:
308 m
.d
.sync
+= resized_input_assignments
309 m
.d
.sync
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
311 m
.d
.comb
+= resized_input_assignments
312 m
.d
.comb
+= self
._reg
_partition
_points
.eq(self
.partition_points
)
314 groups
= AddReduce
.full_adder_groups(len(self
.inputs
))
315 # if there are no full adders to create, then we handle the base cases
316 # and return, otherwise we go on to the recursive case
318 if len(self
.inputs
) == 0:
319 # use 0 as the default output value
320 m
.d
.comb
+= self
.output
.eq(0)
321 elif len(self
.inputs
) == 1:
322 # handle single input
323 m
.d
.comb
+= self
.output
.eq(self
._resized
_inputs
[0])
325 # base case for adding 2 or more inputs, which get recursively
326 # reduced to 2 inputs
327 assert len(self
.inputs
) == 2
328 adder
= PartitionedAdder(len(self
.output
),
329 self
._reg
_partition
_points
)
330 m
.submodules
.final_adder
= adder
331 m
.d
.comb
+= adder
.a
.eq(self
._resized
_inputs
[0])
332 m
.d
.comb
+= adder
.b
.eq(self
._resized
_inputs
[1])
333 m
.d
.comb
+= self
.output
.eq(adder
.output
)
335 # go on to handle recursive case
336 intermediate_terms
= []
338 def add_intermediate_term(value
):
339 intermediate_term
= Signal(
341 name
=f
"intermediate_terms[{len(intermediate_terms)}]")
342 intermediate_terms
.append(intermediate_term
)
343 m
.d
.comb
+= intermediate_term
.eq(value
)
345 # store mask in intermediary (simplifies graph)
346 part_mask
= Signal(len(self
.output
), reset_less
=True)
347 mask
= self
._reg
_partition
_points
.as_mask(len(self
.output
))
348 m
.d
.comb
+= part_mask
.eq(mask
)
350 # create full adders for this recursive level.
351 # this shrinks N terms to 2 * (N // 3) plus the remainder
353 adder_i
= FullAdder(len(self
.output
))
354 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
355 m
.d
.comb
+= adder_i
.in0
.eq(self
._resized
_inputs
[i
])
356 m
.d
.comb
+= adder_i
.in1
.eq(self
._resized
_inputs
[i
+ 1])
357 m
.d
.comb
+= adder_i
.in2
.eq(self
._resized
_inputs
[i
+ 2])
358 add_intermediate_term(adder_i
.sum)
359 shifted_carry
= adder_i
.carry
<< 1
360 # mask out carry bits to prevent carries between partitions
361 add_intermediate_term((adder_i
.carry
<< 1) & part_mask
)
362 # handle the remaining inputs.
363 if len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 1:
364 add_intermediate_term(self
._resized
_inputs
[-1])
365 elif len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 2:
366 # Just pass the terms to the next layer, since we wouldn't gain
367 # anything by using a half adder since there would still be 2 terms
368 # and just passing the terms to the next layer saves gates.
369 add_intermediate_term(self
._resized
_inputs
[-2])
370 add_intermediate_term(self
._resized
_inputs
[-1])
372 assert len(self
.inputs
) % FULL_ADDER_INPUT_COUNT
== 0
373 # recursive invocation of ``AddReduce``
374 next_level
= AddReduce(intermediate_terms
,
376 self
.next_register_levels(),
377 self
._reg
_partition
_points
)
378 m
.submodules
.next_level
= next_level
379 m
.d
.comb
+= self
.output
.eq(next_level
.output
)
384 OP_MUL_SIGNED_HIGH
= 1
385 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
386 OP_MUL_UNSIGNED_HIGH
= 3
389 def get_term(value
, shift
=0, enabled
=None):
390 if enabled
is not None:
391 value
= Mux(enabled
, value
, 0)
393 value
= Cat(Repl(C(0, 1), shift
), value
)
399 class ProductTerm(Elaboratable
):
400 """ this class creates a single product term (a[..]*b[..]).
401 it has a design flaw in that is the *output* that is selected,
402 where the multiplication(s) are combinatorially generated
406 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
407 self
.a_index
= a_index
408 self
.b_index
= b_index
409 shift
= 8 * (self
.a_index
+ self
.b_index
)
415 self
.ti
= Signal(self
.width
, reset_less
=True)
416 self
.term
= Signal(twidth
, reset_less
=True)
417 self
.a
= Signal(twidth
//2, reset_less
=True)
418 self
.b
= Signal(twidth
//2, reset_less
=True)
419 self
.pb_en
= Signal(pbwid
, reset_less
=True)
422 min_index
= min(self
.a_index
, self
.b_index
)
423 max_index
= max(self
.a_index
, self
.b_index
)
424 for i
in range(min_index
, max_index
):
425 tl
.append(self
.pb_en
[i
])
426 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
428 term_enabled
= Signal(name
=name
, reset_less
=True)
431 self
.enabled
= term_enabled
432 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
434 def elaborate(self
, platform
):
437 if self
.enabled
is not None:
438 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
440 bsa
= Signal(self
.width
, reset_less
=True)
441 bsb
= Signal(self
.width
, reset_less
=True)
442 a_index
, b_index
= self
.a_index
, self
.b_index
444 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
445 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
446 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
447 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
449 #TODO: sort out width issues, get inputs a/b switched on/off.
450 #data going into Muxes is 1/2 the required width
454 bsa = Signal(self.twidth//2, reset_less=True)
455 bsb = Signal(self.twidth//2, reset_less=True)
456 asel = Signal(width, reset_less=True)
457 bsel = Signal(width, reset_less=True)
458 a_index, b_index = self.a_index, self.b_index
459 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
460 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
461 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
462 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
463 m.d.comb += self.ti.eq(bsa * bsb)
464 m.d.comb += self.term.eq(self.ti)
470 class ProductTerms(Elaboratable
):
471 """ creates a bank of product terms. also performs the actual bit-selection
472 this class is to be wrapped with a for-loop on the "a" operand.
473 it creates a second-level for-loop on the "b" operand.
475 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
476 self
.a_index
= a_index
481 self
.a
= Signal(twidth
//2, reset_less
=True)
482 self
.b
= Signal(twidth
//2, reset_less
=True)
483 self
.pb_en
= Signal(pbwid
, reset_less
=True)
484 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
485 for i
in range(blen
)]
487 def elaborate(self
, platform
):
491 for b_index
in range(self
.blen
):
492 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
493 self
.a_index
, b_index
)
494 setattr(m
.submodules
, "term_%d" % b_index
, t
)
496 m
.d
.comb
+= t
.a
.eq(self
.a
)
497 m
.d
.comb
+= t
.b
.eq(self
.b
)
498 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
500 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
504 class LSBNegTerm(Elaboratable
):
506 def __init__(self
, bit_width
):
507 self
.bit_width
= bit_width
508 self
.part
= Signal(reset_less
=True)
509 self
.signed
= Signal(reset_less
=True)
510 self
.op
= Signal(bit_width
, reset_less
=True)
511 self
.msb
= Signal(reset_less
=True)
512 self
.nt
= Signal(bit_width
*2, reset_less
=True)
513 self
.nl
= Signal(bit_width
*2, reset_less
=True)
515 def elaborate(self
, platform
):
518 bit_wid
= self
.bit_width
519 ext
= Repl(0, bit_wid
) # extend output to HI part
521 # determine sign of each incoming number *in this partition*
522 enabled
= Signal(reset_less
=True)
523 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
525 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
526 # negation operation is split into a bitwise not and a +1.
527 # likewise for 16, 32, and 64-bit values.
529 # width-extended 1s complement if a is signed, otherwise zero
530 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
532 # add 1 if signed, otherwise add zero
533 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
538 class Part(Elaboratable
):
539 """ a key class which, depending on the partitioning, will determine
540 what action to take when parts of the output are signed or unsigned.
542 this requires 2 pieces of data *per operand, per partition*:
543 whether the MSB is HI/LO (per partition!), and whether a signed
544 or unsigned operation has been *requested*.
546 once that is determined, signed is basically carried out
547 by splitting 2's complement into 1's complement plus one.
548 1's complement is just a bit-inversion.
550 the extra terms - as separate terms - are then thrown at the
551 AddReduce alongside the multiplication part-results.
553 def __init__(self
, width
, n_parts
, n_levels
, pbwid
):
558 self
.a_signed
= [Signal(name
=f
"a_signed_{i}") for i
in range(8)]
559 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}") for i
in range(8)]
560 self
.pbs
= Signal(pbwid
, reset_less
=True)
563 self
.parts
= [Signal(name
=f
"part_{i}") for i
in range(n_parts
)]
564 self
.delayed_parts
= [
565 [Signal(name
=f
"delayed_part_{delay}_{i}")
566 for i
in range(n_parts
)]
567 for delay
in range(n_levels
)]
568 # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts
569 self
.dplast
= [Signal(name
=f
"dplast_{i}")
570 for i
in range(n_parts
)]
572 self
.not_a_term
= Signal(width
)
573 self
.neg_lsb_a_term
= Signal(width
)
574 self
.not_b_term
= Signal(width
)
575 self
.neg_lsb_b_term
= Signal(width
)
577 def elaborate(self
, platform
):
580 pbs
, parts
, delayed_parts
= self
.pbs
, self
.parts
, self
.delayed_parts
581 # negated-temporary copy of partition bits
582 npbs
= Signal
.like(pbs
, reset_less
=True)
583 m
.d
.comb
+= npbs
.eq(~pbs
)
584 byte_count
= 8 // len(parts
)
585 for i
in range(len(parts
)):
587 pbl
.append(npbs
[i
* byte_count
- 1])
588 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
590 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
591 value
= Signal(len(pbl
), name
="value_%di" % i
, reset_less
=True)
592 m
.d
.comb
+= value
.eq(Cat(*pbl
))
593 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
594 m
.d
.comb
+= delayed_parts
[0][i
].eq(parts
[i
])
595 m
.d
.sync
+= [delayed_parts
[j
+ 1][i
].eq(delayed_parts
[j
][i
])
596 for j
in range(len(delayed_parts
)-1)]
597 m
.d
.comb
+= self
.dplast
[i
].eq(delayed_parts
[-1][i
])
599 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= \
600 self
.not_a_term
, self
.neg_lsb_a_term
, \
601 self
.not_b_term
, self
.neg_lsb_b_term
603 byte_width
= 8 // len(parts
) # byte width
604 bit_wid
= 8 * byte_width
# bit width
605 nat
, nbt
, nla
, nlb
= [], [], [], []
606 for i
in range(len(parts
)):
607 # work out bit-inverted and +1 term for a.
608 pa
= LSBNegTerm(bit_wid
)
609 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
610 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
611 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
612 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
613 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
617 # work out bit-inverted and +1 term for b
618 pb
= LSBNegTerm(bit_wid
)
619 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
620 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
621 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
622 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
623 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
627 # concatenate together and return all 4 results.
628 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
629 not_b_term
.eq(Cat(*nbt
)),
630 neg_lsb_a_term
.eq(Cat(*nla
)),
631 neg_lsb_b_term
.eq(Cat(*nlb
)),
637 class IntermediateOut(Elaboratable
):
638 """ selects the HI/LO part of the multiplication, for a given bit-width
639 the output is also reconstructed in its SIMD (partition) lanes.
641 def __init__(self
, width
, out_wid
, n_parts
):
643 self
.n_parts
= n_parts
644 self
.delayed_part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
646 self
.intermed
= Signal(out_wid
, reset_less
=True)
647 self
.output
= Signal(out_wid
//2, reset_less
=True)
649 def elaborate(self
, platform
):
655 for i
in range(self
.n_parts
):
656 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
658 Mux(self
.delayed_part_ops
[sel
* i
] == OP_MUL_LOW
,
659 self
.intermed
.bit_select(i
* w
*2, w
),
660 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
662 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
667 class FinalOut(Elaboratable
):
668 """ selects the final output based on the partitioning.
670 each byte is selectable independently, i.e. it is possible
671 that some partitions requested 8-bit computation whilst others
672 requested 16 or 32 bit.
674 def __init__(self
, out_wid
):
676 self
.d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
677 self
.d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
678 self
.d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
680 self
.i8
= Signal(out_wid
, reset_less
=True)
681 self
.i16
= Signal(out_wid
, reset_less
=True)
682 self
.i32
= Signal(out_wid
, reset_less
=True)
683 self
.i64
= Signal(out_wid
, reset_less
=True)
686 self
.out
= Signal(out_wid
, reset_less
=True)
688 def elaborate(self
, platform
):
692 # select one of the outputs: d8 selects i8, d16 selects i16
693 # d32 selects i32, and the default is i64.
694 # d8 and d16 are ORed together in the first Mux
695 # then the 2nd selects either i8 or i16.
696 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
697 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
699 Mux(self
.d8
[i
] | self
.d16
[i
// 2],
700 Mux(self
.d8
[i
], self
.i8
.bit_select(i
* 8, 8),
701 self
.i16
.bit_select(i
* 8, 8)),
702 Mux(self
.d32
[i
// 4], self
.i32
.bit_select(i
* 8, 8),
703 self
.i64
.bit_select(i
* 8, 8))))
705 m
.d
.comb
+= self
.out
.eq(Cat(*ol
))
709 class OrMod(Elaboratable
):
710 """ ORs four values together in a hierarchical tree
712 def __init__(self
, wid
):
714 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
716 self
.orout
= Signal(wid
, reset_less
=True)
718 def elaborate(self
, platform
):
720 or1
= Signal(self
.wid
, reset_less
=True)
721 or2
= Signal(self
.wid
, reset_less
=True)
722 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
723 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
724 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
729 class Signs(Elaboratable
):
730 """ determines whether a or b are signed numbers
731 based on the required operation type (OP_MUL_*)
735 self
.part_ops
= Signal(2, reset_less
=True)
736 self
.a_signed
= Signal(reset_less
=True)
737 self
.b_signed
= Signal(reset_less
=True)
739 def elaborate(self
, platform
):
743 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
744 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
745 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
746 m
.d
.comb
+= self
.a_signed
.eq(asig
)
747 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
752 class Mul8_16_32_64(Elaboratable
):
753 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
755 Supports partitioning into any combination of 8, 16, 32, and 64-bit
756 partitions on naturally-aligned boundaries. Supports the operation being
757 set for each partition independently.
759 :attribute part_pts: the input partition points. Has a partition point at
760 multiples of 8 in 0 < i < 64. Each partition point's associated
761 ``Value`` is a ``Signal``. Modification not supported, except for by
763 :attribute part_ops: the operation for each byte. The operation for a
764 particular partition is selected by assigning the selected operation
765 code to each byte in the partition. The allowed operation codes are:
767 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
768 RISC-V's `mul` instruction.
769 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
770 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
772 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
773 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
774 `mulhsu` instruction.
775 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
776 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
780 def __init__(self
, register_levels
=()):
781 """ register_levels: specifies the points in the cascade at which
782 flip-flops are to be inserted.
786 self
.register_levels
= list(register_levels
)
789 self
.part_pts
= PartitionPoints()
790 for i
in range(8, 64, 8):
791 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
792 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
796 # intermediates (needed for unit tests)
797 self
._intermediate
_output
= Signal(128)
800 self
.output
= Signal(64)
802 def _part_byte(self
, index
):
803 if index
== -1 or index
== 7:
805 assert index
>= 0 and index
< 8
806 return self
.part_pts
[index
* 8 + 8]
808 def elaborate(self
, platform
):
812 pbs
= Signal(8, reset_less
=True)
815 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
816 m
.d
.comb
+= pb
.eq(self
._part
_byte
(i
))
818 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
825 setattr(m
.submodules
, "signs%d" % i
, s
)
826 m
.d
.comb
+= s
.part_ops
.eq(self
.part_ops
[i
])
829 [Signal(2, name
=f
"_delayed_part_ops_{delay}_{i}")
831 for delay
in range(1 + len(self
.register_levels
))]
832 for i
in range(len(self
.part_ops
)):
833 m
.d
.comb
+= delayed_part_ops
[0][i
].eq(self
.part_ops
[i
])
834 m
.d
.sync
+= [delayed_part_ops
[j
+ 1][i
].eq(delayed_part_ops
[j
][i
])
835 for j
in range(len(self
.register_levels
))]
837 n_levels
= len(self
.register_levels
)+1
838 m
.submodules
.part_8
= part_8
= Part(128, 8, n_levels
, 8)
839 m
.submodules
.part_16
= part_16
= Part(128, 4, n_levels
, 8)
840 m
.submodules
.part_32
= part_32
= Part(128, 2, n_levels
, 8)
841 m
.submodules
.part_64
= part_64
= Part(128, 1, n_levels
, 8)
842 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
843 for mod
in [part_8
, part_16
, part_32
, part_64
]:
844 m
.d
.comb
+= mod
.a
.eq(self
.a
)
845 m
.d
.comb
+= mod
.b
.eq(self
.b
)
846 for i
in range(len(signs
)):
847 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
848 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
849 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
850 nat_l
.append(mod
.not_a_term
)
851 nbt_l
.append(mod
.not_b_term
)
852 nla_l
.append(mod
.neg_lsb_a_term
)
853 nlb_l
.append(mod
.neg_lsb_b_term
)
857 for a_index
in range(8):
858 t
= ProductTerms(8, 128, 8, a_index
, 8)
859 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
861 m
.d
.comb
+= t
.a
.eq(self
.a
)
862 m
.d
.comb
+= t
.b
.eq(self
.b
)
863 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
868 # it's fine to bitwise-or data together since they are never enabled
870 m
.submodules
.nat_or
= nat_or
= OrMod(128)
871 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
872 m
.submodules
.nla_or
= nla_or
= OrMod(128)
873 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
874 for l
, mod
in [(nat_l
, nat_or
),
878 for i
in range(len(l
)):
879 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
880 terms
.append(mod
.orout
)
882 expanded_part_pts
= PartitionPoints()
883 for i
, v
in self
.part_pts
.items():
884 signal
= Signal(name
=f
"expanded_part_pts_{i*2}", reset_less
=True)
885 expanded_part_pts
[i
* 2] = signal
886 m
.d
.comb
+= signal
.eq(v
)
888 add_reduce
= AddReduce(terms
,
890 self
.register_levels
,
892 m
.submodules
.add_reduce
= add_reduce
893 m
.d
.comb
+= self
._intermediate
_output
.eq(add_reduce
.output
)
895 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
896 m
.d
.comb
+= io64
.intermed
.eq(self
._intermediate
_output
)
898 m
.d
.comb
+= io64
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
901 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
902 m
.d
.comb
+= io32
.intermed
.eq(self
._intermediate
_output
)
904 m
.d
.comb
+= io32
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
907 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
908 m
.d
.comb
+= io16
.intermed
.eq(self
._intermediate
_output
)
910 m
.d
.comb
+= io16
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
913 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
914 m
.d
.comb
+= io8
.intermed
.eq(self
._intermediate
_output
)
916 m
.d
.comb
+= io8
.delayed_part_ops
[i
].eq(delayed_part_ops
[-1][i
])
919 m
.submodules
.finalout
= finalout
= FinalOut(64)
920 for i
in range(len(part_8
.delayed_parts
[-1])):
921 m
.d
.comb
+= finalout
.d8
[i
].eq(part_8
.dplast
[i
])
922 for i
in range(len(part_16
.delayed_parts
[-1])):
923 m
.d
.comb
+= finalout
.d16
[i
].eq(part_16
.dplast
[i
])
924 for i
in range(len(part_32
.delayed_parts
[-1])):
925 m
.d
.comb
+= finalout
.d32
[i
].eq(part_32
.dplast
[i
])
926 m
.d
.comb
+= finalout
.i8
.eq(io8
.output
)
927 m
.d
.comb
+= finalout
.i16
.eq(io16
.output
)
928 m
.d
.comb
+= finalout
.i32
.eq(io32
.output
)
929 m
.d
.comb
+= finalout
.i64
.eq(io64
.output
)
930 m
.d
.comb
+= self
.output
.eq(finalout
.out
)
935 if __name__
== "__main__":
939 m
._intermediate
_output
,
942 *m
.part_pts
.values()])