From 227c29a80dec4c3209d60a6e8bccf08864ec0e7f Mon Sep 17 00:00:00 2001 From: =?utf8?q?Nicolai=20H=C3=A4hnle?= Date: Wed, 23 May 2018 22:08:22 +0200 Subject: [PATCH] amd/common/gfx10: implement scan & reduce operations Acked-by: Bas Nieuwenhuizen --- src/amd/common/ac_llvm_build.c | 112 ++++++++++++++++++++++++++++++--- 1 file changed, 104 insertions(+), 8 deletions(-) diff --git a/src/amd/common/ac_llvm_build.c b/src/amd/common/ac_llvm_build.c index 92c878acdcf..f0c96442b2e 100644 --- a/src/amd/common/ac_llvm_build.c +++ b/src/amd/common/ac_llvm_build.c @@ -3874,6 +3874,58 @@ ac_build_dpp(struct ac_llvm_context *ctx, LLVMValueRef old, LLVMValueRef src, return LLVMBuildBitCast(ctx->builder, ret, src_type, ""); } +static LLVMValueRef +_ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel, + bool exchange_rows, bool bound_ctrl) +{ + LLVMValueRef args[6] = { + src, + src, + LLVMConstInt(ctx->i32, sel, false), + LLVMConstInt(ctx->i32, sel >> 32, false), + ctx->i1true, /* fi */ + bound_ctrl ? ctx->i1true : ctx->i1false, + }; + return ac_build_intrinsic(ctx, exchange_rows ? "llvm.amdgcn.permlanex16" + : "llvm.amdgcn.permlane16", + ctx->i32, args, 6, + AC_FUNC_ATTR_READNONE | AC_FUNC_ATTR_CONVERGENT); +} + +static LLVMValueRef +ac_build_permlane16(struct ac_llvm_context *ctx, LLVMValueRef src, uint64_t sel, + bool exchange_rows, bool bound_ctrl) +{ + LLVMTypeRef src_type = LLVMTypeOf(src); + src = ac_to_integer(ctx, src); + unsigned bits = LLVMGetIntTypeWidth(LLVMTypeOf(src)); + LLVMValueRef ret; + if (bits == 32) { + ret = _ac_build_permlane16(ctx, src, sel, exchange_rows, + bound_ctrl); + } else { + assert(bits % 32 == 0); + LLVMTypeRef vec_type = LLVMVectorType(ctx->i32, bits / 32); + LLVMValueRef src_vector = + LLVMBuildBitCast(ctx->builder, src, vec_type, ""); + ret = LLVMGetUndef(vec_type); + for (unsigned i = 0; i < bits / 32; i++) { + src = LLVMBuildExtractElement(ctx->builder, src_vector, + LLVMConstInt(ctx->i32, i, + 0), ""); + LLVMValueRef ret_comp = + _ac_build_permlane16(ctx, src, sel, + exchange_rows, + bound_ctrl); + ret = LLVMBuildInsertElement(ctx->builder, ret, + ret_comp, + LLVMConstInt(ctx->i32, i, + 0), ""); + } + } + return LLVMBuildBitCast(ctx->builder, ret, src_type, ""); +} + static inline unsigned ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask) { @@ -4037,10 +4089,18 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs, */ static LLVMValueRef ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity, - unsigned maxprefix) + unsigned maxprefix, bool inclusive) { LLVMValueRef result, tmp; - result = src; + + if (ctx->chip_class >= GFX10) { + result = inclusive ? src : identity; + } else { + if (inclusive) + result = src; + else + result = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false); + } if (maxprefix <= 1) return result; tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false); @@ -4063,6 +4123,38 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu result = ac_build_alu_op(ctx, result, tmp, op); if (maxprefix <= 16) return result; + + if (ctx->chip_class >= GFX10) { + /* dpp_row_bcast{15,31} are not supported on gfx10. */ + LLVMBuilderRef builder = ctx->builder; + LLVMValueRef tid = ac_get_thread_id(ctx); + LLVMValueRef cc; + /* TODO-GFX10: Can we get better code-gen by putting this into + * a branch so that LLVM generates EXEC mask manipulations? */ + if (inclusive) + tmp = result; + else + tmp = ac_build_alu_op(ctx, result, src, op); + tmp = ac_build_permlane16(ctx, tmp, ~(uint64_t)0, true, false); + tmp = ac_build_alu_op(ctx, result, tmp, op); + cc = LLVMBuildAnd(builder, tid, LLVMConstInt(ctx->i32, 16, false), ""); + cc = LLVMBuildICmp(builder, LLVMIntNE, cc, ctx->i32_0, ""); + result = LLVMBuildSelect(builder, cc, tmp, result, ""); + if (maxprefix <= 32) + return result; + + if (inclusive) + tmp = result; + else + tmp = ac_build_alu_op(ctx, result, src, op); + tmp = ac_build_readlane(ctx, tmp, LLVMConstInt(ctx->i32, 31, false)); + tmp = ac_build_alu_op(ctx, result, tmp, op); + cc = LLVMBuildICmp(builder, LLVMIntUGE, tid, + LLVMConstInt(ctx->i32, 32, false), ""); + result = LLVMBuildSelect(builder, cc, tmp, result, ""); + return result; + } + tmp = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false); result = ac_build_alu_op(ctx, result, tmp, op); if (maxprefix <= 32) @@ -4092,7 +4184,7 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src))); result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); - result = ac_build_scan(ctx, op, result, identity, 64); + result = ac_build_scan(ctx, op, result, identity, 64, true); return ac_build_wwm(ctx, result); } @@ -4116,8 +4208,7 @@ ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src))); result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity), LLVMTypeOf(identity), ""); - result = ac_build_dpp(ctx, identity, result, dpp_wf_sr1, 0xf, 0xf, false); - result = ac_build_scan(ctx, op, result, identity, 64); + result = ac_build_scan(ctx, op, result, identity, 64, false); return ac_build_wwm(ctx, result); } @@ -4155,7 +4246,9 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign result = ac_build_alu_op(ctx, result, swap, op); if (cluster_size == 16) return ac_build_wwm(ctx, result); - if (ctx->chip_class >= GFX8 && cluster_size != 32) + if (ctx->chip_class >= GFX10) + swap = ac_build_permlane16(ctx, result, 0, true, false); + else if (ctx->chip_class >= GFX8 && cluster_size != 32) swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast15, 0xa, 0xf, false); else swap = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1f, 0, 0x10)); @@ -4163,7 +4256,10 @@ ac_build_reduce(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op, unsign if (cluster_size == 32) return ac_build_wwm(ctx, result); if (ctx->chip_class >= GFX8) { - swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false); + if (ctx->chip_class >= GFX10) + swap = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, false)); + else + swap = ac_build_dpp(ctx, identity, result, dpp_row_bcast31, 0xc, 0xf, false); result = ac_build_alu_op(ctx, result, swap, op); result = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 63, 0)); return ac_build_wwm(ctx, result); @@ -4242,7 +4338,7 @@ ac_build_wg_wavescan_bottom(struct ac_llvm_context *ctx, struct ac_wg_scan *ws) ac_build_optimization_barrier(ctx, &tmp); bbs[1] = LLVMGetInsertBlock(builder); - phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves); + phivalues_scan[1] = ac_build_scan(ctx, ws->op, tmp, identity, ws->maxwaves, true); } ac_build_endif(ctx, 1001); -- 2.30.2