aco: rewrite literal combining
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 22 Nov 2019 13:43:39 +0000 (13:43 +0000)
committerMarge Bot <eric+marge@anholt.net>
Tue, 14 Jan 2020 12:56:28 +0000 (12:56 +0000)
Should make taking advantage of GFX10's increased constant bus limit and
VOP3 literals easier.

No pipeline-db changes

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/merge_requests/2883>

src/amd/compiler/aco_optimizer.cpp

index da169694ffbe18372d58bf1f3a6cb56af76e2feb..de3fdad42e7fa6e26479e9485297bae311935ba4 100644 (file)
@@ -497,19 +497,6 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
    }
 }
 
-bool valu_can_accept_literal(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned operand)
-{
-   /* instructions like v_cndmask_b32 can't take a literal because they always
-    * read SGPRs */
-   if (instr->operands.size() >= 3 &&
-       instr->operands[2].isTemp() && instr->operands[2].regClass().type() == RegType::sgpr)
-      return false;
-
-   // TODO: VOP3 can take a literal on GFX10
-   return !instr->isSDWA() && !instr->isDPP() && !instr->isVOP3() &&
-          operand == 0 && can_accept_constant(instr, operand);
-}
-
 bool valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
 {
    if (instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_readlane_b32_e64 ||
@@ -2349,43 +2336,74 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
             info->check_literal = true;
             info->literal_idx = literal_idx;
          }
+         return;
       }
    }
 
    /* check for literals */
+   if (!instr->isSALU() && !instr->isVALU())
+      return;
+
+   if (instr->isSDWA() || instr->isDPP() || instr->isVOP3())
+      return; /* some encodings can't ever take literals */
+
    /* we do not apply the literals yet as we don't know if it is profitable */
-   if (instr->isSALU()) {
-      uint32_t literal_idx = 0;
-      uint32_t literal_uses = UINT32_MAX;
-      bool has_literal = false;
-      for (unsigned i = 0; i < instr->operands.size(); i++)
-      {
-         if (instr->operands[i].isLiteral()) {
-            has_literal = true;
-            break;
-         }
-         if (!instr->operands[i].isTemp())
-            continue;
-         if (ctx.info[instr->operands[i].tempId()].is_literal() &&
-             ctx.uses[instr->operands[i].tempId()] < literal_uses) {
-            literal_uses = ctx.uses[instr->operands[i].tempId()];
-            literal_idx = i;
-         }
+   Operand current_literal(s1);
+
+   unsigned literal_id = 0;
+   unsigned literal_uses = UINT32_MAX;
+   Operand literal(s1);
+   unsigned num_operands = instr->isSALU() ? instr->operands.size() : 1;
+
+   unsigned sgpr_ids[2] = {0, 0};
+   bool is_literal_sgpr = false;
+   uint32_t mask = 0;
+
+   /* choose a literal to apply */
+   for (unsigned i = 0; i < num_operands; i++) {
+      Operand op = instr->operands[i];
+      if (op.isLiteral()) {
+         current_literal = op;
+         continue;
+      } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal()) {
+         if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
+             op.tempId() != sgpr_ids[0])
+            sgpr_ids[!!sgpr_ids[0]] = op.tempId();
+         continue;
       }
-      if (!has_literal && literal_uses < threshold) {
-         ctx.uses[instr->operands[literal_idx].tempId()]--;
-         if (ctx.uses[instr->operands[literal_idx].tempId()] == 0)
-            instr->operands[literal_idx] = Operand(ctx.info[instr->operands[literal_idx].tempId()].val);
+
+      if (!can_accept_constant(instr, i))
+         continue;
+
+      if (ctx.uses[op.tempId()] < literal_uses) {
+         is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
+         mask = 0;
+         literal = Operand(ctx.info[op.tempId()].val);
+         literal_uses = ctx.uses[op.tempId()];
+         literal_id = op.tempId();
       }
-   } else if (instr->isVALU() && valu_can_accept_literal(ctx, instr, 0) &&
-       instr->operands[0].isTemp() &&
-       ctx.info[instr->operands[0].tempId()].is_literal() &&
-       ctx.uses[instr->operands[0].tempId()] < threshold) {
-      ctx.uses[instr->operands[0].tempId()]--;
-      if (ctx.uses[instr->operands[0].tempId()] == 0)
-         instr->operands[0] = Operand(ctx.info[instr->operands[0].tempId()].val);
+
+      mask |= (op.tempId() == literal_id) << i;
    }
 
+
+   /* don't go over the constant bus limit */
+   unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
+   unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
+   if (num_sgprs == const_bus_limit && !is_literal_sgpr)
+      return;
+
+   if (literal_id && literal_uses < threshold &&
+       (current_literal.isUndefined() ||
+        (current_literal.size() == literal.size() &&
+         current_literal.constantValue() == literal.constantValue()))) {
+      /* mark the literal to be applied */
+      while (mask) {
+         unsigned i = u_bit_scan(&mask);
+         if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
+            ctx.uses[instr->operands[i].tempId()]--;
+      }
+   }
 }
 
 
