f89cd02c32cd978b6b6bdee6c4bf1484d9a3af97
[litex.git] / migen / sim / core.py
1 import operator
2 import collections
3 import inspect
4
5 from migen.fhdl.structure import *
6 from migen.fhdl.structure import (_Value, _Statement,
7 _Operator, _Slice, _ArrayProxy,
8 _Assign, _Fragment)
9 from migen.fhdl.bitcontainer import value_bits_sign
10 from migen.fhdl.tools import list_signals, list_targets, insert_resets
11 from migen.fhdl.simplify import MemoryToArray
12 from migen.fhdl.specials import _MemoryLocation
13 from migen.sim.vcd import VCDWriter, DummyVCDWriter
14
15
16 class ClockState:
17 def __init__(self, high, half_period, time_before_trans):
18 self.high = high
19 self.half_period = half_period
20 self.time_before_trans = time_before_trans
21
22
23 class TimeManager:
24 def __init__(self, description):
25 self.clocks = dict()
26
27 for k, period_phase in description.items():
28 if isinstance(period_phase, tuple):
29 period, phase = period_phase
30 else:
31 period = period_phase
32 phase = 0
33 half_period = period//2
34 if phase >= half_period:
35 phase -= half_period
36 high = True
37 else:
38 high = False
39 self.clocks[k] = ClockState(high, half_period, half_period - phase)
40
41 def tick(self):
42 rising = set()
43 falling = set()
44 dt = min(cs.time_before_trans for cs in self.clocks.values())
45 for k, cs in self.clocks.items():
46 if cs.time_before_trans == dt:
47 cs.high = not cs.high
48 if cs.high:
49 rising.add(k)
50 else:
51 falling.add(k)
52 cs.time_before_trans -= dt
53 if not cs.time_before_trans:
54 cs.time_before_trans += cs.half_period
55 return dt, rising, falling
56
57
58 str2op = {
59 "~": operator.invert,
60 "+": operator.add,
61 "-": operator.sub,
62 "*": operator.mul,
63
64 ">>>": operator.rshift,
65 "<<<": operator.lshift,
66
67 "&": operator.and_,
68 "^": operator.xor,
69 "|": operator.or_,
70
71 "<": operator.lt,
72 "<=": operator.le,
73 "==": operator.eq,
74 "!=": operator.ne,
75 ">": operator.gt,
76 ">=": operator.ge,
77 }
78
79
80 def _truncate(value, nbits, signed):
81 value = value & (2**nbits - 1)
82 if signed and (value & 2**(nbits - 1)):
83 value -= 2**nbits
84 return value
85
86
87 class Evaluator:
88 def __init__(self, clock_domains, replaced_memories):
89 self.clock_domains = clock_domains
90 self.replaced_memories = replaced_memories
91 self.signal_values = dict()
92 self.modifications = dict()
93
94 def commit(self):
95 r = set()
96 for k, v in self.modifications.items():
97 if k not in self.signal_values or self.signal_values[k] != v:
98 self.signal_values[k] = v
99 r.add(k)
100 self.modifications.clear()
101 return r
102
103 def eval(self, node, postcommit=False):
104 if isinstance(node, Constant):
105 return node.value
106 elif isinstance(node, Signal):
107 if postcommit:
108 try:
109 return self.modifications[node]
110 except KeyError:
111 pass
112 try:
113 return self.signal_values[node]
114 except KeyError:
115 return node.reset.value
116 elif isinstance(node, _Operator):
117 operands = [self.eval(o, postcommit) for o in node.operands]
118 if node.op == "-":
119 if len(operands) == 1:
120 return -operands[0]
121 else:
122 return operands[0] - operands[1]
123 elif node.op == "m":
124 return operands[1] if operands[0] else operands[2]
125 else:
126 return str2op[node.op](*operands)
127 elif isinstance(node, _Slice):
128 v = self.eval(node.value, postcommit)
129 idx = range(node.start, node.stop)
130 return sum(((v >> i) & 1) << j for j, i in enumerate(idx))
131 elif isinstance(node, Cat):
132 shift = 0
133 r = 0
134 for element in node.l:
135 nbits = len(element)
136 # make value always positive
137 r |= (self.eval(element, postcommit) & (2**nbits-1)) << shift
138 shift += nbits
139 return r
140 elif isinstance(node, _ArrayProxy):
141 return self.eval(node.choices[self.eval(node.key, postcommit)],
142 postcommit)
143 elif isinstance(node, _MemoryLocation):
144 array = self.replaced_memories[node.memory]
145 return self.eval(array[self.eval(node.index, postcommit)], postcommit)
146 elif isinstance(node, ClockSignal):
147 return self.eval(self.clock_domains[node.cd].clk, postcommit)
148 elif isinstance(node, ResetSignal):
149 rst = self.clock_domains[node.cd].rst
150 if rst is None:
151 if node.allow_reset_less:
152 return 0
153 else:
154 raise ValueError("Attempted to get reset signal of resetless"
155 " domain '{}'".format(node.cd))
156 else:
157 return self.eval(rst, postcommit)
158 else:
159 raise NotImplementedError
160
161 def assign(self, node, value):
162 if isinstance(node, Signal):
163 assert not node.variable
164 self.modifications[node] = _truncate(value,
165 node.nbits, node.signed)
166 elif isinstance(node, Cat):
167 for element in node.l:
168 nbits = len(element)
169 self.assign(element, value & (2**nbits-1))
170 value >>= nbits
171 elif isinstance(node, _Slice):
172 full_value = self.eval(node.value, True)
173 # clear bits assigned to by the slice
174 full_value &= ~((2**node.stop-1) - (2**node.start-1))
175 # set them to the new value
176 value &= 2**(node.stop - node.start)-1
177 full_value |= value << node.start
178 self.assign(node.value, full_value)
179 elif isinstance(node, _ArrayProxy):
180 self.assign(node.choices[self.eval(node.key)], value)
181 elif isinstance(node, _MemoryLocation):
182 array = self.replaced_memories[node.memory]
183 self.assign(array[self.eval(node.index)], value)
184 else:
185 raise NotImplementedError
186
187 def execute(self, statements):
188 for s in statements:
189 if isinstance(s, _Assign):
190 self.assign(s.l, self.eval(s.r))
191 elif isinstance(s, If):
192 if self.eval(s.cond) & (2**len(s.cond) - 1):
193 self.execute(s.t)
194 else:
195 self.execute(s.f)
196 elif isinstance(s, Case):
197 nbits, signed = value_bits_sign(s.test)
198 test = _truncate(self.eval(s.test), nbits, signed)
199 found = False
200 for k, v in s.cases.items():
201 if isinstance(k, Constant) and k.value == test:
202 self.execute(v)
203 found = True
204 break
205 if not found and "default" in s.cases:
206 self.execute(s.cases["default"])
207 elif isinstance(s, collections.Iterable):
208 self.execute(s)
209 else:
210 raise NotImplementedError
211
212
213 # TODO: instances via Iverilog/VPI
214 class Simulator:
215 def __init__(self, fragment_or_module, generators, clocks={"sys": 10}, vcd_name=None):
216 if isinstance(fragment_or_module, _Fragment):
217 self.fragment = fragment_or_module
218 else:
219 self.fragment = fragment_or_module.get_fragment()
220 if not isinstance(generators, dict):
221 generators = {"sys": generators}
222 self.generators = dict()
223 for k, v in generators.items():
224 if (isinstance(v, collections.Iterable)
225 and not inspect.isgenerator(v)):
226 self.generators[k] = list(v)
227 else:
228 self.generators[k] = [v]
229
230 self.time = TimeManager(clocks)
231 for clock in clocks.keys():
232 if clock not in self.fragment.clock_domains:
233 cd = ClockDomain(name=clock, reset_less=True)
234 cd.clk.reset = C(self.time.clocks[clock].high)
235 self.fragment.clock_domains.append(cd)
236
237 mta = MemoryToArray()
238 mta.transform_fragment(None, self.fragment)
239 insert_resets(self.fragment)
240 # comb signals return to their reset value if nothing assigns them
241 self.fragment.comb[0:0] = [s.eq(s.reset)
242 for s in list_targets(self.fragment.comb)]
243 self.evaluator = Evaluator(self.fragment.clock_domains,
244 mta.replacements)
245
246 if vcd_name is None:
247 self.vcd = DummyVCDWriter()
248 else:
249 signals = list_signals(self.fragment)
250 for cd in self.fragment.clock_domains:
251 signals.add(cd.clk)
252 if cd.rst is not None:
253 signals.add(cd.rst)
254 for memory_array in mta.replacements.values():
255 signals |= set(memory_array)
256 signals = sorted(signals, key=lambda x: x.duid)
257 self.vcd = VCDWriter(vcd_name, signals)
258
259 def __enter__(self):
260 return self
261
262 def __exit__(self, type, value, traceback):
263 self.close()
264
265 def close(self):
266 self.vcd.close()
267
268 def _commit_and_comb_propagate(self):
269 # TODO: optimize
270 all_modified = set()
271 modified = self.evaluator.commit()
272 all_modified |= modified
273 while modified:
274 self.evaluator.execute(self.fragment.comb)
275 modified = self.evaluator.commit()
276 all_modified |= modified
277 for signal in all_modified:
278 self.vcd.set(signal, self.evaluator.signal_values[signal])
279
280 def _evalexec_nested_lists(self, x):
281 if isinstance(x, list):
282 return [self._evalexec_nested_lists(e) for e in x]
283 elif isinstance(x, _Value):
284 return self.evaluator.eval(x)
285 elif isinstance(x, _Statement):
286 self.evaluator.execute([x])
287 return None
288 else:
289 raise ValueError
290
291 def _process_generators(self, cd):
292 exhausted = []
293 for generator in self.generators[cd]:
294 reply = None
295 while True:
296 try:
297 request = generator.send(reply)
298 if request is None:
299 break # next cycle
300 else:
301 reply = self._evalexec_nested_lists(request)
302 except StopIteration:
303 exhausted.append(generator)
304 break
305 for generator in exhausted:
306 self.generators[cd].remove(generator)
307
308 def _continue_simulation(self):
309 # TODO: passive generators
310 return any(self.generators.values())
311
312 def run(self):
313 self.evaluator.execute(self.fragment.comb)
314 self._commit_and_comb_propagate()
315
316 while True:
317 dt, rising, falling = self.time.tick()
318 self.vcd.delay(dt)
319 for cd in rising:
320 self.evaluator.assign(self.fragment.clock_domains[cd].clk, 1)
321 if cd in self.fragment.sync:
322 self.evaluator.execute(self.fragment.sync[cd])
323 if cd in self.generators:
324 self._process_generators(cd)
325 for cd in falling:
326 self.evaluator.assign(self.fragment.clock_domains[cd].clk, 0)
327 self._commit_and_comb_propagate()
328
329 if not self._continue_simulation():
330 break
331
332
333 def run_simulation(*args, **kwargs):
334 with Simulator(*args, **kwargs) as s:
335 s.run()