b77a340246f9014e1646a4dc02d23760f4afa12f
[nmigen.git] / nmigen / hdl / ir.py
1 from abc import ABCMeta, abstractmethod
2 from collections import defaultdict, OrderedDict
3 from functools import reduce
4 import warnings
5 import traceback
6 import sys
7
8 from ..tools import *
9 from .ast import *
10 from .cd import *
11
12
13 __all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"]
14
15
16 class UnusedElaboratable(Warning):
17 pass
18
19
20 class Elaboratable(metaclass=ABCMeta):
21 _Elaboratable__silence = False
22
23 def __new__(cls, *args, src_loc_at=0, **kwargs):
24 self = super().__new__(cls)
25 self._Elaboratable__src_loc = traceback.extract_stack(limit=2 + src_loc_at)[0]
26 self._Elaboratable__used = False
27 return self
28
29 def __del__(self):
30 if self._Elaboratable__silence:
31 return
32 if hasattr(self, "_Elaboratable__used") and not self._Elaboratable__used:
33 warnings.warn_explicit("{!r} created but never used".format(self), UnusedElaboratable,
34 filename=self._Elaboratable__src_loc.filename,
35 lineno=self._Elaboratable__src_loc.lineno,
36 source=self)
37
38
39 _old_excepthook = sys.excepthook
40 def _silence_elaboratable(type, value, traceback):
41 # Don't show anything if the interpreter crashed; that'd just obscure the exception
42 # traceback instead of helping.
43 Elaboratable._Elaboratable__silence = True
44 _old_excepthook(type, value, traceback)
45 sys.excepthook = _silence_elaboratable
46
47
48 class DriverConflict(UserWarning):
49 pass
50
51
52 class Fragment:
53 @staticmethod
54 def get(obj, platform):
55 code = None
56 while True:
57 if isinstance(obj, Fragment):
58 return obj
59 elif isinstance(obj, Elaboratable):
60 code = obj.elaborate.__code__
61 obj._Elaboratable__used = True
62 obj = obj.elaborate(platform)
63 elif hasattr(obj, "elaborate"):
64 warnings.warn(
65 message="Class {!r} is an elaboratable that does not explicitly inherit from "
66 "Elaboratable; doing so would improve diagnostics"
67 .format(type(obj)),
68 category=RuntimeWarning,
69 stacklevel=2)
70 code = obj.elaborate.__code__
71 obj = obj.elaborate(platform)
72 else:
73 raise AttributeError("Object {!r} cannot be elaborated".format(obj))
74 if obj is None and code is not None:
75 warnings.warn_explicit(
76 message=".elaborate() returned None; missing return statement?",
77 category=UserWarning,
78 filename=code.co_filename,
79 lineno=code.co_firstlineno)
80
81 def __init__(self):
82 self.ports = SignalDict()
83 self.drivers = OrderedDict()
84 self.statements = []
85 self.domains = OrderedDict()
86 self.subfragments = []
87 self.attrs = OrderedDict()
88 self.generated = OrderedDict()
89 self.flatten = False
90
91 def add_ports(self, *ports, dir):
92 assert dir in ("i", "o", "io")
93 for port in flatten(ports):
94 self.ports[port] = dir
95
96 def iter_ports(self, dir=None):
97 if dir is None:
98 yield from self.ports
99 else:
100 for port, port_dir in self.ports.items():
101 if port_dir == dir:
102 yield port
103
104 def add_driver(self, signal, domain=None):
105 if domain not in self.drivers:
106 self.drivers[domain] = SignalSet()
107 self.drivers[domain].add(signal)
108
109 def iter_drivers(self):
110 for domain, signals in self.drivers.items():
111 for signal in signals:
112 yield domain, signal
113
114 def iter_comb(self):
115 if None in self.drivers:
116 yield from self.drivers[None]
117
118 def iter_sync(self):
119 for domain, signals in self.drivers.items():
120 if domain is None:
121 continue
122 for signal in signals:
123 yield domain, signal
124
125 def iter_signals(self):
126 signals = SignalSet()
127 signals |= self.ports.keys()
128 for domain, domain_signals in self.drivers.items():
129 if domain is not None:
130 cd = self.domains[domain]
131 signals.add(cd.clk)
132 if cd.rst is not None:
133 signals.add(cd.rst)
134 signals |= domain_signals
135 return signals
136
137 def add_domains(self, *domains):
138 for domain in flatten(domains):
139 assert isinstance(domain, ClockDomain)
140 assert domain.name not in self.domains
141 self.domains[domain.name] = domain
142
143 def iter_domains(self):
144 yield from self.domains
145
146 def add_statements(self, *stmts):
147 self.statements += Statement.cast(stmts)
148
149 def add_subfragment(self, subfragment, name=None):
150 assert isinstance(subfragment, Fragment)
151 self.subfragments.append((subfragment, name))
152
153 def find_subfragment(self, name_or_index):
154 if isinstance(name_or_index, int):
155 if name_or_index < len(self.subfragments):
156 subfragment, name = self.subfragments[name_or_index]
157 return subfragment
158 raise NameError("No subfragment at index #{}".format(name_or_index))
159 else:
160 for subfragment, name in self.subfragments:
161 if name == name_or_index:
162 return subfragment
163 raise NameError("No subfragment with name '{}'".format(name_or_index))
164
165 def find_generated(self, *path):
166 if len(path) > 1:
167 path_component, *path = path
168 return self.find_subfragment(path_component).find_generated(*path)
169 else:
170 item, = path
171 return self.generated[item]
172
173 def elaborate(self, platform):
174 return self
175
176 def _merge_subfragment(self, subfragment):
177 # Merge subfragment's everything except clock domains into this fragment.
178 # Flattening is done after clock domain propagation, so we can assume the domains
179 # are already the same in every involved fragment in the first place.
180 self.ports.update(subfragment.ports)
181 for domain, signal in subfragment.iter_drivers():
182 self.add_driver(signal, domain)
183 self.statements += subfragment.statements
184 self.subfragments += subfragment.subfragments
185
186 # Remove the merged subfragment.
187 found = False
188 for i, (check_subfrag, check_name) in enumerate(self.subfragments): # :nobr:
189 if subfragment == check_subfrag:
190 del self.subfragments[i]
191 found = True
192 break
193 assert found
194
195 def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
196 assert mode in ("silent", "warn", "error")
197
198 driver_subfrags = SignalDict()
199 memory_subfrags = OrderedDict()
200 def add_subfrag(registry, entity, entry):
201 if entity not in registry:
202 registry[entity] = set()
203 registry[entity].add(entry)
204
205 # For each signal driven by this fragment and/or its subfragments, determine which
206 # subfragments also drive it.
207 for domain, signal in self.iter_drivers():
208 add_subfrag(driver_subfrags, signal, (None, hierarchy))
209
210 flatten_subfrags = set()
211 for i, (subfrag, name) in enumerate(self.subfragments):
212 if name is None:
213 name = "<unnamed #{}>".format(i)
214 subfrag_hierarchy = hierarchy + (name,)
215
216 if subfrag.flatten:
217 # Always flatten subfragments that explicitly request it.
218 flatten_subfrags.add((subfrag, subfrag_hierarchy))
219
220 if isinstance(subfrag, Instance):
221 # For memories (which are subfragments, but semantically a part of superfragment),
222 # record that this fragment is driving it.
223 if subfrag.type in ("$memrd", "$memwr"):
224 memory = subfrag.parameters["MEMID"]
225 add_subfrag(memory_subfrags, memory, (None, hierarchy))
226
227 # Never flatten instances.
228 continue
229
230 # First, recurse into subfragments and let them detect driver conflicts as well.
231 subfrag_drivers, subfrag_memories = \
232 subfrag._resolve_hierarchy_conflicts(subfrag_hierarchy, mode)
233
234 # Second, classify subfragments by signals they drive and memories they use.
235 for signal in subfrag_drivers:
236 add_subfrag(driver_subfrags, signal, (subfrag, subfrag_hierarchy))
237 for memory in subfrag_memories:
238 add_subfrag(memory_subfrags, memory, (subfrag, subfrag_hierarchy))
239
240 # Find out the set of subfragments that needs to be flattened into this fragment
241 # to resolve driver-driver conflicts.
242 def flatten_subfrags_if_needed(subfrags):
243 if len(subfrags) == 1:
244 return []
245 flatten_subfrags.update((f, h) for f, h in subfrags if f is not None)
246 return list(sorted(".".join(h) for f, h in subfrags))
247
248 for signal, subfrags in driver_subfrags.items():
249 subfrag_names = flatten_subfrags_if_needed(subfrags)
250 if not subfrag_names:
251 continue
252
253 # While we're at it, show a message.
254 message = ("Signal '{}' is driven from multiple fragments: {}"
255 .format(signal, ", ".join(subfrag_names)))
256 if mode == "error":
257 raise DriverConflict(message)
258 elif mode == "warn":
259 message += "; hierarchy will be flattened"
260 warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
261
262 for memory, subfrags in memory_subfrags.items():
263 subfrag_names = flatten_subfrags_if_needed(subfrags)
264 if not subfrag_names:
265 continue
266
267 # While we're at it, show a message.
268 message = ("Memory '{}' is accessed from multiple fragments: {}"
269 .format(memory.name, ", ".join(subfrag_names)))
270 if mode == "error":
271 raise DriverConflict(message)
272 elif mode == "warn":
273 message += "; hierarchy will be flattened"
274 warnings.warn_explicit(message, DriverConflict, *memory.src_loc)
275
276 # Flatten hierarchy.
277 for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
278 self._merge_subfragment(subfrag)
279
280 # If we flattened anything, we might be in a situation where we have a driver conflict
281 # again, e.g. if we had a tree of fragments like A --- B --- C where only fragments
282 # A and C were driving a signal S. In that case, since B is not driving S itself,
283 # processing B will not result in any flattening, but since B is transitively driving S,
284 # processing A will flatten B into it. Afterwards, we have a tree like AB --- C, which
285 # has another conflict.
286 if any(flatten_subfrags):
287 # Try flattening again.
288 return self._resolve_hierarchy_conflicts(hierarchy, mode)
289
290 # Nothing was flattened, we're done!
291 return (SignalSet(driver_subfrags.keys()),
292 set(memory_subfrags.keys()))
293
294 def _propagate_domains_up(self, hierarchy=("top",)):
295 from .xfrm import DomainRenamer
296
297 domain_subfrags = defaultdict(lambda: set())
298
299 # For each domain defined by a subfragment, determine which subfragments define it.
300 for i, (subfrag, name) in enumerate(self.subfragments):
301 # First, recurse into subfragments and let them propagate domains up as well.
302 hier_name = name
303 if hier_name is None:
304 hier_name = "<unnamed #{}>".format(i)
305 subfrag._propagate_domains_up(hierarchy + (hier_name,))
306
307 # Second, classify subfragments by domains they define.
308 for domain_name, domain in subfrag.domains.items():
309 if domain.local:
310 continue
311 domain_subfrags[domain_name].add((subfrag, name, i))
312
313 # For each domain defined by more than one subfragment, rename the domain in each
314 # of the subfragments such that they no longer conflict.
315 for domain_name, subfrags in domain_subfrags.items():
316 if len(subfrags) == 1:
317 continue
318
319 names = [n for f, n, i in subfrags]
320 if not all(names):
321 names = sorted("<unnamed #{}>".format(i) if n is None else "'{}'".format(n)
322 for f, n, i in subfrags)
323 raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; "
324 "it is necessary to either rename subfragment domains "
325 "explicitly, or give names to subfragments"
326 .format(domain_name, ", ".join(names), ".".join(hierarchy)))
327
328 if len(names) != len(set(names)):
329 names = sorted("#{}".format(i) for f, n, i in subfrags)
330 raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', "
331 "some of which have identical names; it is necessary to either "
332 "rename subfragment domains explicitly, or give distinct names "
333 "to subfragments"
334 .format(domain_name, ", ".join(names), ".".join(hierarchy)))
335
336 for subfrag, name, i in subfrags:
337 domain_name_map = {domain_name: "{}_{}".format(name, domain_name)}
338 self.subfragments[i] = (DomainRenamer(domain_name_map)(subfrag), name)
339
340 # Finally, collect the (now unique) subfragment domains, and merge them into our domains.
341 for subfrag, name in self.subfragments:
342 for domain_name, domain in subfrag.domains.items():
343 if domain.local:
344 continue
345 self.add_domains(domain)
346
347 def _propagate_domains_down(self):
348 # For each domain defined in this fragment, ensure it also exists in all subfragments.
349 for subfrag, name in self.subfragments:
350 for domain in self.iter_domains():
351 if domain in subfrag.domains:
352 assert self.domains[domain] is subfrag.domains[domain]
353 else:
354 subfrag.add_domains(self.domains[domain])
355
356 subfrag._propagate_domains_down()
357
358 def create_missing_domains(self, missing_domain, *, platform=None):
359 from .xfrm import DomainCollector
360
361 collector = DomainCollector()
362 collector(self)
363
364 new_domains = []
365 for domain_name in collector.used_domains - collector.defined_domains:
366 if domain_name is None:
367 continue
368 value = missing_domain(domain_name)
369 if value is None:
370 raise DomainError("Domain '{}' is used but not defined".format(domain_name))
371 if type(value) is ClockDomain:
372 self.add_domains(value)
373 # And expose ports on the newly added clock domain, since it is added directly
374 # and there was no chance to add any logic driving it.
375 new_domains.append(value)
376 else:
377 new_fragment = Fragment.get(value, platform=platform)
378 if domain_name not in new_fragment.domains:
379 defined = new_fragment.domains.keys()
380 raise DomainError(
381 "Fragment returned by missing domain callback does not define "
382 "requested domain '{}' (defines {})."
383 .format(domain_name, ", ".join("'{}'".format(n) for n in defined)))
384 self.add_subfragment(new_fragment, "cd_{}".format(domain_name))
385 return new_domains
386
387 def _propagate_domains(self, missing_domain):
388 new_domains = self.create_missing_domains(missing_domain)
389 self._propagate_domains_up()
390 self._propagate_domains_down()
391 return new_domains
392
393 def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
394 def add_uses(*sigs, self=self):
395 for sig in flatten(sigs):
396 if sig not in uses:
397 uses[sig] = set()
398 uses[sig].add(self)
399
400 def add_defs(*sigs):
401 for sig in flatten(sigs):
402 if sig not in defs:
403 defs[sig] = self
404 else:
405 assert defs[sig] is self
406
407 def add_io(*sigs):
408 for sig in flatten(sigs):
409 if sig not in ios:
410 ios[sig] = self
411 else:
412 assert ios[sig] is self
413
414 # Collect all signals we're driving (on LHS of statements), and signals we're using
415 # (on RHS of statements, or in clock domains).
416 for stmt in self.statements:
417 add_uses(stmt._rhs_signals())
418 add_defs(stmt._lhs_signals())
419
420 for domain, _ in self.iter_sync():
421 cd = self.domains[domain]
422 add_uses(cd.clk)
423 if cd.rst is not None:
424 add_uses(cd.rst)
425
426 # Repeat for subfragments.
427 for subfrag, name in self.subfragments:
428 if isinstance(subfrag, Instance):
429 for port_name, (value, dir) in subfrag.named_ports.items():
430 if dir == "i":
431 subfrag.add_ports(value._rhs_signals(), dir=dir)
432 add_uses(value._rhs_signals())
433 if dir == "o":
434 subfrag.add_ports(value._lhs_signals(), dir=dir)
435 add_defs(value._lhs_signals())
436 if dir == "io":
437 subfrag.add_ports(value._lhs_signals(), dir=dir)
438 add_io(value._lhs_signals())
439 else:
440 parent[subfrag] = self
441 level [subfrag] = level[self] + 1
442
443 subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
444
445 def _propagate_ports(self, ports, all_undef_as_ports):
446 # Take this fragment graph:
447 #
448 # __ B (def: q, use: p r)
449 # /
450 # A (def: p, use: q r)
451 # \
452 # \_ C (def: r, use: p q)
453 #
454 # We need to consider three cases.
455 # 1. Signal p requires an input port in B;
456 # 2. Signal r requires an output port in C;
457 # 3. Signal r requires an output port in C and an input port in B.
458 #
459 # Adding these ports can be in general done in three steps for each signal:
460 # 1. Find the least common ancestor of all uses and defs.
461 # 2. Going upwards from the single def, add output ports.
462 # 3. Going upwards from all uses, add input ports.
463
464 parent = {self: None}
465 level = {self: 0}
466 uses = SignalDict()
467 defs = SignalDict()
468 ios = SignalDict()
469 self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
470
471 ports = SignalSet(ports)
472 if all_undef_as_ports:
473 for sig in uses:
474 if sig in defs:
475 continue
476 ports.add(sig)
477 for sig in ports:
478 if sig not in uses:
479 uses[sig] = set()
480 uses[sig].add(self)
481
482 @memoize
483 def lca_of(fragu, fragv):
484 # Normalize fragu to be deeper than fragv.
485 if level[fragu] < level[fragv]:
486 fragu, fragv = fragv, fragu
487 # Find ancestor of fragu on the same level as fragv.
488 for _ in range(level[fragu] - level[fragv]):
489 fragu = parent[fragu]
490 # If fragv was the ancestor of fragv, we're done.
491 if fragu == fragv:
492 return fragu
493 # Otherwise, they are at the same level but in different branches. Step both fragu
494 # and fragv until we find the common ancestor.
495 while parent[fragu] != parent[fragv]:
496 fragu = parent[fragu]
497 fragv = parent[fragv]
498 return parent[fragu]
499
500 for sig in uses:
501 if sig in defs:
502 lca = reduce(lca_of, uses[sig], defs[sig])
503 else:
504 lca = reduce(lca_of, uses[sig])
505
506 for frag in uses[sig]:
507 if sig in defs and frag is defs[sig]:
508 continue
509 while frag != lca:
510 frag.add_ports(sig, dir="i")
511 frag = parent[frag]
512
513 if sig in defs:
514 frag = defs[sig]
515 while frag != lca:
516 frag.add_ports(sig, dir="o")
517 frag = parent[frag]
518
519 for sig in ios:
520 frag = ios[sig]
521 while frag is not None:
522 frag.add_ports(sig, dir="io")
523 frag = parent[frag]
524
525 for sig in ports:
526 if sig in ios:
527 continue
528 if sig in defs:
529 self.add_ports(sig, dir="o")
530 else:
531 self.add_ports(sig, dir="i")
532
533 def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)):
534 from .xfrm import SampleLowerer, DomainLowerer
535
536 fragment = SampleLowerer()(self)
537 new_domains = fragment._propagate_domains(missing_domain)
538 fragment._resolve_hierarchy_conflicts()
539 fragment = DomainLowerer()(fragment)
540 if ports is None:
541 fragment._propagate_ports(ports=(), all_undef_as_ports=True)
542 else:
543 new_ports = []
544 for cd in new_domains:
545 new_ports.append(cd.clk)
546 if cd.rst is not None:
547 new_ports.append(cd.rst)
548 fragment._propagate_ports(ports=(*ports, *new_ports), all_undef_as_ports=False)
549 return fragment
550
551
552 class Instance(Fragment):
553 def __init__(self, type, *args, **kwargs):
554 super().__init__()
555
556 self.type = type
557 self.parameters = OrderedDict()
558 self.named_ports = OrderedDict()
559
560 for (kind, name, value) in args:
561 if kind == "a":
562 self.attrs[name] = value
563 elif kind == "p":
564 self.parameters[name] = value
565 elif kind in ("i", "o", "io"):
566 self.named_ports[name] = (value, kind)
567 else:
568 raise NameError("Instance argument {!r} should be a tuple (kind, name, value) "
569 "where kind is one of \"p\", \"i\", \"o\", or \"io\""
570 .format((kind, name, value)))
571
572 for kw, arg in kwargs.items():
573 if kw.startswith("a_"):
574 self.attrs[kw[2:]] = arg
575 elif kw.startswith("p_"):
576 self.parameters[kw[2:]] = arg
577 elif kw.startswith("i_"):
578 self.named_ports[kw[2:]] = (arg, "i")
579 elif kw.startswith("o_"):
580 self.named_ports[kw[2:]] = (arg, "o")
581 elif kw.startswith("io_"):
582 self.named_ports[kw[3:]] = (arg, "io")
583 else:
584 raise NameError("Instance keyword argument {}={!r} does not start with one of "
585 "\"p_\", \"i_\", \"o_\", or \"io_\""
586 .format(kw, arg))