python: Teach PyBindMethod how to set return_value_policy.
authorGabe Black <gabeblack@google.com>
Thu, 7 Mar 2019 08:27:52 +0000 (00:27 -0800)
committerGabe Black <gabeblack@google.com>
Thu, 14 Mar 2019 21:03:48 +0000 (21:03 +0000)
Change-Id: Ia208e43672672556b36f905e8f71dce44b978d22
Reviewed-on: https://gem5-review.googlesource.com/c/public/gem5/+/17033
Reviewed-by: Andreas Sandberg <andreas.sandberg@arm.com>
Maintainer: Andreas Sandberg <andreas.sandberg@arm.com>

src/python/m5/util/pybind.py

index 4b5e03d3180796128220f6479a18ea22b5dd6777..4664c16023796ca7acfafdaeb5942ac54e894369 100644 (file)
@@ -58,10 +58,12 @@ class PyBindProperty(PyBindExport):
         code('.${export}("${{self.name}}", &${cname}::${{self.cxx_name}})')
 
 class PyBindMethod(PyBindExport):
-    def __init__(self, name, cxx_name=None, args=None):
+    def __init__(self, name, cxx_name=None, args=None,
+                 return_value_policy=None):
         self.name = name
         self.cxx_name = cxx_name if cxx_name else name
         self.args = args
+        self.return_value_policy = return_value_policy
 
     def _conv_arg(self, value):
         if isinstance(value, bool):
@@ -72,6 +74,10 @@ class PyBindMethod(PyBindExport):
             raise TypeError("Unsupported PyBind default value type")
 
     def export(self, code, cname):
+        arguments = [ '"${{self.name}}"', '&${cname}::${{self.cxx_name}}' ]
+        if self.return_value_policy:
+            arguments.append('pybind11::return_value_policy::'
+                             '${{self.return_value_policy}}')
         if self.args:
             def get_arg_decl(arg):
                 if isinstance(arg, tuple):
@@ -81,8 +87,5 @@ class PyBindMethod(PyBindExport):
                 else:
                     return 'py::arg("%s")' % arg
 
-            code('.def("${{self.name}}", &${cname}::${{self.name}}, ')
-            code('    ' + \
-                 ', '.join([ get_arg_decl(a) for a in self.args ]) + ')')
-        else:
-            code('.def("${{self.name}}", &${cname}::${{self.cxx_name}})')
+            arguments.extend(list([ get_arg_decl(a) for a in self.args ]))
+        code('.def(' + ', '.join(arguments) + ')')