From: whitequark Date: Fri, 11 Oct 2019 12:52:41 +0000 (+0000) Subject: hdl.ast: add an explicit Shape class, included in prelude. X-Git-Tag: v0.1rc1~23 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=6aabdc0a73dac75e42e0f0923f66497827cd8b17;p=nmigen.git hdl.ast: add an explicit Shape class, included in prelude. Shapes have long been a part of nMigen, but represented using tuples. This commit adds a Shape class (using namedtuple for backwards compatibility), and accepts anything castable to Shape (including enums, ranges, etc) anywhere a tuple was accepted previously. In addition, `signed(n)` and `unsigned(n)` are added as aliases for `Shape(n, signed=True)` and `Shape(n, signed=False)`, transforming code such as `Signal((8, True))` to `Signal(signed(8))`. These aliases are also included in prelude. Preparation for #225. --- diff --git a/nmigen/hdl/__init__.py b/nmigen/hdl/__init__.py index 6f6d6a9..355b9f8 100644 --- a/nmigen/hdl/__init__.py +++ b/nmigen/hdl/__init__.py @@ -1,3 +1,4 @@ +from .ast import Shape, unsigned, signed from .ast import Value, Const, C, Mux, Cat, Repl, Array, Signal, ClockSignal, ResetSignal from .dsl import Module from .cd import ClockDomain diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index f4528ed..db4c24e 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -2,6 +2,7 @@ from abc import ABCMeta, abstractmethod import builtins import traceback import warnings +import typing from collections import OrderedDict from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence from enum import Enum @@ -11,6 +12,7 @@ from ..tools import * __all__ = [ + "Shape", "signed", "unsigned", "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Array", "ArrayProxy", "Signal", "ClockSignal", "ResetSignal", @@ -23,21 +25,65 @@ __all__ = [ class DUID: - """Deterministic Unique IDentifier""" + """Deterministic Unique IDentifier.""" __next_uid = 0 def __init__(self): self.duid = DUID.__next_uid DUID.__next_uid += 1 -def _enum_shape(enum_type): - min_value = min(member.value for member in enum_type) - max_value = max(member.value for member in enum_type) - if not isinstance(min_value, int) or not isinstance(max_value, int): - raise TypeError("Only enumerations with integer values can be converted to nMigen values") - signed = min_value < 0 or max_value < 0 - width = max(bits_for(min_value, signed), bits_for(max_value, signed)) - return (width, signed) +class Shape(typing.NamedTuple): + """Bit width and signedness of a value. + + Attributes + ---------- + width : int + The number of bits in the representation, including the sign bit (if any). + signed : bool + If ``False``, the value is unsigned. If ``True``, the value is signed two's complement. + """ + width: int = 1 + signed: bool = False + + @staticmethod + def cast(obj): + if isinstance(obj, int): + return Shape(obj) + if isinstance(obj, tuple): + return Shape(*obj) + if isinstance(obj, range): + if len(obj) == 0: + return Shape(0, obj.start < 0) + signed = obj.start < 0 or (obj.stop - obj.step) < 0 + width = max(bits_for(obj.start, signed), + bits_for(obj.stop - obj.step, signed)) + return Shape(width, signed) + if isinstance(obj, type) and issubclass(obj, Enum): + min_value = min(member.value for member in obj) + max_value = max(member.value for member in obj) + if not isinstance(min_value, int) or not isinstance(max_value, int): + raise TypeError("Only enumerations with integer values can be used " + "as value shapes") + signed = min_value < 0 or max_value < 0 + width = max(bits_for(min_value, signed), bits_for(max_value, signed)) + return Shape(width, signed) + raise TypeError("Object {!r} cannot be used as value shape".format(obj)) + + +# TODO: use dataclasses instead of this hack +def _Shape___init__(self, width=1, signed=False): + if not isinstance(width, int) or width < 0: + raise TypeError("Width must be a non-negative integer, not {!r}" + .format(width)) +Shape.__init__ = _Shape___init__ + + +def unsigned(width): + return Shape(width, signed=False) + + +def signed(width): + return Shape(width, signed=True) class Value(metaclass=ABCMeta): @@ -50,12 +96,11 @@ class Value(metaclass=ABCMeta): """ if isinstance(obj, Value): return obj - elif isinstance(obj, (bool, int)): + if isinstance(obj, int): return Const(obj) - elif isinstance(obj, Enum): - return Const(obj.value, _enum_shape(type(obj))) - else: - raise TypeError("Object {!r} is not an nMigen value".format(obj)) + if isinstance(obj, Enum): + return Const(obj.value, Shape.cast(type(obj))) + raise TypeError("Object {!r} cannot be converted to an nMigen value".format(obj)) # TODO(nmigen-0.2): remove this @classmethod @@ -146,7 +191,7 @@ class Value(metaclass=ABCMeta): return Operator(">=", [self, other]) def __len__(self): - return self.shape()[0] + return self.shape().width def __getitem__(self, key): n = len(self) @@ -329,20 +374,19 @@ class Value(metaclass=ABCMeta): @abstractmethod def shape(self): - """Bit length and signedness of a value. + """Bit width and signedness of a value. Returns ------- - int, bool - Number of bits required to store `v` or available in `v`, followed by - whether `v` has a sign bit (included in the bit count). + Shape + See :class:`Shape`. Examples -------- - >>> Value.shape(Signal(8)) - 8, False - >>> Value.shape(C(0xaa)) - 8, False + >>> Signal(8).shape() + Shape(width=8, signed=False) + >>> Const(0xaa).shape() + Shape(width=8, signed=False) """ pass # :nocov: @@ -391,13 +435,12 @@ class Const(Value): # We deliberately do not call Value.__init__ here. self.value = int(value) if shape is None: - shape = bits_for(self.value), self.value < 0 - if isinstance(shape, int): - shape = shape, self.value < 0 + shape = Shape(bits_for(self.value), signed=self.value < 0) + elif isinstance(shape, int): + shape = Shape(shape, signed=self.value < 0) + else: + shape = Shape.cast(shape) self.width, self.signed = shape - if not isinstance(self.width, int) or self.width < 0: - raise TypeError("Width must be a non-negative integer, not {!r}" - .format(self.width)) self.value = self.normalize(self.value, shape) # TODO(nmigen-0.2): move this to nmigen.compat and make it a deprecated extension @@ -407,7 +450,7 @@ class Const(Value): return self.width def shape(self): - return self.width, self.signed + return Shape(self.width, self.signed) def _rhs_signals(self): return ValueSet() @@ -425,15 +468,13 @@ C = Const # shorthand class AnyValue(Value, DUID): def __init__(self, shape, *, src_loc_at=0): super().__init__(src_loc_at=src_loc_at) - if isinstance(shape, int): - shape = shape, False - self.width, self.signed = shape + self.width, self.signed = Shape.cast(shape) if not isinstance(self.width, int) or self.width < 0: raise TypeError("Width must be a non-negative integer, not {!r}" .format(self.width)) def shape(self): - return self.width, self.signed + return Shape(self.width, self.signed) def _rhs_signals(self): return ValueSet() @@ -470,41 +511,41 @@ class Operator(Value): b_bits, b_sign = b_shape if not a_sign and not b_sign: # both operands unsigned - return max(a_bits, b_bits), False + return Shape(max(a_bits, b_bits), False) elif a_sign and b_sign: # both operands signed - return max(a_bits, b_bits), True + return Shape(max(a_bits, b_bits), True) elif not a_sign and b_sign: # first operand unsigned (add sign bit), second operand signed - return max(a_bits + 1, b_bits), True + return Shape(max(a_bits + 1, b_bits), True) else: # first signed, second operand unsigned (add sign bit) - return max(a_bits, b_bits + 1), True + return Shape(max(a_bits, b_bits + 1), True) op_shapes = list(map(lambda x: x.shape(), self.operands)) if len(op_shapes) == 1: (a_width, a_signed), = op_shapes if self.operator in ("+", "~"): - return a_width, a_signed + return Shape(a_width, a_signed) if self.operator == "-": if not a_signed: - return a_width + 1, True + return Shape(a_width + 1, True) else: - return a_width, a_signed + return Shape(a_width, a_signed) if self.operator in ("b", "r|", "r&", "r^"): - return 1, False + return Shape(1, False) elif len(op_shapes) == 2: (a_width, a_signed), (b_width, b_signed) = op_shapes - if self.operator == "+" or self.operator == "-": + if self.operator in ("+", "-"): width, signed = _bitwise_binary_shape(*op_shapes) - return width + 1, signed + return Shape(width + 1, signed) if self.operator == "*": - return a_width + b_width, a_signed or b_signed + return Shape(a_width + b_width, a_signed or b_signed) if self.operator in ("//", "%"): assert not b_signed - return a_width, a_signed + return Shape(a_width, a_signed) if self.operator in ("<", "<=", "==", "!=", ">", ">="): - return 1, False + return Shape(1, False) if self.operator in ("&", "^", "|"): return _bitwise_binary_shape(*op_shapes) if self.operator == "<<": @@ -512,13 +553,13 @@ class Operator(Value): extra = 2 ** (b_width - 1) - 1 else: extra = 2 ** (b_width) - 1 - return a_width + extra, a_signed + return Shape(a_width + extra, a_signed) if self.operator == ">>": if b_signed: extra = 2 ** (b_width - 1) else: extra = 0 - return a_width + extra, a_signed + return Shape(a_width + extra, a_signed) elif len(op_shapes) == 3: if self.operator == "m": s_shape, a_shape, b_shape = op_shapes @@ -581,7 +622,7 @@ class Slice(Value): self.end = end def shape(self): - return self.end - self.start, False + return Shape(self.end - self.start) def _lhs_signals(self): return self.value._lhs_signals() @@ -608,7 +649,7 @@ class Part(Value): self.stride = stride def shape(self): - return self.width, False + return Shape(self.width) def _lhs_signals(self): return self.value._lhs_signals() @@ -651,7 +692,7 @@ class Cat(Value): self.parts = [Value.cast(v) for v in flatten(args)] def shape(self): - return sum(len(part) for part in self.parts), False + return Shape(sum(len(part) for part in self.parts)) def _lhs_signals(self): return union((part._lhs_signals() for part in self.parts), start=ValueSet()) @@ -701,7 +742,7 @@ class Repl(Value): self.count = count def shape(self): - return len(self.value) * self.count, False + return Shape(len(self.value) * self.count) def _rhs_signals(self): return self.value._rhs_signals() @@ -792,13 +833,7 @@ class Signal(Value, DUID): else: if not (min is None and max is None): raise ValueError("Only one of bits/signedness or bounds may be specified") - if isinstance(shape, int): - self.width, self.signed = shape, False - else: - self.width, self.signed = shape - - if not isinstance(self.width, int) or self.width < 0: - raise TypeError("Width must be a non-negative integer, not {!r}".format(self.width)) + self.width, self.signed = Shape.cast(shape) reset_width = bits_for(reset, self.signed) if reset != 0 and reset_width > self.width: @@ -829,14 +864,7 @@ class Signal(Value, DUID): That is, for any given ``range(*args)``, ``Signal.range(*args)`` can represent any ``x for x in range(*args)``. """ - value_range = range(*args) - if len(value_range) > 0: - signed = value_range.start < 0 or (value_range.stop - value_range.step) < 0 - else: - signed = value_range.start < 0 - width = max(bits_for(value_range.start, signed), - bits_for(value_range.stop - value_range.step, signed)) - return cls((width, signed), src_loc_at=1 + src_loc_at, **kwargs) + return cls(Shape.cast(range(*args)), src_loc_at=1 + src_loc_at, **kwargs) @classmethod def enum(cls, enum_type, *, src_loc_at=0, **kwargs): @@ -849,7 +877,7 @@ class Signal(Value, DUID): """ if not issubclass(enum_type, Enum): raise TypeError("Type {!r} is not an enumeration") - return cls(_enum_shape(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs) + return cls(Shape.cast(enum_type), src_loc_at=1 + src_loc_at, decoder=enum_type, **kwargs) @classmethod def like(cls, other, *, name=None, name_suffix=None, src_loc_at=0, **kwargs): @@ -885,7 +913,7 @@ class Signal(Value, DUID): self.width = value def shape(self): - return self.width, self.signed + return Shape(self.width, self.signed) def _lhs_signals(self): return ValueSet((self,)) @@ -919,7 +947,7 @@ class ClockSignal(Value): self.domain = domain def shape(self): - return 1, False + return Shape(1) def _lhs_signals(self): return ValueSet((self,)) @@ -956,7 +984,7 @@ class ResetSignal(Value): self.allow_reset_less = allow_reset_less def shape(self): - return 1, False + return Shape(1) def _lhs_signals(self): return ValueSet((self,)) @@ -1077,7 +1105,7 @@ class ArrayProxy(Value): for elem_width, elem_signed in (elem.shape() for elem in self._iter_as_values()): width = max(width, elem_width + elem_signed) signed = max(signed, elem_signed) - return width, signed + return Shape(width, signed) def _lhs_signals(self): signals = union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet()) @@ -1195,7 +1223,7 @@ class Initial(Value): super().__init__(src_loc_at=1 + src_loc_at) def shape(self): - return (1, False) + return Shape(1) def _rhs_signals(self): return ValueSet((self,)) diff --git a/nmigen/hdl/rec.py b/nmigen/hdl/rec.py index 9ad5ba7..3cd2746 100644 --- a/nmigen/hdl/rec.py +++ b/nmigen/hdl/rec.py @@ -5,7 +5,6 @@ from functools import reduce from .. import tracer from ..tools import union from .ast import * -from .ast import _enum_shape __all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"] @@ -46,17 +45,16 @@ class Layout: if not isinstance(name, str): raise TypeError("Field {!r} has invalid name: should be a string" .format(field)) - if isinstance(shape, type) and issubclass(shape, Enum): - shape = _enum_shape(shape) - if not isinstance(shape, (int, tuple, Layout)): - raise TypeError("Field {!r} has invalid shape: should be an int, tuple, Enum, or " - "list of fields of a nested record" - .format(field)) + if not isinstance(shape, Layout): + try: + shape = Shape.cast(shape) + except Exception as error: + raise TypeError("Field {!r} has invalid shape: should be castable to Shape " + "or a list of fields of a nested record" + .format(field)) if name in self.fields: raise NameError("Field {!r} has a name that is already present in the layout" .format(field)) - if isinstance(shape, int): - shape = (shape, False) self.fields[name] = (shape, direction) def __getitem__(self, item): @@ -159,7 +157,7 @@ class Record(Value): return super().__getitem__(item) def shape(self): - return sum(len(f) for f in self.fields.values()), False + return Shape(sum(len(f) for f in self.fields.values())) def _lhs_signals(self): return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet()) diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index 653d0f7..9022ec0 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -22,6 +22,105 @@ class StringEnum(Enum): BAR = "b" +class ShapeTestCase(FHDLTestCase): + def test_make(self): + s1 = Shape() + self.assertEqual(s1.width, 1) + self.assertEqual(s1.signed, False) + s2 = Shape(signed=True) + self.assertEqual(s2.width, 1) + self.assertEqual(s2.signed, True) + s3 = Shape(3, True) + self.assertEqual(s3.width, 3) + self.assertEqual(s3.signed, True) + + def test_make_wrong(self): + with self.assertRaises(TypeError, + msg="Width must be a non-negative integer, not -1"): + Shape(-1) + + def test_tuple(self): + width, signed = Shape() + self.assertEqual(width, 1) + self.assertEqual(signed, False) + + def test_unsigned(self): + s1 = unsigned(2) + self.assertIsInstance(s1, Shape) + self.assertEqual(s1.width, 2) + self.assertEqual(s1.signed, False) + + def test_signed(self): + s1 = signed(2) + self.assertIsInstance(s1, Shape) + self.assertEqual(s1.width, 2) + self.assertEqual(s1.signed, True) + + def test_cast_int(self): + s1 = Shape.cast(2) + self.assertEqual(s1.width, 2) + self.assertEqual(s1.signed, False) + + def test_cast_int_wrong(self): + with self.assertRaises(TypeError, + msg="Width must be a non-negative integer, not -1"): + Shape.cast(-1) + + def test_cast_tuple(self): + s1 = Shape.cast((1, False)) + self.assertEqual(s1.width, 1) + self.assertEqual(s1.signed, False) + s2 = Shape.cast((3, True)) + self.assertEqual(s2.width, 3) + self.assertEqual(s2.signed, True) + + def test_cast_tuple_wrong(self): + with self.assertRaises(TypeError, + msg="Width must be a non-negative integer, not -1"): + Shape.cast((-1, True)) + + def test_cast_range(self): + s1 = Shape.cast(range(0, 8)) + self.assertEqual(s1.width, 3) + self.assertEqual(s1.signed, False) + s2 = Shape.cast(range(0, 9)) + self.assertEqual(s2.width, 4) + self.assertEqual(s2.signed, False) + s3 = Shape.cast(range(-7, 8)) + self.assertEqual(s3.width, 4) + self.assertEqual(s3.signed, True) + s4 = Shape.cast(range(0, 1)) + self.assertEqual(s4.width, 1) + self.assertEqual(s4.signed, False) + s5 = Shape.cast(range(-1, 0)) + self.assertEqual(s5.width, 1) + self.assertEqual(s5.signed, True) + s6 = Shape.cast(range(0, 0)) + self.assertEqual(s6.width, 0) + self.assertEqual(s6.signed, False) + s7 = Shape.cast(range(-1, -1)) + self.assertEqual(s7.width, 0) + self.assertEqual(s7.signed, True) + + def test_cast_enum(self): + s1 = Shape.cast(UnsignedEnum) + self.assertEqual(s1.width, 2) + self.assertEqual(s1.signed, False) + s2 = Shape.cast(SignedEnum) + self.assertEqual(s2.width, 2) + self.assertEqual(s2.signed, True) + + def test_cast_enum_bad(self): + with self.assertRaises(TypeError, + msg="Only enumerations with integer values can be used as value shapes"): + Shape.cast(StringEnum) + + def test_cast_bad(self): + with self.assertRaises(TypeError, + msg="Object 'foo' cannot be used as value shape"): + Shape.cast("foo") + + class ValueTestCase(FHDLTestCase): def test_cast(self): self.assertIsInstance(Value.cast(0), Const) @@ -29,7 +128,7 @@ class ValueTestCase(FHDLTestCase): c = Const(0) self.assertIs(Value.cast(c), c) with self.assertRaises(TypeError, - msg="Object 'str' is not an nMigen value"): + msg="Object 'str' cannot be converted to an nMigen value"): Value.cast("str") def test_cast_enum(self): @@ -42,7 +141,7 @@ class ValueTestCase(FHDLTestCase): def test_cast_enum_wrong(self): with self.assertRaises(TypeError, - msg="Only enumerations with integer values can be converted to nMigen values"): + msg="Only enumerations with integer values can be used as value shapes"): Value.cast(StringEnum.FOO) def test_bool(self): @@ -97,11 +196,13 @@ class ValueTestCase(FHDLTestCase): class ConstTestCase(FHDLTestCase): def test_shape(self): self.assertEqual(Const(0).shape(), (1, False)) + self.assertIsInstance(Const(0).shape(), Shape) self.assertEqual(Const(1).shape(), (1, False)) self.assertEqual(Const(10).shape(), (4, False)) self.assertEqual(Const(-10).shape(), (5, True)) self.assertEqual(Const(1, 4).shape(), (4, False)) + self.assertEqual(Const(-1, 4).shape(), (4, True)) self.assertEqual(Const(1, (4, True)).shape(), (4, True)) self.assertEqual(Const(0, (0, False)).shape(), (0, False)) @@ -380,6 +481,7 @@ class SliceTestCase(FHDLTestCase): def test_shape(self): s1 = Const(10)[2] self.assertEqual(s1.shape(), (1, False)) + self.assertIsInstance(s1.shape(), Shape) s2 = Const(-10)[0:2] self.assertEqual(s2.shape(), (2, False)) @@ -423,6 +525,7 @@ class BitSelectTestCase(FHDLTestCase): def test_shape(self): s1 = self.c.bit_select(self.s, 2) self.assertEqual(s1.shape(), (2, False)) + self.assertIsInstance(s1.shape(), Shape) s2 = self.c.bit_select(self.s, 0) self.assertEqual(s2.shape(), (0, False)) @@ -447,6 +550,7 @@ class WordSelectTestCase(FHDLTestCase): def test_shape(self): s1 = self.c.word_select(self.s, 2) self.assertEqual(s1.shape(), (2, False)) + self.assertIsInstance(s1.shape(), Shape) def test_stride(self): s1 = self.c.word_select(self.s, 2) @@ -467,6 +571,7 @@ class CatTestCase(FHDLTestCase): def test_shape(self): c0 = Cat() self.assertEqual(c0.shape(), (0, False)) + self.assertIsInstance(c0.shape(), Shape) c1 = Cat(Const(10)) self.assertEqual(c1.shape(), (4, False)) c2 = Cat(Const(10), Const(1)) @@ -483,6 +588,7 @@ class ReplTestCase(FHDLTestCase): def test_shape(self): s1 = Repl(Const(10), 3) self.assertEqual(s1.shape(), (12, False)) + self.assertIsInstance(s1.shape(), Shape) s2 = Repl(Const(10), 0) self.assertEqual(s2.shape(), (0, False)) @@ -561,6 +667,7 @@ class SignalTestCase(FHDLTestCase): def test_shape(self): s1 = Signal() self.assertEqual(s1.shape(), (1, False)) + self.assertIsInstance(s1.shape(), Shape) s2 = Signal(2) self.assertEqual(s2.shape(), (2, False)) s3 = Signal((2, False)) @@ -578,7 +685,7 @@ class SignalTestCase(FHDLTestCase): s9 = Signal.range(-20, 16) self.assertEqual(s9.shape(), (6, True)) s10 = Signal.range(0) - self.assertEqual(s10.shape(), (1, False)) + self.assertEqual(s10.shape(), (0, False)) s11 = Signal.range(1) self.assertEqual(s11.shape(), (1, False)) # deprecated @@ -692,7 +799,9 @@ class ClockSignalTestCase(FHDLTestCase): ClockSignal(1) def test_shape(self): - self.assertEqual(ClockSignal().shape(), (1, False)) + s1 = ClockSignal() + self.assertEqual(s1.shape(), (1, False)) + self.assertIsInstance(s1.shape(), Shape) def test_repr(self): s1 = ClockSignal() @@ -716,7 +825,9 @@ class ResetSignalTestCase(FHDLTestCase): ResetSignal(1) def test_shape(self): - self.assertEqual(ResetSignal().shape(), (1, False)) + s1 = ResetSignal() + self.assertEqual(s1.shape(), (1, False)) + self.assertIsInstance(s1.shape(), Shape) def test_repr(self): s1 = ResetSignal() @@ -743,6 +854,7 @@ class UserValueTestCase(FHDLTestCase): def test_shape(self): uv = MockUserValue(1) self.assertEqual(uv.shape(), (1, False)) + self.assertIsInstance(uv.shape(), Shape) uv.lowered = 2 self.assertEqual(uv.shape(), (1, False)) self.assertEqual(uv.lower_count, 1) diff --git a/nmigen/test/test_hdl_rec.py b/nmigen/test/test_hdl_rec.py index bae4fb1..e3721ef 100644 --- a/nmigen/test/test_hdl_rec.py +++ b/nmigen/test/test_hdl_rec.py @@ -41,6 +41,12 @@ class LayoutTestCase(FHDLTestCase): self.assertEqual(layout["enum"], ((2, False), DIR_NONE)) self.assertEqual(layout["enum_dir"], ((2, False), DIR_FANOUT)) + def test_range_field(self): + layout = Layout.wrap([ + ("range", range(0, 7)), + ]) + self.assertEqual(layout["range"], ((3, False), DIR_NONE)) + def test_slice_tuple(self): layout = Layout.wrap([ ("a", 1), @@ -77,8 +83,8 @@ class LayoutTestCase(FHDLTestCase): def test_wrong_shape(self): with self.assertRaises(TypeError, - msg="Field ('a', 'x') has invalid shape: should be an int, tuple, Enum, or " - "list of fields of a nested record"): + msg="Field ('a', 'x') has invalid shape: should be castable to Shape or " + "a list of fields of a nested record"): Layout.wrap([("a", "x")])