radv: simplify allocating user SGPRS for descriptor sets
[mesa.git] / src / amd / vulkan / radv_nir_to_llvm.c
index 46c96dfac06312137c33752491865cce6edc83bc..690be94748572023e900d342802d09d6d89056bd 100644 (file)
@@ -33,9 +33,7 @@
 #include <llvm-c/Core.h>
 #include <llvm-c/TargetMachine.h>
 #include <llvm-c/Transforms/Scalar.h>
-#if HAVE_LLVM >= 0x0700
 #include <llvm-c/Transforms/Utils.h>
-#endif
 
 #include "sid.h"
 #include "gfx9d.h"
@@ -256,7 +254,16 @@ get_tcs_num_patches(struct radv_shader_context *ctx)
        /* Make sure that the data fits in LDS. This assumes the shaders only
         * use LDS for the inputs and outputs.
         */
-       hardware_lds_size = ctx->options->chip_class >= CIK ? 65536 : 32768;
+       hardware_lds_size = 32768;
+
+       /* Looks like STONEY hangs if we use more than 32 KiB LDS in a single
+        * threadgroup, even though there is more than 32 KiB LDS.
+        *
+        * Test: dEQP-VK.tessellation.shader_input_output.barrier
+        */
+       if (ctx->options->chip_class >= CIK && ctx->options->family != CHIP_STONEY)
+               hardware_lds_size = 65536;
+
        num_patches = MIN2(num_patches, hardware_lds_size / (input_patch_size + output_patch_size));
        /* Make sure the output data fits in the offchip buffer */
        num_patches = MIN2(num_patches, (ctx->options->tess_offchip_block_dw_size * 4) / output_patch_size);
@@ -546,11 +553,10 @@ create_llvm_function(LLVMContextRef ctx, LLVMModuleRef module,
 
 static void
 set_loc(struct radv_userdata_info *ud_info, uint8_t *sgpr_idx,
-       uint8_t num_sgprs, bool indirect)
+       uint8_t num_sgprs)
 {
        ud_info->sgpr_idx = *sgpr_idx;
        ud_info->num_sgprs = num_sgprs;
-       ud_info->indirect = indirect;
        *sgpr_idx += num_sgprs;
 }
 
@@ -562,31 +568,28 @@ set_loc_shader(struct radv_shader_context *ctx, int idx, uint8_t *sgpr_idx,
                &ctx->shader_info->user_sgprs_locs.shader_data[idx];
        assert(ud_info);
 
-       set_loc(ud_info, sgpr_idx, num_sgprs, false);
+       set_loc(ud_info, sgpr_idx, num_sgprs);
 }
 
 static void
 set_loc_shader_ptr(struct radv_shader_context *ctx, int idx, uint8_t *sgpr_idx)
 {
-       bool use_32bit_pointers = HAVE_32BIT_POINTERS &&
-                                 idx != AC_UD_SCRATCH_RING_OFFSETS;
+       bool use_32bit_pointers = idx != AC_UD_SCRATCH_RING_OFFSETS;
 
        set_loc_shader(ctx, idx, sgpr_idx, use_32bit_pointers ? 1 : 2);
 }
 
 static void
