back.pysim: simplify. NFC.
[nmigen.git] / nmigen / back / pysim.py
index f0fa2da99ec890ff328823fb81b8c4fd4e6a9a72..a9c6c162ac09cc5e9953a6f10e609c3cd69e961f 100644 (file)
@@ -132,7 +132,8 @@ class _VCDWaveformWriter(_WaveformWriter):
                     self.gtkw_names[signal] = (*var_scope, var_name_suffix)
 
     def update(self, timestamp, signal, value):
-        if signal not in self.vcd_vars:
+        vcd_vars = self.vcd_vars.get(signal)
+        if vcd_vars is None:
             return
 
         vcd_timestamp = self.timestamp_to_vcd(timestamp)
@@ -140,7 +141,7 @@ class _VCDWaveformWriter(_WaveformWriter):
             var_value = self.decode_to_vcd(signal, value)
         else:
             var_value = value
-        for vcd_var in self.vcd_vars[signal]:
+        for vcd_var in vcd_vars:
             self.vcd_writer.change(vcd_var, vcd_timestamp, var_value)
 
     def close(self, timestamp):
@@ -174,10 +175,6 @@ class _Process:
     def run(self):
         raise NotImplementedError # :nocov:
 
-    @property
-    def name(self):
-        raise NotImplementedError # :nocov:
-
 
 class _SignalState:
     __slots__ = ("signal", "curr", "next", "waiters", "pending")
@@ -186,10 +183,7 @@ class _SignalState:
         self.signal = signal
         self.pending = pending
         self.waiters = dict()
-        self.reset()
-
-    def reset(self):
-        self.curr = self.next = self.signal.reset
+        self.curr = self.next = signal.reset
 
     def set(self, value):
         if self.next == value:
@@ -197,17 +191,11 @@ class _SignalState:
         self.next = value
         self.pending.add(self)
 
-    def wait(self, task, *, trigger=None):
-        assert task not in self.waiters
-        self.waiters[task] = trigger
-
     def commit(self):
         if self.curr == self.next:
             return False
         self.curr = self.next
-        return True
 
-    def wakeup(self):
         awoken_any = False
         for process, trigger in self.waiters.items():
             if trigger is None or trigger == self.curr:
@@ -218,39 +206,47 @@ class _SignalState:
 class _SimulatorState:
     def __init__(self):
         self.signals = SignalDict()
+        self.slots   = []
         self.pending = set()
 
         self.timestamp = 0.0
         self.deadlines = dict()
 
-        self.waveform_writer = None
-
     def reset(self):
-        for signal_state in self.signals.values():
-            signal_state.reset()
+        for signal, index in self.signals.items():
+            self.slots[index].curr = self.slots[index].next = signal.reset
         self.pending.clear()
 
         self.timestamp = 0.0
         self.deadlines.clear()
 
-    def for_signal(self, signal):
+    def get_signal(self, signal):
         try:
             return self.signals[signal]
         except KeyError:
-            signal_state = _SignalState(signal, self.pending)
-            self.signals[signal] = signal_state
-            return signal_state
+            index = len(self.slots)
+            self.slots.append(_SignalState(signal, self.pending))
+            self.signals[signal] = index
+            return index
+
+    def add_trigger(self, process, signal, *, trigger=None):
+        index = self.get_signal(signal)
+        assert (process not in self.slots[index].waiters or
+                self.slots[index].waiters[process] == trigger)
+        self.slots[index].waiters[process] = trigger
+
+    def remove_trigger(self, process, signal):
+        index = self.get_signal(signal)
+        assert process in self.slots[index].waiters
+        del self.slots[index].waiters[process]
 
     def commit(self):
-        awoken_any = False
+        converged = True
         for signal_state in self.pending:
             if signal_state.commit():
-                if signal_state.wakeup():
-                    awoken_any = True
-                if self.waveform_writer is not None:
-                    self.waveform_writer.update(self.timestamp,
-                        signal_state.signal, signal_state.curr)
-        return awoken_any
+                converged = False
+        self.pending.clear()
+        return converged
 
     def advance(self):
         nearest_processes = set()
@@ -279,46 +275,6 @@ class _SimulatorState:
 
         return True
 
