aco: allow an extra SGPR with multiple uses to be applied to VOP3
[mesa.git] / src / amd / compiler / aco_optimizer.cpp
index 48a56a57f356a995c2182345f0db6b0c29f7d6b2..88075fabfb2df3bb60f6c400609b4a87535bbc2a 100644 (file)
@@ -86,7 +86,7 @@ enum Label {
 };
 
 static constexpr uint32_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success | label_add_sub | label_bitwise | label_minmax | label_fcmp;
-static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool;
+static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool | label_omod2 | label_omod4 | label_omod5 | label_clamp;
 static constexpr uint32_t val_labels = label_constant | label_literal | label_mad;
 
 struct ssa_info {
@@ -211,9 +211,10 @@ struct ssa_info {
       return label & label_mad;
    }
 
-   void set_omod2()
+   void set_omod2(Temp def)
    {
       add_label(label_omod2);
+      temp = def;
    }
 
    bool is_omod2()
@@ -221,9 +222,10 @@ struct ssa_info {
       return label & label_omod2;
    }
 
-   void set_omod4()
+   void set_omod4(Temp def)
    {
       add_label(label_omod4);
+      temp = def;
    }
 
    bool is_omod4()
@@ -231,9 +233,10 @@ struct ssa_info {
       return label & label_omod4;
    }
 
-   void set_omod5()
+   void set_omod5(Temp def)
    {
       add_label(label_omod5);
+      temp = def;
    }
 
    bool is_omod5()
@@ -252,9 +255,10 @@ struct ssa_info {
       return label & label_omod_success;
    }
 
-   void set_clamp()
+   void set_clamp(Temp def)
    {
       add_label(label_clamp);
+      temp = def;
    }
 
    bool is_clamp()
@@ -412,6 +416,9 @@ bool can_swap_operands(aco_ptr<Instruction>& instr)
 
 bool can_use_VOP3(aco_ptr<Instruction>& instr)
 {
+   if (instr->isVOP3())
+      return true;
+
    if (instr->operands.size() && instr->operands[0].isLiteral())
       return false;
 
@@ -422,6 +429,10 @@ bool can_use_VOP3(aco_ptr<Instruction>& instr)
           instr->opcode != aco_opcode::v_madak_f32 &&
           instr->opcode != aco_opcode::v_madmk_f16 &&
           instr->opcode != aco_opcode::v_madak_f16 &&
+          instr->opcode != aco_opcode::v_fmamk_f32 &&
+          instr->opcode != aco_opcode::v_fmaak_f32 &&
+          instr->opcode != aco_opcode::v_fmamk_f16 &&
+          instr->opcode != aco_opcode::v_fmaak_f16 &&
           instr->opcode != aco_opcode::v_readlane_b32 &&
           instr->opcode != aco_opcode::v_writelane_b32 &&
           instr->opcode != aco_opcode::v_readfirstlane_b32;
@@ -486,19 +497,6 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
    }
 }
 
