From 9f1624dae103c45ad2ef7bb235b513cc9130bf6e Mon Sep 17 00:00:00 2001 From: Dmitry Selyutin Date: Sun, 14 Jan 2024 15:24:39 +0300 Subject: [PATCH] oppc/code: introduce ccall --- src/openpower/oppc/pc_code.py | 98 +++++++++++++++++++++++++---------- 1 file changed, 71 insertions(+), 27 deletions(-) diff --git a/src/openpower/oppc/pc_code.py b/src/openpower/oppc/pc_code.py index 7b37a294..f351f5b8 100644 --- a/src/openpower/oppc/pc_code.py +++ b/src/openpower/oppc/pc_code.py @@ -17,6 +17,12 @@ class Transient(pc_ast.Node): return f"oppc_transient(&(struct oppc_int){{}}, {self.__value}, {self.__bits})" +class CCall(pc_ast.Dataclass): + name: str + code: tuple + stmt: bool + + class CodeVisitor(pc_util.Visitor): def __init__(self, name, root): self.__root = root @@ -51,29 +57,30 @@ class CodeVisitor(pc_util.Visitor): self.traverse(root=transient) return transient + def ccall(self, node, name, code, stmt=False): + def validate(item): + def validate(item): + (level, stmt) = item + if not isinstance(level, int): + raise ValueError(level) + if not isinstance(stmt, str): + raise ValueError(stmt) + return (level, stmt) + + return tuple(map(validate, item)) + + code = tuple(map(validate, code)) + ccall = CCall(name=name, code=code, stmt=stmt) + with self.pseudocode(node=node): + self.traverse(root=ccall) + return ccall + @contextlib.contextmanager def pseudocode(self, node): for (level, stmt) in self.__pseudocode[node]: self[node].emit(stmt=f"/* {stmt} */", level=level) yield - def call(self, node, code, prefix="", suffix=""): - with self.pseudocode(node=node): - self[node].emit(stmt=f"{prefix}(") - with self[node]: - for chunk in code[:-1]: - for (level, stmt) in chunk: - if not (not stmt or - stmt.startswith("/*") or - stmt.endswith((",", "(", "{", "*/"))): - stmt = (stmt + ",") - self[node].emit(stmt=stmt, level=level) - if len(code) > 0: - for (level, stmt) in code[-1]: - if stmt: - self[node].emit(stmt=stmt, level=level) - self[node].emit(stmt=f"){suffix}") - @pc_util.Hook(pc_ast.Scope) def Scope(self, node): yield node @@ -101,23 +108,25 @@ class CodeVisitor(pc_util.Visitor): ]))] if isinstance(node.lvalue, pc_ast.SubscriptExpr): - self.call(prefix="oppc_subscript_assign", suffix=";", node=node, code=[ + ccall = self.ccall(name="oppc_subscript_assign", node=node, stmt=True, code=[ self[node.lvalue.subject], self[node.lvalue.index], rvalue, ]) elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr): - self.call(prefix="oppc_range_subscript_assign", suffix=";", node=node, code=[ + ccall = self.ccall(name="oppc_range_subscript_assign", node=node, stmt=True, code=[ self[node.lvalue.subject], self[node.lvalue.start], self[node.lvalue.end], rvalue, ]) else: - self.call(prefix="oppc_assign", suffix=";", node=node, code=[ + ccall = self.ccall(name="oppc_assign", stmt=True, node=node, code=[ self[node.lvalue], rvalue, ]) + for (level, stmt) in self[ccall]: + self[node].emit(stmt=stmt, level=level) @pc_util.Hook(pc_ast.BinaryExpr) def BinaryExpr(self, node): @@ -133,24 +142,28 @@ class CodeVisitor(pc_util.Visitor): pc_ast.Ge, pc_ast.Gt, ) if isinstance(node.op, comparison): - self.call(prefix=str(self[node.op]), node=node, code=[ + ccall = self.ccall(name=str(self[node.op]), node=node, code=[ self[node.left], self[node.right], ]) else: transient = self.transient(node=node) - self.call(prefix=str(self[node.op]), node=node, code=[ + ccall = self.ccall(name=str(self[node.op]), node=node, code=[ self[transient], self[node.left], self[node.right], ]) + for (level, stmt) in self[ccall]: + self[node].emit(stmt=stmt, level=level) @pc_util.Hook(pc_ast.UnaryExpr) def UnaryExpr(self, node): yield node - self.call(prefix=str(self[node.op]), node=node, code=[ + ccall = self.ccall(name=str(self[node.op]), node=node, code=[ self[node.value], ]) + for (level, stmt) in self[ccall]: + self[node].emit(stmt=stmt, level=level) @pc_util.Hook( pc_ast.Not, pc_ast.Add, pc_ast.Sub, @@ -211,6 +224,29 @@ class CodeVisitor(pc_util.Visitor): yield node self[node].emit(stmt=str(node)) + @pc_util.Hook(CCall) + def CCall(self, node): + yield node + end = (";" if node.stmt else "") + if len(node.code) == 0: + self[node].emit(stmt=f"{str(node.name)}(){end}") + else: + self[node].emit(stmt=f"{str(node.name)}(") + with self[node]: + (*head, tail) = node.code + for code in head: + for (level, stmt) in code: + self[node].emit(stmt=stmt, level=level) + (level, stmt) = self[node][-1] + if not (not stmt or + stmt.startswith("/*") or + stmt.endswith((",", "(", "{", "*/"))): + stmt = (stmt + ",") + self[node][-1] = (level, stmt) + for (level, stmt) in tail: + self[node].emit(stmt=stmt, level=level) + self[node].emit(stmt=f"){end}") + @pc_util.Hook(pc_ast.GPR) def GPR(self, node): yield node @@ -227,11 +263,13 @@ class CodeVisitor(pc_util.Visitor): def RepeatExpr(self, node): yield node transient = self.transient(node=node) - self.call(prefix="oppc_repeat", node=node, code=[ + ccall = self.ccall(name="oppc_repeat", node=node, code=[ self[transient], self[node.subject], self[node.times], ]) + for (level, stmt) in self[ccall]: + self[node].emit(stmt=stmt, level=level) @pc_util.Hook(pc_ast.XLEN) def XLEN(self, node): @@ -244,19 +282,23 @@ class CodeVisitor(pc_util.Visitor): @pc_util.Hook(pc_ast.SubscriptExpr) def SubscriptExpr(self, node): yield node - self.call(prefix="oppc_subscript", node=node, code=[ + ccall = self.ccall(name="oppc_subscript", node=node, code=[ self[node.subject], self[node.index], ]) + for (level, stmt) in self[ccall]: + self[node].emit(stmt=stmt, level=level) @pc_util.Hook(pc_ast.RangeSubscriptExpr) def RangeSubscriptExpr(self, node): yield node - self.call(prefix="oppc_subscript", node=node, code=[ + ccall = self.ccall(name="oppc_subscript", node=node, code=[ self[node.subject], self[node.start], self[node.end], ]) + for (level, stmt) in self[ccall]: + self[node].emit(stmt=stmt, level=level) @pc_util.Hook(pc_ast.ForExpr) def ForExpr(self, node): @@ -353,7 +395,9 @@ class CodeVisitor(pc_util.Visitor): def Call(self, node): yield node code = tuple(map(lambda arg: self[arg], node.args)) - self.call(prefix=str(node.name), node=node, code=code) + ccall = self.ccall(name=str(node.name), node=node, code=code) + for (level, stmt) in self[ccall]: + self[node].emit(stmt=stmt, level=level) @pc_util.Hook(pc_ast.Symbol) def Symbol(self, node): -- 2.30.2