-    def start_waveform(self, waveform_writer):
-        if self.timestamp != 0.0:
-            raise ValueError("Cannot start writing waveforms after advancing simulation time")
-        if self.waveform_writer is not None:
-            raise ValueError("Already writing waveforms to {!r}"
-                             .format(self.waveform_writer))
-        self.waveform_writer = waveform_writer
-
-    def finish_waveform(self):
-        if self.waveform_writer is None:
-            return
-        self.waveform_writer.close(self.timestamp)
-        self.waveform_writer = None
-
-
-class _EvalContext:
-    __slots__ = ("state", "indexes", "slots")
-
-    def __init__(self, state):
-        self.state = state
-        self.indexes = SignalDict()
-        self.slots = []
-
-    def get_signal(self, signal):
-        try:
-            return self.indexes[signal]
-        except KeyError:
-            index = len(self.slots)
-            self.slots.append(self.state.for_signal(signal))
-            self.indexes[signal] = index
-            return index
-
-    def get_in_signal(self, signal, *, trigger=None):
-        index = self.get_signal(signal)
-        self.slots[index].waiters[self] = trigger
-        return index
-
-    def get_out_signal(self, signal):
-        return self.get_signal(signal)
-
 
 class _Emitter:
     def __init__(self):
@@ -354,8 +310,8 @@ class _Emitter:
 
 
 class _Compiler:
-    def __init__(self, context, emitter):
-        self.context = context
+    def __init__(self, state, emitter):
+        self.state = state
         self.emitter = emitter
 
 
@@ -372,9 +328,6 @@ class _ValueCompiler(ValueVisitor, _Compiler):
     def on_ResetSignal(self, value):
         raise NotImplementedError # :nocov:
 
-    def on_Record(self, value):
-        return self(Cat(value.fields.values()))
-
     def on_AnyConst(self, value):
         raise NotImplementedError # :nocov:
 
@@ -389,8 +342,8 @@ class _ValueCompiler(ValueVisitor, _Compiler):
 
 
 class _RHSValueCompiler(_ValueCompiler):
-    def __init__(self, context, emitter, *, mode, inputs=None):
-        super().__init__(context, emitter)
+    def __init__(self, state, emitter, *, mode, inputs=None):
+        super().__init__(state, emitter)
         assert mode in ("curr", "next")
         self.mode = mode
         # If not None, `inputs` gets populated with RHS signals.
@@ -404,9 +357,9 @@ class _RHSValueCompiler(_ValueCompiler):
             self.inputs.add(value)
 
         if self.mode == "curr":
-            return f"slots[{self.context.get_signal(value)}].{self.mode}"
+            return f"slots[{self.state.get_signal(value)}].{self.mode}"
         else:
-            return f"next_{self.context.get_signal(value)}"
+            return f"next_{self.state.get_signal(value)}"
 
     def on_Operator(self, value):
         def mask(value):
@@ -532,22 +485,22 @@ class _RHSValueCompiler(_ValueCompiler):
             return f"0"
 
     @classmethod
-    def compile(cls, context, value, *, mode, inputs=None):
+    def compile(cls, state, value, *, mode, inputs=None):
         emitter = _Emitter()
-        compiler = cls(context, emitter, mode=mode, inputs=inputs)
+        compiler = cls(state, emitter, mode=mode, inputs=inputs)
         emitter.append(f"result = {compiler(value)}")
         return emitter.flush()
 
 
 class _LHSValueCompiler(_ValueCompiler):
-    def __init__(self, context, emitter, *, rhs, outputs=None):
-        super().__init__(context, emitter)
+    def __init__(self, state, emitter, *, rhs, outputs=None):
+        super().__init__(state, emitter)
         # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
         # the offset of a Part.
         self.rrhs = rhs
         # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
         # update of an lvalue.
-        self.lrhs = _RHSValueCompiler(context, emitter, mode="next", inputs=None)
+        self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
         # If not None, `outputs` gets populated with signals on LHS.
         self.outputs = outputs
 
@@ -564,7 +517,7 @@ class _LHSValueCompiler(_ValueCompiler):
                 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
             else: # unsigned
                 value_sign = f"{arg} & {value_mask}"
-            self.emitter.append(f"next_{self.context.get_out_signal(value)} = {value_sign}")
+            self.emitter.append(f"next_{self.state.get_signal(value)} = {value_sign}")
         return gen
 
     def on_Operator(self, value):
@@ -623,18 +576,18 @@ class _LHSValueCompiler(_ValueCompiler):
         return gen
 
     @classmethod
-    def compile(cls, context, stmt, *, inputs=None, outputs=None):
+    def compile(cls, state, stmt, *, inputs=None, outputs=None):
         emitter = _Emitter()
-        compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
+        compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
         compiler(stmt)
         return emitter.flush()
 
 
 class _StatementCompiler(StatementVisitor, _Compiler):
