9524399749114d870d91bcc839c559869defb275
[nmigen.git] / nmigen / back / pysim.py
1 import os
2 import tempfile
3 import warnings
4 import inspect
5 from contextlib import contextmanager
6 import itertools
7 from vcd import VCDWriter
8 from vcd.gtkw import GTKWSave
9
10 from .._utils import deprecated
11 from ..hdl.ast import *
12 from ..hdl.cd import *
13 from ..hdl.ir import *
14 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
15
16
17 class Command:
18 pass
19
20
21 class Settle(Command):
22 def __repr__(self):
23 return "(settle)"
24
25
26 class Delay(Command):
27 def __init__(self, interval=None):
28 self.interval = None if interval is None else float(interval)
29
30 def __repr__(self):
31 if self.interval is None:
32 return "(delay ε)"
33 else:
34 return "(delay {:.3}us)".format(self.interval * 1e6)
35
36
37 class Tick(Command):
38 def __init__(self, domain="sync"):
39 if not isinstance(domain, (str, ClockDomain)):
40 raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}"
41 .format(domain))
42 assert domain != "comb"
43 self.domain = domain
44
45 def __repr__(self):
46 return "(tick {})".format(self.domain)
47
48
49 class Passive(Command):
50 def __repr__(self):
51 return "(passive)"
52
53
54 class Active(Command):
55 def __repr__(self):
56 return "(active)"
57
58
59 class _WaveformWriter:
60 def update(self, timestamp, signal, value):
61 raise NotImplementedError # :nocov:
62
63 def close(self, timestamp):
64 raise NotImplementedError # :nocov:
65
66
67 class _VCDWaveformWriter(_WaveformWriter):
68 @staticmethod
69 def timestamp_to_vcd(timestamp):
70 return timestamp * (10 ** 10) # 1/(100 ps)
71
72 @staticmethod
73 def decode_to_vcd(signal, value):
74 return signal.decoder(value).expandtabs().replace(" ", "_")
75
76 def __init__(self, signal_names, *, vcd_file, gtkw_file=None, traces=()):
77 if isinstance(vcd_file, str):
78 vcd_file = open(vcd_file, "wt")
79 if isinstance(gtkw_file, str):
80 gtkw_file = open(gtkw_file, "wt")
81
82 self.vcd_vars = SignalDict()
83 self.vcd_file = vcd_file
84 self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
85 timescale="100 ps", comment="Generated by nMigen")
86
87 self.gtkw_names = SignalDict()
88 self.gtkw_file = gtkw_file
89 self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
90
91 self.traces = []
92
93 trace_names = SignalDict()
94 for trace in traces:
95 if trace not in signal_names:
96 trace_names[trace] = {("top", trace.name)}
97 self.traces.append(trace)
98
99 if self.vcd_writer is None:
100 return
101
102 for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
103 if signal.decoder:
104 var_type = "string"
105 var_size = 1
106 var_init = self.decode_to_vcd(signal, signal.reset)
107 else:
108 var_type = "wire"
109 var_size = signal.width
110 var_init = signal.reset
111
112 for (*var_scope, var_name) in names:
113 suffix = None
114 while True:
115 try:
116 if suffix is None:
117 var_name_suffix = var_name
118 else:
119 var_name_suffix = "{}${}".format(var_name, suffix)
120 vcd_var = self.vcd_writer.register_var(
121 scope=var_scope, name=var_name_suffix,
122 var_type=var_type, size=var_size, init=var_init)
123 break
124 except KeyError:
125 suffix = (suffix or 0) + 1
126
127 if signal not in self.vcd_vars:
128 self.vcd_vars[signal] = set()
129 self.vcd_vars[signal].add(vcd_var)
130
131 if signal not in self.gtkw_names:
132 self.gtkw_names[signal] = (*var_scope, var_name_suffix)
133
134 def update(self, timestamp, signal, value):
135 vcd_vars = self.vcd_vars.get(signal)
136 if vcd_vars is None:
137 return
138
139 vcd_timestamp = self.timestamp_to_vcd(timestamp)
140 if signal.decoder:
141 var_value = self.decode_to_vcd(signal, value)
142 else:
143 var_value = value
144 for vcd_var in vcd_vars:
145 self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
146
147 def close(self, timestamp):
148 if self.vcd_writer is not None:
149 self.vcd_writer.close(self.timestamp_to_vcd(timestamp))
150
151 if self.gtkw_save is not None:
152 self.gtkw_save.dumpfile(self.vcd_file.name)
153 self.gtkw_save.dumpfile_size(self.vcd_file.tell())
154
155 self.gtkw_save.treeopen("top")
156 for signal in self.traces:
157 if len(signal) > 1 and not signal.decoder:
158 suffix = "[{}:0]".format(len(signal) - 1)
159 else:
160 suffix = ""
161 self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix)
162
163 if self.vcd_file is not None:
164 self.vcd_file.close()
165 if self.gtkw_file is not None:
166 self.gtkw_file.close()
167
168
169 class _Process:
170 __slots__ = ("runnable", "passive")
171
172 def reset(self):
173 raise NotImplementedError # :nocov:
174
175 def run(self):
176 raise NotImplementedError # :nocov:
177
178 @property
179 def name(self):
180 raise NotImplementedError # :nocov:
181
182
183 class _SignalState:
184 __slots__ = ("signal", "curr", "next", "waiters", "pending")
185
186 def __init__(self, signal, pending):
187 self.signal = signal
188 self.pending = pending
189 self.waiters = dict()
190 self.curr = self.next = signal.reset
191
192 def set(self, value):
193 if self.next == value:
194 return
195 self.next = value
196 self.pending.add(self)
197
198 def wait(self, task, *, trigger=None):
199 assert task not in self.waiters
200 self.waiters[task] = trigger
201
202 def commit(self):
203 if self.curr == self.next:
204 return False
205 self.curr = self.next
206 return True
207
208 def wakeup(self):
209 awoken_any = False
210 for process, trigger in self.waiters.items():
211 if trigger is None or trigger == self.curr:
212 process.runnable = awoken_any = True
213 return awoken_any
214
215
216 class _SimulatorState:
217 def __init__(self):
218 self.signals = SignalDict()
219 self.slots = []
220 self.pending = set()
221
222 self.timestamp = 0.0
223 self.deadlines = dict()
224
225 self.waveform_writer = None
226
227 def reset(self):
228 for signal, index in self.signals.items():
229 self.slots[index].curr = self.slots[index].next = signal.reset
230 self.pending.clear()
231
232 self.timestamp = 0.0
233 self.deadlines.clear()
234
235 def get_signal(self, signal):
236 try:
237 return self.signals[signal]
238 except KeyError:
239 index = len(self.slots)
240 self.slots.append(_SignalState(signal, self.pending))
241 self.signals[signal] = index
242 return index
243
244 def get_in_signal(self, signal, *, trigger=None):
245 index = self.get_signal(signal)
246 self.slots[index].waiters[self] = trigger
247 return index
248
249 def get_out_signal(self, signal):
250 return self.get_signal(signal)
251
252 def for_signal(self, signal):
253 return self.slots[self.get_signal(signal)]
254
255 def commit(self):
256 converged = True
257 for signal_state in self.pending:
258 if signal_state.commit():
259 if signal_state.wakeup():
260 converged = False
261 if self.waveform_writer is not None:
262 self.waveform_writer.update(self.timestamp,
263 signal_state.signal, signal_state.curr)
264 self.pending.clear()
265 return converged
266
267 def advance(self):
268 nearest_processes = set()
269 nearest_deadline = None
270 for process, deadline in self.deadlines.items():
271 if deadline is None:
272 if nearest_deadline is not None:
273 nearest_processes.clear()
274 nearest_processes.add(process)
275 nearest_deadline = self.timestamp
276 break
277 elif nearest_deadline is None or deadline <= nearest_deadline:
278 assert deadline >= self.timestamp
279 if nearest_deadline is not None and deadline < nearest_deadline:
280 nearest_processes.clear()
281 nearest_processes.add(process)
282 nearest_deadline = deadline
283
284 if not nearest_processes:
285 return False
286
287 for process in nearest_processes:
288 process.runnable = True
289 del self.deadlines[process]
290 self.timestamp = nearest_deadline
291
292 return True
293
294
295 class _Emitter:
296 def __init__(self):
297 self._buffer = []
298 self._suffix = 0
299 self._level = 0
300
301 def append(self, code):
302 self._buffer.append(" " * self._level)
303 self._buffer.append(code)
304 self._buffer.append("\n")
305
306 @contextmanager
307 def indent(self):
308 self._level += 1
309 yield
310 self._level -= 1
311
312 def flush(self, indent=""):
313 code = "".join(self._buffer)
314 self._buffer.clear()
315 return code
316
317 def gen_var(self, prefix):
318 name = f"{prefix}_{self._suffix}"
319 self._suffix += 1
320 return name
321
322 def def_var(self, prefix, value):
323 name = self.gen_var(prefix)
324 self.append(f"{name} = {value}")
325 return name
326
327
328 class _Compiler:
329 def __init__(self, state, emitter):
330 self.state = state
331 self.emitter = emitter
332
333
334 class _ValueCompiler(ValueVisitor, _Compiler):
335 helpers = {
336 "sign": lambda value, sign: value | sign if value & sign else value,
337 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
338 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
339 }
340
341 def on_ClockSignal(self, value):
342 raise NotImplementedError # :nocov:
343
344 def on_ResetSignal(self, value):
345 raise NotImplementedError # :nocov:
346
347 def on_AnyConst(self, value):
348 raise NotImplementedError # :nocov:
349
350 def on_AnySeq(self, value):
351 raise NotImplementedError # :nocov:
352
353 def on_Sample(self, value):
354 raise NotImplementedError # :nocov:
355
356 def on_Initial(self, value):
357 raise NotImplementedError # :nocov:
358
359
360 class _RHSValueCompiler(_ValueCompiler):
361 def __init__(self, state, emitter, *, mode, inputs=None):
362 super().__init__(state, emitter)
363 assert mode in ("curr", "next")
364 self.mode = mode
365 # If not None, `inputs` gets populated with RHS signals.
366 self.inputs = inputs
367
368 def on_Const(self, value):
369 return f"{value.value}"
370
371 def on_Signal(self, value):
372 if self.inputs is not None:
373 self.inputs.add(value)
374
375 if self.mode == "curr":
376 return f"slots[{self.state.get_signal(value)}].{self.mode}"
377 else:
378 return f"next_{self.state.get_signal(value)}"
379
380 def on_Operator(self, value):
381 def mask(value):
382 value_mask = (1 << len(value)) - 1
383 return f"({self(value)} & {value_mask})"
384
385 def sign(value):
386 if value.shape().signed:
387 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
388 else: # unsigned
389 return mask(value)
390
391 if len(value.operands) == 1:
392 arg, = value.operands
393 if value.operator == "~":
394 return f"(~{self(arg)})"
395 if value.operator == "-":
396 return f"(-{self(arg)})"
397 if value.operator == "b":
398 return f"bool({mask(arg)})"
399 if value.operator == "r|":
400 return f"({mask(arg)} != 0)"
401 if value.operator == "r&":
402 return f"({mask(arg)} == {(1 << len(arg)) - 1})"
403 if value.operator == "r^":
404 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
405 return f"(format({mask(arg)}, 'b').count('1') % 2)"
406 if value.operator in ("u", "s"):
407 # These operators don't change the bit pattern, only its interpretation.
408 return self(arg)
409 elif len(value.operands) == 2:
410 lhs, rhs = value.operands
411 lhs_mask = (1 << len(lhs)) - 1
412 rhs_mask = (1 << len(rhs)) - 1
413 if value.operator == "+":
414 return f"({sign(lhs)} + {sign(rhs)})"
415 if value.operator == "-":
416 return f"({sign(lhs)} - {sign(rhs)})"
417 if value.operator == "*":
418 return f"({sign(lhs)} * {sign(rhs)})"
419 if value.operator == "//":
420 return f"zdiv({sign(lhs)}, {sign(rhs)})"
421 if value.operator == "%":
422 return f"zmod({sign(lhs)}, {sign(rhs)})"
423 if value.operator == "&":
424 return f"({self(lhs)} & {self(rhs)})"
425 if value.operator == "|":
426 return f"({self(lhs)} | {self(rhs)})"
427 if value.operator == "^":
428 return f"({self(lhs)} ^ {self(rhs)})"
429 if value.operator == "<<":
430 return f"({sign(lhs)} << {sign(rhs)})"
431 if value.operator == ">>":
432 return f"({sign(lhs)} >> {sign(rhs)})"
433 if value.operator == "==":
434 return f"({sign(lhs)} == {sign(rhs)})"
435 if value.operator == "!=":
436 return f"({sign(lhs)} != {sign(rhs)})"
437 if value.operator == "<":
438 return f"({sign(lhs)} < {sign(rhs)})"
439 if value.operator == "<=":
440 return f"({sign(lhs)} <= {sign(rhs)})"
441 if value.operator == ">":
442 return f"({sign(lhs)} > {sign(rhs)})"
443 if value.operator == ">=":
444 return f"({sign(lhs)} >= {sign(rhs)})"
445 elif len(value.operands) == 3:
446 if value.operator == "m":
447 sel, val1, val0 = value.operands
448 return f"({self(val1)} if {self(sel)} else {self(val0)})"
449 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
450
451 def on_Slice(self, value):
452 return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
453
454 def on_Part(self, value):
455 offset_mask = (1 << len(value.offset)) - 1
456 offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
457 return f"({self(value.value)} >> {offset} & " \
458 f"{(1 << value.width) - 1})"
459
460 def on_Cat(self, value):
461 gen_parts = []
462 offset = 0
463 for part in value.parts:
464 part_mask = (1 << len(part)) - 1
465 gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
466 offset += len(part)
467 if gen_parts:
468 return f"({' | '.join(gen_parts)})"
469 return f"0"
470
471 def on_Repl(self, value):
472 part_mask = (1 << len(value.value)) - 1
473 gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
474 gen_parts = []
475 offset = 0
476 for _ in range(value.count):
477 gen_parts.append(f"({gen_part} << {offset})")
478 offset += len(value.value)
479 if gen_parts:
480 return f"({' | '.join(gen_parts)})"
481 return f"0"
482
483 def on_ArrayProxy(self, value):
484 index_mask = (1 << len(value.index)) - 1
485 gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
486 gen_value = self.emitter.gen_var("rhs_proxy")
487 if value.elems:
488 gen_elems = []
489 for index, elem in enumerate(value.elems):
490 if index == 0:
491 self.emitter.append(f"if {gen_index} == {index}:")
492 else:
493 self.emitter.append(f"elif {gen_index} == {index}:")
494 with self.emitter.indent():
495 self.emitter.append(f"{gen_value} = {self(elem)}")
496 self.emitter.append(f"else:")
497 with self.emitter.indent():
498 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
499 return gen_value
500 else:
501 return f"0"
502
503 @classmethod
504 def compile(cls, state, value, *, mode, inputs=None):
505 emitter = _Emitter()
506 compiler = cls(state, emitter, mode=mode, inputs=inputs)
507 emitter.append(f"result = {compiler(value)}")
508 return emitter.flush()
509
510
511 class _LHSValueCompiler(_ValueCompiler):
512 def __init__(self, state, emitter, *, rhs, outputs=None):
513 super().__init__(state, emitter)
514 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
515 # the offset of a Part.
516 self.rrhs = rhs
517 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
518 # update of an lvalue.
519 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
520 # If not None, `outputs` gets populated with signals on LHS.
521 self.outputs = outputs
522
523 def on_Const(self, value):
524 raise TypeError # :nocov:
525
526 def on_Signal(self, value):
527 if self.outputs is not None:
528 self.outputs.add(value)
529
530 def gen(arg):
531 value_mask = (1 << len(value)) - 1
532 if value.shape().signed:
533 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
534 else: # unsigned
535 value_sign = f"{arg} & {value_mask}"
536 self.emitter.append(f"next_{self.state.get_out_signal(value)} = {value_sign}")
537 return gen
538
539 def on_Operator(self, value):
540 raise TypeError # :nocov:
541
542 def on_Slice(self, value):
543 def gen(arg):
544 width_mask = (1 << (value.stop - value.start)) - 1
545 self(value.value)(f"({self.lrhs(value.value)} & " \
546 f"{~(width_mask << value.start)} | " \
547 f"(({arg} & {width_mask}) << {value.start}))")
548 return gen
549
550 def on_Part(self, value):
551 def gen(arg):
552 width_mask = (1 << value.width) - 1
553 offset_mask = (1 << len(value.offset)) - 1
554 offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
555 self(value.value)(f"({self.lrhs(value.value)} & " \
556 f"~({width_mask} << {offset}) | " \
557 f"(({arg} & {width_mask}) << {offset}))")
558 return gen
559
560 def on_Cat(self, value):
561 def gen(arg):
562 gen_arg = self.emitter.def_var("cat", arg)
563 gen_parts = []
564 offset = 0
565 for part in value.parts:
566 part_mask = (1 << len(part)) - 1
567 self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
568 offset += len(part)
569 return gen
570
571 def on_Repl(self, value):
572 raise TypeError # :nocov:
573
574 def on_ArrayProxy(self, value):
575 def gen(arg):
576 index_mask = (1 << len(value.index)) - 1
577 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
578 if value.elems:
579 gen_elems = []
580 for index, elem in enumerate(value.elems):
581 if index == 0:
582 self.emitter.append(f"if {gen_index} == {index}:")
583 else:
584 self.emitter.append(f"elif {gen_index} == {index}:")
585 with self.emitter.indent():
586 self(elem)(arg)
587 self.emitter.append(f"else:")
588 with self.emitter.indent():
589 self(value.elems[-1])(arg)
590 else:
591 self.emitter.append(f"pass")
592 return gen
593
594 @classmethod
595 def compile(cls, state, stmt, *, inputs=None, outputs=None):
596 emitter = _Emitter()
597 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
598 compiler(stmt)
599 return emitter.flush()
600
601
602 class _StatementCompiler(StatementVisitor, _Compiler):
603 def __init__(self, state, emitter, *, inputs=None, outputs=None):
604 super().__init__(state, emitter)
605 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
606 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
607
608 def on_statements(self, stmts):
609 for stmt in stmts:
610 self(stmt)
611 if not stmts:
612 self.emitter.append("pass")
613
614 def on_Assign(self, stmt):
615 return self.lhs(stmt.lhs)(self.rhs(stmt.rhs))
616
617 def on_Switch(self, stmt):
618 gen_test = self.emitter.def_var("test",
619 f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
620 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
621 gen_checks = []
622 if not patterns:
623 gen_checks.append(f"True")
624 else:
625 for pattern in patterns:
626 if "-" in pattern:
627 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
628 value = int("".join("0" if b == "-" else b for b in pattern), 2)
629 gen_checks.append(f"({gen_test} & {mask}) == {value}")
630 else:
631 value = int(pattern, 2)
632 gen_checks.append(f"{gen_test} == {value}")
633 if index == 0:
634 self.emitter.append(f"if {' or '.join(gen_checks)}:")
635 else:
636 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
637 with self.emitter.indent():
638 self(stmts)
639
640 def on_Assert(self, stmt):
641 raise NotImplementedError # :nocov:
642
643 def on_Assume(self, stmt):
644 raise NotImplementedError # :nocov:
645
646 def on_Cover(self, stmt):
647 raise NotImplementedError # :nocov:
648
649 @classmethod
650 def compile(cls, state, stmt, *, inputs=None, outputs=None):
651 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
652 emitter = _Emitter()
653 for signal_index in output_indexes:
654 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
655 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
656 compiler(stmt)
657 for signal_index in output_indexes:
658 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
659 return emitter.flush()
660
661
662 class _CompiledProcess(_Process):
663 __slots__ = ("state", "comb", "name", "run")
664
665 def __init__(self, state, *, comb, name):
666 self.state = state
667 self.comb = comb
668 self.name = name
669 self.run = None # set by _FragmentCompiler
670 self.reset()
671
672 def reset(self):
673 self.runnable = self.comb
674 self.passive = True
675
676
677 class _FragmentCompiler:
678 def __init__(self, state, signal_names):
679 self.state = state
680 self.signal_names = signal_names
681
682 def __call__(self, fragment, *, hierarchy=("top",)):
683 processes = set()
684
685 def add_signal_name(signal):
686 hierarchical_signal_name = (*hierarchy, signal.name)
687 if signal not in self.signal_names:
688 self.signal_names[signal] = {hierarchical_signal_name}
689 else:
690 self.signal_names[signal].add(hierarchical_signal_name)
691
692 for domain_name, domain_signals in fragment.drivers.items():
693 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
694 domain_process = _CompiledProcess(self.state, comb=domain_name is None,
695 name=".".join((*hierarchy, "<{}>".format(domain_name or "comb"))))
696
697 emitter = _Emitter()
698 emitter.append(f"def run():")
699 emitter._level += 1
700
701 if domain_name is None:
702 for signal in domain_signals:
703 signal_index = domain_process.state.get_signal(signal)
704 emitter.append(f"next_{signal_index} = {signal.reset}")
705
706 inputs = SignalSet()
707 _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts)
708
709 for input in inputs:
710 self.state.for_signal(input).wait(domain_process)
711
712 else:
713 domain = fragment.domains[domain_name]
714 add_signal_name(domain.clk)
715 if domain.rst is not None:
716 add_signal_name(domain.rst)
717
718 clk_trigger = 1 if domain.clk_edge == "pos" else 0
719 self.state.for_signal(domain.clk).wait(domain_process, trigger=clk_trigger)
720 if domain.rst is not None and domain.async_reset:
721 rst_trigger = 1
722 self.state.for_signal(domain.rst).wait(domain_process, trigger=rst_trigger)
723
724 gen_asserts = []
725 clk_index = domain_process.state.get_signal(domain.clk)
726 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
727 if domain.rst is not None and domain.async_reset:
728 rst_index = domain_process.state.get_signal(domain.rst)
729 gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
730 emitter.append(f"assert {' or '.join(gen_asserts)}")
731
732 for signal in domain_signals:
733 signal_index = domain_process.state.get_signal(signal)
734 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
735
736 _StatementCompiler(domain_process.state, emitter)(domain_stmts)
737
738 for signal in domain_signals:
739 signal_index = domain_process.state.get_signal(signal)
740 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
741
742 # There shouldn't be any exceptions raised by the generated code, but if there are
743 # (almost certainly due to a bug in the code generator), use this environment variable
744 # to make backtraces useful.
745 code = emitter.flush()
746 if os.getenv("NMIGEN_pysim_dump"):
747 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
748 file.write(code)
749 filename = file.name
750 else:
751 filename = "<string>"
752
753 exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers}
754 exec(compile(code, filename, "exec"), exec_locals)
755 domain_process.run = exec_locals["run"]
756
757 processes.add(domain_process)
758
759 for used_signal in domain_process.state.signals:
760 add_signal_name(used_signal)
761
762 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
763 if subfragment_name is None:
764 subfragment_name = "U${}".format(subfragment_index)
765 processes.update(self(subfragment, hierarchy=(*hierarchy, subfragment_name)))
766
767 return processes
768
769
770 class _CoroutineProcess(_Process):
771 def __init__(self, state, domains, constructor, *, default_cmd=None):
772 self.state = state
773 self.domains = domains
774 self.constructor = constructor
775 self.default_cmd = default_cmd
776 self.reset()
777
778 def reset(self):
779 self.runnable = True
780 self.passive = False
781 self.coroutine = self.constructor()
782 self.exec_locals = {
783 "slots": self.state.slots,
784 "result": None,
785 **_ValueCompiler.helpers
786 }
787 self.waits_on = set()
788
789 @property
790 def name(self):
791 coroutine = self.coroutine
792 while coroutine.gi_yieldfrom is not None:
793 coroutine = coroutine.gi_yieldfrom
794 if inspect.isgenerator(coroutine):
795 frame = coroutine.gi_frame
796 if inspect.iscoroutine(coroutine):
797 frame = coroutine.cr_frame
798 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
799
800 def get_in_signal(self, signal, *, trigger=None):
801 signal_state = self.state.for_signal(signal)
802 assert self not in signal_state.waiters
803 signal_state.waiters[self] = trigger
804 self.waits_on.add(signal_state)
805 return signal_state
806
807 def run(self):
808 if self.coroutine is None:
809 return
810
811 if self.waits_on:
812 for signal_state in self.waits_on:
813 del signal_state.waiters[self]
814 self.waits_on.clear()
815
816 response = None
817 while True:
818 try:
819 command = self.coroutine.send(response)
820 if command is None:
821 command = self.default_cmd
822 response = None
823
824 if isinstance(command, Value):
825 exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
826 self.exec_locals)
827 response = Const.normalize(self.exec_locals["result"], command.shape())
828
829 elif isinstance(command, Statement):
830 exec(_StatementCompiler.compile(self.state, command),
831 self.exec_locals)
832
833 elif type(command) is Tick:
834 domain = command.domain
835 if isinstance(domain, ClockDomain):
836 pass
837 elif domain in self.domains:
838 domain = self.domains[domain]
839 else:
840 raise NameError("Received command {!r} that refers to a nonexistent "
841 "domain {!r} from process {!r}"
842 .format(command, command.domain, self.name))
843 self.get_in_signal(domain.clk, trigger=1 if domain.clk_edge == "pos" else 0)
844 if domain.rst is not None and domain.async_reset:
845 self.get_in_signal(domain.rst, trigger=1)
846 return
847
848 elif type(command) is Settle:
849 self.state.deadlines[self] = None
850 return
851
852 elif type(command) is Delay:
853 if command.interval is None:
854 self.state.deadlines[self] = None
855 else:
856 self.state.deadlines[self] = self.state.timestamp + command.interval
857 return
858
859 elif type(command) is Passive:
860 self.passive = True
861
862 elif type(command) is Active:
863 self.passive = False
864
865 elif command is None: # only possible if self.default_cmd is None
866 raise TypeError("Received default command from process {!r} that was added "
867 "with add_process(); did you mean to add this process with "
868 "add_sync_process() instead?"
869 .format(self.name))
870
871 else:
872 raise TypeError("Received unsupported command {!r} from process {!r}"
873 .format(command, self.name))
874
875 except StopIteration:
876 self.passive = True
877 self.coroutine = None
878 return
879
880 except Exception as exn:
881 self.coroutine.throw(exn)
882
883
884 class _WaveformContextManager:
885 def __init__(self, state, waveform_writer):
886 self._state = state
887 self._waveform_writer = waveform_writer
888
889 def __enter__(self):
890 try:
891 if self._state.timestamp != 0.0:
892 raise ValueError("Cannot start writing waveforms after advancing simulation time")
893 if self._state.waveform_writer is not None:
894 raise ValueError("Already writing waveforms to {!r}"
895 .format(self._state.waveform_writer))
896 self._state.waveform_writer = self._waveform_writer
897 except:
898 self._waveform_writer.close(0)
899 raise
900
901 def __exit__(self, *args):
902 if self._state.waveform_writer is None:
903 return
904 self._state.waveform_writer.close(self._state.timestamp)
905 self._state.waveform_writer = None
906
907
908 class Simulator:
909 def __init__(self, fragment):
910 self._state = _SimulatorState()
911 self._signal_names = SignalDict()
912 self._fragment = Fragment.get(fragment, platform=None).prepare()
913 self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment)
914 self._clocked = set()
915
916 def _check_process(self, process):
917 if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
918 raise TypeError("Cannot add a process {!r} because it is not a generator function"
919 .format(process))
920 return process
921
922 def _add_coroutine_process(self, process, *, default_cmd):
923 self._processes.add(_CoroutineProcess(self._state, self._fragment.domains, process,
924 default_cmd=default_cmd))
925
926 def add_process(self, process):
927 process = self._check_process(process)
928 def wrapper():
929 # Only start a bench process after comb settling, so that the reset values are correct.
930 yield Settle()
931 yield from process()
932 self._add_coroutine_process(wrapper, default_cmd=None)
933
934 def add_sync_process(self, process, *, domain="sync"):
935 process = self._check_process(process)
936 def wrapper():
937 # Only start a sync process after the first clock edge (or reset edge, if the domain
938 # uses an asynchronous reset). This matches the behavior of synchronous FFs.
939 yield Tick(domain)
940 yield from process()
941 return self._add_coroutine_process(wrapper, default_cmd=Tick(domain))
942
943 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
944 """Add a clock process.
945
946 Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle.
947
948 Arguments
949 ---------
950 period : float
951 Clock period. The process will toggle the ``domain`` clock signal every ``period / 2``
952 seconds.
953 phase : None or float
954 Clock phase. The process will wait ``phase`` seconds before the first clock transition.
955 If not specified, defaults to ``period / 2``.
956 domain : str or ClockDomain
957 Driven clock domain. If specified as a string, the domain with that name is looked up
958 in the root fragment of the simulation.
959 if_exists : bool
960 If ``False`` (the default), raise an error if the driven domain is specified as
961 a string and the root fragment does not have such a domain. If ``True``, do nothing
962 in this case.
963 """
964 if isinstance(domain, ClockDomain):
965 pass
966 elif domain in self._fragment.domains:
967 domain = self._fragment.domains[domain]
968 elif if_exists:
969 return
970 else:
971 raise ValueError("Domain {!r} is not present in simulation"
972 .format(domain))
973 if domain in self._clocked:
974 raise ValueError("Domain {!r} already has a clock driving it"
975 .format(domain.name))
976
977 half_period = period / 2
978 if phase is None:
979 # By default, delay the first edge by half period. This causes any synchronous activity
980 # to happen at a non-zero time, distinguishing it from the reset values in the waveform
981 # viewer.
982 phase = half_period
983 def clk_process():
984 yield Passive()
985 yield Delay(phase)
986 # Behave correctly if the process is added after the clock signal is manipulated, or if
987 # its reset state is high.
988 initial = (yield domain.clk)
989 steps = (
990 domain.clk.eq(~initial),
991 Delay(half_period),
992 domain.clk.eq(initial),
993 Delay(half_period),
994 )
995 while True:
996 yield from iter(steps)
997 self._add_coroutine_process(clk_process, default_cmd=None)
998 self._clocked.add(domain)
999
1000 def reset(self):
1001 """Reset the simulation.
1002
1003 Assign the reset value to every signal in the simulation, and restart every user process.
1004 """
1005 self._state.reset()
1006 for process in self._processes:
1007 process.reset()
1008
1009 def _real_step(self):
1010 """Step the simulation.
1011
1012 Run every process and commit changes until a fixed point is reached. If there is
1013 an unstable combinatorial loop, this function will never return.
1014 """
1015 # Performs the two phases of a delta cycle in a loop:
1016 converged = False
1017 while not converged:
1018 # 1. eval: run and suspend every non-waiting process once, queueing signal changes
1019 for process in self._processes:
1020 if process.runnable:
1021 process.runnable = False
1022 process.run()
1023
1024 # 2. commit: apply every queued signal change, waking up any waiting processes
1025 converged = self._state.commit()
1026
1027 # TODO(nmigen-0.4): replace with _real_step
1028 @deprecated("instead of `sim.step()`, use `sim.advance()`")
1029 def step(self):
1030 return self.advance()
1031
1032 def advance(self):
1033 """Advance the simulation.
1034
1035 Run every process and commit changes until a fixed point is reached, then advance time
1036 to the closest deadline (if any). If there is an unstable combinatorial loop,
1037 this function will never return.
1038
1039 Returns ``True`` if there are any active processes, ``False`` otherwise.
1040 """
1041 self._real_step()
1042 self._state.advance()
1043 return any(not process.passive for process in self._processes)
1044
1045 def run(self):
1046 """Run the simulation while any processes are active.
1047
1048 Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active,
1049 and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
1050 Processes compiled from HDL and added with :meth:`add_clock` are always passive.
1051 """
1052 while self.advance():
1053 pass
1054
1055 def run_until(self, deadline, *, run_passive=False):
1056 """Run the simulation until it advances to ``deadline``.
1057
1058 If ``run_passive`` is ``False``, the simulation also stops when there are no active
1059 processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it
1060 advances to or past ``deadline``.
1061
1062 If the simulation stops advancing, this function will never return.
1063 """
1064 assert self._state.timestamp <= deadline
1065 while (self.advance() or run_passive) and self._state.timestamp < deadline:
1066 pass
1067
1068 def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
1069 """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
1070
1071 This method returns a context manager. It can be used as: ::
1072
1073 sim = Simulator(frag)
1074 sim.add_clock(1e-6)
1075 with sim.write_vcd("dump.vcd", "dump.gtkw"):
1076 sim.run_until(1e-3)
1077
1078 Arguments
1079 ---------
1080 vcd_file : str or file-like object
1081 Verilog Value Change Dump file or filename.
1082 gtkw_file : str or file-like object
1083 GTKWave save file or filename.
1084 traces : iterable of Signal
1085 Signals to display traces for.
1086 """
1087 waveform_writer = _VCDWaveformWriter(self._signal_names,
1088 vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
1089 return _WaveformContextManager(self._state, waveform_writer)