gallivm: rework coroutine malloc/free callouts.
authorDave Airlie <airlied@redhat.com>
Fri, 15 May 2020 00:03:32 +0000 (10:03 +1000)
committerDave Airlie <airlied@redhat.com>
Wed, 10 Jun 2020 20:05:40 +0000 (06:05 +1000)
When using cached shaders we have to relink the shader with
external symbols when it's loaded. However the way gallivm does
function calls now hardcodes the function pointer into the shader.

LLVM had a mechanism for doing this properly using global mappings,
this switches the coroutine alloc/free code to use a global mapping.

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5049>

src/gallium/auxiliary/draw/draw_llvm.c
src/gallium/auxiliary/gallivm/lp_bld_coro.c
src/gallium/auxiliary/gallivm/lp_bld_coro.h
src/gallium/auxiliary/gallivm/lp_bld_init.h
src/gallium/drivers/llvmpipe/lp_state_cs.c

index 79322e4485b6893422ea8b3946d23a88f3bd5bec..db8e24f43069e07080620d96fef0f5eeb559bf47 100644 (file)
@@ -3464,10 +3464,12 @@ draw_tcs_llvm_create_variant(struct draw_llvm *llvm,
       draw_tcs_llvm_dump_variant_key(&variant->key);
    }
 
+   lp_build_coro_declare_malloc_hooks(variant->gallivm);
    draw_tcs_llvm_generate(llvm, variant);
 
    gallivm_compile_module(variant->gallivm);
 
+   lp_build_coro_add_malloc_hooks(variant->gallivm);
    variant->jit_func = (draw_tcs_jit_func)
       gallivm_jit_function(variant->gallivm, variant->function);
 
index 8f3db6f481f221d5785c18f079a953254e93cd9d..28f722e93c23ddaad36649318eb2f525b34663c1 100644 (file)
@@ -154,6 +154,26 @@ coro_free(char *ptr)
    os_free_aligned(ptr);
 }
 
