switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11 from ieee754.pipeline import PipelineSpec
12 from nmutil.pipemodbase import PipeModBase
13
14 from ieee754.part_mul_add.partpoints import PartitionPoints
15 from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder
16
17
18 FULL_ADDER_INPUT_COUNT = 3
19
20
21 class AddReduceData:
22
23 def __init__(self, part_pts, n_inputs, output_width, n_parts):
24 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
25 for i in range(n_parts)]
26 self.terms = [Signal(output_width, name=f"terms_{i}",
27 reset_less=True)
28 for i in range(n_inputs)]
29 self.part_pts = part_pts.like()
30
31 def eq_from(self, part_pts, inputs, part_ops):
32 return [self.part_pts.eq(part_pts)] + \
33 [self.terms[i].eq(inputs[i])
34 for i in range(len(self.terms))] + \
35 [self.part_ops[i].eq(part_ops[i])
36 for i in range(len(self.part_ops))]
37
38 def eq(self, rhs):
39 return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
40
41
42 class FinalReduceData:
43
44 def __init__(self, part_pts, output_width, n_parts):
45 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
46 for i in range(n_parts)]
47 self.output = Signal(output_width, reset_less=True)
48 self.part_pts = part_pts.like()
49
50 def eq_from(self, part_pts, output, part_ops):
51 return [self.part_pts.eq(part_pts)] + \
52 [self.output.eq(output)] + \
53 [self.part_ops[i].eq(part_ops[i])
54 for i in range(len(self.part_ops))]
55
56 def eq(self, rhs):
57 return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
58
59
60 class FinalAdd(PipeModBase):
61 """ Final stage of add reduce
62 """
63
64 def __init__(self, pspec, lidx, n_inputs, partition_points,
65 partition_step=1):
66 self.lidx = lidx
67 self.partition_step = partition_step
68 self.output_width = pspec.width * 2
69 self.n_inputs = n_inputs
70 self.n_parts = pspec.n_parts
71 self.partition_points = PartitionPoints(partition_points)
72 if not self.partition_points.fits_in_width(self.output_width):
73 raise ValueError("partition_points doesn't fit in output_width")
74
75 super().__init__(pspec, "finaladd")
76
77 def ispec(self):
78 return AddReduceData(self.partition_points, self.n_inputs,
79 self.output_width, self.n_parts)
80
81 def ospec(self):
82 return FinalReduceData(self.partition_points,
83 self.output_width, self.n_parts)
84
85 def elaborate(self, platform):
86 """Elaborate this module."""
87 m = Module()
88
89 output_width = self.output_width
90 output = Signal(output_width, reset_less=True)
91 if self.n_inputs == 0:
92 # use 0 as the default output value
93 m.d.comb += output.eq(0)
94 elif self.n_inputs == 1:
95 # handle single input
96 m.d.comb += output.eq(self.i.terms[0])
97 else:
98 # base case for adding 2 inputs
99 assert self.n_inputs == 2
100 adder = PartitionedAdder(output_width,
101 self.i.part_pts, self.partition_step)
102 m.submodules.final_adder = adder
103 m.d.comb += adder.a.eq(self.i.terms[0])
104 m.d.comb += adder.b.eq(self.i.terms[1])
105 m.d.comb += output.eq(adder.output)
106
107 # create output
108 m.d.comb += self.o.eq_from(self.i.part_pts, output,
109 self.i.part_ops)
110
111 return m
112
113
114 class AddReduceSingle(PipeModBase):
115 """Add list of numbers together.
116
117 :attribute inputs: input ``Signal``s to be summed. Modification not
118 supported, except for by ``Signal.eq``.
119 :attribute register_levels: List of nesting levels that should have
120 pipeline registers.
121 :attribute output: output sum.
122 :attribute partition_points: the input partition points. Modification not
123 supported, except for by ``Signal.eq``.
124 """
125
126 def __init__(self, pspec, lidx, n_inputs, partition_points,
127 partition_step=1):
128 """Create an ``AddReduce``.
129
130 :param inputs: input ``Signal``s to be summed.
131 :param output_width: bit-width of ``output``.
132 :param partition_points: the input partition points.
133 """
134 self.lidx = lidx
135 self.partition_step = partition_step
136 self.n_inputs = n_inputs
137 self.n_parts = pspec.n_parts
138 self.output_width = pspec.width * 2
139 self.partition_points = PartitionPoints(partition_points)
140 if not self.partition_points.fits_in_width(self.output_width):
141 raise ValueError("partition_points doesn't fit in output_width")
142
143 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
144 self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
145
146 super().__init__(pspec, "addreduce_%d" % lidx)
147
148 def ispec(self):
149 return AddReduceData(self.partition_points, self.n_inputs,
150 self.output_width, self.n_parts)
151
152 def ospec(self):
153 return AddReduceData(self.partition_points, self.n_terms,
154 self.output_width, self.n_parts)
155
156 @staticmethod
157 def calc_n_inputs(n_inputs, groups):
158 retval = len(groups)*2
159 if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
160 retval += 1
161 elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
162 retval += 2
163 else:
164 assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
165 return retval
166
167 @staticmethod
168 def get_max_level(input_count):
169 """Get the maximum level.
170
171 All ``register_levels`` must be less than or equal to the maximum
172 level.
173 """
174 retval = 0
175 while True:
176 groups = AddReduceSingle.full_adder_groups(input_count)
177 if len(groups) == 0:
178 return retval
179 input_count %= FULL_ADDER_INPUT_COUNT
180 input_count += 2 * len(groups)
181 retval += 1
182
183 @staticmethod
184 def full_adder_groups(input_count):
185 """Get ``inputs`` indices for which a full adder should be built."""
186 return range(0,
187 input_count - FULL_ADDER_INPUT_COUNT + 1,
188 FULL_ADDER_INPUT_COUNT)
189
190 def create_next_terms(self):
191 """ create next intermediate terms, for linking up in elaborate, below
192 """
193 terms = []
194 adders = []
195
196 # create full adders for this recursive level.
197 # this shrinks N terms to 2 * (N // 3) plus the remainder
198 for i in self.groups:
199 adder_i = MaskedFullAdder(self.output_width)
200 adders.append((i, adder_i))
201 # add both the sum and the masked-carry to the next level.
202 # 3 inputs have now been reduced to 2...
203 terms.append(adder_i.sum)
204 terms.append(adder_i.mcarry)
205 # handle the remaining inputs.
206 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
207 terms.append(self.i.terms[-1])
208 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
209 # Just pass the terms to the next layer, since we wouldn't gain
210 # anything by using a half adder since there would still be 2 terms
211 # and just passing the terms to the next layer saves gates.
212 terms.append(self.i.terms[-2])
213 terms.append(self.i.terms[-1])
214 else:
215 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
216
217 return terms, adders
218
219 def elaborate(self, platform):
220 """Elaborate this module."""
221 m = Module()
222
223 terms, adders = self.create_next_terms()
224
225 # copy the intermediate terms to the output
226 for i, value in enumerate(terms):
227 m.d.comb += self.o.terms[i].eq(value)
228
229 # copy reg part points and part ops to output
230 m.d.comb += self.o.part_pts.eq(self.i.part_pts)
231 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
232 for i in range(len(self.i.part_ops))]
233
234 # set up the partition mask (for the adders)
235 part_mask = Signal(self.output_width, reset_less=True)
236
237 # get partition points as a mask
238 mask = self.i.part_pts.as_mask(self.output_width,
239 mul=self.partition_step)
240 m.d.comb += part_mask.eq(mask)
241
242 # add and link the intermediate term modules
243 for i, (iidx, adder_i) in enumerate(adders):
244 setattr(m.submodules, f"adder_{i}", adder_i)
245
246 m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
247 m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
248 m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
249 m.d.comb += adder_i.mask.eq(part_mask)
250
251 return m
252
253
254 class AddReduceInternal:
255 """Iteratively Add list of numbers together.
256
257 :attribute inputs: input ``Signal``s to be summed. Modification not
258 supported, except for by ``Signal.eq``.
259 :attribute register_levels: List of nesting levels that should have
260 pipeline registers.
261 :attribute output: output sum.
262 :attribute partition_points: the input partition points. Modification not
263 supported, except for by ``Signal.eq``.
264 """
265
266 def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
267 """Create an ``AddReduce``.
268
269 :param inputs: input ``Signal``s to be summed.
270 :param output_width: bit-width of ``output``.
271 :param partition_points: the input partition points.
272 """
273 self.pspec = pspec
274 self.n_inputs = n_inputs
275 self.output_width = pspec.width * 2
276 self.partition_points = part_pts
277 self.partition_step = partition_step
278
279 self.create_levels()
280
281 def create_levels(self):
282 """creates reduction levels"""
283
284 mods = []
285 partition_points = self.partition_points
286 ilen = self.n_inputs
287 while True:
288 groups = AddReduceSingle.full_adder_groups(ilen)
289 if len(groups) == 0:
290 break
291 lidx = len(mods)
292 next_level = AddReduceSingle(self.pspec, lidx, ilen,
293 partition_points,
294 self.partition_step)
295 mods.append(next_level)
296 partition_points = next_level.i.part_pts
297 ilen = len(next_level.o.terms)
298
299 lidx = len(mods)
300 next_level = FinalAdd(self.pspec, lidx, ilen,
301 partition_points, self.partition_step)
302 mods.append(next_level)
303
304 self.levels = mods
305
306
307 class AddReduce(AddReduceInternal, Elaboratable):
308 """Recursively Add list of numbers together.
309
310 :attribute inputs: input ``Signal``s to be summed. Modification not
311 supported, except for by ``Signal.eq``.
312 :attribute register_levels: List of nesting levels that should have
313 pipeline registers.
314 :attribute output: output sum.
315 :attribute partition_points: the input partition points. Modification not
316 supported, except for by ``Signal.eq``.
317 """
318
319 def __init__(self, inputs, output_width, register_levels, part_pts,
320 part_ops, partition_step=1):
321 """Create an ``AddReduce``.
322
323 :param inputs: input ``Signal``s to be summed.
324 :param output_width: bit-width of ``output``.
325 :param register_levels: List of nesting levels that should have
326 pipeline registers.
327 :param partition_points: the input partition points.
328 """
329 self._inputs = inputs
330 self._part_pts = part_pts
331 self._part_ops = part_ops
332 n_parts = len(part_ops)
333 self.i = AddReduceData(part_pts, len(inputs),
334 output_width, n_parts)
335 AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
336 partition_step)
337 self.o = FinalReduceData(part_pts, output_width, n_parts)
338 self.register_levels = register_levels
339
340 @staticmethod
341 def get_max_level(input_count):
342 return AddReduceSingle.get_max_level(input_count)
343
344 @staticmethod
345 def next_register_levels(register_levels):
346 """``Iterable`` of ``register_levels`` for next recursive level."""
347 for level in register_levels:
348 if level > 0:
349 yield level - 1
350
351 def elaborate(self, platform):
352 """Elaborate this module."""
353 m = Module()
354
355 m.d.comb += self.i.eq_from(self._part_pts,
356 self._inputs, self._part_ops)
357
358 for i, next_level in enumerate(self.levels):
359 setattr(m.submodules, "next_level%d" % i, next_level)
360
361 i = self.i
362 for idx in range(len(self.levels)):
363 mcur = self.levels[idx]
364 if idx in self.register_levels:
365 m.d.sync += mcur.i.eq(i)
366 else:
367 m.d.comb += mcur.i.eq(i)
368 i = mcur.o # for next loop
369
370 # output comes from last module
371 m.d.comb += self.o.eq(i)
372
373 return m
374
375
376 OP_MUL_LOW = 0
377 OP_MUL_SIGNED_HIGH = 1
378 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
379 OP_MUL_UNSIGNED_HIGH = 3
380
381
382 def get_term(value, shift=0, enabled=None):
383 if enabled is not None:
384 value = Mux(enabled, value, 0)
385 if shift > 0:
386 value = Cat(Repl(C(0, 1), shift), value)
387 else:
388 assert shift == 0
389 return value
390
391
392 class ProductTerm(Elaboratable):
393 """ this class creates a single product term (a[..]*b[..]).
394 it has a design flaw in that is the *output* that is selected,
395 where the multiplication(s) are combinatorially generated
396 all the time.
397 """
398
399 def __init__(self, width, twidth, pbwid, a_index, b_index):
400 self.a_index = a_index
401 self.b_index = b_index
402 shift = 8 * (self.a_index + self.b_index)
403 self.pwidth = width
404 self.twidth = twidth
405 self.width = width*2
406 self.shift = shift
407
408 self.ti = Signal(self.width, reset_less=True)
409 self.term = Signal(twidth, reset_less=True)
410 self.a = Signal(twidth//2, reset_less=True)
411 self.b = Signal(twidth//2, reset_less=True)
412 self.pb_en = Signal(pbwid, reset_less=True)
413
414 self.tl = tl = []
415 min_index = min(self.a_index, self.b_index)
416 max_index = max(self.a_index, self.b_index)
417 for i in range(min_index, max_index):
418 tl.append(self.pb_en[i])
419 name = "te_%d_%d" % (self.a_index, self.b_index)
420 if len(tl) > 0:
421 term_enabled = Signal(name=name, reset_less=True)
422 else:
423 term_enabled = None
424 self.enabled = term_enabled
425 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
426
427 def elaborate(self, platform):
428
429 m = Module()
430 if self.enabled is not None:
431 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
432
433 bsa = Signal(self.width, reset_less=True)
434 bsb = Signal(self.width, reset_less=True)
435 a_index, b_index = self.a_index, self.b_index
436 pwidth = self.pwidth
437 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
438 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
439 m.d.comb += self.ti.eq(bsa * bsb)
440 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
441 """
442 #TODO: sort out width issues, get inputs a/b switched on/off.
443 #data going into Muxes is 1/2 the required width
444
445 pwidth = self.pwidth
446 width = self.width
447 bsa = Signal(self.twidth//2, reset_less=True)
448 bsb = Signal(self.twidth//2, reset_less=True)
449 asel = Signal(width, reset_less=True)
450 bsel = Signal(width, reset_less=True)
451 a_index, b_index = self.a_index, self.b_index
452 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
453 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
454 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
455 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
456 m.d.comb += self.ti.eq(bsa * bsb)
457 m.d.comb += self.term.eq(self.ti)
458 """
459
460 return m
461
462
463 class ProductTerms(Elaboratable):
464 """ creates a bank of product terms. also performs the actual bit-selection
465 this class is to be wrapped with a for-loop on the "a" operand.
466 it creates a second-level for-loop on the "b" operand.
467 """
468
469 def __init__(self, width, twidth, pbwid, a_index, blen):
470 self.a_index = a_index
471 self.blen = blen
472 self.pwidth = width
473 self.twidth = twidth
474 self.pbwid = pbwid
475 self.a = Signal(twidth//2, reset_less=True)
476 self.b = Signal(twidth//2, reset_less=True)
477 self.pb_en = Signal(pbwid, reset_less=True)
478 self.terms = [Signal(twidth, name="term%d" % i, reset_less=True)
479 for i in range(blen)]
480
481 def elaborate(self, platform):
482
483 m = Module()
484
485 for b_index in range(self.blen):
486 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
487 self.a_index, b_index)
488 setattr(m.submodules, "term_%d" % b_index, t)
489
490 m.d.comb += t.a.eq(self.a)
491 m.d.comb += t.b.eq(self.b)
492 m.d.comb += t.pb_en.eq(self.pb_en)
493
494 m.d.comb += self.terms[b_index].eq(t.term)
495
496 return m
497
498
499 class LSBNegTerm(Elaboratable):
500
501 def __init__(self, bit_width):
502 self.bit_width = bit_width
503 self.part = Signal(reset_less=True)
504 self.signed = Signal(reset_less=True)
505 self.op = Signal(bit_width, reset_less=True)
506 self.msb = Signal(reset_less=True)
507 self.nt = Signal(bit_width*2, reset_less=True)
508 self.nl = Signal(bit_width*2, reset_less=True)
509
510 def elaborate(self, platform):
511 m = Module()
512 comb = m.d.comb
513 bit_wid = self.bit_width
514 ext = Repl(0, bit_wid) # extend output to HI part
515
516 # determine sign of each incoming number *in this partition*
517 enabled = Signal(reset_less=True)
518 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
519
520 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
521 # negation operation is split into a bitwise not and a +1.
522 # likewise for 16, 32, and 64-bit values.
523
524 # width-extended 1s complement if a is signed, otherwise zero
525 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
526
527 # add 1 if signed, otherwise add zero
528 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
529
530 return m
531
532
533 class Parts(Elaboratable):
534
535 def __init__(self, pbwid, part_pts, n_parts):
536 self.pbwid = pbwid
537 # inputs
538 self.part_pts = PartitionPoints.like(part_pts)
539 # outputs
540 self.parts = [Signal(name=f"part_{i}", reset_less=True)
541 for i in range(n_parts)]
542
543 def elaborate(self, platform):
544 m = Module()
545
546 part_pts, parts = self.part_pts, self.parts
547 # collect part-bytes (double factor because the input is extended)
548 pbs = Signal(self.pbwid, reset_less=True)
549 tl = []
550 for i in range(self.pbwid):
551 pb = Signal(name="pb%d" % i, reset_less=True)
552 m.d.comb += pb.eq(part_pts.part_byte(i))
553 tl.append(pb)
554 m.d.comb += pbs.eq(Cat(*tl))
555
556 # negated-temporary copy of partition bits
557 npbs = Signal.like(pbs, reset_less=True)
558 m.d.comb += npbs.eq(~pbs)
559 byte_count = 8 // len(parts)
560 for i in range(len(parts)):
561 pbl = []
562 pbl.append(npbs[i * byte_count - 1])
563 for j in range(i * byte_count, (i + 1) * byte_count - 1):
564 pbl.append(pbs[j])
565 pbl.append(npbs[(i + 1) * byte_count - 1])
566 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
567 m.d.comb += value.eq(Cat(*pbl))
568 m.d.comb += parts[i].eq(~(value).bool())
569
570 return m
571
572
573 class Part(Elaboratable):
574 """ a key class which, depending on the partitioning, will determine
575 what action to take when parts of the output are signed or unsigned.
576
577 this requires 2 pieces of data *per operand, per partition*:
578 whether the MSB is HI/LO (per partition!), and whether a signed
579 or unsigned operation has been *requested*.
580
581 once that is determined, signed is basically carried out
582 by splitting 2's complement into 1's complement plus one.
583 1's complement is just a bit-inversion.
584
585 the extra terms - as separate terms - are then thrown at the
586 AddReduce alongside the multiplication part-results.
587 """
588
589 def __init__(self, part_pts, width, n_parts, pbwid):
590
591 self.pbwid = pbwid
592 self.part_pts = part_pts
593
594 # inputs
595 self.a = Signal(64, reset_less=True)
596 self.b = Signal(64, reset_less=True)
597 self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
598 for i in range(8)]
599 self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
600 for i in range(8)]
601 self.pbs = Signal(pbwid, reset_less=True)
602
603 # outputs
604 self.parts = [Signal(name=f"part_{i}", reset_less=True)
605 for i in range(n_parts)]
606
607 self.not_a_term = Signal(width, reset_less=True)
608 self.neg_lsb_a_term = Signal(width, reset_less=True)
609 self.not_b_term = Signal(width, reset_less=True)
610 self.neg_lsb_b_term = Signal(width, reset_less=True)
611
612 def elaborate(self, platform):
613 m = Module()
614
615 pbs, parts = self.pbs, self.parts
616 part_pts = self.part_pts
617 m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
618 m.d.comb += p.part_pts.eq(part_pts)
619 parts = p.parts
620
621 byte_count = 8 // len(parts)
622
623 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
624 self.not_a_term, self.neg_lsb_a_term,
625 self.not_b_term, self.neg_lsb_b_term)
626
627 byte_width = 8 // len(parts) # byte width
628 bit_wid = 8 * byte_width # bit width
629 nat, nbt, nla, nlb = [], [], [], []
630 for i in range(len(parts)):
631 # work out bit-inverted and +1 term for a.
632 pa = LSBNegTerm(bit_wid)
633 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
634 m.d.comb += pa.part.eq(parts[i])
635 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
636 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
637 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
638 nat.append(pa.nt)
639 nla.append(pa.nl)
640
641 # work out bit-inverted and +1 term for b
642 pb = LSBNegTerm(bit_wid)
643 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
644 m.d.comb += pb.part.eq(parts[i])
645 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
646 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
647 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
648 nbt.append(pb.nt)
649 nlb.append(pb.nl)
650
651 # concatenate together and return all 4 results.
652 m.d.comb += [not_a_term.eq(Cat(*nat)),
653 not_b_term.eq(Cat(*nbt)),
654 neg_lsb_a_term.eq(Cat(*nla)),
655 neg_lsb_b_term.eq(Cat(*nlb)),
656 ]
657
658 return m
659
660
661 class IntermediateOut(Elaboratable):
662 """ selects the HI/LO part of the multiplication, for a given bit-width
663 the output is also reconstructed in its SIMD (partition) lanes.
664 """
665
666 def __init__(self, width, out_wid, n_parts):
667 self.width = width
668 self.n_parts = n_parts
669 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
670 for i in range(8)]
671 self.intermed = Signal(out_wid, reset_less=True)
672 self.output = Signal(out_wid//2, reset_less=True)
673
674 def elaborate(self, platform):
675 m = Module()
676
677 ol = []
678 w = self.width
679 sel = w // 8
680 for i in range(self.n_parts):
681 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
682 m.d.comb += op.eq(
683 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
684 self.intermed.bit_select(i * w*2, w),
685 self.intermed.bit_select(i * w*2 + w, w)))
686 ol.append(op)
687 m.d.comb += self.output.eq(Cat(*ol))
688
689 return m
690
691
692 class FinalOut(PipeModBase):
693 """ selects the final output based on the partitioning.
694
695 each byte is selectable independently, i.e. it is possible
696 that some partitions requested 8-bit computation whilst others
697 requested 16 or 32 bit.
698 """
699
700 def __init__(self, pspec, part_pts):
701
702 self.part_pts = part_pts
703 self.output_width = pspec.width * 2
704 self.n_parts = pspec.n_parts
705 self.out_wid = pspec.width
706
707 super().__init__(pspec, "finalout")
708
709 def ispec(self):
710 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
711
712 def ospec(self):
713 return OutputData()
714
715 def elaborate(self, platform):
716 m = Module()
717
718 part_pts = self.part_pts
719 m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
720 m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
721 m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
722 m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
723
724 out_part_pts = self.i.part_pts
725
726 # temporaries
727 d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
728 d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
729 d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
730
731 i8 = Signal(self.out_wid, reset_less=True)
732 i16 = Signal(self.out_wid, reset_less=True)
733 i32 = Signal(self.out_wid, reset_less=True)
734 i64 = Signal(self.out_wid, reset_less=True)
735
736 m.d.comb += p_8.part_pts.eq(out_part_pts)
737 m.d.comb += p_16.part_pts.eq(out_part_pts)
738 m.d.comb += p_32.part_pts.eq(out_part_pts)
739 m.d.comb += p_64.part_pts.eq(out_part_pts)
740
741 for i in range(len(p_8.parts)):
742 m.d.comb += d8[i].eq(p_8.parts[i])
743 for i in range(len(p_16.parts)):
744 m.d.comb += d16[i].eq(p_16.parts[i])
745 for i in range(len(p_32.parts)):
746 m.d.comb += d32[i].eq(p_32.parts[i])
747 m.d.comb += i8.eq(self.i.outputs[0])
748 m.d.comb += i16.eq(self.i.outputs[1])
749 m.d.comb += i32.eq(self.i.outputs[2])
750 m.d.comb += i64.eq(self.i.outputs[3])
751
752 ol = []
753 for i in range(8):
754 # select one of the outputs: d8 selects i8, d16 selects i16
755 # d32 selects i32, and the default is i64.
756 # d8 and d16 are ORed together in the first Mux
757 # then the 2nd selects either i8 or i16.
758 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
759 op = Signal(8, reset_less=True, name="op_%d" % i)
760 m.d.comb += op.eq(
761 Mux(d8[i] | d16[i // 2],
762 Mux(d8[i], i8.bit_select(i * 8, 8),
763 i16.bit_select(i * 8, 8)),
764 Mux(d32[i // 4], i32.bit_select(i * 8, 8),
765 i64.bit_select(i * 8, 8))))
766 ol.append(op)
767
768 # create outputs
769 m.d.comb += self.o.output.eq(Cat(*ol))
770 m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
771
772 return m
773
774
775 class OrMod(Elaboratable):
776 """ ORs four values together in a hierarchical tree
777 """
778
779 def __init__(self, wid):
780 self.wid = wid
781 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
782 for i in range(4)]
783 self.orout = Signal(wid, reset_less=True)
784
785 def elaborate(self, platform):
786 m = Module()
787 or1 = Signal(self.wid, reset_less=True)
788 or2 = Signal(self.wid, reset_less=True)
789 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
790 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
791 m.d.comb += self.orout.eq(or1 | or2)
792
793 return m
794
795
796 class Signs(Elaboratable):
797 """ determines whether a or b are signed numbers
798 based on the required operation type (OP_MUL_*)
799 """
800
801 def __init__(self):
802 self.part_ops = Signal(2, reset_less=True)
803 self.a_signed = Signal(reset_less=True)
804 self.b_signed = Signal(reset_less=True)
805
806 def elaborate(self, platform):
807
808 m = Module()
809
810 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
811 bsig = (self.part_ops == OP_MUL_LOW) \
812 | (self.part_ops == OP_MUL_SIGNED_HIGH)
813 m.d.comb += self.a_signed.eq(asig)
814 m.d.comb += self.b_signed.eq(bsig)
815
816 return m
817
818
819 class IntermediateData:
820
821 def __init__(self, part_pts, output_width, n_parts):
822 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
823 for i in range(n_parts)]
824 self.part_pts = part_pts.like()
825 self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
826 for i in range(4)]
827 # intermediates (needed for unit tests)
828 self.intermediate_output = Signal(output_width)
829
830 def eq_from(self, part_pts, outputs, intermediate_output,
831 part_ops):
832 return [self.part_pts.eq(part_pts)] + \
833 [self.intermediate_output.eq(intermediate_output)] + \
834 [self.outputs[i].eq(outputs[i])
835 for i in range(4)] + \
836 [self.part_ops[i].eq(part_ops[i])
837 for i in range(len(self.part_ops))]
838
839 def eq(self, rhs):
840 return self.eq_from(rhs.part_pts, rhs.outputs,
841 rhs.intermediate_output, rhs.part_ops)
842
843
844 class InputData:
845
846 def __init__(self):
847 self.a = Signal(64)
848 self.b = Signal(64)
849 self.part_pts = PartitionPoints()
850 for i in range(8, 64, 8):
851 self.part_pts[i] = Signal(name=f"part_pts_{i}")
852 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
853
854 def eq_from(self, part_pts, a, b, part_ops):
855 return [self.part_pts.eq(part_pts)] + \
856 [self.a.eq(a), self.b.eq(b)] + \
857 [self.part_ops[i].eq(part_ops[i])
858 for i in range(len(self.part_ops))]
859
860 def eq(self, rhs):
861 return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
862
863
864 class OutputData:
865
866 def __init__(self):
867 self.intermediate_output = Signal(128) # needed for unit tests
868 self.output = Signal(64)
869
870 def eq(self, rhs):
871 return [self.intermediate_output.eq(rhs.intermediate_output),
872 self.output.eq(rhs.output)]
873
874
875 class AllTerms(PipeModBase):
876 """Set of terms to be added together
877 """
878
879 def __init__(self, pspec, n_inputs):
880 """Create an ``AllTerms``.
881 """
882 self.n_inputs = n_inputs
883 self.n_parts = pspec.n_parts
884 self.output_width = pspec.width * 2
885 super().__init__(pspec, "allterms")
886
887 def ispec(self):
888 return InputData()
889
890 def ospec(self):
891 return AddReduceData(self.i.part_pts, self.n_inputs,
892 self.output_width, self.n_parts)
893
894 def elaborate(self, platform):
895 m = Module()
896
897 eps = self.i.part_pts
898
899 # collect part-bytes
900 pbs = Signal(8, reset_less=True)
901 tl = []
902 for i in range(8):
903 pb = Signal(name="pb%d" % i, reset_less=True)
904 m.d.comb += pb.eq(eps.part_byte(i))
905 tl.append(pb)
906 m.d.comb += pbs.eq(Cat(*tl))
907
908 # local variables
909 signs = []
910 for i in range(8):
911 s = Signs()
912 signs.append(s)
913 setattr(m.submodules, "signs%d" % i, s)
914 m.d.comb += s.part_ops.eq(self.i.part_ops[i])
915
916 m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
917 m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
918 m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
919 m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
920 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
921 for mod in [part_8, part_16, part_32, part_64]:
922 m.d.comb += mod.a.eq(self.i.a)
923 m.d.comb += mod.b.eq(self.i.b)
924 for i in range(len(signs)):
925 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
926 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
927 m.d.comb += mod.pbs.eq(pbs)
928 nat_l.append(mod.not_a_term)
929 nbt_l.append(mod.not_b_term)
930 nla_l.append(mod.neg_lsb_a_term)
931 nlb_l.append(mod.neg_lsb_b_term)
932
933 terms = []
934
935 for a_index in range(8):
936 t = ProductTerms(8, 128, 8, a_index, 8)
937 setattr(m.submodules, "terms_%d" % a_index, t)
938
939 m.d.comb += t.a.eq(self.i.a)
940 m.d.comb += t.b.eq(self.i.b)
941 m.d.comb += t.pb_en.eq(pbs)
942
943 for term in t.terms:
944 terms.append(term)
945
946 # it's fine to bitwise-or data together since they are never enabled
947 # at the same time
948 m.submodules.nat_or = nat_or = OrMod(128)
949 m.submodules.nbt_or = nbt_or = OrMod(128)
950 m.submodules.nla_or = nla_or = OrMod(128)
951 m.submodules.nlb_or = nlb_or = OrMod(128)
952 for l, mod in [(nat_l, nat_or),
953 (nbt_l, nbt_or),
954 (nla_l, nla_or),
955 (nlb_l, nlb_or)]:
956 for i in range(len(l)):
957 m.d.comb += mod.orin[i].eq(l[i])
958 terms.append(mod.orout)
959
960 # copy the intermediate terms to the output
961 for i, value in enumerate(terms):
962 m.d.comb += self.o.terms[i].eq(value)
963
964 # copy reg part points and part ops to output
965 m.d.comb += self.o.part_pts.eq(eps)
966 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
967 for i in range(len(self.i.part_ops))]
968
969 return m
970
971
972 class Intermediates(PipeModBase):
973 """ Intermediate output modules
974 """
975
976 def __init__(self, pspec, part_pts):
977 self.part_pts = part_pts
978 self.output_width = pspec.width * 2
979 self.n_parts = pspec.n_parts
980
981 super().__init__(pspec, "intermediates")
982
983 def ispec(self):
984 return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
985
986 def ospec(self):
987 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
988
989 def elaborate(self, platform):
990 m = Module()
991
992 out_part_ops = self.i.part_ops
993 out_part_pts = self.i.part_pts
994
995 # create _output_64
996 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
997 m.d.comb += io64.intermed.eq(self.i.output)
998 for i in range(8):
999 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1000 m.d.comb += self.o.outputs[3].eq(io64.output)
1001
1002 # create _output_32
1003 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1004 m.d.comb += io32.intermed.eq(self.i.output)
1005 for i in range(8):
1006 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1007 m.d.comb += self.o.outputs[2].eq(io32.output)
1008
1009 # create _output_16
1010 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1011 m.d.comb += io16.intermed.eq(self.i.output)
1012 for i in range(8):
1013 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1014 m.d.comb += self.o.outputs[1].eq(io16.output)
1015
1016 # create _output_8
1017 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1018 m.d.comb += io8.intermed.eq(self.i.output)
1019 for i in range(8):
1020 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1021 m.d.comb += self.o.outputs[0].eq(io8.output)
1022
1023 for i in range(8):
1024 m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
1025 m.d.comb += self.o.part_pts.eq(out_part_pts)
1026 m.d.comb += self.o.intermediate_output.eq(self.i.output)
1027
1028 return m
1029
1030
1031 class Mul8_16_32_64(Elaboratable):
1032 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1033
1034 XXX NOTE: this class is intended for unit test purposes ONLY.
1035
1036 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1037 partitions on naturally-aligned boundaries. Supports the operation being
1038 set for each partition independently.
1039
1040 :attribute part_pts: the input partition points. Has a partition point at
1041 multiples of 8 in 0 < i < 64. Each partition point's associated
1042 ``Value`` is a ``Signal``. Modification not supported, except for by
1043 ``Signal.eq``.
1044 :attribute part_ops: the operation for each byte. The operation for a
1045 particular partition is selected by assigning the selected operation
1046 code to each byte in the partition. The allowed operation codes are:
1047
1048 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1049 RISC-V's `mul` instruction.
1050 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1051 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1052 instruction.
1053 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1054 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1055 `mulhsu` instruction.
1056 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1057 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1058 instruction.
1059 """
1060
1061 def __init__(self, register_levels=()):
1062 """ register_levels: specifies the points in the cascade at which
1063 flip-flops are to be inserted.
1064 """
1065
1066 self.id_wid = 0 # num_bits(num_rows)
1067 self.op_wid = 0
1068 self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
1069 self.pspec.n_parts = 8
1070
1071 # parameter(s)
1072 self.register_levels = list(register_levels)
1073
1074 self.i = self.ispec()
1075 self.o = self.ospec()
1076
1077 # inputs
1078 self.part_pts = self.i.part_pts
1079 self.part_ops = self.i.part_ops
1080 self.a = self.i.a
1081 self.b = self.i.b
1082
1083 # output
1084 self.intermediate_output = self.o.intermediate_output
1085 self.output = self.o.output
1086
1087 def ispec(self):
1088 return InputData()
1089
1090 def ospec(self):
1091 return OutputData()
1092
1093 def elaborate(self, platform):
1094 m = Module()
1095
1096 part_pts = self.part_pts
1097
1098 n_inputs = 64 + 4
1099 t = AllTerms(self.pspec, n_inputs)
1100 t.setup(m, self.i)
1101
1102 terms = t.o.terms
1103
1104 at = AddReduceInternal(self.pspec, n_inputs,
1105 part_pts, partition_step=2)
1106
1107 i = t.o
1108 for idx in range(len(at.levels)):
1109 mcur = at.levels[idx]
1110 mcur.setup(m, i)
1111 o = mcur.ospec()
1112 if idx in self.register_levels:
1113 m.d.sync += o.eq(mcur.process(i))
1114 else:
1115 m.d.comb += o.eq(mcur.process(i))
1116 i = o # for next loop
1117
1118 interm = Intermediates(self.pspec, part_pts)
1119 interm.setup(m, i)
1120 o = interm.process(interm.i)
1121
1122 # final output
1123 finalout = FinalOut(self.pspec, part_pts)
1124 finalout.setup(m, o)
1125 m.d.comb += self.o.eq(finalout.process(o))
1126
1127 return m
1128
1129
1130 if __name__ == "__main__":
1131 m = Mul8_16_32_64()
1132 main(m, ports=[m.a,
1133 m.b,
1134 m.intermediate_output,
1135 m.output,
1136 *m.part_ops,
1137 *m.part_pts.values()])