amd/common: add i1 special case to ac_build_{inclusive,exclusive}_scan
authorNicolai Hähnle <nicolai.haehnle@amd.com>
Thu, 20 Sep 2018 17:09:50 +0000 (19:09 +0200)
committerNicolai Hähnle <nicolai.haehnle@amd.com>
Wed, 19 Dec 2018 11:01:19 +0000 (12:01 +0100)
Allow for a unified but efficient treatment of adding a bitmask over a
wave or an entire threadgroup.

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
src/amd/common/ac_llvm_build.c

index cbd715cd378798492ec837e8bc549f3b84d70423..55ead314d680121197c9f79ac511e5105eba5c50 100644 (file)
@@ -3158,8 +3158,19 @@ ac_build_scan(struct ac_llvm_context *ctx, nir_op op, LLVMValueRef src, LLVMValu
 LLVMValueRef
 ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op)
 {
-       ac_build_optimization_barrier(ctx, &src);
        LLVMValueRef result;
+
+       if (LLVMTypeOf(src) == ctx->i1 && op == nir_op_iadd) {
+               LLVMBuilderRef builder = ctx->builder;
+               src = LLVMBuildZExt(builder, src, ctx->i32, "");
+               result = ac_build_ballot(ctx, src);
+               result = ac_build_mbcnt(ctx, result);
+               result = LLVMBuildAdd(builder, result, src, "");
+               return result;
+       }
+
+       ac_build_optimization_barrier(ctx, &src);
+
        LLVMValueRef identity =
                get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
        result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
@@ -3172,8 +3183,18 @@ ac_build_inclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op
 LLVMValueRef
 ac_build_exclusive_scan(struct ac_llvm_context *ctx, LLVMValueRef src, nir_op op)
 {
-       ac_build_optimization_barrier(ctx, &src);
        LLVMValueRef result;
+
+       if (LLVMTypeOf(src) == ctx->i1 && op == nir_op_iadd) {
+               LLVMBuilderRef builder = ctx->builder;
+               src = LLVMBuildZExt(builder, src, ctx->i32, "");
+               result = ac_build_ballot(ctx, src);
+               result = ac_build_mbcnt(ctx, result);
+               return result;
+       }
+
+       ac_build_optimization_barrier(ctx, &src);
+
        LLVMValueRef identity =
                get_reduction_identity(ctx, op, ac_get_type_size(LLVMTypeOf(src)));
        result = LLVMBuildBitCast(ctx->builder, ac_build_set_inactive(ctx, src, identity),
@@ -3352,6 +3373,8 @@ ac_build_wg_scan_top(struct ac_llvm_context *ctx, struct ac_wg_scan *ws)
 {
        if (ws->enable_exclusive) {
                ws->extra = ac_build_exclusive_scan(ctx, ws->src, ws->op);
+               if (LLVMTypeOf(ws->src) == ctx->i1 && ws->op == nir_op_iadd)
+                       ws->src = LLVMBuildZExt(ctx->builder, ws->src, ctx->i32, "");
                ws->src = ac_build_alu_op(ctx, ws->extra, ws->src, ws->op);
        } else {
                ws->src = ac_build_inclusive_scan(ctx, ws->src, ws->op);