Convert add and sub to return PartitionedSignal
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11 from ieee754.pipeline import PipelineSpec
12 from nmutil.pipemodbase import PipeModBase
13
14 from ieee754.part_mul_add.partpoints import PartitionPoints
15 from ieee754.part_mul_add.adder import PartitionedAdder, MaskedFullAdder
16
17
18 FULL_ADDER_INPUT_COUNT = 3
19
20 class AddReduceData:
21
22 def __init__(self, part_pts, n_inputs, output_width, n_parts):
23 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
24 for i in range(n_parts)]
25 self.terms = [Signal(output_width, name=f"terms_{i}",
26 reset_less=True)
27 for i in range(n_inputs)]
28 self.part_pts = part_pts.like()
29
30 def eq_from(self, part_pts, inputs, part_ops):
31 return [self.part_pts.eq(part_pts)] + \
32 [self.terms[i].eq(inputs[i])
33 for i in range(len(self.terms))] + \
34 [self.part_ops[i].eq(part_ops[i])
35 for i in range(len(self.part_ops))]
36
37 def eq(self, rhs):
38 return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
39
40
41 class FinalReduceData:
42
43 def __init__(self, part_pts, output_width, n_parts):
44 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
45 for i in range(n_parts)]
46 self.output = Signal(output_width, reset_less=True)
47 self.part_pts = part_pts.like()
48
49 def eq_from(self, part_pts, output, part_ops):
50 return [self.part_pts.eq(part_pts)] + \
51 [self.output.eq(output)] + \
52 [self.part_ops[i].eq(part_ops[i])
53 for i in range(len(self.part_ops))]
54
55 def eq(self, rhs):
56 return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
57
58
59 class FinalAdd(PipeModBase):
60 """ Final stage of add reduce
61 """
62
63 def __init__(self, pspec, lidx, n_inputs, partition_points,
64 partition_step=1):
65 self.lidx = lidx
66 self.partition_step = partition_step
67 self.output_width = pspec.width * 2
68 self.n_inputs = n_inputs
69 self.n_parts = pspec.n_parts
70 self.partition_points = PartitionPoints(partition_points)
71 if not self.partition_points.fits_in_width(self.output_width):
72 raise ValueError("partition_points doesn't fit in output_width")
73
74 super().__init__(pspec, "finaladd")
75
76 def ispec(self):
77 return AddReduceData(self.partition_points, self.n_inputs,
78 self.output_width, self.n_parts)
79
80 def ospec(self):
81 return FinalReduceData(self.partition_points,
82 self.output_width, self.n_parts)
83
84 def elaborate(self, platform):
85 """Elaborate this module."""
86 m = Module()
87
88 output_width = self.output_width
89 output = Signal(output_width, reset_less=True)
90 if self.n_inputs == 0:
91 # use 0 as the default output value
92 m.d.comb += output.eq(0)
93 elif self.n_inputs == 1:
94 # handle single input
95 m.d.comb += output.eq(self.i.terms[0])
96 else:
97 # base case for adding 2 inputs
98 assert self.n_inputs == 2
99 adder = PartitionedAdder(output_width,
100 self.i.part_pts, self.partition_step)
101 m.submodules.final_adder = adder
102 m.d.comb += adder.a.eq(self.i.terms[0])
103 m.d.comb += adder.b.eq(self.i.terms[1])
104 m.d.comb += output.eq(adder.output)
105
106 # create output
107 m.d.comb += self.o.eq_from(self.i.part_pts, output,
108 self.i.part_ops)
109
110 return m
111
112
113 class AddReduceSingle(PipeModBase):
114 """Add list of numbers together.
115
116 :attribute inputs: input ``Signal``s to be summed. Modification not
117 supported, except for by ``Signal.eq``.
118 :attribute register_levels: List of nesting levels that should have
119 pipeline registers.
120 :attribute output: output sum.
121 :attribute partition_points: the input partition points. Modification not
122 supported, except for by ``Signal.eq``.
123 """
124
125 def __init__(self, pspec, lidx, n_inputs, partition_points,
126 partition_step=1):
127 """Create an ``AddReduce``.
128
129 :param inputs: input ``Signal``s to be summed.
130 :param output_width: bit-width of ``output``.
131 :param partition_points: the input partition points.
132 """
133 self.lidx = lidx
134 self.partition_step = partition_step
135 self.n_inputs = n_inputs
136 self.n_parts = pspec.n_parts
137 self.output_width = pspec.width * 2
138 self.partition_points = PartitionPoints(partition_points)
139 if not self.partition_points.fits_in_width(self.output_width):
140 raise ValueError("partition_points doesn't fit in output_width")
141
142 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
143 self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
144
145 super().__init__(pspec, "addreduce_%d" % lidx)
146
147 def ispec(self):
148 return AddReduceData(self.partition_points, self.n_inputs,
149 self.output_width, self.n_parts)
150
151 def ospec(self):
152 return AddReduceData(self.partition_points, self.n_terms,
153 self.output_width, self.n_parts)
154
155 @staticmethod
156 def calc_n_inputs(n_inputs, groups):
157 retval = len(groups)*2
158 if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
159 retval += 1
160 elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
161 retval += 2
162 else:
163 assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
164 return retval
165
166 @staticmethod
167 def get_max_level(input_count):
168 """Get the maximum level.
169
170 All ``register_levels`` must be less than or equal to the maximum
171 level.
172 """
173 retval = 0
174 while True:
175 groups = AddReduceSingle.full_adder_groups(input_count)
176 if len(groups) == 0:
177 return retval
178 input_count %= FULL_ADDER_INPUT_COUNT
179 input_count += 2 * len(groups)
180 retval += 1
181
182 @staticmethod
183 def full_adder_groups(input_count):
184 """Get ``inputs`` indices for which a full adder should be built."""
185 return range(0,
186 input_count - FULL_ADDER_INPUT_COUNT + 1,
187 FULL_ADDER_INPUT_COUNT)
188
189 def create_next_terms(self):
190 """ create next intermediate terms, for linking up in elaborate, below
191 """
192 terms = []
193 adders = []
194
195 # create full adders for this recursive level.
196 # this shrinks N terms to 2 * (N // 3) plus the remainder
197 for i in self.groups:
198 adder_i = MaskedFullAdder(self.output_width)
199 adders.append((i, adder_i))
200 # add both the sum and the masked-carry to the next level.
201 # 3 inputs have now been reduced to 2...
202 terms.append(adder_i.sum)
203 terms.append(adder_i.mcarry)
204 # handle the remaining inputs.
205 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
206 terms.append(self.i.terms[-1])
207 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
208 # Just pass the terms to the next layer, since we wouldn't gain
209 # anything by using a half adder since there would still be 2 terms
210 # and just passing the terms to the next layer saves gates.
211 terms.append(self.i.terms[-2])
212 terms.append(self.i.terms[-1])
213 else:
214 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
215
216 return terms, adders
217
218 def elaborate(self, platform):
219 """Elaborate this module."""
220 m = Module()
221
222 terms, adders = self.create_next_terms()
223
224 # copy the intermediate terms to the output
225 for i, value in enumerate(terms):
226 m.d.comb += self.o.terms[i].eq(value)
227
228 # copy reg part points and part ops to output
229 m.d.comb += self.o.part_pts.eq(self.i.part_pts)
230 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
231 for i in range(len(self.i.part_ops))]
232
233 # set up the partition mask (for the adders)
234 part_mask = Signal(self.output_width, reset_less=True)
235
236 # get partition points as a mask
237 mask = self.i.part_pts.as_mask(self.output_width,
238 mul=self.partition_step)
239 m.d.comb += part_mask.eq(mask)
240
241 # add and link the intermediate term modules
242 for i, (iidx, adder_i) in enumerate(adders):
243 setattr(m.submodules, f"adder_{i}", adder_i)
244
245 m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
246 m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
247 m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
248 m.d.comb += adder_i.mask.eq(part_mask)
249
250 return m
251
252
253 class AddReduceInternal:
254 """Iteratively Add list of numbers together.
255
256 :attribute inputs: input ``Signal``s to be summed. Modification not
257 supported, except for by ``Signal.eq``.
258 :attribute register_levels: List of nesting levels that should have
259 pipeline registers.
260 :attribute output: output sum.
261 :attribute partition_points: the input partition points. Modification not
262 supported, except for by ``Signal.eq``.
263 """
264
265 def __init__(self, pspec, n_inputs, part_pts, partition_step=1):
266 """Create an ``AddReduce``.
267
268 :param inputs: input ``Signal``s to be summed.
269 :param output_width: bit-width of ``output``.
270 :param partition_points: the input partition points.
271 """
272 self.pspec = pspec
273 self.n_inputs = n_inputs
274 self.output_width = pspec.width * 2
275 self.partition_points = part_pts
276 self.partition_step = partition_step
277
278 self.create_levels()
279
280 def create_levels(self):
281 """creates reduction levels"""
282
283 mods = []
284 partition_points = self.partition_points
285 ilen = self.n_inputs
286 while True:
287 groups = AddReduceSingle.full_adder_groups(ilen)
288 if len(groups) == 0:
289 break
290 lidx = len(mods)
291 next_level = AddReduceSingle(self.pspec, lidx, ilen,
292 partition_points,
293 self.partition_step)
294 mods.append(next_level)
295 partition_points = next_level.i.part_pts
296 ilen = len(next_level.o.terms)
297
298 lidx = len(mods)
299 next_level = FinalAdd(self.pspec, lidx, ilen,
300 partition_points, self.partition_step)
301 mods.append(next_level)
302
303 self.levels = mods
304
305
306 class AddReduce(AddReduceInternal, Elaboratable):
307 """Recursively Add list of numbers together.
308
309 :attribute inputs: input ``Signal``s to be summed. Modification not
310 supported, except for by ``Signal.eq``.
311 :attribute register_levels: List of nesting levels that should have
312 pipeline registers.
313 :attribute output: output sum.
314 :attribute partition_points: the input partition points. Modification not
315 supported, except for by ``Signal.eq``.
316 """
317
318 def __init__(self, inputs, output_width, register_levels, part_pts,
319 part_ops, partition_step=1):
320 """Create an ``AddReduce``.
321
322 :param inputs: input ``Signal``s to be summed.
323 :param output_width: bit-width of ``output``.
324 :param register_levels: List of nesting levels that should have
325 pipeline registers.
326 :param partition_points: the input partition points.
327 """
328 self._inputs = inputs
329 self._part_pts = part_pts
330 self._part_ops = part_ops
331 n_parts = len(part_ops)
332 self.i = AddReduceData(part_pts, len(inputs),
333 output_width, n_parts)
334 AddReduceInternal.__init__(self, pspec, n_inputs, part_pts,
335 partition_step)
336 self.o = FinalReduceData(part_pts, output_width, n_parts)
337 self.register_levels = register_levels
338
339 @staticmethod
340 def get_max_level(input_count):
341 return AddReduceSingle.get_max_level(input_count)
342
343 @staticmethod
344 def next_register_levels(register_levels):
345 """``Iterable`` of ``register_levels`` for next recursive level."""
346 for level in register_levels:
347 if level > 0:
348 yield level - 1
349
350 def elaborate(self, platform):
351 """Elaborate this module."""
352 m = Module()
353
354 m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
355
356 for i, next_level in enumerate(self.levels):
357 setattr(m.submodules, "next_level%d" % i, next_level)
358
359 i = self.i
360 for idx in range(len(self.levels)):
361 mcur = self.levels[idx]
362 if idx in self.register_levels:
363 m.d.sync += mcur.i.eq(i)
364 else:
365 m.d.comb += mcur.i.eq(i)
366 i = mcur.o # for next loop
367
368 # output comes from last module
369 m.d.comb += self.o.eq(i)
370
371 return m
372
373
374 OP_MUL_LOW = 0
375 OP_MUL_SIGNED_HIGH = 1
376 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
377 OP_MUL_UNSIGNED_HIGH = 3
378
379
380 def get_term(value, shift=0, enabled=None):
381 if enabled is not None:
382 value = Mux(enabled, value, 0)
383 if shift > 0:
384 value = Cat(Repl(C(0, 1), shift), value)
385 else:
386 assert shift == 0
387 return value
388
389
390 class ProductTerm(Elaboratable):
391 """ this class creates a single product term (a[..]*b[..]).
392 it has a design flaw in that is the *output* that is selected,
393 where the multiplication(s) are combinatorially generated
394 all the time.
395 """
396
397 def __init__(self, width, twidth, pbwid, a_index, b_index):
398 self.a_index = a_index
399 self.b_index = b_index
400 shift = 8 * (self.a_index + self.b_index)
401 self.pwidth = width
402 self.twidth = twidth
403 self.width = width*2
404 self.shift = shift
405
406 self.ti = Signal(self.width, reset_less=True)
407 self.term = Signal(twidth, reset_less=True)
408 self.a = Signal(twidth//2, reset_less=True)
409 self.b = Signal(twidth//2, reset_less=True)
410 self.pb_en = Signal(pbwid, reset_less=True)
411
412 self.tl = tl = []
413 min_index = min(self.a_index, self.b_index)
414 max_index = max(self.a_index, self.b_index)
415 for i in range(min_index, max_index):
416 tl.append(self.pb_en[i])
417 name = "te_%d_%d" % (self.a_index, self.b_index)
418 if len(tl) > 0:
419 term_enabled = Signal(name=name, reset_less=True)
420 else:
421 term_enabled = None
422 self.enabled = term_enabled
423 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
424
425 def elaborate(self, platform):
426
427 m = Module()
428 if self.enabled is not None:
429 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
430
431 bsa = Signal(self.width, reset_less=True)
432 bsb = Signal(self.width, reset_less=True)
433 a_index, b_index = self.a_index, self.b_index
434 pwidth = self.pwidth
435 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
436 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
437 m.d.comb += self.ti.eq(bsa * bsb)
438 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
439 """
440 #TODO: sort out width issues, get inputs a/b switched on/off.
441 #data going into Muxes is 1/2 the required width
442
443 pwidth = self.pwidth
444 width = self.width
445 bsa = Signal(self.twidth//2, reset_less=True)
446 bsb = Signal(self.twidth//2, reset_less=True)
447 asel = Signal(width, reset_less=True)
448 bsel = Signal(width, reset_less=True)
449 a_index, b_index = self.a_index, self.b_index
450 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
451 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
452 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
453 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
454 m.d.comb += self.ti.eq(bsa * bsb)
455 m.d.comb += self.term.eq(self.ti)
456 """
457
458 return m
459
460
461 class ProductTerms(Elaboratable):
462 """ creates a bank of product terms. also performs the actual bit-selection
463 this class is to be wrapped with a for-loop on the "a" operand.
464 it creates a second-level for-loop on the "b" operand.
465 """
466 def __init__(self, width, twidth, pbwid, a_index, blen):
467 self.a_index = a_index
468 self.blen = blen
469 self.pwidth = width
470 self.twidth = twidth
471 self.pbwid = pbwid
472 self.a = Signal(twidth//2, reset_less=True)
473 self.b = Signal(twidth//2, reset_less=True)
474 self.pb_en = Signal(pbwid, reset_less=True)
475 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
476 for i in range(blen)]
477
478 def elaborate(self, platform):
479
480 m = Module()
481
482 for b_index in range(self.blen):
483 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
484 self.a_index, b_index)
485 setattr(m.submodules, "term_%d" % b_index, t)
486
487 m.d.comb += t.a.eq(self.a)
488 m.d.comb += t.b.eq(self.b)
489 m.d.comb += t.pb_en.eq(self.pb_en)
490
491 m.d.comb += self.terms[b_index].eq(t.term)
492
493 return m
494
495
496 class LSBNegTerm(Elaboratable):
497
498 def __init__(self, bit_width):
499 self.bit_width = bit_width
500 self.part = Signal(reset_less=True)
501 self.signed = Signal(reset_less=True)
502 self.op = Signal(bit_width, reset_less=True)
503 self.msb = Signal(reset_less=True)
504 self.nt = Signal(bit_width*2, reset_less=True)
505 self.nl = Signal(bit_width*2, reset_less=True)
506
507 def elaborate(self, platform):
508 m = Module()
509 comb = m.d.comb
510 bit_wid = self.bit_width
511 ext = Repl(0, bit_wid) # extend output to HI part
512
513 # determine sign of each incoming number *in this partition*
514 enabled = Signal(reset_less=True)
515 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
516
517 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
518 # negation operation is split into a bitwise not and a +1.
519 # likewise for 16, 32, and 64-bit values.
520
521 # width-extended 1s complement if a is signed, otherwise zero
522 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
523
524 # add 1 if signed, otherwise add zero
525 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
526
527 return m
528
529
530 class Parts(Elaboratable):
531
532 def __init__(self, pbwid, part_pts, n_parts):
533 self.pbwid = pbwid
534 # inputs
535 self.part_pts = PartitionPoints.like(part_pts)
536 # outputs
537 self.parts = [Signal(name=f"part_{i}", reset_less=True)
538 for i in range(n_parts)]
539
540 def elaborate(self, platform):
541 m = Module()
542
543 part_pts, parts = self.part_pts, self.parts
544 # collect part-bytes (double factor because the input is extended)
545 pbs = Signal(self.pbwid, reset_less=True)
546 tl = []
547 for i in range(self.pbwid):
548 pb = Signal(name="pb%d" % i, reset_less=True)
549 m.d.comb += pb.eq(part_pts.part_byte(i))
550 tl.append(pb)
551 m.d.comb += pbs.eq(Cat(*tl))
552
553 # negated-temporary copy of partition bits
554 npbs = Signal.like(pbs, reset_less=True)
555 m.d.comb += npbs.eq(~pbs)
556 byte_count = 8 // len(parts)
557 for i in range(len(parts)):
558 pbl = []
559 pbl.append(npbs[i * byte_count - 1])
560 for j in range(i * byte_count, (i + 1) * byte_count - 1):
561 pbl.append(pbs[j])
562 pbl.append(npbs[(i + 1) * byte_count - 1])
563 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
564 m.d.comb += value.eq(Cat(*pbl))
565 m.d.comb += parts[i].eq(~(value).bool())
566
567 return m
568
569
570 class Part(Elaboratable):
571 """ a key class which, depending on the partitioning, will determine
572 what action to take when parts of the output are signed or unsigned.
573
574 this requires 2 pieces of data *per operand, per partition*:
575 whether the MSB is HI/LO (per partition!), and whether a signed
576 or unsigned operation has been *requested*.
577
578 once that is determined, signed is basically carried out
579 by splitting 2's complement into 1's complement plus one.
580 1's complement is just a bit-inversion.
581
582 the extra terms - as separate terms - are then thrown at the
583 AddReduce alongside the multiplication part-results.
584 """
585 def __init__(self, part_pts, width, n_parts, pbwid):
586
587 self.pbwid = pbwid
588 self.part_pts = part_pts
589
590 # inputs
591 self.a = Signal(64, reset_less=True)
592 self.b = Signal(64, reset_less=True)
593 self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
594 for i in range(8)]
595 self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
596 for i in range(8)]
597 self.pbs = Signal(pbwid, reset_less=True)
598
599 # outputs
600 self.parts = [Signal(name=f"part_{i}", reset_less=True)
601 for i in range(n_parts)]
602
603 self.not_a_term = Signal(width, reset_less=True)
604 self.neg_lsb_a_term = Signal(width, reset_less=True)
605 self.not_b_term = Signal(width, reset_less=True)
606 self.neg_lsb_b_term = Signal(width, reset_less=True)
607
608 def elaborate(self, platform):
609 m = Module()
610
611 pbs, parts = self.pbs, self.parts
612 part_pts = self.part_pts
613 m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
614 m.d.comb += p.part_pts.eq(part_pts)
615 parts = p.parts
616
617 byte_count = 8 // len(parts)
618
619 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
620 self.not_a_term, self.neg_lsb_a_term,
621 self.not_b_term, self.neg_lsb_b_term)
622
623 byte_width = 8 // len(parts) # byte width
624 bit_wid = 8 * byte_width # bit width
625 nat, nbt, nla, nlb = [], [], [], []
626 for i in range(len(parts)):
627 # work out bit-inverted and +1 term for a.
628 pa = LSBNegTerm(bit_wid)
629 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
630 m.d.comb += pa.part.eq(parts[i])
631 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
632 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
633 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
634 nat.append(pa.nt)
635 nla.append(pa.nl)
636
637 # work out bit-inverted and +1 term for b
638 pb = LSBNegTerm(bit_wid)
639 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
640 m.d.comb += pb.part.eq(parts[i])
641 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
642 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
643 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
644 nbt.append(pb.nt)
645 nlb.append(pb.nl)
646
647 # concatenate together and return all 4 results.
648 m.d.comb += [not_a_term.eq(Cat(*nat)),
649 not_b_term.eq(Cat(*nbt)),
650 neg_lsb_a_term.eq(Cat(*nla)),
651 neg_lsb_b_term.eq(Cat(*nlb)),
652 ]
653
654 return m
655
656
657 class IntermediateOut(Elaboratable):
658 """ selects the HI/LO part of the multiplication, for a given bit-width
659 the output is also reconstructed in its SIMD (partition) lanes.
660 """
661 def __init__(self, width, out_wid, n_parts):
662 self.width = width
663 self.n_parts = n_parts
664 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
665 for i in range(8)]
666 self.intermed = Signal(out_wid, reset_less=True)
667 self.output = Signal(out_wid//2, reset_less=True)
668
669 def elaborate(self, platform):
670 m = Module()
671
672 ol = []
673 w = self.width
674 sel = w // 8
675 for i in range(self.n_parts):
676 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
677 m.d.comb += op.eq(
678 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
679 self.intermed.bit_select(i * w*2, w),
680 self.intermed.bit_select(i * w*2 + w, w)))
681 ol.append(op)
682 m.d.comb += self.output.eq(Cat(*ol))
683
684 return m
685
686
687 class FinalOut(PipeModBase):
688 """ selects the final output based on the partitioning.
689
690 each byte is selectable independently, i.e. it is possible
691 that some partitions requested 8-bit computation whilst others
692 requested 16 or 32 bit.
693 """
694 def __init__(self, pspec, part_pts):
695
696 self.part_pts = part_pts
697 self.output_width = pspec.width * 2
698 self.n_parts = pspec.n_parts
699 self.out_wid = pspec.width
700
701 super().__init__(pspec, "finalout")
702
703 def ispec(self):
704 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
705
706 def ospec(self):
707 return OutputData()
708
709 def elaborate(self, platform):
710 m = Module()
711
712 part_pts = self.part_pts
713 m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
714 m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
715 m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
716 m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
717
718 out_part_pts = self.i.part_pts
719
720 # temporaries
721 d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
722 d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
723 d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
724
725 i8 = Signal(self.out_wid, reset_less=True)
726 i16 = Signal(self.out_wid, reset_less=True)
727 i32 = Signal(self.out_wid, reset_less=True)
728 i64 = Signal(self.out_wid, reset_less=True)
729
730 m.d.comb += p_8.part_pts.eq(out_part_pts)
731 m.d.comb += p_16.part_pts.eq(out_part_pts)
732 m.d.comb += p_32.part_pts.eq(out_part_pts)
733 m.d.comb += p_64.part_pts.eq(out_part_pts)
734
735 for i in range(len(p_8.parts)):
736 m.d.comb += d8[i].eq(p_8.parts[i])
737 for i in range(len(p_16.parts)):
738 m.d.comb += d16[i].eq(p_16.parts[i])
739 for i in range(len(p_32.parts)):
740 m.d.comb += d32[i].eq(p_32.parts[i])
741 m.d.comb += i8.eq(self.i.outputs[0])
742 m.d.comb += i16.eq(self.i.outputs[1])
743 m.d.comb += i32.eq(self.i.outputs[2])
744 m.d.comb += i64.eq(self.i.outputs[3])
745
746 ol = []
747 for i in range(8):
748 # select one of the outputs: d8 selects i8, d16 selects i16
749 # d32 selects i32, and the default is i64.
750 # d8 and d16 are ORed together in the first Mux
751 # then the 2nd selects either i8 or i16.
752 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
753 op = Signal(8, reset_less=True, name="op_%d" % i)
754 m.d.comb += op.eq(
755 Mux(d8[i] | d16[i // 2],
756 Mux(d8[i], i8.bit_select(i * 8, 8),
757 i16.bit_select(i * 8, 8)),
758 Mux(d32[i // 4], i32.bit_select(i * 8, 8),
759 i64.bit_select(i * 8, 8))))
760 ol.append(op)
761
762 # create outputs
763 m.d.comb += self.o.output.eq(Cat(*ol))
764 m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
765
766 return m
767
768
769 class OrMod(Elaboratable):
770 """ ORs four values together in a hierarchical tree
771 """
772 def __init__(self, wid):
773 self.wid = wid
774 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
775 for i in range(4)]
776 self.orout = Signal(wid, reset_less=True)
777
778 def elaborate(self, platform):
779 m = Module()
780 or1 = Signal(self.wid, reset_less=True)
781 or2 = Signal(self.wid, reset_less=True)
782 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
783 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
784 m.d.comb += self.orout.eq(or1 | or2)
785
786 return m
787
788
789 class Signs(Elaboratable):
790 """ determines whether a or b are signed numbers
791 based on the required operation type (OP_MUL_*)
792 """
793
794 def __init__(self):
795 self.part_ops = Signal(2, reset_less=True)
796 self.a_signed = Signal(reset_less=True)
797 self.b_signed = Signal(reset_less=True)
798
799 def elaborate(self, platform):
800
801 m = Module()
802
803 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
804 bsig = (self.part_ops == OP_MUL_LOW) \
805 | (self.part_ops == OP_MUL_SIGNED_HIGH)
806 m.d.comb += self.a_signed.eq(asig)
807 m.d.comb += self.b_signed.eq(bsig)
808
809 return m
810
811
812 class IntermediateData:
813
814 def __init__(self, part_pts, output_width, n_parts):
815 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
816 for i in range(n_parts)]
817 self.part_pts = part_pts.like()
818 self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
819 for i in range(4)]
820 # intermediates (needed for unit tests)
821 self.intermediate_output = Signal(output_width)
822
823 def eq_from(self, part_pts, outputs, intermediate_output,
824 part_ops):
825 return [self.part_pts.eq(part_pts)] + \
826 [self.intermediate_output.eq(intermediate_output)] + \
827 [self.outputs[i].eq(outputs[i])
828 for i in range(4)] + \
829 [self.part_ops[i].eq(part_ops[i])
830 for i in range(len(self.part_ops))]
831
832 def eq(self, rhs):
833 return self.eq_from(rhs.part_pts, rhs.outputs,
834 rhs.intermediate_output, rhs.part_ops)
835
836
837 class InputData:
838
839 def __init__(self):
840 self.a = Signal(64)
841 self.b = Signal(64)
842 self.part_pts = PartitionPoints()
843 for i in range(8, 64, 8):
844 self.part_pts[i] = Signal(name=f"part_pts_{i}")
845 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
846
847 def eq_from(self, part_pts, a, b, part_ops):
848 return [self.part_pts.eq(part_pts)] + \
849 [self.a.eq(a), self.b.eq(b)] + \
850 [self.part_ops[i].eq(part_ops[i])
851 for i in range(len(self.part_ops))]
852
853 def eq(self, rhs):
854 return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
855
856
857 class OutputData:
858
859 def __init__(self):
860 self.intermediate_output = Signal(128) # needed for unit tests
861 self.output = Signal(64)
862
863 def eq(self, rhs):
864 return [self.intermediate_output.eq(rhs.intermediate_output),
865 self.output.eq(rhs.output)]
866
867
868 class AllTerms(PipeModBase):
869 """Set of terms to be added together
870 """
871
872 def __init__(self, pspec, n_inputs):
873 """Create an ``AllTerms``.
874 """
875 self.n_inputs = n_inputs
876 self.n_parts = pspec.n_parts
877 self.output_width = pspec.width * 2
878 super().__init__(pspec, "allterms")
879
880 def ispec(self):
881 return InputData()
882
883 def ospec(self):
884 return AddReduceData(self.i.part_pts, self.n_inputs,
885 self.output_width, self.n_parts)
886
887 def elaborate(self, platform):
888 m = Module()
889
890 eps = self.i.part_pts
891
892 # collect part-bytes
893 pbs = Signal(8, reset_less=True)
894 tl = []
895 for i in range(8):
896 pb = Signal(name="pb%d" % i, reset_less=True)
897 m.d.comb += pb.eq(eps.part_byte(i))
898 tl.append(pb)
899 m.d.comb += pbs.eq(Cat(*tl))
900
901 # local variables
902 signs = []
903 for i in range(8):
904 s = Signs()
905 signs.append(s)
906 setattr(m.submodules, "signs%d" % i, s)
907 m.d.comb += s.part_ops.eq(self.i.part_ops[i])
908
909 m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
910 m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
911 m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
912 m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
913 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
914 for mod in [part_8, part_16, part_32, part_64]:
915 m.d.comb += mod.a.eq(self.i.a)
916 m.d.comb += mod.b.eq(self.i.b)
917 for i in range(len(signs)):
918 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
919 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
920 m.d.comb += mod.pbs.eq(pbs)
921 nat_l.append(mod.not_a_term)
922 nbt_l.append(mod.not_b_term)
923 nla_l.append(mod.neg_lsb_a_term)
924 nlb_l.append(mod.neg_lsb_b_term)
925
926 terms = []
927
928 for a_index in range(8):
929 t = ProductTerms(8, 128, 8, a_index, 8)
930 setattr(m.submodules, "terms_%d" % a_index, t)
931
932 m.d.comb += t.a.eq(self.i.a)
933 m.d.comb += t.b.eq(self.i.b)
934 m.d.comb += t.pb_en.eq(pbs)
935
936 for term in t.terms:
937 terms.append(term)
938
939 # it's fine to bitwise-or data together since they are never enabled
940 # at the same time
941 m.submodules.nat_or = nat_or = OrMod(128)
942 m.submodules.nbt_or = nbt_or = OrMod(128)
943 m.submodules.nla_or = nla_or = OrMod(128)
944 m.submodules.nlb_or = nlb_or = OrMod(128)
945 for l, mod in [(nat_l, nat_or),
946 (nbt_l, nbt_or),
947 (nla_l, nla_or),
948 (nlb_l, nlb_or)]:
949 for i in range(len(l)):
950 m.d.comb += mod.orin[i].eq(l[i])
951 terms.append(mod.orout)
952
953 # copy the intermediate terms to the output
954 for i, value in enumerate(terms):
955 m.d.comb += self.o.terms[i].eq(value)
956
957 # copy reg part points and part ops to output
958 m.d.comb += self.o.part_pts.eq(eps)
959 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
960 for i in range(len(self.i.part_ops))]
961
962 return m
963
964
965 class Intermediates(PipeModBase):
966 """ Intermediate output modules
967 """
968
969 def __init__(self, pspec, part_pts):
970 self.part_pts = part_pts
971 self.output_width = pspec.width * 2
972 self.n_parts = pspec.n_parts
973
974 super().__init__(pspec, "intermediates")
975
976 def ispec(self):
977 return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
978
979 def ospec(self):
980 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
981
982 def elaborate(self, platform):
983 m = Module()
984
985 out_part_ops = self.i.part_ops
986 out_part_pts = self.i.part_pts
987
988 # create _output_64
989 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
990 m.d.comb += io64.intermed.eq(self.i.output)
991 for i in range(8):
992 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
993 m.d.comb += self.o.outputs[3].eq(io64.output)
994
995 # create _output_32
996 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
997 m.d.comb += io32.intermed.eq(self.i.output)
998 for i in range(8):
999 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1000 m.d.comb += self.o.outputs[2].eq(io32.output)
1001
1002 # create _output_16
1003 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1004 m.d.comb += io16.intermed.eq(self.i.output)
1005 for i in range(8):
1006 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1007 m.d.comb += self.o.outputs[1].eq(io16.output)
1008
1009 # create _output_8
1010 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1011 m.d.comb += io8.intermed.eq(self.i.output)
1012 for i in range(8):
1013 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1014 m.d.comb += self.o.outputs[0].eq(io8.output)
1015
1016 for i in range(8):
1017 m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
1018 m.d.comb += self.o.part_pts.eq(out_part_pts)
1019 m.d.comb += self.o.intermediate_output.eq(self.i.output)
1020
1021 return m
1022
1023
1024 class Mul8_16_32_64(Elaboratable):
1025 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1026
1027 XXX NOTE: this class is intended for unit test purposes ONLY.
1028
1029 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1030 partitions on naturally-aligned boundaries. Supports the operation being
1031 set for each partition independently.
1032
1033 :attribute part_pts: the input partition points. Has a partition point at
1034 multiples of 8 in 0 < i < 64. Each partition point's associated
1035 ``Value`` is a ``Signal``. Modification not supported, except for by
1036 ``Signal.eq``.
1037 :attribute part_ops: the operation for each byte. The operation for a
1038 particular partition is selected by assigning the selected operation
1039 code to each byte in the partition. The allowed operation codes are:
1040
1041 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1042 RISC-V's `mul` instruction.
1043 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1044 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1045 instruction.
1046 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1047 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1048 `mulhsu` instruction.
1049 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1050 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1051 instruction.
1052 """
1053
1054 def __init__(self, register_levels=()):
1055 """ register_levels: specifies the points in the cascade at which
1056 flip-flops are to be inserted.
1057 """
1058
1059 self.id_wid = 0 # num_bits(num_rows)
1060 self.op_wid = 0
1061 self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
1062 self.pspec.n_parts = 8
1063
1064 # parameter(s)
1065 self.register_levels = list(register_levels)
1066
1067 self.i = self.ispec()
1068 self.o = self.ospec()
1069
1070 # inputs
1071 self.part_pts = self.i.part_pts
1072 self.part_ops = self.i.part_ops
1073 self.a = self.i.a
1074 self.b = self.i.b
1075
1076 # output
1077 self.intermediate_output = self.o.intermediate_output
1078 self.output = self.o.output
1079
1080 def ispec(self):
1081 return InputData()
1082
1083 def ospec(self):
1084 return OutputData()
1085
1086 def elaborate(self, platform):
1087 m = Module()
1088
1089 part_pts = self.part_pts
1090
1091 n_inputs = 64 + 4
1092 t = AllTerms(self.pspec, n_inputs)
1093 t.setup(m, self.i)
1094
1095 terms = t.o.terms
1096
1097 at = AddReduceInternal(self.pspec, n_inputs, part_pts, partition_step=2)
1098
1099 i = t.o
1100 for idx in range(len(at.levels)):
1101 mcur = at.levels[idx]
1102 mcur.setup(m, i)
1103 o = mcur.ospec()
1104 if idx in self.register_levels:
1105 m.d.sync += o.eq(mcur.process(i))
1106 else:
1107 m.d.comb += o.eq(mcur.process(i))
1108 i = o # for next loop
1109
1110 interm = Intermediates(self.pspec, part_pts)
1111 interm.setup(m, i)
1112 o = interm.process(interm.i)
1113
1114 # final output
1115 finalout = FinalOut(self.pspec, part_pts)
1116 finalout.setup(m, o)
1117 m.d.comb += self.o.eq(finalout.process(o))
1118
1119 return m
1120
1121
1122 if __name__ == "__main__":
1123 m = Mul8_16_32_64()
1124 main(m, ports=[m.a,
1125 m.b,
1126 m.intermediate_output,
1127 m.output,
1128 *m.part_ops,
1129 *m.part_pts.values()])