removing unnecessary type information which makes the code
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 2 Oct 2021 09:39:43 +0000 (10:39 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 2 Oct 2021 09:39:43 +0000 (10:39 +0100)
completely unreadable, longer and more complex

src/ieee754/partitioned_signal_tester.py

index 9d094950b8a05d080f1a3149a2cc737461f3007e..27db86925f9c90d82564bec50ae3359e1a5dd9bc 100644 (file)
@@ -1,8 +1,6 @@
 # SPDX-License-Identifier: LGPL-3-or-later
 # See Notices.txt for copyright information
 from enum import Enum
-from typing import (Any, Callable, Dict, Generator, Iterable, List, Mapping,
-                    Optional, Sequence, Tuple, Union, final, overload)
 import shutil
 from nmigen.hdl.ast import (AnyConst, Assert, Signal, Value, ValueCastable)
 from nmigen.hdl.dsl import Module
@@ -17,17 +15,7 @@ from nmigen.back import rtlil
 from nmutil.get_test_path import get_test_path, _StrPath
 
 
-_PartitionedSignalTestable = Callable[[Tuple[PartitionedSignal, ...]],
-                                      PartitionedSignal]
-
-_WidthCastable = Union["Layout", int]
-_LayoutCastable = Union["Layout", Mapping[int, Any], Iterable[int]]
-_ValueCastableType = Union[Value, int, Enum, ValueCastable]
-_FragmentLike = Union[Elaboratable, Fragment]
-
-
-def formal(test_case: unittest.TestCase, hdl: _FragmentLike, *,
-           base_path: _StrPath = "formal_test_temp"):
+def formal(test_case, hdl, *, base_path="formal_test_temp"):
     hdl = Fragment.get(hdl, platform="formal")
     path = get_test_path(test_case, base_path)
     shutil.rmtree(path, ignore_errors=True)
@@ -65,23 +53,20 @@ def formal(test_case: unittest.TestCase, hdl: _FragmentLike, *,
 
 @final
 class Layout:
-    __lane_starts_for_sizes: Dict[int, Dict[int, None]]
+    __lane_starts_for_sizes
     """keys are in sorted order"""
 
-    part_indexes: Tuple[int, ...]
+    part_indexes
     """bit indexes of partition points in sorted order, always includes
     `0` and `self.width`"""
 
     @staticmethod
-    def cast(layout: _LayoutCastable,
-             width: Optional[_WidthCastable] = None) -> "Layout":
+    def cast(layout, width=None):
         if isinstance(layout, Layout):
             return layout
         return Layout(layout, width)
 
-    def __init__(self,
-                 part_indexes: Union[Mapping[int, Any], Iterable[int]],
-                 width: Optional[_WidthCastable] = None):
+    def __init__(self, part_indexes, width=None):
         part_indexes = set(part_indexes)
         for p in part_indexes:
             assert isinstance(p, int)
@@ -109,64 +94,60 @@ class Layout:
                 self.__lane_starts_for_sizes[end - start][start] = None
 
     @property
-    def width(self) -> int:
+    def width(self):
         return self.part_indexes[-1]
 
     @property
-    def part_signal_count(self) -> int:
+    def part_signal_count(self):
         return max(len(self.part_indexes) - 2, 0)
 
     @staticmethod
-    def get_width(width: _WidthCastable) -> int:
+    def get_width(width):
         if isinstance(width, Layout):
             width = width.width
         assert isinstance(width, int)
         assert width >= 0
         return width
 
-    def partition_points_signals(self, name: Optional[str] = None,
-                                 src_loc_at: int = 0) -> PartitionPoints:
+    def partition_points_signals(self, nameNone,
+                                 src_loc_at=0):
         if name is None:
             name = Signal(src_loc_at=1 + src_loc_at).name
-        return PartitionPoints({
-            i: Signal(name=f"{name}_{i}", src_loc_at=1 + src_loc_at)
-            for i in self.part_indexes[1:-1]
-        })
+        return PartitionPoints({ i for i in self.part_indexes[1:-1] })
 
-    def __repr__(self) -> str:
+    def __repr__(self):
         return f"Layout({self.part_indexes}, width={self.width})"
 
-    def __eq__(self, o: object) -> bool:
+    def __eq__(self, o):
         if isinstance(o, Layout):
             return self.part_indexes == o.part_indexes
         return NotImplemented
 
-    def __hash__(self) -> int:
+    def __hash__(self):
         return hash(self.part_indexes)
 
-    def is_lane_valid(self, start: int, size: int) -> bool:
+    def is_lane_valid(self, start, size):
         return start in self.__lane_starts_for_sizes.get(size, ())
 
-    def lane_sizes(self) -> Iterable[int]:
+    def lane_sizes(self):
         return self.__lane_starts_for_sizes.keys()
 
-    def lane_starts_for_size(self, size: int) -> Iterable[int]:
+    def lane_starts_for_size(self, size):
         return self.__lane_starts_for_sizes[size].keys()
 
-    def lanes_for_size(self, size: int) -> Iterable["Lane"]:
+    def lanes_for_size(self, size):
         for start in self.lane_starts_for_size(size):
             yield Lane(start, size, self)
 
-    def lanes(self) -> Iterable["Lane"]:
+    def lanes(self):
         for size in self.lane_sizes():
             yield from self.lanes_for_size(size)
 
-    def is_compatible(self, other: _LayoutCastable) -> bool:
+    def is_compatible(self, other):
         other = Layout.cast(other)
         return len(self.part_indexes) == len(other.part_indexes)
 
-    def translate_lane_to(self, lane: "Lane",
-                          target_layout: _LayoutCastable) -> "Lane":
+    def translate_lane_to(self, lane, target_layout):
         assert lane.layout == self
         target_layout = Layout.cast(target_layout)
         assert self.is_compatible(target_layout)
@@ -179,51 +160,49 @@ class Layout:
 
 @final
 class Lane:
-    def __init__(self, start: int, size: int, layout: _LayoutCastable):
+    def __init__(self, start, size, layout):
         self.layout = Layout.cast(layout)
         assert self.layout.is_lane_valid(start, size)
         self.start = start
         self.size = size
 
-    def __repr__(self) -> str:
+    def __repr__(self):
         return (f"Lane(start={self.start}, size={self.size}, "
                 f"layout={self.layout})")
 
-    def __eq__(self, o: object) -> bool:
+    def __eq__(self, o):
         if isinstance(o, Lane):
             return self.start == o.start and self.size == o.size \
                 and self.layout == o.layout
         return NotImplemented
 
-    def __hash__(self) -> int:
+    def __hash__(self):
         return hash((self.start, self.size, self.layout))
 
-    def as_slice(self) -> slice:
+    def as_slice(self):
         return slice(self.start, self.end)
 
     @property
-    def end(self) -> int:
+    def end(self):
         return self.start + self.size
 
-    def translate_to(self, target_layout: _LayoutCastable) -> "Lane":
+    def translate_to(self, target_layout):
         return self.layout.translate_lane_to(self, target_layout)
 
     @overload
-    def is_active(self, partition_points: Sequence[bool]) -> bool: ...
+    def is_active(self, partition_points): ...
 
     @overload
-    def is_active(self, partition_points: Sequence[_ValueCastableType]
-                  ) -> Union[Value, bool]: ...
+    def is_active(self, partition_points): ...
 
     @overload
-    def is_active(self, partition_points: Mapping[int, bool]) -> bool: ...
+    def is_active(self, partition_points): ...
 
     @overload
-    def is_active(self, partition_points: Mapping[int, _ValueCastableType]
-                  ) -> Union[Value, bool]: ...
+    def is_active(self, partition_points): ...
 
     def is_active(self, partition_points):
-        def get_partition_point(index: int, invert: bool):
+        def get_partition_point(index, invert):
             if index == 0 or index == len(self.layout.part_indexes) - 1:
                 return True
             if isinstance(partition_points, Sequence):
@@ -238,37 +217,22 @@ class Lane:
             if invert:
                 return ~retval
             return retval
+
         start_index = self.layout.part_indexes.index(self.start)
         end_index = self.layout.part_indexes.index(self.end)
         retval = get_partition_point(start_index, False) \
             & get_partition_point(end_index, False)
         for i in range(start_index + 1, end_index):
             retval &= get_partition_point(i, True)
-        return retval
 
-
-_PartitionedSignalTestReference = Callable[[Lane, Tuple[Value, ...]],
-                                           _ValueCastableType]
-
-_PartitionedSignalTestCasePartMode = Tuple[bool, ...]
-_PartitionedSignalTestCaseInputs = Tuple[int, ...]
-_PartitionedSignalTestCase = Tuple[_PartitionedSignalTestCasePartMode,
-                                   _PartitionedSignalTestCaseInputs]
+        return retval
 
 
 class PartitionedSignalTester:
-    layouts: List[Layout]
-    inputs: List[PartitionedSignal]
-
-    def __init__(self,
-                 m: Module,
-                 operation: _PartitionedSignalTestable,
-                 reference: _PartitionedSignalTestReference,
-                 *layouts: _LayoutCastable,
-                 src_loc_at: int = 0,
-                 additional_case_count: int = 30,
-                 special_cases: Iterable[_PartitionedSignalTestCase] = (),
-                 seed: str = ""):
+
+    def __init__(self, m, operation, reference, *layouts,
+                 src_loc_at=0, additional_case_count=30,
+                 special_cases=(), seed=""):
         self.m = m
         self.operation = operation
         self.reference = reference
@@ -305,33 +269,32 @@ class PartitionedSignalTester:
             self.test_output.partpoints, self.test_output.sig.width)
         assert self.test_output_layout.is_compatible(self.layouts[0])
         self.reference_output_values = {
-            lane: Value.cast(reference(lane, tuple(
+            lane, tuple(
                 inp.sig[lane.translate_to(layout).as_slice()]
-                for inp, layout in zip(self.inputs, self.layouts))))
+                for inp, layout in zip(self.inputs, self.layouts))
             for lane in self.layouts[0].lanes()
         }
         self.reference_outputs = {
-            lane: Signal(value.shape(),
-                         name=f"reference_output_{lane.start}_{lane.size}")
+            lane, name=f"reference_output_{lane.start}_{lane.size}")
             for lane, value in self.reference_output_values.items()
         }
         for lane, value in self.reference_output_values.items():
             m.d.comb += self.reference_outputs[lane].eq(value)
 
