1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
11 from ieee754
.pipeline
import PipelineSpec
12 from nmutil
.pipemodbase
import PipeModBase
15 class PartitionPoints(dict):
16 """Partition points and corresponding ``Value``s.
18 The points at where an ALU is partitioned along with ``Value``s that
19 specify if the corresponding partition points are enabled.
21 For example: ``{1: True, 5: True, 10: True}`` with
22 ``width == 16`` specifies that the ALU is split into 4 sections:
25 * bits 5 <= ``i`` < 10
26 * bits 10 <= ``i`` < 16
28 If the partition_points were instead ``{1: True, 5: a, 10: True}``
29 where ``a`` is a 1-bit ``Signal``:
30 * If ``a`` is asserted:
33 * bits 5 <= ``i`` < 10
34 * bits 10 <= ``i`` < 16
37 * bits 1 <= ``i`` < 10
38 * bits 10 <= ``i`` < 16
41 def __init__(self
, partition_points
=None):
42 """Create a new ``PartitionPoints``.
44 :param partition_points: the input partition points to values mapping.
47 if partition_points
is not None:
48 for point
, enabled
in partition_points
.items():
49 if not isinstance(point
, int):
50 raise TypeError("point must be a non-negative integer")
52 raise ValueError("point must be a non-negative integer")
53 self
[point
] = Value
.wrap(enabled
)
55 def like(self
, name
=None, src_loc_at
=0, mul
=1):
56 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
58 :param name: the base name for the new ``Signal``s.
59 :param mul: a multiplication factor on the indices
62 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
63 retval
= PartitionPoints()
64 for point
, enabled
in self
.items():
66 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
70 """Assign ``PartitionPoints`` using ``Signal.eq``."""
71 if set(self
.keys()) != set(rhs
.keys()):
72 raise ValueError("incompatible point set")
73 for point
, enabled
in self
.items():
74 yield enabled
.eq(rhs
[point
])
76 def as_mask(self
, width
, mul
=1):
77 """Create a bit-mask from `self`.
79 Each bit in the returned mask is clear only if the partition point at
80 the same bit-index is enabled.
82 :param width: the bit width of the resulting mask
83 :param mul: a "multiplier" which in-place expands the partition points
84 typically set to "2" when used for multipliers
87 for i
in range(width
):
89 if i
.is_integer() and int(i
) in self
:
95 def get_max_partition_count(self
, width
):
96 """Get the maximum number of partitions.
98 Gets the number of partitions when all partition points are enabled.
101 for point
in self
.keys():
106 def fits_in_width(self
, width
):
107 """Check if all partition points are smaller than `width`."""
108 for point
in self
.keys():
113 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
114 if index
== -1 or index
== 7:
116 assert index
>= 0 and index
< 8
117 return self
[(index
* 8 + 8)*mfactor
]
120 class FullAdder(Elaboratable
):
123 :attribute in0: the first input
124 :attribute in1: the second input
125 :attribute in2: the third input
126 :attribute sum: the sum output
127 :attribute carry: the carry output
129 Rather than do individual full adders (and have an array of them,
130 which would be very slow to simulate), this module can specify the
131 bit width of the inputs and outputs: in effect it performs multiple
132 Full 3-2 Add operations "in parallel".
135 def __init__(self
, width
):
136 """Create a ``FullAdder``.
138 :param width: the bit width of the input and output
140 self
.in0
= Signal(width
, reset_less
=True)
141 self
.in1
= Signal(width
, reset_less
=True)
142 self
.in2
= Signal(width
, reset_less
=True)
143 self
.sum = Signal(width
, reset_less
=True)
144 self
.carry
= Signal(width
, reset_less
=True)
146 def elaborate(self
, platform
):
147 """Elaborate this module."""
149 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
150 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
151 |
(self
.in1
& self
.in2
)
152 |
(self
.in2
& self
.in0
))
156 class MaskedFullAdder(Elaboratable
):
157 """Masked Full Adder.
159 :attribute mask: the carry partition mask
160 :attribute in0: the first input
161 :attribute in1: the second input
162 :attribute in2: the third input
163 :attribute sum: the sum output
164 :attribute mcarry: the masked carry output
166 FullAdders are always used with a "mask" on the output. To keep
167 the graphviz "clean", this class performs the masking here rather
168 than inside a large for-loop.
170 See the following discussion as to why this is no longer derived
171 from FullAdder. Each carry is shifted here *before* being ANDed
172 with the mask, so that an AOI cell may be used (which is more
174 https://en.wikipedia.org/wiki/AND-OR-Invert
175 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
178 def __init__(self
, width
):
179 """Create a ``MaskedFullAdder``.
181 :param width: the bit width of the input and output
184 self
.mask
= Signal(width
, reset_less
=True)
185 self
.mcarry
= Signal(width
, reset_less
=True)
186 self
.in0
= Signal(width
, reset_less
=True)
187 self
.in1
= Signal(width
, reset_less
=True)
188 self
.in2
= Signal(width
, reset_less
=True)
189 self
.sum = Signal(width
, reset_less
=True)
191 def elaborate(self
, platform
):
192 """Elaborate this module."""
194 s1
= Signal(self
.width
, reset_less
=True)
195 s2
= Signal(self
.width
, reset_less
=True)
196 s3
= Signal(self
.width
, reset_less
=True)
197 c1
= Signal(self
.width
, reset_less
=True)
198 c2
= Signal(self
.width
, reset_less
=True)
199 c3
= Signal(self
.width
, reset_less
=True)
200 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
201 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
202 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
203 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
204 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
205 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
206 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
207 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
211 class PartitionedAdder(Elaboratable
):
212 """Partitioned Adder.
214 Performs the final add. The partition points are included in the
215 actual add (in one of the operands only), which causes a carry over
216 to the next bit. Then the final output *removes* the extra bits from
219 partition: .... P... P... P... P... (32 bits)
220 a : .... .... .... .... .... (32 bits)
221 b : .... .... .... .... .... (32 bits)
222 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
223 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
224 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
225 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
227 :attribute width: the bit width of the input and output. Read-only.
228 :attribute a: the first input to the adder
229 :attribute b: the second input to the adder
230 :attribute output: the sum output
231 :attribute partition_points: the input partition points. Modification not
232 supported, except for by ``Signal.eq``.
235 def __init__(self
, width
, partition_points
, partition_step
=1):
236 """Create a ``PartitionedAdder``.
238 :param width: the bit width of the input and output
239 :param partition_points: the input partition points
240 :param partition_step: a multiplier (typically double) step
241 which in-place "expands" the partition points
244 self
.pmul
= partition_step
245 self
.a
= Signal(width
, reset_less
=True)
246 self
.b
= Signal(width
, reset_less
=True)
247 self
.output
= Signal(width
, reset_less
=True)
248 self
.partition_points
= PartitionPoints(partition_points
)
249 if not self
.partition_points
.fits_in_width(width
):
250 raise ValueError("partition_points doesn't fit in width")
252 for i
in range(self
.width
):
253 if i
in self
.partition_points
:
256 self
._expanded
_width
= expanded_width
258 def elaborate(self
, platform
):
259 """Elaborate this module."""
261 expanded_a
= Signal(self
._expanded
_width
, reset_less
=True)
262 expanded_b
= Signal(self
._expanded
_width
, reset_less
=True)
263 expanded_o
= Signal(self
._expanded
_width
, reset_less
=True)
266 # store bits in a list, use Cat later. graphviz is much cleaner
267 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
269 # partition points are "breaks" (extra zeros or 1s) in what would
270 # otherwise be a massive long add. when the "break" points are 0,
271 # whatever is in it (in the output) is discarded. however when
272 # there is a "1", it causes a roll-over carry to the *next* bit.
273 # we still ignore the "break" bit in the [intermediate] output,
274 # however by that time we've got the effect that we wanted: the
275 # carry has been carried *over* the break point.
277 for i
in range(self
.width
):
278 pi
= i
/self
.pmul
# double the range of the partition point test
279 if pi
.is_integer() and pi
in self
.partition_points
:
280 # add extra bit set to 0 + 0 for enabled partition points
281 # and 1 + 0 for disabled partition points
282 ea
.append(expanded_a
[expanded_index
])
283 al
.append(~self
.partition_points
[pi
]) # add extra bit in a
284 eb
.append(expanded_b
[expanded_index
])
285 bl
.append(C(0)) # yes, add a zero
286 expanded_index
+= 1 # skip the extra point. NOT in the output
287 ea
.append(expanded_a
[expanded_index
])
288 eb
.append(expanded_b
[expanded_index
])
289 eo
.append(expanded_o
[expanded_index
])
292 ol
.append(self
.output
[i
])
295 # combine above using Cat
296 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
297 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
298 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
300 # use only one addition to take advantage of look-ahead carry and
301 # special hardware on FPGAs
302 m
.d
.comb
+= expanded_o
.eq(expanded_a
+ expanded_b
)
306 FULL_ADDER_INPUT_COUNT
= 3
310 def __init__(self
, part_pts
, n_inputs
, output_width
, n_parts
):
311 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
312 for i
in range(n_parts
)]
313 self
.terms
= [Signal(output_width
, name
=f
"inputs_{i}",
315 for i
in range(n_inputs
)]
316 self
.part_pts
= part_pts
.like()
318 def eq_from(self
, part_pts
, inputs
, part_ops
):
319 return [self
.part_pts
.eq(part_pts
)] + \
320 [self
.terms
[i
].eq(inputs
[i
])
321 for i
in range(len(self
.terms
))] + \
322 [self
.part_ops
[i
].eq(part_ops
[i
])
323 for i
in range(len(self
.part_ops
))]
326 return self
.eq_from(rhs
.part_pts
, rhs
.terms
, rhs
.part_ops
)
329 class FinalReduceData
:
331 def __init__(self
, part_pts
, output_width
, n_parts
):
332 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
333 for i
in range(n_parts
)]
334 self
.output
= Signal(output_width
, reset_less
=True)
335 self
.part_pts
= part_pts
.like()
337 def eq_from(self
, part_pts
, output
, part_ops
):
338 return [self
.part_pts
.eq(part_pts
)] + \
339 [self
.output
.eq(output
)] + \
340 [self
.part_ops
[i
].eq(part_ops
[i
])
341 for i
in range(len(self
.part_ops
))]
344 return self
.eq_from(rhs
.part_pts
, rhs
.output
, rhs
.part_ops
)
347 class FinalAdd(Elaboratable
):
348 """ Final stage of add reduce
351 def __init__(self
, lidx
, n_inputs
, output_width
, n_parts
, partition_points
,
354 self
.partition_step
= partition_step
355 self
.output_width
= output_width
356 self
.n_inputs
= n_inputs
357 self
.n_parts
= n_parts
358 self
.partition_points
= PartitionPoints(partition_points
)
359 if not self
.partition_points
.fits_in_width(output_width
):
360 raise ValueError("partition_points doesn't fit in output_width")
362 self
.i
= self
.ispec()
363 self
.o
= self
.ospec()
366 return AddReduceData(self
.partition_points
, self
.n_inputs
,
367 self
.output_width
, self
.n_parts
)
370 return FinalReduceData(self
.partition_points
,
371 self
.output_width
, self
.n_parts
)
373 def setup(self
, m
, i
):
374 m
.submodules
.finaladd
= self
375 m
.d
.comb
+= self
.i
.eq(i
)
377 def process(self
, i
):
380 def elaborate(self
, platform
):
381 """Elaborate this module."""
384 output_width
= self
.output_width
385 output
= Signal(output_width
, reset_less
=True)
386 if self
.n_inputs
== 0:
387 # use 0 as the default output value
388 m
.d
.comb
+= output
.eq(0)
389 elif self
.n_inputs
== 1:
390 # handle single input
391 m
.d
.comb
+= output
.eq(self
.i
.terms
[0])
393 # base case for adding 2 inputs
394 assert self
.n_inputs
== 2
395 adder
= PartitionedAdder(output_width
,
396 self
.i
.part_pts
, self
.partition_step
)
397 m
.submodules
.final_adder
= adder
398 m
.d
.comb
+= adder
.a
.eq(self
.i
.terms
[0])
399 m
.d
.comb
+= adder
.b
.eq(self
.i
.terms
[1])
400 m
.d
.comb
+= output
.eq(adder
.output
)
403 m
.d
.comb
+= self
.o
.eq_from(self
.i
.part_pts
, output
,
409 class AddReduceSingle(PipeModBase
):
410 """Add list of numbers together.
412 :attribute inputs: input ``Signal``s to be summed. Modification not
413 supported, except for by ``Signal.eq``.
414 :attribute register_levels: List of nesting levels that should have
416 :attribute output: output sum.
417 :attribute partition_points: the input partition points. Modification not
418 supported, except for by ``Signal.eq``.
421 def __init__(self
, pspec
, lidx
, n_inputs
, partition_points
,
423 """Create an ``AddReduce``.
425 :param inputs: input ``Signal``s to be summed.
426 :param output_width: bit-width of ``output``.
427 :param partition_points: the input partition points.
430 self
.partition_step
= partition_step
431 self
.n_inputs
= n_inputs
432 self
.n_parts
= pspec
.n_parts
433 self
.output_width
= pspec
.width
* 2
434 self
.partition_points
= PartitionPoints(partition_points
)
435 if not self
.partition_points
.fits_in_width(self
.output_width
):
436 raise ValueError("partition_points doesn't fit in output_width")
438 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
439 self
.n_terms
= AddReduceSingle
.calc_n_inputs(n_inputs
, self
.groups
)
441 super().__init
__(pspec
, "addreduce_%d" % lidx
)
444 return AddReduceData(self
.partition_points
, self
.n_inputs
,
445 self
.output_width
, self
.n_parts
)
448 return AddReduceData(self
.partition_points
, self
.n_terms
,
449 self
.output_width
, self
.n_parts
)
452 def calc_n_inputs(n_inputs
, groups
):
453 retval
= len(groups
)*2
454 if n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
456 elif n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
459 assert n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
463 def get_max_level(input_count
):
464 """Get the maximum level.
466 All ``register_levels`` must be less than or equal to the maximum
471 groups
= AddReduceSingle
.full_adder_groups(input_count
)
474 input_count
%= FULL_ADDER_INPUT_COUNT
475 input_count
+= 2 * len(groups
)
479 def full_adder_groups(input_count
):
480 """Get ``inputs`` indices for which a full adder should be built."""
482 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
483 FULL_ADDER_INPUT_COUNT
)
485 def create_next_terms(self
):
486 """ create next intermediate terms, for linking up in elaborate, below
491 # create full adders for this recursive level.
492 # this shrinks N terms to 2 * (N // 3) plus the remainder
493 for i
in self
.groups
:
494 adder_i
= MaskedFullAdder(self
.output_width
)
495 adders
.append((i
, adder_i
))
496 # add both the sum and the masked-carry to the next level.
497 # 3 inputs have now been reduced to 2...
498 terms
.append(adder_i
.sum)
499 terms
.append(adder_i
.mcarry
)
500 # handle the remaining inputs.
501 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
502 terms
.append(self
.i
.terms
[-1])
503 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
504 # Just pass the terms to the next layer, since we wouldn't gain
505 # anything by using a half adder since there would still be 2 terms
506 # and just passing the terms to the next layer saves gates.
507 terms
.append(self
.i
.terms
[-2])
508 terms
.append(self
.i
.terms
[-1])
510 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
514 def elaborate(self
, platform
):
515 """Elaborate this module."""
518 terms
, adders
= self
.create_next_terms()
520 # copy the intermediate terms to the output
521 for i
, value
in enumerate(terms
):
522 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
524 # copy reg part points and part ops to output
525 m
.d
.comb
+= self
.o
.part_pts
.eq(self
.i
.part_pts
)
526 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
527 for i
in range(len(self
.i
.part_ops
))]
529 # set up the partition mask (for the adders)
530 part_mask
= Signal(self
.output_width
, reset_less
=True)
532 # get partition points as a mask
533 mask
= self
.i
.part_pts
.as_mask(self
.output_width
,
534 mul
=self
.partition_step
)
535 m
.d
.comb
+= part_mask
.eq(mask
)
537 # add and link the intermediate term modules
538 for i
, (iidx
, adder_i
) in enumerate(adders
):
539 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
541 m
.d
.comb
+= adder_i
.in0
.eq(self
.i
.terms
[iidx
])
542 m
.d
.comb
+= adder_i
.in1
.eq(self
.i
.terms
[iidx
+ 1])
543 m
.d
.comb
+= adder_i
.in2
.eq(self
.i
.terms
[iidx
+ 2])
544 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
549 class AddReduceInternal
:
550 """Recursively Add list of numbers together.
552 :attribute inputs: input ``Signal``s to be summed. Modification not
553 supported, except for by ``Signal.eq``.
554 :attribute register_levels: List of nesting levels that should have
556 :attribute output: output sum.
557 :attribute partition_points: the input partition points. Modification not
558 supported, except for by ``Signal.eq``.
561 def __init__(self
, i
, pspec
, partition_step
=1):
562 """Create an ``AddReduce``.
564 :param inputs: input ``Signal``s to be summed.
565 :param output_width: bit-width of ``output``.
566 :param partition_points: the input partition points.
570 self
.inputs
= i
.terms
571 self
.part_ops
= i
.part_ops
572 self
.output_width
= pspec
.width
* 2
573 self
.partition_points
= i
.part_pts
574 self
.partition_step
= partition_step
578 def create_levels(self
):
579 """creates reduction levels"""
582 partition_points
= self
.partition_points
583 part_ops
= self
.part_ops
584 n_parts
= len(part_ops
)
588 groups
= AddReduceSingle
.full_adder_groups(len(inputs
))
592 next_level
= AddReduceSingle(self
.pspec
, lidx
, ilen
,
595 mods
.append(next_level
)
596 partition_points
= next_level
.i
.part_pts
597 inputs
= next_level
.o
.terms
599 part_ops
= next_level
.i
.part_ops
602 next_level
= FinalAdd(lidx
, ilen
, self
.output_width
, n_parts
,
603 partition_points
, self
.partition_step
)
604 mods
.append(next_level
)
609 class AddReduce(AddReduceInternal
, Elaboratable
):
610 """Recursively Add list of numbers together.
612 :attribute inputs: input ``Signal``s to be summed. Modification not
613 supported, except for by ``Signal.eq``.
614 :attribute register_levels: List of nesting levels that should have
616 :attribute output: output sum.
617 :attribute partition_points: the input partition points. Modification not
618 supported, except for by ``Signal.eq``.
621 def __init__(self
, inputs
, output_width
, register_levels
, part_pts
,
622 part_ops
, partition_step
=1):
623 """Create an ``AddReduce``.
625 :param inputs: input ``Signal``s to be summed.
626 :param output_width: bit-width of ``output``.
627 :param register_levels: List of nesting levels that should have
629 :param partition_points: the input partition points.
631 self
._inputs
= inputs
632 self
._part
_pts
= part_pts
633 self
._part
_ops
= part_ops
634 n_parts
= len(part_ops
)
635 self
.i
= AddReduceData(part_pts
, len(inputs
),
636 output_width
, n_parts
)
637 AddReduceInternal
.__init
__(self
, self
.i
, output_width
, partition_step
)
638 self
.o
= FinalReduceData(part_pts
, output_width
, n_parts
)
639 self
.register_levels
= register_levels
642 def get_max_level(input_count
):
643 return AddReduceSingle
.get_max_level(input_count
)
646 def next_register_levels(register_levels
):
647 """``Iterable`` of ``register_levels`` for next recursive level."""
648 for level
in register_levels
:
652 def elaborate(self
, platform
):
653 """Elaborate this module."""
656 m
.d
.comb
+= self
.i
.eq_from(self
._part
_pts
, self
._inputs
, self
._part
_ops
)
658 for i
, next_level
in enumerate(self
.levels
):
659 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
662 for idx
in range(len(self
.levels
)):
663 mcur
= self
.levels
[idx
]
664 if idx
in self
.register_levels
:
665 m
.d
.sync
+= mcur
.i
.eq(i
)
667 m
.d
.comb
+= mcur
.i
.eq(i
)
668 i
= mcur
.o
# for next loop
670 # output comes from last module
671 m
.d
.comb
+= self
.o
.eq(i
)
677 OP_MUL_SIGNED_HIGH
= 1
678 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
679 OP_MUL_UNSIGNED_HIGH
= 3
682 def get_term(value
, shift
=0, enabled
=None):
683 if enabled
is not None:
684 value
= Mux(enabled
, value
, 0)
686 value
= Cat(Repl(C(0, 1), shift
), value
)
692 class ProductTerm(Elaboratable
):
693 """ this class creates a single product term (a[..]*b[..]).
694 it has a design flaw in that is the *output* that is selected,
695 where the multiplication(s) are combinatorially generated
699 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
700 self
.a_index
= a_index
701 self
.b_index
= b_index
702 shift
= 8 * (self
.a_index
+ self
.b_index
)
708 self
.ti
= Signal(self
.width
, reset_less
=True)
709 self
.term
= Signal(twidth
, reset_less
=True)
710 self
.a
= Signal(twidth
//2, reset_less
=True)
711 self
.b
= Signal(twidth
//2, reset_less
=True)
712 self
.pb_en
= Signal(pbwid
, reset_less
=True)
715 min_index
= min(self
.a_index
, self
.b_index
)
716 max_index
= max(self
.a_index
, self
.b_index
)
717 for i
in range(min_index
, max_index
):
718 tl
.append(self
.pb_en
[i
])
719 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
721 term_enabled
= Signal(name
=name
, reset_less
=True)
724 self
.enabled
= term_enabled
725 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
727 def elaborate(self
, platform
):
730 if self
.enabled
is not None:
731 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
733 bsa
= Signal(self
.width
, reset_less
=True)
734 bsb
= Signal(self
.width
, reset_less
=True)
735 a_index
, b_index
= self
.a_index
, self
.b_index
737 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
738 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
739 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
740 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
742 #TODO: sort out width issues, get inputs a/b switched on/off.
743 #data going into Muxes is 1/2 the required width
747 bsa = Signal(self.twidth//2, reset_less=True)
748 bsb = Signal(self.twidth//2, reset_less=True)
749 asel = Signal(width, reset_less=True)
750 bsel = Signal(width, reset_less=True)
751 a_index, b_index = self.a_index, self.b_index
752 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
753 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
754 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
755 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
756 m.d.comb += self.ti.eq(bsa * bsb)
757 m.d.comb += self.term.eq(self.ti)
763 class ProductTerms(Elaboratable
):
764 """ creates a bank of product terms. also performs the actual bit-selection
765 this class is to be wrapped with a for-loop on the "a" operand.
766 it creates a second-level for-loop on the "b" operand.
768 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
769 self
.a_index
= a_index
774 self
.a
= Signal(twidth
//2, reset_less
=True)
775 self
.b
= Signal(twidth
//2, reset_less
=True)
776 self
.pb_en
= Signal(pbwid
, reset_less
=True)
777 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
778 for i
in range(blen
)]
780 def elaborate(self
, platform
):
784 for b_index
in range(self
.blen
):
785 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
786 self
.a_index
, b_index
)
787 setattr(m
.submodules
, "term_%d" % b_index
, t
)
789 m
.d
.comb
+= t
.a
.eq(self
.a
)
790 m
.d
.comb
+= t
.b
.eq(self
.b
)
791 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
793 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
798 class LSBNegTerm(Elaboratable
):
800 def __init__(self
, bit_width
):
801 self
.bit_width
= bit_width
802 self
.part
= Signal(reset_less
=True)
803 self
.signed
= Signal(reset_less
=True)
804 self
.op
= Signal(bit_width
, reset_less
=True)
805 self
.msb
= Signal(reset_less
=True)
806 self
.nt
= Signal(bit_width
*2, reset_less
=True)
807 self
.nl
= Signal(bit_width
*2, reset_less
=True)
809 def elaborate(self
, platform
):
812 bit_wid
= self
.bit_width
813 ext
= Repl(0, bit_wid
) # extend output to HI part
815 # determine sign of each incoming number *in this partition*
816 enabled
= Signal(reset_less
=True)
817 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
819 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
820 # negation operation is split into a bitwise not and a +1.
821 # likewise for 16, 32, and 64-bit values.
823 # width-extended 1s complement if a is signed, otherwise zero
824 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
826 # add 1 if signed, otherwise add zero
827 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
832 class Parts(Elaboratable
):
834 def __init__(self
, pbwid
, part_pts
, n_parts
):
837 self
.part_pts
= PartitionPoints
.like(part_pts
)
839 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
840 for i
in range(n_parts
)]
842 def elaborate(self
, platform
):
845 part_pts
, parts
= self
.part_pts
, self
.parts
846 # collect part-bytes (double factor because the input is extended)
847 pbs
= Signal(self
.pbwid
, reset_less
=True)
849 for i
in range(self
.pbwid
):
850 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
851 m
.d
.comb
+= pb
.eq(part_pts
.part_byte(i
))
853 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
855 # negated-temporary copy of partition bits
856 npbs
= Signal
.like(pbs
, reset_less
=True)
857 m
.d
.comb
+= npbs
.eq(~pbs
)
858 byte_count
= 8 // len(parts
)
859 for i
in range(len(parts
)):
861 pbl
.append(npbs
[i
* byte_count
- 1])
862 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
864 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
865 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
866 m
.d
.comb
+= value
.eq(Cat(*pbl
))
867 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
872 class Part(Elaboratable
):
873 """ a key class which, depending on the partitioning, will determine
874 what action to take when parts of the output are signed or unsigned.
876 this requires 2 pieces of data *per operand, per partition*:
877 whether the MSB is HI/LO (per partition!), and whether a signed
878 or unsigned operation has been *requested*.
880 once that is determined, signed is basically carried out
881 by splitting 2's complement into 1's complement plus one.
882 1's complement is just a bit-inversion.
884 the extra terms - as separate terms - are then thrown at the
885 AddReduce alongside the multiplication part-results.
887 def __init__(self
, part_pts
, width
, n_parts
, pbwid
):
890 self
.part_pts
= part_pts
893 self
.a
= Signal(64, reset_less
=True)
894 self
.b
= Signal(64, reset_less
=True)
895 self
.a_signed
= [Signal(name
=f
"a_signed_{i}", reset_less
=True)
897 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}", reset_less
=True)
899 self
.pbs
= Signal(pbwid
, reset_less
=True)
902 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
903 for i
in range(n_parts
)]
905 self
.not_a_term
= Signal(width
, reset_less
=True)
906 self
.neg_lsb_a_term
= Signal(width
, reset_less
=True)
907 self
.not_b_term
= Signal(width
, reset_less
=True)
908 self
.neg_lsb_b_term
= Signal(width
, reset_less
=True)
910 def elaborate(self
, platform
):
913 pbs
, parts
= self
.pbs
, self
.parts
914 part_pts
= self
.part_pts
915 m
.submodules
.p
= p
= Parts(self
.pbwid
, part_pts
, len(parts
))
916 m
.d
.comb
+= p
.part_pts
.eq(part_pts
)
919 byte_count
= 8 // len(parts
)
921 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
922 self
.not_a_term
, self
.neg_lsb_a_term
,
923 self
.not_b_term
, self
.neg_lsb_b_term
)
925 byte_width
= 8 // len(parts
) # byte width
926 bit_wid
= 8 * byte_width
# bit width
927 nat
, nbt
, nla
, nlb
= [], [], [], []
928 for i
in range(len(parts
)):
929 # work out bit-inverted and +1 term for a.
930 pa
= LSBNegTerm(bit_wid
)
931 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
932 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
933 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
934 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
935 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
939 # work out bit-inverted and +1 term for b
940 pb
= LSBNegTerm(bit_wid
)
941 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
942 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
943 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
944 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
945 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
949 # concatenate together and return all 4 results.
950 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
951 not_b_term
.eq(Cat(*nbt
)),
952 neg_lsb_a_term
.eq(Cat(*nla
)),
953 neg_lsb_b_term
.eq(Cat(*nlb
)),
959 class IntermediateOut(Elaboratable
):
960 """ selects the HI/LO part of the multiplication, for a given bit-width
961 the output is also reconstructed in its SIMD (partition) lanes.
963 def __init__(self
, width
, out_wid
, n_parts
):
965 self
.n_parts
= n_parts
966 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
968 self
.intermed
= Signal(out_wid
, reset_less
=True)
969 self
.output
= Signal(out_wid
//2, reset_less
=True)
971 def elaborate(self
, platform
):
977 for i
in range(self
.n_parts
):
978 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
980 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
981 self
.intermed
.bit_select(i
* w
*2, w
),
982 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
984 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
989 class FinalOut(Elaboratable
):
990 """ selects the final output based on the partitioning.
992 each byte is selectable independently, i.e. it is possible
993 that some partitions requested 8-bit computation whilst others
994 requested 16 or 32 bit.
996 def __init__(self
, output_width
, n_parts
, part_pts
):
997 self
.part_pts
= part_pts
998 self
.output_width
= output_width
999 self
.n_parts
= n_parts
1000 self
.out_wid
= output_width
//2
1002 self
.i
= self
.ispec()
1003 self
.o
= self
.ospec()
1006 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1011 def setup(self
, m
, i
):
1012 m
.submodules
.finalout
= self
1013 m
.d
.comb
+= self
.i
.eq(i
)
1015 def process(self
, i
):
1018 def elaborate(self
, platform
):
1021 part_pts
= self
.part_pts
1022 m
.submodules
.p_8
= p_8
= Parts(8, part_pts
, 8)
1023 m
.submodules
.p_16
= p_16
= Parts(8, part_pts
, 4)
1024 m
.submodules
.p_32
= p_32
= Parts(8, part_pts
, 2)
1025 m
.submodules
.p_64
= p_64
= Parts(8, part_pts
, 1)
1027 out_part_pts
= self
.i
.part_pts
1030 d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
1031 d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
1032 d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
1034 i8
= Signal(self
.out_wid
, reset_less
=True)
1035 i16
= Signal(self
.out_wid
, reset_less
=True)
1036 i32
= Signal(self
.out_wid
, reset_less
=True)
1037 i64
= Signal(self
.out_wid
, reset_less
=True)
1039 m
.d
.comb
+= p_8
.part_pts
.eq(out_part_pts
)
1040 m
.d
.comb
+= p_16
.part_pts
.eq(out_part_pts
)
1041 m
.d
.comb
+= p_32
.part_pts
.eq(out_part_pts
)
1042 m
.d
.comb
+= p_64
.part_pts
.eq(out_part_pts
)
1044 for i
in range(len(p_8
.parts
)):
1045 m
.d
.comb
+= d8
[i
].eq(p_8
.parts
[i
])
1046 for i
in range(len(p_16
.parts
)):
1047 m
.d
.comb
+= d16
[i
].eq(p_16
.parts
[i
])
1048 for i
in range(len(p_32
.parts
)):
1049 m
.d
.comb
+= d32
[i
].eq(p_32
.parts
[i
])
1050 m
.d
.comb
+= i8
.eq(self
.i
.outputs
[0])
1051 m
.d
.comb
+= i16
.eq(self
.i
.outputs
[1])
1052 m
.d
.comb
+= i32
.eq(self
.i
.outputs
[2])
1053 m
.d
.comb
+= i64
.eq(self
.i
.outputs
[3])
1057 # select one of the outputs: d8 selects i8, d16 selects i16
1058 # d32 selects i32, and the default is i64.
1059 # d8 and d16 are ORed together in the first Mux
1060 # then the 2nd selects either i8 or i16.
1061 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1062 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
1064 Mux(d8
[i
] | d16
[i
// 2],
1065 Mux(d8
[i
], i8
.bit_select(i
* 8, 8),
1066 i16
.bit_select(i
* 8, 8)),
1067 Mux(d32
[i
// 4], i32
.bit_select(i
* 8, 8),
1068 i64
.bit_select(i
* 8, 8))))
1072 m
.d
.comb
+= self
.o
.output
.eq(Cat(*ol
))
1073 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.intermediate_output
)
1078 class OrMod(Elaboratable
):
1079 """ ORs four values together in a hierarchical tree
1081 def __init__(self
, wid
):
1083 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
1085 self
.orout
= Signal(wid
, reset_less
=True)
1087 def elaborate(self
, platform
):
1089 or1
= Signal(self
.wid
, reset_less
=True)
1090 or2
= Signal(self
.wid
, reset_less
=True)
1091 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
1092 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
1093 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
1098 class Signs(Elaboratable
):
1099 """ determines whether a or b are signed numbers
1100 based on the required operation type (OP_MUL_*)
1104 self
.part_ops
= Signal(2, reset_less
=True)
1105 self
.a_signed
= Signal(reset_less
=True)
1106 self
.b_signed
= Signal(reset_less
=True)
1108 def elaborate(self
, platform
):
1112 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
1113 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
1114 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
1115 m
.d
.comb
+= self
.a_signed
.eq(asig
)
1116 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
1121 class IntermediateData
:
1123 def __init__(self
, part_pts
, output_width
, n_parts
):
1124 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
1125 for i
in range(n_parts
)]
1126 self
.part_pts
= part_pts
.like()
1127 self
.outputs
= [Signal(output_width
, name
="io%d" % i
, reset_less
=True)
1129 # intermediates (needed for unit tests)
1130 self
.intermediate_output
= Signal(output_width
)
1132 def eq_from(self
, part_pts
, outputs
, intermediate_output
,
1134 return [self
.part_pts
.eq(part_pts
)] + \
1135 [self
.intermediate_output
.eq(intermediate_output
)] + \
1136 [self
.outputs
[i
].eq(outputs
[i
])
1137 for i
in range(4)] + \
1138 [self
.part_ops
[i
].eq(part_ops
[i
])
1139 for i
in range(len(self
.part_ops
))]
1142 return self
.eq_from(rhs
.part_pts
, rhs
.outputs
,
1143 rhs
.intermediate_output
, rhs
.part_ops
)
1151 self
.part_pts
= PartitionPoints()
1152 for i
in range(8, 64, 8):
1153 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1154 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1156 def eq_from(self
, part_pts
, a
, b
, part_ops
):
1157 return [self
.part_pts
.eq(part_pts
)] + \
1158 [self
.a
.eq(a
), self
.b
.eq(b
)] + \
1159 [self
.part_ops
[i
].eq(part_ops
[i
])
1160 for i
in range(len(self
.part_ops
))]
1163 return self
.eq_from(rhs
.part_pts
, rhs
.a
, rhs
.b
, rhs
.part_ops
)
1169 self
.intermediate_output
= Signal(128) # needed for unit tests
1170 self
.output
= Signal(64)
1173 return [self
.intermediate_output
.eq(rhs
.intermediate_output
),
1174 self
.output
.eq(rhs
.output
)]
1177 class AllTerms(PipeModBase
):
1178 """Set of terms to be added together
1181 def __init__(self
, pspec
, n_inputs
):
1182 """Create an ``AllTerms``.
1184 self
.n_inputs
= n_inputs
1185 self
.n_parts
= pspec
.n_parts
1186 self
.output_width
= pspec
.width
* 2
1187 super().__init
__(pspec
, "allterms")
1193 return AddReduceData(self
.i
.part_pts
, self
.n_inputs
,
1194 self
.output_width
, self
.n_parts
)
1196 def elaborate(self
, platform
):
1199 eps
= self
.i
.part_pts
1201 # collect part-bytes
1202 pbs
= Signal(8, reset_less
=True)
1205 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1206 m
.d
.comb
+= pb
.eq(eps
.part_byte(i
))
1208 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1215 setattr(m
.submodules
, "signs%d" % i
, s
)
1216 m
.d
.comb
+= s
.part_ops
.eq(self
.i
.part_ops
[i
])
1218 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, 8)
1219 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, 8)
1220 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, 8)
1221 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, 8)
1222 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1223 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1224 m
.d
.comb
+= mod
.a
.eq(self
.i
.a
)
1225 m
.d
.comb
+= mod
.b
.eq(self
.i
.b
)
1226 for i
in range(len(signs
)):
1227 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1228 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1229 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1230 nat_l
.append(mod
.not_a_term
)
1231 nbt_l
.append(mod
.not_b_term
)
1232 nla_l
.append(mod
.neg_lsb_a_term
)
1233 nlb_l
.append(mod
.neg_lsb_b_term
)
1237 for a_index
in range(8):
1238 t
= ProductTerms(8, 128, 8, a_index
, 8)
1239 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1241 m
.d
.comb
+= t
.a
.eq(self
.i
.a
)
1242 m
.d
.comb
+= t
.b
.eq(self
.i
.b
)
1243 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1245 for term
in t
.terms
:
1248 # it's fine to bitwise-or data together since they are never enabled
1250 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1251 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1252 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1253 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1254 for l
, mod
in [(nat_l
, nat_or
),
1258 for i
in range(len(l
)):
1259 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1260 terms
.append(mod
.orout
)
1262 # copy the intermediate terms to the output
1263 for i
, value
in enumerate(terms
):
1264 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
1266 # copy reg part points and part ops to output
1267 m
.d
.comb
+= self
.o
.part_pts
.eq(eps
)
1268 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
1269 for i
in range(len(self
.i
.part_ops
))]
1274 class Intermediates(Elaboratable
):
1275 """ Intermediate output modules
1278 def __init__(self
, output_width
, n_parts
, part_pts
):
1279 self
.part_pts
= part_pts
1280 self
.output_width
= output_width
1281 self
.n_parts
= n_parts
1283 self
.i
= self
.ispec()
1284 self
.o
= self
.ospec()
1287 return FinalReduceData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1290 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1292 def setup(self
, m
, i
):
1293 m
.submodules
.intermediates
= self
1294 m
.d
.comb
+= self
.i
.eq(i
)
1296 def process(self
, i
):
1299 def elaborate(self
, platform
):
1302 out_part_ops
= self
.i
.part_ops
1303 out_part_pts
= self
.i
.part_pts
1306 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1307 m
.d
.comb
+= io64
.intermed
.eq(self
.i
.output
)
1309 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1310 m
.d
.comb
+= self
.o
.outputs
[3].eq(io64
.output
)
1313 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1314 m
.d
.comb
+= io32
.intermed
.eq(self
.i
.output
)
1316 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1317 m
.d
.comb
+= self
.o
.outputs
[2].eq(io32
.output
)
1320 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1321 m
.d
.comb
+= io16
.intermed
.eq(self
.i
.output
)
1323 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1324 m
.d
.comb
+= self
.o
.outputs
[1].eq(io16
.output
)
1327 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1328 m
.d
.comb
+= io8
.intermed
.eq(self
.i
.output
)
1330 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1331 m
.d
.comb
+= self
.o
.outputs
[0].eq(io8
.output
)
1334 m
.d
.comb
+= self
.o
.part_ops
[i
].eq(out_part_ops
[i
])
1335 m
.d
.comb
+= self
.o
.part_pts
.eq(out_part_pts
)
1336 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.output
)
1341 class Mul8_16_32_64(Elaboratable
):
1342 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1344 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1345 partitions on naturally-aligned boundaries. Supports the operation being
1346 set for each partition independently.
1348 :attribute part_pts: the input partition points. Has a partition point at
1349 multiples of 8 in 0 < i < 64. Each partition point's associated
1350 ``Value`` is a ``Signal``. Modification not supported, except for by
1352 :attribute part_ops: the operation for each byte. The operation for a
1353 particular partition is selected by assigning the selected operation
1354 code to each byte in the partition. The allowed operation codes are:
1356 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1357 RISC-V's `mul` instruction.
1358 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1359 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1361 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1362 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1363 `mulhsu` instruction.
1364 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1365 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1369 def __init__(self
, register_levels
=()):
1370 """ register_levels: specifies the points in the cascade at which
1371 flip-flops are to be inserted.
1374 self
.id_wid
= 0 # num_bits(num_rows)
1376 self
.pspec
= PipelineSpec(64, self
.id_wid
, self
.op_wid
, n_ops
=3)
1377 self
.pspec
.n_parts
= 8
1380 self
.register_levels
= list(register_levels
)
1382 self
.i
= self
.ispec()
1383 self
.o
= self
.ospec()
1386 self
.part_pts
= self
.i
.part_pts
1387 self
.part_ops
= self
.i
.part_ops
1392 self
.intermediate_output
= self
.o
.intermediate_output
1393 self
.output
= self
.o
.output
1401 def elaborate(self
, platform
):
1404 part_pts
= self
.part_pts
1406 n_parts
= self
.pspec
.n_parts
1408 output_width
= self
.pspec
.width
* 2
1409 t
= AllTerms(self
.pspec
, n_inputs
)
1414 at
= AddReduceInternal(t
.process(self
.i
), self
.pspec
, partition_step
=2)
1417 for idx
in range(len(at
.levels
)):
1418 mcur
= at
.levels
[idx
]
1421 if idx
in self
.register_levels
:
1422 m
.d
.sync
+= o
.eq(mcur
.process(i
))
1424 m
.d
.comb
+= o
.eq(mcur
.process(i
))
1425 i
= o
# for next loop
1427 interm
= Intermediates(128, 8, part_pts
)
1429 o
= interm
.process(interm
.i
)
1432 finalout
= FinalOut(128, 8, part_pts
)
1433 finalout
.setup(m
, o
)
1434 m
.d
.comb
+= self
.o
.eq(finalout
.process(o
))
1439 if __name__
== "__main__":
1443 m
.intermediate_output
,
1446 *m
.part_pts
.values()])