pytholite: do not use ast.NodeVisitor
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Tue, 6 Nov 2012 12:52:19 +0000 (13:52 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Tue, 6 Nov 2012 12:52:19 +0000 (13:52 +0100)
migen/pytholite/compiler.py

index 5a9e0978bdae98c2e76ba44e2fd674a17022b331..e3f5a760948ea5a7b5fb316b13666554c032161f 100644 (file)
@@ -37,31 +37,76 @@ class _Register:
                sync = [Case(self.sel, *cases)]
                return Fragment(sync=sync)
 
-class _AnonymousRegister:
-       def __init__(self, nbits):
-               self.nbits = nbits
-
-class _CompileVisitor(ast.NodeVisitor):
+class _Compiler:
        def __init__(self, symdict, registers):
                self.symdict = symdict
                self.registers = registers
+               self.targetname = ""
        
-       def visit_Assign(self, node):
-               value = self.visit(node.value)
-               if isinstance(value, _AnonymousRegister):
-                       if isinstance(node.targets[0], ast.Name):
-                               name = node.targets[0].id
+       def visit_top(self, node):
+               if isinstance(node, ast.Module) \
+                 and len(node.body) == 1 \
+                 and isinstance(node.body[0], ast.FunctionDef):
+                       return self.visit_block(node.body[0].body)
+               else:
+                       raise NotImplementedError
+       
+       # blocks and statements
+       def visit_block(self, statements):
+               r = []
+               for statement in statements:
+                       if isinstance(statement, ast.Assign):
+                               r += self.visit_assign(statement)
                        else:
                                raise NotImplementedError
-                       value = _Register(name, value.nbits)
+               return r
+       
+       def visit_assign(self, node):
+               if isinstance(node.targets[0], ast.Name):
+                       self.targetname = node.targets[0].id
+               value = self.visit_expr(node.value, True)
+               self.targetname = ""
+               
+               if isinstance(value, _Register):
                        self.registers.append(value)
                        for target in node.targets:
                                if isinstance(target, ast.Name):
                                        self.symdict[target.id] = value
                                else:
                                        raise NotImplementedError
+                       return []
+               elif isinstance(value, Value):
+                       r = []
+                       for target in node.targets:
+                               if isinstance(target, ast.Attribute) and target.attr == "store":
+                                       treg = target.value
+                                       if isinstance(treg, ast.Name):
+                                               r.append(self.symdict[treg.id].load(value))
+                                       else:
+                                               raise NotImplementedError
+                               else:
+                                       raise NotImplementedError
+                       return r
+               else:
+                       raise NotImplementedError
        
-       def visit_Call(self, node):
+       # expressions
+       def visit_expr(self, node, allow_call=False):
+               if isinstance(node, ast.Call):
+                       if allow_call:
+                               return self.visit_expr_call(node)
+                       else:
+                               raise NotImplementedError
+               elif isinstance(node, ast.BinOp):
+                       return self.visit_expr_binop(node)
+               elif isinstance(node, ast.Name):
+                       return self.visit_expr_name(node)
+               elif isinstance(node, ast.Num):
+                       return self.visit_expr_num(node)
+               else:
+                       raise NotImplementedError
+       
+       def visit_expr_call(self, node):
                if isinstance(node.func, ast.Name):
                        callee = self.symdict[node.func.id]
                else:
@@ -70,19 +115,54 @@ class _CompileVisitor(ast.NodeVisitor):
                        if len(node.args) != 1:
                                raise TypeError("Register() takes exactly 1 argument")
                        nbits = ast.literal_eval(node.args[0])
-                       return _AnonymousRegister(nbits)
+                       return _Register(self.targetname, nbits)
                else:
                        raise NotImplementedError
        
+       def visit_expr_binop(self, node):
+               left = self.visit_expr(node.left)
+               right = self.visit_expr(node.right)
+               if isinstance(node.op, ast.Add):
+                       return left + right
+               elif isinstance(node.op, ast.Sub):
+                       return left - right
+               elif isinstance(node.op, ast.Mult):
+                       return left * right
+               elif isinstance(node.op, ast.LShift):
+                       return left << right
+               elif isinstance(node.op, ast.RShift):
+                       return left >> right
+               elif isinstance(node.op, ast.BitOr):
+                       return left | right
+               elif isinstance(node.op, ast.BitXor):
+                       return left ^ right
+               elif isinstance(node.op, ast.BitAnd):
+                       return left & right
+               else:
+                       raise NotImplementedError
+       
+       def visit_expr_name(self, node):
+               r = self.symdict[node.id]
+               if isinstance(r, _Register):
+                       r = r.storage
+               return r
+       
+       def visit_expr_num(self, node):
+               return node.n
+
 def make_pytholite(func):
        tree = ast.parse(inspect.getsource(func))
        symdict = func.__globals__.copy()
        registers = []
        
-       cv = _CompileVisitor(symdict, registers)
-       cv.visit(tree)
+       c = _Compiler(symdict, registers)
+       print("compilation result:")
+       print(c.visit_top(tree))
        
+       print("registers:")
        print(registers)
-       print(symdict)
+       #print("symdict:")
+       #print(symdict)
 
-       #print(ast.dump(tree))
+       print("ast:")
+       print(ast.dump(tree))