hdl.ast: implement ValueCastable.
[nmigen.git] / nmigen / hdl / ast.py
index d26c3c3d9c154f03dd4a2d0ba941c1fd07f969ae..0b86e7153da0cac1dd5d62c83cca15caaf7357e2 100644 (file)
@@ -1,13 +1,16 @@
 from abc import ABCMeta, abstractmethod
 import traceback
+import sys
 import warnings
 import typing
+import functools
 from collections import OrderedDict
 from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence
 from enum import Enum
 
 from .. import tracer
 from .._utils import *
+from .._unused import *
 
 
 __all__ = [
@@ -15,11 +18,11 @@ __all__ = [
     "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
     "Array", "ArrayProxy",
     "Signal", "ClockSignal", "ResetSignal",
-    "UserValue",
+    "UserValue", "ValueCastable",
     "Sample", "Past", "Stable", "Rose", "Fell", "Initial",
-    "Statement", "Assign", "Assert", "Assume", "Cover", "Switch",
-    "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict",
-    "SignalSet",
+    "Statement", "Switch",
+    "Property", "Assign", "Assert", "Assume", "Cover",
+    "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict", "SignalSet",
 ]
 
 
@@ -31,7 +34,7 @@ class DUID:
         DUID.__next_uid += 1
 
 
-class Shape(typing.NamedTuple):
+class Shape:
     """Bit width and signedness of a value.
 
     A ``Shape`` can be constructed using:
@@ -54,8 +57,15 @@ class Shape(typing.NamedTuple):
     signed : bool
         If ``False``, the value is unsigned. If ``True``, the value is signed two's complement.
     """
-    width:  int  = 1
-    signed: bool = False
+    def __init__(self, width=1, signed=False):
+        if not isinstance(width, int) or width < 0:
+            raise TypeError("Width must be a non-negative integer, not {!r}"
+                            .format(width))
+        self.width = width
+        self.signed = signed
+
+    def __iter__(self):
+        return iter((self.width, self.signed))
 
     @staticmethod
     def cast(obj, *, src_loc_at=0):
@@ -88,13 +98,26 @@ class Shape(typing.NamedTuple):
             return Shape(width, signed)
         raise TypeError("Object {!r} cannot be used as value shape".format(obj))
 
+    def __repr__(self):
+        if self.signed:
+            return "signed({})".format(self.width)
+        else:
+            return "unsigned({})".format(self.width)
 
-# TODO: use dataclasses instead of this hack
-def _Shape___init__(self, width=1, signed=False):
-    if not isinstance(width, int) or width < 0:
-        raise TypeError("Width must be a non-negative integer, not {!r}"
-                        .format(width))
-Shape.__init__ = _Shape___init__
+    def __eq__(self, other):
+        if isinstance(other, tuple) and len(other) == 2:
+            width, signed = other
+            if isinstance(width, int) and isinstance(signed, bool):
+                return self.width == width and self.signed == signed
+            else:
+                raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, "
+                        "not {!r}"
+                        .format(other))
+        if not isinstance(other, Shape):
+            raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, "
+                    "not {!r}"
+                    .format(other))
+        return self.width == other.width and self.signed == other.signed
 
 
 def unsigned(width):
@@ -121,6 +144,8 @@ class Value(metaclass=ABCMeta):
             return Const(obj)
         if isinstance(obj, Enum):
             return Const(obj.value, Shape.cast(type(obj)))
+        if isinstance(obj, ValueCastable):
+            return obj.as_value()
         raise TypeError("Object {!r} cannot be converted to an nMigen value".format(obj))
 
     def __init__(self, *, src_loc_at=0):
@@ -128,7 +153,7 @@ class Value(metaclass=ABCMeta):
         self.src_loc = tracer.get_src_loc(1 + src_loc_at)
 
     def __bool__(self):
-        raise TypeError("Attempted to convert nMigen value to boolean")
+        raise TypeError("Attempted to convert nMigen value to Python boolean")
 
     def __invert__(self):
         return Operator("~", [self])
@@ -171,14 +196,28 @@ class Value(metaclass=ABCMeta):
         self.__check_divisor()
         return Operator("//", [other, self])
 
+    def __check_shamt(self):
+        width, signed = self.shape()
+        if signed:
+            # Neither Python nor HDLs implement shifts by negative values; prohibit any shifts
+            # by a signed value to make sure the shift amount can always be interpreted as
+            # an unsigned value.
+            raise TypeError("Shift amount must be unsigned")
     def __lshift__(self, other):
