nir: allow specifying a set of opcodes in lower_alu_to_scalar
[mesa.git] / src / amd / vulkan / radv_shader.c
index 13f1f9aa9dc9fcf4ec39f1ac462886cc7470bfe5..17d6c5bc33adeff28116e7354b0fbdfb08368129 100644 (file)
@@ -124,6 +124,10 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively,
                   bool allow_copies)
 {
         bool progress;
+        unsigned lower_flrp =
+                (shader->options->lower_flrp16 ? 16 : 0) |
+                (shader->options->lower_flrp32 ? 32 : 0) |
+                (shader->options->lower_flrp64 ? 64 : 0);
 
         do {
                 progress = false;
@@ -146,7 +150,7 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively,
                NIR_PASS(progress, shader, nir_opt_copy_prop_vars);
                NIR_PASS(progress, shader, nir_opt_dead_write_vars);
 
-                NIR_PASS_V(shader, nir_lower_alu_to_scalar);
+                NIR_PASS_V(shader, nir_lower_alu_to_scalar, NULL);
                 NIR_PASS_V(shader, nir_lower_phis_to_scalar);
 
                 NIR_PASS(progress, shader, nir_copy_prop);
@@ -162,8 +166,29 @@ radv_optimize_nir(struct nir_shader *shader, bool optimize_conservatively,
                 NIR_PASS(progress, shader, nir_opt_dead_cf);
                 NIR_PASS(progress, shader, nir_opt_cse);
                 NIR_PASS(progress, shader, nir_opt_peephole_select, 8, true, true);
-                NIR_PASS(progress, shader, nir_opt_algebraic);
                 NIR_PASS(progress, shader, nir_opt_constant_folding);
+                NIR_PASS(progress, shader, nir_opt_algebraic);
+
+                if (lower_flrp != 0) {
+                        bool lower_flrp_progress = false;
+                        NIR_PASS(lower_flrp_progress,
+                                 shader,
+                                 nir_lower_flrp,
+                                 lower_flrp,
+                                 false /* always_precise */,
+                                 shader->options->lower_ffma);
+                        if (lower_flrp_progress) {
+                                NIR_PASS(progress, shader,
+                                         nir_opt_constant_folding);
+                                progress = true;
+                        }
+
+                        /* Nothing should rematerialize any flrps, so we only
+                         * need to do this lowering once.
+                         */
+                        lower_flrp = 0;
+                }
+
                 NIR_PASS(progress, shader, nir_opt_undef);
                 NIR_PASS(progress, shader, nir_opt_conditional_discard);
                 if (shader->options->max_unroll_iterations) {
@@ -181,7 +206,8 @@ radv_shader_compile_to_nir(struct radv_device *device,
                           const char *entrypoint_name,
                           gl_shader_stage stage,
                           const VkSpecializationInfo *spec_info,
-                          const VkPipelineCreateFlags flags)
+                          const VkPipelineCreateFlags flags,
+                          const struct radv_pipeline_layout *layout)
 {
        nir_shader *nir;
        nir_function *entry_point;
@@ -225,6 +251,8 @@ radv_shader_compile_to_nir(struct radv_device *device,
                        .caps = {
                                .derivative_group = true,
                                .descriptor_array_dynamic_indexing = true,
+                               .descriptor_array_non_uniform_indexing = true,
+                               .descriptor_indexing = true,
                                .device_group = true,
                                .draw_parameters = true,
                                .float16 = true,
@@ -310,6 +338,7 @@ radv_shader_compile_to_nir(struct radv_device *device,
 
                NIR_PASS_V(nir, nir_lower_system_values);
                NIR_PASS_V(nir, nir_lower_clip_cull_distance_arrays);
+               NIR_PASS_V(nir, radv_nir_lower_ycbcr_textures, layout);
        }
 
        /* Vulkan uses the separate-shader linking model */