ac: add prefix bitcount functions
[mesa.git] / src / amd / llvm / ac_llvm_build.c
index cf6eda30e2cc61187f4d0d0f849d77c5505428a5..f789ff5a368f650bee70a45235bb257183e70e71 100644 (file)
@@ -85,6 +85,7 @@ ac_llvm_context_init(struct ac_llvm_context *ctx,
        ctx->i16 = LLVMIntTypeInContext(ctx->context, 16);
        ctx->i32 = LLVMIntTypeInContext(ctx->context, 32);
        ctx->i64 = LLVMIntTypeInContext(ctx->context, 64);
+       ctx->i128 = LLVMIntTypeInContext(ctx->context, 128);
        ctx->intptr = ctx->i32;
        ctx->f16 = LLVMHalfTypeInContext(ctx->context);
        ctx->f32 = LLVMFloatTypeInContext(ctx->context);
@@ -108,6 +109,8 @@ ac_llvm_context_init(struct ac_llvm_context *ctx,
        ctx->i32_1 = LLVMConstInt(ctx->i32, 1, false);
        ctx->i64_0 = LLVMConstInt(ctx->i64, 0, false);
        ctx->i64_1 = LLVMConstInt(ctx->i64, 1, false);
+       ctx->i128_0 = LLVMConstInt(ctx->i128, 0, false);
+       ctx->i128_1 = LLVMConstInt(ctx->i128, 1, false);
        ctx->f16_0 = LLVMConstReal(ctx->f16, 0.0);
        ctx->f16_1 = LLVMConstReal(ctx->f16, 1.0);
        ctx->f32_0 = LLVMConstReal(ctx->f32, 0.0);
@@ -505,14 +508,23 @@ ac_build_ballot(struct ac_llvm_context *ctx,
 LLVMValueRef ac_get_i1_sgpr_mask(struct ac_llvm_context *ctx,
                                 LLVMValueRef value)
 {
-       const char *name = LLVM_VERSION_MAJOR >= 9 ? "llvm.amdgcn.icmp.i64.i1" : "llvm.amdgcn.icmp.i1";
+       const char *name;
+
+       if (LLVM_VERSION_MAJOR >= 9) {
+               if (ctx->wave_size == 64)
+                       name = "llvm.amdgcn.icmp.i64.i1";
+               else
+                       name = "llvm.amdgcn.icmp.i32.i1";
+       } else {
+               name = "llvm.amdgcn.icmp.i1";
+       }
        LLVMValueRef args[3] = {
                value,
                ctx->i1false,
                LLVMConstInt(ctx->i32, LLVMIntNE, 0),
        };
 
-       return ac_build_intrinsic(ctx, name, ctx->i64, args, 3,
+       return ac_build_intrinsic(ctx, name, ctx->iN_wavemask, args, 3,
                                  AC_FUNC_ATTR_NOUNWIND |
                                  AC_FUNC_ATTR_READNONE |
                                  AC_FUNC_ATTR_CONVERGENT);
@@ -2808,6 +2820,12 @@ LLVMValueRef ac_build_bit_count(struct ac_llvm_context *ctx, LLVMValueRef src0)
        bitsize = ac_get_elem_bits(ctx, LLVMTypeOf(src0));
 
        switch (bitsize) {
+       case 128:
+               result = ac_build_intrinsic(ctx, "llvm.ctpop.i128", ctx->i128,
+                                           (LLVMValueRef []) { src0 }, 1,
+                                           AC_FUNC_ATTR_READNONE);
+               result = LLVMBuildTrunc(ctx->builder, result, ctx->i32, "");
+               break;
        case 64:
                result = ac_build_intrinsic(ctx, "llvm.ctpop.i64", ctx->i64,
                                            (LLVMValueRef []) { src0 }, 1,
@@ -4045,18 +4063,17 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
 }
 
 /**
+ * \param src The value to shift.
+ * \param identity The value to use the first lane.
  * \param maxprefix specifies that the result only needs to be correct for a
  *     prefix of this many threads
+ * \return src, shifted 1 lane up, and identity shifted into lane 0.
  */
 static LLVMValueRef
-ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
-             unsigned maxprefix, bool inclusive)
+ac_wavefront_shift_right_1(struct ac_llvm_context *ctx, LLVMValueRef src,
+                           LLVMValueRef identity, unsigned maxprefix)
 {
-       LLVMValueRef result, tmp;
-
-       if (inclusive) {
-               result = src;
-       } else if (ctx->chip_class >= GFX10) {
+       if (ctx->chip_class >= GFX10) {
                /* wavefront shift_right by 1 on GFX10 (emulate dpp_wf_sr1) */
                LLVMValueRef active, tmp1, tmp2;
                LLVMValueRef tid = ac_get_thread_id(ctx);
@@ -4079,45 +4096,57 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
                                                           LLVMBuildAnd(ctx->builder, tid,
                                                                        LLVMConstInt(ctx->i32, 0x1f, false), ""),
                                                           LLVMConstInt(ctx->i32, 0x10, false), ""), "");
-                       src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+                       return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
                } else if (maxprefix > 16) {
                        active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid,
                                               LLVMConstInt(ctx->i32, 16, false), "");
 
-                       src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+                       return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
                }
-
-               result = src;
        } else if (ctx->chip_class >= GFX8) {
-               src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
-               result = src;
-       } else {
-               /* wavefront shift_right by 1 on SI/CI */
-               LLVMValueRef active, tmp1, tmp2;
-               LLVMValueRef tid = ac_get_thread_id(ctx);
-               tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
-               tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
-                                      LLVMConstInt(ctx->i32, 0x4, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
-                                      LLVMConstInt(ctx->i32, 0x8, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
-                                      LLVMConstInt(ctx->i32, 0x10, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
-               src = LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
-               result = src;
-        }
+               return ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
+       }
+
+       /* wavefront shift_right by 1 on SI/CI */
+       LLVMValueRef active, tmp1, tmp2;
+       LLVMValueRef tid = ac_get_thread_id(ctx);
+       tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x4, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x8, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x10, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
+       return LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
+}
+
+/**
+ * \param maxprefix specifies that the result only needs to be correct for a
+ *     prefix of this many threads
+ */
+static LLVMValueRef
+ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
+             unsigned maxprefix, bool inclusive)
+{
+       LLVMValueRef result, tmp;
+
+       if (!inclusive)
+               src = ac_wavefront_shift_right_1(ctx, src, identity, maxprefix);
+
+       result = src;
 
        if (ctx->chip_class <= GFX7) {
                assert(maxprefix == 64);
@@ -4315,12 +4344,15 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign
        if (cluster_size == 32) return ac_build_wwm(ctx, result);
 
        if (ctx->chip_class >= GFX8) {
-               if (ctx->chip_class >= GFX10)
-                       swap = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, false));
-               else
-                       swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
-               result = ac_build_alu_op(ctx, result, swap, op);
-               result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 63, 0));
+               if (ctx->wave_size == 64) {
+                       if (ctx->chip_class >= GFX10)
+                               swap = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, false));
+                       else
+                               swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false);
+                       result = ac_build_alu_op(ctx, result, swap, op);
+                       result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 63, 0));
+               }
+
                return ac_build_wwm(ctx, result);
        } else {
                swap = ac_build_readlane(ctx, result, ctx->i32_0);
@@ -4708,6 +4740,79 @@ ac_export_mrt_z(struct ac_llvm_context *ctx, LLVMValueRef depth,
        args->enabled_channels = mask;
 }
 
+/* Send GS Alloc Req message from the first wave of the group to SPI.
+ * Message payload is:
+ * - bits 0..10: vertices in group
+ * - bits 12..22: primitives in group
+ */
+void ac_build_sendmsg_gs_alloc_req(struct ac_llvm_context *ctx, LLVMValueRef wave_id,
+                                  LLVMValueRef vtx_cnt, LLVMValueRef prim_cnt)
+{
+       LLVMBuilderRef builder = ctx->builder;
+       LLVMValueRef tmp;
+
+       ac_build_ifcc(ctx, LLVMBuildICmp(builder, LLVMIntEQ, wave_id, ctx->i32_0, ""), 5020);
+
+       tmp = LLVMBuildShl(builder, prim_cnt, LLVMConstInt(ctx->i32, 12, false),"");
+       tmp = LLVMBuildOr(builder, tmp, vtx_cnt, "");
+       ac_build_sendmsg(ctx, AC_SENDMSG_GS_ALLOC_REQ, tmp);
+
+       ac_build_endif(ctx, 5020);
+}
+
+LLVMValueRef ac_pack_prim_export(struct ac_llvm_context *ctx,
+                                const struct ac_ngg_prim *prim)
+{
+       /* The prim export format is:
+        *  - bits 0..8: index 0
+        *  - bit 9: edge flag 0
+        *  - bits 10..18: index 1
+        *  - bit 19: edge flag 1
+        *  - bits 20..28: index 2
+        *  - bit 29: edge flag 2
+        *  - bit 31: null primitive (skip)
+        */
+       LLVMBuilderRef builder = ctx->builder;
+       LLVMValueRef tmp = LLVMBuildZExt(builder, prim->isnull, ctx->i32, "");
+       LLVMValueRef result = LLVMBuildShl(builder, tmp, LLVMConstInt(ctx->i32, 31, false), "");
+
+       for (unsigned i = 0; i < prim->num_vertices; ++i) {
+               tmp = LLVMBuildShl(builder, prim->index[i],
+                                  LLVMConstInt(ctx->i32, 10 * i, false), "");
+               result = LLVMBuildOr(builder, result, tmp, "");
+               tmp = LLVMBuildZExt(builder, prim->edgeflag[i], ctx->i32, "");
+               tmp = LLVMBuildShl(builder, tmp,
+                                  LLVMConstInt(ctx->i32, 10 * i + 9, false), "");
+               result = LLVMBuildOr(builder, result, tmp, "");
+       }
+       return result;
+}
+
+void ac_build_export_prim(struct ac_llvm_context *ctx,
+                         const struct ac_ngg_prim *prim)
+{
+       struct ac_export_args args;
+
+       if (prim->passthrough) {
+               args.out[0] = prim->passthrough;
+       } else {
+               args.out[0] = ac_pack_prim_export(ctx, prim);
+       }
+
+       args.out[0] = LLVMBuildBitCast(ctx->builder, args.out[0], ctx->f32, "");
+       args.out[1] = LLVMGetUndef(ctx->f32);
+       args.out[2] = LLVMGetUndef(ctx->f32);
+       args.out[3] = LLVMGetUndef(ctx->f32);
+
+       args.target = V_008DFC_SQ_EXP_PRIM;
+       args.enabled_channels = 1;
+       args.done = true;
+       args.valid_mask = false;
+       args.compr = false;
+
+       ac_build_export(ctx, &args);
+}
+
 static LLVMTypeRef
 arg_llvm_type(enum ac_arg_type type, unsigned size, struct ac_llvm_context *ctx)
 {
@@ -4787,3 +4892,68 @@ ac_build_main(const struct ac_shader_args *args,
        return main_function;
 }
 
+void ac_build_s_endpgm(struct ac_llvm_context *ctx)
+{
+       LLVMTypeRef calltype = LLVMFunctionType(ctx->voidt, NULL, 0, false);
+       LLVMValueRef code = LLVMConstInlineAsm(calltype, "s_endpgm", "", true, false);
+       LLVMBuildCall(ctx->builder, code, NULL, 0, "");
+}
+
+LLVMValueRef ac_prefix_bitcount(struct ac_llvm_context *ctx,
+                               LLVMValueRef mask, LLVMValueRef index)
+{
+       LLVMBuilderRef builder = ctx->builder;
+       LLVMTypeRef type = LLVMTypeOf(mask);
+
+       LLVMValueRef bit = LLVMBuildShl(builder, LLVMConstInt(type, 1, 0),
+                                       LLVMBuildZExt(builder, index, type, ""), "");
+       LLVMValueRef prefix_bits = LLVMBuildSub(builder, bit, LLVMConstInt(type, 1, 0), "");
+       LLVMValueRef prefix_mask = LLVMBuildAnd(builder, mask, prefix_bits, "");
+       return ac_build_bit_count(ctx, prefix_mask);
+}
+
+/* Compute the prefix sum of the "mask" bit array with 128 elements (bits). */
+LLVMValueRef ac_prefix_bitcount_2x64(struct ac_llvm_context *ctx,
+                                    LLVMValueRef mask[2], LLVMValueRef index)
+{
+       LLVMBuilderRef builder = ctx->builder;
+#if 0
+       /* Reference version using i128. */
+       LLVMValueRef input_mask =
+               LLVMBuildBitCast(builder, ac_build_gather_values(ctx, mask, 2), ctx->i128, "");
+
+       return ac_prefix_bitcount(ctx, input_mask, index);
+#else
+       /* Optimized version using 2 64-bit masks. */
+       LLVMValueRef is_hi, is_0, c64, c128, all_bits;
+       LLVMValueRef prefix_mask[2], shift[2], mask_bcnt0, prefix_bcnt[2];
+
+       /* Compute the 128-bit prefix mask. */
+       c64 = LLVMConstInt(ctx->i32, 64, 0);
+       c128 = LLVMConstInt(ctx->i32, 128, 0);
+       all_bits = LLVMConstInt(ctx->i64, UINT64_MAX, 0);
+       /* The first index that can have non-zero high bits in the prefix mask is 65. */
+       is_hi = LLVMBuildICmp(builder, LLVMIntUGT, index, c64, "");
+       is_0 = LLVMBuildICmp(builder, LLVMIntEQ, index, ctx->i32_0, "");
+       mask_bcnt0 = ac_build_bit_count(ctx, mask[0]);
+
+       for (unsigned i = 0; i < 2; i++) {
+               shift[i] = LLVMBuildSub(builder, i ? c128 : c64, index, "");
+               /* For i==0, index==0, the right shift by 64 doesn't give the desired result,
+                * so we handle it by the is_0 select.
+                * For i==1, index==64, same story, so we handle it by the last is_hi select.
+                * For i==0, index==64, we shift by 0, which is what we want.
+                */
+               prefix_mask[i] = LLVMBuildLShr(builder, all_bits,
+                                       LLVMBuildZExt(builder, shift[i], ctx->i64, ""), "");
+               prefix_mask[i] = LLVMBuildAnd(builder, mask[i], prefix_mask[i], "");
+               prefix_bcnt[i] = ac_build_bit_count(ctx, prefix_mask[i]);
+       }
+
+       prefix_bcnt[0] = LLVMBuildSelect(builder, is_0, ctx->i32_0, prefix_bcnt[0], "");
+       prefix_bcnt[0] = LLVMBuildSelect(builder, is_hi, mask_bcnt0, prefix_bcnt[0], "");
+       prefix_bcnt[1] = LLVMBuildSelect(builder, is_hi, prefix_bcnt[1], ctx->i32_0, "");
+
+       return LLVMBuildAdd(builder, prefix_bcnt[0], prefix_bcnt[1], "");
+#endif
+}