47bf570eea017b6cc0858aefea1aedb95bfa9a49
[openpower-isa.git] / src / openpower / decoder / test / _pyrtl.py
1 import os
2 from contextlib import contextmanager
3
4 from nmigen.hdl.ast import SignalSet
5 from nmigen.hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
6 from nmigen.sim._base import BaseProcess
7
8 __all__ = ["PyRTLProcess"]
9
10
11 class PyRTLProcess(BaseProcess):
12 __slots__ = ("is_comb", "runnable", "passive", "name", "filename", "crtl", "run")
13
14 def __init__(self, *, is_comb):
15 self.is_comb = is_comb
16 self.reset()
17
18 def reset(self):
19 self.runnable = self.is_comb
20 self.passive = True
21
22
23 class _PythonEmitter:
24 def __init__(self):
25 self._buffer = []
26 self._suffix = 0
27 self._level = 0
28
29 def append(self, code):
30 self._buffer.append(" " * self._level)
31 self._buffer.append(code)
32 self._buffer.append("\n")
33
34 @contextmanager
35 def indent(self):
36 self._level += 1
37 yield
38 self._level -= 1
39
40 @contextmanager
41 def nest(self):
42 self.append(f"{{")
43 self._level += 1
44 #yield self.indent()
45 yield
46 self._level -= 1
47 self.append(f"}}")
48
49 def flush(self, indent=""):
50 code = "".join(self._buffer)
51 self._buffer.clear()
52 return code
53
54 def gen_var(self, prefix):
55 name = f"{prefix}_{self._suffix}"
56 self._suffix += 1
57 return name
58
59 def def_var(self, prefix, value):
60 name = self.gen_var(prefix)
61 self.append(f"uint64_t {name} = {value};")
62 return name
63
64 def assign(self, lhs, rhs):
65 self.append(f"{lhs} = {rhs}")
66
67 def if_(self, cond):
68 self.append(f"if ({cond})")
69
70 def else_if(self, cond):
71 self.append(f"else if ({cond})")
72
73 def else_(self):
74 self.append(f"else")
75
76
77 class _Compiler:
78 def __init__(self, state, emitter):
79 self.state = state
80 self.emitter = emitter
81
82
83 class _ValueCompiler(ValueVisitor, _Compiler):
84 helpers = {
85 "sign": lambda value, sign: value | sign if value & sign else value,
86 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
87 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
88 }
89
90 def on_ClockSignal(self, value):
91 raise NotImplementedError # :nocov:
92
93 def on_ResetSignal(self, value):
94 raise NotImplementedError # :nocov:
95
96 def on_AnyConst(self, value):
97 raise NotImplementedError # :nocov:
98
99 def on_AnySeq(self, value):
100 raise NotImplementedError # :nocov:
101
102 def on_Sample(self, value):
103 raise NotImplementedError # :nocov:
104
105 def on_Initial(self, value):
106 raise NotImplementedError # :nocov:
107
108
109 class _RHSValueCompiler(_ValueCompiler):
110 def __init__(self, state, emitter, *, mode, inputs=None):
111 super().__init__(state, emitter)
112 assert mode in ("curr", "next")
113 self.mode = mode
114 # If not None, `inputs` gets populated with RHS signals.
115 self.inputs = inputs
116
117 def on_Const(self, value):
118 return f"{value.value}"
119
120 def on_Signal(self, value):
121 if self.inputs is not None:
122 self.inputs.add(value)
123
124 if self.mode == "curr":
125 return f"slots[{self.state.get_signal(value)}].{self.mode}"
126 else:
127 return f"next_{self.state.get_signal(value)}"
128
129 def on_Operator(self, value):
130 def mask(value):
131 value_mask = (1 << len(value)) - 1
132 return f"({value_mask} & {self(value)})"
133
134 def sign(value):
135 if value.shape().signed:
136 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
137 else: # unsigned
138 return mask(value)
139
140 if len(value.operands) == 1:
141 arg, = value.operands
142 if value.operator == "~":
143 return f"(~{self(arg)})"
144 if value.operator == "-":
145 return f"(-{sign(arg)})"
146 if value.operator == "b":
147 return f"bool({mask(arg)})"
148 if value.operator == "r|":
149 return f"(0 != {mask(arg)})"
150 if value.operator == "r&":
151 return f"({(1 << len(arg)) - 1} == {mask(arg)})"
152 if value.operator == "r^":
153 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
154 return f"(format({mask(arg)}, 'b').count('1') % 2)"
155 if value.operator in ("u", "s"):
156 # These operators don't change the bit pattern, only its interpretation.
157 return self(arg)
158 elif len(value.operands) == 2:
159 lhs, rhs = value.operands
160 if value.operator == "+":
161 return f"({sign(lhs)} + {sign(rhs)})"
162 if value.operator == "-":
163 return f"({sign(lhs)} - {sign(rhs)})"
164 if value.operator == "*":
165 return f"({sign(lhs)} * {sign(rhs)})"
166 if value.operator == "//":
167 return f"zdiv({sign(lhs)}, {sign(rhs)})"
168 if value.operator == "%":
169 return f"zmod({sign(lhs)}, {sign(rhs)})"
170 if value.operator == "&":
171 return f"({self(lhs)} & {self(rhs)})"
172 if value.operator == "|":
173 return f"({self(lhs)} | {self(rhs)})"
174 if value.operator == "^":
175 return f"({self(lhs)} ^ {self(rhs)})"
176 if value.operator == "<<":
177 return f"({sign(lhs)} << {sign(rhs)})"
178 if value.operator == ">>":
179 return f"({sign(lhs)} >> {sign(rhs)})"
180 if value.operator == "==":
181 return f"({sign(lhs)} == {sign(rhs)})"
182 if value.operator == "!=":
183 return f"({sign(lhs)} != {sign(rhs)})"
184 if value.operator == "<":
185 return f"({sign(lhs)} < {sign(rhs)})"
186 if value.operator == "<=":
187 return f"({sign(lhs)} <= {sign(rhs)})"
188 if value.operator == ">":
189 return f"({sign(lhs)} > {sign(rhs)})"
190 if value.operator == ">=":
191 return f"({sign(lhs)} >= {sign(rhs)})"
192 elif len(value.operands) == 3:
193 if value.operator == "m":
194 sel, val1, val0 = value.operands
195 return f"(({mask(sel)}) ? ({self(val1)}) : ({self(val0)}))"
196 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
197
198 def on_Slice(self, value):
199 return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
200
201 def on_Part(self, value):
202 offset_mask = (1 << len(value.offset)) - 1
203 offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
204 return f"({(1 << value.width) - 1} & " \
205 f"{self(value.value)} >> {offset})"
206
207 def on_Cat(self, value):
208 gen_parts = []
209 offset = 0
210 for part in value.parts:
211 part_mask = (1 << len(part)) - 1
212 gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
213 offset += len(part)
214 if gen_parts:
215 return f"({' | '.join(gen_parts)})"
216 return f"0"
217
218 def on_Repl(self, value):
219 part_mask = (1 << len(value.value)) - 1
220 gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
221 gen_parts = []
222 offset = 0
223 for _ in range(value.count):
224 gen_parts.append(f"({gen_part} << {offset})")
225 offset += len(value.value)
226 if gen_parts:
227 return f"({' | '.join(gen_parts)})"
228 return f"0"
229
230 def on_ArrayProxy(self, value):
231 index_mask = (1 << len(value.index)) - 1
232 gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
233 gen_value = self.emitter.gen_var("rhs_proxy")
234 if value.elems:
235 for index, elem in enumerate(value.elems):
236 if index == 0:
237 self.emitter.if_(f"{index} == {gen_index}")
238 else:
239 self.emitter.else_if(f"{index} == {gen_index}")
240 with self.emitter.nest():
241 self.emitter.assign(f"{gen_value}", f"{self(elem)}")
242 self.emitter.else_()
243 with self.emitter.nest():
244 self.emitter.assign(f"{gen_value}", f"{self(value.elems[-1])}")
245 return gen_value
246 else:
247 return f"0"
248
249 @classmethod
250 def compile(cls, state, value, *, mode):
251 emitter = _PythonEmitter()
252 compiler = cls(state, emitter, mode=mode)
253 emitter.assign(f"result", f"{compiler(value)}")
254 return emitter.flush()
255
256
257 class _LHSValueCompiler(_ValueCompiler):
258 def __init__(self, state, emitter, *, rhs, outputs=None):
259 super().__init__(state, emitter)
260 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
261 # the offset of a Part.
262 self.rrhs = rhs
263 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
264 # update of an lvalue.
265 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
266 # If not None, `outputs` gets populated with signals on LHS.
267 self.outputs = outputs
268
269 def on_Const(self, value):
270 raise TypeError # :nocov:
271
272 def on_Signal(self, value):
273 if self.outputs is not None:
274 self.outputs.add(value)
275
276 def gen(arg):
277 value_mask = (1 << len(value)) - 1
278 name = ''
279 # TODO: useful trick, actually put the name into the c code
280 # but this has to be done consistently right across the board.
281 # all occurrences of next_{....} have to use the same trick
282 # but at least then the names in the auto-generated c-code
283 # are readable...
284 #if hasattr(value, "name") and value.name is not None:
285 # name = value.name
286 if value.shape().signed:
287 value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
288 else: # unsigned
289 value_sign = f"{value_mask} & {arg}"
290 self.emitter.append(f"next_{name}{self.state.get_signal(value)} = {value_sign};")
291 return gen
292
293 def on_Operator(self, value):
294 raise TypeError # :nocov:
295
296 def on_Slice(self, value):
297 def gen(arg):
298 width_mask = (1 << (value.stop - value.start)) - 1
299 self(value.value)(f"({self.lrhs(value.value)} & " \
300 f"{~(width_mask << value.start)} | " \
301 f"(({width_mask} & {arg}) << {value.start}))")
302 return gen
303
304 def on_Part(self, value):
305 def gen(arg):
306 width_mask = (1 << value.width) - 1
307 offset_mask = (1 << len(value.offset)) - 1
308 offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
309 self(value.value)(f"({self.lrhs(value.value)} & " \
310 f"~({width_mask} << {offset}) | " \
311 f"(({width_mask} & {arg}) << {offset}))")
312 return gen
313
314 def on_Cat(self, value):
315 def gen(arg):
316 gen_arg = self.emitter.def_var("cat", arg)
317 offset = 0
318 for part in value.parts:
319 part_mask = (1 << len(part)) - 1
320 self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
321 offset += len(part)
322 return gen
323
324 def on_Repl(self, value):
325 raise TypeError # :nocov:
326
327 def on_ArrayProxy(self, value):
328 def gen(arg):
329 index_mask = (1 << len(value.index)) - 1
330 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
331 if value.elems:
332 for index, elem in enumerate(value.elems):
333 if index == 0:
334 self.emitter.if_(f"{index} == {gen_index}")
335 else:
336 self.emitter.append(f"{index} == {gen_index}")
337 with self.emitter.nest():
338 self(elem)(arg)
339 self.emitter.else_
340 with self.emitter.nest():
341 self(value.elems[-1])(arg)
342 return gen
343
344
345 class _StatementCompiler(StatementVisitor, _Compiler):
346 def __init__(self, state, emitter, *, inputs=None, outputs=None):
347 super().__init__(state, emitter)
348 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
349 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
350
351 def on_statements(self, stmts):
352 for stmt in stmts:
353 self(stmt)
354 if not stmts:
355 self.emitter.append("pass")
356
357 def on_Assign(self, stmt):
358 gen_rhs = f"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
359 if stmt.rhs.shape().signed:
360 gen_rhs = f"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
361 return self.lhs(stmt.lhs)(gen_rhs)
362
363 def on_Switch(self, stmt):
364 gen_test = self.emitter.def_var("test",
365 f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
366 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
367 gen_checks = []
368 if not patterns:
369 gen_checks.append(f"True")
370 else:
371 for pattern in patterns:
372 if "-" in pattern:
373 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
374 value = int("".join("0" if b == "-" else b for b in pattern), 2)
375 gen_checks.append(f"{value} == ({mask} & {gen_test})")
376 else:
377 value = int(pattern, 2)
378 gen_checks.append(f"{value} == {gen_test}")
379 if index == 0:
380 self.emitter.if_(f"{' or '.join(gen_checks)}")
381 else:
382 self.emitter.else_if(f"{' or '.join(gen_checks)}")
383 with self.emitter.nest():
384 self(stmts)
385
386 def on_Display(self, stmt):
387 raise NotImplementedError # :nocov:
388
389 def on_Assert(self, stmt):
390 raise NotImplementedError # :nocov:
391
392 def on_Assume(self, stmt):
393 raise NotImplementedError # :nocov:
394
395 def on_Cover(self, stmt):
396 raise NotImplementedError # :nocov:
397
398 @classmethod
399 def compile(cls, state, stmt):
400 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
401 emitter = _PythonEmitter()
402 for signal_index in output_indexes:
403 emitter.append(f"uint64_t next_{signal_index} = slots[{signal_index}].next")
404 compiler = cls(state, emitter)
405 compiler(stmt)
406 for signal_index in output_indexes:
407 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
408 return emitter.flush()
409
410
411 class _FragmentCompiler:
412 def __init__(self, state):
413 self.state = state
414
415 def __call__(self, fragment):
416 processes = set()
417
418 for index, (domain_name, domain_signals) in enumerate(fragment.drivers.items()):
419 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
420 domain_process = PyRTLProcess(is_comb=domain_name is None)
421
422 domain_process.name = f"{id(fragment)}_{domain_name or ''}_{index}"
423
424 emitter = _PythonEmitter()
425 emitter.append(f"void run_{domain_process.name}(void)")
426 with emitter.nest():
427 if domain_name is None:
428 for signal in domain_signals:
429 signal_index = self.state.get_signal(signal)
430 emitter.append(f"uint64_t next_{signal_index} = {signal.reset};")
431
432 inputs = SignalSet()
433 _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
434
435 for input in inputs:
436 self.state.add_trigger(domain_process, input)
437
438 else:
439 domain = fragment.domains[domain_name]
440 clk_trigger = 1 if domain.clk_edge == "pos" else 0
441 self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
442 if domain.rst is not None and domain.async_reset:
443 rst_trigger = 1
444 self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
445
446 for signal in domain_signals:
447 signal_index = self.state.get_signal(signal)
448 emitter.append(f"uint64_t next_{signal_index} = slots[{signal_index}].next;")
449
450 _StatementCompiler(self.state, emitter)(domain_stmts)
451
452 for signal in domain_signals:
453 signal_index = self.state.get_signal(signal)
454 emitter.append(f"set({signal_index}, next_{signal_index});")
455
456 code = "#include <stdint.h>\n"
457 code += "#include \"common.h\"\n"
458 code += emitter.flush()
459
460 file = open(f"crtl/{domain_process.name}.c", "w")
461 file.write(code)
462 file.close()
463
464 processes.add(domain_process)
465
466 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
467 if subfragment_name is None:
468 subfragment_name = "U${}".format(subfragment_index)
469 processes.update(self(subfragment))
470
471 return processes