+void lp_build_coro_add_malloc_hooks(struct gallivm_state *gallivm)
+{
+   assert(gallivm->engine);
+
+   assert(gallivm->coro_malloc_hook);
+   assert(gallivm->coro_free_hook);
+   LLVMAddGlobalMapping(gallivm->engine, gallivm->coro_malloc_hook, coro_malloc);
+   LLVMAddGlobalMapping(gallivm->engine, gallivm->coro_free_hook, coro_free);
+}
+
+void lp_build_coro_declare_malloc_hooks(struct gallivm_state *gallivm)
+{
+   LLVMTypeRef int32_type = LLVMInt32TypeInContext(gallivm->context);
+   LLVMTypeRef mem_ptr_type = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0);
+   LLVMTypeRef malloc_type = LLVMFunctionType(mem_ptr_type, &int32_type, 1, 0);
+   gallivm->coro_malloc_hook = LLVMAddFunction(gallivm->module, "coro_malloc", malloc_type);
+   LLVMTypeRef free_type = LLVMFunctionType(LLVMVoidTypeInContext(gallivm->context), &mem_ptr_type, 1, 0);
+   gallivm->coro_free_hook = LLVMAddFunction(gallivm->module, "coro_free", free_type);
+}
+
 LLVMValueRef lp_build_coro_begin_alloc_mem(struct gallivm_state *gallivm, LLVMValueRef coro_id)
 {
    LLVMValueRef do_alloc = lp_build_coro_alloc(gallivm, coro_id);
@@ -163,13 +183,9 @@ LLVMValueRef lp_build_coro_begin_alloc_mem(struct gallivm_state *gallivm, LLVMVa
    lp_build_if(&if_state_coro, gallivm, do_alloc);
    LLVMValueRef coro_size = lp_build_coro_size(gallivm);
    LLVMValueRef alloc_mem;
-   LLVMTypeRef int32_type = LLVMInt32TypeInContext(gallivm->context);
 
-   LLVMTypeRef malloc_type = LLVMFunctionType(mem_ptr_type, &int32_type, 1, 0);
-
-   LLVMValueRef func_malloc = lp_build_const_int_pointer(gallivm, func_to_pointer((func_pointer)coro_malloc));
-   func_malloc = LLVMBuildBitCast(gallivm->builder, func_malloc, LLVMPointerType(malloc_type, 0), "coro_malloc");
-   alloc_mem = LLVMBuildCall(gallivm->builder, func_malloc, &coro_size, 1, "");
+   assert(gallivm->coro_malloc_hook);
+   alloc_mem = LLVMBuildCall(gallivm->builder, gallivm->coro_malloc_hook, &coro_size, 1, "");
 
    LLVMBuildStore(gallivm->builder, alloc_mem, alloc_mem_store);
    lp_build_endif(&if_state_coro);
@@ -181,11 +197,9 @@ LLVMValueRef lp_build_coro_begin_alloc_mem(struct gallivm_state *gallivm, LLVMVa
 void lp_build_coro_free_mem(struct gallivm_state *gallivm, LLVMValueRef coro_id, LLVMValueRef coro_hdl)
 {
    LLVMValueRef alloc_mem = lp_build_coro_free(gallivm, coro_id, coro_hdl);
-   LLVMTypeRef ptr_type = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0);
-   LLVMTypeRef free_type = LLVMFunctionType(LLVMVoidTypeInContext(gallivm->context), &ptr_type, 1, 0);
-   LLVMValueRef func_free = lp_build_const_int_pointer(gallivm, func_to_pointer((func_pointer)coro_free));
-   func_free = LLVMBuildBitCast(gallivm->builder, func_free, LLVMPointerType(free_type, 0), "coro_free");
-   alloc_mem = LLVMBuildCall(gallivm->builder, func_free, &alloc_mem, 1, "");
+
+   assert(gallivm->coro_malloc_hook);
+   alloc_mem = LLVMBuildCall(gallivm->builder, gallivm->coro_free_hook, &alloc_mem, 1, "");
 }
 
 void lp_build_coro_suspend_switch(struct gallivm_state *gallivm, const struct lp_build_coro_suspend_info *sus_info,
index 4cb1e7c6e4d0eee6c1ba9d2b080105d06a468bdb..2ffc130c9ad0d7b59262a67594d3b23628815a6c 100644 (file)
@@ -66,4 +66,7 @@ void lp_build_coro_suspend_switch(struct gallivm_state *gallivm,
                                   const struct lp_build_coro_suspend_info *sus_info,
                                   LLVMBasicBlockRef resume_block,
                                   bool final_suspend);
+
+void lp_build_coro_add_malloc_hooks(struct gallivm_state *gallivm);
+void lp_build_coro_declare_malloc_hooks(struct gallivm_state *gallivm);
 #endif
index 56279ce1aac2bc6ecff24a94302cafad88725d4c..1c0c627fb5997ce2b3ccc3997ecb0eb78beca10a 100644 (file)
@@ -52,6 +52,8 @@ struct gallivm_state
    LLVMMCJITMemoryManagerRef memorymgr;
    struct lp_generated_code *code;
    unsigned compiled;
+   LLVMValueRef coro_malloc_hook;
+   LLVMValueRef coro_free_hook;
 };
 
 
index 9c4699331a0291214cb9b4edb7e3d7f0fdc93df4..97c479ef45590bf86785aa2f473238baa62ee4a5 100644 (file)
@@ -138,6 +138,8 @@ generate_compute(struct llvmpipe_context *lp,
       }
    }
 
+   lp_build_coro_declare_malloc_hooks(gallivm);
+
    context_ptr  = LLVMGetParam(function, 0);
    x_size_arg = LLVMGetParam(function, 1);
    y_size_arg = LLVMGetParam(function, 2);
@@ -709,6 +711,7 @@ generate_variant(struct llvmpipe_context *lp,
 
    gallivm_compile_module(variant->gallivm);
 
+   lp_build_coro_add_malloc_hooks(variant->gallivm);
    variant->nr_instrs += lp_build_count_ir_module(variant->gallivm->module);
 
    variant->jit_function = (lp_jit_cs_func)gallivm_jit_function(variant->gallivm, variant->function);