nir: Relax opt_if logic to prevent re-merging 64bit phis for loop headers
[mesa.git] / src / compiler / nir / nir_lower_returns.c
index 3ea69e25204855b1cc79157d0083ab8b4353e2e1..56c7656aeafc7412fc7c41781087839c9eb166ca 100644 (file)
@@ -37,6 +37,8 @@ struct lower_returns_state {
     * needs to be predicated on the return flag variable.
     */
    bool has_predicated_return;
+
+   bool removed_unreachable_code;
 };
 
 static bool lower_returns_in_cf_list(struct exec_list *cf_list,
@@ -48,7 +50,7 @@ predicate_following(nir_cf_node *node, struct lower_returns_state *state)
    nir_builder *b = &state->builder;
    b->cursor = nir_after_cf_node_and_phis(node);
 
-   if (nir_cursors_equal(b->cursor, nir_after_cf_list(state->cf_list)))
+   if (!state->loop && nir_cursors_equal(b->cursor, nir_after_cf_list(state->cf_list)))
       return; /* Nothing to predicate */
 
    assert(state->return_flag);
@@ -162,8 +164,9 @@ lower_returns_in_block(nir_block *block, struct lower_returns_state *state)
           */
          return false;
       } else {
+         state->removed_unreachable_code = true;
          nir_cf_delete(&list);
-         return true;
+         return false;
       }
    }
 
@@ -195,11 +198,11 @@ lower_returns_in_block(nir_block *block, struct lower_returns_state *state)
 
       /* Initialize the variable to 0 */
       b->cursor = nir_before_cf_list(&b->impl->body);
-      nir_store_var(b, state->return_flag, nir_imm_int(b, NIR_FALSE), 1);
+      nir_store_var(b, state->return_flag, nir_imm_false(b), 1);
    }
 
    b->cursor = nir_after_block(block);
-   nir_store_var(b, state->return_flag, nir_imm_int(b, NIR_TRUE), 1);
+   nir_store_var(b, state->return_flag, nir_imm_true(b), 1);
 
    if (state->loop) {
       /* We're in a loop;  we need to break out of it. */
@@ -262,13 +265,17 @@ nir_lower_returns_impl(nir_function_impl *impl)
    state.loop = NULL;
    state.return_flag = NULL;
    state.has_predicated_return = false;
+   state.removed_unreachable_code = false;
    nir_builder_init(&state.builder, impl);
 
    bool progress = lower_returns_in_cf_list(&impl->body, &state);
+   progress = progress || state.removed_unreachable_code;
 
    if (progress) {
       nir_metadata_preserve(impl, nir_metadata_none);
       nir_repair_ssa_impl(impl);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
    }
 
    return progress;