hdl.ast: add an explicit Shape class, included in prelude.
authorwhitequark <whitequark@whitequark.org>
Fri, 11 Oct 2019 12:52:41 +0000 (12:52 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 11 Oct 2019 12:52:41 +0000 (12:52 +0000)
Shapes have long been a part of nMigen, but represented using tuples.
This commit adds a Shape class (using namedtuple for backwards
compatibility), and accepts anything castable to Shape (including
enums, ranges, etc) anywhere a tuple was accepted previously.

In addition, `signed(n)` and `unsigned(n)` are added as aliases for
`Shape(n, signed=True)` and `Shape(n, signed=False)`, transforming
code such as `Signal((8, True))` to `Signal(signed(8))`.
These aliases are also included in prelude.

Preparation for #225.

nmigen/hdl/__init__.py
nmigen/hdl/ast.py
nmigen/hdl/rec.py
nmigen/test/test_hdl_ast.py
nmigen/test/test_hdl_rec.py

index 6f6d6a907ceba41425fa9178f251bbe68b343f72..355b9f8747f956b1ceb263e664054e4b65b0c6e7 100644 (file)
@@ -1,3 +1,4 @@
+from .ast import Shape, unsigned, signed
 from .ast import Value, Const, C, Mux, Cat, Repl, Array, Signal, ClockSignal, ResetSignal
 from .dsl import Module
 from .cd import ClockDomain
index f4528ed0e785b1b7ba0a66a0347d3656c47f98f8..db4c24ea5e02f6271a52d89cc13d773ccc3a1f9b 100644 (file)
@@ -2,6 +2,7 @@ from abc import ABCMeta, abstractmethod
 import builtins
 import traceback
 import warnings
+import typing
 from collections import OrderedDict
 from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence
 from enum import Enum
@@ -11,6 +12,7 @@ from ..tools import *
 
 
 __all__ = [
+    "Shape", "signed", "unsigned",
     "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
     "Array", "ArrayProxy",
     "Signal", "ClockSignal", "ResetSignal",
@@ -23,21 +25,65 @@ __all__ = [
 
 
 class DUID:
-    """Deterministic Unique IDentifier"""
+    """Deterministic Unique IDentifier."""
     __next_uid = 0
     def __init__(self):
         self.duid = DUID.__next_uid
         DUID.__next_uid += 1
 
 
-def _enum_shape(enum_type):
-    min_value = min(member.value for member in enum_type)
-    max_value = max(member.value for member in enum_type)
-    if not isinstance(min_value, int) or not isinstance(max_value, int):
-        raise TypeError("Only enumerations with integer values can be converted to nMigen values")
-    signed = min_value < 0 or max_value < 0
-    width  = max(bits_for(min_value, signed), bits_for(max_value, signed))
-    return (width, signed)
+class Shape(typing.NamedTuple):
+    """Bit width and signedness of a value.
+
+    Attributes
+    ----------
+    width : int
+        The number of bits in the representation, including the sign bit (if any).
+    signed : bool
+        If ``False``, the value is unsigned. If ``True``, the value is signed two's complement.
+    """
+    width:  int  = 1
+    signed: bool = False
+
+    @staticmethod
+    def cast(obj):
+        if isinstance(obj, int):
+            return Shape(obj)
+        if isinstance(obj, tuple):
+            return Shape(*obj)
+        if isinstance(obj, range):
+            if len(obj) == 0:
+                return Shape(0, obj.start < 0)
+            signed = obj.start < 0 or (obj.stop - obj.step) < 0
+            width  = max(bits_for(obj.start, signed),
+                         bits_for(obj.stop - obj.step, signed))
+            return Shape(width, signed)
+        if isinstance(obj, type) and issubclass(obj, Enum):
+            min_value = min(member.value for member in obj)
+            max_value = max(member.value for member in obj)
+            if not isinstance(min_value, int) or not isinstance(max_value, int):
+                raise TypeError("Only enumerations with integer values can be used "
+                                "as value shapes")
+            signed = min_value < 0 or max_value < 0
+            width  = max(bits_for(min_value, signed), bits_for(max_value, signed))
+            return Shape(width, signed)
+        raise TypeError("Object {!r} cannot be used as value shape".format(obj))
+
+
+# TODO: use dataclasses instead of this hack
+def _Shape___init__(self, width=1, signed=False):
+    if not isinstance(width, int) or width < 0:
+        raise TypeError("Width must be a non-negative integer, not {!r}"
+                        .format(width))
+Shape.__init__ = _Shape___init__
+
+
+def unsigned(width):
+    return Shape(width, signed=False)
+
+
+def signed(width):
+    return Shape(width, signed=True)
 
 
 class Value(metaclass=ABCMeta):
@@ -50,12 +96,11 @@ class Value(metaclass=ABCMeta):
         """
         if isinstance(obj, Value):
             return obj
-        elif isinstance(obj, (bool, int)):
+        if isinstance(obj, int):
             return Const(obj)
-        elif isinstance(obj, Enum):
-            return Const(obj.value, _enum_shape(type(obj)))
-        else:
-            raise TypeError("Object {!r} is not an nMigen value".format(obj))
+        if isinstance(obj, Enum):
+            return Const(obj.value, Shape.cast(type(obj)))
+        raise TypeError("Object {!r} cannot be converted to an nMigen value".format(obj))
 
     # TODO(nmigen-0.2): remove this
     @classmethod
@@ -146,7 +191,7 @@ class Value(metaclass=ABCMeta):
         return Operator(">=", [self, other])
 
     def __len__(self):
-        return self.shape()[0]
+        return self.shape().width
 
     def __getitem__(self, key):
         n = len(self)
@@ -329,20 +374,19 @@ class Value(metaclass=ABCMeta):
 
     @abstractmethod
     def shape(self):
-        """Bit length and signedness of a value.
+        """Bit width and signedness of a value.
 
         Returns
         -------
-        int, bool
-            Number of bits required to store `v` or available in `v`, followed by
-            whether `v` has a sign bit (included in the bit count).
+        Shape
+            See :class:`Shape`.
 
         Examples
         --------
-        >>> Value.shape(Signal(8))
-        8, False
-        >>> Value.shape(C(0xaa))
-        8, False
+        >>> Signal(8).shape()
+        Shape(width=8, signed=False)
+        >>> Const(0xaa).shape()
+        Shape(width=8, signed=False)
         """
         pass # :nocov:
 
@@ -391,13 +435,12 @@ class Const(Value):
         # We deliberately do not call Value.__init__ here.
         self.value = int(value)
         if shape is None:
-            shape = bits_for(self.value), self.value < 0
-        if isinstance(shape, int):
-            shape = shape, self.value < 0
+            shape = Shape(bits_for(self.value), signed=self.value < 0)
+        elif isinstance(shape, int):
+            shape = Shape(shape, signed=self.value < 0)
+        else:
+            shape = Shape.cast(shape)
         self.width, self.signed = shape
-        if not isinstance(self.width, int) or self.width < 0:
-            raise TypeError("Width must be a non-negative integer, not {!r}"
-                            .format(self.width))
         self.value = self.normalize(self.value, shape)
 
     # TODO(nmigen-0.2): move this to nmigen.compat and make it a deprecated extension
@@ -407,7 +450,7 @@ class Const(Value):
         return self.width
 
     def shape(self):
-        return self.width, self.signed
+        return Shape(self.width, self.signed)
 
     def _rhs_signals(self):
         return ValueSet()
@@ -425,15 +468,13 @@ C = Const  # shorthand
 class AnyValue(Value, DUID):
     def __init__(self, shape, *, src_loc_at=0):
         super().__init__(src_loc_at=src_loc_at)
-        if isinstance(shape, int):
-            shape = shape, False
-        self.width, self.signed = shape
+        self.width, self.signed = Shape.cast(shape)
         if not isinstance(self.width, int) or self.width < 0:
             raise TypeError("Width must be a non-negative integer, not {!r}"
                             .format(self.width))
 
     def shape(self):
-        return self.width, self.signed
+        return Shape(self.width, self.signed)
 
     def _rhs_signals(self):
         return ValueSet()
@@ -470,41 +511,41 @@ class Operator(Value):
             b_bits, b_sign = b_shape
             if not a_sign and not b_sign:
                 # both operands unsigned
-                return max(a_bits, b_bits), False
+                return Shape(max(a_bits, b_bits), False)
             elif a_sign and b_sign:
                 # both operands signed
-                return max(a_bits, b_bits), True
+                return Shape(max(a_bits, b_bits), True)
             elif not a_sign and b_sign:
                 # first operand unsigned (add sign bit), second operand signed
-                return max(a_bits + 1, b_bits), True
+                return Shape(max(a_bits + 1, b_bits), True)
             else:
                 # first signed, second operand unsigned (add sign bit)
-                return max(a_bits, b_bits + 1), True
+                return Shape(max(a_bits, b_bits + 1), True)
 
         op_shapes = list(map(lambda x: x.shape(), self.operands))
         if len(op_shapes) == 1:
             (a_width, a_signed), = op_shapes
             if self.operator in ("+", "~"):
-                return a_width, a_signed
+                return Shape(a_width, a_signed)
             if self.operator == "-":
                 if not a_signed:
-                    return a_width + 1, True
+                    return Shape(a_width + 1, True)
                 else:
-                    return a_width, a_signed
+                    return Shape(a_width, a_signed)
             if self.operator in ("b", "r|", "r&", "r^"):
-                return 1, False
+                return Shape(1, False)
         elif len(op_shapes) == 2:
             (a_width, a_signed), (b_width, b_signed) = op_shapes
-            if self.operator == "+" or self.operator == "-":
+            if self.operator in ("+", "-"):
                 width, signed = _bitwise_binary_shape(*op_shapes)
-                return width + 1, signed
+                return Shape(width + 1, signed)
             if self.operator == "*":
-                return a_width + b_width, a_signed or b_signed
+                return Shape(a_width + b_width, a_signed or b_signed)
             if self.operator in ("//", "%"):
                 assert not b_signed
-                return a_width, a_signed
+                return Shape(a_width, a_signed)
             if self.operator in ("<", "<=", "==", "!=", ">", ">="):
-                return 1, False
+                return Shape(1, False)
             if self.operator in ("&", "^", "|"):
                 return _bitwise_binary_shape(*op_shapes)
             if self.operator == "<<":
@@ -512,13 +553,13 @@ class Operator(Value):
                     extra = 2 ** (b_width - 1) - 1
                 else:
                     extra = 2 ** (b_width)     - 1
-                return a_width + extra, a_signed
+                return Shape(a_width + extra, a_signed)
             if self.operator == ">>":
                 if b_signed:
                     extra = 2 ** (b_width - 1)
                 else:
                     extra = 0
-                return a_width + extra, a_signed
+                return Shape(a_width + extra, a_signed)
         elif len(op_shapes) == 3:
             if self.operator == "m":
                 s_shape, a_shape, b_shape = op_shapes
@@ -581,7 +622,7 @@ class Slice(Value):
         self.end   = end
 
     def shape(self):
-        return self.end - self.start, False
+        return Shape(self.end - self.start)
 
     def _lhs_signals(self):
         return self.value._lhs_signals()
@@ -608,7 +649,7 @@ class Part(Value):
         self.stride = stride
 
     def shape(self):
-        return self.width, False
+        return Shape(self.width)
 
     def _lhs_signals(self):
         return self.value._lhs_signals()
@@ -651,7 +692,7 @@ class Cat(Value):
         self.parts = [Value.cast(v) for v in flatten(args)]
 
     def shape(self):
-        return sum(len(part) for part in self.parts), False
+        return Shape(sum(len(part) for part in self.parts))
 
     def _lhs_signals(self):
         return union((part._lhs_signals() for part in self.parts), start=ValueSet())
@@ -701,7 +742,7 @@ class Repl(Value):
         self.count = count
 
     def shape(self):
-        return len(self.value) * self.count, False
+        return Shape(len(self.value) * self.count)
 
     def _rhs_signals(self):
         return self.value._rhs_signals()
@@ -792,13 +833,7 @@ class Signal(Value, DUID):
         else:
             if not (min is None and max is None):
                 raise ValueError("Only one of bits/signedness or bounds may be specified")
-            if isinstance(shape, int):
-                self.width, self.signed = shape, False
-            else:
-                self.width, self.signed = shape
-
-        if not isinstance(self.width, int) or self.width < 0:
-            raise TypeError("Width must be a non-negative integer, not {!r}".format(self.width))
+            self.width, self.signed = Shape.cast(shape)
 
         reset_width = bits_for(reset, self.signed)
         if reset != 0 and reset_width > self.width:
@@ -829,14 +864,7 @@ class Signal(Value, DUID):
         That is, for any given ``range(*args)``, ``Signal.range(*args)`` can represent any
         ``x for x in range(*args)``.
         """
-        value_range = range(*args)
-        if len(value_range) > 0:
-            signed = value_range.start < 0 or (value_range.stop - value_range.step) < 0
-        else:
-            signed = value_range.start < 0
-        width = max(bits_for(value_range.start, signed),
-                    bits_for(value_range.stop - value_range.step, signed))
-        return cls((width, signed), src_loc_at=1 + src_loc_at, **kwargs)
+        return cls(Shape.cast(range(*args)), src_loc_at=1 + src_loc_at, **kwargs)
 
     @classmethod
     def enum(cls, enum_type, *, src_loc_at=0, **kwargs):
@@ -849,7 +877,7 @@ class Signal(Value, DUID):
         """
         if not issubclass(enum_type, Enum):
             raise TypeError("Type {!r} is not an enumeration")
-        return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
+        return cls(Shape.cast(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs)
 
     @classmethod
     def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs):
@@ -885,7 +913,7 @@ class Signal(Value, DUID):
         self.width = value
 
     def shape(self):
-        return self.width, self.signed
+        return Shape(self.width, self.signed)
 
     def _lhs_signals(self):
         return ValueSet((self,))
@@ -919,7 +947,7 @@ class ClockSignal(Value):
         self.domain = domain
 
     def shape(self):
-        return 1, False
+        return Shape(1)
 
     def _lhs_signals(self):
         return ValueSet((self,))
@@ -956,7 +984,7 @@ class ResetSignal(Value):
         self.allow_reset_less = allow_reset_less
 
     def shape(self):
-        return 1, False
+        return Shape(1)
 
     def _lhs_signals(self):
         return ValueSet((self,))
@@ -1077,7 +1105,7 @@ class ArrayProxy(Value):
         for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()):
             width  = max(width, elem_width + elem_signed)
             signed = max(signed, elem_signed)
-        return width, signed
+        return Shape(width, signed)
 
     def _lhs_signals(self):
         signals = union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet())
@@ -1195,7 +1223,7 @@ class Initial(Value):
         super().__init__(src_loc_at=1 + src_loc_at)
 
     def shape(self):
-        return (1, False)
+        return Shape(1)
 
     def _rhs_signals(self):
         return ValueSet((self,))
index 9ad5ba79b92d0e281aa2326d0b8bd538592fbbe2..3cd274663d62e44a24b8fcfc61117068364c3aca 100644 (file)
@@ -5,7 +5,6 @@ from functools import reduce
 from .. import tracer
 from ..tools import union
 from .ast import *
-from .ast import _enum_shape
 
 
 __all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"]
@@ -46,17 +45,16 @@ class Layout:
             if not isinstance(name, str):
                 raise TypeError("Field {!r} has invalid name: should be a string"
                                 .format(field))
-            if isinstance(shape, type) and issubclass(shape, Enum):
-                shape = _enum_shape(shape)
-            if not isinstance(shape, (int, tuple, Layout)):
-                raise TypeError("Field {!r} has invalid shape: should be an int, tuple, Enum, or "
-                                "list of fields of a nested record"
-                                .format(field))
+            if not isinstance(shape, Layout):
+                try:
+                    shape = Shape.cast(shape)
+                except Exception as error:
+                    raise TypeError("Field {!r} has invalid shape: should be castable to Shape "
+                                    "or a list of fields of a nested record"
+                                    .format(field))
             if name in self.fields:
                 raise NameError("Field {!r} has a name that is already present in the layout"
                                 .format(field))
-            if isinstance(shape, int):
-                shape = (shape, False)
             self.fields[name] = (shape, direction)
 
     def __getitem__(self, item):
@@ -159,7 +157,7 @@ class Record(Value):
             return super().__getitem__(item)
 
     def shape(self):
-        return sum(len(f) for f in self.fields.values()), False
+        return Shape(sum(len(f) for f in self.fields.values()))
 
     def _lhs_signals(self):
         return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
index 653d0f760cffc7373fe862a4351bedc4522c8417..9022ec0942b66eafcf9316f6d86ec272974299c6 100644 (file)
@@ -22,6 +22,105 @@ class StringEnum(Enum):
     BAR = "b"
 
 
+class ShapeTestCase(FHDLTestCase):
+    def test_make(self):
+        s1 = Shape()
+        self.assertEqual(s1.width, 1)
+        self.assertEqual(s1.signed, False)
+        s2 = Shape(signed=True)
+        self.assertEqual(s2.width, 1)
+        self.assertEqual(s2.signed, True)
+        s3 = Shape(3, True)
+        self.assertEqual(s3.width, 3)
+        self.assertEqual(s3.signed, True)
+
+    def test_make_wrong(self):
+        with self.assertRaises(TypeError,
+                msg="Width must be a non-negative integer, not -1"):
+            Shape(-1)
+
+    def test_tuple(self):
+        width, signed = Shape()
+        self.assertEqual(width, 1)
+        self.assertEqual(signed, False)
+
+    def test_unsigned(self):
+        s1 = unsigned(2)
+        self.assertIsInstance(s1, Shape)
+        self.assertEqual(s1.width, 2)
+        self.assertEqual(s1.signed, False)
+
+    def test_signed(self):
+        s1 = signed(2)
+        self.assertIsInstance(s1, Shape)
+        self.assertEqual(s1.width, 2)
+        self.assertEqual(s1.signed, True)
+
+    def test_cast_int(self):
+        s1 = Shape.cast(2)
+        self.assertEqual(s1.width, 2)
+        self.assertEqual(s1.signed, False)
+
+    def test_cast_int_wrong(self):
+        with self.assertRaises(TypeError,
+                msg="Width must be a non-negative integer, not -1"):
+            Shape.cast(-1)
+
+    def test_cast_tuple(self):
+        s1 = Shape.cast((1, False))
+        self.assertEqual(s1.width, 1)
+        self.assertEqual(s1.signed, False)
+        s2 = Shape.cast((3, True))
+        self.assertEqual(s2.width, 3)
+        self.assertEqual(s2.signed, True)
+
+    def test_cast_tuple_wrong(self):
+        with self.assertRaises(TypeError,
+                msg="Width must be a non-negative integer, not -1"):
+            Shape.cast((-1, True))
+
+    def test_cast_range(self):
+        s1 = Shape.cast(range(0, 8))
+        self.assertEqual(s1.width, 3)
+        self.assertEqual(s1.signed, False)
+        s2 = Shape.cast(range(0, 9))
+        self.assertEqual(s2.width, 4)
+        self.assertEqual(s2.signed, False)
+        s3 = Shape.cast(range(-7, 8))
+        self.assertEqual(s3.width, 4)
+        self.assertEqual(s3.signed, True)
+        s4 = Shape.cast(range(0, 1))
+        self.assertEqual(s4.width, 1)
+        self.assertEqual(s4.signed, False)
+        s5 = Shape.cast(range(-1, 0))
+        self.assertEqual(s5.width, 1)
+        self.assertEqual(s5.signed, True)
+        s6 = Shape.cast(range(0, 0))
+        self.assertEqual(s6.width, 0)
+        self.assertEqual(s6.signed, False)
+        s7 = Shape.cast(range(-1, -1))
+        self.assertEqual(s7.width, 0)
+        self.assertEqual(s7.signed, True)
+
+    def test_cast_enum(self):
+        s1 = Shape.cast(UnsignedEnum)
+        self.assertEqual(s1.width, 2)
+        self.assertEqual(s1.signed, False)
+        s2 = Shape.cast(SignedEnum)
+        self.assertEqual(s2.width, 2)
+        self.assertEqual(s2.signed, True)
+
+    def test_cast_enum_bad(self):
+        with self.assertRaises(TypeError,
+                msg="Only enumerations with integer values can be used as value shapes"):
+            Shape.cast(StringEnum)
+
+    def test_cast_bad(self):
+        with self.assertRaises(TypeError,
+                msg="Object 'foo' cannot be used as value shape"):
+            Shape.cast("foo")
+
+
 class ValueTestCase(FHDLTestCase):
     def test_cast(self):
         self.assertIsInstance(Value.cast(0), Const)
@@ -29,7 +128,7 @@ class ValueTestCase(FHDLTestCase):
         c = Const(0)
         self.assertIs(Value.cast(c), c)
         with self.assertRaises(TypeError,
-                msg="Object 'str' is not an nMigen value"):
+                msg="Object 'str' cannot be converted to an nMigen value"):
             Value.cast("str")
 
     def test_cast_enum(self):
