pytholite: support generator arguments
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 3 Jul 2013 14:35:07 +0000 (16:35 +0200)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Wed, 3 Jul 2013 14:35:07 +0000 (16:35 +0200)
migen/pytholite/compiler.py

index 767d86a662eaf5338240aadfedd13a993bb81f82..0df85391e547fda9377b09071d197254f7748836 100644 (file)
@@ -1,5 +1,6 @@
 import inspect
 import ast
+from collections import OrderedDict
 
 from migen.fhdl.structure import *
 from migen.fhdl.visit import TransformModule
@@ -17,6 +18,66 @@ def _is_name_used(node, name):
                        return True
        return False
 
+def _make_function_args_dict(undefined, symdict, args, defaults):
+       d = OrderedDict()
+       for argument in args:
+               d[argument.arg] = undefined
+       for default, argname in zip(defaults, reversed(list(d.keys()))):
+               default_val = eval_ast(default, symdict)
+               d[argname] = default_val
+       return d
+
+def _process_function_args(symdict, function_def, args, kwargs):
+       defargs = function_def.args
+       undefined = object()
+
+       ad_positional = _make_function_args_dict(undefined, symdict, defargs.args, defargs.defaults)
+       vararg_name = defargs.vararg
+       kwarg_name = defargs.kwarg
+       ad_kwonly = _make_function_args_dict(undefined, symdict, defargs.kwonlyargs, defargs.kw_defaults)
+
+       # grab argument values
+       current_argvalue = iter(args)
+       try:
+               for argname in ad_positional.keys():
+                       ad_positional[argname] = next(current_argvalue)
+       except StopIteration:
+               pass
+       vararg = tuple(current_argvalue)
+
+       kwarg = OrderedDict()
+       for k, v in kwarg.items():
+               if k in ad_positional:
+                       ad_positional[k] = v
+               elif k in ad_kwonly:
+                       ad_kwonly[k] = v
+               else:
+                       kwarg[k] = v
+
+       # check
+       undefined_pos = [k for k, v in ad_positional.items() if v is undefined]
+       if undefined_pos:
+               formatted = " and ".join("'" + k + "'" for k in undefined_pos)
+               raise TypeError("Missing required positional arguments: " + formatted)
+       if vararg and vararg_name is None:
+               raise TypeError("Function takes {} positional arguments but {} were given".format(len(ad_positional),
+                       len(ad_positional) + len(vararg)))
+       ad_kwonly = [k for k, v in ad_positional.items() if v is undefined]
+       if undefined_pos:
+               formatted = " and ".join("'" + k + "'" for k in undefined_pos)
+               raise TypeError("Missing required keyword-only arguments: " + formatted)
+       if kwarg and kwarg_name is None:
+               formatted = " and ".join("'" + k + "'" for k in kwarg.keys())
+               raise TypeError("Got unexpected keyword arguments: " + formatted)
+
+       # update symdict
+       symdict.update(ad_positional)
+       if vararg_name is not None:
+               symdict[vararg_name] = vararg
+       symdict.update(ad_kwonly)
+       if kwarg_name is not None:
+               symdict[kwarg_name] = kwarg
+
 class _Compiler:
        def __init__(self, ioo, symdict, registers):
                self.ioo = ioo
@@ -24,11 +85,13 @@ class _Compiler:
                self.registers = registers
                self.ec = ExprCompiler(self.symdict)
        
-       def visit_top(self, node):
+       def visit_top(self, node, args, kwargs):
                if isinstance(node, ast.Module) \
                  and len(node.body) == 1 \
                  and isinstance(node.body[0], ast.FunctionDef):
-                       states, exit_states = self.visit_block(node.body[0].body)
+                       function_def = node.body[0]
+                       _process_function_args(self.symdict, function_def, args, kwargs)
+                       states, exit_states = self.visit_block(function_def.body)
                        return states
                else:
                        raise NotImplementedError
@@ -220,8 +283,10 @@ class _Compiler:
                        raise NotImplementedError
 
 class Pytholite(UnifiedIOObject):
-       def __init__(self, func):
+       def __init__(self, func, *args, **kwargs):
                self.func = func
+               self.args = args
+               self.kwargs = kwargs
 
        def do_finalize(self):
                UnifiedIOObject.do_finalize(self)
@@ -240,7 +305,7 @@ class Pytholite(UnifiedIOObject):
                symdict = self.func.__globals__.copy()
                registers = []
                
-               states = _Compiler(self, symdict, registers).visit_top(tree)
+               states = _Compiler(self, symdict, registers).visit_top(tree, self.args, self.kwargs)
                
                for register in registers:
                        if register.source_encoding: