From 3f78f9d559d21d7c92eb0fe0e851b9d585765f1a Mon Sep 17 00:00:00 2001 From: Dmitry Selyutin Date: Sat, 10 Jun 2023 18:40:47 +0300 Subject: [PATCH] insndb: holy-cow visitors --- src/openpower/insndb/core.py | 61 ++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/src/openpower/insndb/core.py b/src/openpower/insndb/core.py index 05042d2d..a451fc38 100644 --- a/src/openpower/insndb/core.py +++ b/src/openpower/insndb/core.py @@ -106,10 +106,67 @@ class Dataclass(metaclass=DataclassMeta): yield from filter(match, map(field, _dataclasses.fields(clsself))) -class Visitor: +class VisitorMethod: + def __init__(self, nodecls, method): + self.__nodecls = nodecls + self.__method = method + return super().__init__() + + @property + def type(self): + return self.__nodecls + @_contextlib.contextmanager def __call__(self, node): - yield node + return self.__method(self=self, node=node) + + +class VisitorMeta(type): + def __init__(cls, name, bases, ns): + cls.__registry = {} + for (key, value) in ns.items(): + if isinstance(value, VisitorMethod): + if value.type in cls.__registry: + raise AttributeError(f"overriding visitor method: {key!r}") + cls.__registry[value.type] = value + return super().__init__(name, bases, ns) + + def __contains__(self, nodecls): + return self.__registry.__contains__(nodecls) + + def __getitem__(self, nodecls): + return self.__registry.__getitem__(nodecls) + + def __setitem__(self, nodecls, call): + return self.__registry.__setitem__(nodecls, call) + + def __iter__(self): + yield from self.__registry.items() + + +class Visitor(metaclass=VisitorMeta): + @_contextlib.contextmanager + def __call__(self, node): + (visitorcls, nodecls) = map(type, (self, node)) + if nodecls in visitorcls: + handler = visitorcls[nodecls] + with handler(node=node) as ctx: + yield ctx + else: + yield node + + +class visitormethod: + def __init__(self, nodecls): + if not isinstance(nodecls, type): + raise ValueError(nodecls) + self.__nodecls = nodecls + return super().__init__() + + def __call__(self, method): + if not callable(method): + raise ValueError(method) + return VisitorMethod(nodecls=self.__nodecls, method=method) def walk(root, match=None): -- 2.30.2