radv: Remove remaining hard coded references to VS.
[mesa.git] / src / amd / vulkan / radv_pipeline.c
index 559862678e7e225ac1f729053a2d7971e5b607cc..4369c3a6b1b755a293951f380225de52fc93e354 100644 (file)
@@ -1207,6 +1207,16 @@ static void si_multiwave_lds_size_workaround(struct radv_device *device,
                *lds_size = MAX2(*lds_size, 8);
 }
 
+struct radv_shader_variant *
+radv_get_vertex_shader(struct radv_pipeline *pipeline)
+{
+       if (pipeline->shaders[MESA_SHADER_VERTEX])
+               return pipeline->shaders[MESA_SHADER_VERTEX];
+       if (pipeline->shaders[MESA_SHADER_TESS_CTRL])
+               return pipeline->shaders[MESA_SHADER_TESS_CTRL];
+       return pipeline->shaders[MESA_SHADER_GEOMETRY];
+}
+
 static void
 calculate_tess_state(struct radv_pipeline *pipeline,
                     const VkGraphicsPipelineCreateInfo *pCreateInfo)
@@ -1223,7 +1233,7 @@ calculate_tess_state(struct radv_pipeline *pipeline,
 
        /* This calculates how shader inputs and outputs among VS, TCS, and TES
         * are laid out in LDS. */
-       num_tcs_inputs = util_last_bit64(pipeline->shaders[MESA_SHADER_VERTEX]->info.vs.outputs_written);
+       num_tcs_inputs = util_last_bit64(radv_get_vertex_shader(pipeline)->info.vs.outputs_written);
 
        num_tcs_outputs = util_last_bit64(pipeline->shaders[MESA_SHADER_TESS_CTRL]->info.tcs.outputs_written); //tcs->outputs_written
        num_tcs_output_cp = pipeline->shaders[MESA_SHADER_TESS_CTRL]->info.tcs.tcs_vertices_out; //TCS VERTICES OUT
@@ -1597,10 +1607,15 @@ void radv_create_shaders(struct radv_pipeline *pipeline,
        }
 
        if (radv_create_shader_variants_from_pipeline_cache(device, cache, hash, pipeline->shaders) &&
-           (!modules[MESA_SHADER_GEOMETRY] || pipeline->gs_copy_shader))
+           (!modules[MESA_SHADER_GEOMETRY] || pipeline->gs_copy_shader)) {
+               for (unsigned i = 0; i < MESA_SHADER_STAGES; ++i) {
+                       if (pipeline->shaders[i])
+                               pipeline->active_stages |= mesa_to_vk_shader_stage(i);
+               }
                return;
+       }
 
