1 from litex
.gen
.fhdl
.structure
import *
2 from litex
.gen
.fhdl
.structure
import _Slice
, _Assign
, _Fragment
3 from litex
.gen
.fhdl
.visit
import NodeVisitor
, NodeTransformer
4 from litex
.gen
.fhdl
.bitcontainer
import value_bits_sign
5 from litex
.gen
.util
.misc
import flat_iteration
8 class _SignalLister(NodeVisitor
):
10 self
.output_list
= set()
12 def visit_Signal(self
, node
):
13 self
.output_list
.add(node
)
16 class _TargetLister(NodeVisitor
):
18 self
.output_list
= set()
19 self
.target_context
= False
21 def visit_Signal(self
, node
):
22 if self
.target_context
:
23 self
.output_list
.add(node
)
25 def visit_Assign(self
, node
):
26 self
.target_context
= True
28 self
.target_context
= False
30 def visit_ArrayProxy(self
, node
):
31 for choice
in node
.choices
:
35 class _InputLister(NodeVisitor
):
37 self
.output_list
= set()
39 def visit_Signal(self
, node
):
40 self
.output_list
.add(node
)
42 def visit_Assign(self
, node
):
46 def list_signals(node
):
47 lister
= _SignalLister()
49 return lister
.output_list
52 def list_targets(node
):
53 lister
= _TargetLister()
55 return lister
.output_list
58 def list_inputs(node
):
59 lister
= _InputLister()
61 return lister
.output_list
64 def _resort_statements(ol
):
65 return [statement
for i
, statement
in
66 sorted(ol
, key
=lambda x
: x
[0])]
69 def group_by_targets(sl
):
72 for order
, stmt
in enumerate(flat_iteration(sl
)):
73 targets
= set(list_targets(stmt
))
74 group
= [(order
, stmt
)]
75 disjoint
= targets
.isdisjoint(seen
)
78 groups
, old_groups
= [], groups
79 for old_targets
, old_group
in old_groups
:
80 if targets
.isdisjoint(old_targets
):
81 groups
.append((old_targets
, old_group
))
83 targets |
= old_targets
85 groups
.append((targets
, group
))
86 return [(targets
, _resort_statements(stmts
))
87 for targets
, stmts
in groups
]
90 def list_special_ios(f
, ins
, outs
, inouts
):
92 for special
in f
.specials
:
93 r |
= special
.list_ios(ins
, outs
, inouts
)
97 class _ClockDomainLister(NodeVisitor
):
99 self
.clock_domains
= set()
101 def visit_ClockSignal(self
, node
):
102 self
.clock_domains
.add(node
.cd
)
104 def visit_ResetSignal(self
, node
):
105 self
.clock_domains
.add(node
.cd
)
107 def visit_clock_domains(self
, node
):
108 for clockname
, statements
in node
.items():
109 self
.clock_domains
.add(clockname
)
110 self
.visit(statements
)
113 def list_clock_domains_expr(f
):
114 cdl
= _ClockDomainLister()
116 return cdl
.clock_domains
119 def list_clock_domains(f
):
120 r
= list_clock_domains_expr(f
)
121 for special
in f
.specials
:
122 r |
= special
.list_clock_domains()
123 for cd
in f
.clock_domains
:
128 def is_variable(node
):
129 if isinstance(node
, Signal
):
131 elif isinstance(node
, _Slice
):
132 return is_variable(node
.value
)
133 elif isinstance(node
, Cat
):
134 arevars
= list(map(is_variable
, node
.l
))
144 def generate_reset(rst
, sl
):
145 targets
= list_targets(sl
)
146 return [t
.eq(t
.reset
) for t
in sorted(targets
, key
=lambda x
: x
.duid
)
150 def insert_reset(rst
, sl
):
151 return sl
+ [If(rst
, *generate_reset(rst
, sl
))]
154 def insert_resets(f
):
156 for k
, v
in f
.sync
.items():
157 if f
.clock_domains
[k
].rst
is not None:
158 newsync
[k
] = insert_reset(ResetSignal(k
), v
)
164 class _Lowerer(NodeTransformer
):
166 self
.target_context
= False
167 self
.extra_stmts
= []
170 def visit_Assign(self
, node
):
171 old_target_context
, old_extra_stmts
= self
.target_context
, self
.extra_stmts
172 self
.extra_stmts
= []
174 self
.target_context
= True
175 lhs
= self
.visit(node
.l
)
176 self
.target_context
= False
177 rhs
= self
.visit(node
.r
)
178 r
= _Assign(lhs
, rhs
)
180 r
= [r
] + self
.extra_stmts
182 self
.target_context
, self
.extra_stmts
= old_target_context
, old_extra_stmts
186 # Basics are FHDL structure elements that back-ends are not required to support
187 # but can be expressed in terms of other elements (lowered) before conversion.
188 class _BasicLowerer(_Lowerer
):
189 def __init__(self
, clock_domains
):
190 self
.clock_domains
= clock_domains
191 _Lowerer
.__init
__(self
)
193 def visit_ArrayProxy(self
, node
):
194 # TODO: rewrite without variables
195 array_muxed
= Signal(value_bits_sign(node
), variable
=True)
196 if self
.target_context
:
197 k
= self
.visit(node
.key
)
199 for n
, choice
in enumerate(node
.choices
):
200 cases
[n
] = [self
.visit_Assign(_Assign(choice
, array_muxed
))]
201 self
.extra_stmts
.append(Case(k
, cases
).makedefault())
203 cases
= dict((n
, _Assign(array_muxed
, self
.visit(choice
)))
204 for n
, choice
in enumerate(node
.choices
))
205 self
.comb
.append(Case(self
.visit(node
.key
), cases
).makedefault())
208 def visit_ClockSignal(self
, node
):
209 return self
.clock_domains
[node
.cd
].clk
211 def visit_ResetSignal(self
, node
):
212 rst
= self
.clock_domains
[node
.cd
].rst
214 if node
.allow_reset_less
:
217 raise ValueError("Attempted to get reset signal of resetless"
218 " domain '{}'".format(node
.cd
))
223 class _ComplexSliceLowerer(_Lowerer
):
224 def visit_Slice(self
, node
):
225 if not isinstance(node
.value
, Signal
):
226 slice_proxy
= Signal(value_bits_sign(node
.value
))
227 if self
.target_context
:
228 a
= _Assign(node
.value
, slice_proxy
)
230 a
= _Assign(slice_proxy
, node
.value
)
231 self
.comb
.append(self
.visit_Assign(a
))
232 node
= _Slice(slice_proxy
, node
.start
, node
.stop
)
233 return NodeTransformer
.visit_Slice(self
, node
)
236 def _apply_lowerer(l
, f
):
240 for special
in sorted(f
.specials
, key
=lambda s
: s
.duid
):
241 for obj
, attr
, direction
in special
.iter_expressions():
242 if direction
!= SPECIAL_INOUT
:
243 # inouts are only supported by Migen when connected directly to top-level
244 # in this case, they are Signal and never need lowering
246 l
.target_context
= direction
!= SPECIAL_INPUT
248 expr
= getattr(obj
, attr
)
250 setattr(obj
, attr
, expr
)
251 f
.comb
+= l
.comb
+ l
.extra_stmts
257 return _apply_lowerer(_BasicLowerer(f
.clock_domains
), f
)
260 def lower_complex_slices(f
):
261 return _apply_lowerer(_ComplexSliceLowerer(), f
)
264 class _ClockDomainRenamer(NodeVisitor
):
265 def __init__(self
, old
, new
):
269 def visit_ClockSignal(self
, node
):
270 if node
.cd
== self
.old
:
273 def visit_ResetSignal(self
, node
):
274 if node
.cd
== self
.old
:
278 def rename_clock_domain_expr(f
, old
, new
):
279 cdr
= _ClockDomainRenamer(old
, new
)
283 def rename_clock_domain(f
, old
, new
):
284 rename_clock_domain_expr(f
, old
, new
)
288 f
.sync
[new
].extend(f
.sync
[old
])
290 f
.sync
[new
] = f
.sync
[old
]
292 for special
in f
.specials
:
293 special
.rename_clock_domain(old
, new
)
295 cd
= f
.clock_domains
[old
]
302 def call_special_classmethod(overrides
, obj
, method
, *args
, **kwargs
):
306 if hasattr(cl
, method
):
307 return getattr(cl
, method
)(obj
, *args
, **kwargs
)
312 def _lower_specials_step(overrides
, specials
):
314 lowered_specials
= set()
315 for special
in sorted(specials
, key
=lambda x
: x
.duid
):
316 impl
= call_special_classmethod(overrides
, special
, "lower")
318 f
+= impl
.get_fragment()
319 lowered_specials
.add(special
)
320 return f
, lowered_specials
323 def _can_lower(overrides
, specials
):
324 for special
in specials
:
325 cl
= special
.__class
__
328 if hasattr(cl
, "lower"):
333 def lower_specials(overrides
, specials
):
334 f
, lowered_specials
= _lower_specials_step(overrides
, specials
)
335 while _can_lower(overrides
, f
.specials
):
336 f2
, lowered_specials2
= _lower_specials_step(overrides
, f
.specials
)
338 lowered_specials |
= lowered_specials2
339 f
.specials
-= lowered_specials2
340 return f
, lowered_specials