1 # SPDX-License-Identifier: LGPL-3-or-later
2 # See Notices.txt for copyright information
5 from nmigen
.hdl
.ast
import (AnyConst
, Assert
, Signal
, Value
, ValueCastable
)
6 from nmigen
.hdl
.dsl
import Module
7 from nmigen
.hdl
.ir
import Elaboratable
, Fragment
8 from nmigen
.sim
import Simulator
, Delay
9 from ieee754
.part
.partsig
import SimdSignal
, PartitionPoints
13 from hashlib
import sha256
14 from nmigen
.back
import rtlil
15 from nmutil
.get_test_path
import get_test_path
16 from collections
.abc
import Sequence
19 def formal(test_case
, hdl
, *, base_path
="formal_test_temp"):
20 hdl
= Fragment
.get(hdl
, platform
="formal")
21 path
= get_test_path(test_case
, base_path
)
22 shutil
.rmtree(path
, ignore_errors
=True)
23 path
.mkdir(parents
=True)
24 sby_name
= "config.sby"
25 sby_file
= path
/ sby_name
27 sby_file
.write_text(textwrap
.dedent(f
"""\
42 """), encoding
="utf-8")
43 sby
= shutil
.which('sby')
44 assert sby
is not None
45 with subprocess
.Popen(
47 cwd
=path
, text
=True, encoding
="utf-8",
48 stdin
=subprocess
.DEVNULL
, stdout
=subprocess
.PIPE
50 stdout
, stderr
= p
.communicate()
52 test_case
.fail(f
"Formal failed:\n{stdout}")
57 def cast(layout
, width
=None):
58 if isinstance(layout
, Layout
):
60 return Layout(layout
, width
)
62 def __init__(self
, part_indexes
, width
=None):
63 part_indexes
= set(part_indexes
)
64 for p
in part_indexes
:
65 assert isinstance(p
, int)
68 width
= Layout
.get_width(width
)
69 for p
in part_indexes
:
71 part_indexes
.add(width
)
73 part_indexes
= list(part_indexes
)
75 self
.part_indexes
= tuple(part_indexes
)
76 """bit indexes of partition points in sorted order, always
77 includes `0` and `self.width`"""
80 for start_index
in range(len(self
.part_indexes
)):
81 start
= self
.part_indexes
[start_index
]
82 for end
in self
.part_indexes
[start_index
+ 1:]:
83 sizes
.append(end
- start
)
85 # build in sorted order
86 self
.__lane
_starts
_for
_sizes
= {size
: {} for size
in sizes
}
87 """keys are in sorted order"""
89 for start_index
in range(len(self
.part_indexes
)):
90 start
= self
.part_indexes
[start_index
]
91 for end
in self
.part_indexes
[start_index
+ 1:]:
92 self
.__lane
_starts
_for
_sizes
[end
- start
][start
] = None
96 return self
.part_indexes
[-1]
99 def part_signal_count(self
):
100 return max(len(self
.part_indexes
) - 2, 0)
103 def get_width(width
):
104 if isinstance(width
, Layout
):
106 assert isinstance(width
, int)
110 def partition_points_signals(self
, name
=None,
113 name
= Signal(src_loc_at
=1 + src_loc_at
).name
115 for i
in self
.part_indexes
[1:-1]:
116 pps
[i
] = Signal(name
=f
"{name}_{i}", src_loc_at
=1 + src_loc_at
)
117 return PartitionPoints(pps
)
120 return f
"Layout({self.part_indexes}, width={self.width})"
123 if isinstance(o
, Layout
):
124 return self
.part_indexes
== o
.part_indexes
125 return NotImplemented
128 return hash(self
.part_indexes
)
130 def is_lane_valid(self
, start
, size
):
131 return start
in self
.__lane
_starts
_for
_sizes
.get(size
, ())
133 def lane_sizes(self
):
134 return self
.__lane
_starts
_for
_sizes
.keys()
136 def lane_starts_for_size(self
, size
):
137 return self
.__lane
_starts
_for
_sizes
[size
].keys()
139 def lanes_for_size(self
, size
):
140 for start
in self
.lane_starts_for_size(size
):
141 yield Lane(start
, size
, self
)
144 for size
in self
.lane_sizes():
145 yield from self
.lanes_for_size(size
)
147 def is_compatible(self
, other
):
148 other
= Layout
.cast(other
)
149 return len(self
.part_indexes
) == len(other
.part_indexes
)
151 def translate_lane_to(self
, lane
, target_layout
):
152 assert lane
.layout
== self
153 target_layout
= Layout
.cast(target_layout
)
154 assert self
.is_compatible(target_layout
)
155 start_index
= self
.part_indexes
.index(lane
.start
)
156 end_index
= self
.part_indexes
.index(lane
.end
)
157 target_start
= target_layout
.part_indexes
[start_index
]
158 target_end
= target_layout
.part_indexes
[end_index
]
159 return Lane(target_start
, target_end
- target_start
, target_layout
)
163 def __init__(self
, start
, size
, layout
):
164 self
.layout
= Layout
.cast(layout
)
165 assert self
.layout
.is_lane_valid(start
, size
)
170 return (f
"Lane(start={self.start}, size={self.size}, "
171 f
"layout={self.layout})")
174 if isinstance(o
, Lane
):
175 return self
.start
== o
.start
and self
.size
== o
.size \
176 and self
.layout
== o
.layout
177 return NotImplemented
180 return hash((self
.start
, self
.size
, self
.layout
))
183 return slice(self
.start
, self
.end
)
187 return self
.start
+ self
.size
189 def translate_to(self
, target_layout
):
190 return self
.layout
.translate_lane_to(self
, target_layout
)
192 def is_active(self
, partition_points
):
193 def get_partition_point(index
, invert
):
194 if index
== 0 or index
== len(self
.layout
.part_indexes
) - 1:
196 if isinstance(partition_points
, Sequence
):
197 retval
= partition_points
[index
]
199 retval
= partition_points
[self
.layout
.part_indexes
[index
]]
200 if isinstance(retval
, bool):
204 retval
= Value
.cast(retval
)
209 start_index
= self
.layout
.part_indexes
.index(self
.start
)
210 end_index
= self
.layout
.part_indexes
.index(self
.end
)
211 retval
= get_partition_point(start_index
, False) \
212 & get_partition_point(end_index
, False)
213 for i
in range(start_index
+ 1, end_index
):
214 retval
&= get_partition_point(i
, True)
219 class SimdSignalTester
:
221 def __init__(self
, m
, operation
, reference
, *layouts
,
222 src_loc_at
=0, additional_case_count
=30,
223 special_cases
=(), seed
=""):
225 self
.operation
= operation
226 self
.reference
= reference
229 for layout
in layouts
:
230 layout
= Layout
.cast(layout
)
231 if len(self
.layouts
) > 0:
232 assert self
.layouts
[0].is_compatible(layout
)
233 self
.layouts
.append(layout
)
234 name
= f
"input_{len(self.inputs)}"
236 layout
.partition_points_signals(name
=name
,
237 src_loc_at
=1 + src_loc_at
),
241 self
.inputs
.append(ps
)
242 assert len(self
.layouts
) != 0, "must have at least one input layout"
243 for i
in range(1, len(self
.inputs
)):
244 for j
in range(1, len(self
.layouts
[0].part_indexes
) - 1):
245 lhs_part_point
= self
.layouts
[i
].part_indexes
[j
]
246 rhs_part_point
= self
.layouts
[0].part_indexes
[j
]
247 lhs
= self
.inputs
[i
].partpoints
[lhs_part_point
]
248 rhs
= self
.inputs
[0].partpoints
[rhs_part_point
]
249 m
.d
.comb
+= lhs
.eq(rhs
)
250 self
.special_cases
= list(special_cases
)
251 self
.case_count
= additional_case_count
+ len(self
.special_cases
)
253 self
.case_number
= Signal(64)
254 self
.test_output
= operation(tuple(self
.inputs
))
255 assert isinstance(self
.test_output
, SimdSignal
)
256 self
.test_output_layout
= Layout(
257 self
.test_output
.partpoints
, self
.test_output
.sig
.width
)
258 assert self
.test_output_layout
.is_compatible(self
.layouts
[0])
259 self
.reference_output_values
= {}
260 for lane
in self
.layouts
[0].lanes():
262 for inp
, layout
in zip(self
.inputs
, self
.layouts
):
263 in_t
.append(inp
.sig
[lane
.translate_to(layout
).as_slice()])
264 v
= Value
.cast(reference(lane
, tuple(in_t
)))
265 self
.reference_output_values
[lane
] = v
266 self
.reference_outputs
= {}
267 for lane
, value
in self
.reference_output_values
.items():
268 s
= Signal(value
.shape(),
269 name
=f
"reference_output_{lane.start}_{lane.size}")
270 self
.reference_outputs
[lane
] = s
271 m
.d
.comb
+= s
.eq(value
)
273 def __hash_256(self
, v
):
274 return int.from_bytes(
275 sha256(bytes(self
.seed
+ v
, encoding
='utf-8')).digest(),
279 def __hash(self
, v
, bits
):
281 for i
in range(0, bits
, 256):
283 retval |
= self
.__hash
_256(f
" {v} {i}")
284 return retval
& ((1 << bits
) - 1)
286 def __get_case(self
, case_number
):
287 if case_number
< len(self
.special_cases
):
288 return self
.special_cases
[case_number
]
290 bits
= self
.__hash
(f
"{case_number} trial {trial}",
291 self
.layouts
[0].part_signal_count
)
292 bits |
= 1 |
(1 << len(self
.layouts
[0].part_indexes
)) |
(bits
<< 1)
294 for i
in range(len(self
.layouts
[0].part_indexes
)):
295 part_starts
.append((bits
& (1 << i
)) != 0)
297 for i
in range(len(self
.layouts
)):
298 inputs
.append(self
.__hash
(f
"{case_number} input {i}",
299 self
.layouts
[i
].width
))
300 return tuple(part_starts
), tuple(inputs
)
302 def __format_case(self
, case
):
303 part_starts
, inputs
= case
304 str_inputs
= [hex(i
) for i
in inputs
]
305 return f
"part_starts={part_starts}, inputs={str_inputs}"
307 def __setup_case(self
, case_number
, case
=None):
309 case
= self
.__get
_case
(case_number
)
310 yield self
.case_number
.eq(case_number
)
311 part_starts
, inputs
= case
312 part_indexes
= self
.layouts
[0].part_indexes
313 assert len(part_starts
) == len(part_indexes
)
314 for i
in range(1, len(part_starts
) - 1):
315 yield self
.inputs
[0].partpoints
[part_indexes
[i
]].eq(part_starts
[i
])
316 for i
in range(len(self
.inputs
)):
317 yield self
.inputs
[i
].sig
.eq(inputs
[i
])
319 def run_sim(self
, test_case
, *, engine
=None, base_path
="sim_test_out"):
321 sim
= Simulator(self
.m
)
323 sim
= Simulator(self
.m
, engine
=engine
)
325 def check_active_lane(lane
):
326 reference
= yield self
.reference_outputs
[lane
]
327 output
= yield self
.test_output
.sig
[
328 lane
.translate_to(self
.test_output_layout
).as_slice()]
329 test_case
.assertEqual(hex(reference
), hex(output
))
331 def check_case(case
):
332 part_starts
, inputs
= case
333 for i
in range(1, len(self
.layouts
[0].part_indexes
) - 1):
334 part_point
= yield self
.test_output
.partpoints
[
335 self
.test_output_layout
.part_indexes
[i
]]
336 test_case
.assertEqual(part_point
, part_starts
[i
])
337 for lane
in self
.layouts
[0].lanes():
338 with test_case
.subTest(lane
=lane
):
339 active
= lane
.is_active(part_starts
)
341 yield from check_active_lane(lane
)
344 for case_number
in range(self
.case_count
):
345 with test_case
.subTest(case_number
=str(case_number
)):
346 case
= self
.__get
_case
(case_number
)
347 with test_case
.subTest(case
=self
.__format
_case
(case
)):
348 yield from self
.__setup
_case
(case_number
, case
)
350 yield from check_case(case
)
351 sim
.add_process(process
)
352 path
= get_test_path(test_case
, base_path
)
353 path
.parent
.mkdir(parents
=True, exist_ok
=True)
354 vcd_path
= path
.with_suffix(".vcd")
355 gtkw_path
= path
.with_suffix(".gtkw")
356 traces
= [self
.case_number
]
357 for i
in self
.layouts
[0].part_indexes
[1:-1]:
358 traces
.append(self
.inputs
[0].partpoints
[i
])
359 for inp
in self
.inputs
:
360 traces
.append(inp
.sig
)
361 traces
.extend(self
.reference_outputs
.values())
362 traces
.append(self
.test_output
.sig
)
363 with sim
.write_vcd(vcd_path
.open("wt", encoding
="utf-8"),
364 gtkw_path
.open("wt", encoding
="utf-8"),
368 def run_formal(self
, test_case
, **kwargs
):
369 for part_point
in self
.inputs
[0].partpoints
.values():
370 self
.m
.d
.comb
+= part_point
.eq(AnyConst(1))
371 for i
in range(len(self
.inputs
)):
372 s
= self
.inputs
[i
].sig
373 self
.m
.d
.comb
+= s
.eq(AnyConst(s
.shape()))
374 for i
in range(1, len(self
.layouts
[0].part_indexes
) - 1):
375 in_part_point
= self
.inputs
[0].partpoints
[
376 self
.layouts
[0].part_indexes
[i
]]
377 out_part_point
= self
.test_output
.partpoints
[
378 self
.test_output_layout
.part_indexes
[i
]]
379 self
.m
.d
.comb
+= Assert(in_part_point
== out_part_point
)
381 def check_active_lane(lane
):
382 reference
= self
.reference_outputs
[lane
]
383 output
= self
.test_output
.sig
[
384 lane
.translate_to(self
.test_output_layout
).as_slice()]
385 yield Assert(reference
== output
)
387 for lane
in self
.layouts
[0].lanes():
388 with test_case
.subTest(lane
=lane
):
389 a
= check_active_lane(lane
)
390 with self
.m
.If(lane
.is_active(self
.inputs
[0].partpoints
)):
392 formal(test_case
, self
.m
, **kwargs
)