aco: allow an extra SGPR with multiple uses to be applied to VOP3
[mesa.git] / src / amd / compiler / aco_optimizer.cpp
index a589fbbe73e4ef24d505ae60ff64fbf47b81e061..88075fabfb2df3bb60f6c400609b4a87535bbc2a 100644 (file)
@@ -416,6 +416,9 @@ bool can_swap_operands(aco_ptr<Instruction>& instr)
 
 bool can_use_VOP3(aco_ptr<Instruction>& instr)
 {
+   if (instr->isVOP3())
+      return true;
+
    if (instr->operands.size() && instr->operands[0].isLiteral())
       return false;
 
@@ -426,6 +429,10 @@ bool can_use_VOP3(aco_ptr<Instruction>& instr)
           instr->opcode != aco_opcode::v_madak_f32 &&
           instr->opcode != aco_opcode::v_madmk_f16 &&
           instr->opcode != aco_opcode::v_madak_f16 &&
+          instr->opcode != aco_opcode::v_fmamk_f32 &&
+          instr->opcode != aco_opcode::v_fmaak_f32 &&
+          instr->opcode != aco_opcode::v_fmamk_f16 &&
+          instr->opcode != aco_opcode::v_fmaak_f16 &&
           instr->opcode != aco_opcode::v_readlane_b32 &&
           instr->opcode != aco_opcode::v_writelane_b32 &&
           instr->opcode != aco_opcode::v_readfirstlane_b32;
@@ -490,19 +497,6 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
    }
 }
 
-bool valu_can_accept_literal(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned operand)
-{
-   /* instructions like v_cndmask_b32 can't take a literal because they always
-    * read SGPRs */
-   if (instr->operands.size() >= 3 &&
-       instr->operands[2].isTemp() && instr->operands[2].regClass().type() == RegType::sgpr)
-      return false;
-
-   // TODO: VOP3 can take a literal on GFX10
-   return !instr->isSDWA() && !instr->isDPP() && !instr->isVOP3() &&
-          operand == 0 && can_accept_constant(instr, operand);
-}
-
 bool valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
 {
    if (instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_readlane_b32_e64 ||
@@ -511,6 +505,33 @@ bool valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
    return true;
 }
 
+/* check constant bus and literal limitations */
+bool check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand *operands)
+{
+   int limit = 1;
+   unsigned num_sgprs = 0;
+   unsigned sgpr[] = {0, 0};
+
+   for (unsigned i = 0; i < num_operands; i++) {
+      Operand op = operands[i];
+
+      if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
+         /* two reads of the same SGPR count as 1 to the limit */
+         if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
+            if (num_sgprs < 2)
+               sgpr[num_sgprs++] = op.tempId();
+            limit--;
+            if (limit < 0)
+               return false;
+         }
+      } else if (op.isLiteral()) {
+         return false;
+      }
+   }
+
+   return true;
+}
+
 bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp *base, uint32_t *offset)
 {
    Operand op = instr->operands[op_index];
@@ -1373,7 +1394,8 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& in
    if (cmp->operands[constant_operand].isConstant()) {
       constant = cmp->operands[constant_operand].constantValue();
    } else if (cmp->operands[constant_operand].isTemp()) {
-      unsigned id = cmp->operands[constant_operand].tempId();
+      Temp tmp = cmp->operands[constant_operand].getTemp();
+      unsigned id = original_temp_id(ctx, tmp);
       if (!ctx.info[id].is_constant() && !ctx.info[id].is_literal())
          return false;
       constant = ctx.info[id].val;
@@ -1530,17 +1552,8 @@ bool match_op3_for_vop3(opt_ctx &ctx, aco_opcode op1, aco_opcode op2,
    }
 
    /* check operands */
-   unsigned sgpr_id = 0;
-   for (unsigned i = 0; i < 3; i++) {
-      Operand op = operands[i];
-      if (op.isLiteral()) {
-         return false;
-      } else if (op.isTemp() && op.getTemp().type() == RegType::sgpr) {
-         if (sgpr_id && sgpr_id != op.tempId())
-            return false;
-         sgpr_id = op.tempId();
-      }
-   }
+   if (!check_vop3_operands(ctx, 3, operands))
+      return false;
 
    return true;
 }
@@ -1666,6 +1679,10 @@ bool combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 && op2_instr->opcode != aco_opcode::s_not_b64))
          continue;
 
