Merge branch 'master' of ../mesa into vulkan
[mesa.git] / src / glsl / nir / nir_search.c
index e69fdfd431c8f70f2a231eed5f2465cb6bbcb27e..bb1544079142421eef3e0276277e85c5897bfda9 100644 (file)
@@ -39,6 +39,33 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
 
 static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
 
+static bool alu_instr_is_bool(nir_alu_instr *instr);
+
+static bool
+src_is_bool(nir_src src)
+{
+   if (!src.is_ssa)
+      return false;
+   if (src.ssa->parent_instr->type != nir_instr_type_alu)
+      return false;
+   return alu_instr_is_bool(nir_instr_as_alu(src.ssa->parent_instr));
+}
+
+static bool
+alu_instr_is_bool(nir_alu_instr *instr)
+{
+   switch (instr->op) {
+   case nir_op_iand:
+   case nir_op_ior:
+   case nir_op_ixor:
+      return src_is_bool(instr->src[0].src) && src_is_bool(instr->src[1].src);
+   case nir_op_inot:
+      return src_is_bool(instr->src[0].src);
+   default:
+      return nir_op_infos[instr->op].output_type == nir_type_bool;
+   }
+}
+
 static bool
 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
             unsigned num_components, const uint8_t *swizzle,
@@ -46,7 +73,15 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
 {
    uint8_t new_swizzle[4];
 
-   for (int i = 0; i < num_components; ++i)
+   /* If the source is an explicitly sized source, then we need to reset
+    * both the number of components and the swizzle.
+    */
+   if (nir_op_infos[instr->op].input_sizes[src] != 0) {
+      num_components = nir_op_infos[instr->op].input_sizes[src];
+      swizzle = identity_swizzle;
+   }
+
+   for (unsigned i = 0; i < num_components; ++i)
       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
 
    switch (value->type) {
@@ -63,6 +98,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
 
    case nir_search_value_variable: {
       nir_search_variable *var = nir_search_value_as_variable(value);
+      assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
 
       if (state->variables_seen & (1 << var->variable)) {
          if (!nir_srcs_equal(state->variables[var->variable].src,
@@ -71,19 +107,35 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
 
          assert(!instr->src[src].abs && !instr->src[src].negate);
 
-         for (int i = 0; i < num_components; ++i) {
+         for (unsigned i = 0; i < num_components; ++i) {
             if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
                return false;
          }
 
          return true;
       } else {
+         if (var->is_constant &&
+             instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
+            return false;
+
+         if (var->type != nir_type_invalid) {
+            if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
+               return false;
+
+            nir_alu_instr *src_alu =
+               nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
+
+            if (nir_op_infos[src_alu->op].output_type != var->type &&
+                !(var->type == nir_type_bool && alu_instr_is_bool(src_alu)))
+               return false;
+         }
+
          state->variables_seen |= (1 << var->variable);
          state->variables[var->variable].src = instr->src[src].src;
          state->variables[var->variable].abs = false;
          state->variables[var->variable].negate = false;
 
-         for (int i = 0; i < 4; ++i) {
+         for (unsigned i = 0; i < 4; ++i) {
             if (i < num_components)
                state->variables[var->variable].swizzle[i] = new_swizzle[i];
             else
@@ -155,16 +207,13 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
       }
    }
 
+   /* Stash off the current variables_seen bitmask.  This way we can
+    * restore it prior to matching in the commutative case below.
+    */
+   unsigned variables_seen_stash = state->variables_seen;
+
    bool matched = true;
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
-      /* If the source is an explicitly sized source, then we need to reset
-       * both the number of components and the swizzle.
-       */
-      if (nir_op_infos[instr->op].input_sizes[i] != 0) {
-         num_components = nir_op_infos[instr->op].input_sizes[i];
-         swizzle = identity_swizzle;
-      }
-
       if (!match_value(expr->srcs[i], instr, i, num_components,
                        swizzle, state)) {
          matched = false;
@@ -175,8 +224,15 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
    if (matched)
       return true;
 
-   if (nir_op_infos[instr->op].num_inputs == 2 &&
-       (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE)) {
+   if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
+      assert(nir_op_infos[instr->op].num_inputs == 2);
+
+      /* Restore the variables_seen bitmask.  If we don't do this, then we
+       * could end up with an erroneous failure due to variables found in the
+       * first match attempt above not matching those in the second.
+       */
+      state->variables_seen = variables_seen_stash;
+
       if (!match_value(expr->srcs[0], instr, 1, num_components,
                        swizzle, state))
          return false;
@@ -201,8 +257,7 @@ construct_value(const nir_search_value *value, nir_alu_type type,
          num_components = nir_op_infos[expr->opcode].output_size;
 
       nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
-      alu->dest.dest.is_ssa = true;
-      nir_ssa_def_init(&alu->instr, &alu->dest.dest.ssa, num_components, NULL);
+      nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, NULL);
       alu->dest.write_mask = (1 << num_components) - 1;
       alu->dest.saturate = false;
 
@@ -234,8 +289,10 @@ construct_value(const nir_search_value *value, nir_alu_type type,
       const nir_search_variable *var = nir_search_value_as_variable(value);
       assert(state->variables_seen & (1 << var->variable));
 
-      nir_alu_src val = state->variables[var->variable];
-      val.src = nir_src_copy(val.src, mem_ctx);
+      nir_alu_src val = { NIR_SRC_INIT };
+      nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
+
+      assert(!var->is_constant);
 
       return val;
    }
@@ -301,9 +358,8 @@ nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
     */
    nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
    mov->dest.write_mask = instr->dest.write_mask;
-   mov->dest.dest.is_ssa = true;
-   nir_ssa_def_init(&mov->instr, &mov->dest.dest.ssa,
-                    instr->dest.dest.ssa.num_components, NULL);
+   nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
+                     instr->dest.dest.ssa.num_components, NULL);
 
    mov->src[0] = construct_value(replace, nir_op_infos[instr->op].output_type,
                                  instr->dest.dest.ssa.num_components, &state,
@@ -311,7 +367,7 @@ nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
    nir_instr_insert_before(&instr->instr, &mov->instr);
 
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
-                            nir_src_for_ssa(&mov->dest.dest.ssa), mem_ctx);
+                            nir_src_for_ssa(&mov->dest.dest.ssa));
 
    /* We know this one has no more uses because we just rewrote them all,
     * so we can remove it.  The rest of the matched expression, however, we