back.pysim: eliminate most dictionary lookups.
authorwhitequark <whitequark@whitequark.org>
Tue, 18 Dec 2018 15:28:27 +0000 (15:28 +0000)
committerwhitequark <whitequark@whitequark.org>
Tue, 18 Dec 2018 16:36:54 +0000 (16:36 +0000)
This makes the Glasgow testsuite about 30% faster.

.travis.yml
nmigen/back/pysim.py
setup.py

index bb32baed806f0a045a24b3e68d34fdd809d48147..d4dccf41ece443f86bd957cbb1744db790e45add 100644 (file)
@@ -10,7 +10,7 @@ cache:
 before_install:
   - export PATH="/usr/lib/ccache:$HOME/.local/bin:$PATH"
 install:
-  - pip install coverage codecov pyvcd
+  - pip install coverage codecov pyvcd bitarray
   - git clone https://github.com/YosysHQ/yosys
   - (cd yosys && if ! yosys -V || [ $(git rev-parse HEAD $(yosys -V | awk 'match($0,/sha1 ([0-9a-f]+)/,m) { print m[1] }') | uniq | wc -l) != 1 ]; then make CONFIG=gcc ENABLE_ABC=0 PREFIX=$HOME/.local install; fi)
 script:
index 1c14ba69e53f1fc8250012b35230afb61cc26918..55ab15b6ed6dd5488989e7f45282fc1b29675663 100644 (file)
@@ -1,6 +1,7 @@
 import math
 import inspect
 from contextlib import contextmanager
+from bitarray import bitarray
 from vcd import VCDWriter
 from vcd.gtkw import GTKWSave
 
@@ -22,30 +23,50 @@ class _State:
     def __init__(self):
         self.curr = []
         self.next = []
-        self.curr_dirty = SignalSet()
-        self.next_dirty = SignalSet()
+        self.curr_dirty = bitarray()
+        self.next_dirty = bitarray()
 
-    def add(self, signal, value):
+    def add(self, value):
         slot = len(self.curr)
         self.curr.append(value)
         self.next.append(value)
-        self.curr_dirty.add(signal)
+        self.curr_dirty.append(True)
+        self.next_dirty.append(False)
         return slot
 
-    def set(self, signal, slot, value):
+    def set(self, slot, value):
         if self.next[slot] != value:
-            self.next_dirty.add(signal)
+            self.next_dirty[slot] = True
             self.next[slot] = value
 
-    def commit(self, signal, slot):
+    def commit(self, slot):
         old_value = self.curr[slot]
         new_value = self.next[slot]
         if old_value != new_value:
-            self.next_dirty.remove(signal)
-            self.curr_dirty.add(signal)
+            self.next_dirty[slot] = False
+            self.curr_dirty[slot] = True
             self.curr[slot] = new_value
         return old_value, new_value
 
+    def flush_curr_dirty(self):
+        while True:
+            try:
+                slot = self.curr_dirty.index(True)
+            except ValueError:
+                break
+            self.curr_dirty[slot] = False
+            yield slot
+
+    def iter_next_dirty(self):
+        start = 0
+        while True:
+            try:
+                slot  = self.next_dirty.index(True, start)
+                start = slot + 1
+            except ValueError:
+                break
+            yield slot
+
 
 normalize = Const.normalize
 
@@ -185,7 +206,7 @@ class _LHSValueCompiler(AbstractValueTransformer):
         shape = value.shape()
         value_slot = self.signal_slots[value]
         def eval(state, rhs):
-            state.set(value, value_slot, normalize(rhs, shape))
+            state.set(value_slot, normalize(rhs, shape))
         return eval
 
     def on_ClockSignal(self, value):
@@ -293,15 +314,17 @@ class Simulator:
     def __init__(self, fragment, vcd_file=None, gtkw_file=None, traces=()):
         self._fragment        = fragment
 
+        self._signal_slots    = SignalDict()  # Signal -> int/slot
+        self._slot_signals    = list()        # int/slot -> Signal
+
         self._domains         = dict()        # str/domain -> ClockDomain
-        self._domain_triggers = SignalDict()  # Signal -> str/domain
-        self._domain_signals  = dict()        # str/domain -> {Signal}
+        self._domain_triggers = list()        # int/slot -> str/domain
 
         self._signals         = SignalSet()   # {Signal}
-        self._comb_signals    = SignalSet()   # {Signal}
-        self._sync_signals    = SignalSet()   # {Signal}
-        self._user_signals    = SignalSet()   # {Signal}
-        self._signal_slots    = SignalDict()  # Signal -> int/slot
+        self._comb_signals    = bitarray()    # {Signal}
+        self._sync_signals    = bitarray()    # {Signal}
+        self._user_signals    = bitarray()    # {Signal}
+        self._domain_signals  = dict()        # str/domain -> {Signal}
 
         self._started         = False
         self._timestamp       = 0.
