hdl.ast: factor out _MappedKeyDict, _MappedKeySet. NFC.
authorwhitequark <cz@m-labs.hk>
Mon, 17 Dec 2018 17:13:08 +0000 (17:13 +0000)
committerwhitequark <cz@m-labs.hk>
Mon, 17 Dec 2018 17:21:29 +0000 (17:21 +0000)
nmigen/hdl/ast.py

index 9ad7b90394ac04aaec637901925dc02e3e7334dc..f9c111b936c4e4cf6544fd2ecf74b6db9bb6771f 100644 (file)
@@ -884,6 +884,92 @@ class Passive(Statement):
         return "(passive)"
 
 
+class _MappedKeyCollection(metaclass=ABCMeta):
+    @abstractmethod
+    def _map_key(self, key):
+        pass
+
+    @abstractmethod
+    def _unmap_key(self, key):
+        pass
+
+
+class _MappedKeyDict(MutableMapping, _MappedKeyCollection):
+    def __init__(self, pairs=()):
+        self._storage = OrderedDict()
+        for key, value in pairs:
+            self[key] = value
+
+    def __getitem__(self, key):
+        key = None if key is None else self._map_key(key)
+        return self._storage[key]
+
+    def __setitem__(self, key, value):
+        key = None if key is None else self._map_key(key)
+        self._storage[key] = value
+
+    def __delitem__(self, key):
+        key = None if key is None else self._map_key(key)
+        del self._storage[key]
+
+    def __iter__(self):
+        for key in self._storage:
+            if key is None:
+                yield None
+            else:
+                yield self._unmap_key(key)
+
+    def __eq__(self, other):
+        if not isinstance(other, type(self)):
+            return False
+        if len(self) != len(other):
+            return False
+        for ak, bk in zip(sorted(self._storage), sorted(other._storage)):
+            if ak != bk:
+                return False
+            if self._storage[ak] != other._storage[bk]:
+                return False
+        return True
+
+    def __len__(self):
+        return len(self._storage)
+
+    def __repr__(self):
+        pairs = ["({!r}, {!r})".format(k, v) for k, v in self.items()]
+        return "{}([{}])".format(type(self).__name__, ", ".join(pairs))
+
+
+class _MappedKeySet(MutableSet, _MappedKeyCollection):
+    def __init__(self, elements=()):
+        self._storage = OrderedDict()
+        for elem in elements:
+            self.add(elem)
+
+    def add(self, value):
+        self._storage[self._map_key(value)] = None
+
+    def update(self, values):
+        for value in values:
+            self.add(value)
+
+    def discard(self, value):
+        if value in self:
+            del self._storage[self._map_key(value)]
+
+    def __contains__(self, value):
+        return self._map_key(value) in self._storage
+
+    def __iter__(self):
+        for key in [k for k in self._storage]:
+            yield self._unmap_key(key)
+
+    def __len__(self):
+        return len(self._storage)
+
+    def __repr__(self):
+        return "{}({})".format(type(self).__name__, ", ".join(repr(x) for x in self))
+
+
 class ValueKey:
     def __init__(self, value):
         self.value = Value.wrap(value)
@@ -966,71 +1052,11 @@ class ValueKey:
         return "<{}.ValueKey {!r}>".format(__name__, self.value)
 
 
-class ValueDict(MutableMapping):
-    def __init__(self, pairs=()):
-        self._inner = dict()
-        for key, value in pairs:
-            self[key] = value
-
-    def __getitem__(self, key):
-        key = None if key is None else ValueKey(key)
-        return self._inner[key]
-
-    def __setitem__(self, key, value):
-        key = None if key is None else ValueKey(key)
-        self._inner[key] = value
-
-    def __delitem__(self, key):
-        key = None if key is None else ValueKey(key)
-        del self._inner[key]
+class ValueDict(_MappedKeyDict):
+    _map_key   = ValueKey
+    _unmap_key = lambda self, key: key.value
 
-    def __iter__(self):
-        return map(lambda x: None if x is None else x.value, sorted(self._inner))
-
-    def __eq__(self, other):
-        if not isinstance(other, ValueDict):
-            return False
-        if len(self) != len(other):
-            return False
-        for ak, bk in zip(self, other):
-            if ValueKey(ak) != ValueKey(bk):
-                return False
-            if self[ak] != other[bk]:
-                return False
-        return True
-
-    def __len__(self):
-        return len(self._inner)
 
-    def __repr__(self):
-        pairs = ["({!r}, {!r})".format(k, v) for k, v in self.items()]
-        return "ValueDict([{}])".format(", ".join(pairs))
-
-
-class ValueSet(MutableSet):
-    def __init__(self, elements=()):
-        self._inner = set()
-        for elem in elements:
-            self.add(elem)
-
-    def add(self, value):
-        self._inner.add(ValueKey(value))
-
-    def update(self, values):
-        for value in values:
-            self.add(value)
-
-    def discard(self, value):
-        self._inner.discard(ValueKey(value))
-
-    def __contains__(self, value):
-        return ValueKey(value) in self._inner
-
-    def __iter__(self):
-        return map(lambda x: x.value, sorted(self._inner))
-
-    def __len__(self):
-        return len(self._inner)
-
-    def __repr__(self):
-        return "ValueSet({})".format(", ".join(repr(x) for x in self))
+class ValueSet(_MappedKeySet):
+    _map_key   = ValueKey
+    _unmap_key = lambda self, key: key.value