From 5f07a672ec7eb1cf3206f87ddc00ea10f4430a24 Mon Sep 17 00:00:00 2001 From: Dmitry Selyutin Date: Wed, 7 Jun 2023 11:58:36 +0300 Subject: [PATCH] insndb: refactor visitors --- src/openpower/insndb/core.py | 56 +++++++++++------- src/openpower/insndb/db.py | 106 ++++++++++++----------------------- 2 files changed, 71 insertions(+), 91 deletions(-) diff --git a/src/openpower/insndb/core.py b/src/openpower/insndb/core.py index 511d695a..042b938d 100644 --- a/src/openpower/insndb/core.py +++ b/src/openpower/insndb/core.py @@ -56,18 +56,34 @@ from openpower.decoder.power_fields import ( from openpower.decoder.pseudo.pagereader import ISA as _ISA +class Node: + @property + def subnodes(self): + yield from () + + class Visitor: @_contextlib.contextmanager - def db(self, db): - yield db + def Node(self, node, depth): + yield node + for subnode in node.subnodes: + manager = subnode.__class__.__name__ + manager = getattr(self, manager, self.Node) + with manager(node=subnode, depth=(depth + 1)): + pass - @_contextlib.contextmanager - def record(self, record): - yield record + def __getattr__(self, attr): + return self.Node + + def __call__(self, node, depth): + manager = node.__class__.__name__ + manager = getattr(self, manager, self.Node) + return manager(node=node, depth=depth) - @_contextlib.contextmanager - def extra(self, extra): - yield extra + +def visit(visitor, node): + with visitor(node=node, depth=0): + pass @_functools.total_ordering @@ -824,7 +840,7 @@ class MarkdownRecord: @_dataclasses.dataclass(eq=True, frozen=True) -class Extra: +class Extra(Node): name: str sel: _typing.Union[ _In1Sel, _In2Sel, _In3Sel, _CRInSel, _CRIn2Sel, @@ -839,10 +855,9 @@ class Extra: pass - @_functools.total_ordering @_dataclasses.dataclass(eq=True, frozen=True) -class Record: +class Record(Node): name: str section: Section ppc: PPCRecord @@ -850,11 +865,10 @@ class Record: mdwn: MarkdownRecord svp64: SVP64Record = None - def visit(self, visitor): - with visitor.record(record=self) as record: - for (name, fields) in record.extras.items(): - extra = Extra(name=name, **fields) - extra.visit(visitor=visitor) + @property + def subnodes(self): + for (name, fields) in self.extras.items(): + yield Extra(name=name, **fields) @property def extras(self): @@ -3703,7 +3717,7 @@ class SVP64Database: return None -class Database: +class Database(Node): def __init__(self, root): root = _pathlib.Path(root) mdwndb = MarkdownDatabase() @@ -3737,10 +3751,10 @@ class Database: return super().__init__() - def visit(self, visitor): - with visitor.db(db=self) as db: - for record in self.__db: - record.visit(visitor=visitor) + @property + def subnodes(self): + for record in self.__db: + yield record def __repr__(self): return repr(self.__db) diff --git a/src/openpower/insndb/db.py b/src/openpower/insndb/db.py index aadd3b91..73f4abe2 100644 --- a/src/openpower/insndb/db.py +++ b/src/openpower/insndb/db.py @@ -9,6 +9,7 @@ from openpower.decoder.power_enums import ( from openpower.insndb.core import ( Database, Visitor, + visit, ) @@ -38,55 +39,27 @@ class SVP64Instruction(Instruction): class BaseVisitor(Visitor): def __init__(self, **arguments): self.__arguments = types.MappingProxyType(arguments) - self.__current_db = None - self.__current_record = None - self.__current_extra = None return super().__init__() - @property - def arguments(self): - return self.__arguments - - @property - def current_db(self): - return self.__current_db - - @property - def current_record(self): - return self.__current_record - - @property - def current_extra(self): - return self.__current_extra - - @contextlib.contextmanager - def db(self, db): - self.__current_db = db - yield db - self.__current_db = None - - @contextlib.contextmanager - def record(self, record): - self.__current_record = record - yield record - self.__current_record = None - - @contextlib.contextmanager - def extra(self, extra): - self.__current_extra = extra - yield extra - self.__current_extra = None + def __getitem__(self, argument): + return self.__arguments[argument] class ListVisitor(BaseVisitor): @contextlib.contextmanager - def record(self, record): - print(record.name) - yield record + def Record(self, node, depth): + print(node.name) + yield node class InstructionVisitor(BaseVisitor): - pass + @contextlib.contextmanager + def Database(self, node, depth): + yield node + for subnode in node.subnodes: + if subnode.name == self["insn"]: + with self(node=subnode, depth=(depth + 1)): + pass class SVP64InstructionVisitor(InstructionVisitor): @@ -95,48 +68,41 @@ class SVP64InstructionVisitor(InstructionVisitor): class OpcodesVisitor(InstructionVisitor): @contextlib.contextmanager - def record(self, record): - for opcode in record.opcodes: + def Record(self, node, depth): + for opcode in node.opcodes: print(opcode) + yield node class OperandsVisitor(InstructionVisitor): @contextlib.contextmanager - def record(self, record): - with super().record(record=record): - if self.current_record.name == self.arguments["insn"]: - for operand in record.dynamic_operands: - print(operand.name, ",".join(map(str, operand.span))) - for operand in record.static_operands: - if operand.name not in ("PO", "XO"): - desc = f"{operand.name}={operand.value}" - print(desc, ",".join(map(str, operand.span))) - - yield record + def Record(self, node, depth): + for operand in node.dynamic_operands: + print(operand.name, ",".join(map(str, operand.span))) + for operand in node.static_operands: + if operand.name not in ("PO", "XO"): + desc = f"{operand.name}={operand.value}" + print(desc, ",".join(map(str, operand.span))) + yield node class PCodeVisitor(InstructionVisitor): @contextlib.contextmanager - def record(self, record): - with super().record(record=record): - if self.current_record.name == self.arguments["insn"]: - for line in record.pcode: - print(line) + def Record(self, node, depth): + for line in node.pcode: + print(line) + yield node class ExtrasVisitor(SVP64InstructionVisitor): @contextlib.contextmanager - def extra(self, extra): - with super().extra(extra=extra) as extra: - if self.current_record.name == self.arguments["insn"]: - print(extra.name) - print(" sel", extra.sel) - print(" reg", extra.reg) - print(" seltype", extra.seltype) - print(" idx", extra.idx) - pass - - yield extra + def Extra(self, node, depth): + print(node.name) + print(" sel", node.sel) + print(" reg", node.reg) + print(" seltype", node.seltype) + print(" idx", node.idx) + yield node def main(): @@ -188,7 +154,7 @@ def main(): visitor = commands[command][0](**args) db = Database(find_wiki_dir()) - db.visit(visitor=visitor) + visit(visitor=visitor, node=db) if __name__ == "__main__": -- 2.30.2