-bool valu_can_accept_literal(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned operand)
-{
-   /* instructions like v_cndmask_b32 can't take a literal because they always
-    * read SGPRs */
-   if (instr->operands.size() >= 3 &&
-       instr->operands[2].isTemp() && instr->operands[2].regClass().type() == RegType::sgpr)
-      return false;
-
-   // TODO: VOP3 can take a literal on GFX10
-   return !instr->isSDWA() && !instr->isDPP() && !instr->isVOP3() &&
-          operand == 0 && can_accept_constant(instr, operand);
-}
-
 bool valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
 {
    if (instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_readlane_b32_e64 ||
@@ -507,6 +505,33 @@ bool valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
    return true;
 }
 
+/* check constant bus and literal limitations */
+bool check_vop3_operands(opt_ctx& ctx, unsigned num_operands, Operand *operands)
+{
+   int limit = 1;
+   unsigned num_sgprs = 0;
+   unsigned sgpr[] = {0, 0};
+
+   for (unsigned i = 0; i < num_operands; i++) {
+      Operand op = operands[i];
+
+      if (op.hasRegClass() && op.regClass().type() == RegType::sgpr) {
+         /* two reads of the same SGPR count as 1 to the limit */
+         if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
+            if (num_sgprs < 2)
+               sgpr[num_sgprs++] = op.tempId();
+            limit--;
+            if (limit < 0)
+               return false;
+         }
+      } else if (op.isLiteral()) {
+         return false;
+      }
+   }
+
+   return true;
+}
+
 bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp *base, uint32_t *offset)
 {
    Operand op = instr->operands[op_index];
@@ -907,11 +932,11 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
       for (unsigned i = 0; i < 2; i++) {
          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
             if (instr->operands[!i].constantValue() == 0x40000000) { /* 2.0 */
-               ctx.info[instr->operands[i].tempId()].set_omod2();
+               ctx.info[instr->operands[i].tempId()].set_omod2(instr->definitions[0].getTemp());
             } else if (instr->operands[!i].constantValue() == 0x40800000) { /* 4.0 */
-               ctx.info[instr->operands[i].tempId()].set_omod4();
+               ctx.info[instr->operands[i].tempId()].set_omod4(instr->definitions[0].getTemp());
             } else if (instr->operands[!i].constantValue() == 0x3f000000) { /* 0.5 */
-               ctx.info[instr->operands[i].tempId()].set_omod5();
+               ctx.info[instr->operands[i].tempId()].set_omod5(instr->definitions[0].getTemp());
             } else if (instr->operands[!i].constantValue() == 0x3f800000 &&
                        !block.fp_mode.must_flush_denorms32) { /* 1.0 */
                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[i].getTemp());
@@ -967,7 +992,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
             idx = i;
       }
       if (found_zero && found_one && instr->operands[idx].isTemp()) {
-         ctx.info[instr->operands[idx].tempId()].set_clamp();
+         ctx.info[instr->operands[idx].tempId()].set_clamp(instr->definitions[0].getTemp());
       }
       break;
    }
@@ -1203,13 +1228,16 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       Temp op1 = op_instr[i]->operands[1].getTemp();
       if (original_temp_id(ctx, op0) != original_temp_id(ctx, op1))
          return false;
-      /* shouldn't happen yet, but best to be safe */
-      if (op1.type() != RegType::vgpr)
-         return false;
 
       op[i] = op1;
    }
 
+   if (op[1].type() == RegType::sgpr)
+      std::swap(op[0], op[1]);
+   //TODO: we can use two different SGPRs on GFX10
+   if (op[0].type() == RegType::sgpr && op[1].type() == RegType::sgpr)
+      return false;
+
    ctx.uses[op[0].id()]++;
    ctx.uses[op[1].id()]++;
    decrease_uses(ctx, op_instr[0]);
@@ -1366,7 +1394,8 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& in
    if (cmp->operands[constant_operand].isConstant()) {
       constant = cmp->operands[constant_operand].constantValue();
    } else if (cmp->operands[constant_operand].isTemp()) {
-      unsigned id = cmp->operands[constant_operand].tempId();
+      Temp tmp = cmp->operands[constant_operand].getTemp();
+      unsigned id = original_temp_id(ctx, tmp);
       if (!ctx.info[id].is_constant() && !ctx.info[id].is_literal())
          return false;
       constant = ctx.info[id].val;
@@ -1523,17 +1552,8 @@ bool match_op3_for_vop3(opt_ctx &ctx, aco_opcode op1, aco_opcode op2,
    }
 
    /* check operands */
-   unsigned sgpr_id = 0;
-   for (unsigned i = 0; i < 3; i++) {
-      Operand op = operands[i];
-      if (op.isLiteral()) {
-         return false;
-      } else if (op.isTemp() && op.getTemp().type() == RegType::sgpr) {
-         if (sgpr_id && sgpr_id != op.tempId())
-            return false;
-         sgpr_id = op.tempId();
-      }
-   }
+   if (!check_vop3_operands(ctx, 3, operands))
+      return false;
 
    return true;
 }
@@ -1659,6 +1679,10 @@ bool combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 && op2_instr->opcode != aco_opcode::s_not_b64))
          continue;
 
