1 """IEEE754 Floating Point Library
3 Copyright (C) 2019 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
4 Copyright (C) 2019,2021 Jake Lifshay
9 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
, Array
, Value
11 from operator
import or_
12 from functools
import reduce
14 from nmutil
.singlepipe
import PrevControl
, NextControl
15 from nmutil
.pipeline
import ObjectProxy
21 from nmigen
.hdl
.smtlib2
import RoundingModeEnum
26 # value so FPRoundingMode.to_smtlib2 can detect when no default is supplied
30 class FPRoundingMode(enum
.Enum
):
31 # matches the FPSCR.RN field values, but includes some extra
32 # values (>= 0b100) used in miscellaneous instructions.
34 # naming matches smtlib2 names, doc strings are the OpenPower ISA
35 # specification's names (v3.1 section 7.3.2.6 --
36 # matches values in section 4.3.6).
38 """Round to Nearest Even
40 Rounds to the nearest representable floating-point number, ties are
41 rounded to the number with the even mantissa. Treats +-Infinity as if
42 it were a normalized floating-point number when deciding which number
43 is closer when rounding. See IEEE754 spec. for details.
46 ROUND_NEAREST_TIES_TO_EVEN
= RNE
52 If the result is exactly representable as a floating-point number, return
53 that, otherwise return the nearest representable floating-point value
54 with magnitude smaller than the exact answer.
57 ROUND_TOWARDS_ZERO
= RTZ
60 """Round towards +Infinity
62 If the result is exactly representable as a floating-point number, return
63 that, otherwise return the nearest representable floating-point value
64 that is numerically greater than the exact answer. This can round up to
68 ROUND_TOWARDS_POSITIVE
= RTP
71 """Round towards -Infinity
73 If the result is exactly representable as a floating-point number, return
74 that, otherwise return the nearest representable floating-point value
75 that is numerically less than the exact answer. This can round down to
79 ROUND_TOWARDS_NEGATIVE
= RTN
82 """Round to Nearest Away
84 Rounds to the nearest representable floating-point number, ties are
85 rounded to the number with the maximum magnitude. Treats +-Infinity as if
86 it were a normalized floating-point number when deciding which number
87 is closer when rounding. See IEEE754 spec. for details.
90 ROUND_NEAREST_TIES_TO_AWAY
= RNA
93 """Round to Odd, unsigned zeros are Positive
97 If the result is exactly representable as a floating-point number, return
98 that, otherwise return the nearest representable floating-point value
99 that has an odd mantissa.
101 If the result is zero but with otherwise undetermined sign
102 (e.g. `1.0 - 1.0`), the sign is positive.
104 This rounding mode is used for instructions with Round To Odd enabled,
105 and `FPSCR.RN != RTN`.
107 This is useful to avoid double-rounding errors when doing arithmetic in a
108 larger type (e.g. f128) but where the answer should be a smaller type
112 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE
= RTOP
115 """Round to Odd, unsigned zeros are Negative
119 If the result is exactly representable as a floating-point number, return
120 that, otherwise return the nearest representable floating-point value
121 that has an odd mantissa.
123 If the result is zero but with otherwise undetermined sign
124 (e.g. `1.0 - 1.0`), the sign is negative.
126 This rounding mode is used for instructions with Round To Odd enabled,
127 and `FPSCR.RN == RTN`.
129 This is useful to avoid double-rounding errors when doing arithmetic in a
130 larger type (e.g. f128) but where the answer should be a smaller type
134 ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE
= RTON
138 l
= [None] * len(FPRoundingMode
)
139 for rm
in FPRoundingMode
:
143 def overflow_rounds_to_inf(self
, sign
):
144 """returns true if an overflow should round to `inf`,
145 false if it should round to `max_normal`
147 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
148 if self
is FPRoundingMode
.RNE
:
150 elif self
is FPRoundingMode
.RTZ
:
152 elif self
is FPRoundingMode
.RTP
:
154 elif self
is FPRoundingMode
.RTN
:
156 elif self
is FPRoundingMode
.RNA
:
158 elif self
is FPRoundingMode
.RTOP
:
161 assert self
is FPRoundingMode
.RTON
164 def underflow_rounds_to_zero(self
, sign
):
165 """returns true if an underflow should round to `zero`,
166 false if it should round to `min_denormal`
168 not_sign
= ~sign
if isinstance(sign
, Value
) else not sign
169 if self
is FPRoundingMode
.RNE
:
171 elif self
is FPRoundingMode
.RTZ
:
173 elif self
is FPRoundingMode
.RTP
:
175 elif self
is FPRoundingMode
.RTN
:
177 elif self
is FPRoundingMode
.RNA
:
179 elif self
is FPRoundingMode
.RTOP
:
182 assert self
is FPRoundingMode
.RTON
186 """which sign an exact zero result should have when it isn't
187 otherwise determined, e.g. for `1.0 - 1.0`.
189 if self
is FPRoundingMode
.RNE
:
191 elif self
is FPRoundingMode
.RTZ
:
193 elif self
is FPRoundingMode
.RTP
:
195 elif self
is FPRoundingMode
.RTN
:
197 elif self
is FPRoundingMode
.RNA
:
199 elif self
is FPRoundingMode
.RTOP
:
202 assert self
is FPRoundingMode
.RTON
206 def to_smtlib2(self
, default
=_raise_err
):
207 """return the corresponding smtlib2 rounding mode for `self`. If
208 there is no corresponding smtlib2 rounding mode, then return
209 `default` if specified, else raise `ValueError`.
211 if self
is FPRoundingMode
.RNE
:
212 return RoundingModeEnum
.RNE
213 elif self
is FPRoundingMode
.RTZ
:
214 return RoundingModeEnum
.RTZ
215 elif self
is FPRoundingMode
.RTP
:
216 return RoundingModeEnum
.RTP
217 elif self
is FPRoundingMode
.RTN
:
218 return RoundingModeEnum
.RTN
219 elif self
is FPRoundingMode
.RNA
:
220 return RoundingModeEnum
.RNA
222 assert self
in (FPRoundingMode
.RTOP
, FPRoundingMode
.RTON
)
223 if default
is _raise_err
:
225 "no corresponding smtlib2 rounding mode", self
)
232 """ Class describing binary floating-point formats based on IEEE 754.
234 :attribute e_width: the number of bits in the exponent field.
235 :attribute m_width: the number of bits stored in the mantissa
237 :attribute has_int_bit: if the FP format has an explicit integer bit (like
238 the x87 80-bit format). The bit is considered part of the mantissa.
239 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
240 Image/Buffer formats are FP numbers without a sign bit.)
248 """ Create ``FPFormat`` instance. """
249 self
.e_width
= e_width
250 self
.m_width
= m_width
251 self
.has_int_bit
= has_int_bit
252 self
.has_sign
= has_sign
254 def __eq__(self
, other
):
255 """ Check for equality. """
256 if not isinstance(other
, FPFormat
):
257 return NotImplemented
258 return (self
.e_width
== other
.e_width
259 and self
.m_width
== other
.m_width
260 and self
.has_int_bit
== other
.has_int_bit
261 and self
.has_sign
== other
.has_sign
)
265 """ Get standard IEEE 754-2008 format.
267 :param width: bit-width of requested format.
268 :returns: the requested ``FPFormat`` instance.
271 return FPFormat(5, 10)
273 return FPFormat(8, 23)
275 return FPFormat(11, 52)
277 return FPFormat(15, 112)
278 if width
> 128 and width
% 32 == 0:
279 if width
> 1000000: # arbitrary upper limit
280 raise ValueError("width too big")
281 e_width
= round(4 * math
.log2(width
)) - 13
282 return FPFormat(e_width
, width
- 1 - e_width
)
283 raise ValueError("width must be the bit-width of a valid IEEE"
284 " 754-2008 binary format")
289 if self
== self
.standard(self
.width
):
290 return f
"FPFormat.standard({self.width})"
293 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
294 if self
.has_int_bit
is not False:
295 retval
+= f
", {self.has_int_bit}"
296 if self
.has_sign
is not True:
297 retval
+= f
", {self.has_sign}"
300 def get_sign_field(self
, x
):
301 """ returns the sign bit of its input number, x
302 (assumes FPFormat is set to signed - has_sign=True)
304 return x
>> (self
.e_width
+ self
.m_width
)
306 def get_exponent_field(self
, x
):
307 """ returns the raw exponent of its input number, x (no bias subtracted)
309 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
312 def get_exponent(self
, x
):
313 """ returns the exponent of its input number, x
315 return self
.get_exponent_field(x
) - self
.exponent_bias
317 def get_mantissa_field(self
, x
):
318 """ returns the mantissa of its input number, x
320 return x
& self
.mantissa_mask
322 def is_zero(self
, x
):
323 """ returns true if x is +/- zero
325 return (self
.get_exponent(x
) == self
.e_sub
) & \
326 (self
.get_mantissa_field(x
) == 0)
328 def is_subnormal(self
, x
):
329 """ returns true if x is subnormal (exp at minimum)
331 return (self
.get_exponent(x
) == self
.e_sub
) & \
332 (self
.get_mantissa_field(x
) != 0)
335 """ returns true if x is infinite
337 return (self
.get_exponent(x
) == self
.e_max
) & \
338 (self
.get_mantissa_field(x
) == 0)
341 """ returns true if x is a nan (quiet or signalling)
343 return (self
.get_exponent(x
) == self
.e_max
) & \
344 (self
.get_mantissa_field(x
) != 0)
346 def is_quiet_nan(self
, x
):
347 """ returns true if x is a quiet nan
349 highbit
= 1 << (self
.m_width
- 1)
350 return (self
.get_exponent(x
) == self
.e_max
) & \
351 (self
.get_mantissa_field(x
) != 0) & \
352 (self
.get_mantissa_field(x
) & highbit
!= 0)
354 def is_nan_signaling(self
, x
):
355 """ returns true if x is a signalling nan
357 highbit
= 1 << (self
.m_width
- 1)
358 return (self
.get_exponent(x
) == self
.e_max
) & \
359 (self
.get_mantissa_field(x
) != 0) & \
360 (self
.get_mantissa_field(x
) & highbit
) == 0
364 """ Get the total number of bits in the FP format. """
365 return self
.has_sign
+ self
.e_width
+ self
.m_width
368 def mantissa_mask(self
):
369 """ Get a mantissa mask based on the mantissa width """
370 return (1 << self
.m_width
) - 1
373 def exponent_inf_nan(self
):
374 """ Get the value of the exponent field designating infinity/NaN. """
375 return (1 << self
.e_width
) - 1
379 """ get the maximum exponent (minus bias)
381 return self
.exponent_inf_nan
- self
.exponent_bias
385 return self
.exponent_denormal_zero
- self
.exponent_bias
387 def exponent_denormal_zero(self
):
388 """ Get the value of the exponent field designating denormal/zero. """
392 def exponent_min_normal(self
):
393 """ Get the minimum value of the exponent field for normal numbers. """
397 def exponent_max_normal(self
):
398 """ Get the maximum value of the exponent field for normal numbers. """
399 return self
.exponent_inf_nan
- 1
402 def exponent_bias(self
):
403 """ Get the exponent bias. """
404 return (1 << (self
.e_width
- 1)) - 1
407 def fraction_width(self
):
408 """ Get the number of mantissa bits that are fraction bits. """
409 return self
.m_width
- self
.has_int_bit
412 class TestFPFormat(unittest
.TestCase
):
413 """ very quick test for FPFormat
416 def test_fpformat_fp64(self
):
417 f64
= FPFormat
.standard(64)
418 from sfpy
import Float64
419 x
= Float64(1.0).bits
422 self
.assertEqual(f64
.get_exponent(x
), 0)
423 x
= Float64(2.0).bits
425 self
.assertEqual(f64
.get_exponent(x
), 1)
427 x
= Float64(1.5).bits
428 m
= f64
.get_mantissa_field(x
)
429 print (hex(x
), hex(m
))
430 self
.assertEqual(m
, 0x8000000000000)
432 s
= f64
.get_sign_field(x
)
433 print (hex(x
), hex(s
))
434 self
.assertEqual(s
, 0)
436 x
= Float64(-1.5).bits
437 s
= f64
.get_sign_field(x
)
438 print (hex(x
), hex(s
))
439 self
.assertEqual(s
, 1)
441 def test_fpformat_fp32(self
):
442 f32
= FPFormat
.standard(32)
443 from sfpy
import Float32
444 x
= Float32(1.0).bits
447 self
.assertEqual(f32
.get_exponent(x
), 0)
448 x
= Float32(2.0).bits
450 self
.assertEqual(f32
.get_exponent(x
), 1)
452 x
= Float32(1.5).bits
453 m
= f32
.get_mantissa_field(x
)
454 print (hex(x
), hex(m
))
455 self
.assertEqual(m
, 0x400000)
458 x
= Float32(-1.0).sqrt()
461 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
462 f32
.get_mantissa_field(x
), i
)
463 self
.assertEqual(i
, True)
466 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
469 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
470 f32
.get_mantissa_field(x
), i
)
471 self
.assertEqual(i
, True)
476 i
= f32
.is_subnormal(x
)
477 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
478 f32
.get_mantissa_field(x
), i
)
479 self
.assertEqual(i
, True)
483 i
= f32
.is_subnormal(x
)
484 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
485 f32
.get_mantissa_field(x
), i
)
486 self
.assertEqual(i
, False)
490 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
491 f32
.get_mantissa_field(x
), i
)
492 self
.assertEqual(i
, True)
497 def __init__(self
, width
):
499 self
.smax
= int(log(width
) / log(2))
500 self
.i
= Signal(width
, reset_less
=True)
501 self
.s
= Signal(self
.smax
, reset_less
=True)
502 self
.o
= Signal(width
, reset_less
=True)
504 def elaborate(self
, platform
):
506 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
511 """ Generates variable-length single-cycle shifter from a series
512 of conditional tests on each bit of the left/right shift operand.
513 Each bit tested produces output shifted by that number of bits,
514 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
515 shifts by 2 bits, each partial result cascading to the next Mux.
517 Could be adapted to do arithmetic shift by taking copies of the
518 MSB instead of zeros.
521 def __init__(self
, width
):
523 self
.smax
= int(log(width
) / log(2))
525 def lshift(self
, op
, s
):
529 def rshift(self
, op
, s
):
534 class FPNumBaseRecord
:
535 """ Floating-point Base Number Class.
537 This class is designed to be passed around in other data structures
538 (between pipelines and between stages). Its "friend" is FPNumBase,
539 which is a *module*. The reason for the discernment is because
540 nmigen modules that are not added to submodules results in the
541 irritating "Elaboration" warning. Despite not *needing* FPNumBase
542 in many cases to be added as a submodule (because it is just data)
543 this was not possible to solve without splitting out the data from
547 def __init__(self
, width
, m_extra
=True, e_extra
=False, name
=None):
550 # assert false, "missing name"
554 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
555 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
556 e_max
= 1 << (e_width
-3)
557 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
560 # mantissa extra bits (top,guard,round)
562 m_width
+= self
.m_extra
566 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
567 e_width
+= self
.e_extra
570 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
571 self
.m_width
= m_width
572 self
.e_width
= e_width
573 self
.e_start
= self
.rmw
574 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
576 self
.v
= Signal(width
, reset_less
=True,
577 name
=name
+"v") # Latched copy of value
578 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
579 self
.e
= Signal((e_width
, True),
580 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
581 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
586 def drop_in(self
, fp
):
592 fp
.width
= self
.width
593 fp
.e_width
= self
.e_width
594 fp
.e_max
= self
.e_max
595 fp
.m_width
= self
.m_width
596 fp
.e_start
= self
.e_start
597 fp
.e_end
= self
.e_end
598 fp
.m_extra
= self
.m_extra
600 m_width
= self
.m_width
602 e_width
= self
.e_width
604 self
.mzero
= Const(0, (m_width
, False))
605 m_msb
= 1 << (self
.m_width
-2)
606 self
.msb1
= Const(m_msb
, (m_width
, False))
607 self
.m1s
= Const(-1, (m_width
, False))
608 self
.P128
= Const(e_max
, (e_width
, True))
609 self
.P127
= Const(e_max
-1, (e_width
, True))
610 self
.N127
= Const(-(e_max
-1), (e_width
, True))
611 self
.N126
= Const(-(e_max
-2), (e_width
, True))
613 def create(self
, s
, e
, m
):
614 """ creates a value from sign / exponent / mantissa
616 bias is added here, to the exponent.
618 NOTE: order is important, because e_start/e_end can be
619 a bit too long (overwriting s).
622 self
.v
[0:self
.e_start
].eq(m
), # mantissa
623 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
624 self
.v
[-1].eq(s
), # sign
628 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
631 return (s
, self
.fp
.P128
, 0)
634 return (s
, self
.fp
.N127
, 0)
637 return self
.create(*self
._nan
(s
))
639 def quieted_nan(self
, other
):
640 assert isinstance(other
, FPNumBaseRecord
)
641 assert self
.width
== other
.width
642 return self
.create(other
.s
, self
.fp
.P128
,
643 other
.v
[0:self
.e_start
] |
(1 << (self
.e_start
- 1)))
646 return self
.create(*self
._inf
(s
))
648 def max_normal(self
, s
):
649 return self
.create(s
, self
.fp
.P127
, ~
0)
651 def min_denormal(self
, s
):
652 return self
.create(s
, self
.fp
.N127
, 1)
655 return self
.create(*self
._zero
(s
))
657 def create2(self
, s
, e
, m
):
658 """ creates a value from sign / exponent / mantissa
660 bias is added here, to the exponent
662 e
= e
+ self
.P127
# exp (add on bias)
663 return Cat(m
[0:self
.e_start
],
664 e
[0:self
.e_end
-self
.e_start
],
668 return self
.create2(s
, self
.P128
, self
.msb1
)
671 return self
.create2(s
, self
.P128
, self
.mzero
)
674 return self
.create2(s
, self
.N127
, self
.mzero
)
682 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
685 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
686 """ Floating-point Base Number Class
689 def __init__(self
, fp
):
694 self
.is_nan
= Signal(reset_less
=True)
695 self
.is_zero
= Signal(reset_less
=True)
696 self
.is_inf
= Signal(reset_less
=True)
697 self
.is_overflowed
= Signal(reset_less
=True)
698 self
.is_denormalised
= Signal(reset_less
=True)
699 self
.exp_128
= Signal(reset_less
=True)
700 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
701 self
.exp_lt_n126
= Signal(reset_less
=True)
702 self
.exp_zero
= Signal(reset_less
=True)
703 self
.exp_gt_n126
= Signal(reset_less
=True)
704 self
.exp_gt127
= Signal(reset_less
=True)
705 self
.exp_n127
= Signal(reset_less
=True)
706 self
.exp_n126
= Signal(reset_less
=True)
707 self
.m_zero
= Signal(reset_less
=True)
708 self
.m_msbzero
= Signal(reset_less
=True)
710 def elaborate(self
, platform
):
712 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
713 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
714 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
715 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
716 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
717 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
718 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
719 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
720 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
721 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
722 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
723 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
724 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
725 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
726 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
731 return (self
.exp_128
) & (~self
.m_zero
)
734 return (self
.exp_128
) & (self
.m_zero
)
737 return (self
.exp_n127
) & (self
.m_zero
)
739 def _is_overflowed(self
):
740 return self
.exp_gt127
742 def _is_denormalised(self
):
743 # XXX NOT to be used for "official" quiet NaN tests!
744 # particularly when the MSB has been extended
745 return (self
.exp_n126
) & (self
.m_msbzero
)
748 class FPNumOut(FPNumBase
):
749 """ Floating-point Number Class
751 Contains signals for an incoming copy of the value, decoded into
752 sign / exponent / mantissa.
753 Also contains encoding functions, creation and recognition of
754 zero, NaN and inf (all signed)
756 Four extra bits are included in the mantissa: the top bit
757 (m[-1]) is effectively a carry-overflow. The other three are
758 guard (m[2]), round (m[1]), and sticky (m[0])
761 def __init__(self
, fp
):
762 FPNumBase
.__init
__(self
, fp
)
764 def elaborate(self
, platform
):
765 m
= FPNumBase
.elaborate(self
, platform
)
770 class MultiShiftRMerge(Elaboratable
):
771 """ shifts down (right) and merges lower bits into m[0].
772 m[0] is the "sticky" bit, basically
775 def __init__(self
, width
, s_max
=None):
777 s_max
= int(log(width
) / log(2))
779 self
.m
= Signal(width
, reset_less
=True)
780 self
.inp
= Signal(width
, reset_less
=True)
781 self
.diff
= Signal(s_max
, reset_less
=True)
784 def elaborate(self
, platform
):
787 rs
= Signal(self
.width
, reset_less
=True)
788 m_mask
= Signal(self
.width
, reset_less
=True)
789 smask
= Signal(self
.width
, reset_less
=True)
790 stickybit
= Signal(reset_less
=True)
791 # XXX GRR frickin nuisance https://github.com/nmigen/nmigen/issues/302
792 maxslen
= Signal(self
.smax
[0], reset_less
=True)
793 maxsleni
= Signal(self
.smax
[0], reset_less
=True)
795 sm
= MultiShift(self
.width
-1)
796 m0s
= Const(0, self
.width
-1)
797 mw
= Const(self
.width
-1, len(self
.diff
))
798 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
799 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
803 # shift mantissa by maxslen, mask by inverse
804 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
805 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
806 smask
.eq(self
.inp
[1:] & m_mask
),
807 # sticky bit combines all mask (and mantissa low bit)
808 stickybit
.eq(smask
.bool() | self
.inp
[0]),
809 # mantissa result contains m[0] already.
810 self
.m
.eq(Cat(stickybit
, rs
))
815 class FPNumShift(FPNumBase
, Elaboratable
):
816 """ Floating-point Number Class for shifting
819 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
820 FPNumBase
.__init
__(self
, width
, m_extra
)
821 self
.latch_in
= Signal()
826 def elaborate(self
, platform
):
827 m
= FPNumBase
.elaborate(self
, platform
)
829 m
.d
.comb
+= self
.s
.eq(op
.s
)
830 m
.d
.comb
+= self
.e
.eq(op
.e
)
831 m
.d
.comb
+= self
.m
.eq(op
.m
)
833 with self
.mainm
.State("align"):
834 with m
.If(self
.e
< self
.inv
.e
):
835 m
.d
.sync
+= self
.shift_down()
839 def shift_down(self
, inp
):
840 """ shifts a mantissa down by one. exponent is increased to compensate
842 accuracy is lost as a result in the mantissa however there are 3
843 guard bits (the latter of which is the "sticky" bit)
845 return [self
.e
.eq(inp
.e
+ 1),
846 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
849 def shift_down_multi(self
, diff
):
850 """ shifts a mantissa down. exponent is increased to compensate
852 accuracy is lost as a result in the mantissa however there are 3
853 guard bits (the latter of which is the "sticky" bit)
855 this code works by variable-shifting the mantissa by up to
856 its maximum bit-length: no point doing more (it'll still be
859 the sticky bit is computed by shifting a batch of 1s by
860 the same amount, which will introduce zeros. it's then
861 inverted and used as a mask to get the LSBs of the mantissa.
862 those are then |'d into the sticky bit.
864 sm
= MultiShift(self
.width
)
865 mw
= Const(self
.m_width
-1, len(diff
))
866 maxslen
= Mux(diff
> mw
, mw
, diff
)
867 rs
= sm
.rshift(self
.m
[1:], maxslen
)
868 maxsleni
= mw
- maxslen
869 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
871 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
872 return [self
.e
.eq(self
.e
+ diff
),
873 self
.m
.eq(Cat(stickybits
, rs
))
876 def shift_up_multi(self
, diff
):
877 """ shifts a mantissa up. exponent is decreased to compensate
879 sm
= MultiShift(self
.width
)
880 mw
= Const(self
.m_width
, len(diff
))
881 maxslen
= Mux(diff
> mw
, mw
, diff
)
883 return [self
.e
.eq(self
.e
- diff
),
884 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
888 class FPNumDecode(FPNumBase
):
889 """ Floating-point Number Class
891 Contains signals for an incoming copy of the value, decoded into
892 sign / exponent / mantissa.
893 Also contains encoding functions, creation and recognition of
894 zero, NaN and inf (all signed)
896 Four extra bits are included in the mantissa: the top bit
897 (m[-1]) is effectively a carry-overflow. The other three are
898 guard (m[2]), round (m[1]), and sticky (m[0])
901 def __init__(self
, op
, fp
):
902 FPNumBase
.__init
__(self
, fp
)
905 def elaborate(self
, platform
):
906 m
= FPNumBase
.elaborate(self
, platform
)
908 m
.d
.comb
+= self
.decode(self
.v
)
913 """ decodes a latched value into sign / exponent / mantissa
915 bias is subtracted here, from the exponent. exponent
916 is extended to 10 bits so that subtract 127 is done on
919 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
920 #print ("decode", self.e_end)
921 return [self
.m
.eq(Cat(*args
)), # mantissa
922 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
923 self
.s
.eq(v
[-1]), # sign
927 class FPNumIn(FPNumBase
):
928 """ Floating-point Number Class
930 Contains signals for an incoming copy of the value, decoded into
931 sign / exponent / mantissa.
932 Also contains encoding functions, creation and recognition of
933 zero, NaN and inf (all signed)
935 Four extra bits are included in the mantissa: the top bit
936 (m[-1]) is effectively a carry-overflow. The other three are
937 guard (m[2]), round (m[1]), and sticky (m[0])
940 def __init__(self
, op
, fp
):
941 FPNumBase
.__init
__(self
, fp
)
942 self
.latch_in
= Signal()
945 def decode2(self
, m
):
946 """ decodes a latched value into sign / exponent / mantissa
948 bias is subtracted here, from the exponent. exponent
949 is extended to 10 bits so that subtract 127 is done on
953 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
954 #print ("decode", self.e_end)
955 res
= ObjectProxy(m
, pipemode
=False)
956 res
.m
= Cat(*args
) # mantissa
957 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
962 """ decodes a latched value into sign / exponent / mantissa
964 bias is subtracted here, from the exponent. exponent
965 is extended to 10 bits so that subtract 127 is done on
968 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
969 #print ("decode", self.e_end)
970 return [self
.m
.eq(Cat(*args
)), # mantissa
971 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
972 self
.s
.eq(v
[-1]), # sign
975 def shift_down(self
, inp
):
976 """ shifts a mantissa down by one. exponent is increased to compensate
978 accuracy is lost as a result in the mantissa however there are 3
979 guard bits (the latter of which is the "sticky" bit)
981 return [self
.e
.eq(inp
.e
+ 1),
982 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
985 def shift_down_multi(self
, diff
, inp
=None):
986 """ shifts a mantissa down. exponent is increased to compensate
988 accuracy is lost as a result in the mantissa however there are 3
989 guard bits (the latter of which is the "sticky" bit)
991 this code works by variable-shifting the mantissa by up to
992 its maximum bit-length: no point doing more (it'll still be
995 the sticky bit is computed by shifting a batch of 1s by
996 the same amount, which will introduce zeros. it's then
997 inverted and used as a mask to get the LSBs of the mantissa.
998 those are then |'d into the sticky bit.
1002 sm
= MultiShift(self
.width
)
1003 mw
= Const(self
.m_width
-1, len(diff
))
1004 maxslen
= Mux(diff
> mw
, mw
, diff
)
1005 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
1006 maxsleni
= mw
- maxslen
1007 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
1009 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
1010 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
1011 return [self
.e
.eq(inp
.e
+ diff
),
1012 self
.m
.eq(Cat(stickybit
, rs
))
1015 def shift_up_multi(self
, diff
):
1016 """ shifts a mantissa up. exponent is decreased to compensate
1018 sm
= MultiShift(self
.width
)
1019 mw
= Const(self
.m_width
, len(diff
))
1020 maxslen
= Mux(diff
> mw
, mw
, diff
)
1022 return [self
.e
.eq(self
.e
- diff
),
1023 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
1027 class Trigger(Elaboratable
):
1030 self
.stb
= Signal(reset
=0)
1032 self
.trigger
= Signal(reset_less
=True)
1034 def elaborate(self
, platform
):
1036 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
1040 return [self
.stb
.eq(inp
.stb
),
1041 self
.ack
.eq(inp
.ack
)
1045 return [self
.stb
, self
.ack
]
1048 class FPOpIn(PrevControl
):
1049 def __init__(self
, width
):
1050 PrevControl
.__init
__(self
)
1057 def chain_inv(self
, in_op
, extra
=None):
1059 if extra
is not None:
1061 return [self
.v
.eq(in_op
.v
), # receive value
1062 self
.stb
.eq(stb
), # receive STB
1063 in_op
.ack
.eq(~self
.ack
), # send ACK
1066 def chain_from(self
, in_op
, extra
=None):
1068 if extra
is not None:
1070 return [self
.v
.eq(in_op
.v
), # receive value
1071 self
.stb
.eq(stb
), # receive STB
1072 in_op
.ack
.eq(self
.ack
), # send ACK
1076 class FPOpOut(NextControl
):
1077 def __init__(self
, width
):
1078 NextControl
.__init
__(self
)
1085 def chain_inv(self
, in_op
, extra
=None):
1087 if extra
is not None:
1089 return [self
.v
.eq(in_op
.v
), # receive value
1090 self
.stb
.eq(stb
), # receive STB
1091 in_op
.ack
.eq(~self
.ack
), # send ACK
1094 def chain_from(self
, in_op
, extra
=None):
1096 if extra
is not None:
1098 return [self
.v
.eq(in_op
.v
), # receive value
1099 self
.stb
.eq(stb
), # receive STB
1100 in_op
.ack
.eq(self
.ack
), # send ACK
1105 # TODO: change FFLAGS to be FPSCR's status flags
1106 FFLAGS_NV
= Const(1<<4, 5) # invalid operation
1107 FFLAGS_DZ
= Const(1<<3, 5) # divide by zero
1108 FFLAGS_OF
= Const(1<<2, 5) # overflow
1109 FFLAGS_UF
= Const(1<<1, 5) # underflow
1110 FFLAGS_NX
= Const(1<<0, 5) # inexact
1111 def __init__(self
, name
=None):
1114 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
1115 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
1116 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
1117 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
1118 self
.fpflags
= Signal(5, reset_less
=True, name
=name
+"fflags")
1120 self
.sign
= Signal(reset_less
=True, name
=name
+"sign")
1121 """sign bit -- 1 means negative, 0 means positive"""
1123 self
.rm
= Signal(FPRoundingMode
, name
=name
+"rm",
1124 reset
=FPRoundingMode
.DEFAULT
)
1127 #self.roundz = Signal(reset_less=True)
1131 yield self
.round_bit
1139 return [self
.guard
.eq(inp
.guard
),
1140 self
.round_bit
.eq(inp
.round_bit
),
1141 self
.sticky
.eq(inp
.sticky
),
1143 self
.fpflags
.eq(inp
.fpflags
),
1144 self
.sign
.eq(inp
.sign
),
1148 def roundz_rne(self
):
1149 """true if the mantissa should be rounded up for `rm == RNE`
1151 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_EVEN`
1153 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
1156 def roundz_rna(self
):
1157 """true if the mantissa should be rounded up for `rm == RNA`
1159 assumes the rounding mode is `ROUND_NEAREST_TIES_TO_AWAY`
1164 def roundz_rtn(self
):
1165 """true if the mantissa should be rounded up for `rm == RTN`
1167 assumes the rounding mode is `ROUND_TOWARDS_NEGATIVE`
1169 return self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1172 def roundz_rto(self
):
1173 """true if the mantissa should be rounded up for `rm in (RTOP, RTON)`
1175 assumes the rounding mode is `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_POSITIVE`
1176 or `ROUND_TO_ODD_UNSIGNED_ZEROS_ARE_NEGATIVE`
1178 return ~self
.m0
& (self
.guard | self
.round_bit | self
.sticky
)
1181 def roundz_rtp(self
):
1182 """true if the mantissa should be rounded up for `rm == RTP`
1184 assumes the rounding mode is `ROUND_TOWARDS_POSITIVE`
1186 return ~self
.sign
& (self
.guard | self
.round_bit | self
.sticky
)
1189 def roundz_rtz(self
):
1190 """true if the mantissa should be rounded up for `rm == RTZ`
1192 assumes the rounding mode is `ROUND_TOWARDS_ZERO`
1198 """true if the mantissa should be rounded up for the current rounding
1202 FPRoundingMode
.RNA
: self
.roundz_rna
,
1203 FPRoundingMode
.RNE
: self
.roundz_rne
,
1204 FPRoundingMode
.RTN
: self
.roundz_rtn
,
1205 FPRoundingMode
.RTOP
: self
.roundz_rto
,
1206 FPRoundingMode
.RTON
: self
.roundz_rto
,
1207 FPRoundingMode
.RTP
: self
.roundz_rtp
,
1208 FPRoundingMode
.RTZ
: self
.roundz_rtz
,
1210 return FPRoundingMode
.make_array(lambda rm
: d
[rm
])[self
.rm
]
1213 class OverflowMod(Elaboratable
, Overflow
):
1214 def __init__(self
, name
=None):
1215 Overflow
.__init
__(self
, name
)
1218 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
1221 yield from Overflow
.__iter
__(self
)
1222 yield self
.roundz_out
1225 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
1227 def elaborate(self
, platform
):
1229 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
) # roundz is a property
1234 """ IEEE754 Floating Point Base Class
1236 contains common functions for FP manipulation, such as
1237 extracting and packing operands, normalisation, denormalisation,
1241 def get_op(self
, m
, op
, v
, next_state
):
1242 """ this function moves to the next state and copies the operand
1243 when both stb and ack are 1.
1244 acknowledgement is sent by setting ack to ZERO.
1248 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
1250 # op is latched in from FPNumIn class on same ack/stb
1251 m
.d
.comb
+= ack
.eq(0)
1253 m
.d
.comb
+= ack
.eq(1)
1256 def denormalise(self
, m
, a
):
1257 """ denormalises a number. this is probably the wrong name for
1258 this function. for normalised numbers (exponent != minimum)
1259 one *extra* bit (the implicit 1) is added *back in*.
1260 for denormalised numbers, the mantissa is left alone
1261 and the exponent increased by 1.
1263 both cases *effectively multiply the number stored by 2*,
1264 which has to be taken into account when extracting the result.
1266 with m
.If(a
.exp_n127
):
1267 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
1269 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
1271 def op_normalise(self
, m
, op
, next_state
):
1272 """ operand normalisation
1273 NOTE: just like "align", this one keeps going round every clock
1274 until the result's exponent is within acceptable "range"
1276 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
1278 op
.e
.eq(op
.e
- 1), # DECREASE exponent
1279 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
1284 def normalise_1(self
, m
, z
, of
, next_state
):
1285 """ first stage normalisation
1287 NOTE: just like "align", this one keeps going round every clock
1288 until the result's exponent is within acceptable "range"
1289 NOTE: the weirdness of reassigning guard and round is due to
1290 the extra mantissa bits coming from tot[0..2]
1292 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
1294 z
.e
.eq(z
.e
- 1), # DECREASE exponent
1295 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
1296 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
1297 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
1298 of
.round_bit
.eq(0), # reset round bit
1304 def normalise_2(self
, m
, z
, of
, next_state
):
1305 """ second stage normalisation
1307 NOTE: just like "align", this one keeps going round every clock
1308 until the result's exponent is within acceptable "range"
1309 NOTE: the weirdness of reassigning guard and round is due to
1310 the extra mantissa bits coming from tot[0..2]
1312 with m
.If(z
.e
< z
.fp
.N126
):
1314 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1315 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1316 of
.guard
.eq(z
.m
[0]),
1318 of
.round_bit
.eq(of
.guard
),
1319 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1324 def roundz(self
, m
, z
, roundz
):
1325 """ performs rounding on the output. TODO: different kinds of rounding
1328 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1329 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1330 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1332 def corrections(self
, m
, z
, next_state
):
1333 """ denormalisation and sign-bug corrections
1336 # denormalised, correct exponent to zero
1337 with m
.If(z
.is_denormalised
):
1338 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1340 def pack(self
, m
, z
, next_state
):
1341 """ packs the result into the output (detects overflow->Inf)
1344 # if overflow occurs, return inf
1345 with m
.If(z
.is_overflowed
):
1346 m
.d
.sync
+= z
.inf(z
.s
)
1348 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1350 def put_z(self
, m
, z
, out_z
, next_state
):
1351 """ put_z: stores the result in the output. raises stb and waits
1352 for ack to be set to 1 before moving to the next state.
1353 resets stb back to zero when that occurs, as acknowledgement.
1358 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1359 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1362 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1365 class FPState(FPBase
):
1366 def __init__(self
, state_from
):
1367 self
.state_from
= state_from
1369 def set_inputs(self
, inputs
):
1370 self
.inputs
= inputs
1371 for k
, v
in inputs
.items():
1374 def set_outputs(self
, outputs
):
1375 self
.outputs
= outputs
1376 for k
, v
in outputs
.items():
1381 def __init__(self
, id_wid
):
1382 self
.id_wid
= id_wid
1384 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1385 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1390 def idsync(self
, m
):
1391 if self
.id_wid
is not None:
1392 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1395 if __name__
== '__main__':