update comments
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12 import unittest
13 import math
14
15
16 class FPFormat:
17 """ Class describing binary floating-point formats based on IEEE 754.
18
19 :attribute e_width: the number of bits in the exponent field.
20 :attribute m_width: the number of bits stored in the mantissa
21 field.
22 :attribute has_int_bit: if the FP format has an explicit integer bit (like
23 the x87 80-bit format). The bit is considered part of the mantissa.
24 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
25 Image/Buffer formats are FP numbers without a sign bit.)
26 """
27
28 def __init__(self,
29 e_width,
30 m_width,
31 has_int_bit=False,
32 has_sign=True):
33 """ Create ``FPFormat`` instance. """
34 self.e_width = e_width
35 self.m_width = m_width
36 self.has_int_bit = has_int_bit
37 self.has_sign = has_sign
38
39 def __eq__(self, other):
40 """ Check for equality. """
41 if not isinstance(other, FPFormat):
42 return NotImplemented
43 return (self.e_width == other.e_width
44 and self.m_width == other.m_width
45 and self.has_int_bit == other.has_int_bit
46 and self.has_sign == other.has_sign)
47
48 @staticmethod
49 def standard(width):
50 """ Get standard IEEE 754-2008 format.
51
52 :param width: bit-width of requested format.
53 :returns: the requested ``FPFormat`` instance.
54 """
55 if width == 16:
56 return FPFormat(5, 10)
57 if width == 32:
58 return FPFormat(8, 23)
59 if width == 64:
60 return FPFormat(11, 52)
61 if width == 128:
62 return FPFormat(15, 112)
63 if width > 128 and width % 32 == 0:
64 if width > 1000000: # arbitrary upper limit
65 raise ValueError("width too big")
66 e_width = round(4 * math.log2(width)) - 13
67 return FPFormat(e_width, width - 1 - e_width)
68 raise ValueError("width must be the bit-width of a valid IEEE"
69 " 754-2008 binary format")
70
71 def __repr__(self):
72 """ Get repr. """
73 try:
74 if self == self.standard(self.width):
75 return f"FPFormat.standard({self.width})"
76 except ValueError:
77 pass
78 retval = f"FPFormat({self.e_width}, {self.m_width}"
79 if self.has_int_bit is not False:
80 retval += f", {self.has_int_bit}"
81 if self.has_sign is not True:
82 retval += f", {self.has_sign}"
83 return retval + ")"
84
85 def get_sign_field(self, x):
86 """ returns the sign bit of its input number, x
87 (assumes FPFormat is set to signed - has_sign=True)
88 """
89 return x >> (self.e_width + self.m_width)
90
91 def get_exponent_field(self, x):
92 """ returns the raw exponent of its input number, x (no bias subtracted)
93 """
94 x = ((x >> self.m_width) & self.exponent_inf_nan)
95 return x
96
97 def get_exponent(self, x):
98 """ returns the exponent of its input number, x
99 """
100 return self.get_exponent_field(x) - self.exponent_bias
101
102 def get_mantissa_field(self, x):
103 """ returns the mantissa of its input number, x
104 """
105 return x & self.mantissa_mask
106
107 def is_zero(self, x):
108 """ returns true if x is +/- zero
109 """
110 return (self.get_exponent(x) == self.e_sub and
111 self.get_mantissa_field(x) == 0)
112
113 def is_subnormal(self, x):
114 """ returns true if x is subnormal (exp at minimum)
115 """
116 return (self.get_exponent(x) == self.e_sub and
117 self.get_mantissa_field(x) != 0)
118
119 def is_inf(self, x):
120 """ returns true if x is infinite
121 """
122 return (self.get_exponent(x) == self.e_max and
123 self.get_mantissa_field(x) == 0)
124
125 def is_nan(self, x):
126 """ returns true if x is a nan (quiet or signalling)
127 """
128 return (self.get_exponent(x) == self.e_max and
129 self.get_mantissa_field(x) != 0)
130
131 def is_quiet_nan(self, x):
132 """ returns true if x is a quiet nan
133 """
134 highbit = 1<<(self.m_width-1)
135 return (self.get_exponent(x) == self.e_max and
136 self.get_mantissa_field(x) != 0 and
137 self.get_mantissa_field(x) & highbit != 0)
138
139 def is_nan_signaling(self, x):
140 """ returns true if x is a signalling nan
141 """
142 highbit = 1<<(self.m_width-1)
143 return ((self.get_exponent(x) == self.e_max) and
144 (self.get_mantissa_field(x) != 0) and
145 (self.get_mantissa_field(x) & highbit) == 0)
146
147 @property
148 def width(self):
149 """ Get the total number of bits in the FP format. """
150 return self.has_sign + self.e_width + self.m_width
151
152 @property
153 def mantissa_mask(self):
154 """ Get a mantissa mask based on the mantissa width """
155 return (1 << self.m_width) - 1
156
157 @property
158 def exponent_inf_nan(self):
159 """ Get the value of the exponent field designating infinity/NaN. """
160 return (1 << self.e_width) - 1
161
162 @property
163 def e_max(self):
164 """ get the maximum exponent (minus bias)
165 """
166 return self.exponent_inf_nan - self.exponent_bias
167
168 @property
169 def e_sub(self):
170 return self.exponent_denormal_zero - self.exponent_bias
171 @property
172 def exponent_denormal_zero(self):
173 """ Get the value of the exponent field designating denormal/zero. """
174 return 0
175
176 @property
177 def exponent_min_normal(self):
178 """ Get the minimum value of the exponent field for normal numbers. """
179 return 1
180
181 @property
182 def exponent_max_normal(self):
183 """ Get the maximum value of the exponent field for normal numbers. """
184 return self.exponent_inf_nan - 1
185
186 @property
187 def exponent_bias(self):
188 """ Get the exponent bias. """
189 return (1 << (self.e_width - 1)) - 1
190
191 @property
192 def fraction_width(self):
193 """ Get the number of mantissa bits that are fraction bits. """
194 return self.m_width - self.has_int_bit
195
196
197 class TestFPFormat(unittest.TestCase):
198 """ very quick test for FPFormat
199 """
200
201 def test_fpformat_fp64(self):
202 f64 = FPFormat.standard(64)
203 from sfpy import Float64
204 x = Float64(1.0).bits
205 print (hex(x))
206
207 self.assertEqual(f64.get_exponent(x), 0)
208 x = Float64(2.0).bits
209 print (hex(x))
210 self.assertEqual(f64.get_exponent(x), 1)
211
212 x = Float64(1.5).bits
213 m = f64.get_mantissa_field(x)
214 print (hex(x), hex(m))
215 self.assertEqual(m, 0x8000000000000)
216
217 s = f64.get_sign_field(x)
218 print (hex(x), hex(s))
219 self.assertEqual(s, 0)
220
221 x = Float64(-1.5).bits
222 s = f64.get_sign_field(x)
223 print (hex(x), hex(s))
224 self.assertEqual(s, 1)
225
226 def test_fpformat_fp32(self):
227 f32 = FPFormat.standard(32)
228 from sfpy import Float32
229 x = Float32(1.0).bits
230 print (hex(x))
231
232 self.assertEqual(f32.get_exponent(x), 0)
233 x = Float32(2.0).bits
234 print (hex(x))
235 self.assertEqual(f32.get_exponent(x), 1)
236
237 x = Float32(1.5).bits
238 m = f32.get_mantissa_field(x)
239 print (hex(x), hex(m))
240 self.assertEqual(m, 0x400000)
241
242 # NaN test
243 x = Float32(-1.0).sqrt()
244 x = x.bits
245 i = f32.is_nan(x)
246 print (hex(x), "nan", f32.get_exponent(x), f32.e_max,
247 f32.get_mantissa_field(x), i)
248 self.assertEqual(i, True)
249
250 # Inf test
251 x = Float32(1e36) * Float32(1e36) * Float32(1e36)
252 x = x.bits
253 i = f32.is_inf(x)
254 print (hex(x), "inf", f32.get_exponent(x), f32.e_max,
255 f32.get_mantissa_field(x), i)
256 self.assertEqual(i, True)
257
258 # subnormal
259 x = Float32(1e-41)
260 x = x.bits
261 i = f32.is_subnormal(x)
262 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
263 f32.get_mantissa_field(x), i)
264 self.assertEqual(i, True)
265
266 x = Float32(0.0)
267 x = x.bits
268 i = f32.is_subnormal(x)
269 print (hex(x), "sub", f32.get_exponent(x), f32.e_max,
270 f32.get_mantissa_field(x), i)
271 self.assertEqual(i, False)
272
273 # zero
274 i = f32.is_zero(x)
275 print (hex(x), "zero", f32.get_exponent(x), f32.e_max,
276 f32.get_mantissa_field(x), i)
277 self.assertEqual(i, True)
278
279
280 class MultiShiftR:
281
282 def __init__(self, width):
283 self.width = width
284 self.smax = int(log(width) / log(2))
285 self.i = Signal(width, reset_less=True)
286 self.s = Signal(self.smax, reset_less=True)
287 self.o = Signal(width, reset_less=True)
288
289 def elaborate(self, platform):
290 m = Module()
291 m.d.comb += self.o.eq(self.i >> self.s)
292 return m
293
294
295 class MultiShift:
296 """ Generates variable-length single-cycle shifter from a series
297 of conditional tests on each bit of the left/right shift operand.
298 Each bit tested produces output shifted by that number of bits,
299 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
300 shifts by 2 bits, each partial result cascading to the next Mux.
301
302 Could be adapted to do arithmetic shift by taking copies of the
303 MSB instead of zeros.
304 """
305
306 def __init__(self, width):
307 self.width = width
308 self.smax = int(log(width) / log(2))
309
310 def lshift(self, op, s):
311 res = op << s
312 return res[:len(op)]
313
314 def rshift(self, op, s):
315 res = op >> s
316 return res[:len(op)]
317
318
319 class FPNumBaseRecord:
320 """ Floating-point Base Number Class.
321
322 This class is designed to be passed around in other data structures
323 (between pipelines and between stages). Its "friend" is FPNumBase,
324 which is a *module*. The reason for the discernment is because
325 nmigen modules that are not added to submodules results in the
326 irritating "Elaboration" warning. Despite not *needing* FPNumBase
327 in many cases to be added as a submodule (because it is just data)
328 this was not possible to solve without splitting out the data from
329 the module.
330 """
331
332 def __init__(self, width, m_extra=True, e_extra=False, name=None):
333 if name is None:
334 name = ""
335 # assert false, "missing name"
336 else:
337 name += "_"
338 self.width = width
339 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
340 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
341 e_max = 1 << (e_width-3)
342 self.rmw = m_width - 1 # real mantissa width (not including extras)
343 self.e_max = e_max
344 if m_extra:
345 # mantissa extra bits (top,guard,round)
346 self.m_extra = 3
347 m_width += self.m_extra
348 else:
349 self.m_extra = 0
350 if e_extra:
351 self.e_extra = 6 # enough to cover FP64 when converting to FP16
352 e_width += self.e_extra
353 else:
354 self.e_extra = 0
355 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
356 self.m_width = m_width
357 self.e_width = e_width
358 self.e_start = self.rmw
359 self.e_end = self.rmw + self.e_width - 2 # for decoding
360
361 self.v = Signal(width, reset_less=True,
362 name=name+"v") # Latched copy of value
363 self.m = Signal(m_width, reset_less=True, name=name+"m") # Mantissa
364 self.e = Signal((e_width, True),
365 reset_less=True, name=name+"e") # exp+2 bits, signed
366 self.s = Signal(reset_less=True, name=name+"s") # Sign bit
367
368 self.fp = self
369 self.drop_in(self)
370
371 def drop_in(self, fp):
372 fp.s = self.s
373 fp.e = self.e
374 fp.m = self.m
375 fp.v = self.v
376 fp.rmw = self.rmw
377 fp.width = self.width
378 fp.e_width = self.e_width
379 fp.e_max = self.e_max
380 fp.m_width = self.m_width
381 fp.e_start = self.e_start
382 fp.e_end = self.e_end
383 fp.m_extra = self.m_extra
384
385 m_width = self.m_width
386 e_max = self.e_max
387 e_width = self.e_width
388
389 self.mzero = Const(0, (m_width, False))
390 m_msb = 1 << (self.m_width-2)
391 self.msb1 = Const(m_msb, (m_width, False))
392 self.m1s = Const(-1, (m_width, False))
393 self.P128 = Const(e_max, (e_width, True))
394 self.P127 = Const(e_max-1, (e_width, True))
395 self.N127 = Const(-(e_max-1), (e_width, True))
396 self.N126 = Const(-(e_max-2), (e_width, True))
397
398 def create(self, s, e, m):
399 """ creates a value from sign / exponent / mantissa
400
401 bias is added here, to the exponent.
402
403 NOTE: order is important, because e_start/e_end can be
404 a bit too long (overwriting s).
405 """
406 return [
407 self.v[0:self.e_start].eq(m), # mantissa
408 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
409 self.v[-1].eq(s), # sign
410 ]
411
412 def _nan(self, s):
413 return (s, self.fp.P128, 1 << (self.e_start-1))
414
415 def _inf(self, s):
416 return (s, self.fp.P128, 0)
417
418 def _zero(self, s):
419 return (s, self.fp.N127, 0)
420
421 def nan(self, s):
422 return self.create(*self._nan(s))
423
424 def inf(self, s):
425 return self.create(*self._inf(s))
426
427 def zero(self, s):
428 return self.create(*self._zero(s))
429
430 def create2(self, s, e, m):
431 """ creates a value from sign / exponent / mantissa
432
433 bias is added here, to the exponent
434 """
435 e = e + self.P127 # exp (add on bias)
436 return Cat(m[0:self.e_start],
437 e[0:self.e_end-self.e_start],
438 s)
439
440 def nan2(self, s):
441 return self.create2(s, self.P128, self.msb1)
442
443 def inf2(self, s):
444 return self.create2(s, self.P128, self.mzero)
445
446 def zero2(self, s):
447 return self.create2(s, self.N127, self.mzero)
448
449 def __iter__(self):
450 yield self.s
451 yield self.e
452 yield self.m
453
454 def eq(self, inp):
455 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
456
457
458 class FPNumBase(FPNumBaseRecord, Elaboratable):
459 """ Floating-point Base Number Class
460 """
461
462 def __init__(self, fp):
463 fp.drop_in(self)
464 self.fp = fp
465 e_width = fp.e_width
466
467 self.is_nan = Signal(reset_less=True)
468 self.is_zero = Signal(reset_less=True)
469 self.is_inf = Signal(reset_less=True)
470 self.is_overflowed = Signal(reset_less=True)
471 self.is_denormalised = Signal(reset_less=True)
472 self.exp_128 = Signal(reset_less=True)
473 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
474 self.exp_lt_n126 = Signal(reset_less=True)
475 self.exp_zero = Signal(reset_less=True)
476 self.exp_gt_n126 = Signal(reset_less=True)
477 self.exp_gt127 = Signal(reset_less=True)
478 self.exp_n127 = Signal(reset_less=True)
479 self.exp_n126 = Signal(reset_less=True)
480 self.m_zero = Signal(reset_less=True)
481 self.m_msbzero = Signal(reset_less=True)
482
483 def elaborate(self, platform):
484 m = Module()
485 m.d.comb += self.is_nan.eq(self._is_nan())
486 m.d.comb += self.is_zero.eq(self._is_zero())
487 m.d.comb += self.is_inf.eq(self._is_inf())
488 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
489 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
490 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
491 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
492 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
493 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
494 m.d.comb += self.exp_zero.eq(self.e == 0)
495 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
496 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
497 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
498 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
499 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
500
501 return m
502
503 def _is_nan(self):
504 return (self.exp_128) & (~self.m_zero)
505
506 def _is_inf(self):
507 return (self.exp_128) & (self.m_zero)
508
509 def _is_zero(self):
510 return (self.exp_n127) & (self.m_zero)
511
512 def _is_overflowed(self):
513 return self.exp_gt127
514
515 def _is_denormalised(self):
516 # XXX NOT to be used for "official" quiet NaN tests!
517 # particularly when the MSB has been extended
518 return (self.exp_n126) & (self.m_msbzero)
519
520
521 class FPNumOut(FPNumBase):
522 """ Floating-point Number Class
523
524 Contains signals for an incoming copy of the value, decoded into
525 sign / exponent / mantissa.
526 Also contains encoding functions, creation and recognition of
527 zero, NaN and inf (all signed)
528
529 Four extra bits are included in the mantissa: the top bit
530 (m[-1]) is effectively a carry-overflow. The other three are
531 guard (m[2]), round (m[1]), and sticky (m[0])
532 """
533
534 def __init__(self, fp):
535 FPNumBase.__init__(self, fp)
536
537 def elaborate(self, platform):
538 m = FPNumBase.elaborate(self, platform)
539
540 return m
541
542
543 class MultiShiftRMerge(Elaboratable):
544 """ shifts down (right) and merges lower bits into m[0].
545 m[0] is the "sticky" bit, basically
546 """
547
548 def __init__(self, width, s_max=None):
549 if s_max is None:
550 s_max = int(log(width) / log(2))
551 self.smax = s_max
552 self.m = Signal(width, reset_less=True)
553 self.inp = Signal(width, reset_less=True)
554 self.diff = Signal(s_max, reset_less=True)
555 self.width = width
556
557 def elaborate(self, platform):
558 m = Module()
559
560 rs = Signal(self.width, reset_less=True)
561 m_mask = Signal(self.width, reset_less=True)
562 smask = Signal(self.width, reset_less=True)
563 stickybit = Signal(reset_less=True)
564 maxslen = Signal(self.smax, reset_less=True)
565 maxsleni = Signal(self.smax, reset_less=True)
566
567 sm = MultiShift(self.width-1)
568 m0s = Const(0, self.width-1)
569 mw = Const(self.width-1, len(self.diff))
570 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
571 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
572 ]
573
574 m.d.comb += [
575 # shift mantissa by maxslen, mask by inverse
576 rs.eq(sm.rshift(self.inp[1:], maxslen)),
577 m_mask.eq(sm.rshift(~m0s, maxsleni)),
578 smask.eq(self.inp[1:] & m_mask),
579 # sticky bit combines all mask (and mantissa low bit)
580 stickybit.eq(smask.bool() | self.inp[0]),
581 # mantissa result contains m[0] already.
582 self.m.eq(Cat(stickybit, rs))
583 ]
584 return m
585
586
587 class FPNumShift(FPNumBase, Elaboratable):
588 """ Floating-point Number Class for shifting
589 """
590
591 def __init__(self, mainm, op, inv, width, m_extra=True):
592 FPNumBase.__init__(self, width, m_extra)
593 self.latch_in = Signal()
594 self.mainm = mainm
595 self.inv = inv
596 self.op = op
597
598 def elaborate(self, platform):
599 m = FPNumBase.elaborate(self, platform)
600
601 m.d.comb += self.s.eq(op.s)
602 m.d.comb += self.e.eq(op.e)
603 m.d.comb += self.m.eq(op.m)
604
605 with self.mainm.State("align"):
606 with m.If(self.e < self.inv.e):
607 m.d.sync += self.shift_down()
608
609 return m
610
611 def shift_down(self, inp):
612 """ shifts a mantissa down by one. exponent is increased to compensate
613
614 accuracy is lost as a result in the mantissa however there are 3
615 guard bits (the latter of which is the "sticky" bit)
616 """
617 return [self.e.eq(inp.e + 1),
618 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
619 ]
620
621 def shift_down_multi(self, diff):
622 """ shifts a mantissa down. exponent is increased to compensate
623
624 accuracy is lost as a result in the mantissa however there are 3
625 guard bits (the latter of which is the "sticky" bit)
626
627 this code works by variable-shifting the mantissa by up to
628 its maximum bit-length: no point doing more (it'll still be
629 zero).
630
631 the sticky bit is computed by shifting a batch of 1s by
632 the same amount, which will introduce zeros. it's then
633 inverted and used as a mask to get the LSBs of the mantissa.
634 those are then |'d into the sticky bit.
635 """
636 sm = MultiShift(self.width)
637 mw = Const(self.m_width-1, len(diff))
638 maxslen = Mux(diff > mw, mw, diff)
639 rs = sm.rshift(self.m[1:], maxslen)
640 maxsleni = mw - maxslen
641 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
642
643 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
644 return [self.e.eq(self.e + diff),
645 self.m.eq(Cat(stickybits, rs))
646 ]
647
648 def shift_up_multi(self, diff):
649 """ shifts a mantissa up. exponent is decreased to compensate
650 """
651 sm = MultiShift(self.width)
652 mw = Const(self.m_width, len(diff))
653 maxslen = Mux(diff > mw, mw, diff)
654
655 return [self.e.eq(self.e - diff),
656 self.m.eq(sm.lshift(self.m, maxslen))
657 ]
658
659
660 class FPNumDecode(FPNumBase):
661 """ Floating-point Number Class
662
663 Contains signals for an incoming copy of the value, decoded into
664 sign / exponent / mantissa.
665 Also contains encoding functions, creation and recognition of
666 zero, NaN and inf (all signed)
667
668 Four extra bits are included in the mantissa: the top bit
669 (m[-1]) is effectively a carry-overflow. The other three are
670 guard (m[2]), round (m[1]), and sticky (m[0])
671 """
672
673 def __init__(self, op, fp):
674 FPNumBase.__init__(self, fp)
675 self.op = op
676
677 def elaborate(self, platform):
678 m = FPNumBase.elaborate(self, platform)
679
680 m.d.comb += self.decode(self.v)
681
682 return m
683
684 def decode(self, v):
685 """ decodes a latched value into sign / exponent / mantissa
686
687 bias is subtracted here, from the exponent. exponent
688 is extended to 10 bits so that subtract 127 is done on
689 a 10-bit number
690 """
691 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
692 #print ("decode", self.e_end)
693 return [self.m.eq(Cat(*args)), # mantissa
694 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
695 self.s.eq(v[-1]), # sign
696 ]
697
698
699 class FPNumIn(FPNumBase):
700 """ Floating-point Number Class
701
702 Contains signals for an incoming copy of the value, decoded into
703 sign / exponent / mantissa.
704 Also contains encoding functions, creation and recognition of
705 zero, NaN and inf (all signed)
706
707 Four extra bits are included in the mantissa: the top bit
708 (m[-1]) is effectively a carry-overflow. The other three are
709 guard (m[2]), round (m[1]), and sticky (m[0])
710 """
711
712 def __init__(self, op, fp):
713 FPNumBase.__init__(self, fp)
714 self.latch_in = Signal()
715 self.op = op
716
717 def decode2(self, m):
718 """ decodes a latched value into sign / exponent / mantissa
719
720 bias is subtracted here, from the exponent. exponent
721 is extended to 10 bits so that subtract 127 is done on
722 a 10-bit number
723 """
724 v = self.v
725 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
726 #print ("decode", self.e_end)
727 res = ObjectProxy(m, pipemode=False)
728 res.m = Cat(*args) # mantissa
729 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
730 res.s = v[-1] # sign
731 return res
732
733 def decode(self, v):
734 """ decodes a latched value into sign / exponent / mantissa
735
736 bias is subtracted here, from the exponent. exponent
737 is extended to 10 bits so that subtract 127 is done on
738 a 10-bit number
739 """
740 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
741 #print ("decode", self.e_end)
742 return [self.m.eq(Cat(*args)), # mantissa
743 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
744 self.s.eq(v[-1]), # sign
745 ]
746
747 def shift_down(self, inp):
748 """ shifts a mantissa down by one. exponent is increased to compensate
749
750 accuracy is lost as a result in the mantissa however there are 3
751 guard bits (the latter of which is the "sticky" bit)
752 """
753 return [self.e.eq(inp.e + 1),
754 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
755 ]
756
757 def shift_down_multi(self, diff, inp=None):
758 """ shifts a mantissa down. exponent is increased to compensate
759
760 accuracy is lost as a result in the mantissa however there are 3
761 guard bits (the latter of which is the "sticky" bit)
762
763 this code works by variable-shifting the mantissa by up to
764 its maximum bit-length: no point doing more (it'll still be
765 zero).
766
767 the sticky bit is computed by shifting a batch of 1s by
768 the same amount, which will introduce zeros. it's then
769 inverted and used as a mask to get the LSBs of the mantissa.
770 those are then |'d into the sticky bit.
771 """
772 if inp is None:
773 inp = self
774 sm = MultiShift(self.width)
775 mw = Const(self.m_width-1, len(diff))
776 maxslen = Mux(diff > mw, mw, diff)
777 rs = sm.rshift(inp.m[1:], maxslen)
778 maxsleni = mw - maxslen
779 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
780
781 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
782 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
783 return [self.e.eq(inp.e + diff),
784 self.m.eq(Cat(stickybit, rs))
785 ]
786
787 def shift_up_multi(self, diff):
788 """ shifts a mantissa up. exponent is decreased to compensate
789 """
790 sm = MultiShift(self.width)
791 mw = Const(self.m_width, len(diff))
792 maxslen = Mux(diff > mw, mw, diff)
793
794 return [self.e.eq(self.e - diff),
795 self.m.eq(sm.lshift(self.m, maxslen))
796 ]
797
798
799 class Trigger(Elaboratable):
800 def __init__(self):
801
802 self.stb = Signal(reset=0)
803 self.ack = Signal()
804 self.trigger = Signal(reset_less=True)
805
806 def elaborate(self, platform):
807 m = Module()
808 m.d.comb += self.trigger.eq(self.stb & self.ack)
809 return m
810
811 def eq(self, inp):
812 return [self.stb.eq(inp.stb),
813 self.ack.eq(inp.ack)
814 ]
815
816 def ports(self):
817 return [self.stb, self.ack]
818
819
820 class FPOpIn(PrevControl):
821 def __init__(self, width):
822 PrevControl.__init__(self)
823 self.width = width
824
825 @property
826 def v(self):
827 return self.data_i
828
829 def chain_inv(self, in_op, extra=None):
830 stb = in_op.stb
831 if extra is not None:
832 stb = stb & extra
833 return [self.v.eq(in_op.v), # receive value
834 self.stb.eq(stb), # receive STB
835 in_op.ack.eq(~self.ack), # send ACK
836 ]
837
838 def chain_from(self, in_op, extra=None):
839 stb = in_op.stb
840 if extra is not None:
841 stb = stb & extra
842 return [self.v.eq(in_op.v), # receive value
843 self.stb.eq(stb), # receive STB
844 in_op.ack.eq(self.ack), # send ACK
845 ]
846
847
848 class FPOpOut(NextControl):
849 def __init__(self, width):
850 NextControl.__init__(self)
851 self.width = width
852
853 @property
854 def v(self):
855 return self.data_o
856
857 def chain_inv(self, in_op, extra=None):
858 stb = in_op.stb
859 if extra is not None:
860 stb = stb & extra
861 return [self.v.eq(in_op.v), # receive value
862 self.stb.eq(stb), # receive STB
863 in_op.ack.eq(~self.ack), # send ACK
864 ]
865
866 def chain_from(self, in_op, extra=None):
867 stb = in_op.stb
868 if extra is not None:
869 stb = stb & extra
870 return [self.v.eq(in_op.v), # receive value
871 self.stb.eq(stb), # receive STB
872 in_op.ack.eq(self.ack), # send ACK
873 ]
874
875
876 class Overflow:
877 def __init__(self, name=None):
878 if name is None:
879 name = ""
880 self.guard = Signal(reset_less=True, name=name+"guard") # tot[2]
881 self.round_bit = Signal(reset_less=True, name=name+"round") # tot[1]
882 self.sticky = Signal(reset_less=True, name=name+"sticky") # tot[0]
883 self.m0 = Signal(reset_less=True, name=name+"m0") # mantissa bit 0
884
885 #self.roundz = Signal(reset_less=True)
886
887 def __iter__(self):
888 yield self.guard
889 yield self.round_bit
890 yield self.sticky
891 yield self.m0
892
893 def eq(self, inp):
894 return [self.guard.eq(inp.guard),
895 self.round_bit.eq(inp.round_bit),
896 self.sticky.eq(inp.sticky),
897 self.m0.eq(inp.m0)]
898
899 @property
900 def roundz(self):
901 return self.guard & (self.round_bit | self.sticky | self.m0)
902
903
904 class OverflowMod(Elaboratable, Overflow):
905 def __init__(self, name=None):
906 Overflow.__init__(self, name)
907 if name is None:
908 name = ""
909 self.roundz_out = Signal(reset_less=True, name=name+"roundz_out")
910
911 def __iter__(self):
912 yield from Overflow.__iter__(self)
913 yield self.roundz_out
914
915 def eq(self, inp):
916 return [self.roundz_out.eq(inp.roundz_out)] + Overflow.eq(self)
917
918 def elaborate(self, platform):
919 m = Module()
920 m.d.comb += self.roundz_out.eq(self.roundz)
921 return m
922
923
924 class FPBase:
925 """ IEEE754 Floating Point Base Class
926
927 contains common functions for FP manipulation, such as
928 extracting and packing operands, normalisation, denormalisation,
929 rounding etc.
930 """
931
932 def get_op(self, m, op, v, next_state):
933 """ this function moves to the next state and copies the operand
934 when both stb and ack are 1.
935 acknowledgement is sent by setting ack to ZERO.
936 """
937 res = v.decode2(m)
938 ack = Signal()
939 with m.If((op.ready_o) & (op.valid_i_test)):
940 m.next = next_state
941 # op is latched in from FPNumIn class on same ack/stb
942 m.d.comb += ack.eq(0)
943 with m.Else():
944 m.d.comb += ack.eq(1)
945 return [res, ack]
946
947 def denormalise(self, m, a):
948 """ denormalises a number. this is probably the wrong name for
949 this function. for normalised numbers (exponent != minimum)
950 one *extra* bit (the implicit 1) is added *back in*.
951 for denormalised numbers, the mantissa is left alone
952 and the exponent increased by 1.
953
954 both cases *effectively multiply the number stored by 2*,
955 which has to be taken into account when extracting the result.
956 """
957 with m.If(a.exp_n127):
958 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
959 with m.Else():
960 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
961
962 def op_normalise(self, m, op, next_state):
963 """ operand normalisation
964 NOTE: just like "align", this one keeps going round every clock
965 until the result's exponent is within acceptable "range"
966 """
967 with m.If((op.m[-1] == 0)): # check last bit of mantissa
968 m.d.sync += [
969 op.e.eq(op.e - 1), # DECREASE exponent
970 op.m.eq(op.m << 1), # shift mantissa UP
971 ]
972 with m.Else():
973 m.next = next_state
974
975 def normalise_1(self, m, z, of, next_state):
976 """ first stage normalisation
977
978 NOTE: just like "align", this one keeps going round every clock
979 until the result's exponent is within acceptable "range"
980 NOTE: the weirdness of reassigning guard and round is due to
981 the extra mantissa bits coming from tot[0..2]
982 """
983 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
984 m.d.sync += [
985 z.e.eq(z.e - 1), # DECREASE exponent
986 z.m.eq(z.m << 1), # shift mantissa UP
987 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
988 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
989 of.round_bit.eq(0), # reset round bit
990 of.m0.eq(of.guard),
991 ]
992 with m.Else():
993 m.next = next_state
994
995 def normalise_2(self, m, z, of, next_state):
996 """ second stage normalisation
997
998 NOTE: just like "align", this one keeps going round every clock
999 until the result's exponent is within acceptable "range"
1000 NOTE: the weirdness of reassigning guard and round is due to
1001 the extra mantissa bits coming from tot[0..2]
1002 """
1003 with m.If(z.e < z.fp.N126):
1004 m.d.sync += [
1005 z.e.eq(z.e + 1), # INCREASE exponent
1006 z.m.eq(z.m >> 1), # shift mantissa DOWN
1007 of.guard.eq(z.m[0]),
1008 of.m0.eq(z.m[1]),
1009 of.round_bit.eq(of.guard),
1010 of.sticky.eq(of.sticky | of.round_bit)
1011 ]
1012 with m.Else():
1013 m.next = next_state
1014
1015 def roundz(self, m, z, roundz):
1016 """ performs rounding on the output. TODO: different kinds of rounding
1017 """
1018 with m.If(roundz):
1019 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
1020 with m.If(z.m == z.fp.m1s): # all 1s
1021 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
1022
1023 def corrections(self, m, z, next_state):
1024 """ denormalisation and sign-bug corrections
1025 """
1026 m.next = next_state
1027 # denormalised, correct exponent to zero
1028 with m.If(z.is_denormalised):
1029 m.d.sync += z.e.eq(z.fp.N127)
1030
1031 def pack(self, m, z, next_state):
1032 """ packs the result into the output (detects overflow->Inf)
1033 """
1034 m.next = next_state
1035 # if overflow occurs, return inf
1036 with m.If(z.is_overflowed):
1037 m.d.sync += z.inf(z.s)
1038 with m.Else():
1039 m.d.sync += z.create(z.s, z.e, z.m)
1040
1041 def put_z(self, m, z, out_z, next_state):
1042 """ put_z: stores the result in the output. raises stb and waits
1043 for ack to be set to 1 before moving to the next state.
1044 resets stb back to zero when that occurs, as acknowledgement.
1045 """
1046 m.d.sync += [
1047 out_z.v.eq(z.v)
1048 ]
1049 with m.If(out_z.valid_o & out_z.ready_i_test):
1050 m.d.sync += out_z.valid_o.eq(0)
1051 m.next = next_state
1052 with m.Else():
1053 m.d.sync += out_z.valid_o.eq(1)
1054
1055
1056 class FPState(FPBase):
1057 def __init__(self, state_from):
1058 self.state_from = state_from
1059
1060 def set_inputs(self, inputs):
1061 self.inputs = inputs
1062 for k, v in inputs.items():
1063 setattr(self, k, v)
1064
1065 def set_outputs(self, outputs):
1066 self.outputs = outputs
1067 for k, v in outputs.items():
1068 setattr(self, k, v)
1069
1070
1071 class FPID:
1072 def __init__(self, id_wid):
1073 self.id_wid = id_wid
1074 if self.id_wid:
1075 self.in_mid = Signal(id_wid, reset_less=True)
1076 self.out_mid = Signal(id_wid, reset_less=True)
1077 else:
1078 self.in_mid = None
1079 self.out_mid = None
1080
1081 def idsync(self, m):
1082 if self.id_wid is not None:
1083 m.d.sync += self.out_mid.eq(self.in_mid)
1084
1085
1086 if __name__ == '__main__':
1087 unittest.main()