hdl.ast: don't inherit Shape from NamedTuple.
[nmigen.git] / nmigen / hdl / ast.py
index 37a3b46f3847200aa08b24be05631004a78fe686..55b825b2ddcba471fccc8213d8fa0d254aa31609 100644 (file)
@@ -32,7 +32,7 @@ class DUID:
         DUID.__next_uid += 1
 
 
-class Shape(typing.NamedTuple):
+class Shape:
     """Bit width and signedness of a value.
 
     A ``Shape`` can be constructed using:
@@ -55,8 +55,15 @@ class Shape(typing.NamedTuple):
     signed : bool
         If ``False``, the value is unsigned. If ``True``, the value is signed two's complement.
     """
-    width:  int  = 1
-    signed: bool = False
+    def __init__(self, width=1, signed=False):
+        if not isinstance(width, int) or width < 0:
+            raise TypeError("Width must be a non-negative integer, not {!r}"
+                            .format(width))
+        self.width = width
+        self.signed = signed
+
+    def __iter__(self):
+        return iter((self.width, self.signed))
 
     @staticmethod
     def cast(obj, *, src_loc_at=0):
@@ -95,13 +102,20 @@ class Shape(typing.NamedTuple):
         else:
             return "unsigned({})".format(self.width)
 
-
-# TODO: use dataclasses instead of this hack
-def _Shape___init__(self, width=1, signed=False):
-    if not isinstance(width, int) or width < 0:
-        raise TypeError("Width must be a non-negative integer, not {!r}"
-                        .format(width))
-Shape.__init__ = _Shape___init__
+    def __eq__(self, other):
+        if isinstance(other, tuple) and len(other) == 2:
+            width, signed = other
+            if isinstance(width, int) and isinstance(signed, bool):
+                return self.width == width and self.signed == signed
+            else:
+                raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, "
+                        "not {!r}"
+                        .format(other))
+        if not isinstance(other, Shape):
+            raise TypeError("Shapes may be compared with other Shapes and (int, bool) tuples, "
+                    "not {!r}"
+                    .format(other))
+        return self.width == other.width and self.signed == other.signed
 
 
 def unsigned(width):