back.pysim: new simulator backend (WIP).
authorwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 18:00:05 +0000 (18:00 +0000)
committerwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 18:02:46 +0000 (18:02 +0000)
.gitignore
examples/clkdiv.py
nmigen/back/pysim.py [new file with mode: 0644]
nmigen/fhdl/ast.py
nmigen/fhdl/ir.py
nmigen/test/test_fhdl_dsl.py
nmigen/test/test_fhdl_value.py
nmigen/test/test_fhdl_xfrm.py
setup.py

index 02ef4820ee201fdf19c9eb94d5ca231cb28df5bf..e63a727133e7503708a44a78f6cbecd3224e8875 100644 (file)
@@ -2,5 +2,7 @@
 *.egg-info
 *.il
 *.v
+*.vcd
+*.gtkw
 /.coverage
 /htmlcov
index 6900e1d1c5a2cd180ac6de70abf8bdc7c2aa8607..1eae4215a2f81ee6319827b037e0eb72b79d22dc 100644 (file)
@@ -1,5 +1,5 @@
 from nmigen.fhdl import *
-from nmigen.back import rtlil, verilog
+from nmigen.back import rtlil, verilog, pysim
 
 
 class ClockDivisor:
@@ -16,5 +16,10 @@ class ClockDivisor:
 
 ctr  = ClockDivisor(factor=16)
 frag = ctr.get_fragment(platform=None)
+
 # print(rtlil.convert(frag, ports=[ctr.o]))
 print(verilog.convert(frag, ports=[ctr.o]))