@@ -317,12 +340,12 @@ class Simulator:
         self._wait_deadline   = dict()        # process -> float/timestamp
         self._wait_tick       = dict()        # process -> str/domain
 
-        self._funclets        = SignalDict()  # Signal -> set(lambda)
+        self._funclets        = list()        # int/slot -> set(lambda)
 
         self._vcd_file        = vcd_file
         self._vcd_writer      = None
-        self._vcd_signals     = SignalDict()  # signal -> set(vcd_signal)
-        self._vcd_names       = SignalDict()  # signal -> str/name
+        self._vcd_signals     = list()        # int/slot -> set(vcd_signal)
+        self._vcd_names       = list()        # int/slot -> str/name
         self._gtkw_file       = gtkw_file
         self._traces          = traces
 
@@ -387,13 +410,7 @@ class Simulator:
                                          comment="Generated by nMigen")
 
         root_fragment = self._fragment.prepare()
-
         self._domains = root_fragment.domains
-        for domain, cd in self._domains.items():
-            self._domain_triggers[cd.clk] = domain
-            if cd.rst is not None:
-                self._domain_triggers[cd.rst] = domain
-            self._domain_signals[domain] = SignalSet()
 
         hierarchy = {}
         def add_fragment(fragment, scope=()):
@@ -402,21 +419,48 @@ class Simulator:
                 add_fragment(subfragment, (*scope, name))
         add_fragment(root_fragment)
 
+        def add_signal(signal):
+            if signal not in self._signals:
+                self._signals.add(signal)
+
+                signal_slot = self._state.add(normalize(signal.reset, signal.shape()))
+                self._signal_slots[signal] = signal_slot
+                self._slot_signals.append(signal)
+
+                self._comb_signals.append(False)
+                self._sync_signals.append(False)
+                self._user_signals.append(False)
+                for domain in self._domains:
+                    if domain not in self._domain_signals:
+                        self._domain_signals[domain] = bitarray()
+                    self._domain_signals[domain].append(False)
+
+                self._domain_triggers.append(None)
+                if self._vcd_writer:
+                    self._vcd_signals.append(set())
+                    self._vcd_names.append(None)
+
+            return self._signal_slots[signal]
+
+        def add_domain_signal(signal, domain):
+            signal_slot = add_signal(signal)
+            self._domain_triggers[signal_slot] = domain
+
         for fragment, fragment_scope in hierarchy.items():
             for signal in fragment.iter_signals():
-                if signal not in self._signals:
-                    self._signals.add(signal)
+                add_signal(signal)
 
-                    signal_slot = self._state.add(signal, normalize(signal.reset, signal.shape()))
-                    self._signal_slots[signal] = signal_slot
+            for domain, cd in fragment.domains.items():
+                add_domain_signal(cd.clk, domain)
+                if cd.rst is not None:
+                    add_domain_signal(cd.rst, domain)
 
         for fragment, fragment_scope in hierarchy.items():
             for signal in fragment.iter_signals():
                 if not self._vcd_writer:
                     continue
 
-                if signal not in self._vcd_signals:
-                    self._vcd_signals[signal] = set()
+                signal_slot = self._signal_slots[signal]
 
                 for subfragment, name in fragment.subfragments:
                     if signal in subfragment.ports:
@@ -441,21 +485,27 @@ class Simulator:
                             var_name_suffix = var_name
                         else:
                             var_name_suffix = "{}${}".format(var_name, suffix)
