radv: fill shader info for all stages in the pipeline
[mesa.git] / src / amd / vulkan / radv_shader.c
index a3e9b45a7529de364af8a514da8e7aa6a0e9a5be..146d85ade5a02dd73330f43e4d8bcb19d90fbcaf 100644 (file)
@@ -200,7 +200,7 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively,
                NIR_PASS(progress, shader, nir_remove_dead_variables,
                         nir_var_function_temp);
 
-                NIR_PASS_V(shader, nir_lower_alu_to_scalar, NULL);
+                NIR_PASS_V(shader, nir_lower_alu_to_scalar, NULL, NULL);
                 NIR_PASS_V(shader, nir_lower_phis_to_scalar);
 
                 NIR_PASS(progress, shader, nir_copy_prop);
@@ -389,6 +389,8 @@ radv_shader_compile_to_nir(struct radv_device *device,
                NIR_PASS_V(nir, nir_remove_dead_variables,
                           nir_var_shader_in | nir_var_shader_out | nir_var_system_value);
 
+               NIR_PASS_V(nir, nir_propagate_invariant);
+
                NIR_PASS_V(nir, nir_lower_system_values);
                NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
                NIR_PASS_V(nir, radv_nir_lower_ycbcr_textures, layout);
@@ -442,6 +444,13 @@ radv_shader_compile_to_nir(struct radv_device *device,
         */
        nir_lower_var_copies(nir);
 
+       /* Lower large variables that are always constant with load_constant
+        * intrinsics, which get turned into PC-relative loads from a data
+        * section next to the shader.
+        */
+       NIR_PASS_V(nir, nir_opt_large_constants,
+                  glsl_get_natural_size_align_bytes, 16);
+
        /* Indirect lowering must be called after the radv_optimize_nir() loop
         * has been called at least once. Otherwise indirect lowering can
         * bloat the instruction count of the loop and cause it to be
@@ -520,8 +529,8 @@ lower_view_index(nir_shader *nir)
        return progress;
 }
 
-static void
-lower_fs_io(nir_shader *nir)
+void
+radv_lower_fs_io(nir_shader *nir)
 {
        NIR_PASS_V(nir, lower_view_index);
        nir_assign_io_var_locations(&nir->inputs, &nir->num_inputs,
@@ -607,7 +616,7 @@ radv_get_shader_binary_size(size_t code_size)
 
 static void radv_postprocess_config(const struct radv_physical_device *pdevice,
                                    const struct ac_shader_config *config_in,
-                                   const struct radv_shader_variant_info *info,
+                                   const struct radv_shader_info *info,
                                    gl_shader_stage stage,
                                    struct ac_shader_config *config_out)
 {
@@ -675,14 +684,14 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
 
        config_out->rsrc2 = S_00B12C_USER_SGPR(info->num_user_sgprs) |
                            S_00B12C_SCRATCH_EN(scratch_enabled) |
-                           S_00B12C_SO_BASE0_EN(!!info->info.so.strides[0]) |
-                           S_00B12C_SO_BASE1_EN(!!info->info.so.strides[1]) |
-                           S_00B12C_SO_BASE2_EN(!!info->info.so.strides[2]) |
-                           S_00B12C_SO_BASE3_EN(!!info->info.so.strides[3]) |
-                           S_00B12C_SO_EN(!!info->info.so.num_outputs);
+                           S_00B12C_SO_BASE0_EN(!!info->so.strides[0]) |
+                           S_00B12C_SO_BASE1_EN(!!info->so.strides[1]) |
+                           S_00B12C_SO_BASE2_EN(!!info->so.strides[2]) |
+                           S_00B12C_SO_BASE3_EN(!!info->so.strides[3]) |
+                           S_00B12C_SO_EN(!!info->so.num_outputs);
 
        config_out->rsrc1 = S_00B848_VGPRS((num_vgprs - 1) /
-                                          (info->info.wave_size == 32 ? 8 : 4)) |
+                                          (info->wave_size == 32 ? 8 : 4)) |
                            S_00B848_DX10_CLAMP(1) |
                            S_00B848_FLOAT_MODE(config_out->float_mode);
 
@@ -700,11 +709,11 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
                        config_out->rsrc2 |= S_00B22C_OC_LDS_EN(1);
                } else if (info->tes.as_es) {
                        assert(pdevice->rad_info.chip_class <= GFX8);
-                       vgpr_comp_cnt = info->info.uses_prim_id ? 3 : 2;
+                       vgpr_comp_cnt = info->uses_prim_id ? 3 : 2;
 
                        config_out->rsrc2 |= S_00B12C_OC_LDS_EN(1);
                } else {
-                       bool enable_prim_id = info->tes.export_prim_id || info->info.uses_prim_id;
+                       bool enable_prim_id = info->tes.export_prim_id || info->uses_prim_id;
                        vgpr_comp_cnt = enable_prim_id ? 3 : 2;
 
                        config_out->rsrc1 |= S_00B128_MEM_ORDERED(pdevice->rad_info.chip_class >= GFX10);
@@ -718,9 +727,9 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
                         * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
                         */
                        if (pdevice->rad_info.chip_class >= GFX10) {
-                               vgpr_comp_cnt = info->info.vs.needs_instance_id ? 3 : 1;
+                               vgpr_comp_cnt = info->vs.needs_instance_id ? 3 : 1;
                        } else {
-                               vgpr_comp_cnt = info->info.vs.needs_instance_id ? 2 : 1;
+                               vgpr_comp_cnt = info->vs.needs_instance_id ? 2 : 1;
                        }
                } else {
                        config_out->rsrc2 |= S_00B12C_OC_LDS_EN(1);
@@ -737,21 +746,21 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
                         * VGPR0-3: (VertexID, RelAutoindex, InstanceID / StepRate0, InstanceID).
                         * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
                         */
-                       vgpr_comp_cnt = info->info.vs.needs_instance_id ? 2 : 1;
+                       vgpr_comp_cnt = info->vs.needs_instance_id ? 2 : 1;
                } else if (info->vs.as_es) {
                        assert(pdevice->rad_info.chip_class <= GFX8);
                        /* VGPR0-3: (VertexID, InstanceID / StepRate0, ...) */
-                       vgpr_comp_cnt = info->info.vs.needs_instance_id ? 1 : 0;
+                       vgpr_comp_cnt = info->vs.needs_instance_id ? 1 : 0;
                } else {
                        /* VGPR0-3: (VertexID, InstanceID / StepRate0, PrimID, InstanceID)
                         * If PrimID is disabled. InstanceID / StepRate1 is loaded instead.
                         * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
                         */
-                       if (info->info.vs.needs_instance_id && pdevice->rad_info.chip_class >= GFX10) {
+                       if (info->vs.needs_instance_id && pdevice->rad_info.chip_class >= GFX10) {
                                vgpr_comp_cnt = 3;
                        } else if (info->vs.export_prim_id) {
                                vgpr_comp_cnt = 2;
-                       } else if (info->info.vs.needs_instance_id) {
+                       } else if (info->vs.needs_instance_id) {
                                vgpr_comp_cnt = 1;
                        } else {
                                vgpr_comp_cnt = 0;
@@ -771,12 +780,12 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
                config_out->rsrc1 |= S_00B848_MEM_ORDERED(pdevice->rad_info.chip_class >= GFX10) |
                                     S_00B848_WGP_MODE(pdevice->rad_info.chip_class >= GFX10);
                config_out->rsrc2 |=
-                       S_00B84C_TGID_X_EN(info->info.cs.uses_block_id[0]) |
-                       S_00B84C_TGID_Y_EN(info->info.cs.uses_block_id[1]) |
-                       S_00B84C_TGID_Z_EN(info->info.cs.uses_block_id[2]) |
-                       S_00B84C_TIDIG_COMP_CNT(info->info.cs.uses_thread_id[2] ? 2 :
-                                               info->info.cs.uses_thread_id[1] ? 1 : 0) |
-                       S_00B84C_TG_SIZE_EN(info->info.cs.uses_local_invocation_idx) |
+                       S_00B84C_TGID_X_EN(info->cs.uses_block_id[0]) |
+                       S_00B84C_TGID_Y_EN(info->cs.uses_block_id[1]) |
+                       S_00B84C_TGID_Z_EN(info->cs.uses_block_id[2]) |
+                       S_00B84C_TIDIG_COMP_CNT(info->cs.uses_thread_id[2] ? 2 :
+                                               info->cs.uses_thread_id[1] ? 1 : 0) |
+                       S_00B84C_TG_SIZE_EN(info->cs.uses_local_invocation_idx) |
                        S_00B84C_LDS_SIZE(config_in->lds_size);
                break;
        default:
@@ -793,18 +802,18 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
 
                /* VGPR5-8: (VertexID, UserVGPR0, UserVGPR1, UserVGPR2 / InstanceID) */
                if (es_stage == MESA_SHADER_VERTEX) {
-                       es_vgpr_comp_cnt = info->info.vs.needs_instance_id ? 3 : 0;
+                       es_vgpr_comp_cnt = info->vs.needs_instance_id ? 3 : 0;
                } else if (es_stage == MESA_SHADER_TESS_EVAL) {
-                       bool enable_prim_id = info->tes.export_prim_id || info->info.uses_prim_id;
+                       bool enable_prim_id = info->tes.export_prim_id || info->uses_prim_id;
                        es_vgpr_comp_cnt = enable_prim_id ? 3 : 2;
                } else
                        unreachable("Unexpected ES shader stage");
 
                bool tes_triangles = stage == MESA_SHADER_TESS_EVAL &&
                        info->tes.primitive_mode >= 4; /* GL_TRIANGLES */
-               if (info->info.uses_invocation_id || stage == MESA_SHADER_VERTEX) {
+               if (info->uses_invocation_id || stage == MESA_SHADER_VERTEX) {
                        gs_vgpr_comp_cnt = 3; /* VGPR3 contains InvocationID. */
-               } else if (info->info.uses_prim_id) {
+               } else if (info->uses_prim_id) {
                        gs_vgpr_comp_cnt = 2; /* VGPR2 contains PrimitiveID. */
                } else if (info->gs.vertices_in >= 3 || tes_triangles) {
                        gs_vgpr_comp_cnt = 1; /* VGPR1 contains offsets 2, 3 */
@@ -824,13 +833,13 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
 
                if (es_type == MESA_SHADER_VERTEX) {
                        /* VGPR0-3: (VertexID, InstanceID / StepRate0, ...) */
-                       if (info->info.vs.needs_instance_id) {
+                       if (info->vs.needs_instance_id) {
                                es_vgpr_comp_cnt = pdevice->rad_info.chip_class >= GFX10 ? 3 : 1;
                        } else {
                                es_vgpr_comp_cnt = 0;
                        }
                } else if (es_type == MESA_SHADER_TESS_EVAL) {
-                       es_vgpr_comp_cnt = info->info.uses_prim_id ? 3 : 2;
+                       es_vgpr_comp_cnt = info->uses_prim_id ? 3 : 2;
                } else {
                        unreachable("invalid shader ES type");
                }
@@ -838,9 +847,9 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
                /* If offsets 4, 5 are used, GS_VGPR_COMP_CNT is ignored and
                 * VGPR[0:4] are always loaded.
                 */
-               if (info->info.uses_invocation_id) {
+               if (info->uses_invocation_id) {
                        gs_vgpr_comp_cnt = 3; /* VGPR3 contains InvocationID. */
-               } else if (info->info.uses_prim_id) {
+               } else if (info->uses_prim_id) {
                        gs_vgpr_comp_cnt = 2; /* VGPR2 contains PrimitiveID. */
                } else if (info->gs.vertices_in >= 3) {
                        gs_vgpr_comp_cnt = 1; /* VGPR1 contains offsets 2, 3 */
@@ -859,38 +868,6 @@ static void radv_postprocess_config(const struct radv_physical_device *pdevice,
        }
 }
 
-static void radv_init_llvm_target()
-{
-       LLVMInitializeAMDGPUTargetInfo();
-       LLVMInitializeAMDGPUTarget();
-       LLVMInitializeAMDGPUTargetMC();
-       LLVMInitializeAMDGPUAsmPrinter();
-
-       /* For inline assembly. */
-       LLVMInitializeAMDGPUAsmParser();
-
-       /* Workaround for bug in llvm 4.0 that causes image intrinsics
-        * to disappear.
-        * https://reviews.llvm.org/D26348
-        *
-        * Workaround for bug in llvm that causes the GPU to hang in presence
-        * of nested loops because there is an exec mask issue. The proper
-        * solution is to fix LLVM but this might require a bunch of work.
-        * https://bugs.llvm.org/show_bug.cgi?id=37744
-        *
-        * "mesa" is the prefix for error messages.
-        */
-       const char *argv[2] = { "mesa", "-simplifycfg-sink-common=false" };
-       LLVMParseCommandLineOptions(2, argv, NULL);
-}
-
-static once_flag radv_init_llvm_target_once_flag = ONCE_FLAG_INIT;
-
-static void radv_init_llvm_once(void)
-{
-       call_once(&radv_init_llvm_target_once_flag, radv_init_llvm_target);
-}
-
 struct radv_shader_variant *
 radv_shader_variant_create(struct radv_device *device,
                           const struct radv_shader_binary *binary,
@@ -917,14 +894,14 @@ radv_shader_variant_create(struct radv_device *device,
                        esgs_ring_size = 32 * 1024;
                }
 
-               if (binary->variant_info.is_ngg) {
+               if (binary->info.is_ngg) {
                        /* GS stores Primitive IDs into LDS at the address
                         * corresponding to the ES thread of the provoking
                         * vertex. All ES threads load and export PrimitiveID
                         * for their thread.
                         */
                        if (binary->stage == MESA_SHADER_VERTEX &&
-                           binary->variant_info.vs.export_prim_id) {
+                           binary->info.vs.export_prim_id) {
                                /* TODO: Do not harcode this value */
                                esgs_ring_size = 256 /* max_out_verts */ * 4;
                        }
@@ -941,14 +918,14 @@ radv_shader_variant_create(struct radv_device *device,
 
                        /* Make sure to have LDS space for NGG scratch. */
                        /* TODO: Compute this correctly somehow? */
-                       if (binary->variant_info.is_ngg)
+                       if (binary->info.is_ngg)
                                sym->size -= 32;
                }
 
                struct ac_rtld_open_info open_info = {
                        .info = &device->physical_device->rad_info,
                        .shader_type = binary->stage,
-                       .wave_size = binary->variant_info.info.wave_size,
+                       .wave_size = binary->info.wave_size,
                        .num_parts = 1,
                        .elf_ptrs = &elf_data,
                        .elf_sizes = &elf_size,
@@ -973,14 +950,16 @@ radv_shader_variant_create(struct radv_device *device,
                }
 
                variant->code_size = rtld_binary.rx_size;
+               variant->exec_size = rtld_binary.exec_size;
        } else {
                assert(binary->type == RADV_BINARY_TYPE_LEGACY);
                config = ((struct radv_shader_binary_legacy *)binary)->config;
-               variant->code_size  = radv_get_shader_binary_size(((struct radv_shader_binary_legacy *)binary)->code_size);
+               variant->code_size = radv_get_shader_binary_size(((struct radv_shader_binary_legacy *)binary)->code_size);
+               variant->exec_size = variant->code_size;
        }
 
-       variant->info = binary->variant_info;
-       radv_postprocess_config(device->physical_device, &config, &binary->variant_info,
+       variant->info = binary->info;
+       radv_postprocess_config(device->physical_device, &config, &binary->info,
                                binary->stage, &variant->config);
        
        void *dest_ptr = radv_alloc_shader_memory(device, variant);
@@ -1060,6 +1039,7 @@ shader_variant_compile(struct radv_device *device,
                       struct nir_shader * const *shaders,
                       int shader_count,
                       gl_shader_stage stage,
+                      struct radv_shader_info *info,
                       struct radv_nir_compiler_options *options,
                       bool gs_copy_shader,
                       bool keep_shader_info,
@@ -1069,12 +1049,8 @@ shader_variant_compile(struct radv_device *device,
        enum ac_target_machine_options tm_options = 0;
        struct ac_llvm_compiler ac_llvm;
        struct radv_shader_binary *binary = NULL;
-       struct radv_shader_variant_info variant_info = {0};
        bool thread_compiler;
 
-       if (shaders[0]->info.stage == MESA_SHADER_FRAGMENT)
-               lower_fs_io(shaders[0]);
-
        options->family = chip_family;
        options->chip_class = device->physical_device->rad_info.chip_class;
        options->dump_shader = radv_can_dump_shader(device, module, gs_copy_shader);
@@ -1106,7 +1082,7 @@ shader_variant_compile(struct radv_device *device,
                tm_options |= AC_TM_NO_LOAD_STORE_OPT;
 
        thread_compiler = !(device->instance->debug_flags & RADV_DEBUG_NOTHREADLLVM);
-       radv_init_llvm_once();
+       ac_init_llvm_once();
        radv_init_llvm_compiler(&ac_llvm,
                                thread_compiler,
                                chip_family, tm_options,
@@ -1114,12 +1090,12 @@ shader_variant_compile(struct radv_device *device,
        if (gs_copy_shader) {
                assert(shader_count == 1);
                radv_compile_gs_copy_shader(&ac_llvm, *shaders, &binary,
-                                           &variant_info, options);
+                                           info, options);
        } else {
-               radv_compile_nir_shader(&ac_llvm, &binary, &variant_info,
+               radv_compile_nir_shader(&ac_llvm, &binary, info,
                                        shaders, shader_count, options);
        }
-       binary->variant_info = variant_info;
+       binary->info = *info;
 
        radv_destroy_llvm_compiler(&ac_llvm, thread_compiler);
 
@@ -1158,6 +1134,7 @@ radv_shader_variant_compile(struct radv_device *device,
                           int shader_count,
                           struct radv_pipeline_layout *layout,
                           const struct radv_shader_variant_key *key,
+                          struct radv_shader_info *info,
                           bool keep_shader_info,
                           struct radv_shader_binary **binary_out)
 {
@@ -1171,13 +1148,14 @@ radv_shader_variant_compile(struct radv_device *device,
        options.supports_spill = true;
        options.robust_buffer_access = device->robust_buffer_access;
 
-       return shader_variant_compile(device, module, shaders, shader_count, shaders[shader_count - 1]->info.stage,
+       return shader_variant_compile(device, module, shaders, shader_count, shaders[shader_count - 1]->info.stage, info,
                                     &options, false, keep_shader_info, binary_out);
 }
 
 struct radv_shader_variant *
 radv_create_gs_copy_shader(struct radv_device *device,
                           struct nir_shader *shader,
+                          struct radv_shader_info *info,
                           struct radv_shader_binary **binary_out,
                           bool keep_shader_info,
                           bool multiview)
@@ -1187,7 +1165,7 @@ radv_create_gs_copy_shader(struct radv_device *device,
        options.key.has_multiview_view_index = multiview;
 
        return shader_variant_compile(device, NULL, &shader, 1, MESA_SHADER_VERTEX,
-                                    &options, true, keep_shader_info, binary_out);
+                                     info, &options, true, keep_shader_info, binary_out);
 }
 
 void
@@ -1208,7 +1186,7 @@ radv_shader_variant_destroy(struct radv_device *device,
 }
 
 const char *
-radv_get_shader_name(struct radv_shader_variant_info *info,
+radv_get_shader_name(struct radv_shader_info *info,
                     gl_shader_stage stage)
 {
        switch (stage) {
@@ -1268,16 +1246,16 @@ radv_get_max_waves(struct radv_device *device,
 {
        enum chip_class chip_class = device->physical_device->rad_info.chip_class;
        unsigned lds_increment = chip_class >= GFX7 ? 512 : 256;
-       uint8_t wave_size = variant->info.info.wave_size;
+       uint8_t wave_size = variant->info.wave_size;
        struct ac_shader_config *conf = &variant->config;
        unsigned max_simd_waves;
        unsigned lds_per_wave = 0;
 
-       max_simd_waves = ac_get_max_simd_waves(device->physical_device->rad_info.family);
+       max_simd_waves = ac_get_max_wave64_per_simd(device->physical_device->rad_info.family);
 
        if (stage == MESA_SHADER_FRAGMENT) {
                lds_per_wave = conf->lds_size * lds_increment +
-                              align(variant->info.info.ps.num_interp * 48,
+                              align(variant->info.ps.num_interp * 48,
                                     lds_increment);
        } else if (stage == MESA_SHADER_COMPUTE) {
                unsigned max_workgroup_size =
@@ -1289,7 +1267,8 @@ radv_get_max_waves(struct radv_device *device,
        if (conf->num_sgprs)
                max_simd_waves =
                        MIN2(max_simd_waves,
-                            ac_get_num_physical_sgprs(chip_class) / conf->num_sgprs);
+                            ac_get_num_physical_sgprs(&device->physical_device->rad_info) /
+                            conf->num_sgprs);
 
        if (conf->num_vgprs)
                max_simd_waves =
@@ -1334,7 +1313,7 @@ generate_shader_stats(struct radv_device *device,
                                   "********************\n\n\n",
                                   conf->num_sgprs, conf->num_vgprs,
                                   conf->spilled_sgprs, conf->spilled_vgprs,
-                                  variant->info.private_mem_vgprs, variant->code_size,
+                                  variant->info.private_mem_vgprs, variant->exec_size,
                                   conf->lds_size, conf->scratch_bytes_per_wave,
                                   max_simd_waves);
 }
@@ -1386,7 +1365,7 @@ radv_GetShaderInfoAMD(VkDevice _device,
                        VkShaderStatisticsInfoAMD statistics = {};
                        statistics.shaderStageMask = shaderStage;
                        statistics.numPhysicalVgprs = RADV_NUM_PHYSICAL_VGPRS;
-                       statistics.numPhysicalSgprs = ac_get_num_physical_sgprs(device->physical_device->rad_info.chip_class);
+                       statistics.numPhysicalSgprs = ac_get_num_physical_sgprs(&device->physical_device->rad_info);
                        statistics.numAvailableSgprs = statistics.numPhysicalSgprs;
 
                        if (stage == MESA_SHADER_COMPUTE) {