nir: Use a single list for all shader variables
[mesa.git] / src / compiler / nir / nir_split_vars.c
index 6cd835cb0135a551edcae3c71aaea7198fceddc5..6aee3109d4121a2b8f94455d4a513dfec97c7234 100644 (file)
 #include "nir_deref.h"
 #include "nir_vla.h"
 
-/* Needed for _mesa_bitcount() */
-#include "main/macros.h"
+#include "util/set.h"
+#include "util/u_math.h"
+
+static struct set *
+get_complex_used_vars(nir_shader *shader, void *mem_ctx)
+{
+   struct set *complex_vars = _mesa_pointer_set_create(mem_ctx);
+
+   nir_foreach_function(function, shader) {
+      if (!function->impl)
+         continue;
+
+      nir_foreach_block(block, function->impl) {
+         nir_foreach_instr(instr, block) {
+            if (instr->type != nir_instr_type_deref)
+               continue;
+
+            nir_deref_instr *deref = nir_instr_as_deref(instr);
+
+            /* We only need to consider var derefs because
+             * nir_deref_instr_has_complex_use is recursive.
+             */
+            if (deref->deref_type == nir_deref_type_var &&
+                nir_deref_instr_has_complex_use(deref))
+               _mesa_set_add(complex_vars, deref->var);
+         }
+      }
+   }
+
+   return complex_vars;
+}
 
 struct split_var_state {
    void *mem_ctx;
@@ -58,7 +87,8 @@ wrap_type_in_array(const struct glsl_type *type,
 
    const struct glsl_type *elem_type =
       wrap_type_in_array(type, glsl_get_array_element(array_type));
-   return glsl_array_type(elem_type, glsl_get_length(array_type));
+   assert(glsl_get_explicit_stride(array_type) == 0);
+   return glsl_array_type(elem_type, glsl_get_length(array_type), 0);
 }
 
 static int
@@ -90,7 +120,7 @@ init_field_for_type(struct field *field, struct field *parent,
    };
 
    const struct glsl_type *struct_type = glsl_without_array(type);
-   if (glsl_type_is_struct(struct_type)) {
+   if (glsl_type_is_struct_or_ifc(struct_type)) {
       field->num_fields = glsl_get_length(struct_type),
       field->fields = ralloc_array(state->mem_ctx, struct field,
                                    field->num_fields);
@@ -114,7 +144,7 @@ init_field_for_type(struct field *field, struct field *parent,
          var_type = wrap_type_in_array(var_type, f->type);
 
       nir_variable_mode mode = state->base_var->data.mode;
-      if (mode == nir_var_local) {
+      if (mode == nir_var_function_temp) {
          field->var = nir_local_variable_create(state->impl, var_type, name);
       } else {
          field->var = nir_variable_create(state->shader, mode, var_type, name);
@@ -126,7 +156,9 @@ static bool
 split_var_list_structs(nir_shader *shader,
                        nir_function_impl *impl,
                        struct exec_list *vars,
+                       nir_variable_mode mode,
                        struct hash_table *var_field_map,
+                       struct set **complex_vars,
                        void *mem_ctx)
 {
    struct split_var_state state = {
@@ -141,15 +173,27 @@ split_var_list_structs(nir_shader *shader,
    /* To avoid list confusion (we'll be adding things as we split variables),
     * pull all of the variables we plan to split off of the list
     */
-   nir_foreach_variable_safe(var, vars) {
-      if (!glsl_type_is_struct(glsl_without_array(var->type)))
+   nir_foreach_variable_in_list_safe(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
+      if (!glsl_type_is_struct_or_ifc(glsl_without_array(var->type)))
+         continue;
+
+      if (*complex_vars == NULL)
+         *complex_vars = get_complex_used_vars(shader, mem_ctx);
+
+      /* We can't split a variable that's referenced with deref that has any
+       * sort of complex usage.
+       */
+      if (_mesa_set_search(*complex_vars, var))
          continue;
 
       exec_node_remove(&var->node);
       exec_list_push_tail(&split_vars, &var->node);
    }
 
-   nir_foreach_variable(var, &split_vars) {
+   nir_foreach_variable_in_list(var, &split_vars) {
       state.base_var = var;
 
       struct field *root_field = ralloc(mem_ctx, struct field);
@@ -204,7 +248,7 @@ split_struct_derefs_impl(nir_function_impl *impl,
                continue;
 
             assert(i > 0);
-            assert(glsl_type_is_struct(path.path[i - 1]->type));
+            assert(glsl_type_is_struct_or_ifc(path.path[i - 1]->type));
             assert(path.path[i - 1]->type ==
                    glsl_without_array(tail_field->type));
 
@@ -256,16 +300,19 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
 {
    void *mem_ctx = ralloc_context(NULL);
    struct hash_table *var_field_map =
-      _mesa_hash_table_create(mem_ctx, _mesa_hash_pointer,
-                              _mesa_key_pointer_equal);
+      _mesa_pointer_hash_table_create(mem_ctx);
+   struct set *complex_vars = NULL;
 
-   assert((modes & (nir_var_global | nir_var_local)) == modes);
+   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
 
    bool has_global_splits = false;
-   if (modes & nir_var_global) {
+   if (modes & nir_var_shader_temp) {
       has_global_splits = split_var_list_structs(shader, NULL,
-                                                 &shader->globals,
-                                                 var_field_map, mem_ctx);
+                                                 &shader->variables,
+                                                 nir_var_shader_temp,
+                                                 var_field_map,
+                                                 &complex_vars,
+                                                 mem_ctx);
    }
 
    bool progress = false;
@@ -274,10 +321,13 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
          continue;
 
       bool has_local_splits = false;
-      if (modes & nir_var_local) {
+      if (modes & nir_var_function_temp) {
          has_local_splits = split_var_list_structs(shader, function->impl,
                                                    &function->impl->locals,
-                                                   var_field_map, mem_ctx);
+                                                   nir_var_function_temp,
+                                                   var_field_map,
+                                                   &complex_vars,
+                                                   mem_ctx);
       }
 
       if (has_global_splits || has_local_splits) {
@@ -287,6 +337,8 @@ nir_split_struct_vars(nir_shader *shader, nir_variable_mode modes)
          nir_metadata_preserve(function->impl, nir_metadata_block_index |
                                                nir_metadata_dominance);
          progress = true;
+      } else {
+         nir_metadata_preserve(function->impl, nir_metadata_all);
       }
    }
 
@@ -321,17 +373,32 @@ struct array_var_info {
 };
 
 static bool
-init_var_list_array_infos(struct exec_list *vars,
+init_var_list_array_infos(nir_shader *shader,
+                          struct exec_list *vars,
+                          nir_variable_mode mode,
                           struct hash_table *var_info_map,
+                          struct set **complex_vars,
                           void *mem_ctx)
 {
    bool has_array = false;
 
-   nir_foreach_variable(var, vars) {
+   nir_foreach_variable_in_list(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       int num_levels = num_array_levels_in_array_of_vector_type(var->type);
       if (num_levels <= 0)
          continue;
 
+      if (*complex_vars == NULL)
+         *complex_vars = get_complex_used_vars(shader, mem_ctx);
+
+      /* We can't split a variable that's referenced with deref that has any
+       * sort of complex usage.
+       */
+      if (_mesa_set_search(*complex_vars, var))
+         continue;
+
       struct array_var_info *info =
          rzalloc_size(mem_ctx, sizeof(*info) +
                                num_levels * sizeof(info->levels[0]));
@@ -372,8 +439,11 @@ get_array_deref_info(nir_deref_instr *deref,
    if (!(deref->mode & modes))
       return NULL;
 
-   return get_array_var_info(nir_deref_instr_get_variable(deref),
-                             var_info_map);
+   nir_variable *var = nir_deref_instr_get_variable(deref);
+   if (var == NULL)
+      return NULL;
+
+   return get_array_var_info(var, var_info_map);
 }
 
 static void
@@ -396,7 +466,7 @@ mark_array_deref_used(nir_deref_instr *deref,
    for (unsigned i = 0; i < info->num_levels; i++) {
       nir_deref_instr *p = path.path[i + 1];
       if (p->deref_type == nir_deref_type_array &&
-          nir_src_as_const_value(p->arr.index) == NULL)
+          !nir_src_is_const(p->arr.index))
          info->levels[i].split = false;
    }
 }
@@ -453,7 +523,7 @@ create_split_array_vars(struct array_var_info *var_info,
       name = ralloc_asprintf(mem_ctx, "(%s)", name);
 
       nir_variable_mode mode = var_info->base_var->data.mode;
-      if (mode == nir_var_local) {
+      if (mode == nir_var_function_temp) {
          split->var = nir_local_variable_create(impl,
                                                 var_info->split_var_type, name);
       } else {
@@ -477,13 +547,17 @@ static bool
 split_var_list_arrays(nir_shader *shader,
                       nir_function_impl *impl,
                       struct exec_list *vars,
+                      nir_variable_mode mode,
                       struct hash_table *var_info_map,
                       void *mem_ctx)
 {
    struct exec_list split_vars;
    exec_list_make_empty(&split_vars);
 
-   nir_foreach_variable_safe(var, vars) {
+   nir_foreach_variable_in_list_safe(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       struct array_var_info *info = get_array_var_info(var, var_info_map);
       if (!info)
          continue;
@@ -506,7 +580,7 @@ split_var_list_arrays(nir_shader *shader,
                                           glsl_get_components(split_type),
                                           info->levels[i].array_len);
          } else {
-            split_type = glsl_array_type(split_type, info->levels[i].array_len);
+            split_type = glsl_array_type(split_type, info->levels[i].array_len, 0);
          }
       }
 
@@ -519,7 +593,7 @@ split_var_list_arrays(nir_shader *shader,
          exec_node_remove(&var->node);
          exec_list_push_tail(&split_vars, &var->node);
       } else {
-         assert(split_type == var->type);
+         assert(split_type == glsl_get_bare_type(var->type));
          /* If we're not modifying this variable, delete the info so we skip
           * it faster in later passes.
           */
@@ -527,7 +601,7 @@ split_var_list_arrays(nir_shader *shader,
       }
    }
 
-   nir_foreach_variable(var, &split_vars) {
+   nir_foreach_variable_in_list(var, &split_vars) {
       struct array_var_info *info = get_array_var_info(var, var_info_map);
       create_split_array_vars(info, 0, &info->root_split, var->name,
                               shader, impl, mem_ctx);
@@ -566,8 +640,8 @@ array_path_is_out_of_bounds(nir_deref_path *path,
       if (p->deref_type == nir_deref_type_array_wildcard)
          continue;
 
-      nir_const_value *const_index = nir_src_as_const_value(p->arr.index);
-      if (const_index && const_index->u32[0] >= info->levels[i].array_len)
+      if (nir_src_is_const(p->arr.index) &&
+          nir_src_as_uint(p->arr.index) >= info->levels[i].array_len)
          return true;
    }
 
@@ -615,11 +689,10 @@ emit_split_copies(nir_builder *b,
                 glsl_get_length(src_path->path[src_level]->type));
          unsigned len = glsl_get_length(dst_path->path[dst_level]->type);
          for (unsigned i = 0; i < len; i++) {
-            nir_ssa_def *idx = nir_imm_int(b, i);
             emit_split_copies(b, dst_info, dst_path, dst_level + 1,
-                              nir_build_deref_array(b, dst, idx),
+                              nir_build_deref_array_imm(b, dst, i),
                               src_info, src_path, src_level + 1,
-                              nir_build_deref_array(b, src, idx));
+                              nir_build_deref_array_imm(b, src, i));
          }
       } else {
          /* Neither side is being split so we just keep going */
@@ -751,7 +824,7 @@ split_array_access_impl(nir_function_impl *impl,
             for (unsigned i = 0; i < info->num_levels; i++) {
                if (info->levels[i].split) {
                   nir_deref_instr *p = path.path[i + 1];
-                  unsigned index = nir_src_as_const_value(p->arr.index)->u32[0];
+                  unsigned index = nir_src_as_uint(p->arr.index);
                   assert(index < info->levels[i].array_len);
                   split = &split->splits[index];
                }
@@ -791,16 +864,19 @@ bool
 nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
 {
    void *mem_ctx = ralloc_context(NULL);
-   struct hash_table *var_info_map =
-      _mesa_hash_table_create(mem_ctx, _mesa_hash_pointer,
-                              _mesa_key_pointer_equal);
+   struct hash_table *var_info_map = _mesa_pointer_hash_table_create(mem_ctx);
+   struct set *complex_vars = NULL;
 
-   assert((modes & (nir_var_global | nir_var_local)) == modes);
+   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
 
    bool has_global_array = false;
-   if (modes & nir_var_global) {
-      has_global_array = init_var_list_array_infos(&shader->globals,
-                                                   var_info_map, mem_ctx);
+   if (modes & nir_var_shader_temp) {
+      has_global_array = init_var_list_array_infos(shader,
+                                                   &shader->variables,
+                                                   nir_var_shader_temp,
+                                                   var_info_map,
+                                                   &complex_vars,
+                                                   mem_ctx);
    }
 
    bool has_any_array = false;
@@ -809,9 +885,13 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
          continue;
 
       bool has_local_array = false;
-      if (modes & nir_var_local) {
-         has_local_array = init_var_list_array_infos(&function->impl->locals,
-                                                     var_info_map, mem_ctx);
+      if (modes & nir_var_function_temp) {
+         has_local_array = init_var_list_array_infos(shader,
+                                                     &function->impl->locals,
+                                                     nir_var_function_temp,
+                                                     var_info_map,
+                                                     &complex_vars,
+                                                     mem_ctx);
       }
 
       if (has_global_array || has_local_array) {
@@ -823,13 +903,15 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
    /* If we failed to find any arrays of arrays, bail early. */
    if (!has_any_array) {
       ralloc_free(mem_ctx);
+      nir_shader_preserve_all_metadata(shader);
       return false;
    }
 
    bool has_global_splits = false;
-   if (modes & nir_var_global) {
+   if (modes & nir_var_shader_temp) {
       has_global_splits = split_var_list_arrays(shader, NULL,
-                                                &shader->globals,
+                                                &shader->variables,
+                                                nir_var_shader_temp,
                                                 var_info_map, mem_ctx);
    }
 
@@ -839,9 +921,10 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
          continue;
 
       bool has_local_splits = false;
-      if (modes & nir_var_local) {
+      if (modes & nir_var_function_temp) {
          has_local_splits = split_var_list_arrays(shader, function->impl,
                                                   &function->impl->locals,
+                                                  nir_var_function_temp,
                                                   var_info_map, mem_ctx);
       }
 
@@ -852,6 +935,8 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
          nir_metadata_preserve(function->impl, nir_metadata_block_index |
                                                nir_metadata_dominance);
          progress = true;
+      } else {
+         nir_metadata_preserve(function->impl, nir_metadata_all);
       }
    }
 
@@ -883,6 +968,7 @@ struct vec_var_usage {
 
    /* True if there is a copy that isn't to/from a shrinkable vector */
    bool has_external_copy;
+   bool has_complex_use;
    struct set *vars_copied;
 
    unsigned num_levels;
@@ -942,6 +1028,32 @@ get_vec_deref_usage(nir_deref_instr *deref,
                             var_usage_map, add_usage_entry, mem_ctx);
 }
 
+static void
+mark_deref_if_complex(nir_deref_instr *deref,
+                      struct hash_table *var_usage_map,
+                      nir_variable_mode modes,
+                      void *mem_ctx)
+{
+   if (!(deref->mode & modes))
+      return;
+
+   /* Only bother with var derefs because nir_deref_instr_has_complex_use is
+    * recursive.
+    */
+   if (deref->deref_type != nir_deref_type_var)
+      return;
+
+   if (!nir_deref_instr_has_complex_use(deref))
+      return;
+
+   struct vec_var_usage *usage =
+      get_vec_var_usage(deref->var, var_usage_map, true, mem_ctx);
+   if (!usage)
+      return;
+
+   usage->has_complex_use = true;
+}
+
 static void
 mark_deref_used(nir_deref_instr *deref,
                 nir_component_mask_t comps_read,
@@ -955,6 +1067,8 @@ mark_deref_used(nir_deref_instr *deref,
       return;
 
    nir_variable *var = nir_deref_instr_get_variable(deref);
+   if (var == NULL)
+      return;
 
    struct vec_var_usage *usage =
       get_vec_var_usage(var, var_usage_map, true, mem_ctx);
@@ -970,8 +1084,7 @@ mark_deref_used(nir_deref_instr *deref,
                                        true, mem_ctx);
       if (copy_usage) {
          if (usage->vars_copied == NULL) {
-            usage->vars_copied = _mesa_set_create(mem_ctx, _mesa_hash_pointer,
-                                                  _mesa_key_pointer_equal);
+            usage->vars_copied = _mesa_pointer_set_create(mem_ctx);
          }
          _mesa_set_add(usage->vars_copied, copy_usage);
       } else {
@@ -995,9 +1108,8 @@ mark_deref_used(nir_deref_instr *deref,
 
       unsigned max_used;
       if (deref->deref_type == nir_deref_type_array) {
-         nir_const_value *const_index =
-            nir_src_as_const_value(deref->arr.index);
-         max_used = const_index ? const_index->u32[0] : UINT_MAX;
+         max_used = nir_src_is_const(deref->arr.index) ?
+                    nir_src_as_uint(deref->arr.index) : UINT_MAX;
       } else {
          /* For wildcards, we read or wrote the whole thing. */
          assert(deref->deref_type == nir_deref_type_array_wildcard);
@@ -1014,9 +1126,7 @@ mark_deref_used(nir_deref_instr *deref,
                &copy_usage->levels[copy_i++];
 
             if (level->levels_copied == NULL) {
-               level->levels_copied =
-                  _mesa_set_create(mem_ctx, _mesa_hash_pointer,
-                                   _mesa_key_pointer_equal);
+               level->levels_copied = _mesa_pointer_set_create(mem_ctx);
             }
             _mesa_set_add(level->levels_copied, copy_level);
          } else {
@@ -1037,14 +1147,8 @@ mark_deref_used(nir_deref_instr *deref,
 static bool
 src_is_load_deref(nir_src src, nir_src deref_src)
 {
-   assert(src.is_ssa);
-   assert(deref_src.is_ssa);
-
-   if (src.ssa->parent_instr->type != nir_instr_type_intrinsic)
-      return false;
-
-   nir_intrinsic_instr *load = nir_instr_as_intrinsic(src.ssa->parent_instr);
-   if (load->intrinsic != nir_intrinsic_load_deref)
+   nir_intrinsic_instr *load = nir_src_as_intrinsic(src);
+   if (load == NULL || load->intrinsic != nir_intrinsic_load_deref)
       return false;
 
    assert(load->src[0].is_ssa);
@@ -1075,8 +1179,7 @@ get_non_self_referential_store_comps(nir_intrinsic_instr *store)
 
    nir_alu_instr *src_alu = nir_instr_as_alu(src_instr);
 
-   if (src_alu->op == nir_op_imov ||
-       src_alu->op == nir_op_fmov) {
+   if (src_alu->op == nir_op_mov) {
       /* If it's just a swizzle of a load from the same deref, discount any
        * channels that don't move in the swizzle.
        */
@@ -1086,9 +1189,7 @@ get_non_self_referential_store_comps(nir_intrinsic_instr *store)
                comps &= ~(1u << i);
          }
       }
-   } else if (src_alu->op == nir_op_vec2 ||
-              src_alu->op == nir_op_vec3 ||
-              src_alu->op == nir_op_vec4) {
+   } else if (nir_op_is_vec(src_alu->op)) {
       /* If it's a vec, discount any channels that are just loads from the
        * same deref put in the same spot.
        */
@@ -1110,6 +1211,11 @@ find_used_components_impl(nir_function_impl *impl,
 {
    nir_foreach_block(block, impl) {
       nir_foreach_instr(instr, block) {
+         if (instr->type == nir_instr_type_deref) {
+            mark_deref_if_complex(nir_instr_as_deref(instr),
+                                  var_usage_map, modes, mem_ctx);
+         }
+
          if (instr->type != nir_instr_type_intrinsic)
             continue;
 
@@ -1145,6 +1251,7 @@ find_used_components_impl(nir_function_impl *impl,
 
 static bool
 shrink_vec_var_list(struct exec_list *vars,
+                    nir_variable_mode mode,
                     struct hash_table *var_usage_map)
 {
    /* Initialize the components kept field of each variable.  This is the
@@ -1163,14 +1270,17 @@ shrink_vec_var_list(struct exec_list *vars,
     * Also, if we have a copy that to/from something we can't shrink, we need
     * to leave components and array_len of any wildcards alone.
     */
-   nir_foreach_variable(var, vars) {
+   nir_foreach_variable_in_list(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       struct vec_var_usage *usage =
          get_vec_var_usage(var, var_usage_map, false, NULL);
       if (!usage)
          continue;
 
       assert(usage->comps_kept == 0);
-      if (usage->has_external_copy)
+      if (usage->has_external_copy || usage->has_complex_use)
          usage->comps_kept = usage->all_comps;
       else
          usage->comps_kept = usage->comps_read & usage->comps_written;
@@ -1179,7 +1289,8 @@ shrink_vec_var_list(struct exec_list *vars,
          struct array_level_usage *level = &usage->levels[i];
          assert(level->array_len > 0);
 
-         if (level->max_written == UINT_MAX || level->has_external_copy)
+         if (level->max_written == UINT_MAX || level->has_external_copy ||
+             usage->has_complex_use)
             continue; /* Can't shrink */
 
          unsigned max_used = MIN2(level->max_read, level->max_written);
@@ -1195,13 +1306,15 @@ shrink_vec_var_list(struct exec_list *vars,
    bool fp_progress;
    do {
       fp_progress = false;
-      nir_foreach_variable(var, vars) {
+      nir_foreach_variable_in_list(var, vars) {
+         if (var->data.mode != mode)
+            continue;
+
          struct vec_var_usage *var_usage =
             get_vec_var_usage(var, var_usage_map, false, NULL);
          if (!var_usage || !var_usage->vars_copied)
             continue;
 
-         struct set_entry *copy_entry;
          set_foreach(var_usage->vars_copied, copy_entry) {
             struct vec_var_usage *copy_usage = (void *)copy_entry->key;
             if (copy_usage->comps_kept != var_usage->comps_kept) {
@@ -1233,7 +1346,10 @@ shrink_vec_var_list(struct exec_list *vars,
    } while (fp_progress);
 
    bool vars_shrunk = false;
-   nir_foreach_variable_safe(var, vars) {
+   nir_foreach_variable_in_list_safe(var, vars) {
+      if (var->data.mode != mode)
+         continue;
+
       struct vec_var_usage *usage =
          get_vec_var_usage(var, var_usage_map, false, NULL);
       if (!usage)
@@ -1277,7 +1393,7 @@ shrink_vec_var_list(struct exec_list *vars,
       }
 
       /* Build the new var type */
-      unsigned new_num_comps = _mesa_bitcount(usage->comps_kept);
+      unsigned new_num_comps = util_bitcount(usage->comps_kept);
       const struct glsl_type *new_type =
          glsl_vector_type(glsl_get_base_type(vec_type), new_num_comps);
       for (int i = usage->num_levels - 1; i >= 0; i--) {
@@ -1292,7 +1408,7 @@ shrink_vec_var_list(struct exec_list *vars,
                                         new_num_comps,
                                         usage->levels[i].array_len);
          } else {
-            new_type = glsl_array_type(new_type, usage->levels[i].array_len);
+            new_type = glsl_array_type(new_type, usage->levels[i].array_len, 0);
          }
       }
       var->type = new_type;
@@ -1316,8 +1432,8 @@ vec_deref_is_oob(nir_deref_instr *deref,
       if (p->deref_type == nir_deref_type_array_wildcard)
          continue;
 
-      nir_const_value *const_index = nir_src_as_const_value(p->arr.index);
-      if (const_index && const_index->u32[0] >= usage->levels[i].array_len) {
+      if (nir_src_is_const(p->arr.index) &&
+          nir_src_as_uint(p->arr.index) >= usage->levels[i].array_len) {
          oob = true;
          break;
       }
@@ -1428,6 +1544,12 @@ shrink_vec_var_access_impl(nir_function_impl *impl,
                continue;
             }
 
+            /* If we're not dropping any components, there's no need to
+             * compact vectors.
+             */
+            if (usage->comps_kept == usage->all_comps)
+               continue;
+
             if (intrin->intrinsic == nir_intrinsic_load_deref) {
                b.cursor = nir_after_instr(&intrin->instr);
 
@@ -1472,7 +1594,7 @@ shrink_vec_var_access_impl(nir_function_impl *impl,
                b.cursor = nir_before_instr(&intrin->instr);
 
                nir_ssa_def *swizzled =
-                  nir_swizzle(&b, intrin->src[1].ssa, swizzle, c, false);
+                  nir_swizzle(&b, intrin->src[1].ssa, swizzle, c);
 
                /* Rewrite to use the compacted source */
                nir_instr_rewrite_src(&intrin->instr, &intrin->src[1],
@@ -1496,10 +1618,13 @@ function_impl_has_vars_with_modes(nir_function_impl *impl,
 {
    nir_shader *shader = impl->function->shader;
 
-   if ((modes & nir_var_global) && !exec_list_is_empty(&shader->globals))
-      return true;
+   if (modes & ~nir_var_function_temp) {
+      nir_foreach_variable_with_modes(var, shader,
+                                      modes & ~nir_var_function_temp)
+         return true;
+   }
 
-   if ((modes & nir_var_local) && !exec_list_is_empty(&impl->locals))
+   if ((modes & nir_var_function_temp) && !exec_list_is_empty(&impl->locals))
       return true;
 
    return false;
@@ -1517,13 +1642,12 @@ function_impl_has_vars_with_modes(nir_function_impl *impl,
 bool
 nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
 {
-   assert((modes & (nir_var_global | nir_var_local)) == modes);
+   assert((modes & (nir_var_shader_temp | nir_var_function_temp)) == modes);
 
    void *mem_ctx = ralloc_context(NULL);
 
    struct hash_table *var_usage_map =
-      _mesa_hash_table_create(mem_ctx, _mesa_hash_pointer,
-                              _mesa_key_pointer_equal);
+      _mesa_pointer_hash_table_create(mem_ctx);
 
    bool has_vars_to_shrink = false;
    nir_foreach_function(function, shader) {
@@ -1542,12 +1666,16 @@ nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
    }
    if (!has_vars_to_shrink) {
       ralloc_free(mem_ctx);
+      nir_shader_preserve_all_metadata(shader);
       return false;
    }
 
    bool globals_shrunk = false;
-   if (modes & nir_var_global)
-      globals_shrunk = shrink_vec_var_list(&shader->globals, var_usage_map);
+   if (modes & nir_var_shader_temp) {
+      globals_shrunk = shrink_vec_var_list(&shader->variables,
+                                           nir_var_shader_temp,
+                                           var_usage_map);
+   }
 
    bool progress = false;
    nir_foreach_function(function, shader) {
@@ -1555,8 +1683,9 @@ nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
          continue;
 
       bool locals_shrunk = false;
-      if (modes & nir_var_local) {
+      if (modes & nir_var_function_temp) {
          locals_shrunk = shrink_vec_var_list(&function->impl->locals,
+                                             nir_var_function_temp,
                                              var_usage_map);
       }
 
@@ -1566,6 +1695,8 @@ nir_shrink_vec_array_vars(nir_shader *shader, nir_variable_mode modes)
          nir_metadata_preserve(function->impl, nir_metadata_block_index |
                                                nir_metadata_dominance);
          progress = true;
+      } else {
+         nir_metadata_preserve(function->impl, nir_metadata_all);
       }
    }