hdl.ast: simplify Mux implementation.
[nmigen.git] / nmigen / sim / _pyrtl.py
1 import os
2 import tempfile
3 from contextlib import contextmanager
4
5 from ..hdl import *
6 from ..hdl.ast import SignalSet
7 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
8 from ._base import BaseProcess
9
10
11 __all__ = ["PyRTLProcess"]
12
13
14 class PyRTLProcess(BaseProcess):
15 __slots__ = ("is_comb", "runnable", "passive", "run")
16
17 def __init__(self, *, is_comb):
18 self.is_comb = is_comb
19
20 self.reset()
21
22 def reset(self):
23 self.runnable = self.is_comb
24 self.passive = True
25
26
27 class _PythonEmitter:
28 def __init__(self):
29 self._buffer = []
30 self._suffix = 0
31 self._level = 0
32
33 def append(self, code):
34 self._buffer.append(" " * self._level)
35 self._buffer.append(code)
36 self._buffer.append("\n")
37
38 @contextmanager
39 def indent(self):
40 self._level += 1
41 yield
42 self._level -= 1
43
44 def flush(self, indent=""):
45 code = "".join(self._buffer)
46 self._buffer.clear()
47 return code
48
49 def gen_var(self, prefix):
50 name = f"{prefix}_{self._suffix}"
51 self._suffix += 1
52 return name
53
54 def def_var(self, prefix, value):
55 name = self.gen_var(prefix)
56 self.append(f"{name} = {value}")
57 return name
58
59
60 class _Compiler:
61 def __init__(self, state, emitter):
62 self.state = state
63 self.emitter = emitter
64
65
66 class _ValueCompiler(ValueVisitor, _Compiler):
67 helpers = {
68 "sign": lambda value, sign: value | sign if value & sign else value,
69 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
70 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
71 }
72
73 def on_ClockSignal(self, value):
74 raise NotImplementedError # :nocov:
75
76 def on_ResetSignal(self, value):
77 raise NotImplementedError # :nocov:
78
79 def on_AnyConst(self, value):
80 raise NotImplementedError # :nocov:
81
82 def on_AnySeq(self, value):
83 raise NotImplementedError # :nocov:
84
85 def on_Sample(self, value):
86 raise NotImplementedError # :nocov:
87
88 def on_Initial(self, value):
89 raise NotImplementedError # :nocov:
90
91
92 class _RHSValueCompiler(_ValueCompiler):
93 def __init__(self, state, emitter, *, mode, inputs=None):
94 super().__init__(state, emitter)
95 assert mode in ("curr", "next")
96 self.mode = mode
97 # If not None, `inputs` gets populated with RHS signals.
98 self.inputs = inputs
99
100 def on_Const(self, value):
101 return f"{value.value}"
102
103 def on_Signal(self, value):
104 if self.inputs is not None:
105 self.inputs.add(value)
106
107 if self.mode == "curr":
108 return f"slots[{self.state.get_signal(value)}].{self.mode}"
109 else:
110 return f"next_{self.state.get_signal(value)}"
111
112 def on_Operator(self, value):
113 def mask(value):
114 value_mask = (1 << len(value)) - 1
115 return f"({value_mask} & {self(value)})"
116
117 def sign(value):
118 if value.shape().signed:
119 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
120 else: # unsigned
121 return mask(value)
122
123 if len(value.operands) == 1:
124 arg, = value.operands
125 if value.operator == "~":
126 return f"(~{self(arg)})"
127 if value.operator == "-":
128 return f"(-{sign(arg)})"
129 if value.operator == "b":
130 return f"bool({mask(arg)})"
131 if value.operator == "r|":
132 return f"(0 != {mask(arg)})"
133 if value.operator == "r&":
134 return f"({(1 << len(arg)) - 1} == {mask(arg)})"
135 if value.operator == "r^":
136 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
137 return f"(format({mask(arg)}, 'b').count('1') % 2)"
138 if value.operator in ("u", "s"):
139 # These operators don't change the bit pattern, only its interpretation.
140 return self(arg)
141 elif len(value.operands) == 2:
142 lhs, rhs = value.operands
143 if value.operator == "+":
144 return f"({sign(lhs)} + {sign(rhs)})"
145 if value.operator == "-":
146 return f"({sign(lhs)} - {sign(rhs)})"
147 if value.operator == "*":
148 return f"({sign(lhs)} * {sign(rhs)})"
149 if value.operator == "//":
150 return f"zdiv({sign(lhs)}, {sign(rhs)})"
151 if value.operator == "%":
152 return f"zmod({sign(lhs)}, {sign(rhs)})"
153 if value.operator == "&":
154 return f"({self(lhs)} & {self(rhs)})"
155 if value.operator == "|":
156 return f"({self(lhs)} | {self(rhs)})"
157 if value.operator == "^":
158 return f"({self(lhs)} ^ {self(rhs)})"
159 if value.operator == "<<":
160 return f"({sign(lhs)} << {sign(rhs)})"
161 if value.operator == ">>":
162 return f"({sign(lhs)} >> {sign(rhs)})"
163 if value.operator == "==":
164 return f"({sign(lhs)} == {sign(rhs)})"
165 if value.operator == "!=":
166 return f"({sign(lhs)} != {sign(rhs)})"
167 if value.operator == "<":
168 return f"({sign(lhs)} < {sign(rhs)})"
169 if value.operator == "<=":
170 return f"({sign(lhs)} <= {sign(rhs)})"
171 if value.operator == ">":
172 return f"({sign(lhs)} > {sign(rhs)})"
173 if value.operator == ">=":
174 return f"({sign(lhs)} >= {sign(rhs)})"
175 elif len(value.operands) == 3:
176 if value.operator == "m":
177 sel, val1, val0 = value.operands
178 return f"({self(val1)} if {mask(sel)} else {self(val0)})"
179 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
180
181 def on_Slice(self, value):
182 return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
183
184 def on_Part(self, value):
185 offset_mask = (1 << len(value.offset)) - 1
186 offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
187 return f"({(1 << value.width) - 1} & " \
188 f"{self(value.value)} >> {offset})"
189
190 def on_Cat(self, value):
191 gen_parts = []
192 offset = 0
193 for part in value.parts:
194 part_mask = (1 << len(part)) - 1
195 gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
196 offset += len(part)
197 if gen_parts:
198 return f"({' | '.join(gen_parts)})"
199 return f"0"
200
201 def on_Repl(self, value):
202 part_mask = (1 << len(value.value)) - 1
203 gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
204 gen_parts = []
205 offset = 0
206 for _ in range(value.count):
207 gen_parts.append(f"({gen_part} << {offset})")
208 offset += len(value.value)
209 if gen_parts:
210 return f"({' | '.join(gen_parts)})"
211 return f"0"
212
213 def on_ArrayProxy(self, value):
214 index_mask = (1 << len(value.index)) - 1
215 gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
216 gen_value = self.emitter.gen_var("rhs_proxy")
217 if value.elems:
218 for index, elem in enumerate(value.elems):
219 if index == 0:
220 self.emitter.append(f"if {index} == {gen_index}:")
221 else:
222 self.emitter.append(f"elif {index} == {gen_index}:")
223 with self.emitter.indent():
224 self.emitter.append(f"{gen_value} = {self(elem)}")
225 self.emitter.append(f"else:")
226 with self.emitter.indent():
227 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
228 return gen_value
229 else:
230 return f"0"
231
232 @classmethod
233 def compile(cls, state, value, *, mode):
234 emitter = _PythonEmitter()
235 compiler = cls(state, emitter, mode=mode)
236 emitter.append(f"result = {compiler(value)}")
237 return emitter.flush()
238
239
240 class _LHSValueCompiler(_ValueCompiler):
241 def __init__(self, state, emitter, *, rhs, outputs=None):
242 super().__init__(state, emitter)
243 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
244 # the offset of a Part.
245 self.rrhs = rhs
246 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
247 # update of an lvalue.
248 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
249 # If not None, `outputs` gets populated with signals on LHS.
250 self.outputs = outputs
251
252 def on_Const(self, value):
253 raise TypeError # :nocov:
254
255 def on_Signal(self, value):
256 if self.outputs is not None:
257 self.outputs.add(value)
258
259 def gen(arg):
260 value_mask = (1 << len(value)) - 1
261 if value.shape().signed:
262 value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
263 else: # unsigned
264 value_sign = f"{value_mask} & {arg}"
265 self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
266 return gen
267
268 def on_Operator(self, value):
269 raise TypeError # :nocov:
270
271 def on_Slice(self, value):
272 def gen(arg):
273 width_mask = (1 << (value.stop - value.start)) - 1
274 self(value.value)(f"({self.lrhs(value.value)} & " \
275 f"{~(width_mask << value.start)} | " \
276 f"(({width_mask} & {arg}) << {value.start}))")
277 return gen
278
279 def on_Part(self, value):
280 def gen(arg):
281 width_mask = (1 << value.width) - 1
282 offset_mask = (1 << len(value.offset)) - 1
283 offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
284 self(value.value)(f"({self.lrhs(value.value)} & " \
285 f"~({width_mask} << {offset}) | " \
286 f"(({width_mask} & {arg}) << {offset}))")
287 return gen
288
289 def on_Cat(self, value):
290 def gen(arg):
291 gen_arg = self.emitter.def_var("cat", arg)
292 offset = 0
293 for part in value.parts:
294 part_mask = (1 << len(part)) - 1
295 self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
296 offset += len(part)
297 return gen
298
299 def on_Repl(self, value):
300 raise TypeError # :nocov:
301
302 def on_ArrayProxy(self, value):
303 def gen(arg):
304 index_mask = (1 << len(value.index)) - 1
305 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
306 if value.elems:
307 for index, elem in enumerate(value.elems):
308 if index == 0:
309 self.emitter.append(f"if {index} == {gen_index}:")
310 else:
311 self.emitter.append(f"elif {index} == {gen_index}:")
312 with self.emitter.indent():
313 self(elem)(arg)
314 self.emitter.append(f"else:")
315 with self.emitter.indent():
316 self(value.elems[-1])(arg)
317 else:
318 self.emitter.append(f"pass")
319 return gen
320
321
322 class _StatementCompiler(StatementVisitor, _Compiler):
323 def __init__(self, state, emitter, *, inputs=None, outputs=None):
324 super().__init__(state, emitter)
325 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
326 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
327
328 def on_statements(self, stmts):
329 for stmt in stmts:
330 self(stmt)
331 if not stmts:
332 self.emitter.append("pass")
333
334 def on_Assign(self, stmt):
335 gen_rhs = f"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
336 if stmt.rhs.shape().signed:
337 gen_rhs = f"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
338 return self.lhs(stmt.lhs)(gen_rhs)
339
340 def on_Switch(self, stmt):
341 gen_test = self.emitter.def_var("test",
342 f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
343 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
344 gen_checks = []
345 if not patterns:
346 gen_checks.append(f"True")
347 else:
348 for pattern in patterns:
349 if "-" in pattern:
350 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
351 value = int("".join("0" if b == "-" else b for b in pattern), 2)
352 gen_checks.append(f"{value} == ({mask} & {gen_test})")
353 else:
354 value = int(pattern, 2)
355 gen_checks.append(f"{value} == {gen_test}")
356 if index == 0:
357 self.emitter.append(f"if {' or '.join(gen_checks)}:")
358 else:
359 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
360 with self.emitter.indent():
361 self(stmts)
362
363 def on_Assert(self, stmt):
364 raise NotImplementedError # :nocov:
365
366 def on_Assume(self, stmt):
367 raise NotImplementedError # :nocov:
368
369 def on_Cover(self, stmt):
370 raise NotImplementedError # :nocov:
371
372 @classmethod
373 def compile(cls, state, stmt):
374 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
375 emitter = _PythonEmitter()
376 for signal_index in output_indexes:
377 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
378 compiler = cls(state, emitter)
379 compiler(stmt)
380 for signal_index in output_indexes:
381 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
382 return emitter.flush()
383
384
385 class _FragmentCompiler:
386 def __init__(self, state):
387 self.state = state
388
389 def __call__(self, fragment):
390 processes = set()
391
392 for domain_name, domain_signals in fragment.drivers.items():
393 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
394 domain_process = PyRTLProcess(is_comb=domain_name is None)
395
396 emitter = _PythonEmitter()
397 emitter.append(f"def run():")
398 emitter._level += 1
399
400 if domain_name is None:
401 for signal in domain_signals:
402 signal_index = self.state.get_signal(signal)
403 emitter.append(f"next_{signal_index} = {signal.reset}")
404
405 inputs = SignalSet()
406 _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
407
408 for input in inputs:
409 self.state.add_trigger(domain_process, input)
410
411 else:
412 domain = fragment.domains[domain_name]
413 clk_trigger = 1 if domain.clk_edge == "pos" else 0
414 self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
415 if domain.rst is not None and domain.async_reset:
416 rst_trigger = 1
417 self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
418
419 for signal in domain_signals:
420 signal_index = self.state.get_signal(signal)
421 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
422
423 _StatementCompiler(self.state, emitter)(domain_stmts)
424
425 for signal in domain_signals:
426 signal_index = self.state.get_signal(signal)
427 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
428
429 # There shouldn't be any exceptions raised by the generated code, but if there are
430 # (almost certainly due to a bug in the code generator), use this environment variable
431 # to make backtraces useful.
432 code = emitter.flush()
433 if os.getenv("NMIGEN_pysim_dump"):
434 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
435 file.write(code)
436 filename = file.name
437 else:
438 filename = "<string>"
439
440 exec_locals = {"slots": self.state.slots, **_ValueCompiler.helpers}
441 exec(compile(code, filename, "exec"), exec_locals)
442 domain_process.run = exec_locals["run"]
443
444 processes.add(domain_process)
445
446 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
447 if subfragment_name is None:
448 subfragment_name = "U${}".format(subfragment_index)
449 processes.update(self(subfragment))
450
451 return processes