1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
5 from nmigen
import Signal
, Cat
, Const
, Mux
, Module
7 from operator
import or_
8 from functools
import reduce
10 from pipeline
import ObjectProxy
15 def __init__(self
, width
):
17 self
.smax
= int(log(width
) / log(2))
18 self
.i
= Signal(width
, reset_less
=True)
19 self
.s
= Signal(self
.smax
, reset_less
=True)
20 self
.o
= Signal(width
, reset_less
=True)
22 def elaborate(self
, platform
):
24 m
.d
.comb
+= self
.o
.eq(self
.i
>> self
.s
)
29 """ Generates variable-length single-cycle shifter from a series
30 of conditional tests on each bit of the left/right shift operand.
31 Each bit tested produces output shifted by that number of bits,
32 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
33 shifts by 2 bits, each partial result cascading to the next Mux.
35 Could be adapted to do arithmetic shift by taking copies of the
39 def __init__(self
, width
):
41 self
.smax
= int(log(width
) / log(2))
43 def lshift(self
, op
, s
):
47 for i
in range(self
.smax
):
49 res
= Mux(s
& (1<<i
), Cat(zeros
, res
[0:-(1<<i
)]), res
)
52 def rshift(self
, op
, s
):
56 for i
in range(self
.smax
):
58 res
= Mux(s
& (1<<i
), Cat(res
[(1<<i
):], zeros
), res
)
63 """ Floating-point Base Number Class
65 def __init__(self
, width
, m_extra
=True):
67 m_width
= {16: 11, 32: 24, 64: 53}[width
] # 1 extra bit (overflow)
68 e_width
= {16: 7, 32: 10, 64: 13}[width
] # 2 extra bits (overflow)
69 e_max
= 1<<(e_width
-3)
70 self
.rmw
= m_width
# real mantissa width (not including extras)
73 # mantissa extra bits (top,guard,round)
75 m_width
+= self
.m_extra
78 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
79 self
.m_width
= m_width
80 self
.e_width
= e_width
81 self
.e_start
= self
.rmw
- 1
82 self
.e_end
= self
.rmw
+ self
.e_width
- 3 # for decoding
84 self
.v
= Signal(width
, reset_less
=True) # Latched copy of value
85 self
.m
= Signal(m_width
, reset_less
=True) # Mantissa
86 self
.e
= Signal((e_width
, True), reset_less
=True) # Exponent: IEEE754exp+2 bits, signed
87 self
.s
= Signal(reset_less
=True) # Sign bit
89 self
.mzero
= Const(0, (m_width
, False))
90 m_msb
= 1<<(self
.m_width
-2)
91 self
.msb1
= Const(m_msb
, (m_width
, False))
92 self
.m1s
= Const(-1, (m_width
, False))
93 self
.P128
= Const(e_max
, (e_width
, True))
94 self
.P127
= Const(e_max
-1, (e_width
, True))
95 self
.N127
= Const(-(e_max
-1), (e_width
, True))
96 self
.N126
= Const(-(e_max
-2), (e_width
, True))
98 self
.is_nan
= Signal(reset_less
=True)
99 self
.is_zero
= Signal(reset_less
=True)
100 self
.is_inf
= Signal(reset_less
=True)
101 self
.is_overflowed
= Signal(reset_less
=True)
102 self
.is_denormalised
= Signal(reset_less
=True)
103 self
.exp_128
= Signal(reset_less
=True)
104 self
.exp_sub_n126
= Signal((e_width
, True), reset_less
=True)
105 self
.exp_lt_n126
= Signal(reset_less
=True)
106 self
.exp_gt_n126
= Signal(reset_less
=True)
107 self
.exp_gt127
= Signal(reset_less
=True)
108 self
.exp_n127
= Signal(reset_less
=True)
109 self
.exp_n126
= Signal(reset_less
=True)
110 self
.m_zero
= Signal(reset_less
=True)
111 self
.m_msbzero
= Signal(reset_less
=True)
113 def elaborate(self
, platform
):
115 m
.d
.comb
+= self
.is_nan
.eq(self
._is
_nan
())
116 m
.d
.comb
+= self
.is_zero
.eq(self
._is
_zero
())
117 m
.d
.comb
+= self
.is_inf
.eq(self
._is
_inf
())
118 m
.d
.comb
+= self
.is_overflowed
.eq(self
._is
_overflowed
())
119 m
.d
.comb
+= self
.is_denormalised
.eq(self
._is
_denormalised
())
120 m
.d
.comb
+= self
.exp_128
.eq(self
.e
== self
.P128
)
121 m
.d
.comb
+= self
.exp_sub_n126
.eq(self
.e
- self
.N126
)
122 m
.d
.comb
+= self
.exp_gt_n126
.eq(self
.exp_sub_n126
> 0)
123 m
.d
.comb
+= self
.exp_lt_n126
.eq(self
.exp_sub_n126
< 0)
124 m
.d
.comb
+= self
.exp_gt127
.eq(self
.e
> self
.P127
)
125 m
.d
.comb
+= self
.exp_n127
.eq(self
.e
== self
.N127
)
126 m
.d
.comb
+= self
.exp_n126
.eq(self
.e
== self
.N126
)
127 m
.d
.comb
+= self
.m_zero
.eq(self
.m
== self
.mzero
)
128 m
.d
.comb
+= self
.m_msbzero
.eq(self
.m
[self
.e_start
] == 0)
133 return (self
.exp_128
) & (~self
.m_zero
)
136 return (self
.exp_128
) & (self
.m_zero
)
139 return (self
.exp_n127
) & (self
.m_zero
)
141 def _is_overflowed(self
):
142 return self
.exp_gt127
144 def _is_denormalised(self
):
145 return (self
.exp_n126
) & (self
.m_msbzero
)
148 return [self
.s
.eq(inp
.s
), self
.e
.eq(inp
.e
), self
.m
.eq(inp
.m
)]
151 class FPNumOut(FPNumBase
):
152 """ Floating-point Number Class
154 Contains signals for an incoming copy of the value, decoded into
155 sign / exponent / mantissa.
156 Also contains encoding functions, creation and recognition of
157 zero, NaN and inf (all signed)
159 Four extra bits are included in the mantissa: the top bit
160 (m[-1]) is effectively a carry-overflow. The other three are
161 guard (m[2]), round (m[1]), and sticky (m[0])
163 def __init__(self
, width
, m_extra
=True):
164 FPNumBase
.__init
__(self
, width
, m_extra
)
166 def elaborate(self
, platform
):
167 m
= FPNumBase
.elaborate(self
, platform
)
171 def create(self
, s
, e
, m
):
172 """ creates a value from sign / exponent / mantissa
174 bias is added here, to the exponent
177 self
.v
[-1].eq(s
), # sign
178 self
.v
[self
.e_start
:self
.e_end
].eq(e
+ self
.P127
), # exp (add on bias)
179 self
.v
[0:self
.e_start
].eq(m
) # mantissa
183 return self
.create(s
, self
.P128
, 1<<(self
.e_start
-1))
186 return self
.create(s
, self
.P128
, 0)
189 return self
.create(s
, self
.N127
, 0)
191 def create2(self
, s
, e
, m
):
192 """ creates a value from sign / exponent / mantissa
194 bias is added here, to the exponent
196 e
= e
+ self
.P127
# exp (add on bias)
197 return Cat(m
[0:self
.e_start
],
198 e
[0:self
.e_end
-self
.e_start
],
202 return self
.create2(s
, self
.P128
, self
.msb1
)
205 return self
.create2(s
, self
.P128
, self
.mzero
)
208 return self
.create2(s
, self
.N127
, self
.mzero
)
211 class MultiShiftRMerge
:
212 """ shifts down (right) and merges lower bits into m[0].
213 m[0] is the "sticky" bit, basically
215 def __init__(self
, width
, s_max
=None):
217 s_max
= int(log(width
) / log(2))
219 self
.m
= Signal(width
, reset_less
=True)
220 self
.inp
= Signal(width
, reset_less
=True)
221 self
.diff
= Signal(s_max
, reset_less
=True)
224 def elaborate(self
, platform
):
227 rs
= Signal(self
.width
, reset_less
=True)
228 m_mask
= Signal(self
.width
, reset_less
=True)
229 smask
= Signal(self
.width
, reset_less
=True)
230 stickybit
= Signal(reset_less
=True)
231 maxslen
= Signal(self
.smax
, reset_less
=True)
232 maxsleni
= Signal(self
.smax
, reset_less
=True)
234 sm
= MultiShift(self
.width
-1)
235 m0s
= Const(0, self
.width
-1)
236 mw
= Const(self
.width
-1, len(self
.diff
))
237 m
.d
.comb
+= [maxslen
.eq(Mux(self
.diff
> mw
, mw
, self
.diff
)),
238 maxsleni
.eq(Mux(self
.diff
> mw
, 0, mw
-self
.diff
)),
242 # shift mantissa by maxslen, mask by inverse
243 rs
.eq(sm
.rshift(self
.inp
[1:], maxslen
)),
244 m_mask
.eq(sm
.rshift(~m0s
, maxsleni
)),
245 smask
.eq(self
.inp
[1:] & m_mask
),
246 # sticky bit combines all mask (and mantissa low bit)
247 stickybit
.eq(smask
.bool() | self
.inp
[0]),
248 # mantissa result contains m[0] already.
249 self
.m
.eq(Cat(stickybit
, rs
))
254 class FPNumShift(FPNumBase
):
255 """ Floating-point Number Class for shifting
257 def __init__(self
, mainm
, op
, inv
, width
, m_extra
=True):
258 FPNumBase
.__init
__(self
, width
, m_extra
)
259 self
.latch_in
= Signal()
264 def elaborate(self
, platform
):
265 m
= FPNumBase
.elaborate(self
, platform
)
267 m
.d
.comb
+= self
.s
.eq(op
.s
)
268 m
.d
.comb
+= self
.e
.eq(op
.e
)
269 m
.d
.comb
+= self
.m
.eq(op
.m
)
271 with self
.mainm
.State("align"):
272 with m
.If(self
.e
< self
.inv
.e
):
273 m
.d
.sync
+= self
.shift_down()
277 def shift_down(self
, inp
):
278 """ shifts a mantissa down by one. exponent is increased to compensate
280 accuracy is lost as a result in the mantissa however there are 3
281 guard bits (the latter of which is the "sticky" bit)
283 return [self
.e
.eq(inp
.e
+ 1),
284 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
287 def shift_down_multi(self
, diff
):
288 """ shifts a mantissa down. exponent is increased to compensate
290 accuracy is lost as a result in the mantissa however there are 3
291 guard bits (the latter of which is the "sticky" bit)
293 this code works by variable-shifting the mantissa by up to
294 its maximum bit-length: no point doing more (it'll still be
297 the sticky bit is computed by shifting a batch of 1s by
298 the same amount, which will introduce zeros. it's then
299 inverted and used as a mask to get the LSBs of the mantissa.
300 those are then |'d into the sticky bit.
302 sm
= MultiShift(self
.width
)
303 mw
= Const(self
.m_width
-1, len(diff
))
304 maxslen
= Mux(diff
> mw
, mw
, diff
)
305 rs
= sm
.rshift(self
.m
[1:], maxslen
)
306 maxsleni
= mw
- maxslen
307 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
309 stickybits
= reduce(or_
, self
.m
[1:] & m_mask
) | self
.m
[0]
310 return [self
.e
.eq(self
.e
+ diff
),
311 self
.m
.eq(Cat(stickybits
, rs
))
314 def shift_up_multi(self
, diff
):
315 """ shifts a mantissa up. exponent is decreased to compensate
317 sm
= MultiShift(self
.width
)
318 mw
= Const(self
.m_width
, len(diff
))
319 maxslen
= Mux(diff
> mw
, mw
, diff
)
321 return [self
.e
.eq(self
.e
- diff
),
322 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
326 class FPNumDecode(FPNumBase
):
327 """ Floating-point Number Class
329 Contains signals for an incoming copy of the value, decoded into
330 sign / exponent / mantissa.
331 Also contains encoding functions, creation and recognition of
332 zero, NaN and inf (all signed)
334 Four extra bits are included in the mantissa: the top bit
335 (m[-1]) is effectively a carry-overflow. The other three are
336 guard (m[2]), round (m[1]), and sticky (m[0])
338 def __init__(self
, op
, width
, m_extra
=True):
339 FPNumBase
.__init
__(self
, width
, m_extra
)
342 def elaborate(self
, platform
):
343 m
= FPNumBase
.elaborate(self
, platform
)
345 m
.d
.comb
+= self
.decode(self
.v
)
350 """ decodes a latched value into sign / exponent / mantissa
352 bias is subtracted here, from the exponent. exponent
353 is extended to 10 bits so that subtract 127 is done on
356 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
357 #print ("decode", self.e_end)
358 return [self
.m
.eq(Cat(*args
)), # mantissa
359 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
360 self
.s
.eq(v
[-1]), # sign
363 class FPNumIn(FPNumBase
):
364 """ Floating-point Number Class
366 Contains signals for an incoming copy of the value, decoded into
367 sign / exponent / mantissa.
368 Also contains encoding functions, creation and recognition of
369 zero, NaN and inf (all signed)
371 Four extra bits are included in the mantissa: the top bit
372 (m[-1]) is effectively a carry-overflow. The other three are
373 guard (m[2]), round (m[1]), and sticky (m[0])
375 def __init__(self
, op
, width
, m_extra
=True):
376 FPNumBase
.__init
__(self
, width
, m_extra
)
377 self
.latch_in
= Signal()
380 def decode2(self
, m
):
381 """ decodes a latched value into sign / exponent / mantissa
383 bias is subtracted here, from the exponent. exponent
384 is extended to 10 bits so that subtract 127 is done on
388 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
389 #print ("decode", self.e_end)
390 res
= ObjectProxy(m
, pipemode
=False)
391 res
.m
= Cat(*args
) # mantissa
392 res
.e
= v
[self
.e_start
:self
.e_end
] - self
.P127
# exp
397 """ decodes a latched value into sign / exponent / mantissa
399 bias is subtracted here, from the exponent. exponent
400 is extended to 10 bits so that subtract 127 is done on
403 args
= [0] * self
.m_extra
+ [v
[0:self
.e_start
]] # pad with extra zeros
404 #print ("decode", self.e_end)
405 return [self
.m
.eq(Cat(*args
)), # mantissa
406 self
.e
.eq(v
[self
.e_start
:self
.e_end
] - self
.P127
), # exp
407 self
.s
.eq(v
[-1]), # sign
410 def shift_down(self
, inp
):
411 """ shifts a mantissa down by one. exponent is increased to compensate
413 accuracy is lost as a result in the mantissa however there are 3
414 guard bits (the latter of which is the "sticky" bit)
416 return [self
.e
.eq(inp
.e
+ 1),
417 self
.m
.eq(Cat(inp
.m
[0] | inp
.m
[1], inp
.m
[2:], 0))
420 def shift_down_multi(self
, diff
, inp
=None):
421 """ shifts a mantissa down. exponent is increased to compensate
423 accuracy is lost as a result in the mantissa however there are 3
424 guard bits (the latter of which is the "sticky" bit)
426 this code works by variable-shifting the mantissa by up to
427 its maximum bit-length: no point doing more (it'll still be
430 the sticky bit is computed by shifting a batch of 1s by
431 the same amount, which will introduce zeros. it's then
432 inverted and used as a mask to get the LSBs of the mantissa.
433 those are then |'d into the sticky bit.
437 sm
= MultiShift(self
.width
)
438 mw
= Const(self
.m_width
-1, len(diff
))
439 maxslen
= Mux(diff
> mw
, mw
, diff
)
440 rs
= sm
.rshift(inp
.m
[1:], maxslen
)
441 maxsleni
= mw
- maxslen
442 m_mask
= sm
.rshift(self
.m1s
[1:], maxsleni
) # shift and invert
444 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
445 stickybit
= (inp
.m
[1:] & m_mask
).bool() | inp
.m
[0]
446 return [self
.e
.eq(inp
.e
+ diff
),
447 self
.m
.eq(Cat(stickybit
, rs
))
450 def shift_up_multi(self
, diff
):
451 """ shifts a mantissa up. exponent is decreased to compensate
453 sm
= MultiShift(self
.width
)
454 mw
= Const(self
.m_width
, len(diff
))
455 maxslen
= Mux(diff
> mw
, mw
, diff
)
457 return [self
.e
.eq(self
.e
- diff
),
458 self
.m
.eq(sm
.lshift(self
.m
, maxslen
))
464 self
.stb
= Signal(reset
=0)
466 self
.trigger
= Signal(reset_less
=True)
468 def elaborate(self
, platform
):
470 m
.d
.comb
+= self
.trigger
.eq(self
.stb
& self
.ack
)
474 return [self
.stb
.eq(inp
.stb
),
479 return [self
.stb
, self
.ack
]
483 def __init__(self
, width
):
484 Trigger
.__init
__(self
)
487 self
.v
= Signal(width
)
489 def chain_inv(self
, in_op
, extra
=None):
491 if extra
is not None:
493 return [self
.v
.eq(in_op
.v
), # receive value
494 self
.stb
.eq(stb
), # receive STB
495 in_op
.ack
.eq(~self
.ack
), # send ACK
498 def chain_from(self
, in_op
, extra
=None):
500 if extra
is not None:
502 return [self
.v
.eq(in_op
.v
), # receive value
503 self
.stb
.eq(stb
), # receive STB
504 in_op
.ack
.eq(self
.ack
), # send ACK
508 return [self
.v
.eq(inp
.v
),
509 self
.stb
.eq(inp
.stb
),
514 return [self
.v
, self
.stb
, self
.ack
]
519 self
.guard
= Signal(reset_less
=True) # tot[2]
520 self
.round_bit
= Signal(reset_less
=True) # tot[1]
521 self
.sticky
= Signal(reset_less
=True) # tot[0]
522 self
.m0
= Signal(reset_less
=True) # mantissa zero bit
524 self
.roundz
= Signal(reset_less
=True)
527 return [self
.guard
.eq(inp
.guard
),
528 self
.round_bit
.eq(inp
.round_bit
),
529 self
.sticky
.eq(inp
.sticky
),
532 def elaborate(self
, platform
):
534 m
.d
.comb
+= self
.roundz
.eq(self
.guard
& \
535 (self
.round_bit | self
.sticky | self
.m0
))
540 """ IEEE754 Floating Point Base Class
542 contains common functions for FP manipulation, such as
543 extracting and packing operands, normalisation, denormalisation,
547 def get_op(self
, m
, op
, v
, next_state
):
548 """ this function moves to the next state and copies the operand
549 when both stb and ack are 1.
550 acknowledgement is sent by setting ack to ZERO.
554 with m
.If((op
.ack
) & (op
.stb
)):
556 # op is latched in from FPNumIn class on same ack/stb
557 m
.d
.comb
+= ack
.eq(0)
559 m
.d
.comb
+= ack
.eq(1)
562 def denormalise(self
, m
, a
):
563 """ denormalises a number. this is probably the wrong name for
564 this function. for normalised numbers (exponent != minimum)
565 one *extra* bit (the implicit 1) is added *back in*.
566 for denormalised numbers, the mantissa is left alone
567 and the exponent increased by 1.
569 both cases *effectively multiply the number stored by 2*,
570 which has to be taken into account when extracting the result.
572 with m
.If(a
.exp_n127
):
573 m
.d
.sync
+= a
.e
.eq(a
.N126
) # limit a exponent
575 m
.d
.sync
+= a
.m
[-1].eq(1) # set top mantissa bit
577 def op_normalise(self
, m
, op
, next_state
):
578 """ operand normalisation
579 NOTE: just like "align", this one keeps going round every clock
580 until the result's exponent is within acceptable "range"
582 with m
.If((op
.m
[-1] == 0)): # check last bit of mantissa
584 op
.e
.eq(op
.e
- 1), # DECREASE exponent
585 op
.m
.eq(op
.m
<< 1), # shift mantissa UP
590 def normalise_1(self
, m
, z
, of
, next_state
):
591 """ first stage normalisation
593 NOTE: just like "align", this one keeps going round every clock
594 until the result's exponent is within acceptable "range"
595 NOTE: the weirdness of reassigning guard and round is due to
596 the extra mantissa bits coming from tot[0..2]
598 with m
.If((z
.m
[-1] == 0) & (z
.e
> z
.N126
)):
600 z
.e
.eq(z
.e
- 1), # DECREASE exponent
601 z
.m
.eq(z
.m
<< 1), # shift mantissa UP
602 z
.m
[0].eq(of
.guard
), # steal guard bit (was tot[2])
603 of
.guard
.eq(of
.round_bit
), # steal round_bit (was tot[1])
604 of
.round_bit
.eq(0), # reset round bit
610 def normalise_2(self
, m
, z
, of
, next_state
):
611 """ second stage normalisation
613 NOTE: just like "align", this one keeps going round every clock
614 until the result's exponent is within acceptable "range"
615 NOTE: the weirdness of reassigning guard and round is due to
616 the extra mantissa bits coming from tot[0..2]
618 with m
.If(z
.e
< z
.N126
):
620 z
.e
.eq(z
.e
+ 1), # INCREASE exponent
621 z
.m
.eq(z
.m
>> 1), # shift mantissa DOWN
624 of
.round_bit
.eq(of
.guard
),
625 of
.sticky
.eq(of
.sticky | of
.round_bit
)
630 def roundz(self
, m
, z
, roundz
):
631 """ performs rounding on the output. TODO: different kinds of rounding
634 m
.d
.sync
+= z
.m
.eq(z
.m
+ 1) # mantissa rounds up
635 with m
.If(z
.m
== z
.m1s
): # all 1s
636 m
.d
.sync
+= z
.e
.eq(z
.e
+ 1) # exponent rounds up
638 def corrections(self
, m
, z
, next_state
):
639 """ denormalisation and sign-bug corrections
642 # denormalised, correct exponent to zero
643 with m
.If(z
.is_denormalised
):
644 m
.d
.sync
+= z
.e
.eq(z
.N127
)
646 def pack(self
, m
, z
, next_state
):
647 """ packs the result into the output (detects overflow->Inf)
650 # if overflow occurs, return inf
651 with m
.If(z
.is_overflowed
):
652 m
.d
.sync
+= z
.inf(z
.s
)
654 m
.d
.sync
+= z
.create(z
.s
, z
.e
, z
.m
)
656 def put_z(self
, m
, z
, out_z
, next_state
):
657 """ put_z: stores the result in the output. raises stb and waits
658 for ack to be set to 1 before moving to the next state.
659 resets stb back to zero when that occurs, as acknowledgement.
664 with m
.If(out_z
.stb
& out_z
.ack
):
665 m
.d
.sync
+= out_z
.stb
.eq(0)
668 m
.d
.sync
+= out_z
.stb
.eq(1)
671 class FPState(FPBase
):
672 def __init__(self
, state_from
):
673 self
.state_from
= state_from
675 def set_inputs(self
, inputs
):
677 for k
,v
in inputs
.items():
680 def set_outputs(self
, outputs
):
681 self
.outputs
= outputs
682 for k
,v
in outputs
.items():
687 def __init__(self
, id_wid
):
690 self
.in_mid
= Signal(id_wid
, reset_less
=True)
691 self
.out_mid
= Signal(id_wid
, reset_less
=True)
697 if self
.id_wid
is not None:
698 m
.d
.sync
+= self
.out_mid
.eq(self
.in_mid
)