util/u_queue: track job size and limit the size of queue growth
[mesa.git] / src / gallium / drivers / radeonsi / si_state_shaders.c
index 003d116e8eb4711f0f75d18611bc77105cba7e7a..832e59828949c0da0279d2c95a2c5ffd3fba5722 100644 (file)
@@ -45,7 +45,7 @@
  * Return the IR binary in a buffer. For TGSI the first 4 bytes contain its
  * size as integer.
  */
-void *si_get_ir_binary(struct si_shader_selector *sel)
+void *si_get_ir_binary(struct si_shader_selector *sel, bool ngg, bool es)
 {
        struct blob blob;
        unsigned ir_size;
@@ -64,14 +64,29 @@ void *si_get_ir_binary(struct si_shader_selector *sel)
                ir_size = blob.size;
        }
 
-       unsigned size = 4 + ir_size + sizeof(sel->so);
+       /* These settings affect the compilation, but they are not derived
+        * from the input shader IR.
+        */
+       unsigned shader_variant_flags = 0;
+
+       if (ngg)
+               shader_variant_flags |= 1 << 0;
+       if (sel->nir)
+               shader_variant_flags |= 1 << 1;
+       if (si_get_wave_size(sel->screen, sel->type, ngg, es) == 32)
+               shader_variant_flags |= 1 << 2;
+       if (sel->force_correct_derivs_after_kill)
+               shader_variant_flags |= 1 << 3;
+
+       unsigned size = 4 + 4 + ir_size + sizeof(sel->so);
        char *result = (char*)MALLOC(size);
        if (!result)
                return NULL;
 
-       *((uint32_t*)result) = size;
-       memcpy(result + 4, ir_binary, ir_size);
-       memcpy(result + 4 + ir_size, &sel->so, sizeof(sel->so));
+       ((uint32_t*)result)[0] = size;
+       ((uint32_t*)result)[1] = shader_variant_flags;
+       memcpy(result + 8, ir_binary, ir_size);
+       memcpy(result + 8 + ir_size, &sel->so, sizeof(sel->so));
 
        if (sel->nir)
                blob_finish(&blob);
@@ -463,10 +478,34 @@ static unsigned si_get_num_vs_user_sgprs(unsigned num_always_on_user_sgprs)
        return num_always_on_user_sgprs + 1;
 }
 
+/* Return VGPR_COMP_CNT for the API vertex shader. This can be hw LS, LSHS, ES, ESGS, VS. */
+static unsigned si_get_vs_vgpr_comp_cnt(struct si_screen *sscreen,
+                                       struct si_shader *shader, bool legacy_vs_prim_id)
+{
+       assert(shader->selector->type == PIPE_SHADER_VERTEX ||
+              (shader->previous_stage_sel &&
+               shader->previous_stage_sel->type == PIPE_SHADER_VERTEX));
+
+       /* GFX6-9 LS    (VertexID, RelAutoindex,                InstanceID / StepRate0(==1), ...).
+        * GFX6-9 ES,VS (VertexID, InstanceID / StepRate0(==1), VSPrimID,                    ...)
+        * GFX10  LS    (VertexID, RelAutoindex,                UserVGPR1,                   InstanceID).
+        * GFX10  ES,VS (VertexID, UserVGPR0,                   UserVGPR1 or VSPrimID,       UserVGPR2 or InstanceID)
+        */
+       bool is_ls = shader->selector->type == PIPE_SHADER_TESS_CTRL || shader->key.as_ls;
+
+       if (sscreen->info.chip_class >= GFX10 && shader->info.uses_instanceid)
+               return 3;
+       else if ((is_ls && shader->info.uses_instanceid) || legacy_vs_prim_id)
+               return 2;
+       else if (is_ls || shader->info.uses_instanceid)
+               return 1;
+       else
+               return 0;
+}
+
 static void si_shader_ls(struct si_screen *sscreen, struct si_shader *shader)
 {
        struct si_pm4_state *pm4;
-       unsigned vgpr_comp_cnt;
        uint64_t va;
 
        assert(sscreen->info.chip_class <= GFX8);
@@ -478,18 +517,12 @@ static void si_shader_ls(struct si_screen *sscreen, struct si_shader *shader)
        va = shader->bo->gpu_address;
        si_pm4_add_bo(pm4, shader->bo, RADEON_USAGE_READ, RADEON_PRIO_SHADER_BINARY);
 
-       /* We need at least 2 components for LS.
-        * VGPR0-3: (VertexID, RelAutoindex, InstanceID / StepRate0, InstanceID).
-        * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
-        */
-       vgpr_comp_cnt = shader->info.uses_instanceid ? 2 : 1;
-
        si_pm4_set_reg(pm4, R_00B520_SPI_SHADER_PGM_LO_LS, va >> 8);
        si_pm4_set_reg(pm4, R_00B524_SPI_SHADER_PGM_HI_LS, S_00B524_MEM_BASE(va >> 40));
 
        shader->config.rsrc1 = S_00B528_VGPRS((shader->config.num_vgprs - 1) / 4) |
                           S_00B528_SGPRS((shader->config.num_sgprs - 1) / 8) |
-                          S_00B528_VGPR_COMP_CNT(vgpr_comp_cnt) |
+                          S_00B528_VGPR_COMP_CNT(si_get_vs_vgpr_comp_cnt(sscreen, shader, false)) |
                           S_00B528_DX10_CLAMP(1) |
                           S_00B528_FLOAT_MODE(shader->config.float_mode);
        shader->config.rsrc2 = S_00B52C_USER_SGPR(si_get_num_vs_user_sgprs(SI_VS_NUM_USER_SGPR)) |
@@ -500,7 +533,6 @@ static void si_shader_hs(struct si_screen *sscreen, struct si_shader *shader)
 {
        struct si_pm4_state *pm4;
        uint64_t va;
-       unsigned ls_vgpr_comp_cnt = 0;
 
        pm4 = si_get_shader_pm4_state(shader);
        if (!pm4)
@@ -518,20 +550,6 @@ static void si_shader_hs(struct si_screen *sscreen, struct si_shader *shader)
                        si_pm4_set_reg(pm4, R_00B414_SPI_SHADER_PGM_HI_LS, S_00B414_MEM_BASE(va >> 40));
                }
 
-               /* We need at least 2 components for LS.
-                * GFX9  VGPR0-3: (VertexID, RelAutoindex, InstanceID / StepRate0, InstanceID).
-                * GFX10 VGPR0-3: (VertexID, RelAutoindex, UserVGPR1, InstanceID).
-                * On gfx9, StepRate0 is set to 1 so that VGPR3 doesn't have to
-                * be loaded.
-                */
-               ls_vgpr_comp_cnt = 1;
-               if (shader->info.uses_instanceid) {
-                       if (sscreen->info.chip_class >= GFX10)
-                               ls_vgpr_comp_cnt = 3;
-                       else
-                               ls_vgpr_comp_cnt = 2;
-               }
-
                unsigned num_user_sgprs =
                        si_get_num_vs_user_sgprs(GFX9_TCS_NUM_USER_SGPR);
 
@@ -562,7 +580,8 @@ static void si_shader_hs(struct si_screen *sscreen, struct si_shader *shader)
                       S_00B428_MEM_ORDERED(sscreen->info.chip_class >= GFX10) |
                       S_00B428_WGP_MODE(sscreen->info.chip_class >= GFX10) |
                       S_00B428_FLOAT_MODE(shader->config.float_mode) |
-                      S_00B428_LS_VGPR_COMP_CNT(ls_vgpr_comp_cnt));
+                      S_00B428_LS_VGPR_COMP_CNT(sscreen->info.chip_class >= GFX9 ?
+                                                si_get_vs_vgpr_comp_cnt(sscreen, shader, false) : 0));
 
        if (sscreen->info.chip_class <= GFX8) {
                si_pm4_set_reg(pm4, R_00B42C_SPI_SHADER_PGM_RSRC2_HS,
@@ -615,8 +634,7 @@ static void si_shader_es(struct si_screen *sscreen, struct si_shader *shader)
        si_pm4_add_bo(pm4, shader->bo, RADEON_USAGE_READ, RADEON_PRIO_SHADER_BINARY);
 
        if (shader->selector->type == PIPE_SHADER_VERTEX) {
-               /* VGPR0-3: (VertexID, InstanceID / StepRate0, ...) */
-               vgpr_comp_cnt = shader->info.uses_instanceid ? 1 : 0;
+               vgpr_comp_cnt = si_get_vs_vgpr_comp_cnt(sscreen, shader, false);
                num_user_sgprs = si_get_num_vs_user_sgprs(SI_VS_NUM_USER_SGPR);
        } else if (shader->selector->type == PIPE_SHADER_TESS_EVAL) {
                vgpr_comp_cnt = shader->selector->info.uses_primid ? 3 : 2;
@@ -863,10 +881,9 @@ static void si_shader_gs(struct si_screen *sscreen, struct si_shader *shader)
                unsigned es_type = shader->key.part.gs.es->type;
                unsigned es_vgpr_comp_cnt, gs_vgpr_comp_cnt;
 
-               if (es_type == PIPE_SHADER_VERTEX)
-                       /* VGPR0-3: (VertexID, InstanceID / StepRate0, ...) */
-                       es_vgpr_comp_cnt = shader->info.uses_instanceid ? 1 : 0;
-               else if (es_type == PIPE_SHADER_TESS_EVAL)
+               if (es_type == PIPE_SHADER_VERTEX) {
+                       es_vgpr_comp_cnt = si_get_vs_vgpr_comp_cnt(sscreen, shader, false);
+               } else if (es_type == PIPE_SHADER_TESS_EVAL)
                        es_vgpr_comp_cnt = shader->key.part.gs.es->info.uses_primid ? 3 : 2;
                else
                        unreachable("invalid shader selector type");
@@ -987,6 +1004,11 @@ static void gfx10_emit_shader_ngg_tail(struct si_context *sctx,
                                   SI_TRACKED_PA_CL_NGG_CNTL,
                                   shader->ctx_reg.ngg.pa_cl_ngg_cntl);
 
+       radeon_opt_set_context_reg_rmw(sctx, R_02881C_PA_CL_VS_OUT_CNTL,
+                                      SI_TRACKED_PA_CL_VS_OUT_CNTL__VS,
+                                      shader->pa_cl_vs_out_cntl,
+                                      SI_TRACKED_PA_CL_VS_OUT_CNTL__VS_MASK);
+
        if (initial_cdw != sctx->gfx_cs->current.cdw)
                sctx->context_roll = true;
 }
@@ -1067,6 +1089,19 @@ unsigned si_get_input_prim(const struct si_shader_selector *gs)
        return PIPE_PRIM_TRIANGLES; /* worst case for all callers */
 }
 
+static unsigned si_get_vs_out_cntl(const struct si_shader_selector *sel, bool ngg)
+{
+       bool misc_vec_ena =
+               sel->info.writes_psize || (sel->info.writes_edgeflag && !ngg) ||
+               sel->info.writes_layer || sel->info.writes_viewport_index;
+       return S_02881C_USE_VTX_POINT_SIZE(sel->info.writes_psize) |
+              S_02881C_USE_VTX_EDGE_FLAG(sel->info.writes_edgeflag && !ngg) |
+              S_02881C_USE_VTX_RENDER_TARGET_INDX(sel->info.writes_layer) |
+              S_02881C_USE_VTX_VIEWPORT_INDX(sel->info.writes_viewport_index) |
+              S_02881C_VS_OUT_MISC_VEC_ENA(misc_vec_ena) |
+              S_02881C_VS_OUT_MISC_SIDE_BUS_ENA(misc_vec_ena);
+}
+
 /**
  * Prepare the PM4 image for \p shader, which will run as a merged ESGS shader
  * in NGG mode.
@@ -1105,8 +1140,7 @@ static void gfx10_shader_ngg(struct si_screen *sscreen, struct si_shader *shader
        si_pm4_add_bo(pm4, shader->bo, RADEON_USAGE_READ, RADEON_PRIO_SHADER_BINARY);
 
        if (es_type == PIPE_SHADER_VERTEX) {
-               /* VGPR5-8: (VertexID, UserVGPR0, UserVGPR1, UserVGPR2 / InstanceID) */
-               es_vgpr_comp_cnt = shader->info.uses_instanceid ? 3 : 0;
+               es_vgpr_comp_cnt = si_get_vs_vgpr_comp_cnt(sscreen, shader, false);
 
                if (es_info->properties[TGSI_PROPERTY_VS_BLIT_SGPRS_AMD]) {
                        num_user_sgprs = SI_SGPR_VS_BLIT_DATA +
@@ -1212,6 +1246,7 @@ static void gfx10_shader_ngg(struct si_screen *sscreen, struct si_shader *shader
         */
        shader->ctx_reg.ngg.pa_cl_ngg_cntl =
                S_028838_INDEX_BUF_EDGE_FLAG_ENA(gs_type == PIPE_SHADER_VERTEX);
+       shader->pa_cl_vs_out_cntl = si_get_vs_out_cntl(gs_sel, true);
 
        shader->ge_cntl =
                S_03096C_PRIM_GRP_SIZE(shader->ngg.max_gsprims) |
@@ -1293,6 +1328,23 @@ static void si_emit_shader_vs(struct si_context *sctx)
 
        if (initial_cdw != sctx->gfx_cs->current.cdw)
                sctx->context_roll = true;
+
+       /* Required programming for tessellation. (legacy pipeline only) */
+       if (sctx->chip_class == GFX10 &&
+           shader->selector->type == PIPE_SHADER_TESS_EVAL) {
+               radeon_opt_set_context_reg(sctx, R_028A44_VGT_GS_ONCHIP_CNTL,
+                                          SI_TRACKED_VGT_GS_ONCHIP_CNTL,
+                                          S_028A44_ES_VERTS_PER_SUBGRP(250) |
+                                          S_028A44_GS_PRIMS_PER_SUBGRP(126) |
+                                          S_028A44_GS_INST_PRIMS_IN_SUBGRP(126));
+       }
+
+       if (sctx->chip_class >= GFX10) {
+               radeon_opt_set_context_reg_rmw(sctx, R_02881C_PA_CL_VS_OUT_CNTL,
+                                              SI_TRACKED_PA_CL_VS_OUT_CNTL__VS,
+                                              shader->pa_cl_vs_out_cntl,
+                                              SI_TRACKED_PA_CL_VS_OUT_CNTL__VS_MASK);
+       }
 }
 
 /**
@@ -1355,15 +1407,7 @@ static void si_shader_vs(struct si_screen *sscreen, struct si_shader *shader,
                vgpr_comp_cnt = 0; /* only VertexID is needed for GS-COPY. */
                num_user_sgprs = SI_GSCOPY_NUM_USER_SGPR;
        } else if (shader->selector->type == PIPE_SHADER_VERTEX) {
-               if (sscreen->info.chip_class >= GFX10) {
-                       vgpr_comp_cnt = shader->info.uses_instanceid ? 3 : (enable_prim_id ? 2 : 0);
-               } else {
-                       /* VGPR0-3: (VertexID, InstanceID / StepRate0, PrimID, InstanceID)
-                        * If PrimID is disabled. InstanceID / StepRate1 is loaded instead.
-                        * StepRate0 is set to 1. so that VGPR3 doesn't have to be loaded.
-                        */
-                       vgpr_comp_cnt = enable_prim_id ? 2 : (shader->info.uses_instanceid ? 1 : 0);
-               }
+               vgpr_comp_cnt = si_get_vs_vgpr_comp_cnt(sscreen, shader, enable_prim_id);
 
                if (info->properties[TGSI_PROPERTY_VS_BLIT_SGPRS_AMD]) {
                        num_user_sgprs = SI_SGPR_VS_BLIT_DATA +
@@ -1397,6 +1441,7 @@ static void si_shader_vs(struct si_screen *sscreen, struct si_shader *shader,
                        S_02870C_POS3_EXPORT_FORMAT(shader->info.nr_pos_exports > 3 ?
                                                    V_02870C_SPI_SHADER_4COMP :
                                                    V_02870C_SPI_SHADER_NONE);
+       shader->pa_cl_vs_out_cntl = si_get_vs_out_cntl(shader->selector, false);
 
        oc_lds_en = shader->selector->type == PIPE_SHADER_TESS_EVAL ? 1 : 0;
 
@@ -1809,9 +1854,10 @@ static inline void si_shader_selector_key(struct pipe_context *ctx,
 
                if (sctx->tes_shader.cso)
                        key->as_ls = 1;
-               else if (sctx->gs_shader.cso)
+               else if (sctx->gs_shader.cso) {
                        key->as_es = 1;
-               else {
+                       key->as_ngg = stages_key.u.ngg;
+               } else {
                        key->as_ngg = stages_key.u.ngg;
                        si_shader_selector_key_hw_vs(sctx, sel, key);
 
@@ -2259,16 +2305,14 @@ current_not_ready:
                if (previous_stage_sel) {
                        struct si_shader_key shader1_key = zeroed;
 
-                       if (sel->type == PIPE_SHADER_TESS_CTRL)
+                       if (sel->type == PIPE_SHADER_TESS_CTRL) {
                                shader1_key.as_ls = 1;
-                       else if (sel->type == PIPE_SHADER_GEOMETRY)
+                       } else if (sel->type == PIPE_SHADER_GEOMETRY) {
                                shader1_key.as_es = 1;
-                       else
+                               shader1_key.as_ngg = key->as_ngg; /* for Wave32 vs Wave64 */
+                       } else {
                                assert(0);
-
-                       if (sel->type == PIPE_SHADER_GEOMETRY &&
-                           previous_stage_sel->type == PIPE_SHADER_TESS_EVAL)
-                               shader1_key.as_ngg = key->as_ngg;
+                       }
 
                        mtx_lock(&previous_stage_sel->mutex);
                        ok = si_check_missing_main_part(sscreen,
@@ -2314,7 +2358,8 @@ current_not_ready:
                /* Compile it asynchronously. */
                util_queue_add_job(&sscreen->shader_compiler_queue_low_priority,
                                   shader, &shader->ready,
-                                  si_build_shader_variant_low_priority, NULL);
+                                  si_build_shader_variant_low_priority, NULL,
+                                  0);
 
                /* Add only after the ready fence was reset, to guard against a
                 * race with si_bind_XX_shader. */
@@ -2455,14 +2500,15 @@ static void si_init_shader_selector_async(void *job, int thread_index)
 
                if (sscreen->use_ngg &&
                    (!sel->so.num_outputs || sscreen->use_ngg_streamout) &&
-                   ((sel->type == PIPE_SHADER_VERTEX &&
-                     !shader->key.as_ls && !shader->key.as_es) ||
+                   ((sel->type == PIPE_SHADER_VERTEX && !shader->key.as_ls) ||
                     sel->type == PIPE_SHADER_TESS_EVAL ||
                     sel->type == PIPE_SHADER_GEOMETRY))
                        shader->key.as_ngg = 1;
 
-               if (sel->tokens || sel->nir)
-                       ir_binary = si_get_ir_binary(sel);
+               if (sel->tokens || sel->nir) {
+                       ir_binary = si_get_ir_binary(sel, shader->key.as_ngg,
+                                                    shader->key.as_es);
+               }
 
                /* Try to load the shader from the shader cache. */
                mtx_lock(&sscreen->shader_cache_mutex);
@@ -2538,7 +2584,9 @@ static void si_init_shader_selector_async(void *job, int thread_index)
 
        /* The GS copy shader is always pre-compiled. */
        if (sel->type == PIPE_SHADER_GEOMETRY &&
-           (!sscreen->use_ngg || sel->tess_turns_off_ngg)) {
+           (!sscreen->use_ngg ||
+            !sscreen->use_ngg_streamout || /* also for PRIMITIVES_GENERATED */
+            sel->tess_turns_off_ngg)) {
                sel->gs_copy_shader = si_generate_gs_copy_shader(sscreen, compiler, sel, debug);
                if (!sel->gs_copy_shader) {
                        fprintf(stderr, "radeonsi: can't create GS copy shader\n");
@@ -2568,7 +2616,7 @@ void si_schedule_initial_compile(struct si_context *sctx, unsigned processor,
        }
 
        util_queue_add_job(&sctx->screen->shader_compiler_queue, job,
-                          ready_fence, execute, NULL);
+                          ready_fence, execute, NULL, 0);
 
        if (debug) {
                util_queue_fence_wait(ready_fence);
@@ -2694,14 +2742,6 @@ static void *si_create_shader_selector(struct pipe_context *ctx,
                !sel->info.properties[TGSI_PROPERTY_VS_WINDOW_SPACE_POSITION] &&
                !sel->so.num_outputs;
 
-       if (sel->type == PIPE_SHADER_VERTEX &&
-           sel->info.writes_edgeflag) {
-               if (sscreen->info.chip_class >= GFX10)
-                       sel->ngg_writes_edgeflag = true;
-               else
-                       sel->pos_writes_edgeflag = true;
-       }
-
        switch (sel->type) {
        case PIPE_SHADER_GEOMETRY:
                sel->gs_output_prim =
@@ -2832,16 +2872,9 @@ static void *si_create_shader_selector(struct pipe_context *ctx,
        }
 
        /* PA_CL_VS_OUT_CNTL */
-       bool misc_vec_ena =
-               sel->info.writes_psize || sel->pos_writes_edgeflag ||
-               sel->info.writes_layer || sel->info.writes_viewport_index;
-       sel->pa_cl_vs_out_cntl =
-               S_02881C_USE_VTX_POINT_SIZE(sel->info.writes_psize) |
-               S_02881C_USE_VTX_EDGE_FLAG(sel->pos_writes_edgeflag) |
-               S_02881C_USE_VTX_RENDER_TARGET_INDX(sel->info.writes_layer) |
-               S_02881C_USE_VTX_VIEWPORT_INDX(sel->info.writes_viewport_index) |
-               S_02881C_VS_OUT_MISC_VEC_ENA(misc_vec_ena) |
-               S_02881C_VS_OUT_MISC_SIDE_BUS_ENA(misc_vec_ena);
+       if (sctx->chip_class <= GFX9)
+               sel->pa_cl_vs_out_cntl = si_get_vs_out_cntl(sel, false);
+
        sel->clipdist_mask = sel->info.writes_clipvertex ?
                                     SIX_BITS : sel->info.clipdist_writemask;
        sel->culldist_mask = sel->info.culldist_writemask <<
@@ -2957,8 +2990,6 @@ static void si_update_common_shader_state(struct si_context *sctx)
        sctx->do_update_shaders = true;
 }
 
-static bool si_update_ngg(struct si_context *sctx);
-
 static void si_bind_vs_shader(struct pipe_context *ctx, void *state)
 {
        struct si_context *sctx = (struct si_context *)ctx;
@@ -2997,7 +3028,7 @@ static void si_update_tess_uses_prim_id(struct si_context *sctx)
                 sctx->ps_shader.cso->info.uses_primid);
 }
 
-static bool si_update_ngg(struct si_context *sctx)
+bool si_update_ngg(struct si_context *sctx)
 {
        if (!sctx->screen->use_ngg) {
                assert(!sctx->ngg);
@@ -3012,7 +3043,8 @@ static bool si_update_ngg(struct si_context *sctx)
        } else if (!sctx->screen->use_ngg_streamout) {
                struct si_shader_selector *last = si_get_vs(sctx)->cso;
 
-               if (last && last->so.num_outputs)
+               if ((last && last->so.num_outputs) ||
+                   sctx->streamout.prims_gen_query_enabled)
                        new_ngg = false;
        }
 
@@ -3456,7 +3488,8 @@ static bool si_update_gs_ring_buffers(struct si_context *sctx)
                        pipe_aligned_buffer_create(sctx->b.screen,
                                                   SI_RESOURCE_FLAG_UNMAPPABLE,
                                                   PIPE_USAGE_DEFAULT,
-                                                  esgs_ring_size, alignment);
+                                                  esgs_ring_size,
+                                                  sctx->screen->info.pte_fragment_size);
                if (!sctx->esgs_ring)
                        return false;
        }
@@ -3467,7 +3500,8 @@ static bool si_update_gs_ring_buffers(struct si_context *sctx)
                        pipe_aligned_buffer_create(sctx->b.screen,
                                                   SI_RESOURCE_FLAG_UNMAPPABLE,
                                                   PIPE_USAGE_DEFAULT,
-                                                  gsvs_ring_size, alignment);
+                                                  gsvs_ring_size,
+                                                  sctx->screen->info.pte_fragment_size);
                if (!sctx->gsvs_ring)
                        return false;
        }
@@ -3592,11 +3626,6 @@ static int si_update_scratch_buffer(struct si_context *sctx,
        return 1;
 }
 
-static unsigned si_get_current_scratch_buffer_size(struct si_context *sctx)
-{
-       return sctx->scratch_buffer ? sctx->scratch_buffer->b.b.width0 : 0;
-}
-
 static unsigned si_get_scratch_buffer_bytes_per_wave(struct si_shader *shader)
 {
        return shader ? shader->config.scratch_bytes_per_wave : 0;
@@ -3611,23 +3640,6 @@ static struct si_shader *si_get_tcs_current(struct si_context *sctx)
                                      sctx->fixed_func_tcs_shader.current;
 }
 
-static unsigned si_get_max_scratch_bytes_per_wave(struct si_context *sctx)
-{
-       unsigned bytes = 0;
-
-       bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->ps_shader.current));
-       bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->gs_shader.current));
-       bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->vs_shader.current));
-       bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->tes_shader.current));
-
-       if (sctx->tes_shader.cso) {
-               struct si_shader *tcs = si_get_tcs_current(sctx);
-
-               bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(tcs));
-       }
-       return bytes;
-}
-
 static bool si_update_scratch_relocs(struct si_context *sctx)
 {
        struct si_shader *tcs = si_get_tcs_current(sctx);
@@ -3689,24 +3701,49 @@ static bool si_update_scratch_relocs(struct si_context *sctx)
 
 static bool si_update_spi_tmpring_size(struct si_context *sctx)
 {
-       unsigned current_scratch_buffer_size =
-               si_get_current_scratch_buffer_size(sctx);
-       unsigned scratch_bytes_per_wave =
-               si_get_max_scratch_bytes_per_wave(sctx);
-       unsigned scratch_needed_size = scratch_bytes_per_wave *
-               sctx->scratch_waves;
+       /* SPI_TMPRING_SIZE.WAVESIZE must be constant for each scratch buffer.
+        * There are 2 cases to handle:
+        *
+        * - If the current needed size is less than the maximum seen size,
+        *   use the maximum seen size, so that WAVESIZE remains the same.
+        *
+        * - If the current needed size is greater than the maximum seen size,
+        *   the scratch buffer is reallocated, so we can increase WAVESIZE.
+        *
+        * Shaders that set SCRATCH_EN=0 don't allocate scratch space.
+        * Otherwise, the number of waves that can use scratch is
+        * SPI_TMPRING_SIZE.WAVES.
+        */
+       unsigned bytes = 0;
+
+       bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->ps_shader.current));
+       bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->gs_shader.current));
+       bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->vs_shader.current));
+
+       if (sctx->tes_shader.cso) {
+               bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->tes_shader.current));
+               bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(si_get_tcs_current(sctx)));
+       }
+
+       sctx->max_seen_scratch_bytes_per_wave =
+               MAX2(sctx->max_seen_scratch_bytes_per_wave, bytes);
+
+       unsigned scratch_needed_size =
+               sctx->max_seen_scratch_bytes_per_wave * sctx->scratch_waves;
        unsigned spi_tmpring_size;
 
        if (scratch_needed_size > 0) {
-               if (scratch_needed_size > current_scratch_buffer_size) {
+               if (!sctx->scratch_buffer ||
+                   scratch_needed_size > sctx->scratch_buffer->b.b.width0) {
                        /* Create a bigger scratch buffer */
                        si_resource_reference(&sctx->scratch_buffer, NULL);
 
                        sctx->scratch_buffer =
                                si_aligned_buffer_create(&sctx->screen->b,
-                                                          SI_RESOURCE_FLAG_UNMAPPABLE,
-                                                          PIPE_USAGE_DEFAULT,
-                                                          scratch_needed_size, 256);
+                                                        SI_RESOURCE_FLAG_UNMAPPABLE,
+                                                        PIPE_USAGE_DEFAULT,
+                                                        scratch_needed_size,
+                                                        sctx->screen->info.pte_fragment_size);
                        if (!sctx->scratch_buffer)
                                return false;
 
@@ -3724,7 +3761,7 @@ static bool si_update_spi_tmpring_size(struct si_context *sctx)
                "scratch size should already be aligned correctly.");
 
        spi_tmpring_size = S_0286E8_WAVES(sctx->scratch_waves) |
-                          S_0286E8_WAVESIZE(scratch_bytes_per_wave >> 10);
+                          S_0286E8_WAVESIZE(sctx->max_seen_scratch_bytes_per_wave >> 10);
        if (spi_tmpring_size != sctx->spi_tmpring_size) {
                sctx->spi_tmpring_size = spi_tmpring_size;
                si_mark_atom_dirty(sctx, &sctx->atoms.s.scratch_state);
@@ -3992,7 +4029,7 @@ bool si_update_shaders(struct si_context *sctx)
                        si_mark_atom_dirty(sctx, &sctx->atoms.s.spi_map);
                }
 
-               if (sctx->screen->rbplus_allowed &&
+               if (sctx->screen->info.rbplus_allowed &&
                    si_pm4_state_changed(sctx, ps) &&
                    (!old_ps ||
                     old_spi_shader_col_format !=