oppc/code: fix switch statements
authorDmitry Selyutin <ghostmansd@gmail.com>
Tue, 16 Jan 2024 19:09:45 +0000 (22:09 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Tue, 16 Jan 2024 19:10:08 +0000 (22:10 +0300)
src/openpower/oppc/pc_code.py

index c640bdba21ab9583fb14f9c286ec41f91907da47..458428413f660115633b224401d0e652c87a7629 100644 (file)
@@ -585,9 +585,8 @@ class CodeVisitor(pc_util.Visitor):
             for (level, stmt) in self[subject]:
                 self[node].emit(stmt=stmt, level=level)
         self[node].emit(") {")
-        with self[node]:
-            for (level, stmt) in self[node.cases]:
-                self[node].emit(stmt=stmt, level=level)
+        for (level, stmt) in self[node.cases]:
+            self[node].emit(stmt=stmt, level=level)
         self[node].emit(stmt="}")
 
     @pc_util.Hook(pc_ast.Cases)
@@ -600,25 +599,37 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.Case)
     def Case(self, node):
         yield node
-        for (level, stmt) in self[node.labels]:
-            self[node].emit(stmt=stmt, level=level)
-        for (level, stmt) in self[node.body]:
-            self[node].emit(stmt=stmt, level=level)
+        if not ((len(node.body) == 1) and
+                isinstance(node.body[0], pc_ast.Symbol) and
+                str(node.body[0]) == "fallthrough"):
+            for (level, stmt) in self[node.labels]:
+                self[node].emit(stmt=stmt, level=level)
+            for (level, stmt) in self[node.body]:
+                self[node].emit(stmt=stmt, level=level)
+            with self[node]:
+                self[node].emit(stmt="break;")
+        else:
+            for label in node.labels:
+                label = f"case {str(label)}: /* fallthrough */"
+                self[node].emit(stmt=label)
 
     @pc_util.Hook(pc_ast.Labels)
     def Labels(self, node):
         yield node
-        if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)):
-            stmt = "default:"
-        else:
-            labels = ", ".join(map(lambda label: str(self[label]), node))
-            stmt = f"case ({labels}):"
-        self[node].emit(stmt=stmt)
+        for subnode in node:
+            for (level, stmt) in self[subnode]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.DefaultLabel)
+    def DefaultLabel(self, node):
+        yield node
+        self[node].emit(stmt="default:")
 
     @pc_util.Hook(pc_ast.Label)
     def Label(self, node):
         yield node
-        self[node].emit(stmt=str(node))
+        label = f"case {str(node)}:"
+        self[node].emit(stmt=label)
 
     @pc_util.Hook(pc_ast.LeaveKeyword)
     def LeaveKeyword(self, node):