ac: generate FMA for inexact instructions for radeonsi
authorMarek Olšák <marek.olsak@amd.com>
Thu, 23 Apr 2020 02:45:45 +0000 (22:45 -0400)
committerMarge Bot <eric+marge@anholt.net>
Mon, 27 Apr 2020 11:20:16 +0000 (11:20 +0000)
NIR mostly does this already.

Totals:
SGPRS: 2588520 -> 2591784 (0.13 %)
VGPRS: 1666984 -> 1666888 (-0.01 %)
Spilled SGPRs: 4074 -> 4131 (1.40 %)
Spilled VGPRs: 38 -> 38 (0.00 %)
Private memory VGPRs: 2176 -> 2176 (0.00 %)
Scratch size: 2228 -> 2228 (0.00 %) dwords per thread
Code Size: 52726872 -> 52715468 (-0.02 %) bytes
LDS: 92 -> 92 (0.00 %) blocks
Max Waves: 479872 -> 479897 (0.01 %)
Wait states: 0 -> 0 (0.00 %)

Reviewed-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4696>

src/amd/llvm/ac_llvm_helper.cpp
src/amd/llvm/ac_llvm_util.h
src/amd/llvm/ac_nir_to_llvm.c

index 578521a6f2d6405bd2c471c26a2b8a343a2c8287..f5383344dd4f3c37348449d0eb28c61dd6d3699d 100644 (file)
@@ -96,6 +96,11 @@ LLVMBuilderRef ac_create_builder(LLVMContextRef ctx,
                 */
                flags.setAllowReciprocal(); /* arcp */
 
+               /* Allow floating-point contraction (e.g. fusing a multiply
+                * followed by an addition into a fused multiply-and-add).
+                */
+               flags.setAllowContract(); /* contract */
+
                llvm::unwrap(builder)->setFastMathFlags(flags);
                break;
        }
@@ -103,6 +108,32 @@ LLVMBuilderRef ac_create_builder(LLVMContextRef ctx,
        return builder;
 }
 
+/* Return the original state of inexact math. */
+bool ac_disable_inexact_math(LLVMBuilderRef builder)
+{
+       auto *b = llvm::unwrap(builder);
+       llvm::FastMathFlags flags = b->getFastMathFlags();
+
+       if (!flags.allowContract())
+               return false;
+
+       flags.setAllowContract(false);
+       b->setFastMathFlags(flags);
+       return true;
+}
+
+void ac_restore_inexact_math(LLVMBuilderRef builder, bool value)
+{
+       auto *b = llvm::unwrap(builder);
+       llvm::FastMathFlags flags = b->getFastMathFlags();
+
+       if (flags.allowContract() == value)
+               return;
+
+       flags.setAllowContract(value);
+       b->setFastMathFlags(flags);
+}
+
 LLVMTargetLibraryInfoRef
 ac_create_target_library_info(const char *triple)
 {
index 4cfb3b55388831ae06e8655b2c04702883b13572..f9650bdf4f1b024293bd64aa1a3920b1b306f899 100644 (file)
@@ -109,6 +109,8 @@ LLVMModuleRef ac_create_module(LLVMTargetMachineRef tm, LLVMContextRef ctx);
 
 LLVMBuilderRef ac_create_builder(LLVMContextRef ctx,
                                 enum ac_float_mode float_mode);
+bool ac_disable_inexact_math(LLVMBuilderRef builder);
+void ac_restore_inexact_math(LLVMBuilderRef builder, bool value);
 
 void
 ac_llvm_add_target_dep_function_attr(LLVMValueRef F,
index 627f5d2d931212ccff4dfc57db393c0348acc49d..03717191e243fd8b897d42e08f7c7c5639255d0f 100644 (file)
@@ -589,6 +589,10 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        unsigned num_components = instr->dest.dest.ssa.num_components;
        unsigned src_components;
        LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.dest.ssa);
+       bool saved_inexact = false;
+
+       if (instr->exact)
+               saved_inexact = ac_disable_inexact_math(ctx->ac.builder);
 
        assert(nir_op_infos[instr->op].num_inputs <= ARRAY_SIZE(src));
        switch (instr->op) {
@@ -1182,6 +1186,9 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = ac_to_integer_or_pointer(&ctx->ac, result);
                ctx->ssa_defs[instr->dest.dest.ssa.index] = result;
        }
+
+       if (instr->exact)
+               ac_restore_inexact_math(ctx->ac.builder, saved_inexact);
 }
 
 static void visit_load_const(struct ac_nir_context *ctx,