switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12 import math
13
14
15 class FPFormat:
16 """ Class describing binary floating-point formats based on IEEE 754.
17
18 :attribute e_width: the number of bits in the exponent field.
19 :attribute m_width: the number of bits stored in the mantissa
20 field.
21 :attribute has_int_bit: if the FP format has an explicit integer bit (like
22 the x87 80-bit format). The bit is considered part of the mantissa.
23 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
24 Image/Buffer formats are FP numbers without a sign bit.)
25 """
26
27 def __init__(self,
28 e_width,
29 m_width,
30 has_int_bit=False,
31 has_sign=True):
32 """ Create ``FPFormat`` instance. """
33 self.e_width = e_width
34 self.m_width = m_width
35 self.has_int_bit = has_int_bit
36 self.has_sign = has_sign
37
38 def __eq__(self, other):
39 """ Check for equality. """
40 if not isinstance(other, FPFormat):
41 return NotImplemented
42 return (self.e_width == other.e_width and
43 self.m_width == other.m_width and
44 self.has_int_bit == other.has_int_bit and
45 self.has_sign == other.has_sign)
46
47 @staticmethod
48 def standard(width):
49 """ Get standard IEEE 754-2008 format.
50
51 :param width: bit-width of requested format.
52 :returns: the requested ``FPFormat`` instance.
53 """
54 if width == 16:
55 return FPFormat(5, 10)
56 if width == 32:
57 return FPFormat(8, 23)
58 if width == 64:
59 return FPFormat(11, 52)
60 if width == 128:
61 return FPFormat(15, 112)
62 if width > 128 and width % 32 == 0:
63 if width > 1000000: # arbitrary upper limit
64 raise ValueError("width too big")
65 e_width = round(4 * math.log2(width)) - 13
66 return FPFormat(e_width, width - 1 - e_width)
67 raise ValueError("width must be the bit-width of a valid IEEE"
68 " 754-2008 binary format")
69
70 def __repr__(self):
71 """ Get repr. """
72 try:
73 if self == self.standard(self.width):
74 return f"FPFormat.standard({self.width})"
75 except ValueError:
76 pass
77 retval = f"FPFormat({self.e_width}, {self.m_width}"
78 if self.has_int_bit is not False:
79 retval += f", {self.has_int_bit}"
80 if self.has_sign is not True:
81 retval += f", {self.has_sign}"
82 return retval + ")"
83
84 @property
85 def width(self):
86 """ Get the total number of bits in the FP format. """
87 return self.has_sign + self.e_width + self.m_width
88
89 @property
90 def exponent_inf_nan(self):
91 """ Get the value of the exponent field designating infinity/NaN. """
92 return (1 << self.e_width) - 1
93
94 @property
95 def exponent_denormal_zero(self):
96 """ Get the value of the exponent field designating denormal/zero. """
97 return 0
98
99 @property
100 def exponent_min_normal(self):
101 """ Get the minimum value of the exponent field for normal numbers. """
102 return 1
103
104 @property
105 def exponent_max_normal(self):
106 """ Get the maximum value of the exponent field for normal numbers. """
107 return self.exponent_inf_nan - 1
108
109 @property
110 def exponent_bias(self):
111 """ Get the exponent bias. """
112 return (1 << (self.e_width - 1)) - 1
113
114 @property
115 def fraction_width(self):
116 """ Get the number of mantissa bits that are fraction bits. """
117 return self.m_width - self.has_int_bit
118
119
120 class MultiShiftR:
121
122 def __init__(self, width):
123 self.width = width
124 self.smax = int(log(width) / log(2))
125 self.i = Signal(width, reset_less=True)
126 self.s = Signal(self.smax, reset_less=True)
127 self.o = Signal(width, reset_less=True)
128
129 def elaborate(self, platform):
130 m = Module()
131 m.d.comb += self.o.eq(self.i >> self.s)
132 return m
133
134
135 class MultiShift:
136 """ Generates variable-length single-cycle shifter from a series
137 of conditional tests on each bit of the left/right shift operand.
138 Each bit tested produces output shifted by that number of bits,
139 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
140 shifts by 2 bits, each partial result cascading to the next Mux.
141
142 Could be adapted to do arithmetic shift by taking copies of the
143 MSB instead of zeros.
144 """
145
146 def __init__(self, width):
147 self.width = width
148 self.smax = int(log(width) / log(2))
149
150 def lshift(self, op, s):
151 res = op << s
152 return res[:len(op)]
153
154 def rshift(self, op, s):
155 res = op >> s
156 return res[:len(op)]
157
158
159 class FPNumBaseRecord:
160 """ Floating-point Base Number Class
161 """
162
163 def __init__(self, width, m_extra=True, e_extra=False):
164 self.width = width
165 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
166 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
167 e_max = 1 << (e_width-3)
168 self.rmw = m_width - 1 # real mantissa width (not including extras)
169 self.e_max = e_max
170 if m_extra:
171 # mantissa extra bits (top,guard,round)
172 self.m_extra = 3
173 m_width += self.m_extra
174 else:
175 self.m_extra = 0
176 if e_extra:
177 self.e_extra = 6 # enough to cover FP64 when converting to FP16
178 e_width += self.e_extra
179 else:
180 self.e_extra = 0
181 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
182 self.m_width = m_width
183 self.e_width = e_width
184 self.e_start = self.rmw
185 self.e_end = self.rmw + self.e_width - 2 # for decoding
186
187 self.v = Signal(width, reset_less=True) # Latched copy of value
188 self.m = Signal(m_width, reset_less=True) # Mantissa
189 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
190 self.s = Signal(reset_less=True) # Sign bit
191
192 self.fp = self
193 self.drop_in(self)
194
195 def drop_in(self, fp):
196 fp.s = self.s
197 fp.e = self.e
198 fp.m = self.m
199 fp.v = self.v
200 fp.rmw = self.rmw
201 fp.width = self.width
202 fp.e_width = self.e_width
203 fp.e_max = self.e_max
204 fp.m_width = self.m_width
205 fp.e_start = self.e_start
206 fp.e_end = self.e_end
207 fp.m_extra = self.m_extra
208
209 m_width = self.m_width
210 e_max = self.e_max
211 e_width = self.e_width
212
213 self.mzero = Const(0, (m_width, False))
214 m_msb = 1 << (self.m_width-2)
215 self.msb1 = Const(m_msb, (m_width, False))
216 self.m1s = Const(-1, (m_width, False))
217 self.P128 = Const(e_max, (e_width, True))
218 self.P127 = Const(e_max-1, (e_width, True))
219 self.N127 = Const(-(e_max-1), (e_width, True))
220 self.N126 = Const(-(e_max-2), (e_width, True))
221
222 def create(self, s, e, m):
223 """ creates a value from sign / exponent / mantissa
224
225 bias is added here, to the exponent.
226
227 NOTE: order is important, because e_start/e_end can be
228 a bit too long (overwriting s).
229 """
230 return [
231 self.v[0:self.e_start].eq(m), # mantissa
232 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
233 self.v[-1].eq(s), # sign
234 ]
235
236 def _nan(self, s):
237 return (s, self.fp.P128, 1 << (self.e_start-1))
238
239 def _inf(self, s):
240 return (s, self.fp.P128, 0)
241
242 def _zero(self, s):
243 return (s, self.fp.N127, 0)
244
245 def nan(self, s):
246 return self.create(*self._nan(s))
247
248 def inf(self, s):
249 return self.create(*self._inf(s))
250
251 def zero(self, s):
252 return self.create(*self._zero(s))
253
254 def create2(self, s, e, m):
255 """ creates a value from sign / exponent / mantissa
256
257 bias is added here, to the exponent
258 """
259 e = e + self.P127 # exp (add on bias)
260 return Cat(m[0:self.e_start],
261 e[0:self.e_end-self.e_start],
262 s)
263
264 def nan2(self, s):
265 return self.create2(s, self.P128, self.msb1)
266
267 def inf2(self, s):
268 return self.create2(s, self.P128, self.mzero)
269
270 def zero2(self, s):
271 return self.create2(s, self.N127, self.mzero)
272
273 def __iter__(self):
274 yield self.s
275 yield self.e
276 yield self.m
277
278 def eq(self, inp):
279 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
280
281
282 class FPNumBase(FPNumBaseRecord, Elaboratable):
283 """ Floating-point Base Number Class
284 """
285
286 def __init__(self, fp):
287 fp.drop_in(self)
288 self.fp = fp
289 e_width = fp.e_width
290
291 self.is_nan = Signal(reset_less=True)
292 self.is_zero = Signal(reset_less=True)
293 self.is_inf = Signal(reset_less=True)
294 self.is_overflowed = Signal(reset_less=True)
295 self.is_denormalised = Signal(reset_less=True)
296 self.exp_128 = Signal(reset_less=True)
297 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
298 self.exp_lt_n126 = Signal(reset_less=True)
299 self.exp_zero = Signal(reset_less=True)
300 self.exp_gt_n126 = Signal(reset_less=True)
301 self.exp_gt127 = Signal(reset_less=True)
302 self.exp_n127 = Signal(reset_less=True)
303 self.exp_n126 = Signal(reset_less=True)
304 self.m_zero = Signal(reset_less=True)
305 self.m_msbzero = Signal(reset_less=True)
306
307 def elaborate(self, platform):
308 m = Module()
309 m.d.comb += self.is_nan.eq(self._is_nan())
310 m.d.comb += self.is_zero.eq(self._is_zero())
311 m.d.comb += self.is_inf.eq(self._is_inf())
312 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
313 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
314 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
315 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
316 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
317 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
318 m.d.comb += self.exp_zero.eq(self.e == 0)
319 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
320 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
321 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
322 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
323 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
324
325 return m
326
327 def _is_nan(self):
328 return (self.exp_128) & (~self.m_zero)
329
330 def _is_inf(self):
331 return (self.exp_128) & (self.m_zero)
332
333 def _is_zero(self):
334 return (self.exp_n127) & (self.m_zero)
335
336 def _is_overflowed(self):
337 return self.exp_gt127
338
339 def _is_denormalised(self):
340 return (self.exp_n126) & (self.m_msbzero)
341
342
343 class FPNumOut(FPNumBase):
344 """ Floating-point Number Class
345
346 Contains signals for an incoming copy of the value, decoded into
347 sign / exponent / mantissa.
348 Also contains encoding functions, creation and recognition of
349 zero, NaN and inf (all signed)
350
351 Four extra bits are included in the mantissa: the top bit
352 (m[-1]) is effectively a carry-overflow. The other three are
353 guard (m[2]), round (m[1]), and sticky (m[0])
354 """
355
356 def __init__(self, fp):
357 FPNumBase.__init__(self, fp)
358
359 def elaborate(self, platform):
360 m = FPNumBase.elaborate(self, platform)
361
362 return m
363
364
365 class MultiShiftRMerge(Elaboratable):
366 """ shifts down (right) and merges lower bits into m[0].
367 m[0] is the "sticky" bit, basically
368 """
369
370 def __init__(self, width, s_max=None):
371 if s_max is None:
372 s_max = int(log(width) / log(2))
373 self.smax = s_max
374 self.m = Signal(width, reset_less=True)
375 self.inp = Signal(width, reset_less=True)
376 self.diff = Signal(s_max, reset_less=True)
377 self.width = width
378
379 def elaborate(self, platform):
380 m = Module()
381
382 rs = Signal(self.width, reset_less=True)
383 m_mask = Signal(self.width, reset_less=True)
384 smask = Signal(self.width, reset_less=True)
385 stickybit = Signal(reset_less=True)
386 maxslen = Signal(self.smax, reset_less=True)
387 maxsleni = Signal(self.smax, reset_less=True)
388
389 sm = MultiShift(self.width-1)
390 m0s = Const(0, self.width-1)
391 mw = Const(self.width-1, len(self.diff))
392 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
393 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
394 ]
395
396 m.d.comb += [
397 # shift mantissa by maxslen, mask by inverse
398 rs.eq(sm.rshift(self.inp[1:], maxslen)),
399 m_mask.eq(sm.rshift(~m0s, maxsleni)),
400 smask.eq(self.inp[1:] & m_mask),
401 # sticky bit combines all mask (and mantissa low bit)
402 stickybit.eq(smask.bool() | self.inp[0]),
403 # mantissa result contains m[0] already.
404 self.m.eq(Cat(stickybit, rs))
405 ]
406 return m
407
408
409 class FPNumShift(FPNumBase, Elaboratable):
410 """ Floating-point Number Class for shifting
411 """
412
413 def __init__(self, mainm, op, inv, width, m_extra=True):
414 FPNumBase.__init__(self, width, m_extra)
415 self.latch_in = Signal()
416 self.mainm = mainm
417 self.inv = inv
418 self.op = op
419
420 def elaborate(self, platform):
421 m = FPNumBase.elaborate(self, platform)
422
423 m.d.comb += self.s.eq(op.s)
424 m.d.comb += self.e.eq(op.e)
425 m.d.comb += self.m.eq(op.m)
426
427 with self.mainm.State("align"):
428 with m.If(self.e < self.inv.e):
429 m.d.sync += self.shift_down()
430
431 return m
432
433 def shift_down(self, inp):
434 """ shifts a mantissa down by one. exponent is increased to compensate
435
436 accuracy is lost as a result in the mantissa however there are 3
437 guard bits (the latter of which is the "sticky" bit)
438 """
439 return [self.e.eq(inp.e + 1),
440 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
441 ]
442
443 def shift_down_multi(self, diff):
444 """ shifts a mantissa down. exponent is increased to compensate
445
446 accuracy is lost as a result in the mantissa however there are 3
447 guard bits (the latter of which is the "sticky" bit)
448
449 this code works by variable-shifting the mantissa by up to
450 its maximum bit-length: no point doing more (it'll still be
451 zero).
452
453 the sticky bit is computed by shifting a batch of 1s by
454 the same amount, which will introduce zeros. it's then
455 inverted and used as a mask to get the LSBs of the mantissa.
456 those are then |'d into the sticky bit.
457 """
458 sm = MultiShift(self.width)
459 mw = Const(self.m_width-1, len(diff))
460 maxslen = Mux(diff > mw, mw, diff)
461 rs = sm.rshift(self.m[1:], maxslen)
462 maxsleni = mw - maxslen
463 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
464
465 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
466 return [self.e.eq(self.e + diff),
467 self.m.eq(Cat(stickybits, rs))
468 ]
469
470 def shift_up_multi(self, diff):
471 """ shifts a mantissa up. exponent is decreased to compensate
472 """
473 sm = MultiShift(self.width)
474 mw = Const(self.m_width, len(diff))
475 maxslen = Mux(diff > mw, mw, diff)
476
477 return [self.e.eq(self.e - diff),
478 self.m.eq(sm.lshift(self.m, maxslen))
479 ]
480
481
482 class FPNumDecode(FPNumBase):
483 """ Floating-point Number Class
484
485 Contains signals for an incoming copy of the value, decoded into
486 sign / exponent / mantissa.
487 Also contains encoding functions, creation and recognition of
488 zero, NaN and inf (all signed)
489
490 Four extra bits are included in the mantissa: the top bit
491 (m[-1]) is effectively a carry-overflow. The other three are
492 guard (m[2]), round (m[1]), and sticky (m[0])
493 """
494
495 def __init__(self, op, fp):
496 FPNumBase.__init__(self, fp)
497 self.op = op
498
499 def elaborate(self, platform):
500 m = FPNumBase.elaborate(self, platform)
501
502 m.d.comb += self.decode(self.v)
503
504 return m
505
506 def decode(self, v):
507 """ decodes a latched value into sign / exponent / mantissa
508
509 bias is subtracted here, from the exponent. exponent
510 is extended to 10 bits so that subtract 127 is done on
511 a 10-bit number
512 """
513 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
514 #print ("decode", self.e_end)
515 return [self.m.eq(Cat(*args)), # mantissa
516 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
517 self.s.eq(v[-1]), # sign
518 ]
519
520
521 class FPNumIn(FPNumBase):
522 """ Floating-point Number Class
523
524 Contains signals for an incoming copy of the value, decoded into
525 sign / exponent / mantissa.
526 Also contains encoding functions, creation and recognition of
527 zero, NaN and inf (all signed)
528
529 Four extra bits are included in the mantissa: the top bit
530 (m[-1]) is effectively a carry-overflow. The other three are
531 guard (m[2]), round (m[1]), and sticky (m[0])
532 """
533
534 def __init__(self, op, fp):
535 FPNumBase.__init__(self, fp)
536 self.latch_in = Signal()
537 self.op = op
538
539 def decode2(self, m):
540 """ decodes a latched value into sign / exponent / mantissa
541
542 bias is subtracted here, from the exponent. exponent
543 is extended to 10 bits so that subtract 127 is done on
544 a 10-bit number
545 """
546 v = self.v
547 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
548 #print ("decode", self.e_end)
549 res = ObjectProxy(m, pipemode=False)
550 res.m = Cat(*args) # mantissa
551 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
552 res.s = v[-1] # sign
553 return res
554
555 def decode(self, v):
556 """ decodes a latched value into sign / exponent / mantissa
557
558 bias is subtracted here, from the exponent. exponent
559 is extended to 10 bits so that subtract 127 is done on
560 a 10-bit number
561 """
562 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
563 #print ("decode", self.e_end)
564 return [self.m.eq(Cat(*args)), # mantissa
565 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
566 self.s.eq(v[-1]), # sign
567 ]
568
569 def shift_down(self, inp):
570 """ shifts a mantissa down by one. exponent is increased to compensate
571
572 accuracy is lost as a result in the mantissa however there are 3
573 guard bits (the latter of which is the "sticky" bit)
574 """
575 return [self.e.eq(inp.e + 1),
576 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
577 ]
578
579 def shift_down_multi(self, diff, inp=None):
580 """ shifts a mantissa down. exponent is increased to compensate
581
582 accuracy is lost as a result in the mantissa however there are 3
583 guard bits (the latter of which is the "sticky" bit)
584
585 this code works by variable-shifting the mantissa by up to
586 its maximum bit-length: no point doing more (it'll still be
587 zero).
588
589 the sticky bit is computed by shifting a batch of 1s by
590 the same amount, which will introduce zeros. it's then
591 inverted and used as a mask to get the LSBs of the mantissa.
592 those are then |'d into the sticky bit.
593 """
594 if inp is None:
595 inp = self
596 sm = MultiShift(self.width)
597 mw = Const(self.m_width-1, len(diff))
598 maxslen = Mux(diff > mw, mw, diff)
599 rs = sm.rshift(inp.m[1:], maxslen)
600 maxsleni = mw - maxslen
601 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
602
603 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
604 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
605 return [self.e.eq(inp.e + diff),
606 self.m.eq(Cat(stickybit, rs))
607 ]
608
609 def shift_up_multi(self, diff):
610 """ shifts a mantissa up. exponent is decreased to compensate
611 """
612 sm = MultiShift(self.width)
613 mw = Const(self.m_width, len(diff))
614 maxslen = Mux(diff > mw, mw, diff)
615
616 return [self.e.eq(self.e - diff),
617 self.m.eq(sm.lshift(self.m, maxslen))
618 ]
619
620
621 class Trigger(Elaboratable):
622 def __init__(self):
623
624 self.stb = Signal(reset=0)
625 self.ack = Signal()
626 self.trigger = Signal(reset_less=True)
627
628 def elaborate(self, platform):
629 m = Module()
630 m.d.comb += self.trigger.eq(self.stb & self.ack)
631 return m
632
633 def eq(self, inp):
634 return [self.stb.eq(inp.stb),
635 self.ack.eq(inp.ack)
636 ]
637
638 def ports(self):
639 return [self.stb, self.ack]
640
641
642 class FPOpIn(PrevControl):
643 def __init__(self, width):
644 PrevControl.__init__(self)
645 self.width = width
646
647 @property
648 def v(self):
649 return self.data_i
650
651 def chain_inv(self, in_op, extra=None):
652 stb = in_op.stb
653 if extra is not None:
654 stb = stb & extra
655 return [self.v.eq(in_op.v), # receive value
656 self.stb.eq(stb), # receive STB
657 in_op.ack.eq(~self.ack), # send ACK
658 ]
659
660 def chain_from(self, in_op, extra=None):
661 stb = in_op.stb
662 if extra is not None:
663 stb = stb & extra
664 return [self.v.eq(in_op.v), # receive value
665 self.stb.eq(stb), # receive STB
666 in_op.ack.eq(self.ack), # send ACK
667 ]
668
669
670 class FPOpOut(NextControl):
671 def __init__(self, width):
672 NextControl.__init__(self)
673 self.width = width
674
675 @property
676 def v(self):
677 return self.data_o
678
679 def chain_inv(self, in_op, extra=None):
680 stb = in_op.stb
681 if extra is not None:
682 stb = stb & extra
683 return [self.v.eq(in_op.v), # receive value
684 self.stb.eq(stb), # receive STB
685 in_op.ack.eq(~self.ack), # send ACK
686 ]
687
688 def chain_from(self, in_op, extra=None):
689 stb = in_op.stb
690 if extra is not None:
691 stb = stb & extra
692 return [self.v.eq(in_op.v), # receive value
693 self.stb.eq(stb), # receive STB
694 in_op.ack.eq(self.ack), # send ACK
695 ]
696
697
698 class Overflow: # (Elaboratable):
699 def __init__(self, name=None):
700 if name is None:
701 name = ""
702 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
703 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
704 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
705 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
706
707 #self.roundz = Signal(reset_less=True)
708
709 def __iter__(self):
710 yield self.guard
711 yield self.round_bit
712 yield self.sticky
713 yield self.m0
714
715 def eq(self, inp):
716 return [self.guard.eq(inp.guard),
717 self.round_bit.eq(inp.round_bit),
718 self.sticky.eq(inp.sticky),
719 self.m0.eq(inp.m0)]
720
721 @property
722 def roundz(self):
723 return self.guard & (self.round_bit | self.sticky | self.m0)
724
725
726 class OverflowMod(Elaboratable, Overflow):
727 def __init__(self, name=None):
728 Overflow.__init__(self, name)
729 if name is None:
730 name = ""
731 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
732
733 def __iter__(self):
734 yield from Overflow.__iter__(self)
735 yield self.roundz_out
736
737 def eq(self, inp):
738 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
739
740 def elaborate(self, platform):
741 m = Module()
742 m.d.comb += self.roundz_out.eq(self.roundz)
743 return m
744
745
746 class FPBase:
747 """ IEEE754 Floating Point Base Class
748
749 contains common functions for FP manipulation, such as
750 extracting and packing operands, normalisation, denormalisation,
751 rounding etc.
752 """
753
754 def get_op(self, m, op, v, next_state):
755 """ this function moves to the next state and copies the operand
756 when both stb and ack are 1.
757 acknowledgement is sent by setting ack to ZERO.
758 """
759 res = v.decode2(m)
760 ack = Signal()
761 with m.If((op.ready_o) & (op.valid_i_test)):
762 m.next = next_state
763 # op is latched in from FPNumIn class on same ack/stb
764 m.d.comb += ack.eq(0)
765 with m.Else():
766 m.d.comb += ack.eq(1)
767 return [res, ack]
768
769 def denormalise(self, m, a):
770 """ denormalises a number. this is probably the wrong name for
771 this function. for normalised numbers (exponent != minimum)
772 one *extra* bit (the implicit 1) is added *back in*.
773 for denormalised numbers, the mantissa is left alone
774 and the exponent increased by 1.
775
776 both cases *effectively multiply the number stored by 2*,
777 which has to be taken into account when extracting the result.
778 """
779 with m.If(a.exp_n127):
780 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
781 with m.Else():
782 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
783
784 def op_normalise(self, m, op, next_state):
785 """ operand normalisation
786 NOTE: just like "align", this one keeps going round every clock
787 until the result's exponent is within acceptable "range"
788 """
789 with m.If((op.m[-1] == 0)): # check last bit of mantissa
790 m.d.sync += [
791 op.e.eq(op.e - 1), # DECREASE exponent
792 op.m.eq(op.m << 1), # shift mantissa UP
793 ]
794 with m.Else():
795 m.next = next_state
796
797 def normalise_1(self, m, z, of, next_state):
798 """ first stage normalisation
799
800 NOTE: just like "align", this one keeps going round every clock
801 until the result's exponent is within acceptable "range"
802 NOTE: the weirdness of reassigning guard and round is due to
803 the extra mantissa bits coming from tot[0..2]
804 """
805 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
806 m.d.sync += [
807 z.e.eq(z.e - 1), # DECREASE exponent
808 z.m.eq(z.m << 1), # shift mantissa UP
809 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
810 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
811 of.round_bit.eq(0), # reset round bit
812 of.m0.eq(of.guard),
813 ]
814 with m.Else():
815 m.next = next_state
816
817 def normalise_2(self, m, z, of, next_state):
818 """ second stage normalisation
819
820 NOTE: just like "align", this one keeps going round every clock
821 until the result's exponent is within acceptable "range"
822 NOTE: the weirdness of reassigning guard and round is due to
823 the extra mantissa bits coming from tot[0..2]
824 """
825 with m.If(z.e < z.fp.N126):
826 m.d.sync += [
827 z.e.eq(z.e + 1), # INCREASE exponent
828 z.m.eq(z.m >> 1), # shift mantissa DOWN
829 of.guard.eq(z.m[0]),
830 of.m0.eq(z.m[1]),
831 of.round_bit.eq(of.guard),
832 of.sticky.eq(of.sticky | of.round_bit)
833 ]
834 with m.Else():
835 m.next = next_state
836
837 def roundz(self, m, z, roundz):
838 """ performs rounding on the output. TODO: different kinds of rounding
839 """
840 with m.If(roundz):
841 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
842 with m.If(z.m == z.fp.m1s): # all 1s
843 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
844
845 def corrections(self, m, z, next_state):
846 """ denormalisation and sign-bug corrections
847 """
848 m.next = next_state
849 # denormalised, correct exponent to zero
850 with m.If(z.is_denormalised):
851 m.d.sync += z.e.eq(z.fp.N127)
852
853 def pack(self, m, z, next_state):
854 """ packs the result into the output (detects overflow->Inf)
855 """
856 m.next = next_state
857 # if overflow occurs, return inf
858 with m.If(z.is_overflowed):
859 m.d.sync += z.inf(z.s)
860 with m.Else():
861 m.d.sync += z.create(z.s, z.e, z.m)
862
863 def put_z(self, m, z, out_z, next_state):
864 """ put_z: stores the result in the output. raises stb and waits
865 for ack to be set to 1 before moving to the next state.
866 resets stb back to zero when that occurs, as acknowledgement.
867 """
868 m.d.sync += [
869 out_z.v.eq(z.v)
870 ]
871 with m.If(out_z.valid_o & out_z.ready_i_test):
872 m.d.sync += out_z.valid_o.eq(0)
873 m.next = next_state
874 with m.Else():
875 m.d.sync += out_z.valid_o.eq(1)
876
877
878 class FPState(FPBase):
879 def __init__(self, state_from):
880 self.state_from = state_from
881
882 def set_inputs(self, inputs):
883 self.inputs = inputs
884 for k, v in inputs.items():
885 setattr(self, k, v)
886
887 def set_outputs(self, outputs):
888 self.outputs = outputs
889 for k, v in outputs.items():
890 setattr(self, k, v)
891
892
893 class FPID:
894 def __init__(self, id_wid):
895 self.id_wid = id_wid
896 if self.id_wid:
897 self.in_mid = Signal(id_wid, reset_less=True)
898 self.out_mid = Signal(id_wid, reset_less=True)
899 else:
900 self.in_mid = None
901 self.out_mid = None
902
903 def idsync(self, m):
904 if self.id_wid is not None:
905 m.d.sync += self.out_mid.eq(self.in_mid)