llvmpipe: add compute shader generation.
authorDave Airlie <airlied@redhat.com>
Tue, 27 Aug 2019 04:57:54 +0000 (14:57 +1000)
committerDave Airlie <airlied@redhat.com>
Wed, 4 Sep 2019 05:22:20 +0000 (15:22 +1000)
This creates the coroutine execution environment and the
main compute shaders that get executed inside it.

Each compute shader block is executed in it's own coroutine
execution shader, which each "thread" being a coroutine executed
inside it in sequence.

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
src/gallium/drivers/llvmpipe/lp_state_cs.c
src/gallium/drivers/llvmpipe/lp_state_cs.h

index cf47d1e8f06b771d7b04e1ee79a98ad332cb017e..c93d3c5ce2b095a214cd2e5b22c11938a998574f 100644 (file)
 #include "util/os_time.h"
 #include "tgsi/tgsi_dump.h"
 #include "tgsi/tgsi_parse.h"
+#include "gallivm/lp_bld_const.h"
 #include "gallivm/lp_bld_debug.h"
+#include "gallivm/lp_bld_intr.h"
+#include "gallivm/lp_bld_flow.h"
+#include "gallivm/lp_bld_gather.h"
+#include "gallivm/lp_bld_coro.h"
 #include "lp_state_cs.h"
 #include "lp_context.h"
 #include "lp_debug.h"
 #include "lp_state.h"
 #include "lp_perf.h"
 
