1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
7 from operator
import or_
8 from functools
import reduce
12 def __init__(self
, width
):
14 self
.smax
= int(log(width
) / log(2))
15 self
.i
= Signal(width
, reset_less
=True)
16 self
.s
= Signal(self
.smax
, reset_less
=True)
17 self
.o
= Signal(width
, reset_less
=True)
19 def elaborate(self
, platform
):
21 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
26 """ Generates variable-length single-cycle shifter from a series
27 of conditional tests on each bit of the left/right shift operand.
28 Each bit tested produces output shifted by that number of bits,
29 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
30 shifts by 2 bits, each partial result cascading to the next Mux.
32 Could be adapted to do arithmetic shift by taking copies of the
36 def __init__(self
, width
):
38 self
.smax
= int(log(width
) / log(2))
40 def lshift(self
, op
, s
):
44 for i
in range(self
.smax
):
46 res
= Mux(s
& (1<<i
), Cat(zeros
, res
[0:-(1<<i
)]), res
)
49 def rshift(self
, op
, s
):
53 for i
in range(self
.smax
):
55 res
= Mux(s
& (1<<i
), Cat(res
[(1<<i
):], zeros
), res
)
60 """ Floating-point Base Number Class
62 def __init__(self
, width
, m_extra
=True):
64 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
65 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
66 e_max
= 1<<(e_width
-3)
67 self
.rmw
= m_width
# real mantissa width (not including extras)
70 # mantissa extra bits (top,guard,round)
72 m_width
+= self
.m_extra
75 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
76 self
.m_width
= m_width
77 self
.e_width
= e_width
78 self
.e_start
= self
.rmw
- 1
79 self
.e_end
= self
.rmw
+ self
.e_width
- 3 # for decoding
81 self
.v
= Signal(width
, reset_less
=True) # Latched copy of value
82 self
.m
= Signal(m_width
, reset_less
=True) # Mantissa
83 self
.e
= Signal((e_width
, True), reset_less
=True) # Exponent: IEEE754exp+2 bits, signed
84 self
.s
= Signal(reset_less
=True) # Sign bit
86 self
.mzero
= Const(0, (m_width
, False))
87 self
.m1s
= Const(-1, (m_width
, False))
88 self
.P128
= Const(e_max
, (e_width
, True))
89 self
.P127
= Const(e_max
-1, (e_width
, True))
90 self
.N127
= Const(-(e_max
-1), (e_width
, True))
91 self
.N126
= Const(-(e_max
-2), (e_width
, True))
93 self
.is_nan
= Signal(reset_less
=True)
94 self
.is_zero
= Signal(reset_less
=True)
95 self
.is_inf
= Signal(reset_less
=True)
96 self
.is_overflowed
= Signal(reset_less
=True)
97 self
.is_denormalised
= Signal(reset_less
=True)
98 self
.exp_128
= Signal(reset_less
=True)
99 self
.exp_lt_n126
= Signal(reset_less
=True)
100 self
.exp_gt_n126
= Signal(reset_less
=True)
101 self
.exp_gt127
= Signal(reset_less
=True)
102 self
.exp_n127
= Signal(reset_less
=True)
103 self
.exp_n126
= Signal(reset_less
=True)
104 self
.m_zero
= Signal(reset_less
=True)
105 self
.m_msbzero
= Signal(reset_less
=True)
107 def elaborate(self
, platform
):
109 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
110 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
111 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
112 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
113 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
114 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.P128
)
115 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.e
> self
.N126
)
116 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.e
< self
.N126
)
117 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.P127
)
118 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.N127
)
119 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.N126
)
120 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.mzero
)
121 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.e_start
] == 0)
126 return (self
.exp_128
) & (~self
.m_zero
)
129 return (self
.exp_128
) & (self
.m_zero
)
132 return (self
.exp_n127
) & (self
.m_zero
)
134 def _is_overflowed(self
):
135 return self
.exp_gt127
137 def _is_denormalised(self
):
138 return (self
.exp_n126
) & (self
.m_msbzero
)
141 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
144 class FPNumOut(FPNumBase
):
145 """ Floating-point Number Class
147 Contains signals for an incoming copy of the value, decoded into
148 sign / exponent / mantissa.
149 Also contains encoding functions, creation and recognition of
150 zero, NaN and inf (all signed)
152 Four extra bits are included in the mantissa: the top bit
153 (m[-1]) is effectively a carry-overflow. The other three are
154 guard (m[2]), round (m[1]), and sticky (m[0])
156 def __init__(self
, width
, m_extra
=True):
157 FPNumBase
.__init
__(self
, width
, m_extra
)
159 def elaborate(self
, platform
):
160 m
= FPNumBase
.elaborate(self
, platform
)
164 def create(self
, s
, e
, m
):
165 """ creates a value from sign / exponent / mantissa
167 bias is added here, to the exponent
170 self
.v
[-1].eq(s
), # sign
171 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.P127
), # exp (add on bias)
172 self
.v
[0:self
.e_start
].eq(m
) # mantissa
176 return self
.create(s
, self
.P128
, 1<<(self
.e_start
-1))
179 return self
.create(s
, self
.P128
, 0)
182 return self
.create(s
, self
.N127
, 0)
185 class FPNumShift(FPNumBase
):
186 """ Floating-point Number Class for shifting
188 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
189 FPNumBase
.__init
__(self
, width
, m_extra
)
190 self
.latch_in
= Signal()
195 def elaborate(self
, platform
):
196 m
= FPNumBase
.elaborate(self
, platform
)
198 m
.d
.comb
+= self
.s
.eq(op
.s
)
199 m
.d
.comb
+= self
.e
.eq(op
.e
)
200 m
.d
.comb
+= self
.m
.eq(op
.m
)
202 with self
.mainm
.State("align"):
203 with m
.If(self
.e
< self
.inv
.e
):
204 m
.d
.sync
+= self
.shift_down()
208 def shift_down(self
, inp
):
209 """ shifts a mantissa down by one. exponent is increased to compensate
211 accuracy is lost as a result in the mantissa however there are 3
212 guard bits (the latter of which is the "sticky" bit)
214 return [self
.e
.eq(inp
.e
+ 1),
215 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
218 def shift_down_multi(self
, diff
):
219 """ shifts a mantissa down. exponent is increased to compensate
221 accuracy is lost as a result in the mantissa however there are 3
222 guard bits (the latter of which is the "sticky" bit)
224 this code works by variable-shifting the mantissa by up to
225 its maximum bit-length: no point doing more (it'll still be
228 the sticky bit is computed by shifting a batch of 1s by
229 the same amount, which will introduce zeros. it's then
230 inverted and used as a mask to get the LSBs of the mantissa.
231 those are then |'d into the sticky bit.
233 sm
= MultiShift(self
.width
)
234 mw
= Const(self
.m_width
-1, len(diff
))
235 maxslen
= Mux(diff
> mw
, mw
, diff
)
236 rs
= sm
.rshift(self
.m
[1:], maxslen
)
237 maxsleni
= mw
- maxslen
238 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
240 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
241 return [self
.e
.eq(self
.e
+ diff
),
242 self
.m
.eq(Cat(stickybits
, rs
))
245 def shift_up_multi(self
, diff
):
246 """ shifts a mantissa up. exponent is decreased to compensate
248 sm
= MultiShift(self
.width
)
249 mw
= Const(self
.m_width
, len(diff
))
250 maxslen
= Mux(diff
> mw
, mw
, diff
)
252 return [self
.e
.eq(self
.e
- diff
),
253 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
256 class FPNumIn(FPNumBase
):
257 """ Floating-point Number Class
259 Contains signals for an incoming copy of the value, decoded into
260 sign / exponent / mantissa.
261 Also contains encoding functions, creation and recognition of
262 zero, NaN and inf (all signed)
264 Four extra bits are included in the mantissa: the top bit
265 (m[-1]) is effectively a carry-overflow. The other three are
266 guard (m[2]), round (m[1]), and sticky (m[0])
268 def __init__(self
, op
, width
, m_extra
=True):
269 FPNumBase
.__init
__(self
, width
, m_extra
)
270 self
.latch_in
= Signal()
273 def elaborate(self
, platform
):
274 m
= FPNumBase
.elaborate(self
, platform
)
276 #m.d.comb += self.latch_in.eq(self.op.ack & self.op.stb)
277 #with m.If(self.latch_in):
278 # m.d.sync += self.decode(self.v)
283 """ decodes a latched value into sign / exponent / mantissa
285 bias is subtracted here, from the exponent. exponent
286 is extended to 10 bits so that subtract 127 is done on
289 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
290 #print ("decode", self.e_end)
291 return [self
.m
.eq(Cat(*args
)), # mantissa
292 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
293 self
.s
.eq(v
[-1]), # sign
296 def shift_down(self
, inp
):
297 """ shifts a mantissa down by one. exponent is increased to compensate
299 accuracy is lost as a result in the mantissa however there are 3
300 guard bits (the latter of which is the "sticky" bit)
302 return [self
.e
.eq(inp
.e
+ 1),
303 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
306 def shift_down_multi(self
, diff
):
307 """ shifts a mantissa down. exponent is increased to compensate
309 accuracy is lost as a result in the mantissa however there are 3
310 guard bits (the latter of which is the "sticky" bit)
312 this code works by variable-shifting the mantissa by up to
313 its maximum bit-length: no point doing more (it'll still be
316 the sticky bit is computed by shifting a batch of 1s by
317 the same amount, which will introduce zeros. it's then
318 inverted and used as a mask to get the LSBs of the mantissa.
319 those are then |'d into the sticky bit.
321 sm
= MultiShift(self
.width
)
322 mw
= Const(self
.m_width
-1, len(diff
))
323 maxslen
= Mux(diff
> mw
, mw
, diff
)
324 rs
= sm
.rshift(self
.m
[1:], maxslen
)
325 maxsleni
= mw
- maxslen
326 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
328 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
329 return [self
.e
.eq(self
.e
+ diff
),
330 self
.m
.eq(Cat(stickybits
, rs
))
333 def shift_up_multi(self
, diff
):
334 """ shifts a mantissa up. exponent is decreased to compensate
336 sm
= MultiShift(self
.width
)
337 mw
= Const(self
.m_width
, len(diff
))
338 maxslen
= Mux(diff
> mw
, mw
, diff
)
340 return [self
.e
.eq(self
.e
- diff
),
341 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
345 def __init__(self
, width
):
348 self
.v
= Signal(width
)
349 self
.stb
= Signal(reset
=0)
351 self
.trigger
= Signal(reset_less
=True)
353 def elaborate(self
, platform
):
355 m
.d
.sync
+= self
.trigger
.eq(self
.stb
& self
.ack
)
358 def chain_inv(self
, in_op
, extra
=None):
360 if extra
is not None:
362 return [self
.v
.eq(in_op
.v
), # receive value
363 self
.stb
.eq(stb
), # receive STB
364 in_op
.ack
.eq(~self
.ack
), # send ACK
367 def chain_from(self
, in_op
, extra
=None):
369 if extra
is not None:
371 return [self
.v
.eq(in_op
.v
), # receive value
372 self
.stb
.eq(stb
), # receive STB
373 in_op
.ack
.eq(self
.ack
), # send ACK
377 return [self
.v
.eq(inp
.v
),
378 self
.stb
.eq(inp
.stb
),
383 return [self
.v
, self
.stb
, self
.ack
]
388 self
.guard
= Signal(reset_less
=True) # tot[2]
389 self
.round_bit
= Signal(reset_less
=True) # tot[1]
390 self
.sticky
= Signal(reset_less
=True) # tot[0]
391 self
.m0
= Signal(reset_less
=True) # mantissa zero bit
393 self
.roundz
= Signal(reset_less
=True)
396 return [self
.guard
.eq(inp
.guard
),
397 self
.round_bit
.eq(inp
.round_bit
),
398 self
.sticky
.eq(inp
.sticky
),
401 def elaborate(self
, platform
):
403 m
.d
.comb
+= self
.roundz
.eq(self
.guard
& \
404 (self
.round_bit | self
.sticky | self
.m0
))
409 """ IEEE754 Floating Point Base Class
411 contains common functions for FP manipulation, such as
412 extracting and packing operands, normalisation, denormalisation,
416 def get_op(self
, m
, op
, v
, next_state
):
417 """ this function moves to the next state and copies the operand
418 when both stb and ack are 1.
419 acknowledgement is sent by setting ack to ZERO.
421 with m
.If((op
.ack
) & (op
.stb
)):
424 # op is latched in from FPNumIn class on same ack/stb
429 m
.d
.sync
+= op
.ack
.eq(1)
431 def denormalise(self
, m
, a
):
432 """ denormalises a number. this is probably the wrong name for
433 this function. for normalised numbers (exponent != minimum)
434 one *extra* bit (the implicit 1) is added *back in*.
435 for denormalised numbers, the mantissa is left alone
436 and the exponent increased by 1.
438 both cases *effectively multiply the number stored by 2*,
439 which has to be taken into account when extracting the result.
441 with m
.If(a
.exp_n127
):
442 m
.d
.sync
+= a
.e
.eq(a
.N126
) # limit a exponent
444 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
446 def op_normalise(self
, m
, op
, next_state
):
447 """ operand normalisation
448 NOTE: just like "align", this one keeps going round every clock
449 until the result's exponent is within acceptable "range"
451 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
453 op
.e
.eq(op
.e
- 1), # DECREASE exponent
454 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
459 def normalise_1(self
, m
, z
, of
, next_state
):
460 """ first stage normalisation
462 NOTE: just like "align", this one keeps going round every clock
463 until the result's exponent is within acceptable "range"
464 NOTE: the weirdness of reassigning guard and round is due to
465 the extra mantissa bits coming from tot[0..2]
467 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.N126
)):
469 z
.e
.eq(z
.e
- 1), # DECREASE exponent
470 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
471 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
472 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
473 of
.round_bit
.eq(0), # reset round bit
479 def normalise_2(self
, m
, z
, of
, next_state
):
480 """ second stage normalisation
482 NOTE: just like "align", this one keeps going round every clock
483 until the result's exponent is within acceptable "range"
484 NOTE: the weirdness of reassigning guard and round is due to
485 the extra mantissa bits coming from tot[0..2]
487 with m
.If(z
.e
< z
.N126
):
489 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
490 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
493 of
.round_bit
.eq(of
.guard
),
494 of
.sticky
.eq(of
.sticky | of
.round_bit
)
499 def roundz(self
, m
, z
, out_z
, roundz
):
500 """ performs rounding on the output. TODO: different kinds of rounding
502 m
.d
.comb
+= out_z
.copy(z
) # copies input to output first
504 m
.d
.comb
+= out_z
.m
.eq(z
.m
+ 1) # mantissa rounds up
505 with m
.If(z
.m
== z
.m1s
): # all 1s
506 m
.d
.comb
+= out_z
.e
.eq(z
.e
+ 1) # exponent rounds up
508 def corrections(self
, m
, z
, next_state
):
509 """ denormalisation and sign-bug corrections
512 # denormalised, correct exponent to zero
513 with m
.If(z
.is_denormalised
):
514 m
.d
.sync
+= z
.e
.eq(z
.N127
)
516 def pack(self
, m
, z
, next_state
):
517 """ packs the result into the output (detects overflow->Inf)
520 # if overflow occurs, return inf
521 with m
.If(z
.is_overflowed
):
522 m
.d
.sync
+= z
.inf(z
.s
)
524 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
526 def put_z(self
, m
, z
, out_z
, next_state
):
527 """ put_z: stores the result in the output. raises stb and waits
528 for ack to be set to 1 before moving to the next state.
529 resets stb back to zero when that occurs, as acknowledgement.
534 with m
.If(out_z
.stb
& out_z
.ack
):
535 m
.d
.sync
+= out_z
.stb
.eq(0)
538 m
.d
.sync
+= out_z
.stb
.eq(1)