use PipelineSpec and PipeModBase in AddReduce
[ieee754fpu.git] / src / ieee754 / part_mul_add / multiply.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3 """Integer Multiplication."""
4
5 from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl
6 from nmigen.hdl.ast import Assign
7 from abc import ABCMeta, abstractmethod
8 from nmigen.cli import main
9 from functools import reduce
10 from operator import or_
11 from ieee754.pipeline import PipelineSpec
12 from nmutil.pipemodbase import PipeModBase
13
14
15 class PartitionPoints(dict):
16 """Partition points and corresponding ``Value``s.
17
18 The points at where an ALU is partitioned along with ``Value``s that
19 specify if the corresponding partition points are enabled.
20
21 For example: ``{1: True, 5: True, 10: True}`` with
22 ``width == 16`` specifies that the ALU is split into 4 sections:
23 * bits 0 <= ``i`` < 1
24 * bits 1 <= ``i`` < 5
25 * bits 5 <= ``i`` < 10
26 * bits 10 <= ``i`` < 16
27
28 If the partition_points were instead ``{1: True, 5: a, 10: True}``
29 where ``a`` is a 1-bit ``Signal``:
30 * If ``a`` is asserted:
31 * bits 0 <= ``i`` < 1
32 * bits 1 <= ``i`` < 5
33 * bits 5 <= ``i`` < 10
34 * bits 10 <= ``i`` < 16
35 * Otherwise
36 * bits 0 <= ``i`` < 1
37 * bits 1 <= ``i`` < 10
38 * bits 10 <= ``i`` < 16
39 """
40
41 def __init__(self, partition_points=None):
42 """Create a new ``PartitionPoints``.
43
44 :param partition_points: the input partition points to values mapping.
45 """
46 super().__init__()
47 if partition_points is not None:
48 for point, enabled in partition_points.items():
49 if not isinstance(point, int):
50 raise TypeError("point must be a non-negative integer")
51 if point < 0:
52 raise ValueError("point must be a non-negative integer")
53 self[point] = Value.wrap(enabled)
54
55 def like(self, name=None, src_loc_at=0, mul=1):
56 """Create a new ``PartitionPoints`` with ``Signal``s for all values.
57
58 :param name: the base name for the new ``Signal``s.
59 :param mul: a multiplication factor on the indices
60 """
61 if name is None:
62 name = Signal(src_loc_at=1+src_loc_at).name # get variable name
63 retval = PartitionPoints()
64 for point, enabled in self.items():
65 point *= mul
66 retval[point] = Signal(enabled.shape(), name=f"{name}_{point}")
67 return retval
68
69 def eq(self, rhs):
70 """Assign ``PartitionPoints`` using ``Signal.eq``."""
71 if set(self.keys()) != set(rhs.keys()):
72 raise ValueError("incompatible point set")
73 for point, enabled in self.items():
74 yield enabled.eq(rhs[point])
75
76 def as_mask(self, width, mul=1):
77 """Create a bit-mask from `self`.
78
79 Each bit in the returned mask is clear only if the partition point at
80 the same bit-index is enabled.
81
82 :param width: the bit width of the resulting mask
83 :param mul: a "multiplier" which in-place expands the partition points
84 typically set to "2" when used for multipliers
85 """
86 bits = []
87 for i in range(width):
88 i /= mul
89 if i.is_integer() and int(i) in self:
90 bits.append(~self[i])
91 else:
92 bits.append(True)
93 return Cat(*bits)
94
95 def get_max_partition_count(self, width):
96 """Get the maximum number of partitions.
97
98 Gets the number of partitions when all partition points are enabled.
99 """
100 retval = 1
101 for point in self.keys():
102 if point < width:
103 retval += 1
104 return retval
105
106 def fits_in_width(self, width):
107 """Check if all partition points are smaller than `width`."""
108 for point in self.keys():
109 if point >= width:
110 return False
111 return True
112
113 def part_byte(self, index, mfactor=1): # mfactor used for "expanding"
114 if index == -1 or index == 7:
115 return C(True, 1)
116 assert index >= 0 and index < 8
117 return self[(index * 8 + 8)*mfactor]
118
119
120 class FullAdder(Elaboratable):
121 """Full Adder.
122
123 :attribute in0: the first input
124 :attribute in1: the second input
125 :attribute in2: the third input
126 :attribute sum: the sum output
127 :attribute carry: the carry output
128
129 Rather than do individual full adders (and have an array of them,
130 which would be very slow to simulate), this module can specify the
131 bit width of the inputs and outputs: in effect it performs multiple
132 Full 3-2 Add operations "in parallel".
133 """
134
135 def __init__(self, width):
136 """Create a ``FullAdder``.
137
138 :param width: the bit width of the input and output
139 """
140 self.in0 = Signal(width, reset_less=True)
141 self.in1 = Signal(width, reset_less=True)
142 self.in2 = Signal(width, reset_less=True)
143 self.sum = Signal(width, reset_less=True)
144 self.carry = Signal(width, reset_less=True)
145
146 def elaborate(self, platform):
147 """Elaborate this module."""
148 m = Module()
149 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
150 m.d.comb += self.carry.eq((self.in0 & self.in1)
151 | (self.in1 & self.in2)
152 | (self.in2 & self.in0))
153 return m
154
155
156 class MaskedFullAdder(Elaboratable):
157 """Masked Full Adder.
158
159 :attribute mask: the carry partition mask
160 :attribute in0: the first input
161 :attribute in1: the second input
162 :attribute in2: the third input
163 :attribute sum: the sum output
164 :attribute mcarry: the masked carry output
165
166 FullAdders are always used with a "mask" on the output. To keep
167 the graphviz "clean", this class performs the masking here rather
168 than inside a large for-loop.
169
170 See the following discussion as to why this is no longer derived
171 from FullAdder. Each carry is shifted here *before* being ANDed
172 with the mask, so that an AOI cell may be used (which is more
173 gate-efficient)
174 https://en.wikipedia.org/wiki/AND-OR-Invert
175 https://groups.google.com/d/msg/comp.arch/fcq-GLQqvas/vTxmcA0QAgAJ
176 """
177
178 def __init__(self, width):
179 """Create a ``MaskedFullAdder``.
180
181 :param width: the bit width of the input and output
182 """
183 self.width = width
184 self.mask = Signal(width, reset_less=True)
185 self.mcarry = Signal(width, reset_less=True)
186 self.in0 = Signal(width, reset_less=True)
187 self.in1 = Signal(width, reset_less=True)
188 self.in2 = Signal(width, reset_less=True)
189 self.sum = Signal(width, reset_less=True)
190
191 def elaborate(self, platform):
192 """Elaborate this module."""
193 m = Module()
194 s1 = Signal(self.width, reset_less=True)
195 s2 = Signal(self.width, reset_less=True)
196 s3 = Signal(self.width, reset_less=True)
197 c1 = Signal(self.width, reset_less=True)
198 c2 = Signal(self.width, reset_less=True)
199 c3 = Signal(self.width, reset_less=True)
200 m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2)
201 m.d.comb += s1.eq(Cat(0, self.in0))
202 m.d.comb += s2.eq(Cat(0, self.in1))
203 m.d.comb += s3.eq(Cat(0, self.in2))
204 m.d.comb += c1.eq(s1 & s2 & self.mask)
205 m.d.comb += c2.eq(s2 & s3 & self.mask)
206 m.d.comb += c3.eq(s3 & s1 & self.mask)
207 m.d.comb += self.mcarry.eq(c1 | c2 | c3)
208 return m
209
210
211 class PartitionedAdder(Elaboratable):
212 """Partitioned Adder.
213
214 Performs the final add. The partition points are included in the
215 actual add (in one of the operands only), which causes a carry over
216 to the next bit. Then the final output *removes* the extra bits from
217 the result.
218
219 partition: .... P... P... P... P... (32 bits)
220 a : .... .... .... .... .... (32 bits)
221 b : .... .... .... .... .... (32 bits)
222 exp-a : ....P....P....P....P.... (32+4 bits, P=1 if no partition)
223 exp-b : ....0....0....0....0.... (32 bits plus 4 zeros)
224 exp-o : ....xN...xN...xN...xN... (32+4 bits - x to be discarded)
225 o : .... N... N... N... N... (32 bits - x ignored, N is carry-over)
226
227 :attribute width: the bit width of the input and output. Read-only.
228 :attribute a: the first input to the adder
229 :attribute b: the second input to the adder
230 :attribute output: the sum output
231 :attribute partition_points: the input partition points. Modification not
232 supported, except for by ``Signal.eq``.
233 """
234
235 def __init__(self, width, partition_points, partition_step=1):
236 """Create a ``PartitionedAdder``.
237
238 :param width: the bit width of the input and output
239 :param partition_points: the input partition points
240 :param partition_step: a multiplier (typically double) step
241 which in-place "expands" the partition points
242 """
243 self.width = width
244 self.pmul = partition_step
245 self.a = Signal(width, reset_less=True)
246 self.b = Signal(width, reset_less=True)
247 self.output = Signal(width, reset_less=True)
248 self.partition_points = PartitionPoints(partition_points)
249 if not self.partition_points.fits_in_width(width):
250 raise ValueError("partition_points doesn't fit in width")
251 expanded_width = 0
252 for i in range(self.width):
253 if i in self.partition_points:
254 expanded_width += 1
255 expanded_width += 1
256 self._expanded_width = expanded_width
257
258 def elaborate(self, platform):
259 """Elaborate this module."""
260 m = Module()
261 expanded_a = Signal(self._expanded_width, reset_less=True)
262 expanded_b = Signal(self._expanded_width, reset_less=True)
263 expanded_o = Signal(self._expanded_width, reset_less=True)
264
265 expanded_index = 0
266 # store bits in a list, use Cat later. graphviz is much cleaner
267 al, bl, ol, ea, eb, eo = [],[],[],[],[],[]
268
269 # partition points are "breaks" (extra zeros or 1s) in what would
270 # otherwise be a massive long add. when the "break" points are 0,
271 # whatever is in it (in the output) is discarded. however when
272 # there is a "1", it causes a roll-over carry to the *next* bit.
273 # we still ignore the "break" bit in the [intermediate] output,
274 # however by that time we've got the effect that we wanted: the
275 # carry has been carried *over* the break point.
276
277 for i in range(self.width):
278 pi = i/self.pmul # double the range of the partition point test
279 if pi.is_integer() and pi in self.partition_points:
280 # add extra bit set to 0 + 0 for enabled partition points
281 # and 1 + 0 for disabled partition points
282 ea.append(expanded_a[expanded_index])
283 al.append(~self.partition_points[pi]) # add extra bit in a
284 eb.append(expanded_b[expanded_index])
285 bl.append(C(0)) # yes, add a zero
286 expanded_index += 1 # skip the extra point. NOT in the output
287 ea.append(expanded_a[expanded_index])
288 eb.append(expanded_b[expanded_index])
289 eo.append(expanded_o[expanded_index])
290 al.append(self.a[i])
291 bl.append(self.b[i])
292 ol.append(self.output[i])
293 expanded_index += 1
294
295 # combine above using Cat
296 m.d.comb += Cat(*ea).eq(Cat(*al))
297 m.d.comb += Cat(*eb).eq(Cat(*bl))
298 m.d.comb += Cat(*ol).eq(Cat(*eo))
299
300 # use only one addition to take advantage of look-ahead carry and
301 # special hardware on FPGAs
302 m.d.comb += expanded_o.eq(expanded_a + expanded_b)
303 return m
304
305
306 FULL_ADDER_INPUT_COUNT = 3
307
308 class AddReduceData:
309
310 def __init__(self, part_pts, n_inputs, output_width, n_parts):
311 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
312 for i in range(n_parts)]
313 self.terms = [Signal(output_width, name=f"inputs_{i}",
314 reset_less=True)
315 for i in range(n_inputs)]
316 self.part_pts = part_pts.like()
317
318 def eq_from(self, part_pts, inputs, part_ops):
319 return [self.part_pts.eq(part_pts)] + \
320 [self.terms[i].eq(inputs[i])
321 for i in range(len(self.terms))] + \
322 [self.part_ops[i].eq(part_ops[i])
323 for i in range(len(self.part_ops))]
324
325 def eq(self, rhs):
326 return self.eq_from(rhs.part_pts, rhs.terms, rhs.part_ops)
327
328
329 class FinalReduceData:
330
331 def __init__(self, part_pts, output_width, n_parts):
332 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
333 for i in range(n_parts)]
334 self.output = Signal(output_width, reset_less=True)
335 self.part_pts = part_pts.like()
336
337 def eq_from(self, part_pts, output, part_ops):
338 return [self.part_pts.eq(part_pts)] + \
339 [self.output.eq(output)] + \
340 [self.part_ops[i].eq(part_ops[i])
341 for i in range(len(self.part_ops))]
342
343 def eq(self, rhs):
344 return self.eq_from(rhs.part_pts, rhs.output, rhs.part_ops)
345
346
347 class FinalAdd(Elaboratable):
348 """ Final stage of add reduce
349 """
350
351 def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points,
352 partition_step=1):
353 self.lidx = lidx
354 self.partition_step = partition_step
355 self.output_width = output_width
356 self.n_inputs = n_inputs
357 self.n_parts = n_parts
358 self.partition_points = PartitionPoints(partition_points)
359 if not self.partition_points.fits_in_width(output_width):
360 raise ValueError("partition_points doesn't fit in output_width")
361
362 self.i = self.ispec()
363 self.o = self.ospec()
364
365 def ispec(self):
366 return AddReduceData(self.partition_points, self.n_inputs,
367 self.output_width, self.n_parts)
368
369 def ospec(self):
370 return FinalReduceData(self.partition_points,
371 self.output_width, self.n_parts)
372
373 def setup(self, m, i):
374 m.submodules.finaladd = self
375 m.d.comb += self.i.eq(i)
376
377 def process(self, i):
378 return self.o
379
380 def elaborate(self, platform):
381 """Elaborate this module."""
382 m = Module()
383
384 output_width = self.output_width
385 output = Signal(output_width, reset_less=True)
386 if self.n_inputs == 0:
387 # use 0 as the default output value
388 m.d.comb += output.eq(0)
389 elif self.n_inputs == 1:
390 # handle single input
391 m.d.comb += output.eq(self.i.terms[0])
392 else:
393 # base case for adding 2 inputs
394 assert self.n_inputs == 2
395 adder = PartitionedAdder(output_width,
396 self.i.part_pts, self.partition_step)
397 m.submodules.final_adder = adder
398 m.d.comb += adder.a.eq(self.i.terms[0])
399 m.d.comb += adder.b.eq(self.i.terms[1])
400 m.d.comb += output.eq(adder.output)
401
402 # create output
403 m.d.comb += self.o.eq_from(self.i.part_pts, output,
404 self.i.part_ops)
405
406 return m
407
408
409 class AddReduceSingle(PipeModBase):
410 """Add list of numbers together.
411
412 :attribute inputs: input ``Signal``s to be summed. Modification not
413 supported, except for by ``Signal.eq``.
414 :attribute register_levels: List of nesting levels that should have
415 pipeline registers.
416 :attribute output: output sum.
417 :attribute partition_points: the input partition points. Modification not
418 supported, except for by ``Signal.eq``.
419 """
420
421 def __init__(self, pspec, lidx, n_inputs, partition_points,
422 partition_step=1):
423 """Create an ``AddReduce``.
424
425 :param inputs: input ``Signal``s to be summed.
426 :param output_width: bit-width of ``output``.
427 :param partition_points: the input partition points.
428 """
429 self.lidx = lidx
430 self.partition_step = partition_step
431 self.n_inputs = n_inputs
432 self.n_parts = pspec.n_parts
433 self.output_width = pspec.width * 2
434 self.partition_points = PartitionPoints(partition_points)
435 if not self.partition_points.fits_in_width(self.output_width):
436 raise ValueError("partition_points doesn't fit in output_width")
437
438 self.groups = AddReduceSingle.full_adder_groups(n_inputs)
439 self.n_terms = AddReduceSingle.calc_n_inputs(n_inputs, self.groups)
440
441 super().__init__(pspec, "addreduce_%d" % lidx)
442
443 def ispec(self):
444 return AddReduceData(self.partition_points, self.n_inputs,
445 self.output_width, self.n_parts)
446
447 def ospec(self):
448 return AddReduceData(self.partition_points, self.n_terms,
449 self.output_width, self.n_parts)
450
451 @staticmethod
452 def calc_n_inputs(n_inputs, groups):
453 retval = len(groups)*2
454 if n_inputs % FULL_ADDER_INPUT_COUNT == 1:
455 retval += 1
456 elif n_inputs % FULL_ADDER_INPUT_COUNT == 2:
457 retval += 2
458 else:
459 assert n_inputs % FULL_ADDER_INPUT_COUNT == 0
460 return retval
461
462 @staticmethod
463 def get_max_level(input_count):
464 """Get the maximum level.
465
466 All ``register_levels`` must be less than or equal to the maximum
467 level.
468 """
469 retval = 0
470 while True:
471 groups = AddReduceSingle.full_adder_groups(input_count)
472 if len(groups) == 0:
473 return retval
474 input_count %= FULL_ADDER_INPUT_COUNT
475 input_count += 2 * len(groups)
476 retval += 1
477
478 @staticmethod
479 def full_adder_groups(input_count):
480 """Get ``inputs`` indices for which a full adder should be built."""
481 return range(0,
482 input_count - FULL_ADDER_INPUT_COUNT + 1,
483 FULL_ADDER_INPUT_COUNT)
484
485 def create_next_terms(self):
486 """ create next intermediate terms, for linking up in elaborate, below
487 """
488 terms = []
489 adders = []
490
491 # create full adders for this recursive level.
492 # this shrinks N terms to 2 * (N // 3) plus the remainder
493 for i in self.groups:
494 adder_i = MaskedFullAdder(self.output_width)
495 adders.append((i, adder_i))
496 # add both the sum and the masked-carry to the next level.
497 # 3 inputs have now been reduced to 2...
498 terms.append(adder_i.sum)
499 terms.append(adder_i.mcarry)
500 # handle the remaining inputs.
501 if self.n_inputs % FULL_ADDER_INPUT_COUNT == 1:
502 terms.append(self.i.terms[-1])
503 elif self.n_inputs % FULL_ADDER_INPUT_COUNT == 2:
504 # Just pass the terms to the next layer, since we wouldn't gain
505 # anything by using a half adder since there would still be 2 terms
506 # and just passing the terms to the next layer saves gates.
507 terms.append(self.i.terms[-2])
508 terms.append(self.i.terms[-1])
509 else:
510 assert self.n_inputs % FULL_ADDER_INPUT_COUNT == 0
511
512 return terms, adders
513
514 def elaborate(self, platform):
515 """Elaborate this module."""
516 m = Module()
517
518 terms, adders = self.create_next_terms()
519
520 # copy the intermediate terms to the output
521 for i, value in enumerate(terms):
522 m.d.comb += self.o.terms[i].eq(value)
523
524 # copy reg part points and part ops to output
525 m.d.comb += self.o.part_pts.eq(self.i.part_pts)
526 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
527 for i in range(len(self.i.part_ops))]
528
529 # set up the partition mask (for the adders)
530 part_mask = Signal(self.output_width, reset_less=True)
531
532 # get partition points as a mask
533 mask = self.i.part_pts.as_mask(self.output_width,
534 mul=self.partition_step)
535 m.d.comb += part_mask.eq(mask)
536
537 # add and link the intermediate term modules
538 for i, (iidx, adder_i) in enumerate(adders):
539 setattr(m.submodules, f"adder_{i}", adder_i)
540
541 m.d.comb += adder_i.in0.eq(self.i.terms[iidx])
542 m.d.comb += adder_i.in1.eq(self.i.terms[iidx + 1])
543 m.d.comb += adder_i.in2.eq(self.i.terms[iidx + 2])
544 m.d.comb += adder_i.mask.eq(part_mask)
545
546 return m
547
548
549 class AddReduceInternal:
550 """Recursively Add list of numbers together.
551
552 :attribute inputs: input ``Signal``s to be summed. Modification not
553 supported, except for by ``Signal.eq``.
554 :attribute register_levels: List of nesting levels that should have
555 pipeline registers.
556 :attribute output: output sum.
557 :attribute partition_points: the input partition points. Modification not
558 supported, except for by ``Signal.eq``.
559 """
560
561 def __init__(self, i, pspec, partition_step=1):
562 """Create an ``AddReduce``.
563
564 :param inputs: input ``Signal``s to be summed.
565 :param output_width: bit-width of ``output``.
566 :param partition_points: the input partition points.
567 """
568 self.i = i
569 self.pspec = pspec
570 self.inputs = i.terms
571 self.part_ops = i.part_ops
572 self.output_width = pspec.width * 2
573 self.partition_points = i.part_pts
574 self.partition_step = partition_step
575
576 self.create_levels()
577
578 def create_levels(self):
579 """creates reduction levels"""
580
581 mods = []
582 partition_points = self.partition_points
583 part_ops = self.part_ops
584 n_parts = len(part_ops)
585 inputs = self.inputs
586 ilen = len(inputs)
587 while True:
588 groups = AddReduceSingle.full_adder_groups(len(inputs))
589 if len(groups) == 0:
590 break
591 lidx = len(mods)
592 next_level = AddReduceSingle(self.pspec, lidx, ilen,
593 partition_points,
594 self.partition_step)
595 mods.append(next_level)
596 partition_points = next_level.i.part_pts
597 inputs = next_level.o.terms
598 ilen = len(inputs)
599 part_ops = next_level.i.part_ops
600
601 lidx = len(mods)
602 next_level = FinalAdd(lidx, ilen, self.output_width, n_parts,
603 partition_points, self.partition_step)
604 mods.append(next_level)
605
606 self.levels = mods
607
608
609 class AddReduce(AddReduceInternal, Elaboratable):
610 """Recursively Add list of numbers together.
611
612 :attribute inputs: input ``Signal``s to be summed. Modification not
613 supported, except for by ``Signal.eq``.
614 :attribute register_levels: List of nesting levels that should have
615 pipeline registers.
616 :attribute output: output sum.
617 :attribute partition_points: the input partition points. Modification not
618 supported, except for by ``Signal.eq``.
619 """
620
621 def __init__(self, inputs, output_width, register_levels, part_pts,
622 part_ops, partition_step=1):
623 """Create an ``AddReduce``.
624
625 :param inputs: input ``Signal``s to be summed.
626 :param output_width: bit-width of ``output``.
627 :param register_levels: List of nesting levels that should have
628 pipeline registers.
629 :param partition_points: the input partition points.
630 """
631 self._inputs = inputs
632 self._part_pts = part_pts
633 self._part_ops = part_ops
634 n_parts = len(part_ops)
635 self.i = AddReduceData(part_pts, len(inputs),
636 output_width, n_parts)
637 AddReduceInternal.__init__(self, self.i, output_width, partition_step)
638 self.o = FinalReduceData(part_pts, output_width, n_parts)
639 self.register_levels = register_levels
640
641 @staticmethod
642 def get_max_level(input_count):
643 return AddReduceSingle.get_max_level(input_count)
644
645 @staticmethod
646 def next_register_levels(register_levels):
647 """``Iterable`` of ``register_levels`` for next recursive level."""
648 for level in register_levels:
649 if level > 0:
650 yield level - 1
651
652 def elaborate(self, platform):
653 """Elaborate this module."""
654 m = Module()
655
656 m.d.comb += self.i.eq_from(self._part_pts, self._inputs, self._part_ops)
657
658 for i, next_level in enumerate(self.levels):
659 setattr(m.submodules, "next_level%d" % i, next_level)
660
661 i = self.i
662 for idx in range(len(self.levels)):
663 mcur = self.levels[idx]
664 if idx in self.register_levels:
665 m.d.sync += mcur.i.eq(i)
666 else:
667 m.d.comb += mcur.i.eq(i)
668 i = mcur.o # for next loop
669
670 # output comes from last module
671 m.d.comb += self.o.eq(i)
672
673 return m
674
675
676 OP_MUL_LOW = 0
677 OP_MUL_SIGNED_HIGH = 1
678 OP_MUL_SIGNED_UNSIGNED_HIGH = 2 # a is signed, b is unsigned
679 OP_MUL_UNSIGNED_HIGH = 3
680
681
682 def get_term(value, shift=0, enabled=None):
683 if enabled is not None:
684 value = Mux(enabled, value, 0)
685 if shift > 0:
686 value = Cat(Repl(C(0, 1), shift), value)
687 else:
688 assert shift == 0
689 return value
690
691
692 class ProductTerm(Elaboratable):
693 """ this class creates a single product term (a[..]*b[..]).
694 it has a design flaw in that is the *output* that is selected,
695 where the multiplication(s) are combinatorially generated
696 all the time.
697 """
698
699 def __init__(self, width, twidth, pbwid, a_index, b_index):
700 self.a_index = a_index
701 self.b_index = b_index
702 shift = 8 * (self.a_index + self.b_index)
703 self.pwidth = width
704 self.twidth = twidth
705 self.width = width*2
706 self.shift = shift
707
708 self.ti = Signal(self.width, reset_less=True)
709 self.term = Signal(twidth, reset_less=True)
710 self.a = Signal(twidth//2, reset_less=True)
711 self.b = Signal(twidth//2, reset_less=True)
712 self.pb_en = Signal(pbwid, reset_less=True)
713
714 self.tl = tl = []
715 min_index = min(self.a_index, self.b_index)
716 max_index = max(self.a_index, self.b_index)
717 for i in range(min_index, max_index):
718 tl.append(self.pb_en[i])
719 name = "te_%d_%d" % (self.a_index, self.b_index)
720 if len(tl) > 0:
721 term_enabled = Signal(name=name, reset_less=True)
722 else:
723 term_enabled = None
724 self.enabled = term_enabled
725 self.term.name = "term_%d_%d" % (a_index, b_index) # rename
726
727 def elaborate(self, platform):
728
729 m = Module()
730 if self.enabled is not None:
731 m.d.comb += self.enabled.eq(~(Cat(*self.tl).bool()))
732
733 bsa = Signal(self.width, reset_less=True)
734 bsb = Signal(self.width, reset_less=True)
735 a_index, b_index = self.a_index, self.b_index
736 pwidth = self.pwidth
737 m.d.comb += bsa.eq(self.a.bit_select(a_index * pwidth, pwidth))
738 m.d.comb += bsb.eq(self.b.bit_select(b_index * pwidth, pwidth))
739 m.d.comb += self.ti.eq(bsa * bsb)
740 m.d.comb += self.term.eq(get_term(self.ti, self.shift, self.enabled))
741 """
742 #TODO: sort out width issues, get inputs a/b switched on/off.
743 #data going into Muxes is 1/2 the required width
744
745 pwidth = self.pwidth
746 width = self.width
747 bsa = Signal(self.twidth//2, reset_less=True)
748 bsb = Signal(self.twidth//2, reset_less=True)
749 asel = Signal(width, reset_less=True)
750 bsel = Signal(width, reset_less=True)
751 a_index, b_index = self.a_index, self.b_index
752 m.d.comb += asel.eq(self.a.bit_select(a_index * pwidth, pwidth))
753 m.d.comb += bsel.eq(self.b.bit_select(b_index * pwidth, pwidth))
754 m.d.comb += bsa.eq(get_term(asel, self.shift, self.enabled))
755 m.d.comb += bsb.eq(get_term(bsel, self.shift, self.enabled))
756 m.d.comb += self.ti.eq(bsa * bsb)
757 m.d.comb += self.term.eq(self.ti)
758 """
759
760 return m
761
762
763 class ProductTerms(Elaboratable):
764 """ creates a bank of product terms. also performs the actual bit-selection
765 this class is to be wrapped with a for-loop on the "a" operand.
766 it creates a second-level for-loop on the "b" operand.
767 """
768 def __init__(self, width, twidth, pbwid, a_index, blen):
769 self.a_index = a_index
770 self.blen = blen
771 self.pwidth = width
772 self.twidth = twidth
773 self.pbwid = pbwid
774 self.a = Signal(twidth//2, reset_less=True)
775 self.b = Signal(twidth//2, reset_less=True)
776 self.pb_en = Signal(pbwid, reset_less=True)
777 self.terms = [Signal(twidth, name="term%d"%i, reset_less=True) \
778 for i in range(blen)]
779
780 def elaborate(self, platform):
781
782 m = Module()
783
784 for b_index in range(self.blen):
785 t = ProductTerm(self.pwidth, self.twidth, self.pbwid,
786 self.a_index, b_index)
787 setattr(m.submodules, "term_%d" % b_index, t)
788
789 m.d.comb += t.a.eq(self.a)
790 m.d.comb += t.b.eq(self.b)
791 m.d.comb += t.pb_en.eq(self.pb_en)
792
793 m.d.comb += self.terms[b_index].eq(t.term)
794
795 return m
796
797
798 class LSBNegTerm(Elaboratable):
799
800 def __init__(self, bit_width):
801 self.bit_width = bit_width
802 self.part = Signal(reset_less=True)
803 self.signed = Signal(reset_less=True)
804 self.op = Signal(bit_width, reset_less=True)
805 self.msb = Signal(reset_less=True)
806 self.nt = Signal(bit_width*2, reset_less=True)
807 self.nl = Signal(bit_width*2, reset_less=True)
808
809 def elaborate(self, platform):
810 m = Module()
811 comb = m.d.comb
812 bit_wid = self.bit_width
813 ext = Repl(0, bit_wid) # extend output to HI part
814
815 # determine sign of each incoming number *in this partition*
816 enabled = Signal(reset_less=True)
817 m.d.comb += enabled.eq(self.part & self.msb & self.signed)
818
819 # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
820 # negation operation is split into a bitwise not and a +1.
821 # likewise for 16, 32, and 64-bit values.
822
823 # width-extended 1s complement if a is signed, otherwise zero
824 comb += self.nt.eq(Mux(enabled, Cat(ext, ~self.op), 0))
825
826 # add 1 if signed, otherwise add zero
827 comb += self.nl.eq(Cat(ext, enabled, Repl(0, bit_wid-1)))
828
829 return m
830
831
832 class Parts(Elaboratable):
833
834 def __init__(self, pbwid, part_pts, n_parts):
835 self.pbwid = pbwid
836 # inputs
837 self.part_pts = PartitionPoints.like(part_pts)
838 # outputs
839 self.parts = [Signal(name=f"part_{i}", reset_less=True)
840 for i in range(n_parts)]
841
842 def elaborate(self, platform):
843 m = Module()
844
845 part_pts, parts = self.part_pts, self.parts
846 # collect part-bytes (double factor because the input is extended)
847 pbs = Signal(self.pbwid, reset_less=True)
848 tl = []
849 for i in range(self.pbwid):
850 pb = Signal(name="pb%d" % i, reset_less=True)
851 m.d.comb += pb.eq(part_pts.part_byte(i))
852 tl.append(pb)
853 m.d.comb += pbs.eq(Cat(*tl))
854
855 # negated-temporary copy of partition bits
856 npbs = Signal.like(pbs, reset_less=True)
857 m.d.comb += npbs.eq(~pbs)
858 byte_count = 8 // len(parts)
859 for i in range(len(parts)):
860 pbl = []
861 pbl.append(npbs[i * byte_count - 1])
862 for j in range(i * byte_count, (i + 1) * byte_count - 1):
863 pbl.append(pbs[j])
864 pbl.append(npbs[(i + 1) * byte_count - 1])
865 value = Signal(len(pbl), name="value_%d" % i, reset_less=True)
866 m.d.comb += value.eq(Cat(*pbl))
867 m.d.comb += parts[i].eq(~(value).bool())
868
869 return m
870
871
872 class Part(Elaboratable):
873 """ a key class which, depending on the partitioning, will determine
874 what action to take when parts of the output are signed or unsigned.
875
876 this requires 2 pieces of data *per operand, per partition*:
877 whether the MSB is HI/LO (per partition!), and whether a signed
878 or unsigned operation has been *requested*.
879
880 once that is determined, signed is basically carried out
881 by splitting 2's complement into 1's complement plus one.
882 1's complement is just a bit-inversion.
883
884 the extra terms - as separate terms - are then thrown at the
885 AddReduce alongside the multiplication part-results.
886 """
887 def __init__(self, part_pts, width, n_parts, pbwid):
888
889 self.pbwid = pbwid
890 self.part_pts = part_pts
891
892 # inputs
893 self.a = Signal(64, reset_less=True)
894 self.b = Signal(64, reset_less=True)
895 self.a_signed = [Signal(name=f"a_signed_{i}", reset_less=True)
896 for i in range(8)]
897 self.b_signed = [Signal(name=f"_b_signed_{i}", reset_less=True)
898 for i in range(8)]
899 self.pbs = Signal(pbwid, reset_less=True)
900
901 # outputs
902 self.parts = [Signal(name=f"part_{i}", reset_less=True)
903 for i in range(n_parts)]
904
905 self.not_a_term = Signal(width, reset_less=True)
906 self.neg_lsb_a_term = Signal(width, reset_less=True)
907 self.not_b_term = Signal(width, reset_less=True)
908 self.neg_lsb_b_term = Signal(width, reset_less=True)
909
910 def elaborate(self, platform):
911 m = Module()
912
913 pbs, parts = self.pbs, self.parts
914 part_pts = self.part_pts
915 m.submodules.p = p = Parts(self.pbwid, part_pts, len(parts))
916 m.d.comb += p.part_pts.eq(part_pts)
917 parts = p.parts
918
919 byte_count = 8 // len(parts)
920
921 not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = (
922 self.not_a_term, self.neg_lsb_a_term,
923 self.not_b_term, self.neg_lsb_b_term)
924
925 byte_width = 8 // len(parts) # byte width
926 bit_wid = 8 * byte_width # bit width
927 nat, nbt, nla, nlb = [], [], [], []
928 for i in range(len(parts)):
929 # work out bit-inverted and +1 term for a.
930 pa = LSBNegTerm(bit_wid)
931 setattr(m.submodules, "lnt_%d_a_%d" % (bit_wid, i), pa)
932 m.d.comb += pa.part.eq(parts[i])
933 m.d.comb += pa.op.eq(self.a.bit_select(bit_wid * i, bit_wid))
934 m.d.comb += pa.signed.eq(self.b_signed[i * byte_width]) # yes b
935 m.d.comb += pa.msb.eq(self.b[(i + 1) * bit_wid - 1]) # really, b
936 nat.append(pa.nt)
937 nla.append(pa.nl)
938
939 # work out bit-inverted and +1 term for b
940 pb = LSBNegTerm(bit_wid)
941 setattr(m.submodules, "lnt_%d_b_%d" % (bit_wid, i), pb)
942 m.d.comb += pb.part.eq(parts[i])
943 m.d.comb += pb.op.eq(self.b.bit_select(bit_wid * i, bit_wid))
944 m.d.comb += pb.signed.eq(self.a_signed[i * byte_width]) # yes a
945 m.d.comb += pb.msb.eq(self.a[(i + 1) * bit_wid - 1]) # really, a
946 nbt.append(pb.nt)
947 nlb.append(pb.nl)
948
949 # concatenate together and return all 4 results.
950 m.d.comb += [not_a_term.eq(Cat(*nat)),
951 not_b_term.eq(Cat(*nbt)),
952 neg_lsb_a_term.eq(Cat(*nla)),
953 neg_lsb_b_term.eq(Cat(*nlb)),
954 ]
955
956 return m
957
958
959 class IntermediateOut(Elaboratable):
960 """ selects the HI/LO part of the multiplication, for a given bit-width
961 the output is also reconstructed in its SIMD (partition) lanes.
962 """
963 def __init__(self, width, out_wid, n_parts):
964 self.width = width
965 self.n_parts = n_parts
966 self.part_ops = [Signal(2, name="dpop%d" % i, reset_less=True)
967 for i in range(8)]
968 self.intermed = Signal(out_wid, reset_less=True)
969 self.output = Signal(out_wid//2, reset_less=True)
970
971 def elaborate(self, platform):
972 m = Module()
973
974 ol = []
975 w = self.width
976 sel = w // 8
977 for i in range(self.n_parts):
978 op = Signal(w, reset_less=True, name="op%d_%d" % (w, i))
979 m.d.comb += op.eq(
980 Mux(self.part_ops[sel * i] == OP_MUL_LOW,
981 self.intermed.bit_select(i * w*2, w),
982 self.intermed.bit_select(i * w*2 + w, w)))
983 ol.append(op)
984 m.d.comb += self.output.eq(Cat(*ol))
985
986 return m
987
988
989 class FinalOut(Elaboratable):
990 """ selects the final output based on the partitioning.
991
992 each byte is selectable independently, i.e. it is possible
993 that some partitions requested 8-bit computation whilst others
994 requested 16 or 32 bit.
995 """
996 def __init__(self, output_width, n_parts, part_pts):
997 self.part_pts = part_pts
998 self.output_width = output_width
999 self.n_parts = n_parts
1000 self.out_wid = output_width//2
1001
1002 self.i = self.ispec()
1003 self.o = self.ospec()
1004
1005 def ispec(self):
1006 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
1007
1008 def ospec(self):
1009 return OutputData()
1010
1011 def setup(self, m, i):
1012 m.submodules.finalout = self
1013 m.d.comb += self.i.eq(i)
1014
1015 def process(self, i):
1016 return self.o
1017
1018 def elaborate(self, platform):
1019 m = Module()
1020
1021 part_pts = self.part_pts
1022 m.submodules.p_8 = p_8 = Parts(8, part_pts, 8)
1023 m.submodules.p_16 = p_16 = Parts(8, part_pts, 4)
1024 m.submodules.p_32 = p_32 = Parts(8, part_pts, 2)
1025 m.submodules.p_64 = p_64 = Parts(8, part_pts, 1)
1026
1027 out_part_pts = self.i.part_pts
1028
1029 # temporaries
1030 d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)]
1031 d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)]
1032 d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)]
1033
1034 i8 = Signal(self.out_wid, reset_less=True)
1035 i16 = Signal(self.out_wid, reset_less=True)
1036 i32 = Signal(self.out_wid, reset_less=True)
1037 i64 = Signal(self.out_wid, reset_less=True)
1038
1039 m.d.comb += p_8.part_pts.eq(out_part_pts)
1040 m.d.comb += p_16.part_pts.eq(out_part_pts)
1041 m.d.comb += p_32.part_pts.eq(out_part_pts)
1042 m.d.comb += p_64.part_pts.eq(out_part_pts)
1043
1044 for i in range(len(p_8.parts)):
1045 m.d.comb += d8[i].eq(p_8.parts[i])
1046 for i in range(len(p_16.parts)):
1047 m.d.comb += d16[i].eq(p_16.parts[i])
1048 for i in range(len(p_32.parts)):
1049 m.d.comb += d32[i].eq(p_32.parts[i])
1050 m.d.comb += i8.eq(self.i.outputs[0])
1051 m.d.comb += i16.eq(self.i.outputs[1])
1052 m.d.comb += i32.eq(self.i.outputs[2])
1053 m.d.comb += i64.eq(self.i.outputs[3])
1054
1055 ol = []
1056 for i in range(8):
1057 # select one of the outputs: d8 selects i8, d16 selects i16
1058 # d32 selects i32, and the default is i64.
1059 # d8 and d16 are ORed together in the first Mux
1060 # then the 2nd selects either i8 or i16.
1061 # if neither d8 nor d16 are set, d32 selects either i32 or i64.
1062 op = Signal(8, reset_less=True, name="op_%d" % i)
1063 m.d.comb += op.eq(
1064 Mux(d8[i] | d16[i // 2],
1065 Mux(d8[i], i8.bit_select(i * 8, 8),
1066 i16.bit_select(i * 8, 8)),
1067 Mux(d32[i // 4], i32.bit_select(i * 8, 8),
1068 i64.bit_select(i * 8, 8))))
1069 ol.append(op)
1070
1071 # create outputs
1072 m.d.comb += self.o.output.eq(Cat(*ol))
1073 m.d.comb += self.o.intermediate_output.eq(self.i.intermediate_output)
1074
1075 return m
1076
1077
1078 class OrMod(Elaboratable):
1079 """ ORs four values together in a hierarchical tree
1080 """
1081 def __init__(self, wid):
1082 self.wid = wid
1083 self.orin = [Signal(wid, name="orin%d" % i, reset_less=True)
1084 for i in range(4)]
1085 self.orout = Signal(wid, reset_less=True)
1086
1087 def elaborate(self, platform):
1088 m = Module()
1089 or1 = Signal(self.wid, reset_less=True)
1090 or2 = Signal(self.wid, reset_less=True)
1091 m.d.comb += or1.eq(self.orin[0] | self.orin[1])
1092 m.d.comb += or2.eq(self.orin[2] | self.orin[3])
1093 m.d.comb += self.orout.eq(or1 | or2)
1094
1095 return m
1096
1097
1098 class Signs(Elaboratable):
1099 """ determines whether a or b are signed numbers
1100 based on the required operation type (OP_MUL_*)
1101 """
1102
1103 def __init__(self):
1104 self.part_ops = Signal(2, reset_less=True)
1105 self.a_signed = Signal(reset_less=True)
1106 self.b_signed = Signal(reset_less=True)
1107
1108 def elaborate(self, platform):
1109
1110 m = Module()
1111
1112 asig = self.part_ops != OP_MUL_UNSIGNED_HIGH
1113 bsig = (self.part_ops == OP_MUL_LOW) \
1114 | (self.part_ops == OP_MUL_SIGNED_HIGH)
1115 m.d.comb += self.a_signed.eq(asig)
1116 m.d.comb += self.b_signed.eq(bsig)
1117
1118 return m
1119
1120
1121 class IntermediateData:
1122
1123 def __init__(self, part_pts, output_width, n_parts):
1124 self.part_ops = [Signal(2, name=f"part_ops_{i}", reset_less=True)
1125 for i in range(n_parts)]
1126 self.part_pts = part_pts.like()
1127 self.outputs = [Signal(output_width, name="io%d" % i, reset_less=True)
1128 for i in range(4)]
1129 # intermediates (needed for unit tests)
1130 self.intermediate_output = Signal(output_width)
1131
1132 def eq_from(self, part_pts, outputs, intermediate_output,
1133 part_ops):
1134 return [self.part_pts.eq(part_pts)] + \
1135 [self.intermediate_output.eq(intermediate_output)] + \
1136 [self.outputs[i].eq(outputs[i])
1137 for i in range(4)] + \
1138 [self.part_ops[i].eq(part_ops[i])
1139 for i in range(len(self.part_ops))]
1140
1141 def eq(self, rhs):
1142 return self.eq_from(rhs.part_pts, rhs.outputs,
1143 rhs.intermediate_output, rhs.part_ops)
1144
1145
1146 class InputData:
1147
1148 def __init__(self):
1149 self.a = Signal(64)
1150 self.b = Signal(64)
1151 self.part_pts = PartitionPoints()
1152 for i in range(8, 64, 8):
1153 self.part_pts[i] = Signal(name=f"part_pts_{i}")
1154 self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)]
1155
1156 def eq_from(self, part_pts, a, b, part_ops):
1157 return [self.part_pts.eq(part_pts)] + \
1158 [self.a.eq(a), self.b.eq(b)] + \
1159 [self.part_ops[i].eq(part_ops[i])
1160 for i in range(len(self.part_ops))]
1161
1162 def eq(self, rhs):
1163 return self.eq_from(rhs.part_pts, rhs.a, rhs.b, rhs.part_ops)
1164
1165
1166 class OutputData:
1167
1168 def __init__(self):
1169 self.intermediate_output = Signal(128) # needed for unit tests
1170 self.output = Signal(64)
1171
1172 def eq(self, rhs):
1173 return [self.intermediate_output.eq(rhs.intermediate_output),
1174 self.output.eq(rhs.output)]
1175
1176
1177 class AllTerms(PipeModBase):
1178 """Set of terms to be added together
1179 """
1180
1181 def __init__(self, pspec, n_inputs):
1182 """Create an ``AllTerms``.
1183 """
1184 self.n_inputs = n_inputs
1185 self.n_parts = pspec.n_parts
1186 self.output_width = pspec.width * 2
1187 super().__init__(pspec, "allterms")
1188
1189 def ispec(self):
1190 return InputData()
1191
1192 def ospec(self):
1193 return AddReduceData(self.i.part_pts, self.n_inputs,
1194 self.output_width, self.n_parts)
1195
1196 def elaborate(self, platform):
1197 m = Module()
1198
1199 eps = self.i.part_pts
1200
1201 # collect part-bytes
1202 pbs = Signal(8, reset_less=True)
1203 tl = []
1204 for i in range(8):
1205 pb = Signal(name="pb%d" % i, reset_less=True)
1206 m.d.comb += pb.eq(eps.part_byte(i))
1207 tl.append(pb)
1208 m.d.comb += pbs.eq(Cat(*tl))
1209
1210 # local variables
1211 signs = []
1212 for i in range(8):
1213 s = Signs()
1214 signs.append(s)
1215 setattr(m.submodules, "signs%d" % i, s)
1216 m.d.comb += s.part_ops.eq(self.i.part_ops[i])
1217
1218 m.submodules.part_8 = part_8 = Part(eps, 128, 8, 8)
1219 m.submodules.part_16 = part_16 = Part(eps, 128, 4, 8)
1220 m.submodules.part_32 = part_32 = Part(eps, 128, 2, 8)
1221 m.submodules.part_64 = part_64 = Part(eps, 128, 1, 8)
1222 nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
1223 for mod in [part_8, part_16, part_32, part_64]:
1224 m.d.comb += mod.a.eq(self.i.a)
1225 m.d.comb += mod.b.eq(self.i.b)
1226 for i in range(len(signs)):
1227 m.d.comb += mod.a_signed[i].eq(signs[i].a_signed)
1228 m.d.comb += mod.b_signed[i].eq(signs[i].b_signed)
1229 m.d.comb += mod.pbs.eq(pbs)
1230 nat_l.append(mod.not_a_term)
1231 nbt_l.append(mod.not_b_term)
1232 nla_l.append(mod.neg_lsb_a_term)
1233 nlb_l.append(mod.neg_lsb_b_term)
1234
1235 terms = []
1236
1237 for a_index in range(8):
1238 t = ProductTerms(8, 128, 8, a_index, 8)
1239 setattr(m.submodules, "terms_%d" % a_index, t)
1240
1241 m.d.comb += t.a.eq(self.i.a)
1242 m.d.comb += t.b.eq(self.i.b)
1243 m.d.comb += t.pb_en.eq(pbs)
1244
1245 for term in t.terms:
1246 terms.append(term)
1247
1248 # it's fine to bitwise-or data together since they are never enabled
1249 # at the same time
1250 m.submodules.nat_or = nat_or = OrMod(128)
1251 m.submodules.nbt_or = nbt_or = OrMod(128)
1252 m.submodules.nla_or = nla_or = OrMod(128)
1253 m.submodules.nlb_or = nlb_or = OrMod(128)
1254 for l, mod in [(nat_l, nat_or),
1255 (nbt_l, nbt_or),
1256 (nla_l, nla_or),
1257 (nlb_l, nlb_or)]:
1258 for i in range(len(l)):
1259 m.d.comb += mod.orin[i].eq(l[i])
1260 terms.append(mod.orout)
1261
1262 # copy the intermediate terms to the output
1263 for i, value in enumerate(terms):
1264 m.d.comb += self.o.terms[i].eq(value)
1265
1266 # copy reg part points and part ops to output
1267 m.d.comb += self.o.part_pts.eq(eps)
1268 m.d.comb += [self.o.part_ops[i].eq(self.i.part_ops[i])
1269 for i in range(len(self.i.part_ops))]
1270
1271 return m
1272
1273
1274 class Intermediates(Elaboratable):
1275 """ Intermediate output modules
1276 """
1277
1278 def __init__(self, output_width, n_parts, part_pts):
1279 self.part_pts = part_pts
1280 self.output_width = output_width
1281 self.n_parts = n_parts
1282
1283 self.i = self.ispec()
1284 self.o = self.ospec()
1285
1286 def ispec(self):
1287 return FinalReduceData(self.part_pts, self.output_width, self.n_parts)
1288
1289 def ospec(self):
1290 return IntermediateData(self.part_pts, self.output_width, self.n_parts)
1291
1292 def setup(self, m, i):
1293 m.submodules.intermediates = self
1294 m.d.comb += self.i.eq(i)
1295
1296 def process(self, i):
1297 return self.o
1298
1299 def elaborate(self, platform):
1300 m = Module()
1301
1302 out_part_ops = self.i.part_ops
1303 out_part_pts = self.i.part_pts
1304
1305 # create _output_64
1306 m.submodules.io64 = io64 = IntermediateOut(64, 128, 1)
1307 m.d.comb += io64.intermed.eq(self.i.output)
1308 for i in range(8):
1309 m.d.comb += io64.part_ops[i].eq(out_part_ops[i])
1310 m.d.comb += self.o.outputs[3].eq(io64.output)
1311
1312 # create _output_32
1313 m.submodules.io32 = io32 = IntermediateOut(32, 128, 2)
1314 m.d.comb += io32.intermed.eq(self.i.output)
1315 for i in range(8):
1316 m.d.comb += io32.part_ops[i].eq(out_part_ops[i])
1317 m.d.comb += self.o.outputs[2].eq(io32.output)
1318
1319 # create _output_16
1320 m.submodules.io16 = io16 = IntermediateOut(16, 128, 4)
1321 m.d.comb += io16.intermed.eq(self.i.output)
1322 for i in range(8):
1323 m.d.comb += io16.part_ops[i].eq(out_part_ops[i])
1324 m.d.comb += self.o.outputs[1].eq(io16.output)
1325
1326 # create _output_8
1327 m.submodules.io8 = io8 = IntermediateOut(8, 128, 8)
1328 m.d.comb += io8.intermed.eq(self.i.output)
1329 for i in range(8):
1330 m.d.comb += io8.part_ops[i].eq(out_part_ops[i])
1331 m.d.comb += self.o.outputs[0].eq(io8.output)
1332
1333 for i in range(8):
1334 m.d.comb += self.o.part_ops[i].eq(out_part_ops[i])
1335 m.d.comb += self.o.part_pts.eq(out_part_pts)
1336 m.d.comb += self.o.intermediate_output.eq(self.i.output)
1337
1338 return m
1339
1340
1341 class Mul8_16_32_64(Elaboratable):
1342 """Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
1343
1344 Supports partitioning into any combination of 8, 16, 32, and 64-bit
1345 partitions on naturally-aligned boundaries. Supports the operation being
1346 set for each partition independently.
1347
1348 :attribute part_pts: the input partition points. Has a partition point at
1349 multiples of 8 in 0 < i < 64. Each partition point's associated
1350 ``Value`` is a ``Signal``. Modification not supported, except for by
1351 ``Signal.eq``.
1352 :attribute part_ops: the operation for each byte. The operation for a
1353 particular partition is selected by assigning the selected operation
1354 code to each byte in the partition. The allowed operation codes are:
1355
1356 :attribute OP_MUL_LOW: the LSB half of the product. Equivalent to
1357 RISC-V's `mul` instruction.
1358 :attribute OP_MUL_SIGNED_HIGH: the MSB half of the product where both
1359 ``a`` and ``b`` are signed. Equivalent to RISC-V's `mulh`
1360 instruction.
1361 :attribute OP_MUL_SIGNED_UNSIGNED_HIGH: the MSB half of the product
1362 where ``a`` is signed and ``b`` is unsigned. Equivalent to RISC-V's
1363 `mulhsu` instruction.
1364 :attribute OP_MUL_UNSIGNED_HIGH: the MSB half of the product where both
1365 ``a`` and ``b`` are unsigned. Equivalent to RISC-V's `mulhu`
1366 instruction.
1367 """
1368
1369 def __init__(self, register_levels=()):
1370 """ register_levels: specifies the points in the cascade at which
1371 flip-flops are to be inserted.
1372 """
1373
1374 self.id_wid = 0 # num_bits(num_rows)
1375 self.op_wid = 0
1376 self.pspec = PipelineSpec(64, self.id_wid, self.op_wid, n_ops=3)
1377 self.pspec.n_parts = 8
1378
1379 # parameter(s)
1380 self.register_levels = list(register_levels)
1381
1382 self.i = self.ispec()
1383 self.o = self.ospec()
1384
1385 # inputs
1386 self.part_pts = self.i.part_pts
1387 self.part_ops = self.i.part_ops
1388 self.a = self.i.a
1389 self.b = self.i.b
1390
1391 # output
1392 self.intermediate_output = self.o.intermediate_output
1393 self.output = self.o.output
1394
1395 def ispec(self):
1396 return InputData()
1397
1398 def ospec(self):
1399 return OutputData()
1400
1401 def elaborate(self, platform):
1402 m = Module()
1403
1404 part_pts = self.part_pts
1405
1406 n_parts = self.pspec.n_parts
1407 n_inputs = 64 + 4
1408 output_width = self.pspec.width * 2
1409 t = AllTerms(self.pspec, n_inputs)
1410 t.setup(m, self.i)
1411
1412 terms = t.o.terms
1413
1414 at = AddReduceInternal(t.process(self.i), self.pspec, partition_step=2)
1415
1416 i = at.i
1417 for idx in range(len(at.levels)):
1418 mcur = at.levels[idx]
1419 mcur.setup(m, i)
1420 o = mcur.ospec()
1421 if idx in self.register_levels:
1422 m.d.sync += o.eq(mcur.process(i))
1423 else:
1424 m.d.comb += o.eq(mcur.process(i))
1425 i = o # for next loop
1426
1427 interm = Intermediates(128, 8, part_pts)
1428 interm.setup(m, i)
1429 o = interm.process(interm.i)
1430
1431 # final output
1432 finalout = FinalOut(128, 8, part_pts)
1433 finalout.setup(m, o)
1434 m.d.comb += self.o.eq(finalout.process(o))
1435
1436 return m
1437
1438
1439 if __name__ == "__main__":
1440 m = Mul8_16_32_64()
1441 main(m, ports=[m.a,
1442 m.b,
1443 m.intermediate_output,
1444 m.output,
1445 *m.part_ops,
1446 *m.part_pts.values()])