6a1357056f44d40387b3b9c34cf5b15d173f1272
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
7 from operator
import or_
8 from functools
import reduce
12 def __init__(self
, width
):
14 self
.smax
= int(log(width
) / log(2))
15 self
.i
= Signal(width
, reset_less
=True)
16 self
.s
= Signal(self
.smax
, reset_less
=True)
17 self
.o
= Signal(width
, reset_less
=True)
19 def elaborate(self
, platform
):
21 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
32 Could be adapted to do arithmetic shift by taking copies of the
36 def __init__(self
, width
):
38 self
.smax
= int(log(width
) / log(2))
40 def lshift(self
, op
, s
):
44 for i
in range(self
.smax
):
46 res
= Mux(s
& (1<<i
), Cat(zeros
, res
[0:-(1<<i
)]), res
)
49 def rshift(self
, op
, s
):
53 for i
in range(self
.smax
):
55 res
= Mux(s
& (1<<i
), Cat(res
[(1<<i
):], zeros
), res
)
60 """ Floating-point Base Number Class
62 def __init__(self
, width
, m_extra
=True):
64 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
65 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
66 e_max
= 1<<(e_width
-3)
67 self
.rmw
= m_width
# real mantissa width (not including extras)
70 # mantissa extra bits (top,guard,round)
72 m_width
+= self
.m_extra
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self
.m_width
= m_width
77 self
.e_width
= e_width
78 self
.e_start
= self
.rmw
- 1
79 self
.e_end
= self
.rmw
+ self
.e_width
- 3 # for decoding
81 self
.v
= Signal(width
, reset_less
=True) # Latched copy of value
82 self
.m
= Signal(m_width
, reset_less
=True) # Mantissa
83 self
.e
= Signal((e_width
, True), reset_less
=True) # Exponent: IEEE754exp+2 bits, signed
84 self
.s
= Signal(reset_less
=True) # Sign bit
86 self
.mzero
= Const(0, (m_width
, False))
87 self
.m1s
= Const(-1, (m_width
, False))
88 self
.P128
= Const(e_max
, (e_width
, True))
89 self
.P127
= Const(e_max
-1, (e_width
, True))
90 self
.N127
= Const(-(e_max
-1), (e_width
, True))
91 self
.N126
= Const(-(e_max
-2), (e_width
, True))
93 self
.is_nan
= Signal(reset_less
=True)
94 self
.is_zero
= Signal(reset_less
=True)
95 self
.is_inf
= Signal(reset_less
=True)
96 self
.is_overflowed
= Signal(reset_less
=True)
97 self
.is_denormalised
= Signal(reset_less
=True)
98 self
.exp_128
= Signal(reset_less
=True)
99 self
.exp_gt127
= Signal(reset_less
=True)
100 self
.exp_n127
= Signal(reset_less
=True)
101 self
.exp_n126
= Signal(reset_less
=True)
102 self
.m_zero
= Signal(reset_less
=True)
103 self
.m_msbzero
= Signal(reset_less
=True)
105 def elaborate(self
, platform
):
107 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
108 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
109 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
110 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
111 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
112 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.P128
)
113 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.P127
)
114 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.N127
)
115 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.N126
)
116 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.mzero
)
117 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.e_start
] == 0)
122 return (self
.exp_128
) & (~self
.m_zero
)
125 return (self
.exp_128
) & (self
.m_zero
)
128 return (self
.exp_n127
) & (self
.m_zero
)
130 def _is_overflowed(self
):
131 return self
.exp_gt127
133 def _is_denormalised(self
):
134 return (self
.exp_n126
) & (self
.m_msbzero
)
137 class FPNumOut(FPNumBase
):
138 """ Floating-point Number Class
140 Contains signals for an incoming copy of the value, decoded into
141 sign / exponent / mantissa.
142 Also contains encoding functions, creation and recognition of
143 zero, NaN and inf (all signed)
145 Four extra bits are included in the mantissa: the top bit
146 (m[-1]) is effectively a carry-overflow. The other three are
147 guard (m[2]), round (m[1]), and sticky (m[0])
149 def __init__(self
, width
, m_extra
=True):
150 FPNumBase
.__init
__(self
, width
, m_extra
)
152 def elaborate(self
, platform
):
153 m
= FPNumBase
.elaborate(self
, platform
)
157 def create(self
, s
, e
, m
):
158 """ creates a value from sign / exponent / mantissa
160 bias is added here, to the exponent
163 self
.v
[-1].eq(s
), # sign
164 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.P127
), # exp (add on bias)
165 self
.v
[0:self
.e_start
].eq(m
) # mantissa
169 return self
.create(s
, self
.P128
, 1<<(self
.e_start
-1))
172 return self
.create(s
, self
.P128
, 0)
175 return self
.create(s
, self
.N127
, 0)
178 class FPNumShift(FPNumBase
):
179 """ Floating-point Number Class for shifting
181 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
182 FPNumBase
.__init
__(self
, width
, m_extra
)
183 self
.latch_in
= Signal()
188 def elaborate(self
, platform
):
189 m
= FPNumBase
.elaborate(self
, platform
)
191 m
.d
.comb
+= self
.s
.eq(op
.s
)
192 m
.d
.comb
+= self
.e
.eq(op
.e
)
193 m
.d
.comb
+= self
.m
.eq(op
.m
)
195 with self
.mainm
.State("align"):
196 with m
.If(self
.e
< self
.inv
.e
):
197 m
.d
.sync
+= self
.shift_down()
201 def shift_down(self
):
202 """ shifts a mantissa down by one. exponent is increased to compensate
204 accuracy is lost as a result in the mantissa however there are 3
205 guard bits (the latter of which is the "sticky" bit)
207 return [self
.e
.eq(self
.e
+ 1),
208 self
.m
.eq(Cat(self
.m
[0] | self
.m
[1], self
.m
[2:], 0))
211 def shift_down_multi(self
, diff
):
212 """ shifts a mantissa down. exponent is increased to compensate
214 accuracy is lost as a result in the mantissa however there are 3
215 guard bits (the latter of which is the "sticky" bit)
217 this code works by variable-shifting the mantissa by up to
218 its maximum bit-length: no point doing more (it'll still be
221 the sticky bit is computed by shifting a batch of 1s by
222 the same amount, which will introduce zeros. it's then
223 inverted and used as a mask to get the LSBs of the mantissa.
224 those are then |'d into the sticky bit.
226 sm
= MultiShift(self
.width
)
227 mw
= Const(self
.m_width
-1, len(diff
))
228 maxslen
= Mux(diff
> mw
, mw
, diff
)
229 rs
= sm
.rshift(self
.m
[1:], maxslen
)
230 maxsleni
= mw
- maxslen
231 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
233 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
234 return [self
.e
.eq(self
.e
+ diff
),
235 self
.m
.eq(Cat(stickybits
, rs
))
238 def shift_up_multi(self
, diff
):
239 """ shifts a mantissa up. exponent is decreased to compensate
241 sm
= MultiShift(self
.width
)
242 mw
= Const(self
.m_width
, len(diff
))
243 maxslen
= Mux(diff
> mw
, mw
, diff
)
245 return [self
.e
.eq(self
.e
- diff
),
246 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
249 class FPNumIn(FPNumBase
):
250 """ Floating-point Number Class
252 Contains signals for an incoming copy of the value, decoded into
253 sign / exponent / mantissa.
254 Also contains encoding functions, creation and recognition of
255 zero, NaN and inf (all signed)
257 Four extra bits are included in the mantissa: the top bit
258 (m[-1]) is effectively a carry-overflow. The other three are
259 guard (m[2]), round (m[1]), and sticky (m[0])
261 def __init__(self
, op
, width
, m_extra
=True):
262 FPNumBase
.__init
__(self
, width
, m_extra
)
263 self
.latch_in
= Signal()
266 def elaborate(self
, platform
):
267 m
= FPNumBase
.elaborate(self
, platform
)
269 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
270 #with m.If(self.latch_in):
271 # m.d.sync += self.decode(self.v)
276 """ decodes a latched value into sign / exponent / mantissa
278 bias is subtracted here, from the exponent. exponent
279 is extended to 10 bits so that subtract 127 is done on
282 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
283 #print ("decode", self.e_end)
284 return [self
.m
.eq(Cat(*args
)), # mantissa
285 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
286 self
.s
.eq(v
[-1]), # sign
289 def shift_down(self
):
290 """ shifts a mantissa down by one. exponent is increased to compensate
292 accuracy is lost as a result in the mantissa however there are 3
293 guard bits (the latter of which is the "sticky" bit)
295 return [self
.e
.eq(self
.e
+ 1),
296 self
.m
.eq(Cat(self
.m
[0] | self
.m
[1], self
.m
[2:], 0))
299 def shift_down_multi(self
, diff
):
300 """ shifts a mantissa down. exponent is increased to compensate
302 accuracy is lost as a result in the mantissa however there are 3
303 guard bits (the latter of which is the "sticky" bit)
305 this code works by variable-shifting the mantissa by up to
306 its maximum bit-length: no point doing more (it'll still be
309 the sticky bit is computed by shifting a batch of 1s by
310 the same amount, which will introduce zeros. it's then
311 inverted and used as a mask to get the LSBs of the mantissa.
312 those are then |'d into the sticky bit.
314 sm
= MultiShift(self
.width
)
315 mw
= Const(self
.m_width
-1, len(diff
))
316 maxslen
= Mux(diff
> mw
, mw
, diff
)
317 rs
= sm
.rshift(self
.m
[1:], maxslen
)
318 maxsleni
= mw
- maxslen
319 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
321 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
322 return [self
.e
.eq(self
.e
+ diff
),
323 self
.m
.eq(Cat(stickybits
, rs
))
326 def shift_up_multi(self
, diff
):
327 """ shifts a mantissa up. exponent is decreased to compensate
329 sm
= MultiShift(self
.width
)
330 mw
= Const(self
.m_width
, len(diff
))
331 maxslen
= Mux(diff
> mw
, mw
, diff
)
333 return [self
.e
.eq(self
.e
- diff
),
334 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
338 def __init__(self
, width
):
341 self
.v
= Signal(width
)
345 def chain_from(self
, in_op
):
346 return [self
.v
.eq(in_op
.v
), # receive value
347 self
.stb
.eq(in_op
.stb
), # receive STB
348 in_op
.ack
.eq(self
.ack
), # send ACK
352 return [self
.v
, self
.stb
, self
.ack
]
357 self
.guard
= Signal(reset_less
=True) # tot[2]
358 self
.round_bit
= Signal(reset_less
=True) # tot[1]
359 self
.sticky
= Signal(reset_less
=True) # tot[0]
360 self
.m0
= Signal(reset_less
=True) # mantissa zero bit
362 self
.roundz
= Signal(reset_less
=True)
364 def elaborate(self
, platform
):
366 m
.d
.comb
+= self
.roundz
.eq(self
.guard
& \
367 (self
.round_bit | self
.sticky | self
.m0
))
372 """ IEEE754 Floating Point Base Class
374 contains common functions for FP manipulation, such as
375 extracting and packing operands, normalisation, denormalisation,
379 def get_op(self
, m
, op
, v
, next_state
):
380 """ this function moves to the next state and copies the operand
381 when both stb and ack are 1.
382 acknowledgement is sent by setting ack to ZERO.
384 with m
.If((op
.ack
) & (op
.stb
)):
387 # op is latched in from FPNumIn class on same ack/stb
392 m
.d
.sync
+= op
.ack
.eq(1)
394 def denormalise(self
, m
, a
):
395 """ denormalises a number. this is probably the wrong name for
396 this function. for normalised numbers (exponent != minimum)
397 one *extra* bit (the implicit 1) is added *back in*.
398 for denormalised numbers, the mantissa is left alone
399 and the exponent increased by 1.
401 both cases *effectively multiply the number stored by 2*,
402 which has to be taken into account when extracting the result.
404 with m
.If(a
.e
== a
.N127
):
405 m
.d
.sync
+= a
.e
.eq(a
.N126
) # limit a exponent
407 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
409 def op_normalise(self
, m
, op
, next_state
):
410 """ operand normalisation
411 NOTE: just like "align", this one keeps going round every clock
412 until the result's exponent is within acceptable "range"
414 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
416 op
.e
.eq(op
.e
- 1), # DECREASE exponent
417 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
422 def normalise_1(self
, m
, z
, of
, next_state
):
423 """ first stage normalisation
425 NOTE: just like "align", this one keeps going round every clock
426 until the result's exponent is within acceptable "range"
427 NOTE: the weirdness of reassigning guard and round is due to
428 the extra mantissa bits coming from tot[0..2]
430 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.N126
)):
432 z
.e
.eq(z
.e
- 1), # DECREASE exponent
433 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
434 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
435 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
436 of
.round_bit
.eq(0), # reset round bit
442 def normalise_2(self
, m
, z
, of
, next_state
):
443 """ second stage normalisation
445 NOTE: just like "align", this one keeps going round every clock
446 until the result's exponent is within acceptable "range"
447 NOTE: the weirdness of reassigning guard and round is due to
448 the extra mantissa bits coming from tot[0..2]
450 with m
.If(z
.e
< z
.N126
):
452 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
453 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
456 of
.round_bit
.eq(of
.guard
),
457 of
.sticky
.eq(of
.sticky | of
.round_bit
)
462 def roundz(self
, m
, z
, of
, next_state
):
463 """ performs rounding on the output. TODO: different kinds of rounding
466 with m
.If(of
.roundz
):
467 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
468 with m
.If(z
.m
== z
.m1s
): # all 1s
469 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
471 def corrections(self
, m
, z
, next_state
):
472 """ denormalisation and sign-bug corrections
475 # denormalised, correct exponent to zero
476 with m
.If(z
.is_denormalised
):
477 m
.d
.sync
+= z
.e
.eq(z
.N127
)
479 def pack(self
, m
, z
, next_state
):
480 """ packs the result into the output (detects overflow->Inf)
483 # if overflow occurs, return inf
484 with m
.If(z
.is_overflowed
):
485 m
.d
.sync
+= z
.inf(z
.s
)
487 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
489 def put_z(self
, m
, z
, out_z
, next_state
):
490 """ put_z: stores the result in the output. raises stb and waits
491 for ack to be set to 1 before moving to the next state.
492 resets stb back to zero when that occurs, as acknowledgement.
497 with m
.If(out_z
.stb
& out_z
.ack
):
498 m
.d
.sync
+= out_z
.stb
.eq(0)
501 m
.d
.sync
+= out_z
.stb
.eq(1)