oppc/code: emit only function body
[openpower-isa.git] / src / openpower / oppc / pc_code.py
index 62063a1fb0452691363938113da14c5fd30b3938..d490a85dbc3019d921c7b6307e5b7a1f33e69c92 100644 (file)
@@ -13,52 +13,88 @@ class Transient(pc_ast.Node):
 
         return super().__init__()
 
+    def __repr__(self):
+        return f"{hex(id(self))}@{self.__class__.__name__}({self.__value}, {self.__bits})"
+
     def __str__(self):
         return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
 
 
-class CCall(pc_ast.Dataclass):
+class Call(pc_ast.Dataclass):
     name: str
     code: tuple
     stmt: bool
 
 
+class Instruction(pc_ast.Node):
+    pass
+
+
 class CodeVisitor(pc_util.Visitor):
-    def __init__(self, name, root):
+    def __init__(self, insn, root):
+        if not isinstance(root, pc_ast.Scope):
+            raise ValueError(root)
+
         self.__root = root
+        self.__insn = insn
+        self.__decls = set()
         self.__header = object()
         self.__footer = object()
         self.__code = collections.defaultdict(lambda: pc_util.Code())
-        self.__decls = collections.defaultdict(list)
         self.__regfetch = collections.defaultdict(list)
         self.__regstore = collections.defaultdict(list)
         self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
 
         super().__init__(root=root)
 
-        self.__code[self.__header].emit(stmt="void")
-        self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
-        with self.__code[self.__header]:
-            for decl in self.__decls:
-                self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
-        self.__code[self.__footer].emit(stmt=f"}}")
+        for decl in self.__decls:
+            self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
+        decls = sorted(filter(lambda decl: decl in insn.fields, self.__decls))
+        if decls:
+            self.__code[self.__header].emit()
+        for decl in decls:
+            bits = f"{len(insn.fields[decl])}"
+            transient = Transient(bits=bits)
+            symbol = pc_ast.Symbol(decl)
+            assign = pc_ast.AssignExpr(lvalue=symbol, rvalue=transient)
+            self.traverse(root=assign)
+            for (level, stmt) in self[assign]:
+                self[self.__header].emit(stmt=stmt, level=level)
+            for (lbit, rbit) in enumerate(insn.fields[decl]):
+                lsymbol = pc_ast.Symbol(decl)
+                rsymbol = Instruction()
+                lindex = Transient(value=str(lbit))
+                rindex = Transient(value=str(rbit))
+                lvalue = pc_ast.SubscriptExpr(index=lindex, subject=lsymbol)
+                rvalue = pc_ast.SubscriptExpr(index=rindex, subject=rsymbol)
+                assign = pc_ast.AssignExpr(lvalue=lvalue, rvalue=rvalue)
+                self.traverse(root=assign)
+                for (level, stmt) in self[assign]:
+                    self[self.__header].emit(stmt=stmt, level=level)
+            self.__code[self.__header].emit()
+        if decls:
+            self.__code[self.__header].emit()
 
     def __iter__(self):
-        yield from self.__code[self.__header]
-        yield from self.__code[self.__root]
-        yield from self.__code[self.__footer]
+        yield from self[self.__header]
+        for (level, stmt) in self[self.__root]:
+            yield ((level - 1), stmt)
+        yield from self[self.__footer]
 
     def __getitem__(self, node):
         return self.__code[node]
 
