From 2606ee33ad548bed1b9294bbca3962e834d12fd0 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 28 Jun 2020 05:04:16 +0000 Subject: [PATCH] back.pysim: simplify. Remove _EvalContext, which was a level of indirection serving almost no purpose. (The only case where it would be useful is repeatedly resetting a simulation that, each time it is reset, would create new signals to communicate with between coroutine processes. In that case the signal states would not be persisted in _SimulatorState, but would be removed with the _EvalContext that is recreated each time the simulation is reset. But this could be solved with a weak map instead.) This regresses simulator startup time by 10-15% for unknown reasons but is necessary to align pysim and future cxxsim. --- nmigen/back/pysim.py | 122 +++++++++++++++++++------------------------ 1 file changed, 54 insertions(+), 68 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 3077991..fb651e4 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -219,6 +219,7 @@ class _SignalState: class _SimulatorState: def __init__(self): self.signals = SignalDict() + self.slots = [] self.pending = set() self.timestamp = 0.0 @@ -227,20 +228,32 @@ class _SimulatorState: self.waveform_writer = None def reset(self): - for signal_state in self.signals.values(): - signal_state.reset() + for signal, index in self.signals.items(): + self.slots[index].curr = self.slots[index].next = signal.reset self.pending.clear() self.timestamp = 0.0 self.deadlines.clear() - def for_signal(self, signal): + def get_signal(self, signal): try: return self.signals[signal] except KeyError: - signal_state = _SignalState(signal, self.pending) - self.signals[signal] = signal_state - return signal_state + index = len(self.slots) + self.slots.append(_SignalState(signal, self.pending)) + self.signals[signal] = index + return index + + def get_in_signal(self, signal, *, trigger=None): + index = self.get_signal(signal) + self.slots[index].waiters[self] = trigger + return index + + def get_out_signal(self, signal): + return self.get_signal(signal) + + def for_signal(self, signal): + return self.slots[self.get_signal(signal)] def commit(self): awoken_any = False @@ -296,32 +309,6 @@ class _SimulatorState: self.waveform_writer = None -class _EvalContext: - __slots__ = ("state", "indexes", "slots") - - def __init__(self, state): - self.state = state - self.indexes = SignalDict() - self.slots = [] - - def get_signal(self, signal): - try: - return self.indexes[signal] - except KeyError: - index = len(self.slots) - self.slots.append(self.state.for_signal(signal)) - self.indexes[signal] = index - return index - - def get_in_signal(self, signal, *, trigger=None): - index = self.get_signal(signal) - self.slots[index].waiters[self] = trigger - return index - - def get_out_signal(self, signal): - return self.get_signal(signal) - - class _Emitter: def __init__(self): self._buffer = [] @@ -356,8 +343,8 @@ class _Emitter: class _Compiler: - def __init__(self, context, emitter): - self.context = context + def __init__(self, state, emitter): + self.state = state self.emitter = emitter @@ -388,8 +375,8 @@ class _ValueCompiler(ValueVisitor, _Compiler): class _RHSValueCompiler(_ValueCompiler): - def __init__(self, context, emitter, *, mode, inputs=None): - super().__init__(context, emitter) + def __init__(self, state, emitter, *, mode, inputs=None): + super().__init__(state, emitter) assert mode in ("curr", "next") self.mode = mode # If not None, `inputs` gets populated with RHS signals. @@ -403,9 +390,9 @@ class _RHSValueCompiler(_ValueCompiler): self.inputs.add(value) if self.mode == "curr": - return f"slots[{self.context.get_signal(value)}].{self.mode}" + return f"slots[{self.state.get_signal(value)}].{self.mode}" else: - return f"next_{self.context.get_signal(value)}" + return f"next_{self.state.get_signal(value)}" def on_Operator(self, value): def mask(value): @@ -531,22 +518,22 @@ class _RHSValueCompiler(_ValueCompiler): return f"0" @classmethod - def compile(cls, context, value, *, mode, inputs=None): + def compile(cls, state, value, *, mode, inputs=None): emitter = _Emitter() - compiler = cls(context, emitter, mode=mode, inputs=inputs) + compiler = cls(state, emitter, mode=mode, inputs=inputs) emitter.append(f"result = {compiler(value)}") return emitter.flush() class _LHSValueCompiler(_ValueCompiler): - def __init__(self, context, emitter, *, rhs, outputs=None): - super().__init__(context, emitter) + def __init__(self, state, emitter, *, rhs, outputs=None): + super().__init__(state, emitter) # `rrhs` is used to translate rvalues that are syntactically a part of an lvalue, e.g. # the offset of a Part. self.rrhs = rhs # `lrhs` is used to translate the read part of a read-modify-write cycle during partial # update of an lvalue. - self.lrhs = _RHSValueCompiler(context, emitter, mode="next", inputs=None) + self.lrhs = _RHSValueCompiler(state, emitter, mode="next", inputs=None) # If not None, `outputs` gets populated with signals on LHS. self.outputs = outputs @@ -563,7 +550,7 @@ class _LHSValueCompiler(_ValueCompiler): value_sign = f"sign({arg} & {value_mask}, {-1 << (len(value) - 1)})" else: # unsigned value_sign = f"{arg} & {value_mask}" - self.emitter.append(f"next_{self.context.get_out_signal(value)} = {value_sign}") + self.emitter.append(f"next_{self.state.get_out_signal(value)} = {value_sign}") return gen def on_Operator(self, value): @@ -622,18 +609,18 @@ class _LHSValueCompiler(_ValueCompiler): return gen @classmethod - def compile(cls, context, stmt, *, inputs=None, outputs=None): + def compile(cls, state, stmt, *, inputs=None, outputs=None): emitter = _Emitter() - compiler = cls(context, emitter, inputs=inputs, outputs=outputs) + compiler = cls(state, emitter, inputs=inputs, outputs=outputs) compiler(stmt) return emitter.flush() class _StatementCompiler(StatementVisitor, _Compiler): - def __init__(self, context, emitter, *, inputs=None, outputs=None): - super().__init__(context, emitter) - self.rhs = _RHSValueCompiler(context, emitter, mode="curr", inputs=inputs) - self.lhs = _LHSValueCompiler(context, emitter, rhs=self.rhs, outputs=outputs) + def __init__(self, state, emitter, *, inputs=None, outputs=None): + super().__init__(state, emitter) + self.rhs = _RHSValueCompiler(state, emitter, mode="curr", inputs=inputs) + self.lhs = _LHSValueCompiler(state, emitter, rhs=self.rhs, outputs=outputs) def on_statements(self, stmts): for stmt in stmts: @@ -677,12 +664,12 @@ class _StatementCompiler(StatementVisitor, _Compiler): raise NotImplementedError # :nocov: @classmethod - def compile(cls, context, stmt, *, inputs=None, outputs=None): - output_indexes = [context.get_signal(signal) for signal in stmt._lhs_signals()] + def compile(cls, state, stmt, *, inputs=None, outputs=None): + output_indexes = [state.get_signal(signal) for signal in stmt._lhs_signals()] emitter = _Emitter() for signal_index in output_indexes: emitter.append(f"next_{signal_index} = slots[{signal_index}].next") - compiler = cls(context, emitter, inputs=inputs, outputs=outputs) + compiler = cls(state, emitter, inputs=inputs, outputs=outputs) compiler(stmt) for signal_index in output_indexes: emitter.append(f"slots[{signal_index}].set(next_{signal_index})") @@ -690,10 +677,10 @@ class _StatementCompiler(StatementVisitor, _Compiler): class _CompiledProcess(_Process): - __slots__ = ("context", "comb", "name", "run") + __slots__ = ("state", "comb", "name", "run") def __init__(self, state, *, comb, name): - self.context = _EvalContext(state) + self.state = state self.comb = comb self.name = name self.run = None # set by _FragmentCompiler @@ -730,11 +717,11 @@ class _FragmentCompiler: if domain_name is None: for signal in domain_signals: - signal_index = domain_process.context.get_signal(signal) + signal_index = domain_process.state.get_signal(signal) emitter.append(f"next_{signal_index} = {signal.reset}") inputs = SignalSet() - _StatementCompiler(domain_process.context, emitter, inputs=inputs)(domain_stmts) + _StatementCompiler(domain_process.state, emitter, inputs=inputs)(domain_stmts) for input in inputs: self.state.for_signal(input).wait(domain_process) @@ -752,21 +739,21 @@ class _FragmentCompiler: self.state.for_signal(domain.rst).wait(domain_process, trigger=rst_trigger) gen_asserts = [] - clk_index = domain_process.context.get_signal(domain.clk) + clk_index = domain_process.state.get_signal(domain.clk) gen_asserts.append(f"slots[{clk_index}].curr == {clk_trigger}") if domain.rst is not None and domain.async_reset: - rst_index = domain_process.context.get_signal(domain.rst) + rst_index = domain_process.state.get_signal(domain.rst) gen_asserts.append(f"slots[{rst_index}].curr == {rst_trigger}") emitter.append(f"assert {' or '.join(gen_asserts)}") for signal in domain_signals: - signal_index = domain_process.context.get_signal(signal) + signal_index = domain_process.state.get_signal(signal) emitter.append(f"next_{signal_index} = slots[{signal_index}].next") - _StatementCompiler(domain_process.context, emitter)(domain_stmts) + _StatementCompiler(domain_process.state, emitter)(domain_stmts) for signal in domain_signals: - signal_index = domain_process.context.get_signal(signal) + signal_index = domain_process.state.get_signal(signal) emitter.append(f"slots[{signal_index}].set(next_{signal_index})") # There shouldn't be any exceptions raised by the generated code, but if there are @@ -780,13 +767,13 @@ class _FragmentCompiler: else: filename = "" - exec_locals = {"slots": domain_process.context.slots, **_ValueCompiler.helpers} + exec_locals = {"slots": domain_process.state.slots, **_ValueCompiler.helpers} exec(compile(code, filename, "exec"), exec_locals) domain_process.run = exec_locals["run"] processes.add(domain_process) - for used_signal in domain_process.context.indexes: + for used_signal in domain_process.state.signals: add_signal_name(used_signal) for subfragment_index, (subfragment, subfragment_name) in enumerate(fragment.subfragments): @@ -809,9 +796,8 @@ class _CoroutineProcess(_Process): self.runnable = True self.passive = False self.coroutine = self.constructor() - self.eval_context = _EvalContext(self.state) self.exec_locals = { - "slots": self.eval_context.slots, + "slots": self.state.slots, "result": None, **_ValueCompiler.helpers } @@ -853,12 +839,12 @@ class _CoroutineProcess(_Process): response = None if isinstance(command, Value): - exec(_RHSValueCompiler.compile(self.eval_context, command, mode="curr"), + exec(_RHSValueCompiler.compile(self.state, command, mode="curr"), self.exec_locals) response = Const.normalize(self.exec_locals["result"], command.shape()) elif isinstance(command, Statement): - exec(_StatementCompiler.compile(self.eval_context, command), + exec(_StatementCompiler.compile(self.state, command), self.exec_locals) elif type(command) is Tick: -- 2.30.2