rename exponent_width to e_width, mantissa_width to m_width (shorter)
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12 import math
13
14
15 class FPFormat:
16 """ Class describing binary floating-point formats based on IEEE 754.
17
18 :attribute e_width: the number of bits in the exponent field.
19 :attribute m_width: the number of bits stored in the mantissa
20 field.
21 :attribute has_int_bit: if the FP format has an explicit integer bit (like
22 the x87 80-bit format). The bit is considered part of the mantissa.
23 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
24 Image/Buffer formats are FP numbers without a sign bit.)
25 """
26
27 def __init__(self,
28 e_width,
29 m_width,
30 has_int_bit=False,
31 has_sign=True):
32 """ Create ``FPFormat`` instance. """
33 self.e_width = e_width
34 self.m_width = m_width
35 self.has_int_bit = has_int_bit
36 self.has_sign = has_sign
37
38 def __eq__(self, other):
39 """ Check for equality. """
40 if not isinstance(other, FPFormat):
41 return NotImplemented
42 return (self.e_width == other.e_width and
43 self.m_width == other.m_width and
44 self.has_int_bit == other.has_int_bit and
45 self.has_sign == other.has_sign)
46
47 @staticmethod
48 def standard(width):
49 """ Get standard IEEE 754-2008 format.
50
51 :param width: bit-width of requested format.
52 :returns: the requested ``FPFormat`` instance.
53 """
54 if not instanceof(width, int):
55 raise TypeError()
56 if width == 16:
57 return FPFormat(5, 10)
58 if width == 32:
59 return FPFormat(8, 23)
60 if width == 64:
61 return FPFormat(11, 52)
62 if width == 128:
63 return FPFormat(15, 112)
64 if width > 128 and width % 32 == 0:
65 if width > 1000000: # arbitrary upper limit
66 raise ValueError("width too big")
67 e_width = round(4 * math.log2(width)) - 13
68 return FPFormat(e_width, width - 1 - e_width)
69 raise ValueError("width must be the bit-width of a valid IEEE"
70 " 754-2008 binary format")
71
72 def __repr__(self):
73 """ Get repr. """
74 try:
75 if self == self.standard(self.width):
76 return f"FPFormat.standard({self.width})"
77 except ValueError:
78 pass
79 retval = f"FPFormat({self.e_width}, {self.m_width}"
80 if self.has_int_bit is not False:
81 retval += f", {self.has_int_bit}"
82 if self.has_sign is not True:
83 retval += f", {self.has_sign}"
84 return retval + ")"
85
86 @property
87 def width(self):
88 """ Get the total number of bits in the FP format. """
89 return self.has_sign + self.e_width + self.m_width
90
91 @property
92 def exponent_inf_nan(self):
93 """ Get the value of the exponent field designating infinity/NaN. """
94 return (1 << self.e_width) - 1
95
96 @property
97 def exponent_denormal_zero(self):
98 """ Get the value of the exponent field designating denormal/zero. """
99 return 0
100
101 @property
102 def exponent_min_normal(self):
103 """ Get the minimum value of the exponent field for normal numbers. """
104 return 1
105
106 @property
107 def exponent_max_normal(self):
108 """ Get the maximum value of the exponent field for normal numbers. """
109 return self.exponent_inf_nan - 1
110
111 @property
112 def exponent_bias(self):
113 """ Get the exponent bias. """
114 return (1 << (self.e_width - 1)) - 1
115
116 @property
117 def fraction_width(self):
118 """ Get the number of mantissa bits that are fraction bits. """
119 return self.m_width - self.has_int_bit
120
121
122 class MultiShiftR:
123
124 def __init__(self, width):
125 self.width = width
126 self.smax = int(log(width) / log(2))
127 self.i = Signal(width, reset_less=True)
128 self.s = Signal(self.smax, reset_less=True)
129 self.o = Signal(width, reset_less=True)
130
131 def elaborate(self, platform):
132 m = Module()
133 m.d.comb += self.o.eq(self.i >> self.s)
134 return m
135
136
137 class MultiShift:
138 """ Generates variable-length single-cycle shifter from a series
139 of conditional tests on each bit of the left/right shift operand.
140 Each bit tested produces output shifted by that number of bits,
141 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
142 shifts by 2 bits, each partial result cascading to the next Mux.
143
144 Could be adapted to do arithmetic shift by taking copies of the
145 MSB instead of zeros.
146 """
147
148 def __init__(self, width):
149 self.width = width
150 self.smax = int(log(width) / log(2))
151
152 def lshift(self, op, s):
153 res = op << s
154 return res[:len(op)]
155
156 def rshift(self, op, s):
157 res = op >> s
158 return res[:len(op)]
159
160
161 class FPNumBaseRecord:
162 """ Floating-point Base Number Class
163 """
164
165 def __init__(self, width, m_extra=True, e_extra=False):
166 self.width = width
167 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
168 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
169 e_max = 1 << (e_width-3)
170 self.rmw = m_width - 1 # real mantissa width (not including extras)
171 self.e_max = e_max
172 if m_extra:
173 # mantissa extra bits (top,guard,round)
174 self.m_extra = 3
175 m_width += self.m_extra
176 else:
177 self.m_extra = 0
178 if e_extra:
179 self.e_extra = 6 # enough to cover FP64 when converting to FP16
180 e_width += self.e_extra
181 else:
182 self.e_extra = 0
183 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
184 self.m_width = m_width
185 self.e_width = e_width
186 self.e_start = self.rmw
187 self.e_end = self.rmw + self.e_width - 2 # for decoding
188
189 self.v = Signal(width, reset_less=True) # Latched copy of value
190 self.m = Signal(m_width, reset_less=True) # Mantissa
191 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
192 self.s = Signal(reset_less=True) # Sign bit
193
194 self.fp = self
195 self.drop_in(self)
196
197 def drop_in(self, fp):
198 fp.s = self.s
199 fp.e = self.e
200 fp.m = self.m
201 fp.v = self.v
202 fp.rmw = self.rmw
203 fp.width = self.width
204 fp.e_width = self.e_width
205 fp.e_max = self.e_max
206 fp.m_width = self.m_width
207 fp.e_start = self.e_start
208 fp.e_end = self.e_end
209 fp.m_extra = self.m_extra
210
211 m_width = self.m_width
212 e_max = self.e_max
213 e_width = self.e_width
214
215 self.mzero = Const(0, (m_width, False))
216 m_msb = 1 << (self.m_width-2)
217 self.msb1 = Const(m_msb, (m_width, False))
218 self.m1s = Const(-1, (m_width, False))
219 self.P128 = Const(e_max, (e_width, True))
220 self.P127 = Const(e_max-1, (e_width, True))
221 self.N127 = Const(-(e_max-1), (e_width, True))
222 self.N126 = Const(-(e_max-2), (e_width, True))
223
224 def create(self, s, e, m):
225 """ creates a value from sign / exponent / mantissa
226
227 bias is added here, to the exponent.
228
229 NOTE: order is important, because e_start/e_end can be
230 a bit too long (overwriting s).
231 """
232 return [
233 self.v[0:self.e_start].eq(m), # mantissa
234 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
235 self.v[-1].eq(s), # sign
236 ]
237
238 def _nan(self, s):
239 return (s, self.fp.P128, 1 << (self.e_start-1))
240
241 def _inf(self, s):
242 return (s, self.fp.P128, 0)
243
244 def _zero(self, s):
245 return (s, self.fp.N127, 0)
246
247 def nan(self, s):
248 return self.create(*self._nan(s))
249
250 def inf(self, s):
251 return self.create(*self._inf(s))
252
253 def zero(self, s):
254 return self.create(*self._zero(s))
255
256 def create2(self, s, e, m):
257 """ creates a value from sign / exponent / mantissa
258
259 bias is added here, to the exponent
260 """
261 e = e + self.P127 # exp (add on bias)
262 return Cat(m[0:self.e_start],
263 e[0:self.e_end-self.e_start],
264 s)
265
266 def nan2(self, s):
267 return self.create2(s, self.P128, self.msb1)
268
269 def inf2(self, s):
270 return self.create2(s, self.P128, self.mzero)
271
272 def zero2(self, s):
273 return self.create2(s, self.N127, self.mzero)
274
275 def __iter__(self):
276 yield self.s
277 yield self.e
278 yield self.m
279
280 def eq(self, inp):
281 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
282
283
284 class FPNumBase(FPNumBaseRecord, Elaboratable):
285 """ Floating-point Base Number Class
286 """
287
288 def __init__(self, fp):
289 fp.drop_in(self)
290 self.fp = fp
291 e_width = fp.e_width
292
293 self.is_nan = Signal(reset_less=True)
294 self.is_zero = Signal(reset_less=True)
295 self.is_inf = Signal(reset_less=True)
296 self.is_overflowed = Signal(reset_less=True)
297 self.is_denormalised = Signal(reset_less=True)
298 self.exp_128 = Signal(reset_less=True)
299 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
300 self.exp_lt_n126 = Signal(reset_less=True)
301 self.exp_zero = Signal(reset_less=True)
302 self.exp_gt_n126 = Signal(reset_less=True)
303 self.exp_gt127 = Signal(reset_less=True)
304 self.exp_n127 = Signal(reset_less=True)
305 self.exp_n126 = Signal(reset_less=True)
306 self.m_zero = Signal(reset_less=True)
307 self.m_msbzero = Signal(reset_less=True)
308
309 def elaborate(self, platform):
310 m = Module()
311 m.d.comb += self.is_nan.eq(self._is_nan())
312 m.d.comb += self.is_zero.eq(self._is_zero())
313 m.d.comb += self.is_inf.eq(self._is_inf())
314 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
315 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
316 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
317 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
318 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
319 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
320 m.d.comb += self.exp_zero.eq(self.e == 0)
321 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
322 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
323 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
324 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
325 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
326
327 return m
328
329 def _is_nan(self):
330 return (self.exp_128) & (~self.m_zero)
331
332 def _is_inf(self):
333 return (self.exp_128) & (self.m_zero)
334
335 def _is_zero(self):
336 return (self.exp_n127) & (self.m_zero)
337
338 def _is_overflowed(self):
339 return self.exp_gt127
340
341 def _is_denormalised(self):
342 return (self.exp_n126) & (self.m_msbzero)
343
344
345 class FPNumOut(FPNumBase):
346 """ Floating-point Number Class
347
348 Contains signals for an incoming copy of the value, decoded into
349 sign / exponent / mantissa.
350 Also contains encoding functions, creation and recognition of
351 zero, NaN and inf (all signed)
352
353 Four extra bits are included in the mantissa: the top bit
354 (m[-1]) is effectively a carry-overflow. The other three are
355 guard (m[2]), round (m[1]), and sticky (m[0])
356 """
357
358 def __init__(self, fp):
359 FPNumBase.__init__(self, fp)
360
361 def elaborate(self, platform):
362 m = FPNumBase.elaborate(self, platform)
363
364 return m
365
366
367 class MultiShiftRMerge(Elaboratable):
368 """ shifts down (right) and merges lower bits into m[0].
369 m[0] is the "sticky" bit, basically
370 """
371
372 def __init__(self, width, s_max=None):
373 if s_max is None:
374 s_max = int(log(width) / log(2))
375 self.smax = s_max
376 self.m = Signal(width, reset_less=True)
377 self.inp = Signal(width, reset_less=True)
378 self.diff = Signal(s_max, reset_less=True)
379 self.width = width
380
381 def elaborate(self, platform):
382 m = Module()
383
384 rs = Signal(self.width, reset_less=True)
385 m_mask = Signal(self.width, reset_less=True)
386 smask = Signal(self.width, reset_less=True)
387 stickybit = Signal(reset_less=True)
388 maxslen = Signal(self.smax, reset_less=True)
389 maxsleni = Signal(self.smax, reset_less=True)
390
391 sm = MultiShift(self.width-1)
392 m0s = Const(0, self.width-1)
393 mw = Const(self.width-1, len(self.diff))
394 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
395 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
396 ]
397
398 m.d.comb += [
399 # shift mantissa by maxslen, mask by inverse
400 rs.eq(sm.rshift(self.inp[1:], maxslen)),
401 m_mask.eq(sm.rshift(~m0s, maxsleni)),
402 smask.eq(self.inp[1:] & m_mask),
403 # sticky bit combines all mask (and mantissa low bit)
404 stickybit.eq(smask.bool() | self.inp[0]),
405 # mantissa result contains m[0] already.
406 self.m.eq(Cat(stickybit, rs))
407 ]
408 return m
409
410
411 class FPNumShift(FPNumBase, Elaboratable):
412 """ Floating-point Number Class for shifting
413 """
414
415 def __init__(self, mainm, op, inv, width, m_extra=True):
416 FPNumBase.__init__(self, width, m_extra)
417 self.latch_in = Signal()
418 self.mainm = mainm
419 self.inv = inv
420 self.op = op
421
422 def elaborate(self, platform):
423 m = FPNumBase.elaborate(self, platform)
424
425 m.d.comb += self.s.eq(op.s)
426 m.d.comb += self.e.eq(op.e)
427 m.d.comb += self.m.eq(op.m)
428
429 with self.mainm.State("align"):
430 with m.If(self.e < self.inv.e):
431 m.d.sync += self.shift_down()
432
433 return m
434
435 def shift_down(self, inp):
436 """ shifts a mantissa down by one. exponent is increased to compensate
437
438 accuracy is lost as a result in the mantissa however there are 3
439 guard bits (the latter of which is the "sticky" bit)
440 """
441 return [self.e.eq(inp.e + 1),
442 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
443 ]
444
445 def shift_down_multi(self, diff):
446 """ shifts a mantissa down. exponent is increased to compensate
447
448 accuracy is lost as a result in the mantissa however there are 3
449 guard bits (the latter of which is the "sticky" bit)
450
451 this code works by variable-shifting the mantissa by up to
452 its maximum bit-length: no point doing more (it'll still be
453 zero).
454
455 the sticky bit is computed by shifting a batch of 1s by
456 the same amount, which will introduce zeros. it's then
457 inverted and used as a mask to get the LSBs of the mantissa.
458 those are then |'d into the sticky bit.
459 """
460 sm = MultiShift(self.width)
461 mw = Const(self.m_width-1, len(diff))
462 maxslen = Mux(diff > mw, mw, diff)
463 rs = sm.rshift(self.m[1:], maxslen)
464 maxsleni = mw - maxslen
465 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
466
467 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
468 return [self.e.eq(self.e + diff),
469 self.m.eq(Cat(stickybits, rs))
470 ]
471
472 def shift_up_multi(self, diff):
473 """ shifts a mantissa up. exponent is decreased to compensate
474 """
475 sm = MultiShift(self.width)
476 mw = Const(self.m_width, len(diff))
477 maxslen = Mux(diff > mw, mw, diff)
478
479 return [self.e.eq(self.e - diff),
480 self.m.eq(sm.lshift(self.m, maxslen))
481 ]
482
483
484 class FPNumDecode(FPNumBase):
485 """ Floating-point Number Class
486
487 Contains signals for an incoming copy of the value, decoded into
488 sign / exponent / mantissa.
489 Also contains encoding functions, creation and recognition of
490 zero, NaN and inf (all signed)
491
492 Four extra bits are included in the mantissa: the top bit
493 (m[-1]) is effectively a carry-overflow. The other three are
494 guard (m[2]), round (m[1]), and sticky (m[0])
495 """
496
497 def __init__(self, op, fp):
498 FPNumBase.__init__(self, fp)
499 self.op = op
500
501 def elaborate(self, platform):
502 m = FPNumBase.elaborate(self, platform)
503
504 m.d.comb += self.decode(self.v)
505
506 return m
507
508 def decode(self, v):
509 """ decodes a latched value into sign / exponent / mantissa
510
511 bias is subtracted here, from the exponent. exponent
512 is extended to 10 bits so that subtract 127 is done on
513 a 10-bit number
514 """
515 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
516 #print ("decode", self.e_end)
517 return [self.m.eq(Cat(*args)), # mantissa
518 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
519 self.s.eq(v[-1]), # sign
520 ]
521
522
523 class FPNumIn(FPNumBase):
524 """ Floating-point Number Class
525
526 Contains signals for an incoming copy of the value, decoded into
527 sign / exponent / mantissa.
528 Also contains encoding functions, creation and recognition of
529 zero, NaN and inf (all signed)
530
531 Four extra bits are included in the mantissa: the top bit
532 (m[-1]) is effectively a carry-overflow. The other three are
533 guard (m[2]), round (m[1]), and sticky (m[0])
534 """
535
536 def __init__(self, op, fp):
537 FPNumBase.__init__(self, fp)
538 self.latch_in = Signal()
539 self.op = op
540
541 def decode2(self, m):
542 """ decodes a latched value into sign / exponent / mantissa
543
544 bias is subtracted here, from the exponent. exponent
545 is extended to 10 bits so that subtract 127 is done on
546 a 10-bit number
547 """
548 v = self.v
549 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
550 #print ("decode", self.e_end)
551 res = ObjectProxy(m, pipemode=False)
552 res.m = Cat(*args) # mantissa
553 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
554 res.s = v[-1] # sign
555 return res
556
557 def decode(self, v):
558 """ decodes a latched value into sign / exponent / mantissa
559
560 bias is subtracted here, from the exponent. exponent
561 is extended to 10 bits so that subtract 127 is done on
562 a 10-bit number
563 """
564 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
565 #print ("decode", self.e_end)
566 return [self.m.eq(Cat(*args)), # mantissa
567 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
568 self.s.eq(v[-1]), # sign
569 ]
570
571 def shift_down(self, inp):
572 """ shifts a mantissa down by one. exponent is increased to compensate
573
574 accuracy is lost as a result in the mantissa however there are 3
575 guard bits (the latter of which is the "sticky" bit)
576 """
577 return [self.e.eq(inp.e + 1),
578 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
579 ]
580
581 def shift_down_multi(self, diff, inp=None):
582 """ shifts a mantissa down. exponent is increased to compensate
583
584 accuracy is lost as a result in the mantissa however there are 3
585 guard bits (the latter of which is the "sticky" bit)
586
587 this code works by variable-shifting the mantissa by up to
588 its maximum bit-length: no point doing more (it'll still be
589 zero).
590
591 the sticky bit is computed by shifting a batch of 1s by
592 the same amount, which will introduce zeros. it's then
593 inverted and used as a mask to get the LSBs of the mantissa.
594 those are then |'d into the sticky bit.
595 """
596 if inp is None:
597 inp = self
598 sm = MultiShift(self.width)
599 mw = Const(self.m_width-1, len(diff))
600 maxslen = Mux(diff > mw, mw, diff)
601 rs = sm.rshift(inp.m[1:], maxslen)
602 maxsleni = mw - maxslen
603 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
604
605 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
606 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
607 return [self.e.eq(inp.e + diff),
608 self.m.eq(Cat(stickybit, rs))
609 ]
610
611 def shift_up_multi(self, diff):
612 """ shifts a mantissa up. exponent is decreased to compensate
613 """
614 sm = MultiShift(self.width)
615 mw = Const(self.m_width, len(diff))
616 maxslen = Mux(diff > mw, mw, diff)
617
618 return [self.e.eq(self.e - diff),
619 self.m.eq(sm.lshift(self.m, maxslen))
620 ]
621
622
623 class Trigger(Elaboratable):
624 def __init__(self):
625
626 self.stb = Signal(reset=0)
627 self.ack = Signal()
628 self.trigger = Signal(reset_less=True)
629
630 def elaborate(self, platform):
631 m = Module()
632 m.d.comb += self.trigger.eq(self.stb & self.ack)
633 return m
634
635 def eq(self, inp):
636 return [self.stb.eq(inp.stb),
637 self.ack.eq(inp.ack)
638 ]
639
640 def ports(self):
641 return [self.stb, self.ack]
642
643
644 class FPOpIn(PrevControl):
645 def __init__(self, width):
646 PrevControl.__init__(self)
647 self.width = width
648
649 @property
650 def v(self):
651 return self.data_i
652
653 def chain_inv(self, in_op, extra=None):
654 stb = in_op.stb
655 if extra is not None:
656 stb = stb & extra
657 return [self.v.eq(in_op.v), # receive value
658 self.stb.eq(stb), # receive STB
659 in_op.ack.eq(~self.ack), # send ACK
660 ]
661
662 def chain_from(self, in_op, extra=None):
663 stb = in_op.stb
664 if extra is not None:
665 stb = stb & extra
666 return [self.v.eq(in_op.v), # receive value
667 self.stb.eq(stb), # receive STB
668 in_op.ack.eq(self.ack), # send ACK
669 ]
670
671
672 class FPOpOut(NextControl):
673 def __init__(self, width):
674 NextControl.__init__(self)
675 self.width = width
676
677 @property
678 def v(self):
679 return self.data_o
680
681 def chain_inv(self, in_op, extra=None):
682 stb = in_op.stb
683 if extra is not None:
684 stb = stb & extra
685 return [self.v.eq(in_op.v), # receive value
686 self.stb.eq(stb), # receive STB
687 in_op.ack.eq(~self.ack), # send ACK
688 ]
689
690 def chain_from(self, in_op, extra=None):
691 stb = in_op.stb
692 if extra is not None:
693 stb = stb & extra
694 return [self.v.eq(in_op.v), # receive value
695 self.stb.eq(stb), # receive STB
696 in_op.ack.eq(self.ack), # send ACK
697 ]
698
699
700 class Overflow: # (Elaboratable):
701 def __init__(self, name=None):
702 if name is None:
703 name = ""
704 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
705 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
706 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
707 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
708
709 #self.roundz = Signal(reset_less=True)
710
711 def __iter__(self):
712 yield self.guard
713 yield self.round_bit
714 yield self.sticky
715 yield self.m0
716
717 def eq(self, inp):
718 return [self.guard.eq(inp.guard),
719 self.round_bit.eq(inp.round_bit),
720 self.sticky.eq(inp.sticky),
721 self.m0.eq(inp.m0)]
722
723 @property
724 def roundz(self):
725 return self.guard & (self.round_bit | self.sticky | self.m0)
726
727
728 class OverflowMod(Elaboratable, Overflow):
729 def __init__(self, name=None):
730 Overflow.__init__(self, name)
731 if name is None:
732 name = ""
733 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
734
735 def __iter__(self):
736 yield from Overflow.__iter__(self)
737 yield self.roundz_out
738
739 def eq(self, inp):
740 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
741
742 def elaborate(self, platform):
743 m = Module()
744 m.d.comb += self.roundz_out.eq(self.roundz)
745 return m
746
747
748 class FPBase:
749 """ IEEE754 Floating Point Base Class
750
751 contains common functions for FP manipulation, such as
752 extracting and packing operands, normalisation, denormalisation,
753 rounding etc.
754 """
755
756 def get_op(self, m, op, v, next_state):
757 """ this function moves to the next state and copies the operand
758 when both stb and ack are 1.
759 acknowledgement is sent by setting ack to ZERO.
760 """
761 res = v.decode2(m)
762 ack = Signal()
763 with m.If((op.ready_o) & (op.valid_i_test)):
764 m.next = next_state
765 # op is latched in from FPNumIn class on same ack/stb
766 m.d.comb += ack.eq(0)
767 with m.Else():
768 m.d.comb += ack.eq(1)
769 return [res, ack]
770
771 def denormalise(self, m, a):
772 """ denormalises a number. this is probably the wrong name for
773 this function. for normalised numbers (exponent != minimum)
774 one *extra* bit (the implicit 1) is added *back in*.
775 for denormalised numbers, the mantissa is left alone
776 and the exponent increased by 1.
777
778 both cases *effectively multiply the number stored by 2*,
779 which has to be taken into account when extracting the result.
780 """
781 with m.If(a.exp_n127):
782 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
783 with m.Else():
784 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
785
786 def op_normalise(self, m, op, next_state):
787 """ operand normalisation
788 NOTE: just like "align", this one keeps going round every clock
789 until the result's exponent is within acceptable "range"
790 """
791 with m.If((op.m[-1] == 0)): # check last bit of mantissa
792 m.d.sync += [
793 op.e.eq(op.e - 1), # DECREASE exponent
794 op.m.eq(op.m << 1), # shift mantissa UP
795 ]
796 with m.Else():
797 m.next = next_state
798
799 def normalise_1(self, m, z, of, next_state):
800 """ first stage normalisation
801
802 NOTE: just like "align", this one keeps going round every clock
803 until the result's exponent is within acceptable "range"
804 NOTE: the weirdness of reassigning guard and round is due to
805 the extra mantissa bits coming from tot[0..2]
806 """
807 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
808 m.d.sync += [
809 z.e.eq(z.e - 1), # DECREASE exponent
810 z.m.eq(z.m << 1), # shift mantissa UP
811 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
812 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
813 of.round_bit.eq(0), # reset round bit
814 of.m0.eq(of.guard),
815 ]
816 with m.Else():
817 m.next = next_state
818
819 def normalise_2(self, m, z, of, next_state):
820 """ second stage normalisation
821
822 NOTE: just like "align", this one keeps going round every clock
823 until the result's exponent is within acceptable "range"
824 NOTE: the weirdness of reassigning guard and round is due to
825 the extra mantissa bits coming from tot[0..2]
826 """
827 with m.If(z.e < z.fp.N126):
828 m.d.sync += [
829 z.e.eq(z.e + 1), # INCREASE exponent
830 z.m.eq(z.m >> 1), # shift mantissa DOWN
831 of.guard.eq(z.m[0]),
832 of.m0.eq(z.m[1]),
833 of.round_bit.eq(of.guard),
834 of.sticky.eq(of.sticky | of.round_bit)
835 ]
836 with m.Else():
837 m.next = next_state
838
839 def roundz(self, m, z, roundz):
840 """ performs rounding on the output. TODO: different kinds of rounding
841 """
842 with m.If(roundz):
843 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
844 with m.If(z.m == z.fp.m1s): # all 1s
845 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
846
847 def corrections(self, m, z, next_state):
848 """ denormalisation and sign-bug corrections
849 """
850 m.next = next_state
851 # denormalised, correct exponent to zero
852 with m.If(z.is_denormalised):
853 m.d.sync += z.e.eq(z.fp.N127)
854
855 def pack(self, m, z, next_state):
856 """ packs the result into the output (detects overflow->Inf)
857 """
858 m.next = next_state
859 # if overflow occurs, return inf
860 with m.If(z.is_overflowed):
861 m.d.sync += z.inf(z.s)
862 with m.Else():
863 m.d.sync += z.create(z.s, z.e, z.m)
864
865 def put_z(self, m, z, out_z, next_state):
866 """ put_z: stores the result in the output. raises stb and waits
867 for ack to be set to 1 before moving to the next state.
868 resets stb back to zero when that occurs, as acknowledgement.
869 """
870 m.d.sync += [
871 out_z.v.eq(z.v)
872 ]
873 with m.If(out_z.valid_o & out_z.ready_i_test):
874 m.d.sync += out_z.valid_o.eq(0)
875 m.next = next_state
876 with m.Else():
877 m.d.sync += out_z.valid_o.eq(1)
878
879
880 class FPState(FPBase):
881 def __init__(self, state_from):
882 self.state_from = state_from
883
884 def set_inputs(self, inputs):
885 self.inputs = inputs
886 for k, v in inputs.items():
887 setattr(self, k, v)
888
889 def set_outputs(self, outputs):
890 self.outputs = outputs
891 for k, v in outputs.items():
892 setattr(self, k, v)
893
894
895 class FPID:
896 def __init__(self, id_wid):
897 self.id_wid = id_wid
898 if self.id_wid:
899 self.in_mid = Signal(id_wid, reset_less=True)
900 self.out_mid = Signal(id_wid, reset_less=True)
901 else:
902 self.in_mid = None
903 self.out_mid = None
904
905 def idsync(self, m):
906 if self.id_wid is not None:
907 m.d.sync += self.out_mid.eq(self.in_mid)