-    def transient(self, node,
+    def __setitem__(self, node, code):
+        self.__code[node] = code
+
+    def transient(self,
             value="UINT64_C(0)",
             bits="(uint8_t)OPPC_XLEN"):
         transient = Transient(value=value, bits=bits)
         self.traverse(root=transient)
         return transient
 
-    def ccall(self, node, name, code, stmt=False):
+    def call(self, name, code, stmt=False):
         def validate(item):
             def validate(item):
                 (level, stmt) = item
@@ -71,15 +107,18 @@ class CodeVisitor(pc_util.Visitor):
             return tuple(map(validate, item))
 
         code = tuple(map(validate, code))
-        ccall = CCall(name=name, code=code, stmt=stmt)
-        self.traverse(root=ccall)
-        return ccall
+        call = Call(name=name, code=code, stmt=stmt)
+        self.traverse(root=call)
+        return call
 
-    def ternary(self, node):
+    def fixup_ternary(self, node):
         self[node].clear()
+        test = self.call(name="oppc_cast_bool", code=[
+            self[node.test],
+        ])
         self[node].emit(stmt="(")
         with self[node]:
-            for (level, stmt) in self[node.test]:
+            for (level, stmt) in self[test]:
                 self[node].emit(stmt=stmt, level=level)
             self[node].emit(stmt="?")
             for (level, stmt) in self[node.body]:
@@ -89,10 +128,50 @@ class CodeVisitor(pc_util.Visitor):
                 self[node].emit(stmt=stmt, level=level)
         self[node].emit(stmt=")")
 
+    def fixup_attr(self, node, assign=False):
+        root = node
+        code = tuple(self[root])
+        attribute_or_subscript = (
+            pc_ast.Attribute,
+            pc_ast.SubscriptExpr,
+            pc_ast.RangeSubscriptExpr,
+        )
+        while isinstance(node.subject, attribute_or_subscript):
+            node = node.subject
+
+        def wrap(code):
+            def wrap(item):
+                (level, stmt) = item
+                if not (not stmt or
+                        stmt.startswith("/*") or
+                        stmt.endswith((",", "(", "{", "*/"))):
+                    stmt = (stmt + ",")
+                return (level, stmt)
+
+            return tuple(map(wrap, code))
+
+        code = pc_util.Code()
+        for (level, stmt) in wrap(self[node.subject]):
+            code.emit(stmt=stmt, level=level)
+        for (level, stmt) in wrap(self[root]):
+            code.emit(stmt=stmt, level=level)
+
+        # discard the last comma
+        (level, stmt) = code[-1]
+        code[-1] = (level, stmt[:-1])
+
+        if not assign:
+            call = self.call(name="oppc_attr", code=[
+                code,
+            ])
+            code = self[call]
+        self[root] = code
+
     @contextlib.contextmanager
     def pseudocode(self, node):
-        for (level, stmt) in self.__pseudocode[node]:
-            self[node].emit(stmt=f"/* {stmt} */", level=level)
+        if node in self.__pseudocode:
+            for (level, stmt) in self.__pseudocode[node]:
+                self[node].emit(stmt=f"/* {stmt} */", level=level)
         yield
 
     @pc_util.Hook(pc_ast.Scope)
@@ -112,28 +191,52 @@ class CodeVisitor(pc_util.Visitor):
             self.__regfetch[str(node.rvalue)].append(node.rvalue)
 
         if isinstance(node.rvalue, pc_ast.IfExpr):
-            self.ternary(node=node.rvalue)
+            self.fixup_ternary(node=node.rvalue)
+        if isinstance(node.lvalue, pc_ast.Attribute):
+            self.fixup_attr(node=node.lvalue, assign=True)
+        if isinstance(node.rvalue, pc_ast.Attribute):
+            self.fixup_attr(node=node.rvalue)
+
+        if isinstance(node.lvalue, pc_ast.Sequence):
+            if not isinstance(node.rvalue, pc_ast.Sequence):
+                raise ValueError(node.rvalue)
+            if len(node.lvalue) != len(node.rvalue):
+                raise ValueError(node)
+            for (lvalue, rvalue) in zip(node.lvalue, node.rvalue):
+                assign = node.__class__(
+                    lvalue=lvalue.clone(),
+                    rvalue=rvalue.clone(),
+                )
+                self.traverse(root=assign)
+                for (level, stmt) in self[assign]:
+                    self[node].emit(stmt=stmt, level=level)
+            return
 
         if isinstance(node.lvalue, pc_ast.SubscriptExpr):
-            ccall = self.ccall(name="oppc_subscript_assign", node=node, stmt=True, code=[
+            call = self.call(name="oppc_subscript_assign", stmt=True, code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.index],
                 self[node.rvalue],
             ])
         elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
-            ccall = self.ccall(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
+            call = self.call(name="oppc_range_subscript_assign", stmt=True, code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.start],
                 self[node.lvalue.end],
                 self[node.rvalue],
             ])
+        elif isinstance(node.lvalue, pc_ast.Attribute):
+            call = self.call(name="oppc_attr_assign", stmt=True, code=[
+                self[node.lvalue],
+                self[node.rvalue],
+            ])
         else:
