- unsigned current_scratch_buffer_size =
- si_get_current_scratch_buffer_size(sctx);
- unsigned scratch_bytes_per_wave =
- si_get_max_scratch_bytes_per_wave(sctx);
- unsigned scratch_needed_size = scratch_bytes_per_wave *
- sctx->scratch_waves;
+ /* SPI_TMPRING_SIZE.WAVESIZE must be constant for each scratch buffer.
+ * There are 2 cases to handle:
+ *
+ * - If the current needed size is less than the maximum seen size,
+ * use the maximum seen size, so that WAVESIZE remains the same.
+ *
+ * - If the current needed size is greater than the maximum seen size,
+ * the scratch buffer is reallocated, so we can increase WAVESIZE.
+ *
+ * Shaders that set SCRATCH_EN=0 don't allocate scratch space.
+ * Otherwise, the number of waves that can use scratch is
+ * SPI_TMPRING_SIZE.WAVES.
+ */
+ unsigned bytes = 0;
+
+ bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->ps_shader.current));
+ bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->gs_shader.current));
+ bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->vs_shader.current));
+
+ if (sctx->tes_shader.cso) {
+ bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(sctx->tes_shader.current));
+ bytes = MAX2(bytes, si_get_scratch_buffer_bytes_per_wave(si_get_tcs_current(sctx)));
+ }
+
+ sctx->max_seen_scratch_bytes_per_wave =
+ MAX2(sctx->max_seen_scratch_bytes_per_wave, bytes);
+
+ unsigned scratch_needed_size =
+ sctx->max_seen_scratch_bytes_per_wave * sctx->scratch_waves;