57480b2792d9dabc86b2ab273427f5024e36117e
[nmigen.git] / nmigen / back / pysim.py
1 import os
2 import tempfile
3 import warnings
4 import inspect
5 from contextlib import contextmanager
6 import itertools
7 from vcd import VCDWriter
8 from vcd.gtkw import GTKWSave
9
10 from .._utils import deprecated
11 from ..hdl.ast import *
12 from ..hdl.cd import *
13 from ..hdl.ir import *
14 from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter
15 from ..sim._cmds import *
16
17
18 __all__ = ["Settle", "Delay", "Tick", "Passive", "Active", "Simulator"]
19
20
21 class _WaveformWriter:
22 def update(self, timestamp, signal, value):
23 raise NotImplementedError # :nocov:
24
25 def close(self, timestamp):
26 raise NotImplementedError # :nocov:
27
28
29 class _VCDWaveformWriter(_WaveformWriter):
30 @staticmethod
31 def timestamp_to_vcd(timestamp):
32 return timestamp * (10 ** 10) # 1/(100 ps)
33
34 @staticmethod
35 def decode_to_vcd(signal, value):
36 return signal.decoder(value).expandtabs().replace(" ", "_")
37
38 def __init__(self, signal_names, *, vcd_file, gtkw_file=None, traces=()):
39 if isinstance(vcd_file, str):
40 vcd_file = open(vcd_file, "wt")
41 if isinstance(gtkw_file, str):
42 gtkw_file = open(gtkw_file, "wt")
43
44 self.vcd_vars = SignalDict()
45 self.vcd_file = vcd_file
46 self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
47 timescale="100 ps", comment="Generated by nMigen")
48
49 self.gtkw_names = SignalDict()
50 self.gtkw_file = gtkw_file
51 self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
52
53 self.traces = []
54
55 trace_names = SignalDict()
56 for trace in traces:
57 if trace not in signal_names:
58 trace_names[trace] = {("top", trace.name)}
59 self.traces.append(trace)
60
61 if self.vcd_writer is None:
62 return
63
64 for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
65 if signal.decoder:
66 var_type = "string"
67 var_size = 1
68 var_init = self.decode_to_vcd(signal, signal.reset)
69 else:
70 var_type = "wire"
71 var_size = signal.width
72 var_init = signal.reset
73
74 for (*var_scope, var_name) in names:
75 suffix = None
76 while True:
77 try:
78 if suffix is None:
79 var_name_suffix = var_name
80 else:
81 var_name_suffix = "{}${}".format(var_name, suffix)
82 vcd_var = self.vcd_writer.register_var(
83 scope=var_scope, name=var_name_suffix,
84 var_type=var_type, size=var_size, init=var_init)
85 break
86 except KeyError:
87 suffix = (suffix or 0) + 1
88
89 if signal not in self.vcd_vars:
90 self.vcd_vars[signal] = set()
91 self.vcd_vars[signal].add(vcd_var)
92
93 if signal not in self.gtkw_names:
94 self.gtkw_names[signal] = (*var_scope, var_name_suffix)
95
96 def update(self, timestamp, signal, value):
97 vcd_vars = self.vcd_vars.get(signal)
98 if vcd_vars is None:
99 return
100
101 vcd_timestamp = self.timestamp_to_vcd(timestamp)
102 if signal.decoder:
103 var_value = self.decode_to_vcd(signal, value)
104 else:
105 var_value = value
106 for vcd_var in vcd_vars:
107 self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
108
109 def close(self, timestamp):
110 if self.vcd_writer is not None:
111 self.vcd_writer.close(self.timestamp_to_vcd(timestamp))
112
113 if self.gtkw_save is not None:
114 self.gtkw_save.dumpfile(self.vcd_file.name)
115 self.gtkw_save.dumpfile_size(self.vcd_file.tell())
116
117 self.gtkw_save.treeopen("top")
118 for signal in self.traces:
119 if len(signal) > 1 and not signal.decoder:
120 suffix = "[{}:0]".format(len(signal) - 1)
121 else:
122 suffix = ""
123 self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix)
124
125 if self.vcd_file is not None:
126 self.vcd_file.close()
127 if self.gtkw_file is not None:
128 self.gtkw_file.close()
129
130
131 class _Process:
132 __slots__ = ("runnable", "passive")
133
134 def reset(self):
135 raise NotImplementedError # :nocov:
136
137 def run(self):
138 raise NotImplementedError # :nocov:
139
140
141 class _Timeline:
142 def __init__(self):
143 self.now = 0.0
144 self.deadlines = dict()
145
146 def reset(self):
147 self.now = 0.0
148 self.deadlines.clear()
149
150 def at(self, run_at, process):
151 assert process not in self.deadlines
152 self.deadlines[process] = run_at
153
154 def delay(self, delay_by, process):
155 if delay_by is None:
156 run_at = self.now
157 else:
158 run_at = self.now + delay_by
159 self.at(run_at, process)
160
161 def advance(self):
162 nearest_processes = set()
163 nearest_deadline = None
164 for process, deadline in self.deadlines.items():
165 if deadline is None:
166 if nearest_deadline is not None:
167 nearest_processes.clear()
168 nearest_processes.add(process)
169 nearest_deadline = self.now
170 break
171 elif nearest_deadline is None or deadline <= nearest_deadline:
172 assert deadline >= self.now
173 if nearest_deadline is not None and deadline < nearest_deadline:
174 nearest_processes.clear()
175 nearest_processes.add(process)
176 nearest_deadline = deadline
177
178 if not nearest_processes:
179 return False
180
181 for process in nearest_processes:
182 process.runnable = True
183 del self.deadlines[process]
184 self.now = nearest_deadline
185
186 return True
187
188
189 class _SignalState:
190 __slots__ = ("signal", "curr", "next", "waiters", "pending")
191
192 def __init__(self, signal, pending):
193 self.signal = signal
194 self.pending = pending
195 self.waiters = dict()
196 self.curr = self.next = signal.reset
197
198 def set(self, value):
199 if self.next == value:
200 return
201 self.next = value
202 self.pending.add(self)
203
204 def commit(self):
205 if self.curr == self.next:
206 return False
207 self.curr = self.next
208
209 awoken_any = False
210 for process, trigger in self.waiters.items():
211 if trigger is None or trigger == self.curr:
212 process.runnable = awoken_any = True
213 return awoken_any
214
215
216 class _SimulatorState:
217 def __init__(self):
218 self.timeline = _Timeline()
219 self.signals = SignalDict()
220 self.slots = []
221 self.pending = set()
222
223 def reset(self):
224 for signal, index in self.signals.items():
225 self.slots[index].curr = self.slots[index].next = signal.reset
226 self.pending.clear()
227
228 def get_signal(self, signal):
229 try:
230 return self.signals[signal]
231 except KeyError:
232 index = len(self.slots)
233 self.slots.append(_SignalState(signal, self.pending))
234 self.signals[signal] = index
235 return index
236
237 def add_trigger(self, process, signal, *, trigger=None):
238 index = self.get_signal(signal)
239 assert (process not in self.slots[index].waiters or
240 self.slots[index].waiters[process] == trigger)
241 self.slots[index].waiters[process] = trigger
242
243 def remove_trigger(self, process, signal):
244 index = self.get_signal(signal)
245 assert process in self.slots[index].waiters
246 del self.slots[index].waiters[process]
247
248 def commit(self):
249 converged = True
250 for signal_state in self.pending:
251 if signal_state.commit():
252 converged = False
253 self.pending.clear()
254 return converged
255
256
257 class _Emitter:
258 def __init__(self):
259 self._buffer = []
260 self._suffix = 0
261 self._level = 0
262
263 def append(self, code):
264 self._buffer.append(" " * self._level)
265 self._buffer.append(code)
266 self._buffer.append("\n")
267
268 @contextmanager
269 def indent(self):
270 self._level += 1
271 yield
272 self._level -= 1
273
274 def flush(self, indent=""):
275 code = "".join(self._buffer)
276 self._buffer.clear()
277 return code
278
279 def gen_var(self, prefix):
280 name = f"{prefix}_{self._suffix}"
281 self._suffix += 1
282 return name
283
284 def def_var(self, prefix, value):
285 name = self.gen_var(prefix)
286 self.append(f"{name} = {value}")
287 return name
288
289
290 class _Compiler:
291 def __init__(self, state, emitter):
292 self.state = state
293 self.emitter = emitter
294
295
296 class _ValueCompiler(ValueVisitor, _Compiler):
297 helpers = {
298 "sign": lambda value, sign: value | sign if value & sign else value,
299 "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs,
300 "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs,
301 }
302
303 def on_ClockSignal(self, value):
304 raise NotImplementedError # :nocov:
305
306 def on_ResetSignal(self, value):
307 raise NotImplementedError # :nocov:
308
309 def on_AnyConst(self, value):
310 raise NotImplementedError # :nocov:
311
312 def on_AnySeq(self, value):
313 raise NotImplementedError # :nocov:
314
315 def on_Sample(self, value):
316 raise NotImplementedError # :nocov:
317
318 def on_Initial(self, value):
319 raise NotImplementedError # :nocov:
320
321
322 class _RHSValueCompiler(_ValueCompiler):
323 def __init__(self, state, emitter, *, mode, inputs=None):
324 super().__init__(state, emitter)
325 assert mode in ("curr", "next")
326 self.mode = mode
327 # If not None, `inputs` gets populated with RHS signals.
328 self.inputs = inputs
329
330 def on_Const(self, value):
331 return f"{value.value}"
332
333 def on_Signal(self, value):
334 if self.inputs is not None:
335 self.inputs.add(value)
336
337 if self.mode == "curr":
338 return f"slots[{self.state.get_signal(value)}].{self.mode}"
339 else:
340 return f"next_{self.state.get_signal(value)}"
341
342 def on_Operator(self, value):
343 def mask(value):
344 value_mask = (1 << len(value)) - 1
345 return f"({self(value)} & {value_mask})"
346
347 def sign(value):
348 if value.shape().signed:
349 return f"sign({mask(value)}, {-1 << (len(value) - 1)})"
350 else: # unsigned
351 return mask(value)
352
353 if len(value.operands) == 1:
354 arg, = value.operands
355 if value.operator == "~":
356 return f"(~{self(arg)})"
357 if value.operator == "-":
358 return f"(-{self(arg)})"
359 if value.operator == "b":
360 return f"bool({mask(arg)})"
361 if value.operator == "r|":
362 return f"({mask(arg)} != 0)"
363 if value.operator == "r&":
364 return f"({mask(arg)} == {(1 << len(arg)) - 1})"
365 if value.operator == "r^":
366 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
367 return f"(format({mask(arg)}, 'b').count('1') % 2)"
368 if value.operator in ("u", "s"):
369 # These operators don't change the bit pattern, only its interpretation.
370 return self(arg)
371 elif len(value.operands) == 2:
372 lhs, rhs = value.operands
373 lhs_mask = (1 << len(lhs)) - 1
374 rhs_mask = (1 << len(rhs)) - 1
375 if value.operator == "+":
376 return f"({sign(lhs)} + {sign(rhs)})"
377 if value.operator == "-":
378 return f"({sign(lhs)} - {sign(rhs)})"
379 if value.operator == "*":
380 return f"({sign(lhs)} * {sign(rhs)})"
381 if value.operator == "//":
382 return f"zdiv({sign(lhs)}, {sign(rhs)})"
383 if value.operator == "%":
384 return f"zmod({sign(lhs)}, {sign(rhs)})"
385 if value.operator == "&":
386 return f"({self(lhs)} & {self(rhs)})"
387 if value.operator == "|":
388 return f"({self(lhs)} | {self(rhs)})"
389 if value.operator == "^":
390 return f"({self(lhs)} ^ {self(rhs)})"
391 if value.operator == "<<":
392 return f"({sign(lhs)} << {sign(rhs)})"
393 if value.operator == ">>":
394 return f"({sign(lhs)} >> {sign(rhs)})"
395 if value.operator == "==":
396 return f"({sign(lhs)} == {sign(rhs)})"
397 if value.operator == "!=":
398 return f"({sign(lhs)} != {sign(rhs)})"
399 if value.operator == "<":
400 return f"({sign(lhs)} < {sign(rhs)})"
401 if value.operator == "<=":
402 return f"({sign(lhs)} <= {sign(rhs)})"
403 if value.operator == ">":
404 return f"({sign(lhs)} > {sign(rhs)})"
405 if value.operator == ">=":
406 return f"({sign(lhs)} >= {sign(rhs)})"
407 elif len(value.operands) == 3:
408 if value.operator == "m":
409 sel, val1, val0 = value.operands
410 return f"({self(val1)} if {self(sel)} else {self(val0)})"
411 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
412
413 def on_Slice(self, value):
414 return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})"
415
416 def on_Part(self, value):
417 offset_mask = (1 << len(value.offset)) - 1
418 offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})"
419 return f"({self(value.value)} >> {offset} & " \
420 f"{(1 << value.width) - 1})"
421
422 def on_Cat(self, value):
423 gen_parts = []
424 offset = 0
425 for part in value.parts:
426 part_mask = (1 << len(part)) - 1
427 gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})")
428 offset += len(part)
429 if gen_parts:
430 return f"({' | '.join(gen_parts)})"
431 return f"0"
432
433 def on_Repl(self, value):
434 part_mask = (1 << len(value.value)) - 1
435 gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}")
436 gen_parts = []
437 offset = 0
438 for _ in range(value.count):
439 gen_parts.append(f"({gen_part} << {offset})")
440 offset += len(value.value)
441 if gen_parts:
442 return f"({' | '.join(gen_parts)})"
443 return f"0"
444
445 def on_ArrayProxy(self, value):
446 index_mask = (1 << len(value.index)) - 1
447 gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}")
448 gen_value = self.emitter.gen_var("rhs_proxy")
449 if value.elems:
450 gen_elems = []
451 for index, elem in enumerate(value.elems):
452 if index == 0:
453 self.emitter.append(f"if {gen_index} == {index}:")
454 else:
455 self.emitter.append(f"elif {gen_index} == {index}:")
456 with self.emitter.indent():
457 self.emitter.append(f"{gen_value} = {self(elem)}")
458 self.emitter.append(f"else:")
459 with self.emitter.indent():
460 self.emitter.append(f"{gen_value} = {self(value.elems[-1])}")
461 return gen_value
462 else:
463 return f"0"
464
465 @classmethod
466 def compile(cls, state, value, *, mode, inputs=None):
467 emitter = _Emitter()
468 compiler = cls(state, emitter, mode=mode, inputs=inputs)
469 emitter.append(f"result = {compiler(value)}")
470 return emitter.flush()
471
472
473 class _LHSValueCompiler(_ValueCompiler):
474 def __init__(self, state, emitter, *, rhs, outputs=None):
475 super().__init__(state, emitter)
476 # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
477 # the offset of a Part.
478 self.rrhs = rhs
479 # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
480 # update of an lvalue.
481 self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
482 # If not None, `outputs` gets populated with signals on LHS.
483 self.outputs = outputs
484
485 def on_Const(self, value):
486 raise TypeError # :nocov:
487
488 def on_Signal(self, value):
489 if self.outputs is not None:
490 self.outputs.add(value)
491
492 def gen(arg):
493 value_mask = (1 << len(value)) - 1
494 if value.shape().signed:
495 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
496 else: # unsigned
497 value_sign = f"{arg} & {value_mask}"
498 self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
499 return gen
500
501 def on_Operator(self, value):
502 raise TypeError # :nocov:
503
504 def on_Slice(self, value):
505 def gen(arg):
506 width_mask = (1 << (value.stop - value.start)) - 1
507 self(value.value)(f"({self.lrhs(value.value)} & " \
508 f"{~(width_mask << value.start)} | " \
509 f"(({arg} & {width_mask}) << {value.start}))")
510 return gen
511
512 def on_Part(self, value):
513 def gen(arg):
514 width_mask = (1 << value.width) - 1
515 offset_mask = (1 << len(value.offset)) - 1
516 offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})"
517 self(value.value)(f"({self.lrhs(value.value)} & " \
518 f"~({width_mask} << {offset}) | " \
519 f"(({arg} & {width_mask}) << {offset}))")
520 return gen
521
522 def on_Cat(self, value):
523 def gen(arg):
524 gen_arg = self.emitter.def_var("cat", arg)
525 gen_parts = []
526 offset = 0
527 for part in value.parts:
528 part_mask = (1 << len(part)) - 1
529 self(part)(f"(({gen_arg} >> {offset}) & {part_mask})")
530 offset += len(part)
531 return gen
532
533 def on_Repl(self, value):
534 raise TypeError # :nocov:
535
536 def on_ArrayProxy(self, value):
537 def gen(arg):
538 index_mask = (1 << len(value.index)) - 1
539 gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}")
540 if value.elems:
541 gen_elems = []
542 for index, elem in enumerate(value.elems):
543 if index == 0:
544 self.emitter.append(f"if {gen_index} == {index}:")
545 else:
546 self.emitter.append(f"elif {gen_index} == {index}:")
547 with self.emitter.indent():
548 self(elem)(arg)
549 self.emitter.append(f"else:")
550 with self.emitter.indent():
551 self(value.elems[-1])(arg)
552 else:
553 self.emitter.append(f"pass")
554 return gen
555
556 @classmethod
557 def compile(cls, state, stmt, *, inputs=None, outputs=None):
558 emitter = _Emitter()
559 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
560 compiler(stmt)
561 return emitter.flush()
562
563
564 class _StatementCompiler(StatementVisitor, _Compiler):
565 def __init__(self, state, emitter, *, inputs=None, outputs=None):
566 super().__init__(state, emitter)
567 self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
568 self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
569
570 def on_statements(self, stmts):
571 for stmt in stmts:
572 self(stmt)
573 if not stmts:
574 self.emitter.append("pass")
575
576 def on_Assign(self, stmt):
577 return self.lhs(stmt.lhs)(self.rhs(stmt.rhs))
578
579 def on_Switch(self, stmt):
580 gen_test = self.emitter.def_var("test",
581 f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}")
582 for index, (patterns, stmts) in enumerate(stmt.cases.items()):
583 gen_checks = []
584 if not patterns:
585 gen_checks.append(f"True")
586 else:
587 for pattern in patterns:
588 if "-" in pattern:
589 mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
590 value = int("".join("0" if b == "-" else b for b in pattern), 2)
591 gen_checks.append(f"({gen_test} & {mask}) == {value}")
592 else:
593 value = int(pattern, 2)
594 gen_checks.append(f"{gen_test} == {value}")
595 if index == 0:
596 self.emitter.append(f"if {' or '.join(gen_checks)}:")
597 else:
598 self.emitter.append(f"elif {' or '.join(gen_checks)}:")
599 with self.emitter.indent():
600 self(stmts)
601
602 def on_Assert(self, stmt):
603 raise NotImplementedError # :nocov:
604
605 def on_Assume(self, stmt):
606 raise NotImplementedError # :nocov:
607
608 def on_Cover(self, stmt):
609 raise NotImplementedError # :nocov:
610
611 @classmethod
612 def compile(cls, state, stmt, *, inputs=None, outputs=None):
613 output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
614 emitter = _Emitter()
615 for signal_index in output_indexes:
616 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
617 compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
618 compiler(stmt)
619 for signal_index in output_indexes:
620 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
621 return emitter.flush()
622
623
624 class _CompiledProcess(_Process):
625 __slots__ = ("state", "comb", "run")
626
627 def __init__(self, state, *, comb):
628 self.state = state
629 self.comb = comb
630 self.run = None # set by _FragmentCompiler
631 self.reset()
632
633 def reset(self):
634 self.runnable = self.comb
635 self.passive = True
636
637
638 class _FragmentCompiler:
639 def __init__(self, state, signal_names):
640 self.state = state
641 self.signal_names = signal_names
642
643 def __call__(self, fragment, *, hierarchy=("top",)):
644 processes = set()
645
646 def add_signal_name(signal):
647 hierarchical_signal_name = (*hierarchy, signal.name)
648 if signal not in self.signal_names:
649 self.signal_names[signal] = {hierarchical_signal_name}
650 else:
651 self.signal_names[signal].add(hierarchical_signal_name)
652
653 for domain_name, domain_signals in fragment.drivers.items():
654 domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
655 domain_process = _CompiledProcess(self.state, comb=domain_name is None)
656
657 emitter = _Emitter()
658 emitter.append(f"def run():")
659 emitter._level += 1
660
661 if domain_name is None:
662 for signal in domain_signals:
663 signal_index = domain_process.state.get_signal(signal)
664 emitter.append(f"next_{signal_index} = {signal.reset}")
665
666 inputs = SignalSet()
667 _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts)
668
669 for input in inputs:
670 self.state.add_trigger(domain_process, input)
671
672 else:
673 domain = fragment.domains[domain_name]
674 add_signal_name(domain.clk)
675 if domain.rst is not None:
676 add_signal_name(domain.rst)
677
678 clk_trigger = 1 if domain.clk_edge == "pos" else 0
679 self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
680 if domain.rst is not None and domain.async_reset:
681 rst_trigger = 1
682 self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
683
684 gen_asserts = []
685 clk_index = domain_process.state.get_signal(domain.clk)
686 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
687 if domain.rst is not None and domain.async_reset:
688 rst_index = domain_process.state.get_signal(domain.rst)
689 gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
690 emitter.append(f"assert {' or '.join(gen_asserts)}")
691
692 for signal in domain_signals:
693 signal_index = domain_process.state.get_signal(signal)
694 emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
695
696 _StatementCompiler(domain_process.state, emitter)(domain_stmts)
697
698 for signal in domain_signals:
699 signal_index = domain_process.state.get_signal(signal)
700 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
701
702 # There shouldn't be any exceptions raised by the generated code, but if there are
703 # (almost certainly due to a bug in the code generator), use this environment variable
704 # to make backtraces useful.
705 code = emitter.flush()
706 if os.getenv("NMIGEN_pysim_dump"):
707 file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False)
708 file.write(code)
709 filename = file.name
710 else:
711 filename = "<string>"
712
713 exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers}
714 exec(compile(code, filename, "exec"), exec_locals)
715 domain_process.run = exec_locals["run"]
716
717 processes.add(domain_process)
718
719 for used_signal in domain_process.state.signals:
720 add_signal_name(used_signal)
721
722 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
723 if subfragment_name is None:
724 subfragment_name = "U${}".format(subfragment_index)
725 processes.update(self(subfragment, hierarchy=(*hierarchy, subfragment_name)))
726
727 return processes
728
729
730 class _CoroutineProcess(_Process):
731 def __init__(self, state, domains, constructor, *, default_cmd=None):
732 self.state = state
733 self.domains = domains
734 self.constructor = constructor
735 self.default_cmd = default_cmd
736 self.reset()
737
738 def reset(self):
739 self.runnable = True
740 self.passive = False
741 self.coroutine = self.constructor()
742 self.exec_locals = {
743 "slots": self.state.slots,
744 "result": None,
745 **_ValueCompiler.helpers
746 }
747 self.waits_on = set()
748
749 def src_loc(self):
750 coroutine = self.coroutine
751 while coroutine.gi_yieldfrom is not None:
752 coroutine = coroutine.gi_yieldfrom
753 if inspect.isgenerator(coroutine):
754 frame = coroutine.gi_frame
755 if inspect.iscoroutine(coroutine):
756 frame = coroutine.cr_frame
757 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
758
759 def run(self):
760 if self.coroutine is None:
761 return
762
763 if self.waits_on:
764 for signal in self.waits_on:
765 self.state.remove_trigger(self, signal)
766 self.waits_on.clear()
767
768 response = None
769 while True:
770 try:
771 command = self.coroutine.send(response)
772 if command is None:
773 command = self.default_cmd
774 response = None
775
776 if isinstance(command, Value):
777 exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
778 self.exec_locals)
779 response = Const.normalize(self.exec_locals["result"], command.shape())
780
781 elif isinstance(command, Statement):
782 exec(_StatementCompiler.compile(self.state, command),
783 self.exec_locals)
784
785 elif type(command) is Tick:
786 domain = command.domain
787 if isinstance(domain, ClockDomain):
788 pass
789 elif domain in self.domains:
790 domain = self.domains[domain]
791 else:
792 raise NameError("Received command {!r} that refers to a nonexistent "
793 "domain {!r} from process {!r}"
794 .format(command, command.domain, self.src_loc()))
795 self.state.add_trigger(self, domain.clk,
796 trigger=1 if domain.clk_edge == "pos" else 0)
797 if domain.rst is not None and domain.async_reset:
798 self.state.add_trigger(self, domain.rst, trigger=1)
799 return
800
801 elif type(command) is Settle:
802 self.state.timeline.delay(None, self)
803 return
804
805 elif type(command) is Delay:
806 self.state.timeline.delay(command.interval, self)
807 return
808
809 elif type(command) is Passive:
810 self.passive = True
811
812 elif type(command) is Active:
813 self.passive = False
814
815 elif command is None: # only possible if self.default_cmd is None
816 raise TypeError("Received default command from process {!r} that was added "
817 "with add_process(); did you mean to add this process with "
818 "add_sync_process() instead?"
819 .format(self.src_loc()))
820
821 else:
822 raise TypeError("Received unsupported command {!r} from process {!r}"
823 .format(command, self.src_loc()))
824
825 except StopIteration:
826 self.passive = True
827 self.coroutine = None
828 return
829
830 except Exception as exn:
831 self.coroutine.throw(exn)
832
833
834 class Simulator:
835 def __init__(self, fragment):
836 self._state = _SimulatorState()
837 self._signal_names = SignalDict()
838 self._fragment = Fragment.get(fragment, platform=None).prepare()
839 self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment)
840 self._clocked = set()
841 self._waveform_writers = []
842
843 def _check_process(self, process):
844 if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
845 raise TypeError("Cannot add a process {!r} because it is not a generator function"
846 .format(process))
847 return process
848
849 def _add_coroutine_process(self, process, *, default_cmd):
850 self._processes.add(_CoroutineProcess(self._state, self._fragment.domains, process,
851 default_cmd=default_cmd))
852
853 def add_process(self, process):
854 process = self._check_process(process)
855 def wrapper():
856 # Only start a bench process after comb settling, so that the reset values are correct.
857 yield Settle()
858 yield from process()
859 self._add_coroutine_process(wrapper, default_cmd=None)
860
861 def add_sync_process(self, process, *, domain="sync"):
862 process = self._check_process(process)
863 def wrapper():
864 # Only start a sync process after the first clock edge (or reset edge, if the domain
865 # uses an asynchronous reset). This matches the behavior of synchronous FFs.
866 yield Tick(domain)
867 yield from process()
868 return self._add_coroutine_process(wrapper, default_cmd=Tick(domain))
869
870 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
871 """Add a clock process.
872
873 Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle.
874
875 Arguments
876 ---------
877 period : float
878 Clock period. The process will toggle the ``domain`` clock signal every ``period / 2``
879 seconds.
880 phase : None or float
881 Clock phase. The process will wait ``phase`` seconds before the first clock transition.
882 If not specified, defaults to ``period / 2``.
883 domain : str or ClockDomain
884 Driven clock domain. If specified as a string, the domain with that name is looked up
885 in the root fragment of the simulation.
886 if_exists : bool
887 If ``False`` (the default), raise an error if the driven domain is specified as
888 a string and the root fragment does not have such a domain. If ``True``, do nothing
889 in this case.
890 """
891 if isinstance(domain, ClockDomain):
892 pass
893 elif domain in self._fragment.domains:
894 domain = self._fragment.domains[domain]
895 elif if_exists:
896 return
897 else:
898 raise ValueError("Domain {!r} is not present in simulation"
899 .format(domain))
900 if domain in self._clocked:
901 raise ValueError("Domain {!r} already has a clock driving it"
902 .format(domain.name))
903
904 half_period = period / 2
905 if phase is None:
906 # By default, delay the first edge by half period. This causes any synchronous activity
907 # to happen at a non-zero time, distinguishing it from the reset values in the waveform
908 # viewer.
909 phase = half_period
910 def clk_process():
911 yield Passive()
912 yield Delay(phase)
913 # Behave correctly if the process is added after the clock signal is manipulated, or if
914 # its reset state is high.
915 initial = (yield domain.clk)
916 steps = (
917 domain.clk.eq(~initial),
918 Delay(half_period),
919 domain.clk.eq(initial),
920 Delay(half_period),
921 )
922 while True:
923 yield from iter(steps)
924 self._add_coroutine_process(clk_process, default_cmd=None)
925 self._clocked.add(domain)
926
927 def reset(self):
928 """Reset the simulation.
929
930 Assign the reset value to every signal in the simulation, and restart every user process.
931 """
932 self._state.reset()
933 for process in self._processes:
934 process.reset()
935
936 def _real_step(self):
937 """Step the simulation.
938
939 Run every process and commit changes until a fixed point is reached. If there is
940 an unstable combinatorial loop, this function will never return.
941 """
942 # Performs the two phases of a delta cycle in a loop:
943 converged = False
944 while not converged:
945 # 1. eval: run and suspend every non-waiting process once, queueing signal changes
946 for process in self._processes:
947 if process.runnable:
948 process.runnable = False
949 process.run()
950
951 for waveform_writer in self._waveform_writers:
952 for signal_state in self._state.pending:
953 waveform_writer.update(self._state.timeline.now,
954 signal_state.signal, signal_state.curr)
955
956 # 2. commit: apply every queued signal change, waking up any waiting processes
957 converged = self._state.commit()
958
959 # TODO(nmigen-0.4): replace with _real_step
960 @deprecated("instead of `sim.step()`, use `sim.advance()`")
961 def step(self):
962 return self.advance()
963
964 def advance(self):
965 """Advance the simulation.
966
967 Run every process and commit changes until a fixed point is reached, then advance time
968 to the closest deadline (if any). If there is an unstable combinatorial loop,
969 this function will never return.
970
971 Returns ``True`` if there are any active processes, ``False`` otherwise.
972 """
973 self._real_step()
974 self._state.timeline.advance()
975 return any(not process.passive for process in self._processes)
976
977 def run(self):
978 """Run the simulation while any processes are active.
979
980 Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active,
981 and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
982 Processes compiled from HDL and added with :meth:`add_clock` are always passive.
983 """
984 while self.advance():
985 pass
986
987 def run_until(self, deadline, *, run_passive=False):
988 """Run the simulation until it advances to ``deadline``.
989
990 If ``run_passive`` is ``False``, the simulation also stops when there are no active
991 processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it
992 advances to or past ``deadline``.
993
994 If the simulation stops advancing, this function will never return.
995 """
996 assert self._state.timeline.now <= deadline
997 while (self.advance() or run_passive) and self._state.timeline.now < deadline:
998 pass
999
1000 @contextmanager
1001 def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
1002 """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
1003
1004 This method returns a context manager. It can be used as: ::
1005
1006 sim = Simulator(frag)
1007 sim.add_clock(1e-6)
1008 with sim.write_vcd("dump.vcd", "dump.gtkw"):
1009 sim.run_until(1e-3)
1010
1011 Arguments
1012 ---------
1013 vcd_file : str or file-like object
1014 Verilog Value Change Dump file or filename.
1015 gtkw_file : str or file-like object
1016 GTKWave save file or filename.
1017 traces : iterable of Signal
1018 Signals to display traces for.
1019 """
1020 if self._state.timeline.now != 0.0:
1021 raise ValueError("Cannot start writing waveforms after advancing simulation time")
1022 waveform_writer = _VCDWaveformWriter(self._signal_names,
1023 vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
1024 self._waveform_writers.append(waveform_writer)
1025 yield
1026 waveform_writer.close(self._state.timeline.now)
1027 self._waveform_writers.remove(waveform_writer)