-            ccall = self.ccall(name="oppc_assign", stmt=True, node=node, code=[
+            call = self.call(name="oppc_assign", stmt=True, code=[
                 self[node.lvalue],
                 self[node.rvalue],
             ])
         with self.pseudocode(node=node):
-            for (level, stmt) in self[ccall]:
+            for (level, stmt) in self[call]:
                 self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.BinaryExpr)
@@ -144,50 +247,53 @@ class CodeVisitor(pc_util.Visitor):
         if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
             self.__regfetch[str(node.right)].append(node.left)
 
+        if isinstance(node.left, pc_ast.IfExpr):
+            self.fixup_ternary(node=node.left)
+        if isinstance(node.right, pc_ast.IfExpr):
+            self.fixup_ternary(node=node.right)
+        if isinstance(node.left, pc_ast.Attribute):
+            self.fixup_attr(node=node.left)
+        if isinstance(node.right, pc_ast.Attribute):
+            self.fixup_attr(node=node.right)
+
         comparison = (
             pc_ast.Lt, pc_ast.Le,
-            pc_ast.Eq, pc_ast.NotEq,
+            pc_ast.Eq, pc_ast.Ne,
             pc_ast.Ge, pc_ast.Gt,
             pc_ast.LtU, pc_ast.GtU,
         )
-        if isinstance(node.left, pc_ast.IfExpr):
-            self.ternary(node=node.left)
-        if isinstance(node.right, pc_ast.IfExpr):
-            self.ternary(node=node.right)
-
         if isinstance(node.op, comparison):
-            ccall = self.ccall(name=str(self[node.op]), node=node, code=[
-                self[node.left],
-                self[node.right],
-            ])
+            transient = self.transient(bits="UINT8_C(1)")
         else:
-            transient = self.transient(node=node)
-            ccall = self.ccall(name=str(self[node.op]), node=node, code=[
-                self[transient],
-                self[node.left],
-                self[node.right],
-            ])
+            transient = self.transient()
+        call = self.call(name=str(self[node.op]), code=[
+            self[transient],
+            self[node.left],
+            self[node.right],
+        ])
         with self.pseudocode(node=node):
-            for (level, stmt) in self[ccall]:
+            for (level, stmt) in self[call]:
                 self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.UnaryExpr)
     def UnaryExpr(self, node):
         yield node
         if isinstance(node.value, pc_ast.IfExpr):
