power_insn: refactor opcode matching
[openpower-isa.git] / src / openpower / decoder / test / _pyrtl.py
1 import os
2 from contextlib import contextmanager
3
4 from nmigen.hdl.ast import SignalSet
5 from nmigen.hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
6 from nmigen.sim._base import BaseProcess
7 from openpower.decoder.test.crtl_path import get_crtl_path
8
9 __all__ = ["PyRTLProcess"]
10
11
12 class PyRTLProcess(BaseProcess):
13 __slots__ = ("is_comb", "runnable", "passive", "name", "crtl", "run")
14
15 def __init__(self, *, is_comb):
16 self.is_comb = is_comb
17 self.reset()
18
19 def reset(self):
20 self.runnable = self.is_comb
21 self.passive = True
22
23
24 class _PythonEmitter:
25 def __init__(self):
26 self._buffer = []
27 self._suffix = 0
28 self._level = 0
29
30 def append(self, code):
31 self._buffer.append(" " * self._level)
32 self._buffer.append(code)
33 self._buffer.append("\n")
34
35 @contextmanager
36 def indent(self):
37 self._level += 1
38 yield
39 self._level -= 1
40
41 @contextmanager
42 def nest(self):
43 self.append(f"{{")
44 self._level += 1
45 #yield self.indent()
46 yield
47 self._level -= 1
48 self.append(f"}}")
49
50 def flush(self, indent=""):
51 code = "".join(self._buffer)
52 self._buffer.clear()
53 return code
54
55 def gen_var(self, prefix):
56 name = f"{prefix}_{self._suffix}"
57 self._suffix += 1
58 return name
59
60 def def_var(self, prefix, value):
61 name = self.gen_var(prefix)
62 self.append(f"uint64_t {name} = {value};")
63 return name
64
65 def assign(self, lhs, rhs):
66 self.append(f"{lhs} = {rhs}")
67
68 def if_(self, cond):
69 self.append(f"if ({cond})")
70
71 def else_if(self, cond):
72 self.append(f"else if ({cond})")
73
74 def else_(self):
75 self.append(f"else")
76
77
78 class _Compiler:
79 def __init__(self, state, emitter):
80 self.state = state
81 self.emitter = emitter
82
83
84 class _ValueCompiler(ValueVisitor, _Compiler):
85 helpers = {
86 "sign": lambda value, sign: value | sign if value & sign else value,
87 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
88 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
89 }
90
91 def on_ClockSignal(self, value):
92 raise NotImplementedError # :nocov:
93
94 def on_ResetSignal(self, value):
95 raise NotImplementedError # :nocov:
96
97 def on_AnyConst(self, value):
98 raise NotImplementedError # :nocov:
99
100 def on_AnySeq(self, value):
101 raise NotImplementedError # :nocov:
102
103 def on_Sample(self, value):
104 raise NotImplementedError # :nocov:
105
106 def on_Initial(self, value):
107 raise NotImplementedError # :nocov:
108
109
110 class _RHSValueCompiler(_ValueCompiler):
111 def __init__(self, state, emitter, *, mode, inputs=None):
112 super().__init__(state, emitter)
113 assert mode in ("curr", "next")
114 self.mode = mode
115 # If not None, `inputs` gets populated with RHS signals.
116 self.inputs = inputs
117
118 def on_SmtExpr(self, value):
119 raise NotImplementedError
120
121 def on_Const(self, value):
122 return f"{value.value}"
123
124 def on_Signal(self, value):
125 if self.inputs is not None:
126 self.inputs.add(value)
127
128 macro = self.state.get_signal_macro(value)
129 if self.mode == "curr":
130 return f"slots[{macro}].{self.mode}"
131 else:
132 return f"next_{macro}"
133
134 def on_Operator(self, value):
135 def mask(value):
136 value_mask = (1 << len(value)) - 1
137 return f"({value_mask} & {self(value)})"
138
139 def sign(value):
140 if value.shape().signed:
141 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
142 else: # unsigned
143 return mask(value)
144
145 if len(value.operands) == 1:
146 arg, = value.operands
147 if value.operator == "~":
148 return f"(~{self(arg)})"
149 if value.operator == "-":
150 return f"(-{sign(arg)})"
151 if value.operator == "b":
152 return f"!!({mask(arg)})"
153 if value.operator == "r|":
154 return f"(0 != {mask(arg)})"
155 if value.operator == "r&":
156 return f"({(1 << len(arg)) - 1} == {mask(arg)})"
157 if value.operator == "r^":
158 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
159 return f"(format({mask(arg)}, 'b').count('1') % 2)"
160 if value.operator in ("u", "s"):
161 # These operators don't change the bit pattern, only its interpretation.
162 return self(arg)
163 elif len(value.operands) == 2:
164 lhs, rhs = value.operands
165 if value.operator == "+":
166 return f"({sign(lhs)} + {sign(rhs)})"
167 if value.operator == "-":
168 return f"({sign(lhs)} - {sign(rhs)})"
169 if value.operator == "*":
170 return f"({sign(lhs)} * {sign(rhs)})"
171 if value.operator == "//":
172 return f"zdiv({sign(lhs)}, {sign(rhs)})"
173 if value.operator == "%":
174 return f"zmod({sign(lhs)}, {sign(rhs)})"
175 if value.operator == "&":
176 return f"({self(lhs)} & {self(rhs)})"
177 if value.operator == "|":
178 return f"({self(lhs)} | {self(rhs)})"
179 if value.operator == "^":
180 return f"({self(lhs)} ^ {self(rhs)})"
181 if value.operator == "<<":
182 return f"({sign(lhs)} << {sign(rhs)})"
183 if value.operator == ">>":
184 return f"({sign(lhs)} >> {sign(rhs)})"
185 if value.operator == "==":
186 return f"({sign(lhs)} == {sign(rhs)})"
187 if value.operator == "!=":
188 return f"({sign(lhs)} != {sign(rhs)})"
189 if value.operator == "<":
190 return f"({sign(lhs)} < {sign(rhs)})"
191 if value.operator == "<=":
192 return f"({sign(lhs)} <= {sign(rhs)})"
193 if value.operator == ">":
194 return f"({sign(lhs)} > {sign(rhs)})"
195 if value.operator == ">=":
196 return f"({sign(lhs)} >= {sign(rhs)})"
197 elif len(value.operands) == 3:
198 if value.operator == "m":
199 sel, val1, val0 = value.operands
200 return f"(({mask(sel)}) ? ({self(val1)}) : ({self(val0)}))"
201 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
202
203 def on_Slice(self, value):
204 return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
205
206 def on_Part(self, value):
207 offset_mask = (1 << len(value.offset)) - 1
208 offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
209 return f"({(1 << value.width) - 1} & " \
210 f"{self(value.value)} >> {offset})"
211
212 def on_Cat(self, value):
213 gen_parts = []
214 offset = 0
215 for part in value.parts:
216 part_mask = (1 << len(part)) - 1
217 gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
218 offset += len(part)
219 if gen_parts:
220 return f"({' | '.join(gen_parts)})"
221 return f"0"
222
223 def on_Repl(self, value):
224 part_mask = (1 << len(value.value)) - 1
225 gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
226 gen_parts = []
227 offset = 0
228 for _ in range(value.count):
229 gen_parts.append(f"({gen_part} << {offset})")
230 offset += len(value.value)
231 if gen_parts:
232 return f"({' | '.join(gen_parts)})"
233 return f"0"
234
235 def on_ArrayProxy(self, value):
236 index_mask = (1 << len(value.index)) - 1
237 gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
238 gen_value = self.emitter.gen_var("rhs_proxy")
239 if value.elems:
240 for index, elem in enumerate(value.elems):
241 if index == 0:
242 self.emitter.if_(f"{index} == {gen_index}")
243 else:
244 self.emitter.else_if(f"{index} == {gen_index}")
245 with self.emitter.nest():
246 self.emitter.assign(f"{gen_value}", f"{self(elem)}")
247 self.emitter.else_()
248 with self.emitter.nest():
249 self.emitter.assign(f"{gen_value}", f"{self(value.elems[-1])}")
250 return gen_value
251 else:
252 return f"0"
253
254 @classmethod
255 def compile(cls, state, value, *, mode):
256 emitter = _PythonEmitter()
257 compiler = cls(state, emitter, mode=mode)
258 emitter.assign(f"result", f"{compiler(value)}")
259 return emitter.flush()
260
261
262 class _LHSValueCompiler(_ValueCompiler):
263 def __init__(self, state, emitter, *, rhs, outputs=None):
264 super().__init__(state, emitter)
265 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
266 # the offset of a Part.
267 self.rrhs = rhs
268 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
269 # update of an lvalue.
270 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
271 # If not None, `outputs` gets populated with signals on LHS.
272 self.outputs = outputs
273
274 def on_SmtExpr(self, value):
275 raise NotImplementedError
276
277 def on_Const(self, value):
278 raise TypeError # :nocov:
279
280 def on_Signal(self, value):
281 if self.outputs is not None:
282 self.outputs.add(value)
283
284 def gen(arg):
285 value_mask = (1 << len(value)) - 1
286 if value.shape().signed:
287 value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
288 else: # unsigned
289 value_sign = f"{value_mask} & {arg}"
290
291 macro = self.state.get_signal_macro(value)
292 self.emitter.append(f"next_{macro} = {value_sign};")
293 return gen
294
295 def on_Operator(self, value):
296 raise TypeError # :nocov:
297
298 def on_Slice(self, value):
299 def gen(arg):
300 width_mask = (1 << (value.stop - value.start)) - 1
301 self(value.value)(f"(({self.lrhs(value.value)} & " \
302 f"{~(width_mask << value.start)}) | " \
303 f"(({width_mask} & {arg}) << {value.start}))")
304 return gen
305
306 def on_Part(self, value):
307 def gen(arg):
308 width_mask = (1 << value.width) - 1
309 offset_mask = (1 << len(value.offset)) - 1
310 offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
311 self(value.value)(f"({self.lrhs(value.value)} & " \
312 f"~({width_mask} << {offset}) | " \
313 f"(({width_mask} & {arg}) << {offset}))")
314 return gen
315
316 def on_Cat(self, value):
317 def gen(arg):
318 gen_arg = self.emitter.def_var("cat", arg)
319 offset = 0
320 for part in value.parts:
321 part_mask = (1 << len(part)) - 1
322 self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
323 offset += len(part)
324 return gen
325
326 def on_Repl(self, value):
327 raise TypeError # :nocov:
328
329 def on_ArrayProxy(self, value):
330 def gen(arg):
331 index_mask = (1 << len(value.index)) - 1
332 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
333 if value.elems:
334 for index, elem in enumerate(value.elems):
335 if index == 0:
336 self.emitter.if_(f"{index} == {gen_index}")
337 else:
338 self.emitter.append(f"{index} == {gen_index}")
339 with self.emitter.nest():
340 self(elem)(arg)
341 self.emitter.else_
342 with self.emitter.nest():
343 self(value.elems[-1])(arg)
344 return gen
345
346
347 class _StatementCompiler(StatementVisitor, _Compiler):
348 def __init__(self, state, emitter, *, inputs=None, outputs=None):
349 super().__init__(state, emitter)
350 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
351 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
352
353 def on_statements(self, stmts):
354 for stmt in stmts:
355 self(stmt)
356 if not stmts:
357 self.emitter.append("/* pass */;")
358
359 def on_Assign(self, stmt):
360 gen_rhs = f"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
361 if stmt.rhs.shape().signed:
362 gen_rhs = f"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
363 return self.lhs(stmt.lhs)(gen_rhs)
364
365 def on_Switch(self, stmt):
366 gen_test = self.emitter.def_var("test",
367 f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
368 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
369 gen_checks = []
370 if not patterns:
371 gen_checks.append(f"1 /* True */")
372 else:
373 for pattern in patterns:
374 if "-" in pattern:
375 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
376 value = int("".join("0" if b == "-" else b for b in pattern), 2)
377 gen_checks.append(f"{value} == ({mask} & {gen_test})")
378 else:
379 value = int(pattern, 2)
380 gen_checks.append(f"{value} == {gen_test}")
381 if index == 0:
382 self.emitter.if_(f"{' || '.join(gen_checks)}")
383 else:
384 self.emitter.else_if(f"{' || '.join(gen_checks)}")
385 with self.emitter.nest():
386 self(stmts)
387
388 def on_Display(self, stmt):
389 raise NotImplementedError # :nocov:
390
391 def on_Assert(self, stmt):
392 raise NotImplementedError # :nocov:
393
394 def on_Assume(self, stmt):
395 raise NotImplementedError # :nocov:
396
397 def on_Cover(self, stmt):
398 raise NotImplementedError # :nocov:
399
400 @classmethod
401 def compile(cls, state, stmt):
402 output_macros = \
403 [state.get_signal_macro(signal) for signal in stmt._lhs_signals()]
404 emitter = _PythonEmitter()
405 for macro in output_macros:
406 emitter.append(f"uint64_t next_{macro} = slots[{macro}].next")
407 compiler = cls(state, emitter)
408 compiler(stmt)
409 for macro in output_macros:
410 emitter.append(f"set({macro}, next_{macro})")
411 return emitter.flush()
412
413
414 class _FragmentCompiler:
415 def __init__(self, state):
416 self.state = state
417
418 def __call__(self, fragment, fragment_name):
419 processes = set()
420
421 for index, (domain_name, domain_signals) in enumerate(fragment.drivers.items()):
422 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
423 domain_process = PyRTLProcess(is_comb=domain_name is None)
424 domain_process.name = \
425 f"{fragment_name}__{domain_name or ''}" \
426 f"_{id(fragment)}_{index}"
427
428 emitter = _PythonEmitter()
429 emitter.append(f"void run_{domain_process.name}(void)")
430 with emitter.nest():
431 if domain_name is None:
432 for signal in domain_signals:
433 macro = self.state.get_signal_macro(signal)
434 emitter.append(
435 f"uint64_t next_{macro} = {signal.reset};")
436
437 inputs = SignalSet()
438 _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
439
440 for input in inputs:
441 self.state.add_trigger(domain_process, input)
442
443 else:
444 domain = fragment.domains[domain_name]
445 clk_trigger = 1 if domain.clk_edge == "pos" else 0
446 self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
447 if domain.rst is not None and domain.async_reset:
448 rst_trigger = 1
449 self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
450
451 for signal in domain_signals:
452 macro = self.state.get_signal_macro(signal)
453 emitter.append(
454 f"uint64_t next_{macro} = slots[{macro}].next;")
455
456 _StatementCompiler(self.state, emitter)(domain_stmts)
457
458 for signal in domain_signals:
459 macro = self.state.get_signal_macro(signal)
460 emitter.append(f"set({macro}, next_{macro});")
461
462 code = "#include <stdint.h>\n"
463 code += "#include \"common.h\"\n"
464 code += emitter.flush()
465
466 crtl = get_crtl_path()
467
468 file = open(os.path.join(crtl, f"{domain_process.name}.c"), "w")
469 file.write(code)
470 file.close()
471
472 processes.add(domain_process)
473
474 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
475 if subfragment_name is None:
476 subfragment_name = "U${}".format(subfragment_index)
477 processes.update(self(subfragment, subfragment_name))
478
479 return processes