v3d/tex: only look up the 2nd texture gather offset for 1d non-arrays
[mesa.git] / src / amd / compiler / aco_optimizer.cpp
index fbeb76df9902ceed8fe4e12e0b30a1cc49adfa10..9203f1c4b43d1dd3a76e41152027fe36c2a5e50f 100644 (file)
@@ -103,6 +103,8 @@ struct ssa_info {
    };
    uint32_t label;
 
+   ssa_info() : label(0) {}
+
    void add_label(Label new_label)
    {
       /* Since all labels which use "instr" use it for the same thing
@@ -528,9 +530,9 @@ void to_VOP3(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 }
 
 /* only covers special cases */
-bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
+bool alu_can_accept_constant(aco_opcode opcode, unsigned operand)
 {
-   switch (instr->opcode) {
+   switch (opcode) {
    case aco_opcode::v_interp_p2_f32:
    case aco_opcode::v_mac_f32:
    case aco_opcode::v_writelane_b32:
@@ -547,12 +549,6 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
    case aco_opcode::v_readfirstlane_b32:
       return operand != 0;
    default:
-      if ((instr->format == Format::MUBUF ||
-           instr->format == Format::MIMG) &&
-          instr->definitions.size() == 1 &&
-          instr->operands.size() == 4) {
-         return operand != 3;
-      }
       return true;
    }
 }
@@ -628,6 +624,7 @@ bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp
    switch (add_instr->opcode) {
    case aco_opcode::v_add_u32:
    case aco_opcode::v_add_co_u32:
+   case aco_opcode::v_add_co_u32_e64:
    case aco_opcode::s_add_i32:
    case aco_opcode::s_add_u32:
       break;
@@ -671,6 +668,11 @@ Operand get_constant_op(opt_ctx &ctx, uint32_t val, bool is64bit = false)
    return op;
 }
 
+bool fixed_to_exec(Operand op)
+{
+   return op.isFixed() && op.physReg() == exec;
+}
+
 void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    if (instr->isSALU() || instr->isVALU() || instr->format == Format::PSEUDO) {
@@ -697,6 +699,12 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 
       /* SALU / PSEUDO: propagate inline constants */
       if (instr->isSALU() || instr->format == Format::PSEUDO) {
+         const bool is_subdword = std::any_of(instr->definitions.begin(), instr->definitions.end(),
+                                              [] (const Definition& def) { return def.regClass().is_subdword();});
+         // TODO: optimize SGPR and constant propagation for subdword pseudo instructions on gfx9+
+         if (is_subdword)
+            continue;
+
          if (info.is_temp() && info.temp.type() == RegType::sgpr) {
             instr->operands[i].setTemp(info.temp);
             info = ctx.info[info.temp.id()];
@@ -719,7 +727,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
                break;
             }
          }
-         if ((info.is_constant() || info.is_constant_64bit() || (info.is_literal() && instr->format == Format::PSEUDO)) && !instr->operands[i].isFixed() && can_accept_constant(instr, i)) {
+         if ((info.is_constant() || info.is_constant_64bit() || (info.is_literal() && instr->format == Format::PSEUDO)) &&
+             !instr->operands[i].isFixed() && alu_can_accept_constant(instr->opcode, i)) {
             instr->operands[i] = get_constant_op(ctx, info.val, info.is_constant_64bit());
             continue;
          }
@@ -754,7 +763,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
                static_cast<VOP3A_instruction*>(instr.get())->neg[i] = true;
             continue;
          }
-         if ((info.is_constant() || info.is_constant_64bit()) && can_accept_constant(instr, i)) {
+         if ((info.is_constant() || info.is_constant_64bit()) && alu_can_accept_constant(instr->opcode, i)) {
             Operand op = get_constant_op(ctx, info.val, info.is_constant_64bit());
             perfwarn(instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
             if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) {
@@ -780,9 +789,9 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          while (info.is_temp())
             info = ctx.info[info.temp.id()];
 
-         if (mubuf->offen && i == 0 && info.is_constant_or_literal() && mubuf->offset + info.val < 4096) {
+         if (mubuf->offen && i == 1 && info.is_constant_or_literal() && mubuf->offset + info.val < 4096) {
             assert(!mubuf->idxen);
-            instr->operands[i] = Operand(v1);
+            instr->operands[1] = Operand(v1);
             mubuf->offset += info.val;
             mubuf->offen = false;
             continue;
@@ -790,9 +799,9 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
             instr->operands[2] = Operand((uint32_t) 0);
             mubuf->offset += info.val;
             continue;
-         } else if (mubuf->offen && i == 0 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == v1 && mubuf->offset + offset < 4096) {
+         } else if (mubuf->offen && i == 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == v1 && mubuf->offset + offset < 4096) {
             assert(!mubuf->idxen);
-            instr->operands[i].setTemp(base);
+            instr->operands[1].setTemp(base);
             mubuf->offset += offset;
             continue;
          } else if (i == 2 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == s1 && mubuf->offset + offset < 4096) {
@@ -815,12 +824,15 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
              instr->opcode != aco_opcode::ds_swizzle_b32) {
             if (instr->opcode == aco_opcode::ds_write2_b32 || instr->opcode == aco_opcode::ds_read2_b32 ||
                 instr->opcode == aco_opcode::ds_write2_b64 || instr->opcode == aco_opcode::ds_read2_b64) {
-               if (offset % 4 == 0 &&
-                   ds->offset0 + (offset >> 2) <= 255 &&
-                   ds->offset1 + (offset >> 2) <= 255) {
+               unsigned mask = (instr->opcode == aco_opcode::ds_write2_b64 || instr->opcode == aco_opcode::ds_read2_b64) ? 0x7 : 0x3;
+               unsigned shifts = (instr->opcode == aco_opcode::ds_write2_b64 || instr->opcode == aco_opcode::ds_read2_b64) ? 3 : 2;
+
+               if ((offset & mask) == 0 &&
+                   ds->offset0 + (offset >> shifts) <= 255 &&
+                   ds->offset1 + (offset >> shifts) <= 255) {
                   instr->operands[i].setTemp(base);
-                  ds->offset0 += offset >> 2;
-                  ds->offset1 += offset >> 2;
+                  ds->offset0 += offset >> shifts;
+                  ds->offset1 += offset >> shifts;
                }
             } else {
                if (ds->offset0 + offset <= 65535) {
@@ -886,6 +898,13 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 
    switch (instr->opcode) {
    case aco_opcode::p_create_vector: {
+      bool copy_prop = instr->operands.size() == 1 && instr->operands[0].isTemp() &&
+                       instr->operands[0].regClass() == instr->definitions[0].regClass();
+      if (copy_prop) {
+         ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
+         break;
+      }
+
       unsigned num_ops = instr->operands.size();
       for (const Operand& op : instr->operands) {
          if (op.isTemp() && ctx.info[op.tempId()].is_vec())
@@ -911,19 +930,25 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          }
          assert(k == num_ops);
       }
-      if (instr->operands.size() == 1 && instr->operands[0].isTemp())
-         ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
-      else if (instr->definitions[0].getTemp().size() == instr->operands.size())
-         ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
+
+      ctx.info[instr->definitions[0].tempId()].set_vec(instr.get());
       break;
    }
    case aco_opcode::p_split_vector: {
       if (!ctx.info[instr->operands[0].tempId()].is_vec())
          break;
       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
-      assert(instr->definitions.size() == vec->operands.size());
-      for (unsigned i = 0; i < instr->definitions.size(); i++) {
-         Operand vec_op = vec->operands[i];
+      unsigned split_offset = 0;
+      unsigned vec_offset = 0;
+      unsigned vec_index = 0;
+      for (unsigned i = 0; i < instr->definitions.size(); split_offset += instr->definitions[i++].bytes()) {
+         while (vec_offset < split_offset && vec_index < vec->operands.size())
+            vec_offset += vec->operands[vec_index++].bytes();
+
+         if (vec_offset != split_offset || vec->operands[vec_index].bytes() != instr->definitions[i].bytes())
+            continue;
+
+         Operand vec_op = vec->operands[vec_index];
          if (vec_op.isConstant()) {
             if (vec_op.isLiteral())
                ctx.info[instr->definitions[i].tempId()].set_literal(vec_op.constantValue());
@@ -931,6 +956,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
                ctx.info[instr->definitions[i].tempId()].set_constant(vec_op.constantValue());
             else if (vec_op.size() == 2)
                ctx.info[instr->definitions[i].tempId()].set_constant_64bit(vec_op.constantValue());
+         } else if (vec_op.isUndefined()) {
+            ctx.info[instr->definitions[i].tempId()].set_undefined();
          } else {
             assert(vec_op.isTemp());
             ctx.info[instr->definitions[i].tempId()].set_temp(vec_op.getTemp());
@@ -941,33 +968,40 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    case aco_opcode::p_extract_vector: { /* mov */
       if (!ctx.info[instr->operands[0].tempId()].is_vec())
          break;
+
+      /* check if we index directly into a vector element */
       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
-      if (vec->definitions[0].getTemp().size() == vec->operands.size() && /* TODO: what about 64bit or other combinations? */
-          vec->operands[0].size() == instr->definitions[0].size()) {
-
-         /* convert this extract into a mov instruction */
-         Operand vec_op = vec->operands[instr->operands[1].constantValue()];
-         bool is_vgpr = instr->definitions[0].getTemp().type() == RegType::vgpr;
-         aco_opcode opcode = is_vgpr ? aco_opcode::v_mov_b32 : aco_opcode::s_mov_b32;
-         Format format = is_vgpr ? Format::VOP1 : Format::SOP1;
-         instr->opcode = opcode;
-         instr->format = format;
-         while (instr->operands.size() > 1)
-            instr->operands.pop_back();
-         instr->operands[0] = vec_op;
+      const unsigned index = instr->operands[1].constantValue();
+      const unsigned dst_offset = index * instr->definitions[0].bytes();
+      unsigned offset = 0;
 
-         if (vec_op.isConstant()) {
-            if (vec_op.isLiteral())
-               ctx.info[instr->definitions[0].tempId()].set_literal(vec_op.constantValue());
-            else if (vec_op.size() == 1)
-               ctx.info[instr->definitions[0].tempId()].set_constant(vec_op.constantValue());
-            else if (vec_op.size() == 2)
-               ctx.info[instr->definitions[0].tempId()].set_constant_64bit(vec_op.constantValue());
+      for (const Operand& op : vec->operands) {
+         if (offset < dst_offset) {
+            offset += op.bytes();
+            continue;
+         } else if (offset != dst_offset || op.bytes() != instr->definitions[0].bytes()) {
+            break;
+         }
 
+         /* convert this extract into a copy instruction */
+         instr->opcode = aco_opcode::p_parallelcopy;
+         instr->operands.pop_back();
+         instr->operands[0] = op;
+
+         if (op.isConstant()) {
+            if (op.isLiteral())
+               ctx.info[instr->definitions[0].tempId()].set_literal(op.constantValue());
+            else if (op.size() == 1)
+               ctx.info[instr->definitions[0].tempId()].set_constant(op.constantValue());
+            else if (op.size() == 2)
+               ctx.info[instr->definitions[0].tempId()].set_constant_64bit(op.constantValue());
+         } else if (op.isUndefined()) {
+            ctx.info[instr->definitions[0].tempId()].set_undefined();
          } else {
-            assert(vec_op.isTemp());
-            ctx.info[instr->definitions[0].tempId()].set_temp(vec_op.getTemp());
+            assert(op.isTemp());
+            ctx.info[instr->definitions[0].tempId()].set_temp(op.getTemp());
          }
+         break;
       }
       break;
    }
@@ -1143,6 +1177,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    }
    case aco_opcode::v_add_u32:
    case aco_opcode::v_add_co_u32:
+   case aco_opcode::v_add_co_u32_e64:
    case aco_opcode::s_add_i32:
    case aco_opcode::s_add_u32:
       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
@@ -1160,7 +1195,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
       break;
    case aco_opcode::s_and_b32:
    case aco_opcode::s_and_b64:
-      if (instr->operands[1].isFixed() && instr->operands[1].physReg() == exec && instr->operands[0].isTemp()) {
+      if (fixed_to_exec(instr->operands[1]) && instr->operands[0].isTemp()) {
          if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
             /* Try to get rid of the superfluous s_cselect + s_and_b64 that comes from turning a uniform bool into divergent */
             ctx.info[instr->definitions[1].tempId()].set_temp(ctx.info[instr->operands[0].tempId()].temp);
@@ -1645,6 +1680,8 @@ bool match_op3_for_vop3(opt_ctx &ctx, aco_opcode op1, aco_opcode op2,
    Instruction *op2_instr = follow_operand(ctx, op1_instr->operands[swap]);
    if (!op2_instr || op2_instr->opcode != op2)
       return false;
+   if (fixed_to_exec(op2_instr->operands[0]) || fixed_to_exec(op2_instr->operands[1]))
+      return false;
 
    VOP3A_instruction *op1_vop3 = op1_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op1_instr) : NULL;
    VOP3A_instruction *op2_vop3 = op2_instr->isVOP3() ? static_cast<VOP3A_instruction *>(op2_instr) : NULL;
@@ -1810,6 +1847,7 @@ bool combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 
    /* create instruction */
    std::swap(instr->definitions[0], op2_instr->definitions[0]);
+   std::swap(instr->definitions[1], op2_instr->definitions[1]);
    ctx.uses[instr->operands[0].tempId()]--;
    ctx.info[op2_instr->definitions[0].tempId()].label = 0;
 
@@ -1845,13 +1883,15 @@ bool combine_salu_not_bitwise(opt_ctx& ctx, aco_ptr<Instruction>& instr)
  * s_or_b64(a, s_not_b64(b)) -> s_orn2_b64(a, b) */
 bool combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 {
-   if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
+   if (instr->definitions[0].isTemp() && ctx.info[instr->definitions[0].tempId()].is_uniform_bool())
       return false;
 
    for (unsigned i = 0; i < 2; i++) {
       Instruction *op2_instr = follow_operand(ctx, instr->operands[i]);
       if (!op2_instr || (op2_instr->opcode != aco_opcode::s_not_b32 && op2_instr->opcode != aco_opcode::s_not_b64))
          continue;
+      if (ctx.uses[op2_instr->definitions[1].tempId()] || fixed_to_exec(op2_instr->operands[0]))
+         continue;
 
       if (instr->operands[!i].isLiteral() && op2_instr->operands[0].isLiteral() &&
           instr->operands[!i].constantValue() != op2_instr->operands[0].constantValue())
@@ -1887,12 +1927,15 @@ bool combine_salu_n2(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 /* s_add_{i32,u32}(a, s_lshl_b32(b, <n>)) -> s_lshl<n>_add_u32(a, b) */
 bool combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
 {
-   if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
+   if (instr->opcode == aco_opcode::s_add_i32 && ctx.uses[instr->definitions[1].tempId()])
       return false;
 
    for (unsigned i = 0; i < 2; i++) {
       Instruction *op2_instr = follow_operand(ctx, instr->operands[i]);
-      if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 || !op2_instr->operands[1].isConstant())
+      if (!op2_instr || op2_instr->opcode != aco_opcode::s_lshl_b32 ||
+          ctx.uses[op2_instr->definitions[1].tempId()])
+         continue;
+      if (!op2_instr->operands[1].isConstant() || fixed_to_exec(op2_instr->operands[0]))
          continue;
 
       uint32_t shift = op2_instr->operands[1].constantValue();
@@ -2436,12 +2479,19 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
          }
       }
    } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->chip_class >= GFX9) {
-      if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ;
+      if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ;
+      else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ;
+      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ;
       else if (combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ;
+      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_or_b32, "120", 1 | 2)) ;
       else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_or_b32, "210", 1 | 2);
    } else if (instr->opcode == aco_opcode::v_add_u32 && ctx.program->chip_class >= GFX9) {
-      if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32, "120", 1 | 2)) ;
+      if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120", 1 | 2)) ;
+      else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32, "120", 1 | 2)) ;
+      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
+      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
       else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
+      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_add_u32, "120", 1 | 2)) ;
       else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_add_u32, "210", 1 | 2);
    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->chip_class >= GFX9) {
       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120", 2);
@@ -2530,10 +2580,12 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (instr->opcode == aco_opcode::p_split_vector) {
       unsigned num_used = 0;
       unsigned idx = 0;
-      for (unsigned i = 0; i < instr->definitions.size(); i++) {
+      unsigned split_offset = 0;
+      for (unsigned i = 0, offset = 0; i < instr->definitions.size(); offset += instr->definitions[i++].bytes()) {
          if (ctx.uses[instr->definitions[i].tempId()]) {
             num_used++;
             idx = i;
+            split_offset = offset;
          }
       }
       bool done = false;
@@ -2544,13 +2596,13 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          unsigned off = 0;
          Operand op;
          for (Operand& vec_op : vec->operands) {
-            if (off == idx * instr->definitions[0].size()) {
+            if (off == split_offset) {
                op = vec_op;
                break;
             }
-            off += vec_op.size();
+            off += vec_op.bytes();
          }
-         if (off != instr->operands[0].size()) {
+         if (off != instr->operands[0].bytes() && op.bytes() == instr->definitions[idx].bytes()) {
             ctx.uses[instr->operands[0].tempId()]--;
             for (Operand& vec_op : vec->operands) {
                if (vec_op.isTemp())
@@ -2568,10 +2620,12 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          }
       }
 
-      if (!done && num_used == 1) {
+      if (!done && num_used == 1 &&
+          instr->operands[0].bytes() % instr->definitions[idx].bytes() == 0 &&
+          split_offset % instr->definitions[idx].bytes() == 0) {
          aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
          extract->operands[0] = instr->operands[0];
-         extract->operands[1] = Operand((uint32_t) idx);
+         extract->operands[1] = Operand((uint32_t) split_offset / instr->definitions[idx].bytes());
          extract->definitions[0] = instr->definitions[idx];
          instr.reset(extract.release());
       }
@@ -2605,7 +2659,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
                continue;
             /* if one of the operands is sgpr, we cannot add a literal somewhere else on pre-GFX10 or operands other than the 1st */
             if (instr->operands[i].getTemp().type() == RegType::sgpr && (i > 0 || ctx.program->chip_class < GFX10)) {
-               if (ctx.info[instr->operands[i].tempId()].is_literal()) {
+               if (!sgpr_used && ctx.info[instr->operands[i].tempId()].is_literal()) {
                   literal_uses = ctx.uses[instr->operands[i].tempId()];
                   literal_idx = i;
                } else {
@@ -2620,7 +2674,15 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
                literal_idx = i;
             }
          }
-         if (literal_uses < threshold) {
+
+         /* Limit the number of literals to apply to not increase the code
+          * size too much, but always apply literals for v_mad->v_madak
+          * because both instructions are 64-bit and this doesn't increase
+          * code size.
+          * TODO: try to apply the literals earlier to lower the number of
+          * uses below threshold
+          */
+         if (literal_uses < threshold || literal_idx == 2) {
             ctx.uses[instr->operands[literal_idx].tempId()]--;
             mad_info->check_literal = true;
             mad_info->literal_idx = literal_idx;
@@ -2674,6 +2736,9 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    unsigned num_operands = 1;
    if (instr->isSALU() || (ctx.program->chip_class >= GFX10 && can_use_VOP3(ctx, instr)))
       num_operands = instr->operands.size();
+   /* catch VOP2 with a 3rd SGPR operand (e.g. v_cndmask_b32, v_addc_co_u32) */
+   else if (instr->isVALU() && instr->operands.size() >= 3)
+      return;
 
    unsigned sgpr_ids[2] = {0, 0};
    bool is_literal_sgpr = false;
@@ -2694,7 +2759,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          continue;
       }
 
-      if (!can_accept_constant(instr, i))
+      if (!alu_can_accept_constant(instr->opcode, i))
          continue;
 
       if (ctx.uses[op.tempId()] < literal_uses) {
@@ -2744,7 +2809,8 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    /* apply literals on MAD */
    if (instr->opcode == aco_opcode::v_mad_f32 && ctx.info[instr->definitions[0].tempId()].is_mad()) {
       mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].val];
-      if (info->check_literal && ctx.uses[instr->operands[info->literal_idx].tempId()] == 0) {
+      if (info->check_literal &&
+          (ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) {
          aco_ptr<Instruction> new_mad;
          if (info->literal_idx == 2) { /* add literal -> madak */
             new_mad.reset(create_instruction<VOP2_instruction>(aco_opcode::v_madak_f32, Format::VOP2, 3, 1));