intel/vec4: Try to emit a VF source in try_immediate_source
[mesa.git] / src / intel / compiler / brw_vec4_nir.cpp
index 38f92d2d6dbcc7e7fae2ac54cdb2017269ad842e..157a97bfc035a5a627d2ee8e841f9b1c6879cc70 100644 (file)
@@ -25,6 +25,7 @@
 #include "brw_vec4.h"
 #include "brw_vec4_builder.h"
 #include "brw_vec4_surface_builder.h"
+#include "brw_eu.h"
 
 using namespace brw;
 using namespace brw::surface_access;
@@ -1010,21 +1011,43 @@ vec4_visitor::emit_conversion_to_double(dst_reg dst, src_reg src,
 }
 
 /**
- * Try to use an immediate value for source 1
+ * Try to use an immediate value for a source
  *
  * In cases of flow control, constant propagation is sometimes unable to
  * determine that a register contains a constant value.  To work around this,
- * try to emit a literal as the second source here.
+ * try to emit a literal as one of the sources.  If \c try_src0_also is set,
+ * \c op[0] will also be tried for an immediate value.
+ *
+ * If \c op[0] is modified, the operands will be exchanged so that \c op[1]
+ * will always be the immediate value.
+ *
+ * \return The index of the source that was modified, 0 or 1, if successful.
+ * Otherwise, -1.
+ *
+ * \param op - Operands to the instruction
+ * \param try_src0_also - True if \c op[0] should also be a candidate for
+ *                        getting an immediate value.  This should only be set
+ *                        for commutative operations.
  */
-static void
+static int
 try_immediate_source(const nir_alu_instr *instr, src_reg *op,
+                     bool try_src0_also,
                      MAYBE_UNUSED const gen_device_info *devinfo)
 {
-   if (nir_src_bit_size(instr->src[1].src) != 32 ||
-       !nir_src_is_const(instr->src[1].src))
-      return;
+   unsigned idx;
+
+   if (nir_src_bit_size(instr->src[1].src) == 32 &&
+       nir_src_is_const(instr->src[1].src)) {
+      idx = 1;
+   } else if (try_src0_also &&
+         nir_src_bit_size(instr->src[0].src) == 32 &&
+         nir_src_is_const(instr->src[0].src)) {
+      idx = 0;
+   } else {
+      return -1;
+   }
 
-   const enum brw_reg_type old_type = op->type;
+   const enum brw_reg_type old_type = op[idx].type;
 
    switch (old_type) {
    case BRW_REGISTER_TYPE_D:
@@ -1033,22 +1056,22 @@ try_immediate_source(const nir_alu_instr *instr, src_reg *op,
       int d;
 
       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
-         if (nir_alu_instr_channel_used(instr, 1, i)) {
+         if (nir_alu_instr_channel_used(instr, idx, i)) {
             if (first_comp < 0) {
                first_comp = i;
-               d = nir_src_comp_as_int(instr->src[1].src,
-                                       instr->src[1].swizzle[i]);
-            } else if (d != nir_src_comp_as_int(instr->src[1].src,
-                                                instr->src[1].swizzle[i])) {
-               return;
+               d = nir_src_comp_as_int(instr->src[idx].src,
+                                       instr->src[idx].swizzle[i]);
+            } else if (d != nir_src_comp_as_int(instr->src[idx].src,
+                                                instr->src[idx].swizzle[i])) {
+               return -1;
             }
          }
       }
 
-      if (op->abs)
+      if (op[idx].abs)
          d = MAX2(-d, d);
 
-      if (op->negate) {
+      if (op[idx].negate) {
          /* On Gen8+ a negation source modifier on a logical operation means
           * something different.  Nothing should generate this, so assert that
           * it does not occur.
@@ -1059,41 +1082,116 @@ try_immediate_source(const nir_alu_instr *instr, src_reg *op,
          d = -d;
       }
 
-      *op = retype(src_reg(brw_imm_d(d)), old_type);
+      op[idx] = retype(src_reg(brw_imm_d(d)), old_type);
       break;
    }
 
    case BRW_REGISTER_TYPE_F: {
       int first_comp = -1;
-      float f;
+      float f[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
+      bool is_scalar = true;
 
       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
-         if (nir_alu_instr_channel_used(instr, 1, i)) {
+         if (nir_alu_instr_channel_used(instr, idx, i)) {
+            f[i] = nir_src_comp_as_float(instr->src[idx].src,
+                                         instr->src[idx].swizzle[i]);
             if (first_comp < 0) {
                first_comp = i;
-               f = nir_src_comp_as_float(instr->src[1].src,
-                                         instr->src[1].swizzle[i]);
-            } else if (f != nir_src_comp_as_float(instr->src[1].src,
-                                                  instr->src[1].swizzle[i])) {
-               return;
+            } else if (f[first_comp] != f[i]) {
+               is_scalar = false;
             }
          }
       }
 
-      if (op->abs)
-         f = fabs(f);
+      if (is_scalar) {
+         if (op[idx].abs)
+            f[first_comp] = fabs(f[first_comp]);
+
+         if (op[idx].negate)
+            f[first_comp] = -f[first_comp];
+
+         op[idx] = src_reg(brw_imm_f(f[first_comp]));
+         assert(op[idx].type == old_type);
+      } else {
+         uint8_t vf_values[4] = { 0, 0, 0, 0 };
+
+         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
+            if (op[idx].abs)
+               f[i] = fabs(f[i]);
 
-      if (op->negate)
-         f = -f;
+            if (op[idx].negate)
+               f[i] = -f[i];
 
-      *op = src_reg(brw_imm_f(f));
-      assert(op->type == old_type);
+            const int vf = brw_float_to_vf(f[i]);
+            if (vf == -1)
+               return -1;
+
+            vf_values[i] = vf;
+         }
+
+         op[idx] = src_reg(brw_imm_vf4(vf_values[0], vf_values[1],
+                                       vf_values[2], vf_values[3]));
+      }
       break;
    }
 
    default:
       unreachable("Non-32bit type.");
    }
+
+   /* The instruction format only allows source 1 to be an immediate value.
+    * If the immediate value was source 0, then the sources must be exchanged.
+    */
+   if (idx == 0) {
+      src_reg tmp = op[0];
+      op[0] = op[1];
+      op[1] = tmp;
+   }
+
+   return idx;
+}
+
+void
+vec4_visitor::fix_float_operands(src_reg op[3], nir_alu_instr *instr)
+{
+   bool fixed[3] = { false, false, false };
+
+   for (unsigned i = 0; i < 2; i++) {
+      if (!nir_src_is_const(instr->src[i].src))
+         continue;
+
+      for (unsigned j = i + 1; j < 3; j++) {
+         if (fixed[j])
+            continue;
+
+         if (!nir_src_is_const(instr->src[j].src))
+            continue;
+
+         if (nir_alu_srcs_equal(instr, instr, i, j)) {
+            if (!fixed[i])
+               op[i] = fix_3src_operand(op[i]);
+
+            op[j] = op[i];
+
+            fixed[i] = true;
+            fixed[j] = true;
+         } else if (nir_alu_srcs_negative_equal(instr, instr, i, j)) {
+            if (!fixed[i])
+               op[i] = fix_3src_operand(op[i]);
+
+            op[j] = op[i];
+            op[j].negate = !op[j].negate;
+
+            fixed[i] = true;
+            fixed[j] = true;
+         }
+      }
+   }
+
+   for (unsigned i = 0; i < 3; i++) {
+      if (!fixed[i])
+         op[i] = fix_3src_operand(op[i]);
+   }
 }
 
 void
@@ -1175,7 +1273,7 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
       /* fall through */
    case nir_op_fadd:
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit(ADD(dst, op[0], op[1]));
       inst->saturate = instr->dest.saturate;
       break;
@@ -1187,7 +1285,7 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       break;
 
    case nir_op_fmul:
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit(MUL(dst, op[0], op[1]));
       inst->saturate = instr->dest.saturate;
       break;
@@ -1415,7 +1513,7 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
       /* fall through */
    case nir_op_fmin:
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit_minmax(BRW_CONDITIONAL_L, dst, op[0], op[1]);
       inst->saturate = instr->dest.saturate;
       break;
@@ -1425,7 +1523,7 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
       /* fall through */
    case nir_op_fmax:
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit_minmax(BRW_CONDITIONAL_GE, dst, op[0], op[1]);
       inst->saturate = instr->dest.saturate;
       break;
@@ -1454,7 +1552,12 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
          brw_conditional_for_nir_comparison(instr->op);
 
       if (nir_src_bit_size(instr->src[0].src) < 64) {
-         try_immediate_source(instr, &op[1], devinfo);
+         /* If the order of the sources is changed due to an immediate value,
+          * then the condition must also be changed.
+          */
+         if (try_immediate_source(instr, op, true, devinfo) == 0)
+            conditional_mod = brw_swap_cmod(conditional_mod);
+
          emit(CMP(dst, op[0], op[1], conditional_mod));
       } else {
          /* Produce a 32-bit boolean result from the DF comparison by selecting
@@ -1524,7 +1627,7 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
          op[0] = resolve_source_modifiers(op[0]);
          op[1] = resolve_source_modifiers(op[1]);
       }
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, true, devinfo);
       emit(XOR(dst, op[0], op[1]));
       break;
 
@@ -1534,7 +1637,7 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
          op[0] = resolve_source_modifiers(op[0]);
          op[1] = resolve_source_modifiers(op[1]);
       }
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, true, devinfo);
       emit(OR(dst, op[0], op[1]));
       break;
 
@@ -1544,7 +1647,7 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
          op[0] = resolve_source_modifiers(op[0]);
          op[1] = resolve_source_modifiers(op[1]);
       }
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, true, devinfo);
       emit(AND(dst, op[0], op[1]));
       break;
 
@@ -1854,19 +1957,19 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
 
    case nir_op_ishl:
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, false, devinfo);
       emit(SHL(dst, op[0], op[1]));
       break;
 
    case nir_op_ishr:
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, false, devinfo);
       emit(ASR(dst, op[0], op[1]));
       break;
 
    case nir_op_ushr:
       assert(nir_dest_bit_size(instr->dest.dest) < 64);
-      try_immediate_source(instr, &op[1], devinfo);
+      try_immediate_source(instr, op, false, devinfo);
       emit(SHR(dst, op[0], op[1]));
       break;
 
@@ -1877,17 +1980,15 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
          inst = emit(ADD(dst, src_reg(mul_dst), op[2]));
          inst->saturate = instr->dest.saturate;
       } else {
-         op[0] = fix_3src_operand(op[0]);
-         op[1] = fix_3src_operand(op[1]);
-         op[2] = fix_3src_operand(op[2]);
-
+         fix_float_operands(op, instr);
          inst = emit(MAD(dst, op[2], op[1], op[0]));
          inst->saturate = instr->dest.saturate;
       }
       break;
 
    case nir_op_flrp:
-      inst = emit_lrp(dst, op[0], op[1], op[2]);
+      fix_float_operands(op, instr);
+      inst = emit(LRP(dst, op[2], op[1], op[0]));
       inst->saturate = instr->dest.saturate;
       break;
 
@@ -1918,21 +2019,25 @@ vec4_visitor::nir_emit_alu(nir_alu_instr *instr)
       break;
 
    case nir_op_fdot_replicated2:
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit(BRW_OPCODE_DP2, dst, op[0], op[1]);
       inst->saturate = instr->dest.saturate;
       break;
 
    case nir_op_fdot_replicated3:
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit(BRW_OPCODE_DP3, dst, op[0], op[1]);
       inst->saturate = instr->dest.saturate;
       break;
 
    case nir_op_fdot_replicated4:
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit(BRW_OPCODE_DP4, dst, op[0], op[1]);
       inst->saturate = instr->dest.saturate;
       break;
 
    case nir_op_fdph_replicated:
+      try_immediate_source(instr, op, true, devinfo);
       inst = emit(BRW_OPCODE_DPH, dst, op[0], op[1]);
       inst->saturate = instr->dest.saturate;
       break;