pytholite/compiler: go to next state
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Fri, 9 Nov 2012 19:12:15 +0000 (20:12 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Fri, 9 Nov 2012 19:12:15 +0000 (20:12 +0100)
migen/pytholite/compiler.py

index a46176a7ca63ec214464799e94bda2402125703d..eebe247942623a25b005025f7db4c7e879270a13 100644 (file)
@@ -68,30 +68,51 @@ class _Compiler:
                if isinstance(node, ast.Module) \
                  and len(node.body) == 1 \
                  and isinstance(node.body[0], ast.FunctionDef):
-                       return self.visit_block(node.body[0].body)
+                       states, exit_states = self.visit_block(node.body[0].body)
+                       return states
                else:
                        raise NotImplementedError
        
        # blocks and statements
        def visit_block(self, statements):
                states = []
+               exit_states = []
                for statement in statements:
-                       if isinstance(statement, ast.Assign):
-                               op = self.visit_assign(statement)
-                               if op:
-                                       states.append(op)
-                       elif isinstance(statement, ast.If):
-                               test = self.visit_expr(statement.test)
-                               states_t = self.visit_block(statement.body)
-                               states_f = self.visit_block(statement.orelse)
-                               test_state_stmt = If(test, _AbstractNextState(states_t[0]))
-                               if states_f:
-                                       test_state_stmt.Else(_AbstractNextState(states_f[0]))
-                               states.append([test_state_stmt])
-                               states += states_t + states_f
+                       n_states, n_exit_states = self.visit_statement(statement)
+                       if n_states:
+                               states += n_states
+                               for exit_state in exit_states:
+                                       exit_state.insert(0, _AbstractNextState(n_states[0]))
+                               exit_states = n_exit_states
+               return states, exit_states
+       
+       # entry state is first state returned
+       def visit_statement(self, statement):
+               states = []
+               exit_states = []
+               if isinstance(statement, ast.Assign):
+                       op = self.visit_assign(statement)
+                       if op:
+                               states.append(op)
+                               exit_states.append(op)
+               elif isinstance(statement, ast.If):
+                       test = self.visit_expr(statement.test)
+                       states_t, exit_states_t = self.visit_block(statement.body)
+                       states_f, exit_states_f  = self.visit_block(statement.orelse)
+                       
+                       test_state_stmt = If(test, _AbstractNextState(states_t[0]))
+                       test_state = [test_state_stmt]
+                       if states_f:
+                               test_state_stmt.Else(_AbstractNextState(states_f[0]))
                        else:
-                               raise NotImplementedError
-               return states
+                               exit_states.append(test_state)
+                       
+                       states.append(test_state)
+                       states += states_t + states_f
+                       exit_states += exit_states_t + exit_states_f
+               else:
+                       raise NotImplementedError
+               return states, exit_states
        
        def visit_assign(self, node):
                if isinstance(node.targets[0], ast.Name):