radv: add shader_variant_create() helper function
[mesa.git] / src / amd / vulkan / radv_shader.c
index 9bb8f1ddf2ee642c033af6dd63f44ceffbf6f613..de7d9a275247198729acab2c4f26bbf249022e17 100644 (file)
@@ -150,8 +150,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
                           struct radv_shader_module *module,
                           const char *entrypoint_name,
                           gl_shader_stage stage,
-                          const VkSpecializationInfo *spec_info,
-                          bool dump)
+                          const VkSpecializationInfo *spec_info)
 {
        if (strcmp(entrypoint_name, "main") != 0) {
                radv_finishme("Multiple shaders per module not really supported");
@@ -263,7 +262,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
        nir_remove_dead_variables(nir, nir_var_local);
        radv_optimize_nir(nir);
 
-       if (dump)
+       if (device->debug_flags & RADV_DEBUG_DUMP_SHADERS)
                nir_print_shader(nir, stderr);
 
        return nir;
@@ -375,42 +374,47 @@ radv_fill_shader_variant(struct radv_device *device,
        memcpy(ptr, binary->code, binary->code_size);
 }
 
-struct radv_shader_variant *
-radv_shader_variant_create(struct radv_device *device,
-                          struct nir_shader *shader,
-                          struct radv_pipeline_layout *layout,
-                          const struct ac_shader_variant_key *key,
-                          void **code_out,
-                          unsigned *code_size_out,
-                          bool dump)
+static struct radv_shader_variant *
+shader_variant_create(struct radv_device *device,
+                     struct nir_shader *shader,
+                     gl_shader_stage stage,
+                     struct ac_nir_compiler_options *options,
+                     bool gs_copy_shader,
+                     void **code_out,
+                     unsigned *code_size_out)
 {
-       struct radv_shader_variant *variant = calloc(1, sizeof(struct radv_shader_variant));
        enum radeon_family chip_family = device->physical_device->rad_info.family;
+       bool dump_shaders = device->debug_flags & RADV_DEBUG_DUMP_SHADERS;
+       enum ac_target_machine_options tm_options = 0;
+       struct radv_shader_variant *variant;
+       struct ac_shader_binary binary;
        LLVMTargetMachineRef tm;
+
+       variant = calloc(1, sizeof(struct radv_shader_variant));
        if (!variant)
                return NULL;
 
-       struct ac_nir_compiler_options options = {0};
-       options.layout = layout;
-       if (key)
-               options.key = *key;
+       options->family = chip_family;
+       options->chip_class = device->physical_device->rad_info.chip_class;
 
-       struct ac_shader_binary binary;
-       enum ac_target_machine_options tm_options = 0;
-       options.unsafe_math = !!(device->debug_flags & RADV_DEBUG_UNSAFE_MATH);
-       options.family = chip_family;
-       options.chip_class = device->physical_device->rad_info.chip_class;
-       options.supports_spill = device->llvm_supports_spill;
-       if (options.supports_spill)
+       if (options->supports_spill)
                tm_options |= AC_TM_SUPPORTS_SPILL;
        if (device->instance->perftest_flags & RADV_PERFTEST_SISCHED)
                tm_options |= AC_TM_SISCHED;
        tm = ac_create_target_machine(chip_family, tm_options);
-       ac_compile_nir_shader(tm, &binary, &variant->config,
-                             &variant->info, shader, &options, dump);
+
+       if (gs_copy_shader) {
+               ac_create_gs_copy_shader(tm, shader, &binary, &variant->config,
+                                        &variant->info, options, dump_shaders);
+       } else {
+               ac_compile_nir_shader(tm, &binary, &variant->config,
+                                     &variant->info, shader, options,
+                                     dump_shaders);
+       }
+
        LLVMDisposeTargetMachine(tm);
 
-       radv_fill_shader_variant(device, variant, &binary, shader->stage);
+       radv_fill_shader_variant(device, variant, &binary, stage);
 
        if (code_out) {
                *code_out = binary.code;
@@ -427,44 +431,39 @@ radv_shader_variant_create(struct radv_device *device,
 }
 
 struct radv_shader_variant *
-radv_create_gs_copy_shader(struct radv_device *device, struct nir_shader *nir,
-                          void **code_out, unsigned *code_size_out,
-                          bool dump_shader, bool multiview)
+radv_shader_variant_create(struct radv_device *device,
+                          struct nir_shader *shader,
+                          struct radv_pipeline_layout *layout,
+                          const struct ac_shader_variant_key *key,
+                          void **code_out,
+                          unsigned *code_size_out)
 {
-       struct radv_shader_variant *variant = calloc(1, sizeof(struct radv_shader_variant));
-       enum radeon_family chip_family = device->physical_device->rad_info.family;
-       LLVMTargetMachineRef tm;
-       if (!variant)
-               return NULL;
+       struct ac_nir_compiler_options options = {0};
+
+       options.layout = layout;
+       if (key)
+               options.key = *key;
 
+       options.unsafe_math = !!(device->debug_flags & RADV_DEBUG_UNSAFE_MATH);
+       options.supports_spill = device->llvm_supports_spill;
+
+       return shader_variant_create(device, shader, shader->stage,
+                                    &options, false, code_out, code_size_out);
+}
+
+struct radv_shader_variant *
+radv_create_gs_copy_shader(struct radv_device *device,
+                          struct nir_shader *shader,
+                          void **code_out,
+                          unsigned *code_size_out,
+                          bool multiview)
+{
        struct ac_nir_compiler_options options = {0};
-       struct ac_shader_binary binary;
-       enum ac_target_machine_options tm_options = 0;
-       options.family = chip_family;
-       options.chip_class = device->physical_device->rad_info.chip_class;
-       options.key.has_multiview_view_index = multiview;
-       if (options.supports_spill)
-               tm_options |= AC_TM_SUPPORTS_SPILL;
-       if (device->instance->perftest_flags & RADV_PERFTEST_SISCHED)
-               tm_options |= AC_TM_SISCHED;
-       tm = ac_create_target_machine(chip_family, tm_options);
-       ac_create_gs_copy_shader(tm, nir, &binary, &variant->config, &variant->info, &options, dump_shader);
-       LLVMDisposeTargetMachine(tm);
 
-       radv_fill_shader_variant(device, variant, &binary, MESA_SHADER_VERTEX);
+       options.key.has_multiview_view_index = multiview;
 
-       if (code_out) {
-               *code_out = binary.code;
-               *code_size_out = binary.code_size;
-       } else
-               free(binary.code);
-       free(binary.config);
-       free(binary.rodata);
-       free(binary.global_symbol_offsets);
-       free(binary.relocs);
-       free(binary.disasm_string);
-       variant->ref_count = 1;
-       return variant; 
+       return shader_variant_create(device, shader, MESA_SHADER_VERTEX,
+                                    &options, true, code_out, code_size_out);
 }
 
 void