@@ -42,7 +141,7 @@ class ValueTestCase(FHDLTestCase):
 
     def test_cast_enum_wrong(self):
         with self.assertRaises(TypeError,
-                msg="Only enumerations with integer values can be converted to nMigen values"):
+                msg="Only enumerations with integer values can be used as value shapes"):
             Value.cast(StringEnum.FOO)
 
     def test_bool(self):
@@ -97,11 +196,13 @@ class ValueTestCase(FHDLTestCase):
 class ConstTestCase(FHDLTestCase):
     def test_shape(self):
         self.assertEqual(Const(0).shape(),   (1, False))
+        self.assertIsInstance(Const(0).shape(), Shape)
         self.assertEqual(Const(1).shape(),   (1, False))
         self.assertEqual(Const(10).shape(),  (4, False))
         self.assertEqual(Const(-10).shape(), (5, True))
 
         self.assertEqual(Const(1, 4).shape(),          (4, False))
+        self.assertEqual(Const(-1, 4).shape(),         (4, True))
         self.assertEqual(Const(1, (4, True)).shape(),  (4, True))
         self.assertEqual(Const(0, (0, False)).shape(), (0, False))
 
@@ -380,6 +481,7 @@ class SliceTestCase(FHDLTestCase):
     def test_shape(self):
         s1 = Const(10)[2]
         self.assertEqual(s1.shape(), (1, False))