+        other = Value.cast(other)
+        other.__check_shamt()
         return Operator("<<", [self, other])
     def __rlshift__(self, other):
+        self.__check_shamt()
         return Operator("<<", [other, self])
     def __rshift__(self, other):
+        other = Value.cast(other)
+        other.__check_shamt()
         return Operator(">>", [self, other])
     def __rrshift__(self, other):
+        self.__check_shamt()
         return Operator(">>", [other, self])
+
     def __and__(self, other):
         return Operator("&", [self, other])
     def __rand__(self, other):
@@ -205,6 +244,13 @@ class Value(metaclass=ABCMeta):
     def __ge__(self, other):
         return Operator(">=", [self, other])
 
+    def __abs__(self):
+        width, signed = self.shape()
+        if signed:
+            return Mux(self >= 0, self, -self)
+        else:
+            return self
+
     def __len__(self):
         return self.shape().width
 
@@ -212,7 +258,7 @@ class Value(metaclass=ABCMeta):
         n = len(self)
         if isinstance(key, int):
             if key not in range(-n, n):
-                raise IndexError("Cannot index {} bits into {}-bit value".format(key, n))
+                raise IndexError(f"Index {key} is out of bounds for a {n}-bit value")
             if key < 0:
                 key += n
             return Slice(self, key, key + 1)
@@ -224,6 +270,26 @@ class Value(metaclass=ABCMeta):
         else:
             raise TypeError("Cannot index value with {}".format(repr(key)))
 
+    def as_unsigned(self):
+        """Conversion to unsigned.
+
+        Returns
+        -------
+        Value, out
+            This ``Value`` reinterpreted as a unsigned integer.
+        """
+        return Operator("u", [self])
+
+    def as_signed(self):
+        """Conversion to signed.
+
+        Returns
+        -------
+        Value, out
+            This ``Value`` reinterpreted as a signed integer.
+        """
+        return Operator("s", [self])
+
     def bool(self):
         """Conversion to boolean.
 
@@ -282,7 +348,7 @@ class Value(metaclass=ABCMeta):
 
         Parameters
         ----------
-        offset : Value, in
+        offset : Value, int
             Index of first selected bit.
         width : int
             Number of selected bits.
@@ -305,7 +371,7 @@ class Value(metaclass=ABCMeta):
 
         Parameters
         ----------
-        offset : Value, in
+        offset : Value, int
             Index of first selected word.
         width : int
             Number of selected bits.
@@ -342,11 +408,12 @@ class Value(metaclass=ABCMeta):
                 raise SyntaxError("Match pattern must be an integer, a string, or an enumeration, "
                                   "not {!r}"
                                   .format(pattern))
-            if isinstance(pattern, str) and any(bit not in "01-" for bit in pattern):
+            if isinstance(pattern, str) and any(bit not in "01- \t" for bit in pattern):
                 raise SyntaxError("Match pattern '{}' must consist of 0, 1, and - (don't care) "
-                                  "bits"
+                                  "bits, and may include whitespace"
                                   .format(pattern))
-            if isinstance(pattern, str) and len(pattern) != len(self):
+            if (isinstance(pattern, str) and
+                    len("".join(pattern.split())) != len(self)):
                 raise SyntaxError("Match pattern '{}' must have the same width as match value "
                                   "(which is {})"
                                   .format(pattern, len(self)))
@@ -357,6 +424,7 @@ class Value(metaclass=ABCMeta):
                               SyntaxWarning, stacklevel=3)
                 continue
             if isinstance(pattern, str):
+                pattern = "".join(pattern.split()) # remove whitespace
                 mask    = int(pattern.replace("0", "1").replace("-", "0"), 2)
                 pattern = int(pattern.replace("-", "0"), 2)
                 matches.append((self & mask) == pattern)
@@ -373,6 +441,86 @@ class Value(metaclass=ABCMeta):
         else:
             return Cat(*matches).any()
 
+    def shift_left(self, amount):
+        """Shift left by constant amount.
+
+        Parameters
+        ----------
+        amount : int
+            Amount to shift by.
+
+        Returns
+        -------
+        Value, out
+            If the amount is positive, the input shifted left. Otherwise, the input shifted right.
+        """
+        if not isinstance(amount, int):
+            raise TypeError("Shift amount must be an integer, not {!r}".format(amount))
+        if amount < 0:
+            return self.shift_right(-amount)
+        if self.shape().signed:
+            return Cat(Const(0, amount), self).as_signed()
+        else:
+            return Cat(Const(0, amount), self) # unsigned
+
+    def shift_right(self, amount):
+        """Shift right by constant amount.
+
+        Parameters
+        ----------
+        amount : int
+            Amount to shift by.
+
+        Returns
+        -------
+        Value, out
+            If the amount is positive, the input shifted right. Otherwise, the input shifted left.
+        """
+        if not isinstance(amount, int):
+            raise TypeError("Shift amount must be an integer, not {!r}".format(amount))
+        if amount < 0:
+            return self.shift_left(-amount)
+        if self.shape().signed:
+            return self[amount:].as_signed()
+        else:
+            return self[amount:] # unsigned
+
+    def rotate_left(self, amount):
+        """Rotate left by constant amount.
+
+        Parameters
+        ----------
+        amount : int
+            Amount to rotate by.
+
+        Returns
+        -------
+        Value, out
+            If the amount is positive, the input rotated left. Otherwise, the input rotated right.
+        """
+        if not isinstance(amount, int):
+            raise TypeError("Rotate amount must be an integer, not {!r}".format(amount))
+        amount %= len(self)
+        return Cat(self[-amount:], self[:-amount]) # meow :3
+
+    def rotate_right(self, amount):
+        """Rotate right by constant amount.
+
+        Parameters
+        ----------
+        amount : int
+            Amount to rotate by.
+
+        Returns
+        -------
+        Value, out
+            If the amount is positive, the input rotated right. Otherwise, the input rotated right.
+        """
+        if not isinstance(amount, int):
+            raise TypeError("Rotate amount must be an integer, not {!r}".format(amount))
+        amount %= len(self)
+        return Cat(self[amount:], self[:amount])
+
     def eq(self, value):
         """Assignment.
 