-    def __hash_256(self, v: str) -> int:
+    def __hash_256(self, v):
         return int.from_bytes(
             sha256(bytes(self.seed + v, encoding='utf-8')).digest(),
             byteorder='little'
         )
 
-    def __hash(self, v: str, bits: int) -> int:
+    def __hash(self, v, bits):
         retval = 0
         for i in range(0, bits, 256):
             retval <<= 256
             retval |= self.__hash_256(f" {v} {i}")
         return retval & ((1 << bits) - 1)
 
-    def __get_case(self, case_number: int) -> _PartitionedSignalTestCase:
+    def __get_case(self, case_number):
         if case_number < len(self.special_cases):
             return self.special_cases[case_number]
         trial = 0
@@ -346,14 +309,12 @@ class PartitionedSignalTester:
                        for i in range(len(self.layouts)))
         return part_starts, inputs
 
-    def __format_case(self, case: _PartitionedSignalTestCase) -> str:
+    def __format_case(self, case):
         part_starts, inputs = case
         str_inputs = [hex(i) for i in inputs]
         return f"part_starts={part_starts}, inputs={str_inputs}"
 
-    def __setup_case(self, case_number: int,
-                     case: Optional[_PartitionedSignalTestCase] = None
-                     ) -> Generator[Any, int, None]:
+    def __setup_case(self, case_number, case=None):
         if case is None:
             case = self.__get_case(case_number)
         yield self.case_number.eq(case_number)