+
+sim = pysim.Simulator(frag, vcd_file=open("clkdiv.vcd", "w"))
+sim.add_clock("sync", 1e-6)
+with sim: sim.run_until(100e-6, run_passive=True)
diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py
new file mode 100644 (file)
index 0000000..2afffab
--- /dev/null
@@ -0,0 +1,372 @@
+from vcd import VCDWriter
+
+from ..tools import flatten
+from ..fhdl.ast import *
+from ..fhdl.xfrm import ValueTransformer, StatementTransformer
+
+
+__all__ = ["Simulator", "Delay", "Passive"]
+
+
+class _State:
+    __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
+
+    def __init__(self):
+        self.curr = ValueDict()
+        self.next = ValueDict()
+        self.curr_dirty = ValueSet()
+        self.next_dirty = ValueSet()
+
+    def get(self, signal):
+        return self.curr[signal]
+
+    def set_curr(self, signal, value):
+        assert isinstance(value, Const)
+        if self.curr[signal].value != value.value:
+            self.curr_dirty.add(signal)
+            self.curr[signal] = value
+
+    def set_next(self, signal, value):
+        assert isinstance(value, Const)
+        if self.next[signal].value != value.value:
+            self.next_dirty.add(signal)
+            self.next[signal] = value
+
+    def commit(self, signal):
+        old_value = self.curr[signal]
+        if self.curr[signal].value != self.next[signal].value:
+            self.next_dirty.remove(signal)
+            self.curr_dirty.add(signal)
+            self.curr[signal] = self.next[signal]
+        new_value = self.curr[signal]
+        return old_value, new_value
+
+    def iter_dirty(self):
+        dirty, self.dirty = self.dirty, ValueSet()
+        for signal in dirty:
+            yield signal, self.curr[signal], self.next[signal]
+
+
+class _RHSValueCompiler(ValueTransformer):
+    def __init__(self, sensitivity):
+        self.sensitivity = sensitivity
+
+    def on_Const(self, value):
+        return lambda state: value
+
+    def on_Signal(self, value):
+        self.sensitivity.add(value)
+        return lambda state: state.get(value)
+
+    def on_ClockSignal(self, value):
+        raise NotImplementedError
+
+    def on_ResetSignal(self, value):
+        raise NotImplementedError
+
+    def on_Operator(self, value):
+        shape = value.shape()
+        if len(value.operands) == 1:
+            arg, = map(self, value.operands)
+            if value.op == "~":
+                return lambda state: Const(~arg(state).value, shape)
+            elif value.op == "-":
+                return lambda state: Const(-arg(state).value, shape)
+        elif len(value.operands) == 2:
+            lhs, rhs = map(self, value.operands)
+            if value.op == "+":
+                return lambda state: Const(lhs(state).value +  rhs(state).value, shape)
+            if value.op == "-":
+                return lambda state: Const(lhs(state).value -  rhs(state).value, shape)
+            if value.op == "&":
+                return lambda state: Const(lhs(state).value &  rhs(state).value, shape)
+            if value.op == "|":
+                return lambda state: Const(lhs(state).value |  rhs(state).value, shape)
+            if value.op == "^":
+                return lambda state: Const(lhs(state).value ^  rhs(state).value, shape)
+            elif value.op == "==":
+                lhs, rhs = map(self, value.operands)
+                return lambda state: Const(lhs(state).value == rhs(state).value, shape)
+        elif len(value.operands) == 3:
+            if value.op == "m":
+                sel, val1, val0 = map(self, value.operands)
+                return lambda state: val1(state) if sel(state).value else val0(state)
+        raise NotImplementedError("Operator '{}' not implemented".format(value.op))
+
+    def on_Slice(self, value):
+        shape = value.shape()
+        arg   = self(value.value)
+        shift = value.start
+        mask  = (1 << (value.end - value.start)) - 1
+        return lambda state: Const((arg(state).value >> shift) & mask, shape)
+
+    def on_Part(self, value):
+        raise NotImplementedError
+
+    def on_Cat(self, value):
+        shape  = value.shape()
+        parts  = []
+        offset = 0
+        for opnd in value.operands:
+            parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
+            offset += len(opnd)
+        def eval(state):
+            result = 0
+            for offset, mask, opnd in parts:
+                result |= (opnd(state).value & mask) << offset
+            return Const(result, shape)
+        return eval
+
+    def on_Repl(self, value):
+        shape  = value.shape()
+        offset = len(value.value)
+        mask   = (1 << len(value.value)) - 1
+        count  = value.count
+        opnd   = self(value.value)
+        def eval(state):
+            result = 0
+            for _ in range(count):
+                result <<= offset
+                result  |= opnd(state).value
+            return Const(result, shape)
+        return eval
+
+
+class _StatementCompiler(StatementTransformer):
+    def __init__(self):
+        self.sensitivity  = ValueSet()
+        self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
+
+    def lhs_compiler(self, value):
+        # TODO
+        return lambda state, arg: state.set_next(value, arg)
+
+    def on_Assign(self, stmt):
+        assert isinstance(stmt.lhs, Signal)
+        shape = stmt.lhs.shape()
+        lhs   = self.lhs_compiler(stmt.lhs)
+        rhs   = self.rhs_compiler(stmt.rhs)
+        def run(state):
+            lhs(state, Const(rhs(state).value, shape))
+        return run
+
+    def on_Switch(self, stmt):
+        test  = self.rhs_compiler(stmt.test)
+        cases = []
+        for value, stmts in stmt.cases.items():
+            if "-" in value:
+                mask  = "".join("0" if b == "-" else "1" for b in value)
+                value = "".join("0" if b == "-" else  b  for b in value)
+            else:
+                mask  = "1" * len(value)
+            mask  = int(mask,  2)
+            value = int(value, 2)
+            cases.append((lambda test: test & mask == value,
+                          self.on_statements(stmts)))
+        def run(state):
+            test_value = test(state).value
+            for check, body in cases:
+                if check(test_value):
+                    body(state)
+                return
+        return run
+
+    def on_statements(self, stmts):
+        stmts = [self.on_statement(stmt) for stmt in stmts]
+        def run(state):
+            for stmt in stmts:
+                stmt(state)
+        return run
+
+
+class Simulator:
+    def __init__(self, fragment=None, vcd_file=None):
+        self._fragments       = {}            # fragment -> hierarchy
+        self._domains         = {}            # str -> ClockDomain
+        self._domain_triggers = ValueDict()   # Signal -> str
+        self._domain_signals  = {}            # str -> {Signal}
+        self._signals         = ValueSet()    # {Signal}
+        self._comb_signals    = ValueSet()    # {Signal}
+        self._sync_signals    = ValueSet()    # {Signal}
+        self._user_signals    = ValueSet()    # {Signal}
+
+        self._started         = False
+        self._timestamp       = 0.
+        self._state           = _State()
+
+        self._processes       = set()         # {process}
+        self._passive         = set()         # {process}
+        self._suspended       = {}            # process -> until
+
+        self._handlers        = ValueDict()   # Signal -> lambda
+
+        self._vcd_file        = vcd_file
+        self._vcd_writer      = None
+        self._vcd_signals     = ValueDict()   # signal -> set(vcd_signal)
+
+        if fragment is not None:
+            fragment = fragment.prepare()
+            self._add_fragment(fragment)
+            self._domains = fragment.domains
+            for domain, cd in self._domains.items():
+                self._domain_triggers[cd.clk] = domain
+                if cd.rst is not None:
+                    self._domain_triggers[cd.rst] = domain
+                self._domain_signals[domain] = ValueSet()
+
+    def _add_fragment(self, fragment, hierarchy=("top",)):
+        self._fragments[fragment] = hierarchy
+        for subfragment, name in fragment.subfragments:
+            self._add_fragment(subfragment, (*hierarchy, name))
+
+    def add_process(self, fn):
+        self._processes.add(fn)
+
+    def add_clock(self, domain, period):
+        clk = self._domains[domain].clk
+        half_period = period / 2
+        def clk_process():
+            yield Passive()
+            while True:
+                yield clk.eq(1)
+                yield Delay(half_period)
+                yield clk.eq(0)
+                yield Delay(half_period)
+        self.add_process(clk_process())
+
+    def _signal_name_in_fragment(self, fragment, signal):
+        for subfragment, name in fragment.subfragments:
+            if signal in subfragment.ports:
+                return "{}_{}".format(name, signal.name)
+        return signal.name
+
+    def __enter__(self):
+        if self._vcd_file:
+            self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
+                                         comment="Generated by nMigen")
+
+        for fragment in self._fragments:
+            for signal in fragment.iter_signals():
+                self._signals.add(signal)
+
+                self._state.curr[signal] = self._state.next[signal] = \
+                    Const(signal.reset, signal.shape())
+                self._state.curr_dirty.add(signal)
+
+                if signal not in self._vcd_signals:
+                    self._vcd_signals[signal] = set()
+                name   = self._signal_name_in_fragment(fragment, signal)
+                suffix = None
+                while True:
+                    try:
+                        if suffix is None:
+                            name_suffix = name
+                        else:
+                            name_suffix = "{}${}".format(name, suffix)
+                        self._vcd_signals[signal].add(self._vcd_writer.register_var(
+                            scope=".".join(self._fragments[fragment]), name=name_suffix,
+                            var_type="wire", size=signal.nbits, init=signal.reset))
+                        break
+                    except KeyError:
+                        suffix = (suffix or 0) + 1
+
+            for domain, signals in fragment.drivers.items():
+                if domain is None:
+                    self._comb_signals.update(signals)
+                else:
+                    self._sync_signals.update(signals)
+                    self._domain_signals[domain].update(signals)
+
+            compiler = _StatementCompiler()
+            handler  = compiler(fragment.statements)
+            for signal in compiler.sensitivity:
+                self._handlers[signal] = handler
+            for domain, cd in fragment.domains.items():
+                self._handlers[cd.clk] = handler
+                if cd.rst is not None:
+                    self._handlers[cd.rst] = handler
+
+        self._user_signals = self._signals - self._comb_signals - self._sync_signals
+
+    def _commit_signal(self, signal):
+        old, new = self._state.commit(signal)
+        if old.value == 0 and new.value == 1 and signal in self._domain_triggers:
+            domain = self._domain_triggers[signal]
+            for sync_signal in self._state.next_dirty:
+                if sync_signal in self._domain_signals[domain]:
+                    self._commit_signal(sync_signal)
+
+        if self._vcd_writer:
+            for vcd_signal in self._vcd_signals[signal]:
+                self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value)
+
+    def _handle_event(self):
+        while self._state.curr_dirty:
+            signal = self._state.curr_dirty.pop()
+            if signal in self._handlers:
+                self._handlers[signal](self._state)
+
+        for signal in self._state.next_dirty:
+            if signal in self._comb_signals or signal in self._user_signals:
+                self._commit_signal(signal)
+
+    def _force_signal(self, signal, value):
+        assert signal in self._comb_signals or signal in self._user_signals
+        self._state.set_next(signal, value)
+        self._commit_signal(signal)
+
+    def _run_process(self, proc):
+        try:
+            stmt = proc.send(None)
+        except StopIteration:
+            self._processes.remove(proc)
+            self._passive.remove(proc)
+            self._suspended.remove(proc)
+            return
+
+        if isinstance(stmt, Delay):
+            self._suspended[proc] = self._timestamp + stmt.interval
+        elif isinstance(stmt, Passive):
+            self._passive.add(proc)
+        elif isinstance(stmt, Assign):
+            assert isinstance(stmt.lhs, Signal)
+            assert isinstance(stmt.rhs, Const)
+            self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape()))
+        else:
+            raise TypeError("Received unsupported statement '{!r}' from process {}"
+                            .format(stmt, proc))
+
+    def step(self, run_passive=False):
+        # Are there any delta cycles we should run?
+        while self._state.curr_dirty:
+            self._timestamp += 1e-10
+            self._handle_event()
+
+        # Are there any processes that haven't had a chance to run yet?
+        if len(self._processes) > len(self._suspended):
+            # Schedule an arbitrary one.
+            proc = (self._processes - set(self._suspended)).pop()
+            self._run_process(proc)
+            return True
+
+        # All processes are suspended. Are any of them active?
+        if len(self._processes) > len(self._passive) or run_passive:
+            # Schedule the one with the lowest deadline.
+            proc, deadline = min(self._suspended.items(), key=lambda x: x[1])
+            del self._suspended[proc]
+            self._timestamp = deadline
+            self._run_process(proc)
+            return True
+
+        # No processes, or all processes are passive. Nothing to do!
+        return False
+
+    def run_until(self, deadline, run_passive=False):
+        while self._timestamp < deadline:
+            if not self.step(run_passive):
+                return False
+        return True
+
+    def __exit__(self, *args):
+        if self._vcd_writer:
+            self._vcd_writer.close(self._timestamp * 1e10)
index 15f67e397c10031996df8f046b7eb5aea9bcdb27..1f00bb7e246161604496f3cccba622448080dc87 100644 (file)
@@ -10,7 +10,7 @@ from ..tools import *
 __all__ = [
     "Value", "Const", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
     "Signal", "ClockSignal", "ResetSignal",
-    "Statement", "Assign", "Switch",
+    "Statement", "Assign", "Switch", "Delay", "Passive",
     "ValueKey", "ValueDict", "ValueSet",
 ]
 
