oppc/code: handle statements everywhere
authorDmitry Selyutin <ghostmansd@gmail.com>
Tue, 16 Jan 2024 19:09:45 +0000 (22:09 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Tue, 16 Jan 2024 19:10:08 +0000 (22:10 +0300)
src/openpower/oppc/pc_code.py

index 9b92caefda26928b989a9d54764b466e2b78e13c..ff4ebfc00c367ca674faaeb0a7c6672043a797a9 100644 (file)
@@ -25,7 +25,6 @@ class Transient(pc_ast.Node):
 class Call(pc_ast.Dataclass):
     name: str
     code: tuple
-    stmt: bool
 
 
 class Instruction(pc_ast.Node):
@@ -51,12 +50,15 @@ class CodeVisitor(pc_util.Visitor):
 
         operands = tuple(operand.name for operand in insn.dynamic_operands)
         for var in operands:
-            self.__code[self.__header].emit(stmt=f"struct oppc_value {var};")
+            with self.statement(node=self.__header):
+                self[self.__header].emit(stmt=f"struct oppc_value {var}")
         for var in filter(lambda var: var not in operands, sorted(self.__vars)):
-            self.__code[self.__header].emit(stmt=f"struct oppc_value {var};")
+            with self.statement(node=self.__header):
+                self[self.__header].emit(stmt=f"struct oppc_value {var}")
         if not operands:
-            self.__code[self.__header].emit("(void)insn;")
-        self.__code[self.__header].emit()
+            with self.statement(node=self.__header):
+                self[self.__header].emit("(void)insn")
+        self[self.__header].emit()
 
         for operand in insn.dynamic_operands:
             bits = f"{len(operand.span)}"
@@ -64,8 +66,9 @@ class CodeVisitor(pc_util.Visitor):
             symbol = pc_ast.Symbol(operand.name)
             assign = pc_ast.AssignExpr(lvalue=symbol, rvalue=transient)
             self.traverse(root=assign)
-            for (level, stmt) in self[assign]:
-                self[self.__header].emit(stmt=stmt, level=level)
+            with self.statement(node=self.__header):
+                for (level, stmt) in self[assign]:
+                    self[self.__header].emit(stmt=stmt, level=level)
 
             for (lbit, rbit) in enumerate(operand.span):
                 lsymbol = pc_ast.Symbol(operand.name)
@@ -76,9 +79,10 @@ class CodeVisitor(pc_util.Visitor):
                 rvalue = pc_ast.SubscriptExpr(index=rindex, subject=rsymbol)
                 assign = pc_ast.AssignExpr(lvalue=lvalue, rvalue=rvalue)
                 self.traverse(root=assign)
-                for (level, stmt) in self[assign]:
-                    self[self.__header].emit(stmt=stmt, level=level)
-            self.__code[self.__header].emit()
+                with self.statement(node=self.__header):
+                    for (level, stmt) in self[assign]:
+                        self[self.__header].emit(stmt=stmt, level=level)
+            self[self.__header].emit()
 
     def __iter__(self):
         yield from self[self.__header]
@@ -112,7 +116,7 @@ class CodeVisitor(pc_util.Visitor):
             return tuple(map(validate, item))
 
         code = tuple(map(validate, code))
-        call = Call(name=name, code=code, stmt=stmt)
+        call = Call(name=name, code=code)
         self.traverse(root=call)
         return call
 
@@ -172,6 +176,12 @@ class CodeVisitor(pc_util.Visitor):
             code = self[call]
         self[root] = code
 
+    @contextlib.contextmanager
+    def statement(self, node):
+        yield
+        (level, stmt) = self[node][-1]
+        self[node][-1] = (level, f"{stmt};")
+
     @contextlib.contextmanager
     def pseudocode(self, node):
         if node in self.__pseudocode:
@@ -182,12 +192,20 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.Scope)
     def Scope(self, node):
         yield node
