1 from vcd
import VCDWriter
3 from ..tools
import flatten
4 from ..fhdl
.ast
import *
5 from ..fhdl
.xfrm
import ValueTransformer
, StatementTransformer
8 __all__
= ["Simulator", "Delay", "Passive"]
12 __slots__
= ("curr", "curr_dirty", "next", "next_dirty")
15 self
.curr
= ValueDict()
16 self
.next
= ValueDict()
17 self
.curr_dirty
= ValueSet()
18 self
.next_dirty
= ValueSet()
20 def get(self
, signal
):
21 return self
.curr
[signal
]
23 def set_curr(self
, signal
, value
):
24 assert isinstance(value
, int)
25 if self
.curr
[signal
] != value
:
26 self
.curr_dirty
.add(signal
)
27 self
.curr
[signal
] = value
29 def set_next(self
, signal
, value
):
30 assert isinstance(value
, int)
31 if self
.next
[signal
] != value
:
32 self
.next_dirty
.add(signal
)
33 self
.next
[signal
] = value
35 def commit(self
, signal
):
36 old_value
= self
.curr
[signal
]
37 if self
.curr
[signal
] != self
.next
[signal
]:
38 self
.next_dirty
.remove(signal
)
39 self
.curr_dirty
.add(signal
)
40 self
.curr
[signal
] = self
.next
[signal
]
41 new_value
= self
.curr
[signal
]
42 return old_value
, new_value
45 dirty
, self
.dirty
= self
.dirty
, ValueSet()
47 yield signal
, self
.curr
[signal
], self
.next
[signal
]
50 normalize
= Const
.normalize
53 class _RHSValueCompiler(ValueTransformer
):
54 def __init__(self
, sensitivity
):
55 self
.sensitivity
= sensitivity
57 def on_Const(self
, value
):
58 return lambda state
: value
.value
60 def on_Signal(self
, value
):
61 self
.sensitivity
.add(value
)
62 return lambda state
: state
.get(value
)
64 def on_ClockSignal(self
, value
):
65 raise NotImplementedError
67 def on_ResetSignal(self
, value
):
68 raise NotImplementedError
70 def on_Operator(self
, value
):
72 if len(value
.operands
) == 1:
73 arg
, = map(self
, value
.operands
)
75 return lambda state
: normalize(~
arg(state
), shape
)
77 return lambda state
: normalize(-arg(state
), shape
)
78 elif len(value
.operands
) == 2:
79 lhs
, rhs
= map(self
, value
.operands
)
81 return lambda state
: normalize(lhs(state
) + rhs(state
), shape
)
83 return lambda state
: normalize(lhs(state
) - rhs(state
), shape
)
85 return lambda state
: normalize(lhs(state
) & rhs(state
), shape
)
87 return lambda state
: normalize(lhs(state
) |
rhs(state
), shape
)
89 return lambda state
: normalize(lhs(state
) ^
rhs(state
), shape
)
91 return lambda state
: normalize(lhs(state
) == rhs(state
), shape
)
92 elif len(value
.operands
) == 3:
94 sel
, val1
, val0
= map(self
, value
.operands
)
95 return lambda state
: val1(state
) if sel(state
) else val0(state
)
96 raise NotImplementedError("Operator '{}' not implemented".format(value
.op
))
98 def on_Slice(self
, value
):
100 arg
= self(value
.value
)
102 mask
= (1 << (value
.end
- value
.start
)) - 1
103 return lambda state
: normalize((arg(state
) >> shift
) & mask
, shape
)
105 def on_Part(self
, value
):
106 raise NotImplementedError
108 def on_Cat(self
, value
):
109 shape
= value
.shape()
112 for opnd
in value
.operands
:
113 parts
.append((offset
, (1 << len(opnd
)) - 1, self(opnd
)))
117 for offset
, mask
, opnd
in parts
:
118 result |
= (opnd(state
) & mask
) << offset
119 return normalize(result
, shape
)
122 def on_Repl(self
, value
):
123 shape
= value
.shape()
124 offset
= len(value
.value
)
125 mask
= (1 << len(value
.value
)) - 1
127 opnd
= self(value
.value
)
130 for _
in range(count
):
132 result |
= opnd(state
)
133 return normalize(result
, shape
)
137 class _StatementCompiler(StatementTransformer
):
139 self
.sensitivity
= ValueSet()
140 self
.rhs_compiler
= _RHSValueCompiler(self
.sensitivity
)
142 def lhs_compiler(self
, value
):
144 return lambda state
, arg
: state
.set_next(value
, arg
)
146 def on_Assign(self
, stmt
):
147 assert isinstance(stmt
.lhs
, Signal
)
148 shape
= stmt
.lhs
.shape()
149 lhs
= self
.lhs_compiler(stmt
.lhs
)
150 rhs
= self
.rhs_compiler(stmt
.rhs
)
152 lhs(state
, normalize(rhs(state
), shape
))
155 def on_Switch(self
, stmt
):
156 test
= self
.rhs_compiler(stmt
.test
)
158 for value
, stmts
in stmt
.cases
.items():
160 mask
= "".join("0" if b
== "-" else "1" for b
in value
)
161 value
= "".join("0" if b
== "-" else b
for b
in value
)
163 mask
= "1" * len(value
)
165 value
= int(value
, 2)
166 cases
.append((lambda test
: test
& mask
== value
,
167 self
.on_statements(stmts
)))
169 test_value
= test(state
)
170 for check
, body
in cases
:
171 if check(test_value
):
176 def on_statements(self
, stmts
):
177 stmts
= [self
.on_statement(stmt
) for stmt
in stmts
]
185 def __init__(self
, fragment
=None, vcd_file
=None):
186 self
._fragments
= {} # fragment -> hierarchy
187 self
._domains
= {} # str/domain -> ClockDomain
188 self
._domain
_triggers
= ValueDict() # Signal -> str/domain
189 self
._domain
_signals
= {} # str/domain -> {Signal}
190 self
._signals
= ValueSet() # {Signal}
191 self
._comb
_signals
= ValueSet() # {Signal}
192 self
._sync
_signals
= ValueSet() # {Signal}
193 self
._user
_signals
= ValueSet() # {Signal}
195 self
._started
= False
197 self
._state
= _State()
199 self
._processes
= set() # {process}
200 self
._passive
= set() # {process}
201 self
._suspended
= set() # {process}
202 self
._wait
_deadline
= {} # process -> float/timestamp
203 self
._wait
_tick
= {} # process -> str/domain
205 self
._handlers
= ValueDict() # Signal -> set(lambda)
207 self
._vcd
_file
= vcd_file
208 self
._vcd
_writer
= None
209 self
._vcd
_signals
= ValueDict() # signal -> set(vcd_signal)
211 if fragment
is not None:
212 fragment
= fragment
.prepare()
213 self
._add
_fragment
(fragment
)
214 self
._domains
= fragment
.domains
215 for domain
, cd
in self
._domains
.items():
216 self
._domain
_triggers
[cd
.clk
] = domain
217 if cd
.rst
is not None:
218 self
._domain
_triggers
[cd
.rst
] = domain
219 self
._domain
_signals
[domain
] = ValueSet()
221 def _add_fragment(self
, fragment
, hierarchy
=("top",)):
222 self
._fragments
[fragment
] = hierarchy
223 for subfragment
, name
in fragment
.subfragments
:
224 self
._add
_fragment
(subfragment
, (*hierarchy
, name
))
226 def add_process(self
, process
):
227 self
._processes
.add(process
)
229 def add_clock(self
, domain
, period
):
230 clk
= self
._domains
[domain
].clk
231 half_period
= period
/ 2
236 yield Delay(half_period
)
238 yield Delay(half_period
)
239 self
.add_process(clk_process())
241 def add_sync_process(self
, process
, domain
="sync"):
244 result
= process
.send(None)
246 result
= process
.send((yield (result
or Tick(domain
))))
247 except StopIteration:
249 self
.add_process(sync_process())
251 def _signal_name_in_fragment(self
, fragment
, signal
):
252 for subfragment
, name
in fragment
.subfragments
:
253 if signal
in subfragment
.ports
:
254 return "{}_{}".format(name
, signal
.name
)
257 def _add_handler(self
, signal
, handler
):
258 if signal
not in self
._handlers
:
259 self
._handlers
[signal
] = set()
260 self
._handlers
[signal
].add(handler
)
264 self
._vcd
_writer
= VCDWriter(self
._vcd
_file
, timescale
="100 ps",
265 comment
="Generated by nMigen")
267 for fragment
in self
._fragments
:
268 for signal
in fragment
.iter_signals():
269 self
._signals
.add(signal
)
271 self
._state
.curr
[signal
] = self
._state
.next
[signal
] = \
272 normalize(signal
.reset
, signal
.shape())
273 self
._state
.curr_dirty
.add(signal
)
275 if signal
not in self
._vcd
_signals
:
276 self
._vcd
_signals
[signal
] = set()
277 name
= self
._signal
_name
_in
_fragment
(fragment
, signal
)
284 name_suffix
= "{}${}".format(name
, suffix
)
285 self
._vcd
_signals
[signal
].add(self
._vcd
_writer
.register_var(
286 scope
=".".join(self
._fragments
[fragment
]), name
=name_suffix
,
287 var_type
="wire", size
=signal
.nbits
, init
=signal
.reset
))
290 suffix
= (suffix
or 0) + 1
292 for domain
, signals
in fragment
.drivers
.items():
294 self
._comb
_signals
.update(signals
)
296 self
._sync
_signals
.update(signals
)
297 self
._domain
_signals
[domain
].update(signals
)
299 compiler
= _StatementCompiler()
300 handler
= compiler(fragment
.statements
)
301 for signal
in compiler
.sensitivity
:
302 self
._add
_handler
(signal
, handler
)
303 for domain
, cd
in fragment
.domains
.items():
304 self
._add
_handler
(cd
.clk
, handler
)
305 if cd
.rst
is not None:
306 self
._add
_handler
(cd
.rst
, handler
)
308 self
._user
_signals
= self
._signals
- self
._comb
_signals
- self
._sync
_signals
310 def _commit_signal(self
, signal
):
311 old
, new
= self
._state
.commit(signal
)
312 if (old
, new
) == (0, 1) and signal
in self
._domain
_triggers
:
313 domain
= self
._domain
_triggers
[signal
]
314 for sync_signal
in self
._state
.next_dirty
:
315 if sync_signal
in self
._domain
_signals
[domain
]:
316 self
._commit
_signal
(sync_signal
)
318 for proc
, wait_domain
in list(self
._wait
_tick
.items()):
319 if domain
== wait_domain
:
320 del self
._wait
_tick
[proc
]
321 self
._suspended
.remove(proc
)
324 for vcd_signal
in self
._vcd
_signals
[signal
]:
325 self
._vcd
_writer
.change(vcd_signal
, self
._timestamp
* 1e10
, new
)
327 def _handle_event(self
):
329 while self
._state
.curr_dirty
:
330 signal
= self
._state
.curr_dirty
.pop()
331 if signal
in self
._handlers
:
332 handlers
.update(self
._handlers
[signal
])
334 for handler
in handlers
:
337 for signal
in self
._state
.next_dirty
:
338 if signal
in self
._comb
_signals
or signal
in self
._user
_signals
:
339 self
._commit
_signal
(signal
)
341 def _force_signal(self
, signal
, value
):
342 assert signal
in self
._user
_signals
343 self
._state
.set_next(signal
, value
)
344 self
._commit
_signal
(signal
)
346 def _run_process(self
, proc
):
348 stmt
= proc
.send(None)
349 except StopIteration:
350 self
._processes
.remove(proc
)
351 self
._passive
.discard(proc
)
354 if isinstance(stmt
, Delay
):
355 self
._wait
_deadline
[proc
] = self
._timestamp
+ stmt
.interval
356 self
._suspended
.add(proc
)
357 elif isinstance(stmt
, Tick
):
358 self
._wait
_tick
[proc
] = stmt
.domain
359 self
._suspended
.add(proc
)
360 elif isinstance(stmt
, Passive
):
361 self
._passive
.add(proc
)
362 elif isinstance(stmt
, Assign
):
363 assert isinstance(stmt
.lhs
, Signal
)
364 assert isinstance(stmt
.rhs
, Const
)
365 self
._force
_signal
(stmt
.lhs
, normalize(stmt
.rhs
.value
, stmt
.lhs
.shape()))
367 raise TypeError("Received unsupported statement '{!r}' from process {}"
370 def step(self
, run_passive
=False):
371 # Are there any delta cycles we should run?
372 while self
._state
.curr_dirty
:
373 self
._timestamp
+= 1e-10
376 # Are there any processes that haven't had a chance to run yet?
377 if len(self
._processes
) > len(self
._suspended
):
378 # Schedule an arbitrary one.
379 proc
= (self
._processes
- set(self
._suspended
)).pop()
380 self
._run
_process
(proc
)
383 # All processes are suspended. Are any of them active?
384 if len(self
._processes
) > len(self
._passive
) or run_passive
:
385 # Are any of them suspended before a deadline?
386 if self
._wait
_deadline
:
387 # Schedule the one with the lowest deadline.
388 proc
, deadline
= min(self
._wait
_deadline
.items(), key
=lambda x
: x
[1])
389 del self
._wait
_deadline
[proc
]
390 self
._suspended
.remove(proc
)
391 self
._timestamp
= deadline
392 self
._run
_process
(proc
)
395 # No processes, or all processes are passive. Nothing to do!
398 def run_until(self
, deadline
, run_passive
=False):
399 while self
._timestamp
< deadline
:
400 if not self
.step(run_passive
):
404 def __exit__(self
, *args
):
406 self
._vcd
_writer
.close(self
._timestamp
* 1e10
)