2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
5 from nmigen
import Signal
, Module
, Elaboratable
, Mux
, Cat
, Shape
, Repl
6 from nmigen
.back
.pysim
import Simulator
, Delay
, Settle
7 from nmigen
.cli
import rtlil
9 from ieee754
.part
.partsig
import SimdSignal
10 from ieee754
.part_mux
.part_mux
import PMux
12 from random
import randint
34 return map(''.join
, itertools
.product('01', repeat
=k
))
37 def create_ilang(dut
, traces
, test_name
):
38 vl
= rtlil
.convert(dut
, ports
=traces
)
39 with
open("%s.il" % test_name
, "w") as f
:
43 def create_simulator(module
, traces
, test_name
):
44 create_ilang(module
, traces
, test_name
)
45 return Simulator(module
)
48 # XXX this is for coriolis2 experimentation
49 class TestAddMod2(Elaboratable
):
50 def __init__(self
, width
, partpoints
):
51 self
.partpoints
= partpoints
52 self
.a
= SimdSignal(partpoints
, width
)
53 self
.b
= SimdSignal(partpoints
, width
)
54 self
.bsig
= Signal(width
)
55 self
.add_output
= Signal(width
)
56 self
.ls_output
= Signal(width
) # left shift
57 self
.ls_scal_output
= Signal(width
) # left shift
58 self
.rs_output
= Signal(width
) # right shift
59 self
.rs_scal_output
= Signal(width
) # right shift
60 self
.sub_output
= Signal(width
)
61 self
.eq_output
= Signal(len(partpoints
)+1)
62 self
.gt_output
= Signal(len(partpoints
)+1)
63 self
.ge_output
= Signal(len(partpoints
)+1)
64 self
.ne_output
= Signal(len(partpoints
)+1)
65 self
.lt_output
= Signal(len(partpoints
)+1)
66 self
.le_output
= Signal(len(partpoints
)+1)
67 self
.mux_sel2
= Signal(len(partpoints
)+1)
68 self
.mux_sel2
= SimdSignal(partpoints
, len(partpoints
))
69 self
.mux2_out
= Signal(width
)
70 self
.carry_in
= Signal(len(partpoints
)+1)
71 self
.add_carry_out
= Signal(len(partpoints
)+1)
72 self
.sub_carry_out
= Signal(len(partpoints
)+1)
73 self
.neg_output
= Signal(width
)
75 def elaborate(self
, platform
):
81 self
.mux_sel2
.set_module(m
)
83 sync
+= self
.lt_output
.eq(self
.a
< self
.b
)
84 sync
+= self
.ne_output
.eq(self
.a
!= self
.b
)
85 sync
+= self
.le_output
.eq(self
.a
<= self
.b
)
86 sync
+= self
.gt_output
.eq(self
.a
> self
.b
)
87 sync
+= self
.eq_output
.eq(self
.a
== self
.b
)
88 sync
+= self
.ge_output
.eq(self
.a
>= self
.b
)
90 add_out
, add_carry
= self
.a
.add_op(self
.a
, self
.b
,
92 sync
+= self
.add_output
.eq(add_out
)
93 sync
+= self
.add_carry_out
.eq(add_carry
)
95 sub_out
, sub_carry
= self
.a
.sub_op(self
.a
, self
.b
,
97 sync
+= self
.sub_output
.eq(sub_out
)
98 sync
+= self
.sub_carry_out
.eq(sub_carry
)
100 sync
+= self
.neg_output
.eq(-self
.a
)
102 sync
+= self
.ls_output
.eq(self
.a
<< self
.b
)
103 sync
+= self
.rs_output
.eq(self
.a
>> self
.b
)
104 ppts
= self
.partpoints
105 sync
+= self
.mux_out2
.eq(Mux(self
.mux_sel2
, self
.a
, self
.b
))
107 comb
+= self
.bsig
.eq(self
.b
.lower())
108 sync
+= self
.ls_scal_output
.eq(self
.a
<< self
.bsig
)
109 sync
+= self
.rs_scal_output
.eq(self
.a
>> self
.bsig
)
114 class TestMuxMod(Elaboratable
):
115 def __init__(self
, width
, partpoints
):
116 self
.partpoints
= partpoints
117 self
.a
= SimdSignal(partpoints
, width
)
118 self
.b
= SimdSignal(partpoints
, width
)
119 self
.mux_sel
= Signal(len(partpoints
)+1)
120 self
.mux_sel2
= SimdSignal(partpoints
, len(partpoints
)+1)
121 self
.mux_out2
= Signal(width
)
123 def elaborate(self
, platform
):
129 self
.mux_sel2
.set_module(m
)
130 ppts
= self
.partpoints
132 comb
+= self
.mux_out2
.eq(Mux(self
.mux_sel2
, self
.a
, self
.b
))
137 class TestCatMod(Elaboratable
):
138 def __init__(self
, width
, partpoints
):
139 self
.partpoints
= partpoints
140 self
.a
= SimdSignal(partpoints
, width
)
141 self
.b
= SimdSignal(partpoints
, width
*2)
142 self
.o
= SimdSignal(partpoints
, width
*3)
143 self
.cat_out
= self
.o
.sig
145 def elaborate(self
, platform
):
152 comb
+= self
.o
.eq(Cat(self
.a
, self
.b
))
157 class TestReplMod(Elaboratable
):
158 def __init__(self
, width
, partpoints
):
159 self
.partpoints
= partpoints
160 self
.a
= SimdSignal(partpoints
, width
)
161 self
.repl_sel
= Signal(len(partpoints
)+1)
162 self
.repl_out
= Signal(width
*2)
164 def elaborate(self
, platform
):
169 comb
+= self
.repl_out
.eq(Repl(self
.a
, 2))
174 class TestAssMod(Elaboratable
):
175 def __init__(self
, width
, out_shape
, partpoints
, scalar
):
176 self
.partpoints
= partpoints
179 self
.a
= Signal(width
)
181 self
.a
= SimdSignal(partpoints
, width
)
182 self
.ass_out
= SimdSignal(partpoints
, out_shape
)
184 def elaborate(self
, platform
):
189 self
.ass_out
.set_module(m
)
191 comb
+= self
.ass_out
.eq(self
.a
)
196 class TestAddMod(Elaboratable
):
197 def __init__(self
, width
, partpoints
):
198 self
.partpoints
= partpoints
199 self
.a
= SimdSignal(partpoints
, width
)
200 self
.b
= SimdSignal(partpoints
, width
)
201 self
.bsig
= Signal(width
)
202 self
.add_output
= Signal(width
)
203 self
.ls_output
= Signal(width
) # left shift
204 self
.ls_scal_output
= Signal(width
) # left shift
205 self
.rs_output
= Signal(width
) # right shift
206 self
.rs_scal_output
= Signal(width
) # right shift
207 self
.sub_output
= Signal(width
)
208 self
.eq_output
= Signal(len(partpoints
)+1)
209 self
.gt_output
= Signal(len(partpoints
)+1)
210 self
.ge_output
= Signal(len(partpoints
)+1)
211 self
.ne_output
= Signal(len(partpoints
)+1)
212 self
.lt_output
= Signal(len(partpoints
)+1)
213 self
.le_output
= Signal(len(partpoints
)+1)
214 self
.carry_in
= Signal(len(partpoints
)+1)
215 self
.add_carry_out
= Signal(len(partpoints
)+1)
216 self
.sub_carry_out
= Signal(len(partpoints
)+1)
217 self
.neg_output
= Signal(width
)
218 self
.signed_output
= Signal(width
)
219 self
.xor_output
= Signal(len(partpoints
)+1)
220 self
.bool_output
= Signal(len(partpoints
)+1)
221 self
.all_output
= Signal(len(partpoints
)+1)
222 self
.any_output
= Signal(len(partpoints
)+1)
224 def elaborate(self
, platform
):
231 comb
+= self
.lt_output
.eq(self
.a
< self
.b
)
232 comb
+= self
.ne_output
.eq(self
.a
!= self
.b
)
233 comb
+= self
.le_output
.eq(self
.a
<= self
.b
)
234 comb
+= self
.gt_output
.eq(self
.a
> self
.b
)
235 comb
+= self
.eq_output
.eq(self
.a
== self
.b
)
236 comb
+= self
.ge_output
.eq(self
.a
>= self
.b
)
238 add_out
, add_carry
= self
.a
.add_op(self
.a
, self
.b
,
240 comb
+= self
.add_output
.eq(add_out
.sig
)
241 comb
+= self
.add_carry_out
.eq(add_carry
)
243 sub_out
, sub_carry
= self
.a
.sub_op(self
.a
, self
.b
,
245 comb
+= self
.sub_output
.eq(sub_out
.sig
)
246 comb
+= self
.sub_carry_out
.eq(sub_carry
)
247 # neg / signed / unsigned
248 comb
+= self
.neg_output
.eq((-self
.a
).sig
)
249 comb
+= self
.signed_output
.eq(self
.a
.as_signed())
250 # horizontal operators
251 comb
+= self
.xor_output
.eq(self
.a
.xor())
252 comb
+= self
.bool_output
.eq(self
.a
.bool())
253 comb
+= self
.all_output
.eq(self
.a
.all())
254 comb
+= self
.any_output
.eq(self
.a
.any())
256 comb
+= self
.ls_output
.eq(self
.a
<< self
.b
)
258 comb
+= self
.rs_output
.eq(self
.a
>> self
.b
)
259 ppts
= self
.partpoints
261 comb
+= self
.bsig
.eq(self
.b
.lower())
262 comb
+= self
.ls_scal_output
.eq(self
.a
<< self
.bsig
)
264 comb
+= self
.rs_scal_output
.eq(self
.a
>> self
.bsig
)
269 class TestMux(unittest
.TestCase
):
272 part_mask
= Signal(3) # divide into 4-bits
273 module
= TestMuxMod(width
, part_mask
)
275 test_name
= "part_sig_mux"
280 sim
= create_simulator(module
, traces
, test_name
)
284 def test_muxop(msg_prefix
, *maskbit_list
):
285 for a
, b
in [(0x0000, 0x0000),
292 # convert to mask_list
294 for mb
in maskbit_list
:
301 # TODO: sel needs to go through permutations of mask_list
302 for p
in perms(len(mask_list
)):
306 for i
, v
in enumerate(p
):
308 sel |
= maskbit_list
[i
]
309 selmask |
= mask_list
[i
]
311 yield module
.a
.lower().eq(a
)
312 yield module
.b
.lower().eq(b
)
313 yield module
.mux_sel
.eq(sel
)
314 yield module
.mux_sel2
.lower().eq(sel
)
317 # do the partitioned tests
318 for i
, mask
in enumerate(mask_list
):
324 outval2
= (yield module
.mux_out2
)
325 msg
= f
"{msg_prefix}: mux " + \
326 f
"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
327 f
" => 0x{y:X} != 0x{outval2:X}, masklist %s"
328 # print ((msg % str(maskbit_list)).format(locals()))
329 self
.assertEqual(y
, outval2
, msg
% str(maskbit_list
))
331 yield part_mask
.eq(0)
332 yield from test_muxop("16-bit", 0b1111)
333 yield part_mask
.eq(0b10)
334 yield from test_muxop("8-bit", 0b1100, 0b0011)
335 yield part_mask
.eq(0b1111)
336 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
338 sim
.add_process(async_process
)
340 vcd_file
=open(test_name
+ ".vcd", "w"),
341 gtkw_file
=open(test_name
+ ".gtkw", "w"),
346 class TestCat(unittest
.TestCase
):
349 part_mask
= Signal(3) # divide into 4-bits
350 module
= TestCatMod(width
, part_mask
)
352 test_name
= "part_sig_cat"
357 sim
= create_simulator(module
, traces
, test_name
)
359 # annoying recursive import issue
360 from ieee754
.part_cat
.cat
import get_runlengths
364 def test_catop(msg_prefix
):
365 # define lengths of a/b test input
367 # pairs of test values a, b
368 for a
, b
in [(0x0000, 0x00000000),
369 (0xDCBA, 0x12345678),
370 (0xABCD, 0x01234567),
373 (0x1F1F, 0xF1F1F1F1),
374 (0x0000, 0xFFFFFFFF)]:
376 # convert a and b to partitions
377 apart
, bpart
= [], []
378 ajump
, bjump
= alen
// 4, blen
// 4
380 apart
.append((a
>> (ajump
*i
) & ((1<<ajump
)-1)))
381 bpart
.append((b
>> (bjump
*i
) & ((1<<bjump
)-1)))
383 print ("apart bpart", hex(a
), hex(b
),
384 list(map(hex, apart
)), list(map(hex, bpart
)))
386 yield module
.a
.lower().eq(a
)
387 yield module
.b
.lower().eq(b
)
391 # work out the runlengths for this mask.
392 # 0b011 returns [1,1,2] (for a mask of length 3)
393 mval
= yield part_mask
394 runlengths
= get_runlengths(mval
, 3)
401 print ("runlength", i
,
403 "apart", hex(apart
[ai
]),
411 print ("runlength", i
,
413 "bpart", hex(bpart
[bi
]),
421 outval
= (yield module
.cat_out
)
422 msg
= f
"{msg_prefix}: cat " + \
423 f
"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
424 f
" => 0x{y:X} != 0x{outval:X}"
425 self
.assertEqual(y
, outval
, msg
)
427 yield part_mask
.eq(0)
428 yield from test_catop("16-bit")
429 yield part_mask
.eq(0b10)
430 yield from test_catop("8-bit")
431 yield part_mask
.eq(0b1111)
432 yield from test_catop("4-bit")
434 sim
.add_process(async_process
)
436 vcd_file
=open(test_name
+ ".vcd", "w"),
437 gtkw_file
=open(test_name
+ ".gtkw", "w"),
442 class TestRepl(unittest
.TestCase
):
445 part_mask
= Signal(3) # divide into 4-bits
446 module
= TestReplMod(width
, part_mask
)
448 test_name
= "part_sig_repl"
452 sim
= create_simulator(module
, traces
, test_name
)
454 # annoying recursive import issue
455 from ieee754
.part_repl
.repl
import get_runlengths
459 def test_replop(msg_prefix
):
460 # define length of a test input
473 # convert a to partitions
477 apart
.append((a
>> (ajump
*i
) & ((1<<ajump
)-1)))
479 print ("apart", hex(a
), list(map(hex, apart
)))
481 yield module
.a
.lower().eq(a
)
485 # work out the runlengths for this mask.
486 # 0b011 returns [1,1,2] (for a mask of length 3)
487 mval
= yield part_mask
488 runlengths
= get_runlengths(mval
, 3)
492 # a twice because the test is Repl(a, 2)
493 for aidx
in range(2):
495 print ("runlength", i
,
497 "apart", hex(apart
[ai
[aidx
]]),
499 y |
= apart
[ai
[aidx
]] << j
505 outval
= (yield module
.repl_out
)
506 msg
= f
"{msg_prefix}: repl " + \
507 f
"0x{mval:X} 0x{a:X}" + \
508 f
" => 0x{y:X} != 0x{outval:X}"
509 self
.assertEqual(y
, outval
, msg
)
511 yield part_mask
.eq(0)
512 yield from test_replop("16-bit")
513 yield part_mask
.eq(0b10)
514 yield from test_replop("8-bit")
515 yield part_mask
.eq(0b1111)
516 yield from test_replop("4-bit")
518 sim
.add_process(async_process
)
520 vcd_file
=open(test_name
+ ".vcd", "w"),
521 gtkw_file
=open(test_name
+ ".gtkw", "w"),
526 class TestAssign(unittest
.TestCase
):
527 def run_tst(self
, in_width
, out_width
, out_signed
, scalar
):
528 part_mask
= Signal(3) # divide into 4-bits
529 module
= TestAssMod(in_width
,
530 Shape(out_width
, out_signed
),
533 test_name
= "part_sig_ass_%d_%d_%s_%s" % (in_width
, out_width
,
534 "signed" if out_signed
else "unsigned",
535 "scalar" if scalar
else "partitioned")
538 module
.ass_out
.lower()]
540 traces
.append(module
.a
)
542 traces
.append(module
.a
.lower())
543 sim
= create_simulator(module
, traces
, test_name
)
545 # annoying recursive import issue
546 from ieee754
.part_cat
.cat
import get_runlengths
550 def test_assop(msg_prefix
):
551 # define lengths of a test input
555 randomvals
.append(randint(0, 65535))
571 # work out the runlengths for this mask.
572 # 0b011 returns [1,1,2] (for a mask of length 3)
573 mval
= yield part_mask
574 runlengths
= get_runlengths(mval
, 3)
576 print ("test a", hex(a
), "mask", bin(mval
), "widths",
578 "signed", out_signed
,
581 # convert a to runlengths sub-sections
586 subpart
= (a
>> (ajump
*ai
) & ((1<<(ajump
*i
))-1))
587 msb
= (subpart
>> ((ajump
*i
)-1)) # will contain the sign
588 apart
.append((subpart
, msb
))
589 print ("apart", ajump
*i
, hex(a
), hex(subpart
), msb
)
596 yield module
.a
.lower().eq(a
)
601 ojump
= out_width
// 4
602 for ai
, i
in enumerate(runlengths
):
603 # get "a" partition value
605 # do sign-extension if needed
607 if out_signed
and ojump
> ajump
:
609 signext
= (-1 << ajump
*i
) & ((1<<(ojump
*i
))-1)
613 av
&= ((1<<(ojump
*i
))-1)
614 print ("runlength", i
,
616 "apart", hex(av
), amsb
,
617 "signext", hex(signext
),
624 y
&= (1<<out_width
)-1
627 outval
= (yield module
.ass_out
.lower())
628 outval
&= (1<<out_width
)-1
629 msg
= f
"{msg_prefix}: assign " + \
630 f
"mask 0x{mval:X} input 0x{a:X}" + \
631 f
" => expected 0x{y:X} != actual 0x{outval:X}"
632 self
.assertEqual(y
, outval
, msg
)
634 # run the actual tests, here - 16/8/4 bit partitions
635 for (mask
, name
) in ((0, "16-bit"),
638 with self
.subTest(name
+ " " + test_name
):
639 yield part_mask
.eq(mask
)
641 yield from test_assop(name
)
643 sim
.add_process(async_process
)
645 vcd_file
=open(test_name
+ ".vcd", "w"),
646 gtkw_file
=open(test_name
+ ".gtkw", "w"),
651 for out_width
in [16, 24, 8]:
652 for sign
in [True, False]:
653 for scalar
in [True, False]:
654 self
.run_tst(16, out_width
, sign
, scalar
)
657 class TestSimdSignal(unittest
.TestCase
):
660 part_mask
= Signal(3) # divide into 4-bits
661 module
= TestAddMod(width
, part_mask
)
663 test_name
= "part_sig_add"
669 sim
= create_simulator(module
, traces
, test_name
)
673 def test_xor_fn(a
, mask
):
682 def test_bool_fn(a
, mask
):
686 def test_all_fn(a
, mask
):
687 # slightly different: all bits masked must be 1
691 def test_horizop(msg_prefix
, test_fn
, mod_attr
, *maskbit_list
):
694 randomvals
.append(randint(0, 65535))
713 with self
.subTest("%s %s %s" % (msg_prefix
,
714 test_fn
.__name
__, hex(a
))):
715 yield module
.a
.lower().eq(a
)
717 # convert to mask_list
719 for mb
in maskbit_list
:
726 # do the partitioned tests
727 for i
, mask
in enumerate(mask_list
):
729 # OR y with the lowest set bit in the mask
732 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
733 msg
= f
"{msg_prefix}: {mod_attr} 0x{a:X} " + \
734 f
" => 0x{y:X} != 0x{outval:X}, masklist %s"
735 print((msg
% str(maskbit_list
)).format(locals()))
736 self
.assertEqual(y
, outval
, msg
% str(maskbit_list
))
738 for (test_fn
, mod_attr
) in ((test_xor_fn
, "xor"),
739 (test_all_fn
, "all"),
740 (test_bool_fn
, "any"), # same as bool
741 (test_bool_fn
, "bool"),
744 yield part_mask
.eq(0)
745 yield from test_horizop("16-bit", test_fn
, mod_attr
, 0b1111)
746 yield part_mask
.eq(0b10)
747 yield from test_horizop("8-bit", test_fn
, mod_attr
,
749 yield part_mask
.eq(0b1111)
750 yield from test_horizop("4-bit", test_fn
, mod_attr
,
751 0b1000, 0b0100, 0b0010, 0b0001)
753 def test_ls_scal_fn(carry_in
, a
, b
, mask
):
755 bits
= count_bits(mask
)
756 newb
= b
& ((bits
-1))
757 print ("%x %x %x bits %d trunc %x" % \
758 (a
, b
, mask
, bits
, newb
))
762 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
763 sum = ((a
& mask
) << b
)
765 carry
= (sum & mask
) != sum
767 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
770 def test_rs_scal_fn(carry_in
, a
, b
, mask
):
772 bits
= count_bits(mask
)
773 newb
= b
& ((bits
-1))
774 print ("%x %x %x bits %d trunc %x" % \
775 (a
, b
, mask
, bits
, newb
))
779 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
780 sum = ((a
& mask
) >> b
)
782 carry
= (sum & mask
) != sum
784 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
787 def test_ls_fn(carry_in
, a
, b
, mask
):
789 bits
= count_bits(mask
)
790 fz
= first_zero(mask
)
791 newb
= b
& ((bits
-1)<<fz
)
792 print ("%x %x %x bits %d zero %d trunc %x" % \
793 (a
, b
, mask
, bits
, fz
, newb
))
797 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
800 sum = ((a
& mask
) << b
)
802 carry
= (sum & mask
) != sum
804 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
807 def test_rs_fn(carry_in
, a
, b
, mask
):
809 bits
= count_bits(mask
)
810 fz
= first_zero(mask
)
811 newb
= b
& ((bits
-1)<<fz
)
812 print ("%x %x %x bits %d zero %d trunc %x" % \
813 (a
, b
, mask
, bits
, fz
, newb
))
817 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
820 sum = ((a
& mask
) >> b
)
822 carry
= (sum & mask
) != sum
824 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
827 def test_add_fn(carry_in
, a
, b
, mask
):
828 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
829 sum = (a
& mask
) + (b
& mask
) + lsb
831 carry
= (sum & mask
) != sum
832 print(a
, b
, sum, mask
)
835 def test_sub_fn(carry_in
, a
, b
, mask
):
836 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
837 sum = (a
& mask
) + (~b
& mask
) + lsb
839 carry
= (sum & mask
) != sum
842 def test_neg_fn(carry_in
, a
, b
, mask
):
843 lsb
= mask
& ~
(mask
- 1) # has only LSB of mask set
844 pos
= lsb
.bit_length() - 1 # find bit position
845 a
= (a
& mask
) >> pos
# shift it to the beginning
846 return ((-a
) << pos
) & mask
, 0 # negate and shift it back
848 def test_signed_fn(carry_in
, a
, b
, mask
):
851 def test_op(msg_prefix
, carry
, test_fn
, mod_attr
, *mask_list
):
854 a
, b
= randint(0, 1 << 16), randint(0, 1 << 16)
855 rand_data
.append((a
, b
))
856 for a
, b
in [(0x0000, 0x0000),
862 (0x0000, 0xFFFF)] + rand_data
:
863 yield module
.a
.lower().eq(a
)
864 yield module
.b
.lower().eq(b
)
865 carry_sig
= 0xf if carry
else 0
866 yield module
.carry_in
.eq(carry_sig
)
870 for i
, mask
in enumerate(mask_list
):
871 print ("i/mask", i
, hex(mask
))
872 res
, c
= test_fn(carry
, a
, b
, mask
)
874 lsb
= mask
& ~
(mask
- 1)
875 bit_set
= int(math
.log2(lsb
))
876 carry_result |
= c
<< int(bit_set
/4)
877 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
878 # TODO: get (and test) carry output as well
879 print(a
, b
, outval
, carry
)
880 msg
= f
"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
881 f
" => 0x{y:X} != 0x{outval:X}"
882 self
.assertEqual(y
, outval
, msg
)
883 if hasattr(module
, "%s_carry_out" % mod_attr
):
884 c_outval
= (yield getattr(module
,
885 "%s_carry_out" % mod_attr
))
886 msg
= f
"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
887 f
" => 0x{carry_result:X} != 0x{c_outval:X}"
888 self
.assertEqual(carry_result
, c_outval
, msg
)
890 # run through series of operations with corresponding
891 # "helper" routines to reproduce the result (test_fn). the same
892 # a/b input is passed to *all* outputs, where the name of the
893 # output attribute (mod_attr) will contain the result to be
894 # compared against the expected output from test_fn
895 for (test_fn
, mod_attr
) in (
896 (test_ls_scal_fn
, "ls_scal"),
898 (test_rs_scal_fn
, "rs_scal"),
900 (test_add_fn
, "add"),
901 (test_sub_fn
, "sub"),
902 (test_neg_fn
, "neg"),
903 (test_signed_fn
, "signed"),
905 yield part_mask
.eq(0)
906 yield from test_op("16-bit", 1, test_fn
, mod_attr
, 0xFFFF)
907 yield from test_op("16-bit", 0, test_fn
, mod_attr
, 0xFFFF)
908 yield part_mask
.eq(0b10)
909 yield from test_op("8-bit", 0, test_fn
, mod_attr
,
911 yield from test_op("8-bit", 1, test_fn
, mod_attr
,
913 yield part_mask
.eq(0b1111)
914 yield from test_op("4-bit", 0, test_fn
, mod_attr
,
915 0xF000, 0x0F00, 0x00F0, 0x000F)
916 yield from test_op("4-bit", 1, test_fn
, mod_attr
,
917 0xF000, 0x0F00, 0x00F0, 0x000F)
919 def test_ne_fn(a
, b
, mask
):
920 return (a
& mask
) != (b
& mask
)
922 def test_lt_fn(a
, b
, mask
):
923 return (a
& mask
) < (b
& mask
)
925 def test_le_fn(a
, b
, mask
):
926 return (a
& mask
) <= (b
& mask
)
928 def test_eq_fn(a
, b
, mask
):
929 return (a
& mask
) == (b
& mask
)
931 def test_gt_fn(a
, b
, mask
):
932 return (a
& mask
) > (b
& mask
)
934 def test_ge_fn(a
, b
, mask
):
935 return (a
& mask
) >= (b
& mask
)
937 def test_binop(msg_prefix
, test_fn
, mod_attr
, *maskbit_list
):
938 for a
, b
in [(0x0000, 0x0000),
948 yield module
.a
.lower().eq(a
)
949 yield module
.b
.lower().eq(b
)
951 # convert to mask_list
953 for mb
in maskbit_list
:
960 # do the partitioned tests
961 for i
, mask
in enumerate(mask_list
):
962 if test_fn(a
, b
, mask
):
963 # OR y with the lowest set bit in the mask
966 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
967 msg
= f
"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
968 f
" => 0x{y:X} != 0x{outval:X}, masklist %s"
969 print((msg
% str(maskbit_list
)).format(locals()))
970 self
.assertEqual(y
, outval
, msg
% str(maskbit_list
))
972 for (test_fn
, mod_attr
) in ((test_eq_fn
, "eq"),
979 yield part_mask
.eq(0)
980 yield from test_binop("16-bit", test_fn
, mod_attr
, 0b1111)
981 yield part_mask
.eq(0b10)
982 yield from test_binop("8-bit", test_fn
, mod_attr
,
984 yield part_mask
.eq(0b1111)
985 yield from test_binop("4-bit", test_fn
, mod_attr
,
986 0b1000, 0b0100, 0b0010, 0b0001)
988 sim
.add_process(async_process
)
990 vcd_file
=open(test_name
+ ".vcd", "w"),
991 gtkw_file
=open(test_name
+ ".gtkw", "w"),
996 # TODO: adapt to SimdSignal. perhaps a different style?
998 from nmigen.tests.test_hdl_ast import SignedEnum
999 def test_matches(self)
1001 self.assertRepr(s.matches(), "(const 1'd0)")
1002 self.assertRepr(s.matches(1), """
1003 (== (sig s) (const 1'd1))
1005 self.assertRepr(s.matches(0, 1), """
1006 (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
1008 self.assertRepr(s.matches("10--"), """
1009 (== (& (sig s) (const 4'd12)) (const 4'd8))
1011 self.assertRepr(s.matches("1 0--"), """
1012 (== (& (sig s) (const 4'd12)) (const 4'd8))
1015 def test_matches_enum(self):
1016 s = Signal(SignedEnum)
1017 self.assertRepr(s.matches(SignedEnum.FOO), """
1018 (== (sig s) (const 1'sd-1))
1021 def test_matches_width_wrong(self):
1023 with self.assertRaisesRegex(SyntaxError,
1024 r"^Match pattern '--' must have the same width as "
1025 r"match value \(which is 4\)$"):
1027 with self.assertWarnsRegex(SyntaxWarning,
1028 (r"^Match pattern '10110' is wider than match value "
1029 r"\(which has width 4\); "
1030 r"comparison will never be true$")):
1033 def test_matches_bits_wrong(self):
1035 with self.assertRaisesRegex(SyntaxError,
1036 (r"^Match pattern 'abc' must consist of 0, 1, "
1037 r"and - \(don't care\) bits, "
1038 r"and may include whitespace$")):
1041 def test_matches_pattern_wrong(self):
1043 with self.assertRaisesRegex(SyntaxError,
1044 r"^Match pattern must be an integer, a string, "
1045 r"or an enumeration, not 1\.0$"):
1049 if __name__
== '__main__':