From 0cdf4be4df5c0fbae476442c1a91b0e8140e2104 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Sat, 2 Oct 2021 10:39:43 +0100 Subject: [PATCH] removing unnecessary type information which makes the code completely unreadable, longer and more complex --- src/ieee754/partitioned_signal_tester.py | 143 +++++++++-------------- 1 file changed, 52 insertions(+), 91 deletions(-) diff --git a/src/ieee754/partitioned_signal_tester.py b/src/ieee754/partitioned_signal_tester.py index 9d094950..27db8692 100644 --- a/src/ieee754/partitioned_signal_tester.py +++ b/src/ieee754/partitioned_signal_tester.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: LGPL-3-or-later # See Notices.txt for copyright information from enum import Enum -from typing import (Any, Callable, Dict, Generator, Iterable, List, Mapping, - Optional, Sequence, Tuple, Union, final, overload) import shutil from nmigen.hdl.ast import (AnyConst, Assert, Signal, Value, ValueCastable) from nmigen.hdl.dsl import Module @@ -17,17 +15,7 @@ from nmigen.back import rtlil from nmutil.get_test_path import get_test_path, _StrPath -_PartitionedSignalTestable = Callable[[Tuple[PartitionedSignal, ...]], - PartitionedSignal] - -_WidthCastable = Union["Layout", int] -_LayoutCastable = Union["Layout", Mapping[int, Any], Iterable[int]] -_ValueCastableType = Union[Value, int, Enum, ValueCastable] -_FragmentLike = Union[Elaboratable, Fragment] - - -def formal(test_case: unittest.TestCase, hdl: _FragmentLike, *, - base_path: _StrPath = "formal_test_temp"): +def formal(test_case, hdl, *, base_path="formal_test_temp"): hdl = Fragment.get(hdl, platform="formal") path = get_test_path(test_case, base_path) shutil.rmtree(path, ignore_errors=True) @@ -65,23 +53,20 @@ def formal(test_case: unittest.TestCase, hdl: _FragmentLike, *, @final class Layout: - __lane_starts_for_sizes: Dict[int, Dict[int, None]] + __lane_starts_for_sizes """keys are in sorted order""" - part_indexes: Tuple[int, ...] + part_indexes """bit indexes of partition points in sorted order, always includes `0` and `self.width`""" @staticmethod - def cast(layout: _LayoutCastable, - width: Optional[_WidthCastable] = None) -> "Layout": + def cast(layout, width=None): if isinstance(layout, Layout): return layout return Layout(layout, width) - def __init__(self, - part_indexes: Union[Mapping[int, Any], Iterable[int]], - width: Optional[_WidthCastable] = None): + def __init__(self, part_indexes, width=None): part_indexes = set(part_indexes) for p in part_indexes: assert isinstance(p, int) @@ -109,64 +94,60 @@ class Layout: self.__lane_starts_for_sizes[end - start][start] = None @property - def width(self) -> int: + def width(self): return self.part_indexes[-1] @property - def part_signal_count(self) -> int: + def part_signal_count(self): return max(len(self.part_indexes) - 2, 0) @staticmethod - def get_width(width: _WidthCastable) -> int: + def get_width(width): if isinstance(width, Layout): width = width.width assert isinstance(width, int) assert width >= 0 return width - def partition_points_signals(self, name: Optional[str] = None, - src_loc_at: int = 0) -> PartitionPoints: + def partition_points_signals(self, nameNone, + src_loc_at=0): if name is None: name = Signal(src_loc_at=1 + src_loc_at).name - return PartitionPoints({ - i: Signal(name=f"{name}_{i}", src_loc_at=1 + src_loc_at) - for i in self.part_indexes[1:-1] - }) + return PartitionPoints({ i for i in self.part_indexes[1:-1] }) - def __repr__(self) -> str: + def __repr__(self): return f"Layout({self.part_indexes}, width={self.width})" - def __eq__(self, o: object) -> bool: + def __eq__(self, o): if isinstance(o, Layout): return self.part_indexes == o.part_indexes return NotImplemented - def __hash__(self) -> int: + def __hash__(self): return hash(self.part_indexes) - def is_lane_valid(self, start: int, size: int) -> bool: + def is_lane_valid(self, start, size): return start in self.__lane_starts_for_sizes.get(size, ()) - def lane_sizes(self) -> Iterable[int]: + def lane_sizes(self): return self.__lane_starts_for_sizes.keys() - def lane_starts_for_size(self, size: int) -> Iterable[int]: + def lane_starts_for_size(self, size): return self.__lane_starts_for_sizes[size].keys() - def lanes_for_size(self, size: int) -> Iterable["Lane"]: + def lanes_for_size(self, size): for start in self.lane_starts_for_size(size): yield Lane(start, size, self) - def lanes(self) -> Iterable["Lane"]: + def lanes(self): for size in self.lane_sizes(): yield from self.lanes_for_size(size) - def is_compatible(self, other: _LayoutCastable) -> bool: + def is_compatible(self, other): other = Layout.cast(other) return len(self.part_indexes) == len(other.part_indexes) - def translate_lane_to(self, lane: "Lane", - target_layout: _LayoutCastable) -> "Lane": + def translate_lane_to(self, lane, target_layout): assert lane.layout == self target_layout = Layout.cast(target_layout) assert self.is_compatible(target_layout) @@ -179,51 +160,49 @@ class Layout: @final class Lane: - def __init__(self, start: int, size: int, layout: _LayoutCastable): + def __init__(self, start, size, layout): self.layout = Layout.cast(layout) assert self.layout.is_lane_valid(start, size) self.start = start self.size = size - def __repr__(self) -> str: + def __repr__(self): return (f"Lane(start={self.start}, size={self.size}, " f"layout={self.layout})") - def __eq__(self, o: object) -> bool: + def __eq__(self, o): if isinstance(o, Lane): return self.start == o.start and self.size == o.size \ and self.layout == o.layout return NotImplemented - def __hash__(self) -> int: + def __hash__(self): return hash((self.start, self.size, self.layout)) - def as_slice(self) -> slice: + def as_slice(self): return slice(self.start, self.end) @property - def end(self) -> int: + def end(self): return self.start + self.size - def translate_to(self, target_layout: _LayoutCastable) -> "Lane": + def translate_to(self, target_layout): return self.layout.translate_lane_to(self, target_layout) @overload - def is_active(self, partition_points: Sequence[bool]) -> bool: ... + def is_active(self, partition_points): ... @overload - def is_active(self, partition_points: Sequence[_ValueCastableType] - ) -> Union[Value, bool]: ... + def is_active(self, partition_points): ... @overload - def is_active(self, partition_points: Mapping[int, bool]) -> bool: ... + def is_active(self, partition_points): ... @overload - def is_active(self, partition_points: Mapping[int, _ValueCastableType] - ) -> Union[Value, bool]: ... + def is_active(self, partition_points): ... def is_active(self, partition_points): - def get_partition_point(index: int, invert: bool): + def get_partition_point(index, invert): if index == 0 or index == len(self.layout.part_indexes) - 1: return True if isinstance(partition_points, Sequence): @@ -238,37 +217,22 @@ class Lane: if invert: return ~retval return retval + start_index = self.layout.part_indexes.index(self.start) end_index = self.layout.part_indexes.index(self.end) retval = get_partition_point(start_index, False) \ & get_partition_point(end_index, False) for i in range(start_index + 1, end_index): retval &= get_partition_point(i, True) - return retval - -_PartitionedSignalTestReference = Callable[[Lane, Tuple[Value, ...]], - _ValueCastableType] - -_PartitionedSignalTestCasePartMode = Tuple[bool, ...] -_PartitionedSignalTestCaseInputs = Tuple[int, ...] -_PartitionedSignalTestCase = Tuple[_PartitionedSignalTestCasePartMode, - _PartitionedSignalTestCaseInputs] + return retval class PartitionedSignalTester: - layouts: List[Layout] - inputs: List[PartitionedSignal] - - def __init__(self, - m: Module, - operation: _PartitionedSignalTestable, - reference: _PartitionedSignalTestReference, - *layouts: _LayoutCastable, - src_loc_at: int = 0, - additional_case_count: int = 30, - special_cases: Iterable[_PartitionedSignalTestCase] = (), - seed: str = ""): + + def __init__(self, m, operation, reference, *layouts, + src_loc_at=0, additional_case_count=30, + special_cases=(), seed=""): self.m = m self.operation = operation self.reference = reference @@ -305,33 +269,32 @@ class PartitionedSignalTester: self.test_output.partpoints, self.test_output.sig.width) assert self.test_output_layout.is_compatible(self.layouts[0]) self.reference_output_values = { - lane: Value.cast(reference(lane, tuple( + lane, tuple( inp.sig[lane.translate_to(layout).as_slice()] - for inp, layout in zip(self.inputs, self.layouts)))) + for inp, layout in zip(self.inputs, self.layouts)) for lane in self.layouts[0].lanes() } self.reference_outputs = { - lane: Signal(value.shape(), - name=f"reference_output_{lane.start}_{lane.size}") + lane, name=f"reference_output_{lane.start}_{lane.size}") for lane, value in self.reference_output_values.items() } for lane, value in self.reference_output_values.items(): m.d.comb += self.reference_outputs[lane].eq(value) - def __hash_256(self, v: str) -> int: + def __hash_256(self, v): return int.from_bytes( sha256(bytes(self.seed + v, encoding='utf-8')).digest(), byteorder='little' ) - def __hash(self, v: str, bits: int) -> int: + def __hash(self, v, bits): retval = 0 for i in range(0, bits, 256): retval <<= 256 retval |= self.__hash_256(f" {v} {i}") return retval & ((1 << bits) - 1) - def __get_case(self, case_number: int) -> _PartitionedSignalTestCase: + def __get_case(self, case_number): if case_number < len(self.special_cases): return self.special_cases[case_number] trial = 0 @@ -346,14 +309,12 @@ class PartitionedSignalTester: for i in range(len(self.layouts))) return part_starts, inputs - def __format_case(self, case: _PartitionedSignalTestCase) -> str: + def __format_case(self, case): part_starts, inputs = case str_inputs = [hex(i) for i in inputs] return f"part_starts={part_starts}, inputs={str_inputs}" - def __setup_case(self, case_number: int, - case: Optional[_PartitionedSignalTestCase] = None - ) -> Generator[Any, int, None]: + def __setup_case(self, case_number, case=None): if case is None: case = self.__get_case(case_number) yield self.case_number.eq(case_number) @@ -365,7 +326,7 @@ class PartitionedSignalTester: for i in range(len(self.inputs)): yield self.inputs[i].sig.eq(inputs[i]) - def run_sim(self, test_case: unittest.TestCase, *, + def run_sim(self, test_case, *, engine: Optional[str] = None, base_path: _StrPath = "sim_test_out"): if engine is None: @@ -373,13 +334,13 @@ class PartitionedSignalTester: else: sim = Simulator(self.m, engine=engine) - def check_active_lane(lane: Lane): + def check_active_lane(lane): reference = yield self.reference_outputs[lane] output = yield self.test_output.sig[ lane.translate_to(self.test_output_layout).as_slice()] test_case.assertEqual(hex(reference), hex(output)) - def check_case(case: _PartitionedSignalTestCase): + def check_case(case): part_starts, inputs = case for i in range(1, len(self.layouts[0].part_indexes) - 1): part_point = yield self.test_output.partpoints[ @@ -416,7 +377,7 @@ class PartitionedSignalTester: traces=traces): sim.run() - def run_formal(self, test_case: unittest.TestCase, **kwargs): + def run_formal(self, test_case, **kwargs): for part_point in self.inputs[0].partpoints.values(): self.m.d.comb += part_point.eq(AnyConst(1)) for i in range(len(self.inputs)): @@ -429,7 +390,7 @@ class PartitionedSignalTester: self.test_output_layout.part_indexes[i]] self.m.d.comb += Assert(in_part_point == out_part_point) - def check_active_lane(lane: Lane) -> Assert: + def check_active_lane(lane): reference = self.reference_outputs[lane] output = self.test_output.sig[ lane.translate_to(self.test_output_layout).as_slice()] -- 2.30.2