From: whitequark Date: Sat, 22 Dec 2018 06:50:32 +0000 (+0000) Subject: hdl.xfrm: implement LHSGroupAnalyzer. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=36a59ff4af168f1e4f5399f216b896a6f11a5e51;p=nmigen.git hdl.xfrm: implement LHSGroupAnalyzer. --- diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index 31525d9..f7f7f26 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -12,7 +12,9 @@ from .ir import * __all__ = ["ValueVisitor", "ValueTransformer", "StatementVisitor", "StatementTransformer", "FragmentTransformer", - "DomainRenamer", "DomainLowerer", "ResetInserter", "CEInserter"] + "DomainRenamer", "DomainLowerer", + "LHSGroupAnalyzer", + "ResetInserter", "CEInserter"] class ValueVisitor(metaclass=ABCMeta): @@ -134,7 +136,7 @@ class StatementVisitor(metaclass=ABCMeta): pass # :nocov: @abstractmethod - def on_statements(self, stmt): + def on_statements(self, stmts): pass # :nocov: def on_unknown_statement(self, stmt): @@ -166,8 +168,8 @@ class StatementTransformer(StatementVisitor): cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items()) return Switch(self.on_value(stmt.test), cases) - def on_statements(self, stmt): - return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt)) + def on_statements(self, stmts): + return _StatementList(flatten(self.on_statement(stmt) for stmt in stmts)) class FragmentTransformer: @@ -278,6 +280,51 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer) return cd.rst +class LHSGroupAnalyzer(StatementVisitor): + def __init__(self): + self.signals = SignalDict() + self.unions = OrderedDict() + + def find(self, signal): + if signal not in self.signals: + self.signals[signal] = len(self.signals) + group = self.signals[signal] + while group in self.unions: + group = self.unions[group] + self.signals[signal] = group + return group + + def unify(self, root, *leaves): + root_group = self.find(root) + for leaf in leaves: + leaf_group = self.find(leaf) + self.unions[leaf_group] = root_group + + def groups(self): + groups = OrderedDict() + for signal in self.signals: + group = self.find(signal) + if group not in groups: + groups[group] = SignalSet() + groups[group].add(signal) + return groups + + def on_Assign(self, stmt): + self.unify(*stmt._lhs_signals()) + + def on_Switch(self, stmt): + for case_stmts in stmt.cases.values(): + self.on_statements(case_stmts) + + def on_statements(self, stmts): + for stmt in stmts: + self.on_statement(stmt) + + def __call__(self, stmts): + self.on_statements(stmts) + return self.groups() + + class _ControlInserter(FragmentTransformer): def __init__(self, controls): if isinstance(controls, Value): diff --git a/nmigen/test/test_hdl_xfrm.py b/nmigen/test/test_hdl_xfrm.py index 802761d..b811160 100644 --- a/nmigen/test/test_hdl_xfrm.py +++ b/nmigen/test/test_hdl_xfrm.py @@ -158,6 +158,52 @@ class DomainLowererTestCase(FHDLTestCase): DomainLowerer({"sync": sync})(f) +class LHSGroupAnalyzerTestCase(FHDLTestCase): + def test_no_group_unrelated(self): + a = Signal() + b = Signal() + stmts = [ + a.eq(0), + b.eq(0), + ] + + groups = LHSGroupAnalyzer()(stmts) + self.assertEqual(list(groups.values()), [ + SignalSet((a,)), + SignalSet((b,)), + ]) + + def test_group_related(self): + a = Signal() + b = Signal() + stmts = [ + a.eq(0), + Cat(a, b).eq(0), + ] + + groups = LHSGroupAnalyzer()(stmts) + self.assertEqual(list(groups.values()), [ + SignalSet((a, b)), + ]) + + def test_switch(self): + a = Signal() + b = Signal() + stmts = [ + a.eq(0), + Switch(a, { + 1: b.eq(0), + }) + ] + + groups = LHSGroupAnalyzer()(stmts) + self.assertEqual(list(groups.values()), [ + SignalSet((a,)), + SignalSet((b,)), + ]) + + + class ResetInserterTestCase(FHDLTestCase): def setUp(self): self.s1 = Signal()