1 """IEEE754 Floating Point Library
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2022 Jacob Lifshay <programmerjake@gmail.com>
9 from nmigen
import (Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
, Array
,
10 Value
, Shape
, signed
, unsigned
)
11 from nmigen
.utils
import bits_for
12 from operator
import or_
13 from functools
import reduce
15 from nmutil
.singlepipe
import PrevControl
, NextControl
16 from nmutil
.pipeline
import ObjectProxy
22 from nmigen
.hdl
.smtlib2
import RoundingModeEnum
27 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
31 class FPRoundingMode(enum
.Enum
):
32 # matches the FPSCR.RN field values, but includes some extra
33 # values (>= 0b100) used in miscellaneous instructions.
35 # naming matches smtlib2 names, doc strings are the OpenPower ISA
36 # specification's names (v3.1 section 7.3.2.6 --
37 # matches values in section 4.3.6).
39 """Round to Nearest Even
41 Rounds to the nearest representable floating-point number, ties are
42 rounded to the number with the even mantissa. Treats +-Infinity as if
43 it were a normalized floating-point number when deciding which number
44 is closer when rounding. See IEEE754 spec. for details.
47 ROUND_NEAREST_TIES_TO_EVEN
= RNE
53 If the result is exactly representable as a floating-point number, return
54 that, otherwise return the nearest representable floating-point value
55 with magnitude smaller than the exact answer.
58 ROUND_TOWARDS_ZERO
= RTZ
61 """Round towards +Infinity
63 If the result is exactly representable as a floating-point number, return
64 that, otherwise return the nearest representable floating-point value
65 that is numerically greater than the exact answer. This can round up to
69 ROUND_TOWARDS_POSITIVE
= RTP
72 """Round towards -Infinity
74 If the result is exactly representable as a floating-point number, return
75 that, otherwise return the nearest representable floating-point value
76 that is numerically less than the exact answer. This can round down to
80 ROUND_TOWARDS_NEGATIVE
= RTN
83 """Round to Nearest Away
85 Rounds to the nearest representable floating-point number, ties are
86 rounded to the number with the maximum magnitude. Treats +-Infinity as if
87 it were a normalized floating-point number when deciding which number
88 is closer when rounding. See IEEE754 spec. for details.
91 ROUND_NEAREST_TIES_TO_AWAY
= RNA
94 """Round to Odd, unsigned zeros are Positive
98 If the result is exactly representable as a floating-point number, return
99 that, otherwise return the nearest representable floating-point value
100 that has an odd mantissa.
102 If the result is zero but with otherwise undetermined sign
103 (e.g. `1.0 - 1.0`), the sign is positive.
105 This rounding mode is used for instructions with Round To Odd enabled,
106 and `FPSCR.RN != RTN`.
108 This is useful to avoid double-rounding errors when doing arithmetic in a
109 larger type (e.g. f128) but where the answer should be a smaller type
113 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE
= RTOP
116 """Round to Odd, unsigned zeros are Negative
120 If the result is exactly representable as a floating-point number, return
121 that, otherwise return the nearest representable floating-point value
122 that has an odd mantissa.
124 If the result is zero but with otherwise undetermined sign
125 (e.g. `1.0 - 1.0`), the sign is negative.
127 This rounding mode is used for instructions with Round To Odd enabled,
128 and `FPSCR.RN == RTN`.
130 This is useful to avoid double-rounding errors when doing arithmetic in a
131 larger type (e.g. f128) but where the answer should be a smaller type
135 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE
= RTON
139 l
= [None] * len(FPRoundingMode
)
140 for rm
in FPRoundingMode
:
144 def overflow_rounds_to_inf(self
, sign
):
145 """returns true if an overflow should round to `inf`,
146 false if it should round to `max_normal`
148 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
149 if self
is FPRoundingMode
.RNE
:
151 elif self
is FPRoundingMode
.RTZ
:
153 elif self
is FPRoundingMode
.RTP
:
155 elif self
is FPRoundingMode
.RTN
:
157 elif self
is FPRoundingMode
.RNA
:
159 elif self
is FPRoundingMode
.RTOP
:
162 assert self
is FPRoundingMode
.RTON
165 def underflow_rounds_to_zero(self
, sign
):
166 """returns true if an underflow should round to `zero`,
167 false if it should round to `min_denormal`
169 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
170 if self
is FPRoundingMode
.RNE
:
172 elif self
is FPRoundingMode
.RTZ
:
174 elif self
is FPRoundingMode
.RTP
:
176 elif self
is FPRoundingMode
.RTN
:
178 elif self
is FPRoundingMode
.RNA
:
180 elif self
is FPRoundingMode
.RTOP
:
183 assert self
is FPRoundingMode
.RTON
187 """which sign an exact zero result should have when it isn't
188 otherwise determined, e.g. for `1.0 - 1.0`.
190 if self
is FPRoundingMode
.RNE
:
192 elif self
is FPRoundingMode
.RTZ
:
194 elif self
is FPRoundingMode
.RTP
:
196 elif self
is FPRoundingMode
.RTN
:
198 elif self
is FPRoundingMode
.RNA
:
200 elif self
is FPRoundingMode
.RTOP
:
203 assert self
is FPRoundingMode
.RTON
207 def to_smtlib2(self
, default
=_raise_err
):
208 """return the corresponding smtlib2 rounding mode for `self`. If
209 there is no corresponding smtlib2 rounding mode, then return
210 `default` if specified, else raise `ValueError`.
212 if self
is FPRoundingMode
.RNE
:
213 return RoundingModeEnum
.RNE
214 elif self
is FPRoundingMode
.RTZ
:
215 return RoundingModeEnum
.RTZ
216 elif self
is FPRoundingMode
.RTP
:
217 return RoundingModeEnum
.RTP
218 elif self
is FPRoundingMode
.RTN
:
219 return RoundingModeEnum
.RTN
220 elif self
is FPRoundingMode
.RNA
:
221 return RoundingModeEnum
.RNA
223 assert self
in (FPRoundingMode
.RTOP
, FPRoundingMode
.RTON
)
224 if default
is _raise_err
:
226 "no corresponding smtlib2 rounding mode", self
)
233 """ Class describing binary floating-point formats based on IEEE 754.
235 :attribute e_width: the number of bits in the exponent field.
236 :attribute m_width: the number of bits stored in the mantissa
238 :attribute has_int_bit: if the FP format has an explicit integer bit (like
239 the x87 80-bit format). The bit is considered part of the mantissa.
240 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
241 Image/Buffer formats are FP numbers without a sign bit.)
249 """ Create ``FPFormat`` instance. """
250 self
.e_width
= e_width
251 self
.m_width
= m_width
252 self
.has_int_bit
= has_int_bit
253 self
.has_sign
= has_sign
255 def __eq__(self
, other
):
256 """ Check for equality. """
257 if not isinstance(other
, FPFormat
):
258 return NotImplemented
259 return (self
.e_width
== other
.e_width
260 and self
.m_width
== other
.m_width
261 and self
.has_int_bit
== other
.has_int_bit
262 and self
.has_sign
== other
.has_sign
)
266 """ Get standard IEEE 754-2008 format.
268 :param width: bit-width of requested format.
269 :returns: the requested ``FPFormat`` instance.
272 return FPFormat(5, 10)
274 return FPFormat(8, 23)
276 return FPFormat(11, 52)
278 return FPFormat(15, 112)
279 if width
> 128 and width
% 32 == 0:
280 if width
> 1000000: # arbitrary upper limit
281 raise ValueError("width too big")
282 e_width
= round(4 * math
.log2(width
)) - 13
283 return FPFormat(e_width
, width
- 1 - e_width
)
284 raise ValueError("width must be the bit-width of a valid IEEE"
285 " 754-2008 binary format")
290 if self
== self
.standard(self
.width
):
291 return f
"FPFormat.standard({self.width})"
294 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
295 if self
.has_int_bit
is not False:
296 retval
+= f
", {self.has_int_bit}"
297 if self
.has_sign
is not True:
298 retval
+= f
", {self.has_sign}"
301 def get_sign_field(self
, x
):
302 """ returns the sign bit of its input number, x
303 (assumes FPFormat is set to signed - has_sign=True)
305 return x
>> (self
.e_width
+ self
.m_width
)
307 def get_exponent_field(self
, x
):
308 """ returns the raw exponent of its input number, x (no bias subtracted)
310 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
313 def get_exponent(self
, x
):
314 """ returns the exponent of its input number, x
316 x
= self
.get_exponent_field(x
)
317 if isinstance(x
, Value
) and not x
.shape().signed
:
318 # convert x to signed without changing its value,
319 # since exponents can be negative
320 x |
= Const(0, signed(1))
321 return x
- self
.exponent_bias
323 def get_exponent_value(self
, x
):
324 """ returns the exponent of its input number, x, adjusted for the
325 mathematically correct subnormal exponent.
327 x
= self
.get_exponent_field(x
)
328 if isinstance(x
, Value
) and not x
.shape().signed
:
329 # convert x to signed without changing its value,
330 # since exponents can be negative
331 x |
= Const(0, signed(1))
332 return x
+ (x
== self
.exponent_denormal_zero
) - self
.exponent_bias
334 def get_mantissa_field(self
, x
):
335 """ returns the mantissa of its input number, x
337 return x
& self
.mantissa_mask
339 def get_mantissa_value(self
, x
):
340 """ returns the mantissa of its input number, x, but with the
341 implicit bit, if any, made explicit.
344 return self
.get_mantissa_field(x
)
345 exponent_field
= self
.get_exponent_field(x
)
346 mantissa_field
= self
.get_mantissa_field(x
)
347 implicit_bit
= exponent_field
!= self
.exponent_denormal_zero
348 return (implicit_bit
<< self
.fraction_width
) | mantissa_field
350 def is_zero(self
, x
):
351 """ returns true if x is +/- zero
353 return (self
.get_exponent(x
) == self
.e_sub
) & \
354 (self
.get_mantissa_field(x
) == 0)
356 def is_subnormal(self
, x
):
357 """ returns true if x is subnormal (exp at minimum)
359 return (self
.get_exponent(x
) == self
.e_sub
) & \
360 (self
.get_mantissa_field(x
) != 0)
363 """ returns true if x is infinite
365 return (self
.get_exponent(x
) == self
.e_max
) & \
366 (self
.get_mantissa_field(x
) == 0)
369 """ returns true if x is a nan (quiet or signalling)
371 return (self
.get_exponent(x
) == self
.e_max
) & \
372 (self
.get_mantissa_field(x
) != 0)
374 def is_quiet_nan(self
, x
):
375 """ returns true if x is a quiet nan
377 highbit
= 1 << (self
.m_width
- 1)
378 return (self
.get_exponent(x
) == self
.e_max
) & \
379 (self
.get_mantissa_field(x
) != 0) & \
380 (self
.get_mantissa_field(x
) & highbit
!= 0)
382 def to_quiet_nan(self
, x
):
383 """ converts `x` to a quiet NaN """
384 highbit
= 1 << (self
.m_width
- 1)
385 return x | highbit | self
.exponent_mask
387 def quiet_nan(self
, sign
=0):
388 """ return the default quiet NaN with sign `sign` """
389 return self
.to_quiet_nan(self
.zero(sign
))
391 def zero(self
, sign
=0):
392 """ return zero with sign `sign` """
393 return (sign
!= 0) << (self
.e_width
+ self
.m_width
)
395 def inf(self
, sign
=0):
396 """ return infinity with sign `sign` """
397 return self
.zero(sign
) | self
.exponent_mask
399 def is_nan_signaling(self
, x
):
400 """ returns true if x is a signalling nan
402 highbit
= 1 << (self
.m_width
- 1)
403 return (self
.get_exponent(x
) == self
.e_max
) & \
404 (self
.get_mantissa_field(x
) != 0) & \
405 (self
.get_mantissa_field(x
) & highbit
) == 0
409 """ Get the total number of bits in the FP format. """
410 return self
.has_sign
+ self
.e_width
+ self
.m_width
413 def mantissa_mask(self
):
414 """ Get a mantissa mask based on the mantissa width """
415 return (1 << self
.m_width
) - 1
418 def exponent_mask(self
):
419 """ Get an exponent mask """
420 return self
.exponent_inf_nan
<< self
.m_width
423 def exponent_inf_nan(self
):
424 """ Get the value of the exponent field designating infinity/NaN. """
425 return (1 << self
.e_width
) - 1
429 """ get the maximum exponent (minus bias)
431 return self
.exponent_inf_nan
- self
.exponent_bias
435 return self
.exponent_denormal_zero
- self
.exponent_bias
437 def exponent_denormal_zero(self
):
438 """ Get the value of the exponent field designating denormal/zero. """
442 def exponent_min_normal(self
):
443 """ Get the minimum value of the exponent field for normal numbers. """
447 def exponent_max_normal(self
):
448 """ Get the maximum value of the exponent field for normal numbers. """
449 return self
.exponent_inf_nan
- 1
452 def exponent_bias(self
):
453 """ Get the exponent bias. """
454 return (1 << (self
.e_width
- 1)) - 1
457 def fraction_width(self
):
458 """ Get the number of mantissa bits that are fraction bits. """
459 return self
.m_width
- self
.has_int_bit
462 def from_pspec(pspec
):
463 width
= getattr(pspec
, "width", None)
464 assert width
is None or isinstance(width
, int)
465 fpformat
= getattr(pspec
, "fpformat", None)
467 assert width
is not None, \
468 "neither pspec.width nor pspec.fpformat were set"
469 fpformat
= FPFormat
.standard(width
)
471 assert isinstance(fpformat
, FPFormat
)
472 assert width
== fpformat
.width
476 class TestFPFormat(unittest
.TestCase
):
477 """ very quick test for FPFormat
480 def test_fpformat_fp64(self
):
481 f64
= FPFormat
.standard(64)
482 from sfpy
import Float64
483 x
= Float64(1.0).bits
486 self
.assertEqual(f64
.get_exponent(x
), 0)
487 x
= Float64(2.0).bits
489 self
.assertEqual(f64
.get_exponent(x
), 1)
491 x
= Float64(1.5).bits
492 m
= f64
.get_mantissa_field(x
)
493 print (hex(x
), hex(m
))
494 self
.assertEqual(m
, 0x8000000000000)
496 s
= f64
.get_sign_field(x
)
497 print (hex(x
), hex(s
))
498 self
.assertEqual(s
, 0)
500 x
= Float64(-1.5).bits
501 s
= f64
.get_sign_field(x
)
502 print (hex(x
), hex(s
))
503 self
.assertEqual(s
, 1)
505 def test_fpformat_fp32(self
):
506 f32
= FPFormat
.standard(32)
507 from sfpy
import Float32
508 x
= Float32(1.0).bits
511 self
.assertEqual(f32
.get_exponent(x
), 0)
512 x
= Float32(2.0).bits
514 self
.assertEqual(f32
.get_exponent(x
), 1)
516 x
= Float32(1.5).bits
517 m
= f32
.get_mantissa_field(x
)
518 print (hex(x
), hex(m
))
519 self
.assertEqual(m
, 0x400000)
522 x
= Float32(-1.0).sqrt()
525 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
526 f32
.get_mantissa_field(x
), i
)
527 self
.assertEqual(i
, True)
530 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
533 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
534 f32
.get_mantissa_field(x
), i
)
535 self
.assertEqual(i
, True)
540 i
= f32
.is_subnormal(x
)
541 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
542 f32
.get_mantissa_field(x
), i
)
543 self
.assertEqual(i
, True)
547 i
= f32
.is_subnormal(x
)
548 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
549 f32
.get_mantissa_field(x
), i
)
550 self
.assertEqual(i
, False)
554 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
555 f32
.get_mantissa_field(x
), i
)
556 self
.assertEqual(i
, True)
559 class MultiShiftR(Elaboratable
):
561 def __init__(self
, width
):
563 self
.smax
= bits_for(width
- 1)
564 self
.i
= Signal(width
, reset_less
=True)
565 self
.s
= Signal(self
.smax
, reset_less
=True)
566 self
.o
= Signal(width
, reset_less
=True)
568 def elaborate(self
, platform
):
570 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
575 """ Generates variable-length single-cycle shifter from a series
576 of conditional tests on each bit of the left/right shift operand.
577 Each bit tested produces output shifted by that number of bits,
578 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
579 shifts by 2 bits, each partial result cascading to the next Mux.
581 Could be adapted to do arithmetic shift by taking copies of the
582 MSB instead of zeros.
585 def __init__(self
, width
):
587 self
.smax
= bits_for(width
- 1)
589 def lshift(self
, op
, s
):
593 def rshift(self
, op
, s
):
598 class FPNumBaseRecord
:
599 """ Floating-point Base Number Class.
601 This class is designed to be passed around in other data structures
602 (between pipelines and between stages). Its "friend" is FPNumBase,
603 which is a *module*. The reason for the discernment is because
604 nmigen modules that are not added to submodules results in the
605 irritating "Elaboration" warning. Despite not *needing* FPNumBase
606 in many cases to be added as a submodule (because it is just data)
607 this was not possible to solve without splitting out the data from
611 def __init__(self
, width
=None, m_extra
=True, e_extra
=False, name
=None,
615 # assert false, "missing name"
619 assert isinstance(width
, int)
620 fpformat
= FPFormat
.standard(width
)
622 assert isinstance(fpformat
, FPFormat
)
624 width
= fpformat
.width
625 assert isinstance(width
, int)
626 assert width
== fpformat
.width
628 self
.fpformat
= fpformat
629 assert not fpformat
.has_int_bit
630 assert fpformat
.has_sign
631 m_width
= fpformat
.m_width
+ 1 # 1 extra bit (overflow)
632 e_width
= fpformat
.e_width
+ 2 # 2 extra bits (overflow)
633 e_max
= 1 << (e_width
-3)
634 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
637 # mantissa extra bits (top,guard,round)
639 m_width
+= self
.m_extra
643 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
644 e_width
+= self
.e_extra
647 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
648 self
.m_width
= m_width
649 self
.e_width
= e_width
650 self
.e_start
= self
.rmw
651 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
653 self
.v
= Signal(width
, reset_less
=True,
654 name
=name
+"v") # Latched copy of value
655 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
656 self
.e
= Signal(signed(e_width
),
657 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
658 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
663 def drop_in(self
, fp
):
669 fp
.width
= self
.width
670 fp
.e_width
= self
.e_width
671 fp
.e_max
= self
.e_max
672 fp
.m_width
= self
.m_width
673 fp
.e_start
= self
.e_start
674 fp
.e_end
= self
.e_end
675 fp
.m_extra
= self
.m_extra
677 m_width
= self
.m_width
679 e_width
= self
.e_width
681 self
.mzero
= Const(0, unsigned(m_width
))
682 m_msb
= 1 << (self
.m_width
-2)
683 self
.msb1
= Const(m_msb
, unsigned(m_width
))
684 self
.m1s
= Const(-1, unsigned(m_width
))
685 self
.P128
= Const(e_max
, signed(e_width
))
686 self
.P127
= Const(e_max
-1, signed(e_width
))
687 self
.N127
= Const(-(e_max
-1), signed(e_width
))
688 self
.N126
= Const(-(e_max
-2), signed(e_width
))
690 def create(self
, s
, e
, m
):
691 """ creates a value from sign / exponent / mantissa
693 bias is added here, to the exponent.
695 NOTE: order is important, because e_start/e_end can be
696 a bit too long (overwriting s).
699 self
.v
[0:self
.e_start
].eq(m
), # mantissa
700 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
701 self
.v
[-1].eq(s
), # sign
705 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
708 return (s
, self
.fp
.P128
, 0)
711 return (s
, self
.fp
.N127
, 0)
714 return self
.create(*self
._nan
(s
))
716 def quieted_nan(self
, other
):
717 assert isinstance(other
, FPNumBaseRecord
)
718 assert self
.width
== other
.width
719 return self
.create(other
.s
, self
.fp
.P128
,
720 other
.v
[0:self
.e_start
] |
(1 << (self
.e_start
- 1)))
723 return self
.create(*self
._inf
(s
))
725 def max_normal(self
, s
):
726 return self
.create(s
, self
.fp
.P127
, ~
0)
728 def min_denormal(self
, s
):
729 return self
.create(s
, self
.fp
.N127
, 1)
732 return self
.create(*self
._zero
(s
))
734 def create2(self
, s
, e
, m
):
735 """ creates a value from sign / exponent / mantissa
737 bias is added here, to the exponent
739 e
= e
+ self
.P127
# exp (add on bias)
740 return Cat(m
[0:self
.e_start
],
741 e
[0:self
.e_end
-self
.e_start
],
745 return self
.create2(s
, self
.P128
, self
.msb1
)
748 return self
.create2(s
, self
.P128
, self
.mzero
)
751 return self
.create2(s
, self
.N127
, self
.mzero
)
759 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
762 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
763 """ Floating-point Base Number Class
766 def __init__(self
, fp
):
771 self
.is_nan
= Signal(reset_less
=True)
772 self
.is_zero
= Signal(reset_less
=True)
773 self
.is_inf
= Signal(reset_less
=True)
774 self
.is_overflowed
= Signal(reset_less
=True)
775 self
.is_denormalised
= Signal(reset_less
=True)
776 self
.exp_128
= Signal(reset_less
=True)
777 self
.exp_sub_n126
= Signal(signed(e_width
), reset_less
=True)
778 self
.exp_lt_n126
= Signal(reset_less
=True)
779 self
.exp_zero
= Signal(reset_less
=True)
780 self
.exp_gt_n126
= Signal(reset_less
=True)
781 self
.exp_gt127
= Signal(reset_less
=True)
782 self
.exp_n127
= Signal(reset_less
=True)
783 self
.exp_n126
= Signal(reset_less
=True)
784 self
.m_zero
= Signal(reset_less
=True)
785 self
.m_msbzero
= Signal(reset_less
=True)
787 def elaborate(self
, platform
):
789 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
790 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
791 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
792 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
793 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
794 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
795 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
796 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
797 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
798 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
799 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
800 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
801 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
802 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
803 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
808 return (self
.exp_128
) & (~self
.m_zero
)
811 return (self
.exp_128
) & (self
.m_zero
)
814 return (self
.exp_n127
) & (self
.m_zero
)
816 def _is_overflowed(self
):
817 return self
.exp_gt127
819 def _is_denormalised(self
):
820 # XXX NOT to be used for "official" quiet NaN tests!
821 # particularly when the MSB has been extended
822 return (self
.exp_n126
) & (self
.m_msbzero
)
825 class FPNumOut(FPNumBase
):
826 """ Floating-point Number Class
828 Contains signals for an incoming copy of the value, decoded into
829 sign / exponent / mantissa.
830 Also contains encoding functions, creation and recognition of
831 zero, NaN and inf (all signed)
833 Four extra bits are included in the mantissa: the top bit
834 (m[-1]) is effectively a carry-overflow. The other three are
835 guard (m[2]), round (m[1]), and sticky (m[0])
838 def __init__(self
, fp
):
839 FPNumBase
.__init
__(self
, fp
)
841 def elaborate(self
, platform
):
842 m
= FPNumBase
.elaborate(self
, platform
)
847 class MultiShiftRMerge(Elaboratable
):
848 """ shifts down (right) and merges lower bits into m[0].
849 m[0] is the "sticky" bit, basically
852 def __init__(self
, width
, s_max
=None):
854 s_max
= bits_for(width
- 1)
855 self
.smax
= Shape
.cast(s_max
)
856 self
.m
= Signal(width
, reset_less
=True)
857 self
.inp
= Signal(width
, reset_less
=True)
858 self
.diff
= Signal(s_max
, reset_less
=True)
861 def elaborate(self
, platform
):
864 rs
= Signal(self
.width
, reset_less
=True)
865 m_mask
= Signal(self
.width
, reset_less
=True)
866 smask
= Signal(self
.width
, reset_less
=True)
867 stickybit
= Signal(reset_less
=True)
868 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
869 maxslen
= Signal(self
.smax
.width
, reset_less
=True)
870 maxsleni
= Signal(self
.smax
.width
, reset_less
=True)
872 sm
= MultiShift(self
.width
-1)
873 m0s
= Const(0, self
.width
-1)
874 mw
= Const(self
.width
-1, len(self
.diff
))
875 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
876 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
880 # shift mantissa by maxslen, mask by inverse
881 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
882 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
883 smask
.eq(self
.inp
[1:] & m_mask
),
884 # sticky bit combines all mask (and mantissa low bit)
885 stickybit
.eq(smask
.bool() | self
.inp
[0]),
886 # mantissa result contains m[0] already.
887 self
.m
.eq(Cat(stickybit
, rs
))
892 class FPNumShift(FPNumBase
, Elaboratable
):
893 """ Floating-point Number Class for shifting
896 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
897 FPNumBase
.__init
__(self
, width
, m_extra
)
898 self
.latch_in
= Signal()
903 def elaborate(self
, platform
):
904 m
= FPNumBase
.elaborate(self
, platform
)
906 m
.d
.comb
+= self
.s
.eq(op
.s
)
907 m
.d
.comb
+= self
.e
.eq(op
.e
)
908 m
.d
.comb
+= self
.m
.eq(op
.m
)
910 with self
.mainm
.State("align"):
911 with m
.If(self
.e
< self
.inv
.e
):
912 m
.d
.sync
+= self
.shift_down()
916 def shift_down(self
, inp
):
917 """ shifts a mantissa down by one. exponent is increased to compensate
919 accuracy is lost as a result in the mantissa however there are 3
920 guard bits (the latter of which is the "sticky" bit)
922 return [self
.e
.eq(inp
.e
+ 1),
923 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
926 def shift_down_multi(self
, diff
):
927 """ shifts a mantissa down. exponent is increased to compensate
929 accuracy is lost as a result in the mantissa however there are 3
930 guard bits (the latter of which is the "sticky" bit)
932 this code works by variable-shifting the mantissa by up to
933 its maximum bit-length: no point doing more (it'll still be
936 the sticky bit is computed by shifting a batch of 1s by
937 the same amount, which will introduce zeros. it's then
938 inverted and used as a mask to get the LSBs of the mantissa.
939 those are then |'d into the sticky bit.
941 sm
= MultiShift(self
.width
)
942 mw
= Const(self
.m_width
-1, len(diff
))
943 maxslen
= Mux(diff
> mw
, mw
, diff
)
944 rs
= sm
.rshift(self
.m
[1:], maxslen
)
945 maxsleni
= mw
- maxslen
946 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
948 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
949 return [self
.e
.eq(self
.e
+ diff
),
950 self
.m
.eq(Cat(stickybits
, rs
))
953 def shift_up_multi(self
, diff
):
954 """ shifts a mantissa up. exponent is decreased to compensate
956 sm
= MultiShift(self
.width
)
957 mw
= Const(self
.m_width
, len(diff
))
958 maxslen
= Mux(diff
> mw
, mw
, diff
)
960 return [self
.e
.eq(self
.e
- diff
),
961 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
965 class FPNumDecode(FPNumBase
):
966 """ Floating-point Number Class
968 Contains signals for an incoming copy of the value, decoded into
969 sign / exponent / mantissa.
970 Also contains encoding functions, creation and recognition of
971 zero, NaN and inf (all signed)
973 Four extra bits are included in the mantissa: the top bit
974 (m[-1]) is effectively a carry-overflow. The other three are
975 guard (m[2]), round (m[1]), and sticky (m[0])
978 def __init__(self
, op
, fp
):
979 FPNumBase
.__init
__(self
, fp
)
982 def elaborate(self
, platform
):
983 m
= FPNumBase
.elaborate(self
, platform
)
985 m
.d
.comb
+= self
.decode(self
.v
)
990 """ decodes a latched value into sign / exponent / mantissa
992 bias is subtracted here, from the exponent. exponent
993 is extended to 10 bits so that subtract 127 is done on
996 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
997 #print ("decode", self.e_end)
998 return [self
.m
.eq(Cat(*args
)), # mantissa
999 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
1000 self
.s
.eq(v
[-1]), # sign
1004 class FPNumIn(FPNumBase
):
1005 """ Floating-point Number Class
1007 Contains signals for an incoming copy of the value, decoded into
1008 sign / exponent / mantissa.
1009 Also contains encoding functions, creation and recognition of
1010 zero, NaN and inf (all signed)
1012 Four extra bits are included in the mantissa: the top bit
1013 (m[-1]) is effectively a carry-overflow. The other three are
1014 guard (m[2]), round (m[1]), and sticky (m[0])
1017 def __init__(self
, op
, fp
):
1018 FPNumBase
.__init
__(self
, fp
)
1019 self
.latch_in
= Signal()
1022 def decode2(self
, m
):
1023 """ decodes a latched value into sign / exponent / mantissa
1025 bias is subtracted here, from the exponent. exponent
1026 is extended to 10 bits so that subtract 127 is done on
1030 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
1031 #print ("decode", self.e_end)
1032 res
= ObjectProxy(m
, pipemode
=False)
1033 res
.m
= Cat(*args
) # mantissa
1034 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
1035 res
.s
= v
[-1] # sign
1038 def decode(self
, v
):
1039 """ decodes a latched value into sign / exponent / mantissa
1041 bias is subtracted here, from the exponent. exponent
1042 is extended to 10 bits so that subtract 127 is done on
1045 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
1046 #print ("decode", self.e_end)
1047 return [self
.m
.eq(Cat(*args
)), # mantissa
1048 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
1049 self
.s
.eq(v
[-1]), # sign
1052 def shift_down(self
, inp
):
1053 """ shifts a mantissa down by one. exponent is increased to compensate
1055 accuracy is lost as a result in the mantissa however there are 3
1056 guard bits (the latter of which is the "sticky" bit)
1058 return [self
.e
.eq(inp
.e
+ 1),
1059 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
1062 def shift_down_multi(self
, diff
, inp
=None):
1063 """ shifts a mantissa down. exponent is increased to compensate
1065 accuracy is lost as a result in the mantissa however there are 3
1066 guard bits (the latter of which is the "sticky" bit)
1068 this code works by variable-shifting the mantissa by up to
1069 its maximum bit-length: no point doing more (it'll still be
1072 the sticky bit is computed by shifting a batch of 1s by
1073 the same amount, which will introduce zeros. it's then
1074 inverted and used as a mask to get the LSBs of the mantissa.
1075 those are then |'d into the sticky bit.
1079 sm
= MultiShift(self
.width
)
1080 mw
= Const(self
.m_width
-1, len(diff
))
1081 maxslen
= Mux(diff
> mw
, mw
, diff
)
1082 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
1083 maxsleni
= mw
- maxslen
1084 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
1086 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1087 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
1088 return [self
.e
.eq(inp
.e
+ diff
),
1089 self
.m
.eq(Cat(stickybit
, rs
))
1092 def shift_up_multi(self
, diff
):
1093 """ shifts a mantissa up. exponent is decreased to compensate
1095 sm
= MultiShift(self
.width
)
1096 mw
= Const(self
.m_width
, len(diff
))
1097 maxslen
= Mux(diff
> mw
, mw
, diff
)
1099 return [self
.e
.eq(self
.e
- diff
),
1100 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
1104 class Trigger(Elaboratable
):
1107 self
.stb
= Signal(reset
=0)
1109 self
.trigger
= Signal(reset_less
=True)
1111 def elaborate(self
, platform
):
1113 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
1117 return [self
.stb
.eq(inp
.stb
),
1118 self
.ack
.eq(inp
.ack
)
1122 return [self
.stb
, self
.ack
]
1125 class FPOpIn(PrevControl
):
1126 def __init__(self
, width
):
1127 PrevControl
.__init
__(self
)
1134 def chain_inv(self
, in_op
, extra
=None):
1136 if extra
is not None:
1138 return [self
.v
.eq(in_op
.v
), # receive value
1139 self
.stb
.eq(stb
), # receive STB
1140 in_op
.ack
.eq(~self
.ack
), # send ACK
1143 def chain_from(self
, in_op
, extra
=None):
1145 if extra
is not None:
1147 return [self
.v
.eq(in_op
.v
), # receive value
1148 self
.stb
.eq(stb
), # receive STB
1149 in_op
.ack
.eq(self
.ack
), # send ACK
1153 class FPOpOut(NextControl
):
1154 def __init__(self
, width
):
1155 NextControl
.__init
__(self
)
1162 def chain_inv(self
, in_op
, extra
=None):
1164 if extra
is not None:
1166 return [self
.v
.eq(in_op
.v
), # receive value
1167 self
.stb
.eq(stb
), # receive STB
1168 in_op
.ack
.eq(~self
.ack
), # send ACK
1171 def chain_from(self
, in_op
, extra
=None):
1173 if extra
is not None:
1175 return [self
.v
.eq(in_op
.v
), # receive value
1176 self
.stb
.eq(stb
), # receive STB
1177 in_op
.ack
.eq(self
.ack
), # send ACK
1182 # TODO: change FFLAGS to be FPSCR's status flags
1183 FFLAGS_NV
= Const(1<<4, 5) # invalid operation
1184 FFLAGS_DZ
= Const(1<<3, 5) # divide by zero
1185 FFLAGS_OF
= Const(1<<2, 5) # overflow
1186 FFLAGS_UF
= Const(1<<1, 5) # underflow
1187 FFLAGS_NX
= Const(1<<0, 5) # inexact
1188 def __init__(self
, name
=None):
1191 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
1192 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
1193 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
1194 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
1195 self
.fpflags
= Signal(5, reset_less
=True, name
=name
+"fflags")
1197 self
.sign
= Signal(reset_less
=True, name
=name
+"sign")
1198 """sign bit -- 1 means negative, 0 means positive"""
1200 self
.rm
= Signal(FPRoundingMode
, name
=name
+"rm",
1201 reset
=FPRoundingMode
.DEFAULT
)
1204 #self.roundz = Signal(reset_less=True)
1208 yield self
.round_bit
1216 return [self
.guard
.eq(inp
.guard
),
1217 self
.round_bit
.eq(inp
.round_bit
),
1218 self
.sticky
.eq(inp
.sticky
),
1220 self
.fpflags
.eq(inp
.fpflags
),
1221 self
.sign
.eq(inp
.sign
),
1225 def roundz_rne(self
):
1226 """true if the mantissa should be rounded up for `rm == RNE`
1228 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1230 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
1233 def roundz_rna(self
):
1234 """true if the mantissa should be rounded up for `rm == RNA`
1236 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1241 def roundz_rtn(self
):
1242 """true if the mantissa should be rounded up for `rm == RTN`
1244 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1246 return self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1249 def roundz_rto(self
):
1250 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1252 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1253 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1255 return ~self
.m0
& (self
.guard | self
.round_bit | self
.sticky
)
1258 def roundz_rtp(self
):
1259 """true if the mantissa should be rounded up for `rm == RTP`
1261 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1263 return ~self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1266 def roundz_rtz(self
):
1267 """true if the mantissa should be rounded up for `rm == RTZ`
1269 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1275 """true if the mantissa should be rounded up for the current rounding
1279 FPRoundingMode
.RNA
: self
.roundz_rna
,
1280 FPRoundingMode
.RNE
: self
.roundz_rne
,
1281 FPRoundingMode
.RTN
: self
.roundz_rtn
,
1282 FPRoundingMode
.RTOP
: self
.roundz_rto
,
1283 FPRoundingMode
.RTON
: self
.roundz_rto
,
1284 FPRoundingMode
.RTP
: self
.roundz_rtp
,
1285 FPRoundingMode
.RTZ
: self
.roundz_rtz
,
1287 return FPRoundingMode
.make_array(lambda rm
: d
[rm
])[self
.rm
]
1290 class OverflowMod(Elaboratable
, Overflow
):
1291 def __init__(self
, name
=None):
1292 Overflow
.__init
__(self
, name
)
1295 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
1298 yield from Overflow
.__iter
__(self
)
1299 yield self
.roundz_out
1302 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
1304 def elaborate(self
, platform
):
1306 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
) # roundz is a property
1311 """ IEEE754 Floating Point Base Class
1313 contains common functions for FP manipulation, such as
1314 extracting and packing operands, normalisation, denormalisation,
1318 def get_op(self
, m
, op
, v
, next_state
):
1319 """ this function moves to the next state and copies the operand
1320 when both stb and ack are 1.
1321 acknowledgement is sent by setting ack to ZERO.
1325 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
1327 # op is latched in from FPNumIn class on same ack/stb
1328 m
.d
.comb
+= ack
.eq(0)
1330 m
.d
.comb
+= ack
.eq(1)
1333 def denormalise(self
, m
, a
):
1334 """ denormalises a number. this is probably the wrong name for
1335 this function. for normalised numbers (exponent != minimum)
1336 one *extra* bit (the implicit 1) is added *back in*.
1337 for denormalised numbers, the mantissa is left alone
1338 and the exponent increased by 1.
1340 both cases *effectively multiply the number stored by 2*,
1341 which has to be taken into account when extracting the result.
1343 with m
.If(a
.exp_n127
):
1344 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
1346 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
1348 def op_normalise(self
, m
, op
, next_state
):
1349 """ operand normalisation
1350 NOTE: just like "align", this one keeps going round every clock
1351 until the result's exponent is within acceptable "range"
1353 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
1355 op
.e
.eq(op
.e
- 1), # DECREASE exponent
1356 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
1361 def normalise_1(self
, m
, z
, of
, next_state
):
1362 """ first stage normalisation
1364 NOTE: just like "align", this one keeps going round every clock
1365 until the result's exponent is within acceptable "range"
1366 NOTE: the weirdness of reassigning guard and round is due to
1367 the extra mantissa bits coming from tot[0..2]
1369 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
1371 z
.e
.eq(z
.e
- 1), # DECREASE exponent
1372 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
1373 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
1374 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
1375 of
.round_bit
.eq(0), # reset round bit
1381 def normalise_2(self
, m
, z
, of
, next_state
):
1382 """ second stage normalisation
1384 NOTE: just like "align", this one keeps going round every clock
1385 until the result's exponent is within acceptable "range"
1386 NOTE: the weirdness of reassigning guard and round is due to
1387 the extra mantissa bits coming from tot[0..2]
1389 with m
.If(z
.e
< z
.fp
.N126
):
1391 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1392 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1393 of
.guard
.eq(z
.m
[0]),
1395 of
.round_bit
.eq(of
.guard
),
1396 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1401 def roundz(self
, m
, z
, roundz
):
1402 """ performs rounding on the output. TODO: different kinds of rounding
1405 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1406 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1407 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1409 def corrections(self
, m
, z
, next_state
):
1410 """ denormalisation and sign-bug corrections
1413 # denormalised, correct exponent to zero
1414 with m
.If(z
.is_denormalised
):
1415 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1417 def pack(self
, m
, z
, next_state
):
1418 """ packs the result into the output (detects overflow->Inf)
1421 # if overflow occurs, return inf
1422 with m
.If(z
.is_overflowed
):
1423 m
.d
.sync
+= z
.inf(z
.s
)
1425 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1427 def put_z(self
, m
, z
, out_z
, next_state
):
1428 """ put_z: stores the result in the output. raises stb and waits
1429 for ack to be set to 1 before moving to the next state.
1430 resets stb back to zero when that occurs, as acknowledgement.
1435 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1436 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1439 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1442 class FPState(FPBase
):
1443 def __init__(self
, state_from
):
1444 self
.state_from
= state_from
1446 def set_inputs(self
, inputs
):
1447 self
.inputs
= inputs
1448 for k
, v
in inputs
.items():
1451 def set_outputs(self
, outputs
):
1452 self
.outputs
= outputs
1453 for k
, v
in outputs
.items():
1458 def __init__(self
, id_wid
):
1459 self
.id_wid
= id_wid
1461 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1462 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1467 def idsync(self
, m
):
1468 if self
.id_wid
is not None:
1469 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1472 if __name__
== '__main__':