nir: Handle swizzle in nir_alu_srcs_negative_equal
authorIan Romanick <ian.d.romanick@intel.com>
Mon, 10 Jun 2019 22:05:14 +0000 (15:05 -0700)
committerIan Romanick <ian.d.romanick@intel.com>
Mon, 8 Jul 2019 18:30:11 +0000 (11:30 -0700)
When I added this function, I was not sure if swizzles of immediate
values were a thing that occurred in NIR.  The only existing user of
these functions is the partial redundancy elimination for compares.
Since comparison instructions are inherently scalar, this does not
occur.

However, a couple later patches, "nir/algebraic: Recognize
open-coded flrp(-1, 1, a) and flrp(1, -1, a)" combined with "intel/vec4:
Try to emit a single load for multiple 3-src instruction operands",
collaborate to create a few thousand instances.

No shader-db changes on any Intel platform.

v2: Handle the swizzle in nir_alu_srcs_negative_equal and leave
nir_const_value_negative_equal unchanged.  Suggested by Jason.

v3: Correctly handle write masks.  Add note (and assertion) that the
caller is responsible for various compatibility checks.  The single
existing caller only calls this for combinations of scalar fadd and
float comparison instructions, so all of the requirements are met.  A
later patch (intel/vec4: Try to emit a single load for multiple 3-src
instruction operands) will call this for sources of the same
instruction, so all of the requirements are met.

v4: Add unit test for nir_opt_comparison_pre that is fixed by this
commit.

Reviewed-by: Matt Turner <mattst88@gmail.com>
src/compiler/nir/nir_instr_set.c
src/compiler/nir/tests/comparison_pre_tests.cpp
src/compiler/nir/tests/negative_equal_tests.cpp

index 6796fcaad5ba511dccf1ab9f528927ec376a2dd0..d200412dc9c279502317027641ecdc565ea857ee 100644 (file)
@@ -352,12 +352,31 @@ nir_const_value_negative_equal(nir_const_value c1,
  * This function does not detect the general case when \p alu1 and \p alu2 are
  * SSA values that are the negations of each other (e.g., \p alu1 represents
  * (a * b) and \p alu2 represents (-a * b)).
+ *
+ * \warning
+ * It is the responsibility of the caller to ensure that the component counts,
+ * write masks, and base types of the sources being compared are compatible.
  */
 bool
 nir_alu_srcs_negative_equal(const nir_alu_instr *alu1,
                             const nir_alu_instr *alu2,
                             unsigned src1, unsigned src2)
 {
+#ifndef NDEBUG
+   for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
+      assert(nir_alu_instr_channel_used(alu1, src1, i) ==
+             nir_alu_instr_channel_used(alu2, src2, i));
+   }
+
+   if (nir_op_infos[alu1->op].input_types[src1] == nir_type_float) {
+      assert(nir_op_infos[alu1->op].input_types[src1] ==
+             nir_op_infos[alu2->op].input_types[src2]);
+   } else {
+      assert(nir_op_infos[alu1->op].input_types[src1] == nir_type_int);
+      assert(nir_op_infos[alu2->op].input_types[src2] == nir_type_int);
+   }
+#endif
+
    if (alu1->src[src1].abs != alu2->src[src2].abs)
       return false;
 
@@ -385,12 +404,13 @@ nir_alu_srcs_negative_equal(const nir_alu_instr *alu1,
           nir_src_bit_size(alu2->src[src2].src))
          return false;
 
-      /* FINISHME: Apply the swizzle? */
-      const unsigned components = nir_ssa_alu_instr_src_components(alu1, src1);
       const nir_alu_type full_type = nir_op_infos[alu1->op].input_types[src1] |
                                      nir_src_bit_size(alu1->src[src1].src);
-      for (unsigned i = 0; i < components; i++) {
-         if (!nir_const_value_negative_equal(const1[i], const2[i], full_type))
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
+         if (nir_alu_instr_channel_used(alu1, src1, i) &&
+             !nir_const_value_negative_equal(const1[alu1->src[src1].swizzle[i]],
+                                             const2[alu2->src[src2].swizzle[i]],
+                                             full_type))
             return false;
       }
 
index fe1cc23fb3bc13d2861c0d721cb17e442f05c07f..a48aeca8da4616771a0aa4885db3d89777216c8a 100644 (file)
@@ -473,6 +473,56 @@ TEST_F(comparison_pre_test, a_lt_neg_imm_vs_a_plus_imm)
    EXPECT_TRUE(nir_opt_comparison_pre_impl(bld.impl));
 }
 
