oppc/code: fix if exprs check
[openpower-isa.git] / src / openpower / oppc / pc_code.py
1 import collections
2 import contextlib
3
4 import openpower.oppc.pc_ast as pc_ast
5 import openpower.oppc.pc_util as pc_util
6 import openpower.oppc.pc_pseudocode as pc_pseudocode
7
8
9 class Transient(pc_ast.Node):
10 def __init__(self, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
11 self.__value = value
12 self.__bits = bits
13
14 return super().__init__()
15
16 def __str__(self):
17 return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
18
19
20 class CCall(pc_ast.Dataclass):
21 name: str
22 code: tuple
23 stmt: bool
24
25
26 class CodeVisitor(pc_util.Visitor):
27 def __init__(self, name, root):
28 self.__root = root
29 self.__header = object()
30 self.__footer = object()
31 self.__code = collections.defaultdict(lambda: pc_util.Code())
32 self.__decls = collections.defaultdict(list)
33 self.__regfetch = collections.defaultdict(list)
34 self.__regstore = collections.defaultdict(list)
35 self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
36
37 super().__init__(root=root)
38
39 self.__code[self.__header].emit(stmt="void")
40 self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
41 with self.__code[self.__header]:
42 for decl in self.__decls:
43 self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
44 self.__code[self.__footer].emit(stmt=f"}}")
45
46 def __iter__(self):
47 yield from self.__code[self.__header]
48 yield from self.__code[self.__root]
49 yield from self.__code[self.__footer]
50
51 def __getitem__(self, node):
52 return self.__code[node]
53
54 def transient(self, node,
55 value="UINT64_C(0)",
56 bits="(uint8_t)OPPC_XLEN"):
57 transient = Transient(value=value, bits=bits)
58 self.traverse(root=transient)
59 return transient
60
61 def ccall(self, node, name, code, stmt=False):
62 def validate(item):
63 def validate(item):
64 (level, stmt) = item
65 if not isinstance(level, int):
66 raise ValueError(level)
67 if not isinstance(stmt, str):
68 raise ValueError(stmt)
69 return (level, stmt)
70
71 return tuple(map(validate, item))
72
73 code = tuple(map(validate, code))
74 ccall = CCall(name=name, code=code, stmt=stmt)
75 self.traverse(root=ccall)
76 return ccall
77
78 def ternary(self, node):
79 self[node].clear()
80 test = self.ccall(name="oppc_bool", node=node, code=[
81 self[node.test],
82 ])
83 self[node].emit(stmt="(")
84 with self[node]:
85 for (level, stmt) in self[test]:
86 self[node].emit(stmt=stmt, level=level)
87 self[node].emit(stmt="?")
88 for (level, stmt) in self[node.body]:
89 self[node].emit(stmt=stmt, level=level)
90 self[node].emit(stmt=":")
91 for (level, stmt) in self[node.orelse]:
92 self[node].emit(stmt=stmt, level=level)
93 self[node].emit(stmt=")")
94
95 @contextlib.contextmanager
96 def pseudocode(self, node):
97 for (level, stmt) in self.__pseudocode[node]:
98 self[node].emit(stmt=f"/* {stmt} */", level=level)
99 yield
100
101 @pc_util.Hook(pc_ast.Scope)
102 def Scope(self, node):
103 yield node
104 with self[node]:
105 for subnode in node:
106 for (level, stmt) in self[subnode]:
107 self[node].emit(stmt=stmt, level=level)
108
109 @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
110 def AssignExpr(self, node):
111 yield node
112 if isinstance(node.lvalue, (pc_ast.GPR, pc_ast.FPR)):
113 self.__regstore[str(node.lvalue)].append(node.lvalue)
114 if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
115 self.__regfetch[str(node.rvalue)].append(node.rvalue)
116
117 if isinstance(node.rvalue, pc_ast.IfExpr):
118 self.ternary(node=node.rvalue)
119
120 if isinstance(node.lvalue, pc_ast.SubscriptExpr):
121 ccall = self.ccall(name="oppc_subscript_assign", node=node, stmt=True, code=[
122 self[node.lvalue.subject],
123 self[node.lvalue.index],
124 self[node.rvalue],
125 ])
126 elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
127 ccall = self.ccall(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
128 self[node.lvalue.subject],
129 self[node.lvalue.start],
130 self[node.lvalue.end],
131 self[node.rvalue],
132 ])
133 else:
134 ccall = self.ccall(name="oppc_assign", stmt=True, node=node, code=[
135 self[node.lvalue],
136 self[node.rvalue],
137 ])
138 with self.pseudocode(node=node):
139 for (level, stmt) in self[ccall]:
140 self[node].emit(stmt=stmt, level=level)
141
142 @pc_util.Hook(pc_ast.BinaryExpr)
143 def BinaryExpr(self, node):
144 yield node
145 if isinstance(node.left, (pc_ast.GPR, pc_ast.FPR)):
146 self.__regfetch[str(node.left)].append(node.left)
147 if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
148 self.__regfetch[str(node.right)].append(node.left)
149
150 comparison = (
151 pc_ast.Lt, pc_ast.Le,
152 pc_ast.Eq, pc_ast.NotEq,
153 pc_ast.Ge, pc_ast.Gt,
154 pc_ast.LtU, pc_ast.GtU,
155 )
156 if isinstance(node.left, pc_ast.IfExpr):
157 self.ternary(node=node.left)
158 if isinstance(node.right, pc_ast.IfExpr):
159 self.ternary(node=node.right)
160
161 if isinstance(node.op, comparison):
162 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
163 self[node.left],
164 self[node.right],
165 ])
166 else:
167 transient = self.transient(node=node)
168 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
169 self[transient],
170 self[node.left],
171 self[node.right],
172 ])
173 with self.pseudocode(node=node):
174 for (level, stmt) in self[ccall]:
175 self[node].emit(stmt=stmt, level=level)
176
177 @pc_util.Hook(pc_ast.UnaryExpr)
178 def UnaryExpr(self, node):
179 yield node
180 if isinstance(node.value, pc_ast.IfExpr):
181 self.ternary(node=node.value)
182 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
183 self[node.value],
184 ])
185 with self.pseudocode(node=node):
186 for (level, stmt) in self[ccall]:
187 self[node].emit(stmt=stmt, level=level)
188
189 @pc_util.Hook(
190 pc_ast.Not, pc_ast.Add, pc_ast.Sub,
191 pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
192 pc_ast.Lt, pc_ast.Le,
193 pc_ast.Eq, pc_ast.NotEq,
194 pc_ast.Ge, pc_ast.Gt,
195 pc_ast.LtU, pc_ast.GtU,
196 pc_ast.LShift, pc_ast.RShift,
197 pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
198 pc_ast.BitConcat,
199 )
200 def Op(self, node):
201 yield node
202 op = {
203 pc_ast.Not: "oppc_not",
204 pc_ast.Add: "oppc_add",
205 pc_ast.Sub: "oppc_sub",
206 pc_ast.Mul: "oppc_mul",
207 pc_ast.Div: "oppc_div",
208 pc_ast.Mod: "oppc_mod",
209 pc_ast.Lt: "oppc_lt",
210 pc_ast.Le: "oppc_le",
211 pc_ast.Eq: "oppc_eq",
212 pc_ast.LtU: "oppc_ltu",
213 pc_ast.GtU: "oppc_gtu",
214 pc_ast.NotEq: "oppc_noteq",
215 pc_ast.Ge: "oppc_ge",
216 pc_ast.Gt: "oppc_gt",
217 pc_ast.LShift: "oppc_lshift",
218 pc_ast.RShift: "oppc_rshift",
219 pc_ast.BitAnd: "oppc_and",
220 pc_ast.BitOr: "oppc_or",
221 pc_ast.BitXor: "oppc_xor",
222 pc_ast.BitConcat: "oppc_concat",
223 }[node.__class__]
224 self[node].emit(stmt=op)
225
226 @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
227 def Integer(self, node):
228 yield node
229 fmt = hex
230 value = str(node)
231 if isinstance(node, pc_ast.BinLiteral):
232 bits = f"UINT8_C({str(len(value[2:]))})"
233 value = int(value, 2)
234 elif isinstance(node, pc_ast.HexLiteral):
235 bits = f"UINT8_C({str(len(value[2:]) * 4)})"
236 value = int(value, 16)
237 else:
238 bits = "(uint8_t)OPPC_XLEN"
239 value = int(value)
240 fmt = str
241 if (value > ((2**64) - 1)):
242 raise NotImplementedError()
243 value = f"UINT64_C({fmt(value)})"
244 transient = self.transient(node=node, value=value, bits=bits)
245 with self.pseudocode(node=node):
246 for (level, stmt) in self[transient]:
247 self[node].emit(stmt=stmt, level=level)
248
249 @pc_util.Hook(Transient)
250 def Transient(self, node):
251 yield node
252 self[node].emit(stmt=str(node))
253
254 @pc_util.Hook(CCall)
255 def CCall(self, node):
256 yield node
257 end = (";" if node.stmt else "")
258 if len(node.code) == 0:
259 self[node].emit(stmt=f"{str(node.name)}(){end}")
260 else:
261 self[node].emit(stmt=f"{str(node.name)}(")
262 with self[node]:
263 (*head, tail) = node.code
264 for code in head:
265 for (level, stmt) in code:
266 self[node].emit(stmt=stmt, level=level)
267 (level, stmt) = self[node][-1]
268 if not (not stmt or
269 stmt.startswith("/*") or
270 stmt.endswith((",", "(", "{", "*/"))):
271 stmt = (stmt + ",")
272 self[node][-1] = (level, stmt)
273 for (level, stmt) in tail:
274 self[node].emit(stmt=stmt, level=level)
275 self[node].emit(stmt=f"){end}")
276
277 @pc_util.Hook(pc_ast.GPR)
278 def GPR(self, node):
279 yield node
280 with self.pseudocode(node=node):
281 self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
282
283 @pc_util.Hook(pc_ast.FPR)
284 def FPR(self, node):
285 yield node
286 with self.pseudocode(node=node):
287 self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
288
289 @pc_util.Hook(pc_ast.RepeatExpr)
290 def RepeatExpr(self, node):
291 yield node
292 transient = self.transient(node=node)
293 ccall = self.ccall(name="oppc_repeat", node=node, code=[
294 self[transient],
295 self[node.subject],
296 self[node.times],
297 ])
298 for (level, stmt) in self[ccall]:
299 self[node].emit(stmt=stmt, level=level)
300
301 @pc_util.Hook(pc_ast.XLEN)
302 def XLEN(self, node):
303 yield node
304 (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
305 transient = self.transient(node=node, value=value, bits=bits)
306 with self.pseudocode(node=node):
307 for (level, stmt) in self[transient]:
308 self[node].emit(stmt=stmt, level=level)
309
310 @pc_util.Hook(pc_ast.Overflow)
311 def Overflow(self, node):
312 yield node
313 (value, bits) = ("OPPC_OVERFLOW", "(uint8_t)OPPC_OVERFLOW")
314 transient = self.transient(node=node, value=value, bits=bits)
315 with self.pseudocode(node=node):
316 for (level, stmt) in self[transient]:
317 self[node].emit(stmt=stmt, level=level)
318
319 @pc_util.Hook(pc_ast.SubscriptExpr)
320 def SubscriptExpr(self, node):
321 yield node
322 ccall = self.ccall(name="oppc_subscript", node=node, code=[
323 self[node.subject],
324 self[node.index],
325 ])
326 for (level, stmt) in self[ccall]:
327 self[node].emit(stmt=stmt, level=level)
328
329 @pc_util.Hook(pc_ast.RangeSubscriptExpr)
330 def RangeSubscriptExpr(self, node):
331 yield node
332 ccall = self.ccall(name="oppc_subscript", node=node, code=[
333 self[node.subject],
334 self[node.start],
335 self[node.end],
336 ])
337 for (level, stmt) in self[ccall]:
338 self[node].emit(stmt=stmt, level=level)
339
340 @pc_util.Hook(pc_ast.ForExpr)
341 def ForExpr(self, node):
342 yield node
343
344 enter = pc_ast.AssignExpr(
345 lvalue=node.subject.clone(),
346 rvalue=node.start.clone(),
347 )
348 match = pc_ast.BinaryExpr(
349 left=node.subject.clone(),
350 op=pc_ast.Le("<="),
351 right=node.end.clone(),
352 )
353 leave = pc_ast.AssignExpr(
354 lvalue=node.subject.clone(),
355 rvalue=pc_ast.BinaryExpr(
356 left=node.subject.clone(),
357 op=pc_ast.Add("+"),
358 right=node.end.clone(),
359 ),
360 )
361 with self.pseudocode(node=node):
362 (level, stmt) = self[node][0]
363 self[node].clear()
364 self[node].emit(stmt=stmt, level=level)
365 self[node].emit(stmt="for (")
366 with self[node]:
367 with self[node]:
368 for subnode in (enter, match, leave):
369 self.__pseudocode.traverse(root=subnode)
370 self.traverse(root=subnode)
371 for (level, stmt) in self[subnode]:
372 self[node].emit(stmt=stmt, level=level)
373 (level, stmt) = self[node][-1]
374 if subnode is match:
375 stmt = f"{stmt};"
376 elif subnode is leave:
377 stmt = stmt[:-1]
378 self[node][-1] = (level, stmt)
379 (level, stmt) = self[node][0]
380 self[node].emit(stmt=stmt, level=level)
381 self[node].emit(stmt=") {")
382 for (level, stmt) in self[node.body]:
383 self[node].emit(stmt=stmt, level=level)
384 self[node].emit(stmt="}")
385
386 @pc_util.Hook(pc_ast.WhileExpr)
387 def WhileExpr(self, node):
388 yield node
389 self[node].emit(stmt="while (")
390 with self[node]:
391 with self[node]:
392 for (level, stmt) in self[node.test]:
393 self[node].emit(stmt=stmt, level=level)
394 self[node].emit(") {")
395 for (level, stmt) in self[node.body]:
396 self[node].emit(stmt=stmt, level=level)
397 if node.orelse:
398 self[node].emit(stmt="} else {")
399 for (level, stmt) in self[node.orelse]:
400 self[node].emit(stmt=stmt, level=level)
401 self[node].emit(stmt="}")
402
403 @pc_util.Hook(pc_ast.IfExpr)
404 def IfExpr(self, node):
405 yield node
406 test = self.ccall(name="oppc_bool", node=node, code=[
407 self[node.test],
408 ])
409 self[node].emit(stmt="if (")
410 with self[node]:
411 for (level, stmt) in self[test]:
412 self[node].emit(stmt=stmt, level=level)
413 self[node].emit(stmt=") {")
414 for (level, stmt) in self[node.body]:
415 self[node].emit(stmt=stmt, level=level)
416 if node.orelse:
417 self[node].emit(stmt="} else {")
418 for (level, stmt) in self[node.orelse]:
419 self[node].emit(stmt=stmt, level=level)
420 self[node].emit(stmt="}")
421
422 @pc_util.Hook(pc_ast.SwitchExpr)
423 def SwitchExpr(self, node):
424 yield node
425 subject = self.ccall(name="oppc_int64", node=node, code=[
426 self[node.subject],
427 ])
428 self[node].emit(stmt="switch (")
429 with self[node]:
430 for (level, stmt) in self[subject]:
431 self[node].emit(stmt=stmt, level=level)
432 self[node].emit(") {")
433 with self[node]:
434 for (level, stmt) in self[node.cases]:
435 self[node].emit(stmt=stmt, level=level)
436
437 @pc_util.Hook(pc_ast.Cases)
438 def Cases(self, node):
439 yield node
440 for subnode in node:
441 for (level, stmt) in self[subnode]:
442 self[node].emit(stmt=stmt, level=level)
443
444 @pc_util.Hook(pc_ast.Case)
445 def Case(self, node):
446 yield node
447 for (level, stmt) in self[node.labels]:
448 self[node].emit(stmt=stmt, level=level)
449 for (level, stmt) in self[node.body]:
450 self[node].emit(stmt=stmt, level=level)
451
452 @pc_util.Hook(pc_ast.Labels)
453 def Labels(self, node):
454 yield node
455 if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)):
456 stmt = "default:"
457 else:
458 labels = ", ".join(map(lambda label: str(self[label]), node))
459 stmt = f"case ({labels}):"
460 self[node].emit(stmt=stmt)
461
462 @pc_util.Hook(pc_ast.Label)
463 def Label(self, node):
464 yield node
465 self[node].emit(stmt=str(node))
466
467 @pc_util.Hook(pc_ast.Call.Name)
468 def CallName(self, node):
469 yield node
470 self[node].emit(stmt=str(node))
471
472 @pc_util.Hook(pc_ast.Call.Arguments)
473 def CallArguments(self, node):
474 yield node
475 for subnode in node:
476 if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
477 self.__regfetch[str(subnode)].append(subnode)
478
479 @pc_util.Hook(pc_ast.Call)
480 def Call(self, node):
481 yield node
482 code = tuple(map(lambda arg: self[arg], node.args))
483 ccall = self.ccall(name=str(node.name), node=node, code=code)
484 for (level, stmt) in self[ccall]:
485 self[node].emit(stmt=stmt, level=level)
486
487 @pc_util.Hook(pc_ast.Symbol)
488 def Symbol(self, node):
489 yield node
490 with self.pseudocode(node=node):
491 if str(node) not in ("fallthrough",):
492 self.__decls[str(node)].append(node)
493 self[node].emit(stmt=f"&{str(node)}")
494
495 @pc_util.Hook(pc_ast.Node)
496 def Node(self, node):
497 raise NotImplementedError(type(node))
498
499
500 def code(name, root):
501 yield from CodeVisitor(name=name, root=root)