pytholite/compiler: refactor visit_block
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Sun, 11 Nov 2012 13:17:52 +0000 (14:17 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Sun, 11 Nov 2012 13:17:52 +0000 (14:17 +0100)
migen/pytholite/compiler.py

index 5dc001cd25eb8faf9e83c5156353fb3108018a89..fc06923800c5960c0c6de0220bae9abc891314f7 100644 (file)
@@ -61,6 +61,21 @@ class _AbstractNextState:
        def __init__(self, target_state):
                self.target_state = target_state
 
+# entry state is first state returned
+class _StateAssembler:
+       def __init__(self):
+               self.states = []
+               self.exit_states = []
+       
+       def assemble(self, n_states, n_exit_states):
+               self.states += n_states
+               for exit_state in self.exit_states:
+                       exit_state.insert(0, _AbstractNextState(n_states[0]))
+               self.exit_states = n_exit_states
+       
+       def ret(self):
+               return self.states, self.exit_states
+               
 class _Compiler:
        def __init__(self, ioo, symdict, registers):
                self.ioo = ioo
@@ -79,93 +94,27 @@ class _Compiler:
        
        # blocks and statements
        def visit_block(self, statements):
-               states = []
-               exit_states = []
-               for statement in statements:
-                       n_states, n_exit_states = self.visit_statement(statement)
-                       if n_states:
-                               states += n_states
-                               for exit_state in exit_states:
-                                       exit_state.insert(0, _AbstractNextState(n_states[0]))
-                               exit_states = n_exit_states
-               return states, exit_states
-       
-       # entry state is first state returned
-       def visit_statement(self, statement):
-               if isinstance(statement, ast.Assign):
-                       op = self.visit_assign(statement)
-                       if op:
-                               return [op], [op]
-                       else:
-                               return [], []
-               elif isinstance(statement, ast.If):
-                       test = self.visit_expr(statement.test)
-                       states_t, exit_states_t = self.visit_block(statement.body)
-                       states_f, exit_states_f = self.visit_block(statement.orelse)
-                       exit_states = exit_states_t + exit_states_f
-                       
-                       test_state_stmt = If(test, _AbstractNextState(states_t[0]))
-                       test_state = [test_state_stmt]
-                       if states_f:
-                               test_state_stmt.Else(_AbstractNextState(states_f[0]))
+               sa = _StateAssembler()
+               statements = iter(statements)
+               while True:
+                       try:
+                               statement = next(statements)
+                       except StopIteration:
+                               return sa.ret()
+                       if isinstance(statement, ast.Assign):
+                               self.visit_assign(sa, statement)
+                       elif isinstance(statement, ast.If):
+                               self.visit_if(sa, statement)
+                       elif isinstance(statement, ast.While):
+                               self.visit_while(sa, statement)
+                       elif isinstance(statement, ast.For):
+                               self.visit_for(sa, statement)
+                       elif isinstance(statement, ast.Expr):
+                               self.visit_expr_statement(sa, statement)
                        else:
-                               exit_states.append(test_state)
-                       
-                       return [test_state] + states_t + states_f, exit_states
-               elif isinstance(statement, ast.While):
-                       test = self.visit_expr(statement.test)
-                       states_b, exit_states_b = self.visit_block(statement.body)
-
-                       test_state = [If(test, _AbstractNextState(states_b[0]))]
-                       for exit_state in exit_states_b:
-                               exit_state.insert(0, _AbstractNextState(test_state))
-                       
-                       return [test_state] + states_b, [test_state]
-               elif isinstance(statement, ast.For):
-                       if not isinstance(statement.target, ast.Name):
                                raise NotImplementedError
-                       target = statement.target.id
-                       if target in self.symdict:
-                               raise NotImplementedError("For loop target must use an available name")
-                       it = self.visit_iterator(statement.iter)
-                       states = []
-                       last_exit_states = []
-                       for iteration in it:
-                               self.symdict[target] = iteration
-                               states_b, exit_states_b = self.visit_block(statement.body)
-                               for exit_state in last_exit_states:
-                                       exit_state.insert(0, _AbstractNextState(states_b[0]))
-                               last_exit_states = exit_states_b
-                               states += states_b
-                       del self.symdict[target]
-                       return states, last_exit_states
-               elif isinstance(statement, ast.Expr):
-                       if isinstance(statement.value, ast.Yield):
-                               yvalue = statement.value.value
-                               if not isinstance(yvalue, ast.Call) or not isinstance(yvalue.func, ast.Name):
-                                       raise NotImplementedError("Unrecognized I/O sequence")
-                               callee = self.symdict[yvalue.func.id]
-                               return gen_io(self, callee, yvalue.args, [])
-                       else:
-                               raise NotImplementedError
-               else:
-                       raise NotImplementedError
-               return states, exit_states
-       
-       def visit_iterator(self, node):
-               if isinstance(node, ast.List):
-                       return ast.literal_eval(node)
-               elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
-                       funcname = node.func.id
-                       args = map(ast.literal_eval, node.args)
-                       if funcname == "range":
-                               return range(*args)
-                       else:
-                               raise NotImplementedError
-               else:
-                       raise NotImplementedError
        
-       def visit_assign(self, node):
+       def visit_assign(self, sa, node):
                if isinstance(node.targets[0], ast.Name):
                        self.targetname = node.targets[0].id
                value = self.visit_expr(node.value, True)
@@ -178,7 +127,6 @@ class _Compiler:
                                        self.symdict[target.id] = value
                                else:
                                        raise NotImplementedError
-                       return []
                elif isinstance(value, Value):
                        r = []
                        for target in node.targets:
@@ -190,7 +138,76 @@ class _Compiler:
                                                raise NotImplementedError
                                else:
                                        raise NotImplementedError
-                       return r
+                       sa.assemble([r], [r])
+               else:
+                       raise NotImplementedError
+                       
+       def visit_if(self, sa, node):
+               test = self.visit_expr(node.test)
+               states_t, exit_states_t = self.visit_block(node.body)
+               states_f, exit_states_f = self.visit_block(node.orelse)
+               exit_states = exit_states_t + exit_states_f
+               
+               test_state_stmt = If(test, _AbstractNextState(states_t[0]))
+               test_state = [test_state_stmt]
+               if states_f:
+                       test_state_stmt.Else(_AbstractNextState(states_f[0]))
+               else:
+                       exit_states.append(test_state)
+               
+               sa.assemble([test_state] + states_t + states_f,
+                       exit_states)
+       
+       def visit_while(self, sa, node):
+               test = self.visit_expr(node.test)
+               states_b, exit_states_b = self.visit_block(node.body)
+
+               test_state = [If(test, _AbstractNextState(states_b[0]))]
+               for exit_state in exit_states_b:
+                       exit_state.insert(0, _AbstractNextState(test_state))
+               
+               sa.assemble([test_state] + states_b, [test_state])
+       
+       def visit_for(self, sa, node):
+               if not isinstance(node.target, ast.Name):
+                       raise NotImplementedError
+               target = node.target.id
+               if target in self.symdict:
+                       raise NotImplementedError("For loop target must use an available name")
+               it = self.visit_iterator(node.iter)
+               states = []
+               last_exit_states = []
+               for iteration in it:
+                       self.symdict[target] = iteration
+                       states_b, exit_states_b = self.visit_block(node.body)
+                       for exit_state in last_exit_states:
+                               exit_state.insert(0, _AbstractNextState(states_b[0]))
+                       last_exit_states = exit_states_b
+                       states += states_b
+               del self.symdict[target]
+               sa.assemble(states, last_exit_states)
+       
+       def visit_iterator(self, node):
+               if isinstance(node, ast.List):
+                       return ast.literal_eval(node)
+               elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
+                       funcname = node.func.id
+                       args = map(ast.literal_eval, node.args)
+                       if funcname == "range":
+                               return range(*args)
+                       else:
+                               raise NotImplementedError
+               else:
+                       raise NotImplementedError
+       
+       def visit_expr_statement(self, sa, node):
+               if isinstance(node.value, ast.Yield):
+                       yvalue = node.value.value
+                       if not isinstance(yvalue, ast.Call) or not isinstance(yvalue.func, ast.Name):
+                               raise NotImplementedError("Unrecognized I/O sequence")
+                       callee = self.symdict[yvalue.func.id]
+                       states, exit_states = gen_io(self, callee, yvalue.args, [])
+                       sa.assemble(states, exit_states)
                else:
                        raise NotImplementedError