+      if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
+          instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
+         continue;
+
       ctx.uses[instr->operands[i].tempId()]--;
       instr->operands[0] = instr->operands[!i];
       instr->operands[1] = op2_instr->operands[0];
@@ -1701,6 +1725,10 @@ bool combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
       if (shift < 1 || shift > 4)
          continue;
 
+      if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
+          instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
+         continue;
+
       ctx.uses[instr->operands[i].tempId()]--;
       instr->operands[1] = instr->operands[!i];
       instr->operands[0] = op2_instr->operands[0];
@@ -1867,61 +1895,74 @@ bool combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr,
 
 void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 {
-   /* apply sgprs */
-   uint32_t sgpr_idx = 0;
-   uint32_t sgpr_info_id = 0;
-   bool has_sgpr = false;
-   uint32_t sgpr_ssa_id = 0;
-   /* find 'best' possible sgpr */
-   for (unsigned i = 0; i < instr->operands.size(); i++)
-   {
-      if (instr->operands[i].isLiteral()) {
-         has_sgpr = true;
-         break;
-      }
+   /* find candidates and create the set of sgprs already read */
+   unsigned sgpr_ids[2] = {0, 0};
+   uint32_t operand_mask = 0;
+   bool has_literal = false;
+   for (unsigned i = 0; i < instr->operands.size(); i++) {
+      if (instr->operands[i].isLiteral())
+         has_literal = true;
       if (!instr->operands[i].isTemp())
          continue;
       if (instr->operands[i].getTemp().type() == RegType::sgpr) {
-         has_sgpr = true;
-         sgpr_ssa_id = instr->operands[i].tempId();
-         continue;
+         if (instr->operands[i].tempId() != sgpr_ids[0])
+            sgpr_ids[!!sgpr_ids[0]] = instr->operands[i].tempId();
       }
       ssa_info& info = ctx.info[instr->operands[i].tempId()];
-      if (info.is_temp() && info.temp.type() == RegType::sgpr) {
+      if (info.is_temp() && info.temp.type() == RegType::sgpr)
+         operand_mask |= 1u << i;
+   }
+   unsigned max_sgprs = 1;
+   if (has_literal)
+      max_sgprs--;
+
+   unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
+
+   /* keep on applying sgprs until there is nothing left to be done */
+   while (operand_mask) {
+      uint32_t sgpr_idx = 0;
+      uint32_t sgpr_info_id = 0;
+      uint32_t mask = operand_mask;
+      /* choose a sgpr */
+      while (mask) {
+         unsigned i = u_bit_scan(&mask);
          uint16_t uses = ctx.uses[instr->operands[i].tempId()];
          if (sgpr_info_id == 0 || uses < ctx.uses[sgpr_info_id]) {
             sgpr_idx = i;
             sgpr_info_id = instr->operands[i].tempId();
          }
       }
-   }
-   if (!has_sgpr && sgpr_info_id != 0) {
-      ssa_info& info = ctx.info[sgpr_info_id];
+      operand_mask &= ~(1u << sgpr_idx);
+
+      /* Applying two sgprs require making it VOP3, so don't do it unless it's
+       * definitively beneficial.
+       * TODO: this is too conservative because later the use count could be reduced to 1 */
+      if (num_sgprs && ctx.uses[sgpr_info_id] > 1 && !instr->isVOP3())
+         break;
+
+      Temp sgpr = ctx.info[sgpr_info_id].temp;
+      bool new_sgpr = sgpr.id() != sgpr_ids[0] && sgpr.id() != sgpr_ids[1];
+      if (new_sgpr && num_sgprs >= max_sgprs)
+         continue;
+
       if (sgpr_idx == 0 || instr->isVOP3()) {
-         instr->operands[sgpr_idx] = Operand(info.temp);
-         ctx.uses[sgpr_info_id]--;
-         ctx.uses[info.temp.id()]++;
+         instr->operands[sgpr_idx] = Operand(sgpr);
       } else if (can_swap_operands(instr)) {
          instr->operands[sgpr_idx] = instr->operands[0];
-         instr->operands[0] = Operand(info.temp);
-         ctx.uses[sgpr_info_id]--;
-         ctx.uses[info.temp.id()]++;
+         instr->operands[0] = Operand(sgpr);
+         /* swap bits using a 4-entry LUT */
+         uint32_t swapped = (0x3120 >> (operand_mask & 0x3)) & 0xf;
+         operand_mask = (operand_mask & ~0x3) | swapped;
       } else if (can_use_VOP3(instr)) {
          to_VOP3(ctx, instr);
-         instr->operands[sgpr_idx] = Operand(info.temp);
-         ctx.uses[sgpr_info_id]--;
-         ctx.uses[info.temp.id()]++;
+         instr->operands[sgpr_idx] = Operand(sgpr);
+      } else {
+         continue;
       }
 
-   /* we can have two sgprs on one instruction if it is the same sgpr! */
-   } else if (sgpr_info_id != 0 &&
-              sgpr_ssa_id == sgpr_info_id &&
-              ctx.uses[sgpr_info_id] == 1 &&
-              can_use_VOP3(instr)) {
-      to_VOP3(ctx, instr);
-      instr->operands[sgpr_idx] = Operand(ctx.info[sgpr_info_id].temp);
+      sgpr_ids[num_sgprs++] = sgpr.id();
       ctx.uses[sgpr_info_id]--;
-      ctx.uses[ctx.info[sgpr_info_id].temp.id()]++;
+      ctx.uses[sgpr.id()]++;
    }
 }
 
@@ -1944,7 +1985,8 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 
          Instruction* omod_instr = ctx.info[instr->operands[idx].tempId()].instr;
          /* check if we have an additional clamp modifier */
-         if (ctx.info[instr->definitions[0].tempId()].is_clamp() && ctx.uses[instr->definitions[0].tempId()] == 1) {
+         if (ctx.info[instr->definitions[0].tempId()].is_clamp() && ctx.uses[instr->definitions[0].tempId()] == 1 &&
+             ctx.uses[ctx.info[instr->definitions[0].tempId()].temp.id()]) {
             static_cast<VOP3A_instruction*>(omod_instr)->clamp = true;
             ctx.info[instr->definitions[0].tempId()].set_clamp_success(omod_instr);
          }
@@ -2000,22 +2042,23 @@ bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
    if (!instr->definitions.empty() && ctx.uses[instr->definitions[0].tempId()] == 1 &&
        can_use_VOP3(instr) && instr_info.can_use_output_modifiers[(int)instr->opcode]) {
-      if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod2()) {
+      ssa_info& def_info = ctx.info[instr->definitions[0].tempId()];
+      if (can_use_omod && def_info.is_omod2() && ctx.uses[def_info.temp.id()]) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 1;
-         ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod4()) {
+         def_info.set_omod_success(instr.get());
+      } else if (can_use_omod && def_info.is_omod4() && ctx.uses[def_info.temp.id()]) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 2;
-         ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod5()) {
+         def_info.set_omod_success(instr.get());
+      } else if (can_use_omod && def_info.is_omod5() && ctx.uses[def_info.temp.id()]) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 3;
-         ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (ctx.info[instr->definitions[0].tempId()].is_clamp()) {
+         def_info.set_omod_success(instr.get());
+      } else if (def_info.is_clamp() && ctx.uses[def_info.temp.id()]) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->clamp = true;
-         ctx.info[instr->definitions[0].tempId()].set_clamp_success(instr.get());
+         def_info.set_clamp_success(instr.get());
       }
    }
 
