nir: simplify phi handling in divergence analysis
authorDaniel Schürmann <daniel@schuermann.dev>
Wed, 5 Feb 2020 17:36:34 +0000 (18:36 +0100)
committerMarge Bot <eric+marge@anholt.net>
Wed, 13 May 2020 18:49:22 +0000 (18:49 +0000)
This patch adds some control flow information to the
state to keep track whether a loop contains divergent
continue or break statements to not having to
recalculate this property for every phi.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4062>

src/compiler/nir/nir_divergence_analysis.c

index 1e3ead1e114885896b0906fe7d5e9d3282b23dbe..b2f34b0dfb67ebefbb96a56995c3fbbcdb278c76 100644 (file)
 struct divergence_state {
    const nir_divergence_options options;
    const gl_shader_stage stage;
+
+   /** current control flow state */
+   /* True if some loop-active invocations might take a different control-flow path.
+    * A divergent break does not cause subsequent control-flow to be considered
+    * divergent because those invocations are no longer active in the loop.
+    * For a divergent if, both sides are considered divergent flow because
+    * the other side is still loop-active. */
+   bool divergent_loop_cf;
+   /* True if a divergent continue happened since the loop header */
+   bool divergent_loop_continue;
+   /* True if a divergent break happened since the loop header */
+   bool divergent_loop_break;
 };
 
 static bool
@@ -545,6 +557,28 @@ visit_deref(nir_deref_instr *deref, struct divergence_state *state)
    return is_divergent;
 }
 
+static bool
+visit_jump(nir_jump_instr *jump, struct divergence_state *state)
+{
+   switch (jump->type) {
+   case nir_jump_continue:
+      if (state->divergent_loop_continue)
+         return false;
+      if (state->divergent_loop_cf)
+         state->divergent_loop_continue = true;
+      return state->divergent_loop_continue;
+   case nir_jump_break:
+      if (state->divergent_loop_break)
+         return false;
+      if (state->divergent_loop_cf)
+         state->divergent_loop_break = true;
+      return state->divergent_loop_break;
+   case nir_jump_return:
+      unreachable("NIR divergence analysis: Unsupported return instruction.");
+   }
+   return false;
+}
+
 static bool
 visit_block(nir_block *block, struct divergence_state *state)
 {
@@ -570,11 +604,12 @@ visit_block(nir_block *block, struct divergence_state *state)
       case nir_instr_type_deref:
          has_changed |= visit_deref(nir_instr_as_deref(instr), state);
          break;
+      case nir_instr_type_jump:
+         has_changed |= visit_jump(nir_instr_as_jump(instr), state);
+         break;
       /* phis are handled when processing the branches */
       case nir_instr_type_phi:
          break;
-      case nir_instr_type_jump:
-         break;
       case nir_instr_type_call:
       case nir_instr_type_parallel_copy:
          unreachable("NIR divergence analysis: Unsupported instruction type.");
@@ -612,6 +647,7 @@ visit_if_merge_phi(nir_phi_instr *phi, bool if_cond_divergent)
       phi->dest.ssa.divergent = true;
       return true;
    }
+
    return false;
 }
 
@@ -622,64 +658,35 @@ visit_if_merge_phi(nir_phi_instr *phi, bool if_cond_divergent)
  *     is divergent or a divergent loop continue condition
  *     is associated with a different ssa-def. */
 static bool
