X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=nmigen%2Fback%2Fpysim.py;h=a9c6c162ac09cc5e9953a6f10e609c3cd69e961f;hb=e435a217;hp=3f9f4b1eb0de8cc886c62d2dad65c9d6fe200b1f;hpb=39605ef5510582d6585d6e689b4abf3b937481a8;p=nmigen.git diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 3f9f4b1..a9c6c16 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -1,66 +1,326 @@ -import math +import os +import tempfile +import warnings import inspect from contextlib import contextmanager +import itertools from vcd import VCDWriter from vcd.gtkw import GTKWSave -from ..tools import flatten +from .._utils import deprecated from ..hdl.ast import * -from ..hdl.xfrm import AbstractValueTransformer, AbstractStatementTransformer +from ..hdl.cd import * +from ..hdl.ir import * +from ..hdl.xfrm import ValueVisitor, StatementVisitor, LHSGroupFilter -__all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"] +class Command: + pass -class DeadlineError(Exception): - pass +class Settle(Command): + def __repr__(self): + return "(settle)" -class _State: - __slots__ = ("curr", "curr_dirty", "next", "next_dirty") +class Delay(Command): + def __init__(self, interval=None): + self.interval = None if interval is None else float(interval) - def __init__(self): - self.curr = SignalDict() - self.next = SignalDict() - self.curr_dirty = SignalSet() - self.next_dirty = SignalSet() + def __repr__(self): + if self.interval is None: + return "(delay ε)" + else: + return "(delay {:.3}us)".format(self.interval * 1e6) - def set(self, signal, value): - assert isinstance(value, int) - if self.next[signal] != value: - self.next_dirty.add(signal) - self.next[signal] = value - def commit(self, signal): - old_value = self.curr[signal] - new_value = self.next[signal] - if old_value != new_value: - self.next_dirty.remove(signal) - self.curr_dirty.add(signal) - self.curr[signal] = new_value - return old_value, new_value +class Tick(Command): + def __init__(self, domain="sync"): + if not isinstance(domain, (str, ClockDomain)): + raise TypeError("Domain must be a string or a ClockDomain instance, not {!r}" + .format(domain)) + assert domain != "comb" + self.domain = domain + def __repr__(self): + return "(tick {})".format(self.domain) -normalize = Const.normalize +class Passive(Command): + def __repr__(self): + return "(passive)" -class _RHSValueCompiler(AbstractValueTransformer): - def __init__(self, sensitivity=None, mode="rhs"): - self.sensitivity = sensitivity - self.signal_mode = mode - def on_Const(self, value): - return lambda state: value.value +class Active(Command): + def __repr__(self): + return "(active)" - def on_Signal(self, value): - if self.sensitivity is not None: - self.sensitivity.add(value) - if self.signal_mode == "rhs": - return lambda state: state.curr[value] - elif self.signal_mode == "lhs": - return lambda state: state.next[value] + +class _WaveformWriter: + def update(self, timestamp, signal, value): + raise NotImplementedError # :nocov: + + def close(self, timestamp): + raise NotImplementedError # :nocov: + + +class _VCDWaveformWriter(_WaveformWriter): + @staticmethod + def timestamp_to_vcd(timestamp): + return timestamp * (10 ** 10) # 1/(100 ps) + + @staticmethod + def decode_to_vcd(signal, value): + return signal.decoder(value).expandtabs().replace(" ", "_") + + def __init__(self, signal_names, *, vcd_file, gtkw_file=None, traces=()): + if isinstance(vcd_file, str): + vcd_file = open(vcd_file, "wt") + if isinstance(gtkw_file, str): + gtkw_file = open(gtkw_file, "wt") + + self.vcd_vars = SignalDict() + self.vcd_file = vcd_file + self.vcd_writer = vcd_file and VCDWriter(self.vcd_file, + timescale="100 ps", comment="Generated by nMigen") + + self.gtkw_names = SignalDict() + self.gtkw_file = gtkw_file + self.gtkw_save = gtkw_file and GTKWSave(self.gtkw_file) + + self.traces = [] + + trace_names = SignalDict() + for trace in traces: + if trace not in signal_names: + trace_names[trace] = {("top", trace.name)} + self.traces.append(trace) + + if self.vcd_writer is None: + return + + for signal, names in itertools.chain(signal_names.items(), trace_names.items()): + if signal.decoder: + var_type = "string" + var_size = 1 + var_init = self.decode_to_vcd(signal, signal.reset) + else: + var_type = "wire" + var_size = signal.width + var_init = signal.reset + + for (*var_scope, var_name) in names: + suffix = None + while True: + try: + if suffix is None: + var_name_suffix = var_name + else: + var_name_suffix = "{}${}".format(var_name, suffix) + vcd_var = self.vcd_writer.register_var( + scope=var_scope, name=var_name_suffix, + var_type=var_type, size=var_size, init=var_init) + break + except KeyError: + suffix = (suffix or 0) + 1 + + if signal not in self.vcd_vars: + self.vcd_vars[signal] = set() + self.vcd_vars[signal].add(vcd_var) + + if signal not in self.gtkw_names: + self.gtkw_names[signal] = (*var_scope, var_name_suffix) + + def update(self, timestamp, signal, value): + vcd_vars = self.vcd_vars.get(signal) + if vcd_vars is None: + return + + vcd_timestamp = self.timestamp_to_vcd(timestamp) + if signal.decoder: + var_value = self.decode_to_vcd(signal, value) else: - raise ValueError # :nocov: + var_value = value + for vcd_var in vcd_vars: + self.vcd_writer.change(vcd_var, vcd_timestamp, var_value) + + def close(self, timestamp): + if self.vcd_writer is not None: + self.vcd_writer.close(self.timestamp_to_vcd(timestamp)) + + if self.gtkw_save is not None: + self.gtkw_save.dumpfile(self.vcd_file.name) + self.gtkw_save.dumpfile_size(self.vcd_file.tell()) + + self.gtkw_save.treeopen("top") + for signal in self.traces: + if len(signal) > 1 and not signal.decoder: + suffix = "[{}:0]".format(len(signal) - 1) + else: + suffix = "" + self.gtkw_save.trace(".".join(self.gtkw_names[signal]) + suffix) + + if self.vcd_file is not None: + self.vcd_file.close() + if self.gtkw_file is not None: + self.gtkw_file.close() + + +class _Process: + __slots__ = ("runnable", "passive") + + def reset(self): + raise NotImplementedError # :nocov: + + def run(self): + raise NotImplementedError # :nocov: + + +class _SignalState: + __slots__ = ("signal", "curr", "next", "waiters", "pending") + + def __init__(self, signal, pending): + self.signal = signal + self.pending = pending + self.waiters = dict() + self.curr = self.next = signal.reset + + def set(self, value): + if self.next == value: + return + self.next = value + self.pending.add(self) + + def commit(self): + if self.curr == self.next: + return False + self.curr = self.next + + awoken_any = False + for process, trigger in self.waiters.items(): + if trigger is None or trigger == self.curr: + process.runnable = awoken_any = True + return awoken_any + + +class _SimulatorState: + def __init__(self): + self.signals = SignalDict() + self.slots = [] + self.pending = set() + + self.timestamp = 0.0 + self.deadlines = dict() + + def reset(self): + for signal, index in self.signals.items(): + self.slots[index].curr = self.slots[index].next = signal.reset + self.pending.clear() + + self.timestamp = 0.0 + self.deadlines.clear() + + def get_signal(self, signal): + try: + return self.signals[signal] + except KeyError: + index = len(self.slots) + self.slots.append(_SignalState(signal, self.pending)) + self.signals[signal] = index + return index + + def add_trigger(self, process, signal, *, trigger=None): + index = self.get_signal(signal) + assert (process not in self.slots[index].waiters or + self.slots[index].waiters[process] == trigger) + self.slots[index].waiters[process] = trigger + + def remove_trigger(self, process, signal): + index = self.get_signal(signal) + assert process in self.slots[index].waiters + del self.slots[index].waiters[process] + + def commit(self): + converged = True + for signal_state in self.pending: + if signal_state.commit(): + converged = False + self.pending.clear() + return converged + + def advance(self): + nearest_processes = set() + nearest_deadline = None + for process, deadline in self.deadlines.items(): + if deadline is None: + if nearest_deadline is not None: + nearest_processes.clear() + nearest_processes.add(process) + nearest_deadline = self.timestamp + break + elif nearest_deadline is None or deadline <= nearest_deadline: + assert deadline >= self.timestamp + if nearest_deadline is not None and deadline < nearest_deadline: + nearest_processes.clear() + nearest_processes.add(process) + nearest_deadline = deadline + + if not nearest_processes: + return False + + for process in nearest_processes: + process.runnable = True + del self.deadlines[process] + self.timestamp = nearest_deadline + + return True + + +class _Emitter: + def __init__(self): + self._buffer = [] + self._suffix = 0 + self._level = 0 + + def append(self, code): + self._buffer.append(" " * self._level) + self._buffer.append(code) + self._buffer.append("\n") + + @contextmanager + def indent(self): + self._level += 1 + yield + self._level -= 1 + + def flush(self, indent=""): + code = "".join(self._buffer) + self._buffer.clear() + return code + + def gen_var(self, prefix): + name = f"{prefix}_{self._suffix}" + self._suffix += 1 + return name + + def def_var(self, prefix, value): + name = self.gen_var(prefix) + self.append(f"{name} = {value}") + return name + + +class _Compiler: + def __init__(self, state, emitter): + self.state = state + self.emitter = emitter + + +class _ValueCompiler(ValueVisitor, _Compiler): + helpers = { + "sign": lambda value, sign: value | sign if value & sign else value, + "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs, + "zmod": lambda lhs, rhs: 0 if rhs == 0 else lhs % rhs, + } def on_ClockSignal(self, value): raise NotImplementedError # :nocov: @@ -68,636 +328,725 @@ class _RHSValueCompiler(AbstractValueTransformer): def on_ResetSignal(self, value): raise NotImplementedError # :nocov: + def on_AnyConst(self, value): + raise NotImplementedError # :nocov: + + def on_AnySeq(self, value): + raise NotImplementedError # :nocov: + + def on_Sample(self, value): + raise NotImplementedError # :nocov: + + def on_Initial(self, value): + raise NotImplementedError # :nocov: + + +class _RHSValueCompiler(_ValueCompiler): + def __init__(self, state, emitter, *, mode, inputs=None): + super().__init__(state, emitter) + assert mode in ("curr", "next") + self.mode = mode + # If not None, `inputs` gets populated with RHS signals. + self.inputs = inputs + + def on_Const(self, value): + return f"{value.value}" + + def on_Signal(self, value): + if self.inputs is not None: + self.inputs.add(value) + + if self.mode == "curr": + return f"slots[{self.state.get_signal(value)}].{self.mode}" + else: + return f"next_{self.state.get_signal(value)}" + def on_Operator(self, value): - shape = value.shape() + def mask(value): + value_mask = (1 << len(value)) - 1 + return f"({self(value)} & {value_mask})" + + def sign(value): + if value.shape().signed: + return f"sign({mask(value)}, {-1 << (len(value) - 1)})" + else: # unsigned + return mask(value) + if len(value.operands) == 1: - arg, = map(self, value.operands) - if value.op == "~": - return lambda state: normalize(~arg(state), shape) - if value.op == "-": - return lambda state: normalize(-arg(state), shape) - if value.op == "b": - return lambda state: normalize(bool(arg(state)), shape) + arg, = value.operands + if value.operator == "~": + return f"(~{self(arg)})" + if value.operator == "-": + return f"(-{self(arg)})" + if value.operator == "b": + return f"bool({mask(arg)})" + if value.operator == "r|": + return f"({mask(arg)} != 0)" + if value.operator == "r&": + return f"({mask(arg)} == {(1 << len(arg)) - 1})" + if value.operator == "r^": + # Believe it or not, this is the fastest way to compute a sideways XOR in Python. + return f"(format({mask(arg)}, 'b').count('1') % 2)" + if value.operator in ("u", "s"): + # These operators don't change the bit pattern, only its interpretation. + return self(arg) elif len(value.operands) == 2: - lhs, rhs = map(self, value.operands) - if value.op == "+": - return lambda state: normalize(lhs(state) + rhs(state), shape) - if value.op == "-": - return lambda state: normalize(lhs(state) - rhs(state), shape) - if value.op == "&": - return lambda state: normalize(lhs(state) & rhs(state), shape) - if value.op == "|": - return lambda state: normalize(lhs(state) | rhs(state), shape) - if value.op == "^": - return lambda state: normalize(lhs(state) ^ rhs(state), shape) - if value.op == "<<": - def sshl(lhs, rhs): - return lhs << rhs if rhs >= 0 else lhs >> -rhs - return lambda state: normalize(sshl(lhs(state), rhs(state)), shape) - if value.op == ">>": - def sshr(lhs, rhs): - return lhs >> rhs if rhs >= 0 else lhs << -rhs - return lambda state: normalize(sshr(lhs(state), rhs(state)), shape) - if value.op == "==": - return lambda state: normalize(lhs(state) == rhs(state), shape) - if value.op == "!=": - return lambda state: normalize(lhs(state) != rhs(state), shape) - if value.op == "<": - return lambda state: normalize(lhs(state) < rhs(state), shape) - if value.op == "<=": - return lambda state: normalize(lhs(state) <= rhs(state), shape) - if value.op == ">": - return lambda state: normalize(lhs(state) > rhs(state), shape) - if value.op == ">=": - return lambda state: normalize(lhs(state) >= rhs(state), shape) + lhs, rhs = value.operands + lhs_mask = (1 << len(lhs)) - 1 + rhs_mask = (1 << len(rhs)) - 1 + if value.operator == "+": + return f"({sign(lhs)} + {sign(rhs)})" + if value.operator == "-": + return f"({sign(lhs)} - {sign(rhs)})" + if value.operator == "*": + return f"({sign(lhs)} * {sign(rhs)})" + if value.operator == "//": + return f"zdiv({sign(lhs)}, {sign(rhs)})" + if value.operator == "%": + return f"zmod({sign(lhs)}, {sign(rhs)})" + if value.operator == "&": + return f"({self(lhs)} & {self(rhs)})" + if value.operator == "|": + return f"({self(lhs)} | {self(rhs)})" + if value.operator == "^": + return f"({self(lhs)} ^ {self(rhs)})" + if value.operator == "<<": + return f"({sign(lhs)} << {sign(rhs)})" + if value.operator == ">>": + return f"({sign(lhs)} >> {sign(rhs)})" + if value.operator == "==": + return f"({sign(lhs)} == {sign(rhs)})" + if value.operator == "!=": + return f"({sign(lhs)} != {sign(rhs)})" + if value.operator == "<": + return f"({sign(lhs)} < {sign(rhs)})" + if value.operator == "<=": + return f"({sign(lhs)} <= {sign(rhs)})" + if value.operator == ">": + return f"({sign(lhs)} > {sign(rhs)})" + if value.operator == ">=": + return f"({sign(lhs)} >= {sign(rhs)})" elif len(value.operands) == 3: - if value.op == "m": - sel, val1, val0 = map(self, value.operands) - return lambda state: val1(state) if sel(state) else val0(state) - raise NotImplementedError("Operator '{}' not implemented".format(value.op)) # :nocov: + if value.operator == "m": + sel, val1, val0 = value.operands + return f"({self(val1)} if {self(sel)} else {self(val0)})" + raise NotImplementedError("Operator '{}' not implemented".format(value.operator)) # :nocov: def on_Slice(self, value): - shape = value.shape() - arg = self(value.value) - shift = value.start - mask = (1 << (value.end - value.start)) - 1 - return lambda state: normalize((arg(state) >> shift) & mask, shape) + return f"(({self(value.value)} >> {value.start}) & {(1 << len(value)) - 1})" def on_Part(self, value): - shape = value.shape() - arg = self(value.value) - shift = self(value.offset) - mask = (1 << value.width) - 1 - return lambda state: normalize((arg(state) >> shift(state)) & mask, shape) + offset_mask = (1 << len(value.offset)) - 1 + offset = f"(({self(value.offset)} & {offset_mask}) * {value.stride})" + return f"({self(value.value)} >> {offset} & " \ + f"{(1 << value.width) - 1})" def on_Cat(self, value): - shape = value.shape() - parts = [] + gen_parts = [] offset = 0 - for opnd in value.operands: - parts.append((offset, (1 << len(opnd)) - 1, self(opnd))) - offset += len(opnd) - def eval(state): - result = 0 - for offset, mask, opnd in parts: - result |= (opnd(state) & mask) << offset - return normalize(result, shape) - return eval + for part in value.parts: + part_mask = (1 << len(part)) - 1 + gen_parts.append(f"(({self(part)} & {part_mask}) << {offset})") + offset += len(part) + if gen_parts: + return f"({' | '.join(gen_parts)})" + return f"0" def on_Repl(self, value): - shape = value.shape() - offset = len(value.value) - mask = (1 << len(value.value)) - 1 - count = value.count - opnd = self(value.value) - def eval(state): - result = 0 - for _ in range(count): - result <<= offset - result |= opnd(state) - return normalize(result, shape) - return eval + part_mask = (1 << len(value.value)) - 1 + gen_part = self.emitter.def_var("repl", f"{self(value.value)} & {part_mask}") + gen_parts = [] + offset = 0 + for _ in range(value.count): + gen_parts.append(f"({gen_part} << {offset})") + offset += len(value.value) + if gen_parts: + return f"({' | '.join(gen_parts)})" + return f"0" def on_ArrayProxy(self, value): - shape = value.shape() - elems = list(map(self, value.elems)) - index = self(value.index) - return lambda state: normalize(elems[index(state)](state), shape) - - -class _LHSValueCompiler(AbstractValueTransformer): - def __init__(self, rhs_compiler): - self.rhs_compiler = rhs_compiler + index_mask = (1 << len(value.index)) - 1 + gen_index = self.emitter.def_var("rhs_index", f"{self(value.index)} & {index_mask}") + gen_value = self.emitter.gen_var("rhs_proxy") + if value.elems: + gen_elems = [] + for index, elem in enumerate(value.elems): + if index == 0: + self.emitter.append(f"if {gen_index} == {index}:") + else: + self.emitter.append(f"elif {gen_index} == {index}:") + with self.emitter.indent(): + self.emitter.append(f"{gen_value} = {self(elem)}") + self.emitter.append(f"else:") + with self.emitter.indent(): + self.emitter.append(f"{gen_value} = {self(value.elems[-1])}") + return gen_value + else: + return f"0" + + @classmethod + def compile(cls, state, value, *, mode, inputs=None): + emitter = _Emitter() + compiler = cls(state, emitter, mode=mode, inputs=inputs) + emitter.append(f"result = {compiler(value)}") + return emitter.flush() + + +class _LHSValueCompiler(_ValueCompiler): + def __init__(self, state, emitter, *, rhs, outputs=None): + super().__init__(state, emitter) + # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g. + # the offset of a Part. + self.rrhs = rhs + # `lrhs` is used to translate the read part of a read-modify-write cycle during partial + # update of an lvalue. + self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None) + # If not None, `outputs` gets populated with signals on LHS. + self.outputs = outputs def on_Const(self, value): raise TypeError # :nocov: def on_Signal(self, value): - shape = value.shape() - def eval(state, rhs): - state.set(value, normalize(rhs, shape)) - return eval - - def on_ClockSignal(self, value): - raise NotImplementedError # :nocov: - - def on_ResetSignal(self, value): - raise NotImplementedError # :nocov: + if self.outputs is not None: + self.outputs.add(value) + + def gen(arg): + value_mask = (1 << len(value)) - 1 + if value.shape().signed: + value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})" + else: # unsigned + value_sign = f"{arg} & {value_mask}" + self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}") + return gen def on_Operator(self, value): raise TypeError # :nocov: def on_Slice(self, value): - lhs_r = self.rhs_compiler(value.value) - lhs_l = self(value.value) - shift = value.start - mask = (1 << (value.end - value.start)) - 1 - def eval(state, rhs): - lhs_value = lhs_r(state) - lhs_value &= ~(mask << shift) - lhs_value |= (rhs & mask) << shift - lhs_l(state, lhs_value) - return eval + def gen(arg): + width_mask = (1 << (value.stop - value.start)) - 1 + self(value.value)(f"({self.lrhs(value.value)} & " \ + f"{~(width_mask << value.start)} | " \ + f"(({arg} & {width_mask}) << {value.start}))") + return gen def on_Part(self, value): - lhs_r = self.rhs_compiler(value.value) - lhs_l = self(value.value) - shift = self.rhs_compiler(value.offset) - mask = (1 << value.width) - 1 - def eval(state, rhs): - lhs_value = lhs_r(state) - shift_value = shift(state) - lhs_value &= ~(mask << shift_value) - lhs_value |= (rhs & mask) << shift_value - lhs_l(state, lhs_value) - return eval + def gen(arg): + width_mask = (1 << value.width) - 1 + offset_mask = (1 << len(value.offset)) - 1 + offset = f"(({self.rrhs(value.offset)} & {offset_mask}) * {value.stride})" + self(value.value)(f"({self.lrhs(value.value)} & " \ + f"~({width_mask} << {offset}) | " \ + f"(({arg} & {width_mask}) << {offset}))") + return gen def on_Cat(self, value): - parts = [] - offset = 0 - for opnd in value.operands: - parts.append((offset, (1 << len(opnd)) - 1, self(opnd))) - offset += len(opnd) - def eval(state, rhs): - for offset, mask, opnd in parts: - opnd(state, (rhs >> offset) & mask) - return eval + def gen(arg): + gen_arg = self.emitter.def_var("cat", arg) + gen_parts = [] + offset = 0 + for part in value.parts: + part_mask = (1 << len(part)) - 1 + self(part)(f"(({gen_arg} >> {offset}) & {part_mask})") + offset += len(part) + return gen def on_Repl(self, value): raise TypeError # :nocov: def on_ArrayProxy(self, value): - elems = list(map(self, value.elems)) - index = self.rhs_compiler(value.index) - def eval(state, rhs): - elems[index(state)](state, rhs) - return eval + def gen(arg): + index_mask = (1 << len(value.index)) - 1 + gen_index = self.emitter.def_var("index", f"{self.rrhs(value.index)} & {index_mask}") + if value.elems: + gen_elems = [] + for index, elem in enumerate(value.elems): + if index == 0: + self.emitter.append(f"if {gen_index} == {index}:") + else: + self.emitter.append(f"elif {gen_index} == {index}:") + with self.emitter.indent(): + self(elem)(arg) + self.emitter.append(f"else:") + with self.emitter.indent(): + self(value.elems[-1])(arg) + else: + self.emitter.append(f"pass") + return gen + @classmethod + def compile(cls, state, stmt, *, inputs=None, outputs=None): + emitter = _Emitter() + compiler = cls(state, emitter, inputs=inputs, outputs=outputs) + compiler(stmt) + return emitter.flush() -class _StatementCompiler(AbstractStatementTransformer): - def __init__(self): - self.sensitivity = SignalSet() - self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs") - self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs") - self.lhs_compiler = _LHSValueCompiler(self.lrhs_compiler) + +class _StatementCompiler(StatementVisitor, _Compiler): + def __init__(self, state, emitter, *, inputs=None, outputs=None): + super().__init__(state, emitter) + self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs) + self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs) + + def on_statements(self, stmts): + for stmt in stmts: + self(stmt) + if not stmts: + self.emitter.append("pass") def on_Assign(self, stmt): - shape = stmt.lhs.shape() - lhs = self.lhs_compiler(stmt.lhs) - rhs = self.rrhs_compiler(stmt.rhs) - def run(state): - lhs(state, normalize(rhs(state), shape)) - return run + return self.lhs(stmt.lhs)(self.rhs(stmt.rhs)) def on_Switch(self, stmt): - test = self.rrhs_compiler(stmt.test) - cases = [] - for value, stmts in stmt.cases.items(): - if "-" in value: - mask = "".join("0" if b == "-" else "1" for b in value) - value = "".join("0" if b == "-" else b for b in value) + gen_test = self.emitter.def_var("test", + f"{self.rhs(stmt.test)} & {(1 << len(stmt.test)) - 1}") + for index, (patterns, stmts) in enumerate(stmt.cases.items()): + gen_checks = [] + if not patterns: + gen_checks.append(f"True") + else: + for pattern in patterns: + if "-" in pattern: + mask = int("".join("0" if b == "-" else "1" for b in pattern), 2) + value = int("".join("0" if b == "-" else b for b in pattern), 2) + gen_checks.append(f"({gen_test} & {mask}) == {value}") + else: + value = int(pattern, 2) + gen_checks.append(f"{gen_test} == {value}") + if index == 0: + self.emitter.append(f"if {' or '.join(gen_checks)}:") + else: + self.emitter.append(f"elif {' or '.join(gen_checks)}:") + with self.emitter.indent(): + self(stmts) + + def on_Assert(self, stmt): + raise NotImplementedError # :nocov: + + def on_Assume(self, stmt): + raise NotImplementedError # :nocov: + + def on_Cover(self, stmt): + raise NotImplementedError # :nocov: + + @classmethod + def compile(cls, state, stmt, *, inputs=None, outputs=None): + output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()] + emitter = _Emitter() + for signal_index in output_indexes: + emitter.append(f"next_{signal_index} = slots[{signal_index}].next") + compiler = cls(state, emitter, inputs=inputs, outputs=outputs) + compiler(stmt) + for signal_index in output_indexes: + emitter.append(f"slots[{signal_index}].set(next_{signal_index})") + return emitter.flush() + + +class _CompiledProcess(_Process): + __slots__ = ("state", "comb", "run") + + def __init__(self, state, *, comb): + self.state = state + self.comb = comb + self.run = None # set by _FragmentCompiler + self.reset() + + def reset(self): + self.runnable = self.comb + self.passive = True + + +class _FragmentCompiler: + def __init__(self, state, signal_names): + self.state = state + self.signal_names = signal_names + + def __call__(self, fragment, *, hierarchy=("top",)): + processes = set() + + def add_signal_name(signal): + hierarchical_signal_name = (*hierarchy, signal.name) + if signal not in self.signal_names: + self.signal_names[signal] = {hierarchical_signal_name} else: - mask = "1" * len(value) - mask = int(mask, 2) - value = int(value, 2) - def make_test(mask, value): - return lambda test: test & mask == value - cases.append((make_test(mask, value), self.on_statements(stmts))) - def run(state): - test_value = test(state) - for check, body in cases: - if check(test_value): - body(state) + self.signal_names[signal].add(hierarchical_signal_name) + + for domain_name, domain_signals in fragment.drivers.items(): + domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements) + domain_process = _CompiledProcess(self.state, comb=domain_name is None) + + emitter = _Emitter() + emitter.append(f"def run():") + emitter._level += 1 + + if domain_name is None: + for signal in domain_signals: + signal_index = domain_process.state.get_signal(signal) + emitter.append(f"next_{signal_index} = {signal.reset}") + + inputs = SignalSet() + _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts) + + for input in inputs: + self.state.add_trigger(domain_process, input) + + else: + domain = fragment.domains[domain_name] + add_signal_name(domain.clk) + if domain.rst is not None: + add_signal_name(domain.rst) + + clk_trigger = 1 if domain.clk_edge == "pos" else 0 + self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger) + if domain.rst is not None and domain.async_reset: + rst_trigger = 1 + self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger) + + gen_asserts = [] + clk_index = domain_process.state.get_signal(domain.clk) + gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}") + if domain.rst is not None and domain.async_reset: + rst_index = domain_process.state.get_signal(domain.rst) + gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}") + emitter.append(f"assert {' or '.join(gen_asserts)}") + + for signal in domain_signals: + signal_index = domain_process.state.get_signal(signal) + emitter.append(f"next_{signal_index} = slots[{signal_index}].next") + + _StatementCompiler(domain_process.state, emitter)(domain_stmts) + + for signal in domain_signals: + signal_index = domain_process.state.get_signal(signal) + emitter.append(f"slots[{signal_index}].set(next_{signal_index})") + + # There shouldn't be any exceptions raised by the generated code, but if there are + # (almost certainly due to a bug in the code generator), use this environment variable + # to make backtraces useful. + code = emitter.flush() + if os.getenv("NMIGEN_pysim_dump"): + file = tempfile.NamedTemporaryFile("w", prefix="nmigen_pysim_", delete=False) + file.write(code) + filename = file.name + else: + filename = "" + + exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers} + exec(compile(code, filename, "exec"), exec_locals) + domain_process.run = exec_locals["run"] + + processes.add(domain_process) + + for used_signal in domain_process.state.signals: + add_signal_name(used_signal) + + for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments): + if subfragment_name is None: + subfragment_name = "U${}".format(subfragment_index) + processes.update(self(subfragment, hierarchy=(*hierarchy, subfragment_name))) + + return processes + + +class _CoroutineProcess(_Process): + def __init__(self, state, domains, constructor, *, default_cmd=None): + self.state = state + self.domains = domains + self.constructor = constructor + self.default_cmd = default_cmd + self.reset() + + def reset(self): + self.runnable = True + self.passive = False + self.coroutine = self.constructor() + self.exec_locals = { + "slots": self.state.slots, + "result": None, + **_ValueCompiler.helpers + } + self.waits_on = set() + + def src_loc(self): + coroutine = self.coroutine + while coroutine.gi_yieldfrom is not None: + coroutine = coroutine.gi_yieldfrom + if inspect.isgenerator(coroutine): + frame = coroutine.gi_frame + if inspect.iscoroutine(coroutine): + frame = coroutine.cr_frame + return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame)) + + def run(self): + if self.coroutine is None: + return + + if self.waits_on: + for signal in self.waits_on: + self.state.remove_trigger(self, signal) + self.waits_on.clear() + + response = None + while True: + try: + command = self.coroutine.send(response) + if command is None: + command = self.default_cmd + response = None + + if isinstance(command, Value): + exec(_RHSValueCompiler.compile(self.state, command, mode="curr"), + self.exec_locals) + response = Const.normalize(self.exec_locals["result"], command.shape()) + + elif isinstance(command, Statement): + exec(_StatementCompiler.compile(self.state, command), + self.exec_locals) + + elif type(command) is Tick: + domain = command.domain + if isinstance(domain, ClockDomain): + pass + elif domain in self.domains: + domain = self.domains[domain] + else: + raise NameError("Received command {!r} that refers to a nonexistent " + "domain {!r} from process {!r}" + .format(command, command.domain, self.src_loc())) + self.state.add_trigger(self, domain.clk, + trigger=1 if domain.clk_edge == "pos" else 0) + if domain.rst is not None and domain.async_reset: + self.state.add_trigger(self, domain.rst, trigger=1) return - return run - def on_statements(self, stmts): - stmts = [self.on_statement(stmt) for stmt in stmts] - def run(state): - for stmt in stmts: - stmt(state) - return run + elif type(command) is Settle: + self.state.deadlines[self] = None + return + elif type(command) is Delay: + if command.interval is None: + self.state.deadlines[self] = None + else: + self.state.deadlines[self] = self.state.timestamp + command.interval + return -class Simulator: - def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()): - self._fragment = fragment - - self._domains = dict() # str/domain -> ClockDomain - self._domain_triggers = SignalDict() # Signal -> str/domain - self._domain_signals = dict() # str/domain -> {Signal} - - self._signals = SignalSet() # {Signal} - self._comb_signals = SignalSet() # {Signal} - self._sync_signals = SignalSet() # {Signal} - self._user_signals = SignalSet() # {Signal} - - self._started = False - self._timestamp = 0. - self._delta = 0. - self._epsilon = 1e-10 - self._fastest_clock = self._epsilon - self._state = _State() - - self._processes = set() # {process} - self._process_loc = dict() # process -> str/loc - self._passive = set() # {process} - self._suspended = set() # {process} - self._wait_deadline = dict() # process -> float/timestamp - self._wait_tick = dict() # process -> str/domain - - self._funclets = SignalDict() # Signal -> set(lambda) - - self._vcd_file = vcd_file - self._vcd_writer = None - self._vcd_signals = SignalDict() # signal -> set(vcd_signal) - self._vcd_names = SignalDict() # signal -> str/name - self._gtkw_file = gtkw_file - self._traces = traces + elif type(command) is Passive: + self.passive = True - @staticmethod - def _check_process(process): - if inspect.isgeneratorfunction(process): - process = process() - if not inspect.isgenerator(process): - raise TypeError("Cannot add a process '{!r}' because it is not a generator or" - "a generator function" + elif type(command) is Active: + self.passive = False + + elif command is None: # only possible if self.default_cmd is None + raise TypeError("Received default command from process {!r} that was added " + "with add_process(); did you mean to add this process with " + "add_sync_process() instead?" + .format(self.src_loc())) + + else: + raise TypeError("Received unsupported command {!r} from process {!r}" + .format(command, self.src_loc())) + + except StopIteration: + self.passive = True + self.coroutine = None + return + + except Exception as exn: + self.coroutine.throw(exn) + + +class Simulator: + def __init__(self, fragment): + self._state = _SimulatorState() + self._signal_names = SignalDict() + self._fragment = Fragment.get(fragment, platform=None).prepare() + self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment) + self._clocked = set() + self._waveform_writers = [] + + def _check_process(self, process): + if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)): + raise TypeError("Cannot add a process {!r} because it is not a generator function" .format(process)) return process - def _name_process(self, process): - if process in self._process_loc: - return self._process_loc[process] - else: - frame = process.gi_frame - return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame)) + def _add_coroutine_process(self, process, *, default_cmd): + self._processes.add(_CoroutineProcess(self._state, self._fragment.domains, process, + default_cmd=default_cmd)) def add_process(self, process): process = self._check_process(process) - self._processes.add(process) + def wrapper(): + # Only start a bench process after comb settling, so that the reset values are correct. + yield Settle() + yield from process() + self._add_coroutine_process(wrapper, default_cmd=None) - def add_sync_process(self, process, domain="sync"): + def add_sync_process(self, process, *, domain="sync"): process = self._check_process(process) - def sync_process(): - try: - result = None - while True: - self._process_loc[sync_process] = self._name_process(process) - cmd = process.send(result) - if cmd is None: - cmd = Tick(domain) - result = yield cmd - except StopIteration: - pass - sync_process = sync_process() - self.add_process(sync_process) - - def add_clock(self, period, phase=None, domain="sync"): - if self._fastest_clock == self._epsilon or period < self._fastest_clock: - self._fastest_clock = period + def wrapper(): + # Only start a sync process after the first clock edge (or reset edge, if the domain + # uses an asynchronous reset). This matches the behavior of synchronous FFs. + yield Tick(domain) + yield from process() + return self._add_coroutine_process(wrapper, default_cmd=Tick(domain)) + + def add_clock(self, period, *, phase=None, domain="sync", if_exists=False): + """Add a clock process. + + Adds a process that drives the clock signal of ``domain`` at a 50% duty cycle. + + Arguments + --------- + period : float + Clock period. The process will toggle the ``domain`` clock signal every ``period / 2`` + seconds. + phase : None or float + Clock phase. The process will wait ``phase`` seconds before the first clock transition. + If not specified, defaults to ``period / 2``. + domain : str or ClockDomain + Driven clock domain. If specified as a string, the domain with that name is looked up + in the root fragment of the simulation. + if_exists : bool + If ``False`` (the default), raise an error if the driven domain is specified as + a string and the root fragment does not have such a domain. If ``True``, do nothing + in this case. + """ + if isinstance(domain, ClockDomain): + pass + elif domain in self._fragment.domains: + domain = self._fragment.domains[domain] + elif if_exists: + return + else: + raise ValueError("Domain {!r} is not present in simulation" + .format(domain)) + if domain in self._clocked: + raise ValueError("Domain {!r} already has a clock driving it" + .format(domain.name)) half_period = period / 2 if phase is None: + # By default, delay the first edge by half period. This causes any synchronous activity + # to happen at a non-zero time, distinguishing it from the reset values in the waveform + # viewer. phase = half_period - clk = self._domains[domain].clk def clk_process(): yield Passive() yield Delay(phase) + # Behave correctly if the process is added after the clock signal is manipulated, or if + # its reset state is high. + initial = (yield domain.clk) + steps = ( + domain.clk.eq(~initial), + Delay(half_period), + domain.clk.eq(initial), + Delay(half_period), + ) while True: - yield clk.eq(1) - yield Delay(half_period) - yield clk.eq(0) - yield Delay(half_period) - self.add_process(clk_process) - - def __enter__(self): - if self._vcd_file: - self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps", - comment="Generated by nMigen") - - root_fragment = self._fragment.prepare() - - self._domains = root_fragment.domains - for domain, cd in self._domains.items(): - self._domain_triggers[cd.clk] = domain - if cd.rst is not None: - self._domain_triggers[cd.rst] = domain - self._domain_signals[domain] = SignalSet() - - hierarchy = {} - def add_fragment(fragment, scope=()): - hierarchy[fragment] = scope - for subfragment, name in fragment.subfragments: - add_fragment(subfragment, (*scope, name)) - add_fragment(root_fragment) - - for fragment, fragment_scope in hierarchy.items(): - for signal in fragment.iter_signals(): - self._signals.add(signal) - - self._state.curr[signal] = self._state.next[signal] = \ - normalize(signal.reset, signal.shape()) - self._state.curr_dirty.add(signal) - - if not self._vcd_writer: - continue - - if signal not in self._vcd_signals: - self._vcd_signals[signal] = set() - - for subfragment, name in fragment.subfragments: - if signal in subfragment.ports: - var_name = "{}_{}".format(name, signal.name) - break - else: - var_name = signal.name + yield from iter(steps) + self._add_coroutine_process(clk_process, default_cmd=None) + self._clocked.add(domain) + + def reset(self): + """Reset the simulation. + + Assign the reset value to every signal in the simulation, and restart every user process. + """ + self._state.reset() + for process in self._processes: + process.reset() + + def _real_step(self): + """Step the simulation. + + Run every process and commit changes until a fixed point is reached. If there is + an unstable combinatorial loop, this function will never return. + """ + # Performs the two phases of a delta cycle in a loop: + converged = False + while not converged: + # 1. eval: run and suspend every non-waiting process once, queueing signal changes + for process in self._processes: + if process.runnable: + process.runnable = False + process.run() + + for waveform_writer in self._waveform_writers: + for signal_state in self._state.pending: + waveform_writer.update(self._state.timestamp, + signal_state.signal, signal_state.curr) + + # 2. commit: apply every queued signal change, waking up any waiting processes + converged = self._state.commit() + + # TODO(nmigen-0.4): replace with _real_step + @deprecated("instead of `sim.step()`, use `sim.advance()`") + def step(self): + return self.advance() + + def advance(self): + """Advance the simulation. + + Run every process and commit changes until a fixed point is reached, then advance time + to the closest deadline (if any). If there is an unstable combinatorial loop, + this function will never return. + + Returns ``True`` if there are any active processes, ``False`` otherwise. + """ + self._real_step() + self._state.advance() + return any(not process.passive for process in self._processes) - if signal.decoder: - var_type = "string" - var_size = 1 - var_init = signal.decoder(signal.reset).replace(" ", "_") - else: - var_type = "wire" - var_size = signal.nbits - var_init = signal.reset + def run(self): + """Run the simulation while any processes are active. - suffix = None - while True: - try: - if suffix is None: - var_name_suffix = var_name - else: - var_name_suffix = "{}${}".format(var_name, suffix) - self._vcd_signals[signal].add(self._vcd_writer.register_var( - scope=".".join(fragment_scope), name=var_name_suffix, - var_type=var_type, size=var_size, init=var_init)) - if signal not in self._vcd_names: - self._vcd_names[signal] = ".".join(fragment_scope + (var_name_suffix,)) - break - except KeyError: - suffix = (suffix or 0) + 1 + Processes added with :meth:`add_process` and :meth:`add_sync_process` are initially active, + and may change their status using the ``yield Passive()`` and ``yield Active()`` commands. + Processes compiled from HDL and added with :meth:`add_clock` are always passive. + """ + while self.advance(): + pass - for domain, signals in fragment.drivers.items(): - if domain is None: - self._comb_signals.update(signals) - else: - self._sync_signals.update(signals) - self._domain_signals[domain].update(signals) - - statements = [] - for signal in fragment.iter_comb(): - statements.append(signal.eq(signal.reset)) - for domain, signal in fragment.iter_sync(): - statements.append(signal.eq(signal)) - statements += fragment.statements - - compiler = _StatementCompiler() - funclet = compiler(statements) - - def add_funclet(signal, funclet): - if signal not in self._funclets: - self._funclets[signal] = set() - self._funclets[signal].add(funclet) - - for signal in compiler.sensitivity: - add_funclet(signal, funclet) - for domain, cd in fragment.domains.items(): - add_funclet(cd.clk, funclet) - if cd.rst is not None: - add_funclet(cd.rst, funclet) - - self._user_signals = self._signals - self._comb_signals - self._sync_signals - - return self - - def _update_dirty_signals(self): - """Perform the statement part of IR processes (aka RTLIL case).""" - # First, for all dirty signals, use sensitivity lists to determine the set of fragments - # that need their statements to be reevaluated because the signals changed at the previous - # delta cycle. - funclets = set() - while self._state.curr_dirty: - signal = self._state.curr_dirty.pop() - if signal in self._funclets: - funclets.update(self._funclets[signal]) - - # Second, compute the values of all signals at the start of the next delta cycle, by - # running precompiled statements. - for funclet in funclets: - funclet(self._state) - - def _commit_signal(self, signal, domains): - """Perform the driver part of IR processes (aka RTLIL sync), for individual signals.""" - # Take the computed value (at the start of this delta cycle) of a signal (that could have - # come from an IR process that ran earlier, or modified by a simulator process) and update - # the value for this delta cycle. - old, new = self._state.commit(signal) - - # If the signal is a clock that triggers synchronous logic, record that fact. - if (old, new) == (0, 1) and signal in self._domain_triggers: - domains.add(self._domain_triggers[signal]) - - if self._vcd_writer and old != new: - # Finally, dump the new value to the VCD file. - for vcd_signal in self._vcd_signals[signal]: - if signal.decoder: - var_value = signal.decoder(new).replace(" ", "_") - else: - var_value = new - vcd_timestamp = (self._timestamp + self._delta) / self._epsilon - self._vcd_writer.change(vcd_signal, vcd_timestamp, var_value) - - def _commit_comb_signals(self, domains): - """Perform the comb part of IR processes (aka RTLIL always).""" - # Take the computed value (at the start of this delta cycle) of every comb signal and - # update the value for this delta cycle. - for signal in self._state.next_dirty: - if signal in self._comb_signals: - self._commit_signal(signal, domains) - - def _commit_sync_signals(self, domains): - """Perform the sync part of IR processes (aka RTLIL posedge).""" - # At entry, `domains` contains a list of every simultaneously triggered sync update. - while domains: - # Advance the timeline a bit (purely for observational purposes) and commit all of them - # at the same timestamp. - self._delta += self._epsilon - curr_domains, domains = domains, set() - - while curr_domains: - domain = curr_domains.pop() - - # Take the computed value (at the start of this delta cycle) of every sync signal - # in this domain and update the value for this delta cycle. This can trigger more - # synchronous logic, so record that. - for signal in self._state.next_dirty: - if signal in self._domain_signals[domain]: - self._commit_signal(signal, domains) - - # Wake up any simulator processes that wait for a domain tick. - for process, wait_domain in list(self._wait_tick.items()): - if domain == wait_domain: - del self._wait_tick[process] - self._suspended.remove(process) - - # Immediately run the process. It is important that this happens here, - # and not on the next step, when all the processes will run anyway, - # because Tick() simulates an edge triggered process. Like DFFs that latch - # a value from the previous clock cycle, simulator processes observe signal - # values from the previous clock cycle on a tick, too. - self._run_process(process) - - # Unless handling synchronous logic above has triggered more synchronous logic (which - # can happen e.g. if a domain is clocked off a clock divisor in fabric), we're done. - # Otherwise, do one more round of updates. - - def _run_process(self, process): - try: - cmd = process.send(None) - while True: - if isinstance(cmd, Delay): - if cmd.interval is None: - interval = self._epsilon - else: - interval = cmd.interval - self._wait_deadline[process] = self._timestamp + interval - self._suspended.add(process) - break - - elif isinstance(cmd, Tick): - self._wait_tick[process] = cmd.domain - self._suspended.add(process) - break - - elif isinstance(cmd, Passive): - self._passive.add(process) - - elif isinstance(cmd, Value): - compiler = _RHSValueCompiler() - funclet = compiler(cmd) - cmd = process.send(funclet(self._state)) - continue - - elif isinstance(cmd, Assign): - lhs_signals = cmd.lhs._lhs_signals() - for signal in lhs_signals: - if not signal in self._signals: - raise ValueError("Process '{}' sent a request to set signal '{!r}', " - "which is not a part of simulation" - .format(self._name_process(process), signal)) - if signal in self._comb_signals: - raise ValueError("Process '{}' sent a request to set signal '{!r}', " - "which is a part of combinatorial assignment in " - "simulation" - .format(self._name_process(process), signal)) - - compiler = _StatementCompiler() - funclet = compiler(cmd) - funclet(self._state) - - domains = set() - for signal in lhs_signals: - self._commit_signal(signal, domains) - self._commit_sync_signals(domains) + def run_until(self, deadline, *, run_passive=False): + """Run the simulation until it advances to ``deadline``. - else: - raise TypeError("Received unsupported command '{!r}' from process '{}'" - .format(cmd, self._name_process(process))) - - cmd = process.send(None) - - except StopIteration: - self._processes.remove(process) - self._passive.discard(process) - - except Exception as e: - process.throw(e) - - def step(self, run_passive=False): - # Are there any delta cycles we should run? - if self._state.curr_dirty: - # We might run some delta cycles, and we have simulator processes waiting on - # a deadline. Take care to not exceed the closest deadline. - if self._wait_deadline and \ - (self._timestamp + self._delta) >= min(self._wait_deadline.values()): - # Oops, we blew the deadline. We *could* run the processes now, but this is - # virtually certainly a logic loop and a design bug, so bail out instead.d - raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?") - - domains = set() - while self._state.curr_dirty: - self._update_dirty_signals() - self._commit_comb_signals(domains) - self._commit_sync_signals(domains) - return True - - # Are there any processes that haven't had a chance to run yet? - if len(self._processes) > len(self._suspended): - # Schedule an arbitrary one. - process = (self._processes - set(self._suspended)).pop() - self._run_process(process) - return True - - # All processes are suspended. Are any of them active? - if len(self._processes) > len(self._passive) or run_passive: - # Are any of them suspended before a deadline? - if self._wait_deadline: - # Schedule the one with the lowest deadline. - process, deadline = min(self._wait_deadline.items(), key=lambda x: x[1]) - del self._wait_deadline[process] - self._suspended.remove(process) - self._timestamp = deadline - self._delta = 0. - self._run_process(process) - return True - - # No processes, or all processes are passive. Nothing to do! - return False + If ``run_passive`` is ``False``, the simulation also stops when there are no active + processes, similar to :meth:`run`. Otherwise, the simulation will stop only after it + advances to or past ``deadline``. - def run(self): - while self.step(): + If the simulation stops advancing, this function will never return. + """ + assert self._state.timestamp <= deadline + while (self.advance() or run_passive) and self._state.timestamp < deadline: pass - def run_until(self, deadline, run_passive=False): - while self._timestamp < deadline: - if not self.step(run_passive): - return False - - return True - - def __exit__(self, *args): - if self._vcd_writer: - vcd_timestamp = (self._timestamp + self._delta) / self._epsilon - self._vcd_writer.close(vcd_timestamp) - - if self._vcd_file and self._gtkw_file: - gtkw_save = GTKWSave(self._gtkw_file) - if hasattr(self._vcd_file, "name"): - gtkw_save.dumpfile(self._vcd_file.name) - if hasattr(self._vcd_file, "tell"): - gtkw_save.dumpfile_size(self._vcd_file.tell()) - - gtkw_save.treeopen("top") - gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14) - - def add_trace(signal, **kwargs): - if signal in self._vcd_names: - if len(signal) > 1: - suffix = "[{}:0]".format(len(signal) - 1) - else: - suffix = "" - gtkw_save.trace(self._vcd_names[signal] + suffix, **kwargs) - - for domain, cd in self._domains.items(): - with gtkw_save.group("d.{}".format(domain)): - if cd.rst is not None: - add_trace(cd.rst) - add_trace(cd.clk) - - for signal in self._traces: - add_trace(signal) - - if self._vcd_file: - self._vcd_file.close() - if self._gtkw_file: - self._gtkw_file.close() + @contextmanager + def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()): + """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file. + + This method returns a context manager. It can be used as: :: + + sim = Simulator(frag) + sim.add_clock(1e-6) + with sim.write_vcd("dump.vcd", "dump.gtkw"): + sim.run_until(1e-3) + + Arguments + --------- + vcd_file : str or file-like object + Verilog Value Change Dump file or filename. + gtkw_file : str or file-like object + GTKWave save file or filename. + traces : iterable of Signal + Signals to display traces for. + """ + if self._state.timestamp != 0.0: + raise ValueError("Cannot start writing waveforms after advancing simulation time") + waveform_writer = _VCDWaveformWriter(self._signal_names, + vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces) + self._waveform_writers.append(waveform_writer) + yield + waveform_writer.close(self._state.timestamp) + self._waveform_writers.remove(waveform_writer)