back.pysim: simplify. NFC.
[nmigen.git] / nmigen / back / pysim.py
1 import os
2 import tempfile
3 import warnings
4 import inspect
5 from contextlib import contextmanager
6 import itertools
7 from vcd import VCDWriter
8 from vcd.gtkw import GTKWSave
9
10 from .._utils import deprecated
11 from ..hdl.ast import *
12 from ..hdl.cd import *
13 from ..hdl.ir import *
14 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
15
16
17 class Command:
18 pass
19
20
21 class Settle(Command):
22 def __repr__(self):
23 return "(settle)"
24
25
26 class Delay(Command):
27 def __init__(self, interval=None):
28 self.interval = None if interval is None else float(interval)
29
30 def __repr__(self):
31 if self.interval is None:
32 return "(delay ε)"
33 else:
34 return "(delay {:.3}us)".format(self.interval * 1e6)
35
36
37 class Tick(Command):
38 def __init__(self, domain="sync"):
39 if not isinstance(domain, (str, ClockDomain)):
40 raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}"
41 .format(domain))
42 assert domain != "comb"
43 self.domain = domain
44
45 def __repr__(self):
46 return "(tick {})".format(self.domain)
47
48
49 class Passive(Command):
50 def __repr__(self):
51 return "(passive)"
52
53
54 class Active(Command):
55 def __repr__(self):
56 return "(active)"
57
58
59 class _WaveformWriter:
60 def update(self, timestamp, signal, value):
61 raise NotImplementedError # :nocov:
62
63 def close(self, timestamp):
64 raise NotImplementedError # :nocov:
65
66
67 class _VCDWaveformWriter(_WaveformWriter):
68 @staticmethod
69 def timestamp_to_vcd(timestamp):
70 return timestamp * (10 ** 10) # 1/(100 ps)
71
72 @staticmethod
73 def decode_to_vcd(signal, value):
74 return signal.decoder(value).expandtabs().replace(" ", "_")
75
76 def __init__(self, signal_names, *, vcd_file, gtkw_file=None, traces=()):
77 if isinstance(vcd_file, str):
78 vcd_file = open(vcd_file, "wt")
79 if isinstance(gtkw_file, str):
80 gtkw_file = open(gtkw_file, "wt")
81
82 self.vcd_vars = SignalDict()
83 self.vcd_file = vcd_file
84 self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
85 timescale="100 ps", comment="Generated by nMigen")
86
87 self.gtkw_names = SignalDict()
88 self.gtkw_file = gtkw_file
89 self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
90
91 self.traces = []
92
93 trace_names = SignalDict()
94 for trace in traces:
95 if trace not in signal_names:
96 trace_names[trace] = {("top", trace.name)}
97 self.traces.append(trace)
98
99 if self.vcd_writer is None:
100 return
101
102 for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
103 if signal.decoder:
104 var_type = "string"
105 var_size = 1
106 var_init = self.decode_to_vcd(signal, signal.reset)
107 else:
108 var_type = "wire"
109 var_size = signal.width
110 var_init = signal.reset
111
112 for (*var_scope, var_name) in names:
113 suffix = None
114 while True:
115 try:
116 if suffix is None:
117 var_name_suffix = var_name
118 else:
119 var_name_suffix = "{}${}".format(var_name, suffix)
120 vcd_var = self.vcd_writer.register_var(
121 scope=var_scope, name=var_name_suffix,
122 var_type=var_type, size=var_size, init=var_init)
123 break
124 except KeyError:
125 suffix = (suffix or 0) + 1
126
127 if signal not in self.vcd_vars:
128 self.vcd_vars[signal] = set()
129 self.vcd_vars[signal].add(vcd_var)
130
131 if signal not in self.gtkw_names:
132 self.gtkw_names[signal] = (*var_scope, var_name_suffix)
133
134 def update(self, timestamp, signal, value):
135 vcd_vars = self.vcd_vars.get(signal)
136 if vcd_vars is None:
137 return
138
139 vcd_timestamp = self.timestamp_to_vcd(timestamp)
140 if signal.decoder:
141 var_value = self.decode_to_vcd(signal, value)
142 else:
143 var_value = value
144 for vcd_var in vcd_vars:
145 self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
146
147 def close(self, timestamp):
148 if self.vcd_writer is not None:
149 self.vcd_writer.close(self.timestamp_to_vcd(timestamp))
150
151 if self.gtkw_save is not None:
152 self.gtkw_save.dumpfile(self.vcd_file.name)
153 self.gtkw_save.dumpfile_size(self.vcd_file.tell())
154
155 self.gtkw_save.treeopen("top")
156 for signal in self.traces:
157 if len(signal) > 1 and not signal.decoder:
158 suffix = "[{}:0]".format(len(signal) - 1)
159 else:
160 suffix = ""
161 self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix)
162
163 if self.vcd_file is not None:
164 self.vcd_file.close()
165 if self.gtkw_file is not None:
166 self.gtkw_file.close()
167
168
169 class _Process:
170 __slots__ = ("runnable", "passive")
171
172 def reset(self):
173 raise NotImplementedError # :nocov:
174
175 def run(self):
176 raise NotImplementedError # :nocov:
177
178
179 class _SignalState:
180 __slots__ = ("signal", "curr", "next", "waiters", "pending")
181
182 def __init__(self, signal, pending):
183 self.signal = signal
184 self.pending = pending
185 self.waiters = dict()
186 self.curr = self.next = signal.reset
187
188 def set(self, value):
189 if self.next == value:
190 return
191 self.next = value
192 self.pending.add(self)
193
194 def commit(self):
195 if self.curr == self.next:
196 return False
197 self.curr = self.next
198
199 awoken_any = False
200 for process, trigger in self.waiters.items():
201 if trigger is None or trigger == self.curr:
202 process.runnable = awoken_any = True
203 return awoken_any
204
205
206 class _SimulatorState:
207 def __init__(self):
208 self.signals = SignalDict()
209 self.slots = []
210 self.pending = set()
211
212 self.timestamp = 0.0
213 self.deadlines = dict()
214
215 def reset(self):
216 for signal, index in self.signals.items():
217 self.slots[index].curr = self.slots[index].next = signal.reset
218 self.pending.clear()
219
220 self.timestamp = 0.0
221 self.deadlines.clear()
222
223 def get_signal(self, signal):
224 try:
225 return self.signals[signal]
226 except KeyError:
227 index = len(self.slots)
228 self.slots.append(_SignalState(signal, self.pending))
229 self.signals[signal] = index
230 return index
231
232 def add_trigger(self, process, signal, *, trigger=None):
233 index = self.get_signal(signal)
234 assert (process not in self.slots[index].waiters or
235 self.slots[index].waiters[process] == trigger)
236 self.slots[index].waiters[process] = trigger
237
238 def remove_trigger(self, process, signal):
239 index = self.get_signal(signal)
240 assert process in self.slots[index].waiters
241 del self.slots[index].waiters[process]
242
243 def commit(self):
244 converged = True
245 for signal_state in self.pending:
246 if signal_state.commit():
247 converged = False
248 self.pending.clear()
249 return converged
250
251 def advance(self):
252 nearest_processes = set()
253 nearest_deadline = None
254 for process, deadline in self.deadlines.items():
255 if deadline is None:
256 if nearest_deadline is not None:
257 nearest_processes.clear()
258 nearest_processes.add(process)
259 nearest_deadline = self.timestamp
260 break
261 elif nearest_deadline is None or deadline <= nearest_deadline:
262 assert deadline >= self.timestamp
263 if nearest_deadline is not None and deadline < nearest_deadline:
264 nearest_processes.clear()
265 nearest_processes.add(process)
266 nearest_deadline = deadline
267
268 if not nearest_processes:
269 return False
270
271 for process in nearest_processes:
272 process.runnable = True
273 del self.deadlines[process]
274 self.timestamp = nearest_deadline
275
276 return True
277
278
279 class _Emitter:
280 def __init__(self):
281 self._buffer = []
282 self._suffix = 0
283 self._level = 0
284
285 def append(self, code):
286 self._buffer.append(" " * self._level)
287 self._buffer.append(code)
288 self._buffer.append("\n")
289
290 @contextmanager
291 def indent(self):
292 self._level += 1
293 yield
294 self._level -= 1
295
296 def flush(self, indent=""):
297 code = "".join(self._buffer)
298 self._buffer.clear()
299 return code
300
301 def gen_var(self, prefix):
302 name = f"{prefix}_{self._suffix}"
303 self._suffix += 1
304 return name
305
306 def def_var(self, prefix, value):
307 name = self.gen_var(prefix)
308 self.append(f"{name} = {value}")
309 return name
310
311
312 class _Compiler:
313 def __init__(self, state, emitter):
314 self.state = state
315 self.emitter = emitter
316
317
318 class _ValueCompiler(ValueVisitor, _Compiler):
319 helpers = {
320 "sign": lambda value, sign: value | sign if value & sign else value,
321 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
322 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
323 }
324
325 def on_ClockSignal(self, value):
326 raise NotImplementedError # :nocov:
327
328 def on_ResetSignal(self, value):
329 raise NotImplementedError # :nocov:
330
331 def on_AnyConst(self, value):
332 raise NotImplementedError # :nocov:
333
334 def on_AnySeq(self, value):
335 raise NotImplementedError # :nocov:
336
337 def on_Sample(self, value):
338 raise NotImplementedError # :nocov:
339
340 def on_Initial(self, value):
341 raise NotImplementedError # :nocov:
342
343
344 class _RHSValueCompiler(_ValueCompiler):
345 def __init__(self, state, emitter, *, mode, inputs=None):
346 super().__init__(state, emitter)
347 assert mode in ("curr", "next")
348 self.mode = mode
349 # If not None, `inputs` gets populated with RHS signals.
350 self.inputs = inputs
351
352 def on_Const(self, value):
353 return f"{value.value}"
354
355 def on_Signal(self, value):
356 if self.inputs is not None:
357 self.inputs.add(value)
358
359 if self.mode == "curr":
360 return f"slots[{self.state.get_signal(value)}].{self.mode}"
361 else:
362 return f"next_{self.state.get_signal(value)}"
363
364 def on_Operator(self, value):
365 def mask(value):
366 value_mask = (1 << len(value)) - 1
367 return f"({self(value)} & {value_mask})"
368
369 def sign(value):
370 if value.shape().signed:
371 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
372 else: # unsigned
373 return mask(value)
374
375 if len(value.operands) == 1:
376 arg, = value.operands
377 if value.operator == "~":
378 return f"(~{self(arg)})"
379 if value.operator == "-":
380 return f"(-{self(arg)})"
381 if value.operator == "b":
382 return f"bool({mask(arg)})"
383 if value.operator == "r|":
384 return f"({mask(arg)} != 0)"
385 if value.operator == "r&":
386 return f"({mask(arg)} == {(1 << len(arg)) - 1})"
387 if value.operator == "r^":
388 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
389 return f"(format({mask(arg)}, 'b').count('1') % 2)"
390 if value.operator in ("u", "s"):
391 # These operators don't change the bit pattern, only its interpretation.
392 return self(arg)
393 elif len(value.operands) == 2:
394 lhs, rhs = value.operands
395 lhs_mask = (1 << len(lhs)) - 1
396 rhs_mask = (1 << len(rhs)) - 1
397 if value.operator == "+":
398 return f"({sign(lhs)} + {sign(rhs)})"
399 if value.operator == "-":
400 return f"({sign(lhs)} - {sign(rhs)})"
401 if value.operator == "*":
402 return f"({sign(lhs)} * {sign(rhs)})"
403 if value.operator == "//":
404 return f"zdiv({sign(lhs)}, {sign(rhs)})"
405 if value.operator == "%":
406 return f"zmod({sign(lhs)}, {sign(rhs)})"
407 if value.operator == "&":
408 return f"({self(lhs)} & {self(rhs)})"
409 if value.operator == "|":
410 return f"({self(lhs)} | {self(rhs)})"
411 if value.operator == "^":
412 return f"({self(lhs)} ^ {self(rhs)})"
413 if value.operator == "<<":
414 return f"({sign(lhs)} << {sign(rhs)})"
415 if value.operator == ">>":
416 return f"({sign(lhs)} >> {sign(rhs)})"
417 if value.operator == "==":
418 return f"({sign(lhs)} == {sign(rhs)})"
419 if value.operator == "!=":
420 return f"({sign(lhs)} != {sign(rhs)})"
421 if value.operator == "<":
422 return f"({sign(lhs)} < {sign(rhs)})"
423 if value.operator == "<=":
424 return f"({sign(lhs)} <= {sign(rhs)})"
425 if value.operator == ">":
426 return f"({sign(lhs)} > {sign(rhs)})"
427 if value.operator == ">=":
428 return f"({sign(lhs)} >= {sign(rhs)})"
429 elif len(value.operands) == 3:
430 if value.operator == "m":
431 sel, val1, val0 = value.operands
432 return f"({self(val1)} if {self(sel)} else {self(val0)})"
433 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
434
435 def on_Slice(self, value):
436 return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
437
438 def on_Part(self, value):
439 offset_mask = (1 << len(value.offset)) - 1
440 offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
441 return f"({self(value.value)} >> {offset} & " \
442 f"{(1 << value.width) - 1})"
443
444 def on_Cat(self, value):
445 gen_parts = []
446 offset = 0
447 for part in value.parts:
448 part_mask = (1 << len(part)) - 1
449 gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
450 offset += len(part)
451 if gen_parts:
452 return f"({' | '.join(gen_parts)})"
453 return f"0"
454
455 def on_Repl(self, value):
456 part_mask = (1 << len(value.value)) - 1
457 gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
458 gen_parts = []
459 offset = 0
460 for _ in range(value.count):
461 gen_parts.append(f"({gen_part} << {offset})")
462 offset += len(value.value)
463 if gen_parts:
464 return f"({' | '.join(gen_parts)})"
465 return f"0"
466
467 def on_ArrayProxy(self, value):
468 index_mask = (1 << len(value.index)) - 1
469 gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
470 gen_value = self.emitter.gen_var("rhs_proxy")
471 if value.elems:
472 gen_elems = []
473 for index, elem in enumerate(value.elems):
474 if index == 0:
475 self.emitter.append(f"if {gen_index} == {index}:")
476 else:
477 self.emitter.append(f"elif {gen_index} == {index}:")
478 with self.emitter.indent():
479 self.emitter.append(f"{gen_value} = {self(elem)}")
480 self.emitter.append(f"else:")
481 with self.emitter.indent():
482 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
483 return gen_value
484 else:
485 return f"0"
486
487 @classmethod
488 def compile(cls, state, value, *, mode, inputs=None):
489 emitter = _Emitter()
490 compiler = cls(state, emitter, mode=mode, inputs=inputs)
491 emitter.append(f"result = {compiler(value)}")
492 return emitter.flush()
493
494
495 class _LHSValueCompiler(_ValueCompiler):
496 def __init__(self, state, emitter, *, rhs, outputs=None):
497 super().__init__(state, emitter)
498 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
499 # the offset of a Part.
500 self.rrhs = rhs
501 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
502 # update of an lvalue.
503 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
504 # If not None, `outputs` gets populated with signals on LHS.
505 self.outputs = outputs
506
507 def on_Const(self, value):
508 raise TypeError # :nocov:
509
510 def on_Signal(self, value):
511 if self.outputs is not None:
512 self.outputs.add(value)
513
514 def gen(arg):
515 value_mask = (1 << len(value)) - 1
516 if value.shape().signed:
517 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
518 else: # unsigned
519 value_sign = f"{arg} & {value_mask}"
520 self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
521 return gen
522
523 def on_Operator(self, value):
524 raise TypeError # :nocov:
525
526 def on_Slice(self, value):
527 def gen(arg):
528 width_mask = (1 << (value.stop - value.start)) - 1
529 self(value.value)(f"({self.lrhs(value.value)} & " \
530 f"{~(width_mask << value.start)} | " \
531 f"(({arg} & {width_mask}) << {value.start}))")
532 return gen
533
534 def on_Part(self, value):
535 def gen(arg):
536 width_mask = (1 << value.width) - 1
537 offset_mask = (1 << len(value.offset)) - 1
538 offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
539 self(value.value)(f"({self.lrhs(value.value)} & " \
540 f"~({width_mask} << {offset}) | " \
541 f"(({arg} & {width_mask}) << {offset}))")
542 return gen
543
544 def on_Cat(self, value):
545 def gen(arg):
546 gen_arg = self.emitter.def_var("cat", arg)
547 gen_parts = []
548 offset = 0
549 for part in value.parts:
550 part_mask = (1 << len(part)) - 1
551 self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
552 offset += len(part)
553 return gen
554
555 def on_Repl(self, value):
556 raise TypeError # :nocov:
557
558 def on_ArrayProxy(self, value):
559 def gen(arg):
560 index_mask = (1 << len(value.index)) - 1
561 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
562 if value.elems:
563 gen_elems = []
564 for index, elem in enumerate(value.elems):
565 if index == 0:
566 self.emitter.append(f"if {gen_index} == {index}:")
567 else:
568 self.emitter.append(f"elif {gen_index} == {index}:")
569 with self.emitter.indent():
570 self(elem)(arg)
571 self.emitter.append(f"else:")
572 with self.emitter.indent():
573 self(value.elems[-1])(arg)
574 else:
575 self.emitter.append(f"pass")
576 return gen
577
578 @classmethod
579 def compile(cls, state, stmt, *, inputs=None, outputs=None):
580 emitter = _Emitter()
581 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
582 compiler(stmt)
583 return emitter.flush()
584
585
586 class _StatementCompiler(StatementVisitor, _Compiler):
587 def __init__(self, state, emitter, *, inputs=None, outputs=None):
588 super().__init__(state, emitter)
589 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
590 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
591
592 def on_statements(self, stmts):
593 for stmt in stmts:
594 self(stmt)
595 if not stmts:
596 self.emitter.append("pass")
597
598 def on_Assign(self, stmt):
599 return self.lhs(stmt.lhs)(self.rhs(stmt.rhs))
600
601 def on_Switch(self, stmt):
602 gen_test = self.emitter.def_var("test",
603 f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
604 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
605 gen_checks = []
606 if not patterns:
607 gen_checks.append(f"True")
608 else:
609 for pattern in patterns:
610 if "-" in pattern:
611 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
612 value = int("".join("0" if b == "-" else b for b in pattern), 2)
613 gen_checks.append(f"({gen_test} & {mask}) == {value}")
614 else:
615 value = int(pattern, 2)
616 gen_checks.append(f"{gen_test} == {value}")
617 if index == 0:
618 self.emitter.append(f"if {' or '.join(gen_checks)}:")
619 else:
620 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
621 with self.emitter.indent():
622 self(stmts)
623
624 def on_Assert(self, stmt):
625 raise NotImplementedError # :nocov:
626
627 def on_Assume(self, stmt):
628 raise NotImplementedError # :nocov:
629
630 def on_Cover(self, stmt):
631 raise NotImplementedError # :nocov:
632
633 @classmethod
634 def compile(cls, state, stmt, *, inputs=None, outputs=None):
635 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
636 emitter = _Emitter()
637 for signal_index in output_indexes:
638 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
639 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
640 compiler(stmt)
641 for signal_index in output_indexes:
642 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
643 return emitter.flush()
644
645
646 class _CompiledProcess(_Process):
647 __slots__ = ("state", "comb", "run")
648
649 def __init__(self, state, *, comb):
650 self.state = state
651 self.comb = comb
652 self.run = None # set by _FragmentCompiler
653 self.reset()
654
655 def reset(self):
656 self.runnable = self.comb
657 self.passive = True
658
659
660 class _FragmentCompiler:
661 def __init__(self, state, signal_names):
662 self.state = state
663 self.signal_names = signal_names
664
665 def __call__(self, fragment, *, hierarchy=("top",)):
666 processes = set()
667
668 def add_signal_name(signal):
669 hierarchical_signal_name = (*hierarchy, signal.name)
670 if signal not in self.signal_names:
671 self.signal_names[signal] = {hierarchical_signal_name}
672 else:
673 self.signal_names[signal].add(hierarchical_signal_name)
674
675 for domain_name, domain_signals in fragment.drivers.items():
676 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
677 domain_process = _CompiledProcess(self.state, comb=domain_name is None)
678
679 emitter = _Emitter()
680 emitter.append(f"def run():")
681 emitter._level += 1
682
683 if domain_name is None:
684 for signal in domain_signals:
685 signal_index = domain_process.state.get_signal(signal)
686 emitter.append(f"next_{signal_index} = {signal.reset}")
687
688 inputs = SignalSet()
689 _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts)
690
691 for input in inputs:
692 self.state.add_trigger(domain_process, input)
693
694 else:
695 domain = fragment.domains[domain_name]
696 add_signal_name(domain.clk)
697 if domain.rst is not None:
698 add_signal_name(domain.rst)
699
700 clk_trigger = 1 if domain.clk_edge == "pos" else 0
701 self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
702 if domain.rst is not None and domain.async_reset:
703 rst_trigger = 1
704 self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
705
706 gen_asserts = []
707 clk_index = domain_process.state.get_signal(domain.clk)
708 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
709 if domain.rst is not None and domain.async_reset:
710 rst_index = domain_process.state.get_signal(domain.rst)
711 gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
712 emitter.append(f"assert {' or '.join(gen_asserts)}")
713
714 for signal in domain_signals:
715 signal_index = domain_process.state.get_signal(signal)
716 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
717
718 _StatementCompiler(domain_process.state, emitter)(domain_stmts)
719
720 for signal in domain_signals:
721 signal_index = domain_process.state.get_signal(signal)
722 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
723
724 # There shouldn't be any exceptions raised by the generated code, but if there are
725 # (almost certainly due to a bug in the code generator), use this environment variable
726 # to make backtraces useful.
727 code = emitter.flush()
728 if os.getenv("NMIGEN_pysim_dump"):
729 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
730 file.write(code)
731 filename = file.name
732 else:
733 filename = "<string>"
734
735 exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers}
736 exec(compile(code, filename, "exec"), exec_locals)
737 domain_process.run = exec_locals["run"]
738
739 processes.add(domain_process)
740
741 for used_signal in domain_process.state.signals:
742 add_signal_name(used_signal)
743
744 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
745 if subfragment_name is None:
746 subfragment_name = "U${}".format(subfragment_index)
747 processes.update(self(subfragment, hierarchy=(*hierarchy, subfragment_name)))
748
749 return processes
750
751
752 class _CoroutineProcess(_Process):
753 def __init__(self, state, domains, constructor, *, default_cmd=None):
754 self.state = state
755 self.domains = domains
756 self.constructor = constructor
757 self.default_cmd = default_cmd
758 self.reset()
759
760 def reset(self):
761 self.runnable = True
762 self.passive = False
763 self.coroutine = self.constructor()
764 self.exec_locals = {
765 "slots": self.state.slots,
766 "result": None,
767 **_ValueCompiler.helpers
768 }
769 self.waits_on = set()
770
771 def src_loc(self):
772 coroutine = self.coroutine
773 while coroutine.gi_yieldfrom is not None:
774 coroutine = coroutine.gi_yieldfrom
775 if inspect.isgenerator(coroutine):
776 frame = coroutine.gi_frame
777 if inspect.iscoroutine(coroutine):
778 frame = coroutine.cr_frame
779 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
780
781 def run(self):
782 if self.coroutine is None:
783 return
784
785 if self.waits_on:
786 for signal in self.waits_on:
787 self.state.remove_trigger(self, signal)
788 self.waits_on.clear()
789
790 response = None
791 while True:
792 try:
793 command = self.coroutine.send(response)
794 if command is None:
795 command = self.default_cmd
796 response = None
797
798 if isinstance(command, Value):
799 exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
800 self.exec_locals)
801 response = Const.normalize(self.exec_locals["result"], command.shape())
802
803 elif isinstance(command, Statement):
804 exec(_StatementCompiler.compile(self.state, command),
805 self.exec_locals)
806
807 elif type(command) is Tick:
808 domain = command.domain
809 if isinstance(domain, ClockDomain):
810 pass
811 elif domain in self.domains:
812 domain = self.domains[domain]
813 else:
814 raise NameError("Received command {!r} that refers to a nonexistent "
815 "domain {!r} from process {!r}"
816 .format(command, command.domain, self.src_loc()))
817 self.state.add_trigger(self, domain.clk,
818 trigger=1 if domain.clk_edge == "pos" else 0)
819 if domain.rst is not None and domain.async_reset:
820 self.state.add_trigger(self, domain.rst, trigger=1)
821 return
822
823 elif type(command) is Settle:
824 self.state.deadlines[self] = None
825 return
826
827 elif type(command) is Delay:
828 if command.interval is None:
829 self.state.deadlines[self] = None
830 else:
831 self.state.deadlines[self] = self.state.timestamp + command.interval
832 return
833
834 elif type(command) is Passive:
835 self.passive = True
836
837 elif type(command) is Active:
838 self.passive = False
839
840 elif command is None: # only possible if self.default_cmd is None
841 raise TypeError("Received default command from process {!r} that was added "
842 "with add_process(); did you mean to add this process with "
843 "add_sync_process() instead?"
844 .format(self.src_loc()))
845
846 else:
847 raise TypeError("Received unsupported command {!r} from process {!r}"
848 .format(command, self.src_loc()))
849
850 except StopIteration:
851 self.passive = True
852 self.coroutine = None
853 return
854
855 except Exception as exn:
856 self.coroutine.throw(exn)
857
858
859 class Simulator:
860 def __init__(self, fragment):
861 self._state = _SimulatorState()
862 self._signal_names = SignalDict()
863 self._fragment = Fragment.get(fragment, platform=None).prepare()
864 self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment)
865 self._clocked = set()
866 self._waveform_writers = []
867
868 def _check_process(self, process):
869 if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
870 raise TypeError("Cannot add a process {!r} because it is not a generator function"
871 .format(process))
872 return process
873
874 def _add_coroutine_process(self, process, *, default_cmd):
875 self._processes.add(_CoroutineProcess(self._state, self._fragment.domains, process,
876 default_cmd=default_cmd))
877
878 def add_process(self, process):
879 process = self._check_process(process)
880 def wrapper():
881 # Only start a bench process after comb settling, so that the reset values are correct.
882 yield Settle()
883 yield from process()
884 self._add_coroutine_process(wrapper, default_cmd=None)
885
886 def add_sync_process(self, process, *, domain="sync"):
887 process = self._check_process(process)
888 def wrapper():
889 # Only start a sync process after the first clock edge (or reset edge, if the domain
890 # uses an asynchronous reset). This matches the behavior of synchronous FFs.
891 yield Tick(domain)
892 yield from process()
893 return self._add_coroutine_process(wrapper, default_cmd=Tick(domain))
894
895 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
896 """Add a clock process.
897
898 Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle.
899
900 Arguments
901 ---------
902 period : float
903 Clock period. The process will toggle the ``domain`` clock signal every ``period / 2``
904 seconds.
905 phase : None or float
906 Clock phase. The process will wait ``phase`` seconds before the first clock transition.
907 If not specified, defaults to ``period / 2``.
908 domain : str or ClockDomain
909 Driven clock domain. If specified as a string, the domain with that name is looked up
910 in the root fragment of the simulation.
911 if_exists : bool
912 If ``False`` (the default), raise an error if the driven domain is specified as
913 a string and the root fragment does not have such a domain. If ``True``, do nothing
914 in this case.
915 """
916 if isinstance(domain, ClockDomain):
917 pass
918 elif domain in self._fragment.domains:
919 domain = self._fragment.domains[domain]
920 elif if_exists:
921 return
922 else:
923 raise ValueError("Domain {!r} is not present in simulation"
924 .format(domain))
925 if domain in self._clocked:
926 raise ValueError("Domain {!r} already has a clock driving it"
927 .format(domain.name))
928
929 half_period = period / 2
930 if phase is None:
931 # By default, delay the first edge by half period. This causes any synchronous activity
932 # to happen at a non-zero time, distinguishing it from the reset values in the waveform
933 # viewer.
934 phase = half_period
935 def clk_process():
936 yield Passive()
937 yield Delay(phase)
938 # Behave correctly if the process is added after the clock signal is manipulated, or if
939 # its reset state is high.
940 initial = (yield domain.clk)
941 steps = (
942 domain.clk.eq(~initial),
943 Delay(half_period),
944 domain.clk.eq(initial),
945 Delay(half_period),
946 )
947 while True:
948 yield from iter(steps)
949 self._add_coroutine_process(clk_process, default_cmd=None)
950 self._clocked.add(domain)
951
952 def reset(self):
953 """Reset the simulation.
954
955 Assign the reset value to every signal in the simulation, and restart every user process.
956 """
957 self._state.reset()
958 for process in self._processes:
959 process.reset()
960
961 def _real_step(self):
962 """Step the simulation.
963
964 Run every process and commit changes until a fixed point is reached. If there is
965 an unstable combinatorial loop, this function will never return.
966 """
967 # Performs the two phases of a delta cycle in a loop:
968 converged = False
969 while not converged:
970 # 1. eval: run and suspend every non-waiting process once, queueing signal changes
971 for process in self._processes:
972 if process.runnable:
973 process.runnable = False
974 process.run()
975
976 for waveform_writer in self._waveform_writers:
977 for signal_state in self._state.pending:
978 waveform_writer.update(self._state.timestamp,
979 signal_state.signal, signal_state.curr)
980
981 # 2. commit: apply every queued signal change, waking up any waiting processes
982 converged = self._state.commit()
983
984 # TODO(nmigen-0.4): replace with _real_step
985 @deprecated("instead of `sim.step()`, use `sim.advance()`")
986 def step(self):
987 return self.advance()
988
989 def advance(self):
990 """Advance the simulation.
991
992 Run every process and commit changes until a fixed point is reached, then advance time
993 to the closest deadline (if any). If there is an unstable combinatorial loop,
994 this function will never return.
995
996 Returns ``True`` if there are any active processes, ``False`` otherwise.
997 """
998 self._real_step()
999 self._state.advance()
1000 return any(not process.passive for process in self._processes)
1001
1002 def run(self):
1003 """Run the simulation while any processes are active.
1004
1005 Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active,
1006 and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
1007 Processes compiled from HDL and added with :meth:`add_clock` are always passive.
1008 """
1009 while self.advance():
1010 pass
1011
1012 def run_until(self, deadline, *, run_passive=False):
1013 """Run the simulation until it advances to ``deadline``.
1014
1015 If ``run_passive`` is ``False``, the simulation also stops when there are no active
1016 processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it
1017 advances to or past ``deadline``.
1018
1019 If the simulation stops advancing, this function will never return.
1020 """
1021 assert self._state.timestamp <= deadline
1022 while (self.advance() or run_passive) and self._state.timestamp < deadline:
1023 pass
1024
1025 @contextmanager
1026 def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
1027 """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
1028
1029 This method returns a context manager. It can be used as: ::
1030
1031 sim = Simulator(frag)
1032 sim.add_clock(1e-6)
1033 with sim.write_vcd("dump.vcd", "dump.gtkw"):
1034 sim.run_until(1e-3)
1035
1036 Arguments
1037 ---------
1038 vcd_file : str or file-like object
1039 Verilog Value Change Dump file or filename.
1040 gtkw_file : str or file-like object
1041 GTKWave save file or filename.
1042 traces : iterable of Signal
1043 Signals to display traces for.
1044 """
1045 if self._state.timestamp != 0.0:
1046 raise ValueError("Cannot start writing waveforms after advancing simulation time")
1047 waveform_writer = _VCDWaveformWriter(self._signal_names,
1048 vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
1049 self._waveform_writers.append(waveform_writer)
1050 yield
1051 waveform_writer.close(self._state.timestamp)
1052 self._waveform_writers.remove(waveform_writer)