+static void
+generate_compute(struct llvmpipe_context *lp,
+                 struct lp_compute_shader *shader,
+                 struct lp_compute_shader_variant *variant)
+{
+   struct gallivm_state *gallivm = variant->gallivm;
+   char func_name[64], func_name_coro[64];
+   LLVMTypeRef arg_types[13];
+   LLVMTypeRef func_type, coro_func_type;
+   LLVMTypeRef int32_type = LLVMInt32TypeInContext(gallivm->context);
+   LLVMValueRef context_ptr;
+   LLVMValueRef x_size_arg, y_size_arg, z_size_arg;
+   LLVMValueRef grid_x_arg, grid_y_arg, grid_z_arg;
+   LLVMValueRef grid_size_x_arg, grid_size_y_arg, grid_size_z_arg;
+   LLVMValueRef thread_data_ptr;
+   LLVMBasicBlockRef block;
+   LLVMBuilderRef builder;
+   LLVMValueRef function, coro;
+   struct lp_type cs_type;
+   unsigned i;
+
+   /*
+    * This function has two parts
+    * a) setup the coroutine execution environment loop.
+    * b) build the compute shader llvm for use inside the coroutine.
+    */
+   assert(lp_native_vector_width / 32 >= 4);
+
+   memset(&cs_type, 0, sizeof cs_type);
+   cs_type.floating = TRUE;      /* floating point values */
+   cs_type.sign = TRUE;          /* values are signed */
+   cs_type.norm = FALSE;         /* values are not limited to [0,1] or [-1,1] */
+   cs_type.width = 32;           /* 32-bit float */
+   cs_type.length = MIN2(lp_native_vector_width / 32, 16); /* n*4 elements per vector */
+   snprintf(func_name, sizeof(func_name), "cs%u_variant%u",
+            shader->no, variant->no);
+
+   snprintf(func_name_coro, sizeof(func_name), "cs_co_%u_variant%u",
+            shader->no, variant->no);
+
+   arg_types[0] = variant->jit_cs_context_ptr_type;       /* context */
+   arg_types[1] = int32_type;                          /* block_x_size */
+   arg_types[2] = int32_type;                          /* block_y_size */
+   arg_types[3] = int32_type;                          /* block_z_size */
+   arg_types[4] = int32_type;                          /* grid_x */
+   arg_types[5] = int32_type;                          /* grid_y */
+   arg_types[6] = int32_type;                          /* grid_z */
+   arg_types[7] = int32_type;                          /* grid_size_x */
+   arg_types[8] = int32_type;                          /* grid_size_y */
+   arg_types[9] = int32_type;                          /* grid_size_z */
+   arg_types[10] = variant->jit_cs_thread_data_ptr_type;  /* per thread data */
+   arg_types[11] = int32_type;
+   arg_types[12] = int32_type;
+   func_type = LLVMFunctionType(LLVMVoidTypeInContext(gallivm->context),
+                                arg_types, ARRAY_SIZE(arg_types) - 2, 0);
+
+   coro_func_type = LLVMFunctionType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0),
+                                     arg_types, ARRAY_SIZE(arg_types), 0);
+
+   function = LLVMAddFunction(gallivm->module, func_name, func_type);
+   LLVMSetFunctionCallConv(function, LLVMCCallConv);
+
+   coro = LLVMAddFunction(gallivm->module, func_name_coro, coro_func_type);
+   LLVMSetFunctionCallConv(coro, LLVMCCallConv);
+
+   variant->function = function;
+
+   for(i = 0; i < ARRAY_SIZE(arg_types); ++i) {
+      if(LLVMGetTypeKind(arg_types[i]) == LLVMPointerTypeKind) {
+         lp_add_function_attr(coro, i + 1, LP_FUNC_ATTR_NOALIAS);
+         lp_add_function_attr(function, i + 1, LP_FUNC_ATTR_NOALIAS);
+      }
+   }
+
+   context_ptr  = LLVMGetParam(function, 0);
+   x_size_arg = LLVMGetParam(function, 1);
+   y_size_arg = LLVMGetParam(function, 2);
+   z_size_arg = LLVMGetParam(function, 3);
+   grid_x_arg = LLVMGetParam(function, 4);
+   grid_y_arg = LLVMGetParam(function, 5);
+   grid_z_arg = LLVMGetParam(function, 6);
+   grid_size_x_arg = LLVMGetParam(function, 7);
+   grid_size_y_arg = LLVMGetParam(function, 8);
+   grid_size_z_arg = LLVMGetParam(function, 9);
+   thread_data_ptr  = LLVMGetParam(function, 10);
+
+   lp_build_name(context_ptr, "context");
+   lp_build_name(x_size_arg, "x_size");
+   lp_build_name(y_size_arg, "y_size");
+   lp_build_name(z_size_arg, "z_size");
+   lp_build_name(grid_x_arg, "grid_x");
+   lp_build_name(grid_y_arg, "grid_y");
+   lp_build_name(grid_z_arg, "grid_z");
+   lp_build_name(grid_size_x_arg, "grid_size_x");
+   lp_build_name(grid_size_y_arg, "grid_size_y");
+   lp_build_name(grid_size_z_arg, "grid_size_z");
+   lp_build_name(thread_data_ptr, "thread_data");
+
+   block = LLVMAppendBasicBlockInContext(gallivm->context, function, "entry");
+   builder = gallivm->builder;
+   assert(builder);
+   LLVMPositionBuilderAtEnd(builder, block);
+
+   struct lp_build_loop_state loop_state[4];
+   LLVMValueRef num_x_loop;
+   LLVMValueRef vec_length = lp_build_const_int32(gallivm, cs_type.length);
+   num_x_loop = LLVMBuildAdd(gallivm->builder, x_size_arg, vec_length, "");
+   num_x_loop = LLVMBuildSub(gallivm->builder, num_x_loop, lp_build_const_int32(gallivm, 1), "");
+   num_x_loop = LLVMBuildUDiv(gallivm->builder, num_x_loop, vec_length, "");
+   LLVMValueRef partials = LLVMBuildURem(gallivm->builder, x_size_arg, vec_length, "");
+
+   LLVMValueRef coro_num_hdls = LLVMBuildMul(gallivm->builder, num_x_loop, y_size_arg, "");
+   coro_num_hdls = LLVMBuildMul(gallivm->builder, coro_num_hdls, z_size_arg, "");
+
+   LLVMTypeRef hdl_ptr_type = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0);
+   LLVMValueRef coro_hdls = LLVMBuildArrayAlloca(gallivm->builder, hdl_ptr_type, coro_num_hdls, "coro_hdls");
+
+   unsigned end_coroutine = INT_MAX;
+
+   /*
+    * This is the main coroutine execution loop. It iterates over the dimensions
+    * and calls the coroutine main entrypoint on the first pass, but in subsequent
+    * passes it checks if the coroutine has completed and resumes it if not.
+    */
+   /* take x_width - round up to type.length width */
+   lp_build_loop_begin(&loop_state[3], gallivm,
+                       lp_build_const_int32(gallivm, 0)); /* coroutine reentry loop */
+   lp_build_loop_begin(&loop_state[2], gallivm,
+                       lp_build_const_int32(gallivm, 0)); /* z loop */
+   lp_build_loop_begin(&loop_state[1], gallivm,
+                       lp_build_const_int32(gallivm, 0)); /* y loop */
+   lp_build_loop_begin(&loop_state[0], gallivm,
+                       lp_build_const_int32(gallivm, 0)); /* x loop */
+   {
+      LLVMValueRef args[13];
+      args[0] = context_ptr;
+      args[1] = loop_state[0].counter;
+      args[2] = loop_state[1].counter;
+      args[3] = loop_state[2].counter;
+      args[4] = grid_x_arg;
+      args[5] = grid_y_arg;
+      args[6] = grid_z_arg;
+      args[7] = grid_size_x_arg;
+      args[8] = grid_size_y_arg;
+      args[9] = grid_size_z_arg;
+      args[10] = thread_data_ptr;
+      args[11] = num_x_loop;
+      args[12] = partials;
+
+      /* idx = (z * (size_x * size_y) + y * size_x + x */
+      LLVMValueRef coro_hdl_idx = LLVMBuildMul(gallivm->builder, loop_state[2].counter,
+                                               LLVMBuildMul(gallivm->builder, num_x_loop, y_size_arg, ""), "");
+      coro_hdl_idx = LLVMBuildAdd(gallivm->builder, coro_hdl_idx,
+                                  LLVMBuildMul(gallivm->builder, loop_state[1].counter,
+                                               num_x_loop, ""), "");
+      coro_hdl_idx = LLVMBuildAdd(gallivm->builder, coro_hdl_idx,
+                                  loop_state[0].counter, "");
+
+      LLVMValueRef coro_entry = LLVMBuildGEP(gallivm->builder, coro_hdls, &coro_hdl_idx, 1, "");
+
+      LLVMValueRef coro_hdl = LLVMBuildLoad(gallivm->builder, coro_entry, "coro_hdl");
+
+      struct lp_build_if_state ifstate;
+      LLVMValueRef cmp = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, loop_state[3].counter,
+                                       lp_build_const_int32(gallivm, 0), "");
+      /* first time here - call the coroutine function entry point */
+      lp_build_if(&ifstate, gallivm, cmp);
+      LLVMValueRef coro_ret = LLVMBuildCall(gallivm->builder, coro, args, 13, "");
+      LLVMBuildStore(gallivm->builder, coro_ret, coro_entry);
+      lp_build_else(&ifstate);
+      /* subsequent calls for this invocation - check if done. */
+      LLVMValueRef coro_done = lp_build_coro_done(gallivm, coro_hdl);
+      struct lp_build_if_state ifstate2;
+      lp_build_if(&ifstate2, gallivm, coro_done);
+      /* if done destroy and force loop exit */
+      lp_build_coro_destroy(gallivm, coro_hdl);
+      lp_build_loop_force_set_counter(&loop_state[3], lp_build_const_int32(gallivm, end_coroutine - 1));
+      lp_build_else(&ifstate2);
+      /* otherwise resume the coroutine */
+      lp_build_coro_resume(gallivm, coro_hdl);
+      lp_build_endif(&ifstate2);
+      lp_build_endif(&ifstate);
+      lp_build_loop_force_reload_counter(&loop_state[3]);
+   }
+   lp_build_loop_end_cond(&loop_state[0],
+                          num_x_loop,
+                          NULL,  LLVMIntUGE);
+   lp_build_loop_end_cond(&loop_state[1],
+                          y_size_arg,
+                          NULL,  LLVMIntUGE);
+   lp_build_loop_end_cond(&loop_state[2],
+                          z_size_arg,
+                          NULL,  LLVMIntUGE);
+   lp_build_loop_end_cond(&loop_state[3],
+                          lp_build_const_int32(gallivm, end_coroutine),
+                          NULL, LLVMIntEQ);
+   LLVMBuildRetVoid(builder);
+
+   /* This is stage (b) - generate the compute shader code inside the coroutine. */
+   context_ptr  = LLVMGetParam(coro, 0);
+   x_size_arg = LLVMGetParam(coro, 1);
+   y_size_arg = LLVMGetParam(coro, 2);
+   z_size_arg = LLVMGetParam(coro, 3);
+   grid_x_arg = LLVMGetParam(coro, 4);
+   grid_y_arg = LLVMGetParam(coro, 5);
+   grid_z_arg = LLVMGetParam(coro, 6);
+   grid_size_x_arg = LLVMGetParam(coro, 7);
+   grid_size_y_arg = LLVMGetParam(coro, 8);
+   grid_size_z_arg = LLVMGetParam(coro, 9);
+   thread_data_ptr  = LLVMGetParam(coro, 10);
+   num_x_loop = LLVMGetParam(coro, 11);
+   partials = LLVMGetParam(coro, 12);
+   block = LLVMAppendBasicBlockInContext(gallivm->context, coro, "entry");
+   LLVMPositionBuilderAtEnd(builder, block);
+   {
+      const struct tgsi_token *tokens = shader->base.tokens;
+      LLVMValueRef consts_ptr, num_consts_ptr;
+      LLVMValueRef ssbo_ptr, num_ssbo_ptr;
+      LLVMValueRef shared_ptr;
+      struct lp_build_mask_context mask;
+      struct lp_bld_tgsi_system_values system_values;
+
+      memset(&system_values, 0, sizeof(system_values));
+      consts_ptr = lp_jit_cs_context_constants(gallivm, context_ptr);
+      num_consts_ptr = lp_jit_cs_context_num_constants(gallivm, context_ptr);
+      ssbo_ptr = lp_jit_cs_context_ssbos(gallivm, context_ptr);
+      num_ssbo_ptr = lp_jit_cs_context_num_ssbos(gallivm, context_ptr);
+      shared_ptr = lp_jit_cs_thread_data_shared(gallivm, thread_data_ptr);
+
+      /* these are coroutine entrypoint necessities */
+      LLVMValueRef coro_id = lp_build_coro_id(gallivm);
+      LLVMValueRef coro_hdl = lp_build_coro_begin_alloc_mem(gallivm, coro_id);
+
+      LLVMValueRef has_partials = LLVMBuildICmp(gallivm->builder, LLVMIntNE, partials, lp_build_const_int32(gallivm, 0), "");
+      LLVMValueRef tid_vals[3];
+      LLVMValueRef tids_x[LP_MAX_VECTOR_LENGTH], tids_y[LP_MAX_VECTOR_LENGTH], tids_z[LP_MAX_VECTOR_LENGTH];
+      LLVMValueRef base_val = LLVMBuildMul(gallivm->builder, x_size_arg, vec_length, "");
+      for (i = 0; i < cs_type.length; i++) {
+         tids_x[i] = LLVMBuildAdd(gallivm->builder, base_val, lp_build_const_int32(gallivm, i), "");
+         tids_y[i] = y_size_arg;
+         tids_z[i] = z_size_arg;
+      }
+      tid_vals[0] = lp_build_gather_values(gallivm, tids_x, cs_type.length);
+      tid_vals[1] = lp_build_gather_values(gallivm, tids_y, cs_type.length);
+      tid_vals[2] = lp_build_gather_values(gallivm, tids_z, cs_type.length);
+      system_values.thread_id = LLVMGetUndef(LLVMArrayType(LLVMVectorType(int32_type, cs_type.length), 3));
+      for (i = 0; i < 3; i++)
+         system_values.thread_id = LLVMBuildInsertValue(builder, system_values.thread_id, tid_vals[i], i, "");
+
+      LLVMValueRef gtids[3] = { grid_x_arg, grid_y_arg, grid_z_arg };
+      system_values.block_id = LLVMGetUndef(LLVMVectorType(int32_type, 3));
+      for (i = 0; i < 3; i++)
+         system_values.block_id = LLVMBuildInsertElement(builder, system_values.block_id, gtids[i], lp_build_const_int32(gallivm, i), "");
+
+      LLVMValueRef gstids[3] = { grid_size_x_arg, grid_size_y_arg, grid_size_z_arg };
+      system_values.grid_size = LLVMGetUndef(LLVMVectorType(int32_type, 3));
+      for (i = 0; i < 3; i++)
+         system_values.grid_size = LLVMBuildInsertElement(builder, system_values.grid_size, gstids[i], lp_build_const_int32(gallivm, i), "");
+
+      LLVMValueRef last_x_loop = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, x_size_arg, LLVMBuildSub(gallivm->builder, num_x_loop, lp_build_const_int32(gallivm, 1), ""), "");
+      LLVMValueRef use_partial_mask = LLVMBuildAnd(gallivm->builder, last_x_loop, has_partials, "");
+      struct lp_build_if_state if_state;
+      LLVMValueRef mask_val = lp_build_alloca(gallivm, LLVMVectorType(int32_type, cs_type.length), "mask");
+      LLVMValueRef full_mask_val = lp_build_const_int_vec(gallivm, cs_type, ~0);
+      LLVMBuildStore(gallivm->builder, full_mask_val, mask_val);
+
+      lp_build_if(&if_state, gallivm, use_partial_mask);
+      struct lp_build_loop_state mask_loop_state;
+      lp_build_loop_begin(&mask_loop_state, gallivm, partials);
+      LLVMValueRef tmask_val = LLVMBuildLoad(gallivm->builder, mask_val, "");
+      tmask_val = LLVMBuildInsertElement(gallivm->builder, tmask_val, lp_build_const_int32(gallivm, 0), mask_loop_state.counter, "");
+      LLVMBuildStore(gallivm->builder, tmask_val, mask_val);
+      lp_build_loop_end_cond(&mask_loop_state, vec_length, NULL, LLVMIntUGE);
+      lp_build_endif(&if_state);
+
+      mask_val = LLVMBuildLoad(gallivm->builder, mask_val, "");
+      lp_build_mask_begin(&mask, gallivm, cs_type, mask_val);
+
+      struct lp_build_coro_suspend_info coro_info;
+
+      LLVMBasicBlockRef sus_block = LLVMAppendBasicBlockInContext(gallivm->context, coro, "suspend");
+      LLVMBasicBlockRef clean_block = LLVMAppendBasicBlockInContext(gallivm->context, coro, "cleanup");
+
+      coro_info.suspend = sus_block;
+      coro_info.cleanup = clean_block;
+
+      struct lp_build_tgsi_params params;
+      memset(&params, 0, sizeof(params));
+
+      params.type = cs_type;
+      params.mask = &mask;
+      params.consts_ptr = consts_ptr;
+      params.const_sizes_ptr = num_consts_ptr;
+      params.system_values = &system_values;
+      params.context_ptr = context_ptr;
+      params.info = &shader->info.base;
+      params.ssbo_ptr = ssbo_ptr;
+      params.ssbo_sizes_ptr = num_ssbo_ptr;
+      params.shared_ptr = shared_ptr;
+      params.coro = &coro_info;
+
+      lp_build_tgsi_soa(gallivm, tokens, &params, NULL);
+
+      mask_val = lp_build_mask_end(&mask);
+
+      lp_build_coro_suspend_switch(gallivm, &coro_info, NULL, true);
+      LLVMPositionBuilderAtEnd(builder, clean_block);
+
+      lp_build_coro_free_mem(gallivm, coro_id, coro_hdl);
+
+      LLVMBuildBr(builder, sus_block);
+      LLVMPositionBuilderAtEnd(builder, sus_block);
+
+      lp_build_coro_end(gallivm, coro_hdl);
+      LLVMBuildRet(builder, coro_hdl);
+   }
+
+   gallivm_verify_function(gallivm, coro);
+   gallivm_verify_function(gallivm, function);
+}
+
 static void *
 llvmpipe_create_compute_state(struct pipe_context *pipe,
                                      const struct pipe_compute_state *templ)
@@ -174,6 +500,14 @@ generate_variant(struct llvmpipe_context *lp,
 
    lp_jit_init_cs_types(variant);
 
+   generate_compute(lp, shader, variant);
+
+   gallivm_compile_module(variant->gallivm);
+
+   variant->nr_instrs += lp_build_count_ir_module(variant->gallivm->module);
+
+   variant->jit_function = (lp_jit_cs_func)gallivm_jit_function(variant->gallivm, variant->function);
+
    gallivm_free_ir(variant->gallivm);
    return variant;
 }
index adb2a36367973e3cb2f367d4649b983fcdb55c6d..cab0d41fe7d63b6de65caeb802761bb318eaaabf 100644 (file)
@@ -55,6 +55,9 @@ struct lp_compute_shader_variant
    LLVMTypeRef jit_cs_context_ptr_type;
    LLVMTypeRef jit_cs_thread_data_ptr_type;
 
+   LLVMValueRef function;
+   lp_jit_cs_func jit_function;
+
    /* Total number of LLVM instructions generated */
    unsigned nr_instrs;