@@ -2126,22 +2169,17 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          unsigned omod = 0;
          bool clamp = false;
          bool need_vop3 = false;
-         int num_sgpr = 0;
          op[0] = mul_instr->operands[0];
          op[1] = mul_instr->operands[1];
          op[2] = instr->operands[add_op_idx];
-         for (unsigned i = 0; i < 3; i++)
-         {
-            if (op[i].isLiteral())
-               return;
-            if (op[i].isTemp() && op[i].getTemp().type() == RegType::sgpr)
-               num_sgpr++;
+         // TODO: would be better to check this before selecting a mul instr?
+         if (!check_vop3_operands(ctx, 3, op))
+            return;
+
+         for (unsigned i = 0; i < 3; i++) {
             if (!(i == 0 || (op[i].isTemp() && op[i].getTemp().type() == RegType::vgpr)))
                need_vop3 = true;
          }
-         // TODO: would be better to check this before selecting a mul instr?
-         if (num_sgpr > 1)
-            return;
 
          if (mul_instr->isVOP3()) {
             VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*> (mul_instr);
@@ -2295,6 +2333,10 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       /* first, check profitability */
       if (ctx.uses[info->mul_temp_id]) {
          ctx.uses[info->mul_temp_id]++;
+         if (instr->operands[0].isTemp())
+            ctx.uses[instr->operands[0].tempId()]--;
+         if (instr->operands[1].isTemp())
+            ctx.uses[instr->operands[1].tempId()]--;
          instr.swap(info->add_instr);
 
       /* second, check possible literals */
@@ -2326,44 +2368,74 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
             info->check_literal = true;
             info->literal_idx = literal_idx;
          }
+         return;
       }
-      return;
    }
 
    /* check for literals */
