From: whitequark Date: Mon, 17 Dec 2018 17:13:08 +0000 (+0000) Subject: hdl.ast: factor out _MappedKeyDict, _MappedKeySet. NFC. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=fcc5aff3b7ee5c9a3695c263fc5f032227ce3e2d;p=nmigen.git hdl.ast: factor out _MappedKeyDict, _MappedKeySet. NFC. --- diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 9ad7b90..f9c111b 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -884,6 +884,92 @@ class Passive(Statement): return "(passive)" +class _MappedKeyCollection(metaclass=ABCMeta): + @abstractmethod + def _map_key(self, key): + pass + + @abstractmethod + def _unmap_key(self, key): + pass + + +class _MappedKeyDict(MutableMapping, _MappedKeyCollection): + def __init__(self, pairs=()): + self._storage = OrderedDict() + for key, value in pairs: + self[key] = value + + def __getitem__(self, key): + key = None if key is None else self._map_key(key) + return self._storage[key] + + def __setitem__(self, key, value): + key = None if key is None else self._map_key(key) + self._storage[key] = value + + def __delitem__(self, key): + key = None if key is None else self._map_key(key) + del self._storage[key] + + def __iter__(self): + for key in self._storage: + if key is None: + yield None + else: + yield self._unmap_key(key) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if len(self) != len(other): + return False + for ak, bk in zip(sorted(self._storage), sorted(other._storage)): + if ak != bk: + return False + if self._storage[ak] != other._storage[bk]: + return False + return True + + def __len__(self): + return len(self._storage) + + def __repr__(self): + pairs = ["({!r}, {!r})".format(k, v) for k, v in self.items()] + return "{}([{}])".format(type(self).__name__, ", ".join(pairs)) + + +class _MappedKeySet(MutableSet, _MappedKeyCollection): + def __init__(self, elements=()): + self._storage = OrderedDict() + for elem in elements: + self.add(elem) + + def add(self, value): + self._storage[self._map_key(value)] = None + + def update(self, values): + for value in values: + self.add(value) + + def discard(self, value): + if value in self: + del self._storage[self._map_key(value)] + + def __contains__(self, value): + return self._map_key(value) in self._storage + + def __iter__(self): + for key in [k for k in self._storage]: + yield self._unmap_key(key) + + def __len__(self): + return len(self._storage) + + def __repr__(self): + return "{}({})".format(type(self).__name__, ", ".join(repr(x) for x in self)) + + class ValueKey: def __init__(self, value): self.value = Value.wrap(value) @@ -966,71 +1052,11 @@ class ValueKey: return "<{}.ValueKey {!r}>".format(__name__, self.value) -class ValueDict(MutableMapping): - def __init__(self, pairs=()): - self._inner = dict() - for key, value in pairs: - self[key] = value - - def __getitem__(self, key): - key = None if key is None else ValueKey(key) - return self._inner[key] - - def __setitem__(self, key, value): - key = None if key is None else ValueKey(key) - self._inner[key] = value - - def __delitem__(self, key): - key = None if key is None else ValueKey(key) - del self._inner[key] +class ValueDict(_MappedKeyDict): + _map_key = ValueKey + _unmap_key = lambda self, key: key.value - def __iter__(self): - return map(lambda x: None if x is None else x.value, sorted(self._inner)) - - def __eq__(self, other): - if not isinstance(other, ValueDict): - return False - if len(self) != len(other): - return False - for ak, bk in zip(self, other): - if ValueKey(ak) != ValueKey(bk): - return False - if self[ak] != other[bk]: - return False - return True - - def __len__(self): - return len(self._inner) - def __repr__(self): - pairs = ["({!r}, {!r})".format(k, v) for k, v in self.items()] - return "ValueDict([{}])".format(", ".join(pairs)) - - -class ValueSet(MutableSet): - def __init__(self, elements=()): - self._inner = set() - for elem in elements: - self.add(elem) - - def add(self, value): - self._inner.add(ValueKey(value)) - - def update(self, values): - for value in values: - self.add(value) - - def discard(self, value): - self._inner.discard(ValueKey(value)) - - def __contains__(self, value): - return ValueKey(value) in self._inner - - def __iter__(self): - return map(lambda x: x.value, sorted(self._inner)) - - def __len__(self): - return len(self._inner) - - def __repr__(self): - return "ValueSet({})".format(", ".join(repr(x) for x in self)) +class ValueSet(_MappedKeySet): + _map_key = ValueKey + _unmap_key = lambda self, key: key.value