aco/wave32: fix comparison optimizations
[mesa.git] / src / amd / compiler / aco_optimizer.cpp
index a142ccba9a1c50c768e9ee4a755b01b4175832c0..612497928dd788d43a38210a03dd6077493666d7 100644 (file)
@@ -82,10 +82,11 @@ enum Label {
    label_bitwise = 1 << 18,
    label_minmax = 1 << 19,
    label_fcmp = 1 << 20,
+   label_uniform_bool = 1 << 21,
 };
 
 static constexpr uint32_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success | label_add_sub | label_bitwise | label_minmax | label_fcmp;
-static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f;
+static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool;
 static constexpr uint32_t val_labels = label_constant | label_literal | label_mad;
 
 struct ssa_info {
@@ -353,6 +354,17 @@ struct ssa_info {
       return label & label_fcmp;
    }
 
+   void set_uniform_bool(Temp uniform_bool)
+   {
+      add_label(label_uniform_bool);
+      temp = uniform_bool;
+   }
+
+   bool is_uniform_bool()
+   {
+      return label & label_uniform_bool;
+   }
+
 };
 
 struct opt_ctx {
@@ -409,14 +421,19 @@ bool can_use_VOP3(aco_ptr<Instruction>& instr)
    return instr->opcode != aco_opcode::v_madmk_f32 &&
           instr->opcode != aco_opcode::v_madak_f32 &&
           instr->opcode != aco_opcode::v_madmk_f16 &&
-          instr->opcode != aco_opcode::v_madak_f16;
+          instr->opcode != aco_opcode::v_madak_f16 &&
+          instr->opcode != aco_opcode::v_readlane_b32 &&
+          instr->opcode != aco_opcode::v_writelane_b32 &&
+          instr->opcode != aco_opcode::v_readfirstlane_b32;
 }
 
 bool can_apply_sgprs(aco_ptr<Instruction>& instr)
 {
    return instr->opcode != aco_opcode::v_readfirstlane_b32 &&
           instr->opcode != aco_opcode::v_readlane_b32 &&
-          instr->opcode != aco_opcode::v_writelane_b32;
+          instr->opcode != aco_opcode::v_readlane_b32_e64 &&
+          instr->opcode != aco_opcode::v_writelane_b32 &&
+          instr->opcode != aco_opcode::v_writelane_b32_e64;
 }
 
 void to_VOP3(opt_ctx& ctx, aco_ptr<Instruction>& instr)
@@ -439,12 +456,6 @@ void to_VOP3(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    }
 }
 
-bool valu_can_accept_literal(opt_ctx& ctx, aco_ptr<Instruction>& instr)
-{
-   // TODO: VOP3 can take a literal on GFX10
-   return !instr->isSDWA() && !instr->isDPP() && !instr->isVOP3();
-}
-
 /* only covers special cases */
 bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
 {
@@ -452,6 +463,7 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
    case aco_opcode::v_interp_p2_f32:
    case aco_opcode::v_mac_f32:
    case aco_opcode::v_writelane_b32:
+   case aco_opcode::v_writelane_b32_e64:
    case aco_opcode::v_cndmask_b32:
       return operand != 2;
    case aco_opcode::s_addk_i32:
@@ -460,6 +472,7 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
    case aco_opcode::p_extract_vector:
    case aco_opcode::p_split_vector:
    case aco_opcode::v_readlane_b32:
+   case aco_opcode::v_readlane_b32_e64:
    case aco_opcode::v_readfirstlane_b32:
       return operand != 0;
    default:
@@ -473,6 +486,27 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
    }
 }
 
+bool valu_can_accept_literal(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned operand)
+{
+   /* instructions like v_cndmask_b32 can't take a literal because they always
+    * read SGPRs */
+   if (instr->operands.size() >= 3 &&
+       instr->operands[2].isTemp() && instr->operands[2].regClass().type() == RegType::sgpr)
+      return false;
+
+   // TODO: VOP3 can take a literal on GFX10
+   return !instr->isSDWA() && !instr->isDPP() && !instr->isVOP3() &&
+          operand == 0 && can_accept_constant(instr, operand);
+}
+
+bool valu_can_accept_vgpr(aco_ptr<Instruction>& instr, unsigned operand)
+{
+   if (instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_readlane_b32_e64 ||
+       instr->opcode == aco_opcode::v_writelane_b32 || instr->opcode == aco_opcode::v_writelane_b32_e64)
+      return operand != 1;
+   return true;
+}
+
 bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp *base, uint32_t *offset)
 {
    Operand op = instr->operands[op_index];
@@ -495,6 +529,9 @@ bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp
       return false;
    }
 
