litex: reorganize things, first work working version
[litex.git] / litex / gen / fhdl / tools.py
1 from litex.gen.fhdl.structure import *
2 from litex.gen.fhdl.structure import _Slice, _Assign
3 from litex.gen.fhdl.visit import NodeVisitor, NodeTransformer
4 from litex.gen.fhdl.bitcontainer import value_bits_sign
5 from litex.gen.util.misc import flat_iteration
6
7
8 class _SignalLister(NodeVisitor):
9 def __init__(self):
10 self.output_list = set()
11
12 def visit_Signal(self, node):
13 self.output_list.add(node)
14
15
16 class _TargetLister(NodeVisitor):
17 def __init__(self):
18 self.output_list = set()
19 self.target_context = False
20
21 def visit_Signal(self, node):
22 if self.target_context:
23 self.output_list.add(node)
24
25 def visit_Assign(self, node):
26 self.target_context = True
27 self.visit(node.l)
28 self.target_context = False
29
30 def visit_ArrayProxy(self, node):
31 for choice in node.choices:
32 self.visit(choice)
33
34
35 class _InputLister(NodeVisitor):
36 def __init__(self):
37 self.output_list = set()
38
39 def visit_Signal(self, node):
40 self.output_list.add(node)
41
42 def visit_Assign(self, node):
43 self.visit(node.r)
44
45
46 def list_signals(node):
47 lister = _SignalLister()
48 lister.visit(node)
49 return lister.output_list
50
51
52 def list_targets(node):
53 lister = _TargetLister()
54 lister.visit(node)
55 return lister.output_list
56
57
58 def list_inputs(node):
59 lister = _InputLister()
60 lister.visit(node)
61 return lister.output_list
62
63
64 def _resort_statements(ol):
65 return [statement for i, statement in
66 sorted(ol, key=lambda x: x[0])]
67
68
69 def group_by_targets(sl):
70 groups = []
71 seen = set()
72 for order, stmt in enumerate(flat_iteration(sl)):
73 targets = set(list_targets(stmt))
74 group = [(order, stmt)]
75 disjoint = targets.isdisjoint(seen)
76 seen |= targets
77 if not disjoint:
78 groups, old_groups = [], groups
79 for old_targets, old_group in old_groups:
80 if targets.isdisjoint(old_targets):
81 groups.append((old_targets, old_group))
82 else:
83 targets |= old_targets
84 group += old_group
85 groups.append((targets, group))
86 return [(targets, _resort_statements(stmts))
87 for targets, stmts in groups]
88
89
90 def list_special_ios(f, ins, outs, inouts):
91 r = set()
92 for special in f.specials:
93 r |= special.list_ios(ins, outs, inouts)
94 return r
95
96
97 class _ClockDomainLister(NodeVisitor):
98 def __init__(self):
99 self.clock_domains = set()
100
101 def visit_ClockSignal(self, node):
102 self.clock_domains.add(node.cd)
103
104 def visit_ResetSignal(self, node):
105 self.clock_domains.add(node.cd)
106
107 def visit_clock_domains(self, node):
108 for clockname, statements in node.items():
109 self.clock_domains.add(clockname)
110 self.visit(statements)
111
112
113 def list_clock_domains_expr(f):
114 cdl = _ClockDomainLister()
115 cdl.visit(f)
116 return cdl.clock_domains
117
118
119 def list_clock_domains(f):
120 r = list_clock_domains_expr(f)
121 for special in f.specials:
122 r |= special.list_clock_domains()
123 for cd in f.clock_domains:
124 r.add(cd.name)
125 return r
126
127
128 def is_variable(node):
129 if isinstance(node, Signal):
130 return node.variable
131 elif isinstance(node, _Slice):
132 return is_variable(node.value)
133 elif isinstance(node, Cat):
134 arevars = list(map(is_variable, node.l))
135 r = arevars[0]
136 for x in arevars:
137 if x != r:
138 raise TypeError
139 return r
140 else:
141 raise TypeError
142
143
144 def generate_reset(rst, sl):
145 targets = list_targets(sl)
146 return [t.eq(t.reset) for t in sorted(targets, key=lambda x: x.duid)]
147
148
149 def insert_reset(rst, sl):
150 return [If(rst, *generate_reset(rst, sl)).Else(*sl)]
151
152
153 def insert_resets(f):
154 newsync = dict()
155 for k, v in f.sync.items():
156 if f.clock_domains[k].rst is not None:
157 newsync[k] = insert_reset(ResetSignal(k), v)
158 else:
159 newsync[k] = v
160 f.sync = newsync
161
162
163 class _Lowerer(NodeTransformer):
164 def __init__(self):
165 self.target_context = False
166 self.extra_stmts = []
167 self.comb = []
168
169 def visit_Assign(self, node):
170 old_target_context, old_extra_stmts = self.target_context, self.extra_stmts
171 self.extra_stmts = []
172
173 self.target_context = True
174 lhs = self.visit(node.l)
175 self.target_context = False
176 rhs = self.visit(node.r)
177 r = _Assign(lhs, rhs)
178 if self.extra_stmts:
179 r = [r] + self.extra_stmts
180
181 self.target_context, self.extra_stmts = old_target_context, old_extra_stmts
182 return r
183
184
185 # Basics are FHDL structure elements that back-ends are not required to support
186 # but can be expressed in terms of other elements (lowered) before conversion.
187 class _BasicLowerer(_Lowerer):
188 def __init__(self, clock_domains):
189 self.clock_domains = clock_domains
190 _Lowerer.__init__(self)
191
192 def visit_ArrayProxy(self, node):
193 # TODO: rewrite without variables
194 array_muxed = Signal(value_bits_sign(node), variable=True)
195 if self.target_context:
196 k = self.visit(node.key)
197 cases = {}
198 for n, choice in enumerate(node.choices):
199 cases[n] = [self.visit_Assign(_Assign(choice, array_muxed))]
200 self.extra_stmts.append(Case(k, cases).makedefault())
201 else:
202 cases = dict((n, _Assign(array_muxed, self.visit(choice)))
203 for n, choice in enumerate(node.choices))
204 self.comb.append(Case(self.visit(node.key), cases).makedefault())
205 return array_muxed
206
207 def visit_ClockSignal(self, node):
208 return self.clock_domains[node.cd].clk
209
210 def visit_ResetSignal(self, node):
211 rst = self.clock_domains[node.cd].rst
212 if rst is None:
213 if node.allow_reset_less:
214 return 0
215 else:
216 raise ValueError("Attempted to get reset signal of resetless"
217 " domain '{}'".format(node.cd))
218 else:
219 return rst
220
221
222 class _ComplexSliceLowerer(_Lowerer):
223 def visit_Slice(self, node):
224 if not isinstance(node.value, Signal):
225 slice_proxy = Signal(value_bits_sign(node.value))
226 if self.target_context:
227 a = _Assign(node.value, slice_proxy)
228 else:
229 a = _Assign(slice_proxy, node.value)
230 self.comb.append(self.visit_Assign(a))
231 node = _Slice(slice_proxy, node.start, node.stop)
232 return NodeTransformer.visit_Slice(self, node)
233
234
235 def _apply_lowerer(l, f):
236 f = l.visit(f)
237 f.comb += l.comb
238
239 for special in f.specials:
240 for obj, attr, direction in special.iter_expressions():
241 if direction != SPECIAL_INOUT:
242 # inouts are only supported by Migen when connected directly to top-level
243 # in this case, they are Signal and never need lowering
244 l.comb = []
245 l.target_context = direction != SPECIAL_INPUT
246 l.extra_stmts = []
247 expr = getattr(obj, attr)
248 expr = l.visit(expr)
249 setattr(obj, attr, expr)
250 f.comb += l.comb + l.extra_stmts
251
252 return f
253
254
255 def lower_basics(f):
256 return _apply_lowerer(_BasicLowerer(f.clock_domains), f)
257
258
259 def lower_complex_slices(f):
260 return _apply_lowerer(_ComplexSliceLowerer(), f)
261
262
263 class _ClockDomainRenamer(NodeVisitor):
264 def __init__(self, old, new):
265 self.old = old
266 self.new = new
267
268 def visit_ClockSignal(self, node):
269 if node.cd == self.old:
270 node.cd = self.new
271
272 def visit_ResetSignal(self, node):
273 if node.cd == self.old:
274 node.cd = self.new
275
276
277 def rename_clock_domain_expr(f, old, new):
278 cdr = _ClockDomainRenamer(old, new)
279 cdr.visit(f)
280
281
282 def rename_clock_domain(f, old, new):
283 rename_clock_domain_expr(f, old, new)
284 if new != old:
285 if old in f.sync:
286 if new in f.sync:
287 f.sync[new].extend(f.sync[old])
288 else:
289 f.sync[new] = f.sync[old]
290 del f.sync[old]
291 for special in f.specials:
292 special.rename_clock_domain(old, new)
293 try:
294 cd = f.clock_domains[old]
295 except KeyError:
296 pass
297 else:
298 cd.rename(new)