nir/lower_returns: Better algorithm as per connor
authorJason Ekstrand <jason.ekstrand@intel.com>
Mon, 28 Dec 2015 06:50:45 +0000 (22:50 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Mon, 28 Dec 2015 06:50:45 +0000 (22:50 -0800)
src/glsl/nir/nir_lower_returns.c

index f36fc9dd6138cffea23928fc3f8ea5ba3cf3d86c..ce0512c770aa3246bb4d097daab445e99b834933 100644 (file)
 
 struct lower_returns_state {
    nir_builder builder;
-   struct exec_list *parent_cf_list;
    struct exec_list *cf_list;
    nir_loop *loop;
-   nir_if *if_stmt;
    nir_variable *return_flag;
 };
 
 static bool lower_returns_in_cf_list(struct exec_list *cf_list,
                                      struct lower_returns_state *state);
 
+static void
+predicate_following(nir_cf_node *node, struct lower_returns_state *state)
+{
+   nir_builder *b = &state->builder;
+   b->cursor = nir_after_cf_node_and_phis(node);
+
+   if (nir_cursors_equal(b->cursor, nir_after_cf_list(state->cf_list)))
+      return; /* Nothing to predicate */
+
+   assert(state->return_flag);
+
+   nir_if *if_stmt = nir_if_create(b->shader);
+   if_stmt->condition = nir_src_for_ssa(nir_load_var(b, state->return_flag));
+   nir_cf_node_insert(b->cursor, &if_stmt->cf_node);
+
+   if (state->loop) {
+      /* If we're inside of a loop, then all we need to do is insert a
+       * conditional break.
+       */
+      nir_jump_instr *brk =
+         nir_jump_instr_create(state->builder.shader, nir_jump_break);
+      nir_instr_insert(nir_before_cf_list(&if_stmt->then_list), &brk->instr);
+   } else {
+      /* Otherwise, we need to actually move everything into the else case
+       * of the if statement.
+       */
+      nir_cf_list list;
+      nir_cf_extract(&list, nir_after_cf_node(&if_stmt->cf_node),
+                            nir_after_cf_list(state->cf_list));
+      assert(!exec_list_is_empty(&list.list));
+      nir_cf_reinsert(&list, nir_before_cf_list(&if_stmt->else_list));
+   }
+}
+
 static bool
 lower_returns_in_loop(nir_loop *loop, struct lower_returns_state *state)
 {
@@ -45,34 +77,15 @@ lower_returns_in_loop(nir_loop *loop, struct lower_returns_state *state)
    bool progress = lower_returns_in_cf_list(&loop->body, state);
    state->loop = parent;
 
-   /* Nothing interesting */
-   if (!progress)
-      return false;
-
-   /* In this case, there was a return somewhere inside of the loop.  That
-    * return would have been turned into a write to the return_flag
-    * variable and a break.  We need to insert a predicated return right
-    * after the loop ends.
+   /* If the recursive call made progress, then there were returns inside
+    * of the loop.  These would have been lowered to breaks with the return
+    * flag set to true.  We need to predicate everything following the loop
+    * on the return flag.
     */
+   if (progress)
+      predicate_following(&loop->cf_node, state);
 
-   assert(state->return_flag);
-
-   nir_intrinsic_instr *load =
-      nir_intrinsic_instr_create(state->builder.shader, nir_intrinsic_load_var);
-   load->num_components = 1;
-   load->variables[0] = nir_deref_var_create(load, state->return_flag);
-   nir_ssa_dest_init(&load->instr, &load->dest, 1, "return");
-   nir_instr_insert(nir_after_cf_node(&loop->cf_node), &load->instr);
-
-   nir_if *if_stmt = nir_if_create(state->builder.shader);
-   if_stmt->condition = nir_src_for_ssa(&load->dest.ssa);
-   nir_cf_node_insert(nir_after_instr(&load->instr), &if_stmt->cf_node);
-
-   nir_jump_instr *ret =
-      nir_jump_instr_create(state->builder.shader, nir_jump_return);
-   nir_instr_insert(nir_before_cf_list(&if_stmt->then_list), &ret->instr);
-
-   return true;
+   return progress;
 }
 
 static bool
@@ -80,11 +93,21 @@ lower_returns_in_if(nir_if *if_stmt, struct lower_returns_state *state)
 {
    bool progress;
 
-   nir_if *parent = state->if_stmt;
-   state->if_stmt = if_stmt;
    progress = lower_returns_in_cf_list(&if_stmt->then_list, state);
    progress = lower_returns_in_cf_list(&if_stmt->else_list, state) || progress;
-   state->if_stmt = parent;
+
+   /* If either of the recursive calls made progress, then there were
+    * returns inside of the body of the if.  If we're in a loop, then these
+    * were lowered to breaks which automatically skip to the end of the
+    * loop so we don't have to do anything.  If we're not in a loop, then
+    * all we know is that the return flag is set appropreately and that the
+    * recursive calls ensured that nothing gets executed *inside* the if
+    * after a return.  In order to ensure nothing outside gets executed
+    * after a return, we need to predicate everything following on the
+    * return flag.
+    */
+   if (progress && !state->loop)
+      predicate_following(&if_stmt->cf_node, state);
 
    return progress;
 }
@@ -121,51 +144,29 @@ lower_returns_in_block(nir_block *block, struct lower_returns_state *state)
    if (jump->type != nir_jump_return)
       return false;
 
-   if (state->loop) {
-      /* We're in a loop.  Just set the return flag to true and break.
-       * lower_returns_in_loop will do the rest.
-       */
-      nir_builder *b = &state->builder;
-      b->cursor = nir_before_instr(&jump->instr);
+   nir_builder *b = &state->builder;
+   b->cursor = nir_before_instr(&jump->instr);
 
-      if (state->return_flag == NULL) {
-         state->return_flag =
-            nir_local_variable_create(b->impl, glsl_bool_type(), "return");
+   /* Set the return flag */
+   if (state->return_flag == NULL) {
+      state->return_flag =
+         nir_local_variable_create(b->impl, glsl_bool_type(), "return");
 
-         /* Set a default value of false */
-         state->return_flag->constant_initializer =
-            rzalloc(state->return_flag, nir_constant);
-      }
+      /* Set a default value of false */
+      state->return_flag->constant_initializer =
+         rzalloc(state->return_flag, nir_constant);
+   }
+   nir_store_var(b, state->return_flag, nir_imm_int(b, NIR_TRUE));
 
-      nir_store_var(b, state->return_flag, nir_imm_int(b, NIR_TRUE));
+   if (state->loop) {
+      /* We're in a loop.  Make the return a break. */
       jump->type = nir_jump_return;
-   } else if (state->if_stmt) {
-      /* If we're not in a loop but in an if, just move the rest of the CF
-       * list into the the other case of the if.
-       */
-      nir_cf_list list;
-      nir_cf_extract(&list, nir_after_cf_node(&state->if_stmt->cf_node),
-                            nir_after_cf_list(state->parent_cf_list));
-
-      nir_instr_remove(&jump->instr);
-
-      if (state->cf_list == &state->if_stmt->then_list) {
-         nir_cf_reinsert(&list,
-                         nir_after_cf_list(&state->if_stmt->else_list));
-      } else if (state->cf_list == &state->if_stmt->else_list) {
-         nir_cf_reinsert(&list,
-                         nir_after_cf_list(&state->if_stmt->then_list));
-      } else {
-         unreachable("Invalid CF list");
-      }
    } else {
+      /* Not in a loop.  Just delete the return; we'll deal with
+       * predicating later.
+       */
+      assert(nir_cf_node_next(&block->cf_node) == NULL);
       nir_instr_remove(&jump->instr);
-
-      /* No if, no nothing.  Just delete the return and whatever follows. */
-      nir_cf_list list;
-      nir_cf_extract(&list, nir_after_cf_node(&block->cf_node),
-                            nir_after_cf_list(state->parent_cf_list));
-      nir_cf_delete(&list);
    }
 
    return true;
@@ -177,10 +178,14 @@ lower_returns_in_cf_list(struct exec_list *cf_list,
 {
    bool progress = false;
 
-   struct exec_list *prev_parent_list = state->parent_cf_list;
-   state->parent_cf_list = state->cf_list;
+   struct exec_list *parent_list = state->cf_list;
    state->cf_list = cf_list;
 
+   /* We iterate over the list backwards because any given lower call may
+    * take everything following the given CF node and predicate it.  In
+    * order to avoid recursion/iteration problems, we want everything after
+    * a given node to already be lowered before this happens.
+    */
    foreach_list_typed_reverse_safe(nir_cf_node, node, node, cf_list) {
       switch (node->type) {
       case nir_cf_node_block:
@@ -203,8 +208,7 @@ lower_returns_in_cf_list(struct exec_list *cf_list,
       }
    }
 
-   state->cf_list = state->parent_cf_list;
-   state->parent_cf_list = prev_parent_list;
+   state->cf_list = parent_list;
 
    return progress;
 }
@@ -214,10 +218,8 @@ nir_lower_returns_impl(nir_function_impl *impl)
 {
    struct lower_returns_state state;
 
-   state.parent_cf_list = NULL;
    state.cf_list = &impl->body;
    state.loop = NULL;
-   state.if_stmt = NULL;
    state.return_flag = NULL;
    nir_builder_init(&state.builder, impl);