add Array-based version of BitwiseLut, renaming old version to TreeBitwiseLut in...
[nmutil.git] / src / nmutil / test / test_lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from contextlib import contextmanager
5 import unittest
6 from hashlib import sha256
7 import shutil
8 import subprocess
9 from nmigen.back import rtlil
10 import textwrap
11 from nmigen.hdl.ast import AnyConst, Assert, Signal
12 from nmigen.hdl.dsl import Module
13 from nmigen.hdl.ir import Fragment
14 from nmutil.get_test_path import get_test_path
15 from nmutil.lut import BitwiseMux, BitwiseLut, TreeBitwiseLut
16 from nmigen.sim import Simulator, Delay
17
18
19 @contextmanager
20 def do_sim(test_case, dut, traces=()):
21 sim = Simulator(dut)
22 path = get_test_path(test_case, "sim_test_out")
23 path.parent.mkdir(parents=True, exist_ok=True)
24 vcd_path = path.with_suffix(".vcd")
25 gtkw_path = path.with_suffix(".gtkw")
26 with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
27 gtkw_path.open("wt", encoding="utf-8"),
28 traces=traces):
29 yield sim
30
31
32 # copied from ieee754fpu/src/ieee754/partitioned_signal_tester.py
33 def formal(test_case, hdl, *, base_path="formal_test_temp"):
34 hdl = Fragment.get(hdl, platform="formal")
35 path = get_test_path(test_case, base_path)
36 shutil.rmtree(path, ignore_errors=True)
37 path.mkdir(parents=True)
38 sby_name = "config.sby"
39 sby_file = path / sby_name
40
41 sby_file.write_text(textwrap.dedent(f"""\
42 [options]
43 mode prove
44 depth 1
45 wait on
46
47 [engines]
48 smtbmc
49
50 [script]
51 read_rtlil top.il
52 prep
53
54 [file top.il]
55 {rtlil.convert(hdl)}
56 """), encoding="utf-8")
57 sby = shutil.which('sby')
58 assert sby is not None
59 with subprocess.Popen(
60 [sby, sby_name],
61 cwd=path, text=True, encoding="utf-8",
62 stdin=subprocess.DEVNULL, stdout=subprocess.PIPE
63 ) as p:
64 stdout, stderr = p.communicate()
65 if p.returncode != 0:
66 test_case.fail(f"Formal failed:\n{stdout}")
67
68
69 def hash_256(v):
70 return int.from_bytes(
71 sha256(bytes(v, encoding='utf-8')).digest(),
72 byteorder='little'
73 )
74
75
76 class TestBitwiseMux(unittest.TestCase):
77 def test(self):
78 width = 2
79 dut = BitwiseMux(width)
80
81 def case(sel, t, f, expected):
82 with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
83 yield dut.sel.eq(sel)
84 yield dut.t.eq(t)
85 yield dut.f.eq(f)
86 yield Delay(1e-6)
87 output = yield dut.output
88 with self.subTest(output=bin(output), expected=bin(expected)):
89 self.assertEqual(expected, output)
90
91 def process():
92 for sel in range(2 ** width):
93 for t in range(2 ** width):
94 for f in range(2**width):
95 expected = 0
96 for i in range(width):
97 if sel & 2 ** i:
98 if t & 2 ** i:
99 expected |= 2 ** i
100 elif f & 2 ** i:
101 expected |= 2 ** i
102 yield from case(sel, t, f, expected)
103 with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
104 sim.add_process(process)
105 sim.run()
106
107 def test_formal(self):
108 width = 2
109 dut = BitwiseMux(width)
110 m = Module()
111 m.submodules.dut = dut
112 m.d.comb += dut.sel.eq(AnyConst(width))
113 m.d.comb += dut.f.eq(AnyConst(width))
114 m.d.comb += dut.t.eq(AnyConst(width))
115 for i in range(width):
116 with m.If(dut.sel[i]):
117 m.d.comb += Assert(dut.t[i] == dut.output[i])
118 with m.Else():
119 m.d.comb += Assert(dut.f[i] == dut.output[i])
120 formal(self, m)
121
122
123 class TestBitwiseLut(unittest.TestCase):
124 def tst(self, cls):
125 dut = cls(3, 16)
126 mask = 2 ** dut.width - 1
127 lut_mask = 2 ** dut.lut.width - 1
128 if cls is TreeBitwiseLut:
129 mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
130 self.assertEqual(mux_inputs, {
131 (): 'mux_input_0bxxx',
132 (False,): 'mux_input_0bxx0',
133 (False, False): 'mux_input_0bx00',
134 (False, False, False): 'mux_input_0b000',
135 (False, False, True): 'mux_input_0b100',
136 (False, True): 'mux_input_0bx10',
137 (False, True, False): 'mux_input_0b010',
138 (False, True, True): 'mux_input_0b110',
139 (True,): 'mux_input_0bxx1',
140 (True, False): 'mux_input_0bx01',
141 (True, False, False): 'mux_input_0b001',
142 (True, False, True): 'mux_input_0b101',
143 (True, True): 'mux_input_0bx11',
144 (True, True, False): 'mux_input_0b011',
145 (True, True, True): 'mux_input_0b111'
146 })
147
148 def case(in0, in1, in2, lut):
149 expected = 0
150 for i in range(dut.width):
151 lut_index = 0
152 if in0 & 2 ** i:
153 lut_index |= 2 ** 0
154 if in1 & 2 ** i:
155 lut_index |= 2 ** 1
156 if in2 & 2 ** i:
157 lut_index |= 2 ** 2
158 if lut & 2 ** lut_index:
159 expected |= 2 ** i
160 with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
161 lut=bin(lut)):
162 yield dut.inputs[0].eq(in0)
163 yield dut.inputs[1].eq(in1)
164 yield dut.inputs[2].eq(in2)
165 yield dut.lut.eq(lut)
166 yield Delay(1e-6)
167 output = yield dut.output
168 with self.subTest(output=bin(output), expected=bin(expected)):
169 self.assertEqual(expected, output)
170
171 def process():
172 for case_index in range(100):
173 with self.subTest(case_index=case_index):
174 in0 = hash_256(f"{case_index} in0") & mask
175 in1 = hash_256(f"{case_index} in1") & mask
176 in2 = hash_256(f"{case_index} in2") & mask
177 lut = hash_256(f"{case_index} lut") & lut_mask
178 yield from case(in0, in1, in2, lut)
179 with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
180 sim.add_process(process)
181 sim.run()
182
183 def tst_formal(self, cls):
184 dut = cls(3, 16)
185 m = Module()
186 m.submodules.dut = dut
187 m.d.comb += dut.inputs[0].eq(AnyConst(dut.width))
188 m.d.comb += dut.inputs[1].eq(AnyConst(dut.width))
189 m.d.comb += dut.inputs[2].eq(AnyConst(dut.width))
190 m.d.comb += dut.lut.eq(AnyConst(dut.lut.width))
191 for i in range(dut.width):
192 lut_index = Signal(dut.input_count, name=f"lut_index_{i}")
193 for j in range(dut.input_count):
194 m.d.comb += lut_index[j].eq(dut.inputs[j][i])
195 for j in range(dut.lut.width):
196 with m.If(lut_index == j):
197 m.d.comb += Assert(dut.lut[j] == dut.output[i])
198 formal(self, m)
199
200 def test(self):
201 self.tst(BitwiseLut)
202
203 def test_tree(self):
204 self.tst(TreeBitwiseLut)
205
206 def test_formal(self):
207 self.tst_formal(BitwiseLut)
208
209 def test_tree_formal(self):
210 self.tst_formal(TreeBitwiseLut)
211
212
213 if __name__ == "__main__":
214 unittest.main()