radeonsi/gfx10: implement NGG culling for 4x wave32 subgroups
[mesa.git] / src / gallium / drivers / radeonsi / gfx10_shader_ngg.c
index a25c89bac56a903e282366ba7513c2fe789bb689..8092b796b5d7b121f9ea07f6966280e61ecad7a7 100644 (file)
@@ -28,6 +28,7 @@
 
 #include "util/u_memory.h"
 #include "util/u_prim.h"
+#include "ac_llvm_cull.h"
 
 static LLVMValueRef get_wave_id_in_tg(struct si_shader_context *ctx)
 {
@@ -141,14 +142,44 @@ void gfx10_ngg_build_sendmsg_gs_alloc_req(struct si_shader_context *ctx)
 }
 
 void gfx10_ngg_build_export_prim(struct si_shader_context *ctx,
-                                LLVMValueRef user_edgeflags[3])
+                                LLVMValueRef user_edgeflags[3],
+                                LLVMValueRef prim_passthrough)
 {
-       if (gfx10_is_ngg_passthrough(ctx->shader)) {
+       LLVMBuilderRef builder = ctx->ac.builder;
+
+       if (gfx10_is_ngg_passthrough(ctx->shader) ||
+           ctx->shader->key.opt.ngg_culling) {
                ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 6001);
                {
                        struct ac_ngg_prim prim = {};
 
-                       prim.passthrough = ac_get_arg(&ctx->ac, ctx->gs_vtx01_offset);
+                       if (prim_passthrough)
+                               prim.passthrough = prim_passthrough;
+                       else
+                               prim.passthrough = ac_get_arg(&ctx->ac, ctx->gs_vtx01_offset);
+
+                       /* This is only used with NGG culling, which returns the NGG
+                        * passthrough prim export encoding.
+                        */
+                       if (ctx->shader->selector->info.writes_edgeflag) {
+                               unsigned all_bits_no_edgeflags = ~SI_NGG_PRIM_EDGE_FLAG_BITS;
+                               LLVMValueRef edgeflags = LLVMConstInt(ctx->i32, all_bits_no_edgeflags, 0);
+
+                               unsigned num_vertices;
+                               ngg_get_vertices_per_prim(ctx, &num_vertices);
+
+                               for (unsigned i = 0; i < num_vertices; i++) {
+                                       unsigned shift = 9 + i*10;
+                                       LLVMValueRef edge;
+
+                                       edge = LLVMBuildLoad(builder, user_edgeflags[i], "");
+                                       edge = LLVMBuildZExt(builder, edge, ctx->i32, "");
+                                       edge = LLVMBuildShl(builder, edge, LLVMConstInt(ctx->i32, shift, 0), "");
+                                       edgeflags = LLVMBuildOr(builder, edgeflags, edge, "");
+                               }
+                               prim.passthrough = LLVMBuildAnd(builder, prim.passthrough, edgeflags, "");
+                       }
+
                        ac_build_export_prim(&ctx->ac, &prim);
                }
                ac_build_endif(&ctx->ac, 6001);
@@ -535,6 +566,51 @@ static void build_streamout(struct si_shader_context *ctx,
        }
 }
 
+/* LDS layout of ES vertex data for NGG culling. */
+enum {
+       /* Byte 0: Boolean ES thread accepted (unculled) flag, and later the old
+        *         ES thread ID. After vertex compaction, compacted ES threads
+        *         store the old thread ID here to copy input VGPRs from uncompacted
+        *         ES threads.
+        * Byte 1: New ES thread ID, loaded by GS to prepare the prim export value.
+        * Byte 2: TES rel patch ID
+        * Byte 3: Unused
+        */
+       lds_byte0_accept_flag = 0,
+       lds_byte0_old_thread_id = 0,
+       lds_byte1_new_thread_id,
+       lds_byte2_tes_rel_patch_id,
+       lds_byte3_unused,
+
+       lds_packed_data = 0, /* lds_byteN_... */
+
+       lds_pos_x,
+       lds_pos_y,
+       lds_pos_z,
+       lds_pos_w,
+       lds_pos_x_div_w,
+       lds_pos_y_div_w,
+       /* If VS: */
+       lds_vertex_id,
+       lds_instance_id, /* optional */
+       /* If TES: */
+       lds_tes_u = lds_vertex_id,
+       lds_tes_v = lds_instance_id,
+       lds_tes_patch_id, /* optional */
+};
+
+static LLVMValueRef si_build_gep_i8(struct si_shader_context *ctx,
+                                   LLVMValueRef ptr, unsigned byte_index)
+{
+       assert(byte_index < 4);
+       LLVMTypeRef pi8 = LLVMPointerType(ctx->i8, AC_ADDR_SPACE_LDS);
+       LLVMValueRef index = LLVMConstInt(ctx->i32, byte_index, 0);
+
+       return LLVMBuildGEP(ctx->ac.builder,
+                           LLVMBuildPointerCast(ctx->ac.builder, ptr, pi8, ""),
+                           &index, 1, "");
+}
+
 static unsigned ngg_nogs_vertex_size(struct si_shader *shader)
 {
        unsigned lds_vertex_size = 0;
@@ -555,6 +631,24 @@ static unsigned ngg_nogs_vertex_size(struct si_shader *shader)
            shader->key.mono.u.vs_export_prim_id)
                lds_vertex_size = MAX2(lds_vertex_size, 1);
 
+       if (shader->key.opt.ngg_culling) {
+               if (shader->selector->type == PIPE_SHADER_VERTEX) {
+                       STATIC_ASSERT(lds_instance_id + 1 == 9);
+                       lds_vertex_size = MAX2(lds_vertex_size, 9);
+               } else {
+                       assert(shader->selector->type == PIPE_SHADER_TESS_EVAL);
+
+                       if (shader->selector->info.uses_primid ||
+                           shader->key.mono.u.vs_export_prim_id) {
+                               STATIC_ASSERT(lds_tes_patch_id + 2 == 11);
+                               lds_vertex_size = MAX2(lds_vertex_size, 11);
+                       } else {
+                               STATIC_ASSERT(lds_tes_v + 1 == 9);
+                               lds_vertex_size = MAX2(lds_vertex_size, 9);
+                       }
+               }
+       }
+
        return lds_vertex_size;
 }
 
