aco: replace extract_vector with copies
[mesa.git] / src / amd / compiler / aco_optimizer.cpp
index 7a16fc176c9dd977117a6a3cac2cc1713c00ce3f..19e78f9e656962fffcc608b1018a61e73038b080 100644 (file)
@@ -1624,6 +1624,40 @@ bool combine_three_valu_op(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode
    return false;
 }
 
+bool combine_minmax(opt_ctx& ctx, aco_ptr<Instruction>& instr, aco_opcode opposite, aco_opcode minmax3)
+{
+   if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2))
+      return true;
+
+   uint32_t omod_clamp = ctx.info[instr->definitions[0].tempId()].label &
+                         (label_omod_success | label_clamp_success);
+
+   /* min(-max(a, b), c) -> min3(-a, -b, c) *
+    * max(-min(a, b), c) -> max3(-a, -b, c) */
+   for (unsigned swap = 0; swap < 2; swap++) {
+      Operand operands[3];
+      bool neg[3], abs[3], clamp;
+      uint8_t opsel = 0, omod = 0;
+      bool inbetween_neg;
+      if (match_op3_for_vop3(ctx, instr->opcode, opposite,
+                             instr.get(), swap, "012",
+                             operands, neg, abs, &opsel,
+                             &clamp, &omod, &inbetween_neg, NULL, NULL) &&
+          inbetween_neg) {
+         ctx.uses[instr->operands[swap].tempId()]--;
+         neg[1] = true;
+         neg[2] = true;
+         create_vop3_for_op3(ctx, minmax3, instr, operands, neg, abs, opsel, clamp, omod);
+         if (omod_clamp & label_omod_success)
+            ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
+         if (omod_clamp & label_clamp_success)
+            ctx.info[instr->definitions[0].tempId()].set_clamp_success(instr.get());
+         return true;
+      }
+   }
+   return false;
+}
+
 /* s_not_b32(s_and_b32(a, b)) -> s_nand_b32(a, b)
  * s_not_b32(s_or_b32(a, b)) -> s_nor_b32(a, b)
  * s_not_b32(s_xor_b32(a, b)) -> s_xnor_b32(a, b)
@@ -1794,6 +1828,9 @@ bool get_minmax_info(aco_opcode op, aco_opcode *min, aco_opcode *max, aco_opcode
 bool combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr,
                    aco_opcode min, aco_opcode max, aco_opcode med)
 {
+   /* TODO: GLSL's clamp(x, minVal, maxVal) and SPIR-V's
+    * FClamp(x, minVal, maxVal)/NClamp(x, minVal, maxVal) are undefined if
+    * minVal > maxVal, which means we can always select it to a v_med3_f32 */
    aco_opcode other_op;
    if (instr->opcode == min)
       other_op = max;
@@ -1807,19 +1844,18 @@ bool combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr,
 
    for (unsigned swap = 0; swap < 2; swap++) {
       Operand operands[3];
-      bool neg[3], abs[3], clamp, inbetween_neg, inbetween_abs;
+      bool neg[3], abs[3], clamp;
       uint8_t opsel = 0, omod = 0;
       if (match_op3_for_vop3(ctx, instr->opcode, other_op, instr.get(), swap,
                              "012", operands, neg, abs, &opsel,
-                             &clamp, &omod, &inbetween_neg, &inbetween_abs, NULL)) {
+                             &clamp, &omod, NULL, NULL, NULL)) {
          int const0_idx = -1, const1_idx = -1;
          uint32_t const0 = 0, const1 = 0;
          for (int i = 0; i < 3; i++) {
             uint32_t val;
             if (operands[i].isConstant()) {
                val = operands[i].constantValue();
-            } else if (operands[i].isTemp() && ctx.uses[operands[i].tempId()] == 1 &&
-                       ctx.info[operands[i].tempId()].is_constant_or_literal()) {
+            } else if (operands[i].isTemp() && ctx.info[operands[i].tempId()].is_constant_or_literal()) {
                val = ctx.info[operands[i].tempId()].val;
             } else {
                continue;
@@ -1892,11 +1928,6 @@ bool combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr,
                return false;
          }
 
-         neg[1] ^= inbetween_neg;
-         neg[2] ^= inbetween_neg;
-         abs[1] |= inbetween_abs;
-         abs[2] |= inbetween_abs;
-
          ctx.uses[instr->operands[swap].tempId()]--;
          create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod);
          if (omod_clamp & label_omod_success)
@@ -2306,7 +2337,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr
       bool some_gfx9_only;
       if (get_minmax_info(instr->opcode, &min, &max, &min3, &max3, &med3, &some_gfx9_only) &&
           (!some_gfx9_only || ctx.program->chip_class >= GFX9)) {
-         if (combine_three_valu_op(ctx, instr, instr->opcode, instr->opcode == min ? min3 : max3, "012", 1 | 2));
+         if (combine_minmax(ctx, instr, instr->opcode == min ? max : min, instr->opcode == min ? min3 : max3)) ;
          else combine_clamp(ctx, instr, min, max, med3);
       }
    }
@@ -2322,7 +2353,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       return;
    }
 
-   /* convert split_vector into extract_vector if only one definition is ever used */
+   /* convert split_vector into a copy or extract_vector if only one definition is ever used */
    if (instr->opcode == aco_opcode::p_split_vector) {
       unsigned num_used = 0;
       unsigned idx = 0;
@@ -2332,7 +2363,39 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
             idx = i;
          }
       }
-      if (num_used == 1) {
+      bool done = false;
+      if (num_used == 1 && ctx.info[instr->operands[0].tempId()].is_vec() &&
+          ctx.uses[instr->operands[0].tempId()] == 1) {
+         Instruction *vec = ctx.info[instr->operands[0].tempId()].instr;
+
+         unsigned off = 0;
+         Operand op;
+         for (Operand& vec_op : vec->operands) {
+            if (off == idx * instr->definitions[0].size()) {
+               op = vec_op;
+               break;
+            }
+            off += vec_op.size();
+         }
+         if (off != instr->operands[0].size()) {
+            ctx.uses[instr->operands[0].tempId()]--;
+            for (Operand& vec_op : vec->operands) {
+               if (vec_op.isTemp())
+                  ctx.uses[vec_op.tempId()]--;
+            }
+            if (op.isTemp())
+               ctx.uses[op.tempId()]++;
+
+            aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, 1, 1)};
+            extract->operands[0] = op;
+            extract->definitions[0] = instr->definitions[idx];
+            instr.reset(extract.release());
+
+            done = true;
+         }
+      }
+
+      if (!done && num_used == 1) {
          aco_ptr<Pseudo_instruction> extract{create_instruction<Pseudo_instruction>(aco_opcode::p_extract_vector, Format::PSEUDO, 2, 1)};
          extract->operands[0] = instr->operands[0];
          extract->operands[1] = Operand((uint32_t) idx);