1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
4 from contextlib
import contextmanager
6 from hashlib
import sha256
9 from nmigen
.back
import rtlil
11 from nmigen
.hdl
.ast
import AnyConst
, Assert
, Signal
12 from nmigen
.hdl
.dsl
import Module
13 from nmigen
.hdl
.ir
import Fragment
14 from nmutil
.get_test_path
import get_test_path
15 from nmutil
.lut
import BitwiseMux
, BitwiseLut
, TreeBitwiseLut
16 from nmigen
.sim
import Simulator
, Delay
20 def do_sim(test_case
, dut
, traces
=()):
22 path
= get_test_path(test_case
, "sim_test_out")
23 path
.parent
.mkdir(parents
=True, exist_ok
=True)
24 vcd_path
= path
.with_suffix(".vcd")
25 gtkw_path
= path
.with_suffix(".gtkw")
26 with sim
.write_vcd(vcd_path
.open("wt", encoding
="utf-8"),
27 gtkw_path
.open("wt", encoding
="utf-8"),
32 # copied from ieee754fpu/src/ieee754/partitioned_signal_tester.py
33 def formal(test_case
, hdl
, *, base_path
="formal_test_temp"):
34 hdl
= Fragment
.get(hdl
, platform
="formal")
35 path
= get_test_path(test_case
, base_path
)
36 shutil
.rmtree(path
, ignore_errors
=True)
37 path
.mkdir(parents
=True)
38 sby_name
= "config.sby"
39 sby_file
= path
/ sby_name
41 sby_file
.write_text(textwrap
.dedent(f
"""\
56 """), encoding
="utf-8")
57 sby
= shutil
.which('sby')
58 assert sby
is not None
59 with subprocess
.Popen(
61 cwd
=path
, text
=True, encoding
="utf-8",
62 stdin
=subprocess
.DEVNULL
, stdout
=subprocess
.PIPE
64 stdout
, stderr
= p
.communicate()
66 test_case
.fail(f
"Formal failed:\n{stdout}")
70 return int.from_bytes(
71 sha256(bytes(v
, encoding
='utf-8')).digest(),
76 class TestBitwiseMux(unittest
.TestCase
):
79 dut
= BitwiseMux(width
)
81 def case(sel
, t
, f
, expected
):
82 with self
.subTest(sel
=bin(sel
), t
=bin(t
), f
=bin(f
)):
87 output
= yield dut
.output
88 with self
.subTest(output
=bin(output
), expected
=bin(expected
)):
89 self
.assertEqual(expected
, output
)
92 for sel
in range(2 ** width
):
93 for t
in range(2 ** width
):
94 for f
in range(2**width
):
96 for i
in range(width
):
102 yield from case(sel
, t
, f
, expected
)
103 with
do_sim(self
, dut
, [dut
.sel
, dut
.t
, dut
.f
, dut
.output
]) as sim
:
104 sim
.add_process(process
)
107 def test_formal(self
):
109 dut
= BitwiseMux(width
)
111 m
.submodules
.dut
= dut
112 m
.d
.comb
+= dut
.sel
.eq(AnyConst(width
))
113 m
.d
.comb
+= dut
.f
.eq(AnyConst(width
))
114 m
.d
.comb
+= dut
.t
.eq(AnyConst(width
))
115 for i
in range(width
):
116 with m
.If(dut
.sel
[i
]):
117 m
.d
.comb
+= Assert(dut
.t
[i
] == dut
.output
[i
])
119 m
.d
.comb
+= Assert(dut
.f
[i
] == dut
.output
[i
])
123 class TestBitwiseLut(unittest
.TestCase
):
126 mask
= 2 ** dut
.width
- 1
127 lut_mask
= 2 ** dut
.lut
.width
- 1
128 if cls
is TreeBitwiseLut
:
129 mux_inputs
= {k
: s
.name
for k
, s
in dut
._mux
_inputs
.items()}
130 self
.assertEqual(mux_inputs
, {
131 (): 'mux_input_0bxxx',
132 (False,): 'mux_input_0bxx0',
133 (False, False): 'mux_input_0bx00',
134 (False, False, False): 'mux_input_0b000',
135 (False, False, True): 'mux_input_0b100',
136 (False, True): 'mux_input_0bx10',
137 (False, True, False): 'mux_input_0b010',
138 (False, True, True): 'mux_input_0b110',
139 (True,): 'mux_input_0bxx1',
140 (True, False): 'mux_input_0bx01',
141 (True, False, False): 'mux_input_0b001',
142 (True, False, True): 'mux_input_0b101',
143 (True, True): 'mux_input_0bx11',
144 (True, True, False): 'mux_input_0b011',
145 (True, True, True): 'mux_input_0b111'
148 def case(in0
, in1
, in2
, lut
):
150 for i
in range(dut
.width
):
158 if lut
& 2 ** lut_index
:
160 with self
.subTest(in0
=bin(in0
), in1
=bin(in1
), in2
=bin(in2
),
162 yield dut
.inputs
[0].eq(in0
)
163 yield dut
.inputs
[1].eq(in1
)
164 yield dut
.inputs
[2].eq(in2
)
165 yield dut
.lut
.eq(lut
)
167 output
= yield dut
.output
168 with self
.subTest(output
=bin(output
), expected
=bin(expected
)):
169 self
.assertEqual(expected
, output
)
172 for case_index
in range(100):
173 with self
.subTest(case_index
=case_index
):
174 in0
= hash_256(f
"{case_index} in0") & mask
175 in1
= hash_256(f
"{case_index} in1") & mask
176 in2
= hash_256(f
"{case_index} in2") & mask
177 lut
= hash_256(f
"{case_index} lut") & lut_mask
178 yield from case(in0
, in1
, in2
, lut
)
179 with
do_sim(self
, dut
, [*dut
.inputs
, dut
.lut
, dut
.output
]) as sim
:
180 sim
.add_process(process
)
183 def tst_formal(self
, cls
):
186 m
.submodules
.dut
= dut
187 m
.d
.comb
+= dut
.inputs
[0].eq(AnyConst(dut
.width
))
188 m
.d
.comb
+= dut
.inputs
[1].eq(AnyConst(dut
.width
))
189 m
.d
.comb
+= dut
.inputs
[2].eq(AnyConst(dut
.width
))
190 m
.d
.comb
+= dut
.lut
.eq(AnyConst(dut
.lut
.width
))
191 for i
in range(dut
.width
):
192 lut_index
= Signal(dut
.input_count
, name
=f
"lut_index_{i}")
193 for j
in range(dut
.input_count
):
194 m
.d
.comb
+= lut_index
[j
].eq(dut
.inputs
[j
][i
])
195 for j
in range(dut
.lut
.width
):
196 with m
.If(lut_index
== j
):
197 m
.d
.comb
+= Assert(dut
.lut
[j
] == dut
.output
[i
])
204 self
.tst(TreeBitwiseLut
)
206 def test_formal(self
):
207 self
.tst_formal(BitwiseLut
)
209 def test_tree_formal(self
):
210 self
.tst_formal(TreeBitwiseLut
)
213 if __name__
== "__main__":