back.pysim: implement "sync processes", like migen.sim generators.
[nmigen.git] / nmigen / back / pysim.py
1 from vcd import VCDWriter
2
3 from ..tools import flatten
4 from ..fhdl.ast import *
5 from ..fhdl.xfrm import ValueTransformer, StatementTransformer
6
7
8 __all__ = ["Simulator", "Delay", "Passive"]
9
10
11 class _State:
12 __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
13
14 def __init__(self):
15 self.curr = ValueDict()
16 self.next = ValueDict()
17 self.curr_dirty = ValueSet()
18 self.next_dirty = ValueSet()
19
20 def get(self, signal):
21 return self.curr[signal]
22
23 def set_curr(self, signal, value):
24 assert isinstance(value, int)
25 if self.curr[signal] != value:
26 self.curr_dirty.add(signal)
27 self.curr[signal] = value
28
29 def set_next(self, signal, value):
30 assert isinstance(value, int)
31 if self.next[signal] != value:
32 self.next_dirty.add(signal)
33 self.next[signal] = value
34
35 def commit(self, signal):
36 old_value = self.curr[signal]
37 if self.curr[signal] != self.next[signal]:
38 self.next_dirty.remove(signal)
39 self.curr_dirty.add(signal)
40 self.curr[signal] = self.next[signal]
41 new_value = self.curr[signal]
42 return old_value, new_value
43
44 def iter_dirty(self):
45 dirty, self.dirty = self.dirty, ValueSet()
46 for signal in dirty:
47 yield signal, self.curr[signal], self.next[signal]
48
49
50 normalize = Const.normalize
51
52
53 class _RHSValueCompiler(ValueTransformer):
54 def __init__(self, sensitivity):
55 self.sensitivity = sensitivity
56
57 def on_Const(self, value):
58 return lambda state: value.value
59
60 def on_Signal(self, value):
61 self.sensitivity.add(value)
62 return lambda state: state.get(value)
63
64 def on_ClockSignal(self, value):
65 raise NotImplementedError
66
67 def on_ResetSignal(self, value):
68 raise NotImplementedError
69
70 def on_Operator(self, value):
71 shape = value.shape()
72 if len(value.operands) == 1:
73 arg, = map(self, value.operands)
74 if value.op == "~":
75 return lambda state: normalize(~arg(state), shape)
76 if value.op == "-":
77 return lambda state: normalize(-arg(state), shape)
78 elif len(value.operands) == 2:
79 lhs, rhs = map(self, value.operands)
80 if value.op == "+":
81 return lambda state: normalize(lhs(state) + rhs(state), shape)
82 if value.op == "-":
83 return lambda state: normalize(lhs(state) - rhs(state), shape)
84 if value.op == "&":
85 return lambda state: normalize(lhs(state) & rhs(state), shape)
86 if value.op == "|":
87 return lambda state: normalize(lhs(state) | rhs(state), shape)
88 if value.op == "^":
89 return lambda state: normalize(lhs(state) ^ rhs(state), shape)
90 if value.op == "==":
91 return lambda state: normalize(lhs(state) == rhs(state), shape)
92 elif len(value.operands) == 3:
93 if value.op == "m":
94 sel, val1, val0 = map(self, value.operands)
95 return lambda state: val1(state) if sel(state) else val0(state)
96 raise NotImplementedError("Operator '{}' not implemented".format(value.op))
97
98 def on_Slice(self, value):
99 shape = value.shape()
100 arg = self(value.value)
101 shift = value.start
102 mask = (1 << (value.end - value.start)) - 1
103 return lambda state: normalize((arg(state) >> shift) & mask, shape)
104
105 def on_Part(self, value):
106 raise NotImplementedError
107
108 def on_Cat(self, value):
109 shape = value.shape()
110 parts = []
111 offset = 0
112 for opnd in value.operands:
113 parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
114 offset += len(opnd)
115 def eval(state):
116 result = 0
117 for offset, mask, opnd in parts:
118 result |= (opnd(state) & mask) << offset
119 return normalize(result, shape)
120 return eval
121
122 def on_Repl(self, value):
123 shape = value.shape()
124 offset = len(value.value)
125 mask = (1 << len(value.value)) - 1
126 count = value.count
127 opnd = self(value.value)
128 def eval(state):
129 result = 0
130 for _ in range(count):
131 result <<= offset
132 result |= opnd(state)
133 return normalize(result, shape)
134 return eval
135
136
137 class _StatementCompiler(StatementTransformer):
138 def __init__(self):
139 self.sensitivity = ValueSet()
140 self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
141
142 def lhs_compiler(self, value):
143 # TODO
144 return lambda state, arg: state.set_next(value, arg)
145
146 def on_Assign(self, stmt):
147 assert isinstance(stmt.lhs, Signal)
148 shape = stmt.lhs.shape()
149 lhs = self.lhs_compiler(stmt.lhs)
150 rhs = self.rhs_compiler(stmt.rhs)
151 def run(state):
152 lhs(state, normalize(rhs(state), shape))
153 return run
154
155 def on_Switch(self, stmt):
156 test = self.rhs_compiler(stmt.test)
157 cases = []
158 for value, stmts in stmt.cases.items():
159 if "-" in value:
160 mask = "".join("0" if b == "-" else "1" for b in value)
161 value = "".join("0" if b == "-" else b for b in value)
162 else:
163 mask = "1" * len(value)
164 mask = int(mask, 2)
165 value = int(value, 2)
166 cases.append((lambda test: test & mask == value,
167 self.on_statements(stmts)))
168 def run(state):
169 test_value = test(state)
170 for check, body in cases:
171 if check(test_value):
172 body(state)
173 return
174 return run
175
176 def on_statements(self, stmts):
177 stmts = [self.on_statement(stmt) for stmt in stmts]
178 def run(state):
179 for stmt in stmts:
180 stmt(state)
181 return run
182
183
184 class Simulator:
185 def __init__(self, fragment=None, vcd_file=None):
186 self._fragments = {} # fragment -> hierarchy
187 self._domains = {} # str/domain -> ClockDomain
188 self._domain_triggers = ValueDict() # Signal -> str/domain
189 self._domain_signals = {} # str/domain -> {Signal}
190 self._signals = ValueSet() # {Signal}
191 self._comb_signals = ValueSet() # {Signal}
192 self._sync_signals = ValueSet() # {Signal}
193 self._user_signals = ValueSet() # {Signal}
194
195 self._started = False
196 self._timestamp = 0.
197 self._state = _State()
198
199 self._processes = set() # {process}
200 self._passive = set() # {process}
201 self._suspended = set() # {process}
202 self._wait_deadline = {} # process -> float/timestamp
203 self._wait_tick = {} # process -> str/domain
204
205 self._handlers = ValueDict() # Signal -> set(lambda)
206
207 self._vcd_file = vcd_file
208 self._vcd_writer = None
209 self._vcd_signals = ValueDict() # signal -> set(vcd_signal)
210
211 if fragment is not None:
212 fragment = fragment.prepare()
213 self._add_fragment(fragment)
214 self._domains = fragment.domains
215 for domain, cd in self._domains.items():
216 self._domain_triggers[cd.clk] = domain
217 if cd.rst is not None:
218 self._domain_triggers[cd.rst] = domain
219 self._domain_signals[domain] = ValueSet()
220
221 def _add_fragment(self, fragment, hierarchy=("top",)):
222 self._fragments[fragment] = hierarchy
223 for subfragment, name in fragment.subfragments:
224 self._add_fragment(subfragment, (*hierarchy, name))
225
226 def add_process(self, process):
227 self._processes.add(process)
228
229 def add_clock(self, domain, period):
230 clk = self._domains[domain].clk
231 half_period = period / 2
232 def clk_process():
233 yield Passive()
234 while True:
235 yield clk.eq(1)
236 yield Delay(half_period)
237 yield clk.eq(0)
238 yield Delay(half_period)
239 self.add_process(clk_process())
240
241 def add_sync_process(self, process, domain="sync"):
242 def sync_process():
243 try:
244 result = process.send(None)
245 while True:
246 result = process.send((yield (result or Tick(domain))))
247 except StopIteration:
248 pass
249 self.add_process(sync_process())
250
251 def _signal_name_in_fragment(self, fragment, signal):
252 for subfragment, name in fragment.subfragments:
253 if signal in subfragment.ports:
254 return "{}_{}".format(name, signal.name)
255 return signal.name
256
257 def _add_handler(self, signal, handler):
258 if signal not in self._handlers:
259 self._handlers[signal] = set()
260 self._handlers[signal].add(handler)
261
262 def __enter__(self):
263 if self._vcd_file:
264 self._vcd_writer = VCDWriter(self._vcd_file, timescale="100 ps",
265 comment="Generated by nMigen")
266
267 for fragment in self._fragments:
268 for signal in fragment.iter_signals():
269 self._signals.add(signal)
270
271 self._state.curr[signal] = self._state.next[signal] = \
272 normalize(signal.reset, signal.shape())
273 self._state.curr_dirty.add(signal)
274
275 if signal not in self._vcd_signals:
276 self._vcd_signals[signal] = set()
277 name = self._signal_name_in_fragment(fragment, signal)
278 suffix = None
279 while True:
280 try:
281 if suffix is None:
282 name_suffix = name
283 else:
284 name_suffix = "{}${}".format(name, suffix)
285 self._vcd_signals[signal].add(self._vcd_writer.register_var(
286 scope=".".join(self._fragments[fragment]), name=name_suffix,
287 var_type="wire", size=signal.nbits, init=signal.reset))
288 break
289 except KeyError:
290 suffix = (suffix or 0) + 1
291
292 for domain, signals in fragment.drivers.items():
293 if domain is None:
294 self._comb_signals.update(signals)
295 else:
296 self._sync_signals.update(signals)
297 self._domain_signals[domain].update(signals)
298
299 compiler = _StatementCompiler()
300 handler = compiler(fragment.statements)
301 for signal in compiler.sensitivity:
302 self._add_handler(signal, handler)
303 for domain, cd in fragment.domains.items():
304 self._add_handler(cd.clk, handler)
305 if cd.rst is not None:
306 self._add_handler(cd.rst, handler)
307
308 self._user_signals = self._signals - self._comb_signals - self._sync_signals
309
310 def _commit_signal(self, signal):
311 old, new = self._state.commit(signal)
312 if (old, new) == (0, 1) and signal in self._domain_triggers:
313 domain = self._domain_triggers[signal]
314 for sync_signal in self._state.next_dirty:
315 if sync_signal in self._domain_signals[domain]:
316 self._commit_signal(sync_signal)
317
318 for proc, wait_domain in list(self._wait_tick.items()):
319 if domain == wait_domain:
320 del self._wait_tick[proc]
321 self._suspended.remove(proc)
322
323 if self._vcd_writer:
324 for vcd_signal in self._vcd_signals[signal]:
325 self._vcd_writer.change(vcd_signal, self._timestamp * 1e10, new)
326
327 def _handle_event(self):
328 handlers = set()
329 while self._state.curr_dirty:
330 signal = self._state.curr_dirty.pop()
331 if signal in self._handlers:
332 handlers.update(self._handlers[signal])
333
334 for handler in handlers:
335 handler(self._state)
336
337 for signal in self._state.next_dirty:
338 if signal in self._comb_signals or signal in self._user_signals:
339 self._commit_signal(signal)
340
341 def _force_signal(self, signal, value):
342 assert signal in self._user_signals
343 self._state.set_next(signal, value)
344 self._commit_signal(signal)
345
346 def _run_process(self, proc):
347 try:
348 stmt = proc.send(None)
349 except StopIteration:
350 self._processes.remove(proc)
351 self._passive.discard(proc)
352 return
353
354 if isinstance(stmt, Delay):
355 self._wait_deadline[proc] = self._timestamp + stmt.interval
356 self._suspended.add(proc)
357 elif isinstance(stmt, Tick):
358 self._wait_tick[proc] = stmt.domain
359 self._suspended.add(proc)
360 elif isinstance(stmt, Passive):
361 self._passive.add(proc)
362 elif isinstance(stmt, Assign):
363 assert isinstance(stmt.lhs, Signal)
364 assert isinstance(stmt.rhs, Const)
365 self._force_signal(stmt.lhs, normalize(stmt.rhs.value, stmt.lhs.shape()))
366 else:
367 raise TypeError("Received unsupported statement '{!r}' from process {}"
368 .format(stmt, proc))
369
370 def step(self, run_passive=False):
371 # Are there any delta cycles we should run?
372 while self._state.curr_dirty:
373 self._timestamp += 1e-10
374 self._handle_event()
375
376 # Are there any processes that haven't had a chance to run yet?
377 if len(self._processes) > len(self._suspended):
378 # Schedule an arbitrary one.
379 proc = (self._processes - set(self._suspended)).pop()
380 self._run_process(proc)
381 return True
382
383 # All processes are suspended. Are any of them active?
384 if len(self._processes) > len(self._passive) or run_passive:
385 # Are any of them suspended before a deadline?
386 if self._wait_deadline:
387 # Schedule the one with the lowest deadline.
388 proc, deadline = min(self._wait_deadline.items(), key=lambda x: x[1])
389 del self._wait_deadline[proc]
390 self._suspended.remove(proc)
391 self._timestamp = deadline
392 self._run_process(proc)
393 return True
394
395 # No processes, or all processes are passive. Nothing to do!
396 return False
397
398 def run_until(self, deadline, run_passive=False):
399 while self._timestamp < deadline:
400 if not self.step(run_passive):
401 return False
402 return True
403
404 def __exit__(self, *args):
405 if self._vcd_writer:
406 self._vcd_writer.close(self._timestamp * 1e10)