create and use decode function
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat
6 from nmigen.cli import main
7
8 class FPNum:
9 def __init__(self, width, m_width=None):
10 self.width = width
11 if m_width is None:
12 m_width = width + 3
13 self.v = Signal(width) # Latched copy of value
14 self.m = Signal(m_width) # Mantissa: ??? seems to be 1 bit extra??
15 self.e = Signal((10, True)) # Exponent: 10 bits, signed
16 self.s = Signal() # Sign bit
17
18 def decode(self):
19 v = self.v
20 return [self.m.eq(Cat(0, 0, 0, v[0:23])), # mantissa
21 self.e.eq(Cat(v[23:31]) - 127), # exponent (take off bias)
22 self.s.eq(Cat(v[31])), # sign
23 ]
24
25 def create(self, s, e, m):
26 return [
27 self.v[31].eq(s), # sign
28 self.v[23:31].eq(e), # exp
29 self.v[0:23].eq(m) # mantissa
30 ]
31
32 def nan(self, s):
33 return self.create(s, 0xff, 1<<22)
34
35 def inf(self, s):
36 return self.create(s, 0xff, 0)
37
38
39 class FPADD:
40 def __init__(self, width):
41 self.width = width
42
43 self.in_a = Signal(width)
44 self.in_a_stb = Signal()
45 self.in_a_ack = Signal()
46
47 self.in_b = Signal(width)
48 self.in_b_stb = Signal()
49 self.in_b_ack = Signal()
50
51 self.out_z = Signal(width)
52 self.out_z_stb = Signal()
53 self.out_z_ack = Signal()
54
55 s_out_z_stb = Signal()
56 s_out_z = Signal(width)
57 s_in_a_ack = Signal()
58 s_in_b_ack = Signal()
59
60 def get_fragment(self, platform):
61 m = Module()
62
63 # Latches
64 a = FPNum(self.width)
65 b = FPNum(self.width)
66 z = FPNum(self.width, 24)
67
68 # Sign
69 a_s = Signal()
70 b_s = Signal()
71 z_s = Signal()
72
73 guard = Signal()
74 round_bit = Signal()
75 sticky = Signal()
76
77 tot = Signal(28)
78
79 with m.FSM() as fsm:
80
81 # ******
82 # gets operand a
83
84 with m.State("get_a"):
85 with m.If((self.in_a_ack) & (self.in_a_stb)):
86 m.next = "get_b"
87 m.d.sync += [
88 a.v.eq(self.in_a),
89 self.in_a_ack.eq(0)
90 ]
91 with m.Else():
92 m.d.sync += self.in_a_ack.eq(1)
93
94 # ******
95 # gets operand b
96
97 with m.State("get_b"):
98 with m.If((self.in_b_ack) & (self.in_b_stb)):
99 m.next = "get_a"
100 m.d.sync += [
101 b.v.eq(self.in_b),
102 self.in_b_ack.eq(0)
103 ]
104 with m.Else():
105 m.d.sync += self.in_b_ack.eq(1)
106
107 # ******
108 # unpacks operands into sign, mantissa and exponent
109
110 with m.State("unpack"):
111 m.next = "special_cases"
112 m.d.sync += a.decode()
113 m.d.sync += b.decode()
114
115 # ******
116 # special cases: NaNs, infs, zeros, denormalised
117
118 with m.State("special_cases"):
119
120 # if a is NaN or b is NaN return NaN
121 with m.If(((a.e == 128) & (a.m != 0)) | \
122 ((b.e == 128) & (b.m != 0))):
123 m.next = "put_z"
124 m.d.sync += z.nan(1)
125
126 # if a is inf return inf (or NaN)
127 with m.Elif(a.e == 128):
128 m.next = "put_z"
129 m.d.sync += z.inf(a.s)
130 # if a is inf and signs don't match return NaN
131 with m.If((b.e == 128) & (a.s != b.s)):
132 m.d.sync += z.nan(b.s)
133
134 # if b is inf return inf
135 with m.Elif(b.e == 128):
136 m.next = "put_z"
137 m.d.sync += z.inf(b.s)
138
139 # if a is zero and b zero return signed-a/b
140 with m.Elif(((a.e == -127) & (a.m == 0)) & \
141 ((b.e == -127) & (b.m == 0))):
142 m.next = "put_z"
143 m.d.sync += z.create(a.s & b.s, b.e[0:8] + 127, b.m[3:26])
144
145 # if a is zero return b
146 with m.Elif((a.e == -127) & (a.m == 0)):
147 m.next = "put_z"
148 m.d.sync += z.create(b.s, b.e[0:8] + 127, b.m[3:26])
149
150 # if b is zero return a
151 with m.Elif((b.e == -127) & (b.m == 0)):
152 m.next = "put_z"
153 m.d.sync += z.create(a.s, a.e[0:8] + 127, a.m[3:26])
154
155 # Denormalised Number checks
156 with m.Else():
157 m.next = "align"
158 # denormalise a check
159 with m.If(a.e == -127):
160 m.d.sync += a.e.eq(-126) # limit a exponent
161 with m.Else():
162 m.d.sync += a.m[26].eq(1) # set highest mantissa bit
163 # denormalise b check
164 with m.If(b.e == -127):
165 m.d.sync += b.e.eq(-126) # limit b exponent
166 with m.Else():
167 m.d.sync += b.m[26].eq(1) # set highest mantissa bit
168
169 # ******
170 # align. NOTE: this does *not* do single-cycle multi-shifting,
171 # it *STAYS* in the align state until the exponents match
172
173 with m.State("align"):
174 # exponent of a greater than b: increment b exp, shift b mant
175 with m.If(a.e > b.e):
176 m.d.sync += [
177 b.e.eq(b.e + 1),
178 b.m.eq(b.m >> 1),
179 b.m[0].eq(b.m[0] | b.m[1]) # moo??
180 ]
181 # exponent of b greater than a: increment a exp, shift a mant
182 with m.Elif(a.e < b.e):
183 m.d.sync += [
184 a.e.eq(a.e + 1),
185 a.m.eq(a.m >> 1),
186 a.m[0].eq(a.m[0] | a.m[1]) # moo??
187 ]
188 # exponents equal: move to next stage.
189 with m.Else():
190 m.next = "add_0"
191
192 # ******
193 # First stage of add. covers same-sign (add) and subtract
194 # special-casing when mantissas are greater or equal, to
195 # give greatest accuracy.
196
197 with m.State("add_0"):
198 m.next = "add_1"
199 m.d.sync += z.e.eq(a.e)
200 # same-sign (both negative or both positive) add mantissas
201 with m.If(a.s == b.s):
202 m.d.sync += [
203 tot.eq(a.m + b.m),
204 z_s.eq(a.s)
205 ]
206 # a mantissa greater than b, use a
207 with m.Elif(a.m >= b.m):
208 m.d.sync += [
209 tot.eq(a.m - b.m),
210 z_s.eq(a.s)
211 ]
212 # b mantissa greater than a, use b
213 with m.Else():
214 m.d.sync += [
215 tot.eq(b.m - a.m),
216 z_s.eq(b.s)
217 ]
218
219 # ******
220 # Second stage of add: preparation for normalisation.
221 # detects when tot sum is too big (tot[27] is kinda a carry bit)
222
223 with m.State("add_1"):
224 m.next = "normalise_1"
225 # tot[27] gets set when the sum overflows. shift result down
226 with m.If(tot[27]):
227 m.d.sync += [
228 z.m.eq(tot[4:28]),
229 guard.eq(tot[3]),
230 round_bit.eq(tot[2]),
231 sticky.eq(tot[1] | tot[0]),
232 z.e.eq(z.e + 1)
233 ]
234 # tot[27] zero case
235 with m.Else():
236 m.d.sync += [
237 z.m.eq(tot[3:27]),
238 guard.eq(tot[2]),
239 round_bit.eq(tot[1]),
240 sticky.eq(tot[0])
241 ]
242
243 # ******
244 # First stage of normalisation.
245 # NOTE: just like "align", this one keeps going round every clock
246 # until the result's exponent is within acceptable "range"
247 # NOTE: the weirdness of reassigning guard and round is due to
248 # the extra mantissa bits coming from tot[0..2]
249
250 with m.State("normalise_1"):
251 with m.If((z.m[23] == 0) & (z.e > -126)):
252 m.d.sync +=[
253 z.e.eq(z.e - 1), # DECREASE exponent
254 z.m.eq(z.m << 1), # shift mantissa UP
255 z.m[0].eq(guard), # steal guard bit (was tot[2])
256 guard.eq(round_bit), # steal round_bit (was tot[1])
257 ]
258 with m.Else():
259 m.next = "normalize_2"
260
261 # ******
262 # Second stage of normalisation.
263 # NOTE: just like "align", this one keeps going round every clock
264 # until the result's exponent is within acceptable "range"
265 # NOTE: the weirdness of reassigning guard and round is due to
266 # the extra mantissa bits coming from tot[0..2]
267
268 with m.State("normalise_2"):
269 with m.If(z.e < -126):
270 m.d.sync +=[
271 z.e.eq(z.e + 1), # INCREASE exponent
272 z.m.eq(z.m >> 1), # shift mantissa DOWN
273 guard.eq(z.m[0]),
274 round_bit.eq(guard),
275 sticky.eq(sticky | round_bit)
276 ]
277 with m.Else():
278 m.next = "round"
279
280 # ******
281 # rounding stage
282
283 with m.State("round"):
284 m.next = "pack"
285 with m.If(guard & (round_bit | sticky | z.m[0])):
286 m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
287 with m.If(z.m == 0xffffff): # all 1s
288 m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
289
290 return m
291
292 """
293 always @(posedge clk)
294 begin
295
296 case(state)
297
298 get_a:
299 begin
300 s_in_a_ack <= 1;
301 if (s_in_a_ack && in_a_stb) begin
302 a <= in_a;
303 s_in_a_ack <= 0;
304 state <= get_b;
305 end
306 end
307
308 get_b:
309 begin
310 s_in_b_ack <= 1;
311 if (s_in_b_ack && in_b_stb) begin
312 b <= in_b;
313 s_in_b_ack <= 0;
314 state <= unpack;
315 end
316 end
317
318 unpack:
319 begin
320 a_m <= {a[22 : 0], 3'd0};
321 b_m <= {b[22 : 0], 3'd0};
322 a_e <= a[30 : 23] - 127;
323 b_e <= b[30 : 23] - 127;
324 a_s <= a[31];
325 b_s <= b[31];
326 state <= special_cases;
327 end
328
329 special_cases:
330 begin
331 //if a is NaN or b is NaN return NaN
332 if ((a_e == 128 && a_m != 0) || (b_e == 128 && b_m != 0)) begin
333 z[31] <= 1;
334 z[30:23] <= 255;
335 z[22] <= 1;
336 z[21:0] <= 0;
337 state <= put_z;
338 //if a is inf return inf
339 end else if (a_e == 128) begin
340 z[31] <= a_s;
341 z[30:23] <= 255;
342 z[22:0] <= 0;
343 //if a is inf and signs don't match return nan
344 if ((b_e == 128) && (a_s != b_s)) begin
345 z[31] <= b_s;
346 z[30:23] <= 255;
347 z[22] <= 1;
348 z[21:0] <= 0;
349 end
350 state <= put_z;
351 //if b is inf return inf
352 end else if (b_e == 128) begin
353 z[31] <= b_s;
354 z[30:23] <= 255;
355 z[22:0] <= 0;
356 state <= put_z;
357 //if a is zero return b
358 end else if ((($signed(a_e) == -127) && (a_m == 0)) && (($signed(b_e) == -127) && (b_m == 0))) begin
359 z[31] <= a_s & b_s;
360 z[30:23] <= b_e[7:0] + 127;
361 z[22:0] <= b_m[26:3];
362 state <= put_z;
363 //if a is zero return b
364 end else if (($signed(a_e) == -127) && (a_m == 0)) begin
365 z[31] <= b_s;
366 z[30:23] <= b_e[7:0] + 127;
367 z[22:0] <= b_m[26:3];
368 state <= put_z;
369 //if b is zero return a
370 end else if (($signed(b_e) == -127) && (b_m == 0)) begin
371 z[31] <= a_s;
372 z[30:23] <= a_e[7:0] + 127;
373 z[22:0] <= a_m[26:3];
374 state <= put_z;
375 end else begin
376 //Denormalised Number
377 if ($signed(a_e) == -127) begin
378 a_e <= -126;
379 end else begin
380 a_m[26] <= 1;
381 end
382 //Denormalised Number
383 if ($signed(b_e) == -127) begin
384 b_e <= -126;
385 end else begin
386 b_m[26] <= 1;
387 end
388 state <= align;
389 end
390 end
391
392 align:
393 begin
394 if ($signed(a_e) > $signed(b_e)) begin
395 b_e <= b_e + 1;
396 b_m <= b_m >> 1;
397 b_m[0] <= b_m[0] | b_m[1];
398 end else if ($signed(a_e) < $signed(b_e)) begin
399 a_e <= a_e + 1;
400 a_m <= a_m >> 1;
401 a_m[0] <= a_m[0] | a_m[1];
402 end else begin
403 state <= add_0;
404 end
405 end
406
407 add_0:
408 begin
409 z_e <= a_e;
410 if (a_s == b_s) begin
411 tot <= a_m + b_m;
412 z_s <= a_s;
413 end else begin
414 if (a_m >= b_m) begin
415 tot <= a_m - b_m;
416 z_s <= a_s;
417 end else begin
418 tot <= b_m - a_m;
419 z_s <= b_s;
420 end
421 end
422 state <= add_1;
423 end
424
425 add_1:
426 begin
427 if (tot[27]) begin
428 z_m <= tot[27:4];
429 guard <= tot[3];
430 round_bit <= tot[2];
431 sticky <= tot[1] | tot[0];
432 z_e <= z_e + 1;
433 end else begin
434 z_m <= tot[26:3];
435 guard <= tot[2];
436 round_bit <= tot[1];
437 sticky <= tot[0];
438 end
439 state <= normalise_1;
440 end
441
442 normalise_1:
443 begin
444 if (z_m[23] == 0 && $signed(z_e) > -126) begin
445 z_e <= z_e - 1;
446 z_m <= z_m << 1;
447 z_m[0] <= guard;
448 guard <= round_bit;
449 round_bit <= 0;
450 end else begin
451 state <= normalise_2;
452 end
453 end
454
455 normalise_2:
456 begin
457 if ($signed(z_e) < -126) begin
458 z_e <= z_e + 1;
459 z_m <= z_m >> 1;
460 guard <= z_m[0];
461 round_bit <= guard;
462 sticky <= sticky | round_bit;
463 end else begin
464 state <= round;
465 end
466 end
467
468 round:
469 begin
470 if (guard && (round_bit | sticky | z_m[0])) begin
471 z_m <= z_m + 1;
472 if (z_m == 24'hffffff) begin
473 z_e <=z_e + 1;
474 end
475 end
476 state <= pack;
477 end
478
479 pack:
480 begin
481 z[22 : 0] <= z_m[22:0];
482 z[30 : 23] <= z_e[7:0] + 127;
483 z[31] <= z_s;
484 if ($signed(z_e) == -126 && z_m[23] == 0) begin
485 z[30 : 23] <= 0;
486 end
487 if ($signed(z_e) == -126 && z_m[23:0] == 24'h0) begin
488 z[31] <= 1'b0; // FIX SIGN BUG: -a + a = +0.
489 end
490 //if overflow occurs, return inf
491 if ($signed(z_e) > 127) begin
492 z[22 : 0] <= 0;
493 z[30 : 23] <= 255;
494 z[31] <= z_s;
495 end
496 state <= put_z;
497 end
498
499 put_z:
500 begin
501 s_out_z_stb <= 1;
502 s_out_z <= z;
503 if (s_out_z_stb && out_z_ack) begin
504 s_out_z_stb <= 0;
505 state <= get_a;
506 end
507 end
508
509 endcase
510
511 if (rst == 1) begin
512 state <= get_a;
513 s_in_a_ack <= 0;
514 s_in_b_ack <= 0;
515 s_out_z_stb <= 0;
516 end
517
518 end
519 assign in_a_ack = s_in_a_ack;
520 assign in_b_ack = s_in_b_ack;
521 assign out_z_stb = s_out_z_stb;
522 assign out_z = s_out_z;
523
524 endmodule
525 """
526
527 if __name__ == "__main__":
528 alu = FPADD(width=32)
529 main(alu, ports=[
530 alu.in_a, alu.in_a_stb, alu.in_a_ack,
531 alu.in_b, alu.in_b_stb, alu.in_b_ack,
532 alu.out_z, alu.out_z_stb, alu.out_z_ack,
533 ])