1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
11 from ieee754
.pipeline
import PipelineSpec
12 from nmutil
.pipemodbase
import PipeModBase
15 class PartitionPoints(dict):
16 """Partition points and corresponding ``Value``s.
18 The points at where an ALU is partitioned along with ``Value``s that
19 specify if the corresponding partition points are enabled.
21 For example: ``{1: True, 5: True, 10: True}`` with
22 ``width == 16`` specifies that the ALU is split into 4 sections:
25 * bits 5 <= ``i`` < 10
26 * bits 10 <= ``i`` < 16
28 If the partition_points were instead ``{1: True, 5: a, 10: True}``
29 where ``a`` is a 1-bit ``Signal``:
30 * If ``a`` is asserted:
33 * bits 5 <= ``i`` < 10
34 * bits 10 <= ``i`` < 16
37 * bits 1 <= ``i`` < 10
38 * bits 10 <= ``i`` < 16
41 def __init__(self
, partition_points
=None):
42 """Create a new ``PartitionPoints``.
44 :param partition_points: the input partition points to values mapping.
47 if partition_points
is not None:
48 for point
, enabled
in partition_points
.items():
49 if not isinstance(point
, int):
50 raise TypeError("point must be a non-negative integer")
52 raise ValueError("point must be a non-negative integer")
53 self
[point
] = Value
.wrap(enabled
)
55 def like(self
, name
=None, src_loc_at
=0, mul
=1):
56 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
58 :param name: the base name for the new ``Signal``s.
59 :param mul: a multiplication factor on the indices
62 name
= Signal(src_loc_at
=1+src_loc_at
).name
# get variable name
63 retval
= PartitionPoints()
64 for point
, enabled
in self
.items():
66 retval
[point
] = Signal(enabled
.shape(), name
=f
"{name}_{point}")
70 """Assign ``PartitionPoints`` using ``Signal.eq``."""
71 if set(self
.keys()) != set(rhs
.keys()):
72 raise ValueError("incompatible point set")
73 for point
, enabled
in self
.items():
74 yield enabled
.eq(rhs
[point
])
76 def as_mask(self
, width
, mul
=1):
77 """Create a bit-mask from `self`.
79 Each bit in the returned mask is clear only if the partition point at
80 the same bit-index is enabled.
82 :param width: the bit width of the resulting mask
83 :param mul: a "multiplier" which in-place expands the partition points
84 typically set to "2" when used for multipliers
87 for i
in range(width
):
89 if i
.is_integer() and int(i
) in self
:
95 def get_max_partition_count(self
, width
):
96 """Get the maximum number of partitions.
98 Gets the number of partitions when all partition points are enabled.
101 for point
in self
.keys():
106 def fits_in_width(self
, width
):
107 """Check if all partition points are smaller than `width`."""
108 for point
in self
.keys():
113 def part_byte(self
, index
, mfactor
=1): # mfactor used for "expanding"
114 if index
== -1 or index
== 7:
116 assert index
>= 0 and index
< 8
117 return self
[(index
* 8 + 8)*mfactor
]
120 class FullAdder(Elaboratable
):
123 :attribute in0: the first input
124 :attribute in1: the second input
125 :attribute in2: the third input
126 :attribute sum: the sum output
127 :attribute carry: the carry output
129 Rather than do individual full adders (and have an array of them,
130 which would be very slow to simulate), this module can specify the
131 bit width of the inputs and outputs: in effect it performs multiple
132 Full 3-2 Add operations "in parallel".
135 def __init__(self
, width
):
136 """Create a ``FullAdder``.
138 :param width: the bit width of the input and output
140 self
.in0
= Signal(width
, reset_less
=True)
141 self
.in1
= Signal(width
, reset_less
=True)
142 self
.in2
= Signal(width
, reset_less
=True)
143 self
.sum = Signal(width
, reset_less
=True)
144 self
.carry
= Signal(width
, reset_less
=True)
146 def elaborate(self
, platform
):
147 """Elaborate this module."""
149 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
150 m
.d
.comb
+= self
.carry
.eq((self
.in0
& self
.in1
)
151 |
(self
.in1
& self
.in2
)
152 |
(self
.in2
& self
.in0
))
156 class MaskedFullAdder(Elaboratable
):
157 """Masked Full Adder.
159 :attribute mask: the carry partition mask
160 :attribute in0: the first input
161 :attribute in1: the second input
162 :attribute in2: the third input
163 :attribute sum: the sum output
164 :attribute mcarry: the masked carry output
166 FullAdders are always used with a "mask" on the output. To keep
167 the graphviz "clean", this class performs the masking here rather
168 than inside a large for-loop.
170 See the following discussion as to why this is no longer derived
171 from FullAdder. Each carry is shifted here *before* being ANDed
172 with the mask, so that an AOI cell may be used (which is more
174 https://en.wikipedia.org/wiki/AND-OR-Invert
175 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
178 def __init__(self
, width
):
179 """Create a ``MaskedFullAdder``.
181 :param width: the bit width of the input and output
184 self
.mask
= Signal(width
, reset_less
=True)
185 self
.mcarry
= Signal(width
, reset_less
=True)
186 self
.in0
= Signal(width
, reset_less
=True)
187 self
.in1
= Signal(width
, reset_less
=True)
188 self
.in2
= Signal(width
, reset_less
=True)
189 self
.sum = Signal(width
, reset_less
=True)
191 def elaborate(self
, platform
):
192 """Elaborate this module."""
194 s1
= Signal(self
.width
, reset_less
=True)
195 s2
= Signal(self
.width
, reset_less
=True)
196 s3
= Signal(self
.width
, reset_less
=True)
197 c1
= Signal(self
.width
, reset_less
=True)
198 c2
= Signal(self
.width
, reset_less
=True)
199 c3
= Signal(self
.width
, reset_less
=True)
200 m
.d
.comb
+= self
.sum.eq(self
.in0 ^ self
.in1 ^ self
.in2
)
201 m
.d
.comb
+= s1
.eq(Cat(0, self
.in0
))
202 m
.d
.comb
+= s2
.eq(Cat(0, self
.in1
))
203 m
.d
.comb
+= s3
.eq(Cat(0, self
.in2
))
204 m
.d
.comb
+= c1
.eq(s1
& s2
& self
.mask
)
205 m
.d
.comb
+= c2
.eq(s2
& s3
& self
.mask
)
206 m
.d
.comb
+= c3
.eq(s3
& s1
& self
.mask
)
207 m
.d
.comb
+= self
.mcarry
.eq(c1 | c2 | c3
)
211 class PartitionedAdder(Elaboratable
):
212 """Partitioned Adder.
214 Performs the final add. The partition points are included in the
215 actual add (in one of the operands only), which causes a carry over
216 to the next bit. Then the final output *removes* the extra bits from
219 partition: .... P... P... P... P... (32 bits)
220 a : .... .... .... .... .... (32 bits)
221 b : .... .... .... .... .... (32 bits)
222 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
223 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
224 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
225 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
227 :attribute width: the bit width of the input and output. Read-only.
228 :attribute a: the first input to the adder
229 :attribute b: the second input to the adder
230 :attribute output: the sum output
231 :attribute partition_points: the input partition points. Modification not
232 supported, except for by ``Signal.eq``.
235 def __init__(self
, width
, partition_points
, partition_step
=1):
236 """Create a ``PartitionedAdder``.
238 :param width: the bit width of the input and output
239 :param partition_points: the input partition points
240 :param partition_step: a multiplier (typically double) step
241 which in-place "expands" the partition points
244 self
.pmul
= partition_step
245 self
.a
= Signal(width
, reset_less
=True)
246 self
.b
= Signal(width
, reset_less
=True)
247 self
.output
= Signal(width
, reset_less
=True)
248 self
.partition_points
= PartitionPoints(partition_points
)
249 if not self
.partition_points
.fits_in_width(width
):
250 raise ValueError("partition_points doesn't fit in width")
252 for i
in range(self
.width
):
253 if i
in self
.partition_points
:
256 self
._expanded
_width
= expanded_width
258 def elaborate(self
, platform
):
259 """Elaborate this module."""
261 expanded_a
= Signal(self
._expanded
_width
, reset_less
=True)
262 expanded_b
= Signal(self
._expanded
_width
, reset_less
=True)
263 expanded_o
= Signal(self
._expanded
_width
, reset_less
=True)
266 # store bits in a list, use Cat later. graphviz is much cleaner
267 al
, bl
, ol
, ea
, eb
, eo
= [],[],[],[],[],[]
269 # partition points are "breaks" (extra zeros or 1s) in what would
270 # otherwise be a massive long add. when the "break" points are 0,
271 # whatever is in it (in the output) is discarded. however when
272 # there is a "1", it causes a roll-over carry to the *next* bit.
273 # we still ignore the "break" bit in the [intermediate] output,
274 # however by that time we've got the effect that we wanted: the
275 # carry has been carried *over* the break point.
277 for i
in range(self
.width
):
278 pi
= i
/self
.pmul
# double the range of the partition point test
279 if pi
.is_integer() and pi
in self
.partition_points
:
280 # add extra bit set to 0 + 0 for enabled partition points
281 # and 1 + 0 for disabled partition points
282 ea
.append(expanded_a
[expanded_index
])
283 al
.append(~self
.partition_points
[pi
]) # add extra bit in a
284 eb
.append(expanded_b
[expanded_index
])
285 bl
.append(C(0)) # yes, add a zero
286 expanded_index
+= 1 # skip the extra point. NOT in the output
287 ea
.append(expanded_a
[expanded_index
])
288 eb
.append(expanded_b
[expanded_index
])
289 eo
.append(expanded_o
[expanded_index
])
292 ol
.append(self
.output
[i
])
295 # combine above using Cat
296 m
.d
.comb
+= Cat(*ea
).eq(Cat(*al
))
297 m
.d
.comb
+= Cat(*eb
).eq(Cat(*bl
))
298 m
.d
.comb
+= Cat(*ol
).eq(Cat(*eo
))
300 # use only one addition to take advantage of look-ahead carry and
301 # special hardware on FPGAs
302 m
.d
.comb
+= expanded_o
.eq(expanded_a
+ expanded_b
)
306 FULL_ADDER_INPUT_COUNT
= 3
310 def __init__(self
, part_pts
, n_inputs
, output_width
, n_parts
):
311 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
312 for i
in range(n_parts
)]
313 self
.terms
= [Signal(output_width
, name
=f
"terms_{i}",
315 for i
in range(n_inputs
)]
316 self
.part_pts
= part_pts
.like()
318 def eq_from(self
, part_pts
, inputs
, part_ops
):
319 return [self
.part_pts
.eq(part_pts
)] + \
320 [self
.terms
[i
].eq(inputs
[i
])
321 for i
in range(len(self
.terms
))] + \
322 [self
.part_ops
[i
].eq(part_ops
[i
])
323 for i
in range(len(self
.part_ops
))]
326 return self
.eq_from(rhs
.part_pts
, rhs
.terms
, rhs
.part_ops
)
329 class FinalReduceData
:
331 def __init__(self
, part_pts
, output_width
, n_parts
):
332 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
333 for i
in range(n_parts
)]
334 self
.output
= Signal(output_width
, reset_less
=True)
335 self
.part_pts
= part_pts
.like()
337 def eq_from(self
, part_pts
, output
, part_ops
):
338 return [self
.part_pts
.eq(part_pts
)] + \
339 [self
.output
.eq(output
)] + \
340 [self
.part_ops
[i
].eq(part_ops
[i
])
341 for i
in range(len(self
.part_ops
))]
344 return self
.eq_from(rhs
.part_pts
, rhs
.output
, rhs
.part_ops
)
347 class FinalAdd(PipeModBase
):
348 """ Final stage of add reduce
351 def __init__(self
, pspec
, lidx
, n_inputs
, partition_points
,
354 self
.partition_step
= partition_step
355 self
.output_width
= pspec
.width
* 2
356 self
.n_inputs
= n_inputs
357 self
.n_parts
= pspec
.n_parts
358 self
.partition_points
= PartitionPoints(partition_points
)
359 if not self
.partition_points
.fits_in_width(self
.output_width
):
360 raise ValueError("partition_points doesn't fit in output_width")
362 super().__init
__(pspec
, "finaladd")
365 return AddReduceData(self
.partition_points
, self
.n_inputs
,
366 self
.output_width
, self
.n_parts
)
369 return FinalReduceData(self
.partition_points
,
370 self
.output_width
, self
.n_parts
)
372 def elaborate(self
, platform
):
373 """Elaborate this module."""
376 output_width
= self
.output_width
377 output
= Signal(output_width
, reset_less
=True)
378 if self
.n_inputs
== 0:
379 # use 0 as the default output value
380 m
.d
.comb
+= output
.eq(0)
381 elif self
.n_inputs
== 1:
382 # handle single input
383 m
.d
.comb
+= output
.eq(self
.i
.terms
[0])
385 # base case for adding 2 inputs
386 assert self
.n_inputs
== 2
387 adder
= PartitionedAdder(output_width
,
388 self
.i
.part_pts
, self
.partition_step
)
389 m
.submodules
.final_adder
= adder
390 m
.d
.comb
+= adder
.a
.eq(self
.i
.terms
[0])
391 m
.d
.comb
+= adder
.b
.eq(self
.i
.terms
[1])
392 m
.d
.comb
+= output
.eq(adder
.output
)
395 m
.d
.comb
+= self
.o
.eq_from(self
.i
.part_pts
, output
,
401 class AddReduceSingle(PipeModBase
):
402 """Add list of numbers together.
404 :attribute inputs: input ``Signal``s to be summed. Modification not
405 supported, except for by ``Signal.eq``.
406 :attribute register_levels: List of nesting levels that should have
408 :attribute output: output sum.
409 :attribute partition_points: the input partition points. Modification not
410 supported, except for by ``Signal.eq``.
413 def __init__(self
, pspec
, lidx
, n_inputs
, partition_points
,
415 """Create an ``AddReduce``.
417 :param inputs: input ``Signal``s to be summed.
418 :param output_width: bit-width of ``output``.
419 :param partition_points: the input partition points.
422 self
.partition_step
= partition_step
423 self
.n_inputs
= n_inputs
424 self
.n_parts
= pspec
.n_parts
425 self
.output_width
= pspec
.width
* 2
426 self
.partition_points
= PartitionPoints(partition_points
)
427 if not self
.partition_points
.fits_in_width(self
.output_width
):
428 raise ValueError("partition_points doesn't fit in output_width")
430 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
431 self
.n_terms
= AddReduceSingle
.calc_n_inputs(n_inputs
, self
.groups
)
433 super().__init
__(pspec
, "addreduce_%d" % lidx
)
436 return AddReduceData(self
.partition_points
, self
.n_inputs
,
437 self
.output_width
, self
.n_parts
)
440 return AddReduceData(self
.partition_points
, self
.n_terms
,
441 self
.output_width
, self
.n_parts
)
444 def calc_n_inputs(n_inputs
, groups
):
445 retval
= len(groups
)*2
446 if n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
448 elif n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
451 assert n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
455 def get_max_level(input_count
):
456 """Get the maximum level.
458 All ``register_levels`` must be less than or equal to the maximum
463 groups
= AddReduceSingle
.full_adder_groups(input_count
)
466 input_count
%= FULL_ADDER_INPUT_COUNT
467 input_count
+= 2 * len(groups
)
471 def full_adder_groups(input_count
):
472 """Get ``inputs`` indices for which a full adder should be built."""
474 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
475 FULL_ADDER_INPUT_COUNT
)
477 def create_next_terms(self
):
478 """ create next intermediate terms, for linking up in elaborate, below
483 # create full adders for this recursive level.
484 # this shrinks N terms to 2 * (N // 3) plus the remainder
485 for i
in self
.groups
:
486 adder_i
= MaskedFullAdder(self
.output_width
)
487 adders
.append((i
, adder_i
))
488 # add both the sum and the masked-carry to the next level.
489 # 3 inputs have now been reduced to 2...
490 terms
.append(adder_i
.sum)
491 terms
.append(adder_i
.mcarry
)
492 # handle the remaining inputs.
493 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
494 terms
.append(self
.i
.terms
[-1])
495 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
496 # Just pass the terms to the next layer, since we wouldn't gain
497 # anything by using a half adder since there would still be 2 terms
498 # and just passing the terms to the next layer saves gates.
499 terms
.append(self
.i
.terms
[-2])
500 terms
.append(self
.i
.terms
[-1])
502 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
506 def elaborate(self
, platform
):
507 """Elaborate this module."""
510 terms
, adders
= self
.create_next_terms()
512 # copy the intermediate terms to the output
513 for i
, value
in enumerate(terms
):
514 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
516 # copy reg part points and part ops to output
517 m
.d
.comb
+= self
.o
.part_pts
.eq(self
.i
.part_pts
)
518 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
519 for i
in range(len(self
.i
.part_ops
))]
521 # set up the partition mask (for the adders)
522 part_mask
= Signal(self
.output_width
, reset_less
=True)
524 # get partition points as a mask
525 mask
= self
.i
.part_pts
.as_mask(self
.output_width
,
526 mul
=self
.partition_step
)
527 m
.d
.comb
+= part_mask
.eq(mask
)
529 # add and link the intermediate term modules
530 for i
, (iidx
, adder_i
) in enumerate(adders
):
531 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
533 m
.d
.comb
+= adder_i
.in0
.eq(self
.i
.terms
[iidx
])
534 m
.d
.comb
+= adder_i
.in1
.eq(self
.i
.terms
[iidx
+ 1])
535 m
.d
.comb
+= adder_i
.in2
.eq(self
.i
.terms
[iidx
+ 2])
536 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
541 class AddReduceInternal
:
542 """Iteratively Add list of numbers together.
544 :attribute inputs: input ``Signal``s to be summed. Modification not
545 supported, except for by ``Signal.eq``.
546 :attribute register_levels: List of nesting levels that should have
548 :attribute output: output sum.
549 :attribute partition_points: the input partition points. Modification not
550 supported, except for by ``Signal.eq``.
553 def __init__(self
, pspec
, n_inputs
, part_pts
, partition_step
=1):
554 """Create an ``AddReduce``.
556 :param inputs: input ``Signal``s to be summed.
557 :param output_width: bit-width of ``output``.
558 :param partition_points: the input partition points.
561 self
.n_inputs
= n_inputs
562 self
.output_width
= pspec
.width
* 2
563 self
.partition_points
= part_pts
564 self
.partition_step
= partition_step
568 def create_levels(self
):
569 """creates reduction levels"""
572 partition_points
= self
.partition_points
575 groups
= AddReduceSingle
.full_adder_groups(ilen
)
579 next_level
= AddReduceSingle(self
.pspec
, lidx
, ilen
,
582 mods
.append(next_level
)
583 partition_points
= next_level
.i
.part_pts
584 ilen
= len(next_level
.o
.terms
)
587 next_level
= FinalAdd(self
.pspec
, lidx
, ilen
,
588 partition_points
, self
.partition_step
)
589 mods
.append(next_level
)
594 class AddReduce(AddReduceInternal
, Elaboratable
):
595 """Recursively Add list of numbers together.
597 :attribute inputs: input ``Signal``s to be summed. Modification not
598 supported, except for by ``Signal.eq``.
599 :attribute register_levels: List of nesting levels that should have
601 :attribute output: output sum.
602 :attribute partition_points: the input partition points. Modification not
603 supported, except for by ``Signal.eq``.
606 def __init__(self
, inputs
, output_width
, register_levels
, part_pts
,
607 part_ops
, partition_step
=1):
608 """Create an ``AddReduce``.
610 :param inputs: input ``Signal``s to be summed.
611 :param output_width: bit-width of ``output``.
612 :param register_levels: List of nesting levels that should have
614 :param partition_points: the input partition points.
616 self
._inputs
= inputs
617 self
._part
_pts
= part_pts
618 self
._part
_ops
= part_ops
619 n_parts
= len(part_ops
)
620 self
.i
= AddReduceData(part_pts
, len(inputs
),
621 output_width
, n_parts
)
622 AddReduceInternal
.__init
__(self
, pspec
, n_inputs
, part_pts
,
624 self
.o
= FinalReduceData(part_pts
, output_width
, n_parts
)
625 self
.register_levels
= register_levels
628 def get_max_level(input_count
):
629 return AddReduceSingle
.get_max_level(input_count
)
632 def next_register_levels(register_levels
):
633 """``Iterable`` of ``register_levels`` for next recursive level."""
634 for level
in register_levels
:
638 def elaborate(self
, platform
):
639 """Elaborate this module."""
642 m
.d
.comb
+= self
.i
.eq_from(self
._part
_pts
, self
._inputs
, self
._part
_ops
)
644 for i
, next_level
in enumerate(self
.levels
):
645 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
648 for idx
in range(len(self
.levels
)):
649 mcur
= self
.levels
[idx
]
650 if idx
in self
.register_levels
:
651 m
.d
.sync
+= mcur
.i
.eq(i
)
653 m
.d
.comb
+= mcur
.i
.eq(i
)
654 i
= mcur
.o
# for next loop
656 # output comes from last module
657 m
.d
.comb
+= self
.o
.eq(i
)
663 OP_MUL_SIGNED_HIGH
= 1
664 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
665 OP_MUL_UNSIGNED_HIGH
= 3
668 def get_term(value
, shift
=0, enabled
=None):
669 if enabled
is not None:
670 value
= Mux(enabled
, value
, 0)
672 value
= Cat(Repl(C(0, 1), shift
), value
)
678 class ProductTerm(Elaboratable
):
679 """ this class creates a single product term (a[..]*b[..]).
680 it has a design flaw in that is the *output* that is selected,
681 where the multiplication(s) are combinatorially generated
685 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
686 self
.a_index
= a_index
687 self
.b_index
= b_index
688 shift
= 8 * (self
.a_index
+ self
.b_index
)
694 self
.ti
= Signal(self
.width
, reset_less
=True)
695 self
.term
= Signal(twidth
, reset_less
=True)
696 self
.a
= Signal(twidth
//2, reset_less
=True)
697 self
.b
= Signal(twidth
//2, reset_less
=True)
698 self
.pb_en
= Signal(pbwid
, reset_less
=True)
701 min_index
= min(self
.a_index
, self
.b_index
)
702 max_index
= max(self
.a_index
, self
.b_index
)
703 for i
in range(min_index
, max_index
):
704 tl
.append(self
.pb_en
[i
])
705 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
707 term_enabled
= Signal(name
=name
, reset_less
=True)
710 self
.enabled
= term_enabled
711 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
713 def elaborate(self
, platform
):
716 if self
.enabled
is not None:
717 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
719 bsa
= Signal(self
.width
, reset_less
=True)
720 bsb
= Signal(self
.width
, reset_less
=True)
721 a_index
, b_index
= self
.a_index
, self
.b_index
723 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
724 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
725 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
726 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
728 #TODO: sort out width issues, get inputs a/b switched on/off.
729 #data going into Muxes is 1/2 the required width
733 bsa = Signal(self.twidth//2, reset_less=True)
734 bsb = Signal(self.twidth//2, reset_less=True)
735 asel = Signal(width, reset_less=True)
736 bsel = Signal(width, reset_less=True)
737 a_index, b_index = self.a_index, self.b_index
738 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
739 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
740 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
741 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
742 m.d.comb += self.ti.eq(bsa * bsb)
743 m.d.comb += self.term.eq(self.ti)
749 class ProductTerms(Elaboratable
):
750 """ creates a bank of product terms. also performs the actual bit-selection
751 this class is to be wrapped with a for-loop on the "a" operand.
752 it creates a second-level for-loop on the "b" operand.
754 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
755 self
.a_index
= a_index
760 self
.a
= Signal(twidth
//2, reset_less
=True)
761 self
.b
= Signal(twidth
//2, reset_less
=True)
762 self
.pb_en
= Signal(pbwid
, reset_less
=True)
763 self
.terms
= [Signal(twidth
, name
="term%d"%i, reset_less
=True) \
764 for i
in range(blen
)]
766 def elaborate(self
, platform
):
770 for b_index
in range(self
.blen
):
771 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
772 self
.a_index
, b_index
)
773 setattr(m
.submodules
, "term_%d" % b_index
, t
)
775 m
.d
.comb
+= t
.a
.eq(self
.a
)
776 m
.d
.comb
+= t
.b
.eq(self
.b
)
777 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
779 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
784 class LSBNegTerm(Elaboratable
):
786 def __init__(self
, bit_width
):
787 self
.bit_width
= bit_width
788 self
.part
= Signal(reset_less
=True)
789 self
.signed
= Signal(reset_less
=True)
790 self
.op
= Signal(bit_width
, reset_less
=True)
791 self
.msb
= Signal(reset_less
=True)
792 self
.nt
= Signal(bit_width
*2, reset_less
=True)
793 self
.nl
= Signal(bit_width
*2, reset_less
=True)
795 def elaborate(self
, platform
):
798 bit_wid
= self
.bit_width
799 ext
= Repl(0, bit_wid
) # extend output to HI part
801 # determine sign of each incoming number *in this partition*
802 enabled
= Signal(reset_less
=True)
803 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
805 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
806 # negation operation is split into a bitwise not and a +1.
807 # likewise for 16, 32, and 64-bit values.
809 # width-extended 1s complement if a is signed, otherwise zero
810 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
812 # add 1 if signed, otherwise add zero
813 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
818 class Parts(Elaboratable
):
820 def __init__(self
, pbwid
, part_pts
, n_parts
):
823 self
.part_pts
= PartitionPoints
.like(part_pts
)
825 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
826 for i
in range(n_parts
)]
828 def elaborate(self
, platform
):
831 part_pts
, parts
= self
.part_pts
, self
.parts
832 # collect part-bytes (double factor because the input is extended)
833 pbs
= Signal(self
.pbwid
, reset_less
=True)
835 for i
in range(self
.pbwid
):
836 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
837 m
.d
.comb
+= pb
.eq(part_pts
.part_byte(i
))
839 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
841 # negated-temporary copy of partition bits
842 npbs
= Signal
.like(pbs
, reset_less
=True)
843 m
.d
.comb
+= npbs
.eq(~pbs
)
844 byte_count
= 8 // len(parts
)
845 for i
in range(len(parts
)):
847 pbl
.append(npbs
[i
* byte_count
- 1])
848 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
850 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
851 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
852 m
.d
.comb
+= value
.eq(Cat(*pbl
))
853 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
858 class Part(Elaboratable
):
859 """ a key class which, depending on the partitioning, will determine
860 what action to take when parts of the output are signed or unsigned.
862 this requires 2 pieces of data *per operand, per partition*:
863 whether the MSB is HI/LO (per partition!), and whether a signed
864 or unsigned operation has been *requested*.
866 once that is determined, signed is basically carried out
867 by splitting 2's complement into 1's complement plus one.
868 1's complement is just a bit-inversion.
870 the extra terms - as separate terms - are then thrown at the
871 AddReduce alongside the multiplication part-results.
873 def __init__(self
, part_pts
, width
, n_parts
, pbwid
):
876 self
.part_pts
= part_pts
879 self
.a
= Signal(64, reset_less
=True)
880 self
.b
= Signal(64, reset_less
=True)
881 self
.a_signed
= [Signal(name
=f
"a_signed_{i}", reset_less
=True)
883 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}", reset_less
=True)
885 self
.pbs
= Signal(pbwid
, reset_less
=True)
888 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
889 for i
in range(n_parts
)]
891 self
.not_a_term
= Signal(width
, reset_less
=True)
892 self
.neg_lsb_a_term
= Signal(width
, reset_less
=True)
893 self
.not_b_term
= Signal(width
, reset_less
=True)
894 self
.neg_lsb_b_term
= Signal(width
, reset_less
=True)
896 def elaborate(self
, platform
):
899 pbs
, parts
= self
.pbs
, self
.parts
900 part_pts
= self
.part_pts
901 m
.submodules
.p
= p
= Parts(self
.pbwid
, part_pts
, len(parts
))
902 m
.d
.comb
+= p
.part_pts
.eq(part_pts
)
905 byte_count
= 8 // len(parts
)
907 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
908 self
.not_a_term
, self
.neg_lsb_a_term
,
909 self
.not_b_term
, self
.neg_lsb_b_term
)
911 byte_width
= 8 // len(parts
) # byte width
912 bit_wid
= 8 * byte_width
# bit width
913 nat
, nbt
, nla
, nlb
= [], [], [], []
914 for i
in range(len(parts
)):
915 # work out bit-inverted and +1 term for a.
916 pa
= LSBNegTerm(bit_wid
)
917 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
918 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
919 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
920 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
921 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
925 # work out bit-inverted and +1 term for b
926 pb
= LSBNegTerm(bit_wid
)
927 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
928 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
929 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
930 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
931 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
935 # concatenate together and return all 4 results.
936 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
937 not_b_term
.eq(Cat(*nbt
)),
938 neg_lsb_a_term
.eq(Cat(*nla
)),
939 neg_lsb_b_term
.eq(Cat(*nlb
)),
945 class IntermediateOut(Elaboratable
):
946 """ selects the HI/LO part of the multiplication, for a given bit-width
947 the output is also reconstructed in its SIMD (partition) lanes.
949 def __init__(self
, width
, out_wid
, n_parts
):
951 self
.n_parts
= n_parts
952 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
954 self
.intermed
= Signal(out_wid
, reset_less
=True)
955 self
.output
= Signal(out_wid
//2, reset_less
=True)
957 def elaborate(self
, platform
):
963 for i
in range(self
.n_parts
):
964 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
966 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
967 self
.intermed
.bit_select(i
* w
*2, w
),
968 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
970 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
975 class FinalOut(PipeModBase
):
976 """ selects the final output based on the partitioning.
978 each byte is selectable independently, i.e. it is possible
979 that some partitions requested 8-bit computation whilst others
980 requested 16 or 32 bit.
982 def __init__(self
, pspec
, part_pts
):
984 self
.part_pts
= part_pts
985 self
.output_width
= pspec
.width
* 2
986 self
.n_parts
= pspec
.n_parts
987 self
.out_wid
= pspec
.width
989 super().__init
__(pspec
, "finalout")
992 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
997 def elaborate(self
, platform
):
1000 part_pts
= self
.part_pts
1001 m
.submodules
.p_8
= p_8
= Parts(8, part_pts
, 8)
1002 m
.submodules
.p_16
= p_16
= Parts(8, part_pts
, 4)
1003 m
.submodules
.p_32
= p_32
= Parts(8, part_pts
, 2)
1004 m
.submodules
.p_64
= p_64
= Parts(8, part_pts
, 1)
1006 out_part_pts
= self
.i
.part_pts
1009 d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
1010 d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
1011 d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
1013 i8
= Signal(self
.out_wid
, reset_less
=True)
1014 i16
= Signal(self
.out_wid
, reset_less
=True)
1015 i32
= Signal(self
.out_wid
, reset_less
=True)
1016 i64
= Signal(self
.out_wid
, reset_less
=True)
1018 m
.d
.comb
+= p_8
.part_pts
.eq(out_part_pts
)
1019 m
.d
.comb
+= p_16
.part_pts
.eq(out_part_pts
)
1020 m
.d
.comb
+= p_32
.part_pts
.eq(out_part_pts
)
1021 m
.d
.comb
+= p_64
.part_pts
.eq(out_part_pts
)
1023 for i
in range(len(p_8
.parts
)):
1024 m
.d
.comb
+= d8
[i
].eq(p_8
.parts
[i
])
1025 for i
in range(len(p_16
.parts
)):
1026 m
.d
.comb
+= d16
[i
].eq(p_16
.parts
[i
])
1027 for i
in range(len(p_32
.parts
)):
1028 m
.d
.comb
+= d32
[i
].eq(p_32
.parts
[i
])
1029 m
.d
.comb
+= i8
.eq(self
.i
.outputs
[0])
1030 m
.d
.comb
+= i16
.eq(self
.i
.outputs
[1])
1031 m
.d
.comb
+= i32
.eq(self
.i
.outputs
[2])
1032 m
.d
.comb
+= i64
.eq(self
.i
.outputs
[3])
1036 # select one of the outputs: d8 selects i8, d16 selects i16
1037 # d32 selects i32, and the default is i64.
1038 # d8 and d16 are ORed together in the first Mux
1039 # then the 2nd selects either i8 or i16.
1040 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1041 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
1043 Mux(d8
[i
] | d16
[i
// 2],
1044 Mux(d8
[i
], i8
.bit_select(i
* 8, 8),
1045 i16
.bit_select(i
* 8, 8)),
1046 Mux(d32
[i
// 4], i32
.bit_select(i
* 8, 8),
1047 i64
.bit_select(i
* 8, 8))))
1051 m
.d
.comb
+= self
.o
.output
.eq(Cat(*ol
))
1052 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.intermediate_output
)
1057 class OrMod(Elaboratable
):
1058 """ ORs four values together in a hierarchical tree
1060 def __init__(self
, wid
):
1062 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
1064 self
.orout
= Signal(wid
, reset_less
=True)
1066 def elaborate(self
, platform
):
1068 or1
= Signal(self
.wid
, reset_less
=True)
1069 or2
= Signal(self
.wid
, reset_less
=True)
1070 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
1071 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
1072 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
1077 class Signs(Elaboratable
):
1078 """ determines whether a or b are signed numbers
1079 based on the required operation type (OP_MUL_*)
1083 self
.part_ops
= Signal(2, reset_less
=True)
1084 self
.a_signed
= Signal(reset_less
=True)
1085 self
.b_signed
= Signal(reset_less
=True)
1087 def elaborate(self
, platform
):
1091 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
1092 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
1093 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
1094 m
.d
.comb
+= self
.a_signed
.eq(asig
)
1095 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
1100 class IntermediateData
:
1102 def __init__(self
, part_pts
, output_width
, n_parts
):
1103 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
1104 for i
in range(n_parts
)]
1105 self
.part_pts
= part_pts
.like()
1106 self
.outputs
= [Signal(output_width
, name
="io%d" % i
, reset_less
=True)
1108 # intermediates (needed for unit tests)
1109 self
.intermediate_output
= Signal(output_width
)
1111 def eq_from(self
, part_pts
, outputs
, intermediate_output
,
1113 return [self
.part_pts
.eq(part_pts
)] + \
1114 [self
.intermediate_output
.eq(intermediate_output
)] + \
1115 [self
.outputs
[i
].eq(outputs
[i
])
1116 for i
in range(4)] + \
1117 [self
.part_ops
[i
].eq(part_ops
[i
])
1118 for i
in range(len(self
.part_ops
))]
1121 return self
.eq_from(rhs
.part_pts
, rhs
.outputs
,
1122 rhs
.intermediate_output
, rhs
.part_ops
)
1130 self
.part_pts
= PartitionPoints()
1131 for i
in range(8, 64, 8):
1132 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
1133 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
1135 def eq_from(self
, part_pts
, a
, b
, part_ops
):
1136 return [self
.part_pts
.eq(part_pts
)] + \
1137 [self
.a
.eq(a
), self
.b
.eq(b
)] + \
1138 [self
.part_ops
[i
].eq(part_ops
[i
])
1139 for i
in range(len(self
.part_ops
))]
1142 return self
.eq_from(rhs
.part_pts
, rhs
.a
, rhs
.b
, rhs
.part_ops
)
1148 self
.intermediate_output
= Signal(128) # needed for unit tests
1149 self
.output
= Signal(64)
1152 return [self
.intermediate_output
.eq(rhs
.intermediate_output
),
1153 self
.output
.eq(rhs
.output
)]
1156 class AllTerms(PipeModBase
):
1157 """Set of terms to be added together
1160 def __init__(self
, pspec
, n_inputs
):
1161 """Create an ``AllTerms``.
1163 self
.n_inputs
= n_inputs
1164 self
.n_parts
= pspec
.n_parts
1165 self
.output_width
= pspec
.width
* 2
1166 super().__init
__(pspec
, "allterms")
1172 return AddReduceData(self
.i
.part_pts
, self
.n_inputs
,
1173 self
.output_width
, self
.n_parts
)
1175 def elaborate(self
, platform
):
1178 eps
= self
.i
.part_pts
1180 # collect part-bytes
1181 pbs
= Signal(8, reset_less
=True)
1184 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
1185 m
.d
.comb
+= pb
.eq(eps
.part_byte(i
))
1187 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
1194 setattr(m
.submodules
, "signs%d" % i
, s
)
1195 m
.d
.comb
+= s
.part_ops
.eq(self
.i
.part_ops
[i
])
1197 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, 8)
1198 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, 8)
1199 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, 8)
1200 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, 8)
1201 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
1202 for mod
in [part_8
, part_16
, part_32
, part_64
]:
1203 m
.d
.comb
+= mod
.a
.eq(self
.i
.a
)
1204 m
.d
.comb
+= mod
.b
.eq(self
.i
.b
)
1205 for i
in range(len(signs
)):
1206 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
1207 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
1208 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
1209 nat_l
.append(mod
.not_a_term
)
1210 nbt_l
.append(mod
.not_b_term
)
1211 nla_l
.append(mod
.neg_lsb_a_term
)
1212 nlb_l
.append(mod
.neg_lsb_b_term
)
1216 for a_index
in range(8):
1217 t
= ProductTerms(8, 128, 8, a_index
, 8)
1218 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
1220 m
.d
.comb
+= t
.a
.eq(self
.i
.a
)
1221 m
.d
.comb
+= t
.b
.eq(self
.i
.b
)
1222 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
1224 for term
in t
.terms
:
1227 # it's fine to bitwise-or data together since they are never enabled
1229 m
.submodules
.nat_or
= nat_or
= OrMod(128)
1230 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
1231 m
.submodules
.nla_or
= nla_or
= OrMod(128)
1232 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
1233 for l
, mod
in [(nat_l
, nat_or
),
1237 for i
in range(len(l
)):
1238 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
1239 terms
.append(mod
.orout
)
1241 # copy the intermediate terms to the output
1242 for i
, value
in enumerate(terms
):
1243 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
1245 # copy reg part points and part ops to output
1246 m
.d
.comb
+= self
.o
.part_pts
.eq(eps
)
1247 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
1248 for i
in range(len(self
.i
.part_ops
))]
1253 class Intermediates(PipeModBase
):
1254 """ Intermediate output modules
1257 def __init__(self
, pspec
, part_pts
):
1258 self
.part_pts
= part_pts
1259 self
.output_width
= pspec
.width
* 2
1260 self
.n_parts
= pspec
.n_parts
1262 super().__init
__(pspec
, "intermediates")
1265 return FinalReduceData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1268 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
1270 def elaborate(self
, platform
):
1273 out_part_ops
= self
.i
.part_ops
1274 out_part_pts
= self
.i
.part_pts
1277 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
1278 m
.d
.comb
+= io64
.intermed
.eq(self
.i
.output
)
1280 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1281 m
.d
.comb
+= self
.o
.outputs
[3].eq(io64
.output
)
1284 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1285 m
.d
.comb
+= io32
.intermed
.eq(self
.i
.output
)
1287 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1288 m
.d
.comb
+= self
.o
.outputs
[2].eq(io32
.output
)
1291 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1292 m
.d
.comb
+= io16
.intermed
.eq(self
.i
.output
)
1294 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1295 m
.d
.comb
+= self
.o
.outputs
[1].eq(io16
.output
)
1298 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1299 m
.d
.comb
+= io8
.intermed
.eq(self
.i
.output
)
1301 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1302 m
.d
.comb
+= self
.o
.outputs
[0].eq(io8
.output
)
1305 m
.d
.comb
+= self
.o
.part_ops
[i
].eq(out_part_ops
[i
])
1306 m
.d
.comb
+= self
.o
.part_pts
.eq(out_part_pts
)
1307 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.output
)
1312 class Mul8_16_32_64(Elaboratable
):
1313 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1315 XXX NOTE: this class is intended for unit test purposes ONLY.
1317 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1318 partitions on naturally-aligned boundaries. Supports the operation being
1319 set for each partition independently.
1321 :attribute part_pts: the input partition points. Has a partition point at
1322 multiples of 8 in 0 < i < 64. Each partition point's associated
1323 ``Value`` is a ``Signal``. Modification not supported, except for by
1325 :attribute part_ops: the operation for each byte. The operation for a
1326 particular partition is selected by assigning the selected operation
1327 code to each byte in the partition. The allowed operation codes are:
1329 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1330 RISC-V's `mul` instruction.
1331 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1332 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1334 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1335 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1336 `mulhsu` instruction.
1337 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1338 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1342 def __init__(self
, register_levels
=()):
1343 """ register_levels: specifies the points in the cascade at which
1344 flip-flops are to be inserted.
1347 self
.id_wid
= 0 # num_bits(num_rows)
1349 self
.pspec
= PipelineSpec(64, self
.id_wid
, self
.op_wid
, n_ops
=3)
1350 self
.pspec
.n_parts
= 8
1353 self
.register_levels
= list(register_levels
)
1355 self
.i
= self
.ispec()
1356 self
.o
= self
.ospec()
1359 self
.part_pts
= self
.i
.part_pts
1360 self
.part_ops
= self
.i
.part_ops
1365 self
.intermediate_output
= self
.o
.intermediate_output
1366 self
.output
= self
.o
.output
1374 def elaborate(self
, platform
):
1377 part_pts
= self
.part_pts
1380 t
= AllTerms(self
.pspec
, n_inputs
)
1385 at
= AddReduceInternal(self
.pspec
, n_inputs
, part_pts
, partition_step
=2)
1388 for idx
in range(len(at
.levels
)):
1389 mcur
= at
.levels
[idx
]
1392 if idx
in self
.register_levels
:
1393 m
.d
.sync
+= o
.eq(mcur
.process(i
))
1395 m
.d
.comb
+= o
.eq(mcur
.process(i
))
1396 i
= o
# for next loop
1398 interm
= Intermediates(self
.pspec
, part_pts
)
1400 o
= interm
.process(interm
.i
)
1403 finalout
= FinalOut(self
.pspec
, part_pts
)
1404 finalout
.setup(m
, o
)
1405 m
.d
.comb
+= self
.o
.eq(finalout
.process(o
))
1410 if __name__
== "__main__":
1414 m
.intermediate_output
,
1417 *m
.part_pts
.values()])