aco: optimize add/sub(a, cndmask(b, 0, 1, cond)) -> addc/subbrev_co(0, a, b)
authorSamuel Pitoiset <samuel.pitoiset@gmail.com>
Thu, 2 Apr 2020 15:41:36 +0000 (17:41 +0200)
committerMarge Bot <eric+marge@anholt.net>
Tue, 12 May 2020 16:15:17 +0000 (16:15 +0000)
v2: outline into a separate function and also optimize additions (by Daniel Schürmann)

Totals from affected shaders: (VEGA)
SGPRS: 938888 -> 941496 (0.28 %)
VGPRS: 832068 -> 831532 (-0.06 %)
Spilled SGPRs: 618 -> 618 (0.00 %)
Spilled VGPRs: 0 -> 0 (0.00 %)
Private memory VGPRs: 0 -> 0 (0.00 %)
Scratch size: 3696 -> 3696 (0.00 %) dwords per thread
Code Size: 72893900 -> 72558928 (-0.46 %) bytes
LDS: 18201 -> 18201 (0.00 %) blocks
Max Waves: 64256 -> 64268 (0.02 %)

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Co-authored-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4419>

src/amd/compiler/aco_optimizer.cpp

index 9203f1c4b43d1dd3a76e41152027fe36c2a5e50f..ab9b0f5f6e742083dd06d743cc70bea6c6a10cb5 100644 (file)
@@ -87,12 +87,13 @@ enum Label {
    label_scc_invert = 1 << 24,
    label_vcc_hint = 1 << 25,
    label_scc_needed = 1 << 26,
+   label_b2i = 1 << 27,
 };
 
 static constexpr uint32_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success |
                                          label_add_sub | label_bitwise | label_uniform_bitwise | label_minmax | label_fcmp;
 static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool |
-                                        label_omod2 | label_omod4 | label_omod5 | label_clamp | label_scc_invert;
+                                        label_omod2 | label_omod4 | label_omod5 | label_clamp | label_scc_invert | label_b2i;
 static constexpr uint32_t val_labels = label_constant | label_constant_64bit | label_literal | label_mad;
 
 struct ssa_info {
@@ -428,6 +429,18 @@ struct ssa_info {
    {
       return label & label_vcc_hint;
    }
+
+   void set_b2i(Temp val)
+   {
+      add_label(label_b2i);
+      temp = val;
+   }
+
+   bool is_b2i()
+   {
+      return label & label_b2i;
+   }
+
 };
 
 struct opt_ctx {
@@ -1135,13 +1148,14 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    }
    case aco_opcode::v_cndmask_b32:
       if (instr->operands[0].constantEquals(0) &&
-          instr->operands[1].constantEquals(0xFFFFFFFF) &&
-          instr->operands[2].isTemp())
+          instr->operands[1].constantEquals(0xFFFFFFFF))
          ctx.info[instr->definitions[0].tempId()].set_vcc(instr->operands[2].getTemp());
       else if (instr->operands[0].constantEquals(0) &&
-               instr->operands[1].constantEquals(0x3f800000u) &&
-               instr->operands[2].isTemp())
+               instr->operands[1].constantEquals(0x3f800000u))
          ctx.info[instr->definitions[0].tempId()].set_b2f(instr->operands[2].getTemp());
+      else if (instr->operands[0].constantEquals(0) &&
+               instr->operands[1].constantEquals(1))
+         ctx.info[instr->definitions[0].tempId()].set_b2i(instr->operands[2].getTemp());
 
       ctx.info[instr->operands[2].tempId()].set_vcc_hint();
       break;
@@ -1961,6 +1975,44 @@ bool combine_salu_lshl_add(opt_ctx& ctx, aco_ptr<Instruction>& instr)
    return false;
 }
 
