back.pysim: Simulator({gtkw_signals→traces}=).
[nmigen.git] / nmigen / back / pysim.py
1 import math
2 import inspect
3 from vcd import VCDWriter
4 from vcd.gtkw import GTKWSave
5
6 from ..tools import flatten
7 from ..fhdl.ast import *
8 from ..fhdl.xfrm import ValueTransformer, StatementTransformer
9
10
11 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
12
13
14 class DeadlineError(Exception):
15 pass
16
17
18 class _State:
19 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
20
21 def __init__(self):
22 self.curr = ValueDict()
23 self.next = ValueDict()
24 self.curr_dirty = ValueSet()
25 self.next_dirty = ValueSet()
26
27 def get(self, signal):
28 return self.curr[signal]
29
30 def set(self, signal, value):
31 assert isinstance(value, int)
32 if self.next[signal] != value:
33 self.next_dirty.add(signal)
34 self.next[signal] = value
35
36 def commit(self, signal):
37 old_value = self.curr[signal]
38 if self.curr[signal] != self.next[signal]:
39 self.next_dirty.remove(signal)
40 self.curr_dirty.add(signal)
41 self.curr[signal] = self.next[signal]
42 new_value = self.curr[signal]
43 return old_value, new_value
44
45
46 normalize = Const.normalize
47
48
49 class _RHSValueCompiler(ValueTransformer):
50 def __init__(self, sensitivity):
51 self.sensitivity = sensitivity
52
53 def on_Const(self, value):
54 return lambda state: value.value
55
56 def on_Signal(self, value):
57 self.sensitivity.add(value)
58 return lambda state: state.get(value)
59
60 def on_ClockSignal(self, value):
61 raise NotImplementedError # :nocov:
62
63 def on_ResetSignal(self, value):
64 raise NotImplementedError # :nocov:
65
66 def on_Operator(self, value):
67 shape = value.shape()
68 if len(value.operands) == 1:
69 arg, = map(self, value.operands)
70 if value.op == "~":
71 return lambda state: normalize(~arg(state), shape)
72 if value.op == "-":
73 return lambda state: normalize(-arg(state), shape)
74 if value.op == "b":
75 return lambda state: normalize(bool(arg(state)), shape)
76 elif len(value.operands) == 2:
77 lhs, rhs = map(self, value.operands)
78 if value.op == "+":
79 return lambda state: normalize(lhs(state) + rhs(state), shape)
80 if value.op == "-":
81 return lambda state: normalize(lhs(state) - rhs(state), shape)
82 if value.op == "&":
83 return lambda state: normalize(lhs(state) & rhs(state), shape)
84 if value.op == "|":
85 return lambda state: normalize(lhs(state) | rhs(state), shape)
86 if value.op == "^":
87 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
88 if value.op == "==":
89 return lambda state: normalize(lhs(state) == rhs(state), shape)
90 if value.op == "!=":
91 return lambda state: normalize(lhs(state) != rhs(state), shape)
92 if value.op == "<":
93 return lambda state: normalize(lhs(state) < rhs(state), shape)
94 if value.op == "<=":
95 return lambda state: normalize(lhs(state) <= rhs(state), shape)
96 if value.op == ">":
97 return lambda state: normalize(lhs(state) > rhs(state), shape)
98 if value.op == ">=":
99 return lambda state: normalize(lhs(state) >= rhs(state), shape)
100 elif len(value.operands) == 3:
101 if value.op == "m":
102 sel, val1, val0 = map(self, value.operands)
103 return lambda state: val1(state) if sel(state) else val0(state)
104 raise NotImplementedError("Operator '{!r}' not implemented".format(value.op)) # :nocov:
105
106 def on_Slice(self, value):
107 shape = value.shape()
108 arg = self(value.value)
109 shift = value.start
110 mask = (1 << (value.end - value.start)) - 1
111 return lambda state: normalize((arg(state) >> shift) & mask, shape)
112
113 def on_Part(self, value):
114 raise NotImplementedError
115
116 def on_Cat(self, value):
117 shape = value.shape()
118 parts = []
119 offset = 0
120 for opnd in value.operands:
121 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
122 offset += len(opnd)
123 def eval(state):
124 result = 0
125 for offset, mask, opnd in parts:
126 result |= (opnd(state) & mask) << offset
127 return normalize(result, shape)
128 return eval
129
130 def on_Repl(self, value):
131 shape = value.shape()
132 offset = len(value.value)
133 mask = (1 << len(value.value)) - 1
134 count = value.count
135 opnd = self(value.value)
136 def eval(state):
137 result = 0
138 for _ in range(count):
139 result <<= offset
140 result |= opnd(state)
141 return normalize(result, shape)
142 return eval
143
144
145 class _StatementCompiler(StatementTransformer):
146 def __init__(self):
147 self.sensitivity = ValueSet()
148 self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
149
150 def lhs_compiler(self, value):
151 # TODO
152 return lambda state, arg: state.set(value, arg)
153
154 def on_Assign(self, stmt):
155 assert isinstance(stmt.lhs, Signal)
156 shape = stmt.lhs.shape()
157 lhs = self.lhs_compiler(stmt.lhs)
158 rhs = self.rhs_compiler(stmt.rhs)
159 def run(state):
160 lhs(state, normalize(rhs(state), shape))
161 return run
162
163 def on_Switch(self, stmt):
164 test = self.rhs_compiler(stmt.test)
165 cases = []
166 for value, stmts in stmt.cases.items():
167 if "-" in value:
168 mask = "".join("0" if b == "-" else "1" for b in value)
169 value = "".join("0" if b == "-" else b for b in value)
170 else:
171 mask = "1" * len(value)
172 mask = int(mask, 2)
173 value = int(value, 2)
174 def make_test(mask, value):
175 return lambda test: test & mask == value
176 cases.append((make_test(mask, value), self.on_statements(stmts)))
177 def run(state):
178 test_value = test(state)
179 for check, body in cases:
180 if check(test_value):
181 body(state)
182 return
183 return run
184
185 def on_statements(self, stmts):
186 stmts = [self.on_statement(stmt) for stmt in stmts]
187 def run(state):
188 for stmt in stmts:
189 stmt(state)
190 return run
191
192
193 class Simulator:
194 def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()):
195 self._fragment = fragment
196
197 self._domains = {} # str/domain -> ClockDomain
198 self._domain_triggers = ValueDict() # Signal -> str/domain
199 self._domain_signals = {} # str/domain -> {Signal}
200
201 self._signals = ValueSet() # {Signal}
202 self._comb_signals = ValueSet() # {Signal}
203 self._sync_signals = ValueSet() # {Signal}
204 self._user_signals = ValueSet() # {Signal}
205
206 self._started = False
207 self._timestamp = 0.
208 self._epsilon = 1e-10
209 self._fastest_clock = self._epsilon
210 self._state = _State()
211
212 self._processes = set() # {process}
213 self._passive = set() # {process}
214 self._suspended = set() # {process}
215 self._wait_deadline = {} # process -> float/timestamp
216 self._wait_tick = {} # process -> str/domain
217
218 self._funclets = ValueDict() # Signal -> set(lambda)
219
220 self._vcd_file = vcd_file
221 self._vcd_writer = None
222 self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
223 self._vcd_names = ValueDict() # signal -> str/name
224 self._gtkw_file = gtkw_file
225 self._traces = traces
226
227 def _check_process(self, process):
228 if inspect.isgeneratorfunction(process):
229 process = process()
230 if not inspect.isgenerator(process):
231 raise TypeError("Cannot add a process '{!r}' because it is not a generator or"
232 "a generator function"
233 .format(process))
234 return process
235
236 def add_process(self, process):
237 process = self._check_process(process)
238 self._processes.add(process)
239
240 def add_sync_process(self, process, domain="sync"):
241 process = self._check_process(process)
242 def sync_process():
243 try:
244 result = process.send(None)
245 while True:
246 result = process.send((yield (result or Tick(domain))))
247 except StopIteration:
248 pass
249 self.add_process(sync_process())
250
251 def add_clock(self, period, domain="sync"):
252 if self._fastest_clock == self._epsilon or period < self._fastest_clock:
253 self._fastest_clock = period
254
255 half_period = period / 2
256 clk = self._domains[domain].clk
257 def clk_process():
258 yield Passive()
259 yield Delay(half_period)
260 while True:
261 yield clk.eq(1)
262 yield Delay(half_period)
263 yield clk.eq(0)
264 yield Delay(half_period)
265 self.add_process(clk_process())
266
267 def __enter__(self):
268 if self._vcd_file:
269 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
270 comment="Generated by nMigen")
271
272 root_fragment = self._fragment.prepare()
273
274 self._domains = root_fragment.domains
275 for domain, cd in self._domains.items():
276 self._domain_triggers[cd.clk] = domain
277 if cd.rst is not None:
278 self._domain_triggers[cd.rst] = domain
279 self._domain_signals[domain] = ValueSet()
280
281 hierarchy = {}
282 def add_fragment(fragment, scope=("top",)):
283 hierarchy[fragment] = scope
284 for subfragment, name in fragment.subfragments:
285 add_fragment(subfragment, (*scope, name))
286 add_fragment(root_fragment)
287
288 for fragment, fragment_name in hierarchy.items():
289 for signal in fragment.iter_signals():
290 self._signals.add(signal)
291
292 self._state.curr[signal] = self._state.next[signal] = \
293 normalize(signal.reset, signal.shape())
294 self._state.curr_dirty.add(signal)
295
296 if not self._vcd_writer:
297 continue
298
299 if signal not in self._vcd_signals:
300 self._vcd_signals[signal] = set()
301
302 for subfragment, name in fragment.subfragments:
303 if signal in subfragment.ports:
304 var_name = "{}_{}".format(name, signal.name)
305 break
306 else:
307 var_name = signal.name
308
309 if signal.decoder:
310 var_type = "string"
311 var_size = 1
312 var_init = signal.decoder(signal.reset).replace(" ", "_")
313 else:
314 var_type = "wire"
315 var_size = signal.nbits
316 var_init = signal.reset
317
318 suffix = None
319 while True:
320 try:
321 if suffix is None:
322 var_name_suffix = var_name
323 else:
324 var_name_suffix = "{}${}".format(var_name, suffix)
325 self._vcd_signals[signal].add(self._vcd_writer.register_var(
326 scope=".".join(fragment_name), name=var_name_suffix,
327 var_type=var_type, size=var_size, init=var_init))
328 if signal not in self._vcd_names:
329 self._vcd_names[signal] = ".".join(fragment_name + (var_name_suffix,))
330 break
331 except KeyError:
332 suffix = (suffix or 0) + 1
333
334 for domain, signals in fragment.drivers.items():
335 if domain is None:
336 self._comb_signals.update(signals)
337 else:
338 self._sync_signals.update(signals)
339 self._domain_signals[domain].update(signals)
340
341 statements = []
342 for signal in fragment.iter_comb():
343 statements.append(signal.eq(signal.reset))
344 statements += fragment.statements
345
346 def add_funclet(signal, funclet):
347 if signal not in self._funclets:
348 self._funclets[signal] = set()
349 self._funclets[signal].add(funclet)
350
351 compiler = _StatementCompiler()
352 funclet = compiler(statements)
353 for signal in compiler.sensitivity:
354 add_funclet(signal, funclet)
355 for domain, cd in fragment.domains.items():
356 add_funclet(cd.clk, funclet)
357 if cd.rst is not None:
358 add_funclet(cd.rst, funclet)
359
360 self._user_signals = self._signals - self._comb_signals - self._sync_signals
361
362 return self
363
364 def _update_dirty_signals(self):
365 """Perform the statement part of IR processes (aka RTLIL case)."""
366 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
367 # that need their statements to be reevaluated because the signals changed at the previous
368 # delta cycle.
369 funclets = set()
370 while self._state.curr_dirty:
371 signal = self._state.curr_dirty.pop()
372 if signal in self._funclets:
373 funclets.update(self._funclets[signal])
374
375 # Second, compute the values of all signals at the start of the next delta cycle, by
376 # running precompiled statements.
377 for funclet in funclets:
378 funclet(self._state)
379
380 def _commit_signal(self, signal, domains):
381 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
382 # Take the computed value (at the start of this delta cycle) of a signal (that could have
383 # come from an IR process that ran earlier, or modified by a simulator process) and update
384 # the value for this delta cycle.
385 old, new = self._state.commit(signal)
386
387 # If the signal is a clock that triggers synchronous logic, record that fact.
388 if (old, new) == (0, 1) and signal in self._domain_triggers:
389 domains.add(self._domain_triggers[signal])
390
391 if self._vcd_writer and old != new:
392 # Finally, dump the new value to the VCD file.
393 for vcd_signal in self._vcd_signals[signal]:
394 if signal.decoder:
395 var_value = signal.decoder(new).replace(" ", "_")
396 else:
397 var_value = new
398 self._vcd_writer.change(vcd_signal, self._timestamp / self._epsilon, var_value)
399
400 def _commit_comb_signals(self, domains):
401 """Perform the comb part of IR processes (aka RTLIL always)."""
402 # Take the computed value (at the start of this delta cycle) of every comb signal and
403 # update the value for this delta cycle.
404 for signal in self._state.next_dirty:
405 if signal in self._comb_signals or signal in self._user_signals:
406 self._commit_signal(signal, domains)
407
408 def _commit_sync_signals(self, domains):
409 """Perform the sync part of IR processes (aka RTLIL posedge)."""
410 # At entry, `domains` contains a list of every simultaneously triggered sync update.
411 while domains:
412 # Advance the timeline a bit (purely for observational purposes) and commit all of them
413 # at the same timestamp.
414 self._timestamp += self._epsilon
415 curr_domains, domains = domains, set()
416
417 while curr_domains:
418 domain = curr_domains.pop()
419
420 # Take the computed value (at the start of this delta cycle) of every sync signal
421 # in this domain and update the value for this delta cycle. This can trigger more
422 # synchronous logic, so record that.
423 for signal in self._state.next_dirty:
424 if signal in self._domain_signals[domain]:
425 self._commit_signal(signal, domains)
426
427 # Wake up any simulator processes that wait for a domain tick.
428 for process, wait_domain in list(self._wait_tick.items()):
429 if domain == wait_domain:
430 del self._wait_tick[process]
431 self._suspended.remove(process)
432
433 # Unless handling synchronous logic above has triggered more synchronous logic (which
434 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
435 # Otherwise, do one more round of updates.
436
437 def _run_process(self, process):
438 def format_process(process):
439 frame = process.gi_frame
440 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
441
442 try:
443 cmd = process.send(None)
444 while True:
445 if isinstance(cmd, Delay):
446 if cmd.interval is None:
447 interval = self._epsilon
448 else:
449 interval = cmd.interval
450 self._wait_deadline[process] = self._timestamp + interval
451 self._suspended.add(process)
452
453 elif isinstance(cmd, Tick):
454 self._wait_tick[process] = cmd.domain
455 self._suspended.add(process)
456
457 elif isinstance(cmd, Passive):
458 self._passive.add(process)
459
460 elif isinstance(cmd, Value):
461 funclet = _RHSValueCompiler(sensitivity=ValueSet())(cmd)
462 cmd = process.send(funclet(self._state))
463 continue
464
465 elif isinstance(cmd, Assign):
466 lhs_signals = cmd.lhs._lhs_signals()
467 for signal in lhs_signals:
468 if not signal in self._signals:
469 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
470 "which is not a part of simulation"
471 .format(format_process(process), signal))
472 if signal in self._comb_signals:
473 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
474 "which is a part of combinatorial assignment in "
475 "simulation"
476 .format(format_process(process), signal))
477
478 funclet = _StatementCompiler()(cmd)
479 funclet(self._state)
480
481 domains = set()
482 for signal in lhs_signals:
483 self._commit_signal(signal, domains)
484 self._commit_sync_signals(domains)
485
486 else:
487 raise TypeError("Received unsupported command '{!r}' from process '{}'"
488 .format(cmd, format_process(process)))
489
490 break
491
492 except StopIteration:
493 self._processes.remove(process)
494 self._passive.discard(process)
495
496 except Exception as e:
497 process.throw(e)
498
499 def step(self, run_passive=False):
500 deadline = None
501 if self._wait_deadline:
502 # We might run some delta cycles, and we have simulator processes waiting on
503 # a deadline. Take care to not exceed the closest deadline.
504 deadline = min(self._wait_deadline.values())
505
506 # Are there any delta cycles we should run?
507 while self._state.curr_dirty:
508 self._timestamp += self._epsilon
509 if deadline is not None and self._timestamp >= deadline:
510 # Oops, we blew the deadline. We *could* run the processes now, but this is
511 # virtually certainly a logic loop and a design bug, so bail out instead.d
512 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
513
514 domains = set()
515 self._update_dirty_signals()
516 self._commit_comb_signals(domains)
517 self._commit_sync_signals(domains)
518
519 # Are there any processes that haven't had a chance to run yet?
520 if len(self._processes) > len(self._suspended):
521 # Schedule an arbitrary one.
522 process = (self._processes - set(self._suspended)).pop()
523 self._run_process(process)
524 return True
525
526 # All processes are suspended. Are any of them active?
527 if len(self._processes) > len(self._passive) or run_passive:
528 # Are any of them suspended before a deadline?
529 if self._wait_deadline:
530 # Schedule the one with the lowest deadline.
531 process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
532 del self._wait_deadline[process]
533 self._suspended.remove(process)
534 self._timestamp = deadline
535 self._run_process(process)
536 return True
537
538 # No processes, or all processes are passive. Nothing to do!
539 return False
540
541 def run(self):
542 while self.step():
543 pass
544
545 def run_until(self, deadline, run_passive=False):
546 while self._timestamp < deadline:
547 if not self.step(run_passive):
548 return False
549
550 return True
551
552 def __exit__(self, *args):
553 if self._vcd_writer:
554 self._vcd_writer.close(self._timestamp / self._epsilon)
555
556 if self._vcd_file and self._gtkw_file:
557 gtkw_save = GTKWSave(self._gtkw_file)
558 if hasattr(self._vcd_file, "name"):
559 gtkw_save.dumpfile(self._vcd_file.name)
560 if hasattr(self._vcd_file, "tell"):
561 gtkw_save.dumpfile_size(self._vcd_file.tell())
562
563 gtkw_save.treeopen("top")
564 gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
565
566 def add_trace(signal, **kwargs):
567 if signal in self._vcd_names:
568 if len(signal) > 1:
569 suffix = "[{}:0]".format(len(signal) - 1)
570 else:
571 suffix = ""
572 gtkw_save.trace(self._vcd_names[signal] + suffix, **kwargs)
573
574 for domain, cd in self._domains.items():
575 with gtkw_save.group("d.{}".format(domain)):
576 if cd.rst is not None:
577 add_trace(cd.rst)
578 add_trace(cd.clk)
579
580 for signal in self._traces:
581 add_trace(signal)
582
583 if self._vcd_file:
584 self._vcd_file.close()
585 if self._gtkw_file:
586 self._gtkw_file.close()