insndb/core: switch to mdis
authorDmitry Selyutin <ghostmansd@gmail.com>
Thu, 22 Jun 2023 15:23:28 +0000 (18:23 +0300)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 22 Dec 2023 19:26:19 +0000 (19:26 +0000)
src/openpower/insndb/core.py
src/openpower/insndb/db.py

index 3d539e06ae2ea8a40b4265ff3690a5efa3d3b591..01a029d12fb0dac9d85f201ed4e7039166f17d25 100644 (file)
@@ -57,169 +57,14 @@ from openpower.decoder.power_fields import (
 from openpower.decoder.pseudo.pagereader import ISA as _ISA
 
 
-class walkmethod:
-    def __init__(self, walk):
-        self.__walk = walk
-        return super().__init__()
-
-    def __get__(self, instance, owner):
-        entity = owner if instance is None else instance
-        return _functools.partial(self.__walk, entity)
-
-
-class Node:
-    @walkmethod
-    def walk(clsself, match=lambda _: True):
-        yield from ()
-
-
 class DataclassMeta(type):
     def __new__(metacls, name, bases, ns):
         cls = super().__new__(metacls, name, bases, ns)
         return _dataclasses.dataclass(cls, eq=True, frozen=True)
 
 
-class Dataclass(Node, metaclass=DataclassMeta):
-    @walkmethod
-    def walk(clsself, match=lambda _: True):
-        def field_type(field):
-            return field.type
-
-        def field_value(field):
-            return getattr(clsself, field.name)
-
-        field_node = (field_type if isinstance(clsself, type) else field_value)
-
-        for field in _dataclasses.fields(clsself):
-            path = field.name
-            node = field_node(field)
-            yield (path, node)
-
-
-class Tuple(Node, tuple):
-    def __init_subclass__(cls, datatype):
-        cls.__datatype = datatype
-        return super().__init_subclass__()
-
-    @walkmethod
-    def walk(clsself, match=lambda _: True):
-        if isinstance(clsself, type):
-            yield ("[]", clsself.__datatype)
-        else:
-            for (index, item) in enumerate(filter(match, clsself)):
-                yield (str(index), item)
-
-
-class Dict(Node, dict):
-    def __init_subclass__(cls, datatype):
-        cls.__datatype = datatype
-        return super().__init_subclass__()
-
-    def __hash__(self):
-        return hash(tuple(sorted(self.items())))
-
-    def clear(self):
-        raise NotImplementedError()
-
-    def __delitem__(self, key):
-        raise NotImplementedError()
-
-    def __setitem__(self, key, value):
-        raise NotImplementedError()
-
-    def popitem(self) -> tuple:
-        raise NotImplementedError()
-
-    def pop(self, key, default=None):
-        raise NotImplementedError()
-
-    def update(self, entry, **kwargs):
-        raise NotImplementedError()
-
-    @walkmethod
-    def walk(clsself, match=lambda _: True):
-        if isinstance(clsself, type):
-            yield ("{}", clsself.__datatype)
-        else:
-            yield from filter(lambda kv: match(kv[0]), clsself.items())
-
-
-class VisitorMethod:
-    def __init__(self, nodecls, method):
-        self.__nodecls = nodecls
-        self.__method = method
-        return super().__init__()
-
-    @property
-    def type(self):
-        return self.__nodecls
-
-    @_contextlib.contextmanager
-    def __call__(self, path, node):
-        return self.__method(self=self, path=path, node=node)
-
-
-class VisitorMeta(type):
-    def __init__(cls, name, bases, ns):
-        cls.__registry = {}
-        for (key, value) in ns.items():
-            if isinstance(value, VisitorMethod):
-                if value.type in cls.__registry:
-                    raise AttributeError(f"overriding visitor method: {key!r}")
-                cls.__registry[value.type] = value
-        return super().__init__(name, bases, ns)
-
-    def __contains__(self, nodecls):
-        return self.__registry.__contains__(nodecls)
-
-    def __getitem__(self, nodecls):
-        return self.__registry.__getitem__(nodecls)
-
-    def __setitem__(self, nodecls, call):
-        return self.__registry.__setitem__(nodecls, call)
-
-    def __iter__(self):
-        yield from self.__registry.items()
-
-
-class Visitor(metaclass=VisitorMeta):
-    @_contextlib.contextmanager
-    def __call__(self, path, node):
-        (visitorcls, nodecls) = map(type, (self, node))
-        if nodecls in visitorcls:
-            handler = visitorcls[nodecls]
-            with handler(path=path, node=node) as ctx:
-                yield ctx
-        else:
-            yield node
-
-
-class visitormethod:
-    def __init__(self, nodecls):
-        if not isinstance(nodecls, type):
-            raise ValueError(nodecls)
-        self.__nodecls = nodecls
-        return super().__init__()
-
-    def __call__(self, method):
-        if not callable(method):
-            raise ValueError(method)
-        return VisitorMethod(nodecls=self.__nodecls, method=method)
-
-
-def walk(root, match=lambda _: True):
-    pairs = _collections.deque([root])
-    while pairs:
-        (path, node) = pairs.popleft()
-        pairs.extend(node.walk(match=match))
-        yield (path, node)
-
-
-def visit(visitor, node, path="/", match=lambda _: True):
-    with visitor(path=path, node=node):
-        if isinstance(node, Node):
-            for (subpath, subnode) in node.walk(match=match):
-                visit(visitor=visitor, path=subpath, node=subnode)
+class Dataclass(metaclass=DataclassMeta):
+    pass
 
 
 @_functools.total_ordering
@@ -396,7 +241,7 @@ class PPCRecord(Dataclass):
                 "sgl pipe",
             )
 
