hdl.xfrm, back.rtlil: implement and use LHSGroupFilter.
authorwhitequark <cz@m-labs.hk>
Mon, 24 Dec 2018 02:17:28 +0000 (02:17 +0000)
committerwhitequark <cz@m-labs.hk>
Mon, 24 Dec 2018 02:17:28 +0000 (02:17 +0000)
This is a refactoring to simplify reusing the filtering code in
simulation, and separate that concern from backends in general.

nmigen/back/rtlil.py
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_xfrm.py

index aca7ff28fee515d1b38e08084cfdc6adfebd9fac..7e3fede0be4c0659c03fd808294756ada59187ad 100644 (file)
@@ -527,13 +527,9 @@ class _StatementCompiler(xfrm.StatementVisitor):
         self.rhs_compiler = rhs_compiler
         self.lhs_compiler = lhs_compiler
 
-        self._group = None
-        self._case  = None
-
-        self._test_cache = {}
-
-        self._has_rhs = False
-        self._has_assign = False
+        self._case        = None
+        self._test_cache  = {}
+        self._has_rhs     = False
 
     @contextmanager
     def case(self, switch, value):
@@ -544,16 +540,12 @@ class _StatementCompiler(xfrm.StatementVisitor):
         finally:
             self._case = old_case
 
-    def on_Assign(self, stmt):
-        # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
-        # on LHS are a part of the same group, so it is sufficient to check any of them.
-        any_lhs_signal = next(iter(stmt.lhs._lhs_signals()))
-        if any_lhs_signal not in self._group:
-            return
-
-        if self._has_rhs or next(iter(stmt.rhs._rhs_signals()), None) is not None:
+    def _check_rhs(self, value):
+        if self._has_rhs or next(iter(value._rhs_signals()), None) is not None:
             self._has_rhs = True
-        self._has_assign = True
+
+    def on_Assign(self, stmt):
+        self._check_rhs(stmt.rhs)
 
         lhs_bits, lhs_sign = stmt.lhs.shape()
         rhs_bits, rhs_sign = stmt.rhs.shape()
@@ -566,24 +558,16 @@ class _StatementCompiler(xfrm.StatementVisitor):
         self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec)
 
     def on_Switch(self, stmt):
+        self._check_rhs(stmt.test)
+
         if stmt not in self._test_cache:
             self._test_cache[stmt] = self.rhs_compiler(stmt.test)
         test_sigspec = self._test_cache[stmt]
 
-        try:
-            self._has_assign, old_has_assign = False, self._has_assign
-
-            with self._case.switch(test_sigspec) as switch:
-                for value, stmts in stmt.cases.items():
-                    with self.case(switch, value):
-                        self.on_statements(stmts)
-
-        finally:
-            if self._has_assign:
-                if self._has_rhs or next(iter(stmt.test._rhs_signals()), None) is not None:
-                    self._has_rhs = True
-
-            self._has_assign = old_has_assign
+        with self._case.switch(test_sigspec) as switch:
+            for value, stmts in stmt.cases.items():
+                with self.case(switch, value):
+                    self.on_statements(stmts)
 
     def on_statement(self, stmt):
         try:
@@ -620,7 +604,6 @@ def convert_fragment(builder, fragment, name, top):
         rhs_compiler   = _RHSValueCompiler(compiler_state)
         lhs_compiler   = _LHSValueCompiler(compiler_state)
         stmt_compiler  = _StatementCompiler(compiler_state, rhs_compiler, lhs_compiler)
-        switch_cleaner = xfrm.SwitchCleaner()
 
         verilog_trigger = None
         verilog_trigger_sync_emitted = False
@@ -703,6 +686,8 @@ def convert_fragment(builder, fragment, name, top):
         lhs_grouper.on_statements(fragment.statements)
 
         for group, group_signals in lhs_grouper.groups().items():
+            lhs_group_filter = xfrm.LHSGroupFilter(group_signals)
+
             with module.process(name="$group_{}".format(group)) as process:
                 with process.case() as case:
                     # For every signal in comb domain, assign \sig$next to the reset value.
@@ -718,10 +703,9 @@ def convert_fragment(builder, fragment, name, top):
                         case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
 
                     # Convert statements into decision trees.
-                    stmt_compiler._group = group_signals
                     stmt_compiler._case = case
                     stmt_compiler._has_rhs = False
-                    stmt_compiler(switch_cleaner(fragment.statements))
+                    stmt_compiler(lhs_group_filter(fragment.statements))
 
                     # Verilog `always @*` blocks will not run if `*` does not match anythng, i.e.
                     # if the implicit sensitivity list is empty. We check this while translating,
index 2b219b82a311b302028ee8d075a41c4fedc03411..2f1f9377f085c925a54ae63adadfee54258143cc 100644 (file)
@@ -13,7 +13,7 @@ __all__ = ["ValueVisitor", "ValueTransformer",
            "StatementVisitor", "StatementTransformer",
            "FragmentTransformer",
            "DomainRenamer", "DomainLowerer",
-           "SwitchCleaner", "LHSGroupAnalyzer",
+           "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
            "ResetInserter", "CEInserter"]
 
 
@@ -286,7 +286,7 @@ class SwitchCleaner(StatementVisitor):
 
     def on_Switch(self, stmt):
         cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
-        if any(len(s) for s in stmt.cases.values()):
+        if any(len(s) for s in cases.values()):
             return Switch(stmt.test, cases)
 
     def on_statements(self, stmts):
@@ -341,6 +341,18 @@ class LHSGroupAnalyzer(StatementVisitor):
         return self.groups()
 
 
+class LHSGroupFilter(SwitchCleaner):
+    def __init__(self, signals):
+        self.signals = signals
+
+    def on_Assign(self, stmt):
+        # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
+        # on LHS are a part of the same group, so it is sufficient to check any of them.
+        any_lhs_signal = next(iter(stmt.lhs._lhs_signals()))
+        if any_lhs_signal in self.signals:
+            return stmt
+
+
 class _ControlInserter(FragmentTransformer):
     def __init__(self, controls):
         if isinstance(controls, Value):
index 1c6d6c0603afce2b0342464d85d5666ef557b5f9..88bd78f7852e19003329648fe57ab8ffe5b030fa 100644 (file)
@@ -168,7 +168,9 @@ class SwitchCleanerTestCase(FHDLTestCase):
                 1: a.eq(0),
                 0: [
                     b.eq(1),
-                    Switch(b, {1: []})
+                    Switch(b, {1: [
+                        Switch(a|b, {})
+                    ]})
                 ]
             })
         ]
@@ -244,6 +246,31 @@ class LHSGroupAnalyzerTestCase(FHDLTestCase):
         ])
 
 
+class LHSGroupFilterTestCase(FHDLTestCase):
+    def test_filter(self):
+        a = Signal()
+        b = Signal()
+        c = Signal()
+        stmts = [
+            Switch(a, {
+                1: a.eq(0),
+                0: [
+                    b.eq(1),
+                    Switch(b, {1: []})
+                ]
+            })
+        ]
+
+        self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """
+        (
+            (switch (sig a)
+                (case 1
+                    (eq (sig a) (const 1'd0)))
+                (case 0 )
+            )
+        )
+        """)
+
 
 class ResetInserterTestCase(FHDLTestCase):
     def setUp(self):