add comment on special operations
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat, Const
6 from nmigen.cli import main, verilog
7
8
9 class FPNum:
10 """ Floating-point Number Class, variable-width TODO (currently 32-bit)
11
12 Contains signals for an incoming copy of the value, decoded into
13 sign / exponent / mantissa.
14 Also contains encoding functions, creation and recognition of
15 zero, NaN and inf (all signed)
16
17 Four extra bits are included in the mantissa: the top bit
18 (m[-1]) is effectively a carry-overflow. The other three are
19 guard (m[2]), round (m[1]), and sticky (m[0])
20 """
21 def __init__(self, width, m_width=None):
22 self.width = width
23 if m_width is None:
24 m_width = width - 5 # mantissa extra bits (top,guard,round)
25 self.v = Signal(width) # Latched copy of value
26 self.m = Signal(m_width) # Mantissa
27 self.e = Signal((10, True)) # Exponent: 10 bits, signed
28 self.s = Signal() # Sign bit
29
30 self.mzero = Const(0, (m_width, False))
31 self.m1s = Const(-1, (m_width, False))
32 self.P128 = Const(128, (10, True))
33 self.P127 = Const(127, (10, True))
34 self.N127 = Const(-127, (10, True))
35 self.N126 = Const(-126, (10, True))
36
37 def decode(self, v):
38 """ decodes a latched value into sign / exponent / mantissa
39
40 bias is subtracted here, from the exponent. exponent
41 is extended to 10 bits so that subtract 127 is done on
42 a 10-bit number
43 """
44 return [self.m.eq(Cat(0, 0, 0, v[0:23])), # mantissa
45 self.e.eq(v[23:31] - self.P127), # exp (minus bias)
46 self.s.eq(v[31]), # sign
47 ]
48
49 def create(self, s, e, m):
50 """ creates a value from sign / exponent / mantissa
51
52 bias is added here, to the exponent
53 """
54 return [
55 self.v[31].eq(s), # sign
56 self.v[23:31].eq(e + self.P127), # exp (add on bias)
57 self.v[0:23].eq(m) # mantissa
58 ]
59
60 def shift_down(self):
61 """ shifts a mantissa down by one. exponent is increased to compensate
62
63 accuracy is lost as a result in the mantissa however there are 3
64 guard bits (the latter of which is the "sticky" bit)
65 """
66 return [self.e.eq(self.e + 1),
67 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
68 ]
69
70 def nan(self, s):
71 return self.create(s, self.P128, 1<<22)
72
73 def inf(self, s):
74 return self.create(s, self.P128, 0)
75
76 def zero(self, s):
77 return self.create(s, self.N127, 0)
78
79 def is_nan(self):
80 return (self.e == self.P128) & (self.m != 0)
81
82 def is_inf(self):
83 return (self.e == self.P128) & (self.m == 0)
84
85 def is_zero(self):
86 return (self.e == self.N127) & (self.m == self.mzero)
87
88 def is_overflowed(self):
89 return (self.e > self.P127)
90
91 def is_denormalised(self):
92 return (self.e == self.N126) & (self.m[23] == 0)
93
94
95 class FPOp:
96 def __init__(self, width):
97 self.width = width
98
99 self.v = Signal(width)
100 self.stb = Signal()
101 self.ack = Signal()
102
103 def ports(self):
104 return [self.v, self.stb, self.ack]
105
106
107 class Overflow:
108 def __init__(self):
109 self.guard = Signal() # tot[2]
110 self.round_bit = Signal() # tot[1]
111 self.sticky = Signal() # tot[0]
112
113
114 class FPADD:
115 def __init__(self, width):
116 self.width = width
117
118 self.in_a = FPOp(width)
119 self.in_b = FPOp(width)
120 self.out_z = FPOp(width)
121
122 def get_op(self, m, op, v, next_state):
123 """ this function moves to the next state and copies the operand
124 when both stb and ack are 1.
125 acknowledgement is sent by setting ack to ZERO.
126 """
127 with m.If((op.ack) & (op.stb)):
128 m.next = next_state
129 m.d.sync += [
130 v.decode(op.v),
131 op.ack.eq(0)
132 ]
133 with m.Else():
134 m.d.sync += op.ack.eq(1)
135
136 def normalise_1(self, m, z, of, next_state):
137 """ first stage normalisation
138
139 NOTE: just like "align", this one keeps going round every clock
140 until the result's exponent is within acceptable "range"
141 NOTE: the weirdness of reassigning guard and round is due to
142 the extra mantissa bits coming from tot[0..2]
143 """
144 with m.If((z.m[-1] == 0) & (z.e > z.N126)):
145 m.d.sync +=[
146 z.e.eq(z.e - 1), # DECREASE exponent
147 z.m.eq(z.m << 1), # shift mantissa UP
148 z.m[0].eq(of.guard), # steal guard bit (was tot[2])
149 of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
150 of.round_bit.eq(0), # reset round bit
151 ]
152 with m.Else():
153 m.next = next_state
154
155 def normalise_2(self, m, z, of, next_state):
156 """ second stage normalisation
157
158 NOTE: just like "align", this one keeps going round every clock
159 until the result's exponent is within acceptable "range"
160 NOTE: the weirdness of reassigning guard and round is due to
161 the extra mantissa bits coming from tot[0..2]
162 """
163 with m.If(z.e < z.N126):
164 m.d.sync +=[
165 z.e.eq(z.e + 1), # INCREASE exponent
166 z.m.eq(z.m >> 1), # shift mantissa DOWN
167 of.guard.eq(z.m[0]),
168 of.round_bit.eq(of.guard),
169 of.sticky.eq(of.sticky | of.round_bit)
170 ]
171 with m.Else():
172 m.next = next_state
173
174 def roundz(self, m, z, of, next_state):
175 """ performs rounding on the output. TODO: different kinds of rounding
176 """
177 m.next = next_state
178 with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])):
179 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
180 with m.If(z.m == z.m1s): # all 1s
181 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
182
183 def corrections(self, m, z, next_state):
184 """ denormalisation and sign-bug corrections
185 """
186 m.next = next_state
187 # denormalised, correct exponent to zero
188 with m.If(z.is_denormalised()):
189 m.d.sync += z.m.eq(-127)
190 # FIX SIGN BUG: -a + a = +0.
191 with m.If((z.e == z.N126) & (z.m[0:] == 0)):
192 m.d.sync += z.s.eq(0)
193
194 def pack(self, m, z, next_state):
195 """ packs the result into the output (detects overflow->Inf)
196 """
197 m.next = next_state
198 # if overflow occurs, return inf
199 with m.If(z.is_overflowed()):
200 m.d.sync += z.inf(0)
201 with m.Else():
202 m.d.sync += z.create(z.s, z.e, z.m)
203
204 def put_z(self, m, z, out_z, next_state):
205 """ put_z: stores the result in the output. raises stb and waits
206 for ack to be set to 1 before moving to the next state.
207 resets stb back to zero when that occurs, as acknowledgement.
208 """
209 m.d.sync += [
210 out_z.stb.eq(1),
211 out_z.v.eq(z.v)
212 ]
213 with m.If(out_z.stb & out_z.ack):
214 m.d.sync += out_z.stb.eq(0)
215 m.next = next_state
216
217 def get_fragment(self, platform=None):
218 """ creates the HDL code-fragment for FPAdd
219 """
220 m = Module()
221
222 # Latches
223 a = FPNum(self.width)
224 b = FPNum(self.width)
225 z = FPNum(self.width, 24)
226
227 tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow
228
229 of = Overflow()
230
231 with m.FSM() as fsm:
232
233 # ******
234 # gets operand a
235
236 with m.State("get_a"):
237 self.get_op(m, self.in_a, a, "get_b")
238
239 # ******
240 # gets operand b
241
242 with m.State("get_b"):
243 self.get_op(m, self.in_b, b, "special_cases")
244
245 # ******
246 # special cases: NaNs, infs, zeros, denormalised
247 # NOTE: some of these are unique to add. see "Special Operations"
248 # https://steve.hollasch.net/cgindex/coding/ieeefloat.html
249
250 with m.State("special_cases"):
251
252 # if a is NaN or b is NaN return NaN
253 with m.If(a.is_nan() | b.is_nan()):
254 m.next = "put_z"
255 m.d.sync += z.nan(1)
256
257 # if a is inf return inf (or NaN)
258 with m.Elif(a.is_inf()):
259 m.next = "put_z"
260 m.d.sync += z.inf(a.s)
261 # if a is inf and signs don't match return NaN
262 with m.If((b.e == b.P128) & (a.s != b.s)):
263 m.d.sync += z.nan(b.s)
264
265 # if b is inf return inf
266 with m.Elif(b.is_inf()):
267 m.next = "put_z"
268 m.d.sync += z.inf(b.s)
269
270 # if a is zero and b zero return signed-a/b
271 with m.Elif(a.is_zero() & b.is_zero()):
272 m.next = "put_z"
273 m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1])
274
275 # if a is zero return b
276 with m.Elif(a.is_zero()):
277 m.next = "put_z"
278 m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1])
279
280 # if b is zero return a
281 with m.Elif(b.is_zero()):
282 m.next = "put_z"
283 m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1])
284
285 # Denormalised Number checks
286 with m.Else():
287 m.next = "align"
288 # denormalise a check
289 with m.If(a.e == a.N127):
290 m.d.sync += a.e.eq(-126) # limit a exponent
291 with m.Else():
292 m.d.sync += a.m[-1].eq(1) # set top mantissa bit
293 # denormalise b check
294 with m.If(b.e == a.N127):
295 m.d.sync += b.e.eq(-126) # limit b exponent
296 with m.Else():
297 m.d.sync += b.m[-1].eq(1) # set top mantissa bit
298
299 # ******
300 # align. NOTE: this does *not* do single-cycle multi-shifting,
301 # it *STAYS* in the align state until the exponents match
302
303 with m.State("align"):
304 # exponent of a greater than b: increment b exp, shift b mant
305 with m.If(a.e > b.e):
306 m.d.sync += b.shift_down()
307 # exponent of b greater than a: increment a exp, shift a mant
308 with m.Elif(a.e < b.e):
309 m.d.sync += a.shift_down()
310 # exponents equal: move to next stage.
311 with m.Else():
312 m.next = "add_0"
313
314 # ******
315 # First stage of add. covers same-sign (add) and subtract
316 # special-casing when mantissas are greater or equal, to
317 # give greatest accuracy.
318
319 with m.State("add_0"):
320 m.next = "add_1"
321 m.d.sync += z.e.eq(a.e)
322 # same-sign (both negative or both positive) add mantissas
323 with m.If(a.s == b.s):
324 m.d.sync += [
325 tot.eq(a.m + b.m),
326 z.s.eq(a.s)
327 ]
328 # a mantissa greater than b, use a
329 with m.Elif(a.m >= b.m):
330 m.d.sync += [
331 tot.eq(a.m - b.m),
332 z.s.eq(a.s)
333 ]
334 # b mantissa greater than a, use b
335 with m.Else():
336 m.d.sync += [
337 tot.eq(b.m - a.m),
338 z.s.eq(b.s)
339 ]
340
341 # ******
342 # Second stage of add: preparation for normalisation.
343 # detects when tot sum is too big (tot[27] is kinda a carry bit)
344
345 with m.State("add_1"):
346 m.next = "normalise_1"
347 # tot[27] gets set when the sum overflows. shift result down
348 with m.If(tot[27]):
349 m.d.sync += [
350 z.m.eq(tot[4:28]),
351 of.guard.eq(tot[3]),
352 of.round_bit.eq(tot[2]),
353 of.sticky.eq(tot[1] | tot[0]),
354 z.e.eq(z.e + 1)
355 ]
356 # tot[27] zero case
357 with m.Else():
358 m.d.sync += [
359 z.m.eq(tot[3:27]),
360 of.guard.eq(tot[2]),
361 of.round_bit.eq(tot[1]),
362 of.sticky.eq(tot[0])
363 ]
364
365 # ******
366 # First stage of normalisation.
367
368 with m.State("normalise_1"):
369 self.normalise_1(m, z, of, "normalise_2")
370
371 # ******
372 # Second stage of normalisation.
373
374 with m.State("normalise_2"):
375 self.normalise_2(m, z, of, "round")
376
377 # ******
378 # rounding stage
379
380 with m.State("round"):
381 self.roundz(m, z, of, "corrections")
382
383 # ******
384 # correction stage
385
386 with m.State("corrections"):
387 self.corrections(m, z, "pack")
388
389 # ******
390 # pack stage
391
392 with m.State("pack"):
393 self.pack(m, z, "put_z")
394
395 # ******
396 # put_z stage
397
398 with m.State("put_z"):
399 self.put_z(m, z, self.out_z, "get_a")
400
401 return m
402
403
404 if __name__ == "__main__":
405 alu = FPADD(width=32)
406 main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports())
407
408
409 # works... but don't use, just do "python fname.py convert -t v"
410 #print (verilog.convert(alu, ports=[
411 # ports=alu.in_a.ports() + \
412 # alu.in_b.ports() + \
413 # alu.out_z.ports())