From 3eed2ef6c8b9da8e48ed1da151091c0e114ee5cb Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 13 Dec 2018 18:00:05 +0000 Subject: [PATCH] back.pysim: new simulator backend (WIP). --- .gitignore | 2 + examples/clkdiv.py | 7 +- nmigen/back/pysim.py | 372 +++++++++++++++++++++++++++++++++ nmigen/fhdl/ast.py | 33 ++- nmigen/fhdl/ir.py | 5 + nmigen/test/test_fhdl_dsl.py | 4 +- nmigen/test/test_fhdl_value.py | 15 +- nmigen/test/test_fhdl_xfrm.py | 10 +- setup.py | 6 + 9 files changed, 437 insertions(+), 17 deletions(-) create mode 100644 nmigen/back/pysim.py diff --git a/.gitignore b/.gitignore index 02ef482..e63a727 100644 --- a/.gitignore +++ b/.gitignore @@ -2,5 +2,7 @@ *.egg-info *.il *.v +*.vcd +*.gtkw /.coverage /htmlcov diff --git a/examples/clkdiv.py b/examples/clkdiv.py index 6900e1d..1eae421 100644 --- a/examples/clkdiv.py +++ b/examples/clkdiv.py @@ -1,5 +1,5 @@ from nmigen.fhdl import * -from nmigen.back import rtlil, verilog +from nmigen.back import rtlil, verilog, pysim class ClockDivisor: @@ -16,5 +16,10 @@ class ClockDivisor: ctr = ClockDivisor(factor=16) frag = ctr.get_fragment(platform=None) + # print(rtlil.convert(frag, ports=[ctr.o])) print(verilog.convert(frag, ports=[ctr.o])) + +sim = pysim.Simulator(frag, vcd_file=open("clkdiv.vcd", "w")) +sim.add_clock("sync", 1e-6) +with sim: sim.run_until(100e-6, run_passive=True) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py new file mode 100644 index 0000000..2afffab --- /dev/null +++ b/nmigen/back/pysim.py @@ -0,0 +1,372 @@ +from vcd import VCDWriter + +from ..tools import flatten +from ..fhdl.ast import * +from ..fhdl.xfrm import ValueTransformer, StatementTransformer + + +__all__ = ["Simulator", "Delay", "Passive"] + + +class _State: + __slots__ = ("curr", "curr_dirty", "next", "next_dirty") + + def __init__(self): + self.curr = ValueDict() + self.next = ValueDict() + self.curr_dirty = ValueSet() + self.next_dirty = ValueSet() + + def get(self, signal): + return self.curr[signal] + + def set_curr(self, signal, value): + assert isinstance(value, Const) + if self.curr[signal].value != value.value: + self.curr_dirty.add(signal) + self.curr[signal] = value + + def set_next(self, signal, value): + assert isinstance(value, Const) + if self.next[signal].value != value.value: + self.next_dirty.add(signal) + self.next[signal] = value + + def commit(self, signal): + old_value = self.curr[signal] + if self.curr[signal].value != self.next[signal].value: + self.next_dirty.remove(signal) + self.curr_dirty.add(signal) + self.curr[signal] = self.next[signal] + new_value = self.curr[signal] + return old_value, new_value + + def iter_dirty(self): + dirty, self.dirty = self.dirty, ValueSet() + for signal in dirty: + yield signal, self.curr[signal], self.next[signal] + + +class _RHSValueCompiler(ValueTransformer): + def __init__(self, sensitivity): + self.sensitivity = sensitivity + + def on_Const(self, value): + return lambda state: value + + def on_Signal(self, value): + self.sensitivity.add(value) + return lambda state: state.get(value) + + def on_ClockSignal(self, value): + raise NotImplementedError + + def on_ResetSignal(self, value): + raise NotImplementedError + + def on_Operator(self, value): + shape = value.shape() + if len(value.operands) == 1: + arg, = map(self, value.operands) + if value.op == "~": + return lambda state: Const(~arg(state).value, shape) + elif value.op == "-": + return lambda state: Const(-arg(state).value, shape) + elif len(value.operands) == 2: + lhs, rhs = map(self, value.operands) + if value.op == "+": + return lambda state: Const(lhs(state).value + rhs(state).value, shape) + if value.op == "-": + return lambda state: Const(lhs(state).value - rhs(state).value, shape) + if value.op == "&": + return lambda state: Const(lhs(state).value & rhs(state).value, shape) + if value.op == "|": + return lambda state: Const(lhs(state).value | rhs(state).value, shape) + if value.op == "^": + return lambda state: Const(lhs(state).value ^ rhs(state).value, shape) + elif value.op == "==": + lhs, rhs = map(self, value.operands) + return lambda state: Const(lhs(state).value == rhs(state).value, shape) + elif len(value.operands) == 3: + if value.op == "m": + sel, val1, val0 = map(self, value.operands) + return lambda state: val1(state) if sel(state).value else val0(state) + raise NotImplementedError("Operator '{}' not implemented".format(value.op)) + + def on_Slice(self, value): + shape = value.shape() + arg = self(value.value) + shift = value.start + mask = (1 << (value.end - value.start)) - 1 + return lambda state: Const((arg(state).value >> shift) & mask, shape) + + def on_Part(self, value): + raise NotImplementedError + + def on_Cat(self, value): + shape = value.shape() + parts = [] + offset = 0 + for opnd in value.operands: + parts.append((offset, (1 << len(opnd)) - 1, self(opnd))) + offset += len(opnd) + def eval(state): + result = 0 + for offset, mask, opnd in parts: + result |= (opnd(state).value & mask) << offset + return Const(result, shape) + return eval + + def on_Repl(self, value): + shape = value.shape() + offset = len(value.value) + mask = (1 << len(value.value)) - 1 + count = value.count + opnd = self(value.value) + def eval(state): + result = 0 + for _ in range(count): + result <<= offset + result |= opnd(state).value + return Const(result, shape) + return eval + + +class _StatementCompiler(StatementTransformer): + def __init__(self): + self.sensitivity = ValueSet() + self.rhs_compiler = _RHSValueCompiler(self.sensitivity) + + def lhs_compiler(self, value): + # TODO + return lambda state, arg: state.set_next(value, arg) + + def on_Assign(self, stmt): + assert isinstance(stmt.lhs, Signal) + shape = stmt.lhs.shape() + lhs = self.lhs_compiler(stmt.lhs) + rhs = self.rhs_compiler(stmt.rhs) + def run(state): + lhs(state, Const(rhs(state).value, shape)) + return run + + def on_Switch(self, stmt): + test = self.rhs_compiler(stmt.test) + cases = [] + for value, stmts in stmt.cases.items(): + if "-" in value: + mask = "".join("0" if b == "-" else "1" for b in value) + value = "".join("0" if b == "-" else b for b in value) + else: + mask = "1" * len(value) + mask = int(mask, 2) + value = int(value, 2) + cases.append((lambda test: test & mask == value, + self.on_statements(stmts))) + def run(state): + test_value = test(state).value + for check, body in cases: + if check(test_value): + body(state) + return + return run + + def on_statements(self, stmts): + stmts = [self.on_statement(stmt) for stmt in stmts] + def run(state): + for stmt in stmts: + stmt(state) + return run + + +class Simulator: + def __init__(self, fragment=None, vcd_file=None): + self._fragments = {} # fragment -> hierarchy + self._domains = {} # str -> ClockDomain + self._domain_triggers = ValueDict() # Signal -> str + self._domain_signals = {} # str -> {Signal} + self._signals = ValueSet() # {Signal} + self._comb_signals = ValueSet() # {Signal} + self._sync_signals = ValueSet() # {Signal} + self._user_signals = ValueSet() # {Signal} + + self._started = False + self._timestamp = 0. + self._state = _State() + + self._processes = set() # {process} + self._passive = set() # {process} + self._suspended = {} # process -> until + + self._handlers = ValueDict() # Signal -> lambda + + self._vcd_file = vcd_file + self._vcd_writer = None + self._vcd_signals = ValueDict() # signal -> set(vcd_signal) + + if fragment is not None: + fragment = fragment.prepare() + self._add_fragment(fragment) + self._domains = fragment.domains + for domain, cd in self._domains.items(): + self._domain_triggers[cd.clk] = domain + if cd.rst is not None: + self._domain_triggers[cd.rst] = domain + self._domain_signals[domain] = ValueSet() + + def _add_fragment(self, fragment, hierarchy=("top",)): + self._fragments[fragment] = hierarchy + for subfragment, name in fragment.subfragments: + self._add_fragment(subfragment, (*hierarchy, name)) + + def add_process(self, fn): + self._processes.add(fn) + + def add_clock(self, domain, period): + clk = self._domains[domain].clk + half_period = period / 2 + def clk_process(): + yield Passive() + while True: + yield clk.eq(1) + yield Delay(half_period) + yield clk.eq(0) + yield Delay(half_period) + self.add_process(clk_process()) + + def _signal_name_in_fragment(self, fragment, signal): + for subfragment, name in fragment.subfragments: + if signal in subfragment.ports: + return "{}_{}".format(name, signal.name) + return signal.name + + def __enter__(self): + if self._vcd_file: + self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps", + comment="Generated by nMigen") + + for fragment in self._fragments: + for signal in fragment.iter_signals(): + self._signals.add(signal) + + self._state.curr[signal] = self._state.next[signal] = \ + Const(signal.reset, signal.shape()) + self._state.curr_dirty.add(signal) + + if signal not in self._vcd_signals: + self._vcd_signals[signal] = set() + name = self._signal_name_in_fragment(fragment, signal) + suffix = None + while True: + try: + if suffix is None: + name_suffix = name + else: + name_suffix = "{}${}".format(name, suffix) + self._vcd_signals[signal].add(self._vcd_writer.register_var( + scope=".".join(self._fragments[fragment]), name=name_suffix, + var_type="wire", size=signal.nbits, init=signal.reset)) + break + except KeyError: + suffix = (suffix or 0) + 1 + + for domain, signals in fragment.drivers.items(): + if domain is None: + self._comb_signals.update(signals) + else: + self._sync_signals.update(signals) + self._domain_signals[domain].update(signals) + + compiler = _StatementCompiler() + handler = compiler(fragment.statements) + for signal in compiler.sensitivity: + self._handlers[signal] = handler + for domain, cd in fragment.domains.items(): + self._handlers[cd.clk] = handler + if cd.rst is not None: + self._handlers[cd.rst] = handler + + self._user_signals = self._signals - self._comb_signals - self._sync_signals + + def _commit_signal(self, signal): + old, new = self._state.commit(signal) + if old.value == 0 and new.value == 1 and signal in self._domain_triggers: + domain = self._domain_triggers[signal] + for sync_signal in self._state.next_dirty: + if sync_signal in self._domain_signals[domain]: + self._commit_signal(sync_signal) + + if self._vcd_writer: + for vcd_signal in self._vcd_signals[signal]: + self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value) + + def _handle_event(self): + while self._state.curr_dirty: + signal = self._state.curr_dirty.pop() + if signal in self._handlers: + self._handlers[signal](self._state) + + for signal in self._state.next_dirty: + if signal in self._comb_signals or signal in self._user_signals: + self._commit_signal(signal) + + def _force_signal(self, signal, value): + assert signal in self._comb_signals or signal in self._user_signals + self._state.set_next(signal, value) + self._commit_signal(signal) + + def _run_process(self, proc): + try: + stmt = proc.send(None) + except StopIteration: + self._processes.remove(proc) + self._passive.remove(proc) + self._suspended.remove(proc) + return + + if isinstance(stmt, Delay): + self._suspended[proc] = self._timestamp + stmt.interval + elif isinstance(stmt, Passive): + self._passive.add(proc) + elif isinstance(stmt, Assign): + assert isinstance(stmt.lhs, Signal) + assert isinstance(stmt.rhs, Const) + self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape())) + else: + raise TypeError("Received unsupported statement '{!r}' from process {}" + .format(stmt, proc)) + + def step(self, run_passive=False): + # Are there any delta cycles we should run? + while self._state.curr_dirty: + self._timestamp += 1e-10 + self._handle_event() + + # Are there any processes that haven't had a chance to run yet? + if len(self._processes) > len(self._suspended): + # Schedule an arbitrary one. + proc = (self._processes - set(self._suspended)).pop() + self._run_process(proc) + return True + + # All processes are suspended. Are any of them active? + if len(self._processes) > len(self._passive) or run_passive: + # Schedule the one with the lowest deadline. + proc, deadline = min(self._suspended.items(), key=lambda x: x[1]) + del self._suspended[proc] + self._timestamp = deadline + self._run_process(proc) + return True + + # No processes, or all processes are passive. Nothing to do! + return False + + def run_until(self, deadline, run_passive=False): + while self._timestamp < deadline: + if not self.step(run_passive): + return False + return True + + def __exit__(self, *args): + if self._vcd_writer: + self._vcd_writer.close(self._timestamp * 1e10) diff --git a/nmigen/fhdl/ast.py b/nmigen/fhdl/ast.py index 15f67e3..1f00bb7 100644 --- a/nmigen/fhdl/ast.py +++ b/nmigen/fhdl/ast.py @@ -10,7 +10,7 @@ from ..tools import * __all__ = [ "Value", "Const", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Signal", "ClockSignal", "ResetSignal", - "Statement", "Assign", "Switch", + "Statement", "Assign", "Switch", "Delay", "Passive", "ValueKey", "ValueDict", "ValueSet", ] @@ -216,17 +216,23 @@ class Const(Value): nbits : int signed : bool """ + src_loc = None + def __init__(self, value, shape=None): - super().__init__() self.value = int(value) if shape is None: - shape = self.value.bit_length(), self.value < 0 + shape = bits_for(self.value), self.value < 0 if isinstance(shape, int): shape = shape, self.value < 0 self.nbits, self.signed = shape if not isinstance(self.nbits, int) or self.nbits < 0: raise TypeError("Width must be a positive integer") + mask = (1 << self.nbits) - 1 + self.value &= mask + if self.signed and self.value >> (self.nbits - 1): + self.value |= ~mask + def shape(self): return self.nbits, self.signed @@ -347,6 +353,8 @@ class Slice(Value): raise IndexError("Cannot end slice {} bits into {}-bit value".format(end, n)) if end < 0: end += n + if start > end: + raise IndexError("Slice start {} must be less than slice end {}".format(start, end)) super().__init__() self.value = Value.wrap(value) @@ -680,6 +688,25 @@ class Switch(Statement): return "(switch {!r} {})".format(self.test, " ".join(cases)) +class Delay(Statement): + def __init__(self, interval): + self.interval = float(interval) + + def _rhs_signals(self): + return ValueSet() + + def __repr__(self): + return "(delay {:.3}us)".format(self.interval * 10e6) + + +class Passive(Statement): + def _rhs_signals(self): + return ValueSet() + + def __repr__(self): + return "(passive)" + + class ValueKey: def __init__(self, value): self.value = Value.wrap(value) diff --git a/nmigen/fhdl/ir.py b/nmigen/fhdl/ir.py index 8f4efee..8c8b7a9 100644 --- a/nmigen/fhdl/ir.py +++ b/nmigen/fhdl/ir.py @@ -52,6 +52,11 @@ class Fragment: signals = ValueSet() signals |= self.ports.keys() for domain, domain_signals in self.drivers.items(): + if domain is not None: + cd = self.domains[domain] + signals.add(cd.clk) + if cd.rst is not None: + signals.add(cd.rst) signals |= domain_signals return signals diff --git a/nmigen/test/test_fhdl_dsl.py b/nmigen/test/test_fhdl_dsl.py index c0acaed..e28ced6 100644 --- a/nmigen/test/test_fhdl_dsl.py +++ b/nmigen/test/test_fhdl_dsl.py @@ -116,7 +116,7 @@ class DSLTestCase(FHDLTestCase): ( (switch (cat (sig s1) (sig s2)) (case -1 (eq (sig c1) (const 1'd1))) - (case 1- (eq (sig c2) (const 0'd0))) + (case 1- (eq (sig c2) (const 1'd0))) ) ) """) @@ -134,7 +134,7 @@ class DSLTestCase(FHDLTestCase): ( (switch (cat (sig s1) (sig s2)) (case -1 (eq (sig c1) (const 1'd1))) - (case 1- (eq (sig c2) (const 0'd0))) + (case 1- (eq (sig c2) (const 1'd0))) (case -- (eq (sig c3) (const 1'd1))) ) ) diff --git a/nmigen/test/test_fhdl_value.py b/nmigen/test/test_fhdl_value.py index 892b5e0..8e7dfd1 100644 --- a/nmigen/test/test_fhdl_value.py +++ b/nmigen/test/test_fhdl_value.py @@ -59,10 +59,10 @@ class ValueTestCase(FHDLTestCase): class ConstTestCase(FHDLTestCase): def test_shape(self): - self.assertEqual(Const(0).shape(), (0, False)) + self.assertEqual(Const(0).shape(), (1, False)) self.assertEqual(Const(1).shape(), (1, False)) self.assertEqual(Const(10).shape(), (4, False)) - self.assertEqual(Const(-10).shape(), (4, True)) + self.assertEqual(Const(-10).shape(), (5, True)) self.assertEqual(Const(1, 4).shape(), (4, False)) self.assertEqual(Const(1, (4, True)).shape(), (4, True)) @@ -70,12 +70,15 @@ class ConstTestCase(FHDLTestCase): with self.assertRaises(TypeError): Const(1, -1) + def test_normalization(self): + self.assertEqual(Const(0b10110, (5, True)).value, -10) + def test_value(self): self.assertEqual(Const(10).value, 10) def test_repr(self): self.assertEqual(repr(Const(10)), "(const 4'd10)") - self.assertEqual(repr(Const(-10)), "(const 4'sd-10)") + self.assertEqual(repr(Const(-10)), "(const 5'sd-10)") def test_hash(self): with self.assertRaises(TypeError): @@ -205,7 +208,7 @@ class OperatorTestCase(FHDLTestCase): def test_mux(self): s = Const(0) v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False))) - self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))") + self.assertEqual(repr(v1), "(m (const 1'd0) (const 4'd0) (const 6'd0))") self.assertEqual(v1.shape(), (6, False)) v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True))) self.assertEqual(v2.shape(), (6, True)) @@ -216,7 +219,7 @@ class OperatorTestCase(FHDLTestCase): def test_bool(self): v = Const(0).bool() - self.assertEqual(repr(v), "(b (const 0'd0))") + self.assertEqual(repr(v), "(b (const 1'd0))") self.assertEqual(v.shape(), (1, False)) def test_hash(self): @@ -243,7 +246,7 @@ class CatTestCase(FHDLTestCase): c2 = Cat(Const(10), Const(1)) self.assertEqual(c2.shape(), (5, False)) c3 = Cat(Const(10), Const(1), Const(0)) - self.assertEqual(c3.shape(), (5, False)) + self.assertEqual(c3.shape(), (6, False)) def test_repr(self): c1 = Cat(Const(10), Const(1)) diff --git a/nmigen/test/test_fhdl_xfrm.py b/nmigen/test/test_fhdl_xfrm.py index 78346d8..faacf3b 100644 --- a/nmigen/test/test_fhdl_xfrm.py +++ b/nmigen/test/test_fhdl_xfrm.py @@ -32,7 +32,7 @@ class DomainRenamerTestCase(FHDLTestCase): ( (eq (sig s1) (clk pix)) (eq (rst pix) (sig s2)) - (eq (sig s3) (const 0'd0)) + (eq (sig s3) (const 1'd0)) (eq (sig s4) (clk other)) (eq (sig s5) (rst other)) ) @@ -127,7 +127,7 @@ class ResetInserterTestCase(FHDLTestCase): self.assertRepr(f.statements, """ ( (eq (sig s1) (const 1'd1)) - (eq (sig s2) (const 0'd0)) + (eq (sig s2) (const 1'd0)) (switch (sig c1) (case 1 (eq (sig s2) (const 1'd1))) ) @@ -144,7 +144,7 @@ class ResetInserterTestCase(FHDLTestCase): f = ResetInserter(self.c1)(f) self.assertRepr(f.statements, """ ( - (eq (sig s2) (const 0'd0)) + (eq (sig s2) (const 1'd0)) (switch (sig c1) (case 1 (eq (sig s2) (const 1'd1))) ) @@ -161,7 +161,7 @@ class ResetInserterTestCase(FHDLTestCase): f = ResetInserter(self.c1)(f) self.assertRepr(f.statements, """ ( - (eq (sig s3) (const 0'd0)) + (eq (sig s3) (const 1'd0)) (switch (sig c1) (case 1 ) ) @@ -206,7 +206,7 @@ class CEInserterTestCase(FHDLTestCase): self.assertRepr(f.statements, """ ( (eq (sig s1) (const 1'd1)) - (eq (sig s2) (const 0'd0)) + (eq (sig s2) (const 1'd0)) (switch (sig c1) (case 0 (eq (sig s2) (sig s2))) ) diff --git a/setup.py b/setup.py index c406779..d4b60e0 100644 --- a/setup.py +++ b/setup.py @@ -13,5 +13,11 @@ setup( description="Python toolbox for building complex digital hardware", #long_description="""TODO""", license="BSD", + install_requires=["pyvcd"], packages=find_packages(), + project_urls={ + #"Documentation": "https://glasgow.readthedocs.io/", + "Source Code": "https://github.com/m-labs/nmigen", + "Bug Tracker": "https://github.com/m-labs/nmigen/issues", + } ) -- 2.30.2