@@ -573,6 +667,540 @@ static LLVMValueRef ngg_nogs_vertex_ptr(struct si_shader_context *ctx,
        return LLVMBuildGEP(ctx->ac.builder, tmp, &vtxid, 1, "");
 }
 
+static void load_bitmasks_2x64(struct si_shader_context *ctx,
+                              LLVMValueRef lds_ptr, unsigned dw_offset,
+                              LLVMValueRef mask[2], LLVMValueRef *total_bitcount)
+{
+       LLVMBuilderRef builder = ctx->ac.builder;
+       LLVMValueRef ptr64 = LLVMBuildPointerCast(builder, lds_ptr,
+                                                 LLVMPointerType(LLVMArrayType(ctx->i64, 2),
+                                                                 AC_ADDR_SPACE_LDS), "");
+       for (unsigned i = 0; i < 2; i++) {
+               LLVMValueRef index = LLVMConstInt(ctx->i32, dw_offset / 2 + i, 0);
+               mask[i] = LLVMBuildLoad(builder, ac_build_gep0(&ctx->ac, ptr64, index), "");
+       }
+
+       /* We get better code if we don't use the 128-bit bitcount. */
+       *total_bitcount = LLVMBuildAdd(builder, ac_build_bit_count(&ctx->ac, mask[0]),
+                                      ac_build_bit_count(&ctx->ac, mask[1]), "");
+}
+
+/**
+ * Given a total thread count, update total and per-wave thread counts in input SGPRs
+ * and return the per-wave thread count.
+ *
+ * \param new_num_threads    Total thread count on the input, per-wave thread count on the output.
+ * \param tg_info           tg_info SGPR value
+ * \param tg_info_num_bits   the bit size of thread count field in tg_info
+ * \param tg_info_shift      the bit offset of the thread count field in tg_info
+ * \param wave_info          merged_wave_info SGPR value
+ * \param wave_info_num_bits the bit size of thread count field in merged_wave_info
+ * \param wave_info_shift    the bit offset of the thread count field in merged_wave_info
+ */
+static void update_thread_counts(struct si_shader_context *ctx,
+                                LLVMValueRef *new_num_threads,
+                                LLVMValueRef *tg_info,
+                                unsigned tg_info_num_bits,
+                                unsigned tg_info_shift,
+                                LLVMValueRef *wave_info,
+                                unsigned wave_info_num_bits,
+                                unsigned wave_info_shift)
+{
+       LLVMBuilderRef builder = ctx->ac.builder;
+
+       /* Update the total thread count. */
+       unsigned tg_info_mask = ~(u_bit_consecutive(0, tg_info_num_bits) << tg_info_shift);
+       *tg_info = LLVMBuildAnd(builder, *tg_info,
+                               LLVMConstInt(ctx->i32, tg_info_mask, 0), "");
+       *tg_info = LLVMBuildOr(builder, *tg_info,
+                              LLVMBuildShl(builder, *new_num_threads,
+                                           LLVMConstInt(ctx->i32, tg_info_shift, 0), ""), "");
+
+       /* Update the per-wave thread count. */
+       LLVMValueRef prev_threads = LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
+                                                LLVMConstInt(ctx->i32, ctx->ac.wave_size, 0), "");
+       *new_num_threads = LLVMBuildSub(builder, *new_num_threads, prev_threads, "");
+       *new_num_threads = ac_build_imax(&ctx->ac, *new_num_threads, ctx->i32_0);
+       *new_num_threads = ac_build_imin(&ctx->ac, *new_num_threads,
+                                       LLVMConstInt(ctx->i32, ctx->ac.wave_size, 0));
+       unsigned wave_info_mask = ~(u_bit_consecutive(0, wave_info_num_bits) << wave_info_shift);
+       *wave_info = LLVMBuildAnd(builder, *wave_info,
+                                 LLVMConstInt(ctx->i32, wave_info_mask, 0), "");
+       *wave_info = LLVMBuildOr(builder, *wave_info,
+                                LLVMBuildShl(builder, *new_num_threads,
+                                             LLVMConstInt(ctx->i32, wave_info_shift, 0), ""), "");
+}
+
+/**
+ * Cull primitives for NGG VS or TES, then compact vertices, which happens
+ * before the VS or TES main function. Return values for the main function.
+ * Also return the position, which is passed to the shader as an input,
+ * so that we don't compute it twice.
+ */
+void gfx10_emit_ngg_culling_epilogue_4x_wave32(struct ac_shader_abi *abi,
+                                              unsigned max_outputs,
+                                              LLVMValueRef *addrs)
+{
+       struct si_shader_context *ctx = si_shader_context_from_abi(abi);
+       struct si_shader *shader = ctx->shader;
+       struct si_shader_selector *sel = shader->selector;
+       struct si_shader_info *info = &sel->info;
+       LLVMBuilderRef builder = ctx->ac.builder;
+
+       assert(shader->key.opt.ngg_culling);
+       assert(shader->key.as_ngg);
+       assert(sel->type == PIPE_SHADER_VERTEX ||
+              (sel->type == PIPE_SHADER_TESS_EVAL && !shader->key.as_es));
+
+       LLVMValueRef position[4] = {};
+       for (unsigned i = 0; i < info->num_outputs; i++) {
+               switch (info->output_semantic_name[i]) {
+               case TGSI_SEMANTIC_POSITION:
+                       for (unsigned j = 0; j < 4; j++) {
+                               position[j] = LLVMBuildLoad(ctx->ac.builder,
+                                                           addrs[4 * i + j], "");
+                       }
+                       break;
+               }
+       }
+       assert(position[0]);
+
+       /* Store Position.XYZW into LDS. */
+       LLVMValueRef es_vtxptr = ngg_nogs_vertex_ptr(ctx, get_thread_id_in_tg(ctx));
+       for (unsigned chan = 0; chan < 4; chan++) {
+               LLVMBuildStore(builder, ac_to_integer(&ctx->ac, position[chan]),
+                               ac_build_gep0(&ctx->ac, es_vtxptr,
+                                             LLVMConstInt(ctx->i32, lds_pos_x + chan, 0)));
+       }
+       /* Store Position.XY / W into LDS. */
+       for (unsigned chan = 0; chan < 2; chan++) {
+               LLVMValueRef val = ac_build_fdiv(&ctx->ac, position[chan], position[3]);
+               LLVMBuildStore(builder, ac_to_integer(&ctx->ac, val),
+                               ac_build_gep0(&ctx->ac, es_vtxptr,
+                                             LLVMConstInt(ctx->i32, lds_pos_x_div_w + chan, 0)));
+       }
+
+       /* Store VertexID and InstanceID. ES threads will have to load them
+        * from LDS after vertex compaction and use them instead of their own
+        * system values.
+        */
+       bool uses_instance_id = false;
+       bool uses_tes_prim_id = false;
+       LLVMValueRef packed_data = ctx->i32_0;
+
+       if (ctx->type == PIPE_SHADER_VERTEX) {
+               uses_instance_id = sel->info.uses_instanceid ||
+                                  shader->key.part.vs.prolog.instance_divisor_is_one ||
+                                  shader->key.part.vs.prolog.instance_divisor_is_fetched;
+
+               LLVMBuildStore(builder, ctx->abi.vertex_id,
+                              ac_build_gep0(&ctx->ac, es_vtxptr,
+                                            LLVMConstInt(ctx->i32, lds_vertex_id, 0)));
+               if (uses_instance_id) {
+                       LLVMBuildStore(builder, ctx->abi.instance_id,
+                                      ac_build_gep0(&ctx->ac, es_vtxptr,
+                                                    LLVMConstInt(ctx->i32, lds_instance_id, 0)));
+               }
+       } else {
+               uses_tes_prim_id = sel->info.uses_primid ||
+                                  shader->key.mono.u.vs_export_prim_id;
+
+               assert(ctx->type == PIPE_SHADER_TESS_EVAL);
+               LLVMBuildStore(builder, ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, ctx->tes_u)),
+                              ac_build_gep0(&ctx->ac, es_vtxptr,
+                                            LLVMConstInt(ctx->i32, lds_tes_u, 0)));
+               LLVMBuildStore(builder, ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, ctx->tes_v)),
+                              ac_build_gep0(&ctx->ac, es_vtxptr,
+                                            LLVMConstInt(ctx->i32, lds_tes_v, 0)));
+               packed_data = LLVMBuildShl(builder, ac_get_arg(&ctx->ac, ctx->tes_rel_patch_id),
+                                          LLVMConstInt(ctx->i32, lds_byte2_tes_rel_patch_id * 8, 0), "");
+               if (uses_tes_prim_id) {
+                       LLVMBuildStore(builder, ac_get_arg(&ctx->ac, ctx->args.tes_patch_id),
+                                      ac_build_gep0(&ctx->ac, es_vtxptr,
+                                                    LLVMConstInt(ctx->i32, lds_tes_patch_id, 0)));
+               }
+       }
+       /* Initialize the packed data. */
+       LLVMBuildStore(builder, packed_data,
+                      ac_build_gep0(&ctx->ac, es_vtxptr,
+                                    LLVMConstInt(ctx->i32, lds_packed_data, 0)));
+       ac_build_endif(&ctx->ac, ctx->merged_wrap_if_label);
+
+       LLVMValueRef tid = ac_get_thread_id(&ctx->ac);
+
+       /* Initialize the last 3 gs_ngg_scratch dwords to 0, because we may have less
+        * than 4 waves, but we always read all 4 values. This is where the thread
+        * bitmasks of unculled threads will be stored.
+        *
+        * gs_ngg_scratch layout: esmask[0..3]
+        */
+       ac_build_ifcc(&ctx->ac,
+                     LLVMBuildICmp(builder, LLVMIntULT, get_thread_id_in_tg(ctx),
+                                   LLVMConstInt(ctx->i32, 3, 0), ""), 16101);
+       {
+               LLVMValueRef index = LLVMBuildAdd(builder, tid, ctx->i32_1, "");
+               LLVMBuildStore(builder, ctx->i32_0,
+                              ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch, index));
+       }
+       ac_build_endif(&ctx->ac, 16101);
+       ac_build_s_barrier(&ctx->ac);
+
+       /* The hardware requires that there are no holes between unculled vertices,
+        * which means we have to pack ES threads, i.e. reduce the ES thread count
+        * and move ES input VGPRs to lower threads. The upside is that varyings
+        * are only fetched and computed for unculled vertices.
+        *
+        * Vertex compaction in GS threads:
+        *
+        * Part 1: Compute the surviving vertex mask in GS threads:
+        * - Compute 4 32-bit surviving vertex masks in LDS. (max 4 waves)
+        *   - In GS, notify ES threads whether the vertex survived.
+        *   - Barrier
+        *   - ES threads will create the mask and store it in LDS.
+        * - Barrier
+        * - Each GS thread loads the vertex masks from LDS.
+        *
+        * Part 2: Compact ES threads in GS threads:
+        * - Compute the prefix sum for all 3 vertices from the masks. These are the new
+        *   thread IDs for each vertex within the primitive.
+        * - Write the value of the old thread ID into the LDS address of the new thread ID.
+        *   The ES thread will load the old thread ID and use it to load the position, VertexID,
+        *   and InstanceID.
+        * - Update vertex indices and null flag in the GS input VGPRs.
+        * - Barrier
+        *
+        * Part 3: Update inputs GPRs
+        * - For all waves, update per-wave thread counts in input SGPRs.
+        * - In ES threads, update the ES input VGPRs (VertexID, InstanceID, TES inputs).
+        */
+
+       LLVMValueRef vtxindex[] = {
+               si_unpack_param(ctx, ctx->gs_vtx01_offset, 0, 16),
+               si_unpack_param(ctx, ctx->gs_vtx01_offset, 16, 16),
+               si_unpack_param(ctx, ctx->gs_vtx23_offset, 0, 16),
+       };
+       LLVMValueRef gs_vtxptr[] = {
+               ngg_nogs_vertex_ptr(ctx, vtxindex[0]),
+               ngg_nogs_vertex_ptr(ctx, vtxindex[1]),
+               ngg_nogs_vertex_ptr(ctx, vtxindex[2]),
+       };
+       es_vtxptr = ngg_nogs_vertex_ptr(ctx, get_thread_id_in_tg(ctx));
+
+       LLVMValueRef gs_accepted = ac_build_alloca(&ctx->ac, ctx->i32, "");
+
+       /* Do culling in GS threads. */
+       ac_build_ifcc(&ctx->ac, si_is_gs_thread(ctx), 16002);
+       {
+               /* Load positions. */
+               LLVMValueRef pos[3][4] = {};
+               for (unsigned vtx = 0; vtx < 3; vtx++) {
+                       for (unsigned chan = 0; chan < 4; chan++) {
+                               unsigned index;
+                               if (chan == 0 || chan == 1)
+                                       index = lds_pos_x_div_w + chan;
+                               else if (chan == 3)
+                                       index = lds_pos_w;
+                               else
+                                       continue;
+
+                               LLVMValueRef addr = ac_build_gep0(&ctx->ac, gs_vtxptr[vtx],
+                                                                 LLVMConstInt(ctx->i32, index, 0));
+                               pos[vtx][chan] = LLVMBuildLoad(builder, addr, "");
+                               pos[vtx][chan] = ac_to_float(&ctx->ac, pos[vtx][chan]);
+                       }
+               }
+
+               /* Load the viewport state for small prim culling. */
+               LLVMValueRef vp = ac_build_load_invariant(&ctx->ac,
+                                                         ac_get_arg(&ctx->ac, ctx->small_prim_cull_info),
+                                                         ctx->i32_0);
+               vp = LLVMBuildBitCast(builder, vp, ctx->v4f32, "");
+               LLVMValueRef vp_scale[2], vp_translate[2];
+               vp_scale[0] = ac_llvm_extract_elem(&ctx->ac, vp, 0);
+               vp_scale[1] = ac_llvm_extract_elem(&ctx->ac, vp, 1);
+               vp_translate[0] = ac_llvm_extract_elem(&ctx->ac, vp, 2);
+               vp_translate[1] = ac_llvm_extract_elem(&ctx->ac, vp, 3);
+
+               /* Get the small prim filter precision. */
+               LLVMValueRef small_prim_precision = si_unpack_param(ctx, ctx->vs_state_bits, 7, 4);
+               small_prim_precision = LLVMBuildOr(builder, small_prim_precision,
+                                                  LLVMConstInt(ctx->i32, 0x70, 0), "");
+               small_prim_precision = LLVMBuildShl(builder, small_prim_precision,
+                                                   LLVMConstInt(ctx->i32, 23, 0), "");
+               small_prim_precision = LLVMBuildBitCast(builder, small_prim_precision, ctx->f32, "");
+
+               /* Execute culling code. */
+               struct ac_cull_options options = {};
+               options.cull_front = shader->key.opt.ngg_culling & SI_NGG_CULL_FRONT_FACE;
+               options.cull_back = shader->key.opt.ngg_culling & SI_NGG_CULL_BACK_FACE;
+               options.cull_view_xy = shader->key.opt.ngg_culling & SI_NGG_CULL_VIEW_SMALLPRIMS;
+               options.cull_small_prims = options.cull_view_xy;
+               options.cull_zero_area = options.cull_front || options.cull_back;
+               options.cull_w = true;
+
+               /* Tell ES threads whether their vertex survived. */
+               ac_build_ifcc(&ctx->ac, ac_cull_triangle(&ctx->ac, pos, ctx->i1true,
+                                                        vp_scale, vp_translate,
+                                                        small_prim_precision, &options), 16003);
+               {
+                       LLVMBuildStore(builder, ctx->ac.i32_1, gs_accepted);
+                       for (unsigned vtx = 0; vtx < 3; vtx++) {
+                               LLVMBuildStore(builder, ctx->ac.i8_1,
+                                              si_build_gep_i8(ctx, gs_vtxptr[vtx], lds_byte0_accept_flag));
+                       }
+               }
+               ac_build_endif(&ctx->ac, 16003);
+       }
+       ac_build_endif(&ctx->ac, 16002);
+       ac_build_s_barrier(&ctx->ac);
+
+       gs_accepted = LLVMBuildLoad(builder, gs_accepted, "");
+
+       LLVMValueRef es_accepted = ac_build_alloca(&ctx->ac, ctx->i1, "");
+
+       /* Convert the per-vertex flag to a thread bitmask in ES threads and store it in LDS. */
+       ac_build_ifcc(&ctx->ac, si_is_es_thread(ctx), 16007);
+       {
+               LLVMValueRef es_accepted_flag =
+                       LLVMBuildLoad(builder,
+                                     si_build_gep_i8(ctx, es_vtxptr, lds_byte0_accept_flag), "");
+
+               LLVMValueRef es_accepted_bool = LLVMBuildICmp(builder, LLVMIntNE,
+                                                             es_accepted_flag, ctx->ac.i8_0, "");
+               LLVMValueRef es_mask = ac_get_i1_sgpr_mask(&ctx->ac, es_accepted_bool);
+
+               LLVMBuildStore(builder, es_accepted_bool, es_accepted);
+
+               ac_build_ifcc(&ctx->ac, LLVMBuildICmp(builder, LLVMIntEQ,
+                                                     tid, ctx->i32_0, ""), 16008);
+               {
+                       LLVMBuildStore(builder, es_mask,
+                                      ac_build_gep0(&ctx->ac, ctx->gs_ngg_scratch,
+                                                    get_wave_id_in_tg(ctx)));
+               }
+               ac_build_endif(&ctx->ac, 16008);
+       }
+       ac_build_endif(&ctx->ac, 16007);
+       ac_build_s_barrier(&ctx->ac);
+
+       /* Load the vertex masks and compute the new ES thread count. */
+       LLVMValueRef es_mask[2], new_num_es_threads, kill_wave;
+       load_bitmasks_2x64(ctx, ctx->gs_ngg_scratch, 0, es_mask, &new_num_es_threads);
+       new_num_es_threads = ac_build_readlane_no_opt_barrier(&ctx->ac, new_num_es_threads, NULL);
+
+       /* ES threads compute their prefix sum, which is the new ES thread ID.
+        * Then they write the value of the old thread ID into the LDS address
+        * of the new thread ID. It will be used it to load input VGPRs from
+        * the old thread's LDS location.
+        */
+       ac_build_ifcc(&ctx->ac, LLVMBuildLoad(builder, es_accepted, ""), 16009);
+       {
+               LLVMValueRef old_id = get_thread_id_in_tg(ctx);
+               LLVMValueRef new_id = ac_prefix_bitcount_2x64(&ctx->ac, es_mask, old_id);
+
+               LLVMBuildStore(builder, LLVMBuildTrunc(builder, old_id, ctx->i8, ""),
+                              si_build_gep_i8(ctx, ngg_nogs_vertex_ptr(ctx, new_id),
+                                              lds_byte0_old_thread_id));
+               LLVMBuildStore(builder, LLVMBuildTrunc(builder, new_id, ctx->i8, ""),
+                              si_build_gep_i8(ctx, es_vtxptr, lds_byte1_new_thread_id));
+       }
+       ac_build_endif(&ctx->ac, 16009);
+
+       /* Kill waves that have inactive threads. */
+       kill_wave = LLVMBuildICmp(builder, LLVMIntULE,
+                                 ac_build_imax(&ctx->ac, new_num_es_threads, ngg_get_prim_cnt(ctx)),
+                                 LLVMBuildMul(builder, get_wave_id_in_tg(ctx),
+                                              LLVMConstInt(ctx->i32, ctx->ac.wave_size, 0), ""), "");
+       ac_build_ifcc(&ctx->ac, kill_wave, 19202);
+       {
+               /* If we are killing wave 0, send that there are no primitives
+                * in this threadgroup.
+                */
+               ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx),
+                                             ctx->i32_0, ctx->i32_0);
+               ac_build_s_endpgm(&ctx->ac);
+       }
+       ac_build_endif(&ctx->ac, 19202);
+       ac_build_s_barrier(&ctx->ac);
+
+       /* Send the final vertex and primitive counts. */
+       ac_build_sendmsg_gs_alloc_req(&ctx->ac, get_wave_id_in_tg(ctx),
+                                     new_num_es_threads, ngg_get_prim_cnt(ctx));
+
+       /* Update thread counts in SGPRs. */
+       LLVMValueRef new_gs_tg_info = ac_get_arg(&ctx->ac, ctx->gs_tg_info);
+       LLVMValueRef new_merged_wave_info = ac_get_arg(&ctx->ac, ctx->merged_wave_info);
+
+       /* This also converts the thread count from the total count to the per-wave count. */
+       update_thread_counts(ctx, &new_num_es_threads, &new_gs_tg_info, 9, 12,
+                            &new_merged_wave_info, 8, 0);
+
+       /* Update vertex indices in VGPR0 (same format as NGG passthrough). */
+       LLVMValueRef new_vgpr0 = ac_build_alloca_undef(&ctx->ac, ctx->i32, "");
+
+       /* Set the null flag at the beginning (culled), and then
+        * overwrite it for accepted primitives.
+        */
+       LLVMBuildStore(builder, LLVMConstInt(ctx->i32, 1u << 31, 0), new_vgpr0);
+
+       /* Get vertex indices after vertex compaction. */
+       ac_build_ifcc(&ctx->ac, LLVMBuildTrunc(builder, gs_accepted, ctx->i1, ""), 16011);
+       {
+               struct ac_ngg_prim prim = {};
+               prim.num_vertices = 3;
+               prim.isnull = ctx->i1false;
+
+               for (unsigned vtx = 0; vtx < 3; vtx++) {
+                       prim.index[vtx] =
+                               LLVMBuildLoad(builder,
+                                             si_build_gep_i8(ctx, gs_vtxptr[vtx],
+                                                             lds_byte1_new_thread_id), "");
+                       prim.index[vtx] = LLVMBuildZExt(builder, prim.index[vtx], ctx->i32, "");
+                       prim.edgeflag[vtx] = ngg_get_initial_edgeflag(ctx, vtx);
+               }
+
+               /* Set the new GS input VGPR. */
+               LLVMBuildStore(builder, ac_pack_prim_export(&ctx->ac, &prim), new_vgpr0);
+       }
+       ac_build_endif(&ctx->ac, 16011);
+
+       if (gfx10_ngg_export_prim_early(shader))
+               gfx10_ngg_build_export_prim(ctx, NULL, LLVMBuildLoad(builder, new_vgpr0, ""));
+
+       /* Set the new ES input VGPRs. */
+       LLVMValueRef es_data[4];
+       LLVMValueRef old_thread_id = ac_build_alloca_undef(&ctx->ac, ctx->i32, "");
+
+       for (unsigned i = 0; i < 4; i++)
+               es_data[i] = ac_build_alloca_undef(&ctx->ac, ctx->i32, "");
+
+       ac_build_ifcc(&ctx->ac, LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, tid,
+                                             new_num_es_threads, ""), 16012);
+       {
+               LLVMValueRef old_id, old_es_vtxptr, tmp;
+
+               /* Load ES input VGPRs from the ES thread before compaction. */
+               old_id = LLVMBuildLoad(builder,
+                                      si_build_gep_i8(ctx, es_vtxptr, lds_byte0_old_thread_id), "");
+               old_id = LLVMBuildZExt(builder, old_id, ctx->i32, "");
+
+               LLVMBuildStore(builder, old_id, old_thread_id);
+               old_es_vtxptr = ngg_nogs_vertex_ptr(ctx, old_id);
+
+               for (unsigned i = 0; i < 2; i++) {
+                       tmp = LLVMBuildLoad(builder,
+                                           ac_build_gep0(&ctx->ac, old_es_vtxptr,
+                                                         LLVMConstInt(ctx->i32, lds_vertex_id + i, 0)), "");
+                       LLVMBuildStore(builder, tmp, es_data[i]);
+               }
+
+               if (ctx->type == PIPE_SHADER_TESS_EVAL) {
+                       tmp = LLVMBuildLoad(builder,
+                                           si_build_gep_i8(ctx, old_es_vtxptr,
+                                                           lds_byte2_tes_rel_patch_id), "");
+                       tmp = LLVMBuildZExt(builder, tmp, ctx->i32, "");
+                       LLVMBuildStore(builder, tmp, es_data[2]);
+
+                       if (uses_tes_prim_id) {
+                               tmp = LLVMBuildLoad(builder,
+                                                   ac_build_gep0(&ctx->ac, old_es_vtxptr,
+                                                                 LLVMConstInt(ctx->i32, lds_tes_patch_id, 0)), "");
+                               LLVMBuildStore(builder, tmp, es_data[3]);
+                       }
+               }
+       }
+       ac_build_endif(&ctx->ac, 16012);
+
+       /* Return values for the main function. */
+       LLVMValueRef ret = ctx->return_value;
+       LLVMValueRef val;
+
+       ret = LLVMBuildInsertValue(ctx->ac.builder, ret, new_gs_tg_info, 2, "");
+       ret = LLVMBuildInsertValue(ctx->ac.builder, ret, new_merged_wave_info, 3, "");
+       if (ctx->type == PIPE_SHADER_TESS_EVAL)
+               ret = si_insert_input_ret(ctx, ret, ctx->tcs_offchip_offset, 4);
+
+       ret = si_insert_input_ptr(ctx, ret, ctx->rw_buffers,
+                                 8 + SI_SGPR_RW_BUFFERS);
+       ret = si_insert_input_ptr(ctx, ret,
+                                 ctx->bindless_samplers_and_images,
+                                 8 + SI_SGPR_BINDLESS_SAMPLERS_AND_IMAGES);
+       ret = si_insert_input_ptr(ctx, ret,
+                                 ctx->const_and_shader_buffers,
+                                 8 + SI_SGPR_CONST_AND_SHADER_BUFFERS);
+       ret = si_insert_input_ptr(ctx, ret,
+                                 ctx->samplers_and_images,
+                                 8 + SI_SGPR_SAMPLERS_AND_IMAGES);
+       ret = si_insert_input_ptr(ctx, ret, ctx->vs_state_bits,
+                                 8 + SI_SGPR_VS_STATE_BITS);
+
+       if (ctx->type == PIPE_SHADER_VERTEX) {
+               ret = si_insert_input_ptr(ctx, ret, ctx->args.base_vertex,
+                                         8 + SI_SGPR_BASE_VERTEX);
+               ret = si_insert_input_ptr(ctx, ret, ctx->args.start_instance,
+                                         8 + SI_SGPR_START_INSTANCE);
+               ret = si_insert_input_ptr(ctx, ret, ctx->args.draw_id,
+                                         8 + SI_SGPR_DRAWID);
+               ret = si_insert_input_ptr(ctx, ret, ctx->vertex_buffers,
+                                         8 + SI_VS_NUM_USER_SGPR);
+       } else {
+               assert(ctx->type == PIPE_SHADER_TESS_EVAL);
+               ret = si_insert_input_ptr(ctx, ret, ctx->tcs_offchip_layout,
+                                         8 + SI_SGPR_TES_OFFCHIP_LAYOUT);
+               ret = si_insert_input_ptr(ctx, ret, ctx->tes_offchip_addr,
+                                         8 + SI_SGPR_TES_OFFCHIP_ADDR);
+       }
+
+       unsigned vgpr;
+       if (ctx->type == PIPE_SHADER_VERTEX)
+               vgpr = 8 + GFX9_VSGS_NUM_USER_SGPR + 1;
+       else
+               vgpr = 8 + GFX9_TESGS_NUM_USER_SGPR;
+
+       val = LLVMBuildLoad(builder, new_vgpr0, "");
+       ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val),
+                                  vgpr++, "");
+       vgpr++; /* gs_vtx23_offset */
+
+       ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_prim_id, vgpr++);
+       ret = si_insert_input_ret_float(ctx, ret, ctx->args.gs_invocation_id, vgpr++);
+       vgpr++; /* gs_vtx45_offset */
+
+       if (ctx->type == PIPE_SHADER_VERTEX) {
+               val = LLVMBuildLoad(builder, es_data[0], "");
+               ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val),
+                                          vgpr++, ""); /* VGPR5 - VertexID */
+               vgpr += 2;
+               if (uses_instance_id) {
+                       val = LLVMBuildLoad(builder, es_data[1], "");
+                       ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val),
+                                                  vgpr++, ""); /* VGPR8 - InstanceID */
+               } else {
+                       vgpr++;
+               }
+       } else {
+               assert(ctx->type == PIPE_SHADER_TESS_EVAL);
+               unsigned num_vgprs = uses_tes_prim_id ? 4 : 3;
+               for (unsigned i = 0; i < num_vgprs; i++) {
+                       val = LLVMBuildLoad(builder, es_data[i], "");
+                       ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val),
+                                                  vgpr++, "");
+               }
+               if (num_vgprs == 3)
+                       vgpr++;
+       }
+       /* Return the old thread ID. */
+       val = LLVMBuildLoad(builder, old_thread_id, "");
+       ret = LLVMBuildInsertValue(builder, ret, ac_to_float(&ctx->ac, val), vgpr++, "");
+
+       /* These two also use LDS. */
+       if (sel->info.writes_edgeflag ||
+           (ctx->type == PIPE_SHADER_VERTEX && shader->key.mono.u.vs_export_prim_id))
+               ac_build_s_barrier(&ctx->ac);
+
+       ctx->return_value = ret;
+}
+
 /**
  * Emit the epilogue of an API VS or TES shader compiled as ESGS shader.
  */
