# SPDX-License-Identifier: LGPL-3-or-later
# See Notices.txt for copyright information
from enum import Enum
-from typing import (Any, Callable, Dict, Generator, Iterable, List, Mapping,
- Optional, Sequence, Tuple, Union, final, overload)
import shutil
from nmigen.hdl.ast import (AnyConst, Assert, Signal, Value, ValueCastable)
from nmigen.hdl.dsl import Module
from nmutil.get_test_path import get_test_path, _StrPath
-_PartitionedSignalTestable = Callable[[Tuple[PartitionedSignal, ...]],
- PartitionedSignal]
-
-_WidthCastable = Union["Layout", int]
-_LayoutCastable = Union["Layout", Mapping[int, Any], Iterable[int]]
-_ValueCastableType = Union[Value, int, Enum, ValueCastable]
-_FragmentLike = Union[Elaboratable, Fragment]
-
-
-def formal(test_case: unittest.TestCase, hdl: _FragmentLike, *,
- base_path: _StrPath = "formal_test_temp"):
+def formal(test_case, hdl, *, base_path="formal_test_temp"):
hdl = Fragment.get(hdl, platform="formal")
path = get_test_path(test_case, base_path)
shutil.rmtree(path, ignore_errors=True)
@final
class Layout:
- __lane_starts_for_sizes: Dict[int, Dict[int, None]]
+ __lane_starts_for_sizes
"""keys are in sorted order"""
- part_indexes: Tuple[int, ...]
+ part_indexes
"""bit indexes of partition points in sorted order, always includes
`0` and `self.width`"""
@staticmethod
- def cast(layout: _LayoutCastable,
- width: Optional[_WidthCastable] = None) -> "Layout":
+ def cast(layout, width=None):
if isinstance(layout, Layout):
return layout
return Layout(layout, width)
- def __init__(self,
- part_indexes: Union[Mapping[int, Any], Iterable[int]],
- width: Optional[_WidthCastable] = None):
+ def __init__(self, part_indexes, width=None):
part_indexes = set(part_indexes)
for p in part_indexes:
assert isinstance(p, int)
self.__lane_starts_for_sizes[end - start][start] = None
@property
- def width(self) -> int:
+ def width(self):
return self.part_indexes[-1]
@property
- def part_signal_count(self) -> int:
+ def part_signal_count(self):
return max(len(self.part_indexes) - 2, 0)
@staticmethod
- def get_width(width: _WidthCastable) -> int:
+ def get_width(width):
if isinstance(width, Layout):
width = width.width
assert isinstance(width, int)
assert width >= 0
return width
- def partition_points_signals(self, name: Optional[str] = None,
- src_loc_at: int = 0) -> PartitionPoints:
+ def partition_points_signals(self, nameNone,
+ src_loc_at=0):
if name is None:
name = Signal(src_loc_at=1 + src_loc_at).name
- return PartitionPoints({
- i: Signal(name=f"{name}_{i}", src_loc_at=1 + src_loc_at)
- for i in self.part_indexes[1:-1]
- })
+ return PartitionPoints({ i for i in self.part_indexes[1:-1] })
- def __repr__(self) -> str:
+ def __repr__(self):
return f"Layout({self.part_indexes}, width={self.width})"
- def __eq__(self, o: object) -> bool:
+ def __eq__(self, o):
if isinstance(o, Layout):
return self.part_indexes == o.part_indexes
return NotImplemented
- def __hash__(self) -> int:
+ def __hash__(self):
return hash(self.part_indexes)
- def is_lane_valid(self, start: int, size: int) -> bool:
+ def is_lane_valid(self, start, size):
return start in self.__lane_starts_for_sizes.get(size, ())
- def lane_sizes(self) -> Iterable[int]:
+ def lane_sizes(self):
return self.__lane_starts_for_sizes.keys()
- def lane_starts_for_size(self, size: int) -> Iterable[int]:
+ def lane_starts_for_size(self, size):
return self.__lane_starts_for_sizes[size].keys()
- def lanes_for_size(self, size: int) -> Iterable["Lane"]:
+ def lanes_for_size(self, size):
for start in self.lane_starts_for_size(size):
yield Lane(start, size, self)
- def lanes(self) -> Iterable["Lane"]:
+ def lanes(self):
for size in self.lane_sizes():
yield from self.lanes_for_size(size)
- def is_compatible(self, other: _LayoutCastable) -> bool:
+ def is_compatible(self, other):
other = Layout.cast(other)
return len(self.part_indexes) == len(other.part_indexes)
- def translate_lane_to(self, lane: "Lane",
- target_layout: _LayoutCastable) -> "Lane":
+ def translate_lane_to(self, lane, target_layout):
assert lane.layout == self
target_layout = Layout.cast(target_layout)
assert self.is_compatible(target_layout)
@final
class Lane:
- def __init__(self, start: int, size: int, layout: _LayoutCastable):
+ def __init__(self, start, size, layout):
self.layout = Layout.cast(layout)
assert self.layout.is_lane_valid(start, size)
self.start = start
self.size = size
- def __repr__(self) -> str:
+ def __repr__(self):
return (f"Lane(start={self.start}, size={self.size}, "
f"layout={self.layout})")
- def __eq__(self, o: object) -> bool:
+ def __eq__(self, o):
if isinstance(o, Lane):
return self.start == o.start and self.size == o.size \
and self.layout == o.layout
return NotImplemented
- def __hash__(self) -> int:
+ def __hash__(self):
return hash((self.start, self.size, self.layout))
- def as_slice(self) -> slice:
+ def as_slice(self):
return slice(self.start, self.end)
@property
- def end(self) -> int:
+ def end(self):
return self.start + self.size
- def translate_to(self, target_layout: _LayoutCastable) -> "Lane":
+ def translate_to(self, target_layout):
return self.layout.translate_lane_to(self, target_layout)
@overload
- def is_active(self, partition_points: Sequence[bool]) -> bool: ...
+ def is_active(self, partition_points): ...
@overload
- def is_active(self, partition_points: Sequence[_ValueCastableType]
- ) -> Union[Value, bool]: ...
+ def is_active(self, partition_points): ...
@overload
- def is_active(self, partition_points: Mapping[int, bool]) -> bool: ...
+ def is_active(self, partition_points): ...
@overload
- def is_active(self, partition_points: Mapping[int, _ValueCastableType]
- ) -> Union[Value, bool]: ...
+ def is_active(self, partition_points): ...
def is_active(self, partition_points):
- def get_partition_point(index: int, invert: bool):
+ def get_partition_point(index, invert):
if index == 0 or index == len(self.layout.part_indexes) - 1:
return True
if isinstance(partition_points, Sequence):
if invert:
return ~retval
return retval
+
start_index = self.layout.part_indexes.index(self.start)
end_index = self.layout.part_indexes.index(self.end)
retval = get_partition_point(start_index, False) \
& get_partition_point(end_index, False)
for i in range(start_index + 1, end_index):
retval &= get_partition_point(i, True)
- return retval
-
-_PartitionedSignalTestReference = Callable[[Lane, Tuple[Value, ...]],
- _ValueCastableType]
-
-_PartitionedSignalTestCasePartMode = Tuple[bool, ...]
-_PartitionedSignalTestCaseInputs = Tuple[int, ...]
-_PartitionedSignalTestCase = Tuple[_PartitionedSignalTestCasePartMode,
- _PartitionedSignalTestCaseInputs]
+ return retval
class PartitionedSignalTester:
- layouts: List[Layout]
- inputs: List[PartitionedSignal]
-
- def __init__(self,
- m: Module,
- operation: _PartitionedSignalTestable,
- reference: _PartitionedSignalTestReference,
- *layouts: _LayoutCastable,
- src_loc_at: int = 0,
- additional_case_count: int = 30,
- special_cases: Iterable[_PartitionedSignalTestCase] = (),
- seed: str = ""):
+
+ def __init__(self, m, operation, reference, *layouts,
+ src_loc_at=0, additional_case_count=30,
+ special_cases=(), seed=""):
self.m = m
self.operation = operation
self.reference = reference
self.test_output.partpoints, self.test_output.sig.width)
assert self.test_output_layout.is_compatible(self.layouts[0])
self.reference_output_values = {
- lane: Value.cast(reference(lane, tuple(
+ lane, tuple(
inp.sig[lane.translate_to(layout).as_slice()]
- for inp, layout in zip(self.inputs, self.layouts))))
+ for inp, layout in zip(self.inputs, self.layouts))
for lane in self.layouts[0].lanes()
}
self.reference_outputs = {
- lane: Signal(value.shape(),
- name=f"reference_output_{lane.start}_{lane.size}")
+ lane, name=f"reference_output_{lane.start}_{lane.size}")
for lane, value in self.reference_output_values.items()
}
for lane, value in self.reference_output_values.items():
m.d.comb += self.reference_outputs[lane].eq(value)
- def __hash_256(self, v: str) -> int:
+ def __hash_256(self, v):
return int.from_bytes(
sha256(bytes(self.seed + v, encoding='utf-8')).digest(),
byteorder='little'
)
- def __hash(self, v: str, bits: int) -> int:
+ def __hash(self, v, bits):
retval = 0
for i in range(0, bits, 256):
retval <<= 256
retval |= self.__hash_256(f" {v} {i}")
return retval & ((1 << bits) - 1)
- def __get_case(self, case_number: int) -> _PartitionedSignalTestCase:
+ def __get_case(self, case_number):
if case_number < len(self.special_cases):
return self.special_cases[case_number]
trial = 0
for i in range(len(self.layouts)))
return part_starts, inputs
- def __format_case(self, case: _PartitionedSignalTestCase) -> str:
+ def __format_case(self, case):
part_starts, inputs = case
str_inputs = [hex(i) for i in inputs]
return f"part_starts={part_starts}, inputs={str_inputs}"
- def __setup_case(self, case_number: int,
- case: Optional[_PartitionedSignalTestCase] = None
- ) -> Generator[Any, int, None]:
+ def __setup_case(self, case_number, case=None):
if case is None:
case = self.__get_case(case_number)
yield self.case_number.eq(case_number)
for i in range(len(self.inputs)):
yield self.inputs[i].sig.eq(inputs[i])
- def run_sim(self, test_case: unittest.TestCase, *,
+ def run_sim(self, test_case, *,
engine: Optional[str] = None,
base_path: _StrPath = "sim_test_out"):
if engine is None:
else:
sim = Simulator(self.m, engine=engine)
- def check_active_lane(lane: Lane):
+ def check_active_lane(lane):
reference = yield self.reference_outputs[lane]
output = yield self.test_output.sig[
lane.translate_to(self.test_output_layout).as_slice()]
test_case.assertEqual(hex(reference), hex(output))
- def check_case(case: _PartitionedSignalTestCase):
+ def check_case(case):
part_starts, inputs = case
for i in range(1, len(self.layouts[0].part_indexes) - 1):
part_point = yield self.test_output.partpoints[
traces=traces):
sim.run()
- def run_formal(self, test_case: unittest.TestCase, **kwargs):
+ def run_formal(self, test_case, **kwargs):
for part_point in self.inputs[0].partpoints.values():
self.m.d.comb += part_point.eq(AnyConst(1))
for i in range(len(self.inputs)):
self.test_output_layout.part_indexes[i]]
self.m.d.comb += Assert(in_part_point == out_part_point)
- def check_active_lane(lane: Lane) -> Assert:
+ def check_active_lane(lane):
reference = self.reference_outputs[lane]
output = self.test_output.sig[
lane.translate_to(self.test_output_layout).as_slice()]