aco: improve creation of v_madmk_f32/v_madak_f32
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 22 Nov 2019 15:18:38 +0000 (15:18 +0000)
committerMarge Bot <eric+marge@anholt.net>
Tue, 14 Jan 2020 12:56:28 +0000 (12:56 +0000)
Using needs_vop3 check was flawed because it would only combine the
literal if the first operand is the literal. If the second or third
operand is the literal, then needs_vop3 will be true and the literal will
not be combined.

pipeline-db (Navi):
Totals from affected shaders:
SGPRS: 782051 -> 782051 (0.00 %)
VGPRS: 630048 -> 630048 (0.00 %)
Spilled SGPRs: 195 -> 195 (0.00 %)
Spilled VGPRs: 0 -> 0 (0.00 %)
Code Size: 54743740 -> 54585548 (-0.29 %) bytes
Max Waves: 67340 -> 67340 (0.00 %)
Instructions: 10182030 -> 10182030 (0.00 %)

pipeline-db (Vega):
Totals from affected shaders:
SGPRS: 701990 -> 699590 (-0.34 %)
VGPRS: 566632 -> 566784 (0.03 %)
Spilled SGPRs: 218 -> 218 (0.00 %)
Spilled VGPRs: 0 -> 0 (0.00 %)
Code Size: 49173564 -> 49007856 (-0.34 %) bytes
Max Waves: 59650 -> 59612 (-0.06 %)
Instructions: 9315135 -> 9293330 (-0.23 %)

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/merge_requests/2883>

src/amd/compiler/aco_optimizer.cpp

index c6c8931a4262f16b50c5ebcb2e875a2a1037c515..7a16fc176c9dd977117a6a3cac2cc1713c00ce3f 100644 (file)
@@ -53,11 +53,10 @@ struct mad_info {
    aco_ptr<Instruction> add_instr;
    uint32_t mul_temp_id;
    uint32_t literal_idx;
-   bool needs_vop3;
    bool check_literal;
 
-   mad_info(aco_ptr<Instruction> instr, uint32_t id, bool vop3)
-   : add_instr(std::move(instr)), mul_temp_id(id), needs_vop3(vop3), check_literal(false) {}
+   mad_info(aco_ptr<Instruction> instr, uint32_t id)
+   : add_instr(std::move(instr)), mul_temp_id(id), check_literal(false) {}
 };
 
 enum Label {
@@ -2194,7 +2193,6 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          bool abs[3] = {false, false, false};
          unsigned omod = 0;
          bool clamp = false;
-         bool need_vop3 = false;
          op[0] = mul_instr->operands[0];
          op[1] = mul_instr->operands[1];
          op[2] = instr->operands[add_op_idx];
@@ -2202,18 +2200,12 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          if (!check_vop3_operands(ctx, 3, op))
             return;
 
-         for (unsigned i = 0; i < 3; i++) {
-            if (!(i == 0 || (op[i].isTemp() && op[i].getTemp().type() == RegType::vgpr)))
-               need_vop3 = true;
-         }
-
          if (mul_instr->isVOP3()) {
             VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);
             neg[0] = vop3->neg[0];
             neg[1] = vop3->neg[1];
             abs[0] = vop3->abs[0];
             abs[1] = vop3->abs[1];
-            need_vop3 = true;
             /* we cannot use these modifiers between mul and add */
             if (vop3->clamp || vop3->omod)
                return;
@@ -2243,15 +2235,11 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
             }
             /* neg of the multiplication result */
             neg[1] = neg[1] ^ vop3->neg[1 - add_op_idx];
-            need_vop3 = true;
          }
-         if (instr->opcode == aco_opcode::v_sub_f32) {
+         if (instr->opcode == aco_opcode::v_sub_f32)
             neg[1 + add_op_idx] = neg[1 + add_op_idx] ^ true;
-            need_vop3 = true;
-         } else if (instr->opcode == aco_opcode::v_subrev_f32) {
+         else if (instr->opcode == aco_opcode::v_subrev_f32)
             neg[2 - add_op_idx] = neg[2 - add_op_idx] ^ true;
-            need_vop3 = true;
-         }
 
          aco_ptr<VOP3A_instruction> mad{create_instruction<VOP3A_instruction>(aco_opcode::v_mad_f32, Format::VOP3A, 3, 1)};
          for (unsigned i = 0; i < 3; i++)
@@ -2265,7 +2253,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          mad->definitions[0] = instr->definitions[0];
 
          /* mark this ssa_def to be re-checked for profitability and literals */
-         ctx.mad_infos.emplace_back(std::move(instr), mul_instr->definitions[0].tempId(), need_vop3);
+         ctx.mad_infos.emplace_back(std::move(instr), mul_instr->definitions[0].tempId());
          ctx.info[mad->definitions[0].tempId()].set_mad(mad.get(), ctx.mad_infos.size() - 1);
          instr.reset(mad.release());
          return;
