libgomp: adjust nvptx_free callback context checking
authorChung-Lin Tang <cltang@codesourcery.com>
Thu, 20 Aug 2020 14:18:51 +0000 (07:18 -0700)
committerChung-Lin Tang <cltang@codesourcery.com>
Thu, 20 Aug 2020 14:18:51 +0000 (07:18 -0700)
Change test for CUDA callback context in nvptx_free() from using
GOMP_PLUGIN_acc_thread () into checking for CUDA_ERROR_NOT_PERMITTED,
for the former only works for OpenACC, but not OpenMP offloading.

2020-08-20  Chung-Lin Tang  <cltang@codesourcery.com>

libgomp/
* plugin/plugin-nvptx.c (nvptx_free):
Change "GOMP_PLUGIN_acc_thread () == NULL" test into check of
CUDA_ERROR_NOT_PERMITTED status for cuMemGetAddressRange. Adjust
comments.

libgomp/plugin/plugin-nvptx.c

index ec103a2f40b7193285bdf394ead47584d60d6f42..390804ad1fa2f686ea2ddef85e52d118be3cecd4 100644 (file)
@@ -1040,9 +1040,17 @@ goacc_profiling_acc_ev_free (struct goacc_thread *thr, void *p)
 static bool
 nvptx_free (void *p, struct ptx_device *ptx_dev)
 {
-  /* Assume callback context if this is null.  */
-  if (GOMP_PLUGIN_acc_thread () == NULL)
+  CUdeviceptr pb;
+  size_t ps;
+
+  CUresult r = CUDA_CALL_NOCHECK (cuMemGetAddressRange, &pb, &ps,
+                                 (CUdeviceptr) p);
+  if (r == CUDA_ERROR_NOT_PERMITTED)
     {
+      /* We assume that this error indicates we are in a CUDA callback context,
+        where all CUDA calls are not allowed (see cuStreamAddCallback
+        documentation for description). Arrange to free this piece of device
+        memory later.  */
       struct ptx_free_block *n
        = GOMP_PLUGIN_malloc (sizeof (struct ptx_free_block));
       n->ptr = p;
@@ -1052,11 +1060,11 @@ nvptx_free (void *p, struct ptx_device *ptx_dev)
       pthread_mutex_unlock (&ptx_dev->free_blocks_lock);
       return true;
     }
-
-  CUdeviceptr pb;
-  size_t ps;
-
-  CUDA_CALL (cuMemGetAddressRange, &pb, &ps, (CUdeviceptr) p);
+  else if (r != CUDA_SUCCESS)
+    {
+      GOMP_PLUGIN_error ("cuMemGetAddressRange error: %s", cuda_error (r));
+      return false;
+    }
   if ((CUdeviceptr) p != pb)
     {
       GOMP_PLUGIN_error ("invalid device address");