oppc/code: drop explicit ctx argument
[openpower-isa.git] / src / openpower / oppc / pc_code.py
1 import collections
2 import contextlib
3
4 import openpower.oppc.pc_ast as pc_ast
5 import openpower.oppc.pc_util as pc_util
6 import openpower.oppc.pc_pseudocode as pc_pseudocode
7
8
9 class Transient(pc_ast.Node):
10 def __init__(self, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
11 self.__value = value
12 self.__bits = bits
13
14 return super().__init__()
15
16 def __str__(self):
17 return f"oppc_transient(&(struct oppc_int){{}}, {self.__value}, {self.__bits})"
18
19
20 class CCall(pc_ast.Dataclass):
21 name: str
22 code: tuple
23 stmt: bool
24
25
26 class CodeVisitor(pc_util.Visitor):
27 def __init__(self, name, root):
28 self.__root = root
29 self.__header = object()
30 self.__footer = object()
31 self.__code = collections.defaultdict(lambda: pc_util.Code())
32 self.__decls = collections.defaultdict(list)
33 self.__regfetch = collections.defaultdict(list)
34 self.__regstore = collections.defaultdict(list)
35 self.__pseudocode = pc_pseudocode.PseudocodeVisitor(root=root)
36
37 super().__init__(root=root)
38
39 self.__code[self.__header].emit(stmt="void")
40 self.__code[self.__header].emit(stmt=f"oppc_{name}(void) {{")
41 with self.__code[self.__header]:
42 for decl in self.__decls:
43 self.__code[self.__header].emit(stmt=f"struct oppc_int {decl};")
44 self.__code[self.__footer].emit(stmt=f"}}")
45
46 def __iter__(self):
47 yield from self.__code[self.__header]
48 yield from self.__code[self.__root]
49 yield from self.__code[self.__footer]
50
51 def __getitem__(self, node):
52 return self.__code[node]
53
54 def transient(self, node, value="UINT64_C(0)", bits="(uint8_t)OPPC_XLEN"):
55 transient = Transient(value=value, bits=bits)
56 with self.pseudocode(node=node):
57 self.traverse(root=transient)
58 return transient
59
60 def ccall(self, node, name, code, stmt=False):
61 def validate(item):
62 def validate(item):
63 (level, stmt) = item
64 if not isinstance(level, int):
65 raise ValueError(level)
66 if not isinstance(stmt, str):
67 raise ValueError(stmt)
68 return (level, stmt)
69
70 return tuple(map(validate, item))
71
72 code = tuple(map(validate, code))
73 ccall = CCall(name=name, code=code, stmt=stmt)
74 with self.pseudocode(node=node):
75 self.traverse(root=ccall)
76 return ccall
77
78 @contextlib.contextmanager
79 def pseudocode(self, node):
80 for (level, stmt) in self.__pseudocode[node]:
81 self[node].emit(stmt=f"/* {stmt} */", level=level)
82 yield
83
84 @pc_util.Hook(pc_ast.Scope)
85 def Scope(self, node):
86 yield node
87 with self[node]:
88 for subnode in node:
89 for (level, stmt) in self[subnode]:
90 self[node].emit(stmt=stmt, level=level)
91
92 @pc_util.Hook(pc_ast.AssignExpr, pc_ast.AssignIEAExpr)
93 def AssignExpr(self, node):
94 yield node
95 if isinstance(node.lvalue, (pc_ast.GPR, pc_ast.FPR)):
96 self.__regstore[str(node.lvalue)].append(node.lvalue)
97 if isinstance(node.rvalue, (pc_ast.GPR, pc_ast.FPR)):
98 self.__regfetch[str(node.rvalue)].append(node.rvalue)
99
100 rvalue = self[node.rvalue]
101 if isinstance(node.rvalue, pc_ast.IfExpr):
102 rvalue = [(0, " ".join([
103 str(self[node.rvalue.test]),
104 "?",
105 str(self[node.rvalue.body[0]]),
106 ":",
107 str(self[node.rvalue.orelse[0]]),
108 ]))]
109
110 if isinstance(node.lvalue, pc_ast.SubscriptExpr):
111 ccall = self.ccall(name="oppc_subscript_assign", node=node, stmt=True, code=[
112 self[node.lvalue.subject],
113 self[node.lvalue.index],
114 rvalue,
115 ])
116 elif isinstance(node.lvalue, pc_ast.RangeSubscriptExpr):
117 ccall = self.ccall(name="oppc_range_subscript_assign", node=node, stmt=True, code=[
118 self[node.lvalue.subject],
119 self[node.lvalue.start],
120 self[node.lvalue.end],
121 rvalue,
122 ])
123 else:
124 ccall = self.ccall(name="oppc_assign", stmt=True, node=node, code=[
125 self[node.lvalue],
126 rvalue,
127 ])
128 for (level, stmt) in self[ccall]:
129 self[node].emit(stmt=stmt, level=level)
130
131 @pc_util.Hook(pc_ast.BinaryExpr)
132 def BinaryExpr(self, node):
133 yield node
134 if isinstance(node.left, (pc_ast.GPR, pc_ast.FPR)):
135 self.__regfetch[str(node.left)].append(node.left)
136 if isinstance(node.right, (pc_ast.GPR, pc_ast.FPR)):
137 self.__regfetch[str(node.right)].append(node.left)
138
139 comparison = (
140 pc_ast.Lt, pc_ast.Le,
141 pc_ast.Eq, pc_ast.NotEq,
142 pc_ast.Ge, pc_ast.Gt,
143 )
144 if isinstance(node.op, comparison):
145 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
146 self[node.left],
147 self[node.right],
148 ])
149 else:
150 transient = self.transient(node=node)
151 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
152 self[transient],
153 self[node.left],
154 self[node.right],
155 ])
156 for (level, stmt) in self[ccall]:
157 self[node].emit(stmt=stmt, level=level)
158
159 @pc_util.Hook(pc_ast.UnaryExpr)
160 def UnaryExpr(self, node):
161 yield node
162 ccall = self.ccall(name=str(self[node.op]), node=node, code=[
163 self[node.value],
164 ])
165 for (level, stmt) in self[ccall]:
166 self[node].emit(stmt=stmt, level=level)
167
168 @pc_util.Hook(
169 pc_ast.Not, pc_ast.Add, pc_ast.Sub,
170 pc_ast.Mul, pc_ast.Div, pc_ast.Mod,
171 pc_ast.Lt, pc_ast.Le,
172 pc_ast.Eq, pc_ast.NotEq,
173 pc_ast.Ge, pc_ast.Gt,
174 pc_ast.LShift, pc_ast.RShift,
175 pc_ast.BitAnd, pc_ast.BitOr, pc_ast.BitXor,
176 )
177 def Op(self, node):
178 yield node
179 op = {
180 pc_ast.Not: "oppc_not",
181 pc_ast.Add: "oppc_add",
182 pc_ast.Sub: "oppc_sub",
183 pc_ast.Mul: "oppc_mul",
184 pc_ast.Div: "oppc_div",
185 pc_ast.Mod: "oppc_mod",
186 pc_ast.Lt: "oppc_lt",
187 pc_ast.Le: "oppc_le",
188 pc_ast.Eq: "oppc_eq",
189 pc_ast.NotEq: "oppc_noteq",
190 pc_ast.Ge: "oppc_ge",
191 pc_ast.Gt: "oppc_gt",
192 pc_ast.LShift: "oppc_lshift",
193 pc_ast.RShift: "oppc_rshift",
194 pc_ast.BitAnd: "oppc_and",
195 pc_ast.BitOr: "oppc_or",
196 pc_ast.BitXor: "oppc_xor",
197 }[node.__class__]
198 self[node].emit(stmt=op)
199
200 @pc_util.Hook(pc_ast.BinLiteral, pc_ast.DecLiteral, pc_ast.HexLiteral)
201 def Integer(self, node):
202 yield node
203 fmt = hex
204 value = str(node)
205 if isinstance(node, pc_ast.BinLiteral):
206 bits = f"UINT8_C({str(len(value[2:]))})"
207 value = int(value, 2)
208 elif isinstance(node, pc_ast.HexLiteral):
209 bits = f"UINT8_C({str(len(value[2:]) * 4)})"
210 value = int(value, 16)
211 else:
212 bits = "(uint8_t)OPPC_XLEN"
213 value = int(value)
214 fmt = str
215 if (value > ((2**64) - 1)):
216 raise NotImplementedError()
217 value = f"UINT64_C({fmt(value)})"
218 transient = self.transient(node=node, value=value, bits=bits)
219 for (level, stmt) in self[transient]:
220 self[node].emit(stmt=stmt, level=level)
221
222 @pc_util.Hook(Transient)
223 def Transient(self, node):
224 yield node
225 self[node].emit(stmt=str(node))
226
227 @pc_util.Hook(CCall)
228 def CCall(self, node):
229 yield node
230 end = (";" if node.stmt else "")
231 if len(node.code) == 0:
232 self[node].emit(stmt=f"{str(node.name)}(){end}")
233 else:
234 self[node].emit(stmt=f"{str(node.name)}(")
235 with self[node]:
236 (*head, tail) = node.code
237 for code in head:
238 for (level, stmt) in code:
239 self[node].emit(stmt=stmt, level=level)
240 (level, stmt) = self[node][-1]
241 if not (not stmt or
242 stmt.startswith("/*") or
243 stmt.endswith((",", "(", "{", "*/"))):
244 stmt = (stmt + ",")
245 self[node][-1] = (level, stmt)
246 for (level, stmt) in tail:
247 self[node].emit(stmt=stmt, level=level)
248 self[node].emit(stmt=f"){end}")
249
250 @pc_util.Hook(pc_ast.GPR)
251 def GPR(self, node):
252 yield node
253 with self.pseudocode(node=node):
254 self[node].emit(stmt=f"&OPPC_GPR[OPPC_GPR_{str(node)}]")
255
256 @pc_util.Hook(pc_ast.FPR)
257 def FPR(self, node):
258 yield node
259 with self.pseudocode(node=node):
260 self[node].emit(stmt=f"&OPPC_FPR[OPPC_FPR_{str(node)}]")
261
262 @pc_util.Hook(pc_ast.RepeatExpr)
263 def RepeatExpr(self, node):
264 yield node
265 transient = self.transient(node=node)
266 ccall = self.ccall(name="oppc_repeat", node=node, code=[
267 self[transient],
268 self[node.subject],
269 self[node.times],
270 ])
271 for (level, stmt) in self[ccall]:
272 self[node].emit(stmt=stmt, level=level)
273
274 @pc_util.Hook(pc_ast.XLEN)
275 def XLEN(self, node):
276 yield node
277 (value, bits) = ("OPPC_XLEN", "(uint8_t)OPPC_XLEN")
278 transient = self.transient(node=node, value=value, bits=bits)
279 for (level, stmt) in self[transient]:
280 self[node].emit(stmt=stmt, level=level)
281
282 @pc_util.Hook(pc_ast.SubscriptExpr)
283 def SubscriptExpr(self, node):
284 yield node
285 ccall = self.ccall(name="oppc_subscript", node=node, code=[
286 self[node.subject],
287 self[node.index],
288 ])
289 for (level, stmt) in self[ccall]:
290 self[node].emit(stmt=stmt, level=level)
291
292 @pc_util.Hook(pc_ast.RangeSubscriptExpr)
293 def RangeSubscriptExpr(self, node):
294 yield node
295 ccall = self.ccall(name="oppc_subscript", node=node, code=[
296 self[node.subject],
297 self[node.start],
298 self[node.end],
299 ])
300 for (level, stmt) in self[ccall]:
301 self[node].emit(stmt=stmt, level=level)
302
303 @pc_util.Hook(pc_ast.ForExpr)
304 def ForExpr(self, node):
305 yield node
306
307 enter = pc_ast.AssignExpr(
308 lvalue=node.subject.clone(),
309 rvalue=node.start.clone(),
310 )
311 match = pc_ast.BinaryExpr(
312 left=node.subject.clone(),
313 op=pc_ast.Le("<="),
314 right=node.end.clone(),
315 )
316 leave = pc_ast.AssignExpr(
317 lvalue=node.subject.clone(),
318 rvalue=pc_ast.BinaryExpr(
319 left=node.subject.clone(),
320 op=pc_ast.Add("+"),
321 right=node.end.clone(),
322 ),
323 )
324 with self.pseudocode(node=node):
325 (level, stmt) = self[node][0]
326 self[node].clear()
327 self[node].emit(stmt=stmt, level=level)
328 self[node].emit(stmt="for (")
329 with self[node]:
330 with self[node]:
331 for subnode in (enter, match, leave):
332 self.__pseudocode.traverse(root=subnode)
333 self.traverse(root=subnode)
334 for (level, stmt) in self[subnode]:
335 self[node].emit(stmt=stmt, level=level)
336 (level, stmt) = self[node][-1]
337 if subnode is match:
338 stmt = f"{stmt};"
339 elif subnode is leave:
340 stmt = stmt[:-1]
341 self[node][-1] = (level, stmt)
342 (level, stmt) = self[node][0]
343 self[node].emit(stmt=stmt, level=level)
344 self[node].emit(stmt=") {")
345 for (level, stmt) in self[node.body]:
346 self[node].emit(stmt=stmt, level=level)
347 self[node].emit(stmt="}")
348
349 @pc_util.Hook(pc_ast.WhileExpr)
350 def WhileExpr(self, node):
351 yield node
352 self[node].emit(stmt="while (")
353 with self[node]:
354 with self[node]:
355 for (level, stmt) in self[node.test]:
356 self[node].emit(stmt=stmt, level=level)
357 self[node].emit(") {")
358 for (level, stmt) in self[node.body]:
359 self[node].emit(stmt=stmt, level=level)
360 if node.orelse:
361 self[node].emit(stmt="} else {")
362 for (level, stmt) in self[node.orelse]:
363 self[node].emit(stmt=stmt, level=level)
364 self[node].emit(stmt="}")
365
366 @pc_util.Hook(pc_ast.IfExpr)
367 def IfExpr(self, node):
368 yield node
369 self[node].emit(stmt="if (")
370 with self[node]:
371 for (level, stmt) in self[node.test]:
372 self[node].emit(stmt=stmt, level=level)
373 self[node].emit(stmt=") {")
374 for (level, stmt) in self[node.body]:
375 self[node].emit(stmt=stmt, level=level)
376 if node.orelse:
377 self[node].emit(stmt="} else {")
378 for (level, stmt) in self[node.orelse]:
379 self[node].emit(stmt=stmt, level=level)
380 self[node].emit(stmt="}")
381
382 @pc_util.Hook(pc_ast.Call.Name)
383 def CallName(self, node):
384 yield node
385 self[node].emit(stmt=str(node))
386
387 @pc_util.Hook(pc_ast.Call.Arguments)
388 def CallArguments(self, node):
389 yield node
390 for subnode in node:
391 if isinstance(subnode, (pc_ast.GPR, pc_ast.FPR)):
392 self.__regfetch[str(subnode)].append(subnode)
393
394 @pc_util.Hook(pc_ast.Call)
395 def Call(self, node):
396 yield node
397 code = tuple(map(lambda arg: self[arg], node.args))
398 ccall = self.ccall(name=str(node.name), node=node, code=code)
399 for (level, stmt) in self[ccall]:
400 self[node].emit(stmt=stmt, level=level)
401
402 @pc_util.Hook(pc_ast.Symbol)
403 def Symbol(self, node):
404 yield node
405 self.__decls[str(node)].append(node)
406 self[node].emit(stmt=f"&{str(node)}")
407
408 @pc_util.Hook(pc_ast.Node)
409 def Node(self, node):
410 raise NotImplementedError(type(node))
411
412
413 def code(name, root):
414 yield from CodeVisitor(name=name, root=root)