import collections
+import contextlib
import openpower.oppc.pc_ast as pc_ast
import openpower.oppc.pc_util as pc_util
+import openpower.oppc.pc_pseudocode as pc_pseudocode
+
+
+class Transient(pc_ast.Node):
+ def __init__(self, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
+ self.__value = value
+ self.__bits = bits
+
+ return super().__init__()
+
+ def __str__(self):
+ return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
+
+
+class CCall(pc_ast.Dataclass):
+ name: str
+ code: tuple
+ stmt: bool
class CodeVisitor(pc_util.Visitor):
def __init__(self, name, root):
- self.__name = name
- self.__code = pc_util.Code()
+ self.__root = root
+ self.__header = object()
+ self.__footer = object()
+ self.__code = collections.defaultdict(lambda: pc_util.Code())
self.__decls = collections.defaultdict(list)
self.__regfetch = collections.defaultdict(list)
self.__regstore = collections.defaultdict(list)
+ self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
super().__init__(root=root)
- self.__code.emit("void")
- self.__code.emit(f"oppc_{name}(void) {{")
- with self.__code:
+ self.__code[self.__header].emit(stmt="void")
+ self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
+ with self.__code[self.__header]:
for decl in self.__decls:
- self.__code.emit(f"uint64_t {decl};")
- self.__code.emit(f"}}")
+ self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
+ self.__code[self.__footer].emit(stmt=f"}}")
def __iter__(self):
- yield from self.__code
+ yield from self.__code[self.__header]
+ yield from self.__code[self.__root]
+ yield from self.__code[self.__footer]
+
+ def __getitem__(self, node):
+ return self.__code[node]
+
+ def transient(self, node,
+ value="UINT64_C(0)",
+ bits="(uint8_t)OPPC_XLEN"):
+ transient = Transient(value=value, bits=bits)
+ self.traverse(root=transient)
+ return transient
+
+ def ccall(self, node, name, code, stmt=False):
+ def validate(item):
+ def validate(item):
+ (level, stmt) = item
+ if not isinstance(level, int):
+ raise ValueError(level)
+ if not isinstance(stmt, str):
+ raise ValueError(stmt)
+ return (level, stmt)
+
+ return tuple(map(validate, item))
+
+ code = tuple(map(validate, code))
+ ccall = CCall(name=name, code=code, stmt=stmt)
+ self.traverse(root=ccall)
+ return ccall
+
+ def ternary(self, node):
+ self[node].clear()
+ self[node].emit(stmt="(")
+ with self[node]:
+ for (level, stmt) in self[node.test]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt="?")
+ for (level, stmt) in self[node.body]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt=":")
+ for (level, stmt) in self[node.orelse]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt=")")
+
+ @contextlib.contextmanager
+ def pseudocode(self, node):
+ for (level, stmt) in self.__pseudocode[node]:
+ self[node].emit(stmt=f"/* {stmt} */", level=level)
+ yield
+
+ @pc_util.Hook(pc_ast.Scope)
+ def Scope(self, node):
+ yield node
+ with self[node]:
+ for subnode in node:
+ for (level, stmt) in self[subnode]:
+ self[node].emit(stmt=stmt, level=level)
@pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
def AssignExpr(self, node):
if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
self.__regfetch[str(node.rvalue)].append(node.rvalue)
+ if isinstance(node.rvalue, pc_ast.IfExpr):
+ self.ternary(node=node.rvalue)
+
+ if isinstance(node.lvalue, pc_ast.SubscriptExpr):
+ ccall = self.ccall(name="oppc_subscript_assign", node=node, stmt=True, code=[
+ self[node.lvalue.subject],
+ self[node.lvalue.index],
+ self[node.rvalue],
+ ])
+ elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
+ ccall = self.ccall(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
+ self[node.lvalue.subject],
+ self[node.lvalue.start],
+ self[node.lvalue.end],
+ self[node.rvalue],
+ ])
+ else:
+ ccall = self.ccall(name="oppc_assign", stmt=True, node=node, code=[
+ self[node.lvalue],
+ self[node.rvalue],
+ ])
+ with self.pseudocode(node=node):
+ for (level, stmt) in self[ccall]:
+ self[node].emit(stmt=stmt, level=level)
+
@pc_util.Hook(pc_ast.BinaryExpr)
def BinaryExpr(self, node):
yield node
if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
self.__regfetch[str(node.right)].append(node.left)
+ comparison = (
+ pc_ast.Lt, pc_ast.Le,
+ pc_ast.Eq, pc_ast.NotEq,
+ pc_ast.Ge, pc_ast.Gt,
+ pc_ast.LtU, pc_ast.GtU,
+ )
+ if isinstance(node.left, pc_ast.IfExpr):
+ self.ternary(node=node.left)
+ if isinstance(node.right, pc_ast.IfExpr):
+ self.ternary(node=node.right)
+
+ if isinstance(node.op, comparison):
+ ccall = self.ccall(name=str(self[node.op]), node=node, code=[
+ self[node.left],
+ self[node.right],
+ ])
+ else:
+ transient = self.transient(node=node)
+ ccall = self.ccall(name=str(self[node.op]), node=node, code=[
+ self[transient],
+ self[node.left],
+ self[node.right],
+ ])
+ with self.pseudocode(node=node):
+ for (level, stmt) in self[ccall]:
+ self[node].emit(stmt=stmt, level=level)
+
@pc_util.Hook(pc_ast.UnaryExpr)
def UnaryExpr(self, node):
yield node
- if isinstance(node.value, (pc_ast.GPR, pc_ast.FPR)):
- self.__regfetch[str(node.value)].append(node.value)
+ if isinstance(node.value, pc_ast.IfExpr):
+ self.ternary(node=node.value)
+ ccall = self.ccall(name=str(self[node.op]), node=node, code=[
+ self[node.value],
+ ])
+ with self.pseudocode(node=node):
+ for (level, stmt) in self[ccall]:
+ self[node].emit(stmt=stmt, level=level)
+
+ @pc_util.Hook(
+ pc_ast.Not, pc_ast.Add, pc_ast.Sub,
+ pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
+ pc_ast.Lt, pc_ast.Le,
+ pc_ast.Eq, pc_ast.NotEq,
+ pc_ast.Ge, pc_ast.Gt,
+ pc_ast.LtU, pc_ast.GtU,
+ pc_ast.LShift, pc_ast.RShift,
+ pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
+ pc_ast.BitConcat,
+ )
+ def Op(self, node):
+ yield node
+ op = {
+ pc_ast.Not: "oppc_not",
+ pc_ast.Add: "oppc_add",
+ pc_ast.Sub: "oppc_sub",
+ pc_ast.Mul: "oppc_mul",
+ pc_ast.Div: "oppc_div",
+ pc_ast.Mod: "oppc_mod",
+ pc_ast.Lt: "oppc_lt",
+ pc_ast.Le: "oppc_le",
+ pc_ast.Eq: "oppc_eq",
+ pc_ast.LtU: "oppc_ltu",
+ pc_ast.GtU: "oppc_gtu",
+ pc_ast.NotEq: "oppc_noteq",
+ pc_ast.Ge: "oppc_ge",
+ pc_ast.Gt: "oppc_gt",
+ pc_ast.LShift: "oppc_lshift",
+ pc_ast.RShift: "oppc_rshift",
+ pc_ast.BitAnd: "oppc_and",
+ pc_ast.BitOr: "oppc_or",
+ pc_ast.BitXor: "oppc_xor",
+ pc_ast.BitConcat: "oppc_concat",
+ }[node.__class__]
+ self[node].emit(stmt=op)
+
+ @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
+ def Integer(self, node):
+ yield node
+ fmt = hex
+ value = str(node)
+ if isinstance(node, pc_ast.BinLiteral):
+ bits = f"UINT8_C({str(len(value[2:]))})"
+ value = int(value, 2)
+ elif isinstance(node, pc_ast.HexLiteral):
+ bits = f"UINT8_C({str(len(value[2:]) * 4)})"
+ value = int(value, 16)
+ else:
+ bits = "(uint8_t)OPPC_XLEN"
+ value = int(value)
+ fmt = str
+ if (value > ((2**64) - 1)):
+ raise NotImplementedError()
+ value = f"UINT64_C({fmt(value)})"
+ transient = self.transient(node=node, value=value, bits=bits)
+ with self.pseudocode(node=node):
+ for (level, stmt) in self[transient]:
+ self[node].emit(stmt=stmt, level=level)
+
+ @pc_util.Hook(Transient)
+ def Transient(self, node):
+ yield node
+ self[node].emit(stmt=str(node))
+
+ @pc_util.Hook(CCall)
+ def CCall(self, node):
+ yield node
+ end = (";" if node.stmt else "")
+ if len(node.code) == 0:
+ self[node].emit(stmt=f"{str(node.name)}(){end}")
+ else:
+ self[node].emit(stmt=f"{str(node.name)}(")
+ with self[node]:
+ (*head, tail) = node.code
+ for code in head:
+ for (level, stmt) in code:
+ self[node].emit(stmt=stmt, level=level)
+ (level, stmt) = self[node][-1]
+ if not (not stmt or
+ stmt.startswith("/*") or
+ stmt.endswith((",", "(", "{", "*/"))):
+ stmt = (stmt + ",")
+ self[node][-1] = (level, stmt)
+ for (level, stmt) in tail:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt=f"){end}")
+
+ @pc_util.Hook(pc_ast.GPR)
+ def GPR(self, node):
+ yield node
+ with self.pseudocode(node=node):
+ self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
+
+ @pc_util.Hook(pc_ast.FPR)
+ def FPR(self, node):
+ yield node
+ with self.pseudocode(node=node):
+ self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
+
+ @pc_util.Hook(pc_ast.RepeatExpr)
+ def RepeatExpr(self, node):
+ yield node
+ transient = self.transient(node=node)
+ ccall = self.ccall(name="oppc_repeat", node=node, code=[
+ self[transient],
+ self[node.subject],
+ self[node.times],
+ ])
+ for (level, stmt) in self[ccall]:
+ self[node].emit(stmt=stmt, level=level)
+
+ @pc_util.Hook(pc_ast.XLEN)
+ def XLEN(self, node):
+ yield node
+ (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
+ transient = self.transient(node=node, value=value, bits=bits)
+ with self.pseudocode(node=node):
+ for (level, stmt) in self[transient]:
+ self[node].emit(stmt=stmt, level=level)
+
+ @pc_util.Hook(pc_ast.SubscriptExpr)
+ def SubscriptExpr(self, node):
+ yield node
+ ccall = self.ccall(name="oppc_subscript", node=node, code=[
+ self[node.subject],
+ self[node.index],
+ ])
+ for (level, stmt) in self[ccall]:
+ self[node].emit(stmt=stmt, level=level)
+
+ @pc_util.Hook(pc_ast.RangeSubscriptExpr)
+ def RangeSubscriptExpr(self, node):
+ yield node
+ ccall = self.ccall(name="oppc_subscript", node=node, code=[
+ self[node.subject],
+ self[node.start],
+ self[node.end],
+ ])
+ for (level, stmt) in self[ccall]:
+ self[node].emit(stmt=stmt, level=level)
+
+ @pc_util.Hook(pc_ast.ForExpr)
+ def ForExpr(self, node):
+ yield node
+
+ enter = pc_ast.AssignExpr(
+ lvalue=node.subject.clone(),
+ rvalue=node.start.clone(),
+ )
+ match = pc_ast.BinaryExpr(
+ left=node.subject.clone(),
+ op=pc_ast.Le("<="),
+ right=node.end.clone(),
+ )
+ leave = pc_ast.AssignExpr(
+ lvalue=node.subject.clone(),
+ rvalue=pc_ast.BinaryExpr(
+ left=node.subject.clone(),
+ op=pc_ast.Add("+"),
+ right=node.end.clone(),
+ ),
+ )
+ with self.pseudocode(node=node):
+ (level, stmt) = self[node][0]
+ self[node].clear()
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt="for (")
+ with self[node]:
+ with self[node]:
+ for subnode in (enter, match, leave):
+ self.__pseudocode.traverse(root=subnode)
+ self.traverse(root=subnode)
+ for (level, stmt) in self[subnode]:
+ self[node].emit(stmt=stmt, level=level)
+ (level, stmt) = self[node][-1]
+ if subnode is match:
+ stmt = f"{stmt};"
+ elif subnode is leave:
+ stmt = stmt[:-1]
+ self[node][-1] = (level, stmt)
+ (level, stmt) = self[node][0]
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt=") {")
+ for (level, stmt) in self[node.body]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt="}")
+
+ @pc_util.Hook(pc_ast.WhileExpr)
+ def WhileExpr(self, node):
+ yield node
+ self[node].emit(stmt="while (")
+ with self[node]:
+ with self[node]:
+ for (level, stmt) in self[node.test]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(") {")
+ for (level, stmt) in self[node.body]:
+ self[node].emit(stmt=stmt, level=level)
+ if node.orelse:
+ self[node].emit(stmt="} else {")
+ for (level, stmt) in self[node.orelse]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt="}")
+
+ @pc_util.Hook(pc_ast.IfExpr)
+ def IfExpr(self, node):
+ yield node
+ self[node].emit(stmt="if (")
+ with self[node]:
+ for (level, stmt) in self[node.test]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt=") {")
+ for (level, stmt) in self[node.body]:
+ self[node].emit(stmt=stmt, level=level)
+ if node.orelse:
+ self[node].emit(stmt="} else {")
+ for (level, stmt) in self[node.orelse]:
+ self[node].emit(stmt=stmt, level=level)
+ self[node].emit(stmt="}")
@pc_util.Hook(pc_ast.Call.Name)
def CallName(self, node):
yield node
+ self[node].emit(stmt=str(node))
@pc_util.Hook(pc_ast.Call.Arguments)
def CallArguments(self, node):
if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
self.__regfetch[str(subnode)].append(subnode)
+ @pc_util.Hook(pc_ast.Call)
+ def Call(self, node):
+ yield node
+ code = tuple(map(lambda arg: self[arg], node.args))
+ ccall = self.ccall(name=str(node.name), node=node, code=code)
+ for (level, stmt) in self[ccall]:
+ self[node].emit(stmt=stmt, level=level)
+
@pc_util.Hook(pc_ast.Symbol)
def Symbol(self, node):
yield node
self.__decls[str(node)].append(node)
+ with self.pseudocode(node=node):
+ self[node].emit(stmt=f"&{str(node)}")
+
+ @pc_util.Hook(pc_ast.Node)
+ def Node(self, node):
+ raise NotImplementedError(type(node))
def code(name, root):