back.pysim: squash one level of hierarchy.
[nmigen.git] / nmigen / back / pysim.py
1 import math
2 import inspect
3 from contextlib import contextmanager
4 from vcd import VCDWriter
5 from vcd.gtkw import GTKWSave
6
7 from ..tools import flatten
8 from ..fhdl.ast import *
9 from ..fhdl.xfrm import ValueTransformer, StatementTransformer
10
11
12 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
13
14
15 class DeadlineError(Exception):
16 pass
17
18
19 class _State:
20 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
21
22 def __init__(self):
23 self.curr = ValueDict()
24 self.next = ValueDict()
25 self.curr_dirty = ValueSet()
26 self.next_dirty = ValueSet()
27
28 def set(self, signal, value):
29 assert isinstance(value, int)
30 if self.next[signal] != value:
31 self.next_dirty.add(signal)
32 self.next[signal] = value
33
34 def commit(self, signal):
35 old_value = self.curr[signal]
36 if self.curr[signal] != self.next[signal]:
37 self.next_dirty.remove(signal)
38 self.curr_dirty.add(signal)
39 self.curr[signal] = self.next[signal]
40 new_value = self.curr[signal]
41 return old_value, new_value
42
43
44 normalize = Const.normalize
45
46
47 class _RHSValueCompiler(ValueTransformer):
48 def __init__(self, sensitivity=None):
49 self.sensitivity = sensitivity
50 self.signal_mode = "next"
51
52 def on_Const(self, value):
53 return lambda state: value.value
54
55 def on_Signal(self, value):
56 if self.sensitivity:
57 self.sensitivity.add(value)
58 if self.signal_mode == "curr":
59 return lambda state: state.curr[value]
60 if self.signal_mode == "next":
61 return lambda state: state.next[value]
62 raise NotImplementedError # :nocov:
63
64 def on_ClockSignal(self, value):
65 raise NotImplementedError # :nocov:
66
67 def on_ResetSignal(self, value):
68 raise NotImplementedError # :nocov:
69
70 def on_Operator(self, value):
71 shape = value.shape()
72 if len(value.operands) == 1:
73 arg, = map(self, value.operands)
74 if value.op == "~":
75 return lambda state: normalize(~arg(state), shape)
76 if value.op == "-":
77 return lambda state: normalize(-arg(state), shape)
78 if value.op == "b":
79 return lambda state: normalize(bool(arg(state)), shape)
80 elif len(value.operands) == 2:
81 lhs, rhs = map(self, value.operands)
82 if value.op == "+":
83 return lambda state: normalize(lhs(state) + rhs(state), shape)
84 if value.op == "-":
85 return lambda state: normalize(lhs(state) - rhs(state), shape)
86 if value.op == "&":
87 return lambda state: normalize(lhs(state) & rhs(state), shape)
88 if value.op == "|":
89 return lambda state: normalize(lhs(state) | rhs(state), shape)
90 if value.op == "^":
91 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
92 if value.op == "==":
93 return lambda state: normalize(lhs(state) == rhs(state), shape)
94 if value.op == "!=":
95 return lambda state: normalize(lhs(state) != rhs(state), shape)
96 if value.op == "<":
97 return lambda state: normalize(lhs(state) < rhs(state), shape)
98 if value.op == "<=":
99 return lambda state: normalize(lhs(state) <= rhs(state), shape)
100 if value.op == ">":
101 return lambda state: normalize(lhs(state) > rhs(state), shape)
102 if value.op == ">=":
103 return lambda state: normalize(lhs(state) >= rhs(state), shape)
104 elif len(value.operands) == 3:
105 if value.op == "m":
106 sel, val1, val0 = map(self, value.operands)
107 return lambda state: val1(state) if sel(state) else val0(state)
108 raise NotImplementedError("Operator '{!r}' not implemented".format(value.op)) # :nocov:
109
110 def on_Slice(self, value):
111 shape = value.shape()
112 arg = self(value.value)
113 shift = value.start
114 mask = (1 << (value.end - value.start)) - 1
115 return lambda state: normalize((arg(state) >> shift) & mask, shape)
116
117 def on_Part(self, value):
118 raise NotImplementedError
119
120 def on_Cat(self, value):
121 shape = value.shape()
122 parts = []
123 offset = 0
124 for opnd in value.operands:
125 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
126 offset += len(opnd)
127 def eval(state):
128 result = 0
129 for offset, mask, opnd in parts:
130 result |= (opnd(state) & mask) << offset
131 return normalize(result, shape)
132 return eval
133
134 def on_Repl(self, value):
135 shape = value.shape()
136 offset = len(value.value)
137 mask = (1 << len(value.value)) - 1
138 count = value.count
139 opnd = self(value.value)
140 def eval(state):
141 result = 0
142 for _ in range(count):
143 result <<= offset
144 result |= opnd(state)
145 return normalize(result, shape)
146 return eval
147
148
149 class _StatementCompiler(StatementTransformer):
150 def __init__(self):
151 self.sensitivity = ValueSet()
152 self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
153
154 @contextmanager
155 def initial(self):
156 try:
157 self.rhs_compiler.signal_mode = "curr"
158 yield
159 finally:
160 self.rhs_compiler.signal_mode = "next"
161
162 def lhs_compiler(self, value):
163 # TODO
164 return lambda state, arg: state.set(value, arg)
165
166 def on_Assign(self, stmt):
167 assert isinstance(stmt.lhs, Signal)
168 shape = stmt.lhs.shape()
169 lhs = self.lhs_compiler(stmt.lhs)
170 rhs = self.rhs_compiler(stmt.rhs)
171 def run(state):
172 lhs(state, normalize(rhs(state), shape))
173 return run
174
175 def on_Switch(self, stmt):
176 test = self.rhs_compiler(stmt.test)
177 cases = []
178 for value, stmts in stmt.cases.items():
179 if "-" in value:
180 mask = "".join("0" if b == "-" else "1" for b in value)
181 value = "".join("0" if b == "-" else b for b in value)
182 else:
183 mask = "1" * len(value)
184 mask = int(mask, 2)
185 value = int(value, 2)
186 def make_test(mask, value):
187 return lambda test: test & mask == value
188 cases.append((make_test(mask, value), self.on_statements(stmts)))
189 def run(state):
190 test_value = test(state)
191 for check, body in cases:
192 if check(test_value):
193 body(state)
194 return
195 return run
196
197 def on_statements(self, stmts):
198 stmts = [self.on_statement(stmt) for stmt in stmts]
199 def run(state):
200 for stmt in stmts:
201 stmt(state)
202 return run
203
204
205 class Simulator:
206 def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()):
207 self._fragment = fragment
208
209 self._domains = {} # str/domain -> ClockDomain
210 self._domain_triggers = ValueDict() # Signal -> str/domain
211 self._domain_signals = {} # str/domain -> {Signal}
212
213 self._signals = ValueSet() # {Signal}
214 self._comb_signals = ValueSet() # {Signal}
215 self._sync_signals = ValueSet() # {Signal}
216 self._user_signals = ValueSet() # {Signal}
217
218 self._started = False
219 self._timestamp = 0.
220 self._epsilon = 1e-10
221 self._fastest_clock = self._epsilon
222 self._state = _State()
223
224 self._processes = set() # {process}
225 self._passive = set() # {process}
226 self._suspended = set() # {process}
227 self._wait_deadline = {} # process -> float/timestamp
228 self._wait_tick = {} # process -> str/domain
229
230 self._funclets = ValueDict() # Signal -> set(lambda)
231
232 self._vcd_file = vcd_file
233 self._vcd_writer = None
234 self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
235 self._vcd_names = ValueDict() # signal -> str/name
236 self._gtkw_file = gtkw_file
237 self._traces = traces
238
239 def _check_process(self, process):
240 if inspect.isgeneratorfunction(process):
241 process = process()
242 if not inspect.isgenerator(process):
243 raise TypeError("Cannot add a process '{!r}' because it is not a generator or"
244 "a generator function"
245 .format(process))
246 return process
247
248 def add_process(self, process):
249 process = self._check_process(process)
250 self._processes.add(process)
251
252 def add_sync_process(self, process, domain="sync"):
253 process = self._check_process(process)
254 def sync_process():
255 try:
256 result = None
257 while True:
258 if result is None:
259 result = Tick(domain)
260 result = process.send((yield result))
261 except StopIteration:
262 pass
263 self.add_process(sync_process())
264
265 def add_clock(self, period, phase=None, domain="sync"):
266 if self._fastest_clock == self._epsilon or period < self._fastest_clock:
267 self._fastest_clock = period
268
269 half_period = period / 2
270 if phase is None:
271 phase = half_period
272 clk = self._domains[domain].clk
273 def clk_process():
274 yield Passive()
275 yield Delay(phase)
276 while True:
277 yield clk.eq(1)
278 yield Delay(half_period)
279 yield clk.eq(0)
280 yield Delay(half_period)
281 self.add_process(clk_process())
282
283 def __enter__(self):
284 if self._vcd_file:
285 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
286 comment="Generated by nMigen")
287
288 root_fragment = self._fragment.prepare()
289
290 self._domains = root_fragment.domains
291 for domain, cd in self._domains.items():
292 self._domain_triggers[cd.clk] = domain
293 if cd.rst is not None:
294 self._domain_triggers[cd.rst] = domain
295 self._domain_signals[domain] = ValueSet()
296
297 hierarchy = {}
298 def add_fragment(fragment, scope=()):
299 hierarchy[fragment] = scope
300 for subfragment, name in fragment.subfragments:
301 add_fragment(subfragment, (*scope, name))
302 add_fragment(root_fragment)
303
304 for fragment, fragment_name in hierarchy.items():
305 for signal in fragment.iter_signals():
306 self._signals.add(signal)
307
308 self._state.curr[signal] = self._state.next[signal] = \
309 normalize(signal.reset, signal.shape())
310 self._state.curr_dirty.add(signal)
311
312 if not self._vcd_writer:
313 continue
314
315 if signal not in self._vcd_signals:
316 self._vcd_signals[signal] = set()
317
318 for subfragment, name in fragment.subfragments:
319 if signal in subfragment.ports:
320 var_name = "{}_{}".format(name, signal.name)
321 break
322 else:
323 var_name = signal.name
324
325 if signal.decoder:
326 var_type = "string"
327 var_size = 1
328 var_init = signal.decoder(signal.reset).replace(" ", "_")
329 else:
330 var_type = "wire"
331 var_size = signal.nbits
332 var_init = signal.reset
333
334 suffix = None
335 while True:
336 try:
337 if suffix is None:
338 var_name_suffix = var_name
339 else:
340 var_name_suffix = "{}${}".format(var_name, suffix)
341 self._vcd_signals[signal].add(self._vcd_writer.register_var(
342 scope=".".join(fragment_name), name=var_name_suffix,
343 var_type=var_type, size=var_size, init=var_init))
344 if signal not in self._vcd_names:
345 self._vcd_names[signal] = ".".join(fragment_name + (var_name_suffix,))
346 break
347 except KeyError:
348 suffix = (suffix or 0) + 1
349
350 for domain, signals in fragment.drivers.items():
351 if domain is None:
352 self._comb_signals.update(signals)
353 else:
354 self._sync_signals.update(signals)
355 self._domain_signals[domain].update(signals)
356
357 initial_stmts = []
358 for signal in fragment.iter_comb():
359 initial_stmts.append(signal.eq(signal.reset))
360 for domain, signal in fragment.iter_sync():
361 initial_stmts.append(signal.eq(signal))
362
363 compiler = _StatementCompiler()
364 def make_funclet():
365 with compiler.initial():
366 funclet_init = compiler(initial_stmts)
367 funclet_frag = compiler(fragment.statements)
368 def funclet(state):
369 funclet_init(state)
370 funclet_frag(state)
371 return funclet
372 funclet = make_funclet()
373
374 def add_funclet(signal, funclet):
375 if signal not in self._funclets:
376 self._funclets[signal] = set()
377 self._funclets[signal].add(funclet)
378
379 for signal in compiler.sensitivity:
380 add_funclet(signal, funclet)
381 for domain, cd in fragment.domains.items():
382 add_funclet(cd.clk, funclet)
383 if cd.rst is not None:
384 add_funclet(cd.rst, funclet)
385
386 self._user_signals = self._signals - self._comb_signals - self._sync_signals
387
388 return self
389
390 def _update_dirty_signals(self):
391 """Perform the statement part of IR processes (aka RTLIL case)."""
392 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
393 # that need their statements to be reevaluated because the signals changed at the previous
394 # delta cycle.
395 funclets = set()
396 while self._state.curr_dirty:
397 signal = self._state.curr_dirty.pop()
398 if signal in self._funclets:
399 funclets.update(self._funclets[signal])
400
401 # Second, compute the values of all signals at the start of the next delta cycle, by
402 # running precompiled statements.
403 for funclet in funclets:
404 funclet(self._state)
405
406 def _commit_signal(self, signal, domains):
407 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
408 # Take the computed value (at the start of this delta cycle) of a signal (that could have
409 # come from an IR process that ran earlier, or modified by a simulator process) and update
410 # the value for this delta cycle.
411 old, new = self._state.commit(signal)
412
413 # If the signal is a clock that triggers synchronous logic, record that fact.
414 if (old, new) == (0, 1) and signal in self._domain_triggers:
415 domains.add(self._domain_triggers[signal])
416
417 if self._vcd_writer and old != new:
418 # Finally, dump the new value to the VCD file.
419 for vcd_signal in self._vcd_signals[signal]:
420 if signal.decoder:
421 var_value = signal.decoder(new).replace(" ", "_")
422 else:
423 var_value = new
424 self._vcd_writer.change(vcd_signal, self._timestamp / self._epsilon, var_value)
425
426 def _commit_comb_signals(self, domains):
427 """Perform the comb part of IR processes (aka RTLIL always)."""
428 # Take the computed value (at the start of this delta cycle) of every comb signal and
429 # update the value for this delta cycle.
430 for signal in self._state.next_dirty:
431 if signal in self._comb_signals or signal in self._user_signals:
432 self._commit_signal(signal, domains)
433
434 def _commit_sync_signals(self, domains):
435 """Perform the sync part of IR processes (aka RTLIL posedge)."""
436 # At entry, `domains` contains a list of every simultaneously triggered sync update.
437 while domains:
438 # Advance the timeline a bit (purely for observational purposes) and commit all of them
439 # at the same timestamp.
440 self._timestamp += self._epsilon
441 curr_domains, domains = domains, set()
442
443 while curr_domains:
444 domain = curr_domains.pop()
445
446 # Take the computed value (at the start of this delta cycle) of every sync signal
447 # in this domain and update the value for this delta cycle. This can trigger more
448 # synchronous logic, so record that.
449 for signal in self._state.next_dirty:
450 if signal in self._domain_signals[domain]:
451 self._commit_signal(signal, domains)
452
453 # Wake up any simulator processes that wait for a domain tick.
454 for process, wait_domain in list(self._wait_tick.items()):
455 if domain == wait_domain:
456 del self._wait_tick[process]
457 self._suspended.remove(process)
458
459 # Unless handling synchronous logic above has triggered more synchronous logic (which
460 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
461 # Otherwise, do one more round of updates.
462
463 def _run_process(self, process):
464 def format_process(process):
465 frame = process.gi_frame
466 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
467
468 try:
469 cmd = process.send(None)
470 while True:
471 if isinstance(cmd, Delay):
472 if cmd.interval is None:
473 interval = self._epsilon
474 else:
475 interval = cmd.interval
476 self._wait_deadline[process] = self._timestamp + interval
477 self._suspended.add(process)
478
479 elif isinstance(cmd, Tick):
480 self._wait_tick[process] = cmd.domain
481 self._suspended.add(process)
482
483 elif isinstance(cmd, Passive):
484 self._passive.add(process)
485
486 elif isinstance(cmd, Value):
487 funclet = _RHSValueCompiler()(cmd)
488 cmd = process.send(funclet(self._state))
489 continue
490
491 elif isinstance(cmd, Assign):
492 lhs_signals = cmd.lhs._lhs_signals()
493 for signal in lhs_signals:
494 if not signal in self._signals:
495 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
496 "which is not a part of simulation"
497 .format(format_process(process), signal))
498 if signal in self._comb_signals:
499 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
500 "which is a part of combinatorial assignment in "
501 "simulation"
502 .format(format_process(process), signal))
503
504 funclet = _StatementCompiler()(cmd)
505 funclet(self._state)
506
507 domains = set()
508 for signal in lhs_signals:
509 self._commit_signal(signal, domains)
510 self._commit_sync_signals(domains)
511
512 else:
513 raise TypeError("Received unsupported command '{!r}' from process '{}'"
514 .format(cmd, format_process(process)))
515
516 break
517
518 except StopIteration:
519 self._processes.remove(process)
520 self._passive.discard(process)
521
522 except Exception as e:
523 process.throw(e)
524
525 def step(self, run_passive=False):
526 deadline = None
527 if self._wait_deadline:
528 # We might run some delta cycles, and we have simulator processes waiting on
529 # a deadline. Take care to not exceed the closest deadline.
530 deadline = min(self._wait_deadline.values())
531
532 # Are there any delta cycles we should run?
533 while self._state.curr_dirty:
534 self._timestamp += self._epsilon
535 if deadline is not None and self._timestamp >= deadline:
536 # Oops, we blew the deadline. We *could* run the processes now, but this is
537 # virtually certainly a logic loop and a design bug, so bail out instead.d
538 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
539
540 domains = set()
541 self._update_dirty_signals()
542 self._commit_comb_signals(domains)
543 self._commit_sync_signals(domains)
544
545 # Are there any processes that haven't had a chance to run yet?
546 if len(self._processes) > len(self._suspended):
547 # Schedule an arbitrary one.
548 process = (self._processes - set(self._suspended)).pop()
549 self._run_process(process)
550 return True
551
552 # All processes are suspended. Are any of them active?
553 if len(self._processes) > len(self._passive) or run_passive:
554 # Are any of them suspended before a deadline?
555 if self._wait_deadline:
556 # Schedule the one with the lowest deadline.
557 process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
558 del self._wait_deadline[process]
559 self._suspended.remove(process)
560 self._timestamp = deadline
561 self._run_process(process)
562 return True
563
564 # No processes, or all processes are passive. Nothing to do!
565 return False
566
567 def run(self):
568 while self.step():
569 pass
570
571 def run_until(self, deadline, run_passive=False):
572 while self._timestamp < deadline:
573 if not self.step(run_passive):
574 return False
575
576 return True
577
578 def __exit__(self, *args):
579 if self._vcd_writer:
580 self._vcd_writer.close(self._timestamp / self._epsilon)
581
582 if self._vcd_file and self._gtkw_file:
583 gtkw_save = GTKWSave(self._gtkw_file)
584 if hasattr(self._vcd_file, "name"):
585 gtkw_save.dumpfile(self._vcd_file.name)
586 if hasattr(self._vcd_file, "tell"):
587 gtkw_save.dumpfile_size(self._vcd_file.tell())
588
589 gtkw_save.treeopen("top")
590 gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
591
592 def add_trace(signal, **kwargs):
593 if signal in self._vcd_names:
594 if len(signal) > 1:
595 suffix = "[{}:0]".format(len(signal) - 1)
596 else:
597 suffix = ""
598 gtkw_save.trace(self._vcd_names[signal] + suffix, **kwargs)
599
600 for domain, cd in self._domains.items():
601 with gtkw_save.group("d.{}".format(domain)):
602 if cd.rst is not None:
603 add_trace(cd.rst)
604 add_trace(cd.clk)
605
606 for signal in self._traces:
607 add_trace(signal)
608
609 if self._vcd_file:
610 self._vcd_file.close()
611 if self._gtkw_file:
612 self._gtkw_file.close()