-    class Flags(Tuple, datatype=str, metaclass=FlagsMeta):
+    class Flags(tuple, metaclass=FlagsMeta):
         def __new__(cls, flags=frozenset()):
             flags = frozenset(flags)
             diff = (flags - frozenset(cls))
@@ -468,7 +313,7 @@ class PPCRecord(Dataclass):
         return frozenset(self.comment.split("=")[-1].split("/"))
 
 
-class PPCMultiRecord(Tuple, datatype=PPCRecord):
+class PPCMultiRecord(tuple):
     def __getattr__(self, attr):
         if attr == "opcode":
             if len(self) != 1:
@@ -750,7 +595,7 @@ class Section(Dataclass):
         return dataclass(cls, record, typemap=typemap, keymap=keymap)
 
 
-class Fields(Dict, datatype=type("Bits", (Tuple,), {}, datatype=int)):
+class Fields(dict):
     def __init__(self, items):
         if isinstance(items, dict):
             items = items.items()
@@ -763,11 +608,14 @@ class Fields(Dict, datatype=type("Bits", (Tuple,), {}, datatype=int)):
 
         return super().__init__(mapping)
 
+    def __hash__(self):
+        return hash(tuple(sorted(self.items())))
+
     def __iter__(self):
         yield from self.__mapping.items()
 
 