+bool combine_add_sub_b2i(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode new_op, uint8_t ops)
+{
+   if (instr->usesModifiers())
+      return false;
+
+   for (unsigned i = 0; i < 2; i++) {
+      if (!((1 << i) & ops))
+         continue;
+      if (instr->operands[i].isTemp() &&
+          ctx.info[instr->operands[i].tempId()].is_b2i() &&
+          ctx.uses[instr->operands[i].tempId()] == 1) {
+
+         aco_ptr<Instruction> new_instr;
+         if (instr->operands[!i].isTemp() && instr->operands[!i].getTemp().type() == RegType::vgpr) {
+            new_instr.reset(create_instruction<VOP2_instruction>(new_op, Format::VOP2, 3, 2));
+         } else if (ctx.program->chip_class >= GFX10 ||
+                    (instr->operands[!i].isConstant() && !instr->operands[!i].isLiteral())) {
+            new_instr.reset(create_instruction<VOP3A_instruction>(new_op, asVOP3(Format::VOP2), 3, 2));
+         } else {
+            return false;
+         }
+         ctx.uses[instr->operands[i].tempId()]--;
+         new_instr->definitions[0] = instr->definitions[0];
+         new_instr->definitions[1] = instr->definitions.size() == 2 ? instr->definitions[1] :
+                                                                      Definition(ctx.program->allocateId(), ctx.program->lane_mask);
+         new_instr->definitions[1].setHint(vcc);
+         new_instr->operands[0] = Operand(0u);
+         new_instr->operands[1] = instr->operands[!i];
+         new_instr->operands[2] = Operand(ctx.info[instr->operands[i].tempId()].temp);
+         instr = std::move(new_instr);
+         ctx.info[instr->definitions[0].tempId()].label = 0;
+         return true;
+      }
+   }
+
+   return false;
+}
+
 bool get_minmax_info(aco_opcode op, aco_opcode *min, aco_opcode *max, aco_opcode *min3, aco_opcode *max3, aco_opcode *med3, bool *some_gfx9_only)
 {
    switch (op) {
@@ -2485,14 +2537,28 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
       else if (combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ;
       else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_or_b32, "120", 1 | 2)) ;
       else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_or_b32, "210", 1 | 2);
-   } else if (instr->opcode == aco_opcode::v_add_u32 && ctx.program->chip_class >= GFX9) {
-      if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32, "120", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
-      else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_add_u32, "120", 1 | 2)) ;
-      else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_add_u32, "210", 1 | 2);
+   } else if (instr->opcode == aco_opcode::v_add_u32) {
+      if (combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2)) ;
+      else if (ctx.program->chip_class >= GFX9) {
+         if (combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xad_u32, "120", 1 | 2)) ;
+         else if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xad_u32, "120", 1 | 2)) ;
+         else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
+         else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
+         else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ;
+         else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_add_u32, "120", 1 | 2)) ;
+         else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_add_u32, "210", 1 | 2);
+      }
+   } else if (instr->opcode == aco_opcode::v_add_co_u32 ||
+              instr->opcode == aco_opcode::v_add_co_u32_e64) {
+      combine_add_sub_b2i(ctx, instr, aco_opcode::v_addc_co_u32, 1 | 2);
+   } else if (instr->opcode == aco_opcode::v_sub_u32 ||
+              instr->opcode == aco_opcode::v_sub_co_u32 ||
+              instr->opcode == aco_opcode::v_sub_co_u32_e64) {
+      combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 2);
+   } else if (instr->opcode == aco_opcode::v_subrev_u32 ||
+              instr->opcode == aco_opcode::v_subrev_co_u32 ||
+              instr->opcode == aco_opcode::v_subrev_co_u32_e64) {
+      combine_add_sub_b2i(ctx, instr, aco_opcode::v_subbrev_co_u32, 1);
    } else if (instr->opcode == aco_opcode::v_lshlrev_b32 && ctx.program->chip_class >= GFX9) {
       combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add_lshl_u32, "120", 2);
    } else if ((instr->opcode == aco_opcode::s_add_u32 || instr->opcode == aco_opcode::s_add_i32) && ctx.program->chip_class >= GFX9) {