From e09426ad6bea4156a02958c59804263dae8dbf82 Mon Sep 17 00:00:00 2001 From: Bas Nieuwenhuizen Date: Wed, 27 Nov 2019 23:33:59 +0100 Subject: [PATCH] amd/llvm: Refactor ac_build_scan. Split out the logic for exclusive scans into a separate function that makes clear what it does instead of having this opaque 60 line if. Reviewed-by: Samuel Pitoiset --- src/amd/llvm/ac_llvm_build.c | 91 ++++++++++++++++++++---------------- 1 file changed, 51 insertions(+), 40 deletions(-) diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c index cf6eda30e2c..47c27893fe5 100644 --- a/src/amd/llvm/ac_llvm_build.c +++ b/src/amd/llvm/ac_llvm_build.c @@ -4045,18 +4045,17 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs, } /** + * \param src The value to shift. + * \param identity The value to use the first lane. * \param maxprefix specifies that the result only needs to be correct for a * prefix of this many threads + * \return src, shifted 1 lane up, and identity shifted into lane 0. */ static LLVMValueRef -ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity, - unsigned maxprefix, bool inclusive) +ac_wavefront_shift_right_1(struct ac_llvm_context *ctx, LLVMValueRef src, + LLVMValueRef identity, unsigned maxprefix) { - LLVMValueRef result, tmp; - - if (inclusive) { - result = src; - } else if (ctx->chip_class >= GFX10) { + if (ctx->chip_class >= GFX10) { /* wavefront shift_right by 1 on GFX10 (emulate dpp_wf_sr1) */ LLVMValueRef active, tmp1, tmp2; LLVMValueRef tid = ac_get_thread_id(ctx); @@ -4079,45 +4078,57 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, false), ""), LLVMConstInt(ctx->i32, 0x10, false), ""), ""); - src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); } else if (maxprefix > 16) { active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 16, false), ""); - src = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + return LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); } - - result = src; } else if (ctx->chip_class >= GFX8) { - src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false); - result = src; - } else { - /* wavefront shift_right by 1 on SI/CI */ - LLVMValueRef active, tmp1, tmp2; - LLVMValueRef tid = ac_get_thread_id(ctx); - tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2)); - tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00)); - active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, - LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""), - LLVMConstInt(ctx->i32, 0x4, 0), ""); - tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); - tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00)); - active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, - LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""), - LLVMConstInt(ctx->i32, 0x8, 0), ""); - tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); - tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00)); - active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, - LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""), - LLVMConstInt(ctx->i32, 0x10, 0), ""); - tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); - tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0)); - active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), ""); - tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); - active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), ""); - src = LLVMBuildSelect(ctx->builder, active, identity, tmp1, ""); - result = src; - } + return ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false); + } + + /* wavefront shift_right by 1 on SI/CI */ + LLVMValueRef active, tmp1, tmp2; + LLVMValueRef tid = ac_get_thread_id(ctx); + tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2)); + tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""), + LLVMConstInt(ctx->i32, 0x4, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""), + LLVMConstInt(ctx->i32, 0x8, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""), + LLVMConstInt(ctx->i32, 0x10, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), ""); + return LLVMBuildSelect(ctx->builder, active, identity, tmp1, ""); +} + +/** + * \param maxprefix specifies that the result only needs to be correct for a + * prefix of this many threads + */ +static LLVMValueRef +ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity, + unsigned maxprefix, bool inclusive) +{ + LLVMValueRef result, tmp; + + if (!inclusive) + src = ac_wavefront_shift_right_1(ctx, src, identity, maxprefix); + + result = src; if (ctx->chip_class <= GFX7) { assert(maxprefix == 64); -- 2.30.2