@@ -216,17 +216,23 @@ class Const(Value):
     nbits : int
     signed : bool
     """
+    src_loc = None
+
     def __init__(self, value, shape=None):
-        super().__init__()
         self.value = int(value)
         if shape is None:
-            shape = self.value.bit_length(), self.value < 0
+            shape = bits_for(self.value), self.value < 0
         if isinstance(shape, int):
             shape = shape, self.value < 0
         self.nbits, self.signed = shape
         if not isinstance(self.nbits, int) or self.nbits < 0:
             raise TypeError("Width must be a positive integer")
 
+        mask = (1 << self.nbits) - 1
+        self.value &= mask
+        if self.signed and self.value >> (self.nbits - 1):
+            self.value |= ~mask
+
     def shape(self):
         return self.nbits, self.signed
 
@@ -347,6 +353,8 @@ class Slice(Value):
             raise IndexError("Cannot end slice {} bits into {}-bit value".format(end, n))
         if end < 0:
             end += n
+        if start > end:
+            raise IndexError("Slice start {} must be less than slice end {}".format(start, end))
 
         super().__init__()
         self.value = Value.wrap(value)
@@ -680,6 +688,25 @@ class Switch(Statement):
         return "(switch {!r} {})".format(self.test, " ".join(cases))
 
 
+class Delay(Statement):
+    def __init__(self, interval):
+        self.interval = float(interval)
+
+    def _rhs_signals(self):
+        return ValueSet()
+
+    def __repr__(self):
+        return "(delay {:.3}us)".format(self.interval * 10e6)
+
+
+class Passive(Statement):
+    def _rhs_signals(self):
+        return ValueSet()
+
+    def __repr__(self):
+        return "(passive)"
+
+
 class ValueKey:
     def __init__(self, value):
         self.value = Value.wrap(value)
index 8f4efee00c07dbb3d706c69903d84ff7c8b4bf69..8c8b7a91e89eab688d6174c8804ab755d0b7c11a 100644 (file)
@@ -52,6 +52,11 @@ class Fragment:
         signals = ValueSet()
         signals |= self.ports.keys()
         for domain, domain_signals in self.drivers.items():
+            if domain is not None:
+                cd = self.domains[domain]
+                signals.add(cd.clk)
+                if cd.rst is not None:
+                    signals.add(cd.rst)
             signals |= domain_signals
         return signals
 
index c0acaed2eb40e907ed9b39baa10f9ba328a0db60..e28ced6fecd0bc07402bd58a13cfa8a5a5e1284f 100644 (file)
@@ -116,7 +116,7 @@ class DSLTestCase(FHDLTestCase):
         (
             (switch (cat (sig s1) (sig s2))
                 (case -1 (eq (sig c1) (const 1'd1)))
-                (case 1- (eq (sig c2) (const 0'd0)))
+                (case 1- (eq (sig c2) (const 1'd0)))
             )
         )
         """)
@@ -134,7 +134,7 @@ class DSLTestCase(FHDLTestCase):
         (
             (switch (cat (sig s1) (sig s2))
                 (case -1 (eq (sig c1) (const 1'd1)))
-                (case 1- (eq (sig c2) (const 0'd0)))
+                (case 1- (eq (sig c2) (const 1'd0)))
                 (case -- (eq (sig c3) (const 1'd1)))
             )
         )
index 892b5e073a520b576328bec8970952d8bdb9b956..8e7dfd1a736b0dd3dcb127cf45d7a5be8f333ab2 100644 (file)
@@ -59,10 +59,10 @@ class ValueTestCase(FHDLTestCase):
 
 class ConstTestCase(FHDLTestCase):
     def test_shape(self):
-        self.assertEqual(Const(0).shape(),   (0, False))
+        self.assertEqual(Const(0).shape(),   (1, False))
         self.assertEqual(Const(1).shape(),   (1, False))
         self.assertEqual(Const(10).shape(),  (4, False))
-        self.assertEqual(Const(-10).shape(), (4, True))
+        self.assertEqual(Const(-10).shape(), (5, True))
 
         self.assertEqual(Const(1, 4).shape(),         (4, False))
         self.assertEqual(Const(1, (4, True)).shape(), (4, True))
@@ -70,12 +70,15 @@ class ConstTestCase(FHDLTestCase):
         with self.assertRaises(TypeError):
             Const(1, -1)
 
+    def test_normalization(self):
+        self.assertEqual(Const(0b10110, (5, True)).value, -10)
+
     def test_value(self):
         self.assertEqual(Const(10).value, 10)
 
     def test_repr(self):
         self.assertEqual(repr(Const(10)),  "(const 4'd10)")
-        self.assertEqual(repr(Const(-10)), "(const 4'sd-10)")
+        self.assertEqual(repr(Const(-10)), "(const 5'sd-10)")
 
     def test_hash(self):
         with self.assertRaises(TypeError):
@@ -205,7 +208,7 @@ class OperatorTestCase(FHDLTestCase):
     def test_mux(self):
         s  = Const(0)
         v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False)))
-        self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))")
+        self.assertEqual(repr(v1), "(m (const 1'd0) (const 4'd0) (const 6'd0))")
         self.assertEqual(v1.shape(), (6, False))
         v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True)))
         self.assertEqual(v2.shape(), (6, True))
@@ -216,7 +219,7 @@ class OperatorTestCase(FHDLTestCase):
 
     def test_bool(self):
         v = Const(0).bool()
-        self.assertEqual(repr(v), "(b (const 0'd0))")
+        self.assertEqual(repr(v), "(b (const 1'd0))")
         self.assertEqual(v.shape(), (1, False))
 
     def test_hash(self):
@@ -243,7 +246,7 @@ class CatTestCase(FHDLTestCase):
         c2 = Cat(Const(10), Const(1))
         self.assertEqual(c2.shape(), (5, False))
         c3 = Cat(Const(10), Const(1), Const(0))
-        self.assertEqual(c3.shape(), (5, False))
+        self.assertEqual(c3.shape(), (6, False))
 
     def test_repr(self):
         c1 = Cat(Const(10), Const(1))
index 78346d8aacbd108e94d70012c4e3bdaf8b4d0429..faacf3b8bc809e86a0f51f4c7fa238ee1ff0de87 100644 (file)
@@ -32,7 +32,7 @@ class DomainRenamerTestCase(FHDLTestCase):
         (
             (eq (sig s1) (clk pix))
             (eq (rst pix) (sig s2))
-            (eq (sig s3) (const 0'd0))
+            (eq (sig s3) (const 1'd0))
             (eq (sig s4) (clk other))
             (eq (sig s5) (rst other))
         )
@@ -127,7 +127,7 @@ class ResetInserterTestCase(FHDLTestCase):
         self.assertRepr(f.statements, """
         (
             (eq (sig s1) (const 1'd1))
-            (eq (sig s2) (const 0'd0))
+            (eq (sig s2) (const 1'd0))
             (switch (sig c1)
                 (case 1 (eq (sig s2) (const 1'd1)))
             )
@@ -144,7 +144,7 @@ class ResetInserterTestCase(FHDLTestCase):
         f = ResetInserter(self.c1)(f)
         self.assertRepr(f.statements, """
         (
-            (eq (sig s2) (const 0'd0))
+            (eq (sig s2) (const 1'd0))
             (switch (sig c1)
                 (case 1 (eq (sig s2) (const 1'd1)))
             )
@@ -161,7 +161,7 @@ class ResetInserterTestCase(FHDLTestCase):
         f = ResetInserter(self.c1)(f)
         self.assertRepr(f.statements, """
         (
-            (eq (sig s3) (const 0'd0))
+            (eq (sig s3) (const 1'd0))
             (switch (sig c1)
                 (case 1 )
             )
@@ -206,7 +206,7 @@ class CEInserterTestCase(FHDLTestCase):
         self.assertRepr(f.statements, """
         (
             (eq (sig s1) (const 1'd1))
-            (eq (sig s2) (const 0'd0))
+            (eq (sig s2) (const 1'd0))
             (switch (sig c1)
                 (case 0 (eq (sig s2) (sig s2)))
             )
index c4067791e8cf0d06aaf81703d726ca41e1a66031..d4b60e0ca38bdd0ca4d6e09404099c9142ab9326 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -13,5 +13,11 @@ setup(
     description="Python toolbox for building complex digital hardware",
     #long_description="""TODO""",
     license="BSD",
+    install_requires=["pyvcd"],
     packages=find_packages(),
+    project_urls={
+        #"Documentation": "https://glasgow.readthedocs.io/",
+        "Source Code": "https://github.com/m-labs/nmigen",
+        "Bug Tracker": "https://github.com/m-labs/nmigen/issues",
+    }
 )