+   if (!instr->isSALU() && !instr->isVALU())
+      return;
+
+   if (instr->isSDWA() || instr->isDPP() || instr->isVOP3())
+      return; /* some encodings can't ever take literals */
+
    /* we do not apply the literals yet as we don't know if it is profitable */
-   if (instr->isSALU()) {
-      uint32_t literal_idx = 0;
-      uint32_t literal_uses = UINT32_MAX;
-      bool has_literal = false;
-      for (unsigned i = 0; i < instr->operands.size(); i++)
-      {
-         if (instr->operands[i].isLiteral()) {
-            has_literal = true;
-            break;
-         }
-         if (!instr->operands[i].isTemp())
-            continue;
-         if (ctx.info[instr->operands[i].tempId()].is_literal() &&
-             ctx.uses[instr->operands[i].tempId()] < literal_uses) {
-            literal_uses = ctx.uses[instr->operands[i].tempId()];
-            literal_idx = i;
-         }
+   Operand current_literal(s1);
+
+   unsigned literal_id = 0;
+   unsigned literal_uses = UINT32_MAX;
+   Operand literal(s1);
+   unsigned num_operands = instr->isSALU() ? instr->operands.size() : 1;
+
+   unsigned sgpr_ids[2] = {0, 0};
+   bool is_literal_sgpr = false;
+   uint32_t mask = 0;
+
+   /* choose a literal to apply */
+   for (unsigned i = 0; i < num_operands; i++) {
+      Operand op = instr->operands[i];
+      if (op.isLiteral()) {
+         current_literal = op;
+         continue;
+      } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal()) {
+         if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
+             op.tempId() != sgpr_ids[0])
+            sgpr_ids[!!sgpr_ids[0]] = op.tempId();
+         continue;
       }
-      if (!has_literal && literal_uses < threshold) {
-         ctx.uses[instr->operands[literal_idx].tempId()]--;
-         if (ctx.uses[instr->operands[literal_idx].tempId()] == 0)
-            instr->operands[literal_idx] = Operand(ctx.info[instr->operands[literal_idx].tempId()].val);
+
+      if (!can_accept_constant(instr, i))
+         continue;
+
+      if (ctx.uses[op.tempId()] < literal_uses) {
+         is_literal_sgpr = op.getTemp().type() == RegType::sgpr;
+         mask = 0;
+         literal = Operand(ctx.info[op.tempId()].val);
+         literal_uses = ctx.uses[op.tempId()];
+         literal_id = op.tempId();
       }
-   } else if (instr->isVALU() && valu_can_accept_literal(ctx, instr, 0) &&
-       instr->operands[0].isTemp() &&
-       ctx.info[instr->operands[0].tempId()].is_literal() &&
-       ctx.uses[instr->operands[0].tempId()] < threshold) {
-      ctx.uses[instr->operands[0].tempId()]--;
-      if (ctx.uses[instr->operands[0].tempId()] == 0)
-         instr->operands[0] = Operand(ctx.info[instr->operands[0].tempId()].val);
+
+      mask |= (op.tempId() == literal_id) << i;
    }
 
+
+   /* don't go over the constant bus limit */
+   unsigned const_bus_limit = instr->isVALU() ? 1 : UINT32_MAX;
+   unsigned num_sgprs = !!sgpr_ids[0] + !!sgpr_ids[1];
+   if (num_sgprs == const_bus_limit && !is_literal_sgpr)
+      return;
+
+   if (literal_id && literal_uses < threshold &&
+       (current_literal.isUndefined() ||
+        (current_literal.size() == literal.size() &&
+         current_literal.constantValue() == literal.constantValue()))) {
+      /* mark the literal to be applied */
+      while (mask) {
+         unsigned i = u_bit_scan(&mask);
+         if (instr->operands[i].isTemp() && instr->operands[i].tempId() == literal_id)
+            ctx.uses[instr->operands[i].tempId()]--;
+      }
+   }
 }
 
 
