nir: evaluate if condition uses inside the if branches
[mesa.git] / src / compiler / nir / nir_opt_if.c
index dacf2d6c667a5a2733747815aecb3196dd36d37d..512fd92575ba5eaffee4f60eec1b93978a3c6491 100644 (file)
@@ -369,6 +369,73 @@ opt_if_loop_terminator(nir_if *nif)
    return true;
 }
 
+static void
+replace_if_condition_use_with_const(nir_builder *b, nir_src *use,
+                                    nir_const_value nir_boolean,
+                                    bool if_condition)
+{
+   /* Create const */
+   nir_ssa_def *const_def = nir_build_imm(b, 1, 32, nir_boolean);
+
+   /* Rewrite use to use const */
+   nir_src new_src = nir_src_for_ssa(const_def);
+   if (if_condition)
+      nir_if_rewrite_condition(use->parent_if, new_src);
+   else
+      nir_instr_rewrite_src(use->parent_instr, use, new_src);
+}
+
+static bool
+evaluate_if_condition(nir_if *nif, nir_cursor cursor, uint32_t *value)
+{
+   nir_block *use_block = nir_cursor_current_block(cursor);
+   if (nir_block_dominates(nir_if_first_then_block(nif), use_block)) {
+      *value = NIR_TRUE;
+      return true;
+   } else if (nir_block_dominates(nir_if_first_else_block(nif), use_block)) {
+      *value = NIR_FALSE;
+      return true;
+   } else {
+      return false;
+   }
+}
+
+static bool
+evaluate_condition_use(nir_builder *b, nir_if *nif, nir_src *use_src,
+                       bool is_if_condition)
+{
+   bool progress = false;
+
+   nir_const_value value;
+   b->cursor = nir_before_src(use_src, is_if_condition);
+
+   if (evaluate_if_condition(nif, b->cursor, &value.u32[0])) {
+      replace_if_condition_use_with_const(b, use_src, value, is_if_condition);
+      progress = true;
+   }
+
+   return progress;
+}
+
+static bool
+opt_if_evaluate_condition_use(nir_builder *b, nir_if *nif)
+{
+   bool progress = false;
+
+   /* Evaluate any uses of the if condition inside the if branches */
+   assert(nif->condition.is_ssa);
+   nir_foreach_use_safe(use_src, nif->condition.ssa) {
+      progress |= evaluate_condition_use(b, nif, use_src, false);
+   }
+
+   nir_foreach_if_use_safe(use_src, nif->condition.ssa) {
+      if (use_src->parent_if != nif)
+         progress |= evaluate_condition_use(b, nif, use_src, true);
+   }
+
+   return progress;
+}
+
 static bool
 opt_if_cf_list(nir_builder *b, struct exec_list *cf_list)
 {
@@ -402,6 +469,41 @@ opt_if_cf_list(nir_builder *b, struct exec_list *cf_list)
    return progress;
 }
 
+/**
+ * These optimisations depend on nir_metadata_block_index and therefore must
+ * not do anything to cause the metadata to become invalid.
+ */
+static bool
+opt_if_safe_cf_list(nir_builder *b, struct exec_list *cf_list)
+{
+   bool progress = false;
+   foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
+      switch (cf_node->type) {
+      case nir_cf_node_block:
+         break;
+
+      case nir_cf_node_if: {
+         nir_if *nif = nir_cf_node_as_if(cf_node);
+         progress |= opt_if_safe_cf_list(b, &nif->then_list);
+         progress |= opt_if_safe_cf_list(b, &nif->else_list);
+         progress |= opt_if_evaluate_condition_use(b, nif);
+         break;
+      }
+
+      case nir_cf_node_loop: {
+         nir_loop *loop = nir_cf_node_as_loop(cf_node);
+         progress |= opt_if_safe_cf_list(b, &loop->body);
+         break;
+      }
+
+      case nir_cf_node_function:
+         unreachable("Invalid cf type");
+      }
+   }
+
+   return progress;
+}
+
 bool
 nir_opt_if(nir_shader *shader)
 {
@@ -414,6 +516,12 @@ nir_opt_if(nir_shader *shader)
       nir_builder b;
       nir_builder_init(&b, function->impl);
 
+      nir_metadata_require(function->impl, nir_metadata_block_index |
+                           nir_metadata_dominance);
+      progress = opt_if_safe_cf_list(&b, &function->impl->body);
+      nir_metadata_preserve(function->impl, nir_metadata_block_index |
+                            nir_metadata_dominance);
+
       if (opt_if_cf_list(&b, &function->impl->body)) {
          nir_metadata_preserve(function->impl, nir_metadata_none);