+        stmts = (
+            pc_ast.AssignExpr,
+            pc_ast.AssignIEAExpr,
+            pc_ast.SubscriptExpr,
+            pc_ast.RangeSubscriptExpr,
+            pc_ast.Call,
+        )
         with self[node]:
             for subnode in node:
-                if isinstance(subnode, pc_ast.Call):
+                if isinstance(subnode, stmts):
                     (level, stmt) = self[subnode][-1]
-                    assert stmt[-1] == ")"
-                    self[subnode][-1] = (level, f"{stmt};")
+                    if stmt[-1] != ";":
+                        stmt = f"{stmt};"
+                    self[subnode][-1] = (level, stmt)
                 for (level, stmt) in self[subnode]:
                     self[node].emit(stmt=stmt, level=level)
 
@@ -217,30 +235,31 @@ class CodeVisitor(pc_util.Visitor):
                     rvalue=rvalue.clone(),
                 )
                 self.traverse(root=assign)
-                for (level, stmt) in self[assign]:
-                    self[node].emit(stmt=stmt, level=level)
+                with self.statement(node=node):
+                    for (level, stmt) in self[assign]:
+                        self[node].emit(stmt=stmt, level=level)
             return
 
         if isinstance(node.lvalue, pc_ast.SubscriptExpr):
-            call = self.call(name="oppc_subscript_assign", stmt=True, code=[
+            call = self.call(name="oppc_subscript_assign", code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.index],
                 self[node.rvalue],
             ])
         elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
-            call = self.call(name="oppc_range_subscript_assign", stmt=True, code=[
+            call = self.call(name="oppc_range_subscript_assign", code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.start],
                 self[node.lvalue.end],
                 self[node.rvalue],
             ])
         elif isinstance(node.lvalue, pc_ast.Attribute):
-            call = self.call(name="oppc_attr_assign", stmt=True, code=[
+            call = self.call(name="oppc_attr_assign", code=[
                 self[node.lvalue],
                 self[node.rvalue],
             ])
         else:
-            call = self.call(name="oppc_assign", stmt=True, code=[
+            call = self.call(name="oppc_assign", code=[
                 self[node.lvalue],
                 self[node.rvalue],
             ])
@@ -375,9 +394,8 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(Call)
     def CCall(self, node):
         yield node
-        end = (";" if node.stmt else "")
         if len(node.code) == 0:
-            self[node].emit(stmt=f"{str(node.name)}(){end}")
+            self[node].emit(stmt=f"{str(node.name)}()")
         else:
             self[node].emit(stmt=f"{str(node.name)}(")
             with self[node]:
@@ -393,7 +411,7 @@ class CodeVisitor(pc_util.Visitor):
                     self[node][-1] = (level, stmt)
                 for (level, stmt) in tail:
                     self[node].emit(stmt=stmt, level=level)
-            self[node].emit(stmt=f"){end}")
+            self[node].emit(stmt=f")")
 
     @pc_util.Hook(pc_ast.GPR)
     def GPR(self, node):
@@ -504,14 +522,13 @@ class CodeVisitor(pc_util.Visitor):
                 for subnode in (enter, match, leave):
                     self.__pseudocode.traverse(root=subnode)
                     self.traverse(root=subnode)
-                    for (level, stmt) in self[subnode]:
-                        self[node].emit(stmt=stmt, level=level)
-                    (level, stmt) = self[node][-1]
-                    if subnode is match:
-                        stmt = f"{stmt};"
-                    elif subnode is leave:
-                        stmt = stmt[:-1]
-                    self[node][-1] = (level, stmt)
+                    if subnode is not leave:
+                        with self.statement(node=node):
+                            for (level, stmt) in self[subnode]:
+                                self[node].emit(stmt=stmt, level=level)
+                    else:
+                        for (level, stmt) in self[subnode]:
+                            self[node].emit(stmt=stmt, level=level)
         (level, stmt) = self[node][0]
         self[node].emit(stmt=stmt, level=level)
         self[node].emit(stmt=") {")