-set_loc_desc(struct radv_shader_context *ctx, int idx,  uint8_t *sgpr_idx,
-            bool indirect)
+set_loc_desc(struct radv_shader_context *ctx, int idx, uint8_t *sgpr_idx)
 {
        struct radv_userdata_locations *locs =
                &ctx->shader_info->user_sgprs_locs;
        struct radv_userdata_info *ud_info = &locs->descriptor_sets[idx];
        assert(ud_info);
 
-       set_loc(ud_info, sgpr_idx, HAVE_32BIT_POINTERS ? 1 : 2, indirect);
+       set_loc(ud_info, sgpr_idx, 1);
 
-       if (!indirect)
-               locs->descriptor_sets_enabled |= 1 << idx;
+       locs->descriptor_sets_enabled |= 1 << idx;
 }
 
 struct user_sgpr_info {
@@ -624,7 +627,7 @@ count_vs_user_sgprs(struct radv_shader_context *ctx)
        uint8_t count = 0;
 
        if (ctx->shader_info->info.vs.has_vertex_buffers)
-               count += HAVE_32BIT_POINTERS ? 1 : 2;
+               count++;
        count += ctx->shader_info->info.vs.needs_draw_id ? 3 : 2;
 
        return count;
@@ -693,43 +696,37 @@ static void allocate_user_sgprs(struct radv_shader_context *ctx,
                user_sgpr_count++;
 
        if (ctx->shader_info->info.loads_push_constants)
-               user_sgpr_count += HAVE_32BIT_POINTERS ? 1 : 2;
+               user_sgpr_count++;
+
+       if (ctx->streamout_buffers)
+               user_sgpr_count++;
 
        uint32_t available_sgprs = ctx->options->chip_class >= GFX9 && stage != MESA_SHADER_COMPUTE ? 32 : 16;
        uint32_t remaining_sgprs = available_sgprs - user_sgpr_count;
        uint32_t num_desc_set =
                util_bitcount(ctx->shader_info->info.desc_set_used_mask);
 
-       if (remaining_sgprs / (HAVE_32BIT_POINTERS ? 1 : 2) < num_desc_set) {
+       if (remaining_sgprs < num_desc_set) {
                user_sgpr_info->indirect_all_descriptor_sets = true;
        }
 }
 
 static void
 declare_global_input_sgprs(struct radv_shader_context *ctx,
-                          gl_shader_stage stage,
-                          bool has_previous_stage,
-                          gl_shader_stage previous_stage,
                           const struct user_sgpr_info *user_sgpr_info,
                           struct arg_info *args,
                           LLVMValueRef *desc_sets)
 {
        LLVMTypeRef type = ac_array_in_const32_addr_space(ctx->ac.i8);
-       unsigned num_sets = ctx->options->layout ?
-                           ctx->options->layout->num_sets : 0;
-       unsigned stage_mask = 1 << stage;
-
-       if (has_previous_stage)
-               stage_mask |= 1 << previous_stage;
 
        /* 1 for each descriptor set */
        if (!user_sgpr_info->indirect_all_descriptor_sets) {
-               for (unsigned i = 0; i < num_sets; ++i) {
-                       if ((ctx->shader_info->info.desc_set_used_mask & (1 << i)) &&
-                           ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
-                               add_array_arg(args, type,
-                                             &ctx->descriptor_sets[i]);
-                       }
+               uint32_t mask = ctx->shader_info->info.desc_set_used_mask;
+
+               while (mask) {
+                       int i = u_bit_scan(&mask);
+
+                       add_array_arg(args, type, &ctx->descriptor_sets[i]);
                }
        } else {
                add_array_arg(args, ac_array_in_const32_addr_space(type), desc_sets);
@@ -826,41 +823,31 @@ declare_tes_input_vgprs(struct radv_shader_context *ctx, struct arg_info *args)
 }
 
 static void
-set_global_input_locs(struct radv_shader_context *ctx, gl_shader_stage stage,
-                     bool has_previous_stage, gl_shader_stage previous_stage,
+set_global_input_locs(struct radv_shader_context *ctx,
                      const struct user_sgpr_info *user_sgpr_info,
                      LLVMValueRef desc_sets, uint8_t *user_sgpr_idx)
 {
-       unsigned num_sets = ctx->options->layout ?
-                           ctx->options->layout->num_sets : 0;
-       unsigned stage_mask = 1 << stage;
-
-       if (has_previous_stage)
-               stage_mask |= 1 << previous_stage;
+       uint32_t mask = ctx->shader_info->info.desc_set_used_mask;
 
        if (!user_sgpr_info->indirect_all_descriptor_sets) {
-               for (unsigned i = 0; i < num_sets; ++i) {
-                       if ((ctx->shader_info->info.desc_set_used_mask & (1 << i)) &&
-                           ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
-                               set_loc_desc(ctx, i, user_sgpr_idx, false);
-                       } else
-                               ctx->descriptor_sets[i] = NULL;
+               while (mask) {
+                       int i = u_bit_scan(&mask);
+
+                       set_loc_desc(ctx, i, user_sgpr_idx);
                }
        } else {
                set_loc_shader_ptr(ctx, AC_UD_INDIRECT_DESCRIPTOR_SETS,
                                   user_sgpr_idx);
 
-               for (unsigned i = 0; i < num_sets; ++i) {
-                       if ((ctx->shader_info->info.desc_set_used_mask & (1 << i)) &&
-                           ctx->options->layout->set[i].layout->shader_stages & stage_mask) {
-                               ctx->descriptor_sets[i] =
-                                       ac_build_load_to_sgpr(&ctx->ac,
-                                                             desc_sets,
-                                                             LLVMConstInt(ctx->ac.i32, i, false));
+               while (mask) {
+                       int i = u_bit_scan(&mask);
+
+                       ctx->descriptor_sets[i] =
+                               ac_build_load_to_sgpr(&ctx->ac, desc_sets,
+                                                     LLVMConstInt(ctx->ac.i32, i, false));
 
-                       } else
-                               ctx->descriptor_sets[i] = NULL;
                }
+
                ctx->shader_info->need_indirect_descriptor_sets = true;
        }
 
@@ -946,9 +933,8 @@ static void create_function(struct radv_shader_context *ctx,
 
        switch (stage) {
        case MESA_SHADER_COMPUTE:
-               declare_global_input_sgprs(ctx, stage, has_previous_stage,
-                                          previous_stage, &user_sgpr_info,
-                                          &args, &desc_sets);
+               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
+                                          &desc_sets);
 
                if (ctx->shader_info->info.cs.uses_grid_size) {
                        add_arg(&args, ARG_SGPR, ctx->ac.v3i32,
@@ -969,9 +955,9 @@ static void create_function(struct radv_shader_context *ctx,
                        &ctx->abi.local_invocation_ids);
                break;
        case MESA_SHADER_VERTEX:
-               declare_global_input_sgprs(ctx, stage, has_previous_stage,
-                                          previous_stage, &user_sgpr_info,
-                                          &args, &desc_sets);
+               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
+                                          &desc_sets);
+
                declare_vs_specific_input_sgprs(ctx, stage, has_previous_stage,
                                                previous_stage, &args);
 
@@ -1002,11 +988,9 @@ static void create_function(struct radv_shader_context *ctx,
                        add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
                        add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
 
-                       declare_global_input_sgprs(ctx, stage,
-                                                  has_previous_stage,
-                                                  previous_stage,
-                                                  &user_sgpr_info, &args,
+                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
                                                   &desc_sets);
+
                        declare_vs_specific_input_sgprs(ctx, stage,
                                                        has_previous_stage,
                                                        previous_stage, &args);
@@ -1022,10 +1006,7 @@ static void create_function(struct radv_shader_context *ctx,
 
                        declare_vs_input_vgprs(ctx, &args);
                } else {
-                       declare_global_input_sgprs(ctx, stage,
-                                                  has_previous_stage,
-                                                  previous_stage,
-                                                  &user_sgpr_info, &args,
+                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
                                                   &desc_sets);
 
                        if (needs_view_index)
@@ -1042,9 +1023,8 @@ static void create_function(struct radv_shader_context *ctx,
                }
                break;
        case MESA_SHADER_TESS_EVAL:
-               declare_global_input_sgprs(ctx, stage, has_previous_stage,
-                                          previous_stage, &user_sgpr_info,
-                                          &args, &desc_sets);
+               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
+                                          &desc_sets);
 
                if (needs_view_index)
                        add_arg(&args, ARG_SGPR, ctx->ac.i32,
@@ -1075,10 +1055,7 @@ static void create_function(struct radv_shader_context *ctx,
                        add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
                        add_arg(&args, ARG_SGPR, ctx->ac.i32, NULL); // unknown
 
-                       declare_global_input_sgprs(ctx, stage,
-                                                  has_previous_stage,
-                                                  previous_stage,
-                                                  &user_sgpr_info, &args,
+                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
                                                   &desc_sets);
 
                        if (previous_stage != MESA_SHADER_TESS_EVAL) {
@@ -1109,10 +1086,7 @@ static void create_function(struct radv_shader_context *ctx,
                                declare_tes_input_vgprs(ctx, &args);
                        }
                } else {
-                       declare_global_input_sgprs(ctx, stage,
-                                                  has_previous_stage,
-                                                  previous_stage,
-                                                  &user_sgpr_info, &args,
+                       declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
                                                   &desc_sets);
 
                        if (needs_view_index)
@@ -1140,9 +1114,8 @@ static void create_function(struct radv_shader_context *ctx,
                }
                break;
        case MESA_SHADER_FRAGMENT:
-               declare_global_input_sgprs(ctx, stage, has_previous_stage,
-                                          previous_stage, &user_sgpr_info,
-                                          &args, &desc_sets);
+               declare_global_input_sgprs(ctx, &user_sgpr_info, &args,
+                                          &desc_sets);
 
                add_arg(&args, ARG_SGPR, ctx->ac.i32, &ctx->abi.prim_mask);
                add_arg(&args, ARG_VGPR, ctx->ac.v2i32, &ctx->persp_sample);
@@ -1201,8 +1174,7 @@ static void create_function(struct radv_shader_context *ctx,
        if (has_previous_stage)
                user_sgpr_idx = 0;
 
-       set_global_input_locs(ctx, stage, has_previous_stage, previous_stage,
-                             &user_sgpr_info, desc_sets, &user_sgpr_idx);
+       set_global_input_locs(ctx, &user_sgpr_info, desc_sets, &user_sgpr_idx);
 
        switch (stage) {
        case MESA_SHADER_COMPUTE:
@@ -1684,9 +1656,6 @@ radv_get_sample_pos_offset(uint32_t num_samples)
        case 8:
                sample_pos_offset = 7;
                break;
-       case 16:
-               sample_pos_offset = 15;
-               break;
        default:
                break;
        }
@@ -2160,7 +2129,7 @@ handle_fs_input_decl(struct radv_shader_context *ctx,
 
                interp = lookup_interp_param(&ctx->abi, variable->data.interpolation, interp_type);
        }
-       bool is_16bit = glsl_type_is_16bit(variable->type);
+       bool is_16bit = glsl_type_is_16bit(glsl_without_array(variable->type));
        LLVMTypeRef type = is_16bit ? ctx->ac.i16 : ctx->ac.i32;
        if (interp == NULL)
                interp = LLVMGetUndef(type);
@@ -2242,6 +2211,8 @@ handle_fs_inputs(struct radv_shader_context *ctx,
 
                        if (LLVMIsUndef(interp_param))
                                ctx->shader_info->fs.flat_shaded_mask |= 1u << index;
+                       if (i >= VARYING_SLOT_VAR0)
+                               ctx->abi.fs_input_attr_indices[i - VARYING_SLOT_VAR0] = index;
                        ++index;
                } else if (i == VARYING_SLOT_CLIP_DIST0) {
                        int length = ctx->shader_info->info.ps.num_input_clips_culls;
@@ -2512,9 +2483,6 @@ radv_emit_stream_output(struct radv_shader_context *ctx,
        /* Get the first component. */
        start = ffs(output->component_mask) - 1;
 
-       /* Adjust the destination offset. */
-       offset += start * 4;
-
        /* Load the output as int. */
        for (int i = 0; i < num_comps; i++) {
                out[i] = ac_to_integer(&ctx->ac,
@@ -2723,8 +2691,11 @@ handle_vs_outputs_post(struct radv_shader_context *ctx,
                viewport_index_value = radv_load_output(ctx, VARYING_SLOT_VIEWPORT, 0);
        }
 
-       if (ctx->shader_info->info.so.num_outputs)
+       if (ctx->shader_info->info.so.num_outputs &&
+           !ctx->is_gs_copy_shader) {
+               /* The GS copy shader emission already emits streamout. */
                radv_emit_streamout(ctx, 0);
+       }
 
        if (outinfo->writes_pointsize ||
            outinfo->writes_layer ||
@@ -3500,7 +3471,7 @@ LLVMModuleRef ac_translate_nir_to_llvm(struct ac_llvm_compiler *ac_llvm,
        ctx.abi.load_sampler_desc = radv_get_sampler_desc;
        ctx.abi.load_resource = radv_load_resource;
        ctx.abi.clamp_shadow_reference = false;
-       ctx.abi.gfx9_stride_size_workaround = ctx.ac.chip_class == GFX9;
+       ctx.abi.gfx9_stride_size_workaround = ctx.ac.chip_class == GFX9 && HAVE_LLVM < 0x800;
 
        if (shader_count >= 2)
                ac_init_exec_full_mask(&ctx.ac);
@@ -3829,45 +3800,92 @@ ac_gs_copy_shader_emit(struct radv_shader_context *ctx)
        LLVMValueRef vtx_offset =
                LLVMBuildMul(ctx->ac.builder, ctx->abi.vertex_id,
                             LLVMConstInt(ctx->ac.i32, 4, false), "");
-       unsigned offset = 0;
+       LLVMValueRef stream_id;
 
-       for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
-               unsigned output_usage_mask =
-                       ctx->shader_info->info.gs.output_usage_mask[i];
-               int length = util_last_bit(output_usage_mask);
+       /* Fetch the vertex stream ID. */
+       if (ctx->shader_info->info.so.num_outputs) {
+               stream_id =
+                       ac_unpack_param(&ctx->ac, ctx->streamout_config, 24, 2);
+       } else {
+               stream_id = ctx->ac.i32_0;
+       }
 
-               if (!(ctx->output_mask & (1ull << i)))
+       LLVMBasicBlockRef end_bb;
+       LLVMValueRef switch_inst;
+
+       end_bb = LLVMAppendBasicBlockInContext(ctx->ac.context,
+                                              ctx->main_function, "end");
+       switch_inst = LLVMBuildSwitch(ctx->ac.builder, stream_id, end_bb, 4);
+
+       for (unsigned stream = 0; stream < 4; stream++) {
+               unsigned num_components =
+                       ctx->shader_info->info.gs.num_stream_output_components[stream];
+               LLVMBasicBlockRef bb;
+               unsigned offset;
+
+               if (!num_components)
                        continue;
 
-               for (unsigned j = 0; j < length; j++) {
-                       LLVMValueRef value, soffset;
+               if (stream > 0 && !ctx->shader_info->info.so.num_outputs)
+                       continue;
 
-                       if (!(output_usage_mask & (1 << j)))
+               bb = LLVMInsertBasicBlockInContext(ctx->ac.context, end_bb, "out");
+               LLVMAddCase(switch_inst, LLVMConstInt(ctx->ac.i32, stream, 0), bb);
+               LLVMPositionBuilderAtEnd(ctx->ac.builder, bb);
+
+               offset = 0;
+               for (unsigned i = 0; i < AC_LLVM_MAX_OUTPUTS; ++i) {
+                       unsigned output_usage_mask =
+                               ctx->shader_info->info.gs.output_usage_mask[i];
+                       unsigned output_stream =
+                               ctx->shader_info->info.gs.output_streams[i];
+                       int length = util_last_bit(output_usage_mask);
+
+                       if (!(ctx->output_mask & (1ull << i)) ||
+                           output_stream != stream)
                                continue;
 
-                       soffset = LLVMConstInt(ctx->ac.i32,
-                                              offset *
-                                              ctx->gs_max_out_vertices * 16 * 4, false);
+                       for (unsigned j = 0; j < length; j++) {
+                               LLVMValueRef value, soffset;
 
-                       offset++;
+                               if (!(output_usage_mask & (1 << j)))
+                                       continue;
+
+                               soffset = LLVMConstInt(ctx->ac.i32,
+                                                      offset *
+                                                      ctx->gs_max_out_vertices * 16 * 4, false);
+
+                               offset++;
+
+                               value = ac_build_buffer_load(&ctx->ac,
+                                                            ctx->gsvs_ring[0],
+                                                            1, ctx->ac.i32_0,
+                                                            vtx_offset, soffset,
+                                                            0, 1, 1, true, false);
 
-                       value = ac_build_buffer_load(&ctx->ac,
-                                                    ctx->gsvs_ring[0],
-                                                    1, ctx->ac.i32_0,
-                                                    vtx_offset, soffset,
-                                                    0, 1, 1, true, false);
+                               LLVMTypeRef type = LLVMGetAllocatedType(ctx->abi.outputs[ac_llvm_reg_index_soa(i, j)]);
+                               if (ac_get_type_size(type) == 2) {
+                                       value = LLVMBuildBitCast(ctx->ac.builder, value, ctx->ac.i32, "");
+                                       value = LLVMBuildTrunc(ctx->ac.builder, value, ctx->ac.i16, "");
+                               }
 
-                       LLVMTypeRef type = LLVMGetAllocatedType(ctx->abi.outputs[ac_llvm_reg_index_soa(i, j)]);
-                       if (ac_get_type_size(type) == 2) {
-                               value = LLVMBuildBitCast(ctx->ac.builder, value, ctx->ac.i32, "");
-                               value = LLVMBuildTrunc(ctx->ac.builder, value, ctx->ac.i16, "");
+                               LLVMBuildStore(ctx->ac.builder,
+                                              ac_to_float(&ctx->ac, value), ctx->abi.outputs[ac_llvm_reg_index_soa(i, j)]);
                        }
+               }
+
+               if (ctx->shader_info->info.so.num_outputs)
+                       radv_emit_streamout(ctx, stream);
 
-                       LLVMBuildStore(ctx->ac.builder,
-                                      ac_to_float(&ctx->ac, value), ctx->abi.outputs[ac_llvm_reg_index_soa(i, j)]);
+               if (stream == 0) {
+                       handle_vs_outputs_post(ctx, false, false,
+                                              &ctx->shader_info->vs.outinfo);
                }
+
+               LLVMBuildBr(ctx->ac.builder, end_bb);
        }
-       handle_vs_outputs_post(ctx, false, false, &ctx->shader_info->vs.outinfo);
+
+       LLVMPositionBuilderAtEnd(ctx->ac.builder, end_bb);
 }
 
 void