+ keys[MESA_SHADER_FRAGMENT].fs.is_dual_src = key->is_dual_src;
+
+ if (nir[MESA_SHADER_COMPUTE]) {
+ keys[MESA_SHADER_COMPUTE].cs.subgroup_size = key->compute_subgroup_size;
+ }
+}
+
+static uint8_t
+radv_get_wave_size(struct radv_device *device,
+ const VkPipelineShaderStageCreateInfo *pStage,
+ gl_shader_stage stage,
+ const struct radv_shader_variant_key *key)
+{
+ if (stage == MESA_SHADER_GEOMETRY && !key->vs_common_out.as_ngg)
+ return 64;
+ else if (stage == MESA_SHADER_COMPUTE) {
+ if (key->cs.subgroup_size) {
+ /* Return the required subgroup size if specified. */
+ return key->cs.subgroup_size;
+ }
+ return device->physical_device->cs_wave_size;
+ }
+ else if (stage == MESA_SHADER_FRAGMENT)
+ return device->physical_device->ps_wave_size;
+ else
+ return device->physical_device->ge_wave_size;
+}
+
+static uint8_t
+radv_get_ballot_bit_size(struct radv_device *device,
+ const VkPipelineShaderStageCreateInfo *pStage,
+ gl_shader_stage stage,
+ const struct radv_shader_variant_key *key)
+{
+ if (stage == MESA_SHADER_COMPUTE && key->cs.subgroup_size)
+ return key->cs.subgroup_size;
+ return 64;
+}
+
+static void
+radv_fill_shader_info(struct radv_pipeline *pipeline,
+ const VkPipelineShaderStageCreateInfo **pStages,
+ struct radv_shader_variant_key *keys,
+ struct radv_shader_info *infos,
+ nir_shader **nir)
+{
+ unsigned active_stages = 0;
+ unsigned filled_stages = 0;
+
+ for (int i = 0; i < MESA_SHADER_STAGES; i++) {
+ if (nir[i])
+ active_stages |= (1 << i);
+ }
+
+ if (nir[MESA_SHADER_FRAGMENT]) {
+ radv_nir_shader_info_init(&infos[MESA_SHADER_FRAGMENT]);
+ radv_nir_shader_info_pass(nir[MESA_SHADER_FRAGMENT],
+ pipeline->layout,
+ &keys[MESA_SHADER_FRAGMENT],
+ &infos[MESA_SHADER_FRAGMENT],
+ pipeline->device->physical_device->use_llvm);
+
+ /* TODO: These are no longer used as keys we should refactor this */
+ keys[MESA_SHADER_VERTEX].vs_common_out.export_prim_id =
+ infos[MESA_SHADER_FRAGMENT].ps.prim_id_input;
+ keys[MESA_SHADER_VERTEX].vs_common_out.export_layer_id =
+ infos[MESA_SHADER_FRAGMENT].ps.layer_input;
+ keys[MESA_SHADER_VERTEX].vs_common_out.export_clip_dists =
+ !!infos[MESA_SHADER_FRAGMENT].ps.num_input_clips_culls;
+ keys[MESA_SHADER_VERTEX].vs_common_out.export_viewport_index =
+ infos[MESA_SHADER_FRAGMENT].ps.viewport_index_input;
+ keys[MESA_SHADER_TESS_EVAL].vs_common_out.export_prim_id =
+ infos[MESA_SHADER_FRAGMENT].ps.prim_id_input;
+ keys[MESA_SHADER_TESS_EVAL].vs_common_out.export_layer_id =
+ infos[MESA_SHADER_FRAGMENT].ps.layer_input;
+ keys[MESA_SHADER_TESS_EVAL].vs_common_out.export_clip_dists =
+ !!infos[MESA_SHADER_FRAGMENT].ps.num_input_clips_culls;
+ keys[MESA_SHADER_TESS_EVAL].vs_common_out.export_viewport_index =
+ infos[MESA_SHADER_FRAGMENT].ps.viewport_index_input;
+
+ /* NGG passthrough mode can't be enabled for vertex shaders
+ * that export the primitive ID.
+ *
+ * TODO: I should really refactor the keys logic.
+ */
+ if (nir[MESA_SHADER_VERTEX] &&
+ keys[MESA_SHADER_VERTEX].vs_common_out.export_prim_id) {
+ keys[MESA_SHADER_VERTEX].vs_common_out.as_ngg_passthrough = false;
+ }
+
+ filled_stages |= (1 << MESA_SHADER_FRAGMENT);
+ }
+
+ if (nir[MESA_SHADER_TESS_CTRL]) {
+ infos[MESA_SHADER_TESS_CTRL].tcs.tes_inputs_read =
+ nir[MESA_SHADER_TESS_EVAL]->info.inputs_read;
+ infos[MESA_SHADER_TESS_CTRL].tcs.tes_patch_inputs_read =
+ nir[MESA_SHADER_TESS_EVAL]->info.patch_inputs_read;
+ }
+
+ if (pipeline->device->physical_device->rad_info.chip_class >= GFX9 &&
+ nir[MESA_SHADER_TESS_CTRL]) {
+ struct nir_shader *combined_nir[] = {nir[MESA_SHADER_VERTEX], nir[MESA_SHADER_TESS_CTRL]};
+ struct radv_shader_variant_key key = keys[MESA_SHADER_TESS_CTRL];
+ key.tcs.vs_key = keys[MESA_SHADER_VERTEX].vs;
+
+ radv_nir_shader_info_init(&infos[MESA_SHADER_TESS_CTRL]);
+
+ for (int i = 0; i < 2; i++) {
+ radv_nir_shader_info_pass(combined_nir[i],
+ pipeline->layout, &key,
+ &infos[MESA_SHADER_TESS_CTRL],
+ pipeline->device->physical_device->use_llvm);
+ }
+
+ keys[MESA_SHADER_TESS_EVAL].tes.num_patches =
+ infos[MESA_SHADER_TESS_CTRL].tcs.num_patches;
+ keys[MESA_SHADER_TESS_EVAL].tes.tcs_num_outputs =
+ util_last_bit64(infos[MESA_SHADER_TESS_CTRL].tcs.outputs_written);
+
+ filled_stages |= (1 << MESA_SHADER_VERTEX);
+ filled_stages |= (1 << MESA_SHADER_TESS_CTRL);
+ }
+
+ if (pipeline->device->physical_device->rad_info.chip_class >= GFX9 &&
+ nir[MESA_SHADER_GEOMETRY]) {
+ gl_shader_stage pre_stage = nir[MESA_SHADER_TESS_EVAL] ? MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;
+ struct nir_shader *combined_nir[] = {nir[pre_stage], nir[MESA_SHADER_GEOMETRY]};
+
+ radv_nir_shader_info_init(&infos[MESA_SHADER_GEOMETRY]);
+
+ for (int i = 0; i < 2; i++) {
+ radv_nir_shader_info_pass(combined_nir[i],
+ pipeline->layout,
+ &keys[pre_stage],
+ &infos[MESA_SHADER_GEOMETRY],
+ pipeline->device->physical_device->use_llvm);
+ }
+
+ filled_stages |= (1 << pre_stage);
+ filled_stages |= (1 << MESA_SHADER_GEOMETRY);
+ }
+
+ active_stages ^= filled_stages;
+ while (active_stages) {
+ int i = u_bit_scan(&active_stages);
+
+ if (i == MESA_SHADER_TESS_CTRL) {
+ keys[MESA_SHADER_TESS_CTRL].tcs.num_inputs =
+ util_last_bit64(infos[MESA_SHADER_VERTEX].vs.ls_outputs_written);
+ }
+
+ if (i == MESA_SHADER_TESS_EVAL) {
+ keys[MESA_SHADER_TESS_EVAL].tes.num_patches =
+ infos[MESA_SHADER_TESS_CTRL].tcs.num_patches;
+ keys[MESA_SHADER_TESS_EVAL].tes.tcs_num_outputs =
+ util_last_bit64(infos[MESA_SHADER_TESS_CTRL].tcs.outputs_written);
+ }
+
+ radv_nir_shader_info_init(&infos[i]);
+ radv_nir_shader_info_pass(nir[i], pipeline->layout,
+ &keys[i], &infos[i], pipeline->device->physical_device->use_llvm);
+ }
+
+ for (int i = 0; i < MESA_SHADER_STAGES; i++) {
+ if (nir[i]) {
+ infos[i].wave_size =
+ radv_get_wave_size(pipeline->device, pStages[i],
+ i, &keys[i]);
+ infos[i].ballot_bit_size =
+ radv_get_ballot_bit_size(pipeline->device,
+ pStages[i], i,
+ &keys[i]);
+ }
+ }