-visit_loop_header_phi(nir_phi_instr *phi, nir_loop *loop)
+visit_loop_header_phi(nir_phi_instr *phi, nir_block *preheader, bool divergent_continue)
 {
    if (phi->dest.ssa.divergent)
       return false;
 
-   nir_cf_node *prev = nir_cf_node_prev(&loop->cf_node);
    nir_ssa_def* same = NULL;
-   bool all_same = true;
-
-   /* first, check if all loop-carried values are from the same ssa-def */
    nir_foreach_phi_src(src, phi) {
       /* if any source value is divergent, the resulting value is divergent */
       if (src->src.ssa->divergent) {
          phi->dest.ssa.divergent = true;
          return true;
       }
-      /* skip the loop preheader */
-      if (src->pred == nir_cf_node_as_block(prev))
+      /* if this loop is uniform, we're done here */
+      if (!divergent_continue)
          continue;
-      if (src->src.ssa->parent_instr->type == nir_instr_type_ssa_undef)
-         continue;
-      if (!same)
-         same = src->src.ssa;
-      else if (same != src->src.ssa)
-         all_same = false;
-   }
-
-   /* if all loop-carried values are the same, the resulting value is uniform */
-   if (all_same)
-      return false;
-
-   /* check if the loop-carried values come from different ssa-defs
-    * and the corresponding condition is divergent. */
-   nir_foreach_phi_src(src, phi) {
       /* skip the loop preheader */
-      if (src->pred == nir_cf_node_as_block(prev))
-         continue;
-
-      /* skip the unconditional back-edge */
-      if (src->pred == nir_loop_last_block(loop))
+      if (src->pred == preheader)
          continue;
-
-      /* if the value is undef, we don't need to check the condition */
+      /* skip undef values */
       if (src->src.ssa->parent_instr->type == nir_instr_type_ssa_undef)
          continue;
 
-      nir_cf_node *current = src->pred->cf_node.parent;
-      /* check recursively the conditions if any is divergent */
-      while (current->type != nir_cf_node_loop) {
-         assert (current->type == nir_cf_node_if);
-         nir_if *if_node = nir_cf_node_as_if(current);
-         if (if_node->condition.ssa->divergent) {
-            phi->dest.ssa.divergent = true;
-            return true;
-         }
-         current = current->parent;
+      /* check if all loop-carried values are from the same ssa-def */
+      if (!same)
+         same = src->src.ssa;
+      else if (same != src->src.ssa) {
+         phi->dest.ssa.divergent = true;
+         return true;
       }
-      assert(current == &loop->cf_node);
    }
 
    return false;
@@ -692,82 +699,60 @@ visit_loop_header_phi(nir_phi_instr *phi, nir_loop *loop)
  *     not loop-invariant.
  *     (note: there should be no phi for loop-invariant variables.) */
 static bool
-visit_loop_exit_phi(nir_phi_instr *phi, nir_loop *loop)
+visit_loop_exit_phi(nir_phi_instr *phi, bool divergent_break)
 {
    if (phi->dest.ssa.divergent)
       return false;
 
-   /* Check if any loop exit condition is divergent:
-    * That is any break happens under divergent condition or
-    * a break is preceeded by a divergent continue
-    */
+   if (divergent_break) {
+      phi->dest.ssa.divergent = true;
+      return true;
+   }
+
+   /* if any source value is divergent, the resulting value is divergent */
    nir_foreach_phi_src(src, phi) {
-      /* if any source value is divergent, the resulting value is divergent */
       if (src->src.ssa->divergent) {
          phi->dest.ssa.divergent = true;
          return true;
       }
-
-      nir_cf_node *current = src->pred->cf_node.parent;
-
-      /* check recursively the conditions if any is divergent */
-      while (current->type != nir_cf_node_loop) {
-         assert(current->type == nir_cf_node_if);
-         nir_if *if_node = nir_cf_node_as_if(current);
-         if (if_node->condition.ssa->divergent) {
-            phi->dest.ssa.divergent = true;
-            return true;
-         }
-         current = current->parent;
-      }
-
-      /* check if any divergent continue happened before the break */
-      nir_foreach_block_in_cf_node(block, &loop->cf_node) {
-         if (block == src->pred)
-            break;
-         if (!nir_block_ends_in_jump(block))
-            continue;
-
-         nir_jump_instr *jump = nir_instr_as_jump(nir_block_last_instr(block));
-         if (jump->type != nir_jump_continue)
-            continue;
-
-         current = block->cf_node.parent;
-         bool is_divergent = false;
-         while (current != &loop->cf_node) {
-            /* the continue belongs to an inner loop */
-            if (current->type == nir_cf_node_loop) {
-               is_divergent = false;
-               break;
-            }
-            assert(current->type == nir_cf_node_if);
-            nir_if *if_node = nir_cf_node_as_if(current);
-            is_divergent |= if_node->condition.ssa->divergent;
-            current = current->parent;
-         }
-
-         if (is_divergent) {
-            phi->dest.ssa.divergent = true;
-            return true;
-         }
-      }
    }
+
    return false;
 }
 
 static bool
 visit_if(nir_if *if_stmt, struct divergence_state *state)
 {
-   bool progress = visit_cf_list(&if_stmt->then_list, state) |
-                   visit_cf_list(&if_stmt->else_list, state);
+   bool progress = false;
+
+   struct divergence_state then_state = *state;
+   then_state.divergent_loop_cf |= if_stmt->condition.ssa->divergent;
+   progress |= visit_cf_list(&if_stmt->then_list, &then_state);
+
+   struct divergence_state else_state = *state;
+   else_state.divergent_loop_cf |= if_stmt->condition.ssa->divergent;
+   progress |= visit_cf_list(&if_stmt->else_list, &else_state);
 
    /* handle phis after the IF */
    nir_foreach_instr(instr, nir_cf_node_cf_tree_next(&if_stmt->cf_node)) {
       if (instr->type != nir_instr_type_phi)
          break;
-      progress |= visit_if_merge_phi(nir_instr_as_phi(instr), if_stmt->condition.ssa->divergent);
+
+      progress |= visit_if_merge_phi(nir_instr_as_phi(instr),
+                                     if_stmt->condition.ssa->divergent);
    }
 
+   /* join loop divergence information from both branch legs */
+   state->divergent_loop_continue |= then_state.divergent_loop_continue ||
+                                     else_state.divergent_loop_continue;
+   state->divergent_loop_break |= then_state.divergent_loop_break ||
+                                  else_state.divergent_loop_break;
+
+   /* A divergent continue makes succeeding loop CF divergent:
+    * not all loop-active invocations participate in the remaining loop-body
+    * which means that a following break might be taken by some invocations, only */
+   state->divergent_loop_cf |= state->divergent_loop_continue;
+
    return progress;
 }
 
@@ -775,36 +760,51 @@ static bool
 visit_loop(nir_loop *loop, struct divergence_state *state)
 {
    bool progress = false;
+   nir_block *loop_header = nir_loop_first_block(loop);
+   nir_block *loop_preheader = nir_block_cf_tree_prev(loop_header);
 
    /* handle loop header phis first */
-   nir_foreach_instr(instr, nir_loop_first_block(loop)) {
+   nir_foreach_instr(instr, loop_header) {
       if (instr->type != nir_instr_type_phi)
          break;
-      progress |= visit_loop_header_phi(nir_instr_as_phi(instr), loop);
+
+      progress |= visit_loop_header_phi(nir_instr_as_phi(instr),
+                                        loop_preheader, false);
+
    }
 
-   bool repeat = true;
-   while (repeat) {
-      /* process loop body */
-      repeat = visit_cf_list(&loop->body, state);
-
-      if (repeat) {
-         repeat = false;
-         /* revisit loop header phis to see if something has changed */
-         nir_foreach_instr(instr, nir_loop_first_block(loop)) {
-            if (instr->type != nir_instr_type_phi)
-               break;
-            repeat |= visit_loop_header_phi(nir_instr_as_phi(instr), loop);
-         }
-         progress = true;
+   /* setup loop state */
+   struct divergence_state loop_state = *state;
+   loop_state.divergent_loop_cf = false;
+   loop_state.divergent_loop_continue = false;
+   loop_state.divergent_loop_break = false;
+
+   /* process loop body until no further changes are made */
+   bool repeat;
+   do {
+      progress |= visit_cf_list(&loop->body, &loop_state);
+      repeat = false;
+
+      /* revisit loop header phis to see if something has changed */
+      nir_foreach_instr(instr, loop_header) {
+         if (instr->type != nir_instr_type_phi)
+            break;
+
+         repeat |= visit_loop_header_phi(nir_instr_as_phi(instr),
+                                         loop_preheader,
+                                         loop_state.divergent_loop_continue);
       }
-   }
+
+      loop_state.divergent_loop_cf = false;
+   } while (repeat);
 
    /* handle phis after the loop */
    nir_foreach_instr(instr, nir_cf_node_cf_tree_next(&loop->cf_node)) {
       if (instr->type != nir_instr_type_phi)
          break;
-      progress |= visit_loop_exit_phi(nir_instr_as_phi(instr), loop);
+
+      progress |= visit_loop_exit_phi(nir_instr_as_phi(instr),
+                                      loop_state.divergent_loop_break);
    }
 
    return progress;
@@ -855,6 +855,9 @@ nir_divergence_analysis(nir_shader *shader, nir_divergence_options options)
    struct divergence_state state = {
       .options = options,
       .stage = shader->info.stage,
+      .divergent_loop_cf = false,
+      .divergent_loop_continue = false,
+      .divergent_loop_break = false,
    };
 
    visit_cf_list(&impl->body, &state);