1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
, Elaboratable
7 from operator
import or_
8 from functools
import reduce
10 from nmutil
.singlepipe
import PrevControl
, NextControl
11 from nmutil
.pipeline
import ObjectProxy
17 """ Class describing binary floating-point formats based on IEEE 754.
19 :attribute e_width: the number of bits in the exponent field.
20 :attribute m_width: the number of bits stored in the mantissa
22 :attribute has_int_bit: if the FP format has an explicit integer bit (like
23 the x87 80-bit format). The bit is considered part of the mantissa.
24 :attribute has_sign: if the FP format has a sign bit. (Some Vulkan
25 Image/Buffer formats are FP numbers without a sign bit.)
33 """ Create ``FPFormat`` instance. """
34 self
.e_width
= e_width
35 self
.m_width
= m_width
36 self
.has_int_bit
= has_int_bit
37 self
.has_sign
= has_sign
39 def __eq__(self
, other
):
40 """ Check for equality. """
41 if not isinstance(other
, FPFormat
):
43 return (self
.e_width
== other
.e_width
44 and self
.m_width
== other
.m_width
45 and self
.has_int_bit
== other
.has_int_bit
46 and self
.has_sign
== other
.has_sign
)
50 """ Get standard IEEE 754-2008 format.
52 :param width: bit-width of requested format.
53 :returns: the requested ``FPFormat`` instance.
56 return FPFormat(5, 10)
58 return FPFormat(8, 23)
60 return FPFormat(11, 52)
62 return FPFormat(15, 112)
63 if width
> 128 and width
% 32 == 0:
64 if width
> 1000000: # arbitrary upper limit
65 raise ValueError("width too big")
66 e_width
= round(4 * math
.log2(width
)) - 13
67 return FPFormat(e_width
, width
- 1 - e_width
)
68 raise ValueError("width must be the bit-width of a valid IEEE"
69 " 754-2008 binary format")
74 if self
== self
.standard(self
.width
):
75 return f
"FPFormat.standard({self.width})"
78 retval
= f
"FPFormat({self.e_width}, {self.m_width}"
79 if self
.has_int_bit
is not False:
80 retval
+= f
", {self.has_int_bit}"
81 if self
.has_sign
is not True:
82 retval
+= f
", {self.has_sign}"
85 def get_sign_field(self
, x
):
86 """ returns the sign bit of its input number, x
87 (assumes FPFormat is set to signed - has_sign=True)
89 return x
>> (self
.e_width
+ self
.m_width
)
91 def get_exponent_field(self
, x
):
92 """ returns the raw exponent of its input number, x (no bias subtracted)
94 x
= ((x
>> self
.m_width
) & self
.exponent_inf_nan
)
97 def get_exponent(self
, x
):
98 """ returns the exponent of its input number, x
100 return self
.get_exponent_field(x
) - self
.exponent_bias
102 def get_mantissa_field(self
, x
):
103 """ returns the mantissa of its input number, x
105 return x
& self
.mantissa_mask
107 def is_zero(self
, x
):
108 """ returns true if x is +/- zero
110 return (self
.get_exponent(x
) == self
.e_sub
and
111 self
.get_mantissa_field(x
) == 0)
113 def is_subnormal(self
, x
):
114 """ returns true if x is subnormal (exp at minimum)
116 return (self
.get_exponent(x
) == self
.e_sub
and
117 self
.get_mantissa_field(x
) != 0)
120 """ returns true if x is infinite
122 return (self
.get_exponent(x
) == self
.e_max
and
123 self
.get_mantissa_field(x
) == 0)
126 """ returns true if x is a nan (quiet or signalling)
128 return (self
.get_exponent(x
) == self
.e_max
and
129 self
.get_mantissa_field(x
) != 0)
131 def is_quiet_nan(self
, x
):
132 """ returns true if x is a quiet nan
134 highbit
= 1<<(self
.m_width
-1)
135 return (self
.get_exponent(x
) == self
.e_max
and
136 self
.get_mantissa_field(x
) != 0 and
137 self
.get_mantissa_field(x
) & highbit
!= 0)
139 def is_nan_signaling(self
, x
):
140 """ returns true if x is a signalling nan
142 highbit
= 1<<(self
.m_width
-1)
143 return ((self
.get_exponent(x
) == self
.e_max
) and
144 (self
.get_mantissa_field(x
) != 0) and
145 (self
.get_mantissa_field(x
) & highbit
) == 0)
149 """ Get the total number of bits in the FP format. """
150 return self
.has_sign
+ self
.e_width
+ self
.m_width
153 def mantissa_mask(self
):
154 """ Get a mantissa mask based on the mantissa width """
155 return (1 << self
.m_width
) - 1
158 def exponent_inf_nan(self
):
159 """ Get the value of the exponent field designating infinity/NaN. """
160 return (1 << self
.e_width
) - 1
164 """ get the maximum exponent (minus bias)
166 return self
.exponent_inf_nan
- self
.exponent_bias
170 return self
.exponent_denormal_zero
- self
.exponent_bias
172 def exponent_denormal_zero(self
):
173 """ Get the value of the exponent field designating denormal/zero. """
177 def exponent_min_normal(self
):
178 """ Get the minimum value of the exponent field for normal numbers. """
182 def exponent_max_normal(self
):
183 """ Get the maximum value of the exponent field for normal numbers. """
184 return self
.exponent_inf_nan
- 1
187 def exponent_bias(self
):
188 """ Get the exponent bias. """
189 return (1 << (self
.e_width
- 1)) - 1
192 def fraction_width(self
):
193 """ Get the number of mantissa bits that are fraction bits. """
194 return self
.m_width
- self
.has_int_bit
197 class TestFPFormat(unittest
.TestCase
):
198 """ very quick test for FPFormat
201 def test_fpformat_fp64(self
):
202 f64
= FPFormat
.standard(64)
203 from sfpy
import Float64
204 x
= Float64(1.0).bits
207 self
.assertEqual(f64
.get_exponent(x
), 0)
208 x
= Float64(2.0).bits
210 self
.assertEqual(f64
.get_exponent(x
), 1)
212 x
= Float64(1.5).bits
213 m
= f64
.get_mantissa_field(x
)
214 print (hex(x
), hex(m
))
215 self
.assertEqual(m
, 0x8000000000000)
217 s
= f64
.get_sign_field(x
)
218 print (hex(x
), hex(s
))
219 self
.assertEqual(s
, 0)
221 x
= Float64(-1.5).bits
222 s
= f64
.get_sign_field(x
)
223 print (hex(x
), hex(s
))
224 self
.assertEqual(s
, 1)
226 def test_fpformat_fp32(self
):
227 f32
= FPFormat
.standard(32)
228 from sfpy
import Float32
229 x
= Float32(1.0).bits
232 self
.assertEqual(f32
.get_exponent(x
), 0)
233 x
= Float32(2.0).bits
235 self
.assertEqual(f32
.get_exponent(x
), 1)
237 x
= Float32(1.5).bits
238 m
= f32
.get_mantissa_field(x
)
239 print (hex(x
), hex(m
))
240 self
.assertEqual(m
, 0x400000)
243 x
= Float32(-1.0).sqrt()
246 print (hex(x
), "nan", f32
.get_exponent(x
), f32
.e_max
,
247 f32
.get_mantissa_field(x
), i
)
248 self
.assertEqual(i
, True)
251 x
= Float32(1e36
) * Float32(1e36
) * Float32(1e36
)
254 print (hex(x
), "inf", f32
.get_exponent(x
), f32
.e_max
,
255 f32
.get_mantissa_field(x
), i
)
256 self
.assertEqual(i
, True)
261 i
= f32
.is_subnormal(x
)
262 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
263 f32
.get_mantissa_field(x
), i
)
264 self
.assertEqual(i
, True)
268 i
= f32
.is_subnormal(x
)
269 print (hex(x
), "sub", f32
.get_exponent(x
), f32
.e_max
,
270 f32
.get_mantissa_field(x
), i
)
271 self
.assertEqual(i
, False)
275 print (hex(x
), "zero", f32
.get_exponent(x
), f32
.e_max
,
276 f32
.get_mantissa_field(x
), i
)
277 self
.assertEqual(i
, True)
282 def __init__(self
, width
):
284 self
.smax
= int(log(width
) / log(2))
285 self
.i
= Signal(width
, reset_less
=True)
286 self
.s
= Signal(self
.smax
, reset_less
=True)
287 self
.o
= Signal(width
, reset_less
=True)
289 def elaborate(self
, platform
):
291 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
296 """ Generates variable-length single-cycle shifter from a series
297 of conditional tests on each bit of the left/right shift operand.
298 Each bit tested produces output shifted by that number of bits,
299 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
300 shifts by 2 bits, each partial result cascading to the next Mux.
302 Could be adapted to do arithmetic shift by taking copies of the
303 MSB instead of zeros.
306 def __init__(self
, width
):
308 self
.smax
= int(log(width
) / log(2))
310 def lshift(self
, op
, s
):
314 def rshift(self
, op
, s
):
319 class FPNumBaseRecord
:
320 """ Floating-point Base Number Class.
322 This class is designed to be passed around in other data structures
323 (between pipelines and between stages). Its "friend" is FPNumBase,
324 which is a *module*. The reason for the discernment is because
325 nmigen modules that are not added to submodules results in the
326 irritating "Elaboration" warning. Despite not *needing* FPNumBase
327 in many cases to be added as a submodule (because it is just data)
328 this was not possible to solve without splitting out the data from
332 def __init__(self
, width
, m_extra
=True, e_extra
=False, name
=None):
335 # assert false, "missing name"
339 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
340 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
341 e_max
= 1 << (e_width
-3)
342 self
.rmw
= m_width
- 1 # real mantissa width (not including extras)
345 # mantissa extra bits (top,guard,round)
347 m_width
+= self
.m_extra
351 self
.e_extra
= 6 # enough to cover FP64 when converting to FP16
352 e_width
+= self
.e_extra
355 # print (m_width, e_width, e_max, self.rmw, self.m_extra)
356 self
.m_width
= m_width
357 self
.e_width
= e_width
358 self
.e_start
= self
.rmw
359 self
.e_end
= self
.rmw
+ self
.e_width
- 2 # for decoding
361 self
.v
= Signal(width
, reset_less
=True,
362 name
=name
+"v") # Latched copy of value
363 self
.m
= Signal(m_width
, reset_less
=True, name
=name
+"m") # Mantissa
364 self
.e
= Signal((e_width
, True),
365 reset_less
=True, name
=name
+"e") # exp+2 bits, signed
366 self
.s
= Signal(reset_less
=True, name
=name
+"s") # Sign bit
371 def drop_in(self
, fp
):
377 fp
.width
= self
.width
378 fp
.e_width
= self
.e_width
379 fp
.e_max
= self
.e_max
380 fp
.m_width
= self
.m_width
381 fp
.e_start
= self
.e_start
382 fp
.e_end
= self
.e_end
383 fp
.m_extra
= self
.m_extra
385 m_width
= self
.m_width
387 e_width
= self
.e_width
389 self
.mzero
= Const(0, (m_width
, False))
390 m_msb
= 1 << (self
.m_width
-2)
391 self
.msb1
= Const(m_msb
, (m_width
, False))
392 self
.m1s
= Const(-1, (m_width
, False))
393 self
.P128
= Const(e_max
, (e_width
, True))
394 self
.P127
= Const(e_max
-1, (e_width
, True))
395 self
.N127
= Const(-(e_max
-1), (e_width
, True))
396 self
.N126
= Const(-(e_max
-2), (e_width
, True))
398 def create(self
, s
, e
, m
):
399 """ creates a value from sign / exponent / mantissa
401 bias is added here, to the exponent.
403 NOTE: order is important, because e_start/e_end can be
404 a bit too long (overwriting s).
407 self
.v
[0:self
.e_start
].eq(m
), # mantissa
408 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.fp
.P127
), # (add bias)
409 self
.v
[-1].eq(s
), # sign
413 return (s
, self
.fp
.P128
, 1 << (self
.e_start
-1))
416 return (s
, self
.fp
.P128
, 0)
419 return (s
, self
.fp
.N127
, 0)
422 return self
.create(*self
._nan
(s
))
425 return self
.create(*self
._inf
(s
))
428 return self
.create(*self
._zero
(s
))
430 def create2(self
, s
, e
, m
):
431 """ creates a value from sign / exponent / mantissa
433 bias is added here, to the exponent
435 e
= e
+ self
.P127
# exp (add on bias)
436 return Cat(m
[0:self
.e_start
],
437 e
[0:self
.e_end
-self
.e_start
],
441 return self
.create2(s
, self
.P128
, self
.msb1
)
444 return self
.create2(s
, self
.P128
, self
.mzero
)
447 return self
.create2(s
, self
.N127
, self
.mzero
)
455 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
458 class FPNumBase(FPNumBaseRecord
, Elaboratable
):
459 """ Floating-point Base Number Class
462 def __init__(self
, fp
):
467 self
.is_nan
= Signal(reset_less
=True)
468 self
.is_zero
= Signal(reset_less
=True)
469 self
.is_inf
= Signal(reset_less
=True)
470 self
.is_overflowed
= Signal(reset_less
=True)
471 self
.is_denormalised
= Signal(reset_less
=True)
472 self
.exp_128
= Signal(reset_less
=True)
473 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
474 self
.exp_lt_n126
= Signal(reset_less
=True)
475 self
.exp_zero
= Signal(reset_less
=True)
476 self
.exp_gt_n126
= Signal(reset_less
=True)
477 self
.exp_gt127
= Signal(reset_less
=True)
478 self
.exp_n127
= Signal(reset_less
=True)
479 self
.exp_n126
= Signal(reset_less
=True)
480 self
.m_zero
= Signal(reset_less
=True)
481 self
.m_msbzero
= Signal(reset_less
=True)
483 def elaborate(self
, platform
):
485 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
486 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
487 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
488 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
489 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
490 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.fp
.P128
)
491 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.fp
.N126
)
492 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
493 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
494 m
.d
.comb
+= self
.exp_zero
.eq(self
.e
== 0)
495 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.fp
.P127
)
496 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.fp
.N127
)
497 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.fp
.N126
)
498 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.fp
.mzero
)
499 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.fp
.e_start
] == 0)
504 return (self
.exp_128
) & (~self
.m_zero
)
507 return (self
.exp_128
) & (self
.m_zero
)
510 return (self
.exp_n127
) & (self
.m_zero
)
512 def _is_overflowed(self
):
513 return self
.exp_gt127
515 def _is_denormalised(self
):
516 # XXX NOT to be used for "official" quiet NaN tests!
517 # particularly when the MSB has been extended
518 return (self
.exp_n126
) & (self
.m_msbzero
)
521 class FPNumOut(FPNumBase
):
522 """ Floating-point Number Class
524 Contains signals for an incoming copy of the value, decoded into
525 sign / exponent / mantissa.
526 Also contains encoding functions, creation and recognition of
527 zero, NaN and inf (all signed)
529 Four extra bits are included in the mantissa: the top bit
530 (m[-1]) is effectively a carry-overflow. The other three are
531 guard (m[2]), round (m[1]), and sticky (m[0])
534 def __init__(self
, fp
):
535 FPNumBase
.__init
__(self
, fp
)
537 def elaborate(self
, platform
):
538 m
= FPNumBase
.elaborate(self
, platform
)
543 class MultiShiftRMerge(Elaboratable
):
544 """ shifts down (right) and merges lower bits into m[0].
545 m[0] is the "sticky" bit, basically
548 def __init__(self
, width
, s_max
=None):
550 s_max
= int(log(width
) / log(2))
552 self
.m
= Signal(width
, reset_less
=True)
553 self
.inp
= Signal(width
, reset_less
=True)
554 self
.diff
= Signal(s_max
, reset_less
=True)
557 def elaborate(self
, platform
):
560 rs
= Signal(self
.width
, reset_less
=True)
561 m_mask
= Signal(self
.width
, reset_less
=True)
562 smask
= Signal(self
.width
, reset_less
=True)
563 stickybit
= Signal(reset_less
=True)
564 maxslen
= Signal(self
.smax
, reset_less
=True)
565 maxsleni
= Signal(self
.smax
, reset_less
=True)
567 sm
= MultiShift(self
.width
-1)
568 m0s
= Const(0, self
.width
-1)
569 mw
= Const(self
.width
-1, len(self
.diff
))
570 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
571 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
575 # shift mantissa by maxslen, mask by inverse
576 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
577 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
578 smask
.eq(self
.inp
[1:] & m_mask
),
579 # sticky bit combines all mask (and mantissa low bit)
580 stickybit
.eq(smask
.bool() | self
.inp
[0]),
581 # mantissa result contains m[0] already.
582 self
.m
.eq(Cat(stickybit
, rs
))
587 class FPNumShift(FPNumBase
, Elaboratable
):
588 """ Floating-point Number Class for shifting
591 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
592 FPNumBase
.__init
__(self
, width
, m_extra
)
593 self
.latch_in
= Signal()
598 def elaborate(self
, platform
):
599 m
= FPNumBase
.elaborate(self
, platform
)
601 m
.d
.comb
+= self
.s
.eq(op
.s
)
602 m
.d
.comb
+= self
.e
.eq(op
.e
)
603 m
.d
.comb
+= self
.m
.eq(op
.m
)
605 with self
.mainm
.State("align"):
606 with m
.If(self
.e
< self
.inv
.e
):
607 m
.d
.sync
+= self
.shift_down()
611 def shift_down(self
, inp
):
612 """ shifts a mantissa down by one. exponent is increased to compensate
614 accuracy is lost as a result in the mantissa however there are 3
615 guard bits (the latter of which is the "sticky" bit)
617 return [self
.e
.eq(inp
.e
+ 1),
618 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
621 def shift_down_multi(self
, diff
):
622 """ shifts a mantissa down. exponent is increased to compensate
624 accuracy is lost as a result in the mantissa however there are 3
625 guard bits (the latter of which is the "sticky" bit)
627 this code works by variable-shifting the mantissa by up to
628 its maximum bit-length: no point doing more (it'll still be
631 the sticky bit is computed by shifting a batch of 1s by
632 the same amount, which will introduce zeros. it's then
633 inverted and used as a mask to get the LSBs of the mantissa.
634 those are then |'d into the sticky bit.
636 sm
= MultiShift(self
.width
)
637 mw
= Const(self
.m_width
-1, len(diff
))
638 maxslen
= Mux(diff
> mw
, mw
, diff
)
639 rs
= sm
.rshift(self
.m
[1:], maxslen
)
640 maxsleni
= mw
- maxslen
641 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
643 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
644 return [self
.e
.eq(self
.e
+ diff
),
645 self
.m
.eq(Cat(stickybits
, rs
))
648 def shift_up_multi(self
, diff
):
649 """ shifts a mantissa up. exponent is decreased to compensate
651 sm
= MultiShift(self
.width
)
652 mw
= Const(self
.m_width
, len(diff
))
653 maxslen
= Mux(diff
> mw
, mw
, diff
)
655 return [self
.e
.eq(self
.e
- diff
),
656 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
660 class FPNumDecode(FPNumBase
):
661 """ Floating-point Number Class
663 Contains signals for an incoming copy of the value, decoded into
664 sign / exponent / mantissa.
665 Also contains encoding functions, creation and recognition of
666 zero, NaN and inf (all signed)
668 Four extra bits are included in the mantissa: the top bit
669 (m[-1]) is effectively a carry-overflow. The other three are
670 guard (m[2]), round (m[1]), and sticky (m[0])
673 def __init__(self
, op
, fp
):
674 FPNumBase
.__init
__(self
, fp
)
677 def elaborate(self
, platform
):
678 m
= FPNumBase
.elaborate(self
, platform
)
680 m
.d
.comb
+= self
.decode(self
.v
)
685 """ decodes a latched value into sign / exponent / mantissa
687 bias is subtracted here, from the exponent. exponent
688 is extended to 10 bits so that subtract 127 is done on
691 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
692 #print ("decode", self.e_end)
693 return [self
.m
.eq(Cat(*args
)), # mantissa
694 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
), # exp
695 self
.s
.eq(v
[-1]), # sign
699 class FPNumIn(FPNumBase
):
700 """ Floating-point Number Class
702 Contains signals for an incoming copy of the value, decoded into
703 sign / exponent / mantissa.
704 Also contains encoding functions, creation and recognition of
705 zero, NaN and inf (all signed)
707 Four extra bits are included in the mantissa: the top bit
708 (m[-1]) is effectively a carry-overflow. The other three are
709 guard (m[2]), round (m[1]), and sticky (m[0])
712 def __init__(self
, op
, fp
):
713 FPNumBase
.__init
__(self
, fp
)
714 self
.latch_in
= Signal()
717 def decode2(self
, m
):
718 """ decodes a latched value into sign / exponent / mantissa
720 bias is subtracted here, from the exponent. exponent
721 is extended to 10 bits so that subtract 127 is done on
725 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
726 #print ("decode", self.e_end)
727 res
= ObjectProxy(m
, pipemode
=False)
728 res
.m
= Cat(*args
) # mantissa
729 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.fp
.P127
# exp
734 """ decodes a latched value into sign / exponent / mantissa
736 bias is subtracted here, from the exponent. exponent
737 is extended to 10 bits so that subtract 127 is done on
740 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
741 #print ("decode", self.e_end)
742 return [self
.m
.eq(Cat(*args
)), # mantissa
743 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
744 self
.s
.eq(v
[-1]), # sign
747 def shift_down(self
, inp
):
748 """ shifts a mantissa down by one. exponent is increased to compensate
750 accuracy is lost as a result in the mantissa however there are 3
751 guard bits (the latter of which is the "sticky" bit)
753 return [self
.e
.eq(inp
.e
+ 1),
754 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
757 def shift_down_multi(self
, diff
, inp
=None):
758 """ shifts a mantissa down. exponent is increased to compensate
760 accuracy is lost as a result in the mantissa however there are 3
761 guard bits (the latter of which is the "sticky" bit)
763 this code works by variable-shifting the mantissa by up to
764 its maximum bit-length: no point doing more (it'll still be
767 the sticky bit is computed by shifting a batch of 1s by
768 the same amount, which will introduce zeros. it's then
769 inverted and used as a mask to get the LSBs of the mantissa.
770 those are then |'d into the sticky bit.
774 sm
= MultiShift(self
.width
)
775 mw
= Const(self
.m_width
-1, len(diff
))
776 maxslen
= Mux(diff
> mw
, mw
, diff
)
777 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
778 maxsleni
= mw
- maxslen
779 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
781 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
782 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
783 return [self
.e
.eq(inp
.e
+ diff
),
784 self
.m
.eq(Cat(stickybit
, rs
))
787 def shift_up_multi(self
, diff
):
788 """ shifts a mantissa up. exponent is decreased to compensate
790 sm
= MultiShift(self
.width
)
791 mw
= Const(self
.m_width
, len(diff
))
792 maxslen
= Mux(diff
> mw
, mw
, diff
)
794 return [self
.e
.eq(self
.e
- diff
),
795 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
799 class Trigger(Elaboratable
):
802 self
.stb
= Signal(reset
=0)
804 self
.trigger
= Signal(reset_less
=True)
806 def elaborate(self
, platform
):
808 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
812 return [self
.stb
.eq(inp
.stb
),
817 return [self
.stb
, self
.ack
]
820 class FPOpIn(PrevControl
):
821 def __init__(self
, width
):
822 PrevControl
.__init
__(self
)
829 def chain_inv(self
, in_op
, extra
=None):
831 if extra
is not None:
833 return [self
.v
.eq(in_op
.v
), # receive value
834 self
.stb
.eq(stb
), # receive STB
835 in_op
.ack
.eq(~self
.ack
), # send ACK
838 def chain_from(self
, in_op
, extra
=None):
840 if extra
is not None:
842 return [self
.v
.eq(in_op
.v
), # receive value
843 self
.stb
.eq(stb
), # receive STB
844 in_op
.ack
.eq(self
.ack
), # send ACK
848 class FPOpOut(NextControl
):
849 def __init__(self
, width
):
850 NextControl
.__init
__(self
)
857 def chain_inv(self
, in_op
, extra
=None):
859 if extra
is not None:
861 return [self
.v
.eq(in_op
.v
), # receive value
862 self
.stb
.eq(stb
), # receive STB
863 in_op
.ack
.eq(~self
.ack
), # send ACK
866 def chain_from(self
, in_op
, extra
=None):
868 if extra
is not None:
870 return [self
.v
.eq(in_op
.v
), # receive value
871 self
.stb
.eq(stb
), # receive STB
872 in_op
.ack
.eq(self
.ack
), # send ACK
877 def __init__(self
, name
=None):
880 self
.guard
= Signal(reset_less
=True, name
=name
+"guard") # tot[2]
881 self
.round_bit
= Signal(reset_less
=True, name
=name
+"round") # tot[1]
882 self
.sticky
= Signal(reset_less
=True, name
=name
+"sticky") # tot[0]
883 self
.m0
= Signal(reset_less
=True, name
=name
+"m0") # mantissa bit 0
885 #self.roundz = Signal(reset_less=True)
894 return [self
.guard
.eq(inp
.guard
),
895 self
.round_bit
.eq(inp
.round_bit
),
896 self
.sticky
.eq(inp
.sticky
),
901 return self
.guard
& (self
.round_bit | self
.sticky | self
.m0
)
904 class OverflowMod(Elaboratable
, Overflow
):
905 def __init__(self
, name
=None):
906 Overflow
.__init
__(self
, name
)
909 self
.roundz_out
= Signal(reset_less
=True, name
=name
+"roundz_out")
912 yield from Overflow
.__iter
__(self
)
913 yield self
.roundz_out
916 return [self
.roundz_out
.eq(inp
.roundz_out
)] + Overflow
.eq(self
)
918 def elaborate(self
, platform
):
920 m
.d
.comb
+= self
.roundz_out
.eq(self
.roundz
)
925 """ IEEE754 Floating Point Base Class
927 contains common functions for FP manipulation, such as
928 extracting and packing operands, normalisation, denormalisation,
932 def get_op(self
, m
, op
, v
, next_state
):
933 """ this function moves to the next state and copies the operand
934 when both stb and ack are 1.
935 acknowledgement is sent by setting ack to ZERO.
939 with m
.If((op
.ready_o
) & (op
.valid_i_test
)):
941 # op is latched in from FPNumIn class on same ack/stb
942 m
.d
.comb
+= ack
.eq(0)
944 m
.d
.comb
+= ack
.eq(1)
947 def denormalise(self
, m
, a
):
948 """ denormalises a number. this is probably the wrong name for
949 this function. for normalised numbers (exponent != minimum)
950 one *extra* bit (the implicit 1) is added *back in*.
951 for denormalised numbers, the mantissa is left alone
952 and the exponent increased by 1.
954 both cases *effectively multiply the number stored by 2*,
955 which has to be taken into account when extracting the result.
957 with m
.If(a
.exp_n127
):
958 m
.d
.sync
+= a
.e
.eq(a
.fp
.N126
) # limit a exponent
960 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
962 def op_normalise(self
, m
, op
, next_state
):
963 """ operand normalisation
964 NOTE: just like "align", this one keeps going round every clock
965 until the result's exponent is within acceptable "range"
967 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
969 op
.e
.eq(op
.e
- 1), # DECREASE exponent
970 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
975 def normalise_1(self
, m
, z
, of
, next_state
):
976 """ first stage normalisation
978 NOTE: just like "align", this one keeps going round every clock
979 until the result's exponent is within acceptable "range"
980 NOTE: the weirdness of reassigning guard and round is due to
981 the extra mantissa bits coming from tot[0..2]
983 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.fp
.N126
)):
985 z
.e
.eq(z
.e
- 1), # DECREASE exponent
986 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
987 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
988 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
989 of
.round_bit
.eq(0), # reset round bit
995 def normalise_2(self
, m
, z
, of
, next_state
):
996 """ second stage normalisation
998 NOTE: just like "align", this one keeps going round every clock
999 until the result's exponent is within acceptable "range"
1000 NOTE: the weirdness of reassigning guard and round is due to
1001 the extra mantissa bits coming from tot[0..2]
1003 with m
.If(z
.e
< z
.fp
.N126
):
1005 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
1006 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
1007 of
.guard
.eq(z
.m
[0]),
1009 of
.round_bit
.eq(of
.guard
),
1010 of
.sticky
.eq(of
.sticky | of
.round_bit
)
1015 def roundz(self
, m
, z
, roundz
):
1016 """ performs rounding on the output. TODO: different kinds of rounding
1019 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
1020 with m
.If(z
.m
== z
.fp
.m1s
): # all 1s
1021 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
1023 def corrections(self
, m
, z
, next_state
):
1024 """ denormalisation and sign-bug corrections
1027 # denormalised, correct exponent to zero
1028 with m
.If(z
.is_denormalised
):
1029 m
.d
.sync
+= z
.e
.eq(z
.fp
.N127
)
1031 def pack(self
, m
, z
, next_state
):
1032 """ packs the result into the output (detects overflow->Inf)
1035 # if overflow occurs, return inf
1036 with m
.If(z
.is_overflowed
):
1037 m
.d
.sync
+= z
.inf(z
.s
)
1039 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
1041 def put_z(self
, m
, z
, out_z
, next_state
):
1042 """ put_z: stores the result in the output. raises stb and waits
1043 for ack to be set to 1 before moving to the next state.
1044 resets stb back to zero when that occurs, as acknowledgement.
1049 with m
.If(out_z
.valid_o
& out_z
.ready_i_test
):
1050 m
.d
.sync
+= out_z
.valid_o
.eq(0)
1053 m
.d
.sync
+= out_z
.valid_o
.eq(1)
1056 class FPState(FPBase
):
1057 def __init__(self
, state_from
):
1058 self
.state_from
= state_from
1060 def set_inputs(self
, inputs
):
1061 self
.inputs
= inputs
1062 for k
, v
in inputs
.items():
1065 def set_outputs(self
, outputs
):
1066 self
.outputs
= outputs
1067 for k
, v
in outputs
.items():
1072 def __init__(self
, id_wid
):
1073 self
.id_wid
= id_wid
1075 self
.in_mid
= Signal(id_wid
, reset_less
=True)
1076 self
.out_mid
= Signal(id_wid
, reset_less
=True)
1081 def idsync(self
, m
):
1082 if self
.id_wid
is not None:
1083 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)
1086 if __name__
== '__main__':