back.pysim: add (stub) LHSValueCompiler.
[nmigen.git] / nmigen / back / pysim.py
1 import math
2 import inspect
3 from contextlib import contextmanager
4 from vcd import VCDWriter
5 from vcd.gtkw import GTKWSave
6
7 from ..tools import flatten
8 from ..hdl.ast import *
9 from ..hdl.xfrm import ValueTransformer, StatementTransformer
10
11
12 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
13
14
15 class DeadlineError(Exception):
16 pass
17
18
19 class _State:
20 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
21
22 def __init__(self):
23 self.curr = ValueDict()
24 self.next = ValueDict()
25 self.curr_dirty = ValueSet()
26 self.next_dirty = ValueSet()
27
28 def set(self, signal, value):
29 assert isinstance(value, int)
30 if self.next[signal] != value:
31 self.next_dirty.add(signal)
32 self.next[signal] = value
33
34 def commit(self, signal):
35 old_value = self.curr[signal]
36 new_value = self.next[signal]
37 if old_value != new_value:
38 self.next_dirty.remove(signal)
39 self.curr_dirty.add(signal)
40 self.curr[signal] = new_value
41 return old_value, new_value
42
43
44 normalize = Const.normalize
45
46
47 class _RHSValueCompiler(ValueTransformer):
48 def __init__(self, sensitivity=None):
49 self.sensitivity = sensitivity
50
51 def on_Const(self, value):
52 return lambda state: value.value
53
54 def on_Signal(self, value):
55 if self.sensitivity is not None:
56 self.sensitivity.add(value)
57 return lambda state: state.curr[value]
58
59 def on_ClockSignal(self, value):
60 raise NotImplementedError # :nocov:
61
62 def on_ResetSignal(self, value):
63 raise NotImplementedError # :nocov:
64
65 def on_Operator(self, value):
66 shape = value.shape()
67 if len(value.operands) == 1:
68 arg, = map(self, value.operands)
69 if value.op == "~":
70 return lambda state: normalize(~arg(state), shape)
71 if value.op == "-":
72 return lambda state: normalize(-arg(state), shape)
73 if value.op == "b":
74 return lambda state: normalize(bool(arg(state)), shape)
75 elif len(value.operands) == 2:
76 lhs, rhs = map(self, value.operands)
77 if value.op == "+":
78 return lambda state: normalize(lhs(state) + rhs(state), shape)
79 if value.op == "-":
80 return lambda state: normalize(lhs(state) - rhs(state), shape)
81 if value.op == "&":
82 return lambda state: normalize(lhs(state) & rhs(state), shape)
83 if value.op == "|":
84 return lambda state: normalize(lhs(state) | rhs(state), shape)
85 if value.op == "^":
86 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
87 if value.op == "<<":
88 def sshl(lhs, rhs):
89 return lhs << rhs if rhs >= 0 else lhs >> -rhs
90 return lambda state: normalize(sshl(lhs(state), rhs(state)), shape)
91 if value.op == ">>":
92 def sshr(lhs, rhs):
93 return lhs >> rhs if rhs >= 0 else lhs << -rhs
94 return lambda state: normalize(sshr(lhs(state), rhs(state)), shape)
95 if value.op == "==":
96 return lambda state: normalize(lhs(state) == rhs(state), shape)
97 if value.op == "!=":
98 return lambda state: normalize(lhs(state) != rhs(state), shape)
99 if value.op == "<":
100 return lambda state: normalize(lhs(state) < rhs(state), shape)
101 if value.op == "<=":
102 return lambda state: normalize(lhs(state) <= rhs(state), shape)
103 if value.op == ">":
104 return lambda state: normalize(lhs(state) > rhs(state), shape)
105 if value.op == ">=":
106 return lambda state: normalize(lhs(state) >= rhs(state), shape)
107 elif len(value.operands) == 3:
108 if value.op == "m":
109 sel, val1, val0 = map(self, value.operands)
110 return lambda state: val1(state) if sel(state) else val0(state)
111 raise NotImplementedError("Operator '{}' not implemented".format(value.op)) # :nocov:
112
113 def on_Slice(self, value):
114 shape = value.shape()
115 arg = self(value.value)
116 shift = value.start
117 mask = (1 << (value.end - value.start)) - 1
118 return lambda state: normalize((arg(state) >> shift) & mask, shape)
119
120 def on_Part(self, value):
121 shape = value.shape()
122 arg = self(value.value)
123 shift = self(value.offset)
124 mask = (1 << value.width) - 1
125 return lambda state: normalize((arg(state) >> shift(state)) & mask, shape)
126
127 def on_Cat(self, value):
128 shape = value.shape()
129 parts = []
130 offset = 0
131 for opnd in value.operands:
132 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
133 offset += len(opnd)
134 def eval(state):
135 result = 0
136 for offset, mask, opnd in parts:
137 result |= (opnd(state) & mask) << offset
138 return normalize(result, shape)
139 return eval
140
141 def on_Repl(self, value):
142 shape = value.shape()
143 offset = len(value.value)
144 mask = (1 << len(value.value)) - 1
145 count = value.count
146 opnd = self(value.value)
147 def eval(state):
148 result = 0
149 for _ in range(count):
150 result <<= offset
151 result |= opnd(state)
152 return normalize(result, shape)
153 return eval
154
155 def on_ArrayProxy(self, value):
156 shape = value.shape()
157 elems = list(map(self, value.elems))
158 index = self(value.index)
159 return lambda state: normalize(elems[index(state)](state), shape)
160
161
162 class _LHSValueCompiler(ValueTransformer):
163 def on_Const(self, value):
164 raise TypeError # :nocov:
165
166 def on_Signal(self, value):
167 return lambda state, arg: state.set(value, arg)
168
169 def on_ClockSignal(self, value):
170 raise NotImplementedError # :nocov:
171
172 def on_ResetSignal(self, value):
173 raise NotImplementedError # :nocov:
174
175 def on_Operator(self, value):
176 raise TypeError # :nocov:
177
178 def on_Slice(self, value):
179 raise NotImplementedError
180
181 def on_Part(self, value):
182 raise NotImplementedError
183
184 def on_Cat(self, value):
185 raise NotImplementedError
186
187 def on_Repl(self, value):
188 raise TypeError # :nocov:
189
190 def on_ArrayProxy(self, value):
191 raise NotImplementedError
192
193
194 class _StatementCompiler(StatementTransformer):
195 def __init__(self):
196 self.sensitivity = ValueSet()
197 self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
198 self.lhs_compiler = _LHSValueCompiler()
199
200 def on_Assign(self, stmt):
201 shape = stmt.lhs.shape()
202 lhs = self.lhs_compiler(stmt.lhs)
203 rhs = self.rhs_compiler(stmt.rhs)
204 def run(state):
205 lhs(state, normalize(rhs(state), shape))
206 return run
207
208 def on_Switch(self, stmt):
209 test = self.rhs_compiler(stmt.test)
210 cases = []
211 for value, stmts in stmt.cases.items():
212 if "-" in value:
213 mask = "".join("0" if b == "-" else "1" for b in value)
214 value = "".join("0" if b == "-" else b for b in value)
215 else:
216 mask = "1" * len(value)
217 mask = int(mask, 2)
218 value = int(value, 2)
219 def make_test(mask, value):
220 return lambda test: test & mask == value
221 cases.append((make_test(mask, value), self.on_statements(stmts)))
222 def run(state):
223 test_value = test(state)
224 for check, body in cases:
225 if check(test_value):
226 body(state)
227 return
228 return run
229
230 def on_statements(self, stmts):
231 stmts = [self.on_statement(stmt) for stmt in stmts]
232 def run(state):
233 for stmt in stmts:
234 stmt(state)
235 return run
236
237
238 class Simulator:
239 def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()):
240 self._fragment = fragment
241
242 self._domains = dict() # str/domain -> ClockDomain
243 self._domain_triggers = ValueDict() # Signal -> str/domain
244 self._domain_signals = dict() # str/domain -> {Signal}
245
246 self._signals = ValueSet() # {Signal}
247 self._comb_signals = ValueSet() # {Signal}
248 self._sync_signals = ValueSet() # {Signal}
249 self._user_signals = ValueSet() # {Signal}
250
251 self._started = False
252 self._timestamp = 0.
253 self._delta = 0.
254 self._epsilon = 1e-10
255 self._fastest_clock = self._epsilon
256 self._state = _State()
257
258 self._processes = set() # {process}
259 self._process_loc = dict() # process -> str/loc
260 self._passive = set() # {process}
261 self._suspended = set() # {process}
262 self._wait_deadline = dict() # process -> float/timestamp
263 self._wait_tick = dict() # process -> str/domain
264
265 self._funclets = ValueDict() # Signal -> set(lambda)
266
267 self._vcd_file = vcd_file
268 self._vcd_writer = None
269 self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
270 self._vcd_names = ValueDict() # signal -> str/name
271 self._gtkw_file = gtkw_file
272 self._traces = traces
273
274 @staticmethod
275 def _check_process(process):
276 if inspect.isgeneratorfunction(process):
277 process = process()
278 if not inspect.isgenerator(process):
279 raise TypeError("Cannot add a process '{!r}' because it is not a generator or"
280 "a generator function"
281 .format(process))
282 return process
283
284 def _name_process(self, process):
285 if process in self._process_loc:
286 return self._process_loc[process]
287 else:
288 frame = process.gi_frame
289 return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
290
291 def add_process(self, process):
292 process = self._check_process(process)
293 self._processes.add(process)
294
295 def add_sync_process(self, process, domain="sync"):
296 process = self._check_process(process)
297 def sync_process():
298 try:
299 result = None
300 while True:
301 if result is None:
302 result = Tick(domain)
303 self._process_loc[sync_process] = self._name_process(process)
304 result = process.send((yield result))
305 except StopIteration:
306 pass
307 sync_process = sync_process()
308 self.add_process(sync_process)
309
310 def add_clock(self, period, phase=None, domain="sync"):
311 if self._fastest_clock == self._epsilon or period < self._fastest_clock:
312 self._fastest_clock = period
313
314 half_period = period / 2
315 if phase is None:
316 phase = half_period
317 clk = self._domains[domain].clk
318 def clk_process():
319 yield Passive()
320 yield Delay(phase)
321 while True:
322 yield clk.eq(1)
323 yield Delay(half_period)
324 yield clk.eq(0)
325 yield Delay(half_period)
326 self.add_process(clk_process)
327
328 def __enter__(self):
329 if self._vcd_file:
330 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
331 comment="Generated by nMigen")
332
333 root_fragment = self._fragment.prepare()
334
335 self._domains = root_fragment.domains
336 for domain, cd in self._domains.items():
337 self._domain_triggers[cd.clk] = domain
338 if cd.rst is not None:
339 self._domain_triggers[cd.rst] = domain
340 self._domain_signals[domain] = ValueSet()
341
342 hierarchy = {}
343 def add_fragment(fragment, scope=()):
344 hierarchy[fragment] = scope
345 for subfragment, name in fragment.subfragments:
346 add_fragment(subfragment, (*scope, name))
347 add_fragment(root_fragment)
348
349 for fragment, fragment_scope in hierarchy.items():
350 for signal in fragment.iter_signals():
351 self._signals.add(signal)
352
353 self._state.curr[signal] = self._state.next[signal] = \
354 normalize(signal.reset, signal.shape())
355 self._state.curr_dirty.add(signal)
356
357 if not self._vcd_writer:
358 continue
359
360 if signal not in self._vcd_signals:
361 self._vcd_signals[signal] = set()
362
363 for subfragment, name in fragment.subfragments:
364 if signal in subfragment.ports:
365 var_name = "{}_{}".format(name, signal.name)
366 break
367 else:
368 var_name = signal.name
369
370 if signal.decoder:
371 var_type = "string"
372 var_size = 1
373 var_init = signal.decoder(signal.reset).replace(" ", "_")
374 else:
375 var_type = "wire"
376 var_size = signal.nbits
377 var_init = signal.reset
378
379 suffix = None
380 while True:
381 try:
382 if suffix is None:
383 var_name_suffix = var_name
384 else:
385 var_name_suffix = "{}${}".format(var_name, suffix)
386 self._vcd_signals[signal].add(self._vcd_writer.register_var(
387 scope=".".join(fragment_scope), name=var_name_suffix,
388 var_type=var_type, size=var_size, init=var_init))
389 if signal not in self._vcd_names:
390 self._vcd_names[signal] = ".".join(fragment_scope + (var_name_suffix,))
391 break
392 except KeyError:
393 suffix = (suffix or 0) + 1
394
395 for domain, signals in fragment.drivers.items():
396 if domain is None:
397 self._comb_signals.update(signals)
398 else:
399 self._sync_signals.update(signals)
400 self._domain_signals[domain].update(signals)
401
402 statements = []
403 for signal in fragment.iter_comb():
404 statements.append(signal.eq(signal.reset))
405 for domain, signal in fragment.iter_sync():
406 statements.append(signal.eq(signal))
407 statements += fragment.statements
408
409 compiler = _StatementCompiler()
410 funclet = compiler(statements)
411
412 def add_funclet(signal, funclet):
413 if signal not in self._funclets:
414 self._funclets[signal] = set()
415 self._funclets[signal].add(funclet)
416
417 for signal in compiler.sensitivity:
418 add_funclet(signal, funclet)
419 for domain, cd in fragment.domains.items():
420 add_funclet(cd.clk, funclet)
421 if cd.rst is not None:
422 add_funclet(cd.rst, funclet)
423
424 self._user_signals = self._signals - self._comb_signals - self._sync_signals
425
426 return self
427
428 def _update_dirty_signals(self):
429 """Perform the statement part of IR processes (aka RTLIL case)."""
430 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
431 # that need their statements to be reevaluated because the signals changed at the previous
432 # delta cycle.
433 funclets = set()
434 while self._state.curr_dirty:
435 signal = self._state.curr_dirty.pop()
436 if signal in self._funclets:
437 funclets.update(self._funclets[signal])
438
439 # Second, compute the values of all signals at the start of the next delta cycle, by
440 # running precompiled statements.
441 for funclet in funclets:
442 funclet(self._state)
443
444 def _commit_signal(self, signal, domains):
445 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
446 # Take the computed value (at the start of this delta cycle) of a signal (that could have
447 # come from an IR process that ran earlier, or modified by a simulator process) and update
448 # the value for this delta cycle.
449 old, new = self._state.commit(signal)
450
451 # If the signal is a clock that triggers synchronous logic, record that fact.
452 if (old, new) == (0, 1) and signal in self._domain_triggers:
453 domains.add(self._domain_triggers[signal])
454
455 if self._vcd_writer and old != new:
456 # Finally, dump the new value to the VCD file.
457 for vcd_signal in self._vcd_signals[signal]:
458 if signal.decoder:
459 var_value = signal.decoder(new).replace(" ", "_")
460 else:
461 var_value = new
462 vcd_timestamp = (self._timestamp + self._delta) / self._epsilon
463 self._vcd_writer.change(vcd_signal, vcd_timestamp, var_value)
464
465 def _commit_comb_signals(self, domains):
466 """Perform the comb part of IR processes (aka RTLIL always)."""
467 # Take the computed value (at the start of this delta cycle) of every comb signal and
468 # update the value for this delta cycle.
469 for signal in self._state.next_dirty:
470 if signal in self._comb_signals:
471 self._commit_signal(signal, domains)
472
473 def _commit_sync_signals(self, domains):
474 """Perform the sync part of IR processes (aka RTLIL posedge)."""
475 # At entry, `domains` contains a list of every simultaneously triggered sync update.
476 while domains:
477 # Advance the timeline a bit (purely for observational purposes) and commit all of them
478 # at the same timestamp.
479 self._delta += self._epsilon
480 curr_domains, domains = domains, set()
481
482 while curr_domains:
483 domain = curr_domains.pop()
484
485 # Take the computed value (at the start of this delta cycle) of every sync signal
486 # in this domain and update the value for this delta cycle. This can trigger more
487 # synchronous logic, so record that.
488 for signal in self._state.next_dirty:
489 if signal in self._domain_signals[domain]:
490 self._commit_signal(signal, domains)
491
492 # Wake up any simulator processes that wait for a domain tick.
493 for process, wait_domain in list(self._wait_tick.items()):
494 if domain == wait_domain:
495 del self._wait_tick[process]
496 self._suspended.remove(process)
497
498 # Unless handling synchronous logic above has triggered more synchronous logic (which
499 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
500 # Otherwise, do one more round of updates.
501
502 def _run_process(self, process):
503 try:
504 cmd = process.send(None)
505 while True:
506 if isinstance(cmd, Delay):
507 if cmd.interval is None:
508 interval = self._epsilon
509 else:
510 interval = cmd.interval
511 self._wait_deadline[process] = self._timestamp + interval
512 self._suspended.add(process)
513
514 elif isinstance(cmd, Tick):
515 self._wait_tick[process] = cmd.domain
516 self._suspended.add(process)
517
518 elif isinstance(cmd, Passive):
519 self._passive.add(process)
520
521 elif isinstance(cmd, Value):
522 compiler = _RHSValueCompiler()
523 funclet = compiler(cmd)
524 cmd = process.send(funclet(self._state))
525 continue
526
527 elif isinstance(cmd, Assign):
528 lhs_signals = cmd.lhs._lhs_signals()
529 for signal in lhs_signals:
530 if not signal in self._signals:
531 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
532 "which is not a part of simulation"
533 .format(self._name_process(process), signal))
534 if signal in self._comb_signals:
535 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
536 "which is a part of combinatorial assignment in "
537 "simulation"
538 .format(self._name_process(process), signal))
539
540 compiler = _StatementCompiler()
541 funclet = compiler(cmd)
542 funclet(self._state)
543
544 domains = set()
545 for signal in lhs_signals:
546 self._commit_signal(signal, domains)
547 self._commit_sync_signals(domains)
548
549 else:
550 raise TypeError("Received unsupported command '{!r}' from process '{}'"
551 .format(cmd, self._name_process(process)))
552
553 break
554
555 except StopIteration:
556 self._processes.remove(process)
557 self._passive.discard(process)
558
559 except Exception as e:
560 process.throw(e)
561
562 def step(self, run_passive=False):
563 # Are there any delta cycles we should run?
564 if self._state.curr_dirty:
565 # We might run some delta cycles, and we have simulator processes waiting on
566 # a deadline. Take care to not exceed the closest deadline.
567 if self._wait_deadline and \
568 (self._timestamp + self._delta) >= min(self._wait_deadline.values()):
569 # Oops, we blew the deadline. We *could* run the processes now, but this is
570 # virtually certainly a logic loop and a design bug, so bail out instead.d
571 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
572
573 domains = set()
574 while self._state.curr_dirty:
575 self._update_dirty_signals()
576 self._commit_comb_signals(domains)
577 self._commit_sync_signals(domains)
578 return True
579
580 # Are there any processes that haven't had a chance to run yet?
581 if len(self._processes) > len(self._suspended):
582 # Schedule an arbitrary one.
583 process = (self._processes - set(self._suspended)).pop()
584 self._run_process(process)
585 return True
586
587 # All processes are suspended. Are any of them active?
588 if len(self._processes) > len(self._passive) or run_passive:
589 # Are any of them suspended before a deadline?
590 if self._wait_deadline:
591 # Schedule the one with the lowest deadline.
592 process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
593 del self._wait_deadline[process]
594 self._suspended.remove(process)
595 self._timestamp = deadline
596 self._delta = 0.
597 self._run_process(process)
598 return True
599
600 # No processes, or all processes are passive. Nothing to do!
601 return False
602
603 def run(self):
604 while self.step():
605 pass
606
607 def run_until(self, deadline, run_passive=False):
608 while self._timestamp < deadline:
609 if not self.step(run_passive):
610 return False
611
612 return True
613
614 def __exit__(self, *args):
615 if self._vcd_writer:
616 vcd_timestamp = (self._timestamp + self._delta) / self._epsilon
617 self._vcd_writer.close(vcd_timestamp)
618
619 if self._vcd_file and self._gtkw_file:
620 gtkw_save = GTKWSave(self._gtkw_file)
621 if hasattr(self._vcd_file, "name"):
622 gtkw_save.dumpfile(self._vcd_file.name)
623 if hasattr(self._vcd_file, "tell"):
624 gtkw_save.dumpfile_size(self._vcd_file.tell())
625
626 gtkw_save.treeopen("top")
627 gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
628
629 def add_trace(signal, **kwargs):
630 if signal in self._vcd_names:
631 if len(signal) > 1:
632 suffix = "[{}:0]".format(len(signal) - 1)
633 else:
634 suffix = ""
635 gtkw_save.trace(self._vcd_names[signal] + suffix, **kwargs)
636
637 for domain, cd in self._domains.items():
638 with gtkw_save.group("d.{}".format(domain)):
639 if cd.rst is not None:
640 add_trace(cd.rst)
641 add_trace(cd.clk)
642
643 for signal in self._traces:
644 add_trace(signal)
645
646 if self._vcd_file:
647 self._vcd_file.close()
648 if self._gtkw_file:
649 self._gtkw_file.close()