back.rtlil: extract _StatementCompiler. NFC.
authorwhitequark <whitequark@whitequark.org>
Sun, 16 Dec 2018 22:26:58 +0000 (22:26 +0000)
committerwhitequark <whitequark@whitequark.org>
Sun, 16 Dec 2018 22:26:58 +0000 (22:26 +0000)
nmigen/back/rtlil.py

index c04c531f6298f00aab5a889bd4685ea1bdcd3e21..1836d75756171358a6b19db8b82e73e3a264a005 100644 (file)
@@ -462,11 +462,42 @@ class _LHSValueCompiler(_ValueCompiler):
         raise NotImplementedError
 
 
+class _StatementCompiler(xfrm.AbstractStatementTransformer):
+    def __init__(self, rhs_compiler, lhs_compiler):
+        self.rhs_compiler = rhs_compiler
+        self.lhs_compiler = lhs_compiler
+
+    def on_Assign(self, stmt):
+        if isinstance(stmt, ast.Assign):
+            lhs_bits, lhs_sign = stmt.lhs.shape()
+            rhs_bits, rhs_sign = stmt.rhs.shape()
+            if lhs_bits == rhs_bits:
+                rhs_sigspec = self.rhs_compiler(stmt.rhs)
+            else:
+                # In RTLIL, LHS and RHS of assignment must have exactly same width.
+                rhs_sigspec = self.rhs_compiler.match_shape(
+                    stmt.rhs, lhs_bits, rhs_sign)
+            self.case.assign(self.lhs_compiler(stmt.lhs), rhs_sigspec)
+
+    def on_Switch(self, stmt):
+        with self.case.switch(self.rhs_compiler(stmt.test)) as switch:
+            for value, stmts in stmt.cases.items():
+                old_case = self.case
+                with switch.case(value) as self.case:
+                    self.on_statements(stmts)
+                self.case = old_case
+
+    def on_statements(self, stmts):
+        for stmt in stmts:
+            self.on_statement(stmt)
+
+
 def convert_fragment(builder, fragment, name, top):
     with builder.module(name or "anonymous", attrs={"top": 1} if top else {}) as module:
         compiler_state = _ValueCompilerState(module)
         rhs_compiler   = _RHSValueCompiler(compiler_state)
         lhs_compiler   = _LHSValueCompiler(compiler_state)
+        stmt_compiler  = _StatementCompiler(rhs_compiler, lhs_compiler)
 
         # Register all signals driven in the current fragment. This must be done first, as it
         # affects further codegen; e.g. whether sig$next signals will be generated and used.
@@ -509,29 +540,8 @@ def convert_fragment(builder, fragment, name, top):
                     case.assign(lhs_compiler(signal), rhs_compiler(prev_value))
 
                 # Convert statements into decision trees.
-                def _convert_stmts(case, stmts):
-                    for stmt in stmts:
-                        if isinstance(stmt, ast.Assign):
-                            lhs_bits, lhs_sign = stmt.lhs.shape()
-                            rhs_bits, rhs_sign = stmt.rhs.shape()
-                            if lhs_bits == rhs_bits:
-                                rhs_sigspec = rhs_compiler(stmt.rhs)
-                            else:
-                                # In RTLIL, LHS and RHS of assignment must have exactly same width.
-                                rhs_sigspec = rhs_compiler.match_shape(
-                                    stmt.rhs, lhs_bits, rhs_sign)
-                            case.assign(lhs_compiler(stmt.lhs), rhs_sigspec)
-
-                        elif isinstance(stmt, ast.Switch):
-                            with case.switch(rhs_compiler(stmt.test)) as switch:
-                                for value, nested_stmts in stmt.cases.items():
-                                    with switch.case(value) as nested_case:
-                                        _convert_stmts(nested_case, nested_stmts)
-
-                        else:
-                            raise TypeError
-
-                _convert_stmts(case, fragment.statements)
+                stmt_compiler.case = case
+                stmt_compiler(fragment.statements)
 
             # For every signal in the sync domain, assign \sig's initial value (which will end up
             # as the \init reg attribute) to the reset value.