-                        self._vcd_signals[signal].add(self._vcd_writer.register_var(
+                        self._vcd_signals[signal_slot].add(self._vcd_writer.register_var(
                             scope=".".join(fragment_scope), name=var_name_suffix,
                             var_type=var_type, size=var_size, init=var_init))
-                        if signal not in self._vcd_names:
-                            self._vcd_names[signal] = ".".join(fragment_scope + (var_name_suffix,))
+                        if self._vcd_names[signal_slot] is None:
+                            self._vcd_names[signal_slot] = \
+                                ".".join(fragment_scope + (var_name_suffix,))
                         break
                     except KeyError:
                         suffix = (suffix or 0) + 1
 
             for domain, signals in fragment.drivers.items():
+                signals_bits = bitarray(len(self._signals))
+                signals_bits.setall(False)
+                for signal in signals:
+                    signals_bits[self._signal_slots[signal]] = True
+
                 if domain is None:
-                    self._comb_signals.update(signals)
+                    self._comb_signals |= signals_bits
                 else:
-                    self._sync_signals.update(signals)
-                    self._domain_signals[domain].update(signals)
+                    self._sync_signals |= signals_bits
+                    self._domain_signals[domain] |= signals_bits
 
             statements = []
             for signal in fragment.iter_comb():
@@ -468,9 +518,10 @@ class Simulator:
             funclet = compiler(statements)
 
             def add_funclet(signal, funclet):
-                if signal not in self._funclets:
-                    self._funclets[signal] = set()
-                self._funclets[signal].add(funclet)
+                signal_slot = self._signal_slots[signal]
+                while len(self._funclets) <= signal_slot:
+                    self._funclets.append(set())
+                self._funclets[signal_slot].add(funclet)
 
             for signal in compiler.sensitivity:
                 add_funclet(signal, funclet)
@@ -479,7 +530,10 @@ class Simulator:
                 if cd.rst is not None:
                     add_funclet(cd.rst, funclet)
 
-        self._user_signals = self._signals - self._comb_signals - self._sync_signals
+        self._user_signals = bitarray(len(self._signals))
+        self._user_signals.setall(True)
+        self._user_signals &= ~self._comb_signals
+        self._user_signals &= ~self._sync_signals
 
         return self
 
@@ -489,30 +543,31 @@ class Simulator:
         # that need their statements to be reevaluated because the signals changed at the previous
         # delta cycle.
         funclets = set()
-        while self._state.curr_dirty:
-            signal = self._state.curr_dirty.pop()
-            if signal in self._funclets:
-                funclets.update(self._funclets[signal])
+        for signal_slot in self._state.flush_curr_dirty():
+            funclets.update(self._funclets[signal_slot])
 
         # Second, compute the values of all signals at the start of the next delta cycle, by
         # running precompiled statements.
         for funclet in funclets:
             funclet(self._state)
 
-    def _commit_signal(self, signal, domains):
+    def _commit_signal(self, signal_slot, domains):
         """Perform the driver part of IR processes (aka RTLIL sync), for individual signals."""
         # Take the computed value (at the start of this delta cycle) of a signal (that could have
         # come from an IR process that ran earlier, or modified by a simulator process) and update
         # the value for this delta cycle.
-        old, new = self._state.commit(signal, self._signal_slots[signal])
+        old, new = self._state.commit(signal_slot)
+        if old == new:
+            return
 
         # If the signal is a clock that triggers synchronous logic, record that fact.
-        if (old, new) == (0, 1) and signal in self._domain_triggers:
-            domains.add(self._domain_triggers[signal])
+        if new == 1 and self._domain_triggers[signal_slot] is not None:
+            domains.add(self._domain_triggers[signal_slot])
 
-        if self._vcd_writer and old != new:
+        if self._vcd_writer:
             # Finally, dump the new value to the VCD file.
-            for vcd_signal in self._vcd_signals[signal]:
+            for vcd_signal in self._vcd_signals[signal_slot]:
+                signal = self._slot_signals[signal_slot]
                 if signal.decoder:
                     var_value = signal.decoder(new).replace(" ", "_")
                 else:
@@ -524,9 +579,9 @@ class Simulator:
         """Perform the comb part of IR processes (aka RTLIL always)."""
         # Take the computed value (at the start of this delta cycle) of every comb signal and
         # update the value for this delta cycle.
-        for signal in self._state.next_dirty:
-            if signal in self._comb_signals:
-                self._commit_signal(signal, domains)
+        for signal_slot in self._state.iter_next_dirty():
+            if self._comb_signals[signal_slot]:
+                self._commit_signal(signal_slot, domains)
 
     def _commit_sync_signals(self, domains):
         """Perform the sync part of IR processes (aka RTLIL posedge)."""
@@ -543,9 +598,9 @@ class Simulator:
                 # Take the computed value (at the start of this delta cycle) of every sync signal
                 # in this domain and update the value for this delta cycle. This can trigger more
                 # synchronous logic, so record that.
-                for signal in self._state.next_dirty:
-                    if signal in self._domain_signals[domain]:
-                        self._commit_signal(signal, domains)
+                for signal_slot in self._state.iter_next_dirty():
+                    if self._domain_signals[domain][signal_slot]:
+                        self._commit_signal(signal_slot, domains)
 
                 # Wake up any simulator processes that wait for a domain tick.
                 for process, wait_domain in list(self._wait_tick.items()):
@@ -568,7 +623,7 @@ class Simulator:
         try:
             cmd = process.send(None)
             while True:
-                if isinstance(cmd, Delay):
+                if type(cmd) is Delay:
                     if cmd.interval is None:
                         interval = self._epsilon
                     else:
@@ -577,42 +632,53 @@ class Simulator:
                     self._suspended.add(process)
                     break
 
-                elif isinstance(cmd, Tick):
+                elif type(cmd) is Tick:
                     self._wait_tick[process] = cmd.domain
                     self._suspended.add(process)
                     break
 
-                elif isinstance(cmd, Passive):
+                elif type(cmd) is Passive:
                     self._passive.add(process)
 
-                elif isinstance(cmd, Value):
-                    compiler = _RHSValueCompiler(self._signal_slots)
-                    funclet = compiler(cmd)
-                    cmd = process.send(funclet(self._state))
-                    continue
-
-                elif isinstance(cmd, Assign):
+                elif type(cmd) is Assign:
                     lhs_signals = cmd.lhs._lhs_signals()
                     for signal in lhs_signals:
+                        signal_slot = self._signal_slots[signal]
                         if not signal in self._signals:
                             raise ValueError("Process '{}' sent a request to set signal '{!r}', "
                                              "which is not a part of simulation"
                                              .format(self._name_process(process), signal))
-                        if signal in self._comb_signals:
+                        if self._comb_signals[signal_slot]:
                             raise ValueError("Process '{}' sent a request to set signal '{!r}', "
                                              "which is a part of combinatorial assignment in "
                                              "simulation"
                                              .format(self._name_process(process), signal))
 
-                    compiler = _StatementCompiler(self._signal_slots)
-                    funclet = compiler(cmd)
-                    funclet(self._state)
+                    if type(cmd.lhs) is Signal and type(cmd.rhs) is Const:
+                        # Fast path.
+                        self._state.set(self._signal_slots[cmd.lhs],
+                                        normalize(cmd.rhs.value, cmd.lhs.shape()))
+                    else:
+                        compiler = _StatementCompiler(self._signal_slots)
+                        funclet = compiler(cmd)
+                        funclet(self._state)
 
                     domains = set()
                     for signal in lhs_signals:
-                        self._commit_signal(signal, domains)
+                        self._commit_signal(self._signal_slots[signal], domains)
                     self._commit_sync_signals(domains)
 
+                elif type(cmd) is Signal:
+                    # Fast path.
+                    cmd = process.send(self._state.curr[self._signal_slots[cmd]])
+                    continue
+
+                elif isinstance(cmd, Value):
+                    compiler = _RHSValueCompiler(self._signal_slots)
+                    funclet = compiler(cmd)
+                    cmd = process.send(funclet(self._state))
+                    continue
+
                 else:
                     raise TypeError("Received unsupported command '{!r}' from process '{}'"
                                     .format(cmd, self._name_process(process)))
@@ -628,7 +694,7 @@ class Simulator:
 
     def step(self, run_passive=False):
         # Are there any delta cycles we should run?
-        if self._state.curr_dirty:
+        if self._state.curr_dirty.any():
             # We might run some delta cycles, and we have simulator processes waiting on
             # a deadline. Take care to not exceed the closest deadline.
             if self._wait_deadline and \
@@ -638,7 +704,7 @@ class Simulator:
                 raise DeadlineError("Delta cycles exceeded process deadline; combinatorial loop?")
 
             domains = set()
-            while self._state.curr_dirty:
+            while self._state.curr_dirty.any():
                 self._update_dirty_signals()
                 self._commit_comb_signals(domains)
             self._commit_sync_signals(domains)
@@ -694,12 +760,13 @@ class Simulator:
             gtkw_save.zoom_markers(math.log(self._epsilon / self._fastest_clock) - 14)
 
             def add_trace(signal, **kwargs):
-                if signal in self._vcd_names:
+                signal_slot = self._signal_slots[signal]
+                if self._vcd_names[signal_slot] is not None:
                     if len(signal) > 1:
                         suffix = "[{}:0]".format(len(signal) - 1)
                     else:
                         suffix = ""
-                    gtkw_save.trace(self._vcd_names[signal] + suffix, **kwargs)
+                    gtkw_save.trace(self._vcd_names[signal_slot] + suffix, **kwargs)
 
             for domain, cd in self._domains.items():
                 with gtkw_save.group("d.{}".format(domain)):
index d7aff62d8116b463945c435a3587b85b54515559..b1b27a9687dc2d0692010746eb6f9e68deaf32ab 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -14,7 +14,7 @@ setup(
     description="Python toolbox for building complex digital hardware",
     #long_description="""TODO""",
     license="BSD",
-    install_requires=["pyvcd"],
+    install_requires=["pyvcd", "bitarray"],
     packages=find_packages(),
     project_urls={
         #"Documentation": "https://glasgow.readthedocs.io/",