From 0cbcfc071e32fd5fc9950a5660adb7dafb7aaef0 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Daniel=20Sch=C3=BCrmann?= Date: Wed, 20 Nov 2019 12:40:07 +0100 Subject: [PATCH] amd/llvm: Add Subgroup Scan functions for SI The idea of this implementation is taken from the ROCm Device Libs: https://github.com/RadeonOpenCompute/ROCm-Device-Libs/blob/master/ockl/src/wfredscan.cl Reviewed-by: Samuel Pitoiset --- src/amd/llvm/ac_llvm_build.c | 81 +++++++++++++++++++++++++++++++++--- 1 file changed, 75 insertions(+), 6 deletions(-) diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c index d418ee5ab71..12ee06c2678 100644 --- a/src/amd/llvm/ac_llvm_build.c +++ b/src/amd/llvm/ac_llvm_build.c @@ -4042,8 +4042,6 @@ ac_build_alu_op(struct ac_llvm_context *ctx, LLVMValueRef lhs, LLVMValueRef rhs, /** * \param maxprefix specifies that the result only needs to be correct for a * prefix of this many threads - * - * TODO: add inclusive and excluse scan functions for GFX6. */ static LLVMValueRef ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValueRef identity, @@ -4051,13 +4049,84 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu { LLVMValueRef result, tmp; - if (ctx->chip_class >= GFX10) { - result = inclusive ? src : identity; + if (inclusive) { + result = src; + } else if (ctx->chip_class >= GFX10) { + result = identity; + } else if (ctx->chip_class >= GFX8) { + src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false); + result = src; } else { - if (!inclusive) - src = ac_build_dpp(ctx, identity, src, dpp_wf_sr1, 0xf, 0xf, false); + /* wavefront shift_right by 1 on SI/CI */ + LLVMValueRef active, tmp1, tmp2; + LLVMValueRef tid = ac_get_thread_id(ctx); + tmp1 = ac_build_ds_swizzle(ctx, src, (1 << 15) | dpp_quad_perm(0, 0, 1, 2)); + tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x18, 0x03, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x7, 0), ""), + LLVMConstInt(ctx->i32, 0x4, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x10, 0x07, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0xf, 0), ""), + LLVMConstInt(ctx->i32, 0x8, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + tmp2 = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x00, 0x0f, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 0x1f, 0), ""), + LLVMConstInt(ctx->i32, 0x10, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + tmp2 = ac_build_readlane(ctx, src, LLVMConstInt(ctx->i32, 31, 0)); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 32, 0), ""); + tmp1 = LLVMBuildSelect(ctx->builder, active, tmp2, tmp1, ""); + active = LLVMBuildICmp(ctx->builder, LLVMIntEQ, tid, LLVMConstInt(ctx->i32, 0, 0), ""); + src = LLVMBuildSelect(ctx->builder, active, identity, tmp1, ""); result = src; + } + + if (ctx->chip_class <= GFX7) { + assert(maxprefix == 64); + LLVMValueRef tid = ac_get_thread_id(ctx); + LLVMValueRef active; + tmp = ac_build_ds_swizzle(ctx, src, ds_pattern_bitmode(0x1e, 0x00, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntNE, + LLVMBuildAnd(ctx->builder, tid, ctx->i32_1, ""), + ctx->i32_0, ""); + tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, ""); + result = ac_build_alu_op(ctx, result, tmp, op); + tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x1c, 0x01, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntNE, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 2, 0), ""), + ctx->i32_0, ""); + tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, ""); + result = ac_build_alu_op(ctx, result, tmp, op); + tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x18, 0x03, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntNE, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 4, 0), ""), + ctx->i32_0, ""); + tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, ""); + result = ac_build_alu_op(ctx, result, tmp, op); + tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x10, 0x07, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntNE, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 8, 0), ""), + ctx->i32_0, ""); + tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, ""); + result = ac_build_alu_op(ctx, result, tmp, op); + tmp = ac_build_ds_swizzle(ctx, result, ds_pattern_bitmode(0x00, 0x0f, 0x00)); + active = LLVMBuildICmp(ctx->builder, LLVMIntNE, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 16, 0), ""), + ctx->i32_0, ""); + tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, ""); + result = ac_build_alu_op(ctx, result, tmp, op); + tmp = ac_build_readlane(ctx, result, LLVMConstInt(ctx->i32, 31, 0)); + active = LLVMBuildICmp(ctx->builder, LLVMIntNE, + LLVMBuildAnd(ctx->builder, tid, LLVMConstInt(ctx->i32, 32, 0), ""), + ctx->i32_0, ""); + tmp = LLVMBuildSelect(ctx->builder, active, tmp, identity, ""); + result = ac_build_alu_op(ctx, result, tmp, op); + return result; } + if (maxprefix <= 1) return result; tmp = ac_build_dpp(ctx, identity, src, dpp_row_sr(1), 0xf, 0xf, false); -- 2.30.2