@@ -630,7 +1258,8 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
        }
 
        bool unterminated_es_if_block =
-               gfx10_is_ngg_passthrough(ctx->shader) &&
+               !sel->so.num_outputs &&
+               !sel->info.writes_edgeflag &&
                !ctx->screen->use_ngg_streamout && /* no query buffer */
                (ctx->type != PIPE_SHADER_VERTEX ||
                 !ctx->shader->key.mono.u.vs_export_prim_id);
@@ -640,11 +1269,17 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
 
        LLVMValueRef is_gs_thread = si_is_gs_thread(ctx);
        LLVMValueRef is_es_thread = si_is_es_thread(ctx);
-       LLVMValueRef vtxindex[] = {
-               si_unpack_param(ctx, ctx->gs_vtx01_offset, 0, 16),
-               si_unpack_param(ctx, ctx->gs_vtx01_offset, 16, 16),
-               si_unpack_param(ctx, ctx->gs_vtx23_offset, 0, 16),
-       };
+       LLVMValueRef vtxindex[3];
+
+       if (ctx->shader->key.opt.ngg_culling) {
+               vtxindex[0] = si_unpack_param(ctx, ctx->gs_vtx01_offset, 0, 9);
+               vtxindex[1] = si_unpack_param(ctx, ctx->gs_vtx01_offset, 10, 9);
+               vtxindex[2] = si_unpack_param(ctx, ctx->gs_vtx01_offset, 20, 9);
+       } else {
+               vtxindex[0] = si_unpack_param(ctx, ctx->gs_vtx01_offset, 0, 16);
+               vtxindex[1] = si_unpack_param(ctx, ctx->gs_vtx01_offset, 16, 16);
+               vtxindex[2] = si_unpack_param(ctx, ctx->gs_vtx23_offset, 0, 16);
+       }
 
        /* Determine the number of vertices per primitive. */
        unsigned num_vertices;
