back.pysim: revert 70ebc6f2.
[nmigen.git] / nmigen / back / pysim.py
1 import math
2 import inspect
3 from contextlib import contextmanager
4 from vcd import VCDWriter
5 from vcd.gtkw import GTKWSave
6
7 from ..tools import flatten
8 from ..fhdl.ast import *
9 from ..fhdl.xfrm import ValueTransformer, StatementTransformer
10
11
12 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
13
14
15 class DeadlineError(Exception):
16 pass
17
18
19 class _State:
20 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
21
22 def __init__(self):
23 self.curr = ValueDict()
24 self.next = ValueDict()
25 self.curr_dirty = ValueSet()
26 self.next_dirty = ValueSet()
27
28 def set(self, signal, value):
29 assert isinstance(value, int)
30 if self.next[signal] != value:
31 self.next_dirty.add(signal)
32 self.next[signal] = value
33
34 def commit(self, signal):
35 old_value = self.curr[signal]
36 if self.curr[signal] != self.next[signal]:
37 self.next_dirty.remove(signal)
38 self.curr_dirty.add(signal)
39 self.curr[signal] = self.next[signal]
40 new_value = self.curr[signal]
41 return old_value, new_value
42
43
44 normalize = Const.normalize
45
46
47 class _RHSValueCompiler(ValueTransformer):
48 def __init__(self, sensitivity=None):
49 self.sensitivity = sensitivity
50
51 def on_Const(self, value):
52 return lambda state: value.value
53
54 def on_Signal(self, value):
55 if self.sensitivity is not None:
56 self.sensitivity.add(value)
57 return lambda state: state.curr[value]
58
59 def on_ClockSignal(self, value):
60 raise NotImplementedError # :nocov:
61
62 def on_ResetSignal(self, value):
63 raise NotImplementedError # :nocov:
64
65 def on_Operator(self, value):
66 shape = value.shape()
67 if len(value.operands) == 1:
68 arg, = map(self, value.operands)
69 if value.op == "~":
70 return lambda state: normalize(~arg(state), shape)
71 if value.op == "-":
72 return lambda state: normalize(-arg(state), shape)
73 if value.op == "b":
74 return lambda state: normalize(bool(arg(state)), shape)
75 elif len(value.operands) == 2:
76 lhs, rhs = map(self, value.operands)
77 if value.op == "+":
78 return lambda state: normalize(lhs(state) + rhs(state), shape)
79 if value.op == "-":
80 return lambda state: normalize(lhs(state) - rhs(state), shape)
81 if value.op == "&":
82 return lambda state: normalize(lhs(state) & rhs(state), shape)
83 if value.op == "|":
84 return lambda state: normalize(lhs(state) | rhs(state), shape)
85 if value.op == "^":
86 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
87 if value.op == "==":
88 return lambda state: normalize(lhs(state) == rhs(state), shape)
89 if value.op == "!=":
90 return lambda state: normalize(lhs(state) != rhs(state), shape)
91 if value.op == "<":
92 return lambda state: normalize(lhs(state) < rhs(state), shape)
93 if value.op == "<=":
94 return lambda state: normalize(lhs(state) <= rhs(state), shape)
95 if value.op == ">":
96 return lambda state: normalize(lhs(state) > rhs(state), shape)
97 if value.op == ">=":
98 return lambda state: normalize(lhs(state) >= rhs(state), shape)
99 elif len(value.operands) == 3:
100 if value.op == "m":
101 sel, val1, val0 = map(self, value.operands)
102 return lambda state: val1(state) if sel(state) else val0(state)
103 raise NotImplementedError("Operator '{!r}' not implemented".format(value.op)) # :nocov:
104
105 def on_Slice(self, value):
106 shape = value.shape()
107 arg = self(value.value)
108 shift = value.start
109 mask = (1 << (value.end - value.start)) - 1
110 return lambda state: normalize((arg(state) >> shift) & mask, shape)
111
112 def on_Part(self, value):
113 raise NotImplementedError
114
115 def on_Cat(self, value):
116 shape = value.shape()
117 parts = []
118 offset = 0
119 for opnd in value.operands:
120 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
121 offset += len(opnd)
122 def eval(state):
123 result = 0
124 for offset, mask, opnd in parts:
125 result |= (opnd(state) & mask) << offset
126 return normalize(result, shape)
127 return eval
128
129 def on_Repl(self, value):
130 shape = value.shape()
131 offset = len(value.value)
132 mask = (1 << len(value.value)) - 1
133 count = value.count
134 opnd = self(value.value)
135 def eval(state):
136 result = 0
137 for _ in range(count):
138 result <<= offset
139 result |= opnd(state)
140 return normalize(result, shape)
141 return eval
142
143
144 class _StatementCompiler(StatementTransformer):
145 def __init__(self):
146 self.sensitivity = ValueSet()
147 self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
148
149 def lhs_compiler(self, value):
150 # TODO
151 return lambda state, arg: state.set(value, arg)
152
153 def on_Assign(self, stmt):
154 assert isinstance(stmt.lhs, Signal)
155 shape = stmt.lhs.shape()
156 lhs = self.lhs_compiler(stmt.lhs)
157 rhs = self.rhs_compiler(stmt.rhs)
158 def run(state):
159 lhs(state, normalize(rhs(state), shape))
160 return run
161
162 def on_Switch(self, stmt):
163 test = self.rhs_compiler(stmt.test)
164 cases = []
165 for value, stmts in stmt.cases.items():
166 if "-" in value:
167 mask = "".join("0" if b == "-" else "1" for b in value)
168 value = "".join("0" if b == "-" else b for b in value)
169 else:
170 mask = "1" * len(value)
171 mask = int(mask, 2)
172 value = int(value, 2)
173 def make_test(mask, value):
174 return lambda test: test & mask == value
175 cases.append((make_test(mask, value), self.on_statements(stmts)))
176 def run(state):
177 test_value = test(state)
178 for check, body in cases:
179 if check(test_value):
180 body(state)
181 return
182 return run
183
184 def on_statements(self, stmts):
185 stmts = [self.on_statement(stmt) for stmt in stmts]
186 def run(state):
187 for stmt in stmts:
188 stmt(state)
189 return run
190
191
192 class Simulator:
193 def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()):
194 self._fragment = fragment
195
196 self._domains = {} # str/domain -> ClockDomain
197 self._domain_triggers = ValueDict() # Signal -> str/domain
198 self._domain_signals = {} # str/domain -> {Signal}
199
200 self._signals = ValueSet() # {Signal}
201 self._comb_signals = ValueSet() # {Signal}
202 self._sync_signals = ValueSet() # {Signal}
203 self._user_signals = ValueSet() # {Signal}
204
205 self._started = False
206 self._timestamp = 0.
207 self._epsilon = 1e-10
208 self._fastest_clock = self._epsilon
209 self._state = _State()
210
211 self._processes = set() # {process}
212 self._passive = set() # {process}
213 self._suspended = set() # {process}
214 self._wait_deadline = {} # process -> float/timestamp
215 self._wait_tick = {} # process -> str/domain
216
217 self._funclets = ValueDict() # Signal -> set(lambda)
218
219 self._vcd_file = vcd_file
220 self._vcd_writer = None
221 self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
222 self._vcd_names = ValueDict() # signal -> str/name
223 self._gtkw_file = gtkw_file
224 self._traces = traces
225
226 def _check_process(self, process):
227 if inspect.isgeneratorfunction(process):
228 process = process()
229 if not inspect.isgenerator(process):
230 raise TypeError("Cannot add a process '{!r}' because it is not a generator or"
231 "a generator function"
232 .format(process))
233 return process
234
235 def add_process(self, process):
236 process = self._check_process(process)
237 self._processes.add(process)
238
239 def add_sync_process(self, process, domain="sync"):
240 process = self._check_process(process)
241 def sync_process():
242 try:
243 result = None
244 while True:
245 if result is None:
246 result = Tick(domain)
247 result = process.send((yield result))
248 except StopIteration:
249 pass
250 self.add_process(sync_process())
251
252 def add_clock(self, period, phase=None, domain="sync"):
253 if self._fastest_clock == self._epsilon or period < self._fastest_clock:
254 self._fastest_clock = period
255
256 half_period = period / 2
257 if phase is None:
258 phase = half_period
259 clk = self._domains[domain].clk
260 def clk_process():
261 yield Passive()
262 yield Delay(phase)
263 while True:
264 yield clk.eq(1)
265 yield Delay(half_period)
266 yield clk.eq(0)
267 yield Delay(half_period)
268 self.add_process(clk_process())
269
270 def __enter__(self):
271 if self._vcd_file:
272 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
273 comment="Generated by nMigen")
274
275 root_fragment = self._fragment.prepare()
276
277 self._domains = root_fragment.domains
278 for domain, cd in self._domains.items():
279 self._domain_triggers[cd.clk] = domain
280 if cd.rst is not None:
281 self._domain_triggers[cd.rst] = domain
282 self._domain_signals[domain] = ValueSet()
283
284 hierarchy = {}
285 def add_fragment(fragment, scope=()):
286 hierarchy[fragment] = scope
287 for subfragment, name in fragment.subfragments:
288 add_fragment(subfragment, (*scope, name))
289 add_fragment(root_fragment)
290
291 for fragment, fragment_name in hierarchy.items():
292 for signal in fragment.iter_signals():
293 self._signals.add(signal)
294
295 self._state.curr[signal] = self._state.next[signal] = \
296 normalize(signal.reset, signal.shape())
297 self._state.curr_dirty.add(signal)
298
299 if not self._vcd_writer:
300 continue
301
302 if signal not in self._vcd_signals:
303 self._vcd_signals[signal] = set()
304
305 for subfragment, name in fragment.subfragments:
306 if signal in subfragment.ports:
307 var_name = "{}_{}".format(name, signal.name)
308 break
309 else:
310 var_name = signal.name
311
312 if signal.decoder:
313 var_type = "string"
314 var_size = 1
315 var_init = signal.decoder(signal.reset).replace(" ", "_")
316 else:
317 var_type = "wire"
318 var_size = signal.nbits
319 var_init = signal.reset
320
321 suffix = None
322 while True:
323 try:
324 if suffix is None:
325 var_name_suffix = var_name
326 else:
327 var_name_suffix = "{}${}".format(var_name, suffix)
328 self._vcd_signals[signal].add(self._vcd_writer.register_var(
329 scope=".".join(fragment_name), name=var_name_suffix,
330 var_type=var_type, size=var_size, init=var_init))
331 if signal not in self._vcd_names:
332 self._vcd_names[signal] = ".".join(fragment_name + (var_name_suffix,))
333 break
334 except KeyError:
335 suffix = (suffix or 0) + 1
336
337 for domain, signals in fragment.drivers.items():
338 if domain is None:
339 self._comb_signals.update(signals)
340 else:
341 self._sync_signals.update(signals)
342 self._domain_signals[domain].update(signals)
343
344 statements = []
345 for signal in fragment.iter_comb():
346 statements.append(signal.eq(signal.reset))
347 for domain, signal in fragment.iter_sync():
348 statements.append(signal.eq(signal))
349 statements += fragment.statements
350
351 compiler = _StatementCompiler()
352 funclet = compiler(statements)
353
354 def add_funclet(signal, funclet):
355 if signal not in self._funclets:
356 self._funclets[signal] = set()
357 self._funclets[signal].add(funclet)
358
359 for signal in compiler.sensitivity:
360 add_funclet(signal, funclet)
361 for domain, cd in fragment.domains.items():
362 add_funclet(cd.clk, funclet)
363 if cd.rst is not None:
364 add_funclet(cd.rst, funclet)
365
366 self._user_signals = self._signals - self._comb_signals - self._sync_signals
367
368 return self
369
370 def _update_dirty_signals(self):
371 """Perform the statement part of IR processes (aka RTLIL case)."""
372 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
373 # that need their statements to be reevaluated because the signals changed at the previous
374 # delta cycle.
375 funclets = set()
376 while self._state.curr_dirty:
377 signal = self._state.curr_dirty.pop()
378 if signal in self._funclets:
379 funclets.update(self._funclets[signal])
380
381 # Second, compute the values of all signals at the start of the next delta cycle, by
382 # running precompiled statements.
383 for funclet in funclets:
384 funclet(self._state)
385
386 def _commit_signal(self, signal, domains):
387 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
388 # Take the computed value (at the start of this delta cycle) of a signal (that could have
389 # come from an IR process that ran earlier, or modified by a simulator process) and update
390 # the value for this delta cycle.
391 old, new = self._state.commit(signal)
392
393 # If the signal is a clock that triggers synchronous logic, record that fact.
394 if (old, new) == (0, 1) and signal in self._domain_triggers:
395 domains.add(self._domain_triggers[signal])
396
397 if self._vcd_writer and old != new:
398 # Finally, dump the new value to the VCD file.
399 for vcd_signal in self._vcd_signals[signal]:
400 if signal.decoder:
401 var_value = signal.decoder(new).replace(" ", "_")
402 else:
403 var_value = new
404 self._vcd_writer.change(vcd_signal, self._timestamp / self._epsilon, var_value)
405
406 def _commit_comb_signals(self, domains):
407 """Perform the comb part of IR processes (aka RTLIL always)."""
408 # Take the computed value (at the start of this delta cycle) of every comb signal and
409 # update the value for this delta cycle.
410 for signal in self._state.next_dirty:
411 if signal in self._comb_signals or signal in self._user_signals:
412 self._commit_signal(signal, domains)
413
414 def _commit_sync_signals(self, domains):
415 """Perform the sync part of IR processes (aka RTLIL posedge)."""
416 # At entry, `domains` contains a list of every simultaneously triggered sync update.
417 while domains:
418 # Advance the timeline a bit (purely for observational purposes) and commit all of them
419 # at the same timestamp.
420 self._timestamp += self._epsilon
421 curr_domains, domains = domains, set()
422
423 while curr_domains:
424 domain = curr_domains.pop()
425
426 # Take the computed value (at the start of this delta cycle) of every sync signal
427 # in this domain and update the value for this delta cycle. This can trigger more
428 # synchronous logic, so record that.
429 for signal in self._state.next_dirty:
430 if signal in self._domain_signals[domain]:
431 self._commit_signal(signal, domains)
432
433 # Wake up any simulator processes that wait for a domain tick.
434 for process, wait_domain in list(self._wait_tick.items()):
435 if domain == wait_domain:
436 del self._wait_tick[process]
437 self._suspended.remove(process)
438
439 # Unless handling synchronous logic above has triggered more synchronous logic (which
440 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
441 # Otherwise, do one more round of updates.
442
443 def _run_process(self, process):
444 def format_process(process):
445 frame = process.gi_frame
446 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
447
448 try:
449 cmd = process.send(None)
450 while True:
451 if isinstance(cmd, Delay):
452 if cmd.interval is None:
453 interval = self._epsilon
454 else:
455 interval = cmd.interval
456 self._wait_deadline[process] = self._timestamp + interval
457 self._suspended.add(process)
458
459 elif isinstance(cmd, Tick):
460 self._wait_tick[process] = cmd.domain
461 self._suspended.add(process)
462
463 elif isinstance(cmd, Passive):
464 self._passive.add(process)
465
466 elif isinstance(cmd, Value):
467 compiler = _RHSValueCompiler()
468 funclet = compiler(cmd)
469 cmd = process.send(funclet(self._state))
470 continue
471
472 elif isinstance(cmd, Assign):
473 lhs_signals = cmd.lhs._lhs_signals()
474 for signal in lhs_signals:
475 if not signal in self._signals:
476 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
477 "which is not a part of simulation"
478 .format(format_process(process), signal))
479 if signal in self._comb_signals:
480 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
481 "which is a part of combinatorial assignment in "
482 "simulation"
483 .format(format_process(process), signal))
484
485 compiler = _StatementCompiler()
486 funclet = compiler(cmd)
487 funclet(self._state)
488
489 domains = set()
490 for signal in lhs_signals:
491 self._commit_signal(signal, domains)
492 self._commit_sync_signals(domains)
493
494 else:
495 raise TypeError("Received unsupported command '{!r}' from process '{}'"
496 .format(cmd, format_process(process)))
497
498 break
499
500 except StopIteration:
501 self._processes.remove(process)
502 self._passive.discard(process)
503
504 except Exception as e:
505 process.throw(e)
506
507 def step(self, run_passive=False):
508 deadline = None
509 if self._wait_deadline:
510 # We might run some delta cycles, and we have simulator processes waiting on
511 # a deadline. Take care to not exceed the closest deadline.
512 deadline = min(self._wait_deadline.values())
513
514 # Are there any delta cycles we should run?
515 while self._state.curr_dirty:
516 self._timestamp += self._epsilon
517 if deadline is not None and self._timestamp >= deadline:
518 # Oops, we blew the deadline. We *could* run the processes now, but this is
519 # virtually certainly a logic loop and a design bug, so bail out instead.d
520 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
521
522 domains = set()
523 self._update_dirty_signals()
524 self._commit_comb_signals(domains)
525 self._commit_sync_signals(domains)
526
527 # Are there any processes that haven't had a chance to run yet?
528 if len(self._processes) > len(self._suspended):
529 # Schedule an arbitrary one.
530 process = (self._processes - set(self._suspended)).pop()
531 self._run_process(process)
532 return True
533
534 # All processes are suspended. Are any of them active?
535 if len(self._processes) > len(self._passive) or run_passive:
536 # Are any of them suspended before a deadline?
537 if self._wait_deadline:
538 # Schedule the one with the lowest deadline.
539 process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
540 del self._wait_deadline[process]
541 self._suspended.remove(process)
542 self._timestamp = deadline
543 self._run_process(process)
544 return True
545
546 # No processes, or all processes are passive. Nothing to do!
547 return False
548
549 def run(self):
550 while self.step():
551 pass
552
553 def run_until(self, deadline, run_passive=False):
554 while self._timestamp < deadline:
555 if not self.step(run_passive):
556 return False
557
558 return True
559
560 def __exit__(self, *args):
561 if self._vcd_writer:
562 self._vcd_writer.close(self._timestamp / self._epsilon)
563
564 if self._vcd_file and self._gtkw_file:
565 gtkw_save = GTKWSave(self._gtkw_file)
566 if hasattr(self._vcd_file, "name"):
567 gtkw_save.dumpfile(self._vcd_file.name)
568 if hasattr(self._vcd_file, "tell"):
569 gtkw_save.dumpfile_size(self._vcd_file.tell())
570
571 gtkw_save.treeopen("top")
572 gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
573
574 def add_trace(signal, **kwargs):
575 if signal in self._vcd_names:
576 if len(signal) > 1:
577 suffix = "[{}:0]".format(len(signal) - 1)
578 else:
579 suffix = ""
580 gtkw_save.trace(self._vcd_names[signal] + suffix, **kwargs)
581
582 for domain, cd in self._domains.items():
583 with gtkw_save.group("d.{}".format(domain)):
584 if cd.rst is not None:
585 add_trace(cd.rst)
586 add_trace(cd.clk)
587
588 for signal in self._traces:
589 add_trace(signal)
590
591 if self._vcd_file:
592 self._vcd_file.close()
593 if self._gtkw_file:
594 self._gtkw_file.close()