self.rhs_compiler = rhs_compiler
         self.lhs_compiler = lhs_compiler
 
-        self._group = None
-        self._case  = None
-
-        self._test_cache = {}
-
-        self._has_rhs = False
-        self._has_assign = False
+        self._case        = None
+        self._test_cache  = {}
+        self._has_rhs     = False
 
     @contextmanager
     def case(self, switch, value):
         finally:
             self._case = old_case
 
-    def on_Assign(self, stmt):
-        # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
-        # on LHS are a part of the same group, so it is sufficient to check any of them.
-        any_lhs_signal = next(iter(stmt.lhs._lhs_signals()))
-        if any_lhs_signal not in self._group:
-            return
-
-        if self._has_rhs or next(iter(stmt.rhs._rhs_signals()), None) is not None:
+    def _check_rhs(self, value):
+        if self._has_rhs or next(iter(value._rhs_signals()), None) is not None:
             self._has_rhs = True
-        self._has_assign = True
+
+    def on_Assign(self, stmt):
+        self._check_rhs(stmt.rhs)
 
         lhs_bits, lhs_sign = stmt.lhs.shape()
         rhs_bits, rhs_sign = stmt.rhs.shape()
         self._case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec)
 
     def on_Switch(self, stmt):
+        self._check_rhs(stmt.test)
+
         if stmt not in self._test_cache:
             self._test_cache[stmt] = self.rhs_compiler(stmt.test)
         test_sigspec = self._test_cache[stmt]
 
-        try:
-            self._has_assign, old_has_assign = False, self._has_assign
-
-            with self._case.switch(test_sigspec) as switch:
-                for value, stmts in stmt.cases.items():
-                    with self.case(switch, value):
-                        self.on_statements(stmts)
-
-        finally:
-            if self._has_assign:
-                if self._has_rhs or next(iter(stmt.test._rhs_signals()), None) is not None:
-                    self._has_rhs = True
-
-            self._has_assign = old_has_assign
+        with self._case.switch(test_sigspec) as switch:
+            for value, stmts in stmt.cases.items():
+                with self.case(switch, value):
+                    self.on_statements(stmts)
 
     def on_statement(self, stmt):
         try:
         rhs_compiler   = _RHSValueCompiler(compiler_state)
         lhs_compiler   = _LHSValueCompiler(compiler_state)
         stmt_compiler  = _StatementCompiler(compiler_state, rhs_compiler, lhs_compiler)
-        switch_cleaner = xfrm.SwitchCleaner()
 
         verilog_trigger = None
         verilog_trigger_sync_emitted = False
         lhs_grouper.on_statements(fragment.statements)
 
         for group, group_signals in lhs_grouper.groups().items():
+            lhs_group_filter = xfrm.LHSGroupFilter(group_signals)
+
             with module.process(name="$group_{}".format(group)) as process:
                 with process.case() as case:
                     # For every signal in comb domain, assign \sig$next to the reset value.
                         case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
 
                     # Convert statements into decision trees.
-                    stmt_compiler._group = group_signals
                     stmt_compiler._case = case
                     stmt_compiler._has_rhs = False
-                    stmt_compiler(switch_cleaner(fragment.statements))
+                    stmt_compiler(lhs_group_filter(fragment.statements))
 
                     # Verilog `always @*` blocks will not run if `*` does not match anythng, i.e.
                     # if the implicit sensitivity list is empty. We check this while translating,
 
            "StatementVisitor", "StatementTransformer",
            "FragmentTransformer",
            "DomainRenamer", "DomainLowerer",
-           "SwitchCleaner", "LHSGroupAnalyzer",
+           "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
            "ResetInserter", "CEInserter"]
 
 
 
     def on_Switch(self, stmt):
         cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
-        if any(len(s) for s in stmt.cases.values()):
+        if any(len(s) for s in cases.values()):
             return Switch(stmt.test, cases)
 
     def on_statements(self, stmts):
         return self.groups()
 
 
+class LHSGroupFilter(SwitchCleaner):
+    def __init__(self, signals):
+        self.signals = signals
+
+    def on_Assign(self, stmt):
+        # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
+        # on LHS are a part of the same group, so it is sufficient to check any of them.
+        any_lhs_signal = next(iter(stmt.lhs._lhs_signals()))
+        if any_lhs_signal in self.signals:
+            return stmt
+
+
 class _ControlInserter(FragmentTransformer):
     def __init__(self, controls):
         if isinstance(controls, Value):