@@ -2353,48 +2341,55 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       }
    }
 
-   /* re-check mad instructions */
+   mad_info* mad_info = NULL;
    if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
-      mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
-      /* first, check profitability */
-      if (ctx.uses[info->mul_temp_id]) {
-         ctx.uses[info->mul_temp_id]++;
+      mad_info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
+      /* re-check mad instructions */
+      if (ctx.uses[mad_info->mul_temp_id]) {
+         ctx.uses[mad_info->mul_temp_id]++;
          if (instr->operands[0].isTemp())
             ctx.uses[instr->operands[0].tempId()]--;
          if (instr->operands[1].isTemp())
             ctx.uses[instr->operands[1].tempId()]--;
-         instr.swap(info->add_instr);
-
-      /* second, check possible literals */
-      } else if (!info->needs_vop3) {
+         instr.swap(mad_info->add_instr);
+         mad_info = NULL;
+      }
+      /* check literals */
+      else if (!instr->usesModifiers()) {
+         bool sgpr_used = false;
          uint32_t literal_idx = 0;
          uint32_t literal_uses = UINT32_MAX;
          for (unsigned i = 0; i < instr->operands.size(); i++)
          {
+            if (instr->operands[i].isConstant() && i > 0) {
+               literal_uses = UINT32_MAX;
+               break;
+            }
             if (!instr->operands[i].isTemp())
                continue;
-            /* if one of the operands is sgpr, we cannot add a literal somewhere else */
-            if (instr->operands[i].getTemp().type() == RegType::sgpr) {
+            /* if one of the operands is sgpr, we cannot add a literal somewhere else on pre-GFX10 or operands other than the 1st */
+            if (instr->operands[i].getTemp().type() == RegType::sgpr && (i > 0 || ctx.program->chip_class < GFX10)) {
                if (ctx.info[instr->operands[i].tempId()].is_literal()) {
                   literal_uses = ctx.uses[instr->operands[i].tempId()];
                   literal_idx = i;
                } else {
                   literal_uses = UINT32_MAX;
                }
-               break;
-            }
-            else if (ctx.info[instr->operands[i].tempId()].is_literal() &&
-                ctx.uses[instr->operands[i].tempId()] < literal_uses) {
+               sgpr_used = true;
+               /* don't break because we still need to check constants */
+            } else if (!sgpr_used &&
+                       ctx.info[instr->operands[i].tempId()].is_literal() &&
+                       ctx.uses[instr->operands[i].tempId()] < literal_uses) {
                literal_uses = ctx.uses[instr->operands[i].tempId()];
                literal_idx = i;
             }
          }
          if (literal_uses < threshold) {
             ctx.uses[instr->operands[literal_idx].tempId()]--;
-            info->check_literal = true;
-            info->literal_idx = literal_idx;
+            mad_info->check_literal = true;
+            mad_info->literal_idx = literal_idx;
+            return;
          }
-         return;
       }
    }
 
@@ -2480,31 +2475,28 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       return;
 
    /* apply literals on MAD */
-   bool literals_applied = false;
    if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
-      if (!info->needs_vop3) {
+      if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
          aco_ptr<Instruction> new_mad;
-         if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
-            if (info->literal_idx == 2) { /* add literal -> madak */
-               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
-               new_mad->operands[0] = instr->operands[0];
-               new_mad->operands[1] = instr->operands[1];
-            } else { /* mul literal -> madmk */
-               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
-               new_mad->operands[0] = instr->operands[1 - info->literal_idx];
-               new_mad->operands[1] = instr->operands[2];
-            }
-            new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
-            new_mad->definitions[0] = instr->definitions[0];
-            instr.swap(new_mad);
+         if (info->literal_idx == 2) { /* add literal -> madak */
+            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
+            new_mad->operands[0] = instr->operands[0];
+            new_mad->operands[1] = instr->operands[1];
+         } else { /* mul literal -> madmk */
+            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
+            new_mad->operands[0] = instr->operands[1 - info->literal_idx];
+            new_mad->operands[1] = instr->operands[2];
          }
-         literals_applied = true;
+         new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
+         new_mad->definitions[0] = instr->definitions[0];
+         ctx.instructions.emplace_back(std::move(new_mad));
+         return;
       }
    }
 
-   /* apply literals on SALU/VALU */
-   if (!literals_applied && (instr->isSALU() || instr->isVALU())) {
+   /* apply literals on other SALU/VALU */
+   if (instr->isSALU() || instr->isVALU()) {
       for (unsigned i = 0; i < instr->operands.size(); i++) {
          Operand op = instr->operands[i];
          if (op.isTemp() && ctx.info[op.tempId()].is_literal() && ctx.uses[op.tempId()] == 0) {