hdl.ast: simplify Mux implementation.
[nmigen.git] / nmigen / hdl / ir.py
1 from abc import ABCMeta
2 from collections import defaultdict, OrderedDict
3 from functools import reduce
4 import warnings
5
6 from .._utils import *
7 from .._unused import *
8 from .ast import *
9 from .cd import *
10
11
12 __all__ = ["UnusedElaboratable", "Elaboratable", "DriverConflict", "Fragment", "Instance"]
13
14
15 class UnusedElaboratable(UnusedMustUse):
16 pass
17
18
19 class Elaboratable(MustUse, metaclass=ABCMeta):
20 _MustUse__warning = UnusedElaboratable
21
22
23 class DriverConflict(UserWarning):
24 pass
25
26
27 class Fragment:
28 @staticmethod
29 def get(obj, platform):
30 code = None
31 while True:
32 if isinstance(obj, Fragment):
33 return obj
34 elif isinstance(obj, Elaboratable):
35 code = obj.elaborate.__code__
36 obj._MustUse__used = True
37 obj = obj.elaborate(platform)
38 elif hasattr(obj, "elaborate"):
39 warnings.warn(
40 message="Class {!r} is an elaboratable that does not explicitly inherit from "
41 "Elaboratable; doing so would improve diagnostics"
42 .format(type(obj)),
43 category=RuntimeWarning,
44 stacklevel=2)
45 code = obj.elaborate.__code__
46 obj = obj.elaborate(platform)
47 else:
48 raise AttributeError("Object {!r} cannot be elaborated".format(obj))
49 if obj is None and code is not None:
50 warnings.warn_explicit(
51 message=".elaborate() returned None; missing return statement?",
52 category=UserWarning,
53 filename=code.co_filename,
54 lineno=code.co_firstlineno)
55
56 def __init__(self):
57 self.ports = SignalDict()
58 self.drivers = OrderedDict()
59 self.statements = []
60 self.domains = OrderedDict()
61 self.subfragments = []
62 self.attrs = OrderedDict()
63 self.generated = OrderedDict()
64 self.flatten = False
65
66 def add_ports(self, *ports, dir):
67 assert dir in ("i", "o", "io")
68 for port in flatten(ports):
69 self.ports[port] = dir
70
71 def iter_ports(self, dir=None):
72 if dir is None:
73 yield from self.ports
74 else:
75 for port, port_dir in self.ports.items():
76 if port_dir == dir:
77 yield port
78
79 def add_driver(self, signal, domain=None):
80 if domain not in self.drivers:
81 self.drivers[domain] = SignalSet()
82 self.drivers[domain].add(signal)
83
84 def iter_drivers(self):
85 for domain, signals in self.drivers.items():
86 for signal in signals:
87 yield domain, signal
88
89 def iter_comb(self):
90 if None in self.drivers:
91 yield from self.drivers[None]
92
93 def iter_sync(self):
94 for domain, signals in self.drivers.items():
95 if domain is None:
96 continue
97 for signal in signals:
98 yield domain, signal
99
100 def iter_signals(self):
101 signals = SignalSet()
102 signals |= self.ports.keys()
103 for domain, domain_signals in self.drivers.items():
104 if domain is not None:
105 cd = self.domains[domain]
106 signals.add(cd.clk)
107 if cd.rst is not None:
108 signals.add(cd.rst)
109 signals |= domain_signals
110 return signals
111
112 def add_domains(self, *domains):
113 for domain in flatten(domains):
114 assert isinstance(domain, ClockDomain)
115 assert domain.name not in self.domains
116 self.domains[domain.name] = domain
117
118 def iter_domains(self):
119 yield from self.domains
120
121 def add_statements(self, *stmts):
122 for stmt in Statement.cast(stmts):
123 stmt._MustUse__used = True
124 self.statements.append(stmt)
125
126 def add_subfragment(self, subfragment, name=None):
127 assert isinstance(subfragment, Fragment)
128 self.subfragments.append((subfragment, name))
129
130 def find_subfragment(self, name_or_index):
131 if isinstance(name_or_index, int):
132 if name_or_index < len(self.subfragments):
133 subfragment, name = self.subfragments[name_or_index]
134 return subfragment
135 raise NameError("No subfragment at index #{}".format(name_or_index))
136 else:
137 for subfragment, name in self.subfragments:
138 if name == name_or_index:
139 return subfragment
140 raise NameError("No subfragment with name '{}'".format(name_or_index))
141
142 def find_generated(self, *path):
143 if len(path) > 1:
144 path_component, *path = path
145 return self.find_subfragment(path_component).find_generated(*path)
146 else:
147 item, = path
148 return self.generated[item]
149
150 def elaborate(self, platform):
151 return self
152
153 def _merge_subfragment(self, subfragment):
154 # Merge subfragment's everything except clock domains into this fragment.
155 # Flattening is done after clock domain propagation, so we can assume the domains
156 # are already the same in every involved fragment in the first place.
157 self.ports.update(subfragment.ports)
158 for domain, signal in subfragment.iter_drivers():
159 self.add_driver(signal, domain)
160 self.statements += subfragment.statements
161 self.subfragments += subfragment.subfragments
162
163 # Remove the merged subfragment.
164 found = False
165 for i, (check_subfrag, check_name) in enumerate(self.subfragments): # :nobr:
166 if subfragment == check_subfrag:
167 del self.subfragments[i]
168 found = True
169 break
170 assert found
171
172 def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
173 assert mode in ("silent", "warn", "error")
174
175 driver_subfrags = SignalDict()
176 memory_subfrags = OrderedDict()
177 def add_subfrag(registry, entity, entry):
178 # Because of missing domain insertion, at the point when this code runs, we have
179 # a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to
180 # the actual signals (because the signal itself can be driven as well); but leave
181 # the unbound ones as it is, because there's no concrete signal for it yet anyway.
182 if isinstance(entity, ClockSignal) and entity.domain in self.domains:
183 entity = self.domains[entity.domain].clk
184 elif isinstance(entity, ResetSignal) and entity.domain in self.domains:
185 entity = self.domains[entity.domain].rst
186
187 if entity not in registry:
188 registry[entity] = set()
189 registry[entity].add(entry)
190
191 # For each signal driven by this fragment and/or its subfragments, determine which
192 # subfragments also drive it.
193 for domain, signal in self.iter_drivers():
194 add_subfrag(driver_subfrags, signal, (None, hierarchy))
195
196 flatten_subfrags = set()
197 for i, (subfrag, name) in enumerate(self.subfragments):
198 if name is None:
199 name = "<unnamed #{}>".format(i)
200 subfrag_hierarchy = hierarchy + (name,)
201
202 if subfrag.flatten:
203 # Always flatten subfragments that explicitly request it.
204 flatten_subfrags.add((subfrag, subfrag_hierarchy))
205
206 if isinstance(subfrag, Instance):
207 # For memories (which are subfragments, but semantically a part of superfragment),
208 # record that this fragment is driving it.
209 if subfrag.type in ("$memrd", "$memwr"):
210 memory = subfrag.parameters["MEMID"]
211 add_subfrag(memory_subfrags, memory, (None, hierarchy))
212
213 # Never flatten instances.
214 continue
215
216 # First, recurse into subfragments and let them detect driver conflicts as well.
217 subfrag_drivers, subfrag_memories = \
218 subfrag._resolve_hierarchy_conflicts(subfrag_hierarchy, mode)
219
220 # Second, classify subfragments by signals they drive and memories they use.
221 for signal in subfrag_drivers:
222 add_subfrag(driver_subfrags, signal, (subfrag, subfrag_hierarchy))
223 for memory in subfrag_memories:
224 add_subfrag(memory_subfrags, memory, (subfrag, subfrag_hierarchy))
225
226 # Find out the set of subfragments that needs to be flattened into this fragment
227 # to resolve driver-driver conflicts.
228 def flatten_subfrags_if_needed(subfrags):
229 if len(subfrags) == 1:
230 return []
231 flatten_subfrags.update((f, h) for f, h in subfrags if f is not None)
232 return list(sorted(".".join(h) for f, h in subfrags))
233
234 for signal, subfrags in driver_subfrags.items():
235 subfrag_names = flatten_subfrags_if_needed(subfrags)
236 if not subfrag_names:
237 continue
238
239 # While we're at it, show a message.
240 message = ("Signal '{}' is driven from multiple fragments: {}"
241 .format(signal, ", ".join(subfrag_names)))
242 if mode == "error":
243 raise DriverConflict(message)
244 elif mode == "warn":
245 message += "; hierarchy will be flattened"
246 warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
247
248 for memory, subfrags in memory_subfrags.items():
249 subfrag_names = flatten_subfrags_if_needed(subfrags)
250 if not subfrag_names:
251 continue
252
253 # While we're at it, show a message.
254 message = ("Memory '{}' is accessed from multiple fragments: {}"
255 .format(memory.name, ", ".join(subfrag_names)))
256 if mode == "error":
257 raise DriverConflict(message)
258 elif mode == "warn":
259 message += "; hierarchy will be flattened"
260 warnings.warn_explicit(message, DriverConflict, *memory.src_loc)
261
262 # Flatten hierarchy.
263 for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
264 self._merge_subfragment(subfrag)
265
266 # If we flattened anything, we might be in a situation where we have a driver conflict
267 # again, e.g. if we had a tree of fragments like A --- B --- C where only fragments
268 # A and C were driving a signal S. In that case, since B is not driving S itself,
269 # processing B will not result in any flattening, but since B is transitively driving S,
270 # processing A will flatten B into it. Afterwards, we have a tree like AB --- C, which
271 # has another conflict.
272 if any(flatten_subfrags):
273 # Try flattening again.
274 return self._resolve_hierarchy_conflicts(hierarchy, mode)
275
276 # Nothing was flattened, we're done!
277 return (SignalSet(driver_subfrags.keys()),
278 set(memory_subfrags.keys()))
279
280 def _propagate_domains_up(self, hierarchy=("top",)):
281 from .xfrm import DomainRenamer
282
283 domain_subfrags = defaultdict(lambda: set())
284
285 # For each domain defined by a subfragment, determine which subfragments define it.
286 for i, (subfrag, name) in enumerate(self.subfragments):
287 # First, recurse into subfragments and let them propagate domains up as well.
288 hier_name = name
289 if hier_name is None:
290 hier_name = "<unnamed #{}>".format(i)
291 subfrag._propagate_domains_up(hierarchy + (hier_name,))
292
293 # Second, classify subfragments by domains they define.
294 for domain_name, domain in subfrag.domains.items():
295 if domain.local:
296 continue
297 domain_subfrags[domain_name].add((subfrag, name, i))
298
299 # For each domain defined by more than one subfragment, rename the domain in each
300 # of the subfragments such that they no longer conflict.
301 for domain_name, subfrags in domain_subfrags.items():
302 if len(subfrags) == 1:
303 continue
304
305 names = [n for f, n, i in subfrags]
306 if not all(names):
307 names = sorted("<unnamed #{}>".format(i) if n is None else "'{}'".format(n)
308 for f, n, i in subfrags)
309 raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; "
310 "it is necessary to either rename subfragment domains "
311 "explicitly, or give names to subfragments"
312 .format(domain_name, ", ".join(names), ".".join(hierarchy)))
313
314 if len(names) != len(set(names)):
315 names = sorted("#{}".format(i) for f, n, i in subfrags)
316 raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', "
317 "some of which have identical names; it is necessary to either "
318 "rename subfragment domains explicitly, or give distinct names "
319 "to subfragments"
320 .format(domain_name, ", ".join(names), ".".join(hierarchy)))
321
322 for subfrag, name, i in subfrags:
323 domain_name_map = {domain_name: "{}_{}".format(name, domain_name)}
324 self.subfragments[i] = (DomainRenamer(domain_name_map)(subfrag), name)
325
326 # Finally, collect the (now unique) subfragment domains, and merge them into our domains.
327 for subfrag, name in self.subfragments:
328 for domain_name, domain in subfrag.domains.items():
329 if domain.local:
330 continue
331 self.add_domains(domain)
332
333 def _propagate_domains_down(self):
334 # For each domain defined in this fragment, ensure it also exists in all subfragments.
335 for subfrag, name in self.subfragments:
336 for domain in self.iter_domains():
337 if domain in subfrag.domains:
338 assert self.domains[domain] is subfrag.domains[domain]
339 else:
340 subfrag.add_domains(self.domains[domain])
341
342 subfrag._propagate_domains_down()
343
344 def _create_missing_domains(self, missing_domain, *, platform=None):
345 from .xfrm import DomainCollector
346
347 collector = DomainCollector()
348 collector(self)
349
350 new_domains = []
351 for domain_name in collector.used_domains - collector.defined_domains:
352 if domain_name is None:
353 continue
354 value = missing_domain(domain_name)
355 if value is None:
356 raise DomainError("Domain '{}' is used but not defined".format(domain_name))
357 if type(value) is ClockDomain:
358 self.add_domains(value)
359 # And expose ports on the newly added clock domain, since it is added directly
360 # and there was no chance to add any logic driving it.
361 new_domains.append(value)
362 else:
363 new_fragment = Fragment.get(value, platform=platform)
364 if domain_name not in new_fragment.domains:
365 defined = new_fragment.domains.keys()
366 raise DomainError(
367 "Fragment returned by missing domain callback does not define "
368 "requested domain '{}' (defines {})."
369 .format(domain_name, ", ".join("'{}'".format(n) for n in defined)))
370 self.add_subfragment(new_fragment, "cd_{}".format(domain_name))
371 self.add_domains(new_fragment.domains.values())
372 return new_domains
373
374 def _propagate_domains(self, missing_domain, *, platform=None):
375 self._propagate_domains_up()
376 self._propagate_domains_down()
377 self._resolve_hierarchy_conflicts()
378 new_domains = self._create_missing_domains(missing_domain, platform=platform)
379 self._propagate_domains_down()
380 return new_domains
381
382 def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
383 def add_uses(*sigs, self=self):
384 for sig in flatten(sigs):
385 if sig not in uses:
386 uses[sig] = set()
387 uses[sig].add(self)
388
389 def add_defs(*sigs):
390 for sig in flatten(sigs):
391 if sig not in defs:
392 defs[sig] = self
393 else:
394 assert defs[sig] is self
395
396 def add_io(*sigs):
397 for sig in flatten(sigs):
398 if sig not in ios:
399 ios[sig] = self
400 else:
401 assert ios[sig] is self
402
403 # Collect all signals we're driving (on LHS of statements), and signals we're using
404 # (on RHS of statements, or in clock domains).
405 for stmt in self.statements:
406 add_uses(stmt._rhs_signals())
407 add_defs(stmt._lhs_signals())
408
409 for domain, _ in self.iter_sync():
410 cd = self.domains[domain]
411 add_uses(cd.clk)
412 if cd.rst is not None:
413 add_uses(cd.rst)
414
415 # Repeat for subfragments.
416 for subfrag, name in self.subfragments:
417 if isinstance(subfrag, Instance):
418 for port_name, (value, dir) in subfrag.named_ports.items():
419 if dir == "i":
420 # Prioritize defs over uses.
421 rhs_without_outputs = value._rhs_signals() - subfrag.iter_ports(dir="o")
422 subfrag.add_ports(rhs_without_outputs, dir=dir)
423 add_uses(value._rhs_signals())
424 if dir == "o":
425 subfrag.add_ports(value._lhs_signals(), dir=dir)
426 add_defs(value._lhs_signals())
427 if dir == "io":
428 subfrag.add_ports(value._lhs_signals(), dir=dir)
429 add_io(value._lhs_signals())
430 else:
431 parent[subfrag] = self
432 level [subfrag] = level[self] + 1
433
434 subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
435
436 def _propagate_ports(self, ports, all_undef_as_ports):
437 # Take this fragment graph:
438 #
439 # __ B (def: q, use: p r)
440 # /
441 # A (def: p, use: q r)
442 # \
443 # \_ C (def: r, use: p q)
444 #
445 # We need to consider three cases.
446 # 1. Signal p requires an input port in B;
447 # 2. Signal r requires an output port in C;
448 # 3. Signal r requires an output port in C and an input port in B.
449 #
450 # Adding these ports can be in general done in three steps for each signal:
451 # 1. Find the least common ancestor of all uses and defs.
452 # 2. Going upwards from the single def, add output ports.
453 # 3. Going upwards from all uses, add input ports.
454
455 parent = {self: None}
456 level = {self: 0}
457 uses = SignalDict()
458 defs = SignalDict()
459 ios = SignalDict()
460 self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
461
462 ports = SignalSet(ports)
463 if all_undef_as_ports:
464 for sig in uses:
465 if sig in defs:
466 continue
467 ports.add(sig)
468 for sig in ports:
469 if sig not in uses:
470 uses[sig] = set()
471 uses[sig].add(self)
472
473 @memoize
474 def lca_of(fragu, fragv):
475 # Normalize fragu to be deeper than fragv.
476 if level[fragu] < level[fragv]:
477 fragu, fragv = fragv, fragu
478 # Find ancestor of fragu on the same level as fragv.
479 for _ in range(level[fragu] - level[fragv]):
480 fragu = parent[fragu]
481 # If fragv was the ancestor of fragv, we're done.
482 if fragu == fragv:
483 return fragu
484 # Otherwise, they are at the same level but in different branches. Step both fragu
485 # and fragv until we find the common ancestor.
486 while parent[fragu] != parent[fragv]:
487 fragu = parent[fragu]
488 fragv = parent[fragv]
489 return parent[fragu]
490
491 for sig in uses:
492 if sig in defs:
493 lca = reduce(lca_of, uses[sig], defs[sig])
494 else:
495 lca = reduce(lca_of, uses[sig])
496
497 for frag in uses[sig]:
498 if sig in defs and frag is defs[sig]:
499 continue
500 while frag != lca:
501 frag.add_ports(sig, dir="i")
502 frag = parent[frag]
503
504 if sig in defs:
505 frag = defs[sig]
506 while frag != lca:
507 frag.add_ports(sig, dir="o")
508 frag = parent[frag]
509
510 for sig in ios:
511 frag = ios[sig]
512 while frag is not None:
513 frag.add_ports(sig, dir="io")
514 frag = parent[frag]
515
516 for sig in ports:
517 if sig in ios:
518 continue
519 if sig in defs:
520 self.add_ports(sig, dir="o")
521 else:
522 self.add_ports(sig, dir="i")
523
524 def prepare(self, ports=None, missing_domain=lambda name: ClockDomain(name)):
525 from .xfrm import SampleLowerer, DomainLowerer
526
527 fragment = SampleLowerer()(self)
528 new_domains = fragment._propagate_domains(missing_domain)
529 fragment = DomainLowerer()(fragment)
530 if ports is None:
531 fragment._propagate_ports(ports=(), all_undef_as_ports=True)
532 else:
533 if not isinstance(ports, tuple) and not isinstance(ports, list):
534 msg = "`ports` must be either a list or a tuple, not {!r}"\
535 .format(ports)
536 if isinstance(ports, Value):
537 msg += " (did you mean `ports=(<signal>,)`, rather than `ports=<signal>`?)"
538 raise TypeError(msg)
539 mapped_ports = []
540 # Lower late bound signals like ClockSignal() to ports.
541 port_lowerer = DomainLowerer(fragment.domains)
542 for port in ports:
543 if not isinstance(port, (Signal, ClockSignal, ResetSignal)):
544 raise TypeError("Only signals may be added as ports, not {!r}"
545 .format(port))
546 mapped_ports.append(port_lowerer.on_value(port))
547 # Add ports for all newly created missing clock domains, since not doing so defeats
548 # the purpose of domain auto-creation. (It's possible to refer to these ports before
549 # the domain actually exists through late binding, but it's inconvenient.)
550 for cd in new_domains:
551 mapped_ports.append(cd.clk)
552 if cd.rst is not None:
553 mapped_ports.append(cd.rst)
554 fragment._propagate_ports(ports=mapped_ports, all_undef_as_ports=False)
555 return fragment
556
557
558 class Instance(Fragment):
559 def __init__(self, type, *args, **kwargs):
560 super().__init__()
561
562 self.type = type
563 self.parameters = OrderedDict()
564 self.named_ports = OrderedDict()
565
566 for (kind, name, value) in args:
567 if kind == "a":
568 self.attrs[name] = value
569 elif kind == "p":
570 self.parameters[name] = value
571 elif kind in ("i", "o", "io"):
572 self.named_ports[name] = (Value.cast(value), kind)
573 else:
574 raise NameError("Instance argument {!r} should be a tuple (kind, name, value) "
575 "where kind is one of \"a\", \"p\", \"i\", \"o\", or \"io\""
576 .format((kind, name, value)))
577
578 for kw, arg in kwargs.items():
579 if kw.startswith("a_"):
580 self.attrs[kw[2:]] = arg
581 elif kw.startswith("p_"):
582 self.parameters[kw[2:]] = arg
583 elif kw.startswith("i_"):
584 self.named_ports[kw[2:]] = (Value.cast(arg), "i")
585 elif kw.startswith("o_"):
586 self.named_ports[kw[2:]] = (Value.cast(arg), "o")
587 elif kw.startswith("io_"):
588 self.named_ports[kw[3:]] = (Value.cast(arg), "io")
589 else:
590 raise NameError("Instance keyword argument {}={!r} does not start with one of "
591 "\"a_\", \"p_\", \"i_\", \"o_\", or \"io_\""
592 .format(kw, arg))