nir: rename nir_op_fne to nir_op_fneu
[mesa.git] / src / compiler / nir / nir_opt_find_array_copies.c
index 63609715c3a6273393e07bfe73546f81ddcd6209..0e50ca18af9f1c70e260e10573cf01ef7fca10ed 100644 (file)
@@ -53,7 +53,9 @@ struct match_node {
 
 struct match_state {
    /* Map from nir_variable * -> match_node */
-   struct hash_table *table;
+   struct hash_table *var_nodes;
+   /* Map from cast nir_deref_instr * -> match_node */
+   struct hash_table *cast_nodes;
 
    unsigned cur_instr;
 
@@ -89,12 +91,25 @@ node_for_deref(nir_deref_instr *instr, struct match_node *parent,
    unsigned idx;
    switch (instr->deref_type) {
    case nir_deref_type_var: {
-      struct hash_entry *entry = _mesa_hash_table_search(state->table, instr->var);
+      struct hash_entry *entry =
+         _mesa_hash_table_search(state->var_nodes, instr->var);
       if (entry) {
          return entry->data;
       } else {
          struct match_node *node = create_match_node(instr->type, state);
-         _mesa_hash_table_insert(state->table, instr->var, node);
+         _mesa_hash_table_insert(state->var_nodes, instr->var, node);
+         return node;
+      }
+   }
+
+   case nir_deref_type_cast: {
+      struct hash_entry *entry =
+         _mesa_hash_table_search(state->cast_nodes, instr);
+      if (entry) {
+         return entry->data;
+      } else {
+         struct match_node *node = create_match_node(instr->type, state);
+         _mesa_hash_table_insert(state->cast_nodes, instr, node);
          return node;
       }
    }
@@ -223,6 +238,17 @@ _foreach_aliasing(nir_deref_instr **deref, match_cb cb,
    }
 }
 
+static void
+_foreach_child(match_cb cb, struct match_node *node, struct match_state *state)
+{
+   if (node->num_children == 0) {
+      cb(node, state);
+   } else {
+      for (unsigned i = 0; i < node->num_children; i++)
+         _foreach_child(cb, node->children[i], state);
+   }
+}
+
 /* Given a deref path, find all the leaf deref nodes that alias it. */
 
 static void
@@ -230,11 +256,32 @@ foreach_aliasing_node(nir_deref_path *path,
                       match_cb cb,
                       struct match_state *state)
 {
-   assert(path->path[0]->deref_type == nir_deref_type_var);
-   struct hash_entry *entry = _mesa_hash_table_search(state->table,
-                                                      path->path[0]->var);
-   if (entry)
-      _foreach_aliasing(&path->path[1], cb, entry->data, state);
+   if (path->path[0]->deref_type == nir_deref_type_var) {
+      struct hash_entry *entry = _mesa_hash_table_search(state->var_nodes,
+                                                         path->path[0]->var);
+      if (entry)
+         _foreach_aliasing(&path->path[1], cb, entry->data, state);
+
+      hash_table_foreach(state->cast_nodes, entry)
+         _foreach_child(cb, entry->data, state);
+   } else {
+      /* Casts automatically alias anything that isn't a cast */
+      assert(path->path[0]->deref_type == nir_deref_type_cast);
+      hash_table_foreach(state->var_nodes, entry)
+         _foreach_child(cb, entry->data, state);
+
+      /* Casts alias other casts if the casts are different or if they're the
+       * same and the path from the cast may alias as per the usual rules.
+       */
+      hash_table_foreach(state->cast_nodes, entry) {
+         const nir_deref_instr *cast = entry->key;
+         assert(cast->deref_type == nir_deref_type_cast);
+         if (cast == path->path[0])
+            _foreach_aliasing(&path->path[1], cb, entry->data, state);
+         else
+            _foreach_child(cb, entry->data, state);
+      }
+   }
 }
 
 static nir_deref_instr *
@@ -260,7 +307,8 @@ clobber(struct match_node *node, struct match_state *state)
 
 static bool
 try_match_deref(nir_deref_path *base_path, int *path_array_idx,
-                nir_deref_path *deref_path, int arr_idx)
+                nir_deref_path *deref_path, int arr_idx,
+                nir_deref_instr *dst)
 {
    for (int i = 0; ; i++) {
       nir_deref_instr *b = base_path->path[i];
@@ -292,11 +340,13 @@ try_match_deref(nir_deref_path *base_path, int *path_array_idx,
          /* If we don't have an index into the path yet or if this entry in
           * the path is at the array index, see if this is a candidate.  We're
           * looking for an index which is zero in the base deref and arr_idx
-          * in the search deref.
+          * in the search deref and has a matching array size.
           */
          if ((*path_array_idx < 0 || *path_array_idx == i) &&
              const_b_idx && b_idx == 0 &&
-             const_d_idx && d_idx == arr_idx) {
+             const_d_idx && d_idx == arr_idx &&
+             glsl_get_length(nir_deref_instr_parent(b)->type) ==
+             glsl_get_length(nir_deref_instr_parent(dst)->type)) {
             *path_array_idx = i;
             continue;
          }
@@ -398,7 +448,8 @@ handle_write(nir_deref_instr *dst, nir_deref_instr *src,
          nir_deref_path_init(&src_path, src, state->dead_ctx);
          bool result = try_match_deref(&dst_node->first_src_path,
                                        &dst_node->src_wildcard_idx,
-                                       &src_path, dst_node->next_array_idx);
+                                       &src_path, dst_node->next_array_idx,
+                                       *instr);
          nir_deref_path_finish(&src_path);
          if (!result)
             goto reset;
@@ -468,7 +519,8 @@ opt_find_array_copies_block(nir_builder *b, nir_block *block,
 
    unsigned next_index = 0;
 
-   _mesa_hash_table_clear(state->table, NULL);
+   _mesa_hash_table_clear(state->var_nodes, NULL);
+   _mesa_hash_table_clear(state->cast_nodes, NULL);
 
    nir_foreach_instr(instr, block) {
       if (instr->type != nir_instr_type_intrinsic)
@@ -541,13 +593,16 @@ opt_find_array_copies_block(nir_builder *b, nir_block *block,
       /* There must be no indirects in the source or destination and no known
        * out-of-bounds accesses in the source, and the copy must be fully
        * qualified, or else we can't build up the array copy. We handled
-       * out-of-bounds accesses to the dest above.
+       * out-of-bounds accesses to the dest above. The types must match, since
+       * copy_deref currently can't bitcast mismatched deref types.
        */
       if (src_deref &&
           (nir_deref_instr_has_indirect(src_deref) ||
            nir_deref_instr_is_known_out_of_bounds(src_deref) ||
            nir_deref_instr_has_indirect(dst_deref) ||
-           !glsl_type_is_vector_or_scalar(src_deref->type))) {
+           !glsl_type_is_vector_or_scalar(src_deref->type) ||
+           glsl_get_bare_type(src_deref->type) !=
+           glsl_get_bare_type(dst_deref->type))) {
          src_deref = NULL;
       }
 
@@ -569,7 +624,8 @@ opt_find_array_copies_impl(nir_function_impl *impl)
 
    struct match_state s;
    s.dead_ctx = ralloc_context(NULL);
-   s.table = _mesa_pointer_hash_table_create(s.dead_ctx);
+   s.var_nodes = _mesa_pointer_hash_table_create(s.dead_ctx);
+   s.cast_nodes = _mesa_pointer_hash_table_create(s.dead_ctx);
    nir_builder_init(&s.builder, impl);
 
    nir_foreach_block(block, impl) {
@@ -582,6 +638,8 @@ opt_find_array_copies_impl(nir_function_impl *impl)
    if (progress) {
       nir_metadata_preserve(impl, nir_metadata_block_index |
                                   nir_metadata_dominance);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
    }
 
    return progress;