struct radv_shader_args *args)
{
isel_context ctx = setup_isel_context(program, shader_count, shaders, config, args, false);
+ if_context ic_merged_wave_info;
for (unsigned i = 0; i < shader_count; i++) {
nir_shader *nir = shaders[i];
(nir->info.stage == MESA_SHADER_TESS_EVAL &&
ctx.stage == tess_eval_geometry_gs));
- if_context ic;
- if (shader_count >= 2 && !empty_shader) {
+ bool check_merged_wave_info = ctx.tcs_in_out_eq ? i == 0 : (shader_count >= 2 && !empty_shader);
+ bool endif_merged_wave_info = ctx.tcs_in_out_eq ? i == 1 : check_merged_wave_info;
+ if (check_merged_wave_info) {
Builder bld(ctx.program, ctx.block);
Temp count = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), get_arg(&ctx, args->merged_wave_info), Operand((8u << 16) | (i * 8u)));
Temp thread_id = emit_mbcnt(&ctx, bld.def(v1));
Temp cond = bld.vopc(aco_opcode::v_cmp_gt_u32, bld.hint_vcc(bld.def(bld.lm)), count, thread_id);
- begin_divergent_if_then(&ctx, &ic, cond);
+ begin_divergent_if_then(&ctx, &ic_merged_wave_info, cond);
}
if (i) {
if (ctx.stage == fragment_fs)
create_fs_exports(&ctx);
- if (shader_count >= 2 && !empty_shader) {
- begin_divergent_if_else(&ctx, &ic);
- end_divergent_if(&ctx, &ic);
+ if (endif_merged_wave_info) {
+ begin_divergent_if_else(&ctx, &ic_merged_wave_info);
+ end_divergent_if(&ctx, &ic_merged_wave_info);
}
ralloc_free(ctx.divergent_vals);
unsigned tcs_tess_lvl_in_loc;
uint32_t tcs_num_inputs;
uint32_t tcs_num_patches;
+ bool tcs_in_out_eq = false;
/* VS, FS or GS output information */
output_state outputs;
unreachable("Unsupported TCS shader stage");
}
+ /* When the number of TCS input and output vertices are the same (typically 3):
+ * - There is an equal amount of LS and HS invocations
+ * - In case of merged LSHS shaders, the LS and HS halves of the shader
+ * always process the exact same vertex. We can use this knowledge to optimize them.
+ */
+ ctx->tcs_in_out_eq =
+ ctx->stage == vertex_tess_control_hs &&
+ ctx->args->options->key.tcs.input_vertices == nir->info.tess.tcs_vertices_out;
+
ctx->tcs_num_patches = get_tcs_num_patches(
ctx->args->options->key.tcs.input_vertices,
nir->info.tess.tcs_vertices_out,