use function "create_z" which... well... creates a result from (s,e,m)
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
1 # IEEE Floating Point Adder (Single Precision)
2 # Copyright (C) Jonathan P Dawson 2013
3 # 2013-12-12
4
5 from nmigen import Module, Signal, Cat
6 from nmigen.cli import main
7
8
9 class FPADD:
10 def __init__(self, width):
11 self.width = width
12
13 self.in_a = Signal(width)
14 self.in_a_stb = Signal()
15 self.in_a_ack = Signal()
16
17 self.in_b = Signal(width)
18 self.in_b_stb = Signal()
19 self.in_b_ack = Signal()
20
21 self.out_z = Signal(width)
22 self.out_z_stb = Signal()
23 self.out_z_ack = Signal()
24
25 s_out_z_stb = Signal()
26 s_out_z = Signal(width)
27 s_in_a_ack = Signal()
28 s_in_b_ack = Signal()
29
30 def create_z(self, z, s, e, m):
31 return [
32 z[31].eq(s), # sign
33 z[23:31].eq(e), # exp
34 z[0:23].eq(m) # mantissa
35 ]
36
37 def get_fragment(self, platform):
38 m = Module()
39
40 # Latches
41 a = Signal(self.width)
42 b = Signal(self.width)
43 z = Signal(self.width)
44
45 # Mantissa
46 a_m = Signal(27) # ??? seems to be 1 bit extra??
47 b_m = Signal(27) # ??? seems to be 1 bit extra??
48 z_m = Signal(24)
49
50 # Exponent: 10 bits, signed (the exponent bias is subtracted)
51 a_e = Signal((10, True))
52 b_e = Signal((10, True))
53 z_e = Signal((10, True))
54
55 # Sign
56 a_s = Signal()
57 b_s = Signal()
58 z_s = Signal()
59
60 guard = Signal()
61 round_bit = Signal()
62 sticky = Signal()
63
64 tot = Signal(28)
65
66 with m.FSM() as fsm:
67
68 # ******
69 # gets operand a
70
71 with m.State("get_a"):
72 with m.If((self.in_a_ack) & (self.in_a_stb)):
73 m.next = "get_b"
74 m.d.sync += [
75 a.eq(self.in_a),
76 self.in_a_ack.eq(0)
77 ]
78 with m.Else():
79 m.d.sync += self.in_a_ack.eq(1)
80
81 # ******
82 # gets operand b
83
84 with m.State("get_b"):
85 with m.If((self.in_b_ack) & (self.in_b_stb)):
86 m.next = "get_a"
87 m.d.sync += [
88 b.eq(self.in_b),
89 self.in_b_ack.eq(0)
90 ]
91 with m.Else():
92 m.d.sync += self.in_b_ack.eq(1)
93
94 # ******
95 # unpacks operands into sign, mantissa and exponent
96
97 with m.State("unpack"):
98 m.next = "special_cases"
99 m.d.sync += [
100 # mantissa
101 a_m.eq(Cat(0, 0, 0, a[0:23])),
102 b_m.eq(Cat(0, 0, 0, b[0:23])),
103 # exponent (take off exponent bias, here)
104 a_e.eq(Cat(a[23:31]) - 127),
105 b_e.eq(Cat(b[23:31]) - 127),
106 # sign
107 a_s.eq(Cat(a[31])),
108 b_s.eq(Cat(b[31]))
109 ]
110
111 # ******
112 # special cases: NaNs, infs, zeros, denormalised
113
114 with m.State("special_cases"):
115
116 # if a is NaN or b is NaN return NaN
117 with m.If(((a_e == 128) & (a_m != 0)) | \
118 ((b_e == 128) & (b_m != 0))):
119 m.next = "put_z"
120 m.d.sync += self.create_z(z, 1, 255, 1<<22)
121
122 # if a is inf return inf (or NaN)
123 with m.Elif(a_e == 128):
124 m.next = "put_z"
125 m.d.sync += self.create_z(z, a_s, 255, 0)
126 # if a is inf and signs don't match return NaN
127 with m.If((b_e == 128) & (a_s != b_s)):
128 m.d.sync += self.create_z(z, b_s, 255, 1<<22)
129
130 # if b is inf return inf
131 with m.Elif(b_e == 128):
132 m.next = "put_z"
133 m.d.sync += self.create_z(z, b_s, 255, 0)
134
135 # if a is zero and b zero return signed-a/b
136 with m.Elif(((a_e == -127) & (a_m == 0)) & \
137 ((b_e == -127) & (b_m == 0))):
138 m.next = "put_z"
139 m.d.sync += self.create_z(z, a_s & b_s,
140 b_e[0:8] + 127,
141 b_m[3:26])
142
143 # if a is zero return b
144 with m.Elif((a_e == -127) & (a_m == 0)):
145 m.next = "put_z"
146 m.d.sync += self.create_z(z, b_s,
147 b_e[0:8] + 127,
148 b_m[3:26])
149
150 # if b is zero return a
151 with m.Elif((b_e == -127) & (b_m == 0)):
152 m.next = "put_z"
153 m.d.sync += self.create_z(z, a_s,
154 a_e[0:8] + 127,
155 a_m[3:26])
156
157 # Denormalised Number checks
158 with m.Else():
159 m.next = "align"
160 # denormalise a check
161 with m.If(a_e == -127):
162 m.d.sync += a_e.eq(-126) # limit a exponent
163 with m.Else():
164 m.d.sync += a_m[26].eq(1) # set highest mantissa bit
165 # denormalise b check
166 with m.If(b_e == -127):
167 m.d.sync += b_e.eq(-126) # limit b exponent
168 with m.Else():
169 m.d.sync += b_m[26].eq(1) # set highest mantissa bit
170
171 # ******
172 # align. NOTE: this does *not* do single-cycle multi-shifting,
173 # it *STAYS* in the align state until the exponents match
174
175 with m.State("align"):
176 # exponent of a greater than b: increment b exp, shift b mant
177 with m.If(a_e > b_e):
178 m.d.sync += [
179 b_e.eq(b_e + 1),
180 b_m.eq(b_m >> 1),
181 b_m[0].eq(b_m[0] | b_m[1]) # moo??
182 ]
183 # exponent of b greater than a: increment a exp, shift a mant
184 with m.Elif(a_e < b_e):
185 m.d.sync += [
186 a_e.eq(a_e + 1),
187 a_m.eq(a_m >> 1),
188 a_m[0].eq(a_m[0] | a_m[1]) # moo??
189 ]
190 # exponents equal: move to next stage.
191 with m.Else():
192 m.next = "add_0"
193
194 # ******
195 # First stage of add
196
197 with m.State("add_0"):
198 m.next = "add_1"
199 m.d.sync += z_e.eq(a_e)
200 # same-sign (both negative or both positive) add mantissas
201 with m.If(a_s == b_s):
202 m.d.sync += [
203 tot.eq(a_m + b_m),
204 z_s.eq(a_s)
205 ]
206 # a mantissa greater than b, use a
207 with m.Elif(a_m >= b_m):
208 m.d.sync += [
209 tot.eq(a_m - b_m),
210 z_s.eq(a_s)
211 ]
212 # b mantissa greater than a, use b
213 with m.Else():
214 m.d.sync += [
215 tot.eq(b_m - a_m),
216 z_s.eq(b_s)
217 ]
218
219 # ******
220 # Second stage of add: preparation for normalisation
221
222 with m.State("add_1"):
223 m.next = "normalise_1"
224 # tot[27] gets set when the sum overflows. shift result down
225 with m.If(tot[27]):
226 m.d.sync += [
227 z_m.eq(tot[4:28]),
228 guard.eq(tot[3]),
229 round_bit.eq(tot[2]),
230 sticky.eq(tot[1] | tot[0]),
231 z_e.eq(z_e + 1)
232 ]
233 # tot[27] zero case
234 with m.Else():
235 m.d.sync += [
236 z_m.eq(tot[3:27]),
237 guard.eq(tot[2]),
238 round_bit.eq(tot[1]),
239 sticky.eq(tot[0])
240 ]
241 return m
242
243 """
244 always @(posedge clk)
245 begin
246
247 case(state)
248
249 get_a:
250 begin
251 s_in_a_ack <= 1;
252 if (s_in_a_ack && in_a_stb) begin
253 a <= in_a;
254 s_in_a_ack <= 0;
255 state <= get_b;
256 end
257 end
258
259 get_b:
260 begin
261 s_in_b_ack <= 1;
262 if (s_in_b_ack && in_b_stb) begin
263 b <= in_b;
264 s_in_b_ack <= 0;
265 state <= unpack;
266 end
267 end
268
269 unpack:
270 begin
271 a_m <= {a[22 : 0], 3'd0};
272 b_m <= {b[22 : 0], 3'd0};
273 a_e <= a[30 : 23] - 127;
274 b_e <= b[30 : 23] - 127;
275 a_s <= a[31];
276 b_s <= b[31];
277 state <= special_cases;
278 end
279
280 special_cases:
281 begin
282 //if a is NaN or b is NaN return NaN
283 if ((a_e == 128 && a_m != 0) || (b_e == 128 && b_m != 0)) begin
284 z[31] <= 1;
285 z[30:23] <= 255;
286 z[22] <= 1;
287 z[21:0] <= 0;
288 state <= put_z;
289 //if a is inf return inf
290 end else if (a_e == 128) begin
291 z[31] <= a_s;
292 z[30:23] <= 255;
293 z[22:0] <= 0;
294 //if a is inf and signs don't match return nan
295 if ((b_e == 128) && (a_s != b_s)) begin
296 z[31] <= b_s;
297 z[30:23] <= 255;
298 z[22] <= 1;
299 z[21:0] <= 0;
300 end
301 state <= put_z;
302 //if b is inf return inf
303 end else if (b_e == 128) begin
304 z[31] <= b_s;
305 z[30:23] <= 255;
306 z[22:0] <= 0;
307 state <= put_z;
308 //if a is zero return b
309 end else if ((($signed(a_e) == -127) && (a_m == 0)) && (($signed(b_e) == -127) && (b_m == 0))) begin
310 z[31] <= a_s & b_s;
311 z[30:23] <= b_e[7:0] + 127;
312 z[22:0] <= b_m[26:3];
313 state <= put_z;
314 //if a is zero return b
315 end else if (($signed(a_e) == -127) && (a_m == 0)) begin
316 z[31] <= b_s;
317 z[30:23] <= b_e[7:0] + 127;
318 z[22:0] <= b_m[26:3];
319 state <= put_z;
320 //if b is zero return a
321 end else if (($signed(b_e) == -127) && (b_m == 0)) begin
322 z[31] <= a_s;
323 z[30:23] <= a_e[7:0] + 127;
324 z[22:0] <= a_m[26:3];
325 state <= put_z;
326 end else begin
327 //Denormalised Number
328 if ($signed(a_e) == -127) begin
329 a_e <= -126;
330 end else begin
331 a_m[26] <= 1;
332 end
333 //Denormalised Number
334 if ($signed(b_e) == -127) begin
335 b_e <= -126;
336 end else begin
337 b_m[26] <= 1;
338 end
339 state <= align;
340 end
341 end
342
343 align:
344 begin
345 if ($signed(a_e) > $signed(b_e)) begin
346 b_e <= b_e + 1;
347 b_m <= b_m >> 1;
348 b_m[0] <= b_m[0] | b_m[1];
349 end else if ($signed(a_e) < $signed(b_e)) begin
350 a_e <= a_e + 1;
351 a_m <= a_m >> 1;
352 a_m[0] <= a_m[0] | a_m[1];
353 end else begin
354 state <= add_0;
355 end
356 end
357
358 add_0:
359 begin
360 z_e <= a_e;
361 if (a_s == b_s) begin
362 tot <= a_m + b_m;
363 z_s <= a_s;
364 end else begin
365 if (a_m >= b_m) begin
366 tot <= a_m - b_m;
367 z_s <= a_s;
368 end else begin
369 tot <= b_m - a_m;
370 z_s <= b_s;
371 end
372 end
373 state <= add_1;
374 end
375
376 add_1:
377 begin
378 if (tot[27]) begin
379 z_m <= tot[27:4];
380 guard <= tot[3];
381 round_bit <= tot[2];
382 sticky <= tot[1] | tot[0];
383 z_e <= z_e + 1;
384 end else begin
385 z_m <= tot[26:3];
386 guard <= tot[2];
387 round_bit <= tot[1];
388 sticky <= tot[0];
389 end
390 state <= normalise_1;
391 end
392
393 normalise_1:
394 begin
395 if (z_m[23] == 0 && $signed(z_e) > -126) begin
396 z_e <= z_e - 1;
397 z_m <= z_m << 1;
398 z_m[0] <= guard;
399 guard <= round_bit;
400 round_bit <= 0;
401 end else begin
402 state <= normalise_2;
403 end
404 end
405
406 normalise_2:
407 begin
408 if ($signed(z_e) < -126) begin
409 z_e <= z_e + 1;
410 z_m <= z_m >> 1;
411 guard <= z_m[0];
412 round_bit <= guard;
413 sticky <= sticky | round_bit;
414 end else begin
415 state <= round;
416 end
417 end
418
419 round:
420 begin
421 if (guard && (round_bit | sticky | z_m[0])) begin
422 z_m <= z_m + 1;
423 if (z_m == 24'hffffff) begin
424 z_e <=z_e + 1;
425 end
426 end
427 state <= pack;
428 end
429
430 pack:
431 begin
432 z[22 : 0] <= z_m[22:0];
433 z[30 : 23] <= z_e[7:0] + 127;
434 z[31] <= z_s;
435 if ($signed(z_e) == -126 && z_m[23] == 0) begin
436 z[30 : 23] <= 0;
437 end
438 if ($signed(z_e) == -126 && z_m[23:0] == 24'h0) begin
439 z[31] <= 1'b0; // FIX SIGN BUG: -a + a = +0.
440 end
441 //if overflow occurs, return inf
442 if ($signed(z_e) > 127) begin
443 z[22 : 0] <= 0;
444 z[30 : 23] <= 255;
445 z[31] <= z_s;
446 end
447 state <= put_z;
448 end
449
450 put_z:
451 begin
452 s_out_z_stb <= 1;
453 s_out_z <= z;
454 if (s_out_z_stb && out_z_ack) begin
455 s_out_z_stb <= 0;
456 state <= get_a;
457 end
458 end
459
460 endcase
461
462 if (rst == 1) begin
463 state <= get_a;
464 s_in_a_ack <= 0;
465 s_in_b_ack <= 0;
466 s_out_z_stb <= 0;
467 end
468
469 end
470 assign in_a_ack = s_in_a_ack;
471 assign in_b_ack = s_in_b_ack;
472 assign out_z_stb = s_out_z_stb;
473 assign out_z = s_out_z;
474
475 endmodule
476 """
477
478 if __name__ == "__main__":
479 alu = FPADD(width=32)
480 main(alu, ports=[
481 alu.in_a, alu.in_a_stb, alu.in_a_ack,
482 alu.in_b, alu.in_b_stb, alu.in_b_ack,
483 alu.out_z, alu.out_z_stb, alu.out_z_ack,
484 ])