amd/llvm: Refactor ac_build_scan.
[mesa.git] / src / amd / llvm / ac_llvm_build.c
index 54513d79922baf9cb55cb18aade6b235a16dfb32..47c27893fe50d2af2d98cd9498eeaf16cdcbe4d0 100644 (file)
@@ -3768,6 +3768,11 @@ static LLVMValueRef
 _ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
                     bool exchange_rows, bool bound_ctrl)
 {
+       LLVMTypeRef type = LLVMTypeOf(src);
+       LLVMValueRef result;
+
+       src = LLVMBuildZExt(ctx->builder, src, ctx->i32, "");
+
        LLVMValueRef args[6] = {
                src,
                src,
@@ -3776,10 +3781,13 @@ _ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel
                ctx->i1true, /* fi */
                bound_ctrl ? ctx->i1true : ctx->i1false,
        };
-       return ac_build_intrinsic(ctx, exchange_rows ? "llvm.amdgcn.permlanex16"
-                                                    : "llvm.amdgcn.permlane16",
-                                 ctx->i32, args, 6,
-                                 AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
+
+       result = ac_build_intrinsic(ctx, exchange_rows ? "llvm.amdgcn.permlanex16"
+                                                      : "llvm.amdgcn.permlane16",
+                                   ctx->i32, args, 6,
+                                   AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT);
+
+       return LLVMBuildTrunc(ctx->builder, result, type, "");
 }
 
 static LLVMValueRef
@@ -3790,10 +3798,7 @@ ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
        src = ac_to_integer(ctx, src);
        unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src));
        LLVMValueRef ret;
-       if (bits == 32) {
-               ret = _ac_build_permlane16(ctx, src, sel, exchange_rows,
-                                          bound_ctrl);
-       } else {
+       if (bits > 32) {
                assert(bits % 32 == 0);
                LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32);
                LLVMValueRef src_vector =
@@ -3812,6 +3817,9 @@ ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel,
                                                     LLVMConstInt(ctx->i32, i,
                                                                  0), "");
                }
+       } else {
+               ret = _ac_build_permlane16(ctx, src, sel, exchange_rows,
+                                          bound_ctrl);
        }
        return LLVMBuildBitCast(ctx->builder, ret, src_type, "");
 }
@@ -4037,18 +4045,17 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs,
 }
 
 /**
+ * \param src The value to shift.
+ * \param identity The value to use the first lane.
  * \param maxprefix specifies that the result only needs to be correct for a
  *     prefix of this many threads
+ * \return src, shifted 1 lane up, and identity shifted into lane 0.
  */
 static LLVMValueRef
-ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
-             unsigned maxprefix, bool inclusive)
+ac_wavefront_shift_right_1(struct ac_llvm_context *ctx, LLVMValueRef src,
+                           LLVMValueRef identity, unsigned maxprefix)
 {
-       LLVMValueRef result, tmp;
-
-       if (inclusive) {
-               result = src;
-       } else if (ctx->chip_class >= GFX10) {
+       if (ctx->chip_class >= GFX10) {
                /* wavefront shift_right by 1 on GFX10 (emulate dpp_wf_sr1) */
                LLVMValueRef active, tmp1, tmp2;
                LLVMValueRef tid = ac_get_thread_id(ctx);
@@ -4071,45 +4078,57 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
                                                           LLVMBuildAnd(ctx->builder, tid,
                                                                        LLVMConstInt(ctx->i32, 0x1f, false), ""),
                                                           LLVMConstInt(ctx->i32, 0x10, false), ""), "");
-                       src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+                       return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
                } else if (maxprefix > 16) {
                        active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid,
                                               LLVMConstInt(ctx->i32, 16, false), "");
 
-                       src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+                       return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
                }
-
-               result = src;
        } else if (ctx->chip_class >= GFX8) {
-               src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
-               result = src;
-       } else {
-               /* wavefront shift_right by 1 on SI/CI */
-               LLVMValueRef active, tmp1, tmp2;
-               LLVMValueRef tid = ac_get_thread_id(ctx);
-               tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
-               tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
-                                      LLVMConstInt(ctx->i32, 0x4, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
-                                      LLVMConstInt(ctx->i32, 0x8, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
-                                      LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
-                                      LLVMConstInt(ctx->i32, 0x10, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
-               tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
-               active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
-               src = LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
-               result = src;
-        }
+               return ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false);
+       }
+
+       /* wavefront shift_right by 1 on SI/CI */
+       LLVMValueRef active, tmp1, tmp2;
+       LLVMValueRef tid = ac_get_thread_id(ctx);
+       tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2));
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x4, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x8, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ,
+                              LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""),
+                              LLVMConstInt(ctx->i32, 0x10, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0));
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), "");
+       tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, "");
+       active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), "");
+       return LLVMBuildSelect(ctx->builder, active, identity, tmp1, "");
+}
+
+/**
+ * \param maxprefix specifies that the result only needs to be correct for a
+ *     prefix of this many threads
+ */
+static LLVMValueRef
+ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity,
+             unsigned maxprefix, bool inclusive)
+{
+       LLVMValueRef result, tmp;
+
+       if (!inclusive)
+               src = ac_wavefront_shift_right_1(ctx, src, identity, maxprefix);
+
+       result = src;
 
        if (ctx->chip_class <= GFX7) {
                assert(maxprefix == 64);