+TEST_F(comparison_pre_test, swizzle_of_same_immediate_vector)
+{
+   /* Before:
+    *
+    * vec4 32 ssa_0 = load_const (-2.0, -1.0,  1.0,  2.0)
+    * vec4 32 ssa_1 = load_const ( 2.0,  1.0, -1.0, -2.0)
+    * vec4 32 ssa_2 = load_const ( 3.0,  4.0,  5.0,  6.0)
+    * vec4 32 ssa_3 = fadd ssa_0, ssa_2
+    * vec1 1 ssa_4 = flt ssa_0.x, ssa_3.x
+    *
+    * if ssa_4 {
+    *    vec1 32 ssa_5 = fadd ssa_0.w, ssa_3.x
+    * } else {
+    * }
+    */
+   nir_ssa_def *a = nir_fadd(&bld, v1, v3);
+
+   nir_alu_instr *flt = nir_alu_instr_create(bld.shader, nir_op_flt);
+
+   flt->src[0].src = nir_src_for_ssa(v1);
+   flt->src[1].src = nir_src_for_ssa(a);
+
+   memcpy(&flt->src[0].swizzle, xxxx, sizeof(xxxx));
+   memcpy(&flt->src[1].swizzle, xxxx, sizeof(xxxx));
+
+   nir_builder_alu_instr_finish_and_insert(&bld, flt);
+
+   flt->dest.dest.ssa.num_components = 1;
+   flt->dest.write_mask = 1;
+
+   nir_if *nif = nir_push_if(&bld, &flt->dest.dest.ssa);
+
+   nir_alu_instr *fadd = nir_alu_instr_create(bld.shader, nir_op_fadd);
+
+   fadd->src[0].src = nir_src_for_ssa(v1);
+   fadd->src[1].src = nir_src_for_ssa(a);
+
+   memcpy(&fadd->src[0].swizzle, wwww, sizeof(wwww));
+   memcpy(&fadd->src[1].swizzle, xxxx, sizeof(xxxx));
+
+   nir_builder_alu_instr_finish_and_insert(&bld, fadd);
+
+   fadd->dest.dest.ssa.num_components = 1;
+   fadd->dest.write_mask = 1;
+
+   nir_pop_if(&bld, nif);
+
+   EXPECT_TRUE(nir_opt_comparison_pre_impl(bld.impl));
+}
+
 TEST_F(comparison_pre_test, non_scalar_add_result)
 {
    /* The optimization pass should not do anything because the result of the
index 5e13c8fd28a6dc3fc0a39f5d6c57e3676349232f..9fedb987166121a6af97164e6a9d400634a8bb46 100644 (file)
@@ -270,6 +270,42 @@ compare_with_negation(nir_type_uint32)
 compare_with_negation(nir_type_int64)
 compare_with_negation(nir_type_uint64)
 
+TEST_F(alu_srcs_negative_equal_test, swizzle_scalar_to_vector)
+{
+   nir_ssa_def *v = nir_imm_vec2(&bld, 1.0, -1.0);
+   const uint8_t s0[4] = { 0, 0, 0, 0 };
+   const uint8_t s1[4] = { 1, 1, 1, 1 };
+
+   /* We can't use nir_swizzle here because it inserts an extra MOV. */
+   nir_alu_instr *instr = nir_alu_instr_create(bld.shader, nir_op_fadd);
+
+   instr->src[0].src = nir_src_for_ssa(v);
+   instr->src[1].src = nir_src_for_ssa(v);
+
+   memcpy(&instr->src[0].swizzle, s0, sizeof(s0));
+   memcpy(&instr->src[1].swizzle, s1, sizeof(s1));
+
+   nir_builder_alu_instr_finish_and_insert(&bld, instr);
+
+   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
+}
+
+TEST_F(alu_srcs_negative_equal_test, unused_components_mismatch)
+{
+   nir_ssa_def *v1 = nir_imm_vec4(&bld, -2.0, 18.0, 43.0,  1.0);
+   nir_ssa_def *v2 = nir_imm_vec4(&bld,  2.0, 99.0, 76.0, -1.0);
+
+   nir_ssa_def *result = nir_fadd(&bld, v1, v2);
+
+   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
+
+   /* Disable the channels that aren't negations of each other. */
+   instr->dest.dest.is_ssa = false;
+   instr->dest.write_mask = 8 + 1;
+
+   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
+}
+
 static void
 count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS],
                nir_alu_type full_type, int first)