radv: Add code to compile merged shaders.
[mesa.git] / src / amd / vulkan / radv_pipeline.c
index 7102d07a03a19670bd44f04d6ea3ecdea31ee507..4f840d156fa895853706d10f1c98ae42183d9531 100644 (file)
@@ -1637,7 +1637,7 @@ void radv_create_shaders(struct radv_pipeline *pipeline,
 
        if (nir[MESA_SHADER_FRAGMENT]) {
                pipeline->shaders[MESA_SHADER_FRAGMENT] =
-                       radv_shader_variant_create(device, modules[MESA_SHADER_FRAGMENT], nir[MESA_SHADER_FRAGMENT],
+                       radv_shader_variant_create(device, modules[MESA_SHADER_FRAGMENT], &nir[MESA_SHADER_FRAGMENT], 1,
                                                   pipeline->layout, keys ? keys + MESA_SHADER_FRAGMENT : 0,
                                                   &codes[MESA_SHADER_FRAGMENT], &code_sizes[MESA_SHADER_FRAGMENT]);
 
@@ -1652,14 +1652,35 @@ void radv_create_shaders(struct radv_pipeline *pipeline,
                pipeline->active_stages |= mesa_to_vk_shader_stage(MESA_SHADER_FRAGMENT);
        }
 
+       if (device->physical_device->rad_info.chip_class >= GFX9 &&
+           modules[MESA_SHADER_TESS_CTRL] && !pipeline->shaders[MESA_SHADER_TESS_CTRL]) {
+               struct nir_shader *combined_nir[] = {nir[MESA_SHADER_VERTEX], nir[MESA_SHADER_TESS_CTRL]};
+               struct ac_shader_variant_key key = keys[MESA_SHADER_TESS_CTRL];
+               key.tcs.vs_key = keys[MESA_SHADER_VERTEX].vs;
+               pipeline->shaders[MESA_SHADER_TESS_CTRL] = radv_shader_variant_create(device, modules[MESA_SHADER_TESS_CTRL], combined_nir, 2,
+                                                                                     pipeline->layout,
+                                                                                     &key, &codes[MESA_SHADER_TESS_CTRL],
+                                                                                     &code_sizes[MESA_SHADER_TESS_CTRL]);
+               modules[MESA_SHADER_VERTEX] = NULL;
+       }
+
+       if (device->physical_device->rad_info.chip_class >= GFX9 &&
+           modules[MESA_SHADER_GEOMETRY] && !pipeline->shaders[MESA_SHADER_GEOMETRY]) {
+               gl_shader_stage pre_stage = modules[MESA_SHADER_TESS_EVAL] ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;
+               struct nir_shader *combined_nir[] = {nir[pre_stage], nir[MESA_SHADER_GEOMETRY]};
+               pipeline->shaders[MESA_SHADER_GEOMETRY] = radv_shader_variant_create(device, modules[MESA_SHADER_GEOMETRY], combined_nir, 2,
+                                                                                    pipeline->layout,
+                                                                                    &keys[pre_stage] , &codes[MESA_SHADER_GEOMETRY],
+                                                                                    &code_sizes[MESA_SHADER_GEOMETRY]);
+               modules[pre_stage] = NULL;
+       }
+
        for (int i = 0; i < MESA_SHADER_STAGES; ++i) {
                if(modules[i] && !pipeline->shaders[i]) {
-                       pipeline->shaders[i] = radv_shader_variant_create(device, modules[i], nir[i],
+                       pipeline->shaders[i] = radv_shader_variant_create(device, modules[i], &nir[i], 1,
                                                                          pipeline->layout,
                                                                          keys ? keys + i : 0, &codes[i],
                                                                          &code_sizes[i]);
-
-               pipeline->active_stages |= mesa_to_vk_shader_stage(i);
                }
        }