1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
13 class PartitionPoints(dict):
14 """Partition points and corresponding ``Value``s.
16 The points at where an ALU is partitioned along with ``Value``s that
17 specify if the corresponding partition points are enabled.
19 For example: ``{1: True, 5: True, 10: True}`` with
20 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 5 <= ``i`` < 10
24 * bits 10 <= ``i`` < 16
26 If the partition_points were instead ``{1: True, 5: a, 10: True}``
27 where ``a`` is a 1-bit ``Signal``:
28 * If ``a`` is asserted:
31 * bits 5 <= ``i`` < 10
32 * bits 10 <= ``i`` < 16
35 * bits 1 <= ``i`` < 10
36 * bits 10 <= ``i`` < 16
39 def __init__(self
, partition_points
=None):
40 """Create a new ``PartitionPoints``.
42 :param partition_points: the input partition points to values mapping.
45 if partition_points
is not None:
46 for point
, enabled
in partition_points
.items():
47 if not isinstance(point
, int):
48 raise TypeError("point must be a non-negative integer")
50 raise ValueError("point must be a non-negative integer")
51 self
[point
] = Value
.wrap(enabled
)
53 def like(self
, name
=None, src_loc_at
=0, mul
=1):
54 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
56 :param name: the base name for the new ``Signal``s.
57 :param mul: a multiplication factor on the indices
60 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
61 retval
= PartitionPoints()
62 for point
, enabled
in self
.items():
64 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
68 """Assign ``PartitionPoints`` using ``Signal.eq``."""
69 if set(self
.keys()) != set(rhs
.keys()):
70 raise ValueError("incompatible point set")
71 for point
, enabled
in self
.items():
72 yield enabled
.eq(rhs
[point
])
74 def as_mask(self
, width
, mul
=1):
75 """Create a bit-mask from `self`.
77 Each bit in the returned mask is clear only if the partition point at
78 the same bit-index is enabled.
80 :param width: the bit width of the resulting mask
81 :param mul: a "multiplier" which in-place expands the partition points
82 typically set to "2" when used for multipliers
85 for i
in range(width
):
87 if i
.is_integer() and int(i
) in self
:
93 def get_max_partition_count(self
, width
):
94 """Get the maximum number of partitions.
96 Gets the number of partitions when all partition points are enabled.
99 for point
in self
.keys():
104 def fits_in_width(self
, width
):
105 """Check if all partition points are smaller than `width`."""
106 for point
in self
.keys():
111 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
112 if index
== -1 or index
== 7:
114 assert index
>= 0 and index
< 8
115 return self
[(index
* 8 + 8)*mfactor
]
118 class FullAdder(Elaboratable
):
121 :attribute in0: the first input
122 :attribute in1: the second input
123 :attribute in2: the third input
124 :attribute sum: the sum output
125 :attribute carry: the carry output
127 Rather than do individual full adders (and have an array of them,
128 which would be very slow to simulate), this module can specify the
129 bit width of the inputs and outputs: in effect it performs multiple
130 Full 3-2 Add operations "in parallel".
133 def __init__(self
, width
):
134 """Create a ``FullAdder``.
136 :param width: the bit width of the input and output
138 self
.in0
= Signal(width
, reset_less
=True)
139 self
.in1
= Signal(width
, reset_less
=True)
140 self
.in2
= Signal(width
, reset_less
=True)
141 self
.sum = Signal(width
, reset_less
=True)
142 self
.carry
= Signal(width
, reset_less
=True)
144 def elaborate(self
, platform
):
145 """Elaborate this module."""
147 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
148 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
149 |
(self
.in1
& self
.in2
)
150 |
(self
.in2
& self
.in0
))
154 class MaskedFullAdder(Elaboratable
):
155 """Masked Full Adder.
157 :attribute mask: the carry partition mask
158 :attribute in0: the first input
159 :attribute in1: the second input
160 :attribute in2: the third input
161 :attribute sum: the sum output
162 :attribute mcarry: the masked carry output
164 FullAdders are always used with a "mask" on the output. To keep
165 the graphviz "clean", this class performs the masking here rather
166 than inside a large for-loop.
168 See the following discussion as to why this is no longer derived
169 from FullAdder. Each carry is shifted here *before* being ANDed
170 with the mask, so that an AOI cell may be used (which is more
172 https://en.wikipedia.org/wiki/AND-OR-Invert
173 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
176 def __init__(self
, width
):
177 """Create a ``MaskedFullAdder``.
179 :param width: the bit width of the input and output
182 self
.mask
= Signal(width
, reset_less
=True)
183 self
.mcarry
= Signal(width
, reset_less
=True)
184 self
.in0
= Signal(width
, reset_less
=True)
185 self
.in1
= Signal(width
, reset_less
=True)
186 self
.in2
= Signal(width
, reset_less
=True)
187 self
.sum = Signal(width
, reset_less
=True)
189 def elaborate(self
, platform
):
190 """Elaborate this module."""
192 s1
= Signal(self
.width
, reset_less
=True)
193 s2
= Signal(self
.width
, reset_less
=True)
194 s3
= Signal(self
.width
, reset_less
=True)
195 c1
= Signal(self
.width
, reset_less
=True)
196 c2
= Signal(self
.width
, reset_less
=True)
197 c3
= Signal(self
.width
, reset_less
=True)
198 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
199 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
200 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
201 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
202 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
203 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
204 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
205 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
209 class PartitionedAdder(Elaboratable
):
210 """Partitioned Adder.
212 Performs the final add. The partition points are included in the
213 actual add (in one of the operands only), which causes a carry over
214 to the next bit. Then the final output *removes* the extra bits from
217 partition: .... P... P... P... P... (32 bits)
218 a : .... .... .... .... .... (32 bits)
219 b : .... .... .... .... .... (32 bits)
220 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
221 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
222 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
223 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
225 :attribute width: the bit width of the input and output. Read-only.
226 :attribute a: the first input to the adder
227 :attribute b: the second input to the adder
228 :attribute output: the sum output
229 :attribute partition_points: the input partition points. Modification not
230 supported, except for by ``Signal.eq``.
233 def __init__(self
, width
, partition_points
, partition_step
=1):
234 """Create a ``PartitionedAdder``.
236 :param width: the bit width of the input and output
237 :param partition_points: the input partition points
238 :param partition_step: a multiplier (typically double) step
239 which in-place "expands" the partition points
242 self
.pmul
= partition_step
243 self
.a
= Signal(width
, reset_less
=True)
244 self
.b
= Signal(width
, reset_less
=True)
245 self
.output
= Signal(width
, reset_less
=True)
246 self
.partition_points
= PartitionPoints(partition_points
)
247 if not self
.partition_points
.fits_in_width(width
):
248 raise ValueError("partition_points doesn't fit in width")
250 for i
in range(self
.width
):
251 if i
in self
.partition_points
:
254 self
._expanded
_width
= expanded_width
256 def elaborate(self
, platform
):
257 """Elaborate this module."""
259 expanded_a
= Signal(self
._expanded
_width
, reset_less
=True)
260 expanded_b
= Signal(self
._expanded
_width
, reset_less
=True)
261 expanded_o
= Signal(self
._expanded
_width
, reset_less
=True)
264 # store bits in a list, use Cat later. graphviz is much cleaner
265 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
267 # partition points are "breaks" (extra zeros or 1s) in what would
268 # otherwise be a massive long add. when the "break" points are 0,
269 # whatever is in it (in the output) is discarded. however when
270 # there is a "1", it causes a roll-over carry to the *next* bit.
271 # we still ignore the "break" bit in the [intermediate] output,
272 # however by that time we've got the effect that we wanted: the
273 # carry has been carried *over* the break point.
275 for i
in range(self
.width
):
276 pi
= i
/self
.pmul
# double the range of the partition point test
277 if pi
.is_integer() and pi
in self
.partition_points
:
278 # add extra bit set to 0 + 0 for enabled partition points
279 # and 1 + 0 for disabled partition points
280 ea
.append(expanded_a
[expanded_index
])
281 al
.append(~self
.partition_points
[pi
]) # add extra bit in a
282 eb
.append(expanded_b
[expanded_index
])
283 bl
.append(C(0)) # yes, add a zero
284 expanded_index
+= 1 # skip the extra point. NOT in the output
285 ea
.append(expanded_a
[expanded_index
])
286 eb
.append(expanded_b
[expanded_index
])
287 eo
.append(expanded_o
[expanded_index
])
290 ol
.append(self
.output
[i
])
293 # combine above using Cat
294 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
295 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
296 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
298 # use only one addition to take advantage of look-ahead carry and
299 # special hardware on FPGAs
300 m
.d
.comb
+= expanded_o
.eq(expanded_a
+ expanded_b
)
304 FULL_ADDER_INPUT_COUNT
= 3
308 def __init__(self
, part_pts
, n_inputs
, output_width
, n_parts
):
309 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
310 for i
in range(n_parts
)]
311 self
.terms
= [Signal(output_width
, name
=f
"inputs_{i}",
313 for i
in range(n_inputs
)]
314 self
.part_pts
= part_pts
.like()
316 def eq_from(self
, part_pts
, inputs
, part_ops
):
317 return [self
.part_pts
.eq(part_pts
)] + \
318 [self
.terms
[i
].eq(inputs
[i
])
319 for i
in range(len(self
.terms
))] + \
320 [self
.part_ops
[i
].eq(part_ops
[i
])
321 for i
in range(len(self
.part_ops
))]
324 return self
.eq_from(rhs
.part_pts
, rhs
.terms
, rhs
.part_ops
)
327 class FinalReduceData
:
329 def __init__(self
, part_pts
, output_width
, n_parts
):
330 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
331 for i
in range(n_parts
)]
332 self
.output
= Signal(output_width
, reset_less
=True)
333 self
.part_pts
= part_pts
.like()
335 def eq_from(self
, part_pts
, output
, part_ops
):
336 return [self
.part_pts
.eq(part_pts
)] + \
337 [self
.output
.eq(output
)] + \
338 [self
.part_ops
[i
].eq(part_ops
[i
])
339 for i
in range(len(self
.part_ops
))]
342 return self
.eq_from(rhs
.part_pts
, rhs
.output
, rhs
.part_ops
)
345 class FinalAdd(Elaboratable
):
346 """ Final stage of add reduce
349 def __init__(self
, lidx
, n_inputs
, output_width
, n_parts
, partition_points
,
352 self
.partition_step
= partition_step
353 self
.output_width
= output_width
354 self
.n_inputs
= n_inputs
355 self
.n_parts
= n_parts
356 self
.partition_points
= PartitionPoints(partition_points
)
357 if not self
.partition_points
.fits_in_width(output_width
):
358 raise ValueError("partition_points doesn't fit in output_width")
360 self
.i
= self
.ispec()
361 self
.o
= self
.ospec()
364 return AddReduceData(self
.partition_points
, self
.n_inputs
,
365 self
.output_width
, self
.n_parts
)
368 return FinalReduceData(self
.partition_points
,
369 self
.output_width
, self
.n_parts
)
371 def setup(self
, m
, i
):
372 m
.submodules
.finaladd
= self
373 m
.d
.comb
+= self
.i
.eq(i
)
375 def process(self
, i
):
378 def elaborate(self
, platform
):
379 """Elaborate this module."""
382 output_width
= self
.output_width
383 output
= Signal(output_width
, reset_less
=True)
384 if self
.n_inputs
== 0:
385 # use 0 as the default output value
386 m
.d
.comb
+= output
.eq(0)
387 elif self
.n_inputs
== 1:
388 # handle single input
389 m
.d
.comb
+= output
.eq(self
.i
.terms
[0])
391 # base case for adding 2 inputs
392 assert self
.n_inputs
== 2
393 adder
= PartitionedAdder(output_width
,
394 self
.i
.part_pts
, self
.partition_step
)
395 m
.submodules
.final_adder
= adder
396 m
.d
.comb
+= adder
.a
.eq(self
.i
.terms
[0])
397 m
.d
.comb
+= adder
.b
.eq(self
.i
.terms
[1])
398 m
.d
.comb
+= output
.eq(adder
.output
)
401 m
.d
.comb
+= self
.o
.eq_from(self
.i
.part_pts
, output
,
407 class AddReduceSingle(Elaboratable
):
408 """Add list of numbers together.
410 :attribute inputs: input ``Signal``s to be summed. Modification not
411 supported, except for by ``Signal.eq``.
412 :attribute register_levels: List of nesting levels that should have
414 :attribute output: output sum.
415 :attribute partition_points: the input partition points. Modification not
416 supported, except for by ``Signal.eq``.
419 def __init__(self
, lidx
, n_inputs
, output_width
, n_parts
, partition_points
,
421 """Create an ``AddReduce``.
423 :param inputs: input ``Signal``s to be summed.
424 :param output_width: bit-width of ``output``.
425 :param partition_points: the input partition points.
428 self
.partition_step
= partition_step
429 self
.n_inputs
= n_inputs
430 self
.n_parts
= n_parts
431 self
.output_width
= output_width
432 self
.partition_points
= PartitionPoints(partition_points
)
433 if not self
.partition_points
.fits_in_width(output_width
):
434 raise ValueError("partition_points doesn't fit in output_width")
436 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
437 self
.n_terms
= AddReduceSingle
.calc_n_inputs(n_inputs
, self
.groups
)
439 self
.i
= self
.ispec()
440 self
.o
= self
.ospec()
443 return AddReduceData(self
.partition_points
, self
.n_inputs
,
444 self
.output_width
, self
.n_parts
)
447 return AddReduceData(self
.partition_points
, self
.n_terms
,
448 self
.output_width
, self
.n_parts
)
450 def setup(self
, m
, i
):
451 setattr(m
.submodules
, "addreduce_%d" % self
.lidx
, self
)
452 m
.d
.comb
+= self
.i
.eq(i
)
454 def process(self
, i
):
458 def calc_n_inputs(n_inputs
, groups
):
459 retval
= len(groups
)*2
460 if n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
462 elif n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
465 assert n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
469 def get_max_level(input_count
):
470 """Get the maximum level.
472 All ``register_levels`` must be less than or equal to the maximum
477 groups
= AddReduceSingle
.full_adder_groups(input_count
)
480 input_count
%= FULL_ADDER_INPUT_COUNT
481 input_count
+= 2 * len(groups
)
485 def full_adder_groups(input_count
):
486 """Get ``inputs`` indices for which a full adder should be built."""
488 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
489 FULL_ADDER_INPUT_COUNT
)
491 def create_next_terms(self
):
492 """ create next intermediate terms, for linking up in elaborate, below
497 # create full adders for this recursive level.
498 # this shrinks N terms to 2 * (N // 3) plus the remainder
499 for i
in self
.groups
:
500 adder_i
= MaskedFullAdder(self
.output_width
)
501 adders
.append((i
, adder_i
))
502 # add both the sum and the masked-carry to the next level.
503 # 3 inputs have now been reduced to 2...
504 terms
.append(adder_i
.sum)
505 terms
.append(adder_i
.mcarry
)
506 # handle the remaining inputs.
507 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
508 terms
.append(self
.i
.terms
[-1])
509 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
510 # Just pass the terms to the next layer, since we wouldn't gain
511 # anything by using a half adder since there would still be 2 terms
512 # and just passing the terms to the next layer saves gates.
513 terms
.append(self
.i
.terms
[-2])
514 terms
.append(self
.i
.terms
[-1])
516 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
520 def elaborate(self
, platform
):
521 """Elaborate this module."""
524 terms
, adders
= self
.create_next_terms()
526 # copy the intermediate terms to the output
527 for i
, value
in enumerate(terms
):
528 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
530 # copy reg part points and part ops to output
531 m
.d
.comb
+= self
.o
.part_pts
.eq(self
.i
.part_pts
)
532 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
533 for i
in range(len(self
.i
.part_ops
))]
535 # set up the partition mask (for the adders)
536 part_mask
= Signal(self
.output_width
, reset_less
=True)
538 # get partition points as a mask
539 mask
= self
.i
.part_pts
.as_mask(self
.output_width
,
540 mul
=self
.partition_step
)
541 m
.d
.comb
+= part_mask
.eq(mask
)
543 # add and link the intermediate term modules
544 for i
, (iidx
, adder_i
) in enumerate(adders
):
545 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
547 m
.d
.comb
+= adder_i
.in0
.eq(self
.i
.terms
[iidx
])
548 m
.d
.comb
+= adder_i
.in1
.eq(self
.i
.terms
[iidx
+ 1])
549 m
.d
.comb
+= adder_i
.in2
.eq(self
.i
.terms
[iidx
+ 2])
550 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
555 class AddReduceInternal
:
556 """Recursively Add list of numbers together.
558 :attribute inputs: input ``Signal``s to be summed. Modification not
559 supported, except for by ``Signal.eq``.
560 :attribute register_levels: List of nesting levels that should have
562 :attribute output: output sum.
563 :attribute partition_points: the input partition points. Modification not
564 supported, except for by ``Signal.eq``.
567 def __init__(self
, i
, output_width
, partition_step
=1):
568 """Create an ``AddReduce``.
570 :param inputs: input ``Signal``s to be summed.
571 :param output_width: bit-width of ``output``.
572 :param partition_points: the input partition points.
575 self
.inputs
= i
.terms
576 self
.part_ops
= i
.part_ops
577 self
.output_width
= output_width
578 self
.partition_points
= i
.part_pts
579 self
.partition_step
= partition_step
583 def create_levels(self
):
584 """creates reduction levels"""
587 partition_points
= self
.partition_points
588 part_ops
= self
.part_ops
589 n_parts
= len(part_ops
)
593 groups
= AddReduceSingle
.full_adder_groups(len(inputs
))
597 next_level
= AddReduceSingle(lidx
, ilen
, self
.output_width
, n_parts
,
600 mods
.append(next_level
)
601 partition_points
= next_level
.i
.part_pts
602 inputs
= next_level
.o
.terms
604 part_ops
= next_level
.i
.part_ops
607 next_level
= FinalAdd(lidx
, ilen
, self
.output_width
, n_parts
,
608 partition_points
, self
.partition_step
)
609 mods
.append(next_level
)
614 class AddReduce(AddReduceInternal
, Elaboratable
):
615 """Recursively Add list of numbers together.
617 :attribute inputs: input ``Signal``s to be summed. Modification not
618 supported, except for by ``Signal.eq``.
619 :attribute register_levels: List of nesting levels that should have
621 :attribute output: output sum.
622 :attribute partition_points: the input partition points. Modification not
623 supported, except for by ``Signal.eq``.
626 def __init__(self
, inputs
, output_width
, register_levels
, part_pts
,
627 part_ops
, partition_step
=1):
628 """Create an ``AddReduce``.
630 :param inputs: input ``Signal``s to be summed.
631 :param output_width: bit-width of ``output``.
632 :param register_levels: List of nesting levels that should have
634 :param partition_points: the input partition points.
636 self
._inputs
= inputs
637 self
._part
_pts
= part_pts
638 self
._part
_ops
= part_ops
639 n_parts
= len(part_ops
)
640 self
.i
= AddReduceData(part_pts
, len(inputs
),
641 output_width
, n_parts
)
642 AddReduceInternal
.__init
__(self
, self
.i
, output_width
, partition_step
)
643 self
.o
= FinalReduceData(part_pts
, output_width
, n_parts
)
644 self
.register_levels
= register_levels
647 def get_max_level(input_count
):
648 return AddReduceSingle
.get_max_level(input_count
)
651 def next_register_levels(register_levels
):
652 """``Iterable`` of ``register_levels`` for next recursive level."""
653 for level
in register_levels
:
657 def elaborate(self
, platform
):
658 """Elaborate this module."""
661 m
.d
.comb
+= self
.i
.eq_from(self
._part
_pts
, self
._inputs
, self
._part
_ops
)
663 for i
, next_level
in enumerate(self
.levels
):
664 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
667 for idx
in range(len(self
.levels
)):
668 mcur
= self
.levels
[idx
]
669 if idx
in self
.register_levels
:
670 m
.d
.sync
+= mcur
.i
.eq(i
)
672 m
.d
.comb
+= mcur
.i
.eq(i
)
673 i
= mcur
.o
# for next loop
675 # output comes from last module
676 m
.d
.comb
+= self
.o
.eq(i
)
682 OP_MUL_SIGNED_HIGH
= 1
683 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
684 OP_MUL_UNSIGNED_HIGH
= 3
687 def get_term(value
, shift
=0, enabled
=None):
688 if enabled
is not None:
689 value
= Mux(enabled
, value
, 0)
691 value
= Cat(Repl(C(0, 1), shift
), value
)
697 class ProductTerm(Elaboratable
):
698 """ this class creates a single product term (a[..]*b[..]).
699 it has a design flaw in that is the *output* that is selected,
700 where the multiplication(s) are combinatorially generated
704 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
705 self
.a_index
= a_index
706 self
.b_index
= b_index
707 shift
= 8 * (self
.a_index
+ self
.b_index
)
713 self
.ti
= Signal(self
.width
, reset_less
=True)
714 self
.term
= Signal(twidth
, reset_less
=True)
715 self
.a
= Signal(twidth
//2, reset_less
=True)
716 self
.b
= Signal(twidth
//2, reset_less
=True)
717 self
.pb_en
= Signal(pbwid
, reset_less
=True)
720 min_index
= min(self
.a_index
, self
.b_index
)
721 max_index
= max(self
.a_index
, self
.b_index
)
722 for i
in range(min_index
, max_index
):
723 tl
.append(self
.pb_en
[i
])
724 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
726 term_enabled
= Signal(name
=name
, reset_less
=True)
729 self
.enabled
= term_enabled
730 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
732 def elaborate(self
, platform
):
735 if self
.enabled
is not None:
736 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
738 bsa
= Signal(self
.width
, reset_less
=True)
739 bsb
= Signal(self
.width
, reset_less
=True)
740 a_index
, b_index
= self
.a_index
, self
.b_index
742 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
743 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
744 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
745 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
747 #TODO: sort out width issues, get inputs a/b switched on/off.
748 #data going into Muxes is 1/2 the required width
752 bsa = Signal(self.twidth//2, reset_less=True)
753 bsb = Signal(self.twidth//2, reset_less=True)
754 asel = Signal(width, reset_less=True)
755 bsel = Signal(width, reset_less=True)
756 a_index, b_index = self.a_index, self.b_index
757 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
758 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
759 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
760 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
761 m.d.comb += self.ti.eq(bsa * bsb)
762 m.d.comb += self.term.eq(self.ti)
768 class ProductTerms(Elaboratable
):
769 """ creates a bank of product terms. also performs the actual bit-selection
770 this class is to be wrapped with a for-loop on the "a" operand.
771 it creates a second-level for-loop on the "b" operand.
773 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
774 self
.a_index
= a_index
779 self
.a
= Signal(twidth
//2, reset_less
=True)
780 self
.b
= Signal(twidth
//2, reset_less
=True)
781 self
.pb_en
= Signal(pbwid
, reset_less
=True)
782 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
783 for i
in range(blen
)]
785 def elaborate(self
, platform
):
789 for b_index
in range(self
.blen
):
790 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
791 self
.a_index
, b_index
)
792 setattr(m
.submodules
, "term_%d" % b_index
, t
)
794 m
.d
.comb
+= t
.a
.eq(self
.a
)
795 m
.d
.comb
+= t
.b
.eq(self
.b
)
796 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
798 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
803 class LSBNegTerm(Elaboratable
):
805 def __init__(self
, bit_width
):
806 self
.bit_width
= bit_width
807 self
.part
= Signal(reset_less
=True)
808 self
.signed
= Signal(reset_less
=True)
809 self
.op
= Signal(bit_width
, reset_less
=True)
810 self
.msb
= Signal(reset_less
=True)
811 self
.nt
= Signal(bit_width
*2, reset_less
=True)
812 self
.nl
= Signal(bit_width
*2, reset_less
=True)
814 def elaborate(self
, platform
):
817 bit_wid
= self
.bit_width
818 ext
= Repl(0, bit_wid
) # extend output to HI part
820 # determine sign of each incoming number *in this partition*
821 enabled
= Signal(reset_less
=True)
822 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
824 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
825 # negation operation is split into a bitwise not and a +1.
826 # likewise for 16, 32, and 64-bit values.
828 # width-extended 1s complement if a is signed, otherwise zero
829 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
831 # add 1 if signed, otherwise add zero
832 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
837 class Parts(Elaboratable
):
839 def __init__(self
, pbwid
, part_pts
, n_parts
):
842 self
.part_pts
= PartitionPoints
.like(part_pts
)
844 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
845 for i
in range(n_parts
)]
847 def elaborate(self
, platform
):
850 part_pts
, parts
= self
.part_pts
, self
.parts
851 # collect part-bytes (double factor because the input is extended)
852 pbs
= Signal(self
.pbwid
, reset_less
=True)
854 for i
in range(self
.pbwid
):
855 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
856 m
.d
.comb
+= pb
.eq(part_pts
.part_byte(i
))
858 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
860 # negated-temporary copy of partition bits
861 npbs
= Signal
.like(pbs
, reset_less
=True)
862 m
.d
.comb
+= npbs
.eq(~pbs
)
863 byte_count
= 8 // len(parts
)
864 for i
in range(len(parts
)):
866 pbl
.append(npbs
[i
* byte_count
- 1])
867 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
869 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
870 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
871 m
.d
.comb
+= value
.eq(Cat(*pbl
))
872 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
877 class Part(Elaboratable
):
878 """ a key class which, depending on the partitioning, will determine
879 what action to take when parts of the output are signed or unsigned.
881 this requires 2 pieces of data *per operand, per partition*:
882 whether the MSB is HI/LO (per partition!), and whether a signed
883 or unsigned operation has been *requested*.
885 once that is determined, signed is basically carried out
886 by splitting 2's complement into 1's complement plus one.
887 1's complement is just a bit-inversion.
889 the extra terms - as separate terms - are then thrown at the
890 AddReduce alongside the multiplication part-results.
892 def __init__(self
, part_pts
, width
, n_parts
, n_levels
, pbwid
):
895 self
.part_pts
= part_pts
898 self
.a
= Signal(64, reset_less
=True)
899 self
.b
= Signal(64, reset_less
=True)
900 self
.a_signed
= [Signal(name
=f
"a_signed_{i}", reset_less
=True)
902 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}", reset_less
=True)
904 self
.pbs
= Signal(pbwid
, reset_less
=True)
907 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
908 for i
in range(n_parts
)]
910 self
.not_a_term
= Signal(width
, reset_less
=True)
911 self
.neg_lsb_a_term
= Signal(width
, reset_less
=True)
912 self
.not_b_term
= Signal(width
, reset_less
=True)
913 self
.neg_lsb_b_term
= Signal(width
, reset_less
=True)
915 def elaborate(self
, platform
):
918 pbs
, parts
= self
.pbs
, self
.parts
919 part_pts
= self
.part_pts
920 m
.submodules
.p
= p
= Parts(self
.pbwid
, part_pts
, len(parts
))
921 m
.d
.comb
+= p
.part_pts
.eq(part_pts
)
924 byte_count
= 8 // len(parts
)
926 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
927 self
.not_a_term
, self
.neg_lsb_a_term
,
928 self
.not_b_term
, self
.neg_lsb_b_term
)
930 byte_width
= 8 // len(parts
) # byte width
931 bit_wid
= 8 * byte_width
# bit width
932 nat
, nbt
, nla
, nlb
= [], [], [], []
933 for i
in range(len(parts
)):
934 # work out bit-inverted and +1 term for a.
935 pa
= LSBNegTerm(bit_wid
)
936 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
937 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
938 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
939 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
940 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
944 # work out bit-inverted and +1 term for b
945 pb
= LSBNegTerm(bit_wid
)
946 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
947 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
948 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
949 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
950 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
954 # concatenate together and return all 4 results.
955 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
956 not_b_term
.eq(Cat(*nbt
)),
957 neg_lsb_a_term
.eq(Cat(*nla
)),
958 neg_lsb_b_term
.eq(Cat(*nlb
)),
964 class IntermediateOut(Elaboratable
):
965 """ selects the HI/LO part of the multiplication, for a given bit-width
966 the output is also reconstructed in its SIMD (partition) lanes.
968 def __init__(self
, width
, out_wid
, n_parts
):
970 self
.n_parts
= n_parts
971 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
973 self
.intermed
= Signal(out_wid
, reset_less
=True)
974 self
.output
= Signal(out_wid
//2, reset_less
=True)
976 def elaborate(self
, platform
):
982 for i
in range(self
.n_parts
):
983 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
985 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
986 self
.intermed
.bit_select(i
* w
*2, w
),
987 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
989 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
994 class FinalOut(Elaboratable
):
995 """ selects the final output based on the partitioning.
997 each byte is selectable independently, i.e. it is possible
998 that some partitions requested 8-bit computation whilst others
999 requested 16 or 32 bit.
1001 def __init__(self
, output_width
, n_parts
, part_pts
):
1002 self
.part_pts
= part_pts
1003 self
.output_width
= output_width
1004 self
.n_parts
= n_parts
1005 self
.out_wid
= output_width
//2
1007 self
.i
= self
.ispec()
1008 self
.o
= self
.ospec()
1011 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1016 def setup(self
, m
, i
):
1017 m
.submodules
.finalout
= self
1018 m
.d
.comb
+= self
.i
.eq(i
)
1020 def process(self
, i
):
1023 def elaborate(self
, platform
):
1026 part_pts
= self
.part_pts
1027 m
.submodules
.p_8
= p_8
= Parts(8, part_pts
, 8)
1028 m
.submodules
.p_16
= p_16
= Parts(8, part_pts
, 4)
1029 m
.submodules
.p_32
= p_32
= Parts(8, part_pts
, 2)
1030 m
.submodules
.p_64
= p_64
= Parts(8, part_pts
, 1)
1032 out_part_pts
= self
.i
.part_pts
1035 d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
1036 d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
1037 d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
1039 i8
= Signal(self
.out_wid
, reset_less
=True)
1040 i16
= Signal(self
.out_wid
, reset_less
=True)
1041 i32
= Signal(self
.out_wid
, reset_less
=True)
1042 i64
= Signal(self
.out_wid
, reset_less
=True)
1044 m
.d
.comb
+= p_8
.part_pts
.eq(out_part_pts
)
1045 m
.d
.comb
+= p_16
.part_pts
.eq(out_part_pts
)
1046 m
.d
.comb
+= p_32
.part_pts
.eq(out_part_pts
)
1047 m
.d
.comb
+= p_64
.part_pts
.eq(out_part_pts
)
1049 for i
in range(len(p_8
.parts
)):
1050 m
.d
.comb
+= d8
[i
].eq(p_8
.parts
[i
])
1051 for i
in range(len(p_16
.parts
)):
1052 m
.d
.comb
+= d16
[i
].eq(p_16
.parts
[i
])
1053 for i
in range(len(p_32
.parts
)):
1054 m
.d
.comb
+= d32
[i
].eq(p_32
.parts
[i
])
1055 m
.d
.comb
+= i8
.eq(self
.i
.outputs
[0])
1056 m
.d
.comb
+= i16
.eq(self
.i
.outputs
[1])
1057 m
.d
.comb
+= i32
.eq(self
.i
.outputs
[2])
1058 m
.d
.comb
+= i64
.eq(self
.i
.outputs
[3])
1062 # select one of the outputs: d8 selects i8, d16 selects i16
1063 # d32 selects i32, and the default is i64.
1064 # d8 and d16 are ORed together in the first Mux
1065 # then the 2nd selects either i8 or i16.
1066 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1067 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
1069 Mux(d8
[i
] | d16
[i
// 2],
1070 Mux(d8
[i
], i8
.bit_select(i
* 8, 8),
1071 i16
.bit_select(i
* 8, 8)),
1072 Mux(d32
[i
// 4], i32
.bit_select(i
* 8, 8),
1073 i64
.bit_select(i
* 8, 8))))
1077 m
.d
.comb
+= self
.o
.output
.eq(Cat(*ol
))
1078 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.intermediate_output
)
1083 class OrMod(Elaboratable
):
1084 """ ORs four values together in a hierarchical tree
1086 def __init__(self
, wid
):
1088 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
1090 self
.orout
= Signal(wid
, reset_less
=True)
1092 def elaborate(self
, platform
):
1094 or1
= Signal(self
.wid
, reset_less
=True)
1095 or2
= Signal(self
.wid
, reset_less
=True)
1096 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
1097 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
1098 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
1103 class Signs(Elaboratable
):
1104 """ determines whether a or b are signed numbers
1105 based on the required operation type (OP_MUL_*)
1109 self
.part_ops
= Signal(2, reset_less
=True)
1110 self
.a_signed
= Signal(reset_less
=True)
1111 self
.b_signed
= Signal(reset_less
=True)
1113 def elaborate(self
, platform
):
1117 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
1118 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
1119 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
1120 m
.d
.comb
+= self
.a_signed
.eq(asig
)
1121 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
1126 class IntermediateData
:
1128 def __init__(self
, part_pts
, output_width
, n_parts
):
1129 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
1130 for i
in range(n_parts
)]
1131 self
.part_pts
= part_pts
.like()
1132 self
.outputs
= [Signal(output_width
, name
="io%d" % i
, reset_less
=True)
1134 # intermediates (needed for unit tests)
1135 self
.intermediate_output
= Signal(output_width
)
1137 def eq_from(self
, part_pts
, outputs
, intermediate_output
,
1139 return [self
.part_pts
.eq(part_pts
)] + \
1140 [self
.intermediate_output
.eq(intermediate_output
)] + \
1141 [self
.outputs
[i
].eq(outputs
[i
])
1142 for i
in range(4)] + \
1143 [self
.part_ops
[i
].eq(part_ops
[i
])
1144 for i
in range(len(self
.part_ops
))]
1147 return self
.eq_from(rhs
.part_pts
, rhs
.outputs
,
1148 rhs
.intermediate_output
, rhs
.part_ops
)
1156 self
.part_pts
= PartitionPoints()
1157 for i
in range(8, 64, 8):
1158 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1159 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1161 def eq_from(self
, part_pts
, a
, b
, part_ops
):
1162 return [self
.part_pts
.eq(part_pts
)] + \
1163 [self
.a
.eq(a
), self
.b
.eq(b
)] + \
1164 [self
.part_ops
[i
].eq(part_ops
[i
])
1165 for i
in range(len(self
.part_ops
))]
1168 return self
.eq_from(rhs
.part_pts
, rhs
.a
, rhs
.b
, rhs
.part_ops
)
1174 self
.intermediate_output
= Signal(128) # needed for unit tests
1175 self
.output
= Signal(64)
1178 return [self
.intermediate_output
.eq(rhs
.intermediate_output
),
1179 self
.output
.eq(rhs
.output
)]
1182 class AllTerms(Elaboratable
):
1183 """Set of terms to be added together
1186 def __init__(self
, n_inputs
, output_width
, n_parts
, register_levels
):
1187 """Create an ``AddReduce``.
1189 :param inputs: input ``Signal``s to be summed.
1190 :param output_width: bit-width of ``output``.
1191 :param register_levels: List of nesting levels that should have
1193 :param partition_points: the input partition points.
1195 self
.register_levels
= register_levels
1196 self
.n_inputs
= n_inputs
1197 self
.n_parts
= n_parts
1198 self
.output_width
= output_width
1200 self
.i
= self
.ispec()
1201 self
.o
= self
.ospec()
1203 def setup(self
, m
, i
):
1204 m
.submodules
.allterms
= self
1205 m
.d
.comb
+= self
.i
.eq(i
)
1207 def process(self
, i
):
1214 return AddReduceData(self
.i
.part_pts
, self
.n_inputs
,
1215 self
.output_width
, self
.n_parts
)
1217 def elaborate(self
, platform
):
1220 eps
= self
.i
.part_pts
1222 # collect part-bytes
1223 pbs
= Signal(8, reset_less
=True)
1226 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1227 m
.d
.comb
+= pb
.eq(eps
.part_byte(i
))
1229 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1236 setattr(m
.submodules
, "signs%d" % i
, s
)
1237 m
.d
.comb
+= s
.part_ops
.eq(self
.i
.part_ops
[i
])
1239 n_levels
= len(self
.register_levels
)+1
1240 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, n_levels
, 8)
1241 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, n_levels
, 8)
1242 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, n_levels
, 8)
1243 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, n_levels
, 8)
1244 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1245 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1246 m
.d
.comb
+= mod
.a
.eq(self
.i
.a
)
1247 m
.d
.comb
+= mod
.b
.eq(self
.i
.b
)
1248 for i
in range(len(signs
)):
1249 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1250 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1251 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1252 nat_l
.append(mod
.not_a_term
)
1253 nbt_l
.append(mod
.not_b_term
)
1254 nla_l
.append(mod
.neg_lsb_a_term
)
1255 nlb_l
.append(mod
.neg_lsb_b_term
)
1259 for a_index
in range(8):
1260 t
= ProductTerms(8, 128, 8, a_index
, 8)
1261 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1263 m
.d
.comb
+= t
.a
.eq(self
.i
.a
)
1264 m
.d
.comb
+= t
.b
.eq(self
.i
.b
)
1265 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1267 for term
in t
.terms
:
1270 # it's fine to bitwise-or data together since they are never enabled
1272 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1273 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1274 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1275 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1276 for l
, mod
in [(nat_l
, nat_or
),
1280 for i
in range(len(l
)):
1281 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1282 terms
.append(mod
.orout
)
1284 # copy the intermediate terms to the output
1285 for i
, value
in enumerate(terms
):
1286 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
1288 # copy reg part points and part ops to output
1289 m
.d
.comb
+= self
.o
.part_pts
.eq(eps
)
1290 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
1291 for i
in range(len(self
.i
.part_ops
))]
1296 class Intermediates(Elaboratable
):
1297 """ Intermediate output modules
1300 def __init__(self
, output_width
, n_parts
, part_pts
):
1301 self
.part_pts
= part_pts
1302 self
.output_width
= output_width
1303 self
.n_parts
= n_parts
1305 self
.i
= self
.ispec()
1306 self
.o
= self
.ospec()
1309 return FinalReduceData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1312 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1314 def setup(self
, m
, i
):
1315 m
.submodules
.intermediates
= self
1316 m
.d
.comb
+= self
.i
.eq(i
)
1318 def process(self
, i
):
1321 def elaborate(self
, platform
):
1324 out_part_ops
= self
.i
.part_ops
1325 out_part_pts
= self
.i
.part_pts
1328 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1329 m
.d
.comb
+= io64
.intermed
.eq(self
.i
.output
)
1331 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1332 m
.d
.comb
+= self
.o
.outputs
[3].eq(io64
.output
)
1335 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1336 m
.d
.comb
+= io32
.intermed
.eq(self
.i
.output
)
1338 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1339 m
.d
.comb
+= self
.o
.outputs
[2].eq(io32
.output
)
1342 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1343 m
.d
.comb
+= io16
.intermed
.eq(self
.i
.output
)
1345 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1346 m
.d
.comb
+= self
.o
.outputs
[1].eq(io16
.output
)
1349 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1350 m
.d
.comb
+= io8
.intermed
.eq(self
.i
.output
)
1352 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1353 m
.d
.comb
+= self
.o
.outputs
[0].eq(io8
.output
)
1356 m
.d
.comb
+= self
.o
.part_ops
[i
].eq(out_part_ops
[i
])
1357 m
.d
.comb
+= self
.o
.part_pts
.eq(out_part_pts
)
1358 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.output
)
1363 class Mul8_16_32_64(Elaboratable
):
1364 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1366 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1367 partitions on naturally-aligned boundaries. Supports the operation being
1368 set for each partition independently.
1370 :attribute part_pts: the input partition points. Has a partition point at
1371 multiples of 8 in 0 < i < 64. Each partition point's associated
1372 ``Value`` is a ``Signal``. Modification not supported, except for by
1374 :attribute part_ops: the operation for each byte. The operation for a
1375 particular partition is selected by assigning the selected operation
1376 code to each byte in the partition. The allowed operation codes are:
1378 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1379 RISC-V's `mul` instruction.
1380 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1381 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1383 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1384 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1385 `mulhsu` instruction.
1386 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1387 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1391 def __init__(self
, register_levels
=()):
1392 """ register_levels: specifies the points in the cascade at which
1393 flip-flops are to be inserted.
1397 self
.register_levels
= list(register_levels
)
1399 self
.i
= self
.ispec()
1400 self
.o
= self
.ospec()
1403 self
.part_pts
= self
.i
.part_pts
1404 self
.part_ops
= self
.i
.part_ops
1409 self
.intermediate_output
= self
.o
.intermediate_output
1410 self
.output
= self
.o
.output
1418 def elaborate(self
, platform
):
1421 part_pts
= self
.part_pts
1425 t
= AllTerms(n_inputs
, 128, n_parts
, self
.register_levels
)
1430 at
= AddReduceInternal(t
.process(self
.i
), 128, partition_step
=2)
1433 for idx
in range(len(at
.levels
)):
1434 mcur
= at
.levels
[idx
]
1437 if idx
in self
.register_levels
:
1438 m
.d
.sync
+= o
.eq(mcur
.process(i
))
1440 m
.d
.comb
+= o
.eq(mcur
.process(i
))
1441 i
= o
# for next loop
1443 interm
= Intermediates(128, 8, part_pts
)
1445 o
= interm
.process(interm
.i
)
1448 finalout
= FinalOut(128, 8, part_pts
)
1449 finalout
.setup(m
, o
)
1450 m
.d
.comb
+= self
.o
.eq(finalout
.process(o
))
1455 if __name__
== "__main__":
1459 m
.intermediate_output
,
1462 *m
.part_pts
.values()])