From 975a2eb2f4fcdee44e3fc46cc20112f99097ca29 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Sat, 17 Aug 2019 07:56:20 +0100 Subject: [PATCH] move typing to multiplier.pyi --- src/ieee754/part_mul_add/multiply.py | 58 +++++-------- src/ieee754/part_mul_add/multiply.pyi | 86 +++++++++++++++++++ .../part_mul_add/test/test_multiply.py | 7 +- 3 files changed, 111 insertions(+), 40 deletions(-) create mode 100644 src/ieee754/part_mul_add/multiply.pyi diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 5902967c..b62f7373 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -5,14 +5,10 @@ from nmigen import Signal, Module, Value, Elaboratable, Cat, C, Mux, Repl from nmigen.hdl.ast import Assign from abc import ABCMeta, abstractmethod -from typing import Any, NewType, Union, List, Dict, Iterable, Mapping, Optional -from typing_extensions import final from nmigen.cli import main -PartitionPointsIn = Mapping[int, Union[Value, bool, int]] - -class PartitionPoints(Dict[int, Value]): +class PartitionPoints(dict): """Partition points and corresponding ``Value``s. The points at where an ALU is partitioned along with ``Value``s that @@ -38,7 +34,7 @@ class PartitionPoints(Dict[int, Value]): * bits 10 <= ``i`` < 16 """ - def __init__(self, partition_points: Optional[PartitionPointsIn] = None): + def __init__(self, partition_points=None): """Create a new ``PartitionPoints``. :param partition_points: the input partition points to values mapping. @@ -52,9 +48,7 @@ class PartitionPoints(Dict[int, Value]): raise ValueError("point must be a non-negative integer") self[point] = Value.wrap(enabled) - def like(self, - name: Optional[str] = None, - src_loc_at: int = 0) -> 'PartitionPoints': + def like(self, name=None, src_loc_at=0): """Create a new ``PartitionPoints`` with ``Signal``s for all values. :param name: the base name for the new ``Signal``s. @@ -66,14 +60,14 @@ class PartitionPoints(Dict[int, Value]): retval[point] = Signal(enabled.shape(), name=f"{name}_{point}") return retval - def eq(self, rhs: 'PartitionPoints') -> Iterable[Assign]: + def eq(self, rhs): """Assign ``PartitionPoints`` using ``Signal.eq``.""" if set(self.keys()) != set(rhs.keys()): raise ValueError("incompatible point set") for point, enabled in self.items(): yield enabled.eq(rhs[point]) - def as_mask(self, width: int) -> Value: + def as_mask(self, width): """Create a bit-mask from `self`. Each bit in the returned mask is clear only if the partition point at @@ -81,7 +75,6 @@ class PartitionPoints(Dict[int, Value]): :param width: the bit width of the resulting mask """ - bits: List[Union[Value, bool]] bits = [] for i in range(width): if i in self: @@ -90,7 +83,7 @@ class PartitionPoints(Dict[int, Value]): bits.append(True) return Cat(*bits) - def get_max_partition_count(self, width: int) -> int: + def get_max_partition_count(self, width): """Get the maximum number of partitions. Gets the number of partitions when all partition points are enabled. @@ -101,7 +94,7 @@ class PartitionPoints(Dict[int, Value]): retval += 1 return retval - def fits_in_width(self, width: int) -> bool: + def fits_in_width(self, width): """Check if all partition points are smaller than `width`.""" for point in self.keys(): if point >= width: @@ -109,7 +102,6 @@ class PartitionPoints(Dict[int, Value]): return True -@final class FullAdder(Elaboratable): """Full Adder. @@ -120,7 +112,7 @@ class FullAdder(Elaboratable): :attribute carry: the carry output """ - def __init__(self, width: int): + def __init__(self, width): """Create a ``FullAdder``. :param width: the bit width of the input and output @@ -131,7 +123,7 @@ class FullAdder(Elaboratable): self.sum = Signal(width) self.carry = Signal(width) - def elaborate(self, platform: Any) -> Module: + def elaborate(self, platform): """Elaborate this module.""" m = Module() m.d.comb += self.sum.eq(self.in0 ^ self.in1 ^ self.in2) @@ -141,7 +133,6 @@ class FullAdder(Elaboratable): return m -@final class PartitionedAdder(Elaboratable): """Partitioned Adder. @@ -153,7 +144,7 @@ class PartitionedAdder(Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, width: int, partition_points: PartitionPointsIn): + def __init__(self, width, partition_points): """Create a ``PartitionedAdder``. :param width: the bit width of the input and output @@ -176,7 +167,7 @@ class PartitionedAdder(Elaboratable): self._expanded_b = Signal(expanded_width) self._expanded_output = Signal(expanded_width) - def elaborate(self, platform: Any) -> Module: + def elaborate(self, platform): """Elaborate this module.""" m = Module() expanded_index = 0 @@ -203,7 +194,6 @@ class PartitionedAdder(Elaboratable): FULL_ADDER_INPUT_COUNT = 3 -@final class AddReduce(Elaboratable): """Add list of numbers together. @@ -216,11 +206,7 @@ class AddReduce(Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, - inputs: Iterable[Signal], - output_width: int, - register_levels: Iterable[int], - partition_points: PartitionPointsIn): + def __init__(self, inputs, output_width, register_levels, partition_points): """Create an ``AddReduce``. :param inputs: input ``Signal``s to be summed. @@ -246,7 +232,7 @@ class AddReduce(Elaboratable): "not enough adder levels for specified register levels") @staticmethod - def get_max_level(input_count: int) -> int: + def get_max_level(input_count): """Get the maximum level. All ``register_levels`` must be less than or equal to the maximum @@ -261,20 +247,20 @@ class AddReduce(Elaboratable): input_count += 2 * len(groups) retval += 1 - def next_register_levels(self) -> Iterable[int]: + def next_register_levels(self): """``Iterable`` of ``register_levels`` for next recursive level.""" for level in self.register_levels: if level > 0: yield level - 1 @staticmethod - def full_adder_groups(input_count: int) -> range: + def full_adder_groups(input_count): """Get ``inputs`` indices for which a full adder should be built.""" return range(0, input_count - FULL_ADDER_INPUT_COUNT + 1, FULL_ADDER_INPUT_COUNT) - def elaborate(self, platform: Any) -> Module: + def elaborate(self, platform): """Elaborate this module.""" m = Module() @@ -314,7 +300,7 @@ class AddReduce(Elaboratable): intermediate_terms: List[Signal] intermediate_terms = [] - def add_intermediate_term(value: Value) -> None: + def add_intermediate_term(value): intermediate_term = Signal( len(self.output), name=f"intermediate_terms[{len(intermediate_terms)}]") @@ -390,7 +376,7 @@ class Mul8_16_32_64(Elaboratable): instruction. """ - def __init__(self, register_levels: Iterable[int] = ()): + def __init__(self, register_levels= ()): self.part_pts = PartitionPoints() for i in range(8, 64, 8): self.part_pts[i] = Signal(name=f"part_pts_{i}") @@ -446,13 +432,13 @@ class Mul8_16_32_64(Elaboratable): self._not_b_term_64 = Signal(128) self._neg_lsb_b_term_64 = Signal(128) - def _part_byte(self, index: int) -> Value: + def _part_byte(self, index): if index == -1 or index == 7: return C(True, 1) assert index >= 0 and index < 8 return self.part_pts[index * 8 + 8] - def elaborate(self, platform: Any) -> Module: + def elaborate(self, platform): m = Module() for i in range(len(self.part_ops)): @@ -489,9 +475,7 @@ class Mul8_16_32_64(Elaboratable): terms = [] - def add_term(value: Value, - shift: int = 0, - enabled: Optional[Value] = None) -> None: + def add_term(value, shift=0, enabled=None): term = Signal(128) terms.append(term) if enabled is not None: diff --git a/src/ieee754/part_mul_add/multiply.pyi b/src/ieee754/part_mul_add/multiply.pyi new file mode 100644 index 00000000..e96c5974 --- /dev/null +++ b/src/ieee754/part_mul_add/multiply.pyi @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: LGPL-2.1-or-later +# See Notices.txt for copyright information +"""Integer Multiplication.""" + +from typing import Any, NewType, Union, List, Dict, Iterable, Mapping, Optional +from typing_extensions import final + +PartitionPointsIn = Mapping[int, Union[Value, bool, int]] + +class PartitionPoints(Dict[int, Value]): + def __init__(self, partition_points: Optional[PartitionPointsIn] = None): + ... + + def like(self, + name: Optional[str] = None, + src_loc_at: int = 0) -> 'PartitionPoints': + ... + + def eq(self, rhs: 'PartitionPoints') -> Iterable[Assign]: + ... + + def as_mask(self, width: int) -> Value: + bits: List[Union[Value, bool]] + + def get_max_partition_count(self, width: int) -> int: + ... + + def fits_in_width(self, width: int) -> bool: + ... + + +@final +class FullAdder(Elaboratable): + def __init__(self, width: int): + ... + + def elaborate(self, platform: Any) -> Module: + ... + + +@final +class PartitionedAdder(Elaboratable): + def __init__(self, width: int, partition_points: PartitionPointsIn): + ... + + def elaborate(self, platform: Any) -> Module: + ... + + +@final +class AddReduce(Elaboratable): + def __init__(self, + ... + + @staticmethod + def get_max_level(input_count: int) -> int: + ... + + def next_register_levels(self) -> Iterable[int]: + ... + + @staticmethod + def full_adder_groups(input_count: int) -> range: + ... + + def elaborate(self, platform: Any) -> Module: + ... + + def add_intermediate_term(value: Value) -> None: + ... + + +class Mul8_16_32_64(Elaboratable): + def __init__(self, register_levels: Iterable[int] = ()): + ... + + def _part_byte(self, index: int) -> Value: + ... + + def elaborate(self, platform: Any) -> Module: + ... + + def add_term(value: Value, + shift: int = 0, + enabled: Optional[Value] = None) -> None: + ... diff --git a/src/ieee754/part_mul_add/test/test_multiply.py b/src/ieee754/part_mul_add/test/test_multiply.py index ec833073..ef7f5cd7 100644 --- a/src/ieee754/part_mul_add/test/test_multiply.py +++ b/src/ieee754/part_mul_add/test/test_multiply.py @@ -2,9 +2,10 @@ # SPDX-License-Identifier: LGPL-2.1-or-later # See Notices.txt for copyright information -from src.multiply import PartitionPoints, PartitionedAdder, AddReduce, \ - Mul8_16_32_64, OP_MUL_LOW, OP_MUL_SIGNED_HIGH, \ - OP_MUL_SIGNED_UNSIGNED_HIGH, OP_MUL_UNSIGNED_HIGH +from ieee754.part_mul_add.multiply import \ + (PartitionPoints, PartitionedAdder, AddReduce, + Mul8_16_32_64, OP_MUL_LOW, OP_MUL_SIGNED_HIGH, + OP_MUL_SIGNED_UNSIGNED_HIGH, OP_MUL_UNSIGNED_HIGH) from nmigen import Signal, Module from nmigen.back.pysim import Simulator, Delay, Tick, Passive from nmigen.hdl.ast import Assign, Value -- 2.30.2