nir: Add lower_rotate flag and set to true in all drivers
[mesa.git] / src / amd / vulkan / radv_shader.c
index 37fdbfe33f53ef317a900e8c75ea51028066a3fc..3c3f761ca8966775814b4263c663248de1e7facf 100644 (file)
@@ -58,6 +58,8 @@ static const struct nir_shader_compiler_options nir_options = {
        .lower_device_index_to_zero = true,
        .lower_fsat = true,
        .lower_fdiv = true,
+       .lower_bitfield_insert_to_bitfield_select = true,
+       .lower_bitfield_extract = true,
        .lower_sub = true,
        .lower_pack_snorm_2x16 = true,
        .lower_pack_snorm_4x8 = true,
@@ -72,6 +74,7 @@ static const struct nir_shader_compiler_options nir_options = {
        .lower_ffma = true,
        .lower_fpow = true,
        .lower_mul_2x32_64 = true,
+       .lower_rotate = true,
        .max_unroll_iterations = 32
 };
 
@@ -245,6 +248,9 @@ radv_shader_compile_to_nir(struct radv_device *device,
                const struct spirv_to_nir_options spirv_options = {
                        .lower_ubo_ssbo_access_to_offsets = true,
                        .caps = {
+                               .amd_gcn_shader = true,
+                               .amd_shader_ballot = device->instance->perftest_flags & RADV_PERFTEST_SHADER_BALLOT,
+                               .amd_trinary_minmax = true,
                                .derivative_group = true,
                                .descriptor_array_dynamic_indexing = true,
                                .descriptor_array_non_uniform_indexing = true,
@@ -253,7 +259,6 @@ radv_shader_compile_to_nir(struct radv_device *device,
                                .draw_parameters = true,
                                .float16 = true,
                                .float64 = true,
-                               .gcn_shader = true,
                                .geometry_streams = true,
                                .image_read_without_format = true,
                                .image_write_without_format = true,
@@ -277,7 +282,6 @@ radv_shader_compile_to_nir(struct radv_device *device,
                                .subgroup_vote = true,
                                .tessellation = true,
                                .transform_feedback = true,
-                               .trinary_minmax = true,
                                .variable_pointers = true,
                        },
                        .ubo_addr_format = nir_address_format_32bit_index_offset,
@@ -468,6 +472,7 @@ radv_get_shader_binary_size(struct ac_shader_binary *binary)
 static void
 radv_fill_shader_variant(struct radv_device *device,
                         struct radv_shader_variant *variant,
+                        struct radv_nir_compiler_options *options,
                         struct ac_shader_binary *binary,
                         gl_shader_stage stage)
 {
@@ -492,21 +497,54 @@ radv_fill_shader_variant(struct radv_device *device,
 
        switch (stage) {
        case MESA_SHADER_TESS_EVAL:
-               vgpr_comp_cnt = 3;
+               if (options->key.tes.as_es) {
+                       assert(device->physical_device->rad_info.chip_class <= GFX8);
+                       vgpr_comp_cnt = info->uses_prim_id ? 3 : 2;
+               } else {
+                       bool enable_prim_id = options->key.tes.export_prim_id || info->uses_prim_id;
+                       vgpr_comp_cnt = enable_prim_id ? 3 : 2;
+               }
                variant->rsrc2 |= S_00B12C_OC_LDS_EN(1);
                break;
        case MESA_SHADER_TESS_CTRL:
                if (device->physical_device->rad_info.chip_class >= GFX9) {
-                       vgpr_comp_cnt = variant->info.vs.vgpr_comp_cnt;
+                       /* We need at least 2 components for LS.
+                        * VGPR0-3: (VertexID, RelAutoindex, InstanceID / StepRate0, InstanceID).
+                        * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
+                        */
+                       vgpr_comp_cnt = info->vs.needs_instance_id ? 2 : 1;
                } else {
                        variant->rsrc2 |= S_00B12C_OC_LDS_EN(1);
                }
                break;
        case MESA_SHADER_VERTEX:
-       case MESA_SHADER_GEOMETRY:
-               vgpr_comp_cnt = variant->info.vs.vgpr_comp_cnt;
+               if (variant->info.vs.as_ls) {
+                       assert(device->physical_device->rad_info.chip_class <= GFX8);
+                       /* We need at least 2 components for LS.
+                        * VGPR0-3: (VertexID, RelAutoindex, InstanceID / StepRate0, InstanceID).
+                        * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
+                        */
+                       vgpr_comp_cnt = info->vs.needs_instance_id ? 2 : 1;
+               } else if (variant->info.vs.as_es) {
+                       assert(device->physical_device->rad_info.chip_class <= GFX8);
+                       /* VGPR0-3: (VertexID, InstanceID / StepRate0, ...) */
+                       vgpr_comp_cnt = info->vs.needs_instance_id ? 1 : 0;
+               } else {
+                       /* VGPR0-3: (VertexID, InstanceID / StepRate0, PrimID, InstanceID)
+                        * If PrimID is disabled. InstanceID / StepRate1 is loaded instead.
+                        * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
+                        */
+                       if (options->key.vs.export_prim_id) {
+                               vgpr_comp_cnt = 2;
+                       } else if (info->vs.needs_instance_id) {
+                               vgpr_comp_cnt = 1;
+                       } else {
+                               vgpr_comp_cnt = 0;
+                       }
+               }
                break;
        case MESA_SHADER_FRAGMENT:
+       case MESA_SHADER_GEOMETRY:
                break;
        case MESA_SHADER_COMPUTE:
                variant->rsrc2 |=
@@ -529,9 +567,10 @@ radv_fill_shader_variant(struct radv_device *device,
                unsigned gs_vgpr_comp_cnt, es_vgpr_comp_cnt;
 
                if (es_type == MESA_SHADER_VERTEX) {
-                       es_vgpr_comp_cnt = variant->info.vs.vgpr_comp_cnt;
+                       /* VGPR0-3: (VertexID, InstanceID / StepRate0, ...) */
+                       es_vgpr_comp_cnt = info->vs.needs_instance_id ? 1 : 0;
                } else if (es_type == MESA_SHADER_TESS_EVAL) {
-                       es_vgpr_comp_cnt = 3;
+                       es_vgpr_comp_cnt = info->uses_prim_id ? 3 : 2;
                } else {
                        unreachable("invalid shader ES type");
                }
@@ -666,7 +705,7 @@ shader_variant_create(struct radv_device *device,
 
        radv_destroy_llvm_compiler(&ac_llvm, thread_compiler);
 
-       radv_fill_shader_variant(device, variant, &binary, stage);
+       radv_fill_shader_variant(device, variant, options, &binary, stage);
 
        if (code_out) {
                *code_out = binary.code;