242f41d814a2b11d0bfb177c13dd95ee02bf88d2
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
7 from operator
import or_
8 from functools
import reduce
12 def __init__(self
, width
):
14 self
.smax
= int(log(width
) / log(2))
15 self
.i
= Signal(width
, reset_less
=True)
16 self
.s
= Signal(self
.smax
, reset_less
=True)
17 self
.o
= Signal(width
, reset_less
=True)
19 def elaborate(self
, platform
):
21 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
32 Could be adapted to do arithmetic shift by taking copies of the
36 def __init__(self
, width
):
38 self
.smax
= int(log(width
) / log(2))
40 def lshift(self
, op
, s
):
44 for i
in range(self
.smax
):
46 res
= Mux(s
& (1<<i
), Cat(zeros
, res
[0:-(1<<i
)]), res
)
49 def rshift(self
, op
, s
):
53 for i
in range(self
.smax
):
55 res
= Mux(s
& (1<<i
), Cat(res
[(1<<i
):], zeros
), res
)
60 """ Floating-point Base Number Class
62 def __init__(self
, width
, m_extra
=True):
64 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
65 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
66 e_max
= 1<<(e_width
-3)
67 self
.rmw
= m_width
# real mantissa width (not including extras)
70 # mantissa extra bits (top,guard,round)
72 m_width
+= self
.m_extra
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self
.m_width
= m_width
77 self
.e_width
= e_width
78 self
.e_start
= self
.rmw
- 1
79 self
.e_end
= self
.rmw
+ self
.e_width
- 3 # for decoding
81 self
.v
= Signal(width
, reset_less
=True) # Latched copy of value
82 self
.m
= Signal(m_width
, reset_less
=True) # Mantissa
83 self
.e
= Signal((e_width
, True), reset_less
=True) # Exponent: IEEE754exp+2 bits, signed
84 self
.s
= Signal(reset_less
=True) # Sign bit
86 self
.mzero
= Const(0, (m_width
, False))
87 self
.m1s
= Const(-1, (m_width
, False))
88 self
.P128
= Const(e_max
, (e_width
, True))
89 self
.P127
= Const(e_max
-1, (e_width
, True))
90 self
.N127
= Const(-(e_max
-1), (e_width
, True))
91 self
.N126
= Const(-(e_max
-2), (e_width
, True))
93 self
.is_nan
= Signal(reset_less
=True)
94 self
.is_zero
= Signal(reset_less
=True)
95 self
.is_inf
= Signal(reset_less
=True)
96 self
.is_overflowed
= Signal(reset_less
=True)
97 self
.is_denormalised
= Signal(reset_less
=True)
98 self
.exp_128
= Signal(reset_less
=True)
99 self
.exp_gt127
= Signal(reset_less
=True)
100 self
.exp_n127
= Signal(reset_less
=True)
101 self
.exp_n126
= Signal(reset_less
=True)
102 self
.m_zero
= Signal(reset_less
=True)
103 self
.m_msbzero
= Signal(reset_less
=True)
105 def elaborate(self
, platform
):
107 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
108 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
109 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
110 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
111 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
112 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.P128
)
113 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.P127
)
114 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.N127
)
115 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.N126
)
116 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.mzero
)
117 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.e_start
] == 0)
122 return (self
.exp_128
) & (~self
.m_zero
)
125 return (self
.exp_128
) & (self
.m_zero
)
128 return (self
.exp_n127
) & (self
.m_zero
)
130 def _is_overflowed(self
):
131 return self
.exp_gt127
133 def _is_denormalised(self
):
134 return (self
.exp_n126
) & (self
.m_msbzero
)
137 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
140 class FPNumOut(FPNumBase
):
141 """ Floating-point Number Class
143 Contains signals for an incoming copy of the value, decoded into
144 sign / exponent / mantissa.
145 Also contains encoding functions, creation and recognition of
146 zero, NaN and inf (all signed)
148 Four extra bits are included in the mantissa: the top bit
149 (m[-1]) is effectively a carry-overflow. The other three are
150 guard (m[2]), round (m[1]), and sticky (m[0])
152 def __init__(self
, width
, m_extra
=True):
153 FPNumBase
.__init
__(self
, width
, m_extra
)
155 def elaborate(self
, platform
):
156 m
= FPNumBase
.elaborate(self
, platform
)
160 def create(self
, s
, e
, m
):
161 """ creates a value from sign / exponent / mantissa
163 bias is added here, to the exponent
166 self
.v
[-1].eq(s
), # sign
167 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.P127
), # exp (add on bias)
168 self
.v
[0:self
.e_start
].eq(m
) # mantissa
172 return self
.create(s
, self
.P128
, 1<<(self
.e_start
-1))
175 return self
.create(s
, self
.P128
, 0)
178 return self
.create(s
, self
.N127
, 0)
181 class FPNumShift(FPNumBase
):
182 """ Floating-point Number Class for shifting
184 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
185 FPNumBase
.__init
__(self
, width
, m_extra
)
186 self
.latch_in
= Signal()
191 def elaborate(self
, platform
):
192 m
= FPNumBase
.elaborate(self
, platform
)
194 m
.d
.comb
+= self
.s
.eq(op
.s
)
195 m
.d
.comb
+= self
.e
.eq(op
.e
)
196 m
.d
.comb
+= self
.m
.eq(op
.m
)
198 with self
.mainm
.State("align"):
199 with m
.If(self
.e
< self
.inv
.e
):
200 m
.d
.sync
+= self
.shift_down()
204 def shift_down(self
):
205 """ shifts a mantissa down by one. exponent is increased to compensate
207 accuracy is lost as a result in the mantissa however there are 3
208 guard bits (the latter of which is the "sticky" bit)
210 return [self
.e
.eq(self
.e
+ 1),
211 self
.m
.eq(Cat(self
.m
[0] | self
.m
[1], self
.m
[2:], 0))
214 def shift_down_multi(self
, diff
):
215 """ shifts a mantissa down. exponent is increased to compensate
217 accuracy is lost as a result in the mantissa however there are 3
218 guard bits (the latter of which is the "sticky" bit)
220 this code works by variable-shifting the mantissa by up to
221 its maximum bit-length: no point doing more (it'll still be
224 the sticky bit is computed by shifting a batch of 1s by
225 the same amount, which will introduce zeros. it's then
226 inverted and used as a mask to get the LSBs of the mantissa.
227 those are then |'d into the sticky bit.
229 sm
= MultiShift(self
.width
)
230 mw
= Const(self
.m_width
-1, len(diff
))
231 maxslen
= Mux(diff
> mw
, mw
, diff
)
232 rs
= sm
.rshift(self
.m
[1:], maxslen
)
233 maxsleni
= mw
- maxslen
234 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
236 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
237 return [self
.e
.eq(self
.e
+ diff
),
238 self
.m
.eq(Cat(stickybits
, rs
))
241 def shift_up_multi(self
, diff
):
242 """ shifts a mantissa up. exponent is decreased to compensate
244 sm
= MultiShift(self
.width
)
245 mw
= Const(self
.m_width
, len(diff
))
246 maxslen
= Mux(diff
> mw
, mw
, diff
)
248 return [self
.e
.eq(self
.e
- diff
),
249 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
252 class FPNumIn(FPNumBase
):
253 """ Floating-point Number Class
255 Contains signals for an incoming copy of the value, decoded into
256 sign / exponent / mantissa.
257 Also contains encoding functions, creation and recognition of
258 zero, NaN and inf (all signed)
260 Four extra bits are included in the mantissa: the top bit
261 (m[-1]) is effectively a carry-overflow. The other three are
262 guard (m[2]), round (m[1]), and sticky (m[0])
264 def __init__(self
, op
, width
, m_extra
=True):
265 FPNumBase
.__init
__(self
, width
, m_extra
)
266 self
.latch_in
= Signal()
269 def elaborate(self
, platform
):
270 m
= FPNumBase
.elaborate(self
, platform
)
272 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
273 #with m.If(self.latch_in):
274 # m.d.sync += self.decode(self.v)
279 """ decodes a latched value into sign / exponent / mantissa
281 bias is subtracted here, from the exponent. exponent
282 is extended to 10 bits so that subtract 127 is done on
285 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
286 #print ("decode", self.e_end)
287 return [self
.m
.eq(Cat(*args
)), # mantissa
288 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
289 self
.s
.eq(v
[-1]), # sign
292 def shift_down(self
):
293 """ shifts a mantissa down by one. exponent is increased to compensate
295 accuracy is lost as a result in the mantissa however there are 3
296 guard bits (the latter of which is the "sticky" bit)
298 return [self
.e
.eq(self
.e
+ 1),
299 self
.m
.eq(Cat(self
.m
[0] | self
.m
[1], self
.m
[2:], 0))
302 def shift_down_multi(self
, diff
):
303 """ shifts a mantissa down. exponent is increased to compensate
305 accuracy is lost as a result in the mantissa however there are 3
306 guard bits (the latter of which is the "sticky" bit)
308 this code works by variable-shifting the mantissa by up to
309 its maximum bit-length: no point doing more (it'll still be
312 the sticky bit is computed by shifting a batch of 1s by
313 the same amount, which will introduce zeros. it's then
314 inverted and used as a mask to get the LSBs of the mantissa.
315 those are then |'d into the sticky bit.
317 sm
= MultiShift(self
.width
)
318 mw
= Const(self
.m_width
-1, len(diff
))
319 maxslen
= Mux(diff
> mw
, mw
, diff
)
320 rs
= sm
.rshift(self
.m
[1:], maxslen
)
321 maxsleni
= mw
- maxslen
322 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
324 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
325 return [self
.e
.eq(self
.e
+ diff
),
326 self
.m
.eq(Cat(stickybits
, rs
))
329 def shift_up_multi(self
, diff
):
330 """ shifts a mantissa up. exponent is decreased to compensate
332 sm
= MultiShift(self
.width
)
333 mw
= Const(self
.m_width
, len(diff
))
334 maxslen
= Mux(diff
> mw
, mw
, diff
)
336 return [self
.e
.eq(self
.e
- diff
),
337 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
341 def __init__(self
, width
):
344 self
.v
= Signal(width
)
345 self
.stb
= Signal(reset
=0)
348 def chain_inv(self
, in_op
, extra
=None):
350 if extra
is not None:
352 return [self
.v
.eq(in_op
.v
), # receive value
353 self
.stb
.eq(stb
), # receive STB
354 in_op
.ack
.eq(~self
.ack
), # send ACK
357 def chain_from(self
, in_op
, extra
=None):
359 if extra
is not None:
361 return [self
.v
.eq(in_op
.v
), # receive value
362 self
.stb
.eq(stb
), # receive STB
363 in_op
.ack
.eq(self
.ack
), # send ACK
367 return [self
.v
, self
.stb
, self
.ack
]
372 self
.guard
= Signal(reset_less
=True) # tot[2]
373 self
.round_bit
= Signal(reset_less
=True) # tot[1]
374 self
.sticky
= Signal(reset_less
=True) # tot[0]
375 self
.m0
= Signal(reset_less
=True) # mantissa zero bit
377 self
.roundz
= Signal(reset_less
=True)
380 return [self
.guard
.eq(inp
.guard
),
381 self
.round_bit
.eq(inp
.round_bit
),
382 self
.sticky
.eq(inp
.sticky
),
385 def elaborate(self
, platform
):
387 m
.d
.comb
+= self
.roundz
.eq(self
.guard
& \
388 (self
.round_bit | self
.sticky | self
.m0
))
393 """ IEEE754 Floating Point Base Class
395 contains common functions for FP manipulation, such as
396 extracting and packing operands, normalisation, denormalisation,
400 def get_op(self
, m
, op
, v
, next_state
):
401 """ this function moves to the next state and copies the operand
402 when both stb and ack are 1.
403 acknowledgement is sent by setting ack to ZERO.
405 with m
.If((op
.ack
) & (op
.stb
)):
408 # op is latched in from FPNumIn class on same ack/stb
413 m
.d
.sync
+= op
.ack
.eq(1)
415 def denormalise(self
, m
, a
):
416 """ denormalises a number. this is probably the wrong name for
417 this function. for normalised numbers (exponent != minimum)
418 one *extra* bit (the implicit 1) is added *back in*.
419 for denormalised numbers, the mantissa is left alone
420 and the exponent increased by 1.
422 both cases *effectively multiply the number stored by 2*,
423 which has to be taken into account when extracting the result.
425 with m
.If(a
.e
== a
.N127
):
426 m
.d
.sync
+= a
.e
.eq(a
.N126
) # limit a exponent
428 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
430 def op_normalise(self
, m
, op
, next_state
):
431 """ operand normalisation
432 NOTE: just like "align", this one keeps going round every clock
433 until the result's exponent is within acceptable "range"
435 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
437 op
.e
.eq(op
.e
- 1), # DECREASE exponent
438 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
443 def normalise_1(self
, m
, z
, of
, next_state
):
444 """ first stage normalisation
446 NOTE: just like "align", this one keeps going round every clock
447 until the result's exponent is within acceptable "range"
448 NOTE: the weirdness of reassigning guard and round is due to
449 the extra mantissa bits coming from tot[0..2]
451 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.N126
)):
453 z
.e
.eq(z
.e
- 1), # DECREASE exponent
454 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
455 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
456 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
457 of
.round_bit
.eq(0), # reset round bit
463 def normalise_2(self
, m
, z
, of
, next_state
):
464 """ second stage normalisation
466 NOTE: just like "align", this one keeps going round every clock
467 until the result's exponent is within acceptable "range"
468 NOTE: the weirdness of reassigning guard and round is due to
469 the extra mantissa bits coming from tot[0..2]
471 with m
.If(z
.e
< z
.N126
):
473 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
474 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
477 of
.round_bit
.eq(of
.guard
),
478 of
.sticky
.eq(of
.sticky | of
.round_bit
)
483 def roundz(self
, m
, z
, out_z
, roundz
):
484 """ performs rounding on the output. TODO: different kinds of rounding
486 m
.d
.comb
+= out_z
.copy(z
) # copies input to output first
488 m
.d
.comb
+= out_z
.m
.eq(z
.m
+ 1) # mantissa rounds up
489 with m
.If(z
.m
== z
.m1s
): # all 1s
490 m
.d
.comb
+= out_z
.e
.eq(z
.e
+ 1) # exponent rounds up
492 def corrections(self
, m
, z
, next_state
):
493 """ denormalisation and sign-bug corrections
496 # denormalised, correct exponent to zero
497 with m
.If(z
.is_denormalised
):
498 m
.d
.sync
+= z
.e
.eq(z
.N127
)
500 def pack(self
, m
, z
, next_state
):
501 """ packs the result into the output (detects overflow->Inf)
504 # if overflow occurs, return inf
505 with m
.If(z
.is_overflowed
):
506 m
.d
.sync
+= z
.inf(z
.s
)
508 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
510 def put_z(self
, m
, z
, out_z
, next_state
):
511 """ put_z: stores the result in the output. raises stb and waits
512 for ack to be set to 1 before moving to the next state.
513 resets stb back to zero when that occurs, as acknowledgement.
518 with m
.If(out_z
.stb
& out_z
.ack
):
519 m
.d
.sync
+= out_z
.stb
.eq(0)
522 m
.d
.sync
+= out_z
.stb
.eq(1)