back.pysim: simplify.
authorwhitequark <whitequark@whitequark.org>
Sun, 28 Jun 2020 05:04:16 +0000 (05:04 +0000)
committerwhitequark <whitequark@whitequark.org>
Sun, 28 Jun 2020 05:04:16 +0000 (05:04 +0000)
Remove _EvalContext, which was a level of indirection serving almost
no purpose. (The only case where it would be useful is repeatedly
resetting a simulation that, each time it is reset, would create new
signals to communicate with between coroutine processes. In that case
the signal states would not be persisted in _SimulatorState, but
would be removed with the _EvalContext that is recreated each time
the simulation is reset. But this could be solved with a weak map
instead.)

This regresses simulator startup time by 10-15% for unknown reasons
but is necessary to align pysim and future cxxsim.

nmigen/back/pysim.py

index 3077991e24cb40b078ea15d3bc87e5604c8a26a4..fb651e4d04d939cd7fd74db3d3d9773dbf62cd30 100644 (file)
@@ -219,6 +219,7 @@ class _SignalState:
 class _SimulatorState:
     def __init__(self):
         self.signals = SignalDict()
+        self.slots   = []
         self.pending = set()
 
         self.timestamp = 0.0
@@ -227,20 +228,32 @@ class _SimulatorState:
         self.waveform_writer = None
 
     def reset(self):
-        for signal_state in self.signals.values():
-            signal_state.reset()
+        for signal, index in self.signals.items():
+            self.slots[index].curr = self.slots[index].next = signal.reset
         self.pending.clear()
 
         self.timestamp = 0.0
         self.deadlines.clear()
 
-    def for_signal(self, signal):
+    def get_signal(self, signal):
         try:
             return self.signals[signal]
         except KeyError:
-            signal_state = _SignalState(signal, self.pending)
-            self.signals[signal] = signal_state
-            return signal_state
+            index = len(self.slots)
+            self.slots.append(_SignalState(signal, self.pending))
+            self.signals[signal] = index
+            return index
+
+    def get_in_signal(self, signal, *, trigger=None):
+        index = self.get_signal(signal)
+        self.slots[index].waiters[self] = trigger
+        return index
+
+    def get_out_signal(self, signal):
+        return self.get_signal(signal)
+
+    def for_signal(self, signal):
+        return self.slots[self.get_signal(signal)]
 
     def commit(self):
         awoken_any = False
@@ -296,32 +309,6 @@ class _SimulatorState:
         self.waveform_writer = None
 
 
-class _EvalContext:
-    __slots__ = ("state", "indexes", "slots")
-
-    def __init__(self, state):
-        self.state = state
-        self.indexes = SignalDict()
-        self.slots = []
-
-    def get_signal(self, signal):
-        try:
-            return self.indexes[signal]
-        except KeyError:
-            index = len(self.slots)
-            self.slots.append(self.state.for_signal(signal))
-            self.indexes[signal] = index
-            return index
-
-    def get_in_signal(self, signal, *, trigger=None):
-        index = self.get_signal(signal)
-        self.slots[index].waiters[self] = trigger
-        return index
-
-    def get_out_signal(self, signal):
-        return self.get_signal(signal)
-
-
 class _Emitter:
     def __init__(self):
         self._buffer = []
@@ -356,8 +343,8 @@ class _Emitter:
 
 
 class _Compiler:
-    def __init__(self, context, emitter):
-        self.context = context
+    def __init__(self, state, emitter):
+        self.state = state
         self.emitter = emitter
 
 
@@ -388,8 +375,8 @@ class _ValueCompiler(ValueVisitor, _Compiler):
 
 
 class _RHSValueCompiler(_ValueCompiler):
-    def __init__(self, context, emitter, *, mode, inputs=None):
-        super().__init__(context, emitter)
+    def __init__(self, state, emitter, *, mode, inputs=None):
+        super().__init__(state, emitter)
         assert mode in ("curr", "next")
         self.mode = mode
         # If not None, `inputs` gets populated with RHS signals.
