oppc: support attributes
[openpower-isa.git] / src / openpower / oppc / pc_code.py
1 import collections
2 import contextlib
3
4 import openpower.oppc.pc_ast as pc_ast
5 import openpower.oppc.pc_util as pc_util
6 import openpower.oppc.pc_pseudocode as pc_pseudocode
7
8
9 class Transient(pc_ast.Node):
10 def __init__(self, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
11 self.__value = value
12 self.__bits = bits
13
14 return super().__init__()
15
16 def __str__(self):
17 return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
18
19
20 class Call(pc_ast.Dataclass):
21 name: str
22 code: tuple
23 stmt: bool
24
25
26 class CodeVisitor(pc_util.Visitor):
27 def __init__(self, name, root):
28 self.__root = root
29 self.__attrs = {}
30 self.__header = object()
31 self.__footer = object()
32 self.__code = collections.defaultdict(lambda: pc_util.Code())
33 self.__decls = collections.defaultdict(list)
34 self.__regfetch = collections.defaultdict(list)
35 self.__regstore = collections.defaultdict(list)
36 self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
37
38 super().__init__(root=root)
39
40 self.__code[self.__header].emit(stmt="void")
41 self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
42 with self.__code[self.__header]:
43 for decl in self.__decls:
44 self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
45 self.__code[self.__footer].emit(stmt=f"}}")
46
47 def __iter__(self):
48 yield from self.__code[self.__header]
49 yield from self.__code[self.__root]
50 yield from self.__code[self.__footer]
51
52 def __getitem__(self, node):
53 return self.__code[node]
54
55 def __setitem__(self, node, code):
56 self.__code[node] = code
57
58 def transient(self, node,
59 value="UINT64_C(0)",
60 bits="(uint8_t)OPPC_XLEN"):
61 transient = Transient(value=value, bits=bits)
62 self.traverse(root=transient)
63 return transient
64
65 def call(self, node, name, code, stmt=False):
66 def validate(item):
67 def validate(item):
68 (level, stmt) = item
69 if not isinstance(level, int):
70 raise ValueError(level)
71 if not isinstance(stmt, str):
72 raise ValueError(stmt)
73 return (level, stmt)
74
75 return tuple(map(validate, item))
76
77 code = tuple(map(validate, code))
78 call = Call(name=name, code=code, stmt=stmt)
79 self.traverse(root=call)
80 return call
81
82 def fixup_ternary(self, node):
83 self[node].clear()
84 test = self.call(name="oppc_bool", node=node, code=[
85 self[node.test],
86 ])
87 self[node].emit(stmt="(")
88 with self[node]:
89 for (level, stmt) in self[test]:
90 self[node].emit(stmt=stmt, level=level)
91 self[node].emit(stmt="?")
92 for (level, stmt) in self[node.body]:
93 self[node].emit(stmt=stmt, level=level)
94 self[node].emit(stmt=":")
95 for (level, stmt) in self[node.orelse]:
96 self[node].emit(stmt=stmt, level=level)
97 self[node].emit(stmt=")")
98
99 def fixup_attr(self, node, assign=False):
100 root = node
101 code = tuple(self[root])
102 attribute_or_subscript = (
103 pc_ast.Attribute,
104 pc_ast.SubscriptExpr,
105 pc_ast.RangeSubscriptExpr,
106 )
107 while isinstance(node.subject, attribute_or_subscript):
108 node = node.subject
109
110 def wrap(code):
111 def wrap(item):
112 (level, stmt) = item
113 if not (not stmt or
114 stmt.startswith("/*") or
115 stmt.endswith((",", "(", "{", "*/"))):
116 stmt = (stmt + ",")
117 return (level, stmt)
118
119 return tuple(map(wrap, code))
120
121 code = pc_util.Code()
122 for (level, stmt) in wrap(self[node.subject]):
123 code.emit(stmt=stmt, level=level)
124 for (level, stmt) in wrap(self[root]):
125 code.emit(stmt=stmt, level=level)
126
127 # discard the last comma
128 (level, stmt) = code[-1]
129 code[-1] = (level, stmt[:-1])
130
131 if not assign:
132 call = self.call(name="oppc_attr", node=root, code=[
133 code,
134 ])
135 code = self[call]
136 self[root] = code
137
138 @contextlib.contextmanager
139 def pseudocode(self, node):
140 for (level, stmt) in self.__pseudocode[node]:
141 self[node].emit(stmt=f"/* {stmt} */", level=level)
142 yield
143
144 @pc_util.Hook(pc_ast.Scope)
145 def Scope(self, node):
146 yield node
147 with self[node]:
148 for subnode in node:
149 for (level, stmt) in self[subnode]:
150 self[node].emit(stmt=stmt, level=level)
151
152 @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
153 def AssignExpr(self, node):
154 yield node
155 if isinstance(node.lvalue, (pc_ast.GPR, pc_ast.FPR)):
156 self.__regstore[str(node.lvalue)].append(node.lvalue)
157 if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
158 self.__regfetch[str(node.rvalue)].append(node.rvalue)
159
160 if isinstance(node.rvalue, pc_ast.IfExpr):
161 self.fixup_ternary(node=node.rvalue)
162 if isinstance(node.lvalue, pc_ast.Attribute):
163 self.fixup_attr(node=node.lvalue, assign=True)
164 if isinstance(node.rvalue, pc_ast.Attribute):
165 self.fixup_attr(node=node.rvalue)
166
167 if isinstance(node.lvalue, pc_ast.SubscriptExpr):
168 call = self.call(name="oppc_subscript_assign", node=node, stmt=True, code=[
169 self[node.lvalue.subject],
170 self[node.lvalue.index],
171 self[node.rvalue],
172 ])
173 elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
174 call = self.call(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
175 self[node.lvalue.subject],
176 self[node.lvalue.start],
177 self[node.lvalue.end],
178 self[node.rvalue],
179 ])
180 elif isinstance(node.lvalue, pc_ast.Attribute):
181 call = self.call(name="oppc_attr_assign", stmt=True, node=node, code=[
182 self[node.lvalue],
183 self[node.rvalue],
184 ])
185 else:
186 call = self.call(name="oppc_assign", stmt=True, node=node, code=[
187 self[node.lvalue],
188 self[node.rvalue],
189 ])
190 with self.pseudocode(node=node):
191 for (level, stmt) in self[call]:
192 self[node].emit(stmt=stmt, level=level)
193
194 @pc_util.Hook(pc_ast.BinaryExpr)
195 def BinaryExpr(self, node):
196 yield node
197 if isinstance(node.left, (pc_ast.GPR, pc_ast.FPR)):
198 self.__regfetch[str(node.left)].append(node.left)
199 if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
200 self.__regfetch[str(node.right)].append(node.left)
201
202 comparison = (
203 pc_ast.Lt, pc_ast.Le,
204 pc_ast.Eq, pc_ast.NotEq,
205 pc_ast.Ge, pc_ast.Gt,
206 pc_ast.LtU, pc_ast.GtU,
207 )
208 if isinstance(node.left, pc_ast.IfExpr):
209 self.fixup_ternary(node=node.left)
210 if isinstance(node.right, pc_ast.IfExpr):
211 self.fixup_ternary(node=node.right)
212 if isinstance(node.left, pc_ast.Attribute):
213 self.fixup_attr(node=node.left)
214 if isinstance(node.right, pc_ast.Attribute):
215 self.fixup_attr(node=node.right)
216
217 if isinstance(node.op, comparison):
218 call = self.call(name=str(self[node.op]), node=node, code=[
219 self[node.left],
220 self[node.right],
221 ])
222 else:
223 transient = self.transient(node=node)
224 call = self.call(name=str(self[node.op]), node=node, code=[
225 self[transient],
226 self[node.left],
227 self[node.right],
228 ])
229 with self.pseudocode(node=node):
230 for (level, stmt) in self[call]:
231 self[node].emit(stmt=stmt, level=level)
232
233 @pc_util.Hook(pc_ast.UnaryExpr)
234 def UnaryExpr(self, node):
235 yield node
236 if isinstance(node.value, pc_ast.IfExpr):
237 self.fixup_ternary(node=node.value)
238 call = self.call(name=str(self[node.op]), node=node, code=[
239 self[node.value],
240 ])
241 with self.pseudocode(node=node):
242 for (level, stmt) in self[call]:
243 self[node].emit(stmt=stmt, level=level)
244
245 @pc_util.Hook(
246 pc_ast.Not, pc_ast.Add, pc_ast.Sub,
247 pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
248 pc_ast.Lt, pc_ast.Le,
249 pc_ast.Eq, pc_ast.NotEq,
250 pc_ast.Ge, pc_ast.Gt,
251 pc_ast.LtU, pc_ast.GtU,
252 pc_ast.LShift, pc_ast.RShift,
253 pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
254 pc_ast.BitConcat,
255 )
256 def Op(self, node):
257 yield node
258 op = {
259 pc_ast.Not: "oppc_not",
260 pc_ast.Add: "oppc_add",
261 pc_ast.Sub: "oppc_sub",
262 pc_ast.Mul: "oppc_mul",
263 pc_ast.Div: "oppc_div",
264 pc_ast.Mod: "oppc_mod",
265 pc_ast.Lt: "oppc_lt",
266 pc_ast.Le: "oppc_le",
267 pc_ast.Eq: "oppc_eq",
268 pc_ast.LtU: "oppc_ltu",
269 pc_ast.GtU: "oppc_gtu",
270 pc_ast.NotEq: "oppc_noteq",
271 pc_ast.Ge: "oppc_ge",
272 pc_ast.Gt: "oppc_gt",
273 pc_ast.LShift: "oppc_lshift",
274 pc_ast.RShift: "oppc_rshift",
275 pc_ast.BitAnd: "oppc_and",
276 pc_ast.BitOr: "oppc_or",
277 pc_ast.BitXor: "oppc_xor",
278 pc_ast.BitConcat: "oppc_concat",
279 }[node.__class__]
280 self[node].emit(stmt=op)
281
282 @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
283 def Integer(self, node):
284 yield node
285 fmt = hex
286 value = str(node)
287 if isinstance(node, pc_ast.BinLiteral):
288 bits = f"UINT8_C({str(len(value[2:]))})"
289 value = int(value, 2)
290 elif isinstance(node, pc_ast.HexLiteral):
291 bits = f"UINT8_C({str(len(value[2:]) * 4)})"
292 value = int(value, 16)
293 else:
294 bits = "(uint8_t)OPPC_XLEN"
295 value = int(value)
296 fmt = str
297 if (value > ((2**64) - 1)):
298 raise NotImplementedError()
299 value = f"UINT64_C({fmt(value)})"
300 transient = self.transient(node=node, value=value, bits=bits)
301 with self.pseudocode(node=node):
302 for (level, stmt) in self[transient]:
303 self[node].emit(stmt=stmt, level=level)
304
305 @pc_util.Hook(Transient)
306 def Transient(self, node):
307 yield node
308 self[node].emit(stmt=str(node))
309
310 @pc_util.Hook(Call)
311 def CCall(self, node):
312 yield node
313 end = (";" if node.stmt else "")
314 if len(node.code) == 0:
315 self[node].emit(stmt=f"{str(node.name)}(){end}")
316 else:
317 self[node].emit(stmt=f"{str(node.name)}(")
318 with self[node]:
319 (*head, tail) = node.code
320 for code in head:
321 for (level, stmt) in code:
322 self[node].emit(stmt=stmt, level=level)
323 (level, stmt) = self[node][-1]
324 if not (not stmt or
325 stmt.startswith("/*") or
326 stmt.endswith((",", "(", "{", "*/"))):
327 stmt = (stmt + ",")
328 self[node][-1] = (level, stmt)
329 for (level, stmt) in tail:
330 self[node].emit(stmt=stmt, level=level)
331 self[node].emit(stmt=f"){end}")
332
333 @pc_util.Hook(pc_ast.GPR)
334 def GPR(self, node):
335 yield node
336 with self.pseudocode(node=node):
337 self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
338
339 @pc_util.Hook(pc_ast.FPR)
340 def FPR(self, node):
341 yield node
342 with self.pseudocode(node=node):
343 self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
344
345 @pc_util.Hook(pc_ast.RepeatExpr)
346 def RepeatExpr(self, node):
347 yield node
348 transient = self.transient(node=node)
349 call = self.call(name="oppc_repeat", node=node, code=[
350 self[transient],
351 self[node.subject],
352 self[node.times],
353 ])
354 for (level, stmt) in self[call]:
355 self[node].emit(stmt=stmt, level=level)
356
357 @pc_util.Hook(pc_ast.XLEN)
358 def XLEN(self, node):
359 yield node
360 (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
361 transient = self.transient(node=node, value=value, bits=bits)
362 with self.pseudocode(node=node):
363 for (level, stmt) in self[transient]:
364 self[node].emit(stmt=stmt, level=level)
365
366 @pc_util.Hook(pc_ast.Overflow, pc_ast.CR3, pc_ast.CR5,
367 pc_ast.XER, pc_ast.Reserve, pc_ast.Special)
368 def Special(self, node):
369 yield node
370 with self.pseudocode(node=node):
371 self[node].emit(stmt=f"&OPPC_{str(node).upper()}")
372
373 @pc_util.Hook(pc_ast.SubscriptExpr)
374 def SubscriptExpr(self, node):
375 yield node
376 call = self.call(name="oppc_subscript", node=node, code=[
377 self[node.subject],
378 self[node.index],
379 ])
380 for (level, stmt) in self[call]:
381 self[node].emit(stmt=stmt, level=level)
382
383 @pc_util.Hook(pc_ast.RangeSubscriptExpr)
384 def RangeSubscriptExpr(self, node):
385 yield node
386 call = self.call(name="oppc_subscript", node=node, code=[
387 self[node.subject],
388 self[node.start],
389 self[node.end],
390 ])
391 for (level, stmt) in self[call]:
392 self[node].emit(stmt=stmt, level=level)
393
394 @pc_util.Hook(pc_ast.ForExpr)
395 def ForExpr(self, node):
396 yield node
397
398 enter = pc_ast.AssignExpr(
399 lvalue=node.subject.clone(),
400 rvalue=node.start.clone(),
401 )
402 match = pc_ast.BinaryExpr(
403 left=node.subject.clone(),
404 op=pc_ast.Le("<="),
405 right=node.end.clone(),
406 )
407 leave = pc_ast.AssignExpr(
408 lvalue=node.subject.clone(),
409 rvalue=pc_ast.BinaryExpr(
410 left=node.subject.clone(),
411 op=pc_ast.Add("+"),
412 right=node.end.clone(),
413 ),
414 )
415 with self.pseudocode(node=node):
416 (level, stmt) = self[node][0]
417 self[node].clear()
418 self[node].emit(stmt=stmt, level=level)
419 self[node].emit(stmt="for (")
420 with self[node]:
421 with self[node]:
422 for subnode in (enter, match, leave):
423 self.__pseudocode.traverse(root=subnode)
424 self.traverse(root=subnode)
425 for (level, stmt) in self[subnode]:
426 self[node].emit(stmt=stmt, level=level)
427 (level, stmt) = self[node][-1]
428 if subnode is match:
429 stmt = f"{stmt};"
430 elif subnode is leave:
431 stmt = stmt[:-1]
432 self[node][-1] = (level, stmt)
433 (level, stmt) = self[node][0]
434 self[node].emit(stmt=stmt, level=level)
435 self[node].emit(stmt=") {")
436 for (level, stmt) in self[node.body]:
437 self[node].emit(stmt=stmt, level=level)
438 self[node].emit(stmt="}")
439
440 @pc_util.Hook(pc_ast.WhileExpr)
441 def WhileExpr(self, node):
442 yield node
443 self[node].emit(stmt="while (")
444 with self[node]:
445 with self[node]:
446 for (level, stmt) in self[node.test]:
447 self[node].emit(stmt=stmt, level=level)
448 self[node].emit(") {")
449 for (level, stmt) in self[node.body]:
450 self[node].emit(stmt=stmt, level=level)
451 if node.orelse:
452 self[node].emit(stmt="} else {")
453 for (level, stmt) in self[node.orelse]:
454 self[node].emit(stmt=stmt, level=level)
455 self[node].emit(stmt="}")
456
457 @pc_util.Hook(pc_ast.IfExpr)
458 def IfExpr(self, node):
459 yield node
460 test = self.call(name="oppc_bool", node=node, code=[
461 self[node.test],
462 ])
463 self[node].emit(stmt="if (")
464 with self[node]:
465 for (level, stmt) in self[test]:
466 self[node].emit(stmt=stmt, level=level)
467 self[node].emit(stmt=") {")
468 for (level, stmt) in self[node.body]:
469 self[node].emit(stmt=stmt, level=level)
470 if node.orelse:
471 self[node].emit(stmt="} else {")
472 for (level, stmt) in self[node.orelse]:
473 self[node].emit(stmt=stmt, level=level)
474 self[node].emit(stmt="}")
475
476 @pc_util.Hook(pc_ast.SwitchExpr)
477 def SwitchExpr(self, node):
478 yield node
479 subject = self.call(name="oppc_int64", node=node, code=[
480 self[node.subject],
481 ])
482 self[node].emit(stmt="switch (")
483 with self[node]:
484 for (level, stmt) in self[subject]:
485 self[node].emit(stmt=stmt, level=level)
486 self[node].emit(") {")
487 with self[node]:
488 for (level, stmt) in self[node.cases]:
489 self[node].emit(stmt=stmt, level=level)
490
491 @pc_util.Hook(pc_ast.Cases)
492 def Cases(self, node):
493 yield node
494 for subnode in node:
495 for (level, stmt) in self[subnode]:
496 self[node].emit(stmt=stmt, level=level)
497
498 @pc_util.Hook(pc_ast.Case)
499 def Case(self, node):
500 yield node
501 for (level, stmt) in self[node.labels]:
502 self[node].emit(stmt=stmt, level=level)
503 for (level, stmt) in self[node.body]:
504 self[node].emit(stmt=stmt, level=level)
505
506 @pc_util.Hook(pc_ast.Labels)
507 def Labels(self, node):
508 yield node
509 if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)):
510 stmt = "default:"
511 else:
512 labels = ", ".join(map(lambda label: str(self[label]), node))
513 stmt = f"case ({labels}):"
514 self[node].emit(stmt=stmt)
515
516 @pc_util.Hook(pc_ast.Label)
517 def Label(self, node):
518 yield node
519 self[node].emit(stmt=str(node))
520
521 @pc_util.Hook(pc_ast.LeaveKeyword)
522 def LeaveKeyword(self, node):
523 yield node
524 self[node].emit(stmt="break;")
525
526 @pc_util.Hook(pc_ast.Call.Name)
527 def CallName(self, node):
528 yield node
529 self[node].emit(stmt=str(node))
530
531 @pc_util.Hook(pc_ast.Call.Arguments)
532 def CallArguments(self, node):
533 yield node
534 for subnode in node:
535 if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
536 self.__regfetch[str(subnode)].append(subnode)
537
538 @pc_util.Hook(pc_ast.Call)
539 def Call(self, node):
540 yield node
541 code = tuple(map(lambda arg: self[arg], node.args))
542 call = self.call(name=str(node.name), node=node, code=code)
543 for (level, stmt) in self[call]:
544 self[node].emit(stmt=stmt, level=level)
545
546 @pc_util.Hook(pc_ast.Attribute.Name)
547 def AttributeName(self, node):
548 yield node
549
550 @pc_util.Hook(pc_ast.Attribute)
551 def Attribute(self, node):
552 yield node
553 attr = str(self.__pseudocode[node])
554 symbol = f"OPPC_ATTR_{attr.replace('.', '_')}"
555 self[node].emit(f"/* {attr} */")
556 self[node].emit(stmt=symbol)
557 self.__attrs[node] = symbol
558
559 @pc_util.Hook(pc_ast.Symbol)
560 def Symbol(self, node):
561 yield node
562 with self.pseudocode(node=node):
563 if str(node) not in ("fallthrough",):
564 self.__decls[str(node)].append(node)
565 self[node].emit(stmt=f"&{str(node)}")
566
567 @pc_util.Hook(pc_ast.Node)
568 def Node(self, node):
569 raise NotImplementedError(type(node))
570
571
572 def code(name, root):
573 yield from CodeVisitor(name=name, root=root)