From 300876a9a7f849a2b165360e19ec1708a342b68c Mon Sep 17 00:00:00 2001 From: =?utf8?q?Nicolai=20H=C3=A4hnle?= Date: Wed, 23 May 2018 22:09:27 +0200 Subject: [PATCH] amd/common: scan/reduce across waves of a workgroup MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Order-aware scan/reduce can trade-off LDS traffic for external atomics memory traffic in producer/consumer compute shaders. Reviewed-by: Marek Olšák --- src/amd/common/ac_llvm_build.c | 195 ++++++++++++++++++++++++++++++++- src/amd/common/ac_llvm_build.h | 36 ++++++ 2 files changed, 227 insertions(+), 4 deletions(-) diff --git a/src/amd/common/ac_llvm_build.c b/src/amd/common/ac_llvm_build.c index 3b958d2a318..cbd715cd378 100644 --- a/src/amd/common/ac_llvm_build.c +++ b/src/amd/common/ac_llvm_build.c @@ -3112,24 +3112,44 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs, } } -/* TODO: add inclusive and excluse scan functions for SI chip class. */ +/** + * \param maxprefix specifies that the result only needs to be correct for a + * prefix of this many threads + * + * TODO: add inclusive and excluse scan functions for SI chip class. + */ static LLVMValueRef -ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity) +ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity, + unsigned maxprefix) { LLVMValueRef result, tmp; result = src; + if (maxprefix <= 1) + return result; tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false); result = ac_build_alu_op(ctx, result, tmp, op); + if (maxprefix <= 2) + return result; tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(2), 0xf, 0xf, false); result = ac_build_alu_op(ctx, result, tmp, op); + if (maxprefix <= 3) + return result; tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(3), 0xf, 0xf, false); result = ac_build_alu_op(ctx, result, tmp, op); + if (maxprefix <= 4) + return result; tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(4), 0xf, 0xe, false); result = ac_build_alu_op(ctx, result, tmp, op); + if (maxprefix <= 8) + return result; tmp = ac_build_dpp(ctx, identity, result, dpp_row_sr(8), 0xf, 0xc, false); result = ac_build_alu_op(ctx, result, tmp, op); + if (maxprefix <= 16) + return result; tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false); result = ac_build_alu_op(ctx, result, tmp, op); + if (maxprefix <= 32) + return result; tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false); result = ac_build_alu_op(ctx, result, tmp, op); return result; @@ -3144,7 +3164,7 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src))); result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); - result = ac_build_scan(ctx, op, result, identity); + result = ac_build_scan(ctx, op, result, identity, 64); return ac_build_wwm(ctx, result); } @@ -3159,7 +3179,7 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false); - result = ac_build_scan(ctx, op, result, identity); + result = ac_build_scan(ctx, op, result, identity, 64); return ac_build_wwm(ctx, result); } @@ -3217,6 +3237,173 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign } } +/** + * "Top half" of a scan that reduces per-wave values across an entire + * workgroup. + * + * The source value must be present in the highest lane of the wave, and the + * highest lane must be live. + */ +void +ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) +{ + if (ws->maxwaves <= 1) + return; + + const LLVMValueRef i32_63 = LLVMConstInt(ctx->i32, 63, false); + LLVMBuilderRef builder = ctx->builder; + LLVMValueRef tid = ac_get_thread_id(ctx); + LLVMValueRef tmp; + + tmp = LLVMBuildICmp(builder, LLVMIntEQ, tid, i32_63, ""); + ac_build_ifcc(ctx, tmp, 1000); + LLVMBuildStore(builder, ws->src, LLVMBuildGEP(builder, ws->scratch, &ws->waveidx, 1, "")); + ac_build_endif(ctx, 1000); +} + +/** + * "Bottom half" of a scan that reduces per-wave values across an entire + * workgroup. + * + * The caller must place a barrier between the top and bottom halves. + */ +void +ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) +{ + const LLVMTypeRef type = LLVMTypeOf(ws->src); + const LLVMValueRef identity = + get_reduction_identity(ctx, ws->op, ac_get_type_size(type)); + + if (ws->maxwaves <= 1) { + ws->result_reduce = ws->src; + ws->result_inclusive = ws->src; + ws->result_exclusive = identity; + return; + } + assert(ws->maxwaves <= 32); + + LLVMBuilderRef builder = ctx->builder; + LLVMValueRef tid = ac_get_thread_id(ctx); + LLVMBasicBlockRef bbs[2]; + LLVMValueRef phivalues_scan[2]; + LLVMValueRef tmp, tmp2; + + bbs[0] = LLVMGetInsertBlock(builder); + phivalues_scan[0] = LLVMGetUndef(type); + + if (ws->enable_reduce) + tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->numwaves, ""); + else if (ws->enable_inclusive) + tmp = LLVMBuildICmp(builder, LLVMIntULE, tid, ws->waveidx, ""); + else + tmp = LLVMBuildICmp(builder, LLVMIntULT, tid, ws->waveidx, ""); + ac_build_ifcc(ctx, tmp, 1001); + { + tmp = LLVMBuildLoad(builder, LLVMBuildGEP(builder, ws->scratch, &tid, 1, ""), ""); + + ac_build_optimization_barrier(ctx, &tmp); + + bbs[1] = LLVMGetInsertBlock(builder); + phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves); + } + ac_build_endif(ctx, 1001); + + const LLVMValueRef scan = ac_build_phi(ctx, type, 2, phivalues_scan, bbs); + + if (ws->enable_reduce) { + tmp = LLVMBuildSub(builder, ws->numwaves, ctx->i32_1, ""); + ws->result_reduce = ac_build_readlane(ctx, scan, tmp); + } + if (ws->enable_inclusive) + ws->result_inclusive = ac_build_readlane(ctx, scan, ws->waveidx); + if (ws->enable_exclusive) { + tmp = LLVMBuildSub(builder, ws->waveidx, ctx->i32_1, ""); + tmp = ac_build_readlane(ctx, scan, tmp); + tmp2 = LLVMBuildICmp(builder, LLVMIntEQ, ws->waveidx, ctx->i32_0, ""); + ws->result_exclusive = LLVMBuildSelect(builder, tmp2, identity, tmp, ""); + } +} + +/** + * Inclusive scan of a per-wave value across an entire workgroup. + * + * This implies an s_barrier instruction. + * + * Unlike ac_build_inclusive_scan, the caller \em must ensure that all threads + * of the workgroup are live. (This requirement cannot easily be relaxed in a + * useful manner because of the barrier in the algorithm.) + */ +void +ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) +{ + ac_build_wg_wavescan_top(ctx, ws); + ac_build_s_barrier(ctx); + ac_build_wg_wavescan_bottom(ctx, ws); +} + +/** + * "Top half" of a scan that reduces per-thread values across an entire + * workgroup. + * + * All lanes must be active when this code runs. + */ +void +ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) +{ + if (ws->enable_exclusive) { + ws->extra = ac_build_exclusive_scan(ctx, ws->src, ws->op); + ws->src = ac_build_alu_op(ctx, ws->extra, ws->src, ws->op); + } else { + ws->src = ac_build_inclusive_scan(ctx, ws->src, ws->op); + } + + bool enable_inclusive = ws->enable_inclusive; + bool enable_exclusive = ws->enable_exclusive; + ws->enable_inclusive = false; + ws->enable_exclusive = ws->enable_exclusive || enable_inclusive; + ac_build_wg_wavescan_top(ctx, ws); + ws->enable_inclusive = enable_inclusive; + ws->enable_exclusive = enable_exclusive; +} + +/** + * "Bottom half" of a scan that reduces per-thread values across an entire + * workgroup. + * + * The caller must place a barrier between the top and bottom halves. + */ +void +ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) +{ + bool enable_inclusive = ws->enable_inclusive; + bool enable_exclusive = ws->enable_exclusive; + ws->enable_inclusive = false; + ws->enable_exclusive = ws->enable_exclusive || enable_inclusive; + ac_build_wg_wavescan_bottom(ctx, ws); + ws->enable_inclusive = enable_inclusive; + ws->enable_exclusive = enable_exclusive; + + /* ws->result_reduce is already the correct value */ + if (ws->enable_inclusive) + ws->result_inclusive = ac_build_alu_op(ctx, ws->result_exclusive, ws->src, ws->op); + if (ws->enable_exclusive) + ws->result_exclusive = ac_build_alu_op(ctx, ws->result_exclusive, ws->extra, ws->op); +} + +/** + * A scan that reduces per-thread values across an entire workgroup. + * + * The caller must ensure that all lanes are active when this code runs + * (WWM is insufficient!), because there is an implied barrier. + */ +void +ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) +{ + ac_build_wg_scan_top(ctx, ws); + ac_build_s_barrier(ctx); + ac_build_wg_scan_bottom(ctx, ws); +} + LLVMValueRef ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src, unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3) diff --git a/src/amd/common/ac_llvm_build.h b/src/amd/common/ac_llvm_build.h index ab063546338..e47893bbbe6 100644 --- a/src/amd/common/ac_llvm_build.h +++ b/src/amd/common/ac_llvm_build.h @@ -524,6 +524,42 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op LLVMValueRef ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsigned cluster_size); +/** + * Common arguments for a scan/reduce operation that accumulates per-wave + * values across an entire workgroup, while respecting the order of waves. + */ +struct ac_wg_scan { + bool enable_reduce; + bool enable_exclusive; + bool enable_inclusive; + nir_op op; + LLVMValueRef src; /* clobbered! */ + LLVMValueRef result_reduce; + LLVMValueRef result_exclusive; + LLVMValueRef result_inclusive; + LLVMValueRef extra; + LLVMValueRef waveidx; + LLVMValueRef numwaves; /* only needed for "reduce" operations */ + + /* T addrspace(LDS) pointer to the same type as value, at least maxwaves entries */ + LLVMValueRef scratch; + unsigned maxwaves; +}; + +void +ac_build_wg_wavescan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws); +void +ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws); +void +ac_build_wg_wavescan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws); + +void +ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws); +void +ac_build_wg_scan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws); +void +ac_build_wg_scan(struct ac_llvm_context *ctx, struct ac_wg_scan *ws); + LLVMValueRef ac_build_quad_swizzle(struct ac_llvm_context *ctx, LLVMValueRef src, unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3); -- 2.30.2