nir/lower_indirect_derefs: Add a threshold
[mesa.git] / src / compiler / nir / nir_lower_indirect_derefs.c
index d9dcba842988af15a1ba8ebd89c9df0a0031f410..08dbcb1f0d0dc0e9790cd893920f0e9687ae3e83 100644 (file)
@@ -113,7 +113,8 @@ emit_load_store_deref(nir_builder *b, nir_intrinsic_instr *orig_instr,
 
 static bool
 lower_indirect_derefs_block(nir_block *block, nir_builder *b,
-                            nir_variable_mode modes)
+                            nir_variable_mode modes,
+                            uint32_t max_lower_array_len)
 {
    bool progress = false;
 
@@ -133,17 +134,21 @@ lower_indirect_derefs_block(nir_block *block, nir_builder *b,
       nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
 
       /* Walk the deref chain back to the base and look for indirects */
+      uint32_t indirect_array_len = 1;
       bool has_indirect = false;
       nir_deref_instr *base = deref;
       while (base && base->deref_type != nir_deref_type_var) {
+         nir_deref_instr *parent = nir_deref_instr_parent(base);
          if (base->deref_type == nir_deref_type_array &&
-             !nir_src_is_const(base->arr.index))
+             !nir_src_is_const(base->arr.index)) {
+            indirect_array_len *= glsl_get_length(parent->type);
             has_indirect = true;
+         }
 
-         base = nir_deref_instr_parent(base);
+         base = parent;
       }
 
-      if (!has_indirect || !base)
+      if (!has_indirect || !base || indirect_array_len > max_lower_array_len)
          continue;
 
       /* Only lower variables whose mode is in the mask, or compact
@@ -179,14 +184,16 @@ lower_indirect_derefs_block(nir_block *block, nir_builder *b,
 }
 
 static bool
-lower_indirects_impl(nir_function_impl *impl, nir_variable_mode modes)
+lower_indirects_impl(nir_function_impl *impl, nir_variable_mode modes,
+                     uint32_t max_lower_array_len)
 {
    nir_builder builder;
    nir_builder_init(&builder, impl);
    bool progress = false;
 
    nir_foreach_block_safe(block, impl) {
-      progress |= lower_indirect_derefs_block(block, &builder, modes);
+      progress |= lower_indirect_derefs_block(block, &builder, modes,
+                                              max_lower_array_len);
    }
 
    if (progress)
@@ -203,13 +210,16 @@ lower_indirects_impl(nir_function_impl *impl, nir_variable_mode modes)
  * that does a binary search on the array index.
  */
 bool
-nir_lower_indirect_derefs(nir_shader *shader, nir_variable_mode modes)
+nir_lower_indirect_derefs(nir_shader *shader, nir_variable_mode modes,
+                          uint32_t max_lower_array_len)
 {
    bool progress = false;
 
    nir_foreach_function(function, shader) {
-      if (function->impl)
-         progress = lower_indirects_impl(function->impl, modes) || progress;
+      if (function->impl) {
+         progress = lower_indirects_impl(function->impl, modes,
+                                         max_lower_array_len) || progress;
+      }
    }
 
    return progress;