add BitwiseLut and tests
[nmutil.git] / src / nmutil / test / test_lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from contextlib import contextmanager
5 import unittest
6
7 from hashlib import sha256
8 from nmutil.get_test_path import get_test_path
9 from nmutil.lut import BitwiseMux, BitwiseLut
10 from nmigen.sim import Simulator, Delay
11
12
13 @contextmanager
14 def do_sim(test_case, dut, traces=()):
15 sim = Simulator(dut)
16 path = get_test_path(test_case, "sim_test_out")
17 path.parent.mkdir(parents=True, exist_ok=True)
18 vcd_path = path.with_suffix(".vcd")
19 gtkw_path = path.with_suffix(".gtkw")
20 with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
21 gtkw_path.open("wt", encoding="utf-8"),
22 traces=traces):
23 yield sim
24
25
26 def hash_256(v):
27 return int.from_bytes(
28 sha256(bytes(v, encoding='utf-8')).digest(),
29 byteorder='little'
30 )
31
32
33 class TestBitwiseMux(unittest.TestCase):
34 def test(self):
35 width = 2
36 dut = BitwiseMux(width)
37
38 def case(sel, t, f, expected):
39 with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
40 yield dut.sel.eq(sel)
41 yield dut.t.eq(t)
42 yield dut.f.eq(f)
43 yield Delay(1e-6)
44 output = yield dut.output
45 with self.subTest(output=bin(output), expected=bin(expected)):
46 self.assertEqual(expected, output)
47
48 def process():
49 for sel in range(2 ** width):
50 for t in range(2 ** width):
51 for f in range(2**width):
52 expected = 0
53 for i in range(width):
54 if sel & 2 ** i:
55 if t & 2 ** i:
56 expected |= 2 ** i
57 elif f & 2 ** i:
58 expected |= 2 ** i
59 yield from case(sel, t, f, expected)
60 with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
61 sim.add_process(process)
62 sim.run()
63
64
65 class TestBitwiseLut(unittest.TestCase):
66 def test(self):
67 dut = BitwiseLut(3, 16)
68 mask = 2 ** dut.width - 1
69 lut_mask = 2 ** dut.lut.width - 1
70 mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
71 self.assertEqual(mux_inputs, {
72 (): 'mux_input_0bxxx',
73 (False,): 'mux_input_0bxx0',
74 (False, False): 'mux_input_0bx00',
75 (False, False, False): 'mux_input_0b000',
76 (False, False, True): 'mux_input_0b100',
77 (False, True): 'mux_input_0bx10',
78 (False, True, False): 'mux_input_0b010',
79 (False, True, True): 'mux_input_0b110',
80 (True,): 'mux_input_0bxx1',
81 (True, False): 'mux_input_0bx01',
82 (True, False, False): 'mux_input_0b001',
83 (True, False, True): 'mux_input_0b101',
84 (True, True): 'mux_input_0bx11',
85 (True, True, False): 'mux_input_0b011',
86 (True, True, True): 'mux_input_0b111'
87 })
88
89 def case(in0, in1, in2, lut):
90 expected = 0
91 for i in range(dut.width):
92 lut_index = 0
93 if in0 & 2 ** i:
94 lut_index |= 2 ** 0
95 if in1 & 2 ** i:
96 lut_index |= 2 ** 1
97 if in2 & 2 ** i:
98 lut_index |= 2 ** 2
99 if lut & 2 ** lut_index:
100 expected |= 2 ** i
101 with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
102 lut=bin(lut)):
103 yield dut.inputs[0].eq(in0)
104 yield dut.inputs[1].eq(in1)
105 yield dut.inputs[2].eq(in2)
106 yield dut.lut.eq(lut)
107 yield Delay(1e-6)
108 output = yield dut.output
109 with self.subTest(output=bin(output), expected=bin(expected)):
110 self.assertEqual(expected, output)
111
112 def process():
113 for case_index in range(100):
114 with self.subTest(case_index=case_index):
115 in0 = hash_256(f"{case_index} in0") & mask
116 in1 = hash_256(f"{case_index} in1") & mask
117 in2 = hash_256(f"{case_index} in2") & mask
118 lut = hash_256(f"{case_index} lut") & lut_mask
119 yield from case(in0, in1, in2, lut)
120 with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
121 sim.add_process(process)
122 sim.run()
123
124
125 if __name__ == "__main__":
126 unittest.main()