back.pysim: new simulator backend (WIP).
[nmigen.git] / nmigen / back / pysim.py
1 from vcd import VCDWriter
2
3 from ..tools import flatten
4 from ..fhdl.ast import *
5 from ..fhdl.xfrm import ValueTransformer, StatementTransformer
6
7
8 __all__ = ["Simulator", "Delay", "Passive"]
9
10
11 class _State:
12 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
13
14 def __init__(self):
15 self.curr = ValueDict()
16 self.next = ValueDict()
17 self.curr_dirty = ValueSet()
18 self.next_dirty = ValueSet()
19
20 def get(self, signal):
21 return self.curr[signal]
22
23 def set_curr(self, signal, value):
24 assert isinstance(value, Const)
25 if self.curr[signal].value != value.value:
26 self.curr_dirty.add(signal)
27 self.curr[signal] = value
28
29 def set_next(self, signal, value):
30 assert isinstance(value, Const)
31 if self.next[signal].value != value.value:
32 self.next_dirty.add(signal)
33 self.next[signal] = value
34
35 def commit(self, signal):
36 old_value = self.curr[signal]
37 if self.curr[signal].value != self.next[signal].value:
38 self.next_dirty.remove(signal)
39 self.curr_dirty.add(signal)
40 self.curr[signal] = self.next[signal]
41 new_value = self.curr[signal]
42 return old_value, new_value
43
44 def iter_dirty(self):
45 dirty, self.dirty = self.dirty, ValueSet()
46 for signal in dirty:
47 yield signal, self.curr[signal], self.next[signal]
48
49
50 class _RHSValueCompiler(ValueTransformer):
51 def __init__(self, sensitivity):
52 self.sensitivity = sensitivity
53
54 def on_Const(self, value):
55 return lambda state: value
56
57 def on_Signal(self, value):
58 self.sensitivity.add(value)
59 return lambda state: state.get(value)
60
61 def on_ClockSignal(self, value):
62 raise NotImplementedError
63
64 def on_ResetSignal(self, value):
65 raise NotImplementedError
66
67 def on_Operator(self, value):
68 shape = value.shape()
69 if len(value.operands) == 1:
70 arg, = map(self, value.operands)
71 if value.op == "~":
72 return lambda state: Const(~arg(state).value, shape)
73 elif value.op == "-":
74 return lambda state: Const(-arg(state).value, shape)
75 elif len(value.operands) == 2:
76 lhs, rhs = map(self, value.operands)
77 if value.op == "+":
78 return lambda state: Const(lhs(state).value + rhs(state).value, shape)
79 if value.op == "-":
80 return lambda state: Const(lhs(state).value - rhs(state).value, shape)
81 if value.op == "&":
82 return lambda state: Const(lhs(state).value & rhs(state).value, shape)
83 if value.op == "|":
84 return lambda state: Const(lhs(state).value | rhs(state).value, shape)
85 if value.op == "^":
86 return lambda state: Const(lhs(state).value ^ rhs(state).value, shape)
87 elif value.op == "==":
88 lhs, rhs = map(self, value.operands)
89 return lambda state: Const(lhs(state).value == rhs(state).value, shape)
90 elif len(value.operands) == 3:
91 if value.op == "m":
92 sel, val1, val0 = map(self, value.operands)
93 return lambda state: val1(state) if sel(state).value else val0(state)
94 raise NotImplementedError("Operator '{}' not implemented".format(value.op))
95
96 def on_Slice(self, value):
97 shape = value.shape()
98 arg = self(value.value)
99 shift = value.start
100 mask = (1 << (value.end - value.start)) - 1
101 return lambda state: Const((arg(state).value >> shift) & mask, shape)
102
103 def on_Part(self, value):
104 raise NotImplementedError
105
106 def on_Cat(self, value):
107 shape = value.shape()
108 parts = []
109 offset = 0
110 for opnd in value.operands:
111 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
112 offset += len(opnd)
113 def eval(state):
114 result = 0
115 for offset, mask, opnd in parts:
116 result |= (opnd(state).value & mask) << offset
117 return Const(result, shape)
118 return eval
119
120 def on_Repl(self, value):
121 shape = value.shape()
122 offset = len(value.value)
123 mask = (1 << len(value.value)) - 1
124 count = value.count
125 opnd = self(value.value)
126 def eval(state):
127 result = 0
128 for _ in range(count):
129 result <<= offset
130 result |= opnd(state).value
131 return Const(result, shape)
132 return eval
133
134
135 class _StatementCompiler(StatementTransformer):
136 def __init__(self):
137 self.sensitivity = ValueSet()
138 self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
139
140 def lhs_compiler(self, value):
141 # TODO
142 return lambda state, arg: state.set_next(value, arg)
143
144 def on_Assign(self, stmt):
145 assert isinstance(stmt.lhs, Signal)
146 shape = stmt.lhs.shape()
147 lhs = self.lhs_compiler(stmt.lhs)
148 rhs = self.rhs_compiler(stmt.rhs)
149 def run(state):
150 lhs(state, Const(rhs(state).value, shape))
151 return run
152
153 def on_Switch(self, stmt):
154 test = self.rhs_compiler(stmt.test)
155 cases = []
156 for value, stmts in stmt.cases.items():
157 if "-" in value:
158 mask = "".join("0" if b == "-" else "1" for b in value)
159 value = "".join("0" if b == "-" else b for b in value)
160 else:
161 mask = "1" * len(value)
162 mask = int(mask, 2)
163 value = int(value, 2)
164 cases.append((lambda test: test & mask == value,
165 self.on_statements(stmts)))
166 def run(state):
167 test_value = test(state).value
168 for check, body in cases:
169 if check(test_value):
170 body(state)
171 return
172 return run
173
174 def on_statements(self, stmts):
175 stmts = [self.on_statement(stmt) for stmt in stmts]
176 def run(state):
177 for stmt in stmts:
178 stmt(state)
179 return run
180
181
182 class Simulator:
183 def __init__(self, fragment=None, vcd_file=None):
184 self._fragments = {} # fragment -> hierarchy
185 self._domains = {} # str -> ClockDomain
186 self._domain_triggers = ValueDict() # Signal -> str
187 self._domain_signals = {} # str -> {Signal}
188 self._signals = ValueSet() # {Signal}
189 self._comb_signals = ValueSet() # {Signal}
190 self._sync_signals = ValueSet() # {Signal}
191 self._user_signals = ValueSet() # {Signal}
192
193 self._started = False
194 self._timestamp = 0.
195 self._state = _State()
196
197 self._processes = set() # {process}
198 self._passive = set() # {process}
199 self._suspended = {} # process -> until
200
201 self._handlers = ValueDict() # Signal -> lambda
202
203 self._vcd_file = vcd_file
204 self._vcd_writer = None
205 self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
206
207 if fragment is not None:
208 fragment = fragment.prepare()
209 self._add_fragment(fragment)
210 self._domains = fragment.domains
211 for domain, cd in self._domains.items():
212 self._domain_triggers[cd.clk] = domain
213 if cd.rst is not None:
214 self._domain_triggers[cd.rst] = domain
215 self._domain_signals[domain] = ValueSet()
216
217 def _add_fragment(self, fragment, hierarchy=("top",)):
218 self._fragments[fragment] = hierarchy
219 for subfragment, name in fragment.subfragments:
220 self._add_fragment(subfragment, (*hierarchy, name))
221
222 def add_process(self, fn):
223 self._processes.add(fn)
224
225 def add_clock(self, domain, period):
226 clk = self._domains[domain].clk
227 half_period = period / 2
228 def clk_process():
229 yield Passive()
230 while True:
231 yield clk.eq(1)
232 yield Delay(half_period)
233 yield clk.eq(0)
234 yield Delay(half_period)
235 self.add_process(clk_process())
236
237 def _signal_name_in_fragment(self, fragment, signal):
238 for subfragment, name in fragment.subfragments:
239 if signal in subfragment.ports:
240 return "{}_{}".format(name, signal.name)
241 return signal.name
242
243 def __enter__(self):
244 if self._vcd_file:
245 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
246 comment="Generated by nMigen")
247
248 for fragment in self._fragments:
249 for signal in fragment.iter_signals():
250 self._signals.add(signal)
251
252 self._state.curr[signal] = self._state.next[signal] = \
253 Const(signal.reset, signal.shape())
254 self._state.curr_dirty.add(signal)
255
256 if signal not in self._vcd_signals:
257 self._vcd_signals[signal] = set()
258 name = self._signal_name_in_fragment(fragment, signal)
259 suffix = None
260 while True:
261 try:
262 if suffix is None:
263 name_suffix = name
264 else:
265 name_suffix = "{}${}".format(name, suffix)
266 self._vcd_signals[signal].add(self._vcd_writer.register_var(
267 scope=".".join(self._fragments[fragment]), name=name_suffix,
268 var_type="wire", size=signal.nbits, init=signal.reset))
269 break
270 except KeyError:
271 suffix = (suffix or 0) + 1
272
273 for domain, signals in fragment.drivers.items():
274 if domain is None:
275 self._comb_signals.update(signals)
276 else:
277 self._sync_signals.update(signals)
278 self._domain_signals[domain].update(signals)
279
280 compiler = _StatementCompiler()
281 handler = compiler(fragment.statements)
282 for signal in compiler.sensitivity:
283 self._handlers[signal] = handler
284 for domain, cd in fragment.domains.items():
285 self._handlers[cd.clk] = handler
286 if cd.rst is not None:
287 self._handlers[cd.rst] = handler
288
289 self._user_signals = self._signals - self._comb_signals - self._sync_signals
290
291 def _commit_signal(self, signal):
292 old, new = self._state.commit(signal)
293 if old.value == 0 and new.value == 1 and signal in self._domain_triggers:
294 domain = self._domain_triggers[signal]
295 for sync_signal in self._state.next_dirty:
296 if sync_signal in self._domain_signals[domain]:
297 self._commit_signal(sync_signal)
298
299 if self._vcd_writer:
300 for vcd_signal in self._vcd_signals[signal]:
301 self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new.value)
302
303 def _handle_event(self):
304 while self._state.curr_dirty:
305 signal = self._state.curr_dirty.pop()
306 if signal in self._handlers:
307 self._handlers[signal](self._state)
308
309 for signal in self._state.next_dirty:
310 if signal in self._comb_signals or signal in self._user_signals:
311 self._commit_signal(signal)
312
313 def _force_signal(self, signal, value):
314 assert signal in self._comb_signals or signal in self._user_signals
315 self._state.set_next(signal, value)
316 self._commit_signal(signal)
317
318 def _run_process(self, proc):
319 try:
320 stmt = proc.send(None)
321 except StopIteration:
322 self._processes.remove(proc)
323 self._passive.remove(proc)
324 self._suspended.remove(proc)
325 return
326
327 if isinstance(stmt, Delay):
328 self._suspended[proc] = self._timestamp + stmt.interval
329 elif isinstance(stmt, Passive):
330 self._passive.add(proc)
331 elif isinstance(stmt, Assign):
332 assert isinstance(stmt.lhs, Signal)
333 assert isinstance(stmt.rhs, Const)
334 self._force_signal(stmt.lhs, Const(stmt.rhs.value, stmt.lhs.shape()))
335 else:
336 raise TypeError("Received unsupported statement '{!r}' from process {}"
337 .format(stmt, proc))
338
339 def step(self, run_passive=False):
340 # Are there any delta cycles we should run?
341 while self._state.curr_dirty:
342 self._timestamp += 1e-10
343 self._handle_event()
344
345 # Are there any processes that haven't had a chance to run yet?
346 if len(self._processes) > len(self._suspended):
347 # Schedule an arbitrary one.
348 proc = (self._processes - set(self._suspended)).pop()
349 self._run_process(proc)
350 return True
351
352 # All processes are suspended. Are any of them active?
353 if len(self._processes) > len(self._passive) or run_passive:
354 # Schedule the one with the lowest deadline.
355 proc, deadline = min(self._suspended.items(), key=lambda x: x[1])
356 del self._suspended[proc]
357 self._timestamp = deadline
358 self._run_process(proc)
359 return True
360
361 # No processes, or all processes are passive. Nothing to do!
362 return False
363
364 def run_until(self, deadline, run_passive=False):
365 while self._timestamp < deadline:
366 if not self.step(run_passive):
367 return False
368 return True
369
370 def __exit__(self, *args):
371 if self._vcd_writer:
372 self._vcd_writer.close(self._timestamp * 1e10)