pytholite/compiler: visit_assign_special
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Sun, 11 Nov 2012 14:52:06 +0000 (15:52 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Sun, 11 Nov 2012 14:52:06 +0000 (15:52 +0100)
migen/pytholite/compiler.py

index 4763bd6afc4c574b691d57862e58e87d78c4113e..660719c06a942626609410ceca8c4cc8a1964498 100644 (file)
@@ -62,7 +62,6 @@ class _Compiler:
                self.ioo = ioo
                self.symdict = symdict
                self.registers = registers
-               self.targetname = ""
        
        def visit_top(self, node):
                if isinstance(node, ast.Module) \
@@ -77,38 +76,40 @@ class _Compiler:
        def visit_block(self, statements):
                sa = StateAssembler()
                statements = iter(statements)
+               statement = None
                while True:
-                       try:
-                               statement = next(statements)
-                       except StopIteration:
-                               return sa.ret()
+                       if statement is None:
+                               try:
+                                       statement = next(statements)
+                               except StopIteration:
+                                       return sa.ret()
                        if isinstance(statement, ast.Assign):
-                               self.visit_assign(sa, statement)
-                       elif isinstance(statement, ast.If):
-                               self.visit_if(sa, statement)
-                       elif isinstance(statement, ast.While):
-                               self.visit_while(sa, statement)
-                       elif isinstance(statement, ast.For):
-                               self.visit_for(sa, statement)
-                       elif isinstance(statement, ast.Expr):
-                               self.visit_expr_statement(sa, statement)
+                               # visit_assign can recognize a I/O pattern, consume several
+                               # statements from the iterator and return the first statement
+                               # that is not part of the I/O pattern anymore.
+                               statement = self.visit_assign(sa, statement, statements)
                        else:
-                               raise NotImplementedError
-       
-       def visit_assign(self, sa, node):
-               if isinstance(node.targets[0], ast.Name):
-                       self.targetname = node.targets[0].id
-               value = self.visit_expr(node.value, True)
-               self.targetname = ""
-               
-               if isinstance(value, _Register):
-                       self.registers.append(value)
-                       for target in node.targets:
-                               if isinstance(target, ast.Name):
-                                       self.symdict[target.id] = value
+                               if isinstance(statement, ast.If):
+                                       self.visit_if(sa, statement)
+                               elif isinstance(statement, ast.While):
+                                       self.visit_while(sa, statement)
+                               elif isinstance(statement, ast.For):
+                                       self.visit_for(sa, statement)
+                               elif isinstance(statement, ast.Expr):
+                                       self.visit_expr_statement(sa, statement)
                                else:
                                        raise NotImplementedError
-               elif isinstance(value, Value):
+                               statement = None
+       
+       def visit_assign(self, sa, node, statements):
+               if isinstance(node.value, ast.Call):
+                       try:
+                               value = self.visit_expr_call(node.value)
+                       except NotImplementedError:
+                               return self.visit_assign_special(sa, node, statements)
+               else:
+                       value = self.visit_expr(node.value)
+               if isinstance(value, Value):
                        r = []
                        for target in node.targets:
                                if isinstance(target, ast.Attribute) and target.attr == "store":
@@ -122,7 +123,33 @@ class _Compiler:
                        sa.assemble([r], [r])
                else:
                        raise NotImplementedError
-                       
+       
+       def visit_assign_special(self, sa, node, statements):
+               value = node.value
+               assert(isinstance(value, ast.Call))
+               if isinstance(value.func, ast.Name):
+                       callee = self.symdict[value.func.id]
+               else:
+                       raise NotImplementedError
+               
+               if callee == transel.Register:
+                       if len(value.args) != 1:
+                               raise TypeError("Register() takes exactly 1 argument")
+                       nbits = ast.literal_eval(value.args[0])
+                       if isinstance(node.targets[0], ast.Name):
+                               targetname = node.targets[0].id
+                       else:
+                               targetname = "unk"
+                       reg = _Register(targetname, nbits)
+                       self.registers.append(reg)
+                       for target in node.targets:
+                               if isinstance(target, ast.Name):
+                                       self.symdict[target.id] = reg
+                               else:
+                                       raise NotImplementedError
+               else:
+                       raise NotImplementedError
+       
        def visit_if(self, sa, node):
                test = self.visit_expr(node.test)
                states_t, exit_states_t = self.visit_block(node.body)
@@ -193,12 +220,9 @@ class _Compiler:
                        raise NotImplementedError
        
        # expressions
-       def visit_expr(self, node, allow_registers=False):
+       def visit_expr(self, node):
                if isinstance(node, ast.Call):
-                       r = self.visit_expr_call(node)
-                       if not allow_registers and isinstance(r, _Register):
-                               raise NotImplementedError
-                       return r
+                       return self.visit_expr_call(node)
                elif isinstance(node, ast.BinOp):
                        return self.visit_expr_binop(node)
                elif isinstance(node, ast.Compare):
@@ -215,12 +239,7 @@ class _Compiler:
                        callee = self.symdict[node.func.id]
                else:
                        raise NotImplementedError
-               if callee == transel.Register:
-                       if len(node.args) != 1:
-                               raise TypeError("Register() takes exactly 1 argument")
-                       nbits = ast.literal_eval(node.args[0])
-                       return _Register(self.targetname, nbits)
-               elif callee == transel.bitslice:
+               if callee == transel.bitslice:
                        if len(node.args) != 2 and len(node.args) != 3:
                                raise TypeError("bitslice() takes 2 or 3 arguments")
                        val = self.visit_expr(node.args[0])