@@ -365,7 +326,7 @@ class PartitionedSignalTester:
         for i in range(len(self.inputs)):
             yield self.inputs[i].sig.eq(inputs[i])
 
-    def run_sim(self, test_case: unittest.TestCase, *,
+    def run_sim(self, test_case, *,
                 engine: Optional[str] = None,
                 base_path: _StrPath = "sim_test_out"):
         if engine is None:
@@ -373,13 +334,13 @@ class PartitionedSignalTester:
         else:
             sim = Simulator(self.m, engine=engine)
 
-        def check_active_lane(lane: Lane):
+        def check_active_lane(lane):
             reference = yield self.reference_outputs[lane]
             output = yield self.test_output.sig[
                 lane.translate_to(self.test_output_layout).as_slice()]
             test_case.assertEqual(hex(reference), hex(output))
 
-        def check_case(case: _PartitionedSignalTestCase):
+        def check_case(case):
             part_starts, inputs = case
             for i in range(1, len(self.layouts[0].part_indexes) - 1):
                 part_point = yield self.test_output.partpoints[
@@ -416,7 +377,7 @@ class PartitionedSignalTester:
                            traces=traces):
             sim.run()
 
-    def run_formal(self, test_case: unittest.TestCase, **kwargs):
+    def run_formal(self, test_case, **kwargs):
         for part_point in self.inputs[0].partpoints.values():
             self.m.d.comb += part_point.eq(AnyConst(1))
         for i in range(len(self.inputs)):
@@ -429,7 +390,7 @@ class PartitionedSignalTester:
                 self.test_output_layout.part_indexes[i]]
             self.m.d.comb += Assert(in_part_point == out_part_point)
 
-        def check_active_lane(lane: Lane) -> Assert:
+        def check_active_lane(lane):
             reference = self.reference_outputs[lane]
             output = self.test_output.sig[
                 lane.translate_to(self.test_output_layout).as_slice()]