hdl.xfrm: implement LHSGroupAnalyzer.
authorwhitequark <cz@m-labs.hk>
Sat, 22 Dec 2018 06:50:32 +0000 (06:50 +0000)
committerwhitequark <cz@m-labs.hk>
Sat, 22 Dec 2018 06:58:24 +0000 (06:58 +0000)
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_xfrm.py

index 31525d9577dcfd7d2a3f34d0d950d2eacf3b0ac7..f7f7f2602058dae53cf6ff629dde90356a5bb39f 100644 (file)
@@ -12,7 +12,9 @@ from .ir import *
 __all__ = ["ValueVisitor", "ValueTransformer",
            "StatementVisitor", "StatementTransformer",
            "FragmentTransformer",
-           "DomainRenamer", "DomainLowerer", "ResetInserter", "CEInserter"]
+           "DomainRenamer", "DomainLowerer",
+           "LHSGroupAnalyzer",
+           "ResetInserter", "CEInserter"]
 
 
 class ValueVisitor(metaclass=ABCMeta):
@@ -134,7 +136,7 @@ class StatementVisitor(metaclass=ABCMeta):
         pass # :nocov:
 
     @abstractmethod
-    def on_statements(self, stmt):
+    def on_statements(self, stmts):
         pass # :nocov:
 
     def on_unknown_statement(self, stmt):
@@ -166,8 +168,8 @@ class StatementTransformer(StatementVisitor):
         cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
         return Switch(self.on_value(stmt.test), cases)
 
-    def on_statements(self, stmt):
-        return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt))
+    def on_statements(self, stmts):
+        return _StatementList(flatten(self.on_statement(stmt) for stmt in stmts))
 
 
 class FragmentTransformer:
@@ -278,6 +280,51 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
         return cd.rst
 
 
+class LHSGroupAnalyzer(StatementVisitor):
+    def __init__(self):
+        self.signals = SignalDict()
+        self.unions  = OrderedDict()
+
+    def find(self, signal):
+        if signal not in self.signals:
+            self.signals[signal] = len(self.signals)
+        group = self.signals[signal]
+        while group in self.unions:
+            group = self.unions[group]
+        self.signals[signal] = group
+        return group
+
+    def unify(self, root, *leaves):
+        root_group = self.find(root)
+        for leaf in leaves:
+            leaf_group = self.find(leaf)
+            self.unions[leaf_group] = root_group
+
+    def groups(self):
+        groups = OrderedDict()
+        for signal in self.signals:
+            group = self.find(signal)
+            if group not in groups:
+                groups[group] = SignalSet()
+            groups[group].add(signal)
+        return groups
+
+    def on_Assign(self, stmt):
+        self.unify(*stmt._lhs_signals())
+
+    def on_Switch(self, stmt):
+        for case_stmts in stmt.cases.values():
+            self.on_statements(case_stmts)
+
+    def on_statements(self, stmts):
+        for stmt in stmts:
+            self.on_statement(stmt)
+
+    def __call__(self, stmts):
+        self.on_statements(stmts)
+        return self.groups()
+
+
 class _ControlInserter(FragmentTransformer):
     def __init__(self, controls):
         if isinstance(controls, Value):
index 802761dd71ed81abbaab3581b9e5a328502ed1b8..b811160b55b79295ea3185ecb0fe39cb0eedc867 100644 (file)
@@ -158,6 +158,52 @@ class DomainLowererTestCase(FHDLTestCase):
             DomainLowerer({"sync": sync})(f)
 
 
+class LHSGroupAnalyzerTestCase(FHDLTestCase):
+    def test_no_group_unrelated(self):
+        a = Signal()
+        b = Signal()
+        stmts = [
+            a.eq(0),
+            b.eq(0),
+        ]
+
+        groups = LHSGroupAnalyzer()(stmts)
+        self.assertEqual(list(groups.values()), [
+            SignalSet((a,)),
+            SignalSet((b,)),
+        ])
+
+    def test_group_related(self):
+        a = Signal()
+        b = Signal()
+        stmts = [
+            a.eq(0),
+            Cat(a, b).eq(0),
+        ]
+
+        groups = LHSGroupAnalyzer()(stmts)
+        self.assertEqual(list(groups.values()), [
+            SignalSet((a, b)),
+        ])
+
+    def test_switch(self):
+        a = Signal()
+        b = Signal()
+        stmts = [
+            a.eq(0),
+            Switch(a, {
+                1: b.eq(0),
+            })
+        ]
+
+        groups = LHSGroupAnalyzer()(stmts)
+        self.assertEqual(list(groups.values()), [
+            SignalSet((a,)),
+            SignalSet((b,)),
+        ])
+
+
+
 class ResetInserterTestCase(FHDLTestCase):
     def setUp(self):
         self.s1 = Signal()