c0db0271789aff790b93e4f4e4ecb22cb4390435
[nmigen.git] / nmigen / back / pysim.py
1 import os
2 import tempfile
3 import warnings
4 import inspect
5 from contextlib import contextmanager
6 import itertools
7 from vcd import VCDWriter
8 from vcd.gtkw import GTKWSave
9
10 from .._utils import deprecated
11 from ..hdl.ast import *
12 from ..hdl.cd import *
13 from ..hdl.ir import *
14 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
15
16
17 class Command:
18 pass
19
20
21 class Settle(Command):
22 def __repr__(self):
23 return "(settle)"
24
25
26 class Delay(Command):
27 def __init__(self, interval=None):
28 self.interval = None if interval is None else float(interval)
29
30 def __repr__(self):
31 if self.interval is None:
32 return "(delay ε)"
33 else:
34 return "(delay {:.3}us)".format(self.interval * 1e6)
35
36
37 class Tick(Command):
38 def __init__(self, domain="sync"):
39 if not isinstance(domain, (str, ClockDomain)):
40 raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}"
41 .format(domain))
42 assert domain != "comb"
43 self.domain = domain
44
45 def __repr__(self):
46 return "(tick {})".format(self.domain)
47
48
49 class Passive(Command):
50 def __repr__(self):
51 return "(passive)"
52
53
54 class Active(Command):
55 def __repr__(self):
56 return "(active)"
57
58
59 class _WaveformWriter:
60 def update(self, timestamp, signal, value):
61 raise NotImplementedError # :nocov:
62
63 def close(self, timestamp):
64 raise NotImplementedError # :nocov:
65
66
67 class _VCDWaveformWriter(_WaveformWriter):
68 @staticmethod
69 def timestamp_to_vcd(timestamp):
70 return timestamp * (10 ** 10) # 1/(100 ps)
71
72 @staticmethod
73 def decode_to_vcd(signal, value):
74 return signal.decoder(value).expandtabs().replace(" ", "_")
75
76 def __init__(self, signal_names, *, vcd_file, gtkw_file=None, traces=()):
77 if isinstance(vcd_file, str):
78 vcd_file = open(vcd_file, "wt")
79 if isinstance(gtkw_file, str):
80 gtkw_file = open(gtkw_file, "wt")
81
82 self.vcd_vars = SignalDict()
83 self.vcd_file = vcd_file
84 self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
85 timescale="100 ps", comment="Generated by nMigen")
86
87 self.gtkw_names = SignalDict()
88 self.gtkw_file = gtkw_file
89 self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
90
91 self.traces = []
92
93 trace_names = SignalDict()
94 for trace in traces:
95 if trace not in signal_names:
96 trace_names[trace] = {("top", trace.name)}
97 self.traces.append(trace)
98
99 if self.vcd_writer is None:
100 return
101
102 for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
103 if signal.decoder:
104 var_type = "string"
105 var_size = 1
106 var_init = self.decode_to_vcd(signal, signal.reset)
107 else:
108 var_type = "wire"
109 var_size = signal.width
110 var_init = signal.reset
111
112 for (*var_scope, var_name) in names:
113 suffix = None
114 while True:
115 try:
116 if suffix is None:
117 var_name_suffix = var_name
118 else:
119 var_name_suffix = "{}${}".format(var_name, suffix)
120 vcd_var = self.vcd_writer.register_var(
121 scope=var_scope, name=var_name_suffix,
122 var_type=var_type, size=var_size, init=var_init)
123 break
124 except KeyError:
125 suffix = (suffix or 0) + 1
126
127 if signal not in self.vcd_vars:
128 self.vcd_vars[signal] = set()
129 self.vcd_vars[signal].add(vcd_var)
130
131 if signal not in self.gtkw_names:
132 self.gtkw_names[signal] = (*var_scope, var_name_suffix)
133
134 def update(self, timestamp, signal, value):
135 vcd_vars = self.vcd_vars.get(signal)
136 if vcd_vars is None:
137 return
138
139 vcd_timestamp = self.timestamp_to_vcd(timestamp)
140 if signal.decoder:
141 var_value = self.decode_to_vcd(signal, value)
142 else:
143 var_value = value
144 for vcd_var in vcd_vars:
145 self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
146
147 def close(self, timestamp):
148 if self.vcd_writer is not None:
149 self.vcd_writer.close(self.timestamp_to_vcd(timestamp))
150
151 if self.gtkw_save is not None:
152 self.gtkw_save.dumpfile(self.vcd_file.name)
153 self.gtkw_save.dumpfile_size(self.vcd_file.tell())
154
155 self.gtkw_save.treeopen("top")
156 for signal in self.traces:
157 if len(signal) > 1 and not signal.decoder:
158 suffix = "[{}:0]".format(len(signal) - 1)
159 else:
160 suffix = ""
161 self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix)
162
163 if self.vcd_file is not None:
164 self.vcd_file.close()
165 if self.gtkw_file is not None:
166 self.gtkw_file.close()
167
168
169 class _Process:
170 __slots__ = ("runnable", "passive")
171
172 def reset(self):
173 raise NotImplementedError # :nocov:
174
175 def run(self):
176 raise NotImplementedError # :nocov:
177
178
179 class _SignalState:
180 __slots__ = ("signal", "curr", "next", "waiters", "pending")
181
182 def __init__(self, signal, pending):
183 self.signal = signal
184 self.pending = pending
185 self.waiters = dict()
186 self.curr = self.next = signal.reset
187
188 def set(self, value):
189 if self.next == value:
190 return
191 self.next = value
192 self.pending.add(self)
193
194 def wait(self, task, *, trigger=None):
195 assert task not in self.waiters
196 self.waiters[task] = trigger
197
198 def commit(self):
199 if self.curr == self.next:
200 return False
201 self.curr = self.next
202
203 awoken_any = False
204 for process, trigger in self.waiters.items():
205 if trigger is None or trigger == self.curr:
206 process.runnable = awoken_any = True
207 return awoken_any
208
209
210 class _SimulatorState:
211 def __init__(self):
212 self.signals = SignalDict()
213 self.slots = []
214 self.pending = set()
215
216 self.timestamp = 0.0
217 self.deadlines = dict()
218
219 def reset(self):
220 for signal, index in self.signals.items():
221 self.slots[index].curr = self.slots[index].next = signal.reset
222 self.pending.clear()
223
224 self.timestamp = 0.0
225 self.deadlines.clear()
226
227 def get_signal(self, signal):
228 try:
229 return self.signals[signal]
230 except KeyError:
231 index = len(self.slots)
232 self.slots.append(_SignalState(signal, self.pending))
233 self.signals[signal] = index
234 return index
235
236 def get_in_signal(self, signal, *, trigger=None):
237 index = self.get_signal(signal)
238 self.slots[index].waiters[self] = trigger
239 return index
240
241 def get_out_signal(self, signal):
242 return self.get_signal(signal)
243
244 def for_signal(self, signal):
245 return self.slots[self.get_signal(signal)]
246
247 def commit(self):
248 converged = True
249 for signal_state in self.pending:
250 if signal_state.commit():
251 converged = False
252 self.pending.clear()
253 return converged
254
255 def advance(self):
256 nearest_processes = set()
257 nearest_deadline = None
258 for process, deadline in self.deadlines.items():
259 if deadline is None:
260 if nearest_deadline is not None:
261 nearest_processes.clear()
262 nearest_processes.add(process)
263 nearest_deadline = self.timestamp
264 break
265 elif nearest_deadline is None or deadline <= nearest_deadline:
266 assert deadline >= self.timestamp
267 if nearest_deadline is not None and deadline < nearest_deadline:
268 nearest_processes.clear()
269 nearest_processes.add(process)
270 nearest_deadline = deadline
271
272 if not nearest_processes:
273 return False
274
275 for process in nearest_processes:
276 process.runnable = True
277 del self.deadlines[process]
278 self.timestamp = nearest_deadline
279
280 return True
281
282
283 class _Emitter:
284 def __init__(self):
285 self._buffer = []
286 self._suffix = 0
287 self._level = 0
288
289 def append(self, code):
290 self._buffer.append(" " * self._level)
291 self._buffer.append(code)
292 self._buffer.append("\n")
293
294 @contextmanager
295 def indent(self):
296 self._level += 1
297 yield
298 self._level -= 1
299
300 def flush(self, indent=""):
301 code = "".join(self._buffer)
302 self._buffer.clear()
303 return code
304
305 def gen_var(self, prefix):
306 name = f"{prefix}_{self._suffix}"
307 self._suffix += 1
308 return name
309
310 def def_var(self, prefix, value):
311 name = self.gen_var(prefix)
312 self.append(f"{name} = {value}")
313 return name
314
315
316 class _Compiler:
317 def __init__(self, state, emitter):
318 self.state = state
319 self.emitter = emitter
320
321
322 class _ValueCompiler(ValueVisitor, _Compiler):
323 helpers = {
324 "sign": lambda value, sign: value | sign if value & sign else value,
325 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
326 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
327 }
328
329 def on_ClockSignal(self, value):
330 raise NotImplementedError # :nocov:
331
332 def on_ResetSignal(self, value):
333 raise NotImplementedError # :nocov:
334
335 def on_AnyConst(self, value):
336 raise NotImplementedError # :nocov:
337
338 def on_AnySeq(self, value):
339 raise NotImplementedError # :nocov:
340
341 def on_Sample(self, value):
342 raise NotImplementedError # :nocov:
343
344 def on_Initial(self, value):
345 raise NotImplementedError # :nocov:
346
347
348 class _RHSValueCompiler(_ValueCompiler):
349 def __init__(self, state, emitter, *, mode, inputs=None):
350 super().__init__(state, emitter)
351 assert mode in ("curr", "next")
352 self.mode = mode
353 # If not None, `inputs` gets populated with RHS signals.
354 self.inputs = inputs
355
356 def on_Const(self, value):
357 return f"{value.value}"
358
359 def on_Signal(self, value):
360 if self.inputs is not None:
361 self.inputs.add(value)
362
363 if self.mode == "curr":
364 return f"slots[{self.state.get_signal(value)}].{self.mode}"
365 else:
366 return f"next_{self.state.get_signal(value)}"
367
368 def on_Operator(self, value):
369 def mask(value):
370 value_mask = (1 << len(value)) - 1
371 return f"({self(value)} & {value_mask})"
372
373 def sign(value):
374 if value.shape().signed:
375 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
376 else: # unsigned
377 return mask(value)
378
379 if len(value.operands) == 1:
380 arg, = value.operands
381 if value.operator == "~":
382 return f"(~{self(arg)})"
383 if value.operator == "-":
384 return f"(-{self(arg)})"
385 if value.operator == "b":
386 return f"bool({mask(arg)})"
387 if value.operator == "r|":
388 return f"({mask(arg)} != 0)"
389 if value.operator == "r&":
390 return f"({mask(arg)} == {(1 << len(arg)) - 1})"
391 if value.operator == "r^":
392 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
393 return f"(format({mask(arg)}, 'b').count('1') % 2)"
394 if value.operator in ("u", "s"):
395 # These operators don't change the bit pattern, only its interpretation.
396 return self(arg)
397 elif len(value.operands) == 2:
398 lhs, rhs = value.operands
399 lhs_mask = (1 << len(lhs)) - 1
400 rhs_mask = (1 << len(rhs)) - 1
401 if value.operator == "+":
402 return f"({sign(lhs)} + {sign(rhs)})"
403 if value.operator == "-":
404 return f"({sign(lhs)} - {sign(rhs)})"
405 if value.operator == "*":
406 return f"({sign(lhs)} * {sign(rhs)})"
407 if value.operator == "//":
408 return f"zdiv({sign(lhs)}, {sign(rhs)})"
409 if value.operator == "%":
410 return f"zmod({sign(lhs)}, {sign(rhs)})"
411 if value.operator == "&":
412 return f"({self(lhs)} & {self(rhs)})"
413 if value.operator == "|":
414 return f"({self(lhs)} | {self(rhs)})"
415 if value.operator == "^":
416 return f"({self(lhs)} ^ {self(rhs)})"
417 if value.operator == "<<":
418 return f"({sign(lhs)} << {sign(rhs)})"
419 if value.operator == ">>":
420 return f"({sign(lhs)} >> {sign(rhs)})"
421 if value.operator == "==":
422 return f"({sign(lhs)} == {sign(rhs)})"
423 if value.operator == "!=":
424 return f"({sign(lhs)} != {sign(rhs)})"
425 if value.operator == "<":
426 return f"({sign(lhs)} < {sign(rhs)})"
427 if value.operator == "<=":
428 return f"({sign(lhs)} <= {sign(rhs)})"
429 if value.operator == ">":
430 return f"({sign(lhs)} > {sign(rhs)})"
431 if value.operator == ">=":
432 return f"({sign(lhs)} >= {sign(rhs)})"
433 elif len(value.operands) == 3:
434 if value.operator == "m":
435 sel, val1, val0 = value.operands
436 return f"({self(val1)} if {self(sel)} else {self(val0)})"
437 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
438
439 def on_Slice(self, value):
440 return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
441
442 def on_Part(self, value):
443 offset_mask = (1 << len(value.offset)) - 1
444 offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
445 return f"({self(value.value)} >> {offset} & " \
446 f"{(1 << value.width) - 1})"
447
448 def on_Cat(self, value):
449 gen_parts = []
450 offset = 0
451 for part in value.parts:
452 part_mask = (1 << len(part)) - 1
453 gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
454 offset += len(part)
455 if gen_parts:
456 return f"({' | '.join(gen_parts)})"
457 return f"0"
458
459 def on_Repl(self, value):
460 part_mask = (1 << len(value.value)) - 1
461 gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
462 gen_parts = []
463 offset = 0
464 for _ in range(value.count):
465 gen_parts.append(f"({gen_part} << {offset})")
466 offset += len(value.value)
467 if gen_parts:
468 return f"({' | '.join(gen_parts)})"
469 return f"0"
470
471 def on_ArrayProxy(self, value):
472 index_mask = (1 << len(value.index)) - 1
473 gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
474 gen_value = self.emitter.gen_var("rhs_proxy")
475 if value.elems:
476 gen_elems = []
477 for index, elem in enumerate(value.elems):
478 if index == 0:
479 self.emitter.append(f"if {gen_index} == {index}:")
480 else:
481 self.emitter.append(f"elif {gen_index} == {index}:")
482 with self.emitter.indent():
483 self.emitter.append(f"{gen_value} = {self(elem)}")
484 self.emitter.append(f"else:")
485 with self.emitter.indent():
486 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
487 return gen_value
488 else:
489 return f"0"
490
491 @classmethod
492 def compile(cls, state, value, *, mode, inputs=None):
493 emitter = _Emitter()
494 compiler = cls(state, emitter, mode=mode, inputs=inputs)
495 emitter.append(f"result = {compiler(value)}")
496 return emitter.flush()
497
498
499 class _LHSValueCompiler(_ValueCompiler):
500 def __init__(self, state, emitter, *, rhs, outputs=None):
501 super().__init__(state, emitter)
502 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
503 # the offset of a Part.
504 self.rrhs = rhs
505 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
506 # update of an lvalue.
507 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
508 # If not None, `outputs` gets populated with signals on LHS.
509 self.outputs = outputs
510
511 def on_Const(self, value):
512 raise TypeError # :nocov:
513
514 def on_Signal(self, value):
515 if self.outputs is not None:
516 self.outputs.add(value)
517
518 def gen(arg):
519 value_mask = (1 << len(value)) - 1
520 if value.shape().signed:
521 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
522 else: # unsigned
523 value_sign = f"{arg} & {value_mask}"
524 self.emitter.append(f"next_{self.state.get_out_signal(value)} = {value_sign}")
525 return gen
526
527 def on_Operator(self, value):
528 raise TypeError # :nocov:
529
530 def on_Slice(self, value):
531 def gen(arg):
532 width_mask = (1 << (value.stop - value.start)) - 1
533 self(value.value)(f"({self.lrhs(value.value)} & " \
534 f"{~(width_mask << value.start)} | " \
535 f"(({arg} & {width_mask}) << {value.start}))")
536 return gen
537
538 def on_Part(self, value):
539 def gen(arg):
540 width_mask = (1 << value.width) - 1
541 offset_mask = (1 << len(value.offset)) - 1
542 offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
543 self(value.value)(f"({self.lrhs(value.value)} & " \
544 f"~({width_mask} << {offset}) | " \
545 f"(({arg} & {width_mask}) << {offset}))")
546 return gen
547
548 def on_Cat(self, value):
549 def gen(arg):
550 gen_arg = self.emitter.def_var("cat", arg)
551 gen_parts = []
552 offset = 0
553 for part in value.parts:
554 part_mask = (1 << len(part)) - 1
555 self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
556 offset += len(part)
557 return gen
558
559 def on_Repl(self, value):
560 raise TypeError # :nocov:
561
562 def on_ArrayProxy(self, value):
563 def gen(arg):
564 index_mask = (1 << len(value.index)) - 1
565 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
566 if value.elems:
567 gen_elems = []
568 for index, elem in enumerate(value.elems):
569 if index == 0:
570 self.emitter.append(f"if {gen_index} == {index}:")
571 else:
572 self.emitter.append(f"elif {gen_index} == {index}:")
573 with self.emitter.indent():
574 self(elem)(arg)
575 self.emitter.append(f"else:")
576 with self.emitter.indent():
577 self(value.elems[-1])(arg)
578 else:
579 self.emitter.append(f"pass")
580 return gen
581
582 @classmethod
583 def compile(cls, state, stmt, *, inputs=None, outputs=None):
584 emitter = _Emitter()
585 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
586 compiler(stmt)
587 return emitter.flush()
588
589
590 class _StatementCompiler(StatementVisitor, _Compiler):
591 def __init__(self, state, emitter, *, inputs=None, outputs=None):
592 super().__init__(state, emitter)
593 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
594 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
595
596 def on_statements(self, stmts):
597 for stmt in stmts:
598 self(stmt)
599 if not stmts:
600 self.emitter.append("pass")
601
602 def on_Assign(self, stmt):
603 return self.lhs(stmt.lhs)(self.rhs(stmt.rhs))
604
605 def on_Switch(self, stmt):
606 gen_test = self.emitter.def_var("test",
607 f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
608 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
609 gen_checks = []
610 if not patterns:
611 gen_checks.append(f"True")
612 else:
613 for pattern in patterns:
614 if "-" in pattern:
615 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
616 value = int("".join("0" if b == "-" else b for b in pattern), 2)
617 gen_checks.append(f"({gen_test} & {mask}) == {value}")
618 else:
619 value = int(pattern, 2)
620 gen_checks.append(f"{gen_test} == {value}")
621 if index == 0:
622 self.emitter.append(f"if {' or '.join(gen_checks)}:")
623 else:
624 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
625 with self.emitter.indent():
626 self(stmts)
627
628 def on_Assert(self, stmt):
629 raise NotImplementedError # :nocov:
630
631 def on_Assume(self, stmt):
632 raise NotImplementedError # :nocov:
633
634 def on_Cover(self, stmt):
635 raise NotImplementedError # :nocov:
636
637 @classmethod
638 def compile(cls, state, stmt, *, inputs=None, outputs=None):
639 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
640 emitter = _Emitter()
641 for signal_index in output_indexes:
642 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
643 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
644 compiler(stmt)
645 for signal_index in output_indexes:
646 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
647 return emitter.flush()
648
649
650 class _CompiledProcess(_Process):
651 __slots__ = ("state", "comb", "run")
652
653 def __init__(self, state, *, comb):
654 self.state = state
655 self.comb = comb
656 self.run = None # set by _FragmentCompiler
657 self.reset()
658
659 def reset(self):
660 self.runnable = self.comb
661 self.passive = True
662
663
664 class _FragmentCompiler:
665 def __init__(self, state, signal_names):
666 self.state = state
667 self.signal_names = signal_names
668
669 def __call__(self, fragment, *, hierarchy=("top",)):
670 processes = set()
671
672 def add_signal_name(signal):
673 hierarchical_signal_name = (*hierarchy, signal.name)
674 if signal not in self.signal_names:
675 self.signal_names[signal] = {hierarchical_signal_name}
676 else:
677 self.signal_names[signal].add(hierarchical_signal_name)
678
679 for domain_name, domain_signals in fragment.drivers.items():
680 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
681 domain_process = _CompiledProcess(self.state, comb=domain_name is None)
682
683 emitter = _Emitter()
684 emitter.append(f"def run():")
685 emitter._level += 1
686
687 if domain_name is None:
688 for signal in domain_signals:
689 signal_index = domain_process.state.get_signal(signal)
690 emitter.append(f"next_{signal_index} = {signal.reset}")
691
692 inputs = SignalSet()
693 _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts)
694
695 for input in inputs:
696 self.state.for_signal(input).wait(domain_process)
697
698 else:
699 domain = fragment.domains[domain_name]
700 add_signal_name(domain.clk)
701 if domain.rst is not None:
702 add_signal_name(domain.rst)
703
704 clk_trigger = 1 if domain.clk_edge == "pos" else 0
705 self.state.for_signal(domain.clk).wait(domain_process, trigger=clk_trigger)
706 if domain.rst is not None and domain.async_reset:
707 rst_trigger = 1
708 self.state.for_signal(domain.rst).wait(domain_process, trigger=rst_trigger)
709
710 gen_asserts = []
711 clk_index = domain_process.state.get_signal(domain.clk)
712 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
713 if domain.rst is not None and domain.async_reset:
714 rst_index = domain_process.state.get_signal(domain.rst)
715 gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
716 emitter.append(f"assert {' or '.join(gen_asserts)}")
717
718 for signal in domain_signals:
719 signal_index = domain_process.state.get_signal(signal)
720 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
721
722 _StatementCompiler(domain_process.state, emitter)(domain_stmts)
723
724 for signal in domain_signals:
725 signal_index = domain_process.state.get_signal(signal)
726 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
727
728 # There shouldn't be any exceptions raised by the generated code, but if there are
729 # (almost certainly due to a bug in the code generator), use this environment variable
730 # to make backtraces useful.
731 code = emitter.flush()
732 if os.getenv("NMIGEN_pysim_dump"):
733 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
734 file.write(code)
735 filename = file.name
736 else:
737 filename = "<string>"
738
739 exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers}
740 exec(compile(code, filename, "exec"), exec_locals)
741 domain_process.run = exec_locals["run"]
742
743 processes.add(domain_process)
744
745 for used_signal in domain_process.state.signals:
746 add_signal_name(used_signal)
747
748 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
749 if subfragment_name is None:
750 subfragment_name = "U${}".format(subfragment_index)
751 processes.update(self(subfragment, hierarchy=(*hierarchy, subfragment_name)))
752
753 return processes
754
755
756 class _CoroutineProcess(_Process):
757 def __init__(self, state, domains, constructor, *, default_cmd=None):
758 self.state = state
759 self.domains = domains
760 self.constructor = constructor
761 self.default_cmd = default_cmd
762 self.reset()
763
764 def reset(self):
765 self.runnable = True
766 self.passive = False
767 self.coroutine = self.constructor()
768 self.exec_locals = {
769 "slots": self.state.slots,
770 "result": None,
771 **_ValueCompiler.helpers
772 }
773 self.waits_on = set()
774
775 def src_loc(self):
776 coroutine = self.coroutine
777 while coroutine.gi_yieldfrom is not None:
778 coroutine = coroutine.gi_yieldfrom
779 if inspect.isgenerator(coroutine):
780 frame = coroutine.gi_frame
781 if inspect.iscoroutine(coroutine):
782 frame = coroutine.cr_frame
783 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
784
785 def get_in_signal(self, signal, *, trigger=None):
786 signal_state = self.state.for_signal(signal)
787 assert self not in signal_state.waiters
788 signal_state.waiters[self] = trigger
789 self.waits_on.add(signal_state)
790 return signal_state
791
792 def run(self):
793 if self.coroutine is None:
794 return
795
796 if self.waits_on:
797 for signal_state in self.waits_on:
798 del signal_state.waiters[self]
799 self.waits_on.clear()
800
801 response = None
802 while True:
803 try:
804 command = self.coroutine.send(response)
805 if command is None:
806 command = self.default_cmd
807 response = None
808
809 if isinstance(command, Value):
810 exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
811 self.exec_locals)
812 response = Const.normalize(self.exec_locals["result"], command.shape())
813
814 elif isinstance(command, Statement):
815 exec(_StatementCompiler.compile(self.state, command),
816 self.exec_locals)
817
818 elif type(command) is Tick:
819 domain = command.domain
820 if isinstance(domain, ClockDomain):
821 pass
822 elif domain in self.domains:
823 domain = self.domains[domain]
824 else:
825 raise NameError("Received command {!r} that refers to a nonexistent "
826 "domain {!r} from process {!r}"
827 .format(command, command.domain, self.src_loc()))
828 self.get_in_signal(domain.clk, trigger=1 if domain.clk_edge == "pos" else 0)
829 if domain.rst is not None and domain.async_reset:
830 self.get_in_signal(domain.rst, trigger=1)
831 return
832
833 elif type(command) is Settle:
834 self.state.deadlines[self] = None
835 return
836
837 elif type(command) is Delay:
838 if command.interval is None:
839 self.state.deadlines[self] = None
840 else:
841 self.state.deadlines[self] = self.state.timestamp + command.interval
842 return
843
844 elif type(command) is Passive:
845 self.passive = True
846
847 elif type(command) is Active:
848 self.passive = False
849
850 elif command is None: # only possible if self.default_cmd is None
851 raise TypeError("Received default command from process {!r} that was added "
852 "with add_process(); did you mean to add this process with "
853 "add_sync_process() instead?"
854 .format(self.src_loc()))
855
856 else:
857 raise TypeError("Received unsupported command {!r} from process {!r}"
858 .format(command, self.src_loc()))
859
860 except StopIteration:
861 self.passive = True
862 self.coroutine = None
863 return
864
865 except Exception as exn:
866 self.coroutine.throw(exn)
867
868
869 class Simulator:
870 def __init__(self, fragment):
871 self._state = _SimulatorState()
872 self._signal_names = SignalDict()
873 self._fragment = Fragment.get(fragment, platform=None).prepare()
874 self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment)
875 self._clocked = set()
876 self._waveform_writers = []
877
878 def _check_process(self, process):
879 if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
880 raise TypeError("Cannot add a process {!r} because it is not a generator function"
881 .format(process))
882 return process
883
884 def _add_coroutine_process(self, process, *, default_cmd):
885 self._processes.add(_CoroutineProcess(self._state, self._fragment.domains, process,
886 default_cmd=default_cmd))
887
888 def add_process(self, process):
889 process = self._check_process(process)
890 def wrapper():
891 # Only start a bench process after comb settling, so that the reset values are correct.
892 yield Settle()
893 yield from process()
894 self._add_coroutine_process(wrapper, default_cmd=None)
895
896 def add_sync_process(self, process, *, domain="sync"):
897 process = self._check_process(process)
898 def wrapper():
899 # Only start a sync process after the first clock edge (or reset edge, if the domain
900 # uses an asynchronous reset). This matches the behavior of synchronous FFs.
901 yield Tick(domain)
902 yield from process()
903 return self._add_coroutine_process(wrapper, default_cmd=Tick(domain))
904
905 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
906 """Add a clock process.
907
908 Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle.
909
910 Arguments
911 ---------
912 period : float
913 Clock period. The process will toggle the ``domain`` clock signal every ``period / 2``
914 seconds.
915 phase : None or float
916 Clock phase. The process will wait ``phase`` seconds before the first clock transition.
917 If not specified, defaults to ``period / 2``.
918 domain : str or ClockDomain
919 Driven clock domain. If specified as a string, the domain with that name is looked up
920 in the root fragment of the simulation.
921 if_exists : bool
922 If ``False`` (the default), raise an error if the driven domain is specified as
923 a string and the root fragment does not have such a domain. If ``True``, do nothing
924 in this case.
925 """
926 if isinstance(domain, ClockDomain):
927 pass
928 elif domain in self._fragment.domains:
929 domain = self._fragment.domains[domain]
930 elif if_exists:
931 return
932 else:
933 raise ValueError("Domain {!r} is not present in simulation"
934 .format(domain))
935 if domain in self._clocked:
936 raise ValueError("Domain {!r} already has a clock driving it"
937 .format(domain.name))
938
939 half_period = period / 2
940 if phase is None:
941 # By default, delay the first edge by half period. This causes any synchronous activity
942 # to happen at a non-zero time, distinguishing it from the reset values in the waveform
943 # viewer.
944 phase = half_period
945 def clk_process():
946 yield Passive()
947 yield Delay(phase)
948 # Behave correctly if the process is added after the clock signal is manipulated, or if
949 # its reset state is high.
950 initial = (yield domain.clk)
951 steps = (
952 domain.clk.eq(~initial),
953 Delay(half_period),
954 domain.clk.eq(initial),
955 Delay(half_period),
956 )
957 while True:
958 yield from iter(steps)
959 self._add_coroutine_process(clk_process, default_cmd=None)
960 self._clocked.add(domain)
961
962 def reset(self):
963 """Reset the simulation.
964
965 Assign the reset value to every signal in the simulation, and restart every user process.
966 """
967 self._state.reset()
968 for process in self._processes:
969 process.reset()
970
971 def _real_step(self):
972 """Step the simulation.
973
974 Run every process and commit changes until a fixed point is reached. If there is
975 an unstable combinatorial loop, this function will never return.
976 """
977 # Performs the two phases of a delta cycle in a loop:
978 converged = False
979 while not converged:
980 # 1. eval: run and suspend every non-waiting process once, queueing signal changes
981 for process in self._processes:
982 if process.runnable:
983 process.runnable = False
984 process.run()
985
986 for waveform_writer in self._waveform_writers:
987 for signal_state in self._state.pending:
988 waveform_writer.update(self._state.timestamp,
989 signal_state.signal, signal_state.curr)
990
991 # 2. commit: apply every queued signal change, waking up any waiting processes
992 converged = self._state.commit()
993
994 # TODO(nmigen-0.4): replace with _real_step
995 @deprecated("instead of `sim.step()`, use `sim.advance()`")
996 def step(self):
997 return self.advance()
998
999 def advance(self):
1000 """Advance the simulation.
1001
1002 Run every process and commit changes until a fixed point is reached, then advance time
1003 to the closest deadline (if any). If there is an unstable combinatorial loop,
1004 this function will never return.
1005
1006 Returns ``True`` if there are any active processes, ``False`` otherwise.
1007 """
1008 self._real_step()
1009 self._state.advance()
1010 return any(not process.passive for process in self._processes)
1011
1012 def run(self):
1013 """Run the simulation while any processes are active.
1014
1015 Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active,
1016 and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
1017 Processes compiled from HDL and added with :meth:`add_clock` are always passive.
1018 """
1019 while self.advance():
1020 pass
1021
1022 def run_until(self, deadline, *, run_passive=False):
1023 """Run the simulation until it advances to ``deadline``.
1024
1025 If ``run_passive`` is ``False``, the simulation also stops when there are no active
1026 processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it
1027 advances to or past ``deadline``.
1028
1029 If the simulation stops advancing, this function will never return.
1030 """
1031 assert self._state.timestamp <= deadline
1032 while (self.advance() or run_passive) and self._state.timestamp < deadline:
1033 pass
1034
1035 @contextmanager
1036 def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
1037 """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
1038
1039 This method returns a context manager. It can be used as: ::
1040
1041 sim = Simulator(frag)
1042 sim.add_clock(1e-6)
1043 with sim.write_vcd("dump.vcd", "dump.gtkw"):
1044 sim.run_until(1e-3)
1045
1046 Arguments
1047 ---------
1048 vcd_file : str or file-like object
1049 Verilog Value Change Dump file or filename.
1050 gtkw_file : str or file-like object
1051 GTKWave save file or filename.
1052 traces : iterable of Signal
1053 Signals to display traces for.
1054 """
1055 if self._state.timestamp != 0.0:
1056 raise ValueError("Cannot start writing waveforms after advancing simulation time")
1057 waveform_writer = _VCDWaveformWriter(self._signal_names,
1058 vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
1059 self._waveform_writers.append(waveform_writer)
1060 yield
1061 waveform_writer.close(self._state.timestamp)
1062 self._waveform_writers.remove(waveform_writer)