@@ -463,7 +611,7 @@ class Const(Value):
         return Shape(self.width, self.signed)
 
     def _rhs_signals(self):
-        return ValueSet()
+        return SignalSet()
 
     def _as_const(self):
         return self.value
@@ -487,19 +635,19 @@ class AnyValue(Value, DUID):
         return Shape(self.width, self.signed)
 
     def _rhs_signals(self):
-        return ValueSet()
+        return SignalSet()
 
 
 @final
 class AnyConst(AnyValue):
     def __repr__(self):
-        return "(anyconst {}'{})".format(self.nbits, "s" if self.signed else "")
+        return "(anyconst {}'{})".format(self.width, "s" if self.signed else "")
 
 
 @final
 class AnySeq(AnyValue):
     def __repr__(self):
-        return "(anyseq {}'{})".format(self.nbits, "s" if self.signed else "")
+        return "(anyseq {}'{})".format(self.width, "s" if self.signed else "")
 
 
 @final
@@ -535,6 +683,10 @@ class Operator(Value):
                 return Shape(a_width + 1, True)
             if self.operator in ("b", "r|", "r&", "r^"):
                 return Shape(1, False)
+            if self.operator == "u":
+                return Shape(a_width, False)
+            if self.operator == "s":
+                return Shape(a_width, True)
         elif len(op_shapes) == 2:
             (a_width, a_signed), (b_width, b_signed) = op_shapes
             if self.operator in ("+", "-"):
@@ -696,10 +848,10 @@ class Cat(Value):
         return Shape(sum(len(part) for part in self.parts))
 
     def _lhs_signals(self):
-        return union((part._lhs_signals() for part in self.parts), start=ValueSet())
+        return union((part._lhs_signals() for part in self.parts), start=SignalSet())
 
     def _rhs_signals(self):
-        return union((part._rhs_signals() for part in self.parts), start=ValueSet())
+        return union((part._rhs_signals() for part in self.parts), start=SignalSet())
 
     def _as_const(self):
         value = 0
@@ -758,15 +910,13 @@ class Signal(Value, DUID):
 
     Parameters
     ----------
-    shape : int or tuple or None
-        Either an integer ``width`` or a tuple ``(width, signed)`` specifying the number of bits
-        in this ``Signal`` and whether it is signed (can represent negative values).
-        ``shape`` defaults to 1-bit and non-signed.
+    shape : ``Shape``-castable object or None
+        Specification for the number of bits in this ``Signal`` and its signedness (whether it
+        can represent negative values). See ``Shape.cast`` for details.
+        If not specified, ``shape`` defaults to 1-bit and non-signed.
     name : str
         Name hint for this signal. If ``None`` (default) the name is inferred from the variable
-        name this ``Signal`` is assigned to. Name collisions are automatically resolved by
-        prepending names of objects that contain this ``Signal`` and by appending integer
-        sequences.
+        name this ``Signal`` is assigned to.
     reset : int or integral Enum
         Reset (synchronous) or default (combinatorial) value.
         When this ``Signal`` is assigned to in synchronous context and the corresponding clock
@@ -777,11 +927,6 @@ class Signal(Value, DUID):
         If ``True``, do not generate reset logic for this ``Signal`` in synchronous statements.
         The ``reset`` value is only used as a combinatorial default or as the initial value.
         Defaults to ``False``.
-    min : int or None
-    max : int or None
-        If ``shape`` is ``None``, the signal bit width and signedness are
-        determined by the integer range given by ``min`` (inclusive,
-        defaults to 0) and ``max`` (exclusive, defaults to 2).
     attrs : dict
         Dictionary of synthesis attributes.
     decoder : function or Enum
@@ -798,6 +943,7 @@ class Signal(Value, DUID):
     reset : int
     reset_less : bool
     attrs : dict
+    decoder : function
     """
 
     def __init__(self, shape=None, *, name=None, reset=0, reset_less=False,
@@ -838,8 +984,10 @@ class Signal(Value, DUID):
                 except ValueError:
                     return str(value)
             self.decoder = enum_decoder
+            self._enum_class = decoder
         else:
             self.decoder = decoder
+            self._enum_class = None
 
     # Not a @classmethod because nmigen.compat requires it.
     @staticmethod
@@ -868,10 +1016,10 @@ class Signal(Value, DUID):
         return Shape(self.width, self.signed)
 
     def _lhs_signals(self):
-        return ValueSet((self,))
+        return SignalSet((self,))
 
     def _rhs_signals(self):
-        return ValueSet((self,))
+        return SignalSet((self,))
 
     def __repr__(self):
         return "(sig {})".format(self.name)
@@ -902,7 +1050,7 @@ class ClockSignal(Value):
         return Shape(1)
 
     def _lhs_signals(self):
-        return ValueSet((self,))
+        return SignalSet((self,))
 
     def _rhs_signals(self):
         raise NotImplementedError("ClockSignal must be lowered to a concrete signal") # :nocov:
@@ -939,7 +1087,7 @@ class ResetSignal(Value):
         return Shape(1)
 
     def _lhs_signals(self):
-        return ValueSet((self,))
+        return SignalSet((self,))
 
     def _rhs_signals(self):
         raise NotImplementedError("ResetSignal must be lowered to a concrete signal") # :nocov:
@@ -1053,18 +1201,36 @@ class ArrayProxy(Value):
         return (Value.cast(elem) for elem in self.elems)
 
     def shape(self):
-        width, signed = 0, False
+        unsigned_width = signed_width = 0
+        has_unsigned = has_signed = False
         for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()):
-            width  = max(width, elem_width + elem_signed)
-            signed = max(signed, elem_signed)
-        return Shape(width, signed)
+            if elem_signed:
+                has_signed = True
+                signed_width = max(signed_width, elem_width)
+            else:
+                has_unsigned = True
+                unsigned_width = max(unsigned_width, elem_width)
+        # The shape of the proxy must be such that it preserves the mathematical value of the array
+        # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
+        # To ensure this holds, if the array contains both signed and unsigned values, make sure
+        # that every unsigned value is zero-extended by at least one bit.
+        if has_signed and has_unsigned and unsigned_width >= signed_width:
+            # Array contains both signed and unsigned values, and at least one of the unsigned
+            # values won't be zero-extended otherwise.
+            return signed(unsigned_width + 1)
+        else:
+            # Array contains values of the same signedness, or else all of the unsigned values
+            # are zero-extended.
+            return Shape(max(unsigned_width, signed_width), has_signed)
 
     def _lhs_signals(self):
-        signals = union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet())
+        signals = union((elem._lhs_signals() for elem in self._iter_as_values()),
+                        start=SignalSet())
         return signals
 
     def _rhs_signals(self):
-        signals = union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet())
+        signals = union((elem._rhs_signals() for elem in self._iter_as_values()),
+                        start=SignalSet())
         return self.index._rhs_signals() | signals
 
     def __repr__(self):
@@ -1102,7 +1268,10 @@ class UserValue(Value):
 
     def _lazy_lower(self):
         if self.__lowered is None:
-            self.__lowered = Value.cast(self.lower())
+            lowered = self.lower()
+            if isinstance(lowered, UserValue):
+                lowered = lowered._lazy_lower()
+            self.__lowered = Value.cast(lowered)
         return self.__lowered
 
     def shape(self):
@@ -1115,6 +1284,51 @@ class UserValue(Value):
         return self._lazy_lower()._rhs_signals()
 
 
+class ValueCastable:
+    """Base class for classes which can be cast to Values.
+
+    A ``ValueCastable`` can be cast to ``Value``, meaning its precise representation does not have
+    to be immediately known. This is useful in certain metaprogramming scenarios. Instead of
+    providing fixed semantics upfront, it is kept abstract for as long as possible, only being
+    cast to a concrete nMigen value when required.
+
+    Note that it is necessary to ensure that nMigen's view of representation of all values stays 
+    internally consistent. The class deriving from ``ValueCastable`` must decorate the ``as_value``
+    method with the ``lowermethod`` decorator, which ensures that all calls to ``as_value``return the
+    same ``Value`` representation. If the class deriving from ``ValueCastable`` is mutable, it is
+    up to the user to ensure that it is not mutated in a way that changes its representation after
+    the first call to ``as_value``.
+    """
+    def __new__(cls, *args, **kwargs):
+        self = super().__new__(cls)
+        if not hasattr(self, "as_value"):
+            raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override the `as_value` method")
+
+        if not hasattr(self.as_value, "_ValueCastable__memoized"):
+            raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate the `as_value` "
+                            "method with the `ValueCastable.lowermethod` decorator")
+        return self
+
+    @staticmethod
+    def lowermethod(func):
+        """Decorator to memoize lowering methods.
+
+        Ensures the decorated method is called only once, with subsequent method calls returning the
+        object returned by the first first method call.
+
+        This decorator is required to decorate the ``as_value`` method of ``ValueCastable`` subclasses.
+        This is to ensure that nMigen's view of representation of all values stays internally
+        consistent.
+        """
+        @functools.wraps(func)
+        def wrapper_memoized(self, *args, **kwargs):
+            if not hasattr(self, "_ValueCastable__lowered_to"):
+                self.__lowered_to = func(self, *args, **kwargs)
+            return self.__lowered_to
+        wrapper_memoized.__memoized = True
+        return wrapper_memoized
+
+
 @final
 class Sample(Value):
     """Value from the past.
