oppc: decouple attribute name
[openpower-isa.git] / src / openpower / oppc / pc_code.py
index b0bfaede9ba1a7946d1046d3ac1cb99eeaa463f1..d7cdeef5ee8fcbc9ee7844215f517f35531f35c9 100644 (file)
@@ -1,7 +1,26 @@
 import collections
+import contextlib
 
 import openpower.oppc.pc_ast as pc_ast
 import openpower.oppc.pc_util as pc_util
+import openpower.oppc.pc_pseudocode as pc_pseudocode
+
+
+class Transient(pc_ast.Node):
+    def __init__(self, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
+        self.__value = value
+        self.__bits = bits
+
+        return super().__init__()
+
+    def __str__(self):
+        return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
+
+
+class Call(pc_ast.Dataclass):
+    name: str
+    code: tuple
+    stmt: bool
 
 
 class CodeVisitor(pc_util.Visitor):
@@ -13,15 +32,16 @@ class CodeVisitor(pc_util.Visitor):
         self.__decls = collections.defaultdict(list)
         self.__regfetch = collections.defaultdict(list)
         self.__regstore = collections.defaultdict(list)
+        self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
 
         super().__init__(root=root)
 
-        self.__code[self.__header].emit("void")
-        self.__code[self.__header].emit(f"oppc_{name}(void) {{")
+        self.__code[self.__header].emit(stmt="void")
+        self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
         with self.__code[self.__header]:
             for decl in self.__decls:
-                self.__code[self.__header].emit(f"uint64_t {decl};")
-        self.__code[self.__footer].emit(f"}}")
+                self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
+        self.__code[self.__footer].emit(stmt=f"}}")
 
     def __iter__(self):
         yield from self.__code[self.__header]
@@ -31,6 +51,53 @@ class CodeVisitor(pc_util.Visitor):
     def __getitem__(self, node):
         return self.__code[node]
 
+    def transient(self, node,
+            value="UINT64_C(0)",
+            bits="(uint8_t)OPPC_XLEN"):
+        transient = Transient(value=value, bits=bits)
+        self.traverse(root=transient)
+        return transient
+
+    def call(self, node, name, code, stmt=False):
+        def validate(item):
+            def validate(item):
+                (level, stmt) = item
+                if not isinstance(level, int):
+                    raise ValueError(level)
+                if not isinstance(stmt, str):
+                    raise ValueError(stmt)
+                return (level, stmt)
+
+            return tuple(map(validate, item))
+
+        code = tuple(map(validate, code))
+        call = Call(name=name, code=code, stmt=stmt)
+        self.traverse(root=call)
+        return call
+
+    def ternary(self, node):
+        self[node].clear()
+        test = self.call(name="oppc_bool", node=node, code=[
+            self[node.test],
+        ])
+        self[node].emit(stmt="(")
+        with self[node]:
+            for (level, stmt) in self[test]:
+                self[node].emit(stmt=stmt, level=level)
+            self[node].emit(stmt="?")
+            for (level, stmt) in self[node.body]:
+                self[node].emit(stmt=stmt, level=level)
+            self[node].emit(stmt=":")
+            for (level, stmt) in self[node.orelse]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt=")")
+
+    @contextlib.contextmanager
+    def pseudocode(self, node):
+        for (level, stmt) in self.__pseudocode[node]:
+            self[node].emit(stmt=f"/* {stmt} */", level=level)
+        yield
+
     @pc_util.Hook(pc_ast.Scope)
     def Scope(self, node):
         yield node
@@ -47,15 +114,30 @@ class CodeVisitor(pc_util.Visitor):
         if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
             self.__regfetch[str(node.rvalue)].append(node.rvalue)
 
