_tools: extract most utility methods to a private package.
[nmigen.git] / nmigen / back / pysim.py
1 import math
2 import inspect
3 import warnings
4 from contextlib import contextmanager
5 from bitarray import bitarray
6 from vcd import VCDWriter
7 from vcd.gtkw import GTKWSave
8
9 from .._tools import flatten
10 from ..hdl.ast import *
11 from ..hdl.ir import *
12 from ..hdl.xfrm import ValueVisitor, StatementVisitor
13
14
15 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
16
17
18 class DeadlineError(Exception):
19 pass
20
21
22 class _State:
23 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
24
25 def __init__(self):
26 self.curr = []
27 self.next = []
28 self.curr_dirty = bitarray()
29 self.next_dirty = bitarray()
30
31 def add(self, value):
32 slot = len(self.curr)
33 self.curr.append(value)
34 self.next.append(value)
35 self.curr_dirty.append(True)
36 self.next_dirty.append(False)
37 return slot
38
39 def set(self, slot, value):
40 if self.next[slot] != value:
41 self.next_dirty[slot] = True
42 self.next[slot] = value
43
44 def commit(self, slot):
45 old_value = self.curr[slot]
46 new_value = self.next[slot]
47 if old_value != new_value:
48 self.next_dirty[slot] = False
49 self.curr_dirty[slot] = True
50 self.curr[slot] = new_value
51 return old_value, new_value
52
53 def flush_curr_dirty(self):
54 while True:
55 try:
56 slot = self.curr_dirty.index(True)
57 except ValueError:
58 break
59 self.curr_dirty[slot] = False
60 yield slot
61
62 def iter_next_dirty(self):
63 start = 0
64 while True:
65 try:
66 slot = self.next_dirty.index(True, start)
67 start = slot + 1
68 except ValueError:
69 break
70 yield slot
71
72
73 normalize = Const.normalize
74
75
76 class _ValueCompiler(ValueVisitor):
77 def on_AnyConst(self, value):
78 raise NotImplementedError # :nocov:
79
80 def on_AnySeq(self, value):
81 raise NotImplementedError # :nocov:
82
83 def on_Sample(self, value):
84 raise NotImplementedError # :nocov:
85
86 def on_Initial(self, value):
87 raise NotImplementedError # :nocov:
88
89 def on_Record(self, value):
90 return self(Cat(value.fields.values()))
91
92
93 class _RHSValueCompiler(_ValueCompiler):
94 def __init__(self, signal_slots, sensitivity=None, mode="rhs"):
95 self.signal_slots = signal_slots
96 self.sensitivity = sensitivity
97 self.signal_mode = mode
98
99 def on_Const(self, value):
100 return lambda state: value.value
101
102 def on_Signal(self, value):
103 if self.sensitivity is not None:
104 self.sensitivity.add(value)
105 if value not in self.signal_slots:
106 # A signal that is neither driven nor a port always remains at its reset state.
107 return lambda state: value.reset
108 value_slot = self.signal_slots[value]
109 if self.signal_mode == "rhs":
110 return lambda state: state.curr[value_slot]
111 elif self.signal_mode == "lhs":
112 return lambda state: state.next[value_slot]
113 else:
114 raise ValueError # :nocov:
115
116 def on_ClockSignal(self, value):
117 raise NotImplementedError # :nocov:
118
119 def on_ResetSignal(self, value):
120 raise NotImplementedError # :nocov:
121
122 def on_Operator(self, value):
123 shape = value.shape()
124 if len(value.operands) == 1:
125 arg, = map(self, value.operands)
126 if value.operator == "~":
127 return lambda state: normalize(~arg(state), shape)
128 if value.operator == "-":
129 return lambda state: normalize(-arg(state), shape)
130 if value.operator == "b":
131 return lambda state: normalize(bool(arg(state)), shape)
132 if value.operator == "r|":
133 return lambda state: normalize(arg(state) != 0, shape)
134 if value.operator == "r&":
135 val, = value.operands
136 mask = (1 << len(val)) - 1
137 return lambda state: normalize(arg(state) == mask, shape)
138 if value.operator == "r^":
139 # Believe it or not, this is the fastest way to compute a sideways XOR in Python.
140 return lambda state: normalize(format(arg(state), "b").count("1") % 2, shape)
141 elif len(value.operands) == 2:
142 lhs, rhs = map(self, value.operands)
143 if value.operator == "+":
144 return lambda state: normalize(lhs(state) + rhs(state), shape)
145 if value.operator == "-":
146 return lambda state: normalize(lhs(state) - rhs(state), shape)
147 if value.operator == "*":
148 return lambda state: normalize(lhs(state) * rhs(state), shape)
149 if value.operator == "//":
150 def floordiv(lhs, rhs):
151 return 0 if rhs == 0 else lhs // rhs
152 return lambda state: normalize(floordiv(lhs(state), rhs(state)), shape)
153 if value.operator == "&":
154 return lambda state: normalize(lhs(state) & rhs(state), shape)
155 if value.operator == "|":
156 return lambda state: normalize(lhs(state) | rhs(state), shape)
157 if value.operator == "^":
158 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
159 if value.operator == "<<":
160 def sshl(lhs, rhs):
161 return lhs << rhs if rhs >= 0 else lhs >> -rhs
162 return lambda state: normalize(sshl(lhs(state), rhs(state)), shape)
163 if value.operator == ">>":
164 def sshr(lhs, rhs):
165 return lhs >> rhs if rhs >= 0 else lhs << -rhs
166 return lambda state: normalize(sshr(lhs(state), rhs(state)), shape)
167 if value.operator == "==":
168 return lambda state: normalize(lhs(state) == rhs(state), shape)
169 if value.operator == "!=":
170 return lambda state: normalize(lhs(state) != rhs(state), shape)
171 if value.operator == "<":
172 return lambda state: normalize(lhs(state) < rhs(state), shape)
173 if value.operator == "<=":
174 return lambda state: normalize(lhs(state) <= rhs(state), shape)
175 if value.operator == ">":
176 return lambda state: normalize(lhs(state) > rhs(state), shape)
177 if value.operator == ">=":
178 return lambda state: normalize(lhs(state) >= rhs(state), shape)
179 elif len(value.operands) == 3:
180 if value.operator == "m":
181 sel, val1, val0 = map(self, value.operands)
182 return lambda state: val1(state) if sel(state) else val0(state)
183 raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov:
184
185 def on_Slice(self, value):
186 shape = value.shape()
187 arg = self(value.value)
188 shift = value.start
189 mask = (1 << (value.end - value.start)) - 1
190 return lambda state: normalize((arg(state) >> shift) & mask, shape)
191
192 def on_Part(self, value):
193 shape = value.shape()
194 arg = self(value.value)
195 shift = self(value.offset)
196 mask = (1 << value.width) - 1
197 stride = value.stride
198 return lambda state: normalize((arg(state) >> shift(state) * stride) & mask, shape)
199
200 def on_Cat(self, value):
201 shape = value.shape()
202 parts = []
203 offset = 0
204 for opnd in value.parts:
205 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
206 offset += len(opnd)
207 def eval(state):
208 result = 0
209 for offset, mask, opnd in parts:
210 result |= (opnd(state) & mask) << offset
211 return normalize(result, shape)
212 return eval
213
214 def on_Repl(self, value):
215 shape = value.shape()
216 offset = len(value.value)
217 mask = (1 << len(value.value)) - 1
218 count = value.count
219 opnd = self(value.value)
220 def eval(state):
221 result = 0
222 for _ in range(count):
223 result <<= offset
224 result |= opnd(state)
225 return normalize(result, shape)
226 return eval
227
228 def on_ArrayProxy(self, value):
229 shape = value.shape()
230 elems = list(map(self, value.elems))
231 index = self(value.index)
232 def eval(state):
233 index_value = index(state)
234 if index_value >= len(elems):
235 index_value = len(elems) - 1
236 return normalize(elems[index_value](state), shape)
237 return eval
238
239
240 class _LHSValueCompiler(_ValueCompiler):
241 def __init__(self, signal_slots, rhs_compiler):
242 self.signal_slots = signal_slots
243 self.rhs_compiler = rhs_compiler
244
245 def on_Const(self, value):
246 raise TypeError # :nocov:
247
248 def on_Signal(self, value):
249 shape = value.shape()
250 value_slot = self.signal_slots[value]
251 def eval(state, rhs):
252 state.set(value_slot, normalize(rhs, shape))
253 return eval
254
255 def on_ClockSignal(self, value):
256 raise NotImplementedError # :nocov:
257
258 def on_ResetSignal(self, value):
259 raise NotImplementedError # :nocov:
260
261 def on_Operator(self, value):
262 raise TypeError # :nocov:
263
264 def on_Slice(self, value):
265 lhs_r = self.rhs_compiler(value.value)
266 lhs_l = self(value.value)
267 shift = value.start
268 mask = (1 << (value.end - value.start)) - 1
269 def eval(state, rhs):
270 lhs_value = lhs_r(state)
271 lhs_value &= ~(mask << shift)
272 lhs_value |= (rhs & mask) << shift
273 lhs_l(state, lhs_value)
274 return eval
275
276 def on_Part(self, value):
277 lhs_r = self.rhs_compiler(value.value)
278 lhs_l = self(value.value)
279 shift = self.rhs_compiler(value.offset)
280 mask = (1 << value.width) - 1
281 stride = value.stride
282 def eval(state, rhs):
283 lhs_value = lhs_r(state)
284 shift_value = shift(state) * stride
285 lhs_value &= ~(mask << shift_value)
286 lhs_value |= (rhs & mask) << shift_value
287 lhs_l(state, lhs_value)
288 return eval
289
290 def on_Cat(self, value):
291 parts = []
292 offset = 0
293 for opnd in value.parts:
294 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
295 offset += len(opnd)
296 def eval(state, rhs):
297 for offset, mask, opnd in parts:
298 opnd(state, (rhs >> offset) & mask)
299 return eval
300
301 def on_Repl(self, value):
302 raise TypeError # :nocov:
303
304 def on_ArrayProxy(self, value):
305 elems = list(map(self, value.elems))
306 index = self.rhs_compiler(value.index)
307 def eval(state, rhs):
308 index_value = index(state)
309 if index_value >= len(elems):
310 index_value = len(elems) - 1
311 elems[index_value](state, rhs)
312 return eval
313
314
315 class _StatementCompiler(StatementVisitor):
316 def __init__(self, signal_slots):
317 self.sensitivity = SignalSet()
318 self.rrhs_compiler = _RHSValueCompiler(signal_slots, self.sensitivity, mode="rhs")
319 self.lrhs_compiler = _RHSValueCompiler(signal_slots, self.sensitivity, mode="lhs")
320 self.lhs_compiler = _LHSValueCompiler(signal_slots, self.lrhs_compiler)
321
322 def on_Assign(self, stmt):
323 shape = stmt.lhs.shape()
324 lhs = self.lhs_compiler(stmt.lhs)
325 rhs = self.rrhs_compiler(stmt.rhs)
326 def run(state):
327 lhs(state, normalize(rhs(state), shape))
328 return run
329
330 def on_Assert(self, stmt):
331 raise NotImplementedError("Asserts not yet implemented for Simulator backend.") # :nocov:
332
333 def on_Assume(self, stmt):
334 pass # :nocov:
335
336 def on_Cover(self, stmt):
337 raise NotImplementedError("Covers not yet implemented for Simulator backend.") # :nocov:
338
339 def on_Switch(self, stmt):
340 test = self.rrhs_compiler(stmt.test)
341 cases = []
342 for values, stmts in stmt.cases.items():
343 if values == ():
344 check = lambda test: True
345 else:
346 check = lambda test: False
347 def make_check(mask, value, prev_check):
348 return lambda test: prev_check(test) or test & mask == value
349 for value in values:
350 if "-" in value:
351 mask = "".join("0" if b == "-" else "1" for b in value)
352 value = "".join("0" if b == "-" else b for b in value)
353 else:
354 mask = "1" * len(value)
355 mask = int(mask, 2)
356 value = int(value, 2)
357 check = make_check(mask, value, check)
358 cases.append((check, self.on_statements(stmts)))
359 def run(state):
360 test_value = test(state)
361 for check, body in cases:
362 if check(test_value):
363 body(state)
364 return
365 return run
366
367 def on_statements(self, stmts):
368 stmts = [self.on_statement(stmt) for stmt in stmts]
369 def run(state):
370 for stmt in stmts:
371 stmt(state)
372 return run
373
374
375 class Simulator:
376 def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()):
377 self._fragment = Fragment.get(fragment, platform=None)
378
379 self._signal_slots = SignalDict() # Signal -> int/slot
380 self._slot_signals = list() # int/slot -> Signal
381
382 self._domains = list() # [ClockDomain]
383 self._clk_edges = dict() # ClockDomain -> int/edge
384 self._domain_triggers = list() # int/slot -> ClockDomain
385
386 self._signals = SignalSet() # {Signal}
387 self._comb_signals = bitarray() # {Signal}
388 self._sync_signals = bitarray() # {Signal}
389 self._user_signals = bitarray() # {Signal}
390 self._domain_signals = dict() # ClockDomain -> {Signal}
391
392 self._started = False
393 self._timestamp = 0.
394 self._delta = 0.
395 self._epsilon = 1e-10
396 self._fastest_clock = self._epsilon
397 self._all_clocks = set() # {str/domain}
398 self._state = _State()
399
400 self._processes = set() # {process}
401 self._process_loc = dict() # process -> str/loc
402 self._passive = set() # {process}
403 self._suspended = set() # {process}
404 self._wait_deadline = dict() # process -> float/timestamp
405 self._wait_tick = dict() # process -> str/domain
406
407 self._funclets = list() # int/slot -> set(lambda)
408
409 self._vcd_file = vcd_file
410 self._vcd_writer = None
411 self._vcd_signals = list() # int/slot -> set(vcd_signal)
412 self._vcd_names = list() # int/slot -> str/name
413 self._gtkw_file = gtkw_file
414 self._traces = traces
415
416 self._run_called = False
417
418 @staticmethod
419 def _check_process(process):
420 if inspect.isgeneratorfunction(process):
421 process = process()
422 if not (inspect.isgenerator(process) or inspect.iscoroutine(process)):
423 raise TypeError("Cannot add a process {!r} because it is not a generator or "
424 "a generator function"
425 .format(process))
426 return process
427
428 def _name_process(self, process):
429 if process in self._process_loc:
430 return self._process_loc[process]
431 else:
432 if inspect.isgenerator(process):
433 frame = process.gi_frame
434 if inspect.iscoroutine(process):
435 frame = process.cr_frame
436 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
437
438 def add_process(self, process):
439 process = self._check_process(process)
440 self._processes.add(process)
441
442 def add_sync_process(self, process, domain="sync"):
443 process = self._check_process(process)
444 def sync_process():
445 try:
446 cmd = None
447 while True:
448 if cmd is None:
449 cmd = Tick(domain)
450 result = yield cmd
451 self._process_loc[sync_process] = self._name_process(process)
452 cmd = process.send(result)
453 except StopIteration:
454 pass
455 sync_process = sync_process()
456 self.add_process(sync_process)
457
458 def add_clock(self, period, *, phase=None, domain="sync", if_exists=False):
459 if self._fastest_clock == self._epsilon or period < self._fastest_clock:
460 self._fastest_clock = period
461 if domain in self._all_clocks:
462 raise ValueError("Domain '{}' already has a clock driving it"
463 .format(domain))
464
465 half_period = period / 2
466 if phase is None:
467 phase = half_period
468 for domain_obj in self._domains:
469 if not domain_obj.local and domain_obj.name == domain:
470 clk = domain_obj.clk
471 break
472 else:
473 if if_exists:
474 return
475 else:
476 raise ValueError("Domain '{}' is not present in simulation"
477 .format(domain))
478 def clk_process():
479 yield Passive()
480 yield Delay(phase)
481 while True:
482 yield clk.eq(1)
483 yield Delay(half_period)
484 yield clk.eq(0)
485 yield Delay(half_period)
486 self.add_process(clk_process)
487 self._all_clocks.add(domain)
488
489 def __enter__(self):
490 if self._vcd_file:
491 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
492 comment="Generated by nMigen")
493
494 root_fragment = self._fragment.prepare()
495
496 hierarchy = {}
497 domains = set()
498 def add_fragment(fragment, scope=()):
499 hierarchy[fragment] = scope
500 domains.update(fragment.domains.values())
501 for index, (subfragment, name) in enumerate(fragment.subfragments):
502 if name is None:
503 add_fragment(subfragment, (*scope, "U{}".format(index)))
504 else:
505 add_fragment(subfragment, (*scope, name))
506 add_fragment(root_fragment, scope=("top",))
507 self._domains = list(domains)
508 self._clk_edges = {domain: 1 if domain.clk_edge == "pos" else 0 for domain in domains}
509
510 def add_signal(signal):
511 if signal not in self._signals:
512 self._signals.add(signal)
513
514 signal_slot = self._state.add(normalize(signal.reset, signal.shape()))
515 self._signal_slots[signal] = signal_slot
516 self._slot_signals.append(signal)
517
518 self._comb_signals.append(False)
519 self._sync_signals.append(False)
520 self._user_signals.append(False)
521 for domain in self._domains:
522 if domain not in self._domain_signals:
523 self._domain_signals[domain] = bitarray()
524 self._domain_signals[domain].append(False)
525
526 self._funclets.append(set())
527
528 self._domain_triggers.append(None)
529 if self._vcd_writer:
530 self._vcd_signals.append(set())
531 self._vcd_names.append(None)
532
533 return self._signal_slots[signal]
534
535 def add_domain_signal(signal, domain):
536 signal_slot = add_signal(signal)
537 self._domain_triggers[signal_slot] = domain
538
539 for fragment, fragment_scope in hierarchy.items():
540 for signal in fragment.iter_signals():
541 add_signal(signal)
542
543 for domain_name, domain in fragment.domains.items():
544 add_domain_signal(domain.clk, domain)
545 if domain.rst is not None:
546 add_domain_signal(domain.rst, domain)
547
548 for fragment, fragment_scope in hierarchy.items():
549 for signal in fragment.iter_signals():
550 if not self._vcd_writer:
551 continue
552
553 signal_slot = self._signal_slots[signal]
554
555 for i, (subfragment, name) in enumerate(fragment.subfragments):
556 if signal in subfragment.ports:
557 var_name = "{}_{}".format(name or "U{}".format(i), signal.name)
558 break
559 else:
560 var_name = signal.name
561
562 if signal.decoder:
563 var_type = "string"
564 var_size = 1
565 var_init = signal.decoder(signal.reset).expandtabs().replace(" ", "_")
566 else:
567 var_type = "wire"
568 var_size = signal.width
569 var_init = signal.reset
570
571 suffix = None
572 while True:
573 try:
574 if suffix is None:
575 var_name_suffix = var_name
576 else:
577 var_name_suffix = "{}${}".format(var_name, suffix)
578 self._vcd_signals[signal_slot].add(self._vcd_writer.register_var(
579 scope=".".join(fragment_scope), name=var_name_suffix,
580 var_type=var_type, size=var_size, init=var_init))
581 if self._vcd_names[signal_slot] is None:
582 self._vcd_names[signal_slot] = \
583 ".".join(fragment_scope + (var_name_suffix,))
584 break
585 except KeyError:
586 suffix = (suffix or 0) + 1
587
588 for domain_name, signals in fragment.drivers.items():
589 signals_bits = bitarray(len(self._signals))
590 signals_bits.setall(False)
591 for signal in signals:
592 signals_bits[self._signal_slots[signal]] = True
593
594 if domain_name is None:
595 self._comb_signals |= signals_bits
596 else:
597 self._sync_signals |= signals_bits
598 self._domain_signals[fragment.domains[domain_name]] |= signals_bits
599
600 statements = []
601 for domain_name, signals in fragment.drivers.items():
602 reset_stmts = []
603 hold_stmts = []
604 for signal in signals:
605 reset_stmts.append(signal.eq(signal.reset))
606 hold_stmts .append(signal.eq(signal))
607
608 if domain_name is None:
609 statements += reset_stmts
610 else:
611 if fragment.domains[domain_name].async_reset:
612 statements.append(Switch(fragment.domains[domain_name].rst,
613 {0: hold_stmts, 1: reset_stmts}))
614 else:
615 statements += hold_stmts
616 statements += fragment.statements
617
618 compiler = _StatementCompiler(self._signal_slots)
619 funclet = compiler(statements)
620
621 def add_funclet(signal, funclet):
622 if signal in self._signal_slots:
623 self._funclets[self._signal_slots[signal]].add(funclet)
624
625 for signal in compiler.sensitivity:
626 add_funclet(signal, funclet)
627 for domain in fragment.domains.values():
628 add_funclet(domain.clk, funclet)
629 if domain.rst is not None:
630 add_funclet(domain.rst, funclet)
631
632 self._user_signals = bitarray(len(self._signals))
633 self._user_signals.setall(True)
634 self._user_signals &= ~self._comb_signals
635 self._user_signals &= ~self._sync_signals
636
637 return self
638
639 def _update_dirty_signals(self):
640 """Perform the statement part of IR processes (aka RTLIL case)."""
641 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
642 # that need their statements to be reevaluated because the signals changed at the previous
643 # delta cycle.
644 funclets = set()
645 for signal_slot in self._state.flush_curr_dirty():
646 funclets.update(self._funclets[signal_slot])
647
648 # Second, compute the values of all signals at the start of the next delta cycle, by
649 # running precompiled statements.
650 for funclet in funclets:
651 funclet(self._state)
652
653 def _commit_signal(self, signal_slot, domains):
654 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
655 # Take the computed value (at the start of this delta cycle) of a signal (that could have
656 # come from an IR process that ran earlier, or modified by a simulator process) and update
657 # the value for this delta cycle.
658 old, new = self._state.commit(signal_slot)
659 if old == new:
660 return
661
662 # If the signal is a clock that triggers synchronous logic, record that fact.
663 if (self._domain_triggers[signal_slot] is not None and
664 self._clk_edges[self._domain_triggers[signal_slot]] == new):
665 domains.add(self._domain_triggers[signal_slot])
666
667 if self._vcd_writer:
668 # Finally, dump the new value to the VCD file.
669 for vcd_signal in self._vcd_signals[signal_slot]:
670 signal = self._slot_signals[signal_slot]
671 if signal.decoder:
672 var_value = signal.decoder(new).expandtabs().replace(" ", "_")
673 else:
674 var_value = new
675 vcd_timestamp = (self._timestamp + self._delta) / self._epsilon
676 self._vcd_writer.change(vcd_signal, vcd_timestamp, var_value)
677
678 def _commit_comb_signals(self, domains):
679 """Perform the comb part of IR processes (aka RTLIL always)."""
680 # Take the computed value (at the start of this delta cycle) of every comb signal and
681 # update the value for this delta cycle.
682 for signal_slot in self._state.iter_next_dirty():
683 if self._comb_signals[signal_slot]:
684 self._commit_signal(signal_slot, domains)
685
686 def _commit_sync_signals(self, domains):
687 """Perform the sync part of IR processes (aka RTLIL posedge)."""
688 # At entry, `domains` contains a set of every simultaneously triggered sync update.
689 while domains:
690 # Advance the timeline a bit (purely for observational purposes) and commit all of them
691 # at the same timestamp.
692 self._delta += self._epsilon
693 curr_domains, domains = domains, set()
694
695 while curr_domains:
696 domain = curr_domains.pop()
697
698 # Wake up any simulator processes that wait for a domain tick.
699 for process, wait_domain_name in list(self._wait_tick.items()):
700 if domain.name == wait_domain_name:
701 del self._wait_tick[process]
702 self._suspended.remove(process)
703
704 # Immediately run the process. It is important that this happens here,
705 # and not on the next step, when all the processes will run anyway,
706 # because Tick() simulates an edge triggered process. Like DFFs that latch
707 # a value from the previous clock cycle, simulator processes observe signal
708 # values from the previous clock cycle on a tick, too.
709 self._run_process(process)
710
711 # Take the computed value (at the start of this delta cycle) of every sync signal
712 # in this domain and update the value for this delta cycle. This can trigger more
713 # synchronous logic, so record that.
714 for signal_slot in self._state.iter_next_dirty():
715 if self._domain_signals[domain][signal_slot]:
716 self._commit_signal(signal_slot, domains)
717
718 # Unless handling synchronous logic above has triggered more synchronous logic (which
719 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
720 # Otherwise, do one more round of updates.
721
722 def _run_process(self, process):
723 try:
724 cmd = process.send(None)
725 while True:
726 if type(cmd) is Delay:
727 if cmd.interval is None:
728 interval = self._epsilon
729 else:
730 interval = cmd.interval
731 self._wait_deadline[process] = self._timestamp + interval
732 self._suspended.add(process)
733 break
734
735 elif type(cmd) is Tick:
736 self._wait_tick[process] = cmd.domain
737 self._suspended.add(process)
738 break
739
740 elif type(cmd) is Passive:
741 self._passive.add(process)
742
743 elif type(cmd) is Assign:
744 lhs_signals = cmd.lhs._lhs_signals()
745 for signal in lhs_signals:
746 if not signal in self._signals:
747 raise ValueError("Process '{}' sent a request to set signal {!r}, "
748 "which is not a part of simulation"
749 .format(self._name_process(process), signal))
750 signal_slot = self._signal_slots[signal]
751 if self._comb_signals[signal_slot]:
752 raise ValueError("Process '{}' sent a request to set signal {!r}, "
753 "which is a part of combinatorial assignment in "
754 "simulation"
755 .format(self._name_process(process), signal))
756
757 if type(cmd.lhs) is Signal and type(cmd.rhs) is Const:
758 # Fast path.
759 self._state.set(self._signal_slots[cmd.lhs],
760 normalize(cmd.rhs.value, cmd.lhs.shape()))
761 else:
762 compiler = _StatementCompiler(self._signal_slots)
763 funclet = compiler(cmd)
764 funclet(self._state)
765
766 domains = set()
767 for signal in lhs_signals:
768 self._commit_signal(self._signal_slots[signal], domains)
769 self._commit_sync_signals(domains)
770
771 elif type(cmd) is Signal:
772 # Fast path.
773 cmd = process.send(self._state.curr[self._signal_slots[cmd]])
774 continue
775
776 elif isinstance(cmd, Value):
777 compiler = _RHSValueCompiler(self._signal_slots)
778 funclet = compiler(cmd)
779 cmd = process.send(funclet(self._state))
780 continue
781
782 else:
783 raise TypeError("Received unsupported command {!r} from process '{}'"
784 .format(cmd, self._name_process(process)))
785
786 cmd = process.send(None)
787
788 except StopIteration:
789 self._processes.remove(process)
790 self._passive.discard(process)
791
792 except Exception as e:
793 process.throw(e)
794
795 def step(self, run_passive=False):
796 # Are there any delta cycles we should run?
797 if self._state.curr_dirty.any():
798 # We might run some delta cycles, and we have simulator processes waiting on
799 # a deadline. Take care to not exceed the closest deadline.
800 if self._wait_deadline and \
801 (self._timestamp + self._delta) >= min(self._wait_deadline.values()):
802 # Oops, we blew the deadline. We *could* run the processes now, but this is
803 # virtually certainly a logic loop and a design bug, so bail out instead.d
804 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
805
806 domains = set()
807 while self._state.curr_dirty.any():
808 self._update_dirty_signals()
809 self._commit_comb_signals(domains)
810 self._commit_sync_signals(domains)
811 return True
812
813 # Are there any processes that haven't had a chance to run yet?
814 if len(self._processes) > len(self._suspended):
815 # Schedule an arbitrary one.
816 process = (self._processes - set(self._suspended)).pop()
817 self._run_process(process)
818 return True
819
820 # All processes are suspended. Are any of them active?
821 if len(self._processes) > len(self._passive) or run_passive:
822 # Are any of them suspended before a deadline?
823 if self._wait_deadline:
824 # Schedule the one with the lowest deadline.
825 process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
826 del self._wait_deadline[process]
827 self._suspended.remove(process)
828 self._timestamp = deadline
829 self._delta = 0.
830 self._run_process(process)
831 return True
832
833 # No processes, or all processes are passive. Nothing to do!
834 return False
835
836 def run(self):
837 self._run_called = True
838
839 while self.step():
840 pass
841
842 def run_until(self, deadline, run_passive=False):
843 self._run_called = True
844
845 while self._timestamp < deadline:
846 if not self.step(run_passive):
847 return False
848
849 return True
850
851 def __exit__(self, *args):
852 if not self._run_called:
853 warnings.warn("Simulation created, but not run", UserWarning)
854
855 if self._vcd_writer:
856 vcd_timestamp = (self._timestamp + self._delta) / self._epsilon
857 self._vcd_writer.close(vcd_timestamp)
858
859 if self._vcd_file and self._gtkw_file:
860 gtkw_save = GTKWSave(self._gtkw_file)
861 if hasattr(self._vcd_file, "name"):
862 gtkw_save.dumpfile(self._vcd_file.name)
863 if hasattr(self._vcd_file, "tell"):
864 gtkw_save.dumpfile_size(self._vcd_file.tell())
865
866 gtkw_save.treeopen("top")
867 gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
868
869 def add_trace(signal, **kwargs):
870 signal_slot = self._signal_slots[signal]
871 if self._vcd_names[signal_slot] is not None:
872 if len(signal) > 1 and not signal.decoder:
873 suffix = "[{}:0]".format(len(signal) - 1)
874 else:
875 suffix = ""
876 gtkw_save.trace(self._vcd_names[signal_slot] + suffix, **kwargs)
877
878 for domain in self._domains:
879 with gtkw_save.group("d.{}".format(domain.name)):
880 if domain.rst is not None:
881 add_trace(domain.rst)
882 add_trace(domain.clk)
883
884 for signal in self._traces:
885 add_trace(signal)
886
887 if self._vcd_file:
888 self._vcd_file.close()
889 if self._gtkw_file:
890 self._gtkw_file.close()