rhs_compiler = _RHSValueCompiler(compiler_state)
lhs_compiler = _LHSValueCompiler(compiler_state)
stmt_compiler = _StatementCompiler(compiler_state, rhs_compiler, lhs_compiler)
+ switch_cleaner = xfrm.SwitchCleaner()
verilog_trigger = None
verilog_trigger_sync_emitted = False
stmt_compiler._group = group_signals
stmt_compiler._case = case
stmt_compiler._has_rhs = False
- stmt_compiler(fragment.statements)
+ stmt_compiler(switch_cleaner(fragment.statements))
# Verilog `always @*` blocks will not run if `*` does not match anythng, i.e.
# if the implicit sensitivity list is empty. We check this while translating,
"StatementVisitor", "StatementTransformer",
"FragmentTransformer",
"DomainRenamer", "DomainLowerer",
- "LHSGroupAnalyzer",
+ "SwitchCleaner", "LHSGroupAnalyzer",
"ResetInserter", "CEInserter"]
return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
def on_Switch(self, stmt):
- cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
+ cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
return Switch(self.on_value(stmt.test), cases)
def on_statements(self, stmts):
return cd.rst
+class SwitchCleaner(StatementVisitor):
+ def on_Assign(self, stmt):
+ return stmt
+
+ def on_Switch(self, stmt):
+ cases = OrderedDict((k, self.on_statement(s)) for k, s in stmt.cases.items())
+ if any(len(s) for s in stmt.cases.values()):
+ return Switch(stmt.test, cases)
+
+ def on_statements(self, stmts):
+ stmts = flatten(self.on_statement(stmt) for stmt in stmts)
+ return _StatementList(stmt for stmt in stmts if stmt is not None)
+
+
class LHSGroupAnalyzer(StatementVisitor):
def __init__(self):
self.signals = SignalDict()
DomainLowerer({"sync": sync})(f)
+class SwitchCleanerTestCase(FHDLTestCase):
+ def test_clean(self):
+ a = Signal()
+ b = Signal()
+ c = Signal()
+ stmts = [
+ Switch(a, {
+ 1: a.eq(0),
+ 0: [
+ b.eq(1),
+ Switch(b, {1: []})
+ ]
+ })
+ ]
+
+ self.assertRepr(SwitchCleaner()(stmts), """
+ (
+ (switch (sig a)
+ (case 1
+ (eq (sig a) (const 1'd0)))
+ (case 0
+ (eq (sig b) (const 1'd1)))
+ )
+ )
+ """)
+
+
class LHSGroupAnalyzerTestCase(FHDLTestCase):
def test_no_group_unrelated(self):
a = Signal()