3 from contextlib
import contextmanager
4 from vcd
import VCDWriter
5 from vcd
.gtkw
import GTKWSave
7 from ..tools
import flatten
8 from ..hdl
.ast
import *
9 from ..hdl
.xfrm
import AbstractValueTransformer
, AbstractStatementTransformer
12 __all__
= ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
15 class DeadlineError(Exception):
20 __slots__
= ("curr", "curr_dirty", "next", "next_dirty")
23 self
.curr
= SignalDict()
24 self
.next
= SignalDict()
25 self
.curr_dirty
= SignalSet()
26 self
.next_dirty
= SignalSet()
28 def set(self
, signal
, value
):
29 assert isinstance(value
, int)
30 if self
.next
[signal
] != value
:
31 self
.next_dirty
.add(signal
)
32 self
.next
[signal
] = value
34 def commit(self
, signal
):
35 old_value
= self
.curr
[signal
]
36 new_value
= self
.next
[signal
]
37 if old_value
!= new_value
:
38 self
.next_dirty
.remove(signal
)
39 self
.curr_dirty
.add(signal
)
40 self
.curr
[signal
] = new_value
41 return old_value
, new_value
44 normalize
= Const
.normalize
47 class _RHSValueCompiler(AbstractValueTransformer
):
48 def __init__(self
, sensitivity
=None, mode
="rhs"):
49 self
.sensitivity
= sensitivity
50 self
.signal_mode
= mode
52 def on_Const(self
, value
):
53 return lambda state
: value
.value
55 def on_Signal(self
, value
):
56 if self
.sensitivity
is not None:
57 self
.sensitivity
.add(value
)
58 if self
.signal_mode
== "rhs":
59 return lambda state
: state
.curr
[value
]
60 elif self
.signal_mode
== "lhs":
61 return lambda state
: state
.next
[value
]
63 raise ValueError # :nocov:
65 def on_ClockSignal(self
, value
):
66 raise NotImplementedError # :nocov:
68 def on_ResetSignal(self
, value
):
69 raise NotImplementedError # :nocov:
71 def on_Operator(self
, value
):
73 if len(value
.operands
) == 1:
74 arg
, = map(self
, value
.operands
)
76 return lambda state
: normalize(~
arg(state
), shape
)
78 return lambda state
: normalize(-arg(state
), shape
)
80 return lambda state
: normalize(bool(arg(state
)), shape
)
81 elif len(value
.operands
) == 2:
82 lhs
, rhs
= map(self
, value
.operands
)
84 return lambda state
: normalize(lhs(state
) + rhs(state
), shape
)
86 return lambda state
: normalize(lhs(state
) - rhs(state
), shape
)
88 return lambda state
: normalize(lhs(state
) & rhs(state
), shape
)
90 return lambda state
: normalize(lhs(state
) |
rhs(state
), shape
)
92 return lambda state
: normalize(lhs(state
) ^
rhs(state
), shape
)
95 return lhs
<< rhs
if rhs
>= 0 else lhs
>> -rhs
96 return lambda state
: normalize(sshl(lhs(state
), rhs(state
)), shape
)
99 return lhs
>> rhs
if rhs
>= 0 else lhs
<< -rhs
100 return lambda state
: normalize(sshr(lhs(state
), rhs(state
)), shape
)
102 return lambda state
: normalize(lhs(state
) == rhs(state
), shape
)
104 return lambda state
: normalize(lhs(state
) != rhs(state
), shape
)
106 return lambda state
: normalize(lhs(state
) < rhs(state
), shape
)
108 return lambda state
: normalize(lhs(state
) <= rhs(state
), shape
)
110 return lambda state
: normalize(lhs(state
) > rhs(state
), shape
)
112 return lambda state
: normalize(lhs(state
) >= rhs(state
), shape
)
113 elif len(value
.operands
) == 3:
115 sel
, val1
, val0
= map(self
, value
.operands
)
116 return lambda state
: val1(state
) if sel(state
) else val0(state
)
117 raise NotImplementedError("Operator '{}' not implemented".format(value
.op
)) # :nocov:
119 def on_Slice(self
, value
):
120 shape
= value
.shape()
121 arg
= self(value
.value
)
123 mask
= (1 << (value
.end
- value
.start
)) - 1
124 return lambda state
: normalize((arg(state
) >> shift
) & mask
, shape
)
126 def on_Part(self
, value
):
127 shape
= value
.shape()
128 arg
= self(value
.value
)
129 shift
= self(value
.offset
)
130 mask
= (1 << value
.width
) - 1
131 return lambda state
: normalize((arg(state
) >> shift(state
)) & mask
, shape
)
133 def on_Cat(self
, value
):
134 shape
= value
.shape()
137 for opnd
in value
.operands
:
138 parts
.append((offset
, (1 << len(opnd
)) - 1, self(opnd
)))
142 for offset
, mask
, opnd
in parts
:
143 result |
= (opnd(state
) & mask
) << offset
144 return normalize(result
, shape
)
147 def on_Repl(self
, value
):
148 shape
= value
.shape()
149 offset
= len(value
.value
)
150 mask
= (1 << len(value
.value
)) - 1
152 opnd
= self(value
.value
)
155 for _
in range(count
):
157 result |
= opnd(state
)
158 return normalize(result
, shape
)
161 def on_ArrayProxy(self
, value
):
162 shape
= value
.shape()
163 elems
= list(map(self
, value
.elems
))
164 index
= self(value
.index
)
165 return lambda state
: normalize(elems
[index(state
)](state
), shape
)
168 class _LHSValueCompiler(AbstractValueTransformer
):
169 def __init__(self
, rhs_compiler
):
170 self
.rhs_compiler
= rhs_compiler
172 def on_Const(self
, value
):
173 raise TypeError # :nocov:
175 def on_Signal(self
, value
):
176 shape
= value
.shape()
177 def eval(state
, rhs
):
178 state
.set(value
, normalize(rhs
, shape
))
181 def on_ClockSignal(self
, value
):
182 raise NotImplementedError # :nocov:
184 def on_ResetSignal(self
, value
):
185 raise NotImplementedError # :nocov:
187 def on_Operator(self
, value
):
188 raise TypeError # :nocov:
190 def on_Slice(self
, value
):
191 lhs_r
= self
.rhs_compiler(value
.value
)
192 lhs_l
= self(value
.value
)
194 mask
= (1 << (value
.end
- value
.start
)) - 1
195 def eval(state
, rhs
):
196 lhs_value
= lhs_r(state
)
197 lhs_value
&= ~
(mask
<< shift
)
198 lhs_value |
= (rhs
& mask
) << shift
199 lhs_l(state
, lhs_value
)
202 def on_Part(self
, value
):
203 lhs_r
= self
.rhs_compiler(value
.value
)
204 lhs_l
= self(value
.value
)
205 shift
= self
.rhs_compiler(value
.offset
)
206 mask
= (1 << value
.width
) - 1
207 def eval(state
, rhs
):
208 lhs_value
= lhs_r(state
)
209 shift_value
= shift(state
)
210 lhs_value
&= ~
(mask
<< shift_value
)
211 lhs_value |
= (rhs
& mask
) << shift_value
212 lhs_l(state
, lhs_value
)
215 def on_Cat(self
, value
):
218 for opnd
in value
.operands
:
219 parts
.append((offset
, (1 << len(opnd
)) - 1, self(opnd
)))
221 def eval(state
, rhs
):
222 for offset
, mask
, opnd
in parts
:
223 opnd(state
, (rhs
>> offset
) & mask
)
226 def on_Repl(self
, value
):
227 raise TypeError # :nocov:
229 def on_ArrayProxy(self
, value
):
230 elems
= list(map(self
, value
.elems
))
231 index
= self
.rhs_compiler(value
.index
)
232 def eval(state
, rhs
):
233 elems
[index(state
)](state
, rhs
)
237 class _StatementCompiler(AbstractStatementTransformer
):
239 self
.sensitivity
= SignalSet()
240 self
.rrhs_compiler
= _RHSValueCompiler(self
.sensitivity
, mode
="rhs")
241 self
.lrhs_compiler
= _RHSValueCompiler(self
.sensitivity
, mode
="lhs")
242 self
.lhs_compiler
= _LHSValueCompiler(self
.lrhs_compiler
)
244 def on_Assign(self
, stmt
):
245 shape
= stmt
.lhs
.shape()
246 lhs
= self
.lhs_compiler(stmt
.lhs
)
247 rhs
= self
.rrhs_compiler(stmt
.rhs
)
249 lhs(state
, normalize(rhs(state
), shape
))
252 def on_Switch(self
, stmt
):
253 test
= self
.rrhs_compiler(stmt
.test
)
255 for value
, stmts
in stmt
.cases
.items():
257 mask
= "".join("0" if b
== "-" else "1" for b
in value
)
258 value
= "".join("0" if b
== "-" else b
for b
in value
)
260 mask
= "1" * len(value
)
262 value
= int(value
, 2)
263 def make_test(mask
, value
):
264 return lambda test
: test
& mask
== value
265 cases
.append((make_test(mask
, value
), self
.on_statements(stmts
)))
267 test_value
= test(state
)
268 for check
, body
in cases
:
269 if check(test_value
):
274 def on_statements(self
, stmts
):
275 stmts
= [self
.on_statement(stmt
) for stmt
in stmts
]
283 def __init__(self
, fragment
, vcd_file
=None, gtkw_file
=None, traces
=()):
284 self
._fragment
= fragment
286 self
._domains
= dict() # str/domain -> ClockDomain
287 self
._domain
_triggers
= SignalDict() # Signal -> str/domain
288 self
._domain
_signals
= dict() # str/domain -> {Signal}
290 self
._signals
= SignalSet() # {Signal}
291 self
._comb
_signals
= SignalSet() # {Signal}
292 self
._sync
_signals
= SignalSet() # {Signal}
293 self
._user
_signals
= SignalSet() # {Signal}
295 self
._started
= False
298 self
._epsilon
= 1e-10
299 self
._fastest
_clock
= self
._epsilon
300 self
._state
= _State()
302 self
._processes
= set() # {process}
303 self
._process
_loc
= dict() # process -> str/loc
304 self
._passive
= set() # {process}
305 self
._suspended
= set() # {process}
306 self
._wait
_deadline
= dict() # process -> float/timestamp
307 self
._wait
_tick
= dict() # process -> str/domain
309 self
._funclets
= SignalDict() # Signal -> set(lambda)
311 self
._vcd
_file
= vcd_file
312 self
._vcd
_writer
= None
313 self
._vcd
_signals
= SignalDict() # signal -> set(vcd_signal)
314 self
._vcd
_names
= SignalDict() # signal -> str/name
315 self
._gtkw
_file
= gtkw_file
316 self
._traces
= traces
319 def _check_process(process
):
320 if inspect
.isgeneratorfunction(process
):
322 if not inspect
.isgenerator(process
):
323 raise TypeError("Cannot add a process '{!r}' because it is not a generator or"
324 "a generator function"
328 def _name_process(self
, process
):
329 if process
in self
._process
_loc
:
330 return self
._process
_loc
[process
]
332 frame
= process
.gi_frame
333 return "{}:{}".format(inspect
.getfile(frame
), inspect
.getlineno(frame
))
335 def add_process(self
, process
):
336 process
= self
._check
_process
(process
)
337 self
._processes
.add(process
)
339 def add_sync_process(self
, process
, domain
="sync"):
340 process
= self
._check
_process
(process
)
345 self
._process
_loc
[sync_process
] = self
._name
_process
(process
)
346 cmd
= process
.send(result
)
350 except StopIteration:
352 sync_process
= sync_process()
353 self
.add_process(sync_process
)
355 def add_clock(self
, period
, phase
=None, domain
="sync"):
356 if self
._fastest
_clock
== self
._epsilon
or period
< self
._fastest
_clock
:
357 self
._fastest
_clock
= period
359 half_period
= period
/ 2
362 clk
= self
._domains
[domain
].clk
368 yield Delay(half_period
)
370 yield Delay(half_period
)
371 self
.add_process(clk_process
)
375 self
._vcd
_writer
= VCDWriter(self
._vcd
_file
, timescale
="100 ps",
376 comment
="Generated by nMigen")
378 root_fragment
= self
._fragment
.prepare()
380 self
._domains
= root_fragment
.domains
381 for domain
, cd
in self
._domains
.items():
382 self
._domain
_triggers
[cd
.clk
] = domain
383 if cd
.rst
is not None:
384 self
._domain
_triggers
[cd
.rst
] = domain
385 self
._domain
_signals
[domain
] = SignalSet()
388 def add_fragment(fragment
, scope
=()):
389 hierarchy
[fragment
] = scope
390 for subfragment
, name
in fragment
.subfragments
:
391 add_fragment(subfragment
, (*scope
, name
))
392 add_fragment(root_fragment
)
394 for fragment
, fragment_scope
in hierarchy
.items():
395 for signal
in fragment
.iter_signals():
396 self
._signals
.add(signal
)
398 self
._state
.curr
[signal
] = self
._state
.next
[signal
] = \
399 normalize(signal
.reset
, signal
.shape())
400 self
._state
.curr_dirty
.add(signal
)
402 if not self
._vcd
_writer
:
405 if signal
not in self
._vcd
_signals
:
406 self
._vcd
_signals
[signal
] = set()
408 for subfragment
, name
in fragment
.subfragments
:
409 if signal
in subfragment
.ports
:
410 var_name
= "{}_{}".format(name
, signal
.name
)
413 var_name
= signal
.name
418 var_init
= signal
.decoder(signal
.reset
).replace(" ", "_")
421 var_size
= signal
.nbits
422 var_init
= signal
.reset
428 var_name_suffix
= var_name
430 var_name_suffix
= "{}${}".format(var_name
, suffix
)
431 self
._vcd
_signals
[signal
].add(self
._vcd
_writer
.register_var(
432 scope
=".".join(fragment_scope
), name
=var_name_suffix
,
433 var_type
=var_type
, size
=var_size
, init
=var_init
))
434 if signal
not in self
._vcd
_names
:
435 self
._vcd
_names
[signal
] = ".".join(fragment_scope
+ (var_name_suffix
,))
438 suffix
= (suffix
or 0) + 1
440 for domain
, signals
in fragment
.drivers
.items():
442 self
._comb
_signals
.update(signals
)
444 self
._sync
_signals
.update(signals
)
445 self
._domain
_signals
[domain
].update(signals
)
448 for signal
in fragment
.iter_comb():
449 statements
.append(signal
.eq(signal
.reset
))
450 for domain
, signal
in fragment
.iter_sync():
451 statements
.append(signal
.eq(signal
))
452 statements
+= fragment
.statements
454 compiler
= _StatementCompiler()
455 funclet
= compiler(statements
)
457 def add_funclet(signal
, funclet
):
458 if signal
not in self
._funclets
:
459 self
._funclets
[signal
] = set()
460 self
._funclets
[signal
].add(funclet
)
462 for signal
in compiler
.sensitivity
:
463 add_funclet(signal
, funclet
)
464 for domain
, cd
in fragment
.domains
.items():
465 add_funclet(cd
.clk
, funclet
)
466 if cd
.rst
is not None:
467 add_funclet(cd
.rst
, funclet
)
469 self
._user
_signals
= self
._signals
- self
._comb
_signals
- self
._sync
_signals
473 def _update_dirty_signals(self
):
474 """Perform the statement part of IR processes (aka RTLIL case)."""
475 # First, for all dirty signals, use sensitivity lists to determine the set of fragments
476 # that need their statements to be reevaluated because the signals changed at the previous
479 while self
._state
.curr_dirty
:
480 signal
= self
._state
.curr_dirty
.pop()
481 if signal
in self
._funclets
:
482 funclets
.update(self
._funclets
[signal
])
484 # Second, compute the values of all signals at the start of the next delta cycle, by
485 # running precompiled statements.
486 for funclet
in funclets
:
489 def _commit_signal(self
, signal
, domains
):
490 """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
491 # Take the computed value (at the start of this delta cycle) of a signal (that could have
492 # come from an IR process that ran earlier, or modified by a simulator process) and update
493 # the value for this delta cycle.
494 old
, new
= self
._state
.commit(signal
)
496 # If the signal is a clock that triggers synchronous logic, record that fact.
497 if (old
, new
) == (0, 1) and signal
in self
._domain
_triggers
:
498 domains
.add(self
._domain
_triggers
[signal
])
500 if self
._vcd
_writer
and old
!= new
:
501 # Finally, dump the new value to the VCD file.
502 for vcd_signal
in self
._vcd
_signals
[signal
]:
504 var_value
= signal
.decoder(new
).replace(" ", "_")
507 vcd_timestamp
= (self
._timestamp
+ self
._delta
) / self
._epsilon
508 self
._vcd
_writer
.change(vcd_signal
, vcd_timestamp
, var_value
)
510 def _commit_comb_signals(self
, domains
):
511 """Perform the comb part of IR processes (aka RTLIL always)."""
512 # Take the computed value (at the start of this delta cycle) of every comb signal and
513 # update the value for this delta cycle.
514 for signal
in self
._state
.next_dirty
:
515 if signal
in self
._comb
_signals
:
516 self
._commit
_signal
(signal
, domains
)
518 def _commit_sync_signals(self
, domains
):
519 """Perform the sync part of IR processes (aka RTLIL posedge)."""
520 # At entry, `domains` contains a list of every simultaneously triggered sync update.
522 # Advance the timeline a bit (purely for observational purposes) and commit all of them
523 # at the same timestamp.
524 self
._delta
+= self
._epsilon
525 curr_domains
, domains
= domains
, set()
528 domain
= curr_domains
.pop()
530 # Take the computed value (at the start of this delta cycle) of every sync signal
531 # in this domain and update the value for this delta cycle. This can trigger more
532 # synchronous logic, so record that.
533 for signal
in self
._state
.next_dirty
:
534 if signal
in self
._domain
_signals
[domain
]:
535 self
._commit
_signal
(signal
, domains
)
537 # Wake up any simulator processes that wait for a domain tick.
538 for process
, wait_domain
in list(self
._wait
_tick
.items()):
539 if domain
== wait_domain
:
540 del self
._wait
_tick
[process
]
541 self
._suspended
.remove(process
)
543 # Immediately run the process. It is important that this happens here,
544 # and not on the next step, when all the processes will run anyway,
545 # because Tick() simulates an edge triggered process. Like DFFs that latch
546 # a value from the previous clock cycle, simulator processes observe signal
547 # values from the previous clock cycle on a tick, too.
548 self
._run
_process
(process
)
550 # Unless handling synchronous logic above has triggered more synchronous logic (which
551 # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done.
552 # Otherwise, do one more round of updates.
554 def _run_process(self
, process
):
556 cmd
= process
.send(None)
558 if isinstance(cmd
, Delay
):
559 if cmd
.interval
is None:
560 interval
= self
._epsilon
562 interval
= cmd
.interval
563 self
._wait
_deadline
[process
] = self
._timestamp
+ interval
564 self
._suspended
.add(process
)
567 elif isinstance(cmd
, Tick
):
568 self
._wait
_tick
[process
] = cmd
.domain
569 self
._suspended
.add(process
)
572 elif isinstance(cmd
, Passive
):
573 self
._passive
.add(process
)
575 elif isinstance(cmd
, Value
):
576 compiler
= _RHSValueCompiler()
577 funclet
= compiler(cmd
)
578 cmd
= process
.send(funclet(self
._state
))
581 elif isinstance(cmd
, Assign
):
582 lhs_signals
= cmd
.lhs
._lhs
_signals
()
583 for signal
in lhs_signals
:
584 if not signal
in self
._signals
:
585 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
586 "which is not a part of simulation"
587 .format(self
._name
_process
(process
), signal
))
588 if signal
in self
._comb
_signals
:
589 raise ValueError("Process '{}' sent a request to set signal '{!r}', "
590 "which is a part of combinatorial assignment in "
592 .format(self
._name
_process
(process
), signal
))
594 compiler
= _StatementCompiler()
595 funclet
= compiler(cmd
)
599 for signal
in lhs_signals
:
600 self
._commit
_signal
(signal
, domains
)
601 self
._commit
_sync
_signals
(domains
)
604 raise TypeError("Received unsupported command '{!r}' from process '{}'"
605 .format(cmd
, self
._name
_process
(process
)))
607 cmd
= process
.send(None)
609 except StopIteration:
610 self
._processes
.remove(process
)
611 self
._passive
.discard(process
)
613 except Exception as e
:
616 def step(self
, run_passive
=False):
617 # Are there any delta cycles we should run?
618 if self
._state
.curr_dirty
:
619 # We might run some delta cycles, and we have simulator processes waiting on
620 # a deadline. Take care to not exceed the closest deadline.
621 if self
._wait
_deadline
and \
622 (self
._timestamp
+ self
._delta
) >= min(self
._wait
_deadline
.values()):
623 # Oops, we blew the deadline. We *could* run the processes now, but this is
624 # virtually certainly a logic loop and a design bug, so bail out instead.d
625 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
628 while self
._state
.curr_dirty
:
629 self
._update
_dirty
_signals
()
630 self
._commit
_comb
_signals
(domains
)
631 self
._commit
_sync
_signals
(domains
)
634 # Are there any processes that haven't had a chance to run yet?
635 if len(self
._processes
) > len(self
._suspended
):
636 # Schedule an arbitrary one.
637 process
= (self
._processes
- set(self
._suspended
)).pop()
638 self
._run
_process
(process
)
641 # All processes are suspended. Are any of them active?
642 if len(self
._processes
) > len(self
._passive
) or run_passive
:
643 # Are any of them suspended before a deadline?
644 if self
._wait
_deadline
:
645 # Schedule the one with the lowest deadline.
646 process
, deadline
= min(self
._wait
_deadline
.items(), key
=lambda x
: x
[1])
647 del self
._wait
_deadline
[process
]
648 self
._suspended
.remove(process
)
649 self
._timestamp
= deadline
651 self
._run
_process
(process
)
654 # No processes, or all processes are passive. Nothing to do!
661 def run_until(self
, deadline
, run_passive
=False):
662 while self
._timestamp
< deadline
:
663 if not self
.step(run_passive
):
668 def __exit__(self
, *args
):
670 vcd_timestamp
= (self
._timestamp
+ self
._delta
) / self
._epsilon
671 self
._vcd
_writer
.close(vcd_timestamp
)
673 if self
._vcd
_file
and self
._gtkw
_file
:
674 gtkw_save
= GTKWSave(self
._gtkw
_file
)
675 if hasattr(self
._vcd
_file
, "name"):
676 gtkw_save
.dumpfile(self
._vcd
_file
.name
)
677 if hasattr(self
._vcd
_file
, "tell"):
678 gtkw_save
.dumpfile_size(self
._vcd
_file
.tell())
680 gtkw_save
.treeopen("top")
681 gtkw_save
.zoom_markers(math
.log(self
._epsilon
/ self
._fastest
_clock
) - 14)
683 def add_trace(signal
, **kwargs
):
684 if signal
in self
._vcd
_names
:
686 suffix
= "[{}:0]".format(len(signal
) - 1)
689 gtkw_save
.trace(self
._vcd
_names
[signal
] + suffix
, **kwargs
)
691 for domain
, cd
in self
._domains
.items():
692 with gtkw_save
.group("d.{}".format(domain
)):
693 if cd
.rst
is not None:
697 for signal
in self
._traces
:
701 self
._vcd
_file
.close()
703 self
._gtkw
_file
.close()