@@ -758,7 +1393,7 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
        /* Build the primitive export. */
        if (!gfx10_ngg_export_prim_early(ctx->shader)) {
                assert(!unterminated_es_if_block);
-               gfx10_ngg_build_export_prim(ctx, user_edgeflags);
+               gfx10_ngg_build_export_prim(ctx, user_edgeflags, NULL);
        }
 
        /* Export per-vertex data (positions and parameters). */
@@ -769,11 +1404,27 @@ void gfx10_emit_ngg_epilogue(struct ac_shader_abi *abi,
 
                /* Unconditionally (re-)load the values for proper SSA form. */
                for (i = 0; i < info->num_outputs; i++) {
-                       for (unsigned j = 0; j < 4; j++) {
-                               outputs[i].values[j] =
-                                       LLVMBuildLoad(builder,
-                                               addrs[4 * i + j],
-                                               "");
+                       /* If the NGG cull shader part computed the position, don't
+                        * use the position from the current shader part. Instead,
+                        * load it from LDS.
+                        */
+                       if (info->output_semantic_name[i] == TGSI_SEMANTIC_POSITION &&
+                           ctx->shader->key.opt.ngg_culling) {
+                               vertex_ptr = ngg_nogs_vertex_ptr(ctx,
+                                               ac_get_arg(&ctx->ac, ctx->ngg_old_thread_id));
+
+                               for (unsigned j = 0; j < 4; j++) {
+                                       tmp = LLVMConstInt(ctx->i32, lds_pos_x + j, 0);
+                                       tmp = ac_build_gep0(&ctx->ac, vertex_ptr, tmp);
+                                       tmp = LLVMBuildLoad(builder, tmp, "");
+                                       outputs[i].values[j] = ac_to_float(&ctx->ac, tmp);
+                               }
+                       } else {
+                               for (unsigned j = 0; j < 4; j++) {
+                                       outputs[i].values[j] =
+                                               LLVMBuildLoad(builder,
+                                                             addrs[4 * i + j], "");
+                               }
                        }
                }