+        self.assertIsInstance(s1.shape(), Shape)
         s2 = Const(-10)[0:2]
         self.assertEqual(s2.shape(), (2, False))
 
@@ -423,6 +525,7 @@ class BitSelectTestCase(FHDLTestCase):
     def test_shape(self):
         s1 = self.c.bit_select(self.s, 2)
         self.assertEqual(s1.shape(), (2, False))
+        self.assertIsInstance(s1.shape(), Shape)
         s2 = self.c.bit_select(self.s, 0)
         self.assertEqual(s2.shape(), (0, False))
 
@@ -447,6 +550,7 @@ class WordSelectTestCase(FHDLTestCase):
     def test_shape(self):
         s1 = self.c.word_select(self.s, 2)
         self.assertEqual(s1.shape(), (2, False))
+        self.assertIsInstance(s1.shape(), Shape)
 
     def test_stride(self):
         s1 = self.c.word_select(self.s, 2)
@@ -467,6 +571,7 @@ class CatTestCase(FHDLTestCase):
     def test_shape(self):
         c0 = Cat()
         self.assertEqual(c0.shape(), (0, False))
+        self.assertIsInstance(c0.shape(), Shape)
         c1 = Cat(Const(10))
         self.assertEqual(c1.shape(), (4, False))
         c2 = Cat(Const(10), Const(1))
