1 from abc
import ABCMeta
, abstractmethod
2 from collections
import OrderedDict
3 from collections
.abc
import Iterable
5 from .._utils
import flatten
8 from .ast
import _StatementList
14 __all__
= ["ValueVisitor", "ValueTransformer",
15 "StatementVisitor", "StatementTransformer",
16 "FragmentTransformer",
17 "TransformedElaboratable",
18 "DomainCollector", "DomainRenamer", "DomainLowerer",
19 "SampleDomainInjector", "SampleLowerer",
20 "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
21 "ResetInserter", "EnableInserter"]
24 class ValueVisitor(metaclass
=ABCMeta
):
26 def on_Const(self
, value
):
30 def on_AnyConst(self
, value
):
34 def on_AnySeq(self
, value
):
38 def on_Signal(self
, value
):
42 def on_ClockSignal(self
, value
):
46 def on_ResetSignal(self
, value
):
50 def on_Operator(self
, value
):
54 def on_Slice(self
, value
):
58 def on_Part(self
, value
):
62 def on_Cat(self
, value
):
66 def on_Repl(self
, value
):
70 def on_ArrayProxy(self
, value
):
74 def on_Sample(self
, value
):
78 def on_Initial(self
, value
):
81 def on_unknown_value(self
, value
):
82 raise TypeError("Cannot transform value {!r}".format(value
)) # :nocov:
84 def replace_value_src_loc(self
, value
, new_value
):
87 def on_value(self
, value
):
88 if type(value
) is Const
:
89 new_value
= self
.on_Const(value
)
90 elif type(value
) is AnyConst
:
91 new_value
= self
.on_AnyConst(value
)
92 elif type(value
) is AnySeq
:
93 new_value
= self
.on_AnySeq(value
)
94 elif isinstance(value
, Signal
):
95 # Uses `isinstance()` and not `type() is` because nmigen.compat requires it.
96 new_value
= self
.on_Signal(value
)
97 elif type(value
) is ClockSignal
:
98 new_value
= self
.on_ClockSignal(value
)
99 elif type(value
) is ResetSignal
:
100 new_value
= self
.on_ResetSignal(value
)
101 elif type(value
) is Operator
:
102 new_value
= self
.on_Operator(value
)
103 elif type(value
) is Slice
:
104 new_value
= self
.on_Slice(value
)
105 elif type(value
) is Part
:
106 new_value
= self
.on_Part(value
)
107 elif type(value
) is Cat
:
108 new_value
= self
.on_Cat(value
)
109 elif type(value
) is _InternalRepl
:
110 new_value
= self
.on_Repl(value
)
111 elif type(value
) is ArrayProxy
:
112 new_value
= self
.on_ArrayProxy(value
)
113 elif type(value
) is Sample
:
114 new_value
= self
.on_Sample(value
)
115 elif type(value
) is Initial
:
116 new_value
= self
.on_Initial(value
)
117 elif isinstance(value
, UserValue
):
118 # Uses `isinstance()` and not `type() is` to allow inheriting.
119 new_value
= self
.on_value(value
._lazy
_lower
())
121 new_value
= self
.on_unknown_value(value
)
122 if isinstance(new_value
, Value
) and self
.replace_value_src_loc(value
, new_value
):
123 new_value
.src_loc
= value
.src_loc
126 def __call__(self
, value
):
127 return self
.on_value(value
)
130 class ValueTransformer(ValueVisitor
):
131 def on_Const(self
, value
):
134 def on_AnyConst(self
, value
):
137 def on_AnySeq(self
, value
):
140 def on_Signal(self
, value
):
143 def on_ClockSignal(self
, value
):
146 def on_ResetSignal(self
, value
):
149 def on_Operator(self
, value
):
150 return Operator(value
.operator
, [self
.on_value(o
) for o
in value
.operands
])
152 def on_Slice(self
, value
):
153 return Slice(self
.on_value(value
.value
), value
.start
, value
.stop
)
155 def on_Part(self
, value
):
156 return Part(self
.on_value(value
.value
), self
.on_value(value
.offset
),
157 value
.width
, value
.stride
)
159 def on_Cat(self
, value
):
160 return Cat(self
.on_value(o
) for o
in value
.parts
)
162 def on_Repl(self
, value
):
163 return Repl(self
.on_value(value
.value
), value
.count
)
165 def on_ArrayProxy(self
, value
):
166 return ArrayProxy([self
.on_value(elem
) for elem
in value
._iter
_as
_values
()],
167 self
.on_value(value
.index
))
169 def on_Sample(self
, value
):
170 return Sample(self
.on_value(value
.value
), value
.clocks
, value
.domain
)
172 def on_Initial(self
, value
):
176 class StatementVisitor(metaclass
=ABCMeta
):
178 def on_Assign(self
, stmt
):
182 def on_Assert(self
, stmt
):
186 def on_Assume(self
, stmt
):
190 def on_Cover(self
, stmt
):
194 def on_Switch(self
, stmt
):
198 def on_statements(self
, stmts
):
201 def on_unknown_statement(self
, stmt
):
202 raise TypeError("Cannot transform statement {!r}".format(stmt
)) # :nocov:
204 def replace_statement_src_loc(self
, stmt
, new_stmt
):
207 def on_statement(self
, stmt
):
208 if type(stmt
) is _InternalAssign
:
209 new_stmt
= self
.on_Assign(stmt
)
210 elif type(stmt
) is Assert
:
211 new_stmt
= self
.on_Assert(stmt
)
212 elif type(stmt
) is Assume
:
213 new_stmt
= self
.on_Assume(stmt
)
214 elif type(stmt
) is Cover
:
215 new_stmt
= self
.on_Cover(stmt
)
216 elif isinstance(stmt
, _InternalSwitch
):
217 # Uses `isinstance()` and not `type() is` because nmigen.compat requires it.
218 new_stmt
= self
.on_Switch(stmt
)
219 elif isinstance(stmt
, Iterable
):
220 new_stmt
= self
.on_statements(stmt
)
222 new_stmt
= self
.on_unknown_statement(stmt
)
223 if isinstance(new_stmt
, Statement
) and self
.replace_statement_src_loc(stmt
, new_stmt
):
224 new_stmt
.src_loc
= stmt
.src_loc
225 if (isinstance(new_stmt
, _InternalSwitch
) and
226 isinstance(stmt
, _InternalSwitch
)):
227 new_stmt
.case_src_locs
= stmt
.case_src_locs
228 if isinstance(new_stmt
, Property
):
229 new_stmt
._MustUse
__used
= True
232 def __call__(self
, stmt
):
233 return self
.on_statement(stmt
)
236 class StatementTransformer(StatementVisitor
):
237 def on_value(self
, value
):
240 def on_Assign(self
, stmt
):
241 return Assign(self
.on_value(stmt
.lhs
), self
.on_value(stmt
.rhs
))
243 def on_Assert(self
, stmt
):
244 return Assert(self
.on_value(stmt
.test
), _check
=stmt
._check
, _en
=stmt
._en
)
246 def on_Assume(self
, stmt
):
247 return Assume(self
.on_value(stmt
.test
), _check
=stmt
._check
, _en
=stmt
._en
)
249 def on_Cover(self
, stmt
):
250 return Cover(self
.on_value(stmt
.test
), _check
=stmt
._check
, _en
=stmt
._en
)
252 def on_Switch(self
, stmt
):
253 cases
= OrderedDict((k
, self
.on_statement(s
)) for k
, s
in stmt
.cases
.items())
254 return Switch(self
.on_value(stmt
.test
), cases
)
256 def on_statements(self
, stmts
):
257 return _StatementList(flatten(self
.on_statement(stmt
) for stmt
in stmts
))
260 class FragmentTransformer
:
261 def map_subfragments(self
, fragment
, new_fragment
):
262 for subfragment
, name
in fragment
.subfragments
:
263 new_fragment
.add_subfragment(self(subfragment
), name
)
265 def map_ports(self
, fragment
, new_fragment
):
266 for port
, dir in fragment
.ports
.items():
267 new_fragment
.add_ports(port
, dir=dir)
269 def map_named_ports(self
, fragment
, new_fragment
):
270 if hasattr(self
, "on_value"):
271 for name
, (value
, dir) in fragment
.named_ports
.items():
272 new_fragment
.named_ports
[name
] = self
.on_value(value
), dir
274 new_fragment
.named_ports
= OrderedDict(fragment
.named_ports
.items())
276 def map_domains(self
, fragment
, new_fragment
):
277 for domain
in fragment
.iter_domains():
278 new_fragment
.add_domains(fragment
.domains
[domain
])
280 def map_statements(self
, fragment
, new_fragment
):
281 if hasattr(self
, "on_statement"):
282 new_fragment
.add_statements(map(self
.on_statement
, fragment
.statements
))
284 new_fragment
.add_statements(fragment
.statements
)
286 def map_drivers(self
, fragment
, new_fragment
):
287 for domain
, signal
in fragment
.iter_drivers():
288 new_fragment
.add_driver(signal
, domain
)
290 def on_fragment(self
, fragment
):
291 if isinstance(fragment
, Instance
):
292 new_fragment
= Instance(fragment
.type)
293 new_fragment
.parameters
= OrderedDict(fragment
.parameters
)
294 self
.map_named_ports(fragment
, new_fragment
)
296 new_fragment
= Fragment()
297 new_fragment
.flatten
= fragment
.flatten
298 new_fragment
.attrs
= OrderedDict(fragment
.attrs
)
299 self
.map_ports(fragment
, new_fragment
)
300 self
.map_subfragments(fragment
, new_fragment
)
301 self
.map_domains(fragment
, new_fragment
)
302 self
.map_statements(fragment
, new_fragment
)
303 self
.map_drivers(fragment
, new_fragment
)
306 def __call__(self
, value
, *, src_loc_at
=0):
307 if isinstance(value
, Fragment
):
308 return self
.on_fragment(value
)
309 elif isinstance(value
, TransformedElaboratable
):
310 value
._transforms
_.append(self
)
312 elif hasattr(value
, "elaborate"):
313 value
= TransformedElaboratable(value
, src_loc_at
=1 + src_loc_at
)
314 value
._transforms
_.append(self
)
317 raise AttributeError("Object {!r} cannot be elaborated".format(value
))
320 class TransformedElaboratable(Elaboratable
):
321 def __init__(self
, elaboratable
, *, src_loc_at
=0):
322 assert hasattr(elaboratable
, "elaborate")
324 # Fields prefixed and suffixed with underscore to avoid as many conflicts with the inner
325 # object as possible, since we're forwarding attribute requests to it.
326 self
._elaboratable
_ = elaboratable
327 self
._transforms
_ = []
329 def __getattr__(self
, attr
):
330 return getattr(self
._elaboratable
_, attr
)
332 def elaborate(self
, platform
):
333 fragment
= Fragment
.get(self
._elaboratable
_, platform
)
334 for transform
in self
._transforms
_:
335 fragment
= transform(fragment
)
339 class DomainCollector(ValueVisitor
, StatementVisitor
):
341 self
.used_domains
= set()
342 self
.defined_domains
= set()
343 self
._local
_domains
= set()
345 def _add_used_domain(self
, domain_name
):
346 if domain_name
is None:
348 if domain_name
in self
._local
_domains
:
350 self
.used_domains
.add(domain_name
)
352 def on_ignore(self
, value
):
356 on_AnyConst
= on_ignore
357 on_AnySeq
= on_ignore
358 on_Signal
= on_ignore
360 def on_ClockSignal(self
, value
):
361 self
._add
_used
_domain
(value
.domain
)
363 def on_ResetSignal(self
, value
):
364 self
._add
_used
_domain
(value
.domain
)
366 def on_Operator(self
, value
):
367 for o
in value
.operands
:
370 def on_Slice(self
, value
):
371 self
.on_value(value
.value
)
373 def on_Part(self
, value
):
374 self
.on_value(value
.value
)
375 self
.on_value(value
.offset
)
377 def on_Cat(self
, value
):
378 for o
in value
.parts
:
381 def on_Repl(self
, value
):
382 self
.on_value(value
.value
)
384 def on_ArrayProxy(self
, value
):
385 for elem
in value
._iter
_as
_values
():
387 self
.on_value(value
.index
)
389 def on_Sample(self
, value
):
390 self
.on_value(value
.value
)
392 def on_Initial(self
, value
):
395 def on_Assign(self
, stmt
):
396 self
.on_value(stmt
.lhs
)
397 self
.on_value(stmt
.rhs
)
399 def on_property(self
, stmt
):
400 self
.on_value(stmt
.test
)
402 on_Assert
= on_property
403 on_Assume
= on_property
404 on_Cover
= on_property
406 def on_Switch(self
, stmt
):
407 self
.on_value(stmt
.test
)
408 for stmts
in stmt
.cases
.values():
409 self
.on_statement(stmts
)
411 def on_statements(self
, stmts
):
413 self
.on_statement(stmt
)
415 def on_fragment(self
, fragment
):
416 if isinstance(fragment
, Instance
):
417 for name
, (value
, dir) in fragment
.named_ports
.items():
420 old_local_domains
, self
._local
_domains
= self
._local
_domains
, set(self
._local
_domains
)
421 for domain_name
, domain
in fragment
.domains
.items():
423 self
._local
_domains
.add(domain_name
)
425 self
.defined_domains
.add(domain_name
)
427 self
.on_statements(fragment
.statements
)
428 for domain_name
in fragment
.drivers
:
429 self
._add
_used
_domain
(domain_name
)
430 for subfragment
, name
in fragment
.subfragments
:
431 self
.on_fragment(subfragment
)
433 self
._local
_domains
= old_local_domains
435 def __call__(self
, fragment
):
436 self
.on_fragment(fragment
)
439 class DomainRenamer(FragmentTransformer
, ValueTransformer
, StatementTransformer
):
440 def __init__(self
, domain_map
):
441 if isinstance(domain_map
, str):
442 domain_map
= {"sync": domain_map
}
443 for src
, dst
in domain_map
.items():
445 raise ValueError("Domain '{}' may not be renamed".format(src
))
447 raise ValueError("Domain '{}' may not be renamed to '{}'".format(src
, dst
))
448 self
.domain_map
= OrderedDict(domain_map
)
450 def on_ClockSignal(self
, value
):
451 if value
.domain
in self
.domain_map
:
452 return ClockSignal(self
.domain_map
[value
.domain
])
455 def on_ResetSignal(self
, value
):
456 if value
.domain
in self
.domain_map
:
457 return ResetSignal(self
.domain_map
[value
.domain
],
458 allow_reset_less
=value
.allow_reset_less
)
461 def map_domains(self
, fragment
, new_fragment
):
462 for domain
in fragment
.iter_domains():
463 cd
= fragment
.domains
[domain
]
464 if domain
in self
.domain_map
:
465 if cd
.name
== domain
:
466 # Rename the actual ClockDomain object.
467 cd
.rename(self
.domain_map
[domain
])
469 assert cd
.name
== self
.domain_map
[domain
]
470 new_fragment
.add_domains(cd
)
472 def map_drivers(self
, fragment
, new_fragment
):
473 for domain
, signals
in fragment
.drivers
.items():
474 if domain
in self
.domain_map
:
475 domain
= self
.domain_map
[domain
]
476 for signal
in signals
:
477 new_fragment
.add_driver(self
.on_value(signal
), domain
)
480 class DomainLowerer(FragmentTransformer
, ValueTransformer
, StatementTransformer
):
481 def __init__(self
, domains
=None):
482 self
.domains
= domains
484 def _resolve(self
, domain
, context
):
485 if domain
not in self
.domains
:
486 raise DomainError("Signal {!r} refers to nonexistent domain '{}'"
487 .format(context
, domain
))
488 return self
.domains
[domain
]
490 def map_drivers(self
, fragment
, new_fragment
):
491 for domain
, signal
in fragment
.iter_drivers():
492 new_fragment
.add_driver(self
.on_value(signal
), domain
)
494 def replace_value_src_loc(self
, value
, new_value
):
495 return not isinstance(value
, (ClockSignal
, ResetSignal
))
497 def on_ClockSignal(self
, value
):
498 domain
= self
._resolve
(value
.domain
, value
)
501 def on_ResetSignal(self
, value
):
502 domain
= self
._resolve
(value
.domain
, value
)
503 if domain
.rst
is None:
504 if value
.allow_reset_less
:
507 raise DomainError("Signal {!r} refers to reset of reset-less domain '{}'"
508 .format(value
, value
.domain
))
511 def _insert_resets(self
, fragment
):
512 for domain_name
, signals
in fragment
.drivers
.items():
513 if domain_name
is None:
515 domain
= fragment
.domains
[domain_name
]
516 if domain
.rst
is None:
518 stmts
= [signal
.eq(Const(signal
.reset
, signal
.width
))
519 for signal
in signals
if not signal
.reset_less
]
520 fragment
.add_statements(Switch(domain
.rst
, {1: stmts
}))
522 def on_fragment(self
, fragment
):
523 self
.domains
= fragment
.domains
524 new_fragment
= super().on_fragment(fragment
)
525 self
._insert
_resets
(new_fragment
)
529 class SampleDomainInjector(ValueTransformer
, StatementTransformer
):
530 def __init__(self
, domain
):
533 def on_Sample(self
, value
):
534 if value
.domain
is not None:
536 return Sample(value
.value
, value
.clocks
, self
.domain
)
538 def __call__(self
, stmts
):
539 return self
.on_statement(stmts
)
542 class SampleLowerer(FragmentTransformer
, ValueTransformer
, StatementTransformer
):
545 self
.sample_cache
= None
546 self
.sample_stmts
= None
548 def _name_reset(self
, value
):
549 if isinstance(value
, Const
):
550 return "c${}".format(value
.value
), value
.value
551 elif isinstance(value
, Signal
):
552 return "s${}".format(value
.name
), value
.reset
553 elif isinstance(value
, ClockSignal
):
555 elif isinstance(value
, ResetSignal
):
557 elif isinstance(value
, Initial
):
558 return "init", 0 # Past(Initial()) produces 0, 1, 0, 0, ...
560 raise NotImplementedError # :nocov:
562 def on_Sample(self
, value
):
563 if value
in self
.sample_cache
:
564 return self
.sample_cache
[value
]
566 sampled_value
= self
.on_value(value
.value
)
567 if value
.clocks
== 0:
568 sample
= sampled_value
570 assert value
.domain
is not None
571 sampled_name
, sampled_reset
= self
._name
_reset
(value
.value
)
572 name
= "$sample${}${}${}".format(sampled_name
, value
.domain
, value
.clocks
)
573 sample
= Signal
.like(value
.value
, name
=name
, reset_less
=True, reset
=sampled_reset
)
574 sample
.attrs
["nmigen.sample_reg"] = True
576 prev_sample
= self
.on_Sample(Sample(sampled_value
, value
.clocks
- 1, value
.domain
))
577 if value
.domain
not in self
.sample_stmts
:
578 self
.sample_stmts
[value
.domain
] = []
579 self
.sample_stmts
[value
.domain
].append(sample
.eq(prev_sample
))
581 self
.sample_cache
[value
] = sample
584 def on_Initial(self
, value
):
585 if self
.initial
is None:
586 self
.initial
= Signal(name
="init")
589 def map_statements(self
, fragment
, new_fragment
):
591 self
.sample_cache
= ValueDict()
592 self
.sample_stmts
= OrderedDict()
593 new_fragment
.add_statements(map(self
.on_statement
, fragment
.statements
))
594 for domain
, stmts
in self
.sample_stmts
.items():
595 new_fragment
.add_statements(stmts
)
597 new_fragment
.add_driver(stmt
.lhs
, domain
)
598 if self
.initial
is not None:
599 new_fragment
.add_subfragment(Instance("$initstate", o_Y
=self
.initial
))
602 class SwitchCleaner(StatementVisitor
):
603 def on_ignore(self
, stmt
):
606 on_Assign
= on_ignore
607 on_Assert
= on_ignore
608 on_Assume
= on_ignore
611 def on_Switch(self
, stmt
):
612 cases
= OrderedDict((k
, self
.on_statement(s
)) for k
, s
in stmt
.cases
.items())
613 if any(len(s
) for s
in cases
.values()):
614 return Switch(stmt
.test
, cases
)
616 def on_statements(self
, stmts
):
617 stmts
= flatten(self
.on_statement(stmt
) for stmt
in stmts
)
618 return _StatementList(stmt
for stmt
in stmts
if stmt
is not None)
621 class LHSGroupAnalyzer(StatementVisitor
):
623 self
.signals
= SignalDict()
624 self
.unions
= OrderedDict()
626 def find(self
, signal
):
627 if signal
not in self
.signals
:
628 self
.signals
[signal
] = len(self
.signals
)
629 group
= self
.signals
[signal
]
630 while group
in self
.unions
:
631 group
= self
.unions
[group
]
632 self
.signals
[signal
] = group
635 def unify(self
, root
, *leaves
):
636 root_group
= self
.find(root
)
638 leaf_group
= self
.find(leaf
)
639 if root_group
== leaf_group
:
641 self
.unions
[leaf_group
] = root_group
644 groups
= OrderedDict()
645 for signal
in self
.signals
:
646 group
= self
.find(signal
)
647 if group
not in groups
:
648 groups
[group
] = SignalSet()
649 groups
[group
].add(signal
)
652 def on_Assign(self
, stmt
):
653 lhs_signals
= stmt
._lhs
_signals
()
655 self
.unify(*stmt
._lhs
_signals
())
657 def on_property(self
, stmt
):
658 lhs_signals
= stmt
._lhs
_signals
()
660 self
.unify(*stmt
._lhs
_signals
())
662 on_Assert
= on_property
663 on_Assume
= on_property
664 on_Cover
= on_property
666 def on_Switch(self
, stmt
):
667 for case_stmts
in stmt
.cases
.values():
668 self
.on_statements(case_stmts
)
670 def on_statements(self
, stmts
):
672 self
.on_statement(stmt
)
674 def __call__(self
, stmts
):
675 self
.on_statements(stmts
)
679 class LHSGroupFilter(SwitchCleaner
):
680 def __init__(self
, signals
):
681 self
.signals
= signals
683 def on_Assign(self
, stmt
):
684 # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
685 # on LHS are a part of the same group, so it is sufficient to check any of them.
686 lhs_signals
= stmt
.lhs
._lhs
_signals
()
688 any_lhs_signal
= next(iter(lhs_signals
))
689 if any_lhs_signal
in self
.signals
:
692 def on_property(self
, stmt
):
693 any_lhs_signal
= next(iter(stmt
._lhs
_signals
()))
694 if any_lhs_signal
in self
.signals
:
697 on_Assert
= on_property
698 on_Assume
= on_property
699 on_Cover
= on_property
702 class _ControlInserter(FragmentTransformer
):
703 def __init__(self
, controls
):
705 if isinstance(controls
, Value
):
706 controls
= {"sync": controls
}
707 self
.controls
= OrderedDict(controls
)
709 def on_fragment(self
, fragment
):
710 new_fragment
= super().on_fragment(fragment
)
711 for domain
, signals
in fragment
.drivers
.items():
712 if domain
is None or domain
not in self
.controls
:
714 self
._insert
_control
(new_fragment
, domain
, signals
)
717 def _insert_control(self
, fragment
, domain
, signals
):
718 raise NotImplementedError # :nocov:
720 def __call__(self
, value
, *, src_loc_at
=0):
721 self
.src_loc
= tracer
.get_src_loc(src_loc_at
=src_loc_at
)
722 return super().__call
__(value
, src_loc_at
=1 + src_loc_at
)
725 class ResetInserter(_ControlInserter
):
726 def _insert_control(self
, fragment
, domain
, signals
):
727 stmts
= [s
.eq(Const(s
.reset
, s
.width
)) for s
in signals
if not s
.reset_less
]
728 fragment
.add_statements(Switch(self
.controls
[domain
], {1: stmts
}, src_loc
=self
.src_loc
))
731 class EnableInserter(_ControlInserter
):
732 def _insert_control(self
, fragment
, domain
, signals
):
733 stmts
= [s
.eq(s
) for s
in signals
]
734 fragment
.add_statements(Switch(self
.controls
[domain
], {0: stmts
}, src_loc
=self
.src_loc
))
736 def on_fragment(self
, fragment
):
737 new_fragment
= super().on_fragment(fragment
)
738 if isinstance(new_fragment
, Instance
) and new_fragment
.type in ("$memrd", "$memwr"):
739 clk_port
, clk_dir
= new_fragment
.named_ports
["CLK"]
740 if isinstance(clk_port
, ClockSignal
) and clk_port
.domain
in self
.controls
:
741 en_port
, en_dir
= new_fragment
.named_ports
["EN"]
742 en_port
= Mux(self
.controls
[clk_port
.domain
], en_port
, Const(0, len(en_port
)))
743 new_fragment
.named_ports
["EN"] = en_port
, en_dir