nir/split_vars: Add mode checks to list walks
[mesa.git] / src / compiler / nir / nir_split_vars.c
index 3d98b5c8805fbea9db7a0ddf7e7243b33e38c120..db320039b9332d7ff6869bb7bad1a3c3f94cf378 100644 (file)
@@ -156,6 +156,7 @@ static bool
 split_var_list_structs(nir_shader *shader,
                        nir_function_impl *impl,
                        struct exec_list *vars,
+                       nir_variable_mode mode,
                        struct hash_table *var_field_map,
                        struct set **complex_vars,
                        void *mem_ctx)
@@ -173,6 +174,9 @@ split_var_list_structs(nir_shader *shader,
     * pull all of the variables we plan to split off of the list
     */
    nir_foreach_variable_safe(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       if (!glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
          continue;
 
@@ -305,6 +309,7 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
    if (modes & nir_var_shader_temp) {
       has_global_splits = split_var_list_structs(shader, NULL,
                                                  &shader->globals,
+                                                 nir_var_shader_temp,
                                                  var_field_map,
                                                  &complex_vars,
                                                  mem_ctx);
@@ -319,6 +324,7 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
       if (modes & nir_var_function_temp) {
          has_local_splits = split_var_list_structs(shader, function->impl,
                                                    &function->impl->locals,
+                                                   nir_var_function_temp,
                                                    var_field_map,
                                                    &complex_vars,
                                                    mem_ctx);
@@ -331,6 +337,8 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
          nir_metadata_preserve(function->impl, nir_metadata_block_index |
                                                nir_metadata_dominance);
          progress = true;
+      } else {
+         nir_metadata_preserve(function->impl, nir_metadata_all);
       }
    }
 
@@ -367,6 +375,7 @@ struct array_var_info {
 static bool
 init_var_list_array_infos(nir_shader *shader,
                           struct exec_list *vars,
+                          nir_variable_mode mode,
                           struct hash_table *var_info_map,
                           struct set **complex_vars,
                           void *mem_ctx)
@@ -374,6 +383,9 @@ init_var_list_array_infos(nir_shader *shader,
    bool has_array = false;
 
    nir_foreach_variable(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       int num_levels = num_array_levels_in_array_of_vector_type(var->type);
       if (num_levels <= 0)
          continue;
@@ -427,8 +439,11 @@ get_array_deref_info(nir_deref_instr *deref,
    if (!(deref->mode & modes))
       return NULL;
 
-   return get_array_var_info(nir_deref_instr_get_variable(deref),
-                             var_info_map);
+   nir_variable *var = nir_deref_instr_get_variable(deref);
+   if (var == NULL)
+      return NULL;
+
+   return get_array_var_info(var, var_info_map);
 }
 
 static void
@@ -532,6 +547,7 @@ static bool
 split_var_list_arrays(nir_shader *shader,
                       nir_function_impl *impl,
                       struct exec_list *vars,
+                      nir_variable_mode mode,
                       struct hash_table *var_info_map,
                       void *mem_ctx)
 {
@@ -539,6 +555,9 @@ split_var_list_arrays(nir_shader *shader,
    exec_list_make_empty(&split_vars);
 
    nir_foreach_variable_safe(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       struct array_var_info *info = get_array_var_info(var, var_info_map);
       if (!info)
          continue;
@@ -854,6 +873,7 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
    if (modes & nir_var_shader_temp) {
       has_global_array = init_var_list_array_infos(shader,
                                                    &shader->globals,
+                                                   nir_var_shader_temp,
                                                    var_info_map,
                                                    &complex_vars,
                                                    mem_ctx);
@@ -868,6 +888,7 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
       if (modes & nir_var_function_temp) {
          has_local_array = init_var_list_array_infos(shader,
                                                      &function->impl->locals,
+                                                     nir_var_function_temp,
                                                      var_info_map,
                                                      &complex_vars,
                                                      mem_ctx);
@@ -882,6 +903,7 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
    /* If we failed to find any arrays of arrays, bail early. */
    if (!has_any_array) {
       ralloc_free(mem_ctx);
+      nir_shader_preserve_all_metadata(shader);
       return false;
    }
 
@@ -889,6 +911,7 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
    if (modes & nir_var_shader_temp) {
       has_global_splits = split_var_list_arrays(shader, NULL,
                                                 &shader->globals,
+                                                nir_var_shader_temp,
                                                 var_info_map, mem_ctx);
    }
 
@@ -901,6 +924,7 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
       if (modes & nir_var_function_temp) {
          has_local_splits = split_var_list_arrays(shader, function->impl,
                                                   &function->impl->locals,
+                                                  nir_var_function_temp,
                                                   var_info_map, mem_ctx);
       }
 
@@ -911,6 +935,8 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
          nir_metadata_preserve(function->impl, nir_metadata_block_index |
                                                nir_metadata_dominance);
          progress = true;
+      } else {
+         nir_metadata_preserve(function->impl, nir_metadata_all);
       }
    }
 
@@ -1163,9 +1189,7 @@ get_non_self_referential_store_comps(nir_intrinsic_instr *store)
                comps &= ~(1u << i);
          }
       }
-   } else if (src_alu->op == nir_op_vec2 ||
-              src_alu->op == nir_op_vec3 ||
-              src_alu->op == nir_op_vec4) {
+   } else if (nir_op_is_vec(src_alu->op)) {
       /* If it's a vec, discount any channels that are just loads from the
        * same deref put in the same spot.
        */
@@ -1227,6 +1251,7 @@ find_used_components_impl(nir_function_impl *impl,
 
 static bool
 shrink_vec_var_list(struct exec_list *vars,
+                    nir_variable_mode mode,
                     struct hash_table *var_usage_map)
 {
    /* Initialize the components kept field of each variable.  This is the
@@ -1246,6 +1271,9 @@ shrink_vec_var_list(struct exec_list *vars,
     * to leave components and array_len of any wildcards alone.
     */
    nir_foreach_variable(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       struct vec_var_usage *usage =
          get_vec_var_usage(var, var_usage_map, false, NULL);
       if (!usage)
@@ -1279,6 +1307,9 @@ shrink_vec_var_list(struct exec_list *vars,
    do {
       fp_progress = false;
       nir_foreach_variable(var, vars) {
+         if (var->data.mode != mode)
+            continue;
+
          struct vec_var_usage *var_usage =
             get_vec_var_usage(var, var_usage_map, false, NULL);
          if (!var_usage || !var_usage->vars_copied)
@@ -1316,6 +1347,9 @@ shrink_vec_var_list(struct exec_list *vars,
 
    bool vars_shrunk = false;
    nir_foreach_variable_safe(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       struct vec_var_usage *usage =
          get_vec_var_usage(var, var_usage_map, false, NULL);
       if (!usage)
@@ -1629,12 +1663,16 @@ nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
    }
    if (!has_vars_to_shrink) {
       ralloc_free(mem_ctx);
+      nir_shader_preserve_all_metadata(shader);
       return false;
    }
 
    bool globals_shrunk = false;
-   if (modes & nir_var_shader_temp)
-      globals_shrunk = shrink_vec_var_list(&shader->globals, var_usage_map);
+   if (modes & nir_var_shader_temp) {
+      globals_shrunk = shrink_vec_var_list(&shader->globals,
+                                           nir_var_shader_temp,
+                                           var_usage_map);
+   }
 
    bool progress = false;
    nir_foreach_function(function, shader) {
@@ -1644,6 +1682,7 @@ nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
       bool locals_shrunk = false;
       if (modes & nir_var_function_temp) {
          locals_shrunk = shrink_vec_var_list(&function->impl->locals,
+                                             nir_var_function_temp,
                                              var_usage_map);
       }
 
@@ -1653,6 +1692,8 @@ nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
          nir_metadata_preserve(function->impl, nir_metadata_block_index |
                                                nir_metadata_dominance);
          progress = true;
+      } else {
+         nir_metadata_preserve(function->impl, nir_metadata_all);
       }
    }