dff2d6ec847a7b7538db3b61f087b23d80328247
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable, Mux, Cat, Shape
6 from nmigen.back.pysim import Simulator, Delay, Settle
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import PartitionedSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15 import math
16
17
18 def first_zero(x):
19 res = 0
20 for i in range(16):
21 if x & (1<<i):
22 return res
23 res += 1
24
25 def count_bits(x):
26 res = 0
27 for i in range(16):
28 if x & (1<<i):
29 res += 1
30 return res
31
32
33 def perms(k):
34 return map(''.join, itertools.product('01', repeat=k))
35
36
37 def create_ilang(dut, traces, test_name):
38 vl = rtlil.convert(dut, ports=traces)
39 with open("%s.il" % test_name, "w") as f:
40 f.write(vl)
41
42
43 def create_simulator(module, traces, test_name):
44 create_ilang(module, traces, test_name)
45 return Simulator(module)
46
47
48 # XXX this is for coriolis2 experimentation
49 class TestAddMod2(Elaboratable):
50 def __init__(self, width, partpoints):
51 self.partpoints = partpoints
52 self.a = PartitionedSignal(partpoints, width)
53 self.b = PartitionedSignal(partpoints, width)
54 self.bsig = Signal(width)
55 self.add_output = Signal(width)
56 self.ls_output = Signal(width) # left shift
57 self.ls_scal_output = Signal(width) # left shift
58 self.rs_output = Signal(width) # right shift
59 self.rs_scal_output = Signal(width) # right shift
60 self.sub_output = Signal(width)
61 self.eq_output = Signal(len(partpoints)+1)
62 self.gt_output = Signal(len(partpoints)+1)
63 self.ge_output = Signal(len(partpoints)+1)
64 self.ne_output = Signal(len(partpoints)+1)
65 self.lt_output = Signal(len(partpoints)+1)
66 self.le_output = Signal(len(partpoints)+1)
67 self.mux_sel2 = Signal(len(partpoints)+1)
68 self.mux_sel2 = PartitionedSignal(partpoints, len(partpoints))
69 self.mux_out = Signal(width)
70 self.mux2_out = Signal(width)
71 self.carry_in = Signal(len(partpoints)+1)
72 self.add_carry_out = Signal(len(partpoints)+1)
73 self.sub_carry_out = Signal(len(partpoints)+1)
74 self.neg_output = Signal(width)
75
76 def elaborate(self, platform):
77 m = Module()
78 comb = m.d.comb
79 sync = m.d.sync
80 self.a.set_module(m)
81 self.b.set_module(m)
82 self.mux_sel2.set_module(m)
83 # compares
84 sync += self.lt_output.eq(self.a < self.b)
85 sync += self.ne_output.eq(self.a != self.b)
86 sync += self.le_output.eq(self.a <= self.b)
87 sync += self.gt_output.eq(self.a > self.b)
88 sync += self.eq_output.eq(self.a == self.b)
89 sync += self.ge_output.eq(self.a >= self.b)
90 # add
91 add_out, add_carry = self.a.add_op(self.a, self.b,
92 self.carry_in)
93 sync += self.add_output.eq(add_out)
94 sync += self.add_carry_out.eq(add_carry)
95 # sub
96 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
97 self.carry_in)
98 sync += self.sub_output.eq(sub_out)
99 sync += self.sub_carry_out.eq(sub_carry)
100 # neg
101 sync += self.neg_output.eq(-self.a)
102 # left shift
103 sync += self.ls_output.eq(self.a << self.b)
104 sync += self.rs_output.eq(self.a >> self.b)
105 ppts = self.partpoints
106 sync += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
107 sync += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
108 # scalar left shift
109 comb += self.bsig.eq(self.b.lower())
110 sync += self.ls_scal_output.eq(self.a << self.bsig)
111 sync += self.rs_scal_output.eq(self.a >> self.bsig)
112
113 return m
114
115
116 class TestMuxMod(Elaboratable):
117 def __init__(self, width, partpoints):
118 self.partpoints = partpoints
119 self.a = PartitionedSignal(partpoints, width)
120 self.b = PartitionedSignal(partpoints, width)
121 self.mux_sel = Signal(len(partpoints)+1)
122 self.mux_sel2 = PartitionedSignal(partpoints, len(partpoints)+1)
123 self.mux_out = Signal(width)
124 self.mux_out2 = Signal(width)
125
126 def elaborate(self, platform):
127 m = Module()
128 comb = m.d.comb
129 sync = m.d.sync
130 self.a.set_module(m)
131 self.b.set_module(m)
132 self.mux_sel2.set_module(m)
133 ppts = self.partpoints
134
135 comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
136 comb += self.mux_out2.eq(Mux(self.mux_sel2, self.a, self.b))
137
138 return m
139
140
141 class TestCatMod(Elaboratable):
142 def __init__(self, width, partpoints):
143 self.partpoints = partpoints
144 self.a = PartitionedSignal(partpoints, width)
145 self.b = PartitionedSignal(partpoints, width*2)
146 self.cat_sel = Signal(len(partpoints)+1)
147 self.cat_out = Signal(width*3)
148
149 def elaborate(self, platform):
150 m = Module()
151 comb = m.d.comb
152 self.a.set_module(m)
153 self.b.set_module(m)
154 #self.cat_sel.set_module(m)
155
156 comb += self.cat_out.eq(Cat(self.a, self.b))
157
158 return m
159
160
161 class TestAssMod(Elaboratable):
162 def __init__(self, width, out_shape, partpoints, scalar):
163 self.partpoints = partpoints
164 self.scalar = scalar
165 if scalar:
166 self.a = Signal(width)
167 else:
168 self.a = PartitionedSignal(partpoints, width)
169 self.ass_out = PartitionedSignal(partpoints, out_shape)
170
171 def elaborate(self, platform):
172 m = Module()
173 comb = m.d.comb
174 if not self.scalar:
175 self.a.set_module(m)
176 self.ass_out.set_module(m)
177
178 comb += self.ass_out.eq(self.a)
179
180 return m
181
182
183 class TestAddMod(Elaboratable):
184 def __init__(self, width, partpoints):
185 self.partpoints = partpoints
186 self.a = PartitionedSignal(partpoints, width)
187 self.b = PartitionedSignal(partpoints, width)
188 self.bsig = Signal(width)
189 self.add_output = Signal(width)
190 self.ls_output = Signal(width) # left shift
191 self.ls_scal_output = Signal(width) # left shift
192 self.rs_output = Signal(width) # right shift
193 self.rs_scal_output = Signal(width) # right shift
194 self.sub_output = Signal(width)
195 self.eq_output = Signal(len(partpoints)+1)
196 self.gt_output = Signal(len(partpoints)+1)
197 self.ge_output = Signal(len(partpoints)+1)
198 self.ne_output = Signal(len(partpoints)+1)
199 self.lt_output = Signal(len(partpoints)+1)
200 self.le_output = Signal(len(partpoints)+1)
201 self.carry_in = Signal(len(partpoints)+1)
202 self.add_carry_out = Signal(len(partpoints)+1)
203 self.sub_carry_out = Signal(len(partpoints)+1)
204 self.neg_output = Signal(width)
205 self.signed_output = Signal(width)
206 self.xor_output = Signal(len(partpoints)+1)
207 self.bool_output = Signal(len(partpoints)+1)
208 self.all_output = Signal(len(partpoints)+1)
209 self.any_output = Signal(len(partpoints)+1)
210
211 def elaborate(self, platform):
212 m = Module()
213 comb = m.d.comb
214 sync = m.d.sync
215 self.a.set_module(m)
216 self.b.set_module(m)
217 # compares
218 comb += self.lt_output.eq(self.a < self.b)
219 comb += self.ne_output.eq(self.a != self.b)
220 comb += self.le_output.eq(self.a <= self.b)
221 comb += self.gt_output.eq(self.a > self.b)
222 comb += self.eq_output.eq(self.a == self.b)
223 comb += self.ge_output.eq(self.a >= self.b)
224 # add
225 add_out, add_carry = self.a.add_op(self.a, self.b,
226 self.carry_in)
227 comb += self.add_output.eq(add_out.sig)
228 comb += self.add_carry_out.eq(add_carry)
229 # sub
230 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
231 self.carry_in)
232 comb += self.sub_output.eq(sub_out.sig)
233 comb += self.sub_carry_out.eq(sub_carry)
234 # neg / signed / unsigned
235 comb += self.neg_output.eq((-self.a).sig)
236 comb += self.signed_output.eq(self.a.as_signed())
237 # horizontal operators
238 comb += self.xor_output.eq(self.a.xor())
239 comb += self.bool_output.eq(self.a.bool())
240 comb += self.all_output.eq(self.a.all())
241 comb += self.any_output.eq(self.a.any())
242 # left shift
243 comb += self.ls_output.eq(self.a << self.b)
244 # right shift
245 comb += self.rs_output.eq(self.a >> self.b)
246 ppts = self.partpoints
247 # scalar left shift
248 comb += self.bsig.eq(self.b.lower())
249 comb += self.ls_scal_output.eq(self.a << self.bsig)
250 # scalar right shift
251 comb += self.rs_scal_output.eq(self.a >> self.bsig)
252
253 return m
254
255
256 class TestMux(unittest.TestCase):
257 def test(self):
258 width = 16
259 part_mask = Signal(3) # divide into 4-bits
260 module = TestMuxMod(width, part_mask)
261
262 test_name = "part_sig_mux"
263 traces = [part_mask,
264 module.a.sig,
265 module.b.sig,
266 module.mux_out,
267 module.mux_out2]
268 sim = create_simulator(module, traces, test_name)
269
270 def async_process():
271
272 def test_muxop(msg_prefix, *maskbit_list):
273 for a, b in [(0x0000, 0x0000),
274 (0x1234, 0x1234),
275 (0xABCD, 0xABCD),
276 (0xFFFF, 0x0000),
277 (0x0000, 0x0000),
278 (0xFFFF, 0xFFFF),
279 (0x0000, 0xFFFF)]:
280 # convert to mask_list
281 mask_list = []
282 for mb in maskbit_list:
283 v = 0
284 for i in range(4):
285 if mb & (1 << i):
286 v |= 0xf << (i*4)
287 mask_list.append(v)
288
289 # TODO: sel needs to go through permutations of mask_list
290 for p in perms(len(mask_list)):
291
292 sel = 0
293 selmask = 0
294 for i, v in enumerate(p):
295 if v == '1':
296 sel |= maskbit_list[i]
297 selmask |= mask_list[i]
298
299 yield module.a.lower().eq(a)
300 yield module.b.lower().eq(b)
301 yield module.mux_sel.eq(sel)
302 yield module.mux_sel2.lower().eq(sel)
303 yield Delay(0.1e-6)
304 y = 0
305 # do the partitioned tests
306 for i, mask in enumerate(mask_list):
307 if (selmask & mask):
308 y |= (a & mask)
309 else:
310 y |= (b & mask)
311 # check the result
312 outval = (yield module.mux_out)
313 outval2 = (yield module.mux_out2)
314 msg = f"{msg_prefix}: mux " + \
315 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
316 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
317 # print ((msg % str(maskbit_list)).format(locals()))
318 self.assertEqual(y, outval, msg % str(maskbit_list))
319 self.assertEqual(y, outval2, msg % str(maskbit_list))
320
321 yield part_mask.eq(0)
322 yield from test_muxop("16-bit", 0b1111)
323 yield part_mask.eq(0b10)
324 yield from test_muxop("8-bit", 0b1100, 0b0011)
325 yield part_mask.eq(0b1111)
326 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
327
328 sim.add_process(async_process)
329 with sim.write_vcd(
330 vcd_file=open(test_name + ".vcd", "w"),
331 gtkw_file=open(test_name + ".gtkw", "w"),
332 traces=traces):
333 sim.run()
334
335
336 class TestCat(unittest.TestCase):
337 def test(self):
338 width = 16
339 part_mask = Signal(3) # divide into 4-bits
340 module = TestCatMod(width, part_mask)
341
342 test_name = "part_sig_cat"
343 traces = [part_mask,
344 module.a.sig,
345 module.b.sig,
346 module.cat_out]
347 sim = create_simulator(module, traces, test_name)
348
349 # annoying recursive import issue
350 from ieee754.part_cat.cat import get_runlengths
351
352 def async_process():
353
354 def test_catop(msg_prefix):
355 # define lengths of a/b test input
356 alen, blen = 16, 32
357 # pairs of test values a, b
358 for a, b in [(0x0000, 0x00000000),
359 (0xDCBA, 0x12345678),
360 (0xABCD, 0x01234567),
361 (0xFFFF, 0x0000),
362 (0x0000, 0x0000),
363 (0x1F1F, 0xF1F1F1F1),
364 (0x0000, 0xFFFFFFFF)]:
365
366 # convert a and b to partitions
367 apart, bpart = [], []
368 ajump, bjump = alen // 4, blen // 4
369 for i in range(4):
370 apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
371 bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
372
373 print ("apart bpart", hex(a), hex(b),
374 list(map(hex, apart)), list(map(hex, bpart)))
375
376 yield module.a.lower().eq(a)
377 yield module.b.lower().eq(b)
378 yield Delay(0.1e-6)
379
380 y = 0
381 # work out the runlengths for this mask.
382 # 0b011 returns [1,1,2] (for a mask of length 3)
383 mval = yield part_mask
384 runlengths = get_runlengths(mval, 3)
385 j = 0
386 ai = 0
387 bi = 0
388 for i in runlengths:
389 # a first
390 for _ in range(i):
391 print ("runlength", i,
392 "ai", ai,
393 "apart", hex(apart[ai]),
394 "j", j)
395 y |= apart[ai] << j
396 print (" y", hex(y))
397 j += ajump
398 ai += 1
399 # now b
400 for _ in range(i):
401 print ("runlength", i,
402 "bi", bi,
403 "bpart", hex(bpart[bi]),
404 "j", j)
405 y |= bpart[bi] << j
406 print (" y", hex(y))
407 j += bjump
408 bi += 1
409
410 # check the result
411 outval = (yield module.cat_out)
412 msg = f"{msg_prefix}: cat " + \
413 f"0x{mval:X} 0x{a:X} : 0x{b:X}" + \
414 f" => 0x{y:X} != 0x{outval:X}"
415 self.assertEqual(y, outval, msg)
416
417 yield part_mask.eq(0)
418 yield from test_catop("16-bit")
419 yield part_mask.eq(0b10)
420 yield from test_catop("8-bit")
421 yield part_mask.eq(0b1111)
422 yield from test_catop("4-bit")
423
424 sim.add_process(async_process)
425 with sim.write_vcd(
426 vcd_file=open(test_name + ".vcd", "w"),
427 gtkw_file=open(test_name + ".gtkw", "w"),
428 traces=traces):
429 sim.run()
430
431
432 class TestAssign(unittest.TestCase):
433 def run_tst(self, in_width, out_width, out_signed, scalar):
434 part_mask = Signal(3) # divide into 4-bits
435 module = TestAssMod(in_width,
436 Shape(out_width, out_signed),
437 part_mask, scalar)
438
439 test_name = "part_sig_ass_%d_%d_%s_%s" % (in_width, out_width,
440 "signed" if out_signed else "unsigned",
441 "scalar" if scalar else "partitioned")
442
443 traces = [part_mask,
444 module.ass_out.lower()]
445 if module.scalar:
446 traces.append(module.a)
447 else:
448 traces.append(module.a.lower())
449 sim = create_simulator(module, traces, test_name)
450
451 # annoying recursive import issue
452 from ieee754.part_cat.cat import get_runlengths
453
454 def async_process():
455
456 def test_assop(msg_prefix):
457 # define lengths of a test input
458 alen = in_width
459 randomvals = []
460 for i in range(10):
461 randomvals.append(randint(0, 65535))
462 # test values a
463 for a in [0x0001,
464 0x0010,
465 0x0100,
466 0x1000,
467 0x000c,
468 0x00c0,
469 0x0c00,
470 0xc000,
471 0x1234,
472 0xDCBA,
473 0xABCD,
474 0x0000,
475 0xFFFF,
476 ] + randomvals:
477 # work out the runlengths for this mask.
478 # 0b011 returns [1,1,2] (for a mask of length 3)
479 mval = yield part_mask
480 runlengths = get_runlengths(mval, 3)
481
482 print ("test a", hex(a), "mask", bin(mval), "widths",
483 in_width, out_width,
484 "signed", out_signed,
485 "scalar", scalar)
486
487 # convert a to runlengths sub-sections
488 apart = []
489 ajump = alen // 4
490 ai = 0
491 for i in runlengths:
492 subpart = (a >> (ajump*ai) & ((1<<(ajump*i))-1))
493 msb = (subpart >> ((ajump*i)-1)) # will contain the sign
494 apart.append((subpart, msb))
495 print ("apart", ajump*i, hex(a), hex(subpart), msb)
496 if not scalar:
497 ai += i
498
499 if scalar:
500 yield module.a.eq(a)
501 else:
502 yield module.a.lower().eq(a)
503 yield Delay(0.1e-6)
504
505 y = 0
506 j = 0
507 ojump = out_width // 4
508 for ai, i in enumerate(runlengths):
509 # get "a" partition value
510 av, amsb = apart[ai]
511 # do sign-extension if needed
512 signext = 0
513 if out_signed and ojump > ajump:
514 if amsb:
515 signext = (-1 << ajump*i) & ((1<<(ojump*i))-1)
516 av |= signext
517 # truncate if needed
518 if ojump < ajump:
519 av &= ((1<<(ojump*i))-1)
520 print ("runlength", i,
521 "ai", ai,
522 "apart", hex(av), amsb,
523 "signext", hex(signext),
524 "j", j)
525 y |= av << j
526 print (" y", hex(y))
527 j += ojump*i
528 ai += 1
529
530 y &= (1<<out_width)-1
531
532 # check the result
533 outval = (yield module.ass_out.lower())
534 outval &= (1<<out_width)-1
535 msg = f"{msg_prefix}: assign " + \
536 f"mask 0x{mval:X} input 0x{a:X}" + \
537 f" => expected 0x{y:X} != actual 0x{outval:X}"
538 self.assertEqual(y, outval, msg)
539
540 # run the actual tests, here - 16/8/4 bit partitions
541 for (mask, name) in ((0, "16-bit"),
542 (0b10, "8-bit"),
543 (0b111, "4-bit")):
544 with self.subTest(name + " " + test_name):
545 yield part_mask.eq(mask)
546 yield Settle()
547 yield from test_assop(name)
548
549 sim.add_process(async_process)
550 with sim.write_vcd(
551 vcd_file=open(test_name + ".vcd", "w"),
552 gtkw_file=open(test_name + ".gtkw", "w"),
553 traces=traces):
554 sim.run()
555
556 def test(self):
557 for out_width in [16, 24, 8]:
558 for sign in [True, False]:
559 for scalar in [True, False]:
560 self.run_tst(16, out_width, sign, scalar)
561
562
563 class TestPartitionedSignal(unittest.TestCase):
564 def test(self):
565 width = 16
566 part_mask = Signal(3) # divide into 4-bits
567 module = TestAddMod(width, part_mask)
568
569 test_name = "part_sig_add"
570 traces = [part_mask,
571 module.a.sig,
572 module.b.sig,
573 module.add_output,
574 module.eq_output]
575 sim = create_simulator(module, traces, test_name)
576
577 def async_process():
578
579 def test_xor_fn(a, mask):
580 test = (a & mask)
581 result = 0
582 while test != 0:
583 bit = (test & 1)
584 result ^= bit
585 test >>= 1
586 return result
587
588 def test_bool_fn(a, mask):
589 test = (a & mask)
590 return test != 0
591
592 def test_all_fn(a, mask):
593 # slightly different: all bits masked must be 1
594 test = (a & mask)
595 return test == mask
596
597 def test_horizop(msg_prefix, test_fn, mod_attr, *maskbit_list):
598 randomvals = []
599 for i in range(100):
600 randomvals.append(randint(0, 65535))
601 for a in [0x0000,
602 0x1111,
603 0x0001,
604 0x0010,
605 0x0100,
606 0x1000,
607 0x000F,
608 0x00F0,
609 0x0F00,
610 0xF000,
611 0x00FF,
612 0xFF00,
613 0x1234,
614 0xABCD,
615 0xFFFF,
616 0x8000,
617 0xBEEF, 0xFEED,
618 ]+randomvals:
619 with self.subTest("%s %s %s" % (msg_prefix,
620 test_fn.__name__, hex(a))):
621 yield module.a.lower().eq(a)
622 yield Delay(0.1e-6)
623 # convert to mask_list
624 mask_list = []
625 for mb in maskbit_list:
626 v = 0
627 for i in range(4):
628 if mb & (1 << i):
629 v |= 0xf << (i*4)
630 mask_list.append(v)
631 y = 0
632 # do the partitioned tests
633 for i, mask in enumerate(mask_list):
634 if test_fn(a, mask):
635 # OR y with the lowest set bit in the mask
636 y |= maskbit_list[i]
637 # check the result
638 outval = (yield getattr(module, "%s_output" % mod_attr))
639 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} " + \
640 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
641 print((msg % str(maskbit_list)).format(locals()))
642 self.assertEqual(y, outval, msg % str(maskbit_list))
643
644 for (test_fn, mod_attr) in ((test_xor_fn, "xor"),
645 (test_all_fn, "all"),
646 (test_bool_fn, "any"), # same as bool
647 (test_bool_fn, "bool"),
648 #(test_ne_fn, "ne"),
649 ):
650 yield part_mask.eq(0)
651 yield from test_horizop("16-bit", test_fn, mod_attr, 0b1111)
652 yield part_mask.eq(0b10)
653 yield from test_horizop("8-bit", test_fn, mod_attr,
654 0b1100, 0b0011)
655 yield part_mask.eq(0b1111)
656 yield from test_horizop("4-bit", test_fn, mod_attr,
657 0b1000, 0b0100, 0b0010, 0b0001)
658
659 def test_ls_scal_fn(carry_in, a, b, mask):
660 # reduce range of b
661 bits = count_bits(mask)
662 newb = b & ((bits-1))
663 print ("%x %x %x bits %d trunc %x" % \
664 (a, b, mask, bits, newb))
665 b = newb
666 # TODO: carry
667 carry_in = 0
668 lsb = mask & ~(mask-1) if carry_in else 0
669 sum = ((a & mask) << b)
670 result = mask & sum
671 carry = (sum & mask) != sum
672 carry = 0
673 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
674 return result, carry
675
676 def test_rs_scal_fn(carry_in, a, b, mask):
677 # reduce range of b
678 bits = count_bits(mask)
679 newb = b & ((bits-1))
680 print ("%x %x %x bits %d trunc %x" % \
681 (a, b, mask, bits, newb))
682 b = newb
683 # TODO: carry
684 carry_in = 0
685 lsb = mask & ~(mask-1) if carry_in else 0
686 sum = ((a & mask) >> b)
687 result = mask & sum
688 carry = (sum & mask) != sum
689 carry = 0
690 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
691 return result, carry
692
693 def test_ls_fn(carry_in, a, b, mask):
694 # reduce range of b
695 bits = count_bits(mask)
696 fz = first_zero(mask)
697 newb = b & ((bits-1)<<fz)
698 print ("%x %x %x bits %d zero %d trunc %x" % \
699 (a, b, mask, bits, fz, newb))
700 b = newb
701 # TODO: carry
702 carry_in = 0
703 lsb = mask & ~(mask-1) if carry_in else 0
704 b = (b & mask)
705 b = b >>fz
706 sum = ((a & mask) << b)
707 result = mask & sum
708 carry = (sum & mask) != sum
709 carry = 0
710 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
711 return result, carry
712
713 def test_rs_fn(carry_in, a, b, mask):
714 # reduce range of b
715 bits = count_bits(mask)
716 fz = first_zero(mask)
717 newb = b & ((bits-1)<<fz)
718 print ("%x %x %x bits %d zero %d trunc %x" % \
719 (a, b, mask, bits, fz, newb))
720 b = newb
721 # TODO: carry
722 carry_in = 0
723 lsb = mask & ~(mask-1) if carry_in else 0
724 b = (b & mask)
725 b = b >>fz
726 sum = ((a & mask) >> b)
727 result = mask & sum
728 carry = (sum & mask) != sum
729 carry = 0
730 print("res", hex(a), hex(b), hex(sum), hex(mask), hex(result))
731 return result, carry
732
733 def test_add_fn(carry_in, a, b, mask):
734 lsb = mask & ~(mask-1) if carry_in else 0
735 sum = (a & mask) + (b & mask) + lsb
736 result = mask & sum
737 carry = (sum & mask) != sum
738 print(a, b, sum, mask)
739 return result, carry
740
741 def test_sub_fn(carry_in, a, b, mask):
742 lsb = mask & ~(mask-1) if carry_in else 0
743 sum = (a & mask) + (~b & mask) + lsb
744 result = mask & sum
745 carry = (sum & mask) != sum
746 return result, carry
747
748 def test_neg_fn(carry_in, a, b, mask):
749 lsb = mask & ~(mask - 1) # has only LSB of mask set
750 pos = lsb.bit_length() - 1 # find bit position
751 a = (a & mask) >> pos # shift it to the beginning
752 return ((-a) << pos) & mask, 0 # negate and shift it back
753
754 def test_signed_fn(carry_in, a, b, mask):
755 return a & mask, 0
756
757 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
758 rand_data = []
759 for i in range(100):
760 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
761 rand_data.append((a, b))
762 for a, b in [(0x0000, 0x0000),
763 (0x1234, 0x1234),
764 (0xABCD, 0xABCD),
765 (0xFFFF, 0x0000),
766 (0x0000, 0x0000),
767 (0xFFFF, 0xFFFF),
768 (0x0000, 0xFFFF)] + rand_data:
769 yield module.a.lower().eq(a)
770 yield module.b.lower().eq(b)
771 carry_sig = 0xf if carry else 0
772 yield module.carry_in.eq(carry_sig)
773 yield Delay(0.1e-6)
774 y = 0
775 carry_result = 0
776 for i, mask in enumerate(mask_list):
777 print ("i/mask", i, hex(mask))
778 res, c = test_fn(carry, a, b, mask)
779 y |= res
780 lsb = mask & ~(mask - 1)
781 bit_set = int(math.log2(lsb))
782 carry_result |= c << int(bit_set/4)
783 outval = (yield getattr(module, "%s_output" % mod_attr))
784 # TODO: get (and test) carry output as well
785 print(a, b, outval, carry)
786 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
787 f" => 0x{y:X} != 0x{outval:X}"
788 self.assertEqual(y, outval, msg)
789 if hasattr(module, "%s_carry_out" % mod_attr):
790 c_outval = (yield getattr(module,
791 "%s_carry_out" % mod_attr))
792 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
793 f" => 0x{carry_result:X} != 0x{c_outval:X}"
794 self.assertEqual(carry_result, c_outval, msg)
795
796 # run through series of operations with corresponding
797 # "helper" routines to reproduce the result (test_fn). the same
798 # a/b input is passed to *all* outputs, where the name of the
799 # output attribute (mod_attr) will contain the result to be
800 # compared against the expected output from test_fn
801 for (test_fn, mod_attr) in (
802 (test_ls_scal_fn, "ls_scal"),
803 (test_ls_fn, "ls"),
804 (test_rs_scal_fn, "rs_scal"),
805 (test_rs_fn, "rs"),
806 (test_add_fn, "add"),
807 (test_sub_fn, "sub"),
808 (test_neg_fn, "neg"),
809 (test_signed_fn, "signed"),
810 ):
811 yield part_mask.eq(0)
812 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
813 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
814 yield part_mask.eq(0b10)
815 yield from test_op("8-bit", 0, test_fn, mod_attr,
816 0xFF00, 0x00FF)
817 yield from test_op("8-bit", 1, test_fn, mod_attr,
818 0xFF00, 0x00FF)
819 yield part_mask.eq(0b1111)
820 yield from test_op("4-bit", 0, test_fn, mod_attr,
821 0xF000, 0x0F00, 0x00F0, 0x000F)
822 yield from test_op("4-bit", 1, test_fn, mod_attr,
823 0xF000, 0x0F00, 0x00F0, 0x000F)
824
825 def test_ne_fn(a, b, mask):
826 return (a & mask) != (b & mask)
827
828 def test_lt_fn(a, b, mask):
829 return (a & mask) < (b & mask)
830
831 def test_le_fn(a, b, mask):
832 return (a & mask) <= (b & mask)
833
834 def test_eq_fn(a, b, mask):
835 return (a & mask) == (b & mask)
836
837 def test_gt_fn(a, b, mask):
838 return (a & mask) > (b & mask)
839
840 def test_ge_fn(a, b, mask):
841 return (a & mask) >= (b & mask)
842
843 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
844 for a, b in [(0x0000, 0x0000),
845 (0x1234, 0x1234),
846 (0xABCD, 0xABCD),
847 (0xFFFF, 0x0000),
848 (0x0000, 0x0000),
849 (0xFFFF, 0xFFFF),
850 (0x0000, 0xFFFF),
851 (0xABCD, 0xABCE),
852 (0x8000, 0x0000),
853 (0xBEEF, 0xFEED)]:
854 yield module.a.lower().eq(a)
855 yield module.b.lower().eq(b)
856 yield Delay(0.1e-6)
857 # convert to mask_list
858 mask_list = []
859 for mb in maskbit_list:
860 v = 0
861 for i in range(4):
862 if mb & (1 << i):
863 v |= 0xf << (i*4)
864 mask_list.append(v)
865 y = 0
866 # do the partitioned tests
867 for i, mask in enumerate(mask_list):
868 if test_fn(a, b, mask):
869 # OR y with the lowest set bit in the mask
870 y |= maskbit_list[i]
871 # check the result
872 outval = (yield getattr(module, "%s_output" % mod_attr))
873 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
874 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
875 print((msg % str(maskbit_list)).format(locals()))
876 self.assertEqual(y, outval, msg % str(maskbit_list))
877
878 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
879 (test_gt_fn, "gt"),
880 (test_ge_fn, "ge"),
881 (test_lt_fn, "lt"),
882 (test_le_fn, "le"),
883 (test_ne_fn, "ne"),
884 ):
885 yield part_mask.eq(0)
886 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
887 yield part_mask.eq(0b10)
888 yield from test_binop("8-bit", test_fn, mod_attr,
889 0b1100, 0b0011)
890 yield part_mask.eq(0b1111)
891 yield from test_binop("4-bit", test_fn, mod_attr,
892 0b1000, 0b0100, 0b0010, 0b0001)
893
894 sim.add_process(async_process)
895 with sim.write_vcd(
896 vcd_file=open(test_name + ".vcd", "w"),
897 gtkw_file=open(test_name + ".gtkw", "w"),
898 traces=traces):
899 sim.run()
900
901
902 # TODO: adapt to PartitionedSignal. perhaps a different style?
903 '''
904 from nmigen.tests.test_hdl_ast import SignedEnum
905 def test_matches(self)
906 s = Signal(4)
907 self.assertRepr(s.matches(), "(const 1'd0)")
908 self.assertRepr(s.matches(1), """
909 (== (sig s) (const 1'd1))
910 """)
911 self.assertRepr(s.matches(0, 1), """
912 (r| (cat (== (sig s) (const 1'd0)) (== (sig s) (const 1'd1))))
913 """)
914 self.assertRepr(s.matches("10--"), """
915 (== (& (sig s) (const 4'd12)) (const 4'd8))
916 """)
917 self.assertRepr(s.matches("1 0--"), """
918 (== (& (sig s) (const 4'd12)) (const 4'd8))
919 """)
920
921 def test_matches_enum(self):
922 s = Signal(SignedEnum)
923 self.assertRepr(s.matches(SignedEnum.FOO), """
924 (== (sig s) (const 1'sd-1))
925 """)
926
927 def test_matches_width_wrong(self):
928 s = Signal(4)
929 with self.assertRaisesRegex(SyntaxError,
930 r"^Match pattern '--' must have the same width as "
931 r"match value \(which is 4\)$"):
932 s.matches("--")
933 with self.assertWarnsRegex(SyntaxWarning,
934 (r"^Match pattern '10110' is wider than match value "
935 r"\(which has width 4\); "
936 r"comparison will never be true$")):
937 s.matches(0b10110)
938
939 def test_matches_bits_wrong(self):
940 s = Signal(4)
941 with self.assertRaisesRegex(SyntaxError,
942 (r"^Match pattern 'abc' must consist of 0, 1, "
943 r"and - \(don't care\) bits, "
944 r"and may include whitespace$")):
945 s.matches("abc")
946
947 def test_matches_pattern_wrong(self):
948 s = Signal(4)
949 with self.assertRaisesRegex(SyntaxError,
950 r"^Match pattern must be an integer, a string, "
951 r"or an enumeration, not 1\.0$"):
952 s.matches(1.0)
953 '''
954
955 if __name__ == '__main__':
956 unittest.main()