+   if (add_instr->usesModifiers())
+      return false;
+
    for (unsigned i = 0; i < 2; i++) {
       if (add_instr->operands[i].isConstant()) {
          *offset = add_instr->operands[i].constantValue();
@@ -519,7 +556,16 @@ bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp
    return false;
 }
 
-void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+Operand get_constant_op(opt_ctx &ctx, uint32_t val)
+{
+   // TODO: this functions shouldn't be needed if we store Operand instead of value.
+   Operand op(val);
+   if (val == 0x3e22f983 && ctx.program->chip_class >= GFX8)
+      op.setFixed(PhysReg{248}); /* 1/2 PI can be an inline constant on GFX8+ */
+   return op;
+}
+
+void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    if (instr->isSALU() || instr->isVALU() || instr->format == Format::PSEUDO) {
       ASSERTED bool all_const = false;
@@ -568,14 +614,14 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
             }
          }
          if ((info.is_constant() || (info.is_literal() && instr->format == Format::PSEUDO)) && !instr->operands[i].isFixed() && can_accept_constant(instr, i)) {
-            instr->operands[i] = Operand(info.val);
+            instr->operands[i] = get_constant_op(ctx, info.val);
             continue;
          }
       }
 
       /* VALU: propagate neg, abs & inline constants */
       else if (instr->isVALU()) {
-         if (info.is_temp() && info.temp.type() == RegType::vgpr) {
+         if (info.is_temp() && info.temp.type() == RegType::vgpr && valu_can_accept_vgpr(instr, i)) {
             instr->operands[i].setTemp(info.temp);
             info = ctx.info[info.temp.id()];
          }
@@ -604,16 +650,16 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          }
          if (info.is_constant() && can_accept_constant(instr, i)) {
             perfwarn(instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
-            if (i == 0) {
-               instr->operands[i] = Operand(info.val);
+            if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) {
+               instr->operands[i] = get_constant_op(ctx, info.val);
                continue;
             } else if (!instr->isVOP3() && can_swap_operands(instr)) {
                instr->operands[i] = instr->operands[0];
-               instr->operands[0] = Operand(info.val);
+               instr->operands[0] = get_constant_op(ctx, info.val);
                continue;
             } else if (can_use_VOP3(instr)) {
                to_VOP3(ctx, instr);
-               instr->operands[i] = Operand(info.val);
+               instr->operands[i] = get_constant_op(ctx, info.val);
                continue;
             }
          }
@@ -656,7 +702,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          Temp base;
          uint32_t offset;
          if (i == 0 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == instr->operands[i].regClass()) {
-            if (instr->opcode == aco_opcode::ds_write2_b32 || instr->opcode == aco_opcode::ds_read2_b32) {
+            if (instr->opcode == aco_opcode::ds_write2_b32 || instr->opcode == aco_opcode::ds_read2_b32 ||
+                instr->opcode == aco_opcode::ds_write2_b64 || instr->opcode == aco_opcode::ds_read2_b64) {
                if (offset % 4 == 0 &&
                    ds->offset0 + (offset >> 2) <= 255 &&
                    ds->offset1 + (offset >> 2) <= 255) {
@@ -679,7 +726,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          SMEM_instruction *smem = static_cast<SMEM_instruction *>(instr.get());
          Temp base;
          uint32_t offset;
-         if (i == 1 && info.is_constant_or_literal() && info.val <= 0xFFFFF) {
+         if (i == 1 && info.is_constant_or_literal() &&
+             (ctx.program->chip_class < GFX8 || info.val <= 0xFFFFF)) {
             instr->operands[i] = Operand(info.val);
             continue;
          } else if (i == 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->chip_class >= GFX9) {
@@ -701,6 +749,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
                new_instr->operands.back() = Operand(base);
                if (!smem->definitions.empty())
                   new_instr->definitions[0] = smem->definitions[0];
+               new_instr->can_reorder = smem->can_reorder;
+               new_instr->barrier = smem->barrier;
                instr.reset(new_instr);
                smem = static_cast<SMEM_instruction *>(instr.get());
             }
@@ -727,8 +777,13 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          unsigned k = 0;
          for (Operand& old_op : old_vec->operands) {
             if (old_op.isTemp() && ctx.info[old_op.tempId()].is_vec()) {
-               for (unsigned j = 0; j < ctx.info[old_op.tempId()].instr->operands.size(); j++)
-                  instr->operands[k++] = ctx.info[old_op.tempId()].instr->operands[j];
+               for (unsigned j = 0; j < ctx.info[old_op.tempId()].instr->operands.size(); j++) {
+                  Operand op = ctx.info[old_op.tempId()].instr->operands[j];
+                  if (op.isTemp() && ctx.info[op.tempId()].is_temp() &&
+                      ctx.info[op.tempId()].temp.type() == instr->definitions[0].regClass().type())
+                     op.setTemp(ctx.info[op.tempId()].temp);
+                  instr->operands[k++] = op;
+               }
             } else {
                instr->operands[k++] = old_op;
             }
@@ -751,7 +806,7 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          if (vec_op.isConstant()) {
             if (vec_op.isLiteral())
                ctx.info[instr->definitions[i].tempId()].set_literal(vec_op.constantValue());
-            else
+            else if (vec_op.size() == 1)
                ctx.info[instr->definitions[i].tempId()].set_constant(vec_op.constantValue());
          } else {
             assert(vec_op.isTemp());
@@ -764,7 +819,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       if (!ctx.info[instr->operands[0].tempId()].is_vec())
          break;
       Instruction* vec = ctx.info[instr->operands[0].tempId()].instr;
-      if (vec->definitions[0].getTemp().size() == vec->operands.size()) { /* TODO: what about 64bit or other combinations? */
+      if (vec->definitions[0].getTemp().size() == vec->operands.size() && /* TODO: what about 64bit or other combinations? */
+          vec->operands[0].size() == instr->definitions[0].size()) {
 
          /* convert this extract into a mov instruction */
          Operand vec_op = vec->operands[instr->operands[1].constantValue()];
@@ -779,7 +835,7 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          if (vec_op.isConstant()) {
             if (vec_op.isLiteral())
                ctx.info[instr->definitions[0].tempId()].set_literal(vec_op.constantValue());
-            else
+            else if (vec_op.size() == 1)
                ctx.info[instr->definitions[0].tempId()].set_constant(vec_op.constantValue());
          } else {
             assert(vec_op.isTemp());
@@ -794,13 +850,13 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    case aco_opcode::p_as_uniform:
       if (instr->definitions[0].isFixed()) {
          /* don't copy-propagate copies into fixed registers */
+      } else if (instr->usesModifiers()) {
+         // TODO
       } else if (instr->operands[0].isConstant()) {
          if (instr->operands[0].isLiteral())
             ctx.info[instr->definitions[0].tempId()].set_literal(instr->operands[0].constantValue());
-         else
+         else if (instr->operands[0].size() == 1)
             ctx.info[instr->definitions[0].tempId()].set_constant(instr->operands[0].constantValue());
-      } else if (instr->isDPP()) {
-         // TODO
       } else if (instr->operands[0].isTemp()) {
          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
       } else {
@@ -844,11 +900,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    }
    case aco_opcode::v_mul_f32: { /* omod */
       /* TODO: try to move the negate/abs modifier to the consumer instead */
-      if (instr->isVOP3()) {
-         VOP3A_instruction *vop3 = static_cast<VOP3A_instruction*>(instr.get());
-         if (vop3->abs[0] || vop3->abs[1] || vop3->neg[0] || vop3->neg[1] || vop3->omod || vop3->clamp)
-            break;
-      }
+      if (instr->usesModifiers())
+         break;
 
       for (unsigned i = 0; i < 2; i++) {
          if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) {
@@ -858,7 +911,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
                ctx.info[instr->operands[i].tempId()].set_omod4();
             } else if (instr->operands[!i].constantValue() == 0x3f000000) { /* 0.5 */
                ctx.info[instr->operands[i].tempId()].set_omod5();
-            } else if (instr->operands[!i].constantValue() == 0x3f800000) { /* 1.0 */
+            } else if (instr->operands[!i].constantValue() == 0x3f800000 &&
+                       !block.fp_mode.must_flush_denorms32) { /* 1.0 */
                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[i].getTemp());
             } else {
                continue;
@@ -962,10 +1016,15 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    case aco_opcode::s_add_u32:
       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
       break;
-   case aco_opcode::s_not_b32:
-   case aco_opcode::s_not_b64:
    case aco_opcode::s_and_b32:
    case aco_opcode::s_and_b64:
+      if (instr->operands[1].isFixed() && instr->operands[1].physReg() == exec &&
+          instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
+         ctx.info[instr->definitions[1].tempId()].set_temp(ctx.info[instr->operands[0].tempId()].temp);
+      }
+      /* fallthrough */
+   case aco_opcode::s_not_b32:
+   case aco_opcode::s_not_b64:
    case aco_opcode::s_or_b32:
    case aco_opcode::s_or_b64:
    case aco_opcode::s_xor_b32:
@@ -1005,6 +1064,14 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    case aco_opcode::v_cmp_nlt_f32:
       ctx.info[instr->definitions[0].tempId()].set_fcmp(instr.get());
       break;
+   case aco_opcode::s_cselect_b64:
+   case aco_opcode::s_cselect_b32:
+      if (instr->operands[0].constantEquals((unsigned) -1) &&
+          instr->operands[1].constantEquals(0)) {
+         /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
+         ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
+      }
+      break;
    default:
       break;
    }
@@ -1097,11 +1164,13 @@ Instruction *follow_operand(opt_ctx &ctx, Operand op, bool ignore_uses=false)
  * s_and_b64(eq(a, a), eq(b, b)) -> v_cmp_o_f32(a, b) */
 bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 {
-   if (instr->opcode != aco_opcode::s_or_b64 && instr->opcode != aco_opcode::s_and_b64)
+   if (instr->definitions[0].regClass() != ctx.program->lane_mask)
       return false;
    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
       return false;
 
+   bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
+
    bool neg[2] = {false, false};
    bool abs[2] = {false, false};
    bool opsel[2] = {false, false};
@@ -1113,8 +1182,7 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       if (!op_instr[i])
          return false;
 
-      aco_opcode expected_cmp = instr->opcode == aco_opcode::s_or_b64 ?
-                                aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
+      aco_opcode expected_cmp = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
 
       if (op_instr[i]->opcode != expected_cmp)
          return false;
@@ -1146,8 +1214,7 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    decrease_uses(ctx, op_instr[0]);
    decrease_uses(ctx, op_instr[1]);
 
-   aco_opcode new_op = instr->opcode == aco_opcode::s_or_b64 ?
-                       aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
+   aco_opcode new_op = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
    Instruction *new_instr;
    if (neg[0] || neg[1] || abs[0] || abs[1] || opsel[0] || opsel[1]) {
       VOP3A_instruction *vop3 = create_instruction<VOP3A_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
@@ -1176,13 +1243,13 @@ bool combine_ordering_test(opt_ctx &ctx, aco_ptr<Instruction>& instr)
  * s_and_b64(v_cmp_o_f32(a, b), cmp(a, b)) -> get_ordered(cmp)(a, b) */
 bool combine_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 {
-   if (instr->opcode != aco_opcode::s_or_b64 && instr->opcode != aco_opcode::s_and_b64)
+   if (instr->definitions[0].regClass() != ctx.program->lane_mask)
       return false;
    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
       return false;
 
-   aco_opcode expected_nan_test = instr->opcode == aco_opcode::s_or_b64 ?
-                                  aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
+   bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
+   aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_u_f32 : aco_opcode::v_cmp_o_f32;
 
    Instruction *nan_test = follow_operand(ctx, instr->operands[0], true);
    Instruction *cmp = follow_operand(ctx, instr->operands[1], true);
@@ -1216,8 +1283,7 @@ bool combine_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    decrease_uses(ctx, nan_test);
    decrease_uses(ctx, cmp);
 
-   aco_opcode new_op = instr->opcode == aco_opcode::s_or_b64 ?
-                       get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
+   aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
    Instruction *new_instr;
    if (cmp->isVOP3()) {
       VOP3A_instruction *new_vop3 = create_instruction<VOP3A_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
@@ -1247,19 +1313,20 @@ bool combine_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& instr)
  * s_and_b64(v_cmp_eq_f32(a, a), cmp(a, #b)) and b is not NaN -> get_ordered(cmp)(a, b) */
 bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 {
-   if (instr->opcode != aco_opcode::s_or_b64 && instr->opcode != aco_opcode::s_and_b64)
+   if (instr->definitions[0].regClass() != ctx.program->lane_mask)
       return false;
    if (instr->definitions[1].isTemp() && ctx.uses[instr->definitions[1].tempId()])
       return false;
 
+   bool is_or = instr->opcode == aco_opcode::s_or_b64 || instr->opcode == aco_opcode::s_or_b32;
+
    Instruction *nan_test = follow_operand(ctx, instr->operands[0], true);
    Instruction *cmp = follow_operand(ctx, instr->operands[1], true);
 
    if (!nan_test || !cmp)
       return false;
 
-   aco_opcode expected_nan_test = instr->opcode == aco_opcode::s_or_b64 ?
-                                  aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
+   aco_opcode expected_nan_test = is_or ? aco_opcode::v_cmp_neq_f32 : aco_opcode::v_cmp_eq_f32;
    if (cmp->opcode == expected_nan_test)
       std::swap(nan_test, cmp);
    else if (nan_test->opcode != expected_nan_test)
@@ -1312,8 +1379,7 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& in
    decrease_uses(ctx, nan_test);
    decrease_uses(ctx, cmp);
 
-   aco_opcode new_op = instr->opcode == aco_opcode::s_or_b64 ?
-                       get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
+   aco_opcode new_op = is_or ? get_unordered(cmp->opcode) : get_ordered(cmp->opcode);
    Instruction *new_instr;
    if (cmp->isVOP3()) {
       VOP3A_instruction *new_vop3 = create_instruction<VOP3A_instruction>(new_op, asVOP3(Format::VOPC), 2, 1);
@@ -1482,34 +1548,6 @@ void create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr<Instruction>&
    instr.reset(new_instr);
 }
 
-bool combine_minmax3(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op)
-{
-   uint32_t omod_clamp = ctx.info[instr->definitions[0].tempId()].label &
-                         (label_omod_success | label_clamp_success);
-
-   for (unsigned swap = 0; swap < 2; swap++) {
-      Operand operands[3];
-      bool neg[3], abs[3], opsel[3], clamp, inbetween_neg, inbetween_abs;
-      unsigned omod;
-      if (match_op3_for_vop3(ctx, instr->opcode, instr->opcode, instr.get(), swap,
-                             "012", operands, neg, abs, opsel,
-                             &clamp, &omod, &inbetween_neg, &inbetween_abs, NULL)) {
-         ctx.uses[instr->operands[swap].tempId()]--;
-         neg[1] ^= inbetween_neg;
-         neg[2] ^= inbetween_neg;
-         abs[1] |= inbetween_abs;
-         abs[2] |= inbetween_abs;
-         create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod);
-         if (omod_clamp & label_omod_success)
-            ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-         if (omod_clamp & label_clamp_success)
-            ctx.info[instr->definitions[0].tempId()].set_clamp_success(instr.get());
-         return true;
-      }
-   }
-   return false;
-}
-
 bool combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode op2, aco_opcode new_op, const char *shuffle, uint8_t ops)
 {
    uint32_t omod_clamp = ctx.info[instr->definitions[0].tempId()].label &
@@ -1878,7 +1916,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    }
 }
 
-bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    /* check if we could apply omod on predecessor */
    if (instr->opcode == aco_opcode::v_mul_f32) {
@@ -1945,18 +1983,21 @@ bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       }
    }
 
+   /* omod has no effect if denormals are enabled */
+   bool can_use_omod = block.fp_mode.denorm32 == 0;
+
    /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
    if (!instr->definitions.empty() && ctx.uses[instr->definitions[0].tempId()] == 1 &&
        can_use_VOP3(instr) && instr_info.can_use_output_modifiers[(int)instr->opcode]) {
-      if(ctx.info[instr->definitions[0].tempId()].is_omod2()) {
+      if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod2()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 1;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (ctx.info[instr->definitions[0].tempId()].is_omod4()) {
+      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod4()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 2;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (ctx.info[instr->definitions[0].tempId()].is_omod5()) {
+      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod5()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 3;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
@@ -1973,7 +2014,7 @@ bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
 // this would mean that we'd have to fix the instruction uses while value propagation
 
-void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    if (instr->definitions.empty() || !ctx.uses[instr->definitions[0].tempId()])
       return;
@@ -1981,7 +2022,7 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (instr->isVALU()) {
       if (can_apply_sgprs(instr))
          apply_sgprs(ctx, instr);
-      if (apply_omod_clamp(ctx, instr))
+      if (apply_omod_clamp(ctx, block, instr))
          return;
    }
 
@@ -2034,9 +2075,11 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       return;
    }
    /* combine mul+add -> mad */
-   else if (instr->opcode == aco_opcode::v_add_f32 ||
-       instr->opcode == aco_opcode::v_sub_f32 ||
-       instr->opcode == aco_opcode::v_subrev_f32) {
+   else if ((instr->opcode == aco_opcode::v_add_f32 ||
+             instr->opcode == aco_opcode::v_sub_f32 ||
+             instr->opcode == aco_opcode::v_subrev_f32) &&
+            block.fp_mode.denorm32 == 0 && !block.fp_mode.preserve_signed_zero_inf_nan32) {
+      //TODO: we could use fma instead when denormals are enabled if the NIR isn't marked as precise
 
       uint32_t uses_src0 = UINT32_MAX;
       uint32_t uses_src1 = UINT32_MAX;
@@ -2189,9 +2232,8 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    } else if (instr->opcode == aco_opcode::s_not_b64) {
       if (combine_inverse_comparison(ctx, instr)) ;
       else combine_salu_not_bitwise(ctx, instr);
-   } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32) {
-      combine_salu_n2(ctx, instr);
-   } else if (instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
+   } else if (instr->opcode == aco_opcode::s_and_b32 || instr->opcode == aco_opcode::s_or_b32 ||
+              instr->opcode == aco_opcode::s_and_b64 || instr->opcode == aco_opcode::s_or_b64) {
       if (combine_ordering_test(ctx, instr)) ;
       else if (combine_comparison_ordering(ctx, instr)) ;
       else if (combine_constant_comparison_ordering(ctx, instr)) ;
@@ -2201,7 +2243,7 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       bool some_gfx9_only;
       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &some_gfx9_only) &&
           (!some_gfx9_only || ctx.program->chip_class >= GFX9)) {
-         if (combine_minmax3(ctx, instr, instr->opcode == min ? min3 : max3)) ;
+         if (combine_three_valu_op(ctx, instr, instr->opcode, instr->opcode == min ? min3 : max3, "012", 1 | 2));
          else combine_clamp(ctx, instr, min, max, med3);
       }
    }
@@ -2307,7 +2349,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          if (ctx.uses[instr->operands[literal_idx].tempId()] == 0)
             instr->operands[literal_idx] = Operand(ctx.info[instr->operands[literal_idx].tempId()].val);
       }
-   } else if (instr->isVALU() && valu_can_accept_literal(ctx, instr) &&
+   } else if (instr->isVALU() && valu_can_accept_literal(ctx, instr, 0) &&
        instr->operands[0].isTemp() &&
        ctx.info[instr->operands[0].tempId()].is_literal() &&
        ctx.uses[instr->operands[0].tempId()] < threshold) {
@@ -2380,7 +2422,7 @@ void optimize(Program* program)
    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& instr : block.instructions)
-         label_instruction(ctx, instr);
+         label_instruction(ctx, block, instr);
    }
 
    ctx.uses = std::move(dead_code_analysis(program));
@@ -2388,7 +2430,7 @@ void optimize(Program* program)
    /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& instr : block.instructions)
-         combine_instruction(ctx, instr);
+         combine_instruction(ctx, block, instr);
    }
 
    /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */