oppc/code: introduce ccall
authorDmitry Selyutin <ghostmansd@gmail.com>
Sun, 14 Jan 2024 12:24:39 +0000 (15:24 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Tue, 16 Jan 2024 19:10:07 +0000 (22:10 +0300)
src/openpower/oppc/pc_code.py

index 7b37a2944a13ed3fdeace1e3571225c9510684f2..f351f5b825678fc4632807e01c13ad0abdf19cff 100644 (file)
@@ -17,6 +17,12 @@ class Transient(pc_ast.Node):
         return f"oppc_transient(&(struct oppc_int){{}}, {self.__value}, {self.__bits})"
 
 
+class CCall(pc_ast.Dataclass):
+    name: str
+    code: tuple
+    stmt: bool
+
+
 class CodeVisitor(pc_util.Visitor):
     def __init__(self, name, root):
         self.__root = root
@@ -51,29 +57,30 @@ class CodeVisitor(pc_util.Visitor):
             self.traverse(root=transient)
         return transient
 
+    def ccall(self, node, name, code, stmt=False):
+        def validate(item):
+            def validate(item):
+                (level, stmt) = item
+                if not isinstance(level, int):
+                    raise ValueError(level)
+                if not isinstance(stmt, str):
+                    raise ValueError(stmt)
+                return (level, stmt)
+
+            return tuple(map(validate, item))
+
+        code = tuple(map(validate, code))
+        ccall = CCall(name=name, code=code, stmt=stmt)
+        with self.pseudocode(node=node):
+            self.traverse(root=ccall)
+        return ccall
+
     @contextlib.contextmanager
     def pseudocode(self, node):
         for (level, stmt) in self.__pseudocode[node]:
             self[node].emit(stmt=f"/* {stmt} */", level=level)
         yield
 
-    def call(self, node, code, prefix="", suffix=""):
-        with self.pseudocode(node=node):
-            self[node].emit(stmt=f"{prefix}(")
-            with self[node]:
-                for chunk in code[:-1]:
-                    for (level, stmt) in chunk:
-                        if not (not stmt or
-                                stmt.startswith("/*") or
-                                stmt.endswith((",", "(", "{", "*/"))):
-                            stmt = (stmt + ",")
-                        self[node].emit(stmt=stmt, level=level)
-                if len(code) > 0:
-                    for (level, stmt) in code[-1]:
-                        if stmt:
-                            self[node].emit(stmt=stmt, level=level)
-            self[node].emit(stmt=f"){suffix}")
-
     @pc_util.Hook(pc_ast.Scope)
     def Scope(self, node):
         yield node
@@ -101,23 +108,25 @@ class CodeVisitor(pc_util.Visitor):
             ]))]
 
         if isinstance(node.lvalue, pc_ast.SubscriptExpr):
-            self.call(prefix="oppc_subscript_assign", suffix=";", node=node, code=[
+            ccall = self.ccall(name="oppc_subscript_assign", node=node, stmt=True, code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.index],
                 rvalue,
             ])
         elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
-            self.call(prefix="oppc_range_subscript_assign", suffix=";", node=node, code=[
+            ccall = self.ccall(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.start],
                 self[node.lvalue.end],
                 rvalue,
             ])
         else:
-            self.call(prefix="oppc_assign", suffix=";", node=node, code=[
+            ccall = self.ccall(name="oppc_assign", stmt=True, node=node, code=[
                 self[node.lvalue],
                 rvalue,
             ])
+        for (level, stmt) in self[ccall]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.BinaryExpr)
     def BinaryExpr(self, node):
@@ -133,24 +142,28 @@ class CodeVisitor(pc_util.Visitor):
             pc_ast.Ge, pc_ast.Gt,
         )
         if isinstance(node.op, comparison):
-            self.call(prefix=str(self[node.op]), node=node, code=[
+            ccall = self.ccall(name=str(self[node.op]), node=node, code=[
                 self[node.left],
                 self[node.right],
             ])
         else:
             transient = self.transient(node=node)
-            self.call(prefix=str(self[node.op]), node=node, code=[
+            ccall = self.ccall(name=str(self[node.op]), node=node, code=[
                 self[transient],
                 self[node.left],
                 self[node.right],
             ])
+        for (level, stmt) in self[ccall]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.UnaryExpr)
     def UnaryExpr(self, node):
         yield node
-        self.call(prefix=str(self[node.op]), node=node, code=[
+        ccall = self.ccall(name=str(self[node.op]), node=node, code=[
             self[node.value],
         ])
+        for (level, stmt) in self[ccall]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(
             pc_ast.Not, pc_ast.Add, pc_ast.Sub,
@@ -211,6 +224,29 @@ class CodeVisitor(pc_util.Visitor):
         yield node
         self[node].emit(stmt=str(node))
 
+    @pc_util.Hook(CCall)
+    def CCall(self, node):
+        yield node
+        end = (";" if node.stmt else "")
+        if len(node.code) == 0:
+            self[node].emit(stmt=f"{str(node.name)}(){end}")
+        else:
+            self[node].emit(stmt=f"{str(node.name)}(")
+            with self[node]:
+                (*head, tail) = node.code
+                for code in head:
+                    for (level, stmt) in code:
+                        self[node].emit(stmt=stmt, level=level)
+                    (level, stmt) = self[node][-1]
+                    if not (not stmt or
+                            stmt.startswith("/*") or
+                            stmt.endswith((",", "(", "{", "*/"))):
+                        stmt = (stmt + ",")
+                    self[node][-1] = (level, stmt)
+                for (level, stmt) in tail:
+                    self[node].emit(stmt=stmt, level=level)
+            self[node].emit(stmt=f"){end}")
+
     @pc_util.Hook(pc_ast.GPR)
     def GPR(self, node):
         yield node
@@ -227,11 +263,13 @@ class CodeVisitor(pc_util.Visitor):
     def RepeatExpr(self, node):
         yield node
         transient = self.transient(node=node)
-        self.call(prefix="oppc_repeat", node=node, code=[
+        ccall = self.ccall(name="oppc_repeat", node=node, code=[
             self[transient],
             self[node.subject],
             self[node.times],
         ])
+        for (level, stmt) in self[ccall]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.XLEN)
     def XLEN(self, node):
@@ -244,19 +282,23 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.SubscriptExpr)
     def SubscriptExpr(self, node):
         yield node
-        self.call(prefix="oppc_subscript", node=node, code=[
+        ccall = self.ccall(name="oppc_subscript", node=node, code=[
             self[node.subject],
             self[node.index],
         ])
+        for (level, stmt) in self[ccall]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.RangeSubscriptExpr)
     def RangeSubscriptExpr(self, node):
         yield node
-        self.call(prefix="oppc_subscript", node=node, code=[
+        ccall = self.ccall(name="oppc_subscript", node=node, code=[
             self[node.subject],
             self[node.start],
             self[node.end],
         ])
+        for (level, stmt) in self[ccall]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.ForExpr)
     def ForExpr(self, node):
@@ -353,7 +395,9 @@ class CodeVisitor(pc_util.Visitor):
     def Call(self, node):
         yield node
         code = tuple(map(lambda arg: self[arg], node.args))
-        self.call(prefix=str(node.name), node=node, code=code)
+        ccall = self.ccall(name=str(node.name), node=node, code=code)
+        for (level, stmt) in self[ccall]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.Symbol)
     def Symbol(self, node):