-            self.ternary(node=node.value)
-        ccall = self.ccall(name=str(self[node.op]), node=node, code=[
+            self.fixup_ternary(node=node.value)
+        transient = self.transient()
+        call = self.call(name=str(self[node.op]), code=[
+            self[transient],
             self[node.value],
         ])
         with self.pseudocode(node=node):
-            for (level, stmt) in self[ccall]:
+            for (level, stmt) in self[call]:
                 self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(
             pc_ast.Not, pc_ast.Add, pc_ast.Sub,
             pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
             pc_ast.Lt, pc_ast.Le,
-            pc_ast.Eq, pc_ast.NotEq,
+            pc_ast.Eq, pc_ast.Ne,
             pc_ast.Ge, pc_ast.Gt,
             pc_ast.LtU, pc_ast.GtU,
             pc_ast.LShift, pc_ast.RShift,
@@ -206,9 +312,9 @@ class CodeVisitor(pc_util.Visitor):
             pc_ast.Lt: "oppc_lt",
             pc_ast.Le: "oppc_le",
             pc_ast.Eq: "oppc_eq",
+            pc_ast.Ne: "oppc_ne",
             pc_ast.LtU: "oppc_ltu",
             pc_ast.GtU: "oppc_gtu",
-            pc_ast.NotEq: "oppc_noteq",
             pc_ast.Ge: "oppc_ge",
             pc_ast.Gt: "oppc_gt",
             pc_ast.LShift: "oppc_lshift",
@@ -220,6 +326,12 @@ class CodeVisitor(pc_util.Visitor):
         }[node.__class__]
         self[node].emit(stmt=op)
 
+    @pc_util.Hook(pc_ast.StringLiteral)
+    def StringLiteral(self, node):
+        yield node
+        escaped = repr(str(node))[1:-1]
+        self[node].emit(stmt=f"\"{escaped}\"")
+
     @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
     def Integer(self, node):
         yield node
@@ -238,7 +350,7 @@ class CodeVisitor(pc_util.Visitor):
         if (value > ((2**64) - 1)):
             raise NotImplementedError()
         value = f"UINT64_C({fmt(value)})"
-        transient = self.transient(node=node, value=value, bits=bits)
+        transient = self.transient(value=value, bits=bits)
         with self.pseudocode(node=node):
             for (level, stmt) in self[transient]:
                 self[node].emit(stmt=stmt, level=level)
@@ -248,7 +360,7 @@ class CodeVisitor(pc_util.Visitor):
         yield node
         self[node].emit(stmt=str(node))
 
-    @pc_util.Hook(CCall)
+    @pc_util.Hook(Call)
     def CCall(self, node):
         yield node
         end = (";" if node.stmt else "")
@@ -275,54 +387,78 @@ class CodeVisitor(pc_util.Visitor):
     def GPR(self, node):
         yield node
         with self.pseudocode(node=node):
-            self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
+            self[node].emit(stmt=f"OPPC_GPR_{str(node)}")
+
+    @pc_util.Hook(pc_ast.GPRZero)
+    def GPRZero(self, node):
+        yield node
+        name = str(node)
+        test = pc_ast.Symbol(name)
+        body = pc_ast.Scope([pc_ast.GPR(name)])
+        orelse = pc_ast.Scope([Transient()])
+        ifexpr = pc_ast.IfExpr(test=test, body=body, orelse=orelse)
+        self.traverse(root=ifexpr)
+        self.fixup_ternary(node=ifexpr)
+        for (level, stmt) in self[ifexpr]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.FPR)
     def FPR(self, node):
         yield node
         with self.pseudocode(node=node):
-            self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
+            self[node].emit(stmt=f"OPPC_FPR_{str(node)}")
 
     @pc_util.Hook(pc_ast.RepeatExpr)
     def RepeatExpr(self, node):
         yield node
-        transient = self.transient(node=node)
-        ccall = self.ccall(name="oppc_repeat", node=node, code=[
+        transient = self.transient()
+        call = self.call(name="oppc_repeat", code=[
             self[transient],
             self[node.subject],
             self[node.times],
         ])
-        for (level, stmt) in self[ccall]:
+        for (level, stmt) in self[call]:
             self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.XLEN)
     def XLEN(self, node):
         yield node
         (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
-        transient = self.transient(node=node, value=value, bits=bits)
+        transient = self.transient(value=value, bits=bits)
         with self.pseudocode(node=node):
             for (level, stmt) in self[transient]:
                 self[node].emit(stmt=stmt, level=level)
 
+    @pc_util.Hook(pc_ast.Overflow, pc_ast.CR3, pc_ast.CR5,
+            pc_ast.XER, pc_ast.Reserve, pc_ast.Special)
+    def Special(self, node):
+        yield node
+        with self.pseudocode(node=node):
+            self[node].emit(stmt=f"OPPC_{str(node).upper()}")
+
     @pc_util.Hook(pc_ast.SubscriptExpr)
     def SubscriptExpr(self, node):
         yield node
-        ccall = self.ccall(name="oppc_subscript", node=node, code=[
+        transient = self.transient(bits="UINT8_C(1)")
+        call = self.call(name="oppc_subscript", code=[
+            self[transient],
             self[node.subject],
             self[node.index],
         ])
-        for (level, stmt) in self[ccall]:
+        for (level, stmt) in self[call]:
             self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.RangeSubscriptExpr)
     def RangeSubscriptExpr(self, node):
         yield node
-        ccall = self.ccall(name="oppc_subscript", node=node, code=[
+        transient = self.transient()
+        call = self.call(name="oppc_range_subscript", code=[
+            self[transient],
             self[node.subject],
             self[node.start],
             self[node.end],
         ])
-        for (level, stmt) in self[ccall]:
+        for (level, stmt) in self[call]:
             self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.ForExpr)
@@ -391,9 +527,12 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.IfExpr)
     def IfExpr(self, node):
         yield node
+        test = self.call(name="oppc_cast_bool", code=[
+            self[node.test],
+        ])
         self[node].emit(stmt="if (")
         with self[node]:
-            for (level, stmt) in self[node.test]:
+            for (level, stmt) in self[test]:
                 self[node].emit(stmt=stmt, level=level)
         self[node].emit(stmt=") {")
         for (level, stmt) in self[node.body]:
@@ -404,6 +543,56 @@ class CodeVisitor(pc_util.Visitor):
                 self[node].emit(stmt=stmt, level=level)
         self[node].emit(stmt="}")
 
+    @pc_util.Hook(pc_ast.SwitchExpr)
+    def SwitchExpr(self, node):
+        yield node
+        subject = self.call(name="oppc_cast_int64", code=[
+            self[node.subject],
+        ])
+        self[node].emit(stmt="switch (")
+        with self[node]:
+            for (level, stmt) in self[subject]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(") {")
+        with self[node]:
+            for (level, stmt) in self[node.cases]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Cases)
+    def Cases(self, node):
+        yield node
+        for subnode in node:
+            for (level, stmt) in self[subnode]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Case)
+    def Case(self, node):
+        yield node
+        for (level, stmt) in self[node.labels]:
+            self[node].emit(stmt=stmt, level=level)
+        for (level, stmt) in self[node.body]:
+            self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Labels)
+    def Labels(self, node):
+        yield node
+        if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)):
+            stmt = "default:"
+        else:
+            labels = ", ".join(map(lambda label: str(self[label]), node))
+            stmt = f"case ({labels}):"
+        self[node].emit(stmt=stmt)
+
+    @pc_util.Hook(pc_ast.Label)
+    def Label(self, node):
+        yield node
+        self[node].emit(stmt=str(node))
+
+    @pc_util.Hook(pc_ast.LeaveKeyword)
+    def LeaveKeyword(self, node):
+        yield node
+        self[node].emit(stmt="break;")
+
     @pc_util.Hook(pc_ast.Call.Name)
     def CallName(self, node):
         yield node
