aco: try to use fma instead of mad when denormals are enabled
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 15 May 2020 13:03:15 +0000 (14:03 +0100)
committerMarge Bot <eric+marge@anholt.net>
Mon, 15 Jun 2020 18:24:22 +0000 (18:24 +0000)
v_mad_f32 doesn't support denormals but v_fma_f32 does.

No fossil-db changes.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5245>

src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_optimizer.cpp
src/amd/compiler/aco_register_allocation.cpp

index eb07e7b6a830dd2913be0ccd99df350cbb8b0f55..6bd36835ce2ca286881a56edcab0ffd2849b03c8 100644 (file)
@@ -1256,6 +1256,10 @@ setup_isel_context(Program* program,
 
    setup_xnack(program);
    program->sram_ecc_enabled = args->options->family == CHIP_ARCTURUS;
+   /* apparently gfx702 also has fast v_fma_f32 but I can't find a family for that */
+   program->has_fast_fma32 = program->chip_class >= GFX9;
+   if (args->options->family == CHIP_TAHITI || args->options->family == CHIP_CARRIZO || args->options->family == CHIP_HAWAII)
+      program->has_fast_fma32 = true;
 
    return ctx;
 }
index bd221ad6b617fece3adb16c3c9b57267c3740ad2..68d0b9bf4ceed9fce26c16e512fb6edff352818a 100644 (file)
@@ -1451,6 +1451,7 @@ public:
 
    bool xnack_enabled = false;
    bool sram_ecc_enabled = false;
+   bool has_fast_fma32 = false;
 
    bool needs_vcc = false;
    bool needs_flat_scr = false;
index 37dcb89b182fcd772316594bac42f53721f2a8cf..67d18231319f3ab225db6912d30c0a4cab3306f5 100644 (file)
@@ -2410,37 +2410,44 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
       ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
       return;
    }
+
    /* combine mul+add -> mad */
-   else if ((instr->opcode == aco_opcode::v_add_f32 ||
-             instr->opcode == aco_opcode::v_sub_f32 ||
-             instr->opcode == aco_opcode::v_subrev_f32) &&
-            block.fp_mode.denorm32 == 0) {
-      //TODO: we could use fma instead when denormals are enabled if the NIR isn't marked as precise
+   bool mad32 = instr->opcode == aco_opcode::v_add_f32 ||
+                instr->opcode == aco_opcode::v_sub_f32 ||
+                instr->opcode == aco_opcode::v_subrev_f32;
+   if (mad32) {
+      bool need_fma = block.fp_mode.denorm32 != 0;
+      if (need_fma && instr->definitions[0].isPrecise())
+         return;
+      if (need_fma && !ctx.program->has_fast_fma32)
+         return;
 
       uint32_t uses_src0 = UINT32_MAX;
       uint32_t uses_src1 = UINT32_MAX;
       Instruction* mul_instr = nullptr;
       unsigned add_op_idx;
       /* check if any of the operands is a multiplication */
-      if (instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_mul())
+      ssa_info *op0_info = instr->operands[0].isTemp() ? &ctx.info[instr->operands[0].tempId()] : NULL;
+      ssa_info *op1_info = instr->operands[1].isTemp() ? &ctx.info[instr->operands[1].tempId()] : NULL;
+      if (op0_info && op0_info->is_mul() && (!need_fma || !op0_info->instr->definitions[0].isPrecise()))
          uses_src0 = ctx.uses[instr->operands[0].tempId()];
-      if (instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_mul())
+      if (op1_info && op1_info->is_mul() && (!need_fma || !op1_info->instr->definitions[0].isPrecise()))
          uses_src1 = ctx.uses[instr->operands[1].tempId()];
 
       /* find the 'best' mul instruction to combine with the add */
       if (uses_src0 < uses_src1) {
-         mul_instr = ctx.info[instr->operands[0].tempId()].instr;
+         mul_instr = op0_info->instr;
          add_op_idx = 1;
       } else if (uses_src1 < uses_src0) {
-         mul_instr = ctx.info[instr->operands[1].tempId()].instr;
+         mul_instr = op1_info->instr;
          add_op_idx = 0;
       } else if (uses_src0 != UINT32_MAX) {
          /* tiebreaker: quite random what to pick */
-         if (ctx.info[instr->operands[0].tempId()].instr->operands[0].isLiteral()) {
-            mul_instr = ctx.info[instr->operands[1].tempId()].instr;
+         if (op0_info->instr->operands[0].isLiteral()) {
+            mul_instr = op1_info->instr;
             add_op_idx = 0;
          } else {
-            mul_instr = ctx.info[instr->operands[0].tempId()].instr;
+            mul_instr = op0_info->instr;
             add_op_idx = 1;
          }
       }
@@ -2498,7 +2505,9 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          else if (instr->opcode == aco_opcode::v_subrev_f32)
             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
 
-         aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(aco_opcode::v_mad_f32, Format::VOP3A, 3, 1)};
+         aco_opcode mad_op = need_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
+
+         aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(mad_op, Format::VOP3A, 3, 1)};
          for (unsigned i = 0; i < 3; i++)
          {
             mad->operands[i] = op[i];
@@ -2706,7 +2715,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    }
 
    mad_info* mad_info = NULL;
-   if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
+   if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
       mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
       /* re-check mad instructions */
       if (ctx.uses[mad_info->mul_temp_id]) {
@@ -2720,6 +2729,10 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       }
       /* check literals */
       else if (!instr->usesModifiers()) {
+         /* FMA can only take literals on GFX10+ */
+         if (instr->opcode == aco_opcode::v_fma_f32 && ctx.program->chip_class < GFX10)
+            return;
+
          bool sgpr_used = false;
          uint32_t literal_idx = 0;
          uint32_t literal_uses = UINT32_MAX;
@@ -2881,17 +2894,21 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       return;
 
    /* apply literals on MAD */
-   if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
+   if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) {
       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
       if (info->check_literal &&
           (ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) {
          aco_ptr<Instruction> new_mad;
+
+         aco_opcode new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
+         if (instr->opcode == aco_opcode::v_fma_f32)
+            new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
+
+         new_mad.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 1));
          if (info->literal_idx == 2) { /* add literal -> madak */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
             new_mad->operands[0] = instr->operands[0];
             new_mad->operands[1] = instr->operands[1];
          } else { /* mul literal -> madmk */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
             new_mad->operands[0] = instr->operands[1 - info->literal_idx];
             new_mad->operands[1] = instr->operands[2];
          }
index 8e662a282c8e3e0222a4752a0ee1f3a20357de4d..a824e8b546cd63eee300a1ed163f08182eb9f375 100644 (file)
@@ -1734,7 +1734,8 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                Operand op = Operand();
                if (!def.isFixed() && instr->opcode == aco_opcode::p_parallelcopy)
                   op = instr->operands[i];
-               else if (instr->opcode == aco_opcode::v_mad_f32 && !instr->usesModifiers())
+               else if ((instr->opcode == aco_opcode::v_mad_f32 ||
+                        (instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10)) && !instr->usesModifiers())
                   op = instr->operands[2];
 
                if (op.isTemp() && op.isFirstKillBeforeDef() && def.regClass() == op.regClass()) {
@@ -2009,7 +2010,8 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
          }
 
          /* try to optimize v_mad_f32 -> v_mac_f32 */
-         if (instr->opcode == aco_opcode::v_mad_f32 &&
+         if ((instr->opcode == aco_opcode::v_mad_f32 ||
+              (instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10)) &&
              instr->operands[2].isTemp() &&
              instr->operands[2].isKillBeforeDef() &&
              instr->operands[2].getTemp().type() == RegType::vgpr &&
@@ -2022,13 +2024,23 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                 instr->operands[2].physReg() == ctx.assignments[it->second].reg ||
                 register_file.test(ctx.assignments[it->second].reg, instr->operands[2].bytes())) {
                instr->format = Format::VOP2;
-               instr->opcode = aco_opcode::v_mac_f32;
+               switch (instr->opcode) {
+               case aco_opcode::v_mad_f32:
+                  instr->opcode = aco_opcode::v_mac_f32;
+                  break;
+               case aco_opcode::v_fma_f32:
+                  instr->opcode = aco_opcode::v_fmac_f32;
+                  break;
+               default:
+                  break;
+               }
             }
          }
 
          /* handle definitions which must have the same register as an operand */
          if (instr->opcode == aco_opcode::v_interp_p2_f32 ||
              instr->opcode == aco_opcode::v_mac_f32 ||
+             instr->opcode == aco_opcode::v_fmac_f32 ||
              instr->opcode == aco_opcode::v_writelane_b32 ||
              instr->opcode == aco_opcode::v_writelane_b32_e64) {
             instr->definitions[0].setFixed(instr->operands[2].physReg());