nir/algebraic: mark some optimizations with fsat(NaN) as inexact
[mesa.git] / src / compiler / nir / nir_lower_returns.c
index 91bb2f7dfeb1c043d9fbc81f722a7dff35c118d6..56c7656aeafc7412fc7c41781087839c9eb166ca 100644 (file)
@@ -30,6 +30,15 @@ struct lower_returns_state {
    struct exec_list *cf_list;
    nir_loop *loop;
    nir_variable *return_flag;
+
+   /* This indicates that we have a return which is predicated on some form of
+    * control-flow.  Since whether or not the return happens can only be
+    * determined dynamically at run-time, everything that occurs afterwards
+    * needs to be predicated on the return flag variable.
+    */
+   bool has_predicated_return;
+
+   bool removed_unreachable_code;
 };
 
 static bool lower_returns_in_cf_list(struct exec_list *cf_list,
@@ -41,7 +50,7 @@ predicate_following(nir_cf_node *node, struct lower_returns_state *state)
    nir_builder *b = &state->builder;
    b->cursor = nir_after_cf_node_and_phis(node);
 
-   if (nir_cursors_equal(b->cursor, nir_after_cf_list(state->cf_list)))
+   if (!state->loop && nir_cursors_equal(b->cursor, nir_after_cf_list(state->cf_list)))
       return; /* Nothing to predicate */
 
    assert(state->return_flag);
@@ -82,8 +91,10 @@ lower_returns_in_loop(nir_loop *loop, struct lower_returns_state *state)
     * flag set to true.  We need to predicate everything following the loop
     * on the return flag.
     */
-   if (progress)
+   if (progress) {
       predicate_following(&loop->cf_node, state);
+      state->has_predicated_return = true;
+   }
 
    return progress;
 }
@@ -91,23 +102,48 @@ lower_returns_in_loop(nir_loop *loop, struct lower_returns_state *state)
 static bool
 lower_returns_in_if(nir_if *if_stmt, struct lower_returns_state *state)
 {
-   bool progress;
+   bool progress, then_progress, else_progress;
 
-   progress = lower_returns_in_cf_list(&if_stmt->then_list, state);
-   progress = lower_returns_in_cf_list(&if_stmt->else_list, state) || progress;
+   bool has_predicated_return = state->has_predicated_return;
+   state->has_predicated_return = false;
+
+   then_progress = lower_returns_in_cf_list(&if_stmt->then_list, state);
+   else_progress = lower_returns_in_cf_list(&if_stmt->else_list, state);
+   progress = then_progress || else_progress;
 
    /* If either of the recursive calls made progress, then there were
     * returns inside of the body of the if.  If we're in a loop, then these
     * were lowered to breaks which automatically skip to the end of the
     * loop so we don't have to do anything.  If we're not in a loop, then
-    * all we know is that the return flag is set appropreately and that the
+    * all we know is that the return flag is set appropriately and that the
     * recursive calls ensured that nothing gets executed *inside* the if
     * after a return.  In order to ensure nothing outside gets executed
     * after a return, we need to predicate everything following on the
     * return flag.
     */
-   if (progress && !state->loop)
-      predicate_following(&if_stmt->cf_node, state);
+   if (progress && !state->loop) {
+      if (state->has_predicated_return) {
+         predicate_following(&if_stmt->cf_node, state);
+      } else {
+         /* If there are no nested returns we can just add the instructions to
+          * the end of the branch that doesn't have the return.
+          */
+         nir_cf_list list;
+         nir_cf_extract(&list, nir_after_cf_node(&if_stmt->cf_node),
+                        nir_after_cf_list(state->cf_list));
+
+         if (then_progress && else_progress) {
+            /* Both branches return so delete instructions following the if */
+            nir_cf_delete(&list);
+         } else if (then_progress) {
+            nir_cf_reinsert(&list, nir_after_cf_list(&if_stmt->else_list));
+         } else {
+            nir_cf_reinsert(&list, nir_after_cf_list(&if_stmt->then_list));
+         }
+      }
+   }
+
+   state->has_predicated_return = progress || has_predicated_return;
 
    return progress;
 }
@@ -128,8 +164,9 @@ lower_returns_in_block(nir_block *block, struct lower_returns_state *state)
           */
          return false;
       } else {
+         state->removed_unreachable_code = true;
          nir_cf_delete(&list);
-         return true;
+         return false;
       }
    }
 
@@ -146,19 +183,26 @@ lower_returns_in_block(nir_block *block, struct lower_returns_state *state)
 
    nir_instr_remove(&jump->instr);
 
+   /* If this is a return in the last block of the function there is nothing
+    * more to do once its removed.
+    */
+   if (block == nir_impl_last_block(state->builder.impl))
+      return true;
+
    nir_builder *b = &state->builder;
-   b->cursor = nir_after_block(block);
 
    /* Set the return flag */
    if (state->return_flag == NULL) {
       state->return_flag =
          nir_local_variable_create(b->impl, glsl_bool_type(), "return");
 
-      /* Set a default value of false */
-      state->return_flag->constant_initializer =
-         rzalloc(state->return_flag, nir_constant);
+      /* Initialize the variable to 0 */
+      b->cursor = nir_before_cf_list(&b->impl->body);
+      nir_store_var(b, state->return_flag, nir_imm_false(b), 1);
    }
-   nir_store_var(b, state->return_flag, nir_imm_int(b, NIR_TRUE), 1);
+
+   b->cursor = nir_after_block(block);
+   nir_store_var(b, state->return_flag, nir_imm_true(b), 1);
 
    if (state->loop) {
       /* We're in a loop;  we need to break out of it. */
@@ -220,13 +264,18 @@ nir_lower_returns_impl(nir_function_impl *impl)
    state.cf_list = &impl->body;
    state.loop = NULL;
    state.return_flag = NULL;
+   state.has_predicated_return = false;
+   state.removed_unreachable_code = false;
    nir_builder_init(&state.builder, impl);
 
    bool progress = lower_returns_in_cf_list(&impl->body, &state);
+   progress = progress || state.removed_unreachable_code;
 
    if (progress) {
       nir_metadata_preserve(impl, nir_metadata_none);
       nir_repair_ssa_impl(impl);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
    }
 
    return progress;
@@ -237,7 +286,7 @@ nir_lower_returns(nir_shader *shader)
 {
    bool progress = false;
 
-   nir_foreach_function(shader, function) {
+   nir_foreach_function(function, shader) {
       if (function->impl)
          progress = nir_lower_returns_impl(function->impl) || progress;
    }