sim._pycoro: avoid spurious wakeups.
[nmigen.git] / nmigen / sim / pysim.py
1 from contextlib import contextmanager
2 import itertools
3 import inspect
4 from vcd import VCDWriter
5 from vcd.gtkw import GTKWSave
6
7 from .._utils import deprecated
8 from ..hdl import *
9 from ..hdl.ast import SignalDict
10 from ._cmds import *
11 from ._core import *
12 from ._pyrtl import _FragmentCompiler
13 from ._pycoro import PyCoroProcess
14
15
16 __all__ = ["Settle", "Delay", "Tick", "Passive", "Active", "Simulator"]
17
18
19 class _NameExtractor:
20 def __init__(self):
21 self.names = SignalDict()
22
23 def __call__(self, fragment, *, hierarchy=("top",)):
24 def add_signal_name(signal):
25 hierarchical_signal_name = (*hierarchy, signal.name)
26 if signal not in self.names:
27 self.names[signal] = {hierarchical_signal_name}
28 else:
29 self.names[signal].add(hierarchical_signal_name)
30
31 for domain_name, domain_signals in fragment.drivers.items():
32 if domain_name is not None:
33 domain = fragment.domains[domain_name]
34 add_signal_name(domain.clk)
35 if domain.rst is not None:
36 add_signal_name(domain.rst)
37
38 for statement in fragment.statements:
39 for signal in statement._lhs_signals() | statement._rhs_signals():
40 if not isinstance(signal, (ClockSignal, ResetSignal)):
41 add_signal_name(signal)
42
43 for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
44 if subfragment_name is None:
45 subfragment_name = "U${}".format(subfragment_index)
46 self(subfragment, hierarchy=(*hierarchy, subfragment_name))
47
48 return self.names
49
50
51 class _WaveformWriter:
52 def update(self, timestamp, signal, value):
53 raise NotImplementedError # :nocov:
54
55 def close(self, timestamp):
56 raise NotImplementedError # :nocov:
57
58
59 class _VCDWaveformWriter(_WaveformWriter):
60 @staticmethod
61 def timestamp_to_vcd(timestamp):
62 return timestamp * (10 ** 10) # 1/(100 ps)
63
64 @staticmethod
65 def decode_to_vcd(signal, value):
66 return signal.decoder(value).expandtabs().replace(" ", "_")
67
68 def __init__(self, fragment, *, vcd_file, gtkw_file=None, traces=()):
69 if isinstance(vcd_file, str):
70 vcd_file = open(vcd_file, "wt")
71 if isinstance(gtkw_file, str):
72 gtkw_file = open(gtkw_file, "wt")
73
74 self.vcd_vars = SignalDict()
75 self.vcd_file = vcd_file
76 self.vcd_writer = vcd_file and VCDWriter(self.vcd_file,
77 timescale="100 ps", comment="Generated by nMigen")
78
79 self.gtkw_names = SignalDict()
80 self.gtkw_file = gtkw_file
81 self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file)
82
83 self.traces = []
84
85 signal_names = _NameExtractor()(fragment)
86
87 trace_names = SignalDict()
88 for trace in traces:
89 if trace not in signal_names:
90 trace_names[trace] = {("top", trace.name)}
91 self.traces.append(trace)
92
93 if self.vcd_writer is None:
94 return
95
96 for signal, names in itertools.chain(signal_names.items(), trace_names.items()):
97 if signal.decoder:
98 var_type = "string"
99 var_size = 1
100 var_init = self.decode_to_vcd(signal, signal.reset)
101 else:
102 var_type = "wire"
103 var_size = signal.width
104 var_init = signal.reset
105
106 for (*var_scope, var_name) in names:
107 suffix = None
108 while True:
109 try:
110 if suffix is None:
111 var_name_suffix = var_name
112 else:
113 var_name_suffix = "{}${}".format(var_name, suffix)
114 if signal not in self.vcd_vars:
115 vcd_var = self.vcd_writer.register_var(
116 scope=var_scope, name=var_name_suffix,
117 var_type=var_type, size=var_size, init=var_init)
118 self.vcd_vars[signal] = vcd_var
119 else:
120 self.vcd_writer.register_alias(
121 scope=var_scope, name=var_name_suffix,
122 var=self.vcd_vars[signal])
123 break
124 except KeyError:
125 suffix = (suffix or 0) + 1
126
127 if signal not in self.gtkw_names:
128 self.gtkw_names[signal] = (*var_scope, var_name_suffix)
129
130 def update(self, timestamp, signal, value):
131 vcd_var = self.vcd_vars.get(signal)
132 if vcd_var is None:
133 return
134
135 vcd_timestamp = self.timestamp_to_vcd(timestamp)
136 if signal.decoder:
137 var_value = self.decode_to_vcd(signal, value)
138 else:
139 var_value = value
140 self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
141
142 def close(self, timestamp):
143 if self.vcd_writer is not None:
144 self.vcd_writer.close(self.timestamp_to_vcd(timestamp))
145
146 if self.gtkw_save is not None:
147 self.gtkw_save.dumpfile(self.vcd_file.name)
148 self.gtkw_save.dumpfile_size(self.vcd_file.tell())
149
150 self.gtkw_save.treeopen("top")
151 for signal in self.traces:
152 if len(signal) > 1 and not signal.decoder:
153 suffix = "[{}:0]".format(len(signal) - 1)
154 else:
155 suffix = ""
156 self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix)
157
158 if self.vcd_file is not None:
159 self.vcd_file.close()
160 if self.gtkw_file is not None:
161 self.gtkw_file.close()
162
163
164 class _SignalState:
165 __slots__ = ("signal", "curr", "next", "waiters", "pending")
166
167 def __init__(self, signal, pending):
168 self.signal = signal
169 self.pending = pending
170 self.waiters = dict()
171 self.curr = self.next = signal.reset
172
173 def set(self, value):
174 if self.next == value:
175 return
176 self.next = value
177 self.pending.add(self)
178
179 def commit(self):
180 if self.curr == self.next:
181 return False
182 self.curr = self.next
183
184 awoken_any = False
185 for process, trigger in self.waiters.items():
186 if trigger is None or trigger == self.curr:
187 process.runnable = awoken_any = True
188 return awoken_any
189
190
191 class _SimulatorState:
192 def __init__(self):
193 self.timeline = Timeline()
194 self.signals = SignalDict()
195 self.slots = []
196 self.pending = set()
197
198 def reset(self):
199 self.timeline.reset()
200 for signal, index in self.signals.items():
201 self.slots[index].curr = self.slots[index].next = signal.reset
202 self.pending.clear()
203
204 def get_signal(self, signal):
205 try:
206 return self.signals[signal]
207 except KeyError:
208 index = len(self.slots)
209 self.slots.append(_SignalState(signal, self.pending))
210 self.signals[signal] = index
211 return index
212
213 def add_trigger(self, process, signal, *, trigger=None):
214 index = self.get_signal(signal)
215 assert (process not in self.slots[index].waiters or
216 self.slots[index].waiters[process] == trigger)
217 self.slots[index].waiters[process] = trigger
218
219 def remove_trigger(self, process, signal):
220 index = self.get_signal(signal)
221 assert process in self.slots[index].waiters
222 del self.slots[index].waiters[process]
223
224 def commit(self):
225 converged = True
226 for signal_state in self.pending:
227 if signal_state.commit():
228 converged = False
229 self.pending.clear()
230 return converged
231
232
233 class Simulator:
234 def __init__(self, fragment):
235 self._state = _SimulatorState()
236 self._fragment = Fragment.get(fragment, platform=None).prepare()
237 self._processes = _FragmentCompiler(self._state)(self._fragment)
238 self._clocked = set()
239 self._waveform_writers = []
240
241 def _check_process(self, process):
242 if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
243 raise TypeError("Cannot add a process {!r} because it is not a generator function"
244 .format(process))
245 return process
246
247 def _add_coroutine_process(self, process, *, default_cmd):
248 self._processes.add(PyCoroProcess(self._state, self._fragment.domains, process,
249 default_cmd=default_cmd))
250
251 def add_process(self, process):
252 process = self._check_process(process)
253 def wrapper():
254 # Only start a bench process after comb settling, so that the reset values are correct.
255 yield Settle()
256 yield from process()
257 self._add_coroutine_process(wrapper, default_cmd=None)
258
259 def add_sync_process(self, process, *, domain="sync"):
260 process = self._check_process(process)
261 def wrapper():
262 # Only start a sync process after the first clock edge (or reset edge, if the domain
263 # uses an asynchronous reset). This matches the behavior of synchronous FFs.
264 yield Tick(domain)
265 yield from process()
266 return self._add_coroutine_process(wrapper, default_cmd=Tick(domain))
267
268 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
269 """Add a clock process.
270
271 Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle.
272
273 Arguments
274 ---------
275 period : float
276 Clock period. The process will toggle the ``domain`` clock signal every ``period / 2``
277 seconds.
278 phase : None or float
279 Clock phase. The process will wait ``phase`` seconds before the first clock transition.
280 If not specified, defaults to ``period / 2``.
281 domain : str or ClockDomain
282 Driven clock domain. If specified as a string, the domain with that name is looked up
283 in the root fragment of the simulation.
284 if_exists : bool
285 If ``False`` (the default), raise an error if the driven domain is specified as
286 a string and the root fragment does not have such a domain. If ``True``, do nothing
287 in this case.
288 """
289 if isinstance(domain, ClockDomain):
290 pass
291 elif domain in self._fragment.domains:
292 domain = self._fragment.domains[domain]
293 elif if_exists:
294 return
295 else:
296 raise ValueError("Domain {!r} is not present in simulation"
297 .format(domain))
298 if domain in self._clocked:
299 raise ValueError("Domain {!r} already has a clock driving it"
300 .format(domain.name))
301
302 half_period = period / 2
303 if phase is None:
304 # By default, delay the first edge by half period. This causes any synchronous activity
305 # to happen at a non-zero time, distinguishing it from the reset values in the waveform
306 # viewer.
307 phase = half_period
308 def clk_process():
309 yield Passive()
310 yield Delay(phase)
311 # Behave correctly if the process is added after the clock signal is manipulated, or if
312 # its reset state is high.
313 initial = (yield domain.clk)
314 steps = (
315 domain.clk.eq(~initial),
316 Delay(half_period),
317 domain.clk.eq(initial),
318 Delay(half_period),
319 )
320 while True:
321 yield from iter(steps)
322 self._add_coroutine_process(clk_process, default_cmd=None)
323 self._clocked.add(domain)
324
325 def reset(self):
326 """Reset the simulation.
327
328 Assign the reset value to every signal in the simulation, and restart every user process.
329 """
330 self._state.reset()
331 for process in self._processes:
332 process.reset()
333
334 def _real_step(self):
335 """Step the simulation.
336
337 Run every process and commit changes until a fixed point is reached. If there is
338 an unstable combinatorial loop, this function will never return.
339 """
340 # Performs the two phases of a delta cycle in a loop:
341 converged = False
342 while not converged:
343 # 1. eval: run and suspend every non-waiting process once, queueing signal changes
344 for process in self._processes:
345 if process.runnable:
346 process.runnable = False
347 process.run()
348
349 for waveform_writer in self._waveform_writers:
350 for signal_state in self._state.pending:
351 waveform_writer.update(self._state.timeline.now,
352 signal_state.signal, signal_state.next)
353
354 # 2. commit: apply every queued signal change, waking up any waiting processes
355 converged = self._state.commit()
356
357 # TODO(nmigen-0.4): replace with _real_step
358 @deprecated("instead of `sim.step()`, use `sim.advance()`")
359 def step(self):
360 return self.advance()
361
362 def advance(self):
363 """Advance the simulation.
364
365 Run every process and commit changes until a fixed point is reached, then advance time
366 to the closest deadline (if any). If there is an unstable combinatorial loop,
367 this function will never return.
368
369 Returns ``True`` if there are any active processes, ``False`` otherwise.
370 """
371 self._real_step()
372 self._state.timeline.advance()
373 return any(not process.passive for process in self._processes)
374
375 def run(self):
376 """Run the simulation while any processes are active.
377
378 Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active,
379 and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
380 Processes compiled from HDL and added with :meth:`add_clock` are always passive.
381 """
382 while self.advance():
383 pass
384
385 def run_until(self, deadline, *, run_passive=False):
386 """Run the simulation until it advances to ``deadline``.
387
388 If ``run_passive`` is ``False``, the simulation also stops when there are no active
389 processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it
390 advances to or past ``deadline``.
391
392 If the simulation stops advancing, this function will never return.
393 """
394 assert self._state.timeline.now <= deadline
395 while (self.advance() or run_passive) and self._state.timeline.now < deadline:
396 pass
397
398 @contextmanager
399 def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
400 """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
401
402 This method returns a context manager. It can be used as: ::
403
404 sim = Simulator(frag)
405 sim.add_clock(1e-6)
406 with sim.write_vcd("dump.vcd", "dump.gtkw"):
407 sim.run_until(1e-3)
408
409 Arguments
410 ---------
411 vcd_file : str or file-like object
412 Verilog Value Change Dump file or filename.
413 gtkw_file : str or file-like object
414 GTKWave save file or filename.
415 traces : iterable of Signal
416 Signals to display traces for.
417 """
418 if self._state.timeline.now != 0.0:
419 raise ValueError("Cannot start writing waveforms after advancing simulation time")
420 waveform_writer = _VCDWaveformWriter(self._fragment,
421 vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
422 self._waveform_writers.append(waveform_writer)
423 yield
424 waveform_writer.close(self._state.timeline.now)
425 self._waveform_writers.remove(waveform_writer)