back.pysim: more general clean-up.
[nmigen.git] / nmigen / back / pysim.py
1 import math
2 from vcd import VCDWriter
3 from vcd.gtkw import GTKWSave
4
5 from ..tools import flatten
6 from ..fhdl.ast import *
7 from ..fhdl.xfrm import ValueTransformer, StatementTransformer
8
9
10 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
11
12
13 class DeadlineError(Exception):
14 pass
15
16
17 class _State:
18 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
19
20 def __init__(self):
21 self.curr = ValueDict()
22 self.next = ValueDict()
23 self.curr_dirty = ValueSet()
24 self.next_dirty = ValueSet()
25
26 def get(self, signal):
27 return self.curr[signal]
28
29 def set_curr(self, signal, value):
30 assert isinstance(value, int)
31 if self.curr[signal] != value:
32 self.curr_dirty.add(signal)
33 self.curr[signal] = value
34
35 def set_next(self, signal, value):
36 assert isinstance(value, int)
37 if self.next[signal] != value:
38 self.next_dirty.add(signal)
39 self.next[signal] = value
40
41 def commit(self, signal):
42 old_value = self.curr[signal]
43 if self.curr[signal] != self.next[signal]:
44 self.next_dirty.remove(signal)
45 self.curr_dirty.add(signal)
46 self.curr[signal] = self.next[signal]
47 new_value = self.curr[signal]
48 return old_value, new_value
49
50 def iter_dirty(self):
51 dirty, self.dirty = self.dirty, ValueSet()
52 for signal in dirty:
53 yield signal, self.curr[signal], self.next[signal]
54
55
56 normalize = Const.normalize
57
58
59 class _RHSValueCompiler(ValueTransformer):
60 def __init__(self, sensitivity):
61 self.sensitivity = sensitivity
62
63 def on_Const(self, value):
64 return lambda state: value.value
65
66 def on_Signal(self, value):
67 self.sensitivity.add(value)
68 return lambda state: state.get(value)
69
70 def on_ClockSignal(self, value):
71 raise NotImplementedError # :nocov:
72
73 def on_ResetSignal(self, value):
74 raise NotImplementedError # :nocov:
75
76 def on_Operator(self, value):
77 shape = value.shape()
78 if len(value.operands) == 1:
79 arg, = map(self, value.operands)
80 if value.op == "~":
81 return lambda state: normalize(~arg(state), shape)
82 if value.op == "-":
83 return lambda state: normalize(-arg(state), shape)
84 elif len(value.operands) == 2:
85 lhs, rhs = map(self, value.operands)
86 if value.op == "+":
87 return lambda state: normalize(lhs(state) + rhs(state), shape)
88 if value.op == "-":
89 return lambda state: normalize(lhs(state) - rhs(state), shape)
90 if value.op == "&":
91 return lambda state: normalize(lhs(state) & rhs(state), shape)
92 if value.op == "|":
93 return lambda state: normalize(lhs(state) | rhs(state), shape)
94 if value.op == "^":
95 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
96 if value.op == "==":
97 return lambda state: normalize(lhs(state) == rhs(state), shape)
98 elif len(value.operands) == 3:
99 if value.op == "m":
100 sel, val1, val0 = map(self, value.operands)
101 return lambda state: val1(state) if sel(state) else val0(state)
102 raise NotImplementedError("Operator '{}' not implemented".format(value.op))
103
104 def on_Slice(self, value):
105 shape = value.shape()
106 arg = self(value.value)
107 shift = value.start
108 mask = (1 << (value.end - value.start)) - 1
109 return lambda state: normalize((arg(state) >> shift) & mask, shape)
110
111 def on_Part(self, value):
112 raise NotImplementedError
113
114 def on_Cat(self, value):
115 shape = value.shape()
116 parts = []
117 offset = 0
118 for opnd in value.operands:
119 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
120 offset += len(opnd)
121 def eval(state):
122 result = 0
123 for offset, mask, opnd in parts:
124 result |= (opnd(state) & mask) << offset
125 return normalize(result, shape)
126 return eval
127
128 def on_Repl(self, value):
129 shape = value.shape()
130 offset = len(value.value)
131 mask = (1 << len(value.value)) - 1
132 count = value.count
133 opnd = self(value.value)
134 def eval(state):
135 result = 0
136 for _ in range(count):
137 result <<= offset
138 result |= opnd(state)
139 return normalize(result, shape)
140 return eval
141
142
143 class _StatementCompiler(StatementTransformer):
144 def __init__(self):
145 self.sensitivity = ValueSet()
146 self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
147
148 def lhs_compiler(self, value):
149 # TODO
150 return lambda state, arg: state.set_next(value, arg)
151
152 def on_Assign(self, stmt):
153 assert isinstance(stmt.lhs, Signal)
154 shape = stmt.lhs.shape()
155 lhs = self.lhs_compiler(stmt.lhs)
156 rhs = self.rhs_compiler(stmt.rhs)
157 def run(state):
158 lhs(state, normalize(rhs(state), shape))
159 return run
160
161 def on_Switch(self, stmt):
162 test = self.rhs_compiler(stmt.test)
163 cases = []
164 for value, stmts in stmt.cases.items():
165 if "-" in value:
166 mask = "".join("0" if b == "-" else "1" for b in value)
167 value = "".join("0" if b == "-" else b for b in value)
168 else:
169 mask = "1" * len(value)
170 mask = int(mask, 2)
171 value = int(value, 2)
172 def make_test(mask, value):
173 return lambda test: test & mask == value
174 cases.append((make_test(mask, value), self.on_statements(stmts)))
175 def run(state):
176 test_value = test(state)
177 for check, body in cases:
178 if check(test_value):
179 body(state)
180 return
181 return run
182
183 def on_statements(self, stmts):
184 stmts = [self.on_statement(stmt) for stmt in stmts]
185 def run(state):
186 for stmt in stmts:
187 stmt(state)
188 return run
189
190
191 class Simulator:
192 def __init__(self, fragment, vcd_file=None, gtkw_file=None, gtkw_signals=()):
193 self._fragment = fragment
194
195 self._domains = {} # str/domain -> ClockDomain
196 self._domain_triggers = ValueDict() # Signal -> str/domain
197 self._domain_signals = {} # str/domain -> {Signal}
198
199 self._signals = ValueSet() # {Signal}
200 self._comb_signals = ValueSet() # {Signal}
201 self._sync_signals = ValueSet() # {Signal}
202 self._user_signals = ValueSet() # {Signal}
203
204 self._started = False
205 self._timestamp = 0.
206 self._epsilon = 1e-10
207 self._fastest_clock = self._epsilon
208 self._state = _State()
209
210 self._processes = set() # {process}
211 self._passive = set() # {process}
212 self._suspended = set() # {process}
213 self._wait_deadline = {} # process -> float/timestamp
214 self._wait_tick = {} # process -> str/domain
215
216 self._funclets = ValueDict() # Signal -> set(lambda)
217
218 self._vcd_file = vcd_file
219 self._vcd_writer = None
220 self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
221 self._vcd_names = ValueDict() # signal -> str/name
222
223 self._gtkw_file = gtkw_file
224 self._gtkw_signals = gtkw_signals
225
226 def add_process(self, process):
227 self._processes.add(process)
228
229 def add_clock(self, period, domain="sync"):
230 if self._fastest_clock == self._epsilon or period < self._fastest_clock:
231 self._fastest_clock = period
232
233 half_period = period / 2
234 clk = self._domains[domain].clk
235 def clk_process():
236 yield Passive()
237 yield Delay(half_period)
238 while True:
239 yield clk.eq(1)
240 yield Delay(half_period)
241 yield clk.eq(0)
242 yield Delay(half_period)
243 self.add_process(clk_process())
244
245 def add_sync_process(self, process, domain="sync"):
246 def sync_process():
247 try:
248 result = process.send(None)
249 while True:
250 result = process.send((yield (result or Tick(domain))))
251 except StopIteration:
252 pass
253 self.add_process(sync_process())
254
255 def __enter__(self):
256 if self._vcd_file:
257 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
258 comment="Generated by nMigen")
259
260 root_fragment = self._fragment.prepare()
261
262 self._domains = root_fragment.domains
263 for domain, cd in self._domains.items():
264 self._domain_triggers[cd.clk] = domain
265 if cd.rst is not None:
266 self._domain_triggers[cd.rst] = domain
267 self._domain_signals[domain] = ValueSet()
268
269 fragment_names = {}
270 def add_fragment(fragment, hierarchy=("top",)):
271 fragment_names[fragment] = hierarchy
272 for subfragment, name in fragment.subfragments:
273 add_fragment(subfragment, (*hierarchy, name))
274 add_fragment(root_fragment)
275
276 for fragment, fragment_name in fragment_names.items():
277 for signal in fragment.iter_signals():
278 self._signals.add(signal)
279
280 self._state.curr[signal] = self._state.next[signal] = \
281 normalize(signal.reset, signal.shape())
282 self._state.curr_dirty.add(signal)
283
284 if signal not in self._vcd_signals:
285 self._vcd_signals[signal] = set()
286
287 for subfragment, name in fragment.subfragments:
288 if signal in subfragment.ports:
289 var_name = "{}_{}".format(name, signal.name)
290 break
291 else:
292 var_name = signal.name
293
294 if signal.decoder:
295 var_type = "string"
296 var_size = 1
297 var_init = signal.decoder(signal.reset).replace(" ", "_")
298 else:
299 var_type = "wire"
300 var_size = signal.nbits
301 var_init = signal.reset
302
303 suffix = None
304 while True:
305 try:
306 if suffix is None:
307 var_name_suffix = var_name
308 else:
309 var_name_suffix = "{}${}".format(var_name, suffix)
310 self._vcd_signals[signal].add(self._vcd_writer.register_var(
311 scope=".".join(fragment_name), name=var_name_suffix,
312 var_type=var_type, size=var_size, init=var_init))
313 if signal not in self._vcd_names:
314 self._vcd_names[signal] = ".".join(fragment_name + (var_name_suffix,))
315 break
316 except KeyError:
317 suffix = (suffix or 0) + 1
318
319 for domain, signals in fragment.drivers.items():
320 if domain is None:
321 self._comb_signals.update(signals)
322 else:
323 self._sync_signals.update(signals)
324 self._domain_signals[domain].update(signals)
325
326 statements = []
327 for signal in fragment.iter_comb():
328 statements.append(signal.eq(signal.reset))
329 statements += fragment.statements
330
331 def add_funclet(signal, funclet):
332 if signal not in self._funclets:
333 self._funclets[signal] = set()
334 self._funclets[signal].add(funclet)
335
336 compiler = _StatementCompiler()
337 funclet = compiler(statements)
338 for signal in compiler.sensitivity:
339 add_funclet(signal, funclet)
340 for domain, cd in fragment.domains.items():
341 add_funclet(cd.clk, funclet)
342 if cd.rst is not None:
343 add_funclet(cd.rst, funclet)
344
345 self._user_signals = self._signals - self._comb_signals - self._sync_signals
346
347 return self
348
349 def _update_dirty_signals(self):
350 """Perform the statement part of IR processes (aka RTLIL case)."""
351 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
352 # that need their statements to be reevaluated because the signals changed at the previous
353 # delta cycle.
354 funclets = set()
355 while self._state.curr_dirty:
356 signal = self._state.curr_dirty.pop()
357 if signal in self._funclets:
358 funclets.update(self._funclets[signal])
359
360 # Second, compute the values of all signals at the start of the next delta cycle, by
361 # running precompiled statements.
362 for funclet in funclets:
363 funclet(self._state)
364
365 def _commit_signal(self, signal, domains):
366 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
367 # Take the computed value (at the start of this delta cycle) of a signal (that could have
368 # come from an IR process that ran earlier, or modified by a simulator process) and update
369 # the value for this delta cycle.
370 old, new = self._state.commit(signal)
371
372 # If the signal is a clock that triggers synchronous logic, record that fact.
373 if (old, new) == (0, 1) and signal in self._domain_triggers:
374 domains.add(self._domain_triggers[signal])
375
376 if self._vcd_writer and old != new:
377 # Finally, dump the new value to the VCD file.
378 for vcd_signal in self._vcd_signals[signal]:
379 if signal.decoder:
380 var_value = signal.decoder(new).replace(" ", "_")
381 else:
382 var_value = new
383 self._vcd_writer.change(vcd_signal, self._timestamp / self._epsilon, var_value)
384
385 def _commit_comb_signals(self, domains):
386 """Perform the comb part of IR processes (aka RTLIL always)."""
387 # Take the computed value (at the start of this delta cycle) of every comb signal and
388 # update the value for this delta cycle.
389 for signal in self._state.next_dirty:
390 if signal in self._comb_signals or signal in self._user_signals:
391 self._commit_signal(signal, domains)
392
393 def _commit_sync_signals(self, domains):
394 """Perform the sync part of IR processes (aka RTLIL posedge)."""
395 # At entry, `domains` contains a list of every simultaneously triggered sync update.
396 while domains:
397 # Advance the timeline a bit (purely for observational purposes) and commit all of them
398 # at the same timestamp.
399 self._timestamp += self._epsilon
400 curr_domains, domains = domains, set()
401
402 while curr_domains:
403 domain = curr_domains.pop()
404
405 # Take the computed value (at the start of this delta cycle) of every sync signal
406 # in this domain and update the value for this delta cycle. This can trigger more
407 # synchronous logic, so record that.
408 for signal in self._state.next_dirty:
409 if signal in self._domain_signals[domain]:
410 self._commit_signal(signal, domains)
411
412 # Wake up any simulator processes that wait for a domain tick.
413 for process, wait_domain in list(self._wait_tick.items()):
414 if domain == wait_domain:
415 del self._wait_tick[process]
416 self._suspended.remove(process)
417
418 # Unless handling synchronous logic above has triggered more synchronous logic (which
419 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
420 # Otherwise, do one more round of updates.
421
422 def _run_process(self, process):
423 try:
424 stmt = process.send(None)
425 except StopIteration:
426 self._processes.remove(process)
427 self._passive.discard(process)
428 return
429
430 if isinstance(stmt, Delay):
431 self._wait_deadline[process] = self._timestamp + stmt.interval
432 self._suspended.add(process)
433
434 elif isinstance(stmt, Tick):
435 self._wait_tick[process] = stmt.domain
436 self._suspended.add(process)
437
438 elif isinstance(stmt, Passive):
439 self._passive.add(process)
440
441 elif isinstance(stmt, Assign):
442 lhs_signals = stmt.lhs._lhs_signals()
443 for signal in lhs_signals:
444 if not signal in self._signals:
445 raise ValueError("Process {!r} sent a request to set signal '{!r}', "
446 "which is not a part of simulation"
447 .format(process, signal))
448 if signal in self._comb_signals:
449 raise ValueError("Process {!r} sent a request to set signal '{!r}', "
450 "which is a part of combinatorial assignment in simulation"
451 .format(process, signal))
452
453 funclet = _StatementCompiler()(stmt)
454 funclet(self._state)
455
456 domains = set()
457 for signal in lhs_signals:
458 self._commit_signal(signal, domains)
459 self._commit_sync_signals(domains)
460
461 else:
462 raise TypeError("Received unsupported statement '{!r}' from process {!r}"
463 .format(stmt, process))
464
465 def step(self, run_passive=False):
466 deadline = None
467 if self._wait_deadline:
468 # We might run some delta cycles, and we have simulator processes waiting on
469 # a deadline. Take care to not exceed the closest deadline.
470 deadline = min(self._wait_deadline.values())
471
472 # Are there any delta cycles we should run?
473 while self._state.curr_dirty:
474 self._timestamp += self._epsilon
475 if deadline is not None and self._timestamp >= deadline:
476 # Oops, we blew the deadline. We *could* run the processes now, but this is
477 # virtually certainly a logic loop and a design bug, so bail out instead.d
478 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
479
480 domains = set()
481 self._update_dirty_signals()
482 self._commit_comb_signals(domains)
483 self._commit_sync_signals(domains)
484
485 # Are there any processes that haven't had a chance to run yet?
486 if len(self._processes) > len(self._suspended):
487 # Schedule an arbitrary one.
488 process = (self._processes - set(self._suspended)).pop()
489 self._run_process(process)
490 return True
491
492 # All processes are suspended. Are any of them active?
493 if len(self._processes) > len(self._passive) or run_passive:
494 # Are any of them suspended before a deadline?
495 if self._wait_deadline:
496 # Schedule the one with the lowest deadline.
497 process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
498 del self._wait_deadline[process]
499 self._suspended.remove(process)
500 self._timestamp = deadline
501 self._run_process(process)
502 return True
503
504 # No processes, or all processes are passive. Nothing to do!
505 return False
506
507 def run(self):
508 while self.step():
509 pass
510
511 def run_until(self, deadline, run_passive=False):
512 while self._timestamp < deadline:
513 if not self.step(run_passive):
514 return False
515
516 return True
517
518 def __exit__(self, *args):
519 if self._vcd_writer:
520 self._vcd_writer.close(self._timestamp / self._epsilon)
521
522 if self._vcd_file and self._gtkw_file:
523 gtkw_save = GTKWSave(self._gtkw_file)
524 if hasattr(self._vcd_file, "name"):
525 gtkw_save.dumpfile(self._vcd_file.name)
526 if hasattr(self._vcd_file, "tell"):
527 gtkw_save.dumpfile_size(self._vcd_file.tell())
528
529 gtkw_save.treeopen("top")
530 gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
531
532 for domain, cd in self._domains.items():
533 with gtkw_save.group("d.{}".format(domain)):
534 if cd.rst is not None:
535 gtkw_save.trace(self._vcd_names[cd.rst])
536 gtkw_save.trace(self._vcd_names[cd.clk])
537
538 for signal in self._gtkw_signals:
539 if signal in self._vcd_names:
540 if len(signal) > 1:
541 suffix = "[{}:0]".format(len(signal) - 1)
542 else:
543 suffix = ""
544 gtkw_save.trace(self._vcd_names[signal] + suffix)