@@ -2373,44 +2445,40 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (!instr)
       return;
 
-   /* apply literals on SALU */
-   if (instr->isSALU()) {
-      for (Operand& op : instr->operands) {
-         if (!op.isTemp())
-            continue;
-         if (op.isLiteral())
-            break;
-         if (ctx.info[op.tempId()].is_literal() &&
-             ctx.uses[op.tempId()] == 0)
-            op = Operand(ctx.info[op.tempId()].val);
+   /* apply literals on MAD */
+   bool literals_applied = false;
+   if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
+      mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
+      if (!info->needs_vop3) {
+         aco_ptr<Instruction> new_mad;
+         if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
+            if (info->literal_idx == 2) { /* add literal -> madak */
+               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
+               new_mad->operands[0] = instr->operands[0];
+               new_mad->operands[1] = instr->operands[1];
+            } else { /* mul literal -> madmk */
+               new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
+               new_mad->operands[0] = instr->operands[1 - info->literal_idx];
+               new_mad->operands[1] = instr->operands[2];
+            }
+            new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
+            new_mad->definitions[0] = instr->definitions[0];
+            instr.swap(new_mad);
+         }
+         literals_applied = true;
       }
    }
 
-   /* apply literals on VALU */
-   else if (instr->isVALU() && !instr->isVOP3() &&
-       instr->operands[0].isTemp() &&
-       ctx.info[instr->operands[0].tempId()].is_literal() &&
-       ctx.uses[instr->operands[0].tempId()] == 0) {
-      instr->operands[0] = Operand(ctx.info[instr->operands[0].tempId()].val);
-   }
-
-   /* apply literals on MAD */
-   else if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
-      mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
-      aco_ptr<Instruction> new_mad;
-      if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
-         if (info->literal_idx == 2) { /* add literal -> madak */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));
-            new_mad->operands[0] = instr->operands[0];
-            new_mad->operands[1] = instr->operands[1];
-         } else { /* mul literal -> madmk */
-            new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madmk_f32, Format::VOP2, 3, 1));
-            new_mad->operands[0] = instr->operands[1 - info->literal_idx];
-            new_mad->operands[1] = instr->operands[2];
+   /* apply literals on SALU/VALU */
+   if (!literals_applied && (instr->isSALU() || instr->isVALU())) {
+      for (unsigned i = 0; i < instr->operands.size(); i++) {
+         Operand op = instr->operands[i];
+         if (op.isTemp() && ctx.info[op.tempId()].is_literal() && ctx.uses[op.tempId()] == 0) {
+            Operand literal(ctx.info[op.tempId()].val);
+            if (instr->isVALU() && i > 0)
+               to_VOP3(ctx, instr);
+            instr->operands[i] = literal;
          }
-         new_mad->operands[2] = Operand(ctx.info[instr->operands[info->literal_idx].tempId()].val);
-         new_mad->definitions[0] = instr->definitions[0];
-         instr.swap(new_mad);
       }
    }