@@ -2395,44 +2413,40 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (!instr)
       return;
 
-   /* apply literals on SALU */
-   if (instr->isSALU()) {
-      for (Operand& op : instr->operands) {
-         if (!op.isTemp())
-            continue;
-         if (op.isLiteral())
-            break;
-         if (ctx.info[op.tempId()].is_literal() &&
-             ctx.uses[op.tempId()] == 0)
-            op = Operand(ctx.info[op.tempId()].val);
+   /* apply literals on MAD */
+   bool literals_applied = false;
+   if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
+      mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
+      if (!info->needs_vop3) {
+         aco_ptr<Instruction> new_mad;
+         if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
+            if (info->literal_idx == 2) { /* add literal -> madak */
+               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
+               new_mad->operands[0] = instr->operands[0];
+               new_mad->operands[1] = instr->operands[1];
+            } else { /* mul literal -> madmk */
+               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
+               new_mad->operands[0] = instr->operands[1 - info->literal_idx];
+               new_mad->operands[1] = instr->operands[2];
+            }
+            new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
+            new_mad->definitions[0] = instr->definitions[0];
+            instr.swap(new_mad);
+         }
+         literals_applied = true;
       }
    }
 
-   /* apply literals on VALU */
-   else if (instr->isVALU() && !instr->isVOP3() &&
-       instr->operands[0].isTemp() &&
-       ctx.info[instr->operands[0].tempId()].is_literal() &&
-       ctx.uses[instr->operands[0].tempId()] == 0) {
-      instr->operands[0] = Operand(ctx.info[instr->operands[0].tempId()].val);
-   }
-
-   /* apply literals on MAD */
-   else if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
-      mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
-      aco_ptr<Instruction> new_mad;
-      if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
-         if (info->literal_idx == 2) { /* add literal -> madak */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
-            new_mad->operands[0] = instr->operands[0];
-            new_mad->operands[1] = instr->operands[1];
-         } else { /* mul literal -> madmk */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
-            new_mad->operands[0] = instr->operands[1 - info->literal_idx];
-            new_mad->operands[1] = instr->operands[2];
+   /* apply literals on SALU/VALU */
+   if (!literals_applied && (instr->isSALU() || instr->isVALU())) {
+      for (unsigned i = 0; i < instr->operands.size(); i++) {
+         Operand op = instr->operands[i];
+         if (op.isTemp() && ctx.info[op.tempId()].is_literal() && ctx.uses[op.tempId()] == 0) {
+            Operand literal(ctx.info[op.tempId()].val);
+            if (instr->isVALU() && i > 0)
+               to_VOP3(ctx, instr);
+            instr->operands[i] = literal;
          }
-         new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
-         new_mad->definitions[0] = instr->definitions[0];
-         instr.swap(new_mad);
       }
    }