static inline unsigned si_get_wave_size(struct si_screen *sscreen,
enum pipe_shader_type shader_type,
- bool ngg, bool es)
+ bool ngg, bool es, bool prim_discard_cs)
{
if (shader_type == PIPE_SHADER_COMPUTE)
return sscreen->compute_wave_size;
else if (shader_type == PIPE_SHADER_FRAGMENT)
return sscreen->ps_wave_size;
- else if ((shader_type == PIPE_SHADER_VERTEX && es && !ngg) ||
+ else if ((shader_type == PIPE_SHADER_VERTEX && prim_discard_cs) || /* only Wave64 implemented */
+ (shader_type == PIPE_SHADER_VERTEX && es && !ngg) ||
(shader_type == PIPE_SHADER_TESS_EVAL && es && !ngg) ||
(shader_type == PIPE_SHADER_GEOMETRY && !ngg)) /* legacy GS only supports Wave64 */
return 64;
static inline unsigned si_get_shader_wave_size(struct si_shader *shader)
{
return si_get_wave_size(shader->selector->screen, shader->selector->type,
- shader->key.as_ngg, shader->key.as_es);
+ shader->key.as_ngg, shader->key.as_es,
+ shader->key.opt.vs_as_prim_discard_cs);
}
#define PRINT_ERR(fmt, args...) \