radeonsi: return the last part's return value from @wrapper
authorMarek Olšák <marek.olsak@amd.com>
Mon, 13 Aug 2018 23:11:55 +0000 (19:11 -0400)
committerMarek Olšák <marek.olsak@amd.com>
Thu, 16 May 2019 17:06:40 +0000 (13:06 -0400)
The primitive discard compute shader will get the position output this way.

Tested-by: Dieter Nützel <Dieter@nuetzel-hh.de>
Acked-by: Nicolai Hähnle <nicolai.haehnle@amd.com>
src/gallium/drivers/radeonsi/si_shader.c

index 98c11e1c98d8119bff28ae83008375d8cae97933..6455bb5dcd0252b7a772148aafd40546437c6ea7 100644 (file)
@@ -6399,7 +6399,26 @@ static void si_build_wrapper_function(struct si_shader_context *ctx,
                gprs += size;
        }
 
-       si_create_function(ctx, "wrapper", NULL, 0, &fninfo,
+       /* Prepare the return type. */
+       unsigned num_returns = 0;
+       LLVMTypeRef returns[32], last_func_type, return_type;
+
+       last_func_type = LLVMGetElementType(LLVMTypeOf(parts[num_parts - 1]));
+       return_type = LLVMGetReturnType(last_func_type);
+
+       switch (LLVMGetTypeKind(return_type)) {
+       case LLVMStructTypeKind:
+               num_returns = LLVMCountStructElementTypes(return_type);
+               assert(num_returns <= ARRAY_SIZE(returns));
+               LLVMGetStructElementTypes(return_type, returns);
+               break;
+       case LLVMVoidTypeKind:
+               break;
+       default:
+               unreachable("unexpected type");
+       }
+
+       si_create_function(ctx, "wrapper", returns, num_returns, &fninfo,
                           si_get_max_workgroup_size(ctx->shader));
 
        if (is_merged_shader(ctx))
@@ -6451,9 +6470,9 @@ static void si_build_wrapper_function(struct si_shader_context *ctx,
        initial_num_out_sgpr = num_out_sgpr;
 
        /* Now chain the parts. */
+       LLVMValueRef ret;
        for (unsigned part = 0; part < num_parts; ++part) {
                LLVMValueRef in[48];
-               LLVMValueRef ret;
                LLVMTypeRef ret_type;
                unsigned out_idx = 0;
                unsigned num_params = LLVMCountParams(parts[part]);
@@ -6565,7 +6584,11 @@ static void si_build_wrapper_function(struct si_shader_context *ctx,
                }
        }
 
-       LLVMBuildRetVoid(builder);
+       /* Return the value from the last part. */
+       if (LLVMGetTypeKind(LLVMTypeOf(ret)) == LLVMVoidTypeKind)
+               LLVMBuildRetVoid(builder);
+       else
+               LLVMBuildRet(builder, ret);
 }
 
 static bool si_should_optimize_less(struct ac_llvm_compiler *compiler,