switch test_lut to use FHDLTestCase
[nmutil.git] / src / nmutil / test / test_lut.py
1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
3
4 from contextlib import contextmanager
5 import unittest
6 from hashlib import sha256
7 from nmigen.hdl.ast import AnyConst, Assert, Signal
8 from nmigen.hdl.dsl import Module
9 from nmutil.formaltest import FHDLTestCase
10 from nmutil.get_test_path import get_test_path
11 from nmutil.lut import BitwiseMux, BitwiseLut, TreeBitwiseLut
12 from nmigen.sim import Simulator, Delay
13
14
15 @contextmanager
16 def do_sim(test_case, dut, traces=()):
17 sim = Simulator(dut)
18 path = get_test_path(test_case, "sim_test_out")
19 path.parent.mkdir(parents=True, exist_ok=True)
20 vcd_path = path.with_suffix(".vcd")
21 gtkw_path = path.with_suffix(".gtkw")
22 with sim.write_vcd(vcd_path.open("wt", encoding="utf-8"),
23 gtkw_path.open("wt", encoding="utf-8"),
24 traces=traces):
25 yield sim
26
27
28 def hash_256(v):
29 return int.from_bytes(
30 sha256(bytes(v, encoding='utf-8')).digest(),
31 byteorder='little'
32 )
33
34
35 class TestBitwiseMux(FHDLTestCase):
36 def test(self):
37 width = 2
38 dut = BitwiseMux(width)
39
40 def case(sel, t, f, expected):
41 with self.subTest(sel=bin(sel), t=bin(t), f=bin(f)):
42 yield dut.sel.eq(sel)
43 yield dut.t.eq(t)
44 yield dut.f.eq(f)
45 yield Delay(1e-6)
46 output = yield dut.output
47 with self.subTest(output=bin(output), expected=bin(expected)):
48 self.assertEqual(expected, output)
49
50 def process():
51 for sel in range(2 ** width):
52 for t in range(2 ** width):
53 for f in range(2**width):
54 expected = 0
55 for i in range(width):
56 if sel & 2 ** i:
57 if t & 2 ** i:
58 expected |= 2 ** i
59 elif f & 2 ** i:
60 expected |= 2 ** i
61 yield from case(sel, t, f, expected)
62 with do_sim(self, dut, [dut.sel, dut.t, dut.f, dut.output]) as sim:
63 sim.add_process(process)
64 sim.run()
65
66 def test_formal(self):
67 width = 2
68 dut = BitwiseMux(width)
69 m = Module()
70 m.submodules.dut = dut
71 m.d.comb += dut.sel.eq(AnyConst(width))
72 m.d.comb += dut.f.eq(AnyConst(width))
73 m.d.comb += dut.t.eq(AnyConst(width))
74 for i in range(width):
75 with m.If(dut.sel[i]):
76 m.d.comb += Assert(dut.t[i] == dut.output[i])
77 with m.Else():
78 m.d.comb += Assert(dut.f[i] == dut.output[i])
79 self.assertFormal(m)
80
81
82 class TestBitwiseLut(FHDLTestCase):
83 def tst(self, cls):
84 dut = cls(3, 16)
85 mask = 2 ** dut.width - 1
86 lut_mask = 2 ** dut.lut.width - 1
87 if cls is TreeBitwiseLut:
88 mux_inputs = {k: s.name for k, s in dut._mux_inputs.items()}
89 self.assertEqual(mux_inputs, {
90 (): 'mux_input_0bxxx',
91 (False,): 'mux_input_0bxx0',
92 (False, False): 'mux_input_0bx00',
93 (False, False, False): 'mux_input_0b000',
94 (False, False, True): 'mux_input_0b100',
95 (False, True): 'mux_input_0bx10',
96 (False, True, False): 'mux_input_0b010',
97 (False, True, True): 'mux_input_0b110',
98 (True,): 'mux_input_0bxx1',
99 (True, False): 'mux_input_0bx01',
100 (True, False, False): 'mux_input_0b001',
101 (True, False, True): 'mux_input_0b101',
102 (True, True): 'mux_input_0bx11',
103 (True, True, False): 'mux_input_0b011',
104 (True, True, True): 'mux_input_0b111'
105 })
106
107 def case(in0, in1, in2, lut):
108 expected = 0
109 for i in range(dut.width):
110 lut_index = 0
111 if in0 & 2 ** i:
112 lut_index |= 2 ** 0
113 if in1 & 2 ** i:
114 lut_index |= 2 ** 1
115 if in2 & 2 ** i:
116 lut_index |= 2 ** 2
117 if lut & 2 ** lut_index:
118 expected |= 2 ** i
119 with self.subTest(in0=bin(in0), in1=bin(in1), in2=bin(in2),
120 lut=bin(lut)):
121 yield dut.inputs[0].eq(in0)
122 yield dut.inputs[1].eq(in1)
123 yield dut.inputs[2].eq(in2)
124 yield dut.lut.eq(lut)
125 yield Delay(1e-6)
126 output = yield dut.output
127 with self.subTest(output=bin(output), expected=bin(expected)):
128 self.assertEqual(expected, output)
129
130 def process():
131 for case_index in range(100):
132 with self.subTest(case_index=case_index):
133 in0 = hash_256(f"{case_index} in0") & mask
134 in1 = hash_256(f"{case_index} in1") & mask
135 in2 = hash_256(f"{case_index} in2") & mask
136 lut = hash_256(f"{case_index} lut") & lut_mask
137 yield from case(in0, in1, in2, lut)
138 with do_sim(self, dut, [*dut.inputs, dut.lut, dut.output]) as sim:
139 sim.add_process(process)
140 sim.run()
141
142 def tst_formal(self, cls):
143 dut = cls(3, 16)
144 m = Module()
145 m.submodules.dut = dut
146 m.d.comb += dut.inputs[0].eq(AnyConst(dut.width))
147 m.d.comb += dut.inputs[1].eq(AnyConst(dut.width))
148 m.d.comb += dut.inputs[2].eq(AnyConst(dut.width))
149 m.d.comb += dut.lut.eq(AnyConst(dut.lut.width))
150 for i in range(dut.width):
151 lut_index = Signal(dut.input_count, name=f"lut_index_{i}")
152 for j in range(dut.input_count):
153 m.d.comb += lut_index[j].eq(dut.inputs[j][i])
154 for j in range(dut.lut.width):
155 with m.If(lut_index == j):
156 m.d.comb += Assert(dut.lut[j] == dut.output[i])
157 self.assertFormal(m)
158
159 def test(self):
160 self.tst(BitwiseLut)
161
162 def test_tree(self):
163 self.tst(TreeBitwiseLut)
164
165 def test_formal(self):
166 self.tst_formal(BitwiseLut)
167
168 def test_tree_formal(self):
169 self.tst_formal(TreeBitwiseLut)
170
171
172 if __name__ == "__main__":
173 unittest.main()