@@ -483,6 +588,7 @@ class ReplTestCase(FHDLTestCase):
     def test_shape(self):
         s1 = Repl(Const(10), 3)
         self.assertEqual(s1.shape(), (12, False))
+        self.assertIsInstance(s1.shape(), Shape)
         s2 = Repl(Const(10), 0)
         self.assertEqual(s2.shape(), (0, False))
 
@@ -561,6 +667,7 @@ class SignalTestCase(FHDLTestCase):
     def test_shape(self):
         s1 = Signal()
         self.assertEqual(s1.shape(), (1, False))
+        self.assertIsInstance(s1.shape(), Shape)
         s2 = Signal(2)
         self.assertEqual(s2.shape(), (2, False))
         s3 = Signal((2, False))
@@ -578,7 +685,7 @@ class SignalTestCase(FHDLTestCase):
         s9 = Signal.range(-20, 16)
         self.assertEqual(s9.shape(), (6, True))
         s10 = Signal.range(0)
-        self.assertEqual(s10.shape(), (1, False))
+        self.assertEqual(s10.shape(), (0, False))
         s11 = Signal.range(1)
         self.assertEqual(s11.shape(), (1, False))
         # deprecated
@@ -692,7 +799,9 @@ class ClockSignalTestCase(FHDLTestCase):
             ClockSignal(1)
 
     def test_shape(self):