-    def __init__(self, context, emitter, *, inputs=None, outputs=None):
-        super().__init__(context, emitter)
-        self.rhs = _RHSValueCompiler(context, emitter, mode="curr", inputs=inputs)
-        self.lhs = _LHSValueCompiler(context, emitter, rhs=self.rhs, outputs=outputs)
+    def __init__(self, state, emitter, *, inputs=None, outputs=None):
+        super().__init__(state, emitter)
+        self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
+        self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
 
     def on_statements(self, stmts):
         for stmt in stmts:
@@ -678,12 +631,12 @@ class _StatementCompiler(StatementVisitor, _Compiler):
         raise NotImplementedError # :nocov:
 
     @classmethod
-    def compile(cls, context, stmt, *, inputs=None, outputs=None):
-        output_indexes = [context.get_signal(signal) for signal in stmt._lhs_signals()]
+    def compile(cls, state, stmt, *, inputs=None, outputs=None):
+        output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
         emitter = _Emitter()
         for signal_index in output_indexes:
             emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
-        compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
+        compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
         compiler(stmt)
         for signal_index in output_indexes:
             emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
@@ -691,12 +644,11 @@ class _StatementCompiler(StatementVisitor, _Compiler):
 
 
 class _CompiledProcess(_Process):
-    __slots__ = ("context", "comb", "name", "run")
+    __slots__ = ("state", "comb", "run")
 
-    def __init__(self, state, *, comb, name):
-        self.context = _EvalContext(state)
+    def __init__(self, state, *, comb):
+        self.state = state
         self.comb = comb
-        self.name = name
         self.run = None # set by _FragmentCompiler
         self.reset()
 
@@ -722,8 +674,7 @@ class _FragmentCompiler:
 
         for domain_name, domain_signals in fragment.drivers.items():
             domain_stmts = LHSGroupFilter(domain_signals)(fragment.statements)
-            domain_process = _CompiledProcess(self.state, comb=domain_name is None,
-                name=".".join((*hierarchy, "<{}>".format(domain_name or "comb"))))
+            domain_process = _CompiledProcess(self.state, comb=domain_name is None)
 
             emitter = _Emitter()
             emitter.append(f"def run():")
@@ -731,14 +682,14 @@ class _FragmentCompiler:
 
             if domain_name is None:
                 for signal in domain_signals:
-                    signal_index = domain_process.context.get_signal(signal)
+                    signal_index = domain_process.state.get_signal(signal)
                     emitter.append(f"next_{signal_index} = {signal.reset}")
 
                 inputs = SignalSet()
-                _StatementCompiler(domain_process.context, emitter, inputs=inputs)(domain_stmts)
+                _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts)
 
                 for input in inputs:
-                    self.state.for_signal(input).wait(domain_process)
+                    self.state.add_trigger(domain_process, input)
 
             else:
                 domain = fragment.domains[domain_name]
@@ -747,27 +698,27 @@ class _FragmentCompiler:
                     add_signal_name(domain.rst)
 
                 clk_trigger = 1 if domain.clk_edge == "pos" else 0
-                self.state.for_signal(domain.clk).wait(domain_process, trigger=clk_trigger)
+                self.state.add_trigger(domain_process, domain.clk, trigger=clk_trigger)
                 if domain.rst is not None and domain.async_reset:
                     rst_trigger = 1
-                    self.state.for_signal(domain.rst).wait(domain_process, trigger=rst_trigger)
+                    self.state.add_trigger(domain_process, domain.rst, trigger=rst_trigger)
 
                 gen_asserts = []
-                clk_index = domain_process.context.get_signal(domain.clk)
+                clk_index = domain_process.state.get_signal(domain.clk)
                 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
                 if domain.rst is not None and domain.async_reset:
-                    rst_index = domain_process.context.get_signal(domain.rst)
+                    rst_index = domain_process.state.get_signal(domain.rst)
                     gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
                 emitter.append(f"assert {' or '.join(gen_asserts)}")
 
                 for signal in domain_signals:
-                    signal_index = domain_process.context.get_signal(signal)
+                    signal_index = domain_process.state.get_signal(signal)
                     emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
 
-                _StatementCompiler(domain_process.context, emitter)(domain_stmts)
+                _StatementCompiler(domain_process.state, emitter)(domain_stmts)
 
             for signal in domain_signals:
-                signal_index = domain_process.context.get_signal(signal)
+                signal_index = domain_process.state.get_signal(signal)
                 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
 
             # There shouldn't be any exceptions raised by the generated code, but if there are
@@ -781,13 +732,13 @@ class _FragmentCompiler:
             else:
                 filename = "<string>"
 
-            exec_locals = {"slots": domain_process.context.slots, **_ValueCompiler.helpers}
+            exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers}
             exec(compile(code, filename, "exec"), exec_locals)
             domain_process.run = exec_locals["run"]
 
             processes.add(domain_process)
 
