aco: Optimize out trivial code from uniform bools.
[mesa.git] / src / amd / compiler / aco_optimizer.cpp
index 68f17569e329b2e0318b0834242304013e073067..d3b35761704874253796fe4d8b9445bac53d4e15 100644 (file)
@@ -82,10 +82,11 @@ enum Label {
    label_bitwise = 1 << 18,
    label_minmax = 1 << 19,
    label_fcmp = 1 << 20,
+   label_uniform_bool = 1 << 21,
 };
 
 static constexpr uint32_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success | label_add_sub | label_bitwise | label_minmax | label_fcmp;
-static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f;
+static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool;
 static constexpr uint32_t val_labels = label_constant | label_literal | label_mad;
 
 struct ssa_info {
@@ -353,6 +354,17 @@ struct ssa_info {
       return label & label_fcmp;
    }
 
+   void set_uniform_bool(Temp uniform_bool)
+   {
+      add_label(label_uniform_bool);
+      temp = uniform_bool;
+   }
+
+   bool is_uniform_bool()
+   {
+      return label & label_uniform_bool;
+   }
+
 };
 
 struct opt_ctx {
@@ -765,7 +777,7 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          if (vec_op.isConstant()) {
             if (vec_op.isLiteral())
                ctx.info[instr->definitions[i].tempId()].set_literal(vec_op.constantValue());
-            else
+            else if (vec_op.size() == 1)
                ctx.info[instr->definitions[i].tempId()].set_constant(vec_op.constantValue());
          } else {
             assert(vec_op.isTemp());
@@ -794,7 +806,7 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          if (vec_op.isConstant()) {
             if (vec_op.isLiteral())
                ctx.info[instr->definitions[0].tempId()].set_literal(vec_op.constantValue());
-            else
+            else if (vec_op.size() == 1)
                ctx.info[instr->definitions[0].tempId()].set_constant(vec_op.constantValue());
          } else {
             assert(vec_op.isTemp());
@@ -814,7 +826,7 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       } else if (instr->operands[0].isConstant()) {
          if (instr->operands[0].isLiteral())
             ctx.info[instr->definitions[0].tempId()].set_literal(instr->operands[0].constantValue());
-         else
+         else if (instr->operands[0].size() == 1)
             ctx.info[instr->definitions[0].tempId()].set_constant(instr->operands[0].constantValue());
       } else if (instr->operands[0].isTemp()) {
          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
@@ -974,10 +986,15 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    case aco_opcode::s_add_u32:
       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
       break;
+   case aco_opcode::s_and_b64:
+      if (instr->operands[1].isFixed() && instr->operands[1].physReg() == exec &&
+          instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
+         ctx.info[instr->definitions[1].tempId()].set_temp(ctx.info[instr->operands[0].tempId()].temp);
+      }
+      /* fallthrough */
+   case aco_opcode::s_and_b32:
    case aco_opcode::s_not_b32:
    case aco_opcode::s_not_b64:
-   case aco_opcode::s_and_b32:
-   case aco_opcode::s_and_b64:
    case aco_opcode::s_or_b32:
    case aco_opcode::s_or_b64:
    case aco_opcode::s_xor_b32:
@@ -1017,6 +1034,13 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    case aco_opcode::v_cmp_nlt_f32:
       ctx.info[instr->definitions[0].tempId()].set_fcmp(instr.get());
       break;
+   case aco_opcode::s_cselect_b64:
+      if (instr->operands[0].constantEquals((unsigned) -1) &&
+          instr->operands[1].constantEquals(0)) {
+         /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
+         ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
+      }
+      break;
    default:
       break;
    }