hdl.ast: implement ValueCastable.
[nmigen.git] / nmigen / hdl / ast.py
index 39093e7c2a4794ac0c5ef38c78aceb394d755a8b..0b86e7153da0cac1dd5d62c83cca15caaf7357e2 100644 (file)
@@ -1,7 +1,9 @@
 from abc import ABCMeta, abstractmethod
 import traceback
+import sys
 import warnings
 import typing
+import functools
 from collections import OrderedDict
 from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence
 from enum import Enum
@@ -16,7 +18,7 @@ __all__ = [
     "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
     "Array", "ArrayProxy",
     "Signal", "ClockSignal", "ResetSignal",
-    "UserValue",
+    "UserValue", "ValueCastable",
     "Sample", "Past", "Stable", "Rose", "Fell", "Initial",
     "Statement", "Switch",
     "Property", "Assign", "Assert", "Assume", "Cover",
@@ -32,7 +34,7 @@ class DUID:
         DUID.__next_uid += 1
 
 
-class Shape(typing.NamedTuple):
+class Shape:
     """Bit width and signedness of a value.
 
     A ``Shape`` can be constructed using:
@@ -55,8 +57,15 @@ class Shape(typing.NamedTuple):
     signed : bool
         If ``False``, the value is unsigned. If ``True``, the value is signed two's complement.
     """
-    width:  int  = 1
-    signed: bool = False
+    def __init__(self, width=1, signed=False):
+        if not isinstance(width, int) or width < 0:
+            raise TypeError("Width must be a non-negative integer, not {!r}"
+                            .format(width))
+        self.width = width
+        self.signed = signed
+
+    def __iter__(self):
+        return iter((self.width, self.signed))
 
     @staticmethod
     def cast(obj, *, src_loc_at=0):
@@ -95,13 +104,20 @@ class Shape(typing.NamedTuple):
         else:
             return "unsigned({})".format(self.width)
 
-
-# TODO: use dataclasses instead of this hack
-def _Shape___init__(self, width=1, signed=False):
-    if not isinstance(width, int) or width < 0:
-        raise TypeError("Width must be a non-negative integer, not {!r}"
-                        .format(width))
-Shape.__init__ = _Shape___init__
+    def __eq__(self, other):
+        if isinstance(other, tuple) and len(other) == 2:
+            width, signed = other
+            if isinstance(width, int) and isinstance(signed, bool):
+                return self.width == width and self.signed == signed
+            else:
+                raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, "
+                        "not {!r}"
+                        .format(other))
+        if not isinstance(other, Shape):
+            raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, "
+                    "not {!r}"
+                    .format(other))
+        return self.width == other.width and self.signed == other.signed
 
 
 def unsigned(width):
@@ -128,6 +144,8 @@ class Value(metaclass=ABCMeta):
             return Const(obj)
         if isinstance(obj, Enum):
             return Const(obj.value, Shape.cast(type(obj)))
+        if isinstance(obj, ValueCastable):
+            return obj.as_value()
         raise TypeError("Object {!r} cannot be converted to an nMigen value".format(obj))
 
     def __init__(self, *, src_loc_at=0):
@@ -135,7 +153,7 @@ class Value(metaclass=ABCMeta):
         self.src_loc = tracer.get_src_loc(1 + src_loc_at)
 
     def __bool__(self):
-        raise TypeError("Attempted to convert nMigen value to boolean")
+        raise TypeError("Attempted to convert nMigen value to Python boolean")
 
     def __invert__(self):
         return Operator("~", [self])
@@ -184,7 +202,7 @@ class Value(metaclass=ABCMeta):
             # Neither Python nor HDLs implement shifts by negative values; prohibit any shifts
             # by a signed value to make sure the shift amount can always be interpreted as
             # an unsigned value.
-            raise NotImplementedError("Shift by a signed value is not supported")
+            raise TypeError("Shift amount must be unsigned")
     def __lshift__(self, other):
         other = Value.cast(other)
         other.__check_shamt()
@@ -240,7 +258,7 @@ class Value(metaclass=ABCMeta):
         n = len(self)
         if isinstance(key, int):
             if key not in range(-n, n):
-                raise IndexError("Cannot index {} bits into {}-bit value".format(key, n))
+                raise IndexError(f"Index {key} is out of bounds for a {n}-bit value")
             if key < 0:
                 key += n
             return Slice(self, key, key + 1)
@@ -330,7 +348,7 @@ class Value(metaclass=ABCMeta):
 
         Parameters
         ----------
-        offset : Value, in
+        offset : Value, int
             Index of first selected bit.
         width : int
             Number of selected bits.
@@ -353,7 +371,7 @@ class Value(metaclass=ABCMeta):
 
         Parameters
         ----------
-        offset : Value, in
+        offset : Value, int
             Index of first selected word.
         width : int
             Number of selected bits.
@@ -1183,11 +1201,27 @@ class ArrayProxy(Value):
         return (Value.cast(elem) for elem in self.elems)
 
     def shape(self):
-        width, signed = 0, False
+        unsigned_width = signed_width = 0
+        has_unsigned = has_signed = False
         for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()):
-            width  = max(width, elem_width + elem_signed)
-            signed = max(signed, elem_signed)
-        return Shape(width, signed)
+            if elem_signed:
+                has_signed = True
+                signed_width = max(signed_width, elem_width)
+            else:
+                has_unsigned = True
+                unsigned_width = max(unsigned_width, elem_width)
+        # The shape of the proxy must be such that it preserves the mathematical value of the array
+        # elements. I.e., shape-wise, an array proxy must be identical to an equivalent mux tree.
+        # To ensure this holds, if the array contains both signed and unsigned values, make sure
+        # that every unsigned value is zero-extended by at least one bit.
+        if has_signed and has_unsigned and unsigned_width >= signed_width:
+            # Array contains both signed and unsigned values, and at least one of the unsigned
+            # values won't be zero-extended otherwise.
+            return signed(unsigned_width + 1)
+        else:
+            # Array contains values of the same signedness, or else all of the unsigned values
+            # are zero-extended.
+            return Shape(max(unsigned_width, signed_width), has_signed)
 
     def _lhs_signals(self):
         signals = union((elem._lhs_signals() for elem in self._iter_as_values()),
@@ -1250,6 +1284,51 @@ class UserValue(Value):
         return self._lazy_lower()._rhs_signals()
 
 
+class ValueCastable:
+    """Base class for classes which can be cast to Values.
+
+    A ``ValueCastable`` can be cast to ``Value``, meaning its precise representation does not have
+    to be immediately known. This is useful in certain metaprogramming scenarios. Instead of
+    providing fixed semantics upfront, it is kept abstract for as long as possible, only being
+    cast to a concrete nMigen value when required.
+
+    Note that it is necessary to ensure that nMigen's view of representation of all values stays 
+    internally consistent. The class deriving from ``ValueCastable`` must decorate the ``as_value``
+    method with the ``lowermethod`` decorator, which ensures that all calls to ``as_value``return the
+    same ``Value`` representation. If the class deriving from ``ValueCastable`` is mutable, it is
+    up to the user to ensure that it is not mutated in a way that changes its representation after
+    the first call to ``as_value``.
+    """
+    def __new__(cls, *args, **kwargs):
+        self = super().__new__(cls)
+        if not hasattr(self, "as_value"):
+            raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must override the `as_value` method")
+
+        if not hasattr(self.as_value, "_ValueCastable__memoized"):
+            raise TypeError(f"Class '{cls.__name__}' deriving from `ValueCastable` must decorate the `as_value` "
+                            "method with the `ValueCastable.lowermethod` decorator")
+        return self
+
+    @staticmethod
+    def lowermethod(func):
+        """Decorator to memoize lowering methods.
+
+        Ensures the decorated method is called only once, with subsequent method calls returning the
+        object returned by the first first method call.
+
+        This decorator is required to decorate the ``as_value`` method of ``ValueCastable`` subclasses.
+        This is to ensure that nMigen's view of representation of all values stays internally
+        consistent.
+        """
+        @functools.wraps(func)
+        def wrapper_memoized(self, *args, **kwargs):
+            if not hasattr(self, "_ValueCastable__lowered_to"):
+                self.__lowered_to = func(self, *args, **kwargs)
+            return self.__lowered_to
+        wrapper_memoized.__memoized = True
+        return wrapper_memoized
+
+
 @final
 class Sample(Value):
     """Value from the past.