back.pysim: naming. NFC.
[nmigen.git] / nmigen / back / pysim.py
1 import math
2 import inspect
3 from contextlib import contextmanager
4 from vcd import VCDWriter
5 from vcd.gtkw import GTKWSave
6
7 from ..tools import flatten
8 from ..hdl.ast import *
9 from ..hdl.xfrm import AbstractValueTransformer, AbstractStatementTransformer
10
11
12 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
13
14
15 class DeadlineError(Exception):
16 pass
17
18
19 class _State:
20 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
21
22 def __init__(self):
23 self.curr = SignalDict()
24 self.next = SignalDict()
25 self.curr_dirty = SignalSet()
26 self.next_dirty = SignalSet()
27
28 def set(self, signal, value):
29 assert isinstance(value, int)
30 if self.next[signal] != value:
31 self.next_dirty.add(signal)
32 self.next[signal] = value
33
34 def commit(self, signal):
35 old_value = self.curr[signal]
36 new_value = self.next[signal]
37 if old_value != new_value:
38 self.next_dirty.remove(signal)
39 self.curr_dirty.add(signal)
40 self.curr[signal] = new_value
41 return old_value, new_value
42
43
44 normalize = Const.normalize
45
46
47 class _RHSValueCompiler(AbstractValueTransformer):
48 def __init__(self, sensitivity=None, mode="rhs"):
49 self.sensitivity = sensitivity
50 self.signal_mode = mode
51
52 def on_Const(self, value):
53 return lambda state: value.value
54
55 def on_Signal(self, value):
56 if self.sensitivity is not None:
57 self.sensitivity.add(value)
58 if self.signal_mode == "rhs":
59 return lambda state: state.curr[value]
60 elif self.signal_mode == "lhs":
61 return lambda state: state.next[value]
62 else:
63 raise ValueError # :nocov:
64
65 def on_ClockSignal(self, value):
66 raise NotImplementedError # :nocov:
67
68 def on_ResetSignal(self, value):
69 raise NotImplementedError # :nocov:
70
71 def on_Operator(self, value):
72 shape = value.shape()
73 if len(value.operands) == 1:
74 arg, = map(self, value.operands)
75 if value.op == "~":
76 return lambda state: normalize(~arg(state), shape)
77 if value.op == "-":
78 return lambda state: normalize(-arg(state), shape)
79 if value.op == "b":
80 return lambda state: normalize(bool(arg(state)), shape)
81 elif len(value.operands) == 2:
82 lhs, rhs = map(self, value.operands)
83 if value.op == "+":
84 return lambda state: normalize(lhs(state) + rhs(state), shape)
85 if value.op == "-":
86 return lambda state: normalize(lhs(state) - rhs(state), shape)
87 if value.op == "&":
88 return lambda state: normalize(lhs(state) & rhs(state), shape)
89 if value.op == "|":
90 return lambda state: normalize(lhs(state) | rhs(state), shape)
91 if value.op == "^":
92 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
93 if value.op == "<<":
94 def sshl(lhs, rhs):
95 return lhs << rhs if rhs >= 0 else lhs >> -rhs
96 return lambda state: normalize(sshl(lhs(state), rhs(state)), shape)
97 if value.op == ">>":
98 def sshr(lhs, rhs):
99 return lhs >> rhs if rhs >= 0 else lhs << -rhs
100 return lambda state: normalize(sshr(lhs(state), rhs(state)), shape)
101 if value.op == "==":
102 return lambda state: normalize(lhs(state) == rhs(state), shape)
103 if value.op == "!=":
104 return lambda state: normalize(lhs(state) != rhs(state), shape)
105 if value.op == "<":
106 return lambda state: normalize(lhs(state) < rhs(state), shape)
107 if value.op == "<=":
108 return lambda state: normalize(lhs(state) <= rhs(state), shape)
109 if value.op == ">":
110 return lambda state: normalize(lhs(state) > rhs(state), shape)
111 if value.op == ">=":
112 return lambda state: normalize(lhs(state) >= rhs(state), shape)
113 elif len(value.operands) == 3:
114 if value.op == "m":
115 sel, val1, val0 = map(self, value.operands)
116 return lambda state: val1(state) if sel(state) else val0(state)
117 raise NotImplementedError("Operator '{}' not implemented".format(value.op)) # :nocov:
118
119 def on_Slice(self, value):
120 shape = value.shape()
121 arg = self(value.value)
122 shift = value.start
123 mask = (1 << (value.end - value.start)) - 1
124 return lambda state: normalize((arg(state) >> shift) & mask, shape)
125
126 def on_Part(self, value):
127 shape = value.shape()
128 arg = self(value.value)
129 shift = self(value.offset)
130 mask = (1 << value.width) - 1
131 return lambda state: normalize((arg(state) >> shift(state)) & mask, shape)
132
133 def on_Cat(self, value):
134 shape = value.shape()
135 parts = []
136 offset = 0
137 for opnd in value.operands:
138 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
139 offset += len(opnd)
140 def eval(state):
141 result = 0
142 for offset, mask, opnd in parts:
143 result |= (opnd(state) & mask) << offset
144 return normalize(result, shape)
145 return eval
146
147 def on_Repl(self, value):
148 shape = value.shape()
149 offset = len(value.value)
150 mask = (1 << len(value.value)) - 1
151 count = value.count
152 opnd = self(value.value)
153 def eval(state):
154 result = 0
155 for _ in range(count):
156 result <<= offset
157 result |= opnd(state)
158 return normalize(result, shape)
159 return eval
160
161 def on_ArrayProxy(self, value):
162 shape = value.shape()
163 elems = list(map(self, value.elems))
164 index = self(value.index)
165 return lambda state: normalize(elems[index(state)](state), shape)
166
167
168 class _LHSValueCompiler(AbstractValueTransformer):
169 def __init__(self, rhs_compiler):
170 self.rhs_compiler = rhs_compiler
171
172 def on_Const(self, value):
173 raise TypeError # :nocov:
174
175 def on_Signal(self, value):
176 shape = value.shape()
177 def eval(state, rhs):
178 state.set(value, normalize(rhs, shape))
179 return eval
180
181 def on_ClockSignal(self, value):
182 raise NotImplementedError # :nocov:
183
184 def on_ResetSignal(self, value):
185 raise NotImplementedError # :nocov:
186
187 def on_Operator(self, value):
188 raise TypeError # :nocov:
189
190 def on_Slice(self, value):
191 lhs_r = self.rhs_compiler(value.value)
192 lhs_l = self(value.value)
193 shift = value.start
194 mask = (1 << (value.end - value.start)) - 1
195 def eval(state, rhs):
196 lhs_value = lhs_r(state)
197 lhs_value &= ~(mask << shift)
198 lhs_value |= (rhs & mask) << shift
199 lhs_l(state, lhs_value)
200 return eval
201
202 def on_Part(self, value):
203 lhs_r = self.rhs_compiler(value.value)
204 lhs_l = self(value.value)
205 shift = self.rhs_compiler(value.offset)
206 mask = (1 << value.width) - 1
207 def eval(state, rhs):
208 lhs_value = lhs_r(state)
209 shift_value = shift(state)
210 lhs_value &= ~(mask << shift_value)
211 lhs_value |= (rhs & mask) << shift_value
212 lhs_l(state, lhs_value)
213 return eval
214
215 def on_Cat(self, value):
216 parts = []
217 offset = 0
218 for opnd in value.operands:
219 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
220 offset += len(opnd)
221 def eval(state, rhs):
222 for offset, mask, opnd in parts:
223 opnd(state, (rhs >> offset) & mask)
224 return eval
225
226 def on_Repl(self, value):
227 raise TypeError # :nocov:
228
229 def on_ArrayProxy(self, value):
230 elems = list(map(self, value.elems))
231 index = self.rhs_compiler(value.index)
232 def eval(state, rhs):
233 elems[index(state)](state, rhs)
234 return eval
235
236
237 class _StatementCompiler(AbstractStatementTransformer):
238 def __init__(self):
239 self.sensitivity = SignalSet()
240 self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs")
241 self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs")
242 self.lhs_compiler = _LHSValueCompiler(self.lrhs_compiler)
243
244 def on_Assign(self, stmt):
245 shape = stmt.lhs.shape()
246 lhs = self.lhs_compiler(stmt.lhs)
247 rhs = self.rrhs_compiler(stmt.rhs)
248 def run(state):
249 lhs(state, normalize(rhs(state), shape))
250 return run
251
252 def on_Switch(self, stmt):
253 test = self.rrhs_compiler(stmt.test)
254 cases = []
255 for value, stmts in stmt.cases.items():
256 if "-" in value:
257 mask = "".join("0" if b == "-" else "1" for b in value)
258 value = "".join("0" if b == "-" else b for b in value)
259 else:
260 mask = "1" * len(value)
261 mask = int(mask, 2)
262 value = int(value, 2)
263 def make_test(mask, value):
264 return lambda test: test & mask == value
265 cases.append((make_test(mask, value), self.on_statements(stmts)))
266 def run(state):
267 test_value = test(state)
268 for check, body in cases:
269 if check(test_value):
270 body(state)
271 return
272 return run
273
274 def on_statements(self, stmts):
275 stmts = [self.on_statement(stmt) for stmt in stmts]
276 def run(state):
277 for stmt in stmts:
278 stmt(state)
279 return run
280
281
282 class Simulator:
283 def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()):
284 self._fragment = fragment
285
286 self._domains = dict() # str/domain -> ClockDomain
287 self._domain_triggers = SignalDict() # Signal -> str/domain
288 self._domain_signals = dict() # str/domain -> {Signal}
289
290 self._signals = SignalSet() # {Signal}
291 self._comb_signals = SignalSet() # {Signal}
292 self._sync_signals = SignalSet() # {Signal}
293 self._user_signals = SignalSet() # {Signal}
294
295 self._started = False
296 self._timestamp = 0.
297 self._delta = 0.
298 self._epsilon = 1e-10
299 self._fastest_clock = self._epsilon
300 self._state = _State()
301
302 self._processes = set() # {process}
303 self._process_loc = dict() # process -> str/loc
304 self._passive = set() # {process}
305 self._suspended = set() # {process}
306 self._wait_deadline = dict() # process -> float/timestamp
307 self._wait_tick = dict() # process -> str/domain
308
309 self._funclets = SignalDict() # Signal -> set(lambda)
310
311 self._vcd_file = vcd_file
312 self._vcd_writer = None
313 self._vcd_signals = SignalDict() # signal -> set(vcd_signal)
314 self._vcd_names = SignalDict() # signal -> str/name
315 self._gtkw_file = gtkw_file
316 self._traces = traces
317
318 @staticmethod
319 def _check_process(process):
320 if inspect.isgeneratorfunction(process):
321 process = process()
322 if not inspect.isgenerator(process):
323 raise TypeError("Cannot add a process '{!r}' because it is not a generator or"
324 "a generator function"
325 .format(process))
326 return process
327
328 def _name_process(self, process):
329 if process in self._process_loc:
330 return self._process_loc[process]
331 else:
332 frame = process.gi_frame
333 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
334
335 def add_process(self, process):
336 process = self._check_process(process)
337 self._processes.add(process)
338
339 def add_sync_process(self, process, domain="sync"):
340 process = self._check_process(process)
341 def sync_process():
342 try:
343 result = None
344 while True:
345 self._process_loc[sync_process] = self._name_process(process)
346 cmd = process.send(result)
347 if cmd is None:
348 cmd = Tick(domain)
349 result = yield cmd
350 except StopIteration:
351 pass
352 sync_process = sync_process()
353 self.add_process(sync_process)
354
355 def add_clock(self, period, phase=None, domain="sync"):
356 if self._fastest_clock == self._epsilon or period < self._fastest_clock:
357 self._fastest_clock = period
358
359 half_period = period / 2
360 if phase is None:
361 phase = half_period
362 clk = self._domains[domain].clk
363 def clk_process():
364 yield Passive()
365 yield Delay(phase)
366 while True:
367 yield clk.eq(1)
368 yield Delay(half_period)
369 yield clk.eq(0)
370 yield Delay(half_period)
371 self.add_process(clk_process)
372
373 def __enter__(self):
374 if self._vcd_file:
375 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
376 comment="Generated by nMigen")
377
378 root_fragment = self._fragment.prepare()
379
380 self._domains = root_fragment.domains
381 for domain, cd in self._domains.items():
382 self._domain_triggers[cd.clk] = domain
383 if cd.rst is not None:
384 self._domain_triggers[cd.rst] = domain
385 self._domain_signals[domain] = SignalSet()
386
387 hierarchy = {}
388 def add_fragment(fragment, scope=()):
389 hierarchy[fragment] = scope
390 for subfragment, name in fragment.subfragments:
391 add_fragment(subfragment, (*scope, name))
392 add_fragment(root_fragment)
393
394 for fragment, fragment_scope in hierarchy.items():
395 for signal in fragment.iter_signals():
396 self._signals.add(signal)
397
398 self._state.curr[signal] = self._state.next[signal] = \
399 normalize(signal.reset, signal.shape())
400 self._state.curr_dirty.add(signal)
401
402 if not self._vcd_writer:
403 continue
404
405 if signal not in self._vcd_signals:
406 self._vcd_signals[signal] = set()
407
408 for subfragment, name in fragment.subfragments:
409 if signal in subfragment.ports:
410 var_name = "{}_{}".format(name, signal.name)
411 break
412 else:
413 var_name = signal.name
414
415 if signal.decoder:
416 var_type = "string"
417 var_size = 1
418 var_init = signal.decoder(signal.reset).replace(" ", "_")
419 else:
420 var_type = "wire"
421 var_size = signal.nbits
422 var_init = signal.reset
423
424 suffix = None
425 while True:
426 try:
427 if suffix is None:
428 var_name_suffix = var_name
429 else:
430 var_name_suffix = "{}${}".format(var_name, suffix)
431 self._vcd_signals[signal].add(self._vcd_writer.register_var(
432 scope=".".join(fragment_scope), name=var_name_suffix,
433 var_type=var_type, size=var_size, init=var_init))
434 if signal not in self._vcd_names:
435 self._vcd_names[signal] = ".".join(fragment_scope + (var_name_suffix,))
436 break
437 except KeyError:
438 suffix = (suffix or 0) + 1
439
440 for domain, signals in fragment.drivers.items():
441 if domain is None:
442 self._comb_signals.update(signals)
443 else:
444 self._sync_signals.update(signals)
445 self._domain_signals[domain].update(signals)
446
447 statements = []
448 for signal in fragment.iter_comb():
449 statements.append(signal.eq(signal.reset))
450 for domain, signal in fragment.iter_sync():
451 statements.append(signal.eq(signal))
452 statements += fragment.statements
453
454 compiler = _StatementCompiler()
455 funclet = compiler(statements)
456
457 def add_funclet(signal, funclet):
458 if signal not in self._funclets:
459 self._funclets[signal] = set()
460 self._funclets[signal].add(funclet)
461
462 for signal in compiler.sensitivity:
463 add_funclet(signal, funclet)
464 for domain, cd in fragment.domains.items():
465 add_funclet(cd.clk, funclet)
466 if cd.rst is not None:
467 add_funclet(cd.rst, funclet)
468
469 self._user_signals = self._signals - self._comb_signals - self._sync_signals
470
471 return self
472
473 def _update_dirty_signals(self):
474 """Perform the statement part of IR processes (aka RTLIL case)."""
475 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
476 # that need their statements to be reevaluated because the signals changed at the previous
477 # delta cycle.
478 funclets = set()
479 while self._state.curr_dirty:
480 signal = self._state.curr_dirty.pop()
481 if signal in self._funclets:
482 funclets.update(self._funclets[signal])
483
484 # Second, compute the values of all signals at the start of the next delta cycle, by
485 # running precompiled statements.
486 for funclet in funclets:
487 funclet(self._state)
488
489 def _commit_signal(self, signal, domains):
490 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
491 # Take the computed value (at the start of this delta cycle) of a signal (that could have
492 # come from an IR process that ran earlier, or modified by a simulator process) and update
493 # the value for this delta cycle.
494 old, new = self._state.commit(signal)
495
496 # If the signal is a clock that triggers synchronous logic, record that fact.
497 if (old, new) == (0, 1) and signal in self._domain_triggers:
498 domains.add(self._domain_triggers[signal])
499
500 if self._vcd_writer and old != new:
501 # Finally, dump the new value to the VCD file.
502 for vcd_signal in self._vcd_signals[signal]:
503 if signal.decoder:
504 var_value = signal.decoder(new).replace(" ", "_")
505 else:
506 var_value = new
507 vcd_timestamp = (self._timestamp + self._delta) / self._epsilon
508 self._vcd_writer.change(vcd_signal, vcd_timestamp, var_value)
509
510 def _commit_comb_signals(self, domains):
511 """Perform the comb part of IR processes (aka RTLIL always)."""
512 # Take the computed value (at the start of this delta cycle) of every comb signal and
513 # update the value for this delta cycle.
514 for signal in self._state.next_dirty:
515 if signal in self._comb_signals:
516 self._commit_signal(signal, domains)
517
518 def _commit_sync_signals(self, domains):
519 """Perform the sync part of IR processes (aka RTLIL posedge)."""
520 # At entry, `domains` contains a list of every simultaneously triggered sync update.
521 while domains:
522 # Advance the timeline a bit (purely for observational purposes) and commit all of them
523 # at the same timestamp.
524 self._delta += self._epsilon
525 curr_domains, domains = domains, set()
526
527 while curr_domains:
528 domain = curr_domains.pop()
529
530 # Take the computed value (at the start of this delta cycle) of every sync signal
531 # in this domain and update the value for this delta cycle. This can trigger more
532 # synchronous logic, so record that.
533 for signal in self._state.next_dirty:
534 if signal in self._domain_signals[domain]:
535 self._commit_signal(signal, domains)
536
537 # Wake up any simulator processes that wait for a domain tick.
538 for process, wait_domain in list(self._wait_tick.items()):
539 if domain == wait_domain:
540 del self._wait_tick[process]
541 self._suspended.remove(process)
542
543 # Immediately run the process. It is important that this happens here,
544 # and not on the next step, when all the processes will run anyway,
545 # because Tick() simulates an edge triggered process. Like DFFs that latch
546 # a value from the previous clock cycle, simulator processes observe signal
547 # values from the previous clock cycle on a tick, too.
548 self._run_process(process)
549
550 # Unless handling synchronous logic above has triggered more synchronous logic (which
551 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
552 # Otherwise, do one more round of updates.
553
554 def _run_process(self, process):
555 try:
556 cmd = process.send(None)
557 while True:
558 if isinstance(cmd, Delay):
559 if cmd.interval is None:
560 interval = self._epsilon
561 else:
562 interval = cmd.interval
563 self._wait_deadline[process] = self._timestamp + interval
564 self._suspended.add(process)
565 break
566
567 elif isinstance(cmd, Tick):
568 self._wait_tick[process] = cmd.domain
569 self._suspended.add(process)
570 break
571
572 elif isinstance(cmd, Passive):
573 self._passive.add(process)
574
575 elif isinstance(cmd, Value):
576 compiler = _RHSValueCompiler()
577 funclet = compiler(cmd)
578 cmd = process.send(funclet(self._state))
579 continue
580
581 elif isinstance(cmd, Assign):
582 lhs_signals = cmd.lhs._lhs_signals()
583 for signal in lhs_signals:
584 if not signal in self._signals:
585 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
586 "which is not a part of simulation"
587 .format(self._name_process(process), signal))
588 if signal in self._comb_signals:
589 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
590 "which is a part of combinatorial assignment in "
591 "simulation"
592 .format(self._name_process(process), signal))
593
594 compiler = _StatementCompiler()
595 funclet = compiler(cmd)
596 funclet(self._state)
597
598 domains = set()
599 for signal in lhs_signals:
600 self._commit_signal(signal, domains)
601 self._commit_sync_signals(domains)
602
603 else:
604 raise TypeError("Received unsupported command '{!r}' from process '{}'"
605 .format(cmd, self._name_process(process)))
606
607 cmd = process.send(None)
608
609 except StopIteration:
610 self._processes.remove(process)
611 self._passive.discard(process)
612
613 except Exception as e:
614 process.throw(e)
615
616 def step(self, run_passive=False):
617 # Are there any delta cycles we should run?
618 if self._state.curr_dirty:
619 # We might run some delta cycles, and we have simulator processes waiting on
620 # a deadline. Take care to not exceed the closest deadline.
621 if self._wait_deadline and \
622 (self._timestamp + self._delta) >= min(self._wait_deadline.values()):
623 # Oops, we blew the deadline. We *could* run the processes now, but this is
624 # virtually certainly a logic loop and a design bug, so bail out instead.d
625 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
626
627 domains = set()
628 while self._state.curr_dirty:
629 self._update_dirty_signals()
630 self._commit_comb_signals(domains)
631 self._commit_sync_signals(domains)
632 return True
633
634 # Are there any processes that haven't had a chance to run yet?
635 if len(self._processes) > len(self._suspended):
636 # Schedule an arbitrary one.
637 process = (self._processes - set(self._suspended)).pop()
638 self._run_process(process)
639 return True
640
641 # All processes are suspended. Are any of them active?
642 if len(self._processes) > len(self._passive) or run_passive:
643 # Are any of them suspended before a deadline?
644 if self._wait_deadline:
645 # Schedule the one with the lowest deadline.
646 process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
647 del self._wait_deadline[process]
648 self._suspended.remove(process)
649 self._timestamp = deadline
650 self._delta = 0.
651 self._run_process(process)
652 return True
653
654 # No processes, or all processes are passive. Nothing to do!
655 return False
656
657 def run(self):
658 while self.step():
659 pass
660
661 def run_until(self, deadline, run_passive=False):
662 while self._timestamp < deadline:
663 if not self.step(run_passive):
664 return False
665
666 return True
667
668 def __exit__(self, *args):
669 if self._vcd_writer:
670 vcd_timestamp = (self._timestamp + self._delta) / self._epsilon
671 self._vcd_writer.close(vcd_timestamp)
672
673 if self._vcd_file and self._gtkw_file:
674 gtkw_save = GTKWSave(self._gtkw_file)
675 if hasattr(self._vcd_file, "name"):
676 gtkw_save.dumpfile(self._vcd_file.name)
677 if hasattr(self._vcd_file, "tell"):
678 gtkw_save.dumpfile_size(self._vcd_file.tell())
679
680 gtkw_save.treeopen("top")
681 gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
682
683 def add_trace(signal, **kwargs):
684 if signal in self._vcd_names:
685 if len(signal) > 1:
686 suffix = "[{}:0]".format(len(signal) - 1)
687 else:
688 suffix = ""
689 gtkw_save.trace(self._vcd_names[signal] + suffix, **kwargs)
690
691 for domain, cd in self._domains.items():
692 with gtkw_save.group("d.{}".format(domain)):
693 if cd.rst is not None:
694 add_trace(cd.rst)
695 add_trace(cd.clk)
696
697 for signal in self._traces:
698 add_trace(signal)
699
700 if self._vcd_file:
701 self._vcd_file.close()
702 if self._gtkw_file:
703 self._gtkw_file.close()