dispatcher: support arbitrary arguments
authorDmitry Selyutin <ghostmansd@gmail.com>
Wed, 28 Jun 2023 22:12:41 +0000 (01:12 +0300)
committerDmitry Selyutin <ghostmansd@gmail.com>
Wed, 28 Jun 2023 22:12:41 +0000 (01:12 +0300)
src/mdis/dispatcher.py

index 8455039ce298d31637bd6b8c584bc3840c53b5ec..71f229132294a08758cdec27f9e6c1d7c76d62c8 100644 (file)
@@ -33,9 +33,39 @@ class Hook(object):
 
     def __call__(self, call):
         class ConcreteHook(Hook):
-            def __call__(self, dispatcher, node, *arguments):
-                if (len(inspect.signature(call).parameters) > 2):
-                    return call(dispatcher, node, *arguments)
+            def __call__(self, dispatcher, node, *args, **kwargs):
+                # We do not force specific arguments other than node.
+                # API users can introduce additional *args and **kwargs.
+                # However, in case they choose not to, this is fine too.
+                parameters = tuple(inspect.signature(call).parameters.values())
+                if len(parameters) < 2:
+                    raise TypeError(f"{call.__name__}: missing required arguments")
+                if parameters[0].kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
+                    raise TypeError(f"{call.__name__}: incorrect self argument")
+                if parameters[1].kind != inspect.Parameter.POSITIONAL_OR_KEYWORD:
+                    raise TypeError(f"{call.__name__}: incorrect node argument")
+                args_present = False
+                kwargs_present = False
+                for parameter in parameters:
+                    positionals = (
+                        inspect.Parameter.POSITIONAL_OR_KEYWORD,
+                        inspect.Parameter.VAR_POSITIONAL,
+                    )
+                    keywords = (
+                        inspect.Parameter.POSITIONAL_OR_KEYWORD,
+                        inspect.Parameter.VAR_KEYWORD,
+                        inspect.Parameter.KEYWORD_ONLY,
+                    )
+                    if parameter.kind in positionals:
+                        args_present = True
+                    elif parameter.kind in keywords:
+                        kwargs_present = True
+                if args_present and kwargs_present:
+                    return call(dispatcher, node, *args, **kwargs)
+                elif args_present:
+                    return call(dispatcher, node, *args)
+                elif kwargs_present:
+                    return call(dispatcher, node, **kwargs)
                 else:
                     return call(dispatcher, node)
 
@@ -84,15 +114,15 @@ class DispatcherMeta(type):
 
 
 class Dispatcher(metaclass=DispatcherMeta):
-    def __call__(self, node, *arguments):
+    def __call__(self, node, *args, **kwargs):
         for typeid in node.__class__.__mro__:
             hook = self.__class__.dispatch(typeid=typeid)
             if hook is not None:
                 break
         if hook is None:
             hook = self.__class__.dispatch()
-        return hook(self, node, *arguments)
+        return hook(self, node, *args, **kwargs)
 
     @Hook(object)
-    def dispatch_object(self, node, *arguments):
+    def dispatch_object(self, node):
         raise NotImplementedError()