insndb: holy-cow visitors
authorDmitry Selyutin <ghostmansd@gmail.com>
Sat, 10 Jun 2023 15:40:47 +0000 (18:40 +0300)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 22 Dec 2023 19:26:19 +0000 (19:26 +0000)
src/openpower/insndb/core.py

index 05042d2d2ed39163444c0533ac63c2acc4c12973..a451fc38c905c9924b5e6923839fc621e301e0e3 100644 (file)
@@ -106,10 +106,67 @@ class Dataclass(metaclass=DataclassMeta):
         yield from filter(match, map(field, _dataclasses.fields(clsself)))
 
 
-class Visitor:
+class VisitorMethod:
+    def __init__(self, nodecls, method):
+        self.__nodecls = nodecls
+        self.__method = method
+        return super().__init__()
+
+    @property
+    def type(self):
+        return self.__nodecls
+
     @_contextlib.contextmanager
     def __call__(self, node):
-        yield node
+        return self.__method(self=self, node=node)
+
+
+class VisitorMeta(type):
+    def __init__(cls, name, bases, ns):
+        cls.__registry = {}
+        for (key, value) in ns.items():
+            if isinstance(value, VisitorMethod):
+                if value.type in cls.__registry:
+                    raise AttributeError(f"overriding visitor method: {key!r}")
+                cls.__registry[value.type] = value
+        return super().__init__(name, bases, ns)
+
+    def __contains__(self, nodecls):
+        return self.__registry.__contains__(nodecls)
+
+    def __getitem__(self, nodecls):
+        return self.__registry.__getitem__(nodecls)
+
+    def __setitem__(self, nodecls, call):
+        return self.__registry.__setitem__(nodecls, call)
+
+    def __iter__(self):
+        yield from self.__registry.items()
+
+
+class Visitor(metaclass=VisitorMeta):
+    @_contextlib.contextmanager
+    def __call__(self, node):
+        (visitorcls, nodecls) = map(type, (self, node))
+        if nodecls in visitorcls:
+            handler = visitorcls[nodecls]
+            with handler(node=node) as ctx:
+                yield ctx
+        else:
+            yield node
+
+
+class visitormethod:
+    def __init__(self, nodecls):
+        if not isinstance(nodecls, type):
+            raise ValueError(nodecls)
+        self.__nodecls = nodecls
+        return super().__init__()
+
+    def __call__(self, method):
+        if not callable(method):
+            raise ValueError(method)
+        return VisitorMethod(nodecls=self.__nodecls, method=method)
 
 
 def walk(root, match=None):