aco: Flip s_cbranch / s_cselect to optimize out an s_not if possible.
authorTimur Kristóf <timur.kristof@gmail.com>
Tue, 19 Nov 2019 12:29:54 +0000 (13:29 +0100)
committerTimur Kristóf <timur.kristof@gmail.com>
Tue, 14 Jan 2020 20:21:06 +0000 (21:21 +0100)
When possible, get rid of an s_not when all it does is invert the SCC,
and its successor s_cbranch / s_cselect can be inverted instead.

Also modify some parts of instruction_selection to take advantage of
this feature.

Example:
s2: %3900,  s1: %3899:scc = s_andn2_b64 %0:exec, %406
s2: %3902 = s_cselect_b64 -1, 0, %3900:scc
s2: %407,  s1: %3903:scc = s_not_b64 %3902
s2: %3906,  s1: %3905:scc = s_and_b64 %407, %0:exec
p_cbranch_z %3905:scc
Can now be optimized to:
s2: %3900,  s1: %3899:scc = s_andn2_b64 %0:exec, %406
p_cbranch_nz %3900:scc

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_optimizer.cpp

index 996b06d8534d1ea34f509529b7197ed92944f8b8..abd7ffd1502050686bc787d0c18d8abdcfbec71e 100644 (file)
@@ -788,7 +788,9 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       if (instr->dest.dest.ssa.bit_size == 1) {
          assert(src.regClass() == bld.lm);
          assert(dst.regClass() == bld.lm);
-         bld.sop2(Builder::s_andn2, Definition(dst), bld.def(s1, scc), Operand(exec, bld.lm), src);
+         /* Don't use s_andn2 here, this allows the optimizer to make a better decision */
+         Temp tmp = bld.sop1(Builder::s_not, bld.def(bld.lm), bld.def(s1, scc), src);
+         bld.sop2(Builder::s_and, Definition(dst), bld.def(s1, scc), tmp, Operand(exec, bld.lm));
       } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_not_b32, dst);
       } else if (dst.type() == RegType::sgpr) {
@@ -5300,8 +5302,8 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te
    } else if (op == nir_op_iand && cluster_size == ctx->program->wave_size) {
       //subgroupAnd(val) -> (exec & ~val) == 0
       Temp tmp = bld.sop2(Builder::s_andn2, bld.def(bld.lm), bld.def(s1, scc), Operand(exec, bld.lm), src).def(1).getTemp();
-      Temp all = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), bld.scc(tmp), Operand(0u));
-      return bool_to_vector_condition(ctx, all);
+      Temp cond = bool_to_vector_condition(ctx, emit_wqm(ctx, tmp));
+      return bld.sop1(Builder::s_not, bld.def(bld.lm), bld.def(s1, scc), cond);
    } else if (op == nir_op_ior && cluster_size == ctx->program->wave_size) {
       //subgroupOr(val) -> (val & exec) != 0
       Temp tmp = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm)).def(1).getTemp();
@@ -5906,8 +5908,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       assert(dst.regClass() == bld.lm);
 
       Temp tmp = bld.sop2(Builder::s_andn2, bld.def(bld.lm), bld.def(s1, scc), Operand(exec, bld.lm), src).def(1).getTemp();
-      Temp all = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), bld.scc(tmp), Operand(0u));
-      bool_to_vector_condition(ctx, emit_wqm(ctx, all), dst);
+      Temp cond = bool_to_vector_condition(ctx, emit_wqm(ctx, tmp));
+      bld.sop1(Builder::s_not, Definition(dst), bld.def(s1, scc), cond);
       break;
    }
    case nir_intrinsic_vote_any: {
index 4001def5b664cb58873c0b272dfb22dac2313da9..224918c172f047f0011c97e1a49f2c4e8d5b4cf0 100644 (file)
@@ -84,11 +84,13 @@ enum Label {
    label_uniform_bool = 1 << 21,
    label_constant_64bit = 1 << 22,
    label_uniform_bitwise = 1 << 23,
+   label_scc_invert = 1 << 24,
 };
 
 static constexpr uint32_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success |
                                          label_add_sub | label_bitwise | label_uniform_bitwise | label_minmax | label_fcmp;
-static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool | label_omod2 | label_omod4 | label_omod5 | label_clamp;
+static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool |
+                                        label_omod2 | label_omod4 | label_omod5 | label_clamp | label_scc_invert;
 static constexpr uint32_t val_labels = label_constant | label_constant_64bit | label_literal | label_mad;
 
 struct ssa_info {
@@ -381,6 +383,17 @@ struct ssa_info {
       return label & label_fcmp;
    }
 
+   void set_scc_invert(Temp scc_inv)
+   {
+      add_label(label_scc_invert);
+      temp = scc_inv;
+   }
+
+   bool is_scc_invert()
+   {
+      return label & label_scc_invert;
+   }
+
    void set_uniform_bool(Temp uniform_bool)
    {
       add_label(label_uniform_bool);
@@ -830,6 +843,14 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
             continue;
          }
       }
+
+      else if (instr->format == Format::PSEUDO_BRANCH) {
+         if (ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
+            /* Flip the branch instruction to get rid of the scc_invert instruction */
+            instr->opcode = instr->opcode == aco_opcode::p_cbranch_z ? aco_opcode::p_cbranch_nz : aco_opcode::p_cbranch_z;
+            instr->operands[0].setTemp(ctx.info[instr->operands[0].tempId()].temp);
+         }
+      }
    }
 
    /* if this instruction doesn't define anything, return */
@@ -1097,6 +1118,17 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    case aco_opcode::s_add_u32:
       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
       break;
+   case aco_opcode::s_not_b32:
+   case aco_opcode::s_not_b64:
+      if (ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
+         ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
+         ctx.info[instr->definitions[1].tempId()].set_scc_invert(ctx.info[instr->operands[0].tempId()].temp);
+      } else if (ctx.info[instr->operands[0].tempId()].is_uniform_bitwise()) {
+         ctx.info[instr->definitions[0].tempId()].set_uniform_bitwise();
+         ctx.info[instr->definitions[1].tempId()].set_scc_invert(ctx.info[instr->operands[0].tempId()].instr->definitions[1].getTemp());
+      }
+      ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get());
+      break;
    case aco_opcode::s_and_b32:
    case aco_opcode::s_and_b64:
       if (instr->operands[1].isFixed() && instr->operands[1].physReg() == exec && instr->operands[0].isTemp()) {
@@ -1113,8 +1145,6 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          }
       }
       /* fallthrough */
-   case aco_opcode::s_not_b32:
-   case aco_opcode::s_not_b64:
    case aco_opcode::s_or_b32:
    case aco_opcode::s_or_b64:
    case aco_opcode::s_xor_b32:
@@ -1167,6 +1197,17 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
          ctx.info[instr->definitions[0].tempId()].set_uniform_bool(instr->operands[2].getTemp());
       }
+      if (instr->operands[2].isTemp() && ctx.info[instr->operands[2].tempId()].is_scc_invert()) {
+         /* Flip the operands to get rid of the scc_invert instruction */
+         std::swap(instr->operands[0], instr->operands[1]);
+         instr->operands[2].setTemp(ctx.info[instr->operands[2].tempId()].temp);
+      }
+      break;
+   case aco_opcode::p_wqm:
+      if (instr->operands[0].isTemp() &&
+          ctx.info[instr->operands[0].tempId()].is_scc_invert()) {
+         ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
+      }
       break;
    default:
       break;