-        if str(node.lvalue) in self.__decls:
-            stmt = " ".join([
-                str(self[node.lvalue]),
-                "=",
-                str(self[node.rvalue]),
+        if isinstance(node.rvalue, pc_ast.IfExpr):
+            self.ternary(node=node.rvalue)
+
+        if isinstance(node.lvalue, pc_ast.SubscriptExpr):
+            call = self.call(name="oppc_subscript_assign", node=node, stmt=True, code=[
+                self[node.lvalue.subject],
+                self[node.lvalue.index],
+                self[node.rvalue],
+            ])
+        elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
+            call = self.call(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
+                self[node.lvalue.subject],
+                self[node.lvalue.start],
+                self[node.lvalue.end],
+                self[node.rvalue],
             ])
-            self[node].emit(stmt=f"{stmt};")
         else:
-            raise ValueError(node)
+            call = self.call(name="oppc_assign", stmt=True, node=node, code=[
+                self[node.lvalue],
+                self[node.rvalue],
+            ])
+        with self.pseudocode(node=node):
+            for (level, stmt) in self[call]:
+                self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.BinaryExpr)
     def BinaryExpr(self, node):
@@ -64,62 +146,45 @@ class CodeVisitor(pc_util.Visitor):
             self.__regfetch[str(node.left)].append(node.left)
         if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
             self.__regfetch[str(node.right)].append(node.left)
-        if isinstance(node.left, (pc_ast.GPR, pc_ast.FPR)):
-            left = f"oppc_reg_fetch({str(self[node.left])})"
-        else:
-            left = str(self[node.left])
-        if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
-            right = f"oppc_reg_fetch({str(self[node.right])})"
-        else:
-            right = str(self[node.right])
-        if isinstance(node.op, (pc_ast.Add, pc_ast.Sub,
-                    pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
-                    pc_ast.Lt, pc_ast.Le,
-                    pc_ast.Eq, pc_ast.NotEq,
-                    pc_ast.Ge, pc_ast.Gt,
-                    pc_ast.LShift, pc_ast.RShift,
-                    pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
-                )):
-            op = {
-                pc_ast.Not: "~",
-                pc_ast.Add: "+",
-                pc_ast.Sub: "-",
-                pc_ast.Mul: "*",
-                pc_ast.Div: "/",
-                pc_ast.Mod: "%",
-                pc_ast.Lt: "<",
-                pc_ast.Le: "<=",
-                pc_ast.Eq: "==",
-                pc_ast.NotEq: "!=",
-                pc_ast.Ge: ">=",
-                pc_ast.Gt: ">",
-                pc_ast.LShift: "<<",
-                pc_ast.RShift: "<<",
-                pc_ast.BitAnd: "&",
-                pc_ast.BitOr: "|",
-                pc_ast.BitXor: "^",
-            }[node.op.__class__]
-            stmt = " ".join([left, op, right])
-            self[node].emit(stmt=f"({stmt})")
+
+        comparison = (
+            pc_ast.Lt, pc_ast.Le,
+            pc_ast.Eq, pc_ast.NotEq,
+            pc_ast.Ge, pc_ast.Gt,
+            pc_ast.LtU, pc_ast.GtU,
+        )
+        if isinstance(node.left, pc_ast.IfExpr):
+            self.ternary(node=node.left)
+        if isinstance(node.right, pc_ast.IfExpr):
+            self.ternary(node=node.right)
+
+        if isinstance(node.op, comparison):
+            call = self.call(name=str(self[node.op]), node=node, code=[
+                self[node.left],
+                self[node.right],
+            ])
         else:
-            raise ValueError(node)
+            transient = self.transient(node=node)
+            call = self.call(name=str(self[node.op]), node=node, code=[
+                self[transient],
+                self[node.left],
+                self[node.right],
+            ])
+        with self.pseudocode(node=node):
+            for (level, stmt) in self[call]:
+                self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(pc_ast.UnaryExpr)
     def UnaryExpr(self, node):
         yield node
