radv: Add code to compile merged shaders.
[mesa.git] / src / amd / vulkan / radv_shader.c
index c0fbdd3d491e7d407f450f921c4f6f207add1e3c..ed76e1d024213a7b9f24cfcc9d4a5150eadd24b2 100644 (file)
@@ -39,7 +39,6 @@
 
 #include "sid.h"
 #include "gfx9d.h"
-#include "r600d_common.h"
 #include "ac_binary.h"
 #include "ac_llvm_util.h"
 #include "ac_nir_to_llvm.h"
@@ -64,6 +63,7 @@ static const struct nir_shader_compiler_options nir_options = {
        .lower_unpack_unorm_4x8 = true,
        .lower_extract_byte = true,
        .lower_extract_word = true,
+       .lower_ffma = true,
        .max_unroll_iterations = 32
 };
 
@@ -110,7 +110,7 @@ void radv_DestroyShaderModule(
        vk_free2(&device->alloc, pAllocator, module);
 }
 
-static void
+void
 radv_optimize_nir(struct nir_shader *shader)
 {
         bool progress;
@@ -129,6 +129,7 @@ radv_optimize_nir(struct nir_shader *shader)
                 if (nir_opt_trivial_continues(shader)) {
                         progress = true;
                         NIR_PASS(progress, shader, nir_copy_prop);
+                       NIR_PASS(progress, shader, nir_opt_remove_phis);
                         NIR_PASS(progress, shader, nir_opt_dce);
                 }
                 NIR_PASS(progress, shader, nir_opt_if);
@@ -173,8 +174,8 @@ radv_shader_compile_to_nir(struct radv_device *device,
                uint32_t *spirv = (uint32_t *) module->data;
                assert(module->size % 4 == 0);
 
-               if (device->debug_flags & RADV_DEBUG_DUMP_SPIRV)
-                       radv_print_spirv(module, stderr);
+               if (device->instance->debug_flags & RADV_DEBUG_DUMP_SPIRV)
+                       radv_print_spirv(spirv, module->size, stderr);
 
                uint32_t num_spec_entries = 0;
                struct nir_spirv_specialization *spec_entries = NULL;
@@ -262,7 +263,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
        nir_remove_dead_variables(nir, nir_var_local);
        radv_optimize_nir(nir);
 
-       if (device->debug_flags & RADV_DEBUG_DUMP_SHADERS)
+       if (device->instance->debug_flags & RADV_DEBUG_DUMP_SHADERS)
                nir_print_shader(nir, stderr);
 
        return nir;
@@ -374,42 +375,50 @@ radv_fill_shader_variant(struct radv_device *device,
        memcpy(ptr, binary->code, binary->code_size);
 }
 
-struct radv_shader_variant *
-radv_shader_variant_create(struct radv_device *device,
-                          struct nir_shader *shader,
-                          struct radv_pipeline_layout *layout,
-                          const struct ac_shader_variant_key *key,
-                          void **code_out,
-                          unsigned *code_size_out)
+static struct radv_shader_variant *
+shader_variant_create(struct radv_device *device,
+                     struct radv_shader_module *module,
+                     struct nir_shader * const *shaders,
+                     int shader_count,
+                     gl_shader_stage stage,
+                     struct ac_nir_compiler_options *options,
+                     bool gs_copy_shader,
+                     void **code_out,
+                     unsigned *code_size_out)
 {
-       struct radv_shader_variant *variant = calloc(1, sizeof(struct radv_shader_variant));
        enum radeon_family chip_family = device->physical_device->rad_info.family;
+       bool dump_shaders = device->instance->debug_flags & RADV_DEBUG_DUMP_SHADERS;
+       enum ac_target_machine_options tm_options = 0;
+       struct radv_shader_variant *variant;
+       struct ac_shader_binary binary;
        LLVMTargetMachineRef tm;
+
+       variant = calloc(1, sizeof(struct radv_shader_variant));
        if (!variant)
                return NULL;
 
-       struct ac_nir_compiler_options options = {0};
-       options.layout = layout;
-       if (key)
-               options.key = *key;
+       options->family = chip_family;
+       options->chip_class = device->physical_device->rad_info.chip_class;
 
-       struct ac_shader_binary binary;
-       enum ac_target_machine_options tm_options = 0;
-       options.unsafe_math = !!(device->debug_flags & RADV_DEBUG_UNSAFE_MATH);
-       options.family = chip_family;
-       options.chip_class = device->physical_device->rad_info.chip_class;
-       options.supports_spill = device->llvm_supports_spill;
-       if (options.supports_spill)
+       if (options->supports_spill)
                tm_options |= AC_TM_SUPPORTS_SPILL;
        if (device->instance->perftest_flags & RADV_PERFTEST_SISCHED)
                tm_options |= AC_TM_SISCHED;
        tm = ac_create_target_machine(chip_family, tm_options);
-       ac_compile_nir_shader(tm, &binary, &variant->config,
-                             &variant->info, shader, &options,
-                             device->debug_flags & RADV_DEBUG_DUMP_SHADERS);
+
+       if (gs_copy_shader) {
+               assert(shader_count == 1);
+               ac_create_gs_copy_shader(tm, *shaders, &binary, &variant->config,
+                                        &variant->info, options, dump_shaders);
+       } else {
+               ac_compile_nir_shader(tm, &binary, &variant->config,
+                                     &variant->info, shaders, shader_count, options,
+                                     dump_shaders);
+       }
+
        LLVMDisposeTargetMachine(tm);
 
-       radv_fill_shader_variant(device, variant, &binary, shader->stage);
+       radv_fill_shader_variant(device, variant, &binary, stage);
 
        if (code_out) {
                *code_out = binary.code;
@@ -420,52 +429,58 @@ radv_shader_variant_create(struct radv_device *device,
        free(binary.rodata);
        free(binary.global_symbol_offsets);
        free(binary.relocs);
-       free(binary.disasm_string);
        variant->ref_count = 1;
+
+       if (device->trace_bo) {
+               variant->disasm_string = binary.disasm_string;
+               if (!gs_copy_shader && !module->nir) {
+                       variant->nir = *shaders;
+                       variant->spirv = (uint32_t *)module->data;
+                       variant->spirv_size = module->size;
+               }
+       } else {
+               free(binary.disasm_string);
+       }
+
        return variant;
 }
 
 struct radv_shader_variant *
-radv_create_gs_copy_shader(struct radv_device *device, struct nir_shader *nir,
-                          void **code_out, unsigned *code_size_out,
-                          bool multiview)
+radv_shader_variant_create(struct radv_device *device,
+                          struct radv_shader_module *module,
+                          struct nir_shader *const *shaders,
+                          int shader_count,
+                          struct radv_pipeline_layout *layout,
+                          const struct ac_shader_variant_key *key,
+                          void **code_out,
+                          unsigned *code_size_out)
 {
-       struct radv_shader_variant *variant = calloc(1, sizeof(struct radv_shader_variant));
-       enum radeon_family chip_family = device->physical_device->rad_info.family;
-       LLVMTargetMachineRef tm;
-       if (!variant)
-               return NULL;
+       struct ac_nir_compiler_options options = {0};
+
+       options.layout = layout;
+       if (key)
+               options.key = *key;
+
+       options.unsafe_math = !!(device->instance->debug_flags & RADV_DEBUG_UNSAFE_MATH);
+       options.supports_spill = device->llvm_supports_spill;
+
+       return shader_variant_create(device, module, shaders, shader_count, shaders[shader_count - 1]->stage,
+                                    &options, false, code_out, code_size_out);
+}
 
+struct radv_shader_variant *
+radv_create_gs_copy_shader(struct radv_device *device,
+                          struct nir_shader *shader,
+                          void **code_out,
+                          unsigned *code_size_out,
+                          bool multiview)
+{
        struct ac_nir_compiler_options options = {0};
-       struct ac_shader_binary binary;
-       enum ac_target_machine_options tm_options = 0;
-       options.family = chip_family;
-       options.chip_class = device->physical_device->rad_info.chip_class;
-       options.key.has_multiview_view_index = multiview;
-       if (options.supports_spill)
-               tm_options |= AC_TM_SUPPORTS_SPILL;
-       if (device->instance->perftest_flags & RADV_PERFTEST_SISCHED)
-               tm_options |= AC_TM_SISCHED;
-       tm = ac_create_target_machine(chip_family, tm_options);
-       ac_create_gs_copy_shader(tm, nir, &binary, &variant->config,
-                                &variant->info, &options,
-                                device->debug_flags & RADV_DEBUG_DUMP_SHADERS);
-       LLVMDisposeTargetMachine(tm);
 
-       radv_fill_shader_variant(device, variant, &binary, MESA_SHADER_VERTEX);
+       options.key.has_multiview_view_index = multiview;
 
-       if (code_out) {
-               *code_out = binary.code;
-               *code_size_out = binary.code_size;
-       } else
-               free(binary.code);
-       free(binary.config);
-       free(binary.rodata);
-       free(binary.global_symbol_offsets);
-       free(binary.relocs);
-       free(binary.disasm_string);
-       variant->ref_count = 1;
-       return variant; 
+       return shader_variant_create(device, NULL, &shader, 1, MESA_SHADER_VERTEX,
+                                    &options, true, code_out, code_size_out);
 }
 
 void
@@ -479,6 +494,8 @@ radv_shader_variant_destroy(struct radv_device *device,
        list_del(&variant->slab_list);
        mtx_unlock(&device->shader_slab_mutex);
 
+       ralloc_free(variant->nir);
+       free(variant->disasm_string);
        free(variant);
 }
 
@@ -525,3 +542,73 @@ radv_get_shader_name(struct radv_shader_variant *var, gl_shader_stage stage)
        };
 }
 
+void
+radv_shader_dump_stats(struct radv_device *device,
+                      struct radv_shader_variant *variant,
+                      gl_shader_stage stage,
+                      FILE *file)
+{
+       unsigned lds_increment = device->physical_device->rad_info.chip_class >= CIK ? 512 : 256;
+       struct ac_shader_config *conf;
+       unsigned max_simd_waves;
+       unsigned lds_per_wave = 0;
+
+       switch (device->physical_device->rad_info.family) {
+       /* These always have 8 waves: */
+       case CHIP_POLARIS10:
+       case CHIP_POLARIS11:
+       case CHIP_POLARIS12:
+               max_simd_waves = 8;
+               break;
+       default:
+               max_simd_waves = 10;
+       }
+
+       conf = &variant->config;
+
+       if (stage == MESA_SHADER_FRAGMENT) {
+               lds_per_wave = conf->lds_size * lds_increment +
+                              align(variant->info.fs.num_interp * 48,
+                                    lds_increment);
+       }
+
+       if (conf->num_sgprs) {
+               if (device->physical_device->rad_info.chip_class >= VI)
+                       max_simd_waves = MIN2(max_simd_waves, 800 / conf->num_sgprs);
+               else
+                       max_simd_waves = MIN2(max_simd_waves, 512 / conf->num_sgprs);
+       }
+
+       if (conf->num_vgprs)
+               max_simd_waves = MIN2(max_simd_waves, 256 / conf->num_vgprs);
+
+       /* LDS is 64KB per CU (4 SIMDs), divided into 16KB blocks per SIMD
+        * that PS can use.
+        */
+       if (lds_per_wave)
+               max_simd_waves = MIN2(max_simd_waves, 16384 / lds_per_wave);
+
+       fprintf(file, "\n%s:\n", radv_get_shader_name(variant, stage));
+
+       if (stage == MESA_SHADER_FRAGMENT) {
+               fprintf(file, "*** SHADER CONFIG ***\n"
+                       "SPI_PS_INPUT_ADDR = 0x%04x\n"
+                       "SPI_PS_INPUT_ENA  = 0x%04x\n",
+                       conf->spi_ps_input_addr, conf->spi_ps_input_ena);
+       }
+
+       fprintf(file, "*** SHADER STATS ***\n"
+               "SGPRS: %d\n"
+               "VGPRS: %d\n"
+               "Spilled SGPRs: %d\n"
+               "Spilled VGPRs: %d\n"
+               "Code Size: %d bytes\n"
+               "LDS: %d blocks\n"
+               "Scratch: %d bytes per wave\n"
+               "Max Waves: %d\n"
+               "********************\n\n\n",
+               conf->num_sgprs, conf->num_vgprs,
+               conf->spilled_sgprs, conf->spilled_vgprs, variant->code_size,
+               conf->lds_size, conf->scratch_bytes_per_wave,
+               max_simd_waves);
+}