-            for used_signal in domain_process.context.indexes:
+            for used_signal in domain_process.state.signals:
                 add_signal_name(used_signal)
 
         for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
@@ -810,16 +761,14 @@ class _CoroutineProcess(_Process):
         self.runnable = True
         self.passive = False
         self.coroutine = self.constructor()
-        self.eval_context = _EvalContext(self.state)
         self.exec_locals = {
-            "slots": self.eval_context.slots,
+            "slots": self.state.slots,
             "result": None,
             **_ValueCompiler.helpers
         }
         self.waits_on = set()
 
-    @property
-    def name(self):
+    def src_loc(self):
         coroutine = self.coroutine
         while coroutine.gi_yieldfrom is not None:
             coroutine = coroutine.gi_yieldfrom
@@ -829,20 +778,13 @@ class _CoroutineProcess(_Process):
             frame = coroutine.cr_frame
         return "{}:{}".format(inspect.getfile(frame), inspect.getlineno(frame))
 
-    def get_in_signal(self, signal, *, trigger=None):
-        signal_state = self.state.for_signal(signal)
-        assert self not in signal_state.waiters
-        signal_state.waiters[self] = trigger
-        self.waits_on.add(signal_state)
-        return signal_state
-
     def run(self):
         if self.coroutine is None:
             return
 
         if self.waits_on:
-            for signal_state in self.waits_on:
-                del signal_state.waiters[self]
+            for signal in self.waits_on:
+                self.state.remove_trigger(self, signal)
             self.waits_on.clear()
 
         response = None
@@ -854,12 +796,12 @@ class _CoroutineProcess(_Process):
                 response = None
 
                 if isinstance(command, Value):
-                    exec(_RHSValueCompiler.compile(self.eval_context, command, mode="curr"),
+                    exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
                         self.exec_locals)
                     response = Const.normalize(self.exec_locals["result"], command.shape())
 
                 elif isinstance(command, Statement):
-                    exec(_StatementCompiler.compile(self.eval_context, command),
+                    exec(_StatementCompiler.compile(self.state, command),
                         self.exec_locals)
 
                 elif type(command) is Tick:
@@ -871,10 +813,11 @@ class _CoroutineProcess(_Process):
                     else:
                         raise NameError("Received command {!r} that refers to a nonexistent "
                                         "domain {!r} from process {!r}"
-                                        .format(command, command.domain, self.name))
-                    self.get_in_signal(domain.clk, trigger=1 if domain.clk_edge == "pos" else 0)
+                                        .format(command, command.domain, self.src_loc()))
+                    self.state.add_trigger(self, domain.clk,
+                                           trigger=1 if domain.clk_edge == "pos" else 0)
                     if domain.rst is not None and domain.async_reset:
-                        self.get_in_signal(domain.rst, trigger=1)
+                        self.state.add_trigger(self, domain.rst, trigger=1)
                     return
 
                 elif type(command) is Settle:
@@ -898,11 +841,11 @@ class _CoroutineProcess(_Process):
                     raise TypeError("Received default command from process {!r} that was added "
                                     "with add_process(); did you mean to add this process with "
                                     "add_sync_process() instead?"
-                                    .format(self.name))
+                                    .format(self.src_loc()))
 
                 else:
                     raise TypeError("Received unsupported command {!r} from process {!r}"
-                                    .format(command, self.name))
+                                    .format(command, self.src_loc()))
 
             except StopIteration:
                 self.passive = True
@@ -913,45 +856,19 @@ class _CoroutineProcess(_Process):
                 self.coroutine.throw(exn)
 
 
-class _WaveformContextManager:
-    def __init__(self, state, waveform_writer):
-        self._state = state
-        self._waveform_writer = waveform_writer
-
-    def __enter__(self):
-        try:
-            self._state.start_waveform(self._waveform_writer)
-        except:
-            self._waveform_writer.close(0)
-            raise
-
-    def __exit__(self, *args):
-        self._state.finish_waveform()
-
-
 class Simulator:
-    def __init__(self, fragment, **kwargs):
+    def __init__(self, fragment):
         self._state = _SimulatorState()
         self._signal_names = SignalDict()
         self._fragment = Fragment.get(fragment, platform=None).prepare()
         self._processes = _FragmentCompiler(self._state, self._signal_names)(self._fragment)
-        if kwargs: # :nocov:
-            # TODO(nmigen-0.3): remove
-            self._state.start_waveform(_VCDWaveformWriter(self._signal_names, **kwargs))
         self._clocked = set()