-        if isinstance(node.value, (pc_ast.GPR, pc_ast.FPR)):
-            value = f"oppc_reg_fetch({str(self[node.value])})"
-        else:
-            value = f"({str(self[node.value])})"
-        if isinstance(node.op, (pc_ast.Not, pc_ast.Add, pc_ast.Sub)):
-            op = {
-                pc_ast.Not: "~",
-                pc_ast.Add: "+",
-                pc_ast.Sub: "-",
-            }[node.op.__class__]
-            self[node].emit(stmt="".join([op, value]))
-        else:
-            raise ValueError(node)
+        if isinstance(node.value, pc_ast.IfExpr):
+            self.ternary(node=node.value)
+        call = self.call(name=str(self[node.op]), node=node, code=[
+            self[node.value],
+        ])
+        with self.pseudocode(node=node):
+            for (level, stmt) in self[call]:
+                self[node].emit(stmt=stmt, level=level)
 
     @pc_util.Hook(
             pc_ast.Not, pc_ast.Add, pc_ast.Sub,
@@ -127,28 +192,285 @@ class CodeVisitor(pc_util.Visitor):
             pc_ast.Lt, pc_ast.Le,
             pc_ast.Eq, pc_ast.NotEq,
             pc_ast.Ge, pc_ast.Gt,
+            pc_ast.LtU, pc_ast.GtU,
             pc_ast.LShift, pc_ast.RShift,
             pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
+            pc_ast.BitConcat,
         )
     def Op(self, node):
         yield node
+        op = {
+            pc_ast.Not: "oppc_not",
+            pc_ast.Add: "oppc_add",
+            pc_ast.Sub: "oppc_sub",
+            pc_ast.Mul: "oppc_mul",
+            pc_ast.Div: "oppc_div",
+            pc_ast.Mod: "oppc_mod",
+            pc_ast.Lt: "oppc_lt",
+            pc_ast.Le: "oppc_le",
+            pc_ast.Eq: "oppc_eq",
+            pc_ast.LtU: "oppc_ltu",
+            pc_ast.GtU: "oppc_gtu",
+            pc_ast.NotEq: "oppc_noteq",
+            pc_ast.Ge: "oppc_ge",
+            pc_ast.Gt: "oppc_gt",
+            pc_ast.LShift: "oppc_lshift",
+            pc_ast.RShift: "oppc_rshift",
+            pc_ast.BitAnd: "oppc_and",
+            pc_ast.BitOr: "oppc_or",
+            pc_ast.BitXor: "oppc_xor",
+            pc_ast.BitConcat: "oppc_concat",
+        }[node.__class__]
+        self[node].emit(stmt=op)
 
     @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
     def Integer(self, node):
         yield node
+        fmt = hex
+        value = str(node)
         if isinstance(node, pc_ast.BinLiteral):
-            base = 2
-        elif isinstance(node, pc_ast.DecLiteral):
-            base = 10
+            bits = f"UINT8_C({str(len(value[2:]))})"
+            value = int(value, 2)
         elif isinstance(node, pc_ast.HexLiteral):