@@ -419,22 +608,56 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.Call)
     def Call(self, node):
         yield node
-        code = tuple(map(lambda arg: self[arg], node.args))
-        ccall = self.ccall(name=str(node.name), node=node, code=code)
-        for (level, stmt) in self[ccall]:
-            self[node].emit(stmt=stmt, level=level)
+        if node.args:
+            transient = self.transient()
+            code = [self[transient]]
+        else:
+            code = []
+        code.extend(map(lambda arg: self[arg], node.args))
+        name = f"OPPC_CALL_{str(node.name)}"
+        with self.pseudocode(node=node):
+            call = self.call(name=name, code=code)
+            for (level, stmt) in self[call]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Attribute.Name)
+    def AttributeName(self, node):
+        yield node
+
+    @pc_util.Hook(pc_ast.Sequence)
+    def Sequence(self, node):
+        yield node
+
+    @pc_util.Hook(pc_ast.Attribute)
+    def Attribute(self, node):
+        yield node
+        attr = str(self.__pseudocode[node])
+        symbol = f"OPPC_ATTR_{attr.replace('.', '_')}"
+        self[node].emit(f"/* {attr} */")
+        self[node].emit(stmt=symbol)
 
     @pc_util.Hook(pc_ast.Symbol)
     def Symbol(self, node):
         yield node
-        self.__decls[str(node)].append(node)
         with self.pseudocode(node=node):
-            self[node].emit(stmt=f"&{str(node)}")
+            decl = str(node)
+            if decl not in ("fallthrough",):
+                if decl in ("TRAP",):
+                    self[node].emit(stmt=f"OPPC_CALL_{decl}();")
+                else:
+                    if node in self.__pseudocode:
+                        self.__decls.add(decl)
+                    self[node].emit(stmt=f"&{decl}")
+
+    @pc_util.Hook(Instruction)
+    def Instruction(self, node):
+        yield node
+        self[node].emit("insn")
 
     @pc_util.Hook(pc_ast.Node)
     def Node(self, node):
         raise NotImplementedError(type(node))
 
 
-def code(name, root):
-    yield from CodeVisitor(name=name, root=root)
+def code(insn, root):
+    yield from CodeVisitor(insn=insn, root=root)