insndb: decouple visitors and walking
authorDmitry Selyutin <ghostmansd@gmail.com>
Thu, 8 Jun 2023 13:29:18 +0000 (16:29 +0300)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 22 Dec 2023 19:26:19 +0000 (19:26 +0000)
src/openpower/insndb/core.py
src/openpower/insndb/db.py

index 2876e787ec851fee5af3446cbf8284f7b0cf03e1..7c0fad5389b8f38635658455addbf5f463b2dd64 100644 (file)
@@ -56,34 +56,43 @@ from openpower.decoder.power_fields import (
 from openpower.decoder.pseudo.pagereader import ISA as _ISA
 
 
-class Visitor:
-    def __call__(self, node):
-        method = node.__class__.__name__
-        method = getattr(self, method, self.Node)
-        return method(node=node)
+class Node:
+    def subnodes(self, match=None):
+        return ()
 
-    @_contextlib.contextmanager
-    def Node(self, node):
-        for subnode in node.subnodes:
-            with self(subnode):
-                pass
-        yield node
 
+@_dataclasses.dataclass(eq=True, frozen=True)
+class Dataclass:
+    def subnodes(self, match=None):
+        if match is None:
+            match = lambda subnode: True
+
+        def subnode(field):
+            return getattr(self, field.name)
+
+        yield from filter(match, map(subnode, _dataclasses.fields()))
 
-class Node:
-    @property
-    def subnodes(self):
-        yield from ()
+
+class Visitor:
+    @_contextlib.contextmanager
+    def __call__(self, node):
+        yield node
 
 
-def walk(root):
+def walk(root, match=None):
     nodes = _collections.deque([root])
     while nodes:
         node = nodes.popleft()
-        nodes.extend(node.subnodes)
+        nodes.extend(node.subnodes(match=match))
         yield node
 
 
+def visit(visitor, node):
+    with visitor(node=node):
+        for subnode in node.subnodes():
+            visit(visitor=visitor, node=subnode)
+
+
 @_functools.total_ordering
 class Style(_enum.Enum):
     LEGACY = _enum.auto()
@@ -859,10 +868,11 @@ class Record(Node):
     mdwn: MarkdownRecord
     svp64: SVP64Record = None
 
-    @property
-    def subnodes(self):
+    def subnodes(self, match=None):
+        extras = []
         for (name, fields) in self.extras.items():
-            yield Extra(name=name, **fields)
+            extras.append(Extra(name=name, **fields))
+        yield from filter(match, extras)
 
     @property
     def extras(self):
@@ -3745,9 +3755,11 @@ class Database(Node):
 
         return super().__init__()
 
-    @property
-    def subnodes(self):
-        yield from self
+    def subnodes(self, match=None):
+        if match is None:
+            match = lambda subnode: True
+
+        yield from filter(match, self)
 
     def __repr__(self):
         return repr(self.__db)
index 748f60abb139f108f59110e9dc2f05868751ac74..2fce21b158efcc88b1ad990b167a9f551160beef 100644 (file)
@@ -8,7 +8,10 @@ from openpower.decoder.power_enums import (
 )
 from openpower.insndb.core import (
     Database,
+    Extra,
+    Record,
     Visitor,
+    visit,
 )
 
 
@@ -35,65 +38,54 @@ class SVP64Instruction(Instruction):
         return self
 
 
-class RecordNameVisitor(Visitor):
-    def __init__(self, name):
-        self.__name = name
-        self.__records = set()
-        return super().__init__()
-
-    @contextlib.contextmanager
-    def Record(self, node):
-        if node.name == self.__name:
-            self.__records.add(node)
-        yield node
-
-    def __iter__(self):
-        yield from self.__records
-
-
 class ListVisitor(Visitor):
     @contextlib.contextmanager
-    def Record(self, node):
-        print(node.name)
+    def __call__(self, node):
+        if isinstance(node, Record):
+            print(node.name)
         yield node
 
 
 class OpcodesVisitor(Visitor):
     @contextlib.contextmanager
-    def Record(self, node):
-        for opcode in node.opcodes:
-            print(opcode)
+    def __call__(self, node):
+        if isinstance(node, Record):
+            for opcode in node.opcodes:
+                print(opcode)
         yield node
 
 
 class OperandsVisitor(Visitor):
     @contextlib.contextmanager
-    def Record(self, node):
-        for operand in node.dynamic_operands:
-            print(operand.name, ",".join(map(str, operand.span)))
-        for operand in node.static_operands:
-            if operand.name not in ("PO", "XO"):
-                desc = f"{operand.name}={operand.value}"
-                print(desc, ",".join(map(str, operand.span)))
+    def __call__(self, node):
+        if isinstance(node, Record):
+            for operand in node.dynamic_operands:
+                print(operand.name, ",".join(map(str, operand.span)))
+            for operand in node.static_operands:
+                if operand.name not in ("PO", "XO"):
+                    desc = f"{operand.name}={operand.value}"
+                    print(desc, ",".join(map(str, operand.span)))
         yield node
 
 
 class PCodeVisitor(Visitor):
     @contextlib.contextmanager
-    def Record(self, node):
-        for line in node.pcode:
-            print(line)
+    def __call__(self, node):
+        if isinstance(node, Record):
+            for line in node.pcode:
+                print(line)
         yield node
 
 
 class ExtrasVisitor(Visitor):
     @contextlib.contextmanager
-    def Extra(self, node):
-        print(node.name)
-        print("    sel", node.sel)
-        print("    reg", node.reg)
-        print("    seltype", node.seltype)
-        print("    idx", node.idx)
+    def __call__(self, node):
+        if isinstance(node, Extra):
+            print(node.name)
+            print("    sel", node.sel)
+            print("    reg", node.reg)
+            print("    seltype", node.seltype)
+            print("    idx", node.idx)
         yield node
 
 
@@ -147,15 +139,14 @@ def main():
 
     db = Database(find_wiki_dir())
     if command in ("list",):
-        nodes = (db,)
+        match = None
     else:
-        match = RecordNameVisitor(name=args["insn"])
-        with match(node=db):
-            nodes = frozenset(match)
+        insn = args.pop("insn")
+        def match(record):
+            return (isinstance(record, Record) and (record.name == insn))
 
-    for node in nodes:
-        with visitor(node=node):
-            pass
+    for node in db.subnodes(match=match):
+        visit(visitor=visitor, node=node)
 
 
 if __name__ == "__main__":