nir/dead_variables: Respect the modes passed to remove_dead_vars
[mesa.git] / src / compiler / nir / nir_remove_dead_variables.c
index 6baa66a57670d127b02cd0e3f62f40e55a735a3e..dfeaa24959087987fc883135fa5131027bd2ef56 100644 (file)
@@ -143,12 +143,15 @@ remove_dead_var_writes(nir_shader *shader, struct set *live)
 }
 
 static bool
-remove_dead_vars(struct exec_list *var_list, struct set *live,
-                 bool (*can_remove_var)(nir_variable *var))
+remove_dead_vars(struct exec_list *var_list, nir_variable_mode modes,
+                 struct set *live, bool (*can_remove_var)(nir_variable *var))
 {
    bool progress = false;
 
-   foreach_list_typed_safe(nir_variable, var, node, var_list) {
+   nir_foreach_variable_safe(var, var_list) {
+      if (!(var->data.mode & modes))
+         continue;
+
       if (can_remove_var && !can_remove_var(var))
          continue;
 
@@ -174,40 +177,41 @@ nir_remove_dead_variables(nir_shader *shader, nir_variable_mode modes,
    add_var_use_shader(shader, live, modes);
 
    if (modes & nir_var_uniform) {
-      progress = remove_dead_vars(&shader->uniforms, live, can_remove_var) ||
+      progress = remove_dead_vars(&shader->uniforms, modes, live, can_remove_var) ||
          progress;
    }
 
    if (modes & nir_var_shader_in) {
-      progress = remove_dead_vars(&shader->inputs, live, can_remove_var) ||
+      progress = remove_dead_vars(&shader->inputs, modes, live, can_remove_var) ||
          progress;
    }
 
    if (modes & nir_var_shader_out) {
-      progress = remove_dead_vars(&shader->outputs, live, can_remove_var) ||
+      progress = remove_dead_vars(&shader->outputs, modes, live, can_remove_var) ||
          progress;
    }
 
    if (modes & nir_var_shader_temp) {
-      progress = remove_dead_vars(&shader->globals, live, can_remove_var) ||
+      progress = remove_dead_vars(&shader->globals, modes, live, can_remove_var) ||
          progress;
    }
 
    if (modes & nir_var_system_value) {
-      progress = remove_dead_vars(&shader->system_values, live,
+      progress = remove_dead_vars(&shader->system_values, modes, live,
                                   can_remove_var) || progress;
    }
 
    if (modes & nir_var_mem_shared) {
-      progress = remove_dead_vars(&shader->shared, live, can_remove_var) ||
+      progress = remove_dead_vars(&shader->shared, modes, live, can_remove_var) ||
          progress;
    }
 
    if (modes & nir_var_function_temp) {
       nir_foreach_function(function, shader) {
          if (function->impl) {
-            if (remove_dead_vars(&function->impl->locals, live,
-                                 can_remove_var))
+            if (remove_dead_vars(&function->impl->locals,
+                                 nir_var_function_temp,
+                                 live, can_remove_var))
                progress = true;
          }
       }