-        self.assertEqual(ClockSignal().shape(), (1, False))
+        s1 = ClockSignal()
+        self.assertEqual(s1.shape(), (1, False))
+        self.assertIsInstance(s1.shape(), Shape)
 
     def test_repr(self):
         s1 = ClockSignal()
@@ -716,7 +825,9 @@ class ResetSignalTestCase(FHDLTestCase):
             ResetSignal(1)
 
     def test_shape(self):
-        self.assertEqual(ResetSignal().shape(), (1, False))
+        s1 = ResetSignal()
+        self.assertEqual(s1.shape(), (1, False))
+        self.assertIsInstance(s1.shape(), Shape)
 
     def test_repr(self):
         s1 = ResetSignal()
@@ -743,6 +854,7 @@ class UserValueTestCase(FHDLTestCase):
     def test_shape(self):
         uv = MockUserValue(1)
         self.assertEqual(uv.shape(), (1, False))
+        self.assertIsInstance(uv.shape(), Shape)
         uv.lowered = 2
         self.assertEqual(uv.shape(), (1, False))
         self.assertEqual(uv.lower_count, 1)
index bae4fb1b18c29c1903e9e592ce927b80f735172e..e3721ef828b781be8176cd6bcb26caa82c4443da 100644 (file)
@@ -41,6 +41,12 @@ class LayoutTestCase(FHDLTestCase):
         self.assertEqual(layout["enum"], ((2, False), DIR_NONE))
         self.assertEqual(layout["enum_dir"], ((2, False), DIR_FANOUT))
 
+    def test_range_field(self):
+        layout = Layout.wrap([
+            ("range", range(0, 7)),
+        ])
+        self.assertEqual(layout["range"], ((3, False), DIR_NONE))
+
     def test_slice_tuple(self):
         layout = Layout.wrap([
             ("a", 1),
@@ -77,8 +83,8 @@ class LayoutTestCase(FHDLTestCase):
 
     def test_wrong_shape(self):
         with self.assertRaises(TypeError,
-                msg="Field ('a', 'x') has invalid shape: should be an int, tuple, Enum, or "
-                    "list of fields of a nested record"):
+                msg="Field ('a', 'x') has invalid shape: should be castable to Shape or "
+                    "list of fields of a nested record"):
             Layout.wrap([("a", "x")])