vendor.lattice_{ecp5,machxo_2_3l}: remove -forceAll from Diamond scripts.
[nmigen.git] / nmigen / sim / _pyrtl.py
1 import os
2 import tempfile
3 from contextlib import contextmanager
4
5 from ..hdl import *
6 from ..hdl.ast import SignalSet
7 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
8 from ._base import BaseProcess
9
10
11 __all__ = ["PyRTLProcess"]
12
13
14 class PyRTLProcess(BaseProcess):
15 __slots__ = ("is_comb", "runnable", "passive", "run")
16
17 def __init__(self, *, is_comb):
18 self.is_comb = is_comb
19
20 self.reset()
21
22 def reset(self):
23 self.runnable = self.is_comb
24 self.passive = True
25
26
27 class _PythonEmitter:
28 def __init__(self):
29 self._buffer = []
30 self._suffix = 0
31 self._level = 0
32
33 def append(self, code):
34 self._buffer.append(" " * self._level)
35 self._buffer.append(code)
36 self._buffer.append("\n")
37
38 @contextmanager
39 def indent(self):
40 self._level += 1
41 yield
42 self._level -= 1
43
44 def flush(self, indent=""):
45 code = "".join(self._buffer)
46 self._buffer.clear()
47 return code
48
49 def gen_var(self, prefix):
50 name = f"{prefix}_{self._suffix}"
51 self._suffix += 1
52 return name
53
54 def def_var(self, prefix, value):
55 name = self.gen_var(prefix)
56 self.append(f"{name} = {value}")
57 return name
58
59
60 class _Compiler:
61 def __init__(self, state, emitter):
62 self.state = state
63 self.emitter = emitter
64
65
66 class _ValueCompiler(ValueVisitor, _Compiler):
67 helpers = {
68 "sign": lambda value, sign: value | sign if value & sign else value,
69 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
70 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
71 }
72
73 def on_ClockSignal(self, value):
74 raise NotImplementedError # :nocov:
75
76 def on_ResetSignal(self, value):
77 raise NotImplementedError # :nocov:
78
79 def on_AnyConst(self, value):
80 raise NotImplementedError # :nocov:
81
82 def on_AnySeq(self, value):
83 raise NotImplementedError # :nocov:
84
85 def on_Sample(self, value):
86 raise NotImplementedError # :nocov:
87
88 def on_Initial(self, value):
89 raise NotImplementedError # :nocov:
90
91
92 class _RHSValueCompiler(_ValueCompiler):
93 def __init__(self, state, emitter, *, mode, inputs=None):
94 super().__init__(state, emitter)
95 assert mode in ("curr", "next")
96 self.mode = mode
97 # If not None, `inputs` gets populated with RHS signals.
98 self.inputs = inputs
99
100 def on_Const(self, value):
101 return f"{value.value}"
102
103 def on_Signal(self, value):
104 if self.inputs is not None:
105 self.inputs.add(value)
106
107 if self.mode == "curr":
108 return f"slots[{self.state.get_signal(value)}].{self.mode}"
109 else:
110 return f"next_{self.state.get_signal(value)}"
111
112 def on_Operator(self, value):
113 def mask(value):
114 value_mask = (1 << len(value)) - 1
115 return f"({value_mask} & {self(value)})"
116
117 def sign(value):
118 if value.shape().signed:
119 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
120 else: # unsigned
121 return mask(value)
122
123 if len(value.operands) == 1:
124 arg, = value.operands
125 if value.operator == "~":
126 return f"(~{self(arg)})"
127 if value.operator == "-":
128 return f"(-{sign(arg)})"
129 if value.operator == "b":
130 return f"bool({mask(arg)})"
131 if value.operator == "r|":
132 return f"(0 != {mask(arg)})"
133 if value.operator == "r&":
134 return f"({(1 << len(arg)) - 1} == {mask(arg)})"
135 if value.operator == "r^":
136 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
137 return f"(format({mask(arg)}, 'b').count('1') % 2)"
138 if value.operator in ("u", "s"):
139 # These operators don't change the bit pattern, only its interpretation.
140 return self(arg)
141 elif len(value.operands) == 2:
142 lhs, rhs = value.operands
143 lhs_mask = (1 << len(lhs)) - 1
144 rhs_mask = (1 << len(rhs)) - 1
145 if value.operator == "+":
146 return f"({sign(lhs)} + {sign(rhs)})"
147 if value.operator == "-":
148 return f"({sign(lhs)} - {sign(rhs)})"
149 if value.operator == "*":
150 return f"({sign(lhs)} * {sign(rhs)})"
151 if value.operator == "//":
152 return f"zdiv({sign(lhs)}, {sign(rhs)})"
153 if value.operator == "%":
154 return f"zmod({sign(lhs)}, {sign(rhs)})"
155 if value.operator == "&":
156 return f"({self(lhs)} & {self(rhs)})"
157 if value.operator == "|":
158 return f"({self(lhs)} | {self(rhs)})"
159 if value.operator == "^":
160 return f"({self(lhs)} ^ {self(rhs)})"
161 if value.operator == "<<":
162 return f"({sign(lhs)} << {sign(rhs)})"
163 if value.operator == ">>":
164 return f"({sign(lhs)} >> {sign(rhs)})"
165 if value.operator == "==":
166 return f"({sign(lhs)} == {sign(rhs)})"
167 if value.operator == "!=":
168 return f"({sign(lhs)} != {sign(rhs)})"
169 if value.operator == "<":
170 return f"({sign(lhs)} < {sign(rhs)})"
171 if value.operator == "<=":
172 return f"({sign(lhs)} <= {sign(rhs)})"
173 if value.operator == ">":
174 return f"({sign(lhs)} > {sign(rhs)})"
175 if value.operator == ">=":
176 return f"({sign(lhs)} >= {sign(rhs)})"
177 elif len(value.operands) == 3:
178 if value.operator == "m":
179 sel, val1, val0 = value.operands
180 return f"({self(val1)} if {mask(sel)} else {self(val0)})"
181 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
182
183 def on_Slice(self, value):
184 return f"({(1 << len(value)) - 1} & ({self(value.value)} >> {value.start}))"
185
186 def on_Part(self, value):
187 offset_mask = (1 << len(value.offset)) - 1
188 offset = f"({value.stride} * ({offset_mask} & {self(value.offset)}))"
189 return f"({(1 << value.width) - 1} & " \
190 f"{self(value.value)} >> {offset})"
191
192 def on_Cat(self, value):
193 gen_parts = []
194 offset = 0
195 for part in value.parts:
196 part_mask = (1 << len(part)) - 1
197 gen_parts.append(f"(({part_mask} & {self(part)}) << {offset})")
198 offset += len(part)
199 if gen_parts:
200 return f"({' | '.join(gen_parts)})"
201 return f"0"
202
203 def on_Repl(self, value):
204 part_mask = (1 << len(value.value)) - 1
205 gen_part = self.emitter.def_var("repl", f"{part_mask} & {self(value.value)}")
206 gen_parts = []
207 offset = 0
208 for _ in range(value.count):
209 gen_parts.append(f"({gen_part} << {offset})")
210 offset += len(value.value)
211 if gen_parts:
212 return f"({' | '.join(gen_parts)})"
213 return f"0"
214
215 def on_ArrayProxy(self, value):
216 index_mask = (1 << len(value.index)) - 1
217 gen_index = self.emitter.def_var("rhs_index", f"{index_mask} & {self(value.index)}")
218 gen_value = self.emitter.gen_var("rhs_proxy")
219 if value.elems:
220 gen_elems = []
221 for index, elem in enumerate(value.elems):
222 if index == 0:
223 self.emitter.append(f"if {index} == {gen_index}:")
224 else:
225 self.emitter.append(f"elif {index} == {gen_index}:")
226 with self.emitter.indent():
227 self.emitter.append(f"{gen_value} = {self(elem)}")
228 self.emitter.append(f"else:")
229 with self.emitter.indent():
230 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
231 return gen_value
232 else:
233 return f"0"
234
235 @classmethod
236 def compile(cls, state, value, *, mode):
237 emitter = _PythonEmitter()
238 compiler = cls(state, emitter, mode=mode)
239 emitter.append(f"result = {compiler(value)}")
240 return emitter.flush()
241
242
243 class _LHSValueCompiler(_ValueCompiler):
244 def __init__(self, state, emitter, *, rhs, outputs=None):
245 super().__init__(state, emitter)
246 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
247 # the offset of a Part.
248 self.rrhs = rhs
249 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
250 # update of an lvalue.
251 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
252 # If not None, `outputs` gets populated with signals on LHS.
253 self.outputs = outputs
254
255 def on_Const(self, value):
256 raise TypeError # :nocov:
257
258 def on_Signal(self, value):
259 if self.outputs is not None:
260 self.outputs.add(value)
261
262 def gen(arg):
263 value_mask = (1 << len(value)) - 1
264 if value.shape().signed:
265 value_sign = f"sign({value_mask} & {arg}, {-1 << (len(value) - 1)})"
266 else: # unsigned
267 value_sign = f"{value_mask} & {arg}"
268 self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
269 return gen
270
271 def on_Operator(self, value):
272 raise TypeError # :nocov:
273
274 def on_Slice(self, value):
275 def gen(arg):
276 width_mask = (1 << (value.stop - value.start)) - 1
277 self(value.value)(f"({self.lrhs(value.value)} & " \
278 f"{~(width_mask << value.start)} | " \
279 f"(({width_mask} & {arg}) << {value.start}))")
280 return gen
281
282 def on_Part(self, value):
283 def gen(arg):
284 width_mask = (1 << value.width) - 1
285 offset_mask = (1 << len(value.offset)) - 1
286 offset = f"({value.stride} * ({offset_mask} & {self.rrhs(value.offset)}))"
287 self(value.value)(f"({self.lrhs(value.value)} & " \
288 f"~({width_mask} << {offset}) | " \
289 f"(({width_mask} & {arg}) << {offset}))")
290 return gen
291
292 def on_Cat(self, value):
293 def gen(arg):
294 gen_arg = self.emitter.def_var("cat", arg)
295 gen_parts = []
296 offset = 0
297 for part in value.parts:
298 part_mask = (1 << len(part)) - 1
299 self(part)(f"({part_mask} & ({gen_arg} >> {offset}))")
300 offset += len(part)
301 return gen
302
303 def on_Repl(self, value):
304 raise TypeError # :nocov:
305
306 def on_ArrayProxy(self, value):
307 def gen(arg):
308 index_mask = (1 << len(value.index)) - 1
309 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
310 if value.elems:
311 gen_elems = []
312 for index, elem in enumerate(value.elems):
313 if index == 0:
314 self.emitter.append(f"if {index} == {gen_index}:")
315 else:
316 self.emitter.append(f"elif {index} == {gen_index}:")
317 with self.emitter.indent():
318 self(elem)(arg)
319 self.emitter.append(f"else:")
320 with self.emitter.indent():
321 self(value.elems[-1])(arg)
322 else:
323 self.emitter.append(f"pass")
324 return gen
325
326
327 class _StatementCompiler(StatementVisitor, _Compiler):
328 def __init__(self, state, emitter, *, inputs=None, outputs=None):
329 super().__init__(state, emitter)
330 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
331 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
332
333 def on_statements(self, stmts):
334 for stmt in stmts:
335 self(stmt)
336 if not stmts:
337 self.emitter.append("pass")
338
339 def on_Assign(self, stmt):
340 gen_rhs = f"({(1 << len(stmt.rhs)) - 1} & {self.rhs(stmt.rhs)})"
341 if stmt.rhs.shape().signed:
342 gen_rhs = f"sign({gen_rhs}, {-1 << (len(stmt.rhs) - 1)})"
343 return self.lhs(stmt.lhs)(gen_rhs)
344
345 def on_Switch(self, stmt):
346 gen_test = self.emitter.def_var("test",
347 f"{(1 << len(stmt.test)) - 1} & {self.rhs(stmt.test)}")
348 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
349 gen_checks = []
350 if not patterns:
351 gen_checks.append(f"True")
352 else:
353 for pattern in patterns:
354 if "-" in pattern:
355 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
356 value = int("".join("0" if b == "-" else b for b in pattern), 2)
357 gen_checks.append(f"{value} == ({mask} & {gen_test})")
358 else:
359 value = int(pattern, 2)
360 gen_checks.append(f"{value} == {gen_test}")
361 if index == 0:
362 self.emitter.append(f"if {' or '.join(gen_checks)}:")
363 else:
364 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
365 with self.emitter.indent():
366 self(stmts)
367
368 def on_Assert(self, stmt):
369 raise NotImplementedError # :nocov:
370
371 def on_Assume(self, stmt):
372 raise NotImplementedError # :nocov:
373
374 def on_Cover(self, stmt):
375 raise NotImplementedError # :nocov:
376
377 @classmethod
378 def compile(cls, state, stmt):
379 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
380 emitter = _PythonEmitter()
381 for signal_index in output_indexes:
382 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
383 compiler = cls(state, emitter)
384 compiler(stmt)
385 for signal_index in output_indexes:
386 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
387 return emitter.flush()
388
389
390 class _FragmentCompiler:
391 def __init__(self, state):
392 self.state = state
393
394 def __call__(self, fragment):
395 processes = set()
396
397 for domain_name, domain_signals in fragment.drivers.items():
398 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
399 domain_process = PyRTLProcess(is_comb=domain_name is None)
400
401 emitter = _PythonEmitter()
402 emitter.append(f"def run():")
403 emitter._level += 1
404
405 if domain_name is None:
406 for signal in domain_signals:
407 signal_index = self.state.get_signal(signal)
408 emitter.append(f"next_{signal_index} = {signal.reset}")
409
410 inputs = SignalSet()
411 _StatementCompiler(self.state, emitter, inputs=inputs)(domain_stmts)
412
413 for input in inputs:
414 self.state.add_trigger(domain_process, input)
415
416 else:
417 domain = fragment.domains[domain_name]
418 clk_trigger = 1 if domain.clk_edge == "pos" else 0
419 self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
420 if domain.rst is not None and domain.async_reset:
421 rst_trigger = 1
422 self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
423
424 for signal in domain_signals:
425 signal_index = self.state.get_signal(signal)
426 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
427
428 _StatementCompiler(self.state, emitter)(domain_stmts)
429
430 for signal in domain_signals:
431 signal_index = self.state.get_signal(signal)
432 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
433
434 # There shouldn't be any exceptions raised by the generated code, but if there are
435 # (almost certainly due to a bug in the code generator), use this environment variable
436 # to make backtraces useful.
437 code = emitter.flush()
438 if os.getenv("NMIGEN_pysim_dump"):
439 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
440 file.write(code)
441 filename = file.name
442 else:
443 filename = "<string>"
444
445 exec_locals = {"slots": self.state.slots, **_ValueCompiler.helpers}
446 exec(compile(code, filename, "exec"), exec_locals)
447 domain_process.run = exec_locals["run"]
448
449 processes.add(domain_process)
450
451 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
452 if subfragment_name is None:
453 subfragment_name = "U${}".format(subfragment_index)
454 processes.update(self(subfragment))
455
456 return processes