bump to 0.2.dev
[litex.git] / litex / gen / fhdl / tools.py
1 from litex.gen.fhdl.structure import *
2 from litex.gen.fhdl.structure import _Slice, _Assign, _Fragment
3 from litex.gen.fhdl.visit import NodeVisitor, NodeTransformer
4 from litex.gen.fhdl.bitcontainer import value_bits_sign
5 from litex.gen.util.misc import flat_iteration
6
7
8 class _SignalLister(NodeVisitor):
9 def __init__(self):
10 self.output_list = set()
11
12 def visit_Signal(self, node):
13 self.output_list.add(node)
14
15
16 class _TargetLister(NodeVisitor):
17 def __init__(self):
18 self.output_list = set()
19 self.target_context = False
20
21 def visit_Signal(self, node):
22 if self.target_context:
23 self.output_list.add(node)
24
25 def visit_Assign(self, node):
26 self.target_context = True
27 self.visit(node.l)
28 self.target_context = False
29
30 def visit_ArrayProxy(self, node):
31 for choice in node.choices:
32 self.visit(choice)
33
34
35 class _InputLister(NodeVisitor):
36 def __init__(self):
37 self.output_list = set()
38
39 def visit_Signal(self, node):
40 self.output_list.add(node)
41
42 def visit_Assign(self, node):
43 self.visit(node.r)
44
45
46 def list_signals(node):
47 lister = _SignalLister()
48 lister.visit(node)
49 return lister.output_list
50
51
52 def list_targets(node):
53 lister = _TargetLister()
54 lister.visit(node)
55 return lister.output_list
56
57
58 def list_inputs(node):
59 lister = _InputLister()
60 lister.visit(node)
61 return lister.output_list
62
63
64 def _resort_statements(ol):
65 return [statement for i, statement in
66 sorted(ol, key=lambda x: x[0])]
67
68
69 def group_by_targets(sl):
70 groups = []
71 seen = set()
72 for order, stmt in enumerate(flat_iteration(sl)):
73 targets = set(list_targets(stmt))
74 group = [(order, stmt)]
75 disjoint = targets.isdisjoint(seen)
76 seen |= targets
77 if not disjoint:
78 groups, old_groups = [], groups
79 for old_targets, old_group in old_groups:
80 if targets.isdisjoint(old_targets):
81 groups.append((old_targets, old_group))
82 else:
83 targets |= old_targets
84 group += old_group
85 groups.append((targets, group))
86 return [(targets, _resort_statements(stmts))
87 for targets, stmts in groups]
88
89
90 def list_special_ios(f, ins, outs, inouts):
91 r = set()
92 for special in f.specials:
93 r |= special.list_ios(ins, outs, inouts)
94 return r
95
96
97 class _ClockDomainLister(NodeVisitor):
98 def __init__(self):
99 self.clock_domains = set()
100
101 def visit_ClockSignal(self, node):
102 self.clock_domains.add(node.cd)
103
104 def visit_ResetSignal(self, node):
105 self.clock_domains.add(node.cd)
106
107 def visit_clock_domains(self, node):
108 for clockname, statements in node.items():
109 self.clock_domains.add(clockname)
110 self.visit(statements)
111
112
113 def list_clock_domains_expr(f):
114 cdl = _ClockDomainLister()
115 cdl.visit(f)
116 return cdl.clock_domains
117
118
119 def list_clock_domains(f):
120 r = list_clock_domains_expr(f)
121 for special in f.specials:
122 r |= special.list_clock_domains()
123 for cd in f.clock_domains:
124 r.add(cd.name)
125 return r
126
127
128 def is_variable(node):
129 if isinstance(node, Signal):
130 return node.variable
131 elif isinstance(node, _Slice):
132 return is_variable(node.value)
133 elif isinstance(node, Cat):
134 arevars = list(map(is_variable, node.l))
135 r = arevars[0]
136 for x in arevars:
137 if x != r:
138 raise TypeError
139 return r
140 else:
141 raise TypeError
142
143
144 def generate_reset(rst, sl):
145 targets = list_targets(sl)
146 return [t.eq(t.reset) for t in sorted(targets, key=lambda x: x.duid)
147 if not t.reset_less]
148
149
150 def insert_reset(rst, sl):
151 return sl + [If(rst, *generate_reset(rst, sl))]
152
153
154 def insert_resets(f):
155 newsync = dict()
156 for k, v in f.sync.items():
157 if f.clock_domains[k].rst is not None:
158 newsync[k] = insert_reset(ResetSignal(k), v)
159 else:
160 newsync[k] = v
161 f.sync = newsync
162
163
164 class _Lowerer(NodeTransformer):
165 def __init__(self):
166 self.target_context = False
167 self.extra_stmts = []
168 self.comb = []
169
170 def visit_Assign(self, node):
171 old_target_context, old_extra_stmts = self.target_context, self.extra_stmts
172 self.extra_stmts = []
173
174 self.target_context = True
175 lhs = self.visit(node.l)
176 self.target_context = False
177 rhs = self.visit(node.r)
178 r = _Assign(lhs, rhs)
179 if self.extra_stmts:
180 r = [r] + self.extra_stmts
181
182 self.target_context, self.extra_stmts = old_target_context, old_extra_stmts
183 return r
184
185
186 # Basics are FHDL structure elements that back-ends are not required to support
187 # but can be expressed in terms of other elements (lowered) before conversion.
188 class _BasicLowerer(_Lowerer):
189 def __init__(self, clock_domains):
190 self.clock_domains = clock_domains
191 _Lowerer.__init__(self)
192
193 def visit_ArrayProxy(self, node):
194 # TODO: rewrite without variables
195 array_muxed = Signal(value_bits_sign(node), variable=True)
196 if self.target_context:
197 k = self.visit(node.key)
198 cases = {}
199 for n, choice in enumerate(node.choices):
200 cases[n] = [self.visit_Assign(_Assign(choice, array_muxed))]
201 self.extra_stmts.append(Case(k, cases).makedefault())
202 else:
203 cases = dict((n, _Assign(array_muxed, self.visit(choice)))
204 for n, choice in enumerate(node.choices))
205 self.comb.append(Case(self.visit(node.key), cases).makedefault())
206 return array_muxed
207
208 def visit_ClockSignal(self, node):
209 return self.clock_domains[node.cd].clk
210
211 def visit_ResetSignal(self, node):
212 rst = self.clock_domains[node.cd].rst
213 if rst is None:
214 if node.allow_reset_less:
215 return 0
216 else:
217 raise ValueError("Attempted to get reset signal of resetless"
218 " domain '{}'".format(node.cd))
219 else:
220 return rst
221
222
223 class _ComplexSliceLowerer(_Lowerer):
224 def visit_Slice(self, node):
225 if not isinstance(node.value, Signal):
226 slice_proxy = Signal(value_bits_sign(node.value))
227 if self.target_context:
228 a = _Assign(node.value, slice_proxy)
229 else:
230 a = _Assign(slice_proxy, node.value)
231 self.comb.append(self.visit_Assign(a))
232 node = _Slice(slice_proxy, node.start, node.stop)
233 return NodeTransformer.visit_Slice(self, node)
234
235
236 def _apply_lowerer(l, f):
237 f = l.visit(f)
238 f.comb += l.comb
239
240 for special in sorted(f.specials, key=lambda s: s.duid):
241 for obj, attr, direction in special.iter_expressions():
242 if direction != SPECIAL_INOUT:
243 # inouts are only supported by Migen when connected directly to top-level
244 # in this case, they are Signal and never need lowering
245 l.comb = []
246 l.target_context = direction != SPECIAL_INPUT
247 l.extra_stmts = []
248 expr = getattr(obj, attr)
249 expr = l.visit(expr)
250 setattr(obj, attr, expr)
251 f.comb += l.comb + l.extra_stmts
252
253 return f
254
255
256 def lower_basics(f):
257 return _apply_lowerer(_BasicLowerer(f.clock_domains), f)
258
259
260 def lower_complex_slices(f):
261 return _apply_lowerer(_ComplexSliceLowerer(), f)
262
263
264 class _ClockDomainRenamer(NodeVisitor):
265 def __init__(self, old, new):
266 self.old = old
267 self.new = new
268
269 def visit_ClockSignal(self, node):
270 if node.cd == self.old:
271 node.cd = self.new
272
273 def visit_ResetSignal(self, node):
274 if node.cd == self.old:
275 node.cd = self.new
276
277
278 def rename_clock_domain_expr(f, old, new):
279 cdr = _ClockDomainRenamer(old, new)
280 cdr.visit(f)
281
282
283 def rename_clock_domain(f, old, new):
284 rename_clock_domain_expr(f, old, new)
285 if new != old:
286 if old in f.sync:
287 if new in f.sync:
288 f.sync[new].extend(f.sync[old])
289 else:
290 f.sync[new] = f.sync[old]
291 del f.sync[old]
292 for special in f.specials:
293 special.rename_clock_domain(old, new)
294 try:
295 cd = f.clock_domains[old]
296 except KeyError:
297 pass
298 else:
299 cd.rename(new)
300
301
302 def call_special_classmethod(overrides, obj, method, *args, **kwargs):
303 cl = obj.__class__
304 if cl in overrides:
305 cl = overrides[cl]
306 if hasattr(cl, method):
307 return getattr(cl, method)(obj, *args, **kwargs)
308 else:
309 return None
310
311
312 def _lower_specials_step(overrides, specials):
313 f = _Fragment()
314 lowered_specials = set()
315 for special in sorted(specials, key=lambda x: x.duid):
316 impl = call_special_classmethod(overrides, special, "lower")
317 if impl is not None:
318 f += impl.get_fragment()
319 lowered_specials.add(special)
320 return f, lowered_specials
321
322
323 def _can_lower(overrides, specials):
324 for special in specials:
325 cl = special.__class__
326 if cl in overrides:
327 cl = overrides[cl]
328 if hasattr(cl, "lower"):
329 return True
330 return False
331
332
333 def lower_specials(overrides, specials):
334 f, lowered_specials = _lower_specials_step(overrides, specials)
335 while _can_lower(overrides, f.specials):
336 f2, lowered_specials2 = _lower_specials_step(overrides, f.specials)
337 f += f2
338 lowered_specials |= lowered_specials2
339 f.specials -= lowered_specials2
340 return f, lowered_specials