d36cb4e9af40b980ddf9dffb6684ed92668fd057
[nmutil.git] / src / nmutil / test / test_lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from contextlib import contextmanager
5 import unittest
6 from hashlib import sha256
7 import shutil
8 import subprocess
9 from nmigen.back import rtlil
10 import textwrap
11 from nmigen.hdl.ast import AnyConst, Assert, Signal
12 from nmigen.hdl.dsl import Module
13 from nmigen.hdl.ir import Fragment
14 from nmutil.get_test_path import get_test_path
15 from nmutil.lut import BitwiseMux, BitwiseLut
16 from nmigen.sim import Simulator, Delay
17
18
19 @contextmanager
20 def do_sim(test_case, dut, traces=()):
21 sim = Simulator(dut)
22 path = get_test_path(test_case, "sim_test_out")
23 path.parent.mkdir(parents=True, exist_ok=True)
24 vcd_path = path.with_suffix(".vcd")
25 gtkw_path = path.with_suffix(".gtkw")
26 with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
27 gtkw_path.open("wt", encoding="utf-8"),
28 traces=traces):
29 yield sim
30
31
32 # copied from ieee754fpu/src/ieee754/partitioned_signal_tester.py
33 def formal(test_case, hdl, *, base_path="formal_test_temp"):
34 hdl = Fragment.get(hdl, platform="formal")
35 path = get_test_path(test_case, base_path)
36 shutil.rmtree(path, ignore_errors=True)
37 path.mkdir(parents=True)
38 sby_name = "config.sby"
39 sby_file = path / sby_name
40
41 sby_file.write_text(textwrap.dedent(f"""\
42 [options]
43 mode prove
44 depth 1
45 wait on
46
47 [engines]
48 smtbmc
49
50 [script]
51 read_rtlil top.il
52 prep
53
54 [file top.il]
55 {rtlil.convert(hdl)}
56 """), encoding="utf-8")
57 sby = shutil.which('sby')
58 assert sby is not None
59 with subprocess.Popen(
60 [sby, sby_name],
61 cwd=path, text=True, encoding="utf-8",
62 stdin=subprocess.DEVNULL, stdout=subprocess.PIPE
63 ) as p:
64 stdout, stderr = p.communicate()
65 if p.returncode != 0:
66 test_case.fail(f"Formal failed:\n{stdout}")
67
68
69 def hash_256(v):
70 return int.from_bytes(
71 sha256(bytes(v, encoding='utf-8')).digest(),
72 byteorder='little'
73 )
74
75
76 class TestBitwiseMux(unittest.TestCase):
77 def test(self):
78 width = 2
79 dut = BitwiseMux(width)
80
81 def case(sel, t, f, expected):
82 with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
83 yield dut.sel.eq(sel)
84 yield dut.t.eq(t)
85 yield dut.f.eq(f)
86 yield Delay(1e-6)
87 output = yield dut.output
88 with self.subTest(output=bin(output), expected=bin(expected)):
89 self.assertEqual(expected, output)
90
91 def process():
92 for sel in range(2 ** width):
93 for t in range(2 ** width):
94 for f in range(2**width):
95 expected = 0
96 for i in range(width):
97 if sel & 2 ** i:
98 if t & 2 ** i:
99 expected |= 2 ** i
100 elif f & 2 ** i:
101 expected |= 2 ** i
102 yield from case(sel, t, f, expected)
103 with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
104 sim.add_process(process)
105 sim.run()
106
107 def test_formal(self):
108 width = 2
109 dut = BitwiseMux(width)
110 m = Module()
111 m.submodules.dut = dut
112 m.d.comb += dut.sel.eq(AnyConst(width))
113 m.d.comb += dut.f.eq(AnyConst(width))
114 m.d.comb += dut.t.eq(AnyConst(width))
115 for i in range(width):
116 with m.If(dut.sel[i]):
117 m.d.comb += Assert(dut.t[i] == dut.output[i])
118 with m.Else():
119 m.d.comb += Assert(dut.f[i] == dut.output[i])
120 formal(self, m)
121
122
123 class TestBitwiseLut(unittest.TestCase):
124 def test(self):
125 dut = BitwiseLut(3, 16)
126 mask = 2 ** dut.width - 1
127 lut_mask = 2 ** dut.lut.width - 1
128 mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
129 self.assertEqual(mux_inputs, {
130 (): 'mux_input_0bxxx',
131 (False,): 'mux_input_0bxx0',
132 (False, False): 'mux_input_0bx00',
133 (False, False, False): 'mux_input_0b000',
134 (False, False, True): 'mux_input_0b100',
135 (False, True): 'mux_input_0bx10',
136 (False, True, False): 'mux_input_0b010',
137 (False, True, True): 'mux_input_0b110',
138 (True,): 'mux_input_0bxx1',
139 (True, False): 'mux_input_0bx01',
140 (True, False, False): 'mux_input_0b001',
141 (True, False, True): 'mux_input_0b101',
142 (True, True): 'mux_input_0bx11',
143 (True, True, False): 'mux_input_0b011',
144 (True, True, True): 'mux_input_0b111'
145 })
146
147 def case(in0, in1, in2, lut):
148 expected = 0
149 for i in range(dut.width):
150 lut_index = 0
151 if in0 & 2 ** i:
152 lut_index |= 2 ** 0
153 if in1 & 2 ** i:
154 lut_index |= 2 ** 1
155 if in2 & 2 ** i:
156 lut_index |= 2 ** 2
157 if lut & 2 ** lut_index:
158 expected |= 2 ** i
159 with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
160 lut=bin(lut)):
161 yield dut.inputs[0].eq(in0)
162 yield dut.inputs[1].eq(in1)
163 yield dut.inputs[2].eq(in2)
164 yield dut.lut.eq(lut)
165 yield Delay(1e-6)
166 output = yield dut.output
167 with self.subTest(output=bin(output), expected=bin(expected)):
168 self.assertEqual(expected, output)
169
170 def process():
171 for case_index in range(100):
172 with self.subTest(case_index=case_index):
173 in0 = hash_256(f"{case_index} in0") & mask
174 in1 = hash_256(f"{case_index} in1") & mask
175 in2 = hash_256(f"{case_index} in2") & mask
176 lut = hash_256(f"{case_index} lut") & lut_mask
177 yield from case(in0, in1, in2, lut)
178 with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
179 sim.add_process(process)
180 sim.run()
181
182 def test_formal(self):
183 dut = BitwiseLut(3, 16)
184 m = Module()
185 m.submodules.dut = dut
186 m.d.comb += dut.inputs[0].eq(AnyConst(dut.width))
187 m.d.comb += dut.inputs[1].eq(AnyConst(dut.width))
188 m.d.comb += dut.inputs[2].eq(AnyConst(dut.width))
189 m.d.comb += dut.lut.eq(AnyConst(dut.lut.width))
190 for i in range(dut.width):
191 lut_index = Signal(dut.input_count, name=f"lut_index_{i}")
192 for j in range(dut.input_count):
193 m.d.comb += lut_index[j].eq(dut.inputs[j][i])
194 for j in range(dut.lut.width):
195 with m.If(lut_index == j):
196 m.d.comb += Assert(dut.lut[j] == dut.output[i])
197 formal(self, m)
198
199
200 if __name__ == "__main__":
201 unittest.main()