1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
5 from nmigen
import Signal
, Module
, Value
, Elaboratable
, Cat
, C
, Mux
, Repl
6 from nmigen
.hdl
.ast
import Assign
7 from abc
import ABCMeta
, abstractmethod
8 from nmigen
.cli
import main
9 from functools
import reduce
10 from operator
import or_
11 from ieee754
.pipeline
import PipelineSpec
12 from nmutil
.pipemodbase
import PipeModBase
14 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
15 from ieee754
.part_mul_add
.adder
import PartitionedAdder
, MaskedFullAdder
18 FULL_ADDER_INPUT_COUNT
= 3
23 def __init__(self
, part_pts
, n_inputs
, output_width
, n_parts
):
24 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
25 for i
in range(n_parts
)]
26 self
.terms
= [Signal(output_width
, name
=f
"terms_{i}",
28 for i
in range(n_inputs
)]
29 self
.part_pts
= part_pts
.like()
31 def eq_from(self
, part_pts
, inputs
, part_ops
):
32 return [self
.part_pts
.eq(part_pts
)] + \
33 [self
.terms
[i
].eq(inputs
[i
])
34 for i
in range(len(self
.terms
))] + \
35 [self
.part_ops
[i
].eq(part_ops
[i
])
36 for i
in range(len(self
.part_ops
))]
39 return self
.eq_from(rhs
.part_pts
, rhs
.terms
, rhs
.part_ops
)
42 class FinalReduceData
:
44 def __init__(self
, part_pts
, output_width
, n_parts
):
45 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
46 for i
in range(n_parts
)]
47 self
.output
= Signal(output_width
, reset_less
=True)
48 self
.part_pts
= part_pts
.like()
50 def eq_from(self
, part_pts
, output
, part_ops
):
51 return [self
.part_pts
.eq(part_pts
)] + \
52 [self
.output
.eq(output
)] + \
53 [self
.part_ops
[i
].eq(part_ops
[i
])
54 for i
in range(len(self
.part_ops
))]
57 return self
.eq_from(rhs
.part_pts
, rhs
.output
, rhs
.part_ops
)
60 class FinalAdd(PipeModBase
):
61 """ Final stage of add reduce
64 def __init__(self
, pspec
, lidx
, n_inputs
, partition_points
,
67 self
.partition_step
= partition_step
68 self
.output_width
= pspec
.width
* 2
69 self
.n_inputs
= n_inputs
70 self
.n_parts
= pspec
.n_parts
71 self
.partition_points
= PartitionPoints(partition_points
)
72 if not self
.partition_points
.fits_in_width(self
.output_width
):
73 raise ValueError("partition_points doesn't fit in output_width")
75 super().__init
__(pspec
, "finaladd")
78 return AddReduceData(self
.partition_points
, self
.n_inputs
,
79 self
.output_width
, self
.n_parts
)
82 return FinalReduceData(self
.partition_points
,
83 self
.output_width
, self
.n_parts
)
85 def elaborate(self
, platform
):
86 """Elaborate this module."""
89 output_width
= self
.output_width
90 output
= Signal(output_width
, reset_less
=True)
91 if self
.n_inputs
== 0:
92 # use 0 as the default output value
93 m
.d
.comb
+= output
.eq(0)
94 elif self
.n_inputs
== 1:
96 m
.d
.comb
+= output
.eq(self
.i
.terms
[0])
98 # base case for adding 2 inputs
99 assert self
.n_inputs
== 2
100 adder
= PartitionedAdder(output_width
,
101 self
.i
.part_pts
, self
.partition_step
)
102 m
.submodules
.final_adder
= adder
103 m
.d
.comb
+= adder
.a
.eq(self
.i
.terms
[0])
104 m
.d
.comb
+= adder
.b
.eq(self
.i
.terms
[1])
105 m
.d
.comb
+= output
.eq(adder
.output
)
108 m
.d
.comb
+= self
.o
.eq_from(self
.i
.part_pts
, output
,
114 class AddReduceSingle(PipeModBase
):
115 """Add list of numbers together.
117 :attribute inputs: input ``Signal``s to be summed. Modification not
118 supported, except for by ``Signal.eq``.
119 :attribute register_levels: List of nesting levels that should have
121 :attribute output: output sum.
122 :attribute partition_points: the input partition points. Modification not
123 supported, except for by ``Signal.eq``.
126 def __init__(self
, pspec
, lidx
, n_inputs
, partition_points
,
128 """Create an ``AddReduce``.
130 :param inputs: input ``Signal``s to be summed.
131 :param output_width: bit-width of ``output``.
132 :param partition_points: the input partition points.
135 self
.partition_step
= partition_step
136 self
.n_inputs
= n_inputs
137 self
.n_parts
= pspec
.n_parts
138 self
.output_width
= pspec
.width
* 2
139 self
.partition_points
= PartitionPoints(partition_points
)
140 if not self
.partition_points
.fits_in_width(self
.output_width
):
141 raise ValueError("partition_points doesn't fit in output_width")
143 self
.groups
= AddReduceSingle
.full_adder_groups(n_inputs
)
144 self
.n_terms
= AddReduceSingle
.calc_n_inputs(n_inputs
, self
.groups
)
146 super().__init
__(pspec
, "addreduce_%d" % lidx
)
149 return AddReduceData(self
.partition_points
, self
.n_inputs
,
150 self
.output_width
, self
.n_parts
)
153 return AddReduceData(self
.partition_points
, self
.n_terms
,
154 self
.output_width
, self
.n_parts
)
157 def calc_n_inputs(n_inputs
, groups
):
158 retval
= len(groups
)*2
159 if n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
161 elif n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
164 assert n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
168 def get_max_level(input_count
):
169 """Get the maximum level.
171 All ``register_levels`` must be less than or equal to the maximum
176 groups
= AddReduceSingle
.full_adder_groups(input_count
)
179 input_count
%= FULL_ADDER_INPUT_COUNT
180 input_count
+= 2 * len(groups
)
184 def full_adder_groups(input_count
):
185 """Get ``inputs`` indices for which a full adder should be built."""
187 input_count
- FULL_ADDER_INPUT_COUNT
+ 1,
188 FULL_ADDER_INPUT_COUNT
)
190 def create_next_terms(self
):
191 """ create next intermediate terms, for linking up in elaborate, below
196 # create full adders for this recursive level.
197 # this shrinks N terms to 2 * (N // 3) plus the remainder
198 for i
in self
.groups
:
199 adder_i
= MaskedFullAdder(self
.output_width
)
200 adders
.append((i
, adder_i
))
201 # add both the sum and the masked-carry to the next level.
202 # 3 inputs have now been reduced to 2...
203 terms
.append(adder_i
.sum)
204 terms
.append(adder_i
.mcarry
)
205 # handle the remaining inputs.
206 if self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 1:
207 terms
.append(self
.i
.terms
[-1])
208 elif self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 2:
209 # Just pass the terms to the next layer, since we wouldn't gain
210 # anything by using a half adder since there would still be 2 terms
211 # and just passing the terms to the next layer saves gates.
212 terms
.append(self
.i
.terms
[-2])
213 terms
.append(self
.i
.terms
[-1])
215 assert self
.n_inputs
% FULL_ADDER_INPUT_COUNT
== 0
219 def elaborate(self
, platform
):
220 """Elaborate this module."""
223 terms
, adders
= self
.create_next_terms()
225 # copy the intermediate terms to the output
226 for i
, value
in enumerate(terms
):
227 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
229 # copy reg part points and part ops to output
230 m
.d
.comb
+= self
.o
.part_pts
.eq(self
.i
.part_pts
)
231 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
232 for i
in range(len(self
.i
.part_ops
))]
234 # set up the partition mask (for the adders)
235 part_mask
= Signal(self
.output_width
, reset_less
=True)
237 # get partition points as a mask
238 mask
= self
.i
.part_pts
.as_mask(self
.output_width
,
239 mul
=self
.partition_step
)
240 m
.d
.comb
+= part_mask
.eq(mask
)
242 # add and link the intermediate term modules
243 for i
, (iidx
, adder_i
) in enumerate(adders
):
244 setattr(m
.submodules
, f
"adder_{i}", adder_i
)
246 m
.d
.comb
+= adder_i
.in0
.eq(self
.i
.terms
[iidx
])
247 m
.d
.comb
+= adder_i
.in1
.eq(self
.i
.terms
[iidx
+ 1])
248 m
.d
.comb
+= adder_i
.in2
.eq(self
.i
.terms
[iidx
+ 2])
249 m
.d
.comb
+= adder_i
.mask
.eq(part_mask
)
254 class AddReduceInternal
:
255 """Iteratively Add list of numbers together.
257 :attribute inputs: input ``Signal``s to be summed. Modification not
258 supported, except for by ``Signal.eq``.
259 :attribute register_levels: List of nesting levels that should have
261 :attribute output: output sum.
262 :attribute partition_points: the input partition points. Modification not
263 supported, except for by ``Signal.eq``.
266 def __init__(self
, pspec
, n_inputs
, part_pts
, partition_step
=1):
267 """Create an ``AddReduce``.
269 :param inputs: input ``Signal``s to be summed.
270 :param output_width: bit-width of ``output``.
271 :param partition_points: the input partition points.
274 self
.n_inputs
= n_inputs
275 self
.output_width
= pspec
.width
* 2
276 self
.partition_points
= part_pts
277 self
.partition_step
= partition_step
281 def create_levels(self
):
282 """creates reduction levels"""
285 partition_points
= self
.partition_points
288 groups
= AddReduceSingle
.full_adder_groups(ilen
)
292 next_level
= AddReduceSingle(self
.pspec
, lidx
, ilen
,
295 mods
.append(next_level
)
296 partition_points
= next_level
.i
.part_pts
297 ilen
= len(next_level
.o
.terms
)
300 next_level
= FinalAdd(self
.pspec
, lidx
, ilen
,
301 partition_points
, self
.partition_step
)
302 mods
.append(next_level
)
307 class AddReduce(AddReduceInternal
, Elaboratable
):
308 """Recursively Add list of numbers together.
310 :attribute inputs: input ``Signal``s to be summed. Modification not
311 supported, except for by ``Signal.eq``.
312 :attribute register_levels: List of nesting levels that should have
314 :attribute output: output sum.
315 :attribute partition_points: the input partition points. Modification not
316 supported, except for by ``Signal.eq``.
319 def __init__(self
, inputs
, output_width
, register_levels
, part_pts
,
320 part_ops
, partition_step
=1):
321 """Create an ``AddReduce``.
323 :param inputs: input ``Signal``s to be summed.
324 :param output_width: bit-width of ``output``.
325 :param register_levels: List of nesting levels that should have
327 :param partition_points: the input partition points.
329 self
._inputs
= inputs
330 self
._part
_pts
= part_pts
331 self
._part
_ops
= part_ops
332 n_parts
= len(part_ops
)
333 self
.i
= AddReduceData(part_pts
, len(inputs
),
334 output_width
, n_parts
)
335 AddReduceInternal
.__init
__(self
, pspec
, n_inputs
, part_pts
,
337 self
.o
= FinalReduceData(part_pts
, output_width
, n_parts
)
338 self
.register_levels
= register_levels
341 def get_max_level(input_count
):
342 return AddReduceSingle
.get_max_level(input_count
)
345 def next_register_levels(register_levels
):
346 """``Iterable`` of ``register_levels`` for next recursive level."""
347 for level
in register_levels
:
351 def elaborate(self
, platform
):
352 """Elaborate this module."""
355 m
.d
.comb
+= self
.i
.eq_from(self
._part
_pts
,
356 self
._inputs
, self
._part
_ops
)
358 for i
, next_level
in enumerate(self
.levels
):
359 setattr(m
.submodules
, "next_level%d" % i
, next_level
)
362 for idx
in range(len(self
.levels
)):
363 mcur
= self
.levels
[idx
]
364 if idx
in self
.register_levels
:
365 m
.d
.sync
+= mcur
.i
.eq(i
)
367 m
.d
.comb
+= mcur
.i
.eq(i
)
368 i
= mcur
.o
# for next loop
370 # output comes from last module
371 m
.d
.comb
+= self
.o
.eq(i
)
377 OP_MUL_SIGNED_HIGH
= 1
378 OP_MUL_SIGNED_UNSIGNED_HIGH
= 2 # a is signed, b is unsigned
379 OP_MUL_UNSIGNED_HIGH
= 3
382 def get_term(value
, shift
=0, enabled
=None):
383 if enabled
is not None:
384 value
= Mux(enabled
, value
, 0)
386 value
= Cat(Repl(C(0, 1), shift
), value
)
392 class ProductTerm(Elaboratable
):
393 """ this class creates a single product term (a[..]*b[..]).
394 it has a design flaw in that is the *output* that is selected,
395 where the multiplication(s) are combinatorially generated
399 def __init__(self
, width
, twidth
, pbwid
, a_index
, b_index
):
400 self
.a_index
= a_index
401 self
.b_index
= b_index
402 shift
= 8 * (self
.a_index
+ self
.b_index
)
408 self
.ti
= Signal(self
.width
, reset_less
=True)
409 self
.term
= Signal(twidth
, reset_less
=True)
410 self
.a
= Signal(twidth
//2, reset_less
=True)
411 self
.b
= Signal(twidth
//2, reset_less
=True)
412 self
.pb_en
= Signal(pbwid
, reset_less
=True)
415 min_index
= min(self
.a_index
, self
.b_index
)
416 max_index
= max(self
.a_index
, self
.b_index
)
417 for i
in range(min_index
, max_index
):
418 tl
.append(self
.pb_en
[i
])
419 name
= "te_%d_%d" % (self
.a_index
, self
.b_index
)
421 term_enabled
= Signal(name
=name
, reset_less
=True)
424 self
.enabled
= term_enabled
425 self
.term
.name
= "term_%d_%d" % (a_index
, b_index
) # rename
427 def elaborate(self
, platform
):
430 if self
.enabled
is not None:
431 m
.d
.comb
+= self
.enabled
.eq(~
(Cat(*self
.tl
).bool()))
433 bsa
= Signal(self
.width
, reset_less
=True)
434 bsb
= Signal(self
.width
, reset_less
=True)
435 a_index
, b_index
= self
.a_index
, self
.b_index
437 m
.d
.comb
+= bsa
.eq(self
.a
.bit_select(a_index
* pwidth
, pwidth
))
438 m
.d
.comb
+= bsb
.eq(self
.b
.bit_select(b_index
* pwidth
, pwidth
))
439 m
.d
.comb
+= self
.ti
.eq(bsa
* bsb
)
440 m
.d
.comb
+= self
.term
.eq(get_term(self
.ti
, self
.shift
, self
.enabled
))
442 #TODO: sort out width issues, get inputs a/b switched on/off.
443 #data going into Muxes is 1/2 the required width
447 bsa = Signal(self.twidth//2, reset_less=True)
448 bsb = Signal(self.twidth//2, reset_less=True)
449 asel = Signal(width, reset_less=True)
450 bsel = Signal(width, reset_less=True)
451 a_index, b_index = self.a_index, self.b_index
452 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
453 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
454 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
455 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
456 m.d.comb += self.ti.eq(bsa * bsb)
457 m.d.comb += self.term.eq(self.ti)
463 class ProductTerms(Elaboratable
):
464 """ creates a bank of product terms. also performs the actual bit-selection
465 this class is to be wrapped with a for-loop on the "a" operand.
466 it creates a second-level for-loop on the "b" operand.
469 def __init__(self
, width
, twidth
, pbwid
, a_index
, blen
):
470 self
.a_index
= a_index
475 self
.a
= Signal(twidth
//2, reset_less
=True)
476 self
.b
= Signal(twidth
//2, reset_less
=True)
477 self
.pb_en
= Signal(pbwid
, reset_less
=True)
478 self
.terms
= [Signal(twidth
, name
="term%d" % i
, reset_less
=True)
479 for i
in range(blen
)]
481 def elaborate(self
, platform
):
485 for b_index
in range(self
.blen
):
486 t
= ProductTerm(self
.pwidth
, self
.twidth
, self
.pbwid
,
487 self
.a_index
, b_index
)
488 setattr(m
.submodules
, "term_%d" % b_index
, t
)
490 m
.d
.comb
+= t
.a
.eq(self
.a
)
491 m
.d
.comb
+= t
.b
.eq(self
.b
)
492 m
.d
.comb
+= t
.pb_en
.eq(self
.pb_en
)
494 m
.d
.comb
+= self
.terms
[b_index
].eq(t
.term
)
499 class LSBNegTerm(Elaboratable
):
501 def __init__(self
, bit_width
):
502 self
.bit_width
= bit_width
503 self
.part
= Signal(reset_less
=True)
504 self
.signed
= Signal(reset_less
=True)
505 self
.op
= Signal(bit_width
, reset_less
=True)
506 self
.msb
= Signal(reset_less
=True)
507 self
.nt
= Signal(bit_width
*2, reset_less
=True)
508 self
.nl
= Signal(bit_width
*2, reset_less
=True)
510 def elaborate(self
, platform
):
513 bit_wid
= self
.bit_width
514 ext
= Repl(0, bit_wid
) # extend output to HI part
516 # determine sign of each incoming number *in this partition*
517 enabled
= Signal(reset_less
=True)
518 m
.d
.comb
+= enabled
.eq(self
.part
& self
.msb
& self
.signed
)
520 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
521 # negation operation is split into a bitwise not and a +1.
522 # likewise for 16, 32, and 64-bit values.
524 # width-extended 1s complement if a is signed, otherwise zero
525 comb
+= self
.nt
.eq(Mux(enabled
, Cat(ext
, ~self
.op
), 0))
527 # add 1 if signed, otherwise add zero
528 comb
+= self
.nl
.eq(Cat(ext
, enabled
, Repl(0, bit_wid
-1)))
533 class Parts(Elaboratable
):
535 def __init__(self
, pbwid
, part_pts
, n_parts
):
538 self
.part_pts
= PartitionPoints
.like(part_pts
)
540 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
541 for i
in range(n_parts
)]
543 def elaborate(self
, platform
):
546 part_pts
, parts
= self
.part_pts
, self
.parts
547 # collect part-bytes (double factor because the input is extended)
548 pbs
= Signal(self
.pbwid
, reset_less
=True)
550 for i
in range(self
.pbwid
):
551 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
552 m
.d
.comb
+= pb
.eq(part_pts
.part_byte(i
))
554 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
556 # negated-temporary copy of partition bits
557 npbs
= Signal
.like(pbs
, reset_less
=True)
558 m
.d
.comb
+= npbs
.eq(~pbs
)
559 byte_count
= 8 // len(parts
)
560 for i
in range(len(parts
)):
562 pbl
.append(npbs
[i
* byte_count
- 1])
563 for j
in range(i
* byte_count
, (i
+ 1) * byte_count
- 1):
565 pbl
.append(npbs
[(i
+ 1) * byte_count
- 1])
566 value
= Signal(len(pbl
), name
="value_%d" % i
, reset_less
=True)
567 m
.d
.comb
+= value
.eq(Cat(*pbl
))
568 m
.d
.comb
+= parts
[i
].eq(~
(value
).bool())
573 class Part(Elaboratable
):
574 """ a key class which, depending on the partitioning, will determine
575 what action to take when parts of the output are signed or unsigned.
577 this requires 2 pieces of data *per operand, per partition*:
578 whether the MSB is HI/LO (per partition!), and whether a signed
579 or unsigned operation has been *requested*.
581 once that is determined, signed is basically carried out
582 by splitting 2's complement into 1's complement plus one.
583 1's complement is just a bit-inversion.
585 the extra terms - as separate terms - are then thrown at the
586 AddReduce alongside the multiplication part-results.
589 def __init__(self
, part_pts
, width
, n_parts
, pbwid
):
592 self
.part_pts
= part_pts
595 self
.a
= Signal(64, reset_less
=True)
596 self
.b
= Signal(64, reset_less
=True)
597 self
.a_signed
= [Signal(name
=f
"a_signed_{i}", reset_less
=True)
599 self
.b_signed
= [Signal(name
=f
"_b_signed_{i}", reset_less
=True)
601 self
.pbs
= Signal(pbwid
, reset_less
=True)
604 self
.parts
= [Signal(name
=f
"part_{i}", reset_less
=True)
605 for i
in range(n_parts
)]
607 self
.not_a_term
= Signal(width
, reset_less
=True)
608 self
.neg_lsb_a_term
= Signal(width
, reset_less
=True)
609 self
.not_b_term
= Signal(width
, reset_less
=True)
610 self
.neg_lsb_b_term
= Signal(width
, reset_less
=True)
612 def elaborate(self
, platform
):
615 pbs
, parts
= self
.pbs
, self
.parts
616 part_pts
= self
.part_pts
617 m
.submodules
.p
= p
= Parts(self
.pbwid
, part_pts
, len(parts
))
618 m
.d
.comb
+= p
.part_pts
.eq(part_pts
)
621 byte_count
= 8 // len(parts
)
623 not_a_term
, neg_lsb_a_term
, not_b_term
, neg_lsb_b_term
= (
624 self
.not_a_term
, self
.neg_lsb_a_term
,
625 self
.not_b_term
, self
.neg_lsb_b_term
)
627 byte_width
= 8 // len(parts
) # byte width
628 bit_wid
= 8 * byte_width
# bit width
629 nat
, nbt
, nla
, nlb
= [], [], [], []
630 for i
in range(len(parts
)):
631 # work out bit-inverted and +1 term for a.
632 pa
= LSBNegTerm(bit_wid
)
633 setattr(m
.submodules
, "lnt_%d_a_%d" % (bit_wid
, i
), pa
)
634 m
.d
.comb
+= pa
.part
.eq(parts
[i
])
635 m
.d
.comb
+= pa
.op
.eq(self
.a
.bit_select(bit_wid
* i
, bit_wid
))
636 m
.d
.comb
+= pa
.signed
.eq(self
.b_signed
[i
* byte_width
]) # yes b
637 m
.d
.comb
+= pa
.msb
.eq(self
.b
[(i
+ 1) * bit_wid
- 1]) # really, b
641 # work out bit-inverted and +1 term for b
642 pb
= LSBNegTerm(bit_wid
)
643 setattr(m
.submodules
, "lnt_%d_b_%d" % (bit_wid
, i
), pb
)
644 m
.d
.comb
+= pb
.part
.eq(parts
[i
])
645 m
.d
.comb
+= pb
.op
.eq(self
.b
.bit_select(bit_wid
* i
, bit_wid
))
646 m
.d
.comb
+= pb
.signed
.eq(self
.a_signed
[i
* byte_width
]) # yes a
647 m
.d
.comb
+= pb
.msb
.eq(self
.a
[(i
+ 1) * bit_wid
- 1]) # really, a
651 # concatenate together and return all 4 results.
652 m
.d
.comb
+= [not_a_term
.eq(Cat(*nat
)),
653 not_b_term
.eq(Cat(*nbt
)),
654 neg_lsb_a_term
.eq(Cat(*nla
)),
655 neg_lsb_b_term
.eq(Cat(*nlb
)),
661 class IntermediateOut(Elaboratable
):
662 """ selects the HI/LO part of the multiplication, for a given bit-width
663 the output is also reconstructed in its SIMD (partition) lanes.
666 def __init__(self
, width
, out_wid
, n_parts
):
668 self
.n_parts
= n_parts
669 self
.part_ops
= [Signal(2, name
="dpop%d" % i
, reset_less
=True)
671 self
.intermed
= Signal(out_wid
, reset_less
=True)
672 self
.output
= Signal(out_wid
//2, reset_less
=True)
674 def elaborate(self
, platform
):
680 for i
in range(self
.n_parts
):
681 op
= Signal(w
, reset_less
=True, name
="op%d_%d" % (w
, i
))
683 Mux(self
.part_ops
[sel
* i
] == OP_MUL_LOW
,
684 self
.intermed
.bit_select(i
* w
*2, w
),
685 self
.intermed
.bit_select(i
* w
*2 + w
, w
)))
687 m
.d
.comb
+= self
.output
.eq(Cat(*ol
))
692 class FinalOut(PipeModBase
):
693 """ selects the final output based on the partitioning.
695 each byte is selectable independently, i.e. it is possible
696 that some partitions requested 8-bit computation whilst others
697 requested 16 or 32 bit.
700 def __init__(self
, pspec
, part_pts
):
702 self
.part_pts
= part_pts
703 self
.output_width
= pspec
.width
* 2
704 self
.n_parts
= pspec
.n_parts
705 self
.out_wid
= pspec
.width
707 super().__init
__(pspec
, "finalout")
710 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
715 def elaborate(self
, platform
):
718 part_pts
= self
.part_pts
719 m
.submodules
.p_8
= p_8
= Parts(8, part_pts
, 8)
720 m
.submodules
.p_16
= p_16
= Parts(8, part_pts
, 4)
721 m
.submodules
.p_32
= p_32
= Parts(8, part_pts
, 2)
722 m
.submodules
.p_64
= p_64
= Parts(8, part_pts
, 1)
724 out_part_pts
= self
.i
.part_pts
727 d8
= [Signal(name
=f
"d8_{i}", reset_less
=True) for i
in range(8)]
728 d16
= [Signal(name
=f
"d16_{i}", reset_less
=True) for i
in range(4)]
729 d32
= [Signal(name
=f
"d32_{i}", reset_less
=True) for i
in range(2)]
731 i8
= Signal(self
.out_wid
, reset_less
=True)
732 i16
= Signal(self
.out_wid
, reset_less
=True)
733 i32
= Signal(self
.out_wid
, reset_less
=True)
734 i64
= Signal(self
.out_wid
, reset_less
=True)
736 m
.d
.comb
+= p_8
.part_pts
.eq(out_part_pts
)
737 m
.d
.comb
+= p_16
.part_pts
.eq(out_part_pts
)
738 m
.d
.comb
+= p_32
.part_pts
.eq(out_part_pts
)
739 m
.d
.comb
+= p_64
.part_pts
.eq(out_part_pts
)
741 for i
in range(len(p_8
.parts
)):
742 m
.d
.comb
+= d8
[i
].eq(p_8
.parts
[i
])
743 for i
in range(len(p_16
.parts
)):
744 m
.d
.comb
+= d16
[i
].eq(p_16
.parts
[i
])
745 for i
in range(len(p_32
.parts
)):
746 m
.d
.comb
+= d32
[i
].eq(p_32
.parts
[i
])
747 m
.d
.comb
+= i8
.eq(self
.i
.outputs
[0])
748 m
.d
.comb
+= i16
.eq(self
.i
.outputs
[1])
749 m
.d
.comb
+= i32
.eq(self
.i
.outputs
[2])
750 m
.d
.comb
+= i64
.eq(self
.i
.outputs
[3])
754 # select one of the outputs: d8 selects i8, d16 selects i16
755 # d32 selects i32, and the default is i64.
756 # d8 and d16 are ORed together in the first Mux
757 # then the 2nd selects either i8 or i16.
758 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
759 op
= Signal(8, reset_less
=True, name
="op_%d" % i
)
761 Mux(d8
[i
] | d16
[i
// 2],
762 Mux(d8
[i
], i8
.bit_select(i
* 8, 8),
763 i16
.bit_select(i
* 8, 8)),
764 Mux(d32
[i
// 4], i32
.bit_select(i
* 8, 8),
765 i64
.bit_select(i
* 8, 8))))
769 m
.d
.comb
+= self
.o
.output
.eq(Cat(*ol
))
770 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.intermediate_output
)
775 class OrMod(Elaboratable
):
776 """ ORs four values together in a hierarchical tree
779 def __init__(self
, wid
):
781 self
.orin
= [Signal(wid
, name
="orin%d" % i
, reset_less
=True)
783 self
.orout
= Signal(wid
, reset_less
=True)
785 def elaborate(self
, platform
):
787 or1
= Signal(self
.wid
, reset_less
=True)
788 or2
= Signal(self
.wid
, reset_less
=True)
789 m
.d
.comb
+= or1
.eq(self
.orin
[0] | self
.orin
[1])
790 m
.d
.comb
+= or2
.eq(self
.orin
[2] | self
.orin
[3])
791 m
.d
.comb
+= self
.orout
.eq(or1 | or2
)
796 class Signs(Elaboratable
):
797 """ determines whether a or b are signed numbers
798 based on the required operation type (OP_MUL_*)
802 self
.part_ops
= Signal(2, reset_less
=True)
803 self
.a_signed
= Signal(reset_less
=True)
804 self
.b_signed
= Signal(reset_less
=True)
806 def elaborate(self
, platform
):
810 asig
= self
.part_ops
!= OP_MUL_UNSIGNED_HIGH
811 bsig
= (self
.part_ops
== OP_MUL_LOW
) \
812 |
(self
.part_ops
== OP_MUL_SIGNED_HIGH
)
813 m
.d
.comb
+= self
.a_signed
.eq(asig
)
814 m
.d
.comb
+= self
.b_signed
.eq(bsig
)
819 class IntermediateData
:
821 def __init__(self
, part_pts
, output_width
, n_parts
):
822 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}", reset_less
=True)
823 for i
in range(n_parts
)]
824 self
.part_pts
= part_pts
.like()
825 self
.outputs
= [Signal(output_width
, name
="io%d" % i
, reset_less
=True)
827 # intermediates (needed for unit tests)
828 self
.intermediate_output
= Signal(output_width
)
830 def eq_from(self
, part_pts
, outputs
, intermediate_output
,
832 return [self
.part_pts
.eq(part_pts
)] + \
833 [self
.intermediate_output
.eq(intermediate_output
)] + \
834 [self
.outputs
[i
].eq(outputs
[i
])
835 for i
in range(4)] + \
836 [self
.part_ops
[i
].eq(part_ops
[i
])
837 for i
in range(len(self
.part_ops
))]
840 return self
.eq_from(rhs
.part_pts
, rhs
.outputs
,
841 rhs
.intermediate_output
, rhs
.part_ops
)
849 self
.part_pts
= PartitionPoints()
850 for i
in range(8, 64, 8):
851 self
.part_pts
[i
] = Signal(name
=f
"part_pts_{i}")
852 self
.part_ops
= [Signal(2, name
=f
"part_ops_{i}") for i
in range(8)]
854 def eq_from(self
, part_pts
, a
, b
, part_ops
):
855 return [self
.part_pts
.eq(part_pts
)] + \
856 [self
.a
.eq(a
), self
.b
.eq(b
)] + \
857 [self
.part_ops
[i
].eq(part_ops
[i
])
858 for i
in range(len(self
.part_ops
))]
861 return self
.eq_from(rhs
.part_pts
, rhs
.a
, rhs
.b
, rhs
.part_ops
)
867 self
.intermediate_output
= Signal(128) # needed for unit tests
868 self
.output
= Signal(64)
871 return [self
.intermediate_output
.eq(rhs
.intermediate_output
),
872 self
.output
.eq(rhs
.output
)]
875 class AllTerms(PipeModBase
):
876 """Set of terms to be added together
879 def __init__(self
, pspec
, n_inputs
):
880 """Create an ``AllTerms``.
882 self
.n_inputs
= n_inputs
883 self
.n_parts
= pspec
.n_parts
884 self
.output_width
= pspec
.width
* 2
885 super().__init
__(pspec
, "allterms")
891 return AddReduceData(self
.i
.part_pts
, self
.n_inputs
,
892 self
.output_width
, self
.n_parts
)
894 def elaborate(self
, platform
):
897 eps
= self
.i
.part_pts
900 pbs
= Signal(8, reset_less
=True)
903 pb
= Signal(name
="pb%d" % i
, reset_less
=True)
904 m
.d
.comb
+= pb
.eq(eps
.part_byte(i
))
906 m
.d
.comb
+= pbs
.eq(Cat(*tl
))
913 setattr(m
.submodules
, "signs%d" % i
, s
)
914 m
.d
.comb
+= s
.part_ops
.eq(self
.i
.part_ops
[i
])
916 m
.submodules
.part_8
= part_8
= Part(eps
, 128, 8, 8)
917 m
.submodules
.part_16
= part_16
= Part(eps
, 128, 4, 8)
918 m
.submodules
.part_32
= part_32
= Part(eps
, 128, 2, 8)
919 m
.submodules
.part_64
= part_64
= Part(eps
, 128, 1, 8)
920 nat_l
, nbt_l
, nla_l
, nlb_l
= [], [], [], []
921 for mod
in [part_8
, part_16
, part_32
, part_64
]:
922 m
.d
.comb
+= mod
.a
.eq(self
.i
.a
)
923 m
.d
.comb
+= mod
.b
.eq(self
.i
.b
)
924 for i
in range(len(signs
)):
925 m
.d
.comb
+= mod
.a_signed
[i
].eq(signs
[i
].a_signed
)
926 m
.d
.comb
+= mod
.b_signed
[i
].eq(signs
[i
].b_signed
)
927 m
.d
.comb
+= mod
.pbs
.eq(pbs
)
928 nat_l
.append(mod
.not_a_term
)
929 nbt_l
.append(mod
.not_b_term
)
930 nla_l
.append(mod
.neg_lsb_a_term
)
931 nlb_l
.append(mod
.neg_lsb_b_term
)
935 for a_index
in range(8):
936 t
= ProductTerms(8, 128, 8, a_index
, 8)
937 setattr(m
.submodules
, "terms_%d" % a_index
, t
)
939 m
.d
.comb
+= t
.a
.eq(self
.i
.a
)
940 m
.d
.comb
+= t
.b
.eq(self
.i
.b
)
941 m
.d
.comb
+= t
.pb_en
.eq(pbs
)
946 # it's fine to bitwise-or data together since they are never enabled
948 m
.submodules
.nat_or
= nat_or
= OrMod(128)
949 m
.submodules
.nbt_or
= nbt_or
= OrMod(128)
950 m
.submodules
.nla_or
= nla_or
= OrMod(128)
951 m
.submodules
.nlb_or
= nlb_or
= OrMod(128)
952 for l
, mod
in [(nat_l
, nat_or
),
956 for i
in range(len(l
)):
957 m
.d
.comb
+= mod
.orin
[i
].eq(l
[i
])
958 terms
.append(mod
.orout
)
960 # copy the intermediate terms to the output
961 for i
, value
in enumerate(terms
):
962 m
.d
.comb
+= self
.o
.terms
[i
].eq(value
)
964 # copy reg part points and part ops to output
965 m
.d
.comb
+= self
.o
.part_pts
.eq(eps
)
966 m
.d
.comb
+= [self
.o
.part_ops
[i
].eq(self
.i
.part_ops
[i
])
967 for i
in range(len(self
.i
.part_ops
))]
972 class Intermediates(PipeModBase
):
973 """ Intermediate output modules
976 def __init__(self
, pspec
, part_pts
):
977 self
.part_pts
= part_pts
978 self
.output_width
= pspec
.width
* 2
979 self
.n_parts
= pspec
.n_parts
981 super().__init
__(pspec
, "intermediates")
984 return FinalReduceData(self
.part_pts
, self
.output_width
, self
.n_parts
)
987 return IntermediateData(self
.part_pts
, self
.output_width
, self
.n_parts
)
989 def elaborate(self
, platform
):
992 out_part_ops
= self
.i
.part_ops
993 out_part_pts
= self
.i
.part_pts
996 m
.submodules
.io64
= io64
= IntermediateOut(64, 128, 1)
997 m
.d
.comb
+= io64
.intermed
.eq(self
.i
.output
)
999 m
.d
.comb
+= io64
.part_ops
[i
].eq(out_part_ops
[i
])
1000 m
.d
.comb
+= self
.o
.outputs
[3].eq(io64
.output
)
1003 m
.submodules
.io32
= io32
= IntermediateOut(32, 128, 2)
1004 m
.d
.comb
+= io32
.intermed
.eq(self
.i
.output
)
1006 m
.d
.comb
+= io32
.part_ops
[i
].eq(out_part_ops
[i
])
1007 m
.d
.comb
+= self
.o
.outputs
[2].eq(io32
.output
)
1010 m
.submodules
.io16
= io16
= IntermediateOut(16, 128, 4)
1011 m
.d
.comb
+= io16
.intermed
.eq(self
.i
.output
)
1013 m
.d
.comb
+= io16
.part_ops
[i
].eq(out_part_ops
[i
])
1014 m
.d
.comb
+= self
.o
.outputs
[1].eq(io16
.output
)
1017 m
.submodules
.io8
= io8
= IntermediateOut(8, 128, 8)
1018 m
.d
.comb
+= io8
.intermed
.eq(self
.i
.output
)
1020 m
.d
.comb
+= io8
.part_ops
[i
].eq(out_part_ops
[i
])
1021 m
.d
.comb
+= self
.o
.outputs
[0].eq(io8
.output
)
1024 m
.d
.comb
+= self
.o
.part_ops
[i
].eq(out_part_ops
[i
])
1025 m
.d
.comb
+= self
.o
.part_pts
.eq(out_part_pts
)
1026 m
.d
.comb
+= self
.o
.intermediate_output
.eq(self
.i
.output
)
1031 class Mul8_16_32_64(Elaboratable
):
1032 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1034 XXX NOTE: this class is intended for unit test purposes ONLY.
1036 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1037 partitions on naturally-aligned boundaries. Supports the operation being
1038 set for each partition independently.
1040 :attribute part_pts: the input partition points. Has a partition point at
1041 multiples of 8 in 0 < i < 64. Each partition point's associated
1042 ``Value`` is a ``Signal``. Modification not supported, except for by
1044 :attribute part_ops: the operation for each byte. The operation for a
1045 particular partition is selected by assigning the selected operation
1046 code to each byte in the partition. The allowed operation codes are:
1048 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1049 RISC-V's `mul` instruction.
1050 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1051 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1053 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1054 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1055 `mulhsu` instruction.
1056 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1057 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1061 def __init__(self
, register_levels
=()):
1062 """ register_levels: specifies the points in the cascade at which
1063 flip-flops are to be inserted.
1066 self
.id_wid
= 0 # num_bits(num_rows)
1068 self
.pspec
= PipelineSpec(64, self
.id_wid
, self
.op_wid
, n_ops
=3)
1069 self
.pspec
.n_parts
= 8
1072 self
.register_levels
= list(register_levels
)
1074 self
.i
= self
.ispec()
1075 self
.o
= self
.ospec()
1078 self
.part_pts
= self
.i
.part_pts
1079 self
.part_ops
= self
.i
.part_ops
1084 self
.intermediate_output
= self
.o
.intermediate_output
1085 self
.output
= self
.o
.output
1093 def elaborate(self
, platform
):
1096 part_pts
= self
.part_pts
1099 t
= AllTerms(self
.pspec
, n_inputs
)
1104 at
= AddReduceInternal(self
.pspec
, n_inputs
,
1105 part_pts
, partition_step
=2)
1108 for idx
in range(len(at
.levels
)):
1109 mcur
= at
.levels
[idx
]
1112 if idx
in self
.register_levels
:
1113 m
.d
.sync
+= o
.eq(mcur
.process(i
))
1115 m
.d
.comb
+= o
.eq(mcur
.process(i
))
1116 i
= o
# for next loop
1118 interm
= Intermediates(self
.pspec
, part_pts
)
1120 o
= interm
.process(interm
.i
)
1123 finalout
= FinalOut(self
.pspec
, part_pts
)
1124 finalout
.setup(m
, o
)
1125 m
.d
.comb
+= self
.o
.eq(finalout
.process(o
))
1130 if __name__
== "__main__":
1134 m
.intermediate_output
,
1137 *m
.part_pts
.values()])