oppc/code: emit only function body
[openpower-isa.git] / src / openpower / oppc / pc_code.py
index d7cdeef5ee8fcbc9ee7844215f517f35531f35c9..d490a85dbc3019d921c7b6307e5b7a1f33e69c92 100644 (file)
@@ -13,6 +13,9 @@ class Transient(pc_ast.Node):
 
         return super().__init__()
 
+    def __repr__(self):
+        return f"{hex(id(self))}@{self.__class__.__name__}({self.__value}, {self.__bits})"
+
     def __str__(self):
         return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
 
@@ -23,42 +26,75 @@ class Call(pc_ast.Dataclass):
     stmt: bool
 
 
+class Instruction(pc_ast.Node):
+    pass
+
+
 class CodeVisitor(pc_util.Visitor):
-    def __init__(self, name, root):
+    def __init__(self, insn, root):
+        if not isinstance(root, pc_ast.Scope):
+            raise ValueError(root)
+
         self.__root = root
+        self.__insn = insn
+        self.__decls = set()
         self.__header = object()
         self.__footer = object()
         self.__code = collections.defaultdict(lambda: pc_util.Code())
-        self.__decls = collections.defaultdict(list)
         self.__regfetch = collections.defaultdict(list)
         self.__regstore = collections.defaultdict(list)
         self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
 
         super().__init__(root=root)
 
-        self.__code[self.__header].emit(stmt="void")
-        self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
-        with self.__code[self.__header]:
-            for decl in self.__decls:
-                self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
-        self.__code[self.__footer].emit(stmt=f"}}")
+        for decl in self.__decls:
+            self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
+        decls = sorted(filter(lambda decl: decl in insn.fields, self.__decls))
+        if decls:
+            self.__code[self.__header].emit()
+        for decl in decls:
+            bits = f"{len(insn.fields[decl])}"
+            transient = Transient(bits=bits)
+            symbol = pc_ast.Symbol(decl)
+            assign = pc_ast.AssignExpr(lvalue=symbol, rvalue=transient)
+            self.traverse(root=assign)
+            for (level, stmt) in self[assign]:
+                self[self.__header].emit(stmt=stmt, level=level)
+            for (lbit, rbit) in enumerate(insn.fields[decl]):
+                lsymbol = pc_ast.Symbol(decl)
+                rsymbol = Instruction()
+                lindex = Transient(value=str(lbit))
+                rindex = Transient(value=str(rbit))
+                lvalue = pc_ast.SubscriptExpr(index=lindex, subject=lsymbol)
+                rvalue = pc_ast.SubscriptExpr(index=rindex, subject=rsymbol)
+                assign = pc_ast.AssignExpr(lvalue=lvalue, rvalue=rvalue)
+                self.traverse(root=assign)
+                for (level, stmt) in self[assign]:
+                    self[self.__header].emit(stmt=stmt, level=level)
+            self.__code[self.__header].emit()
+        if decls:
+            self.__code[self.__header].emit()
 
     def __iter__(self):
-        yield from self.__code[self.__header]
-        yield from self.__code[self.__root]
-        yield from self.__code[self.__footer]
+        yield from self[self.__header]
+        for (level, stmt) in self[self.__root]:
+            yield ((level - 1), stmt)
+        yield from self[self.__footer]
 
     def __getitem__(self, node):
         return self.__code[node]
 
-    def transient(self, node,
+    def __setitem__(self, node, code):
+        self.__code[node] = code
+
+    def transient(self,
             value="UINT64_C(0)",
             bits="(uint8_t)OPPC_XLEN"):
         transient = Transient(value=value, bits=bits)
         self.traverse(root=transient)
         return transient
 
-    def call(self, node, name, code, stmt=False):
+    def call(self, name, code, stmt=False):
         def validate(item):
             def validate(item):
                 (level, stmt) = item
@@ -75,9 +111,9 @@ class CodeVisitor(pc_util.Visitor):
         self.traverse(root=call)
         return call
 
-    def ternary(self, node):
+    def fixup_ternary(self, node):
         self[node].clear()
-        test = self.call(name="oppc_bool", node=node, code=[
+        test = self.call(name="oppc_cast_bool", code=[
             self[node.test],
         ])
         self[node].emit(stmt="(")
@@ -92,10 +128,50 @@ class CodeVisitor(pc_util.Visitor):
                 self[node].emit(stmt=stmt, level=level)
         self[node].emit(stmt=")")
 
+    def fixup_attr(self, node, assign=False):
+        root = node
+        code = tuple(self[root])
+        attribute_or_subscript = (
+            pc_ast.Attribute,
+            pc_ast.SubscriptExpr,
+            pc_ast.RangeSubscriptExpr,
+        )
+        while isinstance(node.subject, attribute_or_subscript):
+            node = node.subject
+
+        def wrap(code):
+            def wrap(item):
+                (level, stmt) = item
+                if not (not stmt or
+                        stmt.startswith("/*") or
+                        stmt.endswith((",", "(", "{", "*/"))):
+                    stmt = (stmt + ",")
+                return (level, stmt)
+
+            return tuple(map(wrap, code))
+
+        code = pc_util.Code()
+        for (level, stmt) in wrap(self[node.subject]):
+            code.emit(stmt=stmt, level=level)
+        for (level, stmt) in wrap(self[root]):
+            code.emit(stmt=stmt, level=level)
+
+        # discard the last comma
+        (level, stmt) = code[-1]
+        code[-1] = (level, stmt[:-1])
+
+        if not assign:
+            call = self.call(name="oppc_attr", code=[
+                code,
+            ])
+            code = self[call]
+        self[root] = code
+
     @contextlib.contextmanager
     def pseudocode(self, node):
-        for (level, stmt) in self.__pseudocode[node]:
-            self[node].emit(stmt=f"/* {stmt} */", level=level)
+        if node in self.__pseudocode:
+            for (level, stmt) in self.__pseudocode[node]:
+                self[node].emit(stmt=f"/* {stmt} */", level=level)
         yield
 
     @pc_util.Hook(pc_ast.Scope)
@@ -115,23 +191,47 @@ class CodeVisitor(pc_util.Visitor):
             self.__regfetch[str(node.rvalue)].append(node.rvalue)
 
         if isinstance(node.rvalue, pc_ast.IfExpr):
-            self.ternary(node=node.rvalue)
+            self.fixup_ternary(node=node.rvalue)
+        if isinstance(node.lvalue, pc_ast.Attribute):
+            self.fixup_attr(node=node.lvalue, assign=True)
+        if isinstance(node.rvalue, pc_ast.Attribute):
+            self.fixup_attr(node=node.rvalue)
+
+        if isinstance(node.lvalue, pc_ast.Sequence):
+            if not isinstance(node.rvalue, pc_ast.Sequence):
+                raise ValueError(node.rvalue)
+            if len(node.lvalue) != len(node.rvalue):
+                raise ValueError(node)
+            for (lvalue, rvalue) in zip(node.lvalue, node.rvalue):
+                assign = node.__class__(
+                    lvalue=lvalue.clone(),
+                    rvalue=rvalue.clone(),
+                )
+                self.traverse(root=assign)
+                for (level, stmt) in self[assign]:
+                    self[node].emit(stmt=stmt, level=level)
+            return
 
         if isinstance(node.lvalue, pc_ast.SubscriptExpr):
-            call = self.call(name="oppc_subscript_assign", node=node, stmt=True, code=[
+            call = self.call(name="oppc_subscript_assign", stmt=True, code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.index],
                 self[node.rvalue],
             ])
         elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
-            call = self.call(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
+            call = self.call(name="oppc_range_subscript_assign", stmt=True, code=[
                 self[node.lvalue.subject],
                 self[node.lvalue.start],
                 self[node.lvalue.end],
                 self[node.rvalue],
             ])
+        elif isinstance(node.lvalue, pc_ast.Attribute):
+            call = self.call(name="oppc_attr_assign", stmt=True, code=[
+                self[node.lvalue],
+                self[node.rvalue],
+            ])
         else:
-            call = self.call(name="oppc_assign", stmt=True, node=node, code=[
+            call = self.call(name="oppc_assign", stmt=True, code=[
                 self[node.lvalue],
                 self[node.rvalue],
             ])
@@ -147,29 +247,30 @@ class CodeVisitor(pc_util.Visitor):
         if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
             self.__regfetch[str(node.right)].append(node.left)
 
+        if isinstance(node.left, pc_ast.IfExpr):
+            self.fixup_ternary(node=node.left)
+        if isinstance(node.right, pc_ast.IfExpr):
+            self.fixup_ternary(node=node.right)
+        if isinstance(node.left, pc_ast.Attribute):
+            self.fixup_attr(node=node.left)
+        if isinstance(node.right, pc_ast.Attribute):
+            self.fixup_attr(node=node.right)
+
         comparison = (
             pc_ast.Lt, pc_ast.Le,
-            pc_ast.Eq, pc_ast.NotEq,
+            pc_ast.Eq, pc_ast.Ne,
             pc_ast.Ge, pc_ast.Gt,
             pc_ast.LtU, pc_ast.GtU,
         )
-        if isinstance(node.left, pc_ast.IfExpr):
-            self.ternary(node=node.left)
-        if isinstance(node.right, pc_ast.IfExpr):
-            self.ternary(node=node.right)
-
         if isinstance(node.op, comparison):
-            call = self.call(name=str(self[node.op]), node=node, code=[
-                self[node.left],
-                self[node.right],
-            ])
+            transient = self.transient(bits="UINT8_C(1)")
         else:
-            transient = self.transient(node=node)
-            call = self.call(name=str(self[node.op]), node=node, code=[
-                self[transient],
-                self[node.left],
-                self[node.right],
-            ])
+            transient = self.transient()
+        call = self.call(name=str(self[node.op]), code=[
+            self[transient],
+            self[node.left],
+            self[node.right],
+        ])
         with self.pseudocode(node=node):
             for (level, stmt) in self[call]:
                 self[node].emit(stmt=stmt, level=level)
@@ -178,8 +279,10 @@ class CodeVisitor(pc_util.Visitor):
     def UnaryExpr(self, node):
         yield node
         if isinstance(node.value, pc_ast.IfExpr):
-            self.ternary(node=node.value)
-        call = self.call(name=str(self[node.op]), node=node, code=[
+            self.fixup_ternary(node=node.value)
+        transient = self.transient()
+        call = self.call(name=str(self[node.op]), code=[
+            self[transient],
             self[node.value],
         ])
         with self.pseudocode(node=node):
@@ -190,7 +293,7 @@ class CodeVisitor(pc_util.Visitor):
             pc_ast.Not, pc_ast.Add, pc_ast.Sub,
             pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
             pc_ast.Lt, pc_ast.Le,
-            pc_ast.Eq, pc_ast.NotEq,
+            pc_ast.Eq, pc_ast.Ne,
             pc_ast.Ge, pc_ast.Gt,
             pc_ast.LtU, pc_ast.GtU,
             pc_ast.LShift, pc_ast.RShift,
@@ -209,9 +312,9 @@ class CodeVisitor(pc_util.Visitor):
             pc_ast.Lt: "oppc_lt",
             pc_ast.Le: "oppc_le",
             pc_ast.Eq: "oppc_eq",
+            pc_ast.Ne: "oppc_ne",
             pc_ast.LtU: "oppc_ltu",
             pc_ast.GtU: "oppc_gtu",
-            pc_ast.NotEq: "oppc_noteq",
             pc_ast.Ge: "oppc_ge",
             pc_ast.Gt: "oppc_gt",
             pc_ast.LShift: "oppc_lshift",
@@ -223,6 +326,12 @@ class CodeVisitor(pc_util.Visitor):
         }[node.__class__]
         self[node].emit(stmt=op)
 
+    @pc_util.Hook(pc_ast.StringLiteral)
+    def StringLiteral(self, node):
+        yield node
+        escaped = repr(str(node))[1:-1]
+        self[node].emit(stmt=f"\"{escaped}\"")
+
     @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
     def Integer(self, node):
         yield node
@@ -241,7 +350,7 @@ class CodeVisitor(pc_util.Visitor):
         if (value > ((2**64) - 1)):
             raise NotImplementedError()
         value = f"UINT64_C({fmt(value)})"
-        transient = self.transient(node=node, value=value, bits=bits)
+        transient = self.transient(value=value, bits=bits)
         with self.pseudocode(node=node):
             for (level, stmt) in self[transient]:
                 self[node].emit(stmt=stmt, level=level)
@@ -278,19 +387,32 @@ class CodeVisitor(pc_util.Visitor):
     def GPR(self, node):
         yield node
         with self.pseudocode(node=node):
-            self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
+            self[node].emit(stmt=f"OPPC_GPR_{str(node)}")
+
+    @pc_util.Hook(pc_ast.GPRZero)
+    def GPRZero(self, node):
+        yield node
+        name = str(node)
+        test = pc_ast.Symbol(name)
+        body = pc_ast.Scope([pc_ast.GPR(name)])
+        orelse = pc_ast.Scope([Transient()])
+        ifexpr = pc_ast.IfExpr(test=test, body=body, orelse=orelse)
+        self.traverse(root=ifexpr)
+        self.fixup_ternary(node=ifexpr)
+        for (level, stmt) in self[ifexpr]:
+            self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.FPR)
     def FPR(self, node):
         yield node
         with self.pseudocode(node=node):
-            self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
+            self[node].emit(stmt=f"OPPC_FPR_{str(node)}")
 
     @pc_util.Hook(pc_ast.RepeatExpr)
     def RepeatExpr(self, node):
         yield node
-        transient = self.transient(node=node)
-        call = self.call(name="oppc_repeat", node=node, code=[
+        transient = self.transient()
+        call = self.call(name="oppc_repeat", code=[
             self[transient],
             self[node.subject],
             self[node.times],
@@ -302,7 +424,7 @@ class CodeVisitor(pc_util.Visitor):
     def XLEN(self, node):
         yield node
         (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
-        transient = self.transient(node=node, value=value, bits=bits)
+        transient = self.transient(value=value, bits=bits)
         with self.pseudocode(node=node):
             for (level, stmt) in self[transient]:
                 self[node].emit(stmt=stmt, level=level)
@@ -312,12 +434,14 @@ class CodeVisitor(pc_util.Visitor):
     def Special(self, node):
         yield node
         with self.pseudocode(node=node):
-            self[node].emit(stmt=f"&OPPC_{str(node).upper()}")
+            self[node].emit(stmt=f"OPPC_{str(node).upper()}")
 
     @pc_util.Hook(pc_ast.SubscriptExpr)
     def SubscriptExpr(self, node):
         yield node
-        call = self.call(name="oppc_subscript", node=node, code=[
+        transient = self.transient(bits="UINT8_C(1)")
+        call = self.call(name="oppc_subscript", code=[
+            self[transient],
             self[node.subject],
             self[node.index],
         ])
@@ -327,7 +451,9 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.RangeSubscriptExpr)
     def RangeSubscriptExpr(self, node):
         yield node
-        call = self.call(name="oppc_subscript", node=node, code=[
+        transient = self.transient()
+        call = self.call(name="oppc_range_subscript", code=[
+            self[transient],
             self[node.subject],
             self[node.start],
             self[node.end],
@@ -401,7 +527,7 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.IfExpr)
     def IfExpr(self, node):
         yield node
-        test = self.call(name="oppc_bool", node=node, code=[
+        test = self.call(name="oppc_cast_bool", code=[
             self[node.test],
         ])
         self[node].emit(stmt="if (")
@@ -420,7 +546,7 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.SwitchExpr)
     def SwitchExpr(self, node):
         yield node
-        subject = self.call(name="oppc_int64", node=node, code=[
+        subject = self.call(name="oppc_cast_int64", code=[
             self[node.subject],
         ])
         self[node].emit(stmt="switch (")
@@ -482,23 +608,56 @@ class CodeVisitor(pc_util.Visitor):
     @pc_util.Hook(pc_ast.Call)
     def Call(self, node):
         yield node
-        code = tuple(map(lambda arg: self[arg], node.args))
-        call = self.call(name=str(node.name), node=node, code=code)
-        for (level, stmt) in self[call]:
-            self[node].emit(stmt=stmt, level=level)
+        if node.args:
+            transient = self.transient()
+            code = [self[transient]]
+        else:
+            code = []
+        code.extend(map(lambda arg: self[arg], node.args))
+        name = f"OPPC_CALL_{str(node.name)}"
+        with self.pseudocode(node=node):
+            call = self.call(name=name, code=code)
+            for (level, stmt) in self[call]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Attribute.Name)
+    def AttributeName(self, node):
+        yield node
+
+    @pc_util.Hook(pc_ast.Sequence)
+    def Sequence(self, node):
+        yield node
+
+    @pc_util.Hook(pc_ast.Attribute)
+    def Attribute(self, node):
+        yield node
+        attr = str(self.__pseudocode[node])
+        symbol = f"OPPC_ATTR_{attr.replace('.', '_')}"
+        self[node].emit(f"/* {attr} */")
+        self[node].emit(stmt=symbol)
 
     @pc_util.Hook(pc_ast.Symbol)
     def Symbol(self, node):
         yield node
         with self.pseudocode(node=node):
-            if str(node) not in ("fallthrough",):
-                self.__decls[str(node)].append(node)
-                self[node].emit(stmt=f"&{str(node)}")
+            decl = str(node)
+            if decl not in ("fallthrough",):
+                if decl in ("TRAP",):
+                    self[node].emit(stmt=f"OPPC_CALL_{decl}();")
+                else:
+                    if node in self.__pseudocode:
+                        self.__decls.add(decl)
+                    self[node].emit(stmt=f"&{decl}")
+
+    @pc_util.Hook(Instruction)
+    def Instruction(self, node):
+        yield node
+        self[node].emit("insn")
 
     @pc_util.Hook(pc_ast.Node)
     def Node(self, node):
         raise NotImplementedError(type(node))
 
 
-def code(name, root):
-    yield from CodeVisitor(name=name, root=root)
+def code(insn, root):
+    yield from CodeVisitor(insn=insn, root=root)