nir/split_vars: Properly bail in the presence of complex derefs
authorJason Ekstrand <jason@jlekstrand.net>
Wed, 22 May 2019 20:54:39 +0000 (15:54 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Fri, 31 May 2019 01:08:03 +0000 (01:08 +0000)
Reviewed-by: Dave Airlie <airlied@redhat.com>
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
src/compiler/nir/nir_split_vars.c

index 2ff8257020310a1b132c3017b567651f73147ce3..3d98b5c8805fbea9db7a0ddf7e7243b33e38c120 100644 (file)
 #include "nir_deref.h"
 #include "nir_vla.h"
 
+#include "util/set.h"
 #include "util/u_math.h"
 
+static struct set *
+get_complex_used_vars(nir_shader *shader, void *mem_ctx)
+{
+   struct set *complex_vars = _mesa_pointer_set_create(mem_ctx);
+
+   nir_foreach_function(function, shader) {
+      if (!function->impl)
+         continue;
+
+      nir_foreach_block(block, function->impl) {
+         nir_foreach_instr(instr, block) {
+            if (instr->type != nir_instr_type_deref)
+               continue;
+
+            nir_deref_instr *deref = nir_instr_as_deref(instr);
+
+            /* We only need to consider var derefs because
+             * nir_deref_instr_has_complex_use is recursive.
+             */
+            if (deref->deref_type == nir_deref_type_var &&
+                nir_deref_instr_has_complex_use(deref))
+               _mesa_set_add(complex_vars, deref->var);
+         }
+      }
+   }
+
+   return complex_vars;
+}
 
 struct split_var_state {
    void *mem_ctx;
@@ -128,6 +157,7 @@ split_var_list_structs(nir_shader *shader,
                        nir_function_impl *impl,
                        struct exec_list *vars,
                        struct hash_table *var_field_map,
+                       struct set **complex_vars,
                        void *mem_ctx)
 {
    struct split_var_state state = {
@@ -146,6 +176,15 @@ split_var_list_structs(nir_shader *shader,
       if (!glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
          continue;
 
+      if (*complex_vars == NULL)
+         *complex_vars = get_complex_used_vars(shader, mem_ctx);
+
+      /* We can't split a variable that's referenced with deref that has any
+       * sort of complex usage.
+       */
+      if (_mesa_set_search(*complex_vars, var))
+         continue;
+
       exec_node_remove(&var->node);
       exec_list_push_tail(&split_vars, &var->node);
    }
@@ -258,6 +297,7 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
    void *mem_ctx = ralloc_context(NULL);
    struct hash_table *var_field_map =
       _mesa_pointer_hash_table_create(mem_ctx);
+   struct set *complex_vars = NULL;
 
    assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
 
@@ -265,7 +305,9 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
    if (modes & nir_var_shader_temp) {
       has_global_splits = split_var_list_structs(shader, NULL,
                                                  &shader->globals,
-                                                 var_field_map, mem_ctx);
+                                                 var_field_map,
+                                                 &complex_vars,
+                                                 mem_ctx);
    }
 
    bool progress = false;
@@ -277,7 +319,9 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
       if (modes & nir_var_function_temp) {
          has_local_splits = split_var_list_structs(shader, function->impl,
                                                    &function->impl->locals,
-                                                   var_field_map, mem_ctx);
+                                                   var_field_map,
+                                                   &complex_vars,
+                                                   mem_ctx);
       }
 
       if (has_global_splits || has_local_splits) {
@@ -321,8 +365,10 @@ struct array_var_info {
 };
 
 static bool
-init_var_list_array_infos(struct exec_list *vars,
+init_var_list_array_infos(nir_shader *shader,
+                          struct exec_list *vars,
                           struct hash_table *var_info_map,
+                          struct set **complex_vars,
                           void *mem_ctx)
 {
    bool has_array = false;
@@ -332,6 +378,15 @@ init_var_list_array_infos(struct exec_list *vars,
       if (num_levels <= 0)
          continue;
 
+      if (*complex_vars == NULL)
+         *complex_vars = get_complex_used_vars(shader, mem_ctx);
+
+      /* We can't split a variable that's referenced with deref that has any
+       * sort of complex usage.
+       */
+      if (_mesa_set_search(*complex_vars, var))
+         continue;
+
       struct array_var_info *info =
          rzalloc_size(mem_ctx, sizeof(*info) +
                                num_levels * sizeof(info->levels[0]));
@@ -791,13 +846,17 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
 {
    void *mem_ctx = ralloc_context(NULL);
    struct hash_table *var_info_map = _mesa_pointer_hash_table_create(mem_ctx);
+   struct set *complex_vars = NULL;
 
    assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
 
    bool has_global_array = false;
    if (modes & nir_var_shader_temp) {
-      has_global_array = init_var_list_array_infos(&shader->globals,
-                                                   var_info_map, mem_ctx);
+      has_global_array = init_var_list_array_infos(shader,
+                                                   &shader->globals,
+                                                   var_info_map,
+                                                   &complex_vars,
+                                                   mem_ctx);
    }
 
    bool has_any_array = false;
@@ -807,8 +866,11 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
 
       bool has_local_array = false;
       if (modes & nir_var_function_temp) {
-         has_local_array = init_var_list_array_infos(&function->impl->locals,
-                                                     var_info_map, mem_ctx);
+         has_local_array = init_var_list_array_infos(shader,
+                                                     &function->impl->locals,
+                                                     var_info_map,
+                                                     &complex_vars,
+                                                     mem_ctx);
       }
 
       if (has_global_array || has_local_array) {
@@ -880,6 +942,7 @@ struct vec_var_usage {
 
    /* True if there is a copy that isn't to/from a shrinkable vector */
    bool has_external_copy;
+   bool has_complex_use;
    struct set *vars_copied;
 
    unsigned num_levels;
@@ -939,6 +1002,32 @@ get_vec_deref_usage(nir_deref_instr *deref,
                             var_usage_map, add_usage_entry, mem_ctx);
 }
 
+static void
+mark_deref_if_complex(nir_deref_instr *deref,
+                      struct hash_table *var_usage_map,
+                      nir_variable_mode modes,
+                      void *mem_ctx)
+{
+   if (!(deref->mode & modes))
+      return;
+
+   /* Only bother with var derefs because nir_deref_instr_has_complex_use is
+    * recursive.
+    */
+   if (deref->deref_type != nir_deref_type_var)
+      return;
+
+   if (!nir_deref_instr_has_complex_use(deref))
+      return;
+
+   struct vec_var_usage *usage =
+      get_vec_var_usage(deref->var, var_usage_map, true, mem_ctx);
+   if (!usage)
+      return;
+
+   usage->has_complex_use = true;
+}
+
 static void
 mark_deref_used(nir_deref_instr *deref,
                 nir_component_mask_t comps_read,
@@ -952,6 +1041,8 @@ mark_deref_used(nir_deref_instr *deref,
       return;
 
    nir_variable *var = nir_deref_instr_get_variable(deref);
+   if (var == NULL)
+      return;
 
    struct vec_var_usage *usage =
       get_vec_var_usage(var, var_usage_map, true, mem_ctx);
@@ -1096,6 +1187,11 @@ find_used_components_impl(nir_function_impl *impl,
 {
    nir_foreach_block(block, impl) {
       nir_foreach_instr(instr, block) {
+         if (instr->type == nir_instr_type_deref) {
+            mark_deref_if_complex(nir_instr_as_deref(instr),
+                                  var_usage_map, modes, mem_ctx);
+         }
+
          if (instr->type != nir_instr_type_intrinsic)
             continue;
 
@@ -1156,7 +1252,7 @@ shrink_vec_var_list(struct exec_list *vars,
          continue;
 
       assert(usage->comps_kept == 0);
-      if (usage->has_external_copy)
+      if (usage->has_external_copy || usage->has_complex_use)
          usage->comps_kept = usage->all_comps;
       else
          usage->comps_kept = usage->comps_read & usage->comps_written;
@@ -1165,7 +1261,8 @@ shrink_vec_var_list(struct exec_list *vars,
          struct array_level_usage *level = &usage->levels[i];
          assert(level->array_len > 0);
 
-         if (level->max_written == UINT_MAX || level->has_external_copy)
+         if (level->max_written == UINT_MAX || level->has_external_copy ||
+             usage->has_complex_use)
             continue; /* Can't shrink */
 
          unsigned max_used = MIN2(level->max_read, level->max_written);