whoops no e_start-1 in fpnum decode
[ieee754fpu.git] / src / ieee754 / fpcommon / fpbase.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Signal, Cat, Const, Mux, Module, Elaboratable
6 from math import log
7 from operator import or_
8 from functools import reduce
9
10 from nmutil.singlepipe import PrevControl, NextControl
11 from nmutil.pipeline import ObjectProxy
12
13
14 class MultiShiftR:
15
16 def __init__(self, width):
17 self.width = width
18 self.smax = int(log(width) / log(2))
19 self.i = Signal(width, reset_less=True)
20 self.s = Signal(self.smax, reset_less=True)
21 self.o = Signal(width, reset_less=True)
22
23 def elaborate(self, platform):
24 m = Module()
25 m.d.comb += self.o.eq(self.i >> self.s)
26 return m
27
28
29 class MultiShift:
30 """ Generates variable-length single-cycle shifter from a series
31 of conditional tests on each bit of the left/right shift operand.
32 Each bit tested produces output shifted by that number of bits,
33 in a binary fashion: bit 1 if set shifts by 1 bit, bit 2 if set
34 shifts by 2 bits, each partial result cascading to the next Mux.
35
36 Could be adapted to do arithmetic shift by taking copies of the
37 MSB instead of zeros.
38 """
39
40 def __init__(self, width):
41 self.width = width
42 self.smax = int(log(width) / log(2))
43
44 def lshift(self, op, s):
45 res = op << s
46 return res[:len(op)]
47 res = op
48 for i in range(self.smax):
49 zeros = [0] * (1<<i)
50 res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
51 return res
52
53 def rshift(self, op, s):
54 res = op >> s
55 return res[:len(op)]
56 res = op
57 for i in range(self.smax):
58 zeros = [0] * (1<<i)
59 res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
60 return res
61
62
63 class FPNumBaseRecord:
64 """ Floating-point Base Number Class
65 """
66 def __init__(self, width, m_extra=True):
67 self.width = width
68 m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
69 e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
70 e_max = 1<<(e_width-3)
71 self.rmw = m_width # real mantissa width (not including extras)
72 self.e_max = e_max
73 if m_extra:
74 # mantissa extra bits (top,guard,round)
75 self.m_extra = 3
76 m_width += self.m_extra
77 else:
78 self.m_extra = 0
79 #print (m_width, e_width, e_max, self.rmw, self.m_extra)
80 self.m_width = m_width
81 self.e_width = e_width
82 self.e_start = self.rmw - 1
83 self.e_end = self.rmw + self.e_width - 3 # for decoding
84
85 self.v = Signal(width, reset_less=True) # Latched copy of value
86 self.m = Signal(m_width, reset_less=True) # Mantissa
87 self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
88 self.s = Signal(reset_less=True) # Sign bit
89
90 self.fp = self
91 self.drop_in(self)
92
93 def drop_in(self, fp):
94 fp.s = self.s
95 fp.e = self.e
96 fp.m = self.m
97 fp.v = self.v
98 fp.rmw = self.rmw
99 fp.width = self.width
100 fp.e_width = self.e_width
101 fp.e_max = self.e_max
102 fp.m_width = self.m_width
103 fp.e_start = self.e_start
104 fp.e_end = self.e_end
105 fp.m_extra = self.m_extra
106
107 m_width = self.m_width
108 e_max = self.e_max
109 e_width = self.e_width
110
111 self.mzero = Const(0, (m_width, False))
112 m_msb = 1<<(self.m_width-2)
113 self.msb1 = Const(m_msb, (m_width, False))
114 self.m1s = Const(-1, (m_width, False))
115 self.P128 = Const(e_max, (e_width, True))
116 self.P127 = Const(e_max-1, (e_width, True))
117 self.N127 = Const(-(e_max-1), (e_width, True))
118 self.N126 = Const(-(e_max-2), (e_width, True))
119
120 def create(self, s, e, m):
121 """ creates a value from sign / exponent / mantissa
122
123 bias is added here, to the exponent
124 """
125 return [
126 self.v[-1].eq(s), # sign
127 self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add on bias)
128 self.v[0:self.e_start].eq(m) # mantissa
129 ]
130
131 def _nan(self, s):
132 return (s, self.fp.P128, 1<<(self.e_start-1))
133
134 def _inf(self, s):
135 return (s, self.fp.P128, 0)
136
137 def _zero(self, s):
138 return (s, self.fp.N127, 0)
139
140 def nan(self, s):
141 return self.create(*self._nan(s))
142
143 def inf(self, s):
144 return self.create(*self._inf(s))
145
146 def zero(self, s):
147 return self.create(*self._zero(s))
148
149 def create2(self, s, e, m):
150 """ creates a value from sign / exponent / mantissa
151
152 bias is added here, to the exponent
153 """
154 e = e + self.P127 # exp (add on bias)
155 return Cat(m[0:self.e_start],
156 e[0:self.e_end-self.e_start],
157 s)
158
159 def nan2(self, s):
160 return self.create2(s, self.P128, self.msb1)
161
162 def inf2(self, s):
163 return self.create2(s, self.P128, self.mzero)
164
165 def zero2(self, s):
166 return self.create2(s, self.N127, self.mzero)
167
168 def __iter__(self):
169 yield self.s
170 yield self.e
171 yield self.m
172
173 def eq(self, inp):
174 return [self.s.eq(inp.s), self.e.eq(inp.e), self.m.eq(inp.m)]
175
176
177 class FPNumBase(FPNumBaseRecord, Elaboratable):
178 """ Floating-point Base Number Class
179 """
180 def __init__(self, fp):
181 fp.drop_in(self)
182 self.fp = fp
183 e_width = fp.e_width
184
185 self.is_nan = Signal(reset_less=True)
186 self.is_zero = Signal(reset_less=True)
187 self.is_inf = Signal(reset_less=True)
188 self.is_overflowed = Signal(reset_less=True)
189 self.is_denormalised = Signal(reset_less=True)
190 self.exp_128 = Signal(reset_less=True)
191 self.exp_sub_n126 = Signal((e_width, True), reset_less=True)
192 self.exp_lt_n126 = Signal(reset_less=True)
193 self.exp_gt_n126 = Signal(reset_less=True)
194 self.exp_gt127 = Signal(reset_less=True)
195 self.exp_n127 = Signal(reset_less=True)
196 self.exp_n126 = Signal(reset_less=True)
197 self.m_zero = Signal(reset_less=True)
198 self.m_msbzero = Signal(reset_less=True)
199
200 def elaborate(self, platform):
201 m = Module()
202 m.d.comb += self.is_nan.eq(self._is_nan())
203 m.d.comb += self.is_zero.eq(self._is_zero())
204 m.d.comb += self.is_inf.eq(self._is_inf())
205 m.d.comb += self.is_overflowed.eq(self._is_overflowed())
206 m.d.comb += self.is_denormalised.eq(self._is_denormalised())
207 m.d.comb += self.exp_128.eq(self.e == self.fp.P128)
208 m.d.comb += self.exp_sub_n126.eq(self.e - self.fp.N126)
209 m.d.comb += self.exp_gt_n126.eq(self.exp_sub_n126 > 0)
210 m.d.comb += self.exp_lt_n126.eq(self.exp_sub_n126 < 0)
211 m.d.comb += self.exp_gt127.eq(self.e > self.fp.P127)
212 m.d.comb += self.exp_n127.eq(self.e == self.fp.N127)
213 m.d.comb += self.exp_n126.eq(self.e == self.fp.N126)
214 m.d.comb += self.m_zero.eq(self.m == self.fp.mzero)
215 m.d.comb += self.m_msbzero.eq(self.m[self.fp.e_start] == 0)
216
217 return m
218
219 def _is_nan(self):
220 return (self.exp_128) & (~self.m_zero)
221
222 def _is_inf(self):
223 return (self.exp_128) & (self.m_zero)
224
225 def _is_zero(self):
226 return (self.exp_n127) & (self.m_zero)
227
228 def _is_overflowed(self):
229 return self.exp_gt127
230
231 def _is_denormalised(self):
232 return (self.exp_n126) & (self.m_msbzero)
233
234
235 class FPNumOut(FPNumBase):
236 """ Floating-point Number Class
237
238 Contains signals for an incoming copy of the value, decoded into
239 sign / exponent / mantissa.
240 Also contains encoding functions, creation and recognition of
241 zero, NaN and inf (all signed)
242
243 Four extra bits are included in the mantissa: the top bit
244 (m[-1]) is effectively a carry-overflow. The other three are
245 guard (m[2]), round (m[1]), and sticky (m[0])
246 """
247 def __init__(self, fp):
248 FPNumBase.__init__(self, fp)
249
250 def elaborate(self, platform):
251 m = FPNumBase.elaborate(self, platform)
252
253 return m
254
255
256 class MultiShiftRMerge(Elaboratable):
257 """ shifts down (right) and merges lower bits into m[0].
258 m[0] is the "sticky" bit, basically
259 """
260 def __init__(self, width, s_max=None):
261 if s_max is None:
262 s_max = int(log(width) / log(2))
263 self.smax = s_max
264 self.m = Signal(width, reset_less=True)
265 self.inp = Signal(width, reset_less=True)
266 self.diff = Signal(s_max, reset_less=True)
267 self.width = width
268
269 def elaborate(self, platform):
270 m = Module()
271
272 rs = Signal(self.width, reset_less=True)
273 m_mask = Signal(self.width, reset_less=True)
274 smask = Signal(self.width, reset_less=True)
275 stickybit = Signal(reset_less=True)
276 maxslen = Signal(self.smax, reset_less=True)
277 maxsleni = Signal(self.smax, reset_less=True)
278
279 sm = MultiShift(self.width-1)
280 m0s = Const(0, self.width-1)
281 mw = Const(self.width-1, len(self.diff))
282 m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
283 maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
284 ]
285
286 m.d.comb += [
287 # shift mantissa by maxslen, mask by inverse
288 rs.eq(sm.rshift(self.inp[1:], maxslen)),
289 m_mask.eq(sm.rshift(~m0s, maxsleni)),
290 smask.eq(self.inp[1:] & m_mask),
291 # sticky bit combines all mask (and mantissa low bit)
292 stickybit.eq(smask.bool() | self.inp[0]),
293 # mantissa result contains m[0] already.
294 self.m.eq(Cat(stickybit, rs))
295 ]
296 return m
297
298
299 class FPNumShift(FPNumBase, Elaboratable):
300 """ Floating-point Number Class for shifting
301 """
302 def __init__(self, mainm, op, inv, width, m_extra=True):
303 FPNumBase.__init__(self, width, m_extra)
304 self.latch_in = Signal()
305 self.mainm = mainm
306 self.inv = inv
307 self.op = op
308
309 def elaborate(self, platform):
310 m = FPNumBase.elaborate(self, platform)
311
312 m.d.comb += self.s.eq(op.s)
313 m.d.comb += self.e.eq(op.e)
314 m.d.comb += self.m.eq(op.m)
315
316 with self.mainm.State("align"):
317 with m.If(self.e < self.inv.e):
318 m.d.sync += self.shift_down()
319
320 return m
321
322 def shift_down(self, inp):
323 """ shifts a mantissa down by one. exponent is increased to compensate
324
325 accuracy is lost as a result in the mantissa however there are 3
326 guard bits (the latter of which is the "sticky" bit)
327 """
328 return [self.e.eq(inp.e + 1),
329 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
330 ]
331
332 def shift_down_multi(self, diff):
333 """ shifts a mantissa down. exponent is increased to compensate
334
335 accuracy is lost as a result in the mantissa however there are 3
336 guard bits (the latter of which is the "sticky" bit)
337
338 this code works by variable-shifting the mantissa by up to
339 its maximum bit-length: no point doing more (it'll still be
340 zero).
341
342 the sticky bit is computed by shifting a batch of 1s by
343 the same amount, which will introduce zeros. it's then
344 inverted and used as a mask to get the LSBs of the mantissa.
345 those are then |'d into the sticky bit.
346 """
347 sm = MultiShift(self.width)
348 mw = Const(self.m_width-1, len(diff))
349 maxslen = Mux(diff > mw, mw, diff)
350 rs = sm.rshift(self.m[1:], maxslen)
351 maxsleni = mw - maxslen
352 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
353
354 stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
355 return [self.e.eq(self.e + diff),
356 self.m.eq(Cat(stickybits, rs))
357 ]
358
359 def shift_up_multi(self, diff):
360 """ shifts a mantissa up. exponent is decreased to compensate
361 """
362 sm = MultiShift(self.width)
363 mw = Const(self.m_width, len(diff))
364 maxslen = Mux(diff > mw, mw, diff)
365
366 return [self.e.eq(self.e - diff),
367 self.m.eq(sm.lshift(self.m, maxslen))
368 ]
369
370
371 class FPNumDecode(FPNumBase):
372 """ Floating-point Number Class
373
374 Contains signals for an incoming copy of the value, decoded into
375 sign / exponent / mantissa.
376 Also contains encoding functions, creation and recognition of
377 zero, NaN and inf (all signed)
378
379 Four extra bits are included in the mantissa: the top bit
380 (m[-1]) is effectively a carry-overflow. The other three are
381 guard (m[2]), round (m[1]), and sticky (m[0])
382 """
383 def __init__(self, op, fp):
384 FPNumBase.__init__(self, fp)
385 self.op = op
386
387 def elaborate(self, platform):
388 m = FPNumBase.elaborate(self, platform)
389
390 m.d.comb += self.decode(self.v)
391
392 return m
393
394 def decode(self, v):
395 """ decodes a latched value into sign / exponent / mantissa
396
397 bias is subtracted here, from the exponent. exponent
398 is extended to 10 bits so that subtract 127 is done on
399 a 10-bit number
400 """
401 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
402 #print ("decode", self.e_end)
403 return [self.m.eq(Cat(*args)), # mantissa
404 self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
405 self.s.eq(v[-1]), # sign
406 ]
407
408 class FPNumIn(FPNumBase):
409 """ Floating-point Number Class
410
411 Contains signals for an incoming copy of the value, decoded into
412 sign / exponent / mantissa.
413 Also contains encoding functions, creation and recognition of
414 zero, NaN and inf (all signed)
415
416 Four extra bits are included in the mantissa: the top bit
417 (m[-1]) is effectively a carry-overflow. The other three are
418 guard (m[2]), round (m[1]), and sticky (m[0])
419 """
420 def __init__(self, op, fp):
421 FPNumBase.__init__(self, fp)
422 self.latch_in = Signal()
423 self.op = op
424
425 def decode2(self, m):
426 """ decodes a latched value into sign / exponent / mantissa
427
428 bias is subtracted here, from the exponent. exponent
429 is extended to 10 bits so that subtract 127 is done on
430 a 10-bit number
431 """
432 v = self.v
433 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
434 #print ("decode", self.e_end)
435 res = ObjectProxy(m, pipemode=False)
436 res.m = Cat(*args) # mantissa
437 res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
438 res.s = v[-1] # sign
439 return res
440
441 def decode(self, v):
442 """ decodes a latched value into sign / exponent / mantissa
443
444 bias is subtracted here, from the exponent. exponent
445 is extended to 10 bits so that subtract 127 is done on
446 a 10-bit number
447 """
448 args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
449 #print ("decode", self.e_end)
450 return [self.m.eq(Cat(*args)), # mantissa
451 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
452 self.s.eq(v[-1]), # sign
453 ]
454
455 def shift_down(self, inp):
456 """ shifts a mantissa down by one. exponent is increased to compensate
457
458 accuracy is lost as a result in the mantissa however there are 3
459 guard bits (the latter of which is the "sticky" bit)
460 """
461 return [self.e.eq(inp.e + 1),
462 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
463 ]
464
465 def shift_down_multi(self, diff, inp=None):
466 """ shifts a mantissa down. exponent is increased to compensate
467
468 accuracy is lost as a result in the mantissa however there are 3
469 guard bits (the latter of which is the "sticky" bit)
470
471 this code works by variable-shifting the mantissa by up to
472 its maximum bit-length: no point doing more (it'll still be
473 zero).
474
475 the sticky bit is computed by shifting a batch of 1s by
476 the same amount, which will introduce zeros. it's then
477 inverted and used as a mask to get the LSBs of the mantissa.
478 those are then |'d into the sticky bit.
479 """
480 if inp is None:
481 inp = self
482 sm = MultiShift(self.width)
483 mw = Const(self.m_width-1, len(diff))
484 maxslen = Mux(diff > mw, mw, diff)
485 rs = sm.rshift(inp.m[1:], maxslen)
486 maxsleni = mw - maxslen
487 m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
488
489 #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
490 stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
491 return [self.e.eq(inp.e + diff),
492 self.m.eq(Cat(stickybit, rs))
493 ]
494
495 def shift_up_multi(self, diff):
496 """ shifts a mantissa up. exponent is decreased to compensate
497 """
498 sm = MultiShift(self.width)
499 mw = Const(self.m_width, len(diff))
500 maxslen = Mux(diff > mw, mw, diff)
501
502 return [self.e.eq(self.e - diff),
503 self.m.eq(sm.lshift(self.m, maxslen))
504 ]
505
506 class Trigger(Elaboratable):
507 def __init__(self):
508
509 self.stb = Signal(reset=0)
510 self.ack = Signal()
511 self.trigger = Signal(reset_less=True)
512
513 def elaborate(self, platform):
514 m = Module()
515 m.d.comb += self.trigger.eq(self.stb & self.ack)
516 return m
517
518 def eq(self, inp):
519 return [self.stb.eq(inp.stb),
520 self.ack.eq(inp.ack)
521 ]
522
523 def ports(self):
524 return [self.stb, self.ack]
525
526
527 class FPOpIn(PrevControl):
528 def __init__(self, width):
529 PrevControl.__init__(self)
530 self.width = width
531
532 @property
533 def v(self):
534 return self.data_i
535
536 def chain_inv(self, in_op, extra=None):
537 stb = in_op.stb
538 if extra is not None:
539 stb = stb & extra
540 return [self.v.eq(in_op.v), # receive value
541 self.stb.eq(stb), # receive STB
542 in_op.ack.eq(~self.ack), # send ACK
543 ]
544
545 def chain_from(self, in_op, extra=None):
546 stb = in_op.stb
547 if extra is not None:
548 stb = stb & extra
549 return [self.v.eq(in_op.v), # receive value
550 self.stb.eq(stb), # receive STB
551 in_op.ack.eq(self.ack), # send ACK
552 ]
553
554
555 class FPOpOut(NextControl):
556 def __init__(self, width):
557 NextControl.__init__(self)
558 self.width = width
559
560 @property
561 def v(self):
562 return self.data_o
563
564 def chain_inv(self, in_op, extra=None):
565 stb = in_op.stb
566 if extra is not None:
567 stb = stb & extra
568 return [self.v.eq(in_op.v), # receive value
569 self.stb.eq(stb), # receive STB
570 in_op.ack.eq(~self.ack), # send ACK
571 ]
572
573 def chain_from(self, in_op, extra=None):
574 stb = in_op.stb
575 if extra is not None:
576 stb = stb & extra
577 return [self.v.eq(in_op.v), # receive value
578 self.stb.eq(stb), # receive STB
579 in_op.ack.eq(self.ack), # send ACK
580 ]
581
582
583 class Overflow: #(Elaboratable):
584 def __init__(self):
585 self.guard = Signal(reset_less=True) # tot[2]
586 self.round_bit = Signal(reset_less=True) # tot[1]
587 self.sticky = Signal(reset_less=True) # tot[0]
588 self.m0 = Signal(reset_less=True) # mantissa zero bit
589
590 #self.roundz = Signal(reset_less=True)
591
592 def __iter__(self):
593 yield self.guard
594 yield self.round_bit
595 yield self.sticky
596 yield self.m0
597
598 def eq(self, inp):
599 return [self.guard.eq(inp.guard),
600 self.round_bit.eq(inp.round_bit),
601 self.sticky.eq(inp.sticky),
602 self.m0.eq(inp.m0)]
603
604 @property
605 def roundz(self):
606 return self.guard & (self.round_bit | self.sticky | self.m0)
607
608
609 class FPBase:
610 """ IEEE754 Floating Point Base Class
611
612 contains common functions for FP manipulation, such as
613 extracting and packing operands, normalisation, denormalisation,
614 rounding etc.
615 """
616
617 def get_op(self, m, op, v, next_state):
618 """ this function moves to the next state and copies the operand
619 when both stb and ack are 1.
620 acknowledgement is sent by setting ack to ZERO.
621 """
622 res = v.decode2(m)
623 ack = Signal()
624 with m.If((op.ready_o) & (op.valid_i_test)):
625 m.next = next_state
626 # op is latched in from FPNumIn class on same ack/stb
627 m.d.comb += ack.eq(0)
628 with m.Else():
629 m.d.comb += ack.eq(1)
630 return [res, ack]
631
632 def denormalise(self, m, a):
633 """ denormalises a number. this is probably the wrong name for
634 this function. for normalised numbers (exponent != minimum)
635 one *extra* bit (the implicit 1) is added *back in*.
636 for denormalised numbers, the mantissa is left alone
637 and the exponent increased by 1.
638
639 both cases *effectively multiply the number stored by 2*,
640 which has to be taken into account when extracting the result.
641 """
642 with m.If(a.exp_n127):
643 m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
644 with m.Else():
645 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
646
647 def op_normalise(self, m, op, next_state):
648 """ operand normalisation
649 NOTE: just like "align", this one keeps going round every clock
650 until the result's exponent is within acceptable "range"
651 """
652 with m.If((op.m[-1] == 0)): # check last bit of mantissa
653 m.d.sync +=[
654 op.e.eq(op.e - 1), # DECREASE exponent
655 op.m.eq(op.m << 1), # shift mantissa UP
656 ]
657 with m.Else():
658 m.next = next_state
659
660 def normalise_1(self, m, z, of, next_state):
661 """ first stage normalisation
662
663 NOTE: just like "align", this one keeps going round every clock
664 until the result's exponent is within acceptable "range"
665 NOTE: the weirdness of reassigning guard and round is due to
666 the extra mantissa bits coming from tot[0..2]
667 """
668 with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
669 m.d.sync += [
670 z.e.eq(z.e - 1), # DECREASE exponent
671 z.m.eq(z.m << 1), # shift mantissa UP
672 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
673 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
674 of.round_bit.eq(0), # reset round bit
675 of.m0.eq(of.guard),
676 ]
677 with m.Else():
678 m.next = next_state
679
680 def normalise_2(self, m, z, of, next_state):
681 """ second stage normalisation
682
683 NOTE: just like "align", this one keeps going round every clock
684 until the result's exponent is within acceptable "range"
685 NOTE: the weirdness of reassigning guard and round is due to
686 the extra mantissa bits coming from tot[0..2]
687 """
688 with m.If(z.e < z.fp.N126):
689 m.d.sync +=[
690 z.e.eq(z.e + 1), # INCREASE exponent
691 z.m.eq(z.m >> 1), # shift mantissa DOWN
692 of.guard.eq(z.m[0]),
693 of.m0.eq(z.m[1]),
694 of.round_bit.eq(of.guard),
695 of.sticky.eq(of.sticky | of.round_bit)
696 ]
697 with m.Else():
698 m.next = next_state
699
700 def roundz(self, m, z, roundz):
701 """ performs rounding on the output. TODO: different kinds of rounding
702 """
703 with m.If(roundz):
704 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
705 with m.If(z.m == z.fp.m1s): # all 1s
706 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
707
708 def corrections(self, m, z, next_state):
709 """ denormalisation and sign-bug corrections
710 """
711 m.next = next_state
712 # denormalised, correct exponent to zero
713 with m.If(z.is_denormalised):
714 m.d.sync += z.e.eq(z.fp.N127)
715
716 def pack(self, m, z, next_state):
717 """ packs the result into the output (detects overflow->Inf)
718 """
719 m.next = next_state
720 # if overflow occurs, return inf
721 with m.If(z.is_overflowed):
722 m.d.sync += z.inf(z.s)
723 with m.Else():
724 m.d.sync += z.create(z.s, z.e, z.m)
725
726 def put_z(self, m, z, out_z, next_state):
727 """ put_z: stores the result in the output. raises stb and waits
728 for ack to be set to 1 before moving to the next state.
729 resets stb back to zero when that occurs, as acknowledgement.
730 """
731 m.d.sync += [
732 out_z.v.eq(z.v)
733 ]
734 with m.If(out_z.valid_o & out_z.ready_i_test):
735 m.d.sync += out_z.valid_o.eq(0)
736 m.next = next_state
737 with m.Else():
738 m.d.sync += out_z.valid_o.eq(1)
739
740
741 class FPState(FPBase):
742 def __init__(self, state_from):
743 self.state_from = state_from
744
745 def set_inputs(self, inputs):
746 self.inputs = inputs
747 for k,v in inputs.items():
748 setattr(self, k, v)
749
750 def set_outputs(self, outputs):
751 self.outputs = outputs
752 for k,v in outputs.items():
753 setattr(self, k, v)
754
755
756 class FPID:
757 def __init__(self, id_wid):
758 self.id_wid = id_wid
759 if self.id_wid:
760 self.in_mid = Signal(id_wid, reset_less=True)
761 self.out_mid = Signal(id_wid, reset_less=True)
762 else:
763 self.in_mid = None
764 self.out_mid = None
765
766 def idsync(self, m):
767 if self.id_wid is not None:
768 m.d.sync += self.out_mid.eq(self.in_mid)
769
770