@@ -1142,7 +1356,7 @@ class Sample(Value):
         return self.value.shape()
 
     def _rhs_signals(self):
-        return ValueSet((self,))
+        return SignalSet((self,))
 
     def __repr__(self):
         return "(sample {!r} @ {}[{}])".format(
@@ -1172,13 +1386,13 @@ class Initial(Value):
     An ``Initial`` signal is ``1`` at the first cycle of model checking, and ``0`` at any other.
     """
     def __init__(self, *, src_loc_at=0):
-        super().__init__(src_loc_at=1 + src_loc_at)
+        super().__init__(src_loc_at=src_loc_at)
 
     def shape(self):
         return Shape(1)
 
     def _rhs_signals(self):
-        return ValueSet((self,))
+        return SignalSet((self,))
 
     def __repr__(self):
         return "(initial)"
@@ -1221,7 +1435,13 @@ class Assign(Statement):
         return "(eq {!r} {!r})".format(self.lhs, self.rhs)
 
 
-class Property(Statement):
+class UnusedProperty(UnusedMustUse):
+    pass
+
+
+class Property(Statement, MustUse):
+    _MustUse__warning = UnusedProperty
+
     def __init__(self, test, *, _check=None, _en=None, src_loc_at=0):
         super().__init__(src_loc_at=src_loc_at)
         self.test   = Value.cast(test)
@@ -1235,7 +1455,7 @@ class Property(Statement):
             self._en.src_loc = self.src_loc
 
     def _lhs_signals(self):
-        return ValueSet((self._en, self._check))
+        return SignalSet((self._en, self._check))
 
     def _rhs_signals(self):
         return self.test._rhs_signals()
@@ -1285,7 +1505,7 @@ class Switch(Statement):
             new_keys = ()
             for key in keys:
                 if isinstance(key, str):
-                    pass
+                    key = "".join(key.split()) # remove whitespace
                 elif isinstance(key, int):
                     key = format(key, "b").rjust(len(self.test), "0")
                 elif isinstance(key, Enum):
@@ -1303,12 +1523,12 @@ class Switch(Statement):
 
     def _lhs_signals(self):
         signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
-                        start=ValueSet())
+                        start=SignalSet())
         return signals
 
     def _rhs_signals(self):
         signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss),
-                        start=ValueSet())
+                        start=SignalSet())
         return self.test._rhs_signals() | signals
 
     def __repr__(self):