-            base = 16
+            bits = f"UINT8_C({str(len(value[2:]) * 4)})"
+            value = int(value, 16)
+        else:
+            bits = "(uint8_t)OPPC_XLEN"
+            value = int(value)
+            fmt = str
+        if (value > ((2**64) - 1)):
+            raise NotImplementedError()
+        value = f"UINT64_C({fmt(value)})"
+        transient = self.transient(node=node, value=value, bits=bits)
+        with self.pseudocode(node=node):
+            for (level, stmt) in self[transient]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(Transient)
+    def Transient(self, node):
+        yield node
+        self[node].emit(stmt=str(node))
+
+    @pc_util.Hook(Call)
+    def CCall(self, node):
+        yield node
+        end = (";" if node.stmt else "")
+        if len(node.code) == 0:
+            self[node].emit(stmt=f"{str(node.name)}(){end}")
+        else:
+            self[node].emit(stmt=f"{str(node.name)}(")
+            with self[node]:
+                (*head, tail) = node.code
+                for code in head:
+                    for (level, stmt) in code:
+                        self[node].emit(stmt=stmt, level=level)
+                    (level, stmt) = self[node][-1]
+                    if not (not stmt or
+                            stmt.startswith("/*") or
+                            stmt.endswith((",", "(", "{", "*/"))):
+                        stmt = (stmt + ",")
+                    self[node][-1] = (level, stmt)
+                for (level, stmt) in tail:
+                    self[node].emit(stmt=stmt, level=level)
+            self[node].emit(stmt=f"){end}")
+
+    @pc_util.Hook(pc_ast.GPR)
+    def GPR(self, node):
+        yield node
+        with self.pseudocode(node=node):
+            self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
+
+    @pc_util.Hook(pc_ast.FPR)
+    def FPR(self, node):
+        yield node
+        with self.pseudocode(node=node):
+            self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
+
+    @pc_util.Hook(pc_ast.RepeatExpr)
+    def RepeatExpr(self, node):
+        yield node
+        transient = self.transient(node=node)
+        call = self.call(name="oppc_repeat", node=node, code=[
+            self[transient],
+            self[node.subject],
+            self[node.times],
+        ])
+        for (level, stmt) in self[call]:
+            self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.XLEN)
+    def XLEN(self, node):
+        yield node
+        (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
+        transient = self.transient(node=node, value=value, bits=bits)
+        with self.pseudocode(node=node):
+            for (level, stmt) in self[transient]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Overflow, pc_ast.CR3, pc_ast.CR5,
+            pc_ast.XER, pc_ast.Reserve, pc_ast.Special)
+    def Special(self, node):
+        yield node
+        with self.pseudocode(node=node):
+            self[node].emit(stmt=f"&OPPC_{str(node).upper()}")
+
+    @pc_util.Hook(pc_ast.SubscriptExpr)
+    def SubscriptExpr(self, node):
+        yield node
+        call = self.call(name="oppc_subscript", node=node, code=[
+            self[node.subject],
+            self[node.index],
+        ])
+        for (level, stmt) in self[call]:
+            self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.RangeSubscriptExpr)
+    def RangeSubscriptExpr(self, node):
+        yield node
+        call = self.call(name="oppc_subscript", node=node, code=[
+            self[node.subject],
+            self[node.start],
+            self[node.end],
+        ])
+        for (level, stmt) in self[call]:
+            self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.ForExpr)
+    def ForExpr(self, node):
+        yield node
+
+        enter = pc_ast.AssignExpr(
+            lvalue=node.subject.clone(),
+            rvalue=node.start.clone(),
+        )
+        match = pc_ast.BinaryExpr(
+            left=node.subject.clone(),
+            op=pc_ast.Le("<="),
+            right=node.end.clone(),
+        )
+        leave = pc_ast.AssignExpr(
+            lvalue=node.subject.clone(),
+            rvalue=pc_ast.BinaryExpr(
+                left=node.subject.clone(),
+                op=pc_ast.Add("+"),
+                right=node.end.clone(),
+            ),
+        )
+        with self.pseudocode(node=node):
+            (level, stmt) = self[node][0]
+        self[node].clear()
+        self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt="for (")
+        with self[node]:
+            with self[node]:
+                for subnode in (enter, match, leave):
+                    self.__pseudocode.traverse(root=subnode)
+                    self.traverse(root=subnode)
+                    for (level, stmt) in self[subnode]:
+                        self[node].emit(stmt=stmt, level=level)
+                    (level, stmt) = self[node][-1]
+                    if subnode is match:
+                        stmt = f"{stmt};"
+                    elif subnode is leave:
+                        stmt = stmt[:-1]
+                    self[node][-1] = (level, stmt)
+        (level, stmt) = self[node][0]
+        self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt=") {")
+        for (level, stmt) in self[node.body]:
+            self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt="}")
+
+    @pc_util.Hook(pc_ast.WhileExpr)
+    def WhileExpr(self, node):
+        yield node
+        self[node].emit(stmt="while (")
+        with self[node]:
+            with self[node]:
+                for (level, stmt) in self[node.test]:
+                    self[node].emit(stmt=stmt, level=level)
+        self[node].emit(") {")
+        for (level, stmt) in self[node.body]:
+            self[node].emit(stmt=stmt, level=level)
+        if node.orelse:
+            self[node].emit(stmt="} else {")
+            for (level, stmt) in self[node.orelse]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt="}")
+
+    @pc_util.Hook(pc_ast.IfExpr)
+    def IfExpr(self, node):
+        yield node
+        test = self.call(name="oppc_bool", node=node, code=[
+            self[node.test],
+        ])
+        self[node].emit(stmt="if (")
+        with self[node]:
+            for (level, stmt) in self[test]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt=") {")
+        for (level, stmt) in self[node.body]:
+            self[node].emit(stmt=stmt, level=level)
+        if node.orelse:
+            self[node].emit(stmt="} else {")
+            for (level, stmt) in self[node.orelse]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(stmt="}")
+
+    @pc_util.Hook(pc_ast.SwitchExpr)
+    def SwitchExpr(self, node):
+        yield node
+        subject = self.call(name="oppc_int64", node=node, code=[
+            self[node.subject],
+        ])
+        self[node].emit(stmt="switch (")
+        with self[node]:
+            for (level, stmt) in self[subject]:
+                self[node].emit(stmt=stmt, level=level)
+        self[node].emit(") {")
+        with self[node]:
+            for (level, stmt) in self[node.cases]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Cases)
+    def Cases(self, node):
+        yield node
+        for subnode in node:
+            for (level, stmt) in self[subnode]:
+                self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Case)
+    def Case(self, node):
+        yield node
+        for (level, stmt) in self[node.labels]:
+            self[node].emit(stmt=stmt, level=level)
+        for (level, stmt) in self[node.body]:
+            self[node].emit(stmt=stmt, level=level)
+
+    @pc_util.Hook(pc_ast.Labels)
+    def Labels(self, node):
+        yield node
+        if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)):
+            stmt = "default:"
         else:
-            raise ValueError(node)
-        self[node].emit(stmt=f"UINT64_C({hex(int(node, base))})")
+            labels = ", ".join(map(lambda label: str(self[label]), node))
+            stmt = f"case ({labels}):"
+        self[node].emit(stmt=stmt)
+
+    @pc_util.Hook(pc_ast.Label)
+    def Label(self, node):
+        yield node
+        self[node].emit(stmt=str(node))
+
+    @pc_util.Hook(pc_ast.LeaveKeyword)
+    def LeaveKeyword(self, node):
+        yield node
+        self[node].emit(stmt="break;")
 
     @pc_util.Hook(pc_ast.Call.Name)
     def CallName(self, node):
         yield node
+        self[node].emit(stmt=str(node))
 
     @pc_util.Hook(pc_ast.Call.Arguments)
     def CallArguments(self, node):
@@ -157,11 +479,21 @@ class CodeVisitor(pc_util.Visitor):
             if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
                 self.__regfetch[str(subnode)].append(subnode)
 
+    @pc_util.Hook(pc_ast.Call)
+    def Call(self, node):
+        yield node
+        code = tuple(map(lambda arg: self[arg], node.args))
+        call = self.call(name=str(node.name), node=node, code=code)
+        for (level, stmt) in self[call]:
+            self[node].emit(stmt=stmt, level=level)
+
     @pc_util.Hook(pc_ast.Symbol)
     def Symbol(self, node):
         yield node
-        self.__decls[str(node)].append(node)
-        self[node].emit(stmt=str(node))
+        with self.pseudocode(node=node):
+            if str(node) not in ("fallthrough",):
+                self.__decls[str(node)].append(node)
+                self[node].emit(stmt=f"&{str(node)}")
 
     @pc_util.Hook(pc_ast.Node)
     def Node(self, node):