-       if (!modules[MESA_SHADER_FRAGMENT]) {
+       if (!modules[MESA_SHADER_FRAGMENT] && !modules[MESA_SHADER_COMPUTE]) {
                nir_builder fs_b;
                nir_builder_init_simple_shader(&fs_b, NULL, MESA_SHADER_FRAGMENT, NULL);
                fs_b.shader->info.name = ralloc_strdup(fs_b.shader, "noop_fs");
@@ -1632,7 +1647,7 @@ void radv_create_shaders(struct radv_pipeline *pipeline,
 
        if (nir[MESA_SHADER_FRAGMENT]) {
                pipeline->shaders[MESA_SHADER_FRAGMENT] =
-                       radv_shader_variant_create(device, modules[MESA_SHADER_FRAGMENT], nir[MESA_SHADER_FRAGMENT],
+                       radv_shader_variant_create(device, modules[MESA_SHADER_FRAGMENT], &nir[MESA_SHADER_FRAGMENT], 1,
                                                   pipeline->layout, keys ? keys + MESA_SHADER_FRAGMENT : 0,
                                                   &codes[MESA_SHADER_FRAGMENT], &code_sizes[MESA_SHADER_FRAGMENT]);
 
@@ -1647,14 +1662,35 @@ void radv_create_shaders(struct radv_pipeline *pipeline,
                pipeline->active_stages |= mesa_to_vk_shader_stage(MESA_SHADER_FRAGMENT);
        }
 
+       if (device->physical_device->rad_info.chip_class >= GFX9 &&
+           modules[MESA_SHADER_TESS_CTRL] && !pipeline->shaders[MESA_SHADER_TESS_CTRL]) {
+               struct nir_shader *combined_nir[] = {nir[MESA_SHADER_VERTEX], nir[MESA_SHADER_TESS_CTRL]};
+               struct ac_shader_variant_key key = keys[MESA_SHADER_TESS_CTRL];
+               key.tcs.vs_key = keys[MESA_SHADER_VERTEX].vs;
+               pipeline->shaders[MESA_SHADER_TESS_CTRL] = radv_shader_variant_create(device, modules[MESA_SHADER_TESS_CTRL], combined_nir, 2,
+                                                                                     pipeline->layout,
+                                                                                     &key, &codes[MESA_SHADER_TESS_CTRL],
+                                                                                     &code_sizes[MESA_SHADER_TESS_CTRL]);
+               modules[MESA_SHADER_VERTEX] = NULL;
+       }
+
+       if (device->physical_device->rad_info.chip_class >= GFX9 &&
+           modules[MESA_SHADER_GEOMETRY] && !pipeline->shaders[MESA_SHADER_GEOMETRY]) {
+               gl_shader_stage pre_stage = modules[MESA_SHADER_TESS_EVAL] ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;
+               struct nir_shader *combined_nir[] = {nir[pre_stage], nir[MESA_SHADER_GEOMETRY]};
+               pipeline->shaders[MESA_SHADER_GEOMETRY] = radv_shader_variant_create(device, modules[MESA_SHADER_GEOMETRY], combined_nir, 2,
+                                                                                    pipeline->layout,
+                                                                                    &keys[pre_stage] , &codes[MESA_SHADER_GEOMETRY],
+                                                                                    &code_sizes[MESA_SHADER_GEOMETRY]);
+               modules[pre_stage] = NULL;
+       }
+
        for (int i = 0; i < MESA_SHADER_STAGES; ++i) {
                if(modules[i] && !pipeline->shaders[i]) {
-                       pipeline->shaders[i] = radv_shader_variant_create(device, modules[i], nir[i],
+                       pipeline->shaders[i] = radv_shader_variant_create(device, modules[i], &nir[i], 1,
                                                                          pipeline->layout,
                                                                          keys ? keys + i : 0, &codes[i],
                                                                          &code_sizes[i]);
-
-               pipeline->active_stages |= mesa_to_vk_shader_stage(i);
                }
        }
 
@@ -1691,7 +1727,7 @@ void radv_create_shaders(struct radv_pipeline *pipeline,
 
        for (int i = 0; i < MESA_SHADER_STAGES; ++i) {
                free(codes[i]);
-               if (modules[i] && !modules[i]->nir)
+               if (modules[i] && !modules[i]->nir && !pipeline->device->trace_bo)
                        ralloc_free(nir[i]);
        }
 
@@ -1996,9 +2032,9 @@ radv_pipeline_init(struct radv_pipeline *pipeline,
        struct ac_userdata_info *loc = radv_lookup_user_sgpr(pipeline, MESA_SHADER_VERTEX,
                                                             AC_UD_VS_BASE_VERTEX_START_INSTANCE);
        if (loc->sgpr_idx != -1) {
-               pipeline->graphics.vtx_base_sgpr = radv_shader_stage_to_user_data_0(MESA_SHADER_VERTEX, radv_pipeline_has_gs(pipeline), radv_pipeline_has_tess(pipeline));
+               pipeline->graphics.vtx_base_sgpr = radv_shader_stage_to_user_data_0(MESA_SHADER_VERTEX, device->physical_device->rad_info.chip_class, radv_pipeline_has_gs(pipeline), radv_pipeline_has_tess(pipeline));
                pipeline->graphics.vtx_base_sgpr += loc->sgpr_idx * 4;
-               if (pipeline->shaders[MESA_SHADER_VERTEX]->info.info.vs.needs_draw_id)
+               if (radv_get_vertex_shader(pipeline)->info.info.vs.needs_draw_id)
                        pipeline->graphics.vtx_emit_num = 3;
                else
                        pipeline->graphics.vtx_emit_num = 2;