+      if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
+          instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
+         continue;
+
       ctx.uses[instr->operands[i].tempId()]--;
       instr->operands[0] = instr->operands[!i];
       instr->operands[1] = op2_instr->operands[0];
@@ -1708,6 +1725,10 @@ bool combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
       if (shift < 1 || shift > 4)
          continue;
 
+      if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
+          instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
+         continue;
+
       ctx.uses[instr->operands[i].tempId()]--;
       instr->operands[1] = instr->operands[!i];
       instr->operands[0] = op2_instr->operands[0];
@@ -1874,61 +1895,74 @@ bool combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr,
 
 void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 {
-   /* apply sgprs */
-   uint32_t sgpr_idx = 0;
-   uint32_t sgpr_info_id = 0;
-   bool has_sgpr = false;
-   uint32_t sgpr_ssa_id = 0;
-   /* find 'best' possible sgpr */
-   for (unsigned i = 0; i < instr->operands.size(); i++)
-   {
-      if (instr->operands[i].isLiteral()) {
-         has_sgpr = true;
-         break;
-      }
+   /* find candidates and create the set of sgprs already read */
+   unsigned sgpr_ids[2] = {0, 0};
+   uint32_t operand_mask = 0;
+   bool has_literal = false;
+   for (unsigned i = 0; i < instr->operands.size(); i++) {
+      if (instr->operands[i].isLiteral())
+         has_literal = true;
       if (!instr->operands[i].isTemp())
          continue;
       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
-         has_sgpr = true;
-         sgpr_ssa_id = instr->operands[i].tempId();
-         continue;
+         if (instr->operands[i].tempId() != sgpr_ids[0])
+            sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
       }
       ssa_info& info = ctx.info[instr->operands[i].tempId()];
-      if (info.is_temp() && info.temp.type() == RegType::sgpr) {
+      if (info.is_temp() && info.temp.type() == RegType::sgpr)
+         operand_mask |= 1u << i;
+   }
+   unsigned max_sgprs = 1;
+   if (has_literal)
+      max_sgprs--;
+
+   unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
+
+   /* keep on applying sgprs until there is nothing left to be done */
+   while (operand_mask) {
+      uint32_t sgpr_idx = 0;
+      uint32_t sgpr_info_id = 0;
+      uint32_t mask = operand_mask;
+      /* choose a sgpr */
+      while (mask) {
+         unsigned i = u_bit_scan(&mask);
          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
             sgpr_idx = i;
             sgpr_info_id = instr->operands[i].tempId();
          }
       }
-   }
-   if (!has_sgpr && sgpr_info_id != 0) {
-      ssa_info& info = ctx.info[sgpr_info_id];
+      operand_mask &= ~(1u << sgpr_idx);
+
+      /* Applying two sgprs require making it VOP3, so don't do it unless it's
+       * definitively beneficial.
+       * TODO: this is too conservative because later the use count could be reduced to 1 */
+      if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3())
+         break;
+
+      Temp sgpr = ctx.info[sgpr_info_id].temp;
+      bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
+      if (new_sgpr && num_sgprs >= max_sgprs)
+         continue;
+
       if (sgpr_idx == 0 || instr->isVOP3()) {
-         instr->operands[sgpr_idx] = Operand(info.temp);
-         ctx.uses[sgpr_info_id]--;
-         ctx.uses[info.temp.id()]++;
+         instr->operands[sgpr_idx] = Operand(sgpr);
       } else if (can_swap_operands(instr)) {
          instr->operands[sgpr_idx] = instr->operands[0];
-         instr->operands[0] = Operand(info.temp);
-         ctx.uses[sgpr_info_id]--;
-         ctx.uses[info.temp.id()]++;
+         instr->operands[0] = Operand(sgpr);
+         /* swap bits using a 4-entry LUT */
+         uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
+         operand_mask = (operand_mask & ~0x3) | swapped;
       } else if (can_use_VOP3(instr)) {
          to_VOP3(ctx, instr);
-         instr->operands[sgpr_idx] = Operand(info.temp);
-         ctx.uses[sgpr_info_id]--;
-         ctx.uses[info.temp.id()]++;
+         instr->operands[sgpr_idx] = Operand(sgpr);
+      } else {
+         continue;
       }
 
-   /* we can have two sgprs on one instruction if it is the same sgpr! */
-   } else if (sgpr_info_id != 0 &&
-              sgpr_ssa_id == ctx.info[sgpr_info_id].temp.id() &&
-              ctx.uses[sgpr_info_id] == 1 &&
-              can_use_VOP3(instr)) {
-      to_VOP3(ctx, instr);
-      instr->operands[sgpr_idx] = Operand(ctx.info[sgpr_info_id].temp);
+      sgpr_ids[num_sgprs++] = sgpr.id();
       ctx.uses[sgpr_info_id]--;
-      ctx.uses[ctx.info[sgpr_info_id].temp.id()]++;
+      ctx.uses[sgpr.id()]++;
    }
 }
 