+        self._waveform_writers = []
 
     def _check_process(self, process):
         if not (inspect.isgeneratorfunction(process) or inspect.iscoroutinefunction(process)):
-            if inspect.isgenerator(process) or inspect.iscoroutine(process):
-                warnings.warn("instead of generators, use generator functions as processes; "
-                              "this allows the simulator to be repeatedly reset",
-                              DeprecationWarning, stacklevel=3)
-                def wrapper():
-                    yield from process
-                return wrapper
-            else:
-                raise TypeError("Cannot add a process {!r} because it is not a generator function"
-                                .format(process))
+            raise TypeError("Cannot add a process {!r} because it is not a generator function"
+                            .format(process))
         return process
 
     def _add_coroutine_process(self, process, *, default_cmd):
@@ -1041,31 +958,36 @@ class Simulator:
         for process in self._processes:
             process.reset()
 
-    def _delta(self):
-        """Perform a delta cycle.
-
-        Performs the two phases of a delta cycle:
-            1. run and suspend every non-waiting process once, queueing signal changes;
-            2. commit every queued signal change, waking up any waiting process.
-        """
-        for process in self._processes:
-            if process.runnable:
-                process.runnable = False
-                process.run()
-
-        return self._state.commit()
-
-    def _settle(self):
-        """Settle the simulation.
+    def _real_step(self):
+        """Step the simulation.
 
         Run every process and commit changes until a fixed point is reached. If there is
         an unstable combinatorial loop, this function will never return.
         """
-        while self._delta():
-            pass
+        # Performs the two phases of a delta cycle in a loop:
+        converged = False
+        while not converged:
+            # 1. eval: run and suspend every non-waiting process once, queueing signal changes
+            for process in self._processes:
+                if process.runnable:
+                    process.runnable = False
+                    process.run()
+
+            for waveform_writer in self._waveform_writers:
+                for signal_state in self._state.pending:
+                    waveform_writer.update(self._state.timestamp,
+                        signal_state.signal, signal_state.curr)
 
+            # 2. commit: apply every queued signal change, waking up any waiting processes
+            converged = self._state.commit()
+
+    # TODO(nmigen-0.4): replace with _real_step
+    @deprecated("instead of `sim.step()`, use `sim.advance()`")
     def step(self):
-        """Step the simulation.
+        return self.advance()
+
+    def advance(self):
+        """Advance the simulation.
 
         Run every process and commit changes until a fixed point is reached, then advance time
         to the closest deadline (if any). If there is an unstable combinatorial loop,
@@ -1073,7 +995,7 @@ class Simulator:
 
         Returns ``True`` if there are any active processes, ``False`` otherwise.
         """
-        self._settle()
+        self._real_step()
         self._state.advance()
         return any(not process.passive for process in self._processes)
 
@@ -1084,7 +1006,7 @@ class Simulator:
         and may change their status using the ``yield Passive()`` and ``yield Active()`` commands.
         Processes compiled from HDL and added with :meth:`add_clock` are always passive.
         """
-        while self.step():
+        while self.advance():
             pass
 
     def run_until(self, deadline, *, run_passive=False):
@@ -1097,9 +1019,10 @@ class Simulator:
         If the simulation stops advancing, this function will never return.
         """
         assert self._state.timestamp <= deadline
-        while (self.step() or run_passive) and self._state.timestamp < deadline:
+        while (self.advance() or run_passive) and self._state.timestamp < deadline:
             pass
 
+    @contextmanager
     def write_vcd(self, vcd_file, gtkw_file=None, *, traces=()):
         """Write waveforms to a Value Change Dump file, optionally populating a GTKWave save file.
 
@@ -1119,16 +1042,11 @@ class Simulator:
         traces : iterable of Signal
             Signals to display traces for.
         """
+        if self._state.timestamp != 0.0:
+            raise ValueError("Cannot start writing waveforms after advancing simulation time")
         waveform_writer = _VCDWaveformWriter(self._signal_names,
             vcd_file=vcd_file, gtkw_file=gtkw_file, traces=traces)
-        return _WaveformContextManager(self._state, waveform_writer)
-
-    # TODO(nmigen-0.3): remove
-    @deprecated("instead of `with Simulator(fragment, ...) as sim:`, use "
-                "`sim = Simulator(fragment); with sim.write_vcd(...):`")
-    def __enter__(self): # :nocov:
-        return self
-
-    # TODO(nmigen-0.3): remove
-    def __exit__(self, *args): # :nocov:
-        self._state.finish_waveform()
+        self._waveform_writers.append(waveform_writer)
+        yield
+        waveform_writer.close(self._state.timestamp)
+        self._waveform_writers.remove(waveform_writer)