__all__ = ["ValueVisitor", "ValueTransformer",
"StatementVisitor", "StatementTransformer",
"FragmentTransformer",
- "DomainRenamer", "DomainLowerer", "ResetInserter", "CEInserter"]
+ "DomainRenamer", "DomainLowerer",
+ "LHSGroupAnalyzer",
+ "ResetInserter", "CEInserter"]
class ValueVisitor(metaclass=ABCMeta):
pass # :nocov:
@abstractmethod
- def on_statements(self, stmt):
+ def on_statements(self, stmts):
pass # :nocov:
def on_unknown_statement(self, stmt):
cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
return Switch(self.on_value(stmt.test), cases)
- def on_statements(self, stmt):
- return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt))
+ def on_statements(self, stmts):
+ return _StatementList(flatten(self.on_statement(stmt) for stmt in stmts))
class FragmentTransformer:
return cd.rst
+class LHSGroupAnalyzer(StatementVisitor):
+ def __init__(self):
+ self.signals = SignalDict()
+ self.unions = OrderedDict()
+
+ def find(self, signal):
+ if signal not in self.signals:
+ self.signals[signal] = len(self.signals)
+ group = self.signals[signal]
+ while group in self.unions:
+ group = self.unions[group]
+ self.signals[signal] = group
+ return group
+
+ def unify(self, root, *leaves):
+ root_group = self.find(root)
+ for leaf in leaves:
+ leaf_group = self.find(leaf)
+ self.unions[leaf_group] = root_group
+
+ def groups(self):
+ groups = OrderedDict()
+ for signal in self.signals:
+ group = self.find(signal)
+ if group not in groups:
+ groups[group] = SignalSet()
+ groups[group].add(signal)
+ return groups
+
+ def on_Assign(self, stmt):
+ self.unify(*stmt._lhs_signals())
+
+ def on_Switch(self, stmt):
+ for case_stmts in stmt.cases.values():
+ self.on_statements(case_stmts)
+
+ def on_statements(self, stmts):
+ for stmt in stmts:
+ self.on_statement(stmt)
+
+ def __call__(self, stmts):
+ self.on_statements(stmts)
+ return self.groups()
+
+
class _ControlInserter(FragmentTransformer):
def __init__(self, controls):
if isinstance(controls, Value):
DomainLowerer({"sync": sync})(f)
+class LHSGroupAnalyzerTestCase(FHDLTestCase):
+ def test_no_group_unrelated(self):
+ a = Signal()
+ b = Signal()
+ stmts = [
+ a.eq(0),
+ b.eq(0),
+ ]
+
+ groups = LHSGroupAnalyzer()(stmts)
+ self.assertEqual(list(groups.values()), [
+ SignalSet((a,)),
+ SignalSet((b,)),
+ ])
+
+ def test_group_related(self):
+ a = Signal()
+ b = Signal()
+ stmts = [
+ a.eq(0),
+ Cat(a, b).eq(0),
+ ]
+
+ groups = LHSGroupAnalyzer()(stmts)
+ self.assertEqual(list(groups.values()), [
+ SignalSet((a, b)),
+ ])
+
+ def test_switch(self):
+ a = Signal()
+ b = Signal()
+ stmts = [
+ a.eq(0),
+ Switch(a, {
+ 1: b.eq(0),
+ })
+ ]
+
+ groups = LHSGroupAnalyzer()(stmts)
+ self.assertEqual(list(groups.values()), [
+ SignalSet((a,)),
+ SignalSet((b,)),
+ ])
+
+
+
class ResetInserterTestCase(FHDLTestCase):
def setUp(self):
self.s1 = Signal()