oppc/code: fix comparisons
[openpower-isa.git] / src / openpower / oppc / pc_code.py
1 import collections
2 import contextlib
3
4 import openpower.oppc.pc_ast as pc_ast
5 import openpower.oppc.pc_util as pc_util
6 import openpower.oppc.pc_pseudocode as pc_pseudocode
7
8
9 class Transient(pc_ast.Node):
10 def __init__(self, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
11 self.__value = value
12 self.__bits = bits
13
14 return super().__init__()
15
16 def __repr__(self):
17 return f"{hex(id(self))}@{self.__class__.__name__}({self.__value}, {self.__bits})"
18
19 def __str__(self):
20 return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
21
22
23 class Call(pc_ast.Dataclass):
24 name: str
25 code: tuple
26 stmt: bool
27
28
29 class Instruction(pc_ast.Node):
30 pass
31
32
33 class CodeVisitor(pc_util.Visitor):
34 def __init__(self, insn, root):
35 self.__root = root
36 self.__insn = insn
37 self.__attrs = {}
38 self.__decls = set()
39 self.__header = object()
40 self.__footer = object()
41 self.__code = collections.defaultdict(lambda: pc_util.Code())
42 self.__regfetch = collections.defaultdict(list)
43 self.__regstore = collections.defaultdict(list)
44 self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
45
46 super().__init__(root=root)
47
48 self.__code[self.__header].emit(stmt="void")
49 self.__code[self.__header].emit(stmt=f"oppc_{insn.name}(struct oppc_value const *insn) {{")
50 with self.__code[self.__header]:
51 for decl in self.__decls:
52 self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
53 decls = sorted(filter(lambda decl: decl in insn.fields, self.__decls))
54 if decls:
55 self.__code[self.__header].emit()
56 for decl in decls:
57 bits = f"{len(insn.fields[decl])}"
58 transient = Transient(bits=bits)
59 symbol = pc_ast.Symbol(decl)
60 assign = pc_ast.AssignExpr(lvalue=symbol, rvalue=transient)
61 self.traverse(root=assign)
62 with self[self.__header]:
63 for (level, stmt) in self[assign]:
64 self[self.__header].emit(stmt=stmt, level=level)
65 for (lbit, rbit) in enumerate(insn.fields[decl]):
66 lsymbol = pc_ast.Symbol(decl)
67 rsymbol = Instruction()
68 lindex = Transient(value=str(lbit))
69 rindex = Transient(value=str(rbit))
70 lvalue = pc_ast.SubscriptExpr(index=lindex, subject=lsymbol)
71 rvalue = pc_ast.SubscriptExpr(index=rindex, subject=rsymbol)
72 assign = pc_ast.AssignExpr(lvalue=lvalue, rvalue=rvalue)
73 self.traverse(root=assign)
74 with self[self.__header]:
75 for (level, stmt) in self[assign]:
76 self[self.__header].emit(stmt=stmt, level=level)
77 self.__code[self.__header].emit()
78 if decls:
79 self.__code[self.__header].emit()
80 self.__code[self.__footer].emit(stmt=f"}}")
81
82 def __iter__(self):
83 yield from self.__code[self.__header]
84 yield from self.__code[self.__root]
85 yield from self.__code[self.__footer]
86
87 def __getitem__(self, node):
88 return self.__code[node]
89
90 def __setitem__(self, node, code):
91 self.__code[node] = code
92
93 def transient(self,
94 value="UINT64_C(0)",
95 bits="(uint8_t)OPPC_XLEN"):
96 transient = Transient(value=value, bits=bits)
97 self.traverse(root=transient)
98 return transient
99
100 def call(self, name, code, stmt=False):
101 def validate(item):
102 def validate(item):
103 (level, stmt) = item
104 if not isinstance(level, int):
105 raise ValueError(level)
106 if not isinstance(stmt, str):
107 raise ValueError(stmt)
108 return (level, stmt)
109
110 return tuple(map(validate, item))
111
112 code = tuple(map(validate, code))
113 call = Call(name=name, code=code, stmt=stmt)
114 self.traverse(root=call)
115 return call
116
117 def fixup_ternary(self, node):
118 self[node].clear()
119 test = self.call(name="oppc_bool", code=[
120 self[node.test],
121 ])
122 self[node].emit(stmt="(")
123 with self[node]:
124 for (level, stmt) in self[test]:
125 self[node].emit(stmt=stmt, level=level)
126 self[node].emit(stmt="?")
127 for (level, stmt) in self[node.body]:
128 self[node].emit(stmt=stmt, level=level)
129 self[node].emit(stmt=":")
130 for (level, stmt) in self[node.orelse]:
131 self[node].emit(stmt=stmt, level=level)
132 self[node].emit(stmt=")")
133
134 def fixup_attr(self, node, assign=False):
135 root = node
136 code = tuple(self[root])
137 attribute_or_subscript = (
138 pc_ast.Attribute,
139 pc_ast.SubscriptExpr,
140 pc_ast.RangeSubscriptExpr,
141 )
142 while isinstance(node.subject, attribute_or_subscript):
143 node = node.subject
144
145 def wrap(code):
146 def wrap(item):
147 (level, stmt) = item
148 if not (not stmt or
149 stmt.startswith("/*") or
150 stmt.endswith((",", "(", "{", "*/"))):
151 stmt = (stmt + ",")
152 return (level, stmt)
153
154 return tuple(map(wrap, code))
155
156 code = pc_util.Code()
157 for (level, stmt) in wrap(self[node.subject]):
158 code.emit(stmt=stmt, level=level)
159 for (level, stmt) in wrap(self[root]):
160 code.emit(stmt=stmt, level=level)
161
162 # discard the last comma
163 (level, stmt) = code[-1]
164 code[-1] = (level, stmt[:-1])
165
166 if not assign:
167 call = self.call(name="oppc_attr", code=[
168 code,
169 ])
170 code = self[call]
171 self[root] = code
172
173 @contextlib.contextmanager
174 def pseudocode(self, node):
175 for (level, stmt) in self.__pseudocode[node]:
176 self[node].emit(stmt=f"/* {stmt} */", level=level)
177 yield
178
179 @pc_util.Hook(pc_ast.Scope)
180 def Scope(self, node):
181 yield node
182 with self[node]:
183 for subnode in node:
184 for (level, stmt) in self[subnode]:
185 self[node].emit(stmt=stmt, level=level)
186
187 @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
188 def AssignExpr(self, node):
189 yield node
190 if isinstance(node.lvalue, (pc_ast.GPR, pc_ast.FPR)):
191 self.__regstore[str(node.lvalue)].append(node.lvalue)
192 if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
193 self.__regfetch[str(node.rvalue)].append(node.rvalue)
194
195 if isinstance(node.rvalue, pc_ast.IfExpr):
196 self.fixup_ternary(node=node.rvalue)
197 if isinstance(node.lvalue, pc_ast.Attribute):
198 self.fixup_attr(node=node.lvalue, assign=True)
199 if isinstance(node.rvalue, pc_ast.Attribute):
200 self.fixup_attr(node=node.rvalue)
201
202 if isinstance(node.lvalue, pc_ast.Sequence):
203 if not isinstance(node.rvalue, pc_ast.Sequence):
204 raise ValueError(node.rvalue)
205 if len(node.lvalue) != len(node.rvalue):
206 raise ValueError(node)
207 for (lvalue, rvalue) in zip(node.lvalue, node.rvalue):
208 assign = node.__class__(
209 lvalue=lvalue.clone(),
210 rvalue=rvalue.clone(),
211 )
212 self.traverse(root=assign)
213 for (level, stmt) in self[assign]:
214 self[node].emit(stmt=stmt, level=level)
215 return
216
217 if isinstance(node.lvalue, pc_ast.SubscriptExpr):
218 call = self.call(name="oppc_subscript_assign", stmt=True, code=[
219 self[node.lvalue.subject],
220 self[node.lvalue.index],
221 self[node.rvalue],
222 ])
223 elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
224 call = self.call(name="oppc_range_subscript_assign", stmt=True, code=[
225 self[node.lvalue.subject],
226 self[node.lvalue.start],
227 self[node.lvalue.end],
228 self[node.rvalue],
229 ])
230 elif isinstance(node.lvalue, pc_ast.Attribute):
231 call = self.call(name="oppc_attr_assign", stmt=True, code=[
232 self[node.lvalue],
233 self[node.rvalue],
234 ])
235 else:
236 call = self.call(name="oppc_assign", stmt=True, code=[
237 self[node.lvalue],
238 self[node.rvalue],
239 ])
240 with self.pseudocode(node=node):
241 for (level, stmt) in self[call]:
242 self[node].emit(stmt=stmt, level=level)
243
244 @pc_util.Hook(pc_ast.BinaryExpr)
245 def BinaryExpr(self, node):
246 yield node
247 if isinstance(node.left, (pc_ast.GPR, pc_ast.FPR)):
248 self.__regfetch[str(node.left)].append(node.left)
249 if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
250 self.__regfetch[str(node.right)].append(node.left)
251
252 if isinstance(node.left, pc_ast.IfExpr):
253 self.fixup_ternary(node=node.left)
254 if isinstance(node.right, pc_ast.IfExpr):
255 self.fixup_ternary(node=node.right)
256 if isinstance(node.left, pc_ast.Attribute):
257 self.fixup_attr(node=node.left)
258 if isinstance(node.right, pc_ast.Attribute):
259 self.fixup_attr(node=node.right)
260
261 comparison = (
262 pc_ast.Lt, pc_ast.Le,
263 pc_ast.Eq, pc_ast.NotEq,
264 pc_ast.Ge, pc_ast.Gt,
265 pc_ast.LtU, pc_ast.GtU,
266 )
267 if isinstance(node.op, comparison):
268 transient = self.transient(bits="UINT8_C(1)")
269 else:
270 transient = self.transient()
271 call = self.call(name=str(self[node.op]), code=[
272 self[transient],
273 self[node.left],
274 self[node.right],
275 ])
276 with self.pseudocode(node=node):
277 for (level, stmt) in self[call]:
278 self[node].emit(stmt=stmt, level=level)
279
280 @pc_util.Hook(pc_ast.UnaryExpr)
281 def UnaryExpr(self, node):
282 yield node
283 if isinstance(node.value, pc_ast.IfExpr):
284 self.fixup_ternary(node=node.value)
285 call = self.call(name=str(self[node.op]), code=[
286 self[node.value],
287 ])
288 with self.pseudocode(node=node):
289 for (level, stmt) in self[call]:
290 self[node].emit(stmt=stmt, level=level)
291
292 @pc_util.Hook(
293 pc_ast.Not, pc_ast.Add, pc_ast.Sub,
294 pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
295 pc_ast.Lt, pc_ast.Le,
296 pc_ast.Eq, pc_ast.NotEq,
297 pc_ast.Ge, pc_ast.Gt,
298 pc_ast.LtU, pc_ast.GtU,
299 pc_ast.LShift, pc_ast.RShift,
300 pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
301 pc_ast.BitConcat,
302 )
303 def Op(self, node):
304 yield node
305 op = {
306 pc_ast.Not: "oppc_not",
307 pc_ast.Add: "oppc_add",
308 pc_ast.Sub: "oppc_sub",
309 pc_ast.Mul: "oppc_mul",
310 pc_ast.Div: "oppc_div",
311 pc_ast.Mod: "oppc_mod",
312 pc_ast.Lt: "oppc_lt",
313 pc_ast.Le: "oppc_le",
314 pc_ast.Eq: "oppc_eq",
315 pc_ast.LtU: "oppc_ltu",
316 pc_ast.GtU: "oppc_gtu",
317 pc_ast.NotEq: "oppc_noteq",
318 pc_ast.Ge: "oppc_ge",
319 pc_ast.Gt: "oppc_gt",
320 pc_ast.LShift: "oppc_lshift",
321 pc_ast.RShift: "oppc_rshift",
322 pc_ast.BitAnd: "oppc_and",
323 pc_ast.BitOr: "oppc_or",
324 pc_ast.BitXor: "oppc_xor",
325 pc_ast.BitConcat: "oppc_concat",
326 }[node.__class__]
327 self[node].emit(stmt=op)
328
329 @pc_util.Hook(pc_ast.StringLiteral)
330 def StringLiteral(self, node):
331 yield node
332 escaped = repr(str(node))[1:-1]
333 self[node].emit(stmt=f"\"{escaped}\"")
334
335 @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
336 def Integer(self, node):
337 yield node
338 fmt = hex
339 value = str(node)
340 if isinstance(node, pc_ast.BinLiteral):
341 bits = f"UINT8_C({str(len(value[2:]))})"
342 value = int(value, 2)
343 elif isinstance(node, pc_ast.HexLiteral):
344 bits = f"UINT8_C({str(len(value[2:]) * 4)})"
345 value = int(value, 16)
346 else:
347 bits = "(uint8_t)OPPC_XLEN"
348 value = int(value)
349 fmt = str
350 if (value > ((2**64) - 1)):
351 raise NotImplementedError()
352 value = f"UINT64_C({fmt(value)})"
353 transient = self.transient(value=value, bits=bits)
354 with self.pseudocode(node=node):
355 for (level, stmt) in self[transient]:
356 self[node].emit(stmt=stmt, level=level)
357
358 @pc_util.Hook(Transient)
359 def Transient(self, node):
360 yield node
361 self[node].emit(stmt=str(node))
362
363 @pc_util.Hook(Call)
364 def CCall(self, node):
365 yield node
366 end = (";" if node.stmt else "")
367 if len(node.code) == 0:
368 self[node].emit(stmt=f"{str(node.name)}(){end}")
369 else:
370 self[node].emit(stmt=f"{str(node.name)}(")
371 with self[node]:
372 (*head, tail) = node.code
373 for code in head:
374 for (level, stmt) in code:
375 self[node].emit(stmt=stmt, level=level)
376 (level, stmt) = self[node][-1]
377 if not (not stmt or
378 stmt.startswith("/*") or
379 stmt.endswith((",", "(", "{", "*/"))):
380 stmt = (stmt + ",")
381 self[node][-1] = (level, stmt)
382 for (level, stmt) in tail:
383 self[node].emit(stmt=stmt, level=level)
384 self[node].emit(stmt=f"){end}")
385
386 @pc_util.Hook(pc_ast.GPR)
387 def GPR(self, node):
388 yield node
389 with self.pseudocode(node=node):
390 self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
391
392 @pc_util.Hook(pc_ast.GPRZero)
393 def GPRZero(self, node):
394 yield node
395 name = str(node)
396 test = pc_ast.Symbol(name)
397 body = pc_ast.Scope([pc_ast.GPR(name)])
398 orelse = pc_ast.Scope([Transient()])
399 ifexpr = pc_ast.IfExpr(test=test, body=body, orelse=orelse)
400 self.traverse(root=ifexpr)
401 self.fixup_ternary(node=ifexpr)
402 for (level, stmt) in self[ifexpr]:
403 self[node].emit(stmt=stmt, level=level)
404
405 @pc_util.Hook(pc_ast.FPR)
406 def FPR(self, node):
407 yield node
408 with self.pseudocode(node=node):
409 self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
410
411 @pc_util.Hook(pc_ast.RepeatExpr)
412 def RepeatExpr(self, node):
413 yield node
414 transient = self.transient()
415 call = self.call(name="oppc_repeat", code=[
416 self[transient],
417 self[node.subject],
418 self[node.times],
419 ])
420 for (level, stmt) in self[call]:
421 self[node].emit(stmt=stmt, level=level)
422
423 @pc_util.Hook(pc_ast.XLEN)
424 def XLEN(self, node):
425 yield node
426 (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
427 transient = self.transient(value=value, bits=bits)
428 with self.pseudocode(node=node):
429 for (level, stmt) in self[transient]:
430 self[node].emit(stmt=stmt, level=level)
431
432 @pc_util.Hook(pc_ast.Overflow, pc_ast.CR3, pc_ast.CR5,
433 pc_ast.XER, pc_ast.Reserve, pc_ast.Special)
434 def Special(self, node):
435 yield node
436 with self.pseudocode(node=node):
437 self[node].emit(stmt=f"&OPPC_{str(node).upper()}")
438
439 @pc_util.Hook(pc_ast.SubscriptExpr)
440 def SubscriptExpr(self, node):
441 yield node
442 transient = self.transient()
443 call = self.call(name="oppc_subscript", code=[
444 self[transient],
445 self[node.subject],
446 self[node.index],
447 ])
448 for (level, stmt) in self[call]:
449 self[node].emit(stmt=stmt, level=level)
450
451 @pc_util.Hook(pc_ast.RangeSubscriptExpr)
452 def RangeSubscriptExpr(self, node):
453 yield node
454 transient = self.transient()
455 call = self.call(name="oppc_range_subscript", code=[
456 self[transient],
457 self[node.subject],
458 self[node.start],
459 self[node.end],
460 ])
461 for (level, stmt) in self[call]:
462 self[node].emit(stmt=stmt, level=level)
463
464 @pc_util.Hook(pc_ast.ForExpr)
465 def ForExpr(self, node):
466 yield node
467
468 enter = pc_ast.AssignExpr(
469 lvalue=node.subject.clone(),
470 rvalue=node.start.clone(),
471 )
472 match = pc_ast.BinaryExpr(
473 left=node.subject.clone(),
474 op=pc_ast.Le("<="),
475 right=node.end.clone(),
476 )
477 leave = pc_ast.AssignExpr(
478 lvalue=node.subject.clone(),
479 rvalue=pc_ast.BinaryExpr(
480 left=node.subject.clone(),
481 op=pc_ast.Add("+"),
482 right=node.end.clone(),
483 ),
484 )
485 with self.pseudocode(node=node):
486 (level, stmt) = self[node][0]
487 self[node].clear()
488 self[node].emit(stmt=stmt, level=level)
489 self[node].emit(stmt="for (")
490 with self[node]:
491 with self[node]:
492 for subnode in (enter, match, leave):
493 self.__pseudocode.traverse(root=subnode)
494 self.traverse(root=subnode)
495 for (level, stmt) in self[subnode]:
496 self[node].emit(stmt=stmt, level=level)
497 (level, stmt) = self[node][-1]
498 if subnode is match:
499 stmt = f"{stmt};"
500 elif subnode is leave:
501 stmt = stmt[:-1]
502 self[node][-1] = (level, stmt)
503 (level, stmt) = self[node][0]
504 self[node].emit(stmt=stmt, level=level)
505 self[node].emit(stmt=") {")
506 for (level, stmt) in self[node.body]:
507 self[node].emit(stmt=stmt, level=level)
508 self[node].emit(stmt="}")
509
510 @pc_util.Hook(pc_ast.WhileExpr)
511 def WhileExpr(self, node):
512 yield node
513 self[node].emit(stmt="while (")
514 with self[node]:
515 with self[node]:
516 for (level, stmt) in self[node.test]:
517 self[node].emit(stmt=stmt, level=level)
518 self[node].emit(") {")
519 for (level, stmt) in self[node.body]:
520 self[node].emit(stmt=stmt, level=level)
521 if node.orelse:
522 self[node].emit(stmt="} else {")
523 for (level, stmt) in self[node.orelse]:
524 self[node].emit(stmt=stmt, level=level)
525 self[node].emit(stmt="}")
526
527 @pc_util.Hook(pc_ast.IfExpr)
528 def IfExpr(self, node):
529 yield node
530 test = self.call(name="oppc_bool", code=[
531 self[node.test],
532 ])
533 self[node].emit(stmt="if (")
534 with self[node]:
535 for (level, stmt) in self[test]:
536 self[node].emit(stmt=stmt, level=level)
537 self[node].emit(stmt=") {")
538 for (level, stmt) in self[node.body]:
539 self[node].emit(stmt=stmt, level=level)
540 if node.orelse:
541 self[node].emit(stmt="} else {")
542 for (level, stmt) in self[node.orelse]:
543 self[node].emit(stmt=stmt, level=level)
544 self[node].emit(stmt="}")
545
546 @pc_util.Hook(pc_ast.SwitchExpr)
547 def SwitchExpr(self, node):
548 yield node
549 subject = self.call(name="oppc_int64", code=[
550 self[node.subject],
551 ])
552 self[node].emit(stmt="switch (")
553 with self[node]:
554 for (level, stmt) in self[subject]:
555 self[node].emit(stmt=stmt, level=level)
556 self[node].emit(") {")
557 with self[node]:
558 for (level, stmt) in self[node.cases]:
559 self[node].emit(stmt=stmt, level=level)
560
561 @pc_util.Hook(pc_ast.Cases)
562 def Cases(self, node):
563 yield node
564 for subnode in node:
565 for (level, stmt) in self[subnode]:
566 self[node].emit(stmt=stmt, level=level)
567
568 @pc_util.Hook(pc_ast.Case)
569 def Case(self, node):
570 yield node
571 for (level, stmt) in self[node.labels]:
572 self[node].emit(stmt=stmt, level=level)
573 for (level, stmt) in self[node.body]:
574 self[node].emit(stmt=stmt, level=level)
575
576 @pc_util.Hook(pc_ast.Labels)
577 def Labels(self, node):
578 yield node
579 if ((len(node) == 1) and isinstance(node[-1], pc_ast.DefaultLabel)):
580 stmt = "default:"
581 else:
582 labels = ", ".join(map(lambda label: str(self[label]), node))
583 stmt = f"case ({labels}):"
584 self[node].emit(stmt=stmt)
585
586 @pc_util.Hook(pc_ast.Label)
587 def Label(self, node):
588 yield node
589 self[node].emit(stmt=str(node))
590
591 @pc_util.Hook(pc_ast.LeaveKeyword)
592 def LeaveKeyword(self, node):
593 yield node
594 self[node].emit(stmt="break;")
595
596 @pc_util.Hook(pc_ast.Call.Name)
597 def CallName(self, node):
598 yield node
599 self[node].emit(stmt=str(node))
600
601 @pc_util.Hook(pc_ast.Call.Arguments)
602 def CallArguments(self, node):
603 yield node
604 for subnode in node:
605 if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
606 self.__regfetch[str(subnode)].append(subnode)
607
608 @pc_util.Hook(pc_ast.Call)
609 def Call(self, node):
610 yield node
611 code = tuple(map(lambda arg: self[arg], node.args))
612 call = self.call(name=str(node.name), code=code)
613 for (level, stmt) in self[call]:
614 self[node].emit(stmt=stmt, level=level)
615
616 @pc_util.Hook(pc_ast.Attribute.Name)
617 def AttributeName(self, node):
618 yield node
619
620 @pc_util.Hook(pc_ast.Sequence)
621 def Sequence(self, node):
622 yield node
623
624 @pc_util.Hook(pc_ast.Attribute)
625 def Attribute(self, node):
626 yield node
627 attr = str(self.__pseudocode[node])
628 symbol = f"OPPC_ATTR_{attr.replace('.', '_')}"
629 self[node].emit(f"/* {attr} */")
630 self[node].emit(stmt=symbol)
631 self.__attrs[node] = symbol
632
633 @pc_util.Hook(pc_ast.Symbol)
634 def Symbol(self, node):
635 yield node
636 with self.pseudocode(node=node):
637 decl = str(node)
638 if decl not in ("fallthrough",):
639 if decl in ("TRAP",):
640 self[node].emit(stmt=f"{decl}();")
641 else:
642 self.__decls.add(decl)
643 self[node].emit(stmt=f"&{decl}")
644
645 @pc_util.Hook(Instruction)
646 def Instruction(self, node):
647 yield node
648 self[node].emit("insn")
649
650 @pc_util.Hook(pc_ast.Node)
651 def Node(self, node):
652 raise NotImplementedError(type(node))
653
654
655 def code(insn, root):
656 yield from CodeVisitor(insn=insn, root=root)