back.pysim: use bare ints for signal values (-5% runtime).
authorwhitequark <whitequark@whitequark.org>
Fri, 14 Dec 2018 03:05:57 +0000 (03:05 +0000)
committerwhitequark <whitequark@whitequark.org>
Fri, 14 Dec 2018 03:05:57 +0000 (03:05 +0000)
nmigen/back/pysim.py
nmigen/fhdl/ast.py

index 7dfda138ee7cc925f09c702438e159294e9c94a1..77410b81faa98473be391e78c524040eace0508d 100644 (file)
@@ -21,20 +21,20 @@ class _State:
         return self.curr[signal]
 
     def set_curr(self, signal, value):
-        assert isinstance(value, Const)
-        if self.curr[signal].value != value.value:
+        assert isinstance(value, int)
+        if self.curr[signal] != value:
             self.curr_dirty.add(signal)
             self.curr[signal] = value
 
     def set_next(self, signal, value):
-        assert isinstance(value, Const)
-        if self.next[signal].value != value.value:
+        assert isinstance(value, int)
+        if self.next[signal] != value:
             self.next_dirty.add(signal)
             self.next[signal] = value
 
     def commit(self, signal):
         old_value = self.curr[signal]
-        if self.curr[signal].value != self.next[signal].value:
+        if self.curr[signal] != self.next[signal]:
             self.next_dirty.remove(signal)
             self.curr_dirty.add(signal)
             self.curr[signal] = self.next[signal]
@@ -47,12 +47,15 @@ class _State:
             yield signal, self.curr[signal], self.next[signal]
 
 
+normalize = Const.normalize
+
+
 class _RHSValueCompiler(ValueTransformer):
     def __init__(self, sensitivity):
         self.sensitivity = sensitivity
 
     def on_Const(self, value):
-        return lambda state: value
+        return lambda state: value.value
 
     def on_Signal(self, value):
         self.sensitivity.add(value)
@@ -69,28 +72,27 @@ class _RHSValueCompiler(ValueTransformer):
         if len(value.operands) == 1:
             arg, = map(self, value.operands)
             if value.op == "~":
-                return lambda state: Const(~arg(state).value, shape)
-            elif value.op == "-":
-                return lambda state: Const(-arg(state).value, shape)
+                return lambda state: normalize(~arg(state), shape)
+            if value.op == "-":
+                return lambda state: normalize(-arg(state), shape)
         elif len(value.operands) == 2:
             lhs, rhs = map(self, value.operands)
             if value.op == "+":
-                return lambda state: Const(lhs(state).value +  rhs(state).value, shape)
+                return lambda state: normalize(lhs(state) + rhs(state), shape)
             if value.op == "-":
-                return lambda state: Const(lhs(state).value -  rhs(state).value, shape)
+                return lambda state: normalize(lhs(state) - rhs(state), shape)
             if value.op == "&":
-                return lambda state: Const(lhs(state).value &  rhs(state).value, shape)
+                return lambda state: normalize(lhs(state) & rhs(state), shape)
             if value.op == "|":
-                return lambda state: Const(lhs(state).value |  rhs(state).value, shape)
+                return lambda state: normalize(lhs(state) | rhs(state), shape)
             if value.op == "^":
-                return lambda state: Const(lhs(state).value ^  rhs(state).value, shape)
-            elif value.op == "==":
-                lhs, rhs = map(self, value.operands)
-                return lambda state: Const(lhs(state).value == rhs(state).value, shape)
+                return lambda state: normalize(lhs(state) ^ rhs(state), shape)
+            if value.op == "==":
+                return lambda state: normalize(lhs(state) == rhs(state), shape)
         elif len(value.operands) == 3:
             if value.op == "m":
                 sel, val1, val0 = map(self, value.operands)
-                return lambda state: val1(state) if sel(state).value else val0(state)
+                return lambda state: val1(state) if sel(state) else val0(state)
         raise NotImplementedError("Operator '{}' not implemented".format(value.op))
 
     def on_Slice(self, value):
@@ -98,7 +100,7 @@ class _RHSValueCompiler(ValueTransformer):
         arg   = self(value.value)
         shift = value.start
         mask  = (1 << (value.end - value.start)) - 1
-        return lambda state: Const((arg(state).value >> shift) & mask, shape)
+        return lambda state: normalize((arg(state) >> shift) & mask, shape)
 
     def on_Part(self, value):
         raise NotImplementedError
@@ -113,8 +115,8 @@ class _RHSValueCompiler(ValueTransformer):
         def eval(state):
             result = 0
             for offset, mask, opnd in parts:
-                result |= (opnd(state).value & mask) << offset
-            return Const(result, shape)
+                result |= (opnd(state) & mask) << offset
+            return normalize(result, shape)
         return eval
 
     def on_Repl(self, value):
@@ -127,8 +129,8 @@ class _RHSValueCompiler(ValueTransformer):
             result = 0
             for _ in range(count):
                 result <<= offset
-                result  |= opnd(state).value
-            return Const(result, shape)
+                result  |= opnd(state)
+            return normalize(result, shape)
         return eval
 
 
@@ -147,7 +149,7 @@ class _StatementCompiler(StatementTransformer):
         lhs   = self.lhs_compiler(stmt.lhs)
         rhs   = self.rhs_compiler(stmt.rhs)
         def run(state):
-            lhs(state, Const(rhs(state).value, shape))
+            lhs(state, normalize(rhs(state), shape))
         return run
 
     def on_Switch(self, stmt):
@@ -164,7 +166,7 @@ class _StatementCompiler(StatementTransformer):
             cases.append((lambda test: test & mask == value,
                           self.on_statements(stmts)))
         def run(state):
-            test_value = test(state).value
+            test_value = test(state)
             for check, body in cases:
                 if check(test_value):
                     body(state)
@@ -255,7 +257,7 @@ class Simulator:
                 self._signals.add(signal)
 
                 self._state.curr[signal] = self._state.next[signal] = \
-                    Const(signal.reset, signal.shape())
+                    normalize(signal.reset, signal.shape())
                 self._state.curr_dirty.add(signal)
 
                 if signal not in self._vcd_signals:
@@ -295,7 +297,7 @@ class Simulator:
 
     def _commit_signal(self, signal):
         old, new = self._state.commit(signal)
-        if old.value == 0 and new.value == 1 and signal in self._domain_triggers:
+        if (old, new) == (0, 1) and signal in self._domain_triggers:
             domain = self._domain_triggers[signal]
             for sync_signal in self._state.next_dirty:
                 if sync_signal in self._domain_signals[domain]:
@@ -303,7 +305,7 @@ class Simulator:
 
         if self._vcd_writer:
             for vcd_signal in self._vcd_signals[signal]:
-                self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value)
+                self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new)
 
     def _handle_event(self):
         handlers = set()
@@ -320,7 +322,7 @@ class Simulator:
                 self._commit_signal(signal)
 
     def _force_signal(self, signal, value):
-        assert signal in self._comb_signals or signal in self._user_signals
+        assert signal in self._user_signals
         self._state.set_next(signal, value)
         self._commit_signal(signal)
 
@@ -339,7 +341,7 @@ class Simulator:
         elif isinstance(stmt, Assign):
             assert isinstance(stmt.lhs, Signal)
             assert isinstance(stmt.rhs, Const)
-            self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape()))
+            self._force_signal(stmt.lhs, normalize(stmt.rhs.value, stmt.lhs.shape()))
         else:
             raise TypeError("Received unsupported statement '{!r}' from process {}"
                             .format(stmt, proc))
index 1f00bb7e246161604496f3cccba622448080dc87..a7b438736818ad935ec4b4f1da5f7b6e2c108cc1 100644 (file)
@@ -218,6 +218,15 @@ class Const(Value):
     """
     src_loc = None
 
+    @staticmethod
+    def normalize(value, shape):
+        nbits, signed = shape
+        mask = (1 << nbits) - 1
+        value &= mask
+        if signed and value >> (nbits - 1):
+            value |= ~mask
+        return value
+
     def __init__(self, value, shape=None):
         self.value = int(value)
         if shape is None:
@@ -227,11 +236,7 @@ class Const(Value):
         self.nbits, self.signed = shape
         if not isinstance(self.nbits, int) or self.nbits < 0:
             raise TypeError("Width must be a positive integer")
-
-        mask = (1 << self.nbits) - 1
-        self.value &= mask
-        if self.signed and self.value >> (self.nbits - 1):
-            self.value |= ~mask
+        self.value = self.normalize(self.value, shape)
 
     def shape(self):
         return self.nbits, self.signed