-class Operands:
+class Operands(dict):
     __GPR_PAIRS = (
         _Reg.RTp,
         _Reg.RSp,
@@ -825,10 +673,10 @@ class Operands:
 
             if "=" in operand:
                 (name, value) = operand.split("=")
-                mapping[name] = (StaticOperand, {
-                    "name": name,
-                    "value": int(value),
-                })
+                mapping[name] = (StaticOperand, (
+                    ("name", name),
+                    ("value", int(value)),
+                ))
             else:
                 name = operand
                 if name.endswith(")"):
@@ -862,48 +710,29 @@ class Operands:
                             cls = CR5Operand
 
                 if imm_name is not None:
-                    mapping[imm_name] = (imm_cls, {"name": imm_name})
-                mapping[name] = (cls, {"name": name})
-
-        static = []
-        dynamic = []
-        for (name, (cls, kwargs)) in mapping.items():
-            kwargs = dict(kwargs)
-            kwargs["name"] = name
-            if issubclass(cls, StaticOperand):
-                static.append((cls, kwargs))
-            elif issubclass(cls, DynamicOperand):
-                dynamic.append((cls, kwargs))
-            else:
-                raise ValueError(name)
+                    mapping[imm_name] = (imm_cls, (
+                        ("name", imm_name),
+                    ))
+                mapping[name] = (cls, (
+                    ("name", name),
+                ))
 
-        self.__mapping = mapping
-        self.__static = tuple(static)
-        self.__dynamic = tuple(dynamic)
-
-        return super().__init__()
+        return super().__init__(mapping)
 
     def __iter__(self):
-        for (_, items) in self.__mapping.items():
-            (cls, kwargs) = items
-            yield (cls, kwargs)
-
-    def __repr__(self):
-        return self.__mapping.__repr__()
-
-    def __contains__(self, key):
-        return self.__mapping.__contains__(key)
+        for (cls, kwargs) in self.values():
+            yield (cls, dict(kwargs))
 
-    def __getitem__(self, key):
-        return self.__mapping.__getitem__(key)
+    def __hash__(self):
+        return hash(tuple(sorted(self.items())))
 
-    @property
+    @cached_property
     def static(self):
-        return self.__static
+        return tuple(filter(lambda pair: issubclass(pair[0], StaticOperand), self))
 
-    @property
+    @cached_property
     def dynamic(self):
-        return self.__dynamic
+        return tuple(filter(lambda pair: issubclass(pair[0], DynamicOperand), self))
 
 
 class Arguments(tuple):
@@ -943,7 +772,7 @@ class Arguments(tuple):
         return super().__new__(cls, items)
 
 
-class PCode(Tuple, datatype=str):
+class PCode(tuple):
     pass
 
 
@@ -1127,7 +956,7 @@ class Record(Dataclass):
 
     def __getitem__(self, key):
         (cls, kwargs) = self.mdwn.operands.__getitem__(key)
-        return cls(record=self, **kwargs)
+        return cls(record=self, **dict(kwargs))
 
     @cached_property
     def Rc(self):
@@ -3646,7 +3475,7 @@ class MarkdownDatabase:
                 (dynamic, *static) = desc.regs
                 operands.extend(dynamic)
                 operands.extend(static)
-            pcode = PCode(iterable=filter(str.strip, desc.pcode))
+            pcode = PCode(filter(str.strip, desc.pcode))
             operands = Operands(insn=name, operands=operands)
             db[name] = MarkdownRecord(pcode=pcode, operands=operands)
 
@@ -3808,12 +3637,12 @@ class SVP64Database:
         return None
 
 
-class Records(Tuple, datatype=Record):
+class Records(tuple):
     def __new__(cls, records):
         return super().__new__(cls, sorted(records))
 
 
-class Database(Node):
+class Database:
     def __init__(self, root):
         root = _pathlib.Path(root)
         mdwndb = MarkdownDatabase()
@@ -3847,14 +3676,6 @@ class Database(Node):
 
         return super().__init__()
 
-    @walkmethod
-    def walk(clsself, match=lambda _: True):
-        if isinstance(clsself, type):
-            yield ("records", Records)
-        else:
-            if match(clsself.__db):
-                yield ("records", clsself.__db)
-
     def __repr__(self):
         return repr(self.__db)
 
index fa5eb267ea7ae14f442183b260c2580b16d56a80..d18df31de78ec019bef230f61b0bd3d0a531d59e 100644 (file)
@@ -3,19 +3,19 @@ import contextlib
 import os
 import types
 
+import mdis.dispatcher
+import mdis.visitor
+import mdis.walker
+
 from openpower.decoder.power_enums import (
     find_wiki_dir,
 )
 from openpower.insndb.core import (
     Database,
-    Dataclass,
-    Dict,
+    PCode,
+    Operands,
     Record,
-    Records,
-    Tuple,
-    Visitor,
-    visit,
-    visitormethod,
+    Walker,
 )
 
 
@@ -42,34 +42,16 @@ class SVP64Instruction(Instruction):
         return self
 
 
-class TreeVisitor(Visitor):
-    def __init__(self):
-        self.__depth = 0
-        self.__path = [""]
-        return super().__init__()
-
+class ListVisitor(mdis.visitor.ContextVisitor):
+    @mdis.dispatcher.Hook(Record)
     @contextlib.contextmanager
-    def __call__(self, path, node):
-        with super().__call__(path=path, node=node):
-            self.__path.append(path)
-            print("/".join(self.__path))
-            if not isinstance(node, (Dataclass, Tuple, Dict)):
-                print("    ", repr(node), sep="")
-            self.__depth += 1
-            yield node
-            self.__path.pop(-1)
-            self.__depth -= 1
-
-
-class ListVisitor(Visitor):
-    @visitormethod(Record)
-    def Record(self, path, node):
+    def dispatch_record(self, node):
         print(node.name)
         yield node
 
 
 # No use other than checking issubclass and adding an argument.
-class InstructionVisitor(Visitor):
+class InstructionVisitor(mdis.visitor.ContextVisitor):
     pass
 
 class SVP64InstructionVisitor(InstructionVisitor):
@@ -77,38 +59,47 @@ class SVP64InstructionVisitor(InstructionVisitor):
 
 
 class OpcodesVisitor(InstructionVisitor):
-    @visitormethod(Record)
-    def Record(self, path, node):
+    @mdis.dispatcher.Hook(Record)
+    @contextlib.contextmanager
+    def dispatch_record(self, node):
         for opcode in node.opcodes:
             print(opcode)
         yield node
 
 
 class OperandsVisitor(InstructionVisitor):
-    @visitormethod(Record)
-    def Record(self, path, node):
-        if isinstance(node, Record):
-            for operand in node.dynamic_operands:
-                print(operand.name, ",".join(map(str, operand.span)))
-            for operand in node.static_operands:
-                if operand.name not in ("PO", "XO"):
-                    desc = f"{operand.name}={operand.value}"
-                    print(desc, ",".join(map(str, operand.span)))
+    def __init__(self):
+        self.__record = None
+        return super().__init__()
+
+    @mdis.dispatcher.Hook(Record)
+    @contextlib.contextmanager
+    def dispatch_record(self, node):
+        self.__record = node
+        yield node
+
+    @mdis.dispatcher.Hook(Operands)
+    @contextlib.contextmanager
+    def dispatch_operands(self, node):
+        for (cls, kwargs) in node:
+            operand = cls(record=self.__record, **kwargs)
+            print(operand.name, ", ".join(map(str, operand.span)))
         yield node
 
 
 class PCodeVisitor(InstructionVisitor):
-    @visitormethod(Record)
-    def Record(self, path, node):
-        if isinstance(node, Record):
-            for line in node.pcode:
-                print(line)
+    @mdis.dispatcher.Hook(PCode)
+    @contextlib.contextmanager
+    def dispatch_record(self, node):
+        for line in node:
+            print(line)
         yield node
 
 
 class ExtrasVisitor(SVP64InstructionVisitor):
-    @visitormethod(Record)
-    def Record(self, path, node):
+    @mdis.dispatcher.Hook(Record)
+    @contextlib.contextmanager
+    def dispatch_record(self, node):
         for (name, extra) in node.extras.items():
             print(name)
             print("    sel", extra["sel"])
@@ -120,10 +111,6 @@ class ExtrasVisitor(SVP64InstructionVisitor):
 
 def main():
     commands = {
-        "tree": (
-            TreeVisitor,
-            "list all records",
-        ),
         "list": (
             ListVisitor,
             "list available instructions",
@@ -171,16 +158,15 @@ def main():
     visitor = commands[command][0]()
 
     db = Database(find_wiki_dir())
-    (path, records) = next(db.walk(match=lambda pair: isinstance(pair, Records)))
     if not isinstance(visitor, InstructionVisitor):
-        match = lambda _: True
+        root = db
     else:
-        insn = args.pop("insn")
-        def match(record):
-            return (isinstance(record, Record) and (record.name == insn))
+        root = [db[args.pop("insn")]]
 
-    for (subpath, node) in records.walk(match=match):
-        visit(visitor=visitor, node=node, path=subpath)
+    walker = Walker()
+    for (node, *_) in walker(root):
+        with visitor(node):
+            pass
 
 
 if __name__ == "__main__":