@@ -403,9 +390,9 @@ class _RHSValueCompiler(_ValueCompiler):
             self.inputs.add(value)
 
         if self.mode == "curr":
-            return f"slots[{self.context.get_signal(value)}].{self.mode}"
+            return f"slots[{self.state.get_signal(value)}].{self.mode}"
         else:
-            return f"next_{self.context.get_signal(value)}"
+            return f"next_{self.state.get_signal(value)}"
 
     def on_Operator(self, value):
         def mask(value):
@@ -531,22 +518,22 @@ class _RHSValueCompiler(_ValueCompiler):
             return f"0"
 
     @classmethod
-    def compile(cls, context, value, *, mode, inputs=None):
+    def compile(cls, state, value, *, mode, inputs=None):
         emitter = _Emitter()
-        compiler = cls(context, emitter, mode=mode, inputs=inputs)
+        compiler = cls(state, emitter, mode=mode, inputs=inputs)
         emitter.append(f"result = {compiler(value)}")
         return emitter.flush()
 
 
 class _LHSValueCompiler(_ValueCompiler):
-    def __init__(self, context, emitter, *, rhs, outputs=None):
-        super().__init__(context, emitter)
+    def __init__(self, state, emitter, *, rhs, outputs=None):
+        super().__init__(state, emitter)
         # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g.
         # the offset of a Part.
         self.rrhs = rhs
         # `lrhs` is used to translate the read part of a read-modify-write cycle during partial
         # update of an lvalue.
-        self.lrhs = _RHSValueCompiler(context, emitter, mode="next", inputs=None)
+        self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None)
         # If not None, `outputs` gets populated with signals on LHS.
         self.outputs = outputs
 
@@ -563,7 +550,7 @@ class _LHSValueCompiler(_ValueCompiler):
                 value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})"
             else: # unsigned
                 value_sign = f"{arg} & {value_mask}"
-            self.emitter.append(f"next_{self.context.get_out_signal(value)} = {value_sign}")
+            self.emitter.append(f"next_{self.state.get_out_signal(value)} = {value_sign}")
         return gen
 
     def on_Operator(self, value):
@@ -622,18 +609,18 @@ class _LHSValueCompiler(_ValueCompiler):
         return gen
 
     @classmethod
-    def compile(cls, context, stmt, *, inputs=None, outputs=None):
+    def compile(cls, state, stmt, *, inputs=None, outputs=None):
         emitter = _Emitter()
-        compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
+        compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
         compiler(stmt)
         return emitter.flush()
 
 
 class _StatementCompiler(StatementVisitor, _Compiler):
-    def __init__(self, context, emitter, *, inputs=None, outputs=None):
-        super().__init__(context, emitter)
-        self.rhs = _RHSValueCompiler(context, emitter, mode="curr", inputs=inputs)
-        self.lhs = _LHSValueCompiler(context, emitter, rhs=self.rhs, outputs=outputs)
+    def __init__(self, state, emitter, *, inputs=None, outputs=None):
+        super().__init__(state, emitter)
+        self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs)
+        self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs)
 
     def on_statements(self, stmts):
         for stmt in stmts:
@@ -677,12 +664,12 @@ class _StatementCompiler(StatementVisitor, _Compiler):
         raise NotImplementedError # :nocov:
 
     @classmethod
-    def compile(cls, context, stmt, *, inputs=None, outputs=None):
-        output_indexes = [context.get_signal(signal) for signal in stmt._lhs_signals()]
+    def compile(cls, state, stmt, *, inputs=None, outputs=None):
+        output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()]
         emitter = _Emitter()
         for signal_index in output_indexes:
             emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
-        compiler = cls(context, emitter, inputs=inputs, outputs=outputs)
+        compiler = cls(state, emitter, inputs=inputs, outputs=outputs)
         compiler(stmt)
         for signal_index in output_indexes:
             emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
