yield from filter(match, map(field, _dataclasses.fields(clsself)))
-class Visitor:
+class VisitorMethod:
+ def __init__(self, nodecls, method):
+ self.__nodecls = nodecls
+ self.__method = method
+ return super().__init__()
+
+ @property
+ def type(self):
+ return self.__nodecls
+
@_contextlib.contextmanager
def __call__(self, node):
- yield node
+ return self.__method(self=self, node=node)
+
+
+class VisitorMeta(type):
+ def __init__(cls, name, bases, ns):
+ cls.__registry = {}
+ for (key, value) in ns.items():
+ if isinstance(value, VisitorMethod):
+ if value.type in cls.__registry:
+ raise AttributeError(f"overriding visitor method: {key!r}")
+ cls.__registry[value.type] = value
+ return super().__init__(name, bases, ns)
+
+ def __contains__(self, nodecls):
+ return self.__registry.__contains__(nodecls)
+
+ def __getitem__(self, nodecls):
+ return self.__registry.__getitem__(nodecls)
+
+ def __setitem__(self, nodecls, call):
+ return self.__registry.__setitem__(nodecls, call)
+
+ def __iter__(self):
+ yield from self.__registry.items()
+
+
+class Visitor(metaclass=VisitorMeta):
+ @_contextlib.contextmanager
+ def __call__(self, node):
+ (visitorcls, nodecls) = map(type, (self, node))
+ if nodecls in visitorcls:
+ handler = visitorcls[nodecls]
+ with handler(node=node) as ctx:
+ yield ctx
+ else:
+ yield node
+
+
+class visitormethod:
+ def __init__(self, nodecls):
+ if not isinstance(nodecls, type):
+ raise ValueError(nodecls)
+ self.__nodecls = nodecls
+ return super().__init__()
+
+ def __call__(self, method):
+ if not callable(method):
+ raise ValueError(method)
+ return VisitorMethod(nodecls=self.__nodecls, method=method)
def walk(root, match=None):