From: Dmitry Selyutin Date: Tue, 9 Jan 2024 18:23:44 +0000 (+0300) Subject: oppc: decouple common code X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=62df924fcc1793bb10ca13be8e0fee40c3e8cfab;p=openpower-isa.git oppc: decouple common code --- diff --git a/src/openpower/oppc/pc_pseudocode.py b/src/openpower/oppc/pc_pseudocode.py index f6b5a56c..0d111a81 100644 --- a/src/openpower/oppc/pc_pseudocode.py +++ b/src/openpower/oppc/pc_pseudocode.py @@ -1,69 +1,23 @@ import collections -import contextlib -import functools - -import mdis.dispatcher -import mdis.visitor -import mdis.walker import openpower.oppc.pc_ast as pc_ast +import openpower.oppc.pc_util as pc_util -class Hook(mdis.dispatcher.Hook): - def __call__(self, call): - hook = super().__call__(call) - - class ConcreteHook(hook.__class__): - @functools.wraps(hook.__call__) - @contextlib.contextmanager - def __call__(self, dispatcher, node, *args, **kwargs): - return hook(dispatcher, node, *args, **kwargs) - - return ConcreteHook(*tuple(self)) - - -class Code(list): - def __init__(self): - self.__level = 0 - return super().__init__() - - def __enter__(self): - self.__level += 1 - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.__level -= 1 - - def __str__(self): - if len(self) == 0: - raise ValueError("empty code") - - lines = [] - for (level, stmt) in self: - line = ((" " * level * 4) + stmt) - lines.append(line) - - return "\n".join(lines) - - def emit(self, stmt, level=0): - item = ((level + self.__level), stmt) - self.append(item) - - -class PseudocodeVisitor(mdis.visitor.ContextVisitor): +class PseudocodeVisitor(pc_util.Visitor): def __init__(self, root): + self.__code = collections.defaultdict(lambda: pc_util.Code()) self.__root = root - self.__code = collections.defaultdict(lambda: Code()) - return super().__init__() + return super().__init__(root=root) def __iter__(self): - yield from self.__code.items() + yield from self.__code[self.__root] def __getitem__(self, node): return self.__code[node] - @Hook(pc_ast.Scope) + @pc_util.Hook(pc_ast.Scope) def Scope(self, node): yield node if node is not self.__root: @@ -76,7 +30,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): for (level, stmt) in self[subnode]: self[node].emit(stmt=stmt, level=level) - @Hook(pc_ast.Call) + @pc_util.Hook(pc_ast.Call) def Call(self, node): yield node args = [] @@ -88,7 +42,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): stmt = f"{node.name}({args})" self[node].emit(stmt=stmt) - @Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr) + @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr) def AssignExpr(self, node): mapping = { pc_ast.AssignExpr: "<-", @@ -120,7 +74,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): ]) self[node].emit(stmt=stmt) - @Hook(pc_ast.BinaryExpr) + @pc_util.Hook(pc_ast.BinaryExpr) def BinaryExpr(self, node): yield node stmt = " ".join([ @@ -130,7 +84,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): ]) self[node].emit(stmt=f"({stmt})") - @Hook(pc_ast.IfExpr) + @pc_util.Hook(pc_ast.IfExpr) def IfExpr(self, node): yield node stmt = " ".join([ @@ -146,7 +100,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): for (level, stmt) in self[node.orelse]: self[node].emit(stmt=stmt, level=level) - @Hook(pc_ast.ForExpr) + @pc_util.Hook(pc_ast.ForExpr) def ForExpr(self, node): yield node stmt = " ".join([ @@ -161,7 +115,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): for (level, stmt) in self[node.body]: self[node].emit(stmt=stmt, level=level) - @Hook(pc_ast.WhileExpr) + @pc_util.Hook(pc_ast.WhileExpr) def WhileExpr(self, node): yield node stmt = " ".join([ @@ -177,7 +131,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): for (level, stmt) in self[node.orelse]: self[node].emit(stmt=stmt, level=level) - @Hook(pc_ast.RepeatExpr) + @pc_util.Hook(pc_ast.RepeatExpr) def RepeatExpr(self, node): yield node stmt = " ".join([ @@ -187,7 +141,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): ]) self[node].emit(stmt=f"({stmt})") - @Hook(pc_ast.SwitchExpr) + @pc_util.Hook(pc_ast.SwitchExpr) def SwitchExpr(self, node): yield node self[node].emit(f"switch({str(self[node.subject])})") @@ -195,14 +149,14 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): for (level, stmt) in self[node.cases]: self[node].emit(stmt=stmt, level=level) - @Hook(pc_ast.Cases) + @pc_util.Hook(pc_ast.Cases) def Cases(self, node): yield node for subnode in node: for (level, stmt) in self[subnode]: self[node].emit(stmt=stmt, level=level) - @Hook(pc_ast.Case) + @pc_util.Hook(pc_ast.Case) def Case(self, node): yield node for (level, stmt) in self[node.labels]: @@ -210,7 +164,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): for (level, stmt) in self[node.body]: self[node].emit(stmt=stmt, level=level) - @Hook(pc_ast.Labels) + @pc_util.Hook(pc_ast.Labels) def Labels(self, node): yield node if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)): @@ -220,17 +174,17 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): stmt = f"case ({labels}):" self[node].emit(stmt=stmt) - @Hook(pc_ast.Label) + @pc_util.Hook(pc_ast.Label) def Label(self, node): yield node self[node].emit(stmt=str(node)) - @Hook(pc_ast.DefaultLabel) + @pc_util.Hook(pc_ast.DefaultLabel) def DefaultLabel(self, node): yield node self[node].emit(stmt="default:") - @Hook(pc_ast.UnaryExpr) + @pc_util.Hook(pc_ast.UnaryExpr) def UnaryExpr(self, node): yield node stmt = "".join([ @@ -239,22 +193,22 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): ]) self[node].emit(stmt=stmt) - @Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral) + @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral) def Integer(self, node): yield node self[node].emit(stmt=str(node)) - @Hook(pc_ast.StringLiteral) + @pc_util.Hook(pc_ast.StringLiteral) def StringLiteral(self, node): yield node self[node].emit(stmt=f"'{str(node)}'") - @Hook(pc_ast.Symbol) + @pc_util.Hook(pc_ast.Symbol) def Symbol(self, node): yield node self[node].emit(stmt=str(node)) - @Hook(pc_ast.Attribute) + @pc_util.Hook(pc_ast.Attribute) def Attribute(self, node): yield node stmt = ".".join([ @@ -263,7 +217,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): ]) self[node].emit(stmt=stmt) - @Hook(pc_ast.Not, pc_ast.Add, pc_ast.Sub, + @pc_util.Hook(pc_ast.Not, pc_ast.Add, pc_ast.Sub, pc_ast.Mul, pc_ast.MulS, pc_ast.MulU, pc_ast.Div, pc_ast.DivT, pc_ast.Mod, pc_ast.Sqrt, @@ -307,7 +261,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): stmt = mapping[node.__class__] self[node].emit(stmt=stmt) - @Hook(pc_ast.LParenthesis, pc_ast.RParenthesis, + @pc_util.Hook(pc_ast.LParenthesis, pc_ast.RParenthesis, pc_ast.LBracket, pc_ast.RBracket) def BracketOrParenthesis(self, node): yield node @@ -320,7 +274,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): stmt = mapping[node.__class__] self[node].emit(stmt=stmt) - @Hook(pc_ast.Subscript) + @pc_util.Hook(pc_ast.Subscript) def Subscript(self, node): yield node stmt = "".join([ @@ -331,7 +285,7 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): ]) self[node].emit(stmt=stmt) - @Hook(pc_ast.RangeSubscript) + @pc_util.Hook(pc_ast.RangeSubscript) def RangeSubscript(self, node): yield node stmt = "".join([ @@ -344,32 +298,32 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): ]) self[node].emit(stmt=stmt) - @Hook(pc_ast.Colon) + @pc_util.Hook(pc_ast.Colon) def Colon(self, node): yield node self[node].emit(stmt=":") - @Hook(pc_ast.Linebreak, pc_ast.Endmarker) + @pc_util.Hook(pc_ast.Linebreak, pc_ast.Endmarker) def Ignore(self, node): yield node - @Hook(pc_ast.Keyword) + @pc_util.Hook(pc_ast.Keyword) def Keyword(self, node): yield node self[node].emit(stmt=node.__doc__) - @Hook(pc_ast.Sequence) + @pc_util.Hook(pc_ast.Sequence) def Sequence(self, node): yield node stmt = ",".join(map(lambda subnode: str(self[subnode]), node)) self[node].emit(stmt=f"({stmt})") - @Hook(pc_ast.Literal) + @pc_util.Hook(pc_ast.Literal) def Literal(self, node): yield node self[node].emit(stmt=str(node)) - @Hook(pc_ast.GPR, pc_ast.FPR, pc_ast.GPRZero) + @pc_util.Hook(pc_ast.GPR, pc_ast.FPR, pc_ast.GPRZero) def Reg(self, node): yield node if isinstance(node, pc_ast.GPRZero): @@ -377,20 +331,10 @@ class PseudocodeVisitor(mdis.visitor.ContextVisitor): else: self[node].emit(stmt=f"({str(node)})") - @Hook(pc_ast.Node) + @pc_util.Hook(pc_ast.Node) def Node(self, node): raise NotImplementedError(type(node)) -def traverse(root, visitor, walker): - with visitor(root): - for node in walker(root): - traverse(root=node, visitor=visitor, walker=walker) - - def pseudocode(root): - walker = mdis.walker.Walker() - visitor = PseudocodeVisitor(root=root) - traverse(root=root, visitor=visitor, walker=walker) - for (level, stmt) in visitor[root]: - yield (level, stmt) + yield from PseudocodeVisitor(root=root) diff --git a/src/openpower/oppc/pc_util.py b/src/openpower/oppc/pc_util.py new file mode 100644 index 00000000..3127f64a --- /dev/null +++ b/src/openpower/oppc/pc_util.py @@ -0,0 +1,62 @@ +import contextlib +import functools + + +import mdis.visitor +import mdis.walker +import mdis.dispatcher + + +class Hook(mdis.dispatcher.Hook): + def __call__(self, call): + hook = super().__call__(call) + + class ConcreteHook(hook.__class__): + @functools.wraps(hook.__call__) + @contextlib.contextmanager + def __call__(self, dispatcher, node, *args, **kwargs): + return hook(dispatcher, node, *args, **kwargs) + + return ConcreteHook(*tuple(self)) + + +class Code(list): + def __init__(self): + self.__level = 0 + return super().__init__() + + def __enter__(self): + self.__level += 1 + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.__level -= 1 + + def __str__(self): + if len(self) == 0: + raise ValueError("empty code") + + lines = [] + for (level, stmt) in self: + line = ((" " * level * 4) + stmt) + lines.append(line) + + return "\n".join(lines) + + def emit(self, stmt, level=0): + item = ((level + self.__level), stmt) + self.append(item) + + +class Visitor(mdis.visitor.ContextVisitor): + def __init__(self, root): + self.__walker = mdis.walker.Walker() + + self.traverse(root=root) + + return super().__init__() + + def traverse(self, root): + with self(root): + for node in self.__walker(root): + self.traverse(root=node)