@@ -690,10 +677,10 @@ class _StatementCompiler(StatementVisitor, _Compiler):
 
 
 class _CompiledProcess(_Process):
-    __slots__ = ("context", "comb", "name", "run")
+    __slots__ = ("state", "comb", "name", "run")
 
     def __init__(self, state, *, comb, name):
-        self.context = _EvalContext(state)
+        self.state = state
         self.comb = comb
         self.name = name
         self.run = None # set by _FragmentCompiler
@@ -730,11 +717,11 @@ class _FragmentCompiler:
 
             if domain_name is None:
                 for signal in domain_signals:
-                    signal_index = domain_process.context.get_signal(signal)
+                    signal_index = domain_process.state.get_signal(signal)
                     emitter.append(f"next_{signal_index} = {signal.reset}")
 
                 inputs = SignalSet()
-                _StatementCompiler(domain_process.context, emitter, inputs=inputs)(domain_stmts)
+                _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts)
 
                 for input in inputs:
                     self.state.for_signal(input).wait(domain_process)
@@ -752,21 +739,21 @@ class _FragmentCompiler:
                     self.state.for_signal(domain.rst).wait(domain_process, trigger=rst_trigger)
 
                 gen_asserts = []
-                clk_index = domain_process.context.get_signal(domain.clk)
+                clk_index = domain_process.state.get_signal(domain.clk)
                 gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}")
                 if domain.rst is not None and domain.async_reset:
-                    rst_index = domain_process.context.get_signal(domain.rst)
+                    rst_index = domain_process.state.get_signal(domain.rst)
                     gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}")
                 emitter.append(f"assert {' or '.join(gen_asserts)}")
 
                 for signal in domain_signals:
-                    signal_index = domain_process.context.get_signal(signal)
+                    signal_index = domain_process.state.get_signal(signal)
                     emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
 
-                _StatementCompiler(domain_process.context, emitter)(domain_stmts)
+                _StatementCompiler(domain_process.state, emitter)(domain_stmts)
 
             for signal in domain_signals:
-                signal_index = domain_process.context.get_signal(signal)
+                signal_index = domain_process.state.get_signal(signal)
                 emitter.append(f"slots[{signal_index}].set(next_{signal_index})")
 
             # There shouldn't be any exceptions raised by the generated code, but if there are
@@ -780,13 +767,13 @@ class _FragmentCompiler:
             else:
                 filename = "<string>"
 
-            exec_locals = {"slots": domain_process.context.slots, **_ValueCompiler.helpers}
+            exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers}
             exec(compile(code, filename, "exec"), exec_locals)
             domain_process.run = exec_locals["run"]
 
             processes.add(domain_process)
 
-            for used_signal in domain_process.context.indexes:
+            for used_signal in domain_process.state.signals:
                 add_signal_name(used_signal)
 
         for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments):
@@ -809,9 +796,8 @@ class _CoroutineProcess(_Process):
         self.runnable = True
         self.passive = False
         self.coroutine = self.constructor()
-        self.eval_context = _EvalContext(self.state)
         self.exec_locals = {
-            "slots": self.eval_context.slots,
+            "slots": self.state.slots,
             "result": None,
             **_ValueCompiler.helpers
         }
@@ -853,12 +839,12 @@ class _CoroutineProcess(_Process):
                 response = None
 
                 if isinstance(command, Value):
-                    exec(_RHSValueCompiler.compile(self.eval_context, command, mode="curr"),
+                    exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
                         self.exec_locals)
                     response = Const.normalize(self.exec_locals["result"], command.shape())
 
                 elif isinstance(command, Statement):
-                    exec(_StatementCompiler.compile(self.eval_context, command),
+                    exec(_StatementCompiler.compile(self.state, command),
                         self.exec_locals)
 
                 elif type(command) is Tick: