Convert add and sub to return PartitionedSignal
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable
6 from nmigen.back.pysim import Simulator, Delay
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import PartitionedSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15 import math
16
17 def first_zero(x):
18 res = 0
19 for i in range(16):
20 if x & (1<<i):
21 return res
22 res += 1
23
24 def count_bits(x):
25 res = 0
26 for i in range(16):
27 if x & (1<<i):
28 res += 1
29 return res
30
31
32 def perms(k):
33 return map(''.join, itertools.product('01', repeat=k))
34
35
36 def create_ilang(dut, traces, test_name):
37 vl = rtlil.convert(dut, ports=traces)
38 with open("%s.il" % test_name, "w") as f:
39 f.write(vl)
40
41
42 def create_simulator(module, traces, test_name):
43 create_ilang(module, traces, test_name)
44 return Simulator(module)
45
46
47 # XXX this is for coriolis2 experimentation
48 class TestAddMod2(Elaboratable):
49 def __init__(self, width, partpoints):
50 self.partpoints = partpoints
51 self.a = PartitionedSignal(partpoints, width)
52 self.b = PartitionedSignal(partpoints, width)
53 self.bsig = Signal(width)
54 self.add_output = Signal(width)
55 self.ls_output = Signal(width) # left shift
56 self.ls_scal_output = Signal(width) # left shift
57 self.rs_output = Signal(width) # right shift
58 self.rs_scal_output = Signal(width) # right shift
59 self.sub_output = Signal(width)
60 self.eq_output = Signal(len(partpoints)+1)
61 self.gt_output = Signal(len(partpoints)+1)
62 self.ge_output = Signal(len(partpoints)+1)
63 self.ne_output = Signal(len(partpoints)+1)
64 self.lt_output = Signal(len(partpoints)+1)
65 self.le_output = Signal(len(partpoints)+1)
66 self.mux_sel = Signal(len(partpoints)+1)
67 self.mux_out = Signal(width)
68 self.carry_in = Signal(len(partpoints)+1)
69 self.add_carry_out = Signal(len(partpoints)+1)
70 self.sub_carry_out = Signal(len(partpoints)+1)
71 self.neg_output = Signal(width)
72
73 def elaborate(self, platform):
74 m = Module()
75 comb = m.d.comb
76 sync = m.d.sync
77 self.a.set_module(m)
78 self.b.set_module(m)
79 # compares
80 sync += self.lt_output.eq(self.a < self.b)
81 sync += self.ne_output.eq(self.a != self.b)
82 sync += self.le_output.eq(self.a <= self.b)
83 sync += self.gt_output.eq(self.a > self.b)
84 sync += self.eq_output.eq(self.a == self.b)
85 sync += self.ge_output.eq(self.a >= self.b)
86 # add
87 add_out, add_carry = self.a.add_op(self.a, self.b,
88 self.carry_in)
89 sync += self.add_output.eq(add_out)
90 sync += self.add_carry_out.eq(add_carry)
91 # sub
92 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
93 self.carry_in)
94 sync += self.sub_output.eq(sub_out)
95 sync += self.sub_carry_out.eq(sub_carry)
96 # neg
97 sync += self.neg_output.eq(-self.a)
98 # left shift
99 sync += self.ls_output.eq(self.a << self.b)
100 sync += self.rs_output.eq(self.a >> self.b)
101 ppts = self.partpoints
102 sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
103 # scalar left shift
104 comb += self.bsig.eq(self.b.sig)
105 sync += self.ls_scal_output.eq(self.a << self.bsig)
106 sync += self.rs_scal_output.eq(self.a >> self.bsig)
107
108 return m
109
110
111 class TestAddMod(Elaboratable):
112 def __init__(self, width, partpoints):
113 self.partpoints = partpoints
114 self.a = PartitionedSignal(partpoints, width)
115 self.b = PartitionedSignal(partpoints, width)
116 self.bsig = Signal(width)
117 self.add_output = Signal(width)
118 self.ls_output = Signal(width) # left shift
119 self.ls_scal_output = Signal(width) # left shift
120 self.rs_output = Signal(width) # right shift
121 self.rs_scal_output = Signal(width) # right shift
122 self.sub_output = Signal(width)
123 self.eq_output = Signal(len(partpoints)+1)
124 self.gt_output = Signal(len(partpoints)+1)
125 self.ge_output = Signal(len(partpoints)+1)
126 self.ne_output = Signal(len(partpoints)+1)
127 self.lt_output = Signal(len(partpoints)+1)
128 self.le_output = Signal(len(partpoints)+1)
129 self.mux_sel = Signal(len(partpoints)+1)
130 self.mux_out = Signal(width)
131 self.carry_in = Signal(len(partpoints)+1)
132 self.add_carry_out = Signal(len(partpoints)+1)
133 self.sub_carry_out = Signal(len(partpoints)+1)
134 self.neg_output = Signal(width)
135
136 def elaborate(self, platform):
137 m = Module()
138 comb = m.d.comb
139 sync = m.d.sync
140 self.a.set_module(m)
141 self.b.set_module(m)
142 # compares
143 comb += self.lt_output.eq(self.a < self.b)
144 comb += self.ne_output.eq(self.a != self.b)
145 comb += self.le_output.eq(self.a <= self.b)
146 comb += self.gt_output.eq(self.a > self.b)
147 comb += self.eq_output.eq(self.a == self.b)
148 comb += self.ge_output.eq(self.a >= self.b)
149 # add
150 add_out, add_carry = self.a.add_op(self.a, self.b,
151 self.carry_in)
152 comb += self.add_output.eq(add_out.sig)
153 comb += self.add_carry_out.eq(add_carry)
154 # sub
155 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
156 self.carry_in)
157 comb += self.sub_output.eq(sub_out.sig)
158 comb += self.sub_carry_out.eq(sub_carry)
159 # neg
160 comb += self.neg_output.eq((-self.a).sig)
161 # left shift
162 comb += self.ls_output.eq(self.a << self.b)
163 # right shift
164 comb += self.rs_output.eq(self.a >> self.b)
165 ppts = self.partpoints
166 # mux
167 comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
168 # scalar left shift
169 comb += self.bsig.eq(self.b.sig)
170 comb += self.ls_scal_output.eq(self.a << self.bsig)
171 # scalar right shift
172 comb += self.rs_scal_output.eq(self.a >> self.bsig)
173
174 return m
175
176
177 class TestPartitionPoints(unittest.TestCase):
178 def test(self):
179 width = 16
180 part_mask = Signal(4) # divide into 4-bits
181 module = TestAddMod(width, part_mask)
182
183 test_name = "part_sig_add"
184 traces = [part_mask,
185 module.a.sig,
186 module.b.sig,
187 module.add_output,
188 module.eq_output]
189 sim = create_simulator(module, traces, test_name)
190
191 def async_process():
192
193 def test_ls_scal_fn(carry_in, a, b, mask):
194 # reduce range of b
195 bits = count_bits(mask)
196 newb = b & ((bits-1))
197 print ("%x %x %x bits %d trunc %x" % \
198 (a, b, mask, bits, newb))
199 b = newb
200 # TODO: carry
201 carry_in = 0
202 lsb = mask & ~(mask-1) if carry_in else 0
203 sum = ((a & mask) << b)
204 result = mask & sum
205 carry = (sum & mask) != sum
206 carry = 0
207 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
208 return result, carry
209
210 def test_rs_scal_fn(carry_in, a, b, mask):
211 # reduce range of b
212 bits = count_bits(mask)
213 newb = b & ((bits-1))
214 print ("%x %x %x bits %d trunc %x" % \
215 (a, b, mask, bits, newb))
216 b = newb
217 # TODO: carry
218 carry_in = 0
219 lsb = mask & ~(mask-1) if carry_in else 0
220 sum = ((a & mask) >> b)
221 result = mask & sum
222 carry = (sum & mask) != sum
223 carry = 0
224 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
225 return result, carry
226
227 def test_ls_fn(carry_in, a, b, mask):
228 # reduce range of b
229 bits = count_bits(mask)
230 fz = first_zero(mask)
231 newb = b & ((bits-1)<<fz)
232 print ("%x %x %x bits %d zero %d trunc %x" % \
233 (a, b, mask, bits, fz, newb))
234 b = newb
235 # TODO: carry
236 carry_in = 0
237 lsb = mask & ~(mask-1) if carry_in else 0
238 b = (b & mask)
239 b = b >>fz
240 sum = ((a & mask) << b)
241 result = mask & sum
242 carry = (sum & mask) != sum
243 carry = 0
244 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
245 return result, carry
246
247 def test_rs_fn(carry_in, a, b, mask):
248 # reduce range of b
249 bits = count_bits(mask)
250 fz = first_zero(mask)
251 newb = b & ((bits-1)<<fz)
252 print ("%x %x %x bits %d zero %d trunc %x" % \
253 (a, b, mask, bits, fz, newb))
254 b = newb
255 # TODO: carry
256 carry_in = 0
257 lsb = mask & ~(mask-1) if carry_in else 0
258 b = (b & mask)
259 b = b >>fz
260 sum = ((a & mask) >> b)
261 result = mask & sum
262 carry = (sum & mask) != sum
263 carry = 0
264 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
265 return result, carry
266
267 def test_add_fn(carry_in, a, b, mask):
268 lsb = mask & ~(mask-1) if carry_in else 0
269 sum = (a & mask) + (b & mask) + lsb
270 result = mask & sum
271 carry = (sum & mask) != sum
272 print(a, b, sum, mask)
273 return result, carry
274
275 def test_sub_fn(carry_in, a, b, mask):
276 lsb = mask & ~(mask-1) if carry_in else 0
277 sum = (a & mask) + (~b & mask) + lsb
278 result = mask & sum
279 carry = (sum & mask) != sum
280 return result, carry
281
282 def test_neg_fn(carry_in, a, b, mask):
283 lsb = mask & ~(mask - 1) # has only LSB of mask set
284 pos = lsb.bit_length() - 1 # find bit position
285 a = (a & mask) >> pos # shift it to the beginning
286 return ((-a) << pos) & mask, 0 # negate and shift it back
287
288 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
289 rand_data = []
290 for i in range(100):
291 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
292 rand_data.append((a, b))
293 for a, b in [(0x0000, 0x0000),
294 (0x1234, 0x1234),
295 (0xABCD, 0xABCD),
296 (0xFFFF, 0x0000),
297 (0x0000, 0x0000),
298 (0xFFFF, 0xFFFF),
299 (0x0000, 0xFFFF)] + rand_data:
300 yield module.a.eq(a)
301 yield module.b.eq(b)
302 carry_sig = 0xf if carry else 0
303 yield module.carry_in.eq(carry_sig)
304 yield Delay(0.1e-6)
305 y = 0
306 carry_result = 0
307 for i, mask in enumerate(mask_list):
308 print ("i/mask", i, hex(mask))
309 res, c = test_fn(carry, a, b, mask)
310 y |= res
311 lsb = mask & ~(mask - 1)
312 bit_set = int(math.log2(lsb))
313 carry_result |= c << int(bit_set/4)
314 outval = (yield getattr(module, "%s_output" % mod_attr))
315 # TODO: get (and test) carry output as well
316 print(a, b, outval, carry)
317 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
318 f" => 0x{y:X} != 0x{outval:X}"
319 self.assertEqual(y, outval, msg)
320 if hasattr(module, "%s_carry_out" % mod_attr):
321 c_outval = (yield getattr(module,
322 "%s_carry_out" % mod_attr))
323 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
324 f" => 0x{carry_result:X} != 0x{c_outval:X}"
325 self.assertEqual(carry_result, c_outval, msg)
326
327 for (test_fn, mod_attr) in (
328 (test_ls_scal_fn, "ls_scal"),
329 (test_ls_fn, "ls"),
330 (test_rs_scal_fn, "rs_scal"),
331 (test_rs_fn, "rs"),
332 (test_add_fn, "add"),
333 (test_sub_fn, "sub"),
334 (test_neg_fn, "neg"),
335 ):
336 yield part_mask.eq(0)
337 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
338 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
339 yield part_mask.eq(0b10)
340 yield from test_op("8-bit", 0, test_fn, mod_attr,
341 0xFF00, 0x00FF)
342 yield from test_op("8-bit", 1, test_fn, mod_attr,
343 0xFF00, 0x00FF)
344 yield part_mask.eq(0b1111)
345 yield from test_op("4-bit", 0, test_fn, mod_attr,
346 0xF000, 0x0F00, 0x00F0, 0x000F)
347 yield from test_op("4-bit", 1, test_fn, mod_attr,
348 0xF000, 0x0F00, 0x00F0, 0x000F)
349
350 def test_ne_fn(a, b, mask):
351 return (a & mask) != (b & mask)
352
353 def test_lt_fn(a, b, mask):
354 return (a & mask) < (b & mask)
355
356 def test_le_fn(a, b, mask):
357 return (a & mask) <= (b & mask)
358
359 def test_eq_fn(a, b, mask):
360 return (a & mask) == (b & mask)
361
362 def test_gt_fn(a, b, mask):
363 return (a & mask) > (b & mask)
364
365 def test_ge_fn(a, b, mask):
366 return (a & mask) >= (b & mask)
367
368 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
369 for a, b in [(0x0000, 0x0000),
370 (0x1234, 0x1234),
371 (0xABCD, 0xABCD),
372 (0xFFFF, 0x0000),
373 (0x0000, 0x0000),
374 (0xFFFF, 0xFFFF),
375 (0x0000, 0xFFFF),
376 (0xABCD, 0xABCE),
377 (0x8000, 0x0000),
378 (0xBEEF, 0xFEED)]:
379 yield module.a.eq(a)
380 yield module.b.eq(b)
381 yield Delay(0.1e-6)
382 # convert to mask_list
383 mask_list = []
384 for mb in maskbit_list:
385 v = 0
386 for i in range(4):
387 if mb & (1 << i):
388 v |= 0xf << (i*4)
389 mask_list.append(v)
390 y = 0
391 # do the partitioned tests
392 for i, mask in enumerate(mask_list):
393 if test_fn(a, b, mask):
394 # OR y with the lowest set bit in the mask
395 y |= maskbit_list[i]
396 # check the result
397 outval = (yield getattr(module, "%s_output" % mod_attr))
398 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
399 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
400 print((msg % str(maskbit_list)).format(locals()))
401 self.assertEqual(y, outval, msg % str(maskbit_list))
402
403 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
404 (test_gt_fn, "gt"),
405 (test_ge_fn, "ge"),
406 (test_lt_fn, "lt"),
407 (test_le_fn, "le"),
408 (test_ne_fn, "ne"),
409 ):
410 yield part_mask.eq(0)
411 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
412 yield part_mask.eq(0b10)
413 yield from test_binop("8-bit", test_fn, mod_attr,
414 0b1100, 0b0011)
415 yield part_mask.eq(0b1111)
416 yield from test_binop("4-bit", test_fn, mod_attr,
417 0b1000, 0b0100, 0b0010, 0b0001)
418
419 def test_muxop(msg_prefix, *maskbit_list):
420 for a, b in [(0x0000, 0x0000),
421 (0x1234, 0x1234),
422 (0xABCD, 0xABCD),
423 (0xFFFF, 0x0000),
424 (0x0000, 0x0000),
425 (0xFFFF, 0xFFFF),
426 (0x0000, 0xFFFF)]:
427 # convert to mask_list
428 mask_list = []
429 for mb in maskbit_list:
430 v = 0
431 for i in range(4):
432 if mb & (1 << i):
433 v |= 0xf << (i*4)
434 mask_list.append(v)
435
436 # TODO: sel needs to go through permutations of mask_list
437 for p in perms(len(mask_list)):
438
439 sel = 0
440 selmask = 0
441 for i, v in enumerate(p):
442 if v == '1':
443 sel |= maskbit_list[i]
444 selmask |= mask_list[i]
445
446 yield module.a.eq(a)
447 yield module.b.eq(b)
448 yield module.mux_sel.eq(sel)
449 yield Delay(0.1e-6)
450 y = 0
451 # do the partitioned tests
452 for i, mask in enumerate(mask_list):
453 if (selmask & mask):
454 y |= (a & mask)
455 else:
456 y |= (b & mask)
457 # check the result
458 outval = (yield module.mux_out)
459 msg = f"{msg_prefix}: mux " + \
460 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
461 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
462 # print ((msg % str(maskbit_list)).format(locals()))
463 self.assertEqual(y, outval, msg % str(maskbit_list))
464
465 yield part_mask.eq(0)
466 yield from test_muxop("16-bit", 0b1111)
467 yield part_mask.eq(0b10)
468 yield from test_muxop("8-bit", 0b1100, 0b0011)
469 yield part_mask.eq(0b1111)
470 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
471
472 sim.add_process(async_process)
473 with sim.write_vcd(
474 vcd_file=open(test_name + ".vcd", "w"),
475 gtkw_file=open(test_name + ".gtkw", "w"),
476 traces=traces):
477 sim.run()
478
479
480 if __name__ == '__main__':
481 unittest.main()