separate common functions into FPBase class
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Const
6 from nmigen.cli import main, verilog
7
8
9 class FPNum:
10 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
11
12 Contains signals for an incoming copy of the value, decoded into
13 sign / exponent / mantissa.
14 Also contains encoding functions, creation and recognition of
15 zero, NaN and inf (all signed)
16
17 Four extra bits are included in the mantissa: the top bit
18 (m[-1]) is effectively a carry-overflow. The other three are
19 guard (m[2]), round (m[1]), and sticky (m[0])
20 """
21 def __init__(self, width, m_width=None):
22 self.width = width
23 if m_width is None:
24 m_width = width - 5 # mantissa extra bits (top,guard,round)
25 self.v = Signal(width) # Latched copy of value
26 self.m = Signal(m_width) # Mantissa
27 self.e = Signal((10, True)) # Exponent: 10 bits, signed
28 self.s = Signal() # Sign bit
29
30 self.mzero = Const(0, (m_width, False))
31 self.m1s = Const(-1, (m_width, False))
32 self.P128 = Const(128, (10, True))
33 self.P127 = Const(127, (10, True))
34 self.N127 = Const(-127, (10, True))
35 self.N126 = Const(-126, (10, True))
36
37 def decode(self, v):
38 """ decodes a latched value into sign / exponent / mantissa
39
40 bias is subtracted here, from the exponent. exponent
41 is extended to 10 bits so that subtract 127 is done on
42 a 10-bit number
43 """
44 return [self.m.eq(Cat(0, 0, 0, v[0:23])), # mantissa
45 self.e.eq(v[23:31] - self.P127), # exp (minus bias)
46 self.s.eq(v[31]), # sign
47 ]
48
49 def create(self, s, e, m):
50 """ creates a value from sign / exponent / mantissa
51
52 bias is added here, to the exponent
53 """
54 return [
55 self.v[31].eq(s), # sign
56 self.v[23:31].eq(e + self.P127), # exp (add on bias)
57 self.v[0:23].eq(m) # mantissa
58 ]
59
60 def shift_down(self):
61 """ shifts a mantissa down by one. exponent is increased to compensate
62
63 accuracy is lost as a result in the mantissa however there are 3
64 guard bits (the latter of which is the "sticky" bit)
65 """
66 return [self.e.eq(self.e + 1),
67 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
68 ]
69
70 def nan(self, s):
71 return self.create(s, self.P128, 1<<22)
72
73 def inf(self, s):
74 return self.create(s, self.P128, 0)
75
76 def zero(self, s):
77 return self.create(s, self.N127, 0)
78
79 def is_nan(self):
80 return (self.e == self.P128) & (self.m != 0)
81
82 def is_inf(self):
83 return (self.e == self.P128) & (self.m == 0)
84
85 def is_zero(self):
86 return (self.e == self.N127) & (self.m == self.mzero)
87
88 def is_overflowed(self):
89 return (self.e > self.P127)
90
91 def is_denormalised(self):
92 return (self.e == self.N126) & (self.m[23] == 0)
93
94
95 class FPOp:
96 def __init__(self, width):
97 self.width = width
98
99 self.v = Signal(width)
100 self.stb = Signal()
101 self.ack = Signal()
102
103 def ports(self):
104 return [self.v, self.stb, self.ack]
105
106
107 class Overflow:
108 def __init__(self):
109 self.guard = Signal() # tot[2]
110 self.round_bit = Signal() # tot[1]
111 self.sticky = Signal() # tot[0]
112
113
114 class FPBase:
115 """ IEEE754 Floating Point Base Class
116
117 contains common functions for FP manipulation, such as
118 extracting and packing operands, normalisation, denormalisation,
119 rounding etc.
120 """
121
122 def get_op(self, m, op, v, next_state):
123 """ this function moves to the next state and copies the operand
124 when both stb and ack are 1.
125 acknowledgement is sent by setting ack to ZERO.
126 """
127 with m.If((op.ack) & (op.stb)):
128 m.next = next_state
129 m.d.sync += [
130 v.decode(op.v),
131 op.ack.eq(0)
132 ]
133 with m.Else():
134 m.d.sync += op.ack.eq(1)
135
136 def denormalise(self, m, a):
137 """ denormalises a number
138 """
139 with m.If(a.e == a.N127):
140 m.d.sync += a.e.eq(-126) # limit a exponent
141 with m.Else():
142 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
143
144 def normalise_1(self, m, z, of, next_state):
145 """ first stage normalisation
146
147 NOTE: just like "align", this one keeps going round every clock
148 until the result's exponent is within acceptable "range"
149 NOTE: the weirdness of reassigning guard and round is due to
150 the extra mantissa bits coming from tot[0..2]
151 """
152 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
153 m.d.sync +=[
154 z.e.eq(z.e - 1), # DECREASE exponent
155 z.m.eq(z.m << 1), # shift mantissa UP
156 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
157 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
158 of.round_bit.eq(0), # reset round bit
159 ]
160 with m.Else():
161 m.next = next_state
162
163 def normalise_2(self, m, z, of, next_state):
164 """ second stage normalisation
165
166 NOTE: just like "align", this one keeps going round every clock
167 until the result's exponent is within acceptable "range"
168 NOTE: the weirdness of reassigning guard and round is due to
169 the extra mantissa bits coming from tot[0..2]
170 """
171 with m.If(z.e < z.N126):
172 m.d.sync +=[
173 z.e.eq(z.e + 1), # INCREASE exponent
174 z.m.eq(z.m >> 1), # shift mantissa DOWN
175 of.guard.eq(z.m[0]),
176 of.round_bit.eq(of.guard),
177 of.sticky.eq(of.sticky | of.round_bit)
178 ]
179 with m.Else():
180 m.next = next_state
181
182 def roundz(self, m, z, of, next_state):
183 """ performs rounding on the output. TODO: different kinds of rounding
184 """
185 m.next = next_state
186 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
187 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
188 with m.If(z.m == z.m1s): # all 1s
189 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
190
191 def corrections(self, m, z, next_state):
192 """ denormalisation and sign-bug corrections
193 """
194 m.next = next_state
195 # denormalised, correct exponent to zero
196 with m.If(z.is_denormalised()):
197 m.d.sync += z.m.eq(-127)
198 # FIX SIGN BUG: -a + a = +0.
199 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
200 m.d.sync += z.s.eq(0)
201
202 def pack(self, m, z, next_state):
203 """ packs the result into the output (detects overflow->Inf)
204 """
205 m.next = next_state
206 # if overflow occurs, return inf
207 with m.If(z.is_overflowed()):
208 m.d.sync += z.inf(0)
209 with m.Else():
210 m.d.sync += z.create(z.s, z.e, z.m)
211
212 def put_z(self, m, z, out_z, next_state):
213 """ put_z: stores the result in the output. raises stb and waits
214 for ack to be set to 1 before moving to the next state.
215 resets stb back to zero when that occurs, as acknowledgement.
216 """
217 m.d.sync += [
218 out_z.stb.eq(1),
219 out_z.v.eq(z.v)
220 ]
221 with m.If(out_z.stb & out_z.ack):
222 m.d.sync += out_z.stb.eq(0)
223 m.next = next_state
224
225
226 class FPADD(FPBase):
227
228 def __init__(self, width):
229 FPBase.__init__(self)
230 self.width = width
231
232 self.in_a = FPOp(width)
233 self.in_b = FPOp(width)
234 self.out_z = FPOp(width)
235
236 def get_fragment(self, platform=None):
237 """ creates the HDL code-fragment for FPAdd
238 """
239 m = Module()
240
241 # Latches
242 a = FPNum(self.width)
243 b = FPNum(self.width)
244 z = FPNum(self.width, 24)
245
246 tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow
247
248 of = Overflow()
249
250 with m.FSM() as fsm:
251
252 # ******
253 # gets operand a
254
255 with m.State("get_a"):
256 self.get_op(m, self.in_a, a, "get_b")
257
258 # ******
259 # gets operand b
260
261 with m.State("get_b"):
262 self.get_op(m, self.in_b, b, "special_cases")
263
264 # ******
265 # special cases: NaNs, infs, zeros, denormalised
266 # NOTE: some of these are unique to add. see "Special Operations"
267 # https://steve.hollasch.net/cgindex/coding/ieeefloat.html
268
269 with m.State("special_cases"):
270
271 # if a is NaN or b is NaN return NaN
272 with m.If(a.is_nan() | b.is_nan()):
273 m.next = "put_z"
274 m.d.sync += z.nan(1)
275
276 # if a is inf return inf (or NaN)
277 with m.Elif(a.is_inf()):
278 m.next = "put_z"
279 m.d.sync += z.inf(a.s)
280 # if a is inf and signs don't match return NaN
281 with m.If((b.e == b.P128) & (a.s != b.s)):
282 m.d.sync += z.nan(b.s)
283
284 # if b is inf return inf
285 with m.Elif(b.is_inf()):
286 m.next = "put_z"
287 m.d.sync += z.inf(b.s)
288
289 # if a is zero and b zero return signed-a/b
290 with m.Elif(a.is_zero() & b.is_zero()):
291 m.next = "put_z"
292 m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1])
293
294 # if a is zero return b
295 with m.Elif(a.is_zero()):
296 m.next = "put_z"
297 m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1])
298
299 # if b is zero return a
300 with m.Elif(b.is_zero()):
301 m.next = "put_z"
302 m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1])
303
304 # Denormalised Number checks
305 with m.Else():
306 m.next = "align"
307 self.denormalise(m, a)
308 self.denormalise(m, b)
309
310 # ******
311 # align. NOTE: this does *not* do single-cycle multi-shifting,
312 # it *STAYS* in the align state until the exponents match
313
314 with m.State("align"):
315 # exponent of a greater than b: increment b exp, shift b mant
316 with m.If(a.e > b.e):
317 m.d.sync += b.shift_down()
318 # exponent of b greater than a: increment a exp, shift a mant
319 with m.Elif(a.e < b.e):
320 m.d.sync += a.shift_down()
321 # exponents equal: move to next stage.
322 with m.Else():
323 m.next = "add_0"
324
325 # ******
326 # First stage of add. covers same-sign (add) and subtract
327 # special-casing when mantissas are greater or equal, to
328 # give greatest accuracy.
329
330 with m.State("add_0"):
331 m.next = "add_1"
332 m.d.sync += z.e.eq(a.e)
333 # same-sign (both negative or both positive) add mantissas
334 with m.If(a.s == b.s):
335 m.d.sync += [
336 tot.eq(a.m + b.m),
337 z.s.eq(a.s)
338 ]
339 # a mantissa greater than b, use a
340 with m.Elif(a.m >= b.m):
341 m.d.sync += [
342 tot.eq(a.m - b.m),
343 z.s.eq(a.s)
344 ]
345 # b mantissa greater than a, use b
346 with m.Else():
347 m.d.sync += [
348 tot.eq(b.m - a.m),
349 z.s.eq(b.s)
350 ]
351
352 # ******
353 # Second stage of add: preparation for normalisation.
354 # detects when tot sum is too big (tot[27] is kinda a carry bit)
355
356 with m.State("add_1"):
357 m.next = "normalise_1"
358 # tot[27] gets set when the sum overflows. shift result down
359 with m.If(tot[27]):
360 m.d.sync += [
361 z.m.eq(tot[4:28]),
362 of.guard.eq(tot[3]),
363 of.round_bit.eq(tot[2]),
364 of.sticky.eq(tot[1] | tot[0]),
365 z.e.eq(z.e + 1)
366 ]
367 # tot[27] zero case
368 with m.Else():
369 m.d.sync += [
370 z.m.eq(tot[3:27]),
371 of.guard.eq(tot[2]),
372 of.round_bit.eq(tot[1]),
373 of.sticky.eq(tot[0])
374 ]
375
376 # ******
377 # First stage of normalisation.
378
379 with m.State("normalise_1"):
380 self.normalise_1(m, z, of, "normalise_2")
381
382 # ******
383 # Second stage of normalisation.
384
385 with m.State("normalise_2"):
386 self.normalise_2(m, z, of, "round")
387
388 # ******
389 # rounding stage
390
391 with m.State("round"):
392 self.roundz(m, z, of, "corrections")
393
394 # ******
395 # correction stage
396
397 with m.State("corrections"):
398 self.corrections(m, z, "pack")
399
400 # ******
401 # pack stage
402
403 with m.State("pack"):
404 self.pack(m, z, "put_z")
405
406 # ******
407 # put_z stage
408
409 with m.State("put_z"):
410 self.put_z(m, z, self.out_z, "get_a")
411
412 return m
413
414
415 if __name__ == "__main__":
416 alu = FPADD(width=32)
417 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
418
419
420 # works... but don't use, just do "python fname.py convert -t v"
421 #print (verilog.convert(alu, ports=[
422 # ports=alu.in_a.ports() + \
423 # alu.in_b.ports() + \
424 # alu.out_z.ports())