big rename PartitionedSignal to SimdSignal (shorter)
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
6 from nmigen.back.pysim import Simulator, Delay, Settle
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import SimdSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15 import math
16
17
18 def first_zero(x):
19 res = 0
20 for i in range(16):
21 if x & (1<<i):
22 return res
23 res += 1
24
25 def count_bits(x):
26 res = 0
27 for i in range(16):
28 if x & (1<<i):
29 res += 1
30 return res
31
32
33 def perms(k):
34 return map(''.join, itertools.product('01', repeat=k))
35
36
37 def create_ilang(dut, traces, test_name):
38 vl = rtlil.convert(dut, ports=traces)
39 with open("%s.il" % test_name, "w") as f:
40 f.write(vl)
41
42
43 def create_simulator(module, traces, test_name):
44 create_ilang(module, traces, test_name)
45 return Simulator(module)
46
47
48 # XXX this is for coriolis2 experimentation
49 class TestAddMod2(Elaboratable):
50 def __init__(self, width, partpoints):
51 self.partpoints = partpoints
52 self.a = SimdSignal(partpoints, width)
53 self.b = SimdSignal(partpoints, width)
54 self.bsig = Signal(width)
55 self.add_output = Signal(width)
56 self.ls_output = Signal(width) # left shift
57 self.ls_scal_output = Signal(width) # left shift
58 self.rs_output = Signal(width) # right shift
59 self.rs_scal_output = Signal(width) # right shift
60 self.sub_output = Signal(width)
61 self.eq_output = Signal(len(partpoints)+1)
62 self.gt_output = Signal(len(partpoints)+1)
63 self.ge_output = Signal(len(partpoints)+1)
64 self.ne_output = Signal(len(partpoints)+1)
65 self.lt_output = Signal(len(partpoints)+1)
66 self.le_output = Signal(len(partpoints)+1)
67 self.mux_sel2 = Signal(len(partpoints)+1)
68 self.mux_sel2 = SimdSignal(partpoints, len(partpoints))
69 self.mux2_out = Signal(width)
70 self.carry_in = Signal(len(partpoints)+1)
71 self.add_carry_out = Signal(len(partpoints)+1)
72 self.sub_carry_out = Signal(len(partpoints)+1)
73 self.neg_output = Signal(width)
74
75 def elaborate(self, platform):
76 m = Module()
77 comb = m.d.comb
78 sync = m.d.sync
79 self.a.set_module(m)
80 self.b.set_module(m)
81 self.mux_sel2.set_module(m)
82 # compares
83 sync += self.lt_output.eq(self.a < self.b)
84 sync += self.ne_output.eq(self.a != self.b)
85 sync += self.le_output.eq(self.a <= self.b)
86 sync += self.gt_output.eq(self.a > self.b)
87 sync += self.eq_output.eq(self.a == self.b)
88 sync += self.ge_output.eq(self.a >= self.b)
89 # add
90 add_out, add_carry = self.a.add_op(self.a, self.b,
91 self.carry_in)
92 sync += self.add_output.eq(add_out)
93 sync += self.add_carry_out.eq(add_carry)
94 # sub
95 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
96 self.carry_in)
97 sync += self.sub_output.eq(sub_out)
98 sync += self.sub_carry_out.eq(sub_carry)
99 # neg
100 sync += self.neg_output.eq(-self.a)
101 # left shift
102 sync += self.ls_output.eq(self.a << self.b)
103 sync += self.rs_output.eq(self.a >> self.b)
104 ppts = self.partpoints
105 sync += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
106 # scalar left shift
107 comb += self.bsig.eq(self.b.lower())
108 sync += self.ls_scal_output.eq(self.a << self.bsig)
109 sync += self.rs_scal_output.eq(self.a >> self.bsig)
110
111 return m
112
113
114 class TestMuxMod(Elaboratable):
115 def __init__(self, width, partpoints):
116 self.partpoints = partpoints
117 self.a = SimdSignal(partpoints, width)
118 self.b = SimdSignal(partpoints, width)
119 self.mux_sel = Signal(len(partpoints)+1)
120 self.mux_sel2 = SimdSignal(partpoints, len(partpoints)+1)
121 self.mux_out2 = Signal(width)
122
123 def elaborate(self, platform):
124 m = Module()
125 comb = m.d.comb
126 sync = m.d.sync
127 self.a.set_module(m)
128 self.b.set_module(m)
129 self.mux_sel2.set_module(m)
130 ppts = self.partpoints
131
132 comb += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
133
134 return m
135
136
137 class TestCatMod(Elaboratable):
138 def __init__(self, width, partpoints):
139 self.partpoints = partpoints
140 self.a = SimdSignal(partpoints, width)
141 self.b = SimdSignal(partpoints, width*2)
142 self.cat_out = Signal(width*3)
143
144 def elaborate(self, platform):
145 m = Module()
146 comb = m.d.comb
147 self.a.set_module(m)
148 self.b.set_module(m)
149
150 comb += self.cat_out.eq(Cat(self.a, self.b))
151
152 return m
153
154
155 class TestReplMod(Elaboratable):
156 def __init__(self, width, partpoints):
157 self.partpoints = partpoints
158 self.a = SimdSignal(partpoints, width)
159 self.repl_sel = Signal(len(partpoints)+1)
160 self.repl_out = Signal(width*2)
161
162 def elaborate(self, platform):
163 m = Module()
164 comb = m.d.comb
165 self.a.set_module(m)
166
167 comb += self.repl_out.eq(Repl(self.a, 2))
168
169 return m
170
171
172 class TestAssMod(Elaboratable):
173 def __init__(self, width, out_shape, partpoints, scalar):
174 self.partpoints = partpoints
175 self.scalar = scalar
176 if scalar:
177 self.a = Signal(width)
178 else:
179 self.a = SimdSignal(partpoints, width)
180 self.ass_out = SimdSignal(partpoints, out_shape)
181
182 def elaborate(self, platform):
183 m = Module()
184 comb = m.d.comb
185 if not self.scalar:
186 self.a.set_module(m)
187 self.ass_out.set_module(m)
188
189 comb += self.ass_out.eq(self.a)
190
191 return m
192
193
194 class TestAddMod(Elaboratable):
195 def __init__(self, width, partpoints):
196 self.partpoints = partpoints
197 self.a = SimdSignal(partpoints, width)
198 self.b = SimdSignal(partpoints, width)
199 self.bsig = Signal(width)
200 self.add_output = Signal(width)
201 self.ls_output = Signal(width) # left shift
202 self.ls_scal_output = Signal(width) # left shift
203 self.rs_output = Signal(width) # right shift
204 self.rs_scal_output = Signal(width) # right shift
205 self.sub_output = Signal(width)
206 self.eq_output = Signal(len(partpoints)+1)
207 self.gt_output = Signal(len(partpoints)+1)
208 self.ge_output = Signal(len(partpoints)+1)
209 self.ne_output = Signal(len(partpoints)+1)
210 self.lt_output = Signal(len(partpoints)+1)
211 self.le_output = Signal(len(partpoints)+1)
212 self.carry_in = Signal(len(partpoints)+1)
213 self.add_carry_out = Signal(len(partpoints)+1)
214 self.sub_carry_out = Signal(len(partpoints)+1)
215 self.neg_output = Signal(width)
216 self.signed_output = Signal(width)
217 self.xor_output = Signal(len(partpoints)+1)
218 self.bool_output = Signal(len(partpoints)+1)
219 self.all_output = Signal(len(partpoints)+1)
220 self.any_output = Signal(len(partpoints)+1)
221
222 def elaborate(self, platform):
223 m = Module()
224 comb = m.d.comb
225 sync = m.d.sync
226 self.a.set_module(m)
227 self.b.set_module(m)
228 # compares
229 comb += self.lt_output.eq(self.a < self.b)
230 comb += self.ne_output.eq(self.a != self.b)
231 comb += self.le_output.eq(self.a <= self.b)
232 comb += self.gt_output.eq(self.a > self.b)
233 comb += self.eq_output.eq(self.a == self.b)
234 comb += self.ge_output.eq(self.a >= self.b)
235 # add
236 add_out, add_carry = self.a.add_op(self.a, self.b,
237 self.carry_in)
238 comb += self.add_output.eq(add_out.sig)
239 comb += self.add_carry_out.eq(add_carry)
240 # sub
241 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
242 self.carry_in)
243 comb += self.sub_output.eq(sub_out.sig)
244 comb += self.sub_carry_out.eq(sub_carry)
245 # neg / signed / unsigned
246 comb += self.neg_output.eq((-self.a).sig)
247 comb += self.signed_output.eq(self.a.as_signed())
248 # horizontal operators
249 comb += self.xor_output.eq(self.a.xor())
250 comb += self.bool_output.eq(self.a.bool())
251 comb += self.all_output.eq(self.a.all())
252 comb += self.any_output.eq(self.a.any())
253 # left shift
254 comb += self.ls_output.eq(self.a << self.b)
255 # right shift
256 comb += self.rs_output.eq(self.a >> self.b)
257 ppts = self.partpoints
258 # scalar left shift
259 comb += self.bsig.eq(self.b.lower())
260 comb += self.ls_scal_output.eq(self.a << self.bsig)
261 # scalar right shift
262 comb += self.rs_scal_output.eq(self.a >> self.bsig)
263
264 return m
265
266
267 class TestMux(unittest.TestCase):
268 def test(self):
269 width = 16
270 part_mask = Signal(3) # divide into 4-bits
271 module = TestMuxMod(width, part_mask)
272
273 test_name = "part_sig_mux"
274 traces = [part_mask,
275 module.a.sig,
276 module.b.sig,
277 module.mux_out2]
278 sim = create_simulator(module, traces, test_name)
279
280 def async_process():
281
282 def test_muxop(msg_prefix, *maskbit_list):
283 for a, b in [(0x0000, 0x0000),
284 (0x1234, 0x1234),
285 (0xABCD, 0xABCD),
286 (0xFFFF, 0x0000),
287 (0x0000, 0x0000),
288 (0xFFFF, 0xFFFF),
289 (0x0000, 0xFFFF)]:
290 # convert to mask_list
291 mask_list = []
292 for mb in maskbit_list:
293 v = 0
294 for i in range(4):
295 if mb & (1 << i):
296 v |= 0xf << (i*4)
297 mask_list.append(v)
298
299 # TODO: sel needs to go through permutations of mask_list
300 for p in perms(len(mask_list)):
301
302 sel = 0
303 selmask = 0
304 for i, v in enumerate(p):
305 if v == '1':
306 sel |= maskbit_list[i]
307 selmask |= mask_list[i]
308
309 yield module.a.lower().eq(a)
310 yield module.b.lower().eq(b)
311 yield module.mux_sel.eq(sel)
312 yield module.mux_sel2.lower().eq(sel)
313 yield Delay(0.1e-6)
314 y = 0
315 # do the partitioned tests
316 for i, mask in enumerate(mask_list):
317 if (selmask & mask):
318 y |= (a & mask)
319 else:
320 y |= (b & mask)
321 # check the result
322 outval2 = (yield module.mux_out2)
323 msg = f"{msg_prefix}: mux " + \
324 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
325 f" => 0x{y:X} != 0x{outval2:X}, masklist %s"
326 # print ((msg % str(maskbit_list)).format(locals()))
327 self.assertEqual(y, outval2, msg % str(maskbit_list))
328
329 yield part_mask.eq(0)
330 yield from test_muxop("16-bit", 0b1111)
331 yield part_mask.eq(0b10)
332 yield from test_muxop("8-bit", 0b1100, 0b0011)
333 yield part_mask.eq(0b1111)
334 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
335
336 sim.add_process(async_process)
337 with sim.write_vcd(
338 vcd_file=open(test_name + ".vcd", "w"),
339 gtkw_file=open(test_name + ".gtkw", "w"),
340 traces=traces):
341 sim.run()
342
343
344 class TestCat(unittest.TestCase):
345 def test(self):
346 width = 16
347 part_mask = Signal(3) # divide into 4-bits
348 module = TestCatMod(width, part_mask)
349
350 test_name = "part_sig_cat"
351 traces = [part_mask,
352 module.a.sig,
353 module.b.sig,
354 module.cat_out]
355 sim = create_simulator(module, traces, test_name)
356
357 # annoying recursive import issue
358 from ieee754.part_cat.cat import get_runlengths
359
360 def async_process():
361
362 def test_catop(msg_prefix):
363 # define lengths of a/b test input
364 alen, blen = 16, 32
365 # pairs of test values a, b
366 for a, b in [(0x0000, 0x00000000),
367 (0xDCBA, 0x12345678),
368 (0xABCD, 0x01234567),
369 (0xFFFF, 0x0000),
370 (0x0000, 0x0000),
371 (0x1F1F, 0xF1F1F1F1),
372 (0x0000, 0xFFFFFFFF)]:
373
374 # convert a and b to partitions
375 apart, bpart = [], []
376 ajump, bjump = alen // 4, blen // 4
377 for i in range(4):
378 apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
379 bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
380
381 print ("apart bpart", hex(a), hex(b),
382 list(map(hex, apart)), list(map(hex, bpart)))
383
384 yield module.a.lower().eq(a)
385 yield module.b.lower().eq(b)
386 yield Delay(0.1e-6)
387
388 y = 0
389 # work out the runlengths for this mask.
390 # 0b011 returns [1,1,2] (for a mask of length 3)
391 mval = yield part_mask
392 runlengths = get_runlengths(mval, 3)
393 j = 0
394 ai = 0
395 bi = 0
396 for i in runlengths:
397 # a first
398 for _ in range(i):
399 print ("runlength", i,
400 "ai", ai,
401 "apart", hex(apart[ai]),
402 "j", j)
403 y |= apart[ai] << j
404 print (" y", hex(y))
405 j += ajump
406 ai += 1
407 # now b
408 for _ in range(i):
409 print ("runlength", i,
410 "bi", bi,
411 "bpart", hex(bpart[bi]),
412 "j", j)
413 y |= bpart[bi] << j
414 print (" y", hex(y))
415 j += bjump
416 bi += 1
417
418 # check the result
419 outval = (yield module.cat_out)
420 msg = f"{msg_prefix}: cat " + \
421 f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
422 f" => 0x{y:X} != 0x{outval:X}"
423 self.assertEqual(y, outval, msg)
424
425 yield part_mask.eq(0)
426 yield from test_catop("16-bit")
427 yield part_mask.eq(0b10)
428 yield from test_catop("8-bit")
429 yield part_mask.eq(0b1111)
430 yield from test_catop("4-bit")
431
432 sim.add_process(async_process)
433 with sim.write_vcd(
434 vcd_file=open(test_name + ".vcd", "w"),
435 gtkw_file=open(test_name + ".gtkw", "w"),
436 traces=traces):
437 sim.run()
438
439
440 class TestRepl(unittest.TestCase):
441 def test(self):
442 width = 16
443 part_mask = Signal(3) # divide into 4-bits
444 module = TestReplMod(width, part_mask)
445
446 test_name = "part_sig_repl"
447 traces = [part_mask,
448 module.a.sig,
449 module.repl_out]
450 sim = create_simulator(module, traces, test_name)
451
452 # annoying recursive import issue
453 from ieee754.part_repl.repl import get_runlengths
454
455 def async_process():
456
457 def test_replop(msg_prefix):
458 # define lengths of a/b test input
459 alen, blen = 16, 32
460 # pairs of test values a, b
461 for a, b in [(0x0000, 0x00000000),
462 (0xDCBA, 0x12345678),
463 (0xABCD, 0x01234567),
464 (0xFFFF, 0x0000),
465 (0x0000, 0x0000),
466 (0x1F1F, 0xF1F1F1F1),
467 (0x0000, 0xFFFFFFFF)]:
468
469 # convert a and b to partitions
470 apart, bpart = [], []
471 ajump, bjump = alen // 4, blen // 4
472 for i in range(4):
473 apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
474 bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
475
476 print ("apart bpart", hex(a), hex(b),
477 list(map(hex, apart)), list(map(hex, bpart)))
478
479 yield module.a.lower().eq(a)
480 yield Delay(0.1e-6)
481
482 y = 0
483 # work out the runlengths for this mask.
484 # 0b011 returns [1,1,2] (for a mask of length 3)
485 mval = yield part_mask
486 runlengths = get_runlengths(mval, 3)
487 j = 0
488 ai = 0
489 bi = 0
490 for i in runlengths:
491 # a first
492 for _ in range(i):
493 print ("runlength", i,
494 "ai", ai,
495 "apart", hex(apart[ai]),
496 "j", j)
497 y |= apart[ai] << j
498 print (" y", hex(y))
499 j += ajump
500 ai += 1
501 # now b
502 for _ in range(i):
503 print ("runlength", i,
504 "bi", bi,
505 "bpart", hex(bpart[bi]),
506 "j", j)
507 y |= bpart[bi] << j
508 print (" y", hex(y))
509 j += bjump
510 bi += 1
511
512 # check the result
513 outval = (yield module.repl_out)
514 msg = f"{msg_prefix}: repl " + \
515 f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
516 f" => 0x{y:X} != 0x{outval:X}"
517 self.assertEqual(y, outval, msg)
518
519 yield part_mask.eq(0)
520 yield from test_replop("16-bit")
521 yield part_mask.eq(0b10)
522 yield from test_replop("8-bit")
523 yield part_mask.eq(0b1111)
524 yield from test_replop("4-bit")
525
526 sim.add_process(async_process)
527 with sim.write_vcd(
528 vcd_file=open(test_name + ".vcd", "w"),
529 gtkw_file=open(test_name + ".gtkw", "w"),
530 traces=traces):
531 sim.run()
532
533
534 class TestAssign(unittest.TestCase):
535 def run_tst(self, in_width, out_width, out_signed, scalar):
536 part_mask = Signal(3) # divide into 4-bits
537 module = TestAssMod(in_width,
538 Shape(out_width, out_signed),
539 part_mask, scalar)
540
541 test_name = "part_sig_ass_%d_%d_%s_%s" % (in_width, out_width,
542 "signed" if out_signed else "unsigned",
543 "scalar" if scalar else "partitioned")
544
545 traces = [part_mask,
546 module.ass_out.lower()]
547 if module.scalar:
548 traces.append(module.a)
549 else:
550 traces.append(module.a.lower())
551 sim = create_simulator(module, traces, test_name)
552
553 # annoying recursive import issue
554 from ieee754.part_cat.cat import get_runlengths
555
556 def async_process():
557
558 def test_assop(msg_prefix):
559 # define lengths of a test input
560 alen = in_width
561 randomvals = []
562 for i in range(10):
563 randomvals.append(randint(0, 65535))
564 # test values a
565 for a in [0x0001,
566 0x0010,
567 0x0100,
568 0x1000,
569 0x000c,
570 0x00c0,
571 0x0c00,
572 0xc000,
573 0x1234,
574 0xDCBA,
575 0xABCD,
576 0x0000,
577 0xFFFF,
578 ] + randomvals:
579 # work out the runlengths for this mask.
580 # 0b011 returns [1,1,2] (for a mask of length 3)
581 mval = yield part_mask
582 runlengths = get_runlengths(mval, 3)
583
584 print ("test a", hex(a), "mask", bin(mval), "widths",
585 in_width, out_width,
586 "signed", out_signed,
587 "scalar", scalar)
588
589 # convert a to runlengths sub-sections
590 apart = []
591 ajump = alen // 4
592 ai = 0
593 for i in runlengths:
594 subpart = (a >> (ajump*ai) & ((1<<(ajump*i))-1))
595 msb = (subpart >> ((ajump*i)-1)) # will contain the sign
596 apart.append((subpart, msb))
597 print ("apart", ajump*i, hex(a), hex(subpart), msb)
598 if not scalar:
599 ai += i
600
601 if scalar:
602 yield module.a.eq(a)
603 else:
604 yield module.a.lower().eq(a)
605 yield Delay(0.1e-6)
606
607 y = 0
608 j = 0
609 ojump = out_width // 4
610 for ai, i in enumerate(runlengths):
611 # get "a" partition value
612 av, amsb = apart[ai]
613 # do sign-extension if needed
614 signext = 0
615 if out_signed and ojump > ajump:
616 if amsb:
617 signext = (-1 << ajump*i) & ((1<<(ojump*i))-1)
618 av |= signext
619 # truncate if needed
620 if ojump < ajump:
621 av &= ((1<<(ojump*i))-1)
622 print ("runlength", i,
623 "ai", ai,
624 "apart", hex(av), amsb,
625 "signext", hex(signext),
626 "j", j)
627 y |= av << j
628 print (" y", hex(y))
629 j += ojump*i
630 ai += 1
631
632 y &= (1<<out_width)-1
633
634 # check the result
635 outval = (yield module.ass_out.lower())
636 outval &= (1<<out_width)-1
637 msg = f"{msg_prefix}: assign " + \
638 f"mask 0x{mval:X} input 0x{a:X}" + \
639 f" => expected 0x{y:X} != actual 0x{outval:X}"
640 self.assertEqual(y, outval, msg)
641
642 # run the actual tests, here - 16/8/4 bit partitions
643 for (mask, name) in ((0, "16-bit"),
644 (0b10, "8-bit"),
645 (0b111, "4-bit")):
646 with self.subTest(name + " " + test_name):
647 yield part_mask.eq(mask)
648 yield Settle()
649 yield from test_assop(name)
650
651 sim.add_process(async_process)
652 with sim.write_vcd(
653 vcd_file=open(test_name + ".vcd", "w"),
654 gtkw_file=open(test_name + ".gtkw", "w"),
655 traces=traces):
656 sim.run()
657
658 def test(self):
659 for out_width in [16, 24, 8]:
660 for sign in [True, False]:
661 for scalar in [True, False]:
662 self.run_tst(16, out_width, sign, scalar)
663
664
665 class TestSimdSignal(unittest.TestCase):
666 def test(self):
667 width = 16
668 part_mask = Signal(3) # divide into 4-bits
669 module = TestAddMod(width, part_mask)
670
671 test_name = "part_sig_add"
672 traces = [part_mask,
673 module.a.sig,
674 module.b.sig,
675 module.add_output,
676 module.eq_output]
677 sim = create_simulator(module, traces, test_name)
678
679 def async_process():
680
681 def test_xor_fn(a, mask):
682 test = (a & mask)
683 result = 0
684 while test != 0:
685 bit = (test & 1)
686 result ^= bit
687 test >>= 1
688 return result
689
690 def test_bool_fn(a, mask):
691 test = (a & mask)
692 return test != 0
693
694 def test_all_fn(a, mask):
695 # slightly different: all bits masked must be 1
696 test = (a & mask)
697 return test == mask
698
699 def test_horizop(msg_prefix, test_fn, mod_attr, *maskbit_list):
700 randomvals = []
701 for i in range(100):
702 randomvals.append(randint(0, 65535))
703 for a in [0x0000,
704 0x1111,
705 0x0001,
706 0x0010,
707 0x0100,
708 0x1000,
709 0x000F,
710 0x00F0,
711 0x0F00,
712 0xF000,
713 0x00FF,
714 0xFF00,
715 0x1234,
716 0xABCD,
717 0xFFFF,
718 0x8000,
719 0xBEEF, 0xFEED,
720 ]+randomvals:
721 with self.subTest("%s %s %s" % (msg_prefix,
722 test_fn.__name__, hex(a))):
723 yield module.a.lower().eq(a)
724 yield Delay(0.1e-6)
725 # convert to mask_list
726 mask_list = []
727 for mb in maskbit_list:
728 v = 0
729 for i in range(4):
730 if mb & (1 << i):
731 v |= 0xf << (i*4)
732 mask_list.append(v)
733 y = 0
734 # do the partitioned tests
735 for i, mask in enumerate(mask_list):
736 if test_fn(a, mask):
737 # OR y with the lowest set bit in the mask
738 y |= maskbit_list[i]
739 # check the result
740 outval = (yield getattr(module, "%s_output" % mod_attr))
741 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} " + \
742 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
743 print((msg % str(maskbit_list)).format(locals()))
744 self.assertEqual(y, outval, msg % str(maskbit_list))
745
746 for (test_fn, mod_attr) in ((test_xor_fn, "xor"),
747 (test_all_fn, "all"),
748 (test_bool_fn, "any"), # same as bool
749 (test_bool_fn, "bool"),
750 #(test_ne_fn, "ne"),
751 ):
752 yield part_mask.eq(0)
753 yield from test_horizop("16-bit", test_fn, mod_attr, 0b1111)
754 yield part_mask.eq(0b10)
755 yield from test_horizop("8-bit", test_fn, mod_attr,
756 0b1100, 0b0011)
757 yield part_mask.eq(0b1111)
758 yield from test_horizop("4-bit", test_fn, mod_attr,
759 0b1000, 0b0100, 0b0010, 0b0001)
760
761 def test_ls_scal_fn(carry_in, a, b, mask):
762 # reduce range of b
763 bits = count_bits(mask)
764 newb = b & ((bits-1))
765 print ("%x %x %x bits %d trunc %x" % \
766 (a, b, mask, bits, newb))
767 b = newb
768 # TODO: carry
769 carry_in = 0
770 lsb = mask & ~(mask-1) if carry_in else 0
771 sum = ((a & mask) << b)
772 result = mask & sum
773 carry = (sum & mask) != sum
774 carry = 0
775 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
776 return result, carry
777
778 def test_rs_scal_fn(carry_in, a, b, mask):
779 # reduce range of b
780 bits = count_bits(mask)
781 newb = b & ((bits-1))
782 print ("%x %x %x bits %d trunc %x" % \
783 (a, b, mask, bits, newb))
784 b = newb
785 # TODO: carry
786 carry_in = 0
787 lsb = mask & ~(mask-1) if carry_in else 0
788 sum = ((a & mask) >> b)
789 result = mask & sum
790 carry = (sum & mask) != sum
791 carry = 0
792 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
793 return result, carry
794
795 def test_ls_fn(carry_in, a, b, mask):
796 # reduce range of b
797 bits = count_bits(mask)
798 fz = first_zero(mask)
799 newb = b & ((bits-1)<<fz)
800 print ("%x %x %x bits %d zero %d trunc %x" % \
801 (a, b, mask, bits, fz, newb))
802 b = newb
803 # TODO: carry
804 carry_in = 0
805 lsb = mask & ~(mask-1) if carry_in else 0
806 b = (b & mask)
807 b = b >>fz
808 sum = ((a & mask) << b)
809 result = mask & sum
810 carry = (sum & mask) != sum
811 carry = 0
812 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
813 return result, carry
814
815 def test_rs_fn(carry_in, a, b, mask):
816 # reduce range of b
817 bits = count_bits(mask)
818 fz = first_zero(mask)
819 newb = b & ((bits-1)<<fz)
820 print ("%x %x %x bits %d zero %d trunc %x" % \
821 (a, b, mask, bits, fz, newb))
822 b = newb
823 # TODO: carry
824 carry_in = 0
825 lsb = mask & ~(mask-1) if carry_in else 0
826 b = (b & mask)
827 b = b >>fz
828 sum = ((a & mask) >> b)
829 result = mask & sum
830 carry = (sum & mask) != sum
831 carry = 0
832 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
833 return result, carry
834
835 def test_add_fn(carry_in, a, b, mask):
836 lsb = mask & ~(mask-1) if carry_in else 0
837 sum = (a & mask) + (b & mask) + lsb
838 result = mask & sum
839 carry = (sum & mask) != sum
840 print(a, b, sum, mask)
841 return result, carry
842
843 def test_sub_fn(carry_in, a, b, mask):
844 lsb = mask & ~(mask-1) if carry_in else 0
845 sum = (a & mask) + (~b & mask) + lsb
846 result = mask & sum
847 carry = (sum & mask) != sum
848 return result, carry
849
850 def test_neg_fn(carry_in, a, b, mask):
851 lsb = mask & ~(mask - 1) # has only LSB of mask set
852 pos = lsb.bit_length() - 1 # find bit position
853 a = (a & mask) >> pos # shift it to the beginning
854 return ((-a) << pos) & mask, 0 # negate and shift it back
855
856 def test_signed_fn(carry_in, a, b, mask):
857 return a & mask, 0
858
859 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
860 rand_data = []
861 for i in range(100):
862 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
863 rand_data.append((a, b))
864 for a, b in [(0x0000, 0x0000),
865 (0x1234, 0x1234),
866 (0xABCD, 0xABCD),
867 (0xFFFF, 0x0000),
868 (0x0000, 0x0000),
869 (0xFFFF, 0xFFFF),
870 (0x0000, 0xFFFF)] + rand_data:
871 yield module.a.lower().eq(a)
872 yield module.b.lower().eq(b)
873 carry_sig = 0xf if carry else 0
874 yield module.carry_in.eq(carry_sig)
875 yield Delay(0.1e-6)
876 y = 0
877 carry_result = 0
878 for i, mask in enumerate(mask_list):
879 print ("i/mask", i, hex(mask))
880 res, c = test_fn(carry, a, b, mask)
881 y |= res
882 lsb = mask & ~(mask - 1)
883 bit_set = int(math.log2(lsb))
884 carry_result |= c << int(bit_set/4)
885 outval = (yield getattr(module, "%s_output" % mod_attr))
886 # TODO: get (and test) carry output as well
887 print(a, b, outval, carry)
888 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
889 f" => 0x{y:X} != 0x{outval:X}"
890 self.assertEqual(y, outval, msg)
891 if hasattr(module, "%s_carry_out" % mod_attr):
892 c_outval = (yield getattr(module,
893 "%s_carry_out" % mod_attr))
894 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
895 f" => 0x{carry_result:X} != 0x{c_outval:X}"
896 self.assertEqual(carry_result, c_outval, msg)
897
898 # run through series of operations with corresponding
899 # "helper" routines to reproduce the result (test_fn). the same
900 # a/b input is passed to *all* outputs, where the name of the
901 # output attribute (mod_attr) will contain the result to be
902 # compared against the expected output from test_fn
903 for (test_fn, mod_attr) in (
904 (test_ls_scal_fn, "ls_scal"),
905 (test_ls_fn, "ls"),
906 (test_rs_scal_fn, "rs_scal"),
907 (test_rs_fn, "rs"),
908 (test_add_fn, "add"),
909 (test_sub_fn, "sub"),
910 (test_neg_fn, "neg"),
911 (test_signed_fn, "signed"),
912 ):
913 yield part_mask.eq(0)
914 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
915 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
916 yield part_mask.eq(0b10)
917 yield from test_op("8-bit", 0, test_fn, mod_attr,
918 0xFF00, 0x00FF)
919 yield from test_op("8-bit", 1, test_fn, mod_attr,
920 0xFF00, 0x00FF)
921 yield part_mask.eq(0b1111)
922 yield from test_op("4-bit", 0, test_fn, mod_attr,
923 0xF000, 0x0F00, 0x00F0, 0x000F)
924 yield from test_op("4-bit", 1, test_fn, mod_attr,
925 0xF000, 0x0F00, 0x00F0, 0x000F)
926
927 def test_ne_fn(a, b, mask):
928 return (a & mask) != (b & mask)
929
930 def test_lt_fn(a, b, mask):
931 return (a & mask) < (b & mask)
932
933 def test_le_fn(a, b, mask):
934 return (a & mask) <= (b & mask)
935
936 def test_eq_fn(a, b, mask):
937 return (a & mask) == (b & mask)
938
939 def test_gt_fn(a, b, mask):
940 return (a & mask) > (b & mask)
941
942 def test_ge_fn(a, b, mask):
943 return (a & mask) >= (b & mask)
944
945 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
946 for a, b in [(0x0000, 0x0000),
947 (0x1234, 0x1234),
948 (0xABCD, 0xABCD),
949 (0xFFFF, 0x0000),
950 (0x0000, 0x0000),
951 (0xFFFF, 0xFFFF),
952 (0x0000, 0xFFFF),
953 (0xABCD, 0xABCE),
954 (0x8000, 0x0000),
955 (0xBEEF, 0xFEED)]:
956 yield module.a.lower().eq(a)
957 yield module.b.lower().eq(b)
958 yield Delay(0.1e-6)
959 # convert to mask_list
960 mask_list = []
961 for mb in maskbit_list:
962 v = 0
963 for i in range(4):
964 if mb & (1 << i):
965 v |= 0xf << (i*4)
966 mask_list.append(v)
967 y = 0
968 # do the partitioned tests
969 for i, mask in enumerate(mask_list):
970 if test_fn(a, b, mask):
971 # OR y with the lowest set bit in the mask
972 y |= maskbit_list[i]
973 # check the result
974 outval = (yield getattr(module, "%s_output" % mod_attr))
975 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
976 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
977 print((msg % str(maskbit_list)).format(locals()))
978 self.assertEqual(y, outval, msg % str(maskbit_list))
979
980 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
981 (test_gt_fn, "gt"),
982 (test_ge_fn, "ge"),
983 (test_lt_fn, "lt"),
984 (test_le_fn, "le"),
985 (test_ne_fn, "ne"),
986 ):
987 yield part_mask.eq(0)
988 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
989 yield part_mask.eq(0b10)
990 yield from test_binop("8-bit", test_fn, mod_attr,
991 0b1100, 0b0011)
992 yield part_mask.eq(0b1111)
993 yield from test_binop("4-bit", test_fn, mod_attr,
994 0b1000, 0b0100, 0b0010, 0b0001)
995
996 sim.add_process(async_process)
997 with sim.write_vcd(
998 vcd_file=open(test_name + ".vcd", "w"),
999 gtkw_file=open(test_name + ".gtkw", "w"),
1000 traces=traces):
1001 sim.run()
1002
1003
1004 # TODO: adapt to SimdSignal. perhaps a different style?
1005 '''
1006 from nmigen.tests.test_hdl_ast import SignedEnum
1007 def test_matches(self)
1008 s = Signal(4)
1009 self.assertRepr(s.matches(), "(const 1'd0)")
1010 self.assertRepr(s.matches(1), """
1011 (== (sig s) (const 1'd1))
1012 """)
1013 self.assertRepr(s.matches(0, 1), """
1014 (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
1015 """)
1016 self.assertRepr(s.matches("10--"), """
1017 (== (& (sig s) (const 4'd12)) (const 4'd8))
1018 """)
1019 self.assertRepr(s.matches("1 0--"), """
1020 (== (& (sig s) (const 4'd12)) (const 4'd8))
1021 """)
1022
1023 def test_matches_enum(self):
1024 s = Signal(SignedEnum)
1025 self.assertRepr(s.matches(SignedEnum.FOO), """
1026 (== (sig s) (const 1'sd-1))
1027 """)
1028
1029 def test_matches_width_wrong(self):
1030 s = Signal(4)
1031 with self.assertRaisesRegex(SyntaxError,
1032 r"^Match pattern '--' must have the same width as "
1033 r"match value \(which is 4\)$"):
1034 s.matches("--")
1035 with self.assertWarnsRegex(SyntaxWarning,
1036 (r"^Match pattern '10110' is wider than match value "
1037 r"\(which has width 4\); "
1038 r"comparison will never be true$")):
1039 s.matches(0b10110)
1040
1041 def test_matches_bits_wrong(self):
1042 s = Signal(4)
1043 with self.assertRaisesRegex(SyntaxError,
1044 (r"^Match pattern 'abc' must consist of 0, 1, "
1045 r"and - \(don't care\) bits, "
1046 r"and may include whitespace$")):
1047 s.matches("abc")
1048
1049 def test_matches_pattern_wrong(self):
1050 s = Signal(4)
1051 with self.assertRaisesRegex(SyntaxError,
1052 r"^Match pattern must be an integer, a string, "
1053 r"or an enumeration, not 1\.0$"):
1054 s.matches(1.0)
1055 '''
1056
1057 if __name__ == '__main__':
1058 unittest.main()