oppc/code: introduce transient
[openpower-isa.git] / src / openpower / oppc / pc_code.py
1 import collections
2 import contextlib
3
4 import openpower.oppc.pc_ast as pc_ast
5 import openpower.oppc.pc_util as pc_util
6 import openpower.oppc.pc_pseudocode as pc_pseudocode
7
8
9 class Transient(pc_ast.Node):
10 def __init__(self, value="UINT64_C(0)", bits="(uint8_t)ctx->XLEN"):
11 self.__value = value
12 self.__bits = bits
13
14 return super().__init__()
15
16 def __str__(self):
17 return f"oppc_transient(&(struct oppc_int){{}}, {self.__value}, {self.__bits})"
18
19
20 class CodeVisitor(pc_util.Visitor):
21 def __init__(self, name, root):
22 self.__root = root
23 self.__header = object()
24 self.__footer = object()
25 self.__code = collections.defaultdict(lambda: pc_util.Code())
26 self.__decls = collections.defaultdict(list)
27 self.__regfetch = collections.defaultdict(list)
28 self.__regstore = collections.defaultdict(list)
29 self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
30
31 super().__init__(root=root)
32
33 self.__code[self.__header].emit(stmt="void")
34 self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
35 with self.__code[self.__header]:
36 for decl in self.__decls:
37 self.__code[self.__header].emit(stmt=f"struct oppc_int {decl};")
38 self.__code[self.__footer].emit(stmt=f"}}")
39
40 def __iter__(self):
41 yield from self.__code[self.__header]
42 yield from self.__code[self.__root]
43 yield from self.__code[self.__footer]
44
45 def __getitem__(self, node):
46 return self.__code[node]
47
48 def transient(self, node, value="UINT64_C(0)", bits="(uint8_t)ctx->XLEN"):
49 transient = Transient(value=value, bits=bits)
50 with self.pseudocode(node=node):
51 self.traverse(root=transient)
52 return transient
53
54 @contextlib.contextmanager
55 def pseudocode(self, node):
56 for (level, stmt) in self.__pseudocode[node]:
57 self[node].emit(stmt=f"/* {stmt} */", level=level)
58 yield
59
60 def call(self, node, code, prefix="", suffix=""):
61 with self.pseudocode(node=node):
62 self[node].emit(stmt=f"{prefix}(")
63 with self[node]:
64 for chunk in code[:-1]:
65 for (level, stmt) in chunk:
66 if not (not stmt or
67 stmt.startswith("/*") or
68 stmt.endswith((",", "(", "{", "*/"))):
69 stmt = (stmt + ",")
70 self[node].emit(stmt=stmt, level=level)
71 if len(code) > 0:
72 for (level, stmt) in code[-1]:
73 if stmt:
74 self[node].emit(stmt=stmt, level=level)
75 self[node].emit(stmt=f"){suffix}")
76
77 @pc_util.Hook(pc_ast.Scope)
78 def Scope(self, node):
79 yield node
80 with self[node]:
81 for subnode in node:
82 for (level, stmt) in self[subnode]:
83 self[node].emit(stmt=stmt, level=level)
84
85 @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
86 def AssignExpr(self, node):
87 yield node
88 if isinstance(node.lvalue, (pc_ast.GPR, pc_ast.FPR)):
89 self.__regstore[str(node.lvalue)].append(node.lvalue)
90 if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
91 self.__regfetch[str(node.rvalue)].append(node.rvalue)
92
93 rvalue = self[node.rvalue]
94 if isinstance(node.rvalue, pc_ast.IfExpr):
95 rvalue = [(0, " ".join([
96 str(self[node.rvalue.test]),
97 "?",
98 str(self[node.rvalue.body[0]]),
99 ":",
100 str(self[node.rvalue.orelse[0]]),
101 ]))]
102
103 if isinstance(node.lvalue, pc_ast.SubscriptExpr):
104 self.call(prefix="oppc_subscript_assign", suffix=";", node=node, code=[
105 self[node.lvalue.subject],
106 self[node.lvalue.index],
107 rvalue,
108 ])
109 elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
110 self.call(prefix="oppc_range_subscript_assign", suffix=";", node=node, code=[
111 self[node.lvalue.subject],
112 self[node.lvalue.start],
113 self[node.lvalue.end],
114 rvalue,
115 ])
116 else:
117 self.call(prefix="oppc_assign", suffix=";", node=node, code=[
118 self[node.lvalue],
119 rvalue,
120 ])
121
122 @pc_util.Hook(pc_ast.BinaryExpr)
123 def BinaryExpr(self, node):
124 yield node
125 if isinstance(node.left, (pc_ast.GPR, pc_ast.FPR)):
126 self.__regfetch[str(node.left)].append(node.left)
127 if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
128 self.__regfetch[str(node.right)].append(node.left)
129
130 comparison = (
131 pc_ast.Lt, pc_ast.Le,
132 pc_ast.Eq, pc_ast.NotEq,
133 pc_ast.Ge, pc_ast.Gt,
134 )
135 if isinstance(node.op, comparison):
136 self.call(prefix=str(self[node.op]), node=node, code=[
137 self[node.left],
138 self[node.right],
139 ])
140 else:
141 transient = self.transient(node=node)
142 self.call(prefix=str(self[node.op]), node=node, code=[
143 self[transient],
144 self[node.left],
145 self[node.right],
146 ])
147
148 @pc_util.Hook(pc_ast.UnaryExpr)
149 def UnaryExpr(self, node):
150 yield node
151 self.call(prefix=str(self[node.op]), node=node, code=[
152 self[node.value],
153 ])
154
155 @pc_util.Hook(
156 pc_ast.Not, pc_ast.Add, pc_ast.Sub,
157 pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
158 pc_ast.Lt, pc_ast.Le,
159 pc_ast.Eq, pc_ast.NotEq,
160 pc_ast.Ge, pc_ast.Gt,
161 pc_ast.LShift, pc_ast.RShift,
162 pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
163 )
164 def Op(self, node):
165 yield node
166 op = {
167 pc_ast.Not: "oppc_not",
168 pc_ast.Add: "oppc_add",
169 pc_ast.Sub: "oppc_sub",
170 pc_ast.Mul: "oppc_mul",
171 pc_ast.Div: "oppc_div",
172 pc_ast.Mod: "oppc_mod",
173 pc_ast.Lt: "oppc_lt",
174 pc_ast.Le: "oppc_le",
175 pc_ast.Eq: "oppc_eq",
176 pc_ast.NotEq: "oppc_noteq",
177 pc_ast.Ge: "oppc_ge",
178 pc_ast.Gt: "oppc_gt",
179 pc_ast.LShift: "oppc_lshift",
180 pc_ast.RShift: "oppc_rshift",
181 pc_ast.BitAnd: "oppc_and",
182 pc_ast.BitOr: "oppc_or",
183 pc_ast.BitXor: "oppc_xor",
184 }[node.__class__]
185 self[node].emit(stmt=op)
186
187 @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
188 def Integer(self, node):
189 yield node
190 fmt = hex
191 value = str(node)
192 if isinstance(node, pc_ast.BinLiteral):
193 bits = f"UINT8_C({str(len(value[2:]))})"
194 value = int(value, 2)
195 elif isinstance(node, pc_ast.HexLiteral):
196 bits = f"UINT8_C({str(len(value[2:]) * 4)})"
197 value = int(value, 16)
198 else:
199 bits = "ctx->XLEN"
200 value = int(value)
201 fmt = str
202 if (value > ((2**64) - 1)):
203 raise NotImplementedError()
204 value = f"UINT64_C({fmt(value)})"
205 transient = self.transient(node=node, value=value, bits=bits)
206 for (level, stmt) in self[transient]:
207 self[node].emit(stmt=stmt, level=level)
208
209 @pc_util.Hook(Transient)
210 def Transient(self, node):
211 yield node
212 self[node].emit(stmt=str(node))
213
214 @pc_util.Hook(pc_ast.GPR)
215 def GPR(self, node):
216 yield node
217 with self.pseudocode(node=node):
218 self[node].emit(stmt=f"&ctx->gpr[OPPC_GPR_{str(node)}]")
219
220 @pc_util.Hook(pc_ast.FPR)
221 def FPR(self, node):
222 yield node
223 with self.pseudocode(node=node):
224 self[node].emit(stmt=f"&ctx->fpr[OPPC_FPR_{str(node)}]")
225
226 @pc_util.Hook(pc_ast.RepeatExpr)
227 def RepeatExpr(self, node):
228 yield node
229 transient = self.transient(node=node)
230 self.call(prefix="oppc_repeat", node=node, code=[
231 self[transient],
232 self[node.subject],
233 self[node.times],
234 ])
235
236 @pc_util.Hook(pc_ast.XLEN)
237 def XLEN(self, node):
238 yield node
239 (value, bits) = ("ctx->XLEN", "(uint8_t)ctx->XLEN")
240 transient = self.transient(node=node, value=value, bits=bits)
241 for (level, stmt) in self[transient]:
242 self[node].emit(stmt=stmt, level=level)
243
244 @pc_util.Hook(pc_ast.SubscriptExpr)
245 def SubscriptExpr(self, node):
246 yield node
247 self.call(prefix="oppc_subscript", node=node, code=[
248 self[node.subject],
249 self[node.index],
250 ])
251
252 @pc_util.Hook(pc_ast.RangeSubscriptExpr)
253 def RangeSubscriptExpr(self, node):
254 yield node
255 self.call(prefix="oppc_subscript", node=node, code=[
256 self[node.subject],
257 self[node.start],
258 self[node.end],
259 ])
260
261 @pc_util.Hook(pc_ast.ForExpr)
262 def ForExpr(self, node):
263 yield node
264
265 enter = pc_ast.AssignExpr(
266 lvalue=node.subject.clone(),
267 rvalue=node.start.clone(),
268 )
269 match = pc_ast.BinaryExpr(
270 left=node.subject.clone(),
271 op=pc_ast.Le("<="),
272 right=node.end.clone(),
273 )
274 leave = pc_ast.AssignExpr(
275 lvalue=node.subject.clone(),
276 rvalue=pc_ast.BinaryExpr(
277 left=node.subject.clone(),
278 op=pc_ast.Add("+"),
279 right=node.end.clone(),
280 ),
281 )
282 with self.pseudocode(node=node):
283 (level, stmt) = self[node][0]
284 self[node].clear()
285 self[node].emit(stmt=stmt, level=level)
286 self[node].emit(stmt="for (")
287 with self[node]:
288 with self[node]:
289 for subnode in (enter, match, leave):
290 self.__pseudocode.traverse(root=subnode)
291 self.traverse(root=subnode)
292 for (level, stmt) in self[subnode]:
293 self[node].emit(stmt=stmt, level=level)
294 (level, stmt) = self[node][-1]
295 if subnode is match:
296 stmt = f"{stmt};"
297 elif subnode is leave:
298 stmt = stmt[:-1]
299 self[node][-1] = (level, stmt)
300 (level, stmt) = self[node][0]
301 self[node].emit(stmt=stmt, level=level)
302 self[node].emit(stmt=") {")
303 for (level, stmt) in self[node.body]:
304 self[node].emit(stmt=stmt, level=level)
305 self[node].emit(stmt="}")
306
307 @pc_util.Hook(pc_ast.WhileExpr)
308 def WhileExpr(self, node):
309 yield node
310 self[node].emit(stmt="while (")
311 with self[node]:
312 with self[node]:
313 for (level, stmt) in self[node.test]:
314 self[node].emit(stmt=stmt, level=level)
315 self[node].emit(") {")
316 for (level, stmt) in self[node.body]:
317 self[node].emit(stmt=stmt, level=level)
318 if node.orelse:
319 self[node].emit(stmt="} else {")
320 for (level, stmt) in self[node.orelse]:
321 self[node].emit(stmt=stmt, level=level)
322 self[node].emit(stmt="}")
323
324 @pc_util.Hook(pc_ast.IfExpr)
325 def IfExpr(self, node):
326 yield node
327 self[node].emit(stmt="if (")
328 with self[node]:
329 for (level, stmt) in self[node.test]:
330 self[node].emit(stmt=stmt, level=level)
331 self[node].emit(stmt=") {")
332 for (level, stmt) in self[node.body]:
333 self[node].emit(stmt=stmt, level=level)
334 if node.orelse:
335 self[node].emit(stmt="} else {")
336 for (level, stmt) in self[node.orelse]:
337 self[node].emit(stmt=stmt, level=level)
338 self[node].emit(stmt="}")
339
340 @pc_util.Hook(pc_ast.Call.Name)
341 def CallName(self, node):
342 yield node
343 self[node].emit(stmt=str(node))
344
345 @pc_util.Hook(pc_ast.Call.Arguments)
346 def CallArguments(self, node):
347 yield node
348 for subnode in node:
349 if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
350 self.__regfetch[str(subnode)].append(subnode)
351
352 @pc_util.Hook(pc_ast.Call)
353 def Call(self, node):
354 yield node
355 code = tuple(map(lambda arg: self[arg], node.args))
356 self.call(prefix=str(node.name), node=node, code=code)
357
358 @pc_util.Hook(pc_ast.Symbol)
359 def Symbol(self, node):
360 yield node
361 self.__decls[str(node)].append(node)
362 self[node].emit(stmt=f"&{str(node)}")
363
364 @pc_util.Hook(pc_ast.Node)
365 def Node(self, node):
366 raise NotImplementedError(type(node))
367
368
369 def code(name, root):
370 yield from CodeVisitor(name=name, root=root)