@@ -2135,25 +2169,17 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          unsigned omod = 0;
          bool clamp = false;
          bool need_vop3 = false;
-         int num_sgpr = 0;
-         unsigned cur_sgpr = 0;
          op[0] = mul_instr->operands[0];
          op[1] = mul_instr->operands[1];
          op[2] = instr->operands[add_op_idx];
-         for (unsigned i = 0; i < 3; i++)
-         {
-            if (op[i].isLiteral())
-               return;
-            if (op[i].isTemp() && op[i].getTemp().type() == RegType::sgpr && op[i].tempId() != cur_sgpr) {
-               num_sgpr++;
-               cur_sgpr = op[i].tempId();
-            }
+         // TODO: would be better to check this before selecting a mul instr?
+         if (!check_vop3_operands(ctx, 3, op))
+            return;
+
+         for (unsigned i = 0; i < 3; i++) {
             if (!(i == 0 || (op[i].isTemp() && op[i].getTemp().type() == RegType::vgpr)))
                need_vop3 = true;
          }
-         // TODO: would be better to check this before selecting a mul instr?
-         if (num_sgpr > 1)
-            return;
 
          if (mul_instr->isVOP3()) {
             VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);
@@ -2342,43 +2368,74 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
             info->check_literal = true;
             info->literal_idx = literal_idx;
          }
+         return;
       }
    }
 
    /* check for literals */
+   if (!instr->isSALU() && !instr->isVALU())
+      return;
+
+   if (instr->isSDWA() || instr->isDPP() || instr->isVOP3())
+      return; /* some encodings can't ever take literals */
+
    /* we do not apply the literals yet as we don't know if it is profitable */
-   if (instr->isSALU()) {
-      uint32_t literal_idx = 0;
-      uint32_t literal_uses = UINT32_MAX;
-      bool has_literal = false;
-      for (unsigned i = 0; i < instr->operands.size(); i++)
-      {
-         if (instr->operands[i].isLiteral()) {
-            has_literal = true;
-            break;
-         }
-         if (!instr->operands[i].isTemp())
-            continue;
-         if (ctx.info[instr->operands[i].tempId()].is_literal() &&
-             ctx.uses[instr->operands[i].tempId()] < literal_uses) {
-            literal_uses = ctx.uses[instr->operands[i].tempId()];
-            literal_idx = i;
-         }
+   Operand current_literal(s1);
+
+   unsigned literal_id = 0;
+   unsigned literal_uses = UINT32_MAX;
+   Operand literal(s1);
+   unsigned num_operands = instr->isSALU() ? instr->operands.size() : 1;
+
+   unsigned sgpr_ids[2] = {0, 0};
+   bool is_literal_sgpr = false;
+   uint32_t mask = 0;
+
+   /* choose a literal to apply */
+   for (unsigned i = 0; i < num_operands; i++) {
+      Operand op = instr->operands[i];
+      if (op.isLiteral()) {
+         current_literal = op;
+         continue;
+      } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal()) {
+         if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
+             op.tempId() != sgpr_ids[0])
+            sgpr_ids[!!sgpr_ids[0]] = op.tempId();
+         continue;
       }
-      if (!has_literal && literal_uses < threshold) {
-         ctx.uses[instr->operands[literal_idx].tempId()]--;
-         if (ctx.uses[instr->operands[literal_idx].tempId()] == 0)
-            instr->operands[literal_idx] = Operand(ctx.info[instr->operands[literal_idx].tempId()].val);
+
+      if (!can_accept_constant(instr, i))
+         continue;
+
+      if (ctx.uses[op.tempId()] < literal_uses) {
+         is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
+         mask = 0;
+         literal = Operand(ctx.info[op.tempId()].val);
+         literal_uses = ctx.uses[op.tempId()];
+         literal_id = op.tempId();
       }
-   } else if (instr->isVALU() && valu_can_accept_literal(ctx, instr, 0) &&
-       instr->operands[0].isTemp() &&
-       ctx.info[instr->operands[0].tempId()].is_literal() &&
-       ctx.uses[instr->operands[0].tempId()] < threshold) {
-      ctx.uses[instr->operands[0].tempId()]--;
-      if (ctx.uses[instr->operands[0].tempId()] == 0)
-         instr->operands[0] = Operand(ctx.info[instr->operands[0].tempId()].val);
+
+      mask |= (op.tempId() == literal_id) << i;
    }
 
+
+   /* don't go over the constant bus limit */
+   unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
+   unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
+   if (num_sgprs == const_bus_limit && !is_literal_sgpr)
+      return;
+
+   if (literal_id && literal_uses < threshold &&
+       (current_literal.isUndefined() ||
+        (current_literal.size() == literal.size() &&
+         current_literal.constantValue() == literal.constantValue()))) {
+      /* mark the literal to be applied */
+      while (mask) {
+         unsigned i = u_bit_scan(&mask);
+         if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
+            ctx.uses[instr->operands[i].tempId()]--;
+      }
+   }
 }
 
 
@@ -2388,44 +2445,40 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (!instr)
       return;
 
-   /* apply literals on SALU */
-   if (instr->isSALU()) {
-      for (Operand& op : instr->operands) {
-         if (!op.isTemp())
-            continue;
-         if (op.isLiteral())
-            break;
-         if (ctx.info[op.tempId()].is_literal() &&
-             ctx.uses[op.tempId()] == 0)
-            op = Operand(ctx.info[op.tempId()].val);
+   /* apply literals on MAD */
+   bool literals_applied = false;
+   if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
+      mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
+      if (!info->needs_vop3) {
+         aco_ptr<Instruction> new_mad;
+         if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
+            if (info->literal_idx == 2) { /* add literal -> madak */
+               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
+               new_mad->operands[0] = instr->operands[0];
+               new_mad->operands[1] = instr->operands[1];
+            } else { /* mul literal -> madmk */
+               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
+               new_mad->operands[0] = instr->operands[1 - info->literal_idx];
+               new_mad->operands[1] = instr->operands[2];
+            }
+            new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
+            new_mad->definitions[0] = instr->definitions[0];
+            instr.swap(new_mad);
+         }
+         literals_applied = true;
       }
    }
 
-   /* apply literals on VALU */
-   else if (instr->isVALU() && !instr->isVOP3() &&
-       instr->operands[0].isTemp() &&
-       ctx.info[instr->operands[0].tempId()].is_literal() &&
-       ctx.uses[instr->operands[0].tempId()] == 0) {
-      instr->operands[0] = Operand(ctx.info[instr->operands[0].tempId()].val);
-   }
-
-   /* apply literals on MAD */
-   else if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
-      mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
-      aco_ptr<Instruction> new_mad;
-      if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
-         if (info->literal_idx == 2) { /* add literal -> madak */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
-            new_mad->operands[0] = instr->operands[0];
-            new_mad->operands[1] = instr->operands[1];
-         } else { /* mul literal -> madmk */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
-            new_mad->operands[0] = instr->operands[1 - info->literal_idx];
-            new_mad->operands[1] = instr->operands[2];
+   /* apply literals on SALU/VALU */
+   if (!literals_applied && (instr->isSALU() || instr->isVALU())) {
+      for (unsigned i = 0; i < instr->operands.size(); i++) {
+         Operand op = instr->operands[i];
+         if (op.isTemp() && ctx.info[op.tempId()].is_literal() && ctx.uses[op.tempId()] == 0) {
+            Operand literal(ctx.info[op.tempId()].val);
+            if (instr->isVALU() && i > 0)
+               to_VOP3(ctx, instr);
+            instr->operands[i] = literal;
          }
-         new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
-         new_mad->definitions[0] = instr->definitions[0];
-         instr.swap(new_mad);
       }
    }