oppc/pseudocode: always use keywords for emit calls
[openpower-isa.git] / src / openpower / oppc / pc_code.py
1 import collections
2 import contextlib
3
4 import openpower.oppc.pc_ast as pc_ast
5 import openpower.oppc.pc_util as pc_util
6 import openpower.oppc.pc_pseudocode as pc_pseudocode
7
8
9 class Transient(pc_ast.Node):
10 def __init__(self, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
11 self.__value = value
12 self.__bits = bits
13
14 return super().__init__()
15
16 def __str__(self):
17 return f"oppc_transient(&(struct oppc_value){{}}, {self.__value}, {self.__bits})"
18
19
20 class CCall(pc_ast.Dataclass):
21 name: str
22 code: tuple
23 stmt: bool
24
25
26 class CodeVisitor(pc_util.Visitor):
27 def __init__(self, name, root):
28 self.__root = root
29 self.__header = object()
30 self.__footer = object()
31 self.__code = collections.defaultdict(lambda: pc_util.Code())
32 self.__decls = collections.defaultdict(list)
33 self.__regfetch = collections.defaultdict(list)
34 self.__regstore = collections.defaultdict(list)
35 self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
36
37 super().__init__(root=root)
38
39 self.__code[self.__header].emit(stmt="void")
40 self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
41 with self.__code[self.__header]:
42 for decl in self.__decls:
43 self.__code[self.__header].emit(stmt=f"struct oppc_value {decl};")
44 self.__code[self.__footer].emit(stmt=f"}}")
45
46 def __iter__(self):
47 yield from self.__code[self.__header]
48 yield from self.__code[self.__root]
49 yield from self.__code[self.__footer]
50
51 def __getitem__(self, node):
52 return self.__code[node]
53
54 def transient(self, node,
55 value="UINT64_C(0)",
56 bits="(uint8_t)OPPC_XLEN"):
57 transient = Transient(value=value, bits=bits)
58 self.traverse(root=transient)
59 return transient
60
61 def ccall(self, node, name, code, stmt=False):
62 def validate(item):
63 def validate(item):
64 (level, stmt) = item
65 if not isinstance(level, int):
66 raise ValueError(level)
67 if not isinstance(stmt, str):
68 raise ValueError(stmt)
69 return (level, stmt)
70
71 return tuple(map(validate, item))
72
73 code = tuple(map(validate, code))
74 ccall = CCall(name=name, code=code, stmt=stmt)
75 self.traverse(root=ccall)
76 return ccall
77
78 def ternary(self, node):
79 self[node].clear()
80 self[node].emit(stmt="(")
81 with self[node]:
82 for (level, stmt) in self[node.test]:
83 self[node].emit(stmt=stmt, level=level)
84 self[node].emit(stmt="?")
85 for (level, stmt) in self[node.body]:
86 self[node].emit(stmt=stmt, level=level)
87 self[node].emit(stmt=":")
88 for (level, stmt) in self[node.orelse]:
89 self[node].emit(stmt=stmt, level=level)
90 self[node].emit(stmt=")")
91
92 @contextlib.contextmanager
93 def pseudocode(self, node):
94 for (level, stmt) in self.__pseudocode[node]:
95 self[node].emit(stmt=f"/* {stmt} */", level=level)
96 yield
97
98 @pc_util.Hook(pc_ast.Scope)
99 def Scope(self, node):
100 yield node
101 with self[node]:
102 for subnode in node:
103 for (level, stmt) in self[subnode]:
104 self[node].emit(stmt=stmt, level=level)
105
106 @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
107 def AssignExpr(self, node):
108 yield node
109 if isinstance(node.lvalue, (pc_ast.GPR, pc_ast.FPR)):
110 self.__regstore[str(node.lvalue)].append(node.lvalue)
111 if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
112 self.__regfetch[str(node.rvalue)].append(node.rvalue)
113
114 if isinstance(node.rvalue, pc_ast.IfExpr):
115 self.ternary(node=node.rvalue)
116
117 if isinstance(node.lvalue, pc_ast.SubscriptExpr):
118 ccall = self.ccall(name="oppc_subscript_assign", node=node, stmt=True, code=[
119 self[node.lvalue.subject],
120 self[node.lvalue.index],
121 self[node.rvalue],
122 ])
123 elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
124 ccall = self.ccall(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
125 self[node.lvalue.subject],
126 self[node.lvalue.start],
127 self[node.lvalue.end],
128 self[node.rvalue],
129 ])
130 else:
131 ccall = self.ccall(name="oppc_assign", stmt=True, node=node, code=[
132 self[node.lvalue],
133 self[node.rvalue],
134 ])
135 with self.pseudocode(node=node):
136 for (level, stmt) in self[ccall]:
137 self[node].emit(stmt=stmt, level=level)
138
139 @pc_util.Hook(pc_ast.BinaryExpr)
140 def BinaryExpr(self, node):
141 yield node
142 if isinstance(node.left, (pc_ast.GPR, pc_ast.FPR)):
143 self.__regfetch[str(node.left)].append(node.left)
144 if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
145 self.__regfetch[str(node.right)].append(node.left)
146
147 comparison = (
148 pc_ast.Lt, pc_ast.Le,
149 pc_ast.Eq, pc_ast.NotEq,
150 pc_ast.Ge, pc_ast.Gt,
151 pc_ast.LtU, pc_ast.GtU,
152 )
153 if isinstance(node.left, pc_ast.IfExpr):
154 self.ternary(node=node.left)
155 if isinstance(node.right, pc_ast.IfExpr):
156 self.ternary(node=node.right)
157
158 if isinstance(node.op, comparison):
159 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
160 self[node.left],
161 self[node.right],
162 ])
163 else:
164 transient = self.transient(node=node)
165 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
166 self[transient],
167 self[node.left],
168 self[node.right],
169 ])
170 with self.pseudocode(node=node):
171 for (level, stmt) in self[ccall]:
172 self[node].emit(stmt=stmt, level=level)
173
174 @pc_util.Hook(pc_ast.UnaryExpr)
175 def UnaryExpr(self, node):
176 yield node
177 if isinstance(node.value, pc_ast.IfExpr):
178 self.ternary(node=node.value)
179 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
180 self[node.value],
181 ])
182 with self.pseudocode(node=node):
183 for (level, stmt) in self[ccall]:
184 self[node].emit(stmt=stmt, level=level)
185
186 @pc_util.Hook(
187 pc_ast.Not, pc_ast.Add, pc_ast.Sub,
188 pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
189 pc_ast.Lt, pc_ast.Le,
190 pc_ast.Eq, pc_ast.NotEq,
191 pc_ast.Ge, pc_ast.Gt,
192 pc_ast.LtU, pc_ast.GtU,
193 pc_ast.LShift, pc_ast.RShift,
194 pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
195 pc_ast.BitConcat,
196 )
197 def Op(self, node):
198 yield node
199 op = {
200 pc_ast.Not: "oppc_not",
201 pc_ast.Add: "oppc_add",
202 pc_ast.Sub: "oppc_sub",
203 pc_ast.Mul: "oppc_mul",
204 pc_ast.Div: "oppc_div",
205 pc_ast.Mod: "oppc_mod",
206 pc_ast.Lt: "oppc_lt",
207 pc_ast.Le: "oppc_le",
208 pc_ast.Eq: "oppc_eq",
209 pc_ast.LtU: "oppc_ltu",
210 pc_ast.GtU: "oppc_gtu",
211 pc_ast.NotEq: "oppc_noteq",
212 pc_ast.Ge: "oppc_ge",
213 pc_ast.Gt: "oppc_gt",
214 pc_ast.LShift: "oppc_lshift",
215 pc_ast.RShift: "oppc_rshift",
216 pc_ast.BitAnd: "oppc_and",
217 pc_ast.BitOr: "oppc_or",
218 pc_ast.BitXor: "oppc_xor",
219 pc_ast.BitConcat: "oppc_concat",
220 }[node.__class__]
221 self[node].emit(stmt=op)
222
223 @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
224 def Integer(self, node):
225 yield node
226 fmt = hex
227 value = str(node)
228 if isinstance(node, pc_ast.BinLiteral):
229 bits = f"UINT8_C({str(len(value[2:]))})"
230 value = int(value, 2)
231 elif isinstance(node, pc_ast.HexLiteral):
232 bits = f"UINT8_C({str(len(value[2:]) * 4)})"
233 value = int(value, 16)
234 else:
235 bits = "(uint8_t)OPPC_XLEN"
236 value = int(value)
237 fmt = str
238 if (value > ((2**64) - 1)):
239 raise NotImplementedError()
240 value = f"UINT64_C({fmt(value)})"
241 transient = self.transient(node=node, value=value, bits=bits)
242 with self.pseudocode(node=node):
243 for (level, stmt) in self[transient]:
244 self[node].emit(stmt=stmt, level=level)
245
246 @pc_util.Hook(Transient)
247 def Transient(self, node):
248 yield node
249 self[node].emit(stmt=str(node))
250
251 @pc_util.Hook(CCall)
252 def CCall(self, node):
253 yield node
254 end = (";" if node.stmt else "")
255 if len(node.code) == 0:
256 self[node].emit(stmt=f"{str(node.name)}(){end}")
257 else:
258 self[node].emit(stmt=f"{str(node.name)}(")
259 with self[node]:
260 (*head, tail) = node.code
261 for code in head:
262 for (level, stmt) in code:
263 self[node].emit(stmt=stmt, level=level)
264 (level, stmt) = self[node][-1]
265 if not (not stmt or
266 stmt.startswith("/*") or
267 stmt.endswith((",", "(", "{", "*/"))):
268 stmt = (stmt + ",")
269 self[node][-1] = (level, stmt)
270 for (level, stmt) in tail:
271 self[node].emit(stmt=stmt, level=level)
272 self[node].emit(stmt=f"){end}")
273
274 @pc_util.Hook(pc_ast.GPR)
275 def GPR(self, node):
276 yield node
277 with self.pseudocode(node=node):
278 self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
279
280 @pc_util.Hook(pc_ast.FPR)
281 def FPR(self, node):
282 yield node
283 with self.pseudocode(node=node):
284 self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
285
286 @pc_util.Hook(pc_ast.RepeatExpr)
287 def RepeatExpr(self, node):
288 yield node
289 transient = self.transient(node=node)
290 ccall = self.ccall(name="oppc_repeat", node=node, code=[
291 self[transient],
292 self[node.subject],
293 self[node.times],
294 ])
295 for (level, stmt) in self[ccall]:
296 self[node].emit(stmt=stmt, level=level)
297
298 @pc_util.Hook(pc_ast.XLEN)
299 def XLEN(self, node):
300 yield node
301 (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
302 transient = self.transient(node=node, value=value, bits=bits)
303 with self.pseudocode(node=node):
304 for (level, stmt) in self[transient]:
305 self[node].emit(stmt=stmt, level=level)
306
307 @pc_util.Hook(pc_ast.SubscriptExpr)
308 def SubscriptExpr(self, node):
309 yield node
310 ccall = self.ccall(name="oppc_subscript", node=node, code=[
311 self[node.subject],
312 self[node.index],
313 ])
314 for (level, stmt) in self[ccall]:
315 self[node].emit(stmt=stmt, level=level)
316
317 @pc_util.Hook(pc_ast.RangeSubscriptExpr)
318 def RangeSubscriptExpr(self, node):
319 yield node
320 ccall = self.ccall(name="oppc_subscript", node=node, code=[
321 self[node.subject],
322 self[node.start],
323 self[node.end],
324 ])
325 for (level, stmt) in self[ccall]:
326 self[node].emit(stmt=stmt, level=level)
327
328 @pc_util.Hook(pc_ast.ForExpr)
329 def ForExpr(self, node):
330 yield node
331
332 enter = pc_ast.AssignExpr(
333 lvalue=node.subject.clone(),
334 rvalue=node.start.clone(),
335 )
336 match = pc_ast.BinaryExpr(
337 left=node.subject.clone(),
338 op=pc_ast.Le("<="),
339 right=node.end.clone(),
340 )
341 leave = pc_ast.AssignExpr(
342 lvalue=node.subject.clone(),
343 rvalue=pc_ast.BinaryExpr(
344 left=node.subject.clone(),
345 op=pc_ast.Add("+"),
346 right=node.end.clone(),
347 ),
348 )
349 with self.pseudocode(node=node):
350 (level, stmt) = self[node][0]
351 self[node].clear()
352 self[node].emit(stmt=stmt, level=level)
353 self[node].emit(stmt="for (")
354 with self[node]:
355 with self[node]:
356 for subnode in (enter, match, leave):
357 self.__pseudocode.traverse(root=subnode)
358 self.traverse(root=subnode)
359 for (level, stmt) in self[subnode]:
360 self[node].emit(stmt=stmt, level=level)
361 (level, stmt) = self[node][-1]
362 if subnode is match:
363 stmt = f"{stmt};"
364 elif subnode is leave:
365 stmt = stmt[:-1]
366 self[node][-1] = (level, stmt)
367 (level, stmt) = self[node][0]
368 self[node].emit(stmt=stmt, level=level)
369 self[node].emit(stmt=") {")
370 for (level, stmt) in self[node.body]:
371 self[node].emit(stmt=stmt, level=level)
372 self[node].emit(stmt="}")
373
374 @pc_util.Hook(pc_ast.WhileExpr)
375 def WhileExpr(self, node):
376 yield node
377 self[node].emit(stmt="while (")
378 with self[node]:
379 with self[node]:
380 for (level, stmt) in self[node.test]:
381 self[node].emit(stmt=stmt, level=level)
382 self[node].emit(") {")
383 for (level, stmt) in self[node.body]:
384 self[node].emit(stmt=stmt, level=level)
385 if node.orelse:
386 self[node].emit(stmt="} else {")
387 for (level, stmt) in self[node.orelse]:
388 self[node].emit(stmt=stmt, level=level)
389 self[node].emit(stmt="}")
390
391 @pc_util.Hook(pc_ast.IfExpr)
392 def IfExpr(self, node):
393 yield node
394 self[node].emit(stmt="if (")
395 with self[node]:
396 for (level, stmt) in self[node.test]:
397 self[node].emit(stmt=stmt, level=level)
398 self[node].emit(stmt=") {")
399 for (level, stmt) in self[node.body]:
400 self[node].emit(stmt=stmt, level=level)
401 if node.orelse:
402 self[node].emit(stmt="} else {")
403 for (level, stmt) in self[node.orelse]:
404 self[node].emit(stmt=stmt, level=level)
405 self[node].emit(stmt="}")
406
407 @pc_util.Hook(pc_ast.Call.Name)
408 def CallName(self, node):
409 yield node
410 self[node].emit(stmt=str(node))
411
412 @pc_util.Hook(pc_ast.Call.Arguments)
413 def CallArguments(self, node):
414 yield node
415 for subnode in node:
416 if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
417 self.__regfetch[str(subnode)].append(subnode)
418
419 @pc_util.Hook(pc_ast.Call)
420 def Call(self, node):
421 yield node
422 code = tuple(map(lambda arg: self[arg], node.args))
423 ccall = self.ccall(name=str(node.name), node=node, code=code)
424 for (level, stmt) in self[ccall]:
425 self[node].emit(stmt=stmt, level=level)
426
427 @pc_util.Hook(pc_ast.Symbol)
428 def Symbol(self, node):
429 yield node
430 self.__decls[str(node)].append(node)
431 with self.pseudocode(node=node):
432 self[node].emit(stmt=f"&{str(node)}")
433
434 @pc_util.Hook(pc_ast.Node)
435 def Node(self, node):
436 raise NotImplementedError(type(node))
437
438
439 def code(name, root):
440 yield from CodeVisitor(name=name, root=root)