allow alias SimdSignal<-PartitionedSignal to allow tracking back through
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape, Repl
6 from nmigen.back.pysim import Simulator, Delay, Settle
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import SimdSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15 import math
16
17
18 def first_zero(x):
19 res = 0
20 for i in range(16):
21 if x & (1<<i):
22 return res
23 res += 1
24
25 def count_bits(x):
26 res = 0
27 for i in range(16):
28 if x & (1<<i):
29 res += 1
30 return res
31
32
33 def perms(k):
34 return map(''.join, itertools.product('01', repeat=k))
35
36
37 def create_ilang(dut, traces, test_name):
38 vl = rtlil.convert(dut, ports=traces)
39 with open("%s.il" % test_name, "w") as f:
40 f.write(vl)
41
42
43 def create_simulator(module, traces, test_name):
44 create_ilang(module, traces, test_name)
45 return Simulator(module)
46
47
48 # XXX this is for coriolis2 experimentation
49 class TestAddMod2(Elaboratable):
50 def __init__(self, width, partpoints):
51 self.partpoints = partpoints
52 self.a = SimdSignal(partpoints, width)
53 self.b = SimdSignal(partpoints, width)
54 self.bsig = Signal(width)
55 self.add_output = Signal(width)
56 self.ls_output = Signal(width) # left shift
57 self.ls_scal_output = Signal(width) # left shift
58 self.rs_output = Signal(width) # right shift
59 self.rs_scal_output = Signal(width) # right shift
60 self.sub_output = Signal(width)
61 self.eq_output = Signal(len(partpoints)+1)
62 self.gt_output = Signal(len(partpoints)+1)
63 self.ge_output = Signal(len(partpoints)+1)
64 self.ne_output = Signal(len(partpoints)+1)
65 self.lt_output = Signal(len(partpoints)+1)
66 self.le_output = Signal(len(partpoints)+1)
67 self.mux_sel2 = Signal(len(partpoints)+1)
68 self.mux_sel2 = SimdSignal(partpoints, len(partpoints))
69 self.mux2_out = Signal(width)
70 self.carry_in = Signal(len(partpoints)+1)
71 self.add_carry_out = Signal(len(partpoints)+1)
72 self.sub_carry_out = Signal(len(partpoints)+1)
73 self.neg_output = Signal(width)
74
75 def elaborate(self, platform):
76 m = Module()
77 comb = m.d.comb
78 sync = m.d.sync
79 self.a.set_module(m)
80 self.b.set_module(m)
81 self.mux_sel2.set_module(m)
82 # compares
83 sync += self.lt_output.eq(self.a < self.b)
84 sync += self.ne_output.eq(self.a != self.b)
85 sync += self.le_output.eq(self.a <= self.b)
86 sync += self.gt_output.eq(self.a > self.b)
87 sync += self.eq_output.eq(self.a == self.b)
88 sync += self.ge_output.eq(self.a >= self.b)
89 # add
90 add_out, add_carry = self.a.add_op(self.a, self.b,
91 self.carry_in)
92 sync += self.add_output.eq(add_out)
93 sync += self.add_carry_out.eq(add_carry)
94 # sub
95 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
96 self.carry_in)
97 sync += self.sub_output.eq(sub_out)
98 sync += self.sub_carry_out.eq(sub_carry)
99 # neg
100 sync += self.neg_output.eq(-self.a)
101 # left shift
102 sync += self.ls_output.eq(self.a << self.b)
103 sync += self.rs_output.eq(self.a >> self.b)
104 ppts = self.partpoints
105 sync += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
106 # scalar left shift
107 comb += self.bsig.eq(self.b.lower())
108 sync += self.ls_scal_output.eq(self.a << self.bsig)
109 sync += self.rs_scal_output.eq(self.a >> self.bsig)
110
111 return m
112
113
114 class TestMuxMod(Elaboratable):
115 def __init__(self, width, partpoints):
116 self.partpoints = partpoints
117 self.a = SimdSignal(partpoints, width)
118 self.b = SimdSignal(partpoints, width)
119 self.mux_sel = Signal(len(partpoints)+1)
120 self.mux_sel2 = SimdSignal(partpoints, len(partpoints)+1)
121 self.mux_out2 = Signal(width)
122
123 def elaborate(self, platform):
124 m = Module()
125 comb = m.d.comb
126 sync = m.d.sync
127 self.a.set_module(m)
128 self.b.set_module(m)
129 self.mux_sel2.set_module(m)
130 ppts = self.partpoints
131
132 comb += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
133
134 return m
135
136
137 class TestCatMod(Elaboratable):
138 def __init__(self, width, partpoints):
139 self.partpoints = partpoints
140 self.a = SimdSignal(partpoints, width)
141 self.b = SimdSignal(partpoints, width*2)
142 self.o = SimdSignal(partpoints, width*3)
143 self.cat_out = self.o.sig
144
145 def elaborate(self, platform):
146 m = Module()
147 comb = m.d.comb
148 self.a.set_module(m)
149 self.b.set_module(m)
150 self.o.set_module(m)
151
152 comb += self.o.eq(Cat(self.a, self.b))
153
154 return m
155
156
157 class TestReplMod(Elaboratable):
158 def __init__(self, width, partpoints):
159 self.partpoints = partpoints
160 self.a = SimdSignal(partpoints, width)
161 self.repl_sel = Signal(len(partpoints)+1)
162 self.repl_out = Signal(width*2)
163
164 def elaborate(self, platform):
165 m = Module()
166 comb = m.d.comb
167 self.a.set_module(m)
168
169 comb += self.repl_out.eq(Repl(self.a, 2))
170
171 return m
172
173
174 class TestAssMod(Elaboratable):
175 def __init__(self, width, out_shape, partpoints, scalar):
176 self.partpoints = partpoints
177 self.scalar = scalar
178 if scalar:
179 self.a = Signal(width)
180 else:
181 self.a = SimdSignal(partpoints, width)
182 self.ass_out = SimdSignal(partpoints, out_shape)
183
184 def elaborate(self, platform):
185 m = Module()
186 comb = m.d.comb
187 if not self.scalar:
188 self.a.set_module(m)
189 self.ass_out.set_module(m)
190
191 comb += self.ass_out.eq(self.a)
192
193 return m
194
195
196 class TestAddMod(Elaboratable):
197 def __init__(self, width, partpoints):
198 self.partpoints = partpoints
199 self.a = SimdSignal(partpoints, width)
200 self.b = SimdSignal(partpoints, width)
201 self.bsig = Signal(width)
202 self.add_output = Signal(width)
203 self.ls_output = Signal(width) # left shift
204 self.ls_scal_output = Signal(width) # left shift
205 self.rs_output = Signal(width) # right shift
206 self.rs_scal_output = Signal(width) # right shift
207 self.sub_output = Signal(width)
208 self.eq_output = Signal(len(partpoints)+1)
209 self.gt_output = Signal(len(partpoints)+1)
210 self.ge_output = Signal(len(partpoints)+1)
211 self.ne_output = Signal(len(partpoints)+1)
212 self.lt_output = Signal(len(partpoints)+1)
213 self.le_output = Signal(len(partpoints)+1)
214 self.carry_in = Signal(len(partpoints)+1)
215 self.add_carry_out = Signal(len(partpoints)+1)
216 self.sub_carry_out = Signal(len(partpoints)+1)
217 self.neg_output = Signal(width)
218 self.signed_output = Signal(width)
219 self.xor_output = Signal(len(partpoints)+1)
220 self.bool_output = Signal(len(partpoints)+1)
221 self.all_output = Signal(len(partpoints)+1)
222 self.any_output = Signal(len(partpoints)+1)
223
224 def elaborate(self, platform):
225 m = Module()
226 comb = m.d.comb
227 sync = m.d.sync
228 self.a.set_module(m)
229 self.b.set_module(m)
230 # compares
231 comb += self.lt_output.eq(self.a < self.b)
232 comb += self.ne_output.eq(self.a != self.b)
233 comb += self.le_output.eq(self.a <= self.b)
234 comb += self.gt_output.eq(self.a > self.b)
235 comb += self.eq_output.eq(self.a == self.b)
236 comb += self.ge_output.eq(self.a >= self.b)
237 # add
238 add_out, add_carry = self.a.add_op(self.a, self.b,
239 self.carry_in)
240 comb += self.add_output.eq(add_out.sig)
241 comb += self.add_carry_out.eq(add_carry)
242 # sub
243 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
244 self.carry_in)
245 comb += self.sub_output.eq(sub_out.sig)
246 comb += self.sub_carry_out.eq(sub_carry)
247 # neg / signed / unsigned
248 comb += self.neg_output.eq((-self.a).sig)
249 comb += self.signed_output.eq(self.a.as_signed())
250 # horizontal operators
251 comb += self.xor_output.eq(self.a.xor())
252 comb += self.bool_output.eq(self.a.bool())
253 comb += self.all_output.eq(self.a.all())
254 comb += self.any_output.eq(self.a.any())
255 # left shift
256 comb += self.ls_output.eq(self.a << self.b)
257 # right shift
258 comb += self.rs_output.eq(self.a >> self.b)
259 ppts = self.partpoints
260 # scalar left shift
261 comb += self.bsig.eq(self.b.lower())
262 comb += self.ls_scal_output.eq(self.a << self.bsig)
263 # scalar right shift
264 comb += self.rs_scal_output.eq(self.a >> self.bsig)
265
266 return m
267
268
269 class TestMux(unittest.TestCase):
270 def test(self):
271 width = 16
272 part_mask = Signal(3) # divide into 4-bits
273 module = TestMuxMod(width, part_mask)
274
275 test_name = "part_sig_mux"
276 traces = [part_mask,
277 module.a.sig,
278 module.b.sig,
279 module.mux_out2]
280 sim = create_simulator(module, traces, test_name)
281
282 def async_process():
283
284 def test_muxop(msg_prefix, *maskbit_list):
285 for a, b in [(0x0000, 0x0000),
286 (0x1234, 0x1234),
287 (0xABCD, 0xABCD),
288 (0xFFFF, 0x0000),
289 (0x0000, 0x0000),
290 (0xFFFF, 0xFFFF),
291 (0x0000, 0xFFFF)]:
292 # convert to mask_list
293 mask_list = []
294 for mb in maskbit_list:
295 v = 0
296 for i in range(4):
297 if mb & (1 << i):
298 v |= 0xf << (i*4)
299 mask_list.append(v)
300
301 # TODO: sel needs to go through permutations of mask_list
302 for p in perms(len(mask_list)):
303
304 sel = 0
305 selmask = 0
306 for i, v in enumerate(p):
307 if v == '1':
308 sel |= maskbit_list[i]
309 selmask |= mask_list[i]
310
311 yield module.a.lower().eq(a)
312 yield module.b.lower().eq(b)
313 yield module.mux_sel.eq(sel)
314 yield module.mux_sel2.lower().eq(sel)
315 yield Delay(0.1e-6)
316 y = 0
317 # do the partitioned tests
318 for i, mask in enumerate(mask_list):
319 if (selmask & mask):
320 y |= (a & mask)
321 else:
322 y |= (b & mask)
323 # check the result
324 outval2 = (yield module.mux_out2)
325 msg = f"{msg_prefix}: mux " + \
326 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
327 f" => 0x{y:X} != 0x{outval2:X}, masklist %s"
328 # print ((msg % str(maskbit_list)).format(locals()))
329 self.assertEqual(y, outval2, msg % str(maskbit_list))
330
331 yield part_mask.eq(0)
332 yield from test_muxop("16-bit", 0b1111)
333 yield part_mask.eq(0b10)
334 yield from test_muxop("8-bit", 0b1100, 0b0011)
335 yield part_mask.eq(0b1111)
336 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
337
338 sim.add_process(async_process)
339 with sim.write_vcd(
340 vcd_file=open(test_name + ".vcd", "w"),
341 gtkw_file=open(test_name + ".gtkw", "w"),
342 traces=traces):
343 sim.run()
344
345
346 class TestCat(unittest.TestCase):
347 def test(self):
348 width = 16
349 part_mask = Signal(3) # divide into 4-bits
350 module = TestCatMod(width, part_mask)
351
352 test_name = "part_sig_cat"
353 traces = [part_mask,
354 module.a.sig,
355 module.b.sig,
356 module.cat_out]
357 sim = create_simulator(module, traces, test_name)
358
359 # annoying recursive import issue
360 from ieee754.part_cat.cat import get_runlengths
361
362 def async_process():
363
364 def test_catop(msg_prefix):
365 # define lengths of a/b test input
366 alen, blen = 16, 32
367 # pairs of test values a, b
368 for a, b in [(0x0000, 0x00000000),
369 (0xDCBA, 0x12345678),
370 (0xABCD, 0x01234567),
371 (0xFFFF, 0x0000),
372 (0x0000, 0x0000),
373 (0x1F1F, 0xF1F1F1F1),
374 (0x0000, 0xFFFFFFFF)]:
375
376 # convert a and b to partitions
377 apart, bpart = [], []
378 ajump, bjump = alen // 4, blen // 4
379 for i in range(4):
380 apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
381 bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
382
383 print ("apart bpart", hex(a), hex(b),
384 list(map(hex, apart)), list(map(hex, bpart)))
385
386 yield module.a.lower().eq(a)
387 yield module.b.lower().eq(b)
388 yield Delay(0.1e-6)
389
390 y = 0
391 # work out the runlengths for this mask.
392 # 0b011 returns [1,1,2] (for a mask of length 3)
393 mval = yield part_mask
394 runlengths = get_runlengths(mval, 3)
395 j = 0
396 ai = 0
397 bi = 0
398 for i in runlengths:
399 # a first
400 for _ in range(i):
401 print ("runlength", i,
402 "ai", ai,
403 "apart", hex(apart[ai]),
404 "j", j)
405 y |= apart[ai] << j
406 print (" y", hex(y))
407 j += ajump
408 ai += 1
409 # now b
410 for _ in range(i):
411 print ("runlength", i,
412 "bi", bi,
413 "bpart", hex(bpart[bi]),
414 "j", j)
415 y |= bpart[bi] << j
416 print (" y", hex(y))
417 j += bjump
418 bi += 1
419
420 # check the result
421 outval = (yield module.cat_out)
422 msg = f"{msg_prefix}: cat " + \
423 f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
424 f" => 0x{y:X} != 0x{outval:X}"
425 self.assertEqual(y, outval, msg)
426
427 yield part_mask.eq(0)
428 yield from test_catop("16-bit")
429 yield part_mask.eq(0b10)
430 yield from test_catop("8-bit")
431 yield part_mask.eq(0b1111)
432 yield from test_catop("4-bit")
433
434 sim.add_process(async_process)
435 with sim.write_vcd(
436 vcd_file=open(test_name + ".vcd", "w"),
437 gtkw_file=open(test_name + ".gtkw", "w"),
438 traces=traces):
439 sim.run()
440
441
442 class TestRepl(unittest.TestCase):
443 def test(self):
444 width = 16
445 part_mask = Signal(3) # divide into 4-bits
446 module = TestReplMod(width, part_mask)
447
448 test_name = "part_sig_repl"
449 traces = [part_mask,
450 module.a.sig,
451 module.repl_out]
452 sim = create_simulator(module, traces, test_name)
453
454 # annoying recursive import issue
455 from ieee754.part_repl.repl import get_runlengths
456
457 def async_process():
458
459 def test_replop(msg_prefix):
460 # define length of a test input
461 alen = 16
462 # test values a
463 for a in [0x0000,
464 0xDCBA,
465 0x1234,
466 0xABCD,
467 0xFFFF,
468 0x0000,
469 0x1F1F,
470 0xF1F1,
471 ]:
472
473 # convert a to partitions
474 apart = []
475 ajump = alen // 4
476 for i in range(4):
477 apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
478
479 print ("apart", hex(a), list(map(hex, apart)))
480
481 yield module.a.lower().eq(a)
482 yield Delay(0.1e-6)
483
484 y = 0
485 # work out the runlengths for this mask.
486 # 0b011 returns [1,1,2] (for a mask of length 3)
487 mval = yield part_mask
488 runlengths = get_runlengths(mval, 3)
489 j = 0
490 ai = [0, 0]
491 for i in runlengths:
492 # a twice because the test is Repl(a, 2)
493 for aidx in range(2):
494 for _ in range(i):
495 print ("runlength", i,
496 "ai", ai,
497 "apart", hex(apart[ai[aidx]]),
498 "j", j)
499 y |= apart[ai[aidx]] << j
500 print (" y", hex(y))
501 j += ajump
502 ai[aidx] += 1
503
504 # check the result
505 outval = (yield module.repl_out)
506 msg = f"{msg_prefix}: repl " + \
507 f"0x{mval:X} 0x{a:X}" + \
508 f" => 0x{y:X} != 0x{outval:X}"
509 self.assertEqual(y, outval, msg)
510
511 yield part_mask.eq(0)
512 yield from test_replop("16-bit")
513 yield part_mask.eq(0b10)
514 yield from test_replop("8-bit")
515 yield part_mask.eq(0b1111)
516 yield from test_replop("4-bit")
517
518 sim.add_process(async_process)
519 with sim.write_vcd(
520 vcd_file=open(test_name + ".vcd", "w"),
521 gtkw_file=open(test_name + ".gtkw", "w"),
522 traces=traces):
523 sim.run()
524
525
526 class TestAssign(unittest.TestCase):
527 def run_tst(self, in_width, out_width, out_signed, scalar):
528 part_mask = Signal(3) # divide into 4-bits
529 module = TestAssMod(in_width,
530 Shape(out_width, out_signed),
531 part_mask, scalar)
532
533 test_name = "part_sig_ass_%d_%d_%s_%s" % (in_width, out_width,
534 "signed" if out_signed else "unsigned",
535 "scalar" if scalar else "partitioned")
536
537 traces = [part_mask,
538 module.ass_out.lower()]
539 if module.scalar:
540 traces.append(module.a)
541 else:
542 traces.append(module.a.lower())
543 sim = create_simulator(module, traces, test_name)
544
545 # annoying recursive import issue
546 from ieee754.part_cat.cat import get_runlengths
547
548 def async_process():
549
550 def test_assop(msg_prefix):
551 # define lengths of a test input
552 alen = in_width
553 randomvals = []
554 for i in range(10):
555 randomvals.append(randint(0, 65535))
556 # test values a
557 for a in [0x0001,
558 0x0010,
559 0x0100,
560 0x1000,
561 0x000c,
562 0x00c0,
563 0x0c00,
564 0xc000,
565 0x1234,
566 0xDCBA,
567 0xABCD,
568 0x0000,
569 0xFFFF,
570 ] + randomvals:
571 # work out the runlengths for this mask.
572 # 0b011 returns [1,1,2] (for a mask of length 3)
573 mval = yield part_mask
574 runlengths = get_runlengths(mval, 3)
575
576 print ("test a", hex(a), "mask", bin(mval), "widths",
577 in_width, out_width,
578 "signed", out_signed,
579 "scalar", scalar)
580
581 # convert a to runlengths sub-sections
582 apart = []
583 ajump = alen // 4
584 ai = 0
585 for i in runlengths:
586 subpart = (a >> (ajump*ai) & ((1<<(ajump*i))-1))
587 msb = (subpart >> ((ajump*i)-1)) # will contain the sign
588 apart.append((subpart, msb))
589 print ("apart", ajump*i, hex(a), hex(subpart), msb)
590 if not scalar:
591 ai += i
592
593 if scalar:
594 yield module.a.eq(a)
595 else:
596 yield module.a.lower().eq(a)
597 yield Delay(0.1e-6)
598
599 y = 0
600 j = 0
601 ojump = out_width // 4
602 for ai, i in enumerate(runlengths):
603 # get "a" partition value
604 av, amsb = apart[ai]
605 # do sign-extension if needed
606 signext = 0
607 if out_signed and ojump > ajump:
608 if amsb:
609 signext = (-1 << ajump*i) & ((1<<(ojump*i))-1)
610 av |= signext
611 # truncate if needed
612 if ojump < ajump:
613 av &= ((1<<(ojump*i))-1)
614 print ("runlength", i,
615 "ai", ai,
616 "apart", hex(av), amsb,
617 "signext", hex(signext),
618 "j", j)
619 y |= av << j
620 print (" y", hex(y))
621 j += ojump*i
622 ai += 1
623
624 y &= (1<<out_width)-1
625
626 # check the result
627 outval = (yield module.ass_out.lower())
628 outval &= (1<<out_width)-1
629 msg = f"{msg_prefix}: assign " + \
630 f"mask 0x{mval:X} input 0x{a:X}" + \
631 f" => expected 0x{y:X} != actual 0x{outval:X}"
632 self.assertEqual(y, outval, msg)
633
634 # run the actual tests, here - 16/8/4 bit partitions
635 for (mask, name) in ((0, "16-bit"),
636 (0b10, "8-bit"),
637 (0b111, "4-bit")):
638 with self.subTest(name + " " + test_name):
639 yield part_mask.eq(mask)
640 yield Settle()
641 yield from test_assop(name)
642
643 sim.add_process(async_process)
644 with sim.write_vcd(
645 vcd_file=open(test_name + ".vcd", "w"),
646 gtkw_file=open(test_name + ".gtkw", "w"),
647 traces=traces):
648 sim.run()
649
650 def test(self):
651 for out_width in [16, 24, 8]:
652 for sign in [True, False]:
653 for scalar in [True, False]:
654 self.run_tst(16, out_width, sign, scalar)
655
656
657 class TestSimdSignal(unittest.TestCase):
658 def test(self):
659 width = 16
660 part_mask = Signal(3) # divide into 4-bits
661 module = TestAddMod(width, part_mask)
662
663 test_name = "part_sig_add"
664 traces = [part_mask,
665 module.a.sig,
666 module.b.sig,
667 module.add_output,
668 module.eq_output]
669 sim = create_simulator(module, traces, test_name)
670
671 def async_process():
672
673 def test_xor_fn(a, mask):
674 test = (a & mask)
675 result = 0
676 while test != 0:
677 bit = (test & 1)
678 result ^= bit
679 test >>= 1
680 return result
681
682 def test_bool_fn(a, mask):
683 test = (a & mask)
684 return test != 0
685
686 def test_all_fn(a, mask):
687 # slightly different: all bits masked must be 1
688 test = (a & mask)
689 return test == mask
690
691 def test_horizop(msg_prefix, test_fn, mod_attr, *maskbit_list):
692 randomvals = []
693 for i in range(100):
694 randomvals.append(randint(0, 65535))
695 for a in [0x0000,
696 0x1111,
697 0x0001,
698 0x0010,
699 0x0100,
700 0x1000,
701 0x000F,
702 0x00F0,
703 0x0F00,
704 0xF000,
705 0x00FF,
706 0xFF00,
707 0x1234,
708 0xABCD,
709 0xFFFF,
710 0x8000,
711 0xBEEF, 0xFEED,
712 ]+randomvals:
713 with self.subTest("%s %s %s" % (msg_prefix,
714 test_fn.__name__, hex(a))):
715 yield module.a.lower().eq(a)
716 yield Delay(0.1e-6)
717 # convert to mask_list
718 mask_list = []
719 for mb in maskbit_list:
720 v = 0
721 for i in range(4):
722 if mb & (1 << i):
723 v |= 0xf << (i*4)
724 mask_list.append(v)
725 y = 0
726 # do the partitioned tests
727 for i, mask in enumerate(mask_list):
728 if test_fn(a, mask):
729 # OR y with the lowest set bit in the mask
730 y |= maskbit_list[i]
731 # check the result
732 outval = (yield getattr(module, "%s_output" % mod_attr))
733 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} " + \
734 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
735 print((msg % str(maskbit_list)).format(locals()))
736 self.assertEqual(y, outval, msg % str(maskbit_list))
737
738 for (test_fn, mod_attr) in ((test_xor_fn, "xor"),
739 (test_all_fn, "all"),
740 (test_bool_fn, "any"), # same as bool
741 (test_bool_fn, "bool"),
742 #(test_ne_fn, "ne"),
743 ):
744 yield part_mask.eq(0)
745 yield from test_horizop("16-bit", test_fn, mod_attr, 0b1111)
746 yield part_mask.eq(0b10)
747 yield from test_horizop("8-bit", test_fn, mod_attr,
748 0b1100, 0b0011)
749 yield part_mask.eq(0b1111)
750 yield from test_horizop("4-bit", test_fn, mod_attr,
751 0b1000, 0b0100, 0b0010, 0b0001)
752
753 def test_ls_scal_fn(carry_in, a, b, mask):
754 # reduce range of b
755 bits = count_bits(mask)
756 newb = b & ((bits-1))
757 print ("%x %x %x bits %d trunc %x" % \
758 (a, b, mask, bits, newb))
759 b = newb
760 # TODO: carry
761 carry_in = 0
762 lsb = mask & ~(mask-1) if carry_in else 0
763 sum = ((a & mask) << b)
764 result = mask & sum
765 carry = (sum & mask) != sum
766 carry = 0
767 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
768 return result, carry
769
770 def test_rs_scal_fn(carry_in, a, b, mask):
771 # reduce range of b
772 bits = count_bits(mask)
773 newb = b & ((bits-1))
774 print ("%x %x %x bits %d trunc %x" % \
775 (a, b, mask, bits, newb))
776 b = newb
777 # TODO: carry
778 carry_in = 0
779 lsb = mask & ~(mask-1) if carry_in else 0
780 sum = ((a & mask) >> b)
781 result = mask & sum
782 carry = (sum & mask) != sum
783 carry = 0
784 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
785 return result, carry
786
787 def test_ls_fn(carry_in, a, b, mask):
788 # reduce range of b
789 bits = count_bits(mask)
790 fz = first_zero(mask)
791 newb = b & ((bits-1)<<fz)
792 print ("%x %x %x bits %d zero %d trunc %x" % \
793 (a, b, mask, bits, fz, newb))
794 b = newb
795 # TODO: carry
796 carry_in = 0
797 lsb = mask & ~(mask-1) if carry_in else 0
798 b = (b & mask)
799 b = b >>fz
800 sum = ((a & mask) << b)
801 result = mask & sum
802 carry = (sum & mask) != sum
803 carry = 0
804 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
805 return result, carry
806
807 def test_rs_fn(carry_in, a, b, mask):
808 # reduce range of b
809 bits = count_bits(mask)
810 fz = first_zero(mask)
811 newb = b & ((bits-1)<<fz)
812 print ("%x %x %x bits %d zero %d trunc %x" % \
813 (a, b, mask, bits, fz, newb))
814 b = newb
815 # TODO: carry
816 carry_in = 0
817 lsb = mask & ~(mask-1) if carry_in else 0
818 b = (b & mask)
819 b = b >>fz
820 sum = ((a & mask) >> b)
821 result = mask & sum
822 carry = (sum & mask) != sum
823 carry = 0
824 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
825 return result, carry
826
827 def test_add_fn(carry_in, a, b, mask):
828 lsb = mask & ~(mask-1) if carry_in else 0
829 sum = (a & mask) + (b & mask) + lsb
830 result = mask & sum
831 carry = (sum & mask) != sum
832 print(a, b, sum, mask)
833 return result, carry
834
835 def test_sub_fn(carry_in, a, b, mask):
836 lsb = mask & ~(mask-1) if carry_in else 0
837 sum = (a & mask) + (~b & mask) + lsb
838 result = mask & sum
839 carry = (sum & mask) != sum
840 return result, carry
841
842 def test_neg_fn(carry_in, a, b, mask):
843 lsb = mask & ~(mask - 1) # has only LSB of mask set
844 pos = lsb.bit_length() - 1 # find bit position
845 a = (a & mask) >> pos # shift it to the beginning
846 return ((-a) << pos) & mask, 0 # negate and shift it back
847
848 def test_signed_fn(carry_in, a, b, mask):
849 return a & mask, 0
850
851 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
852 rand_data = []
853 for i in range(100):
854 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
855 rand_data.append((a, b))
856 for a, b in [(0x0000, 0x0000),
857 (0x1234, 0x1234),
858 (0xABCD, 0xABCD),
859 (0xFFFF, 0x0000),
860 (0x0000, 0x0000),
861 (0xFFFF, 0xFFFF),
862 (0x0000, 0xFFFF)] + rand_data:
863 yield module.a.lower().eq(a)
864 yield module.b.lower().eq(b)
865 carry_sig = 0xf if carry else 0
866 yield module.carry_in.eq(carry_sig)
867 yield Delay(0.1e-6)
868 y = 0
869 carry_result = 0
870 for i, mask in enumerate(mask_list):
871 print ("i/mask", i, hex(mask))
872 res, c = test_fn(carry, a, b, mask)
873 y |= res
874 lsb = mask & ~(mask - 1)
875 bit_set = int(math.log2(lsb))
876 carry_result |= c << int(bit_set/4)
877 outval = (yield getattr(module, "%s_output" % mod_attr))
878 # TODO: get (and test) carry output as well
879 print(a, b, outval, carry)
880 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
881 f" => 0x{y:X} != 0x{outval:X}"
882 self.assertEqual(y, outval, msg)
883 if hasattr(module, "%s_carry_out" % mod_attr):
884 c_outval = (yield getattr(module,
885 "%s_carry_out" % mod_attr))
886 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
887 f" => 0x{carry_result:X} != 0x{c_outval:X}"
888 self.assertEqual(carry_result, c_outval, msg)
889
890 # run through series of operations with corresponding
891 # "helper" routines to reproduce the result (test_fn). the same
892 # a/b input is passed to *all* outputs, where the name of the
893 # output attribute (mod_attr) will contain the result to be
894 # compared against the expected output from test_fn
895 for (test_fn, mod_attr) in (
896 (test_ls_scal_fn, "ls_scal"),
897 (test_ls_fn, "ls"),
898 (test_rs_scal_fn, "rs_scal"),
899 (test_rs_fn, "rs"),
900 (test_add_fn, "add"),
901 (test_sub_fn, "sub"),
902 (test_neg_fn, "neg"),
903 (test_signed_fn, "signed"),
904 ):
905 yield part_mask.eq(0)
906 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
907 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
908 yield part_mask.eq(0b10)
909 yield from test_op("8-bit", 0, test_fn, mod_attr,
910 0xFF00, 0x00FF)
911 yield from test_op("8-bit", 1, test_fn, mod_attr,
912 0xFF00, 0x00FF)
913 yield part_mask.eq(0b1111)
914 yield from test_op("4-bit", 0, test_fn, mod_attr,
915 0xF000, 0x0F00, 0x00F0, 0x000F)
916 yield from test_op("4-bit", 1, test_fn, mod_attr,
917 0xF000, 0x0F00, 0x00F0, 0x000F)
918
919 def test_ne_fn(a, b, mask):
920 return (a & mask) != (b & mask)
921
922 def test_lt_fn(a, b, mask):
923 return (a & mask) < (b & mask)
924
925 def test_le_fn(a, b, mask):
926 return (a & mask) <= (b & mask)
927
928 def test_eq_fn(a, b, mask):
929 return (a & mask) == (b & mask)
930
931 def test_gt_fn(a, b, mask):
932 return (a & mask) > (b & mask)
933
934 def test_ge_fn(a, b, mask):
935 return (a & mask) >= (b & mask)
936
937 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
938 for a, b in [(0x0000, 0x0000),
939 (0x1234, 0x1234),
940 (0xABCD, 0xABCD),
941 (0xFFFF, 0x0000),
942 (0x0000, 0x0000),
943 (0xFFFF, 0xFFFF),
944 (0x0000, 0xFFFF),
945 (0xABCD, 0xABCE),
946 (0x8000, 0x0000),
947 (0xBEEF, 0xFEED)]:
948 yield module.a.lower().eq(a)
949 yield module.b.lower().eq(b)
950 yield Delay(0.1e-6)
951 # convert to mask_list
952 mask_list = []
953 for mb in maskbit_list:
954 v = 0
955 for i in range(4):
956 if mb & (1 << i):
957 v |= 0xf << (i*4)
958 mask_list.append(v)
959 y = 0
960 # do the partitioned tests
961 for i, mask in enumerate(mask_list):
962 if test_fn(a, b, mask):
963 # OR y with the lowest set bit in the mask
964 y |= maskbit_list[i]
965 # check the result
966 outval = (yield getattr(module, "%s_output" % mod_attr))
967 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
968 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
969 print((msg % str(maskbit_list)).format(locals()))
970 self.assertEqual(y, outval, msg % str(maskbit_list))
971
972 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
973 (test_gt_fn, "gt"),
974 (test_ge_fn, "ge"),
975 (test_lt_fn, "lt"),
976 (test_le_fn, "le"),
977 (test_ne_fn, "ne"),
978 ):
979 yield part_mask.eq(0)
980 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
981 yield part_mask.eq(0b10)
982 yield from test_binop("8-bit", test_fn, mod_attr,
983 0b1100, 0b0011)
984 yield part_mask.eq(0b1111)
985 yield from test_binop("4-bit", test_fn, mod_attr,
986 0b1000, 0b0100, 0b0010, 0b0001)
987
988 sim.add_process(async_process)
989 with sim.write_vcd(
990 vcd_file=open(test_name + ".vcd", "w"),
991 gtkw_file=open(test_name + ".gtkw", "w"),
992 traces=traces):
993 sim.run()
994
995
996 # TODO: adapt to SimdSignal. perhaps a different style?
997 '''
998 from nmigen.tests.test_hdl_ast import SignedEnum
999 def test_matches(self)
1000 s = Signal(4)
1001 self.assertRepr(s.matches(), "(const 1'd0)")
1002 self.assertRepr(s.matches(1), """
1003 (== (sig s) (const 1'd1))
1004 """)
1005 self.assertRepr(s.matches(0, 1), """
1006 (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
1007 """)
1008 self.assertRepr(s.matches("10--"), """
1009 (== (& (sig s) (const 4'd12)) (const 4'd8))
1010 """)
1011 self.assertRepr(s.matches("1 0--"), """
1012 (== (& (sig s) (const 4'd12)) (const 4'd8))
1013 """)
1014
1015 def test_matches_enum(self):
1016 s = Signal(SignedEnum)
1017 self.assertRepr(s.matches(SignedEnum.FOO), """
1018 (== (sig s) (const 1'sd-1))
1019 """)
1020
1021 def test_matches_width_wrong(self):
1022 s = Signal(4)
1023 with self.assertRaisesRegex(SyntaxError,
1024 r"^Match pattern '--' must have the same width as "
1025 r"match value \(which is 4\)$"):
1026 s.matches("--")
1027 with self.assertWarnsRegex(SyntaxWarning,
1028 (r"^Match pattern '10110' is wider than match value "
1029 r"\(which has width 4\); "
1030 r"comparison will never be true$")):
1031 s.matches(0b10110)
1032
1033 def test_matches_bits_wrong(self):
1034 s = Signal(4)
1035 with self.assertRaisesRegex(SyntaxError,
1036 (r"^Match pattern 'abc' must consist of 0, 1, "
1037 r"and - \(don't care\) bits, "
1038 r"and may include whitespace$")):
1039 s.matches("abc")
1040
1041 def test_matches_pattern_wrong(self):
1042 s = Signal(4)
1043 with self.assertRaisesRegex(SyntaxError,
1044 r"^Match pattern must be an integer, a string, "
1045 r"or an enumeration, not 1\.0$"):
1046 s.matches(1.0)
1047 '''
1048
1049 if __name__ == '__main__':
1050 unittest.main()