oppc: decouple common code
authorDmitry Selyutin <ghostmansd@gmail.com>
Tue, 9 Jan 2024 18:23:44 +0000 (21:23 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Tue, 9 Jan 2024 18:55:25 +0000 (21:55 +0300)
src/openpower/oppc/pc_pseudocode.py
src/openpower/oppc/pc_util.py [new file with mode: 0644]

index f6b5a56c1d0e5626e3636a652a59db1960679121..0d111a81fb8d6d9c744770bc0d32f8e709421b35 100644 (file)
@@ -1,69 +1,23 @@
 import collections
-import contextlib
-import functools
-
-import mdis.dispatcher
-import mdis.visitor
-import mdis.walker
 
 import openpower.oppc.pc_ast as pc_ast
+import openpower.oppc.pc_util as pc_util
 
 
-class Hook(mdis.dispatcher.Hook):
-    def __call__(self, call):
-        hook = super().__call__(call)
-
-        class ConcreteHook(hook.__class__):
-            @functools.wraps(hook.__call__)
-            @contextlib.contextmanager
-            def __call__(self, dispatcher, node, *args, **kwargs):
-                return hook(dispatcher, node, *args, **kwargs)
-
-        return ConcreteHook(*tuple(self))
-
-
-class Code(list):
-    def __init__(self):
-        self.__level = 0
-        return super().__init__()
-
-    def __enter__(self):
-        self.__level += 1
-        return self
-
-    def __exit__(self, exc_type, exc_value, exc_traceback):
-        self.__level -= 1
-
-    def __str__(self):
-        if len(self) == 0:
-            raise ValueError("empty code")
-
-        lines = []
-        for (level, stmt) in self:
-            line = ((" " * level * 4) + stmt)
-            lines.append(line)
-
-        return "\n".join(lines)
-
-    def emit(self, stmt, level=0):
-        item = ((level + self.__level), stmt)
-        self.append(item)
-
-
-class PseudocodeVisitor(mdis.visitor.ContextVisitor):
+class PseudocodeVisitor(pc_util.Visitor):
     def __init__(self, root):
+        self.__code = collections.defaultdict(lambda: pc_util.Code())
         self.__root = root
-        self.__code = collections.defaultdict(lambda: Code())
 
-        return super().__init__()
+        return super().__init__(root=root)
 
     def __iter__(self):
-        yield from self.__code.items()
+        yield from self.__code[self.__root]
 
     def __getitem__(self, node):
         return self.__code[node]
 
-    @Hook(pc_ast.Scope)
+    @pc_util.Hook(pc_ast.Scope)
     def Scope(self, node):
         yield node
         if node is not self.__root:
@@ -76,7 +30,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
                 for (level, stmt) in self[subnode]:
                     self[node].emit(stmt=stmt, level=level)
 
-    @Hook(pc_ast.Call)
+    @pc_util.Hook(pc_ast.Call)
     def Call(self, node):
         yield node
         args = []
@@ -88,7 +42,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         stmt = f"{node.name}({args})"
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
+    @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
     def AssignExpr(self, node):
         mapping = {
             pc_ast.AssignExpr: "<-",
@@ -120,7 +74,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
             ])
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.BinaryExpr)
+    @pc_util.Hook(pc_ast.BinaryExpr)
     def BinaryExpr(self, node):
         yield node
         stmt = " ".join([
@@ -130,7 +84,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         ])
         self[node].emit(stmt=f"({stmt})")
 
-    @Hook(pc_ast.IfExpr)
+    @pc_util.Hook(pc_ast.IfExpr)
     def IfExpr(self, node):
         yield node
         stmt = " ".join([
@@ -146,7 +100,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
             for (level, stmt) in self[node.orelse]:
                 self[node].emit(stmt=stmt, level=level)
 
-    @Hook(pc_ast.ForExpr)
+    @pc_util.Hook(pc_ast.ForExpr)
     def ForExpr(self, node):
         yield node
         stmt = " ".join([
@@ -161,7 +115,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         for (level, stmt) in self[node.body]:
             self[node].emit(stmt=stmt, level=level)
 
-    @Hook(pc_ast.WhileExpr)
+    @pc_util.Hook(pc_ast.WhileExpr)
     def WhileExpr(self, node):
         yield node
         stmt = " ".join([
@@ -177,7 +131,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
             for (level, stmt) in self[node.orelse]:
                 self[node].emit(stmt=stmt, level=level)
 
-    @Hook(pc_ast.RepeatExpr)
+    @pc_util.Hook(pc_ast.RepeatExpr)
     def RepeatExpr(self, node):
         yield node
         stmt = " ".join([
@@ -187,7 +141,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         ])
         self[node].emit(stmt=f"({stmt})")
 
-    @Hook(pc_ast.SwitchExpr)
+    @pc_util.Hook(pc_ast.SwitchExpr)
     def SwitchExpr(self, node):
         yield node
         self[node].emit(f"switch({str(self[node.subject])})")
@@ -195,14 +149,14 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
             for (level, stmt) in self[node.cases]:
                 self[node].emit(stmt=stmt, level=level)
 
-    @Hook(pc_ast.Cases)
+    @pc_util.Hook(pc_ast.Cases)
     def Cases(self, node):
         yield node
         for subnode in node:
             for (level, stmt) in self[subnode]:
                 self[node].emit(stmt=stmt, level=level)
 
-    @Hook(pc_ast.Case)
+    @pc_util.Hook(pc_ast.Case)
     def Case(self, node):
         yield node
         for (level, stmt) in self[node.labels]:
@@ -210,7 +164,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         for (level, stmt) in self[node.body]:
             self[node].emit(stmt=stmt, level=level)
 
-    @Hook(pc_ast.Labels)
+    @pc_util.Hook(pc_ast.Labels)
     def Labels(self, node):
         yield node
         if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)):
@@ -220,17 +174,17 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
             stmt = f"case ({labels}):"
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.Label)
+    @pc_util.Hook(pc_ast.Label)
     def Label(self, node):
         yield node
         self[node].emit(stmt=str(node))
 
-    @Hook(pc_ast.DefaultLabel)
+    @pc_util.Hook(pc_ast.DefaultLabel)
     def DefaultLabel(self, node):
         yield node
         self[node].emit(stmt="default:")
 
-    @Hook(pc_ast.UnaryExpr)
+    @pc_util.Hook(pc_ast.UnaryExpr)
     def UnaryExpr(self, node):
         yield node
         stmt = "".join([
@@ -239,22 +193,22 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         ])
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
+    @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
     def Integer(self, node):
         yield node
         self[node].emit(stmt=str(node))
 
-    @Hook(pc_ast.StringLiteral)
+    @pc_util.Hook(pc_ast.StringLiteral)
     def StringLiteral(self, node):
         yield node
         self[node].emit(stmt=f"'{str(node)}'")
 
-    @Hook(pc_ast.Symbol)
+    @pc_util.Hook(pc_ast.Symbol)
     def Symbol(self, node):
         yield node
         self[node].emit(stmt=str(node))
 
-    @Hook(pc_ast.Attribute)
+    @pc_util.Hook(pc_ast.Attribute)
     def Attribute(self, node):
         yield node
         stmt = ".".join([
@@ -263,7 +217,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         ])
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.Not, pc_ast.Add, pc_ast.Sub,
+    @pc_util.Hook(pc_ast.Not, pc_ast.Add, pc_ast.Sub,
             pc_ast.Mul, pc_ast.MulS, pc_ast.MulU,
             pc_ast.Div, pc_ast.DivT, pc_ast.Mod,
             pc_ast.Sqrt,
@@ -307,7 +261,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         stmt = mapping[node.__class__]
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.LParenthesis, pc_ast.RParenthesis,
+    @pc_util.Hook(pc_ast.LParenthesis, pc_ast.RParenthesis,
             pc_ast.LBracket, pc_ast.RBracket)
     def BracketOrParenthesis(self, node):
         yield node
@@ -320,7 +274,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         stmt = mapping[node.__class__]
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.Subscript)
+    @pc_util.Hook(pc_ast.Subscript)
     def Subscript(self, node):
         yield node
         stmt = "".join([
@@ -331,7 +285,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         ])
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.RangeSubscript)
+    @pc_util.Hook(pc_ast.RangeSubscript)
     def RangeSubscript(self, node):
         yield node
         stmt = "".join([
@@ -344,32 +298,32 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         ])
         self[node].emit(stmt=stmt)
 
-    @Hook(pc_ast.Colon)
+    @pc_util.Hook(pc_ast.Colon)
     def Colon(self, node):
         yield node
         self[node].emit(stmt=":")
 
-    @Hook(pc_ast.Linebreak, pc_ast.Endmarker)
+    @pc_util.Hook(pc_ast.Linebreak, pc_ast.Endmarker)
     def Ignore(self, node):
         yield node
 
-    @Hook(pc_ast.Keyword)
+    @pc_util.Hook(pc_ast.Keyword)
     def Keyword(self, node):
         yield node
         self[node].emit(stmt=node.__doc__)
 
-    @Hook(pc_ast.Sequence)
+    @pc_util.Hook(pc_ast.Sequence)
     def Sequence(self, node):
         yield node
         stmt = ",".join(map(lambda subnode: str(self[subnode]), node))
         self[node].emit(stmt=f"({stmt})")
 
-    @Hook(pc_ast.Literal)
+    @pc_util.Hook(pc_ast.Literal)
     def Literal(self, node):
         yield node
         self[node].emit(stmt=str(node))
 
-    @Hook(pc_ast.GPR, pc_ast.FPR, pc_ast.GPRZero)
+    @pc_util.Hook(pc_ast.GPR, pc_ast.FPR, pc_ast.GPRZero)
     def Reg(self, node):
         yield node
         if isinstance(node, pc_ast.GPRZero):
@@ -377,20 +331,10 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor):
         else:
             self[node].emit(stmt=f"({str(node)})")
 
-    @Hook(pc_ast.Node)
+    @pc_util.Hook(pc_ast.Node)
     def Node(self, node):
         raise NotImplementedError(type(node))
 
 
-def traverse(root, visitor, walker):
-    with visitor(root):
-        for node in walker(root):
-            traverse(root=node, visitor=visitor, walker=walker)
-
-
 def pseudocode(root):
-    walker = mdis.walker.Walker()
-    visitor = PseudocodeVisitor(root=root)
-    traverse(root=root, visitor=visitor, walker=walker)
-    for (level, stmt) in visitor[root]:
-        yield (level, stmt)
+    yield from PseudocodeVisitor(root=root)
diff --git a/src/openpower/oppc/pc_util.py b/src/openpower/oppc/pc_util.py
new file mode 100644 (file)
index 0000000..3127f64
--- /dev/null
@@ -0,0 +1,62 @@
+import contextlib
+import functools
+
+
+import mdis.visitor
+import mdis.walker
+import mdis.dispatcher
+
+
+class Hook(mdis.dispatcher.Hook):
+    def __call__(self, call):
+        hook = super().__call__(call)
+
+        class ConcreteHook(hook.__class__):
+            @functools.wraps(hook.__call__)
+            @contextlib.contextmanager
+            def __call__(self, dispatcher, node, *args, **kwargs):
+                return hook(dispatcher, node, *args, **kwargs)
+
+        return ConcreteHook(*tuple(self))
+
+
+class Code(list):
+    def __init__(self):
+        self.__level = 0
+        return super().__init__()
+
+    def __enter__(self):
+        self.__level += 1
+        return self
+
+    def __exit__(self, exc_type, exc_value, exc_traceback):
+        self.__level -= 1
+
+    def __str__(self):
+        if len(self) == 0:
+            raise ValueError("empty code")
+
+        lines = []
+        for (level, stmt) in self:
+            line = ((" " * level * 4) + stmt)
+            lines.append(line)
+
+        return "\n".join(lines)
+
+    def emit(self, stmt, level=0):
+        item = ((level + self.__level), stmt)
+        self.append(item)
+
+
+class Visitor(mdis.visitor.ContextVisitor):
+    def __init__(self, root):
+        self.__walker = mdis.walker.Walker()
+
+        self.traverse(root=root)
+
+        return super().__init__()
+
+    def traverse(self, root):
+        with self(root):
+            for node in self.__walker(root):
+                self.traverse(root=node)