aco: create 16-bit mad/fma
authorRhys Perry <pendingchaos02@gmail.com>
Thu, 14 May 2020 20:09:36 +0000 (21:09 +0100)
committerMarge Bot <eric+marge@anholt.net>
Mon, 15 Jun 2020 18:24:22 +0000 (18:24 +0000)
fossil-db (Navi, fp16 enabled):
Totals from 1 (0.00% of 127638) affected shaders:
CodeSize: 4868 -> 4552 (-6.49%)
Instrs: 956 -> 863 (-9.73%)
Cycles: 3824 -> 3452 (-9.73%)
VMEM: 504 -> 490 (-2.78%)
SMEM: 109 -> 107 (-1.83%)
VClause: 19 -> 20 (+5.26%)
Copies: 54 -> 58 (+7.41%)
PreVGPRs: 43 -> 41 (-4.65%)

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5245>

src/amd/compiler/aco_optimizer.cpp
src/amd/compiler/aco_register_allocation.cpp

index 67d18231319f3ab225db6912d30c0a4cab3306f5..f56ef5f5170c75d662c1ea8abffb21c224e4b36b 100644 (file)
@@ -1103,6 +1103,10 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
       }
       break;
    }
+   case aco_opcode::v_mul_f16: {
+      ctx.info[instr->definitions[0].tempId()].set_mul(instr.get());
+      break;
+   }
    case aco_opcode::v_and_b32: /* abs */
       if (!instr->usesModifiers() && instr->operands[0].constantEquals(0x7FFFFFFF) &&
           instr->operands[1].isTemp() && instr->operands[1].getTemp().type() == RegType::vgpr)
@@ -2415,11 +2419,15 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
    bool mad32 = instr->opcode == aco_opcode::v_add_f32 ||
                 instr->opcode == aco_opcode::v_sub_f32 ||
                 instr->opcode == aco_opcode::v_subrev_f32;
-   if (mad32) {
-      bool need_fma = block.fp_mode.denorm32 != 0;
+   bool mad16 = instr->opcode == aco_opcode::v_add_f16 ||
+                instr->opcode == aco_opcode::v_sub_f16 ||
+                instr->opcode == aco_opcode::v_subrev_f16;
+   if (mad16 || mad32) {
+      bool need_fma = mad32 ? block.fp_mode.denorm32 != 0 :
+                              (block.fp_mode.denorm16_64 != 0 || ctx.program->chip_class >= GFX10);
       if (need_fma && instr->definitions[0].isPrecise())
          return;
-      if (need_fma && !ctx.program->has_fast_fma32)
+      if (need_fma && mad32 && !ctx.program->has_fast_fma32)
          return;
 
       uint32_t uses_src0 = UINT32_MAX;
@@ -2500,12 +2508,15 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
             /* neg of the multiplication result */
             neg[1] = neg[1] ^ vop3->neg[1 - add_op_idx];
          }
-         if (instr->opcode == aco_opcode::v_sub_f32)
+         if (instr->opcode == aco_opcode::v_sub_f32 || instr->opcode == aco_opcode::v_sub_f16)
             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
-         else if (instr->opcode == aco_opcode::v_subrev_f32)
+         else if (instr->opcode == aco_opcode::v_subrev_f32 || instr->opcode == aco_opcode::v_subrev_f16)
             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
 
          aco_opcode mad_op = need_fma ? aco_opcode::v_fma_f32 : aco_opcode::v_mad_f32;
+         if (mad16)
+            mad_op = need_fma ? (ctx.program->chip_class == GFX8 ? aco_opcode::v_fma_legacy_f16 : aco_opcode::v_fma_f16) :
+                                (ctx.program->chip_class == GFX8 ? aco_opcode::v_mad_legacy_f16 : aco_opcode::v_mad_f16);
 
          aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(mad_op, Format::VOP3A, 3, 1)};
          for (unsigned i = 0; i < 3; i++)
@@ -2730,7 +2741,8 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       /* check literals */
       else if (!instr->usesModifiers()) {
          /* FMA can only take literals on GFX10+ */
-         if (instr->opcode == aco_opcode::v_fma_f32 && ctx.program->chip_class < GFX10)
+         if ((instr->opcode == aco_opcode::v_fma_f32 || instr->opcode == aco_opcode::v_fma_f16) &&
+             ctx.program->chip_class < GFX10)
             return;
 
          bool sgpr_used = false;
@@ -2903,6 +2915,10 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          aco_opcode new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32;
          if (instr->opcode == aco_opcode::v_fma_f32)
             new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32;
+         else if (instr->opcode == aco_opcode::v_mad_f16 || instr->opcode == aco_opcode::v_mad_legacy_f16)
+            new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16;
+         else if (instr->opcode == aco_opcode::v_fma_f16)
+            new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16;
 
          new_mad.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 1));
          if (info->literal_idx == 2) { /* add literal -> madak */
index a824e8b546cd63eee300a1ed163f08182eb9f375..505f5cb613df31358bf8abc359576a09e7c34aa3 100644 (file)
@@ -1735,7 +1735,10 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                if (!def.isFixed() && instr->opcode == aco_opcode::p_parallelcopy)
                   op = instr->operands[i];
                else if ((instr->opcode == aco_opcode::v_mad_f32 ||
-                        (instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10)) && !instr->usesModifiers())
+                        (instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10) ||
+                        instr->opcode == aco_opcode::v_mad_f16 ||
+                        instr->opcode == aco_opcode::v_mad_legacy_f16 ||
+                        (instr->opcode == aco_opcode::v_fma_f16 && program->chip_class >= GFX10)) && !instr->usesModifiers())
                   op = instr->operands[2];
 
                if (op.isTemp() && op.isFirstKillBeforeDef() && def.regClass() == op.regClass()) {
@@ -2011,13 +2014,19 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
 
          /* try to optimize v_mad_f32 -> v_mac_f32 */
          if ((instr->opcode == aco_opcode::v_mad_f32 ||
-              (instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10)) &&
+              (instr->opcode == aco_opcode::v_fma_f32 && program->chip_class >= GFX10) ||
+              instr->opcode == aco_opcode::v_mad_f16 ||
+              instr->opcode == aco_opcode::v_mad_legacy_f16 ||
+              (instr->opcode == aco_opcode::v_fma_f16 && program->chip_class >= GFX10)) &&
              instr->operands[2].isTemp() &&
              instr->operands[2].isKillBeforeDef() &&
              instr->operands[2].getTemp().type() == RegType::vgpr &&
              instr->operands[1].isTemp() &&
              instr->operands[1].getTemp().type() == RegType::vgpr &&
-             !instr->usesModifiers()) {
+             !instr->usesModifiers() &&
+             instr->operands[0].physReg().byte() == 0 &&
+             instr->operands[1].physReg().byte() == 0 &&
+             instr->operands[2].physReg().byte() == 0) {
             unsigned def_id = instr->definitions[0].tempId();
             auto it = ctx.affinities.find(def_id);
             if (it == ctx.affinities.end() || !ctx.assignments[it->second].assigned ||
@@ -2031,6 +2040,13 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                case aco_opcode::v_fma_f32:
                   instr->opcode = aco_opcode::v_fmac_f32;
                   break;
+               case aco_opcode::v_mad_f16:
+               case aco_opcode::v_mad_legacy_f16:
+                  instr->opcode = aco_opcode::v_mac_f16;
+                  break;
+               case aco_opcode::v_fma_f16:
+                  instr->opcode = aco_opcode::v_fmac_f16;
+                  break;
                default:
                   break;
                }
@@ -2041,6 +2057,8 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
          if (instr->opcode == aco_opcode::v_interp_p2_f32 ||
              instr->opcode == aco_opcode::v_mac_f32 ||
              instr->opcode == aco_opcode::v_fmac_f32 ||
+             instr->opcode == aco_opcode::v_mac_f16 ||
+             instr->opcode == aco_opcode::v_fmac_f16 ||
              instr->opcode == aco_opcode::v_writelane_b32 ||
              instr->opcode == aco_opcode::v_writelane_b32_e64) {
             instr->definitions[0].setFixed(instr->operands[2].physReg());