nir: Add goto_if jump instruction
[mesa.git] / src / compiler / nir / nir_divergence_analysis.c
index 9b8f9cb6349f185e2334d48ef18f8b933d1f67ba..05892b440bfb508f0ac2a86d16c2447395f94bc6 100644 (file)
  * ACM, 2013, 35 (4), pp.13:1-13:36. <10.1145/2523815>. <hal-00909072v2>
  */
 
+struct divergence_state {
+   const nir_divergence_options options;
+   const gl_shader_stage stage;
+
+   /** current control flow state */
+   /* True if some loop-active invocations might take a different control-flow path.
+    * A divergent break does not cause subsequent control-flow to be considered
+    * divergent because those invocations are no longer active in the loop.
+    * For a divergent if, both sides are considered divergent flow because
+    * the other side is still loop-active. */
+   bool divergent_loop_cf;
+   /* True if a divergent continue happened since the loop header */
+   bool divergent_loop_continue;
+   /* True if a divergent break happened since the loop header */
+   bool divergent_loop_break;
+
+   /* True if we visit the block for the fist time */
+   bool first_visit;
+};
+
 static bool
-visit_cf_list(bool *divergent, struct exec_list *list,
-              nir_divergence_options options, gl_shader_stage stage);
+visit_cf_list(struct exec_list *list, struct divergence_state *state);
 
 static bool
-visit_alu(bool *divergent, nir_alu_instr *instr)
+visit_alu(nir_alu_instr *instr)
 {
-   if (divergent[instr->dest.dest.ssa.index])
+   if (instr->dest.dest.ssa.divergent)
       return false;
 
    unsigned num_src = nir_op_infos[instr->op].num_inputs;
 
    for (unsigned i = 0; i < num_src; i++) {
-      if (divergent[instr->src[i].src.ssa->index]) {
-         divergent[instr->dest.dest.ssa.index] = true;
+      if (instr->src[i].src.ssa->divergent) {
+         instr->dest.dest.ssa.divergent = true;
          return true;
       }
    }
@@ -59,15 +78,16 @@ visit_alu(bool *divergent, nir_alu_instr *instr)
 }
 
 static bool
-visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
-                nir_divergence_options options, gl_shader_stage stage)
+visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
 {
    if (!nir_intrinsic_infos[instr->intrinsic].has_dest)
       return false;
 
-   if (divergent[instr->dest.ssa.index])
+   if (instr->dest.ssa.divergent)
       return false;
 
+   nir_divergence_options options = state->options;
+   gl_shader_stage stage = state->stage;
    bool is_divergent = false;
    switch (instr->intrinsic) {
    /* Intrinsics which are always uniform */
@@ -117,7 +137,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
 
    /* Intrinsics with divergence depending on shader stage and hardware */
    case nir_intrinsic_load_input:
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       if (stage == MESA_SHADER_FRAGMENT)
          is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
       else if (stage == MESA_SHADER_TESS_EVAL)
@@ -125,14 +145,35 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
       else
          is_divergent = true;
       break;
+   case nir_intrinsic_load_per_vertex_input:
+      is_divergent = instr->src[0].ssa->divergent ||
+                     instr->src[1].ssa->divergent;
+      if (stage == MESA_SHADER_TESS_CTRL)
+         is_divergent |= !(options & nir_divergence_single_patch_per_tcs_subgroup);
+      if (stage == MESA_SHADER_TESS_EVAL)
+         is_divergent |= !(options & nir_divergence_single_patch_per_tes_subgroup);
+      else
+         is_divergent = true;
+      break;
+   case nir_intrinsic_load_input_vertex:
+      is_divergent = instr->src[1].ssa->divergent;
+      assert(stage == MESA_SHADER_FRAGMENT);
+      is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
+      break;
    case nir_intrinsic_load_output:
       assert(stage == MESA_SHADER_TESS_CTRL || stage == MESA_SHADER_FRAGMENT);
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       if (stage == MESA_SHADER_TESS_CTRL)
          is_divergent |= !(options & nir_divergence_single_patch_per_tcs_subgroup);
       else
          is_divergent = true;
       break;
+   case nir_intrinsic_load_per_vertex_output:
+      assert(stage == MESA_SHADER_TESS_CTRL);
+      is_divergent = instr->src[0].ssa->divergent ||
+                     instr->src[1].ssa->divergent ||
+                     !(options & nir_divergence_single_patch_per_tcs_subgroup);
+      break;
    case nir_intrinsic_load_layer_id:
    case nir_intrinsic_load_front_face:
       assert(stage == MESA_SHADER_FRAGMENT);
@@ -147,7 +188,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
       break;
    case nir_intrinsic_load_fs_input_interp_deltas:
       assert(stage == MESA_SHADER_FRAGMENT);
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
       break;
    case nir_intrinsic_load_primitive_id:
@@ -157,6 +198,8 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
          is_divergent = !(options & nir_divergence_single_patch_per_tcs_subgroup);
       else if (stage == MESA_SHADER_TESS_EVAL)
          is_divergent = !(options & nir_divergence_single_patch_per_tes_subgroup);
+      else if (stage == MESA_SHADER_GEOMETRY)
+         is_divergent = true;
       else
          unreachable("Invalid stage for load_primitive_id");
       break;
@@ -187,7 +230,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
       /* fallthrough */
    case nir_intrinsic_inclusive_scan: {
       nir_op op = nir_intrinsic_reduction_op(instr);
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       if (op != nir_op_umin && op != nir_op_imin && op != nir_op_fmin &&
           op != nir_op_umax && op != nir_op_imax && op != nir_op_fmax &&
           op != nir_op_iand && op != nir_op_ior)
@@ -238,7 +281,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
    case nir_intrinsic_masked_swizzle_amd: {
       unsigned num_srcs = nir_intrinsic_infos[instr->intrinsic].num_srcs;
       for (unsigned i = 0; i < num_srcs; i++) {
-         if (divergent[instr->src[i].ssa->index]) {
+         if (instr->src[i].ssa->divergent) {
             is_divergent = true;
             break;
          }
@@ -247,8 +290,8 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
    }
 
    case nir_intrinsic_shuffle:
-      is_divergent = divergent[instr->src[0].ssa->index] &&
-                     divergent[instr->src[1].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent &&
+                     instr->src[1].ssa->divergent;
       break;
 
    /* Intrinsics which are always divergent */
@@ -262,19 +305,20 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
    case nir_intrinsic_load_barycentric_pixel:
    case nir_intrinsic_load_barycentric_centroid:
    case nir_intrinsic_load_barycentric_sample:
+   case nir_intrinsic_load_barycentric_model:
    case nir_intrinsic_load_barycentric_at_sample:
    case nir_intrinsic_load_barycentric_at_offset:
    case nir_intrinsic_interp_deref_at_offset:
    case nir_intrinsic_interp_deref_at_sample:
    case nir_intrinsic_interp_deref_at_centroid:
+   case nir_intrinsic_interp_deref_at_vertex:
    case nir_intrinsic_load_tess_coord:
    case nir_intrinsic_load_point_coord:
+   case nir_intrinsic_load_line_coord:
    case nir_intrinsic_load_frag_coord:
    case nir_intrinsic_load_sample_pos:
    case nir_intrinsic_load_vertex_id_zero_base:
    case nir_intrinsic_load_vertex_id:
-   case nir_intrinsic_load_per_vertex_input:
-   case nir_intrinsic_load_per_vertex_output:
    case nir_intrinsic_load_instance_id:
    case nir_intrinsic_load_invocation_id:
    case nir_intrinsic_load_local_invocation_id:
@@ -401,6 +445,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
    case nir_intrinsic_ballot_bit_count_inclusive:
    case nir_intrinsic_write_invocation_amd:
    case nir_intrinsic_mbcnt_amd:
+   case nir_intrinsic_elect:
       is_divergent = true;
       break;
 
@@ -414,14 +459,14 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
 #endif
    }
 
-   divergent[instr->dest.ssa.index] = is_divergent;
+   instr->dest.ssa.divergent = is_divergent;
    return is_divergent;
 }
 
 static bool
-visit_tex(bool *divergent, nir_tex_instr *instr)
+visit_tex(nir_tex_instr *instr)
 {
-   if (divergent[instr->dest.ssa.index])
+   if (instr->dest.ssa.divergent)
       return false;
 
    bool is_divergent = false;
@@ -431,196 +476,33 @@ visit_tex(bool *divergent, nir_tex_instr *instr)
       case nir_tex_src_sampler_deref:
       case nir_tex_src_sampler_handle:
       case nir_tex_src_sampler_offset:
-         is_divergent |= divergent[instr->src[i].src.ssa->index] &&
+         is_divergent |= instr->src[i].src.ssa->divergent &&
                          instr->sampler_non_uniform;
          break;
       case nir_tex_src_texture_deref:
       case nir_tex_src_texture_handle:
       case nir_tex_src_texture_offset:
-         is_divergent |= divergent[instr->src[i].src.ssa->index] &&
+         is_divergent |= instr->src[i].src.ssa->divergent &&
                          instr->texture_non_uniform;
          break;
       default:
-         is_divergent |= divergent[instr->src[i].src.ssa->index];
+         is_divergent |= instr->src[i].src.ssa->divergent;
          break;
       }
    }
 
-   divergent[instr->dest.ssa.index] = is_divergent;
+   instr->dest.ssa.divergent = is_divergent;
    return is_divergent;
 }
 
 static bool
-visit_phi(bool *divergent, nir_phi_instr *instr)
-{
-   /* There are 3 types of phi instructions:
-    * (1) gamma: represent the joining point of different paths
-    *     created by an “if-then-else” branch.
-    *     The resulting value is divergent if the branch condition
-    *     or any of the source values is divergent.
-    *
-    * (2) mu: which only exist at loop headers,
-    *     merge initial and loop-carried values.
-    *     The resulting value is divergent if any source value
-    *     is divergent or a divergent loop continue condition
-    *     is associated with a different ssa-def.
-    *
-    * (3) eta: represent values that leave a loop.
-    *     The resulting value is divergent if the source value is divergent
-    *     or any loop exit condition is divergent for a value which is
-    *     not loop-invariant.
-    *     (note: there should be no phi for loop-invariant variables.)
-    */
-
-   if (divergent[instr->dest.ssa.index])
-      return false;
-
-   nir_foreach_phi_src(src, instr) {
-      /* if any source value is divergent, the resulting value is divergent */
-      if (divergent[src->src.ssa->index]) {
-         divergent[instr->dest.ssa.index] = true;
-         return true;
-      }
-   }
-
-   nir_cf_node *prev = nir_cf_node_prev(&instr->instr.block->cf_node);
-
-   if (!prev) {
-      /* mu: if no predecessor node exists, the phi must be at a loop header */
-      nir_loop *loop = nir_cf_node_as_loop(instr->instr.block->cf_node.parent);
-      prev = nir_cf_node_prev(&loop->cf_node);
-      nir_ssa_def* same = NULL;
-      bool all_same = true;
-
-      /* first, check if all loop-carried values are from the same ssa-def */
-      nir_foreach_phi_src(src, instr) {
-         if (src->pred == nir_cf_node_as_block(prev))
-            continue;
-         if (src->src.ssa->parent_instr->type == nir_instr_type_ssa_undef)
-            continue;
-         if (!same)
-            same = src->src.ssa;
-         else if (same != src->src.ssa)
-            all_same = false;
-      }
-
-      /* if all loop-carried values are the same, the resulting value is uniform */
-      if (all_same)
-         return false;
-
-      /* check if the loop-carried values come from different ssa-defs
-       * and the corresponding condition is divergent. */
-      nir_foreach_phi_src(src, instr) {
-         /* skip the loop preheader */
-         if (src->pred == nir_cf_node_as_block(prev))
-            continue;
-
-         /* skip the unconditional back-edge */
-         if (src->pred == nir_loop_last_block(loop))
-            continue;
-
-         /* if the value is undef, we don't need to check the condition */
-         if (src->src.ssa->parent_instr->type == nir_instr_type_ssa_undef)
-            continue;
-
-         nir_cf_node *current = src->pred->cf_node.parent;
-         /* check recursively the conditions if any is divergent */
-         while (current->type != nir_cf_node_loop) {
-            assert (current->type == nir_cf_node_if);
-            nir_if *if_node = nir_cf_node_as_if(current);
-            if (divergent[if_node->condition.ssa->index]) {
-               divergent[instr->dest.ssa.index] = true;
-               return true;
-            }
-            current = current->parent;
-         }
-         assert(current == &loop->cf_node);
-      }
-
-   } else if (prev->type == nir_cf_node_if) {
-      /* if only one of the incoming values is defined, the resulting value is uniform */
-      unsigned defined_srcs = 0;
-      nir_foreach_phi_src(src, instr) {
-         if (src->src.ssa->parent_instr->type != nir_instr_type_ssa_undef)
-            defined_srcs++;
-      }
-      if (defined_srcs <= 1)
-         return false;
-
-      /* gamma: check if the condition is divergent */
-      nir_if *if_node = nir_cf_node_as_if(prev);
-      if (divergent[if_node->condition.ssa->index]) {
-         divergent[instr->dest.ssa.index] = true;
-         return true;
-      }
-
-   } else {
-      /* eta: the predecessor must be a loop */
-      assert(prev->type == nir_cf_node_loop);
-
-      /* Check if any loop exit condition is divergent:
-       * That is any break happens under divergent condition or
-       * a break is preceeded by a divergent continue
-       */
-      nir_foreach_phi_src(src, instr) {
-         nir_cf_node *current = src->pred->cf_node.parent;
-
-         /* check recursively the conditions if any is divergent */
-         while (current->type != nir_cf_node_loop) {
-            assert(current->type == nir_cf_node_if);
-            nir_if *if_node = nir_cf_node_as_if(current);
-            if (divergent[if_node->condition.ssa->index]) {
-               divergent[instr->dest.ssa.index] = true;
-               return true;
-            }
-            current = current->parent;
-         }
-         assert(current == prev);
-
-         /* check if any divergent continue happened before the break */
-         nir_foreach_block_in_cf_node(block, prev) {
-            if (block == src->pred)
-               break;
-            if (!nir_block_ends_in_jump(block))
-               continue;
-
-            nir_jump_instr *jump = nir_instr_as_jump(nir_block_last_instr(block));
-            if (jump->type != nir_jump_continue)
-               continue;
-
-            current = block->cf_node.parent;
-            bool is_divergent = false;
-            while (current != prev) {
-               /* the continue belongs to an inner loop */
-               if (current->type == nir_cf_node_loop) {
-                  is_divergent = false;
-                  break;
-               }
-               assert(current->type == nir_cf_node_if);
-               nir_if *if_node = nir_cf_node_as_if(current);
-               is_divergent |= divergent[if_node->condition.ssa->index];
-               current = current->parent;
-            }
-
-            if (is_divergent) {
-               divergent[instr->dest.ssa.index] = true;
-               return true;
-            }
-         }
-      }
-   }
-
-   return false;
-}
-
-static bool
-visit_load_const(bool *divergent, nir_load_const_instr *instr)
+visit_load_const(nir_load_const_instr *instr)
 {
    return false;
 }
 
 static bool
-visit_ssa_undef(bool *divergent, nir_ssa_undef_instr *instr)
+visit_ssa_undef(nir_ssa_undef_instr *instr)
 {
    return false;
 }
@@ -640,25 +522,24 @@ nir_variable_mode_is_uniform(nir_variable_mode mode) {
 }
 
 static bool
-nir_variable_is_uniform(nir_variable *var, nir_divergence_options options,
-                        gl_shader_stage stage)
+nir_variable_is_uniform(nir_variable *var, struct divergence_state *state)
 {
    if (nir_variable_mode_is_uniform(var->data.mode))
       return true;
 
-   if (stage == MESA_SHADER_FRAGMENT &&
-       (options & nir_divergence_single_prim_per_subgroup) &&
+   if (state->stage == MESA_SHADER_FRAGMENT &&
+       (state->options & nir_divergence_single_prim_per_subgroup) &&
        var->data.mode == nir_var_shader_in &&
        var->data.interpolation == INTERP_MODE_FLAT)
       return true;
 
-   if (stage == MESA_SHADER_TESS_CTRL &&
-       (options & nir_divergence_single_patch_per_tcs_subgroup) &&
+   if (state->stage == MESA_SHADER_TESS_CTRL &&
+       (state->options & nir_divergence_single_patch_per_tcs_subgroup) &&
        var->data.mode == nir_var_shader_out && var->data.patch)
       return true;
 
-   if (stage == MESA_SHADER_TESS_EVAL &&
-       (options & nir_divergence_single_patch_per_tes_subgroup) &&
+   if (state->stage == MESA_SHADER_TESS_EVAL &&
+       (state->options & nir_divergence_single_patch_per_tes_subgroup) &&
        var->data.mode == nir_var_shader_in && var->data.patch)
       return true;
 
@@ -666,68 +547,104 @@ nir_variable_is_uniform(nir_variable *var, nir_divergence_options options,
 }
 
 static bool
-visit_deref(bool *divergent, nir_deref_instr *deref,
-            nir_divergence_options options, gl_shader_stage stage)
+visit_deref(nir_deref_instr *deref, struct divergence_state *state)
 {
-   if (divergent[deref->dest.ssa.index])
+   if (deref->dest.ssa.divergent)
       return false;
 
    bool is_divergent = false;
    switch (deref->deref_type) {
    case nir_deref_type_var:
-      is_divergent = !nir_variable_is_uniform(deref->var, options, stage);
+      is_divergent = !nir_variable_is_uniform(deref->var, state);
       break;
    case nir_deref_type_array:
    case nir_deref_type_ptr_as_array:
-      is_divergent = divergent[deref->arr.index.ssa->index];
+      is_divergent = deref->arr.index.ssa->divergent;
       /* fallthrough */
    case nir_deref_type_struct:
    case nir_deref_type_array_wildcard:
-      is_divergent |= divergent[deref->parent.ssa->index];
+      is_divergent |= deref->parent.ssa->divergent;
       break;
    case nir_deref_type_cast:
       is_divergent = !nir_variable_mode_is_uniform(deref->var->data.mode) ||
-                     divergent[deref->parent.ssa->index];
+                     deref->parent.ssa->divergent;
       break;
    }
 
-   divergent[deref->dest.ssa.index] = is_divergent;
+   deref->dest.ssa.divergent = is_divergent;
    return is_divergent;
 }
 
 static bool
-visit_block(bool *divergent, nir_block *block, nir_divergence_options options,
-            gl_shader_stage stage)
+visit_jump(nir_jump_instr *jump, struct divergence_state *state)
+{
+   switch (jump->type) {
+   case nir_jump_continue:
+      if (state->divergent_loop_continue)
+         return false;
+      if (state->divergent_loop_cf)
+         state->divergent_loop_continue = true;
+      return state->divergent_loop_continue;
+   case nir_jump_break:
+      if (state->divergent_loop_break)
+         return false;
+      if (state->divergent_loop_cf)
+         state->divergent_loop_break = true;
+      return state->divergent_loop_break;
+   case nir_jump_return:
+      unreachable("NIR divergence analysis: Unsupported return instruction.");
+      break;
+   case nir_jump_goto:
+   case nir_jump_goto_if:
+      unreachable("NIR divergence analysis: Unsupported goto_if instruction.");
+      break;
+   }
+   return false;
+}
+
+static bool
+set_ssa_def_not_divergent(nir_ssa_def *def, UNUSED void *_state)
+{
+   def->divergent = false;
+   return true;
+}
+
+static bool
+visit_block(nir_block *block, struct divergence_state *state)
 {
    bool has_changed = false;
 
    nir_foreach_instr(instr, block) {
+      /* phis are handled when processing the branches */
+      if (instr->type == nir_instr_type_phi)
+         continue;
+
+      if (state->first_visit)
+         nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL);
+
       switch (instr->type) {
       case nir_instr_type_alu:
-         has_changed |= visit_alu(divergent, nir_instr_as_alu(instr));
+         has_changed |= visit_alu(nir_instr_as_alu(instr));
          break;
       case nir_instr_type_intrinsic:
-         has_changed |= visit_intrinsic(divergent, nir_instr_as_intrinsic(instr),
-                                        options, stage);
+         has_changed |= visit_intrinsic(nir_instr_as_intrinsic(instr), state);
          break;
       case nir_instr_type_tex:
-         has_changed |= visit_tex(divergent, nir_instr_as_tex(instr));
-         break;
-      case nir_instr_type_phi:
-         has_changed |= visit_phi(divergent, nir_instr_as_phi(instr));
+         has_changed |= visit_tex(nir_instr_as_tex(instr));
          break;
       case nir_instr_type_load_const:
-         has_changed |= visit_load_const(divergent, nir_instr_as_load_const(instr));
+         has_changed |= visit_load_const(nir_instr_as_load_const(instr));
          break;
       case nir_instr_type_ssa_undef:
-         has_changed |= visit_ssa_undef(divergent, nir_instr_as_ssa_undef(instr));
+         has_changed |= visit_ssa_undef(nir_instr_as_ssa_undef(instr));
          break;
       case nir_instr_type_deref:
-         has_changed |= visit_deref(divergent, nir_instr_as_deref(instr),
-                                    options, stage);
+         has_changed |= visit_deref(nir_instr_as_deref(instr), state);
          break;
       case nir_instr_type_jump:
+         has_changed |= visit_jump(nir_instr_as_jump(instr), state);
          break;
+      case nir_instr_type_phi:
       case nir_instr_type_call:
       case nir_instr_type_parallel_copy:
          unreachable("NIR divergence analysis: Unsupported instruction type.");
@@ -737,47 +654,226 @@ visit_block(bool *divergent, nir_block *block, nir_divergence_options options,
    return has_changed;
 }
 
+/* There are 3 types of phi instructions:
+ * (1) gamma: represent the joining point of different paths
+ *     created by an “if-then-else” branch.
+ *     The resulting value is divergent if the branch condition
+ *     or any of the source values is divergent. */
 static bool
-visit_if(bool *divergent, nir_if *if_stmt, nir_divergence_options options, gl_shader_stage stage)
+visit_if_merge_phi(nir_phi_instr *phi, bool if_cond_divergent)
 {
-   return visit_cf_list(divergent, &if_stmt->then_list, options, stage) |
-          visit_cf_list(divergent, &if_stmt->else_list, options, stage);
+   if (phi->dest.ssa.divergent)
+      return false;
+
+   unsigned defined_srcs = 0;
+   nir_foreach_phi_src(src, phi) {
+      /* if any source value is divergent, the resulting value is divergent */
+      if (src->src.ssa->divergent) {
+         phi->dest.ssa.divergent = true;
+         return true;
+      }
+      if (src->src.ssa->parent_instr->type != nir_instr_type_ssa_undef) {
+         defined_srcs++;
+      }
+   }
+
+   /* if the condition is divergent and two sources defined, the definition is divergent */
+   if (defined_srcs > 1 && if_cond_divergent) {
+      phi->dest.ssa.divergent = true;
+      return true;
+   }
+
+   return false;
 }
 
+/* There are 3 types of phi instructions:
+ * (2) mu: which only exist at loop headers,
+ *     merge initial and loop-carried values.
+ *     The resulting value is divergent if any source value
+ *     is divergent or a divergent loop continue condition
+ *     is associated with a different ssa-def. */
 static bool
-visit_loop(bool *divergent, nir_loop *loop, nir_divergence_options options, gl_shader_stage stage)
+visit_loop_header_phi(nir_phi_instr *phi, nir_block *preheader, bool divergent_continue)
 {
-   bool has_changed = false;
-   bool repeat = true;
+   if (phi->dest.ssa.divergent)
+      return false;
 
-   /* TODO: restructure this and the phi handling more efficiently */
-   while (repeat) {
-      repeat = visit_cf_list(divergent, &loop->body, options, stage);
-      has_changed |= repeat;
+   nir_ssa_def* same = NULL;
+   nir_foreach_phi_src(src, phi) {
+      /* if any source value is divergent, the resulting value is divergent */
+      if (src->src.ssa->divergent) {
+         phi->dest.ssa.divergent = true;
+         return true;
+      }
+      /* if this loop is uniform, we're done here */
+      if (!divergent_continue)
+         continue;
+      /* skip the loop preheader */
+      if (src->pred == preheader)
+         continue;
+      /* skip undef values */
+      if (src->src.ssa->parent_instr->type == nir_instr_type_ssa_undef)
+         continue;
+
+      /* check if all loop-carried values are from the same ssa-def */
+      if (!same)
+         same = src->src.ssa;
+      else if (same != src->src.ssa) {
+         phi->dest.ssa.divergent = true;
+         return true;
+      }
    }
 
-   return has_changed;
+   return false;
+}
+
+/* There are 3 types of phi instructions:
+ * (3) eta: represent values that leave a loop.
+ *     The resulting value is divergent if the source value is divergent
+ *     or any loop exit condition is divergent for a value which is
+ *     not loop-invariant.
+ *     (note: there should be no phi for loop-invariant variables.) */
+static bool
+visit_loop_exit_phi(nir_phi_instr *phi, bool divergent_break)
+{
+   if (phi->dest.ssa.divergent)
+      return false;
+
+   if (divergent_break) {
+      phi->dest.ssa.divergent = true;
+      return true;
+   }
+
+   /* if any source value is divergent, the resulting value is divergent */
+   nir_foreach_phi_src(src, phi) {
+      if (src->src.ssa->divergent) {
+         phi->dest.ssa.divergent = true;
+         return true;
+      }
+   }
+
+   return false;
+}
+
+static bool
+visit_if(nir_if *if_stmt, struct divergence_state *state)
+{
+   bool progress = false;
+
+   struct divergence_state then_state = *state;
+   then_state.divergent_loop_cf |= if_stmt->condition.ssa->divergent;
+   progress |= visit_cf_list(&if_stmt->then_list, &then_state);
+
+   struct divergence_state else_state = *state;
+   else_state.divergent_loop_cf |= if_stmt->condition.ssa->divergent;
+   progress |= visit_cf_list(&if_stmt->else_list, &else_state);
+
+   /* handle phis after the IF */
+   nir_foreach_instr(instr, nir_cf_node_cf_tree_next(&if_stmt->cf_node)) {
+      if (instr->type != nir_instr_type_phi)
+         break;
+
+      if (state->first_visit)
+         nir_instr_as_phi(instr)->dest.ssa.divergent = false;
+      progress |= visit_if_merge_phi(nir_instr_as_phi(instr),
+                                     if_stmt->condition.ssa->divergent);
+   }
+
+   /* join loop divergence information from both branch legs */
+   state->divergent_loop_continue |= then_state.divergent_loop_continue ||
+                                     else_state.divergent_loop_continue;
+   state->divergent_loop_break |= then_state.divergent_loop_break ||
+                                  else_state.divergent_loop_break;
+
+   /* A divergent continue makes succeeding loop CF divergent:
+    * not all loop-active invocations participate in the remaining loop-body
+    * which means that a following break might be taken by some invocations, only */
+   state->divergent_loop_cf |= state->divergent_loop_continue;
+
+   return progress;
 }
 
 static bool
-visit_cf_list(bool *divergent, struct exec_list *list,
-              nir_divergence_options options, gl_shader_stage stage)
+visit_loop(nir_loop *loop, struct divergence_state *state)
+{
+   bool progress = false;
+   nir_block *loop_header = nir_loop_first_block(loop);
+   nir_block *loop_preheader = nir_block_cf_tree_prev(loop_header);
+
+   /* handle loop header phis first: we have no knowledge yet about
+    * the loop's control flow or any loop-carried sources. */
+   nir_foreach_instr(instr, loop_header) {
+      if (instr->type != nir_instr_type_phi)
+         break;
+
+      nir_phi_instr *phi = nir_instr_as_phi(instr);
+      if (!state->first_visit && phi->dest.ssa.divergent)
+         continue;
+
+      nir_foreach_phi_src(src, phi) {
+         if (src->pred == loop_preheader) {
+            phi->dest.ssa.divergent = src->src.ssa->divergent;
+            break;
+         }
+      }
+      progress |= phi->dest.ssa.divergent;
+   }
+
+   /* setup loop state */
+   struct divergence_state loop_state = *state;
+   loop_state.divergent_loop_cf = false;
+   loop_state.divergent_loop_continue = false;
+   loop_state.divergent_loop_break = false;
+
+   /* process loop body until no further changes are made */
+   bool repeat;
+   do {
+      progress |= visit_cf_list(&loop->body, &loop_state);
+      repeat = false;
+
+      /* revisit loop header phis to see if something has changed */
+      nir_foreach_instr(instr, loop_header) {
+         if (instr->type != nir_instr_type_phi)
+            break;
+
+         repeat |= visit_loop_header_phi(nir_instr_as_phi(instr),
+                                         loop_preheader,
+                                         loop_state.divergent_loop_continue);
+      }
+
+      loop_state.divergent_loop_cf = false;
+      loop_state.first_visit = false;
+   } while (repeat);
+
+   /* handle phis after the loop */
+   nir_foreach_instr(instr, nir_cf_node_cf_tree_next(&loop->cf_node)) {
+      if (instr->type != nir_instr_type_phi)
+         break;
+
+      if (state->first_visit)
+         nir_instr_as_phi(instr)->dest.ssa.divergent = false;
+      progress |= visit_loop_exit_phi(nir_instr_as_phi(instr),
+                                      loop_state.divergent_loop_break);
+   }
+
+   return progress;
+}
+
+static bool
+visit_cf_list(struct exec_list *list, struct divergence_state *state)
 {
    bool has_changed = false;
 
    foreach_list_typed(nir_cf_node, node, node, list) {
       switch (node->type) {
       case nir_cf_node_block:
-         has_changed |= visit_block(divergent, nir_cf_node_as_block(node),
-                                    options, stage);
+         has_changed |= visit_block(nir_cf_node_as_block(node), state);
          break;
       case nir_cf_node_if:
-         has_changed |= visit_if(divergent, nir_cf_node_as_if(node),
-                                 options, stage);
+         has_changed |= visit_if(nir_cf_node_as_if(node), state);
          break;
       case nir_cf_node_loop:
-         has_changed |= visit_loop(divergent, nir_cf_node_as_loop(node),
-                                   options, stage);
+         has_changed |= visit_loop(nir_cf_node_as_loop(node), state);
          break;
       case nir_cf_node_function:
          unreachable("NIR divergence analysis: Unsupported cf_node type.");
@@ -787,14 +883,18 @@ visit_cf_list(bool *divergent, struct exec_list *list,
    return has_changed;
 }
 
-
-bool*
+void
 nir_divergence_analysis(nir_shader *shader, nir_divergence_options options)
 {
-   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
-   bool *t = rzalloc_array(shader, bool, impl->ssa_alloc);
-
-   visit_cf_list(t, &impl->body, options, shader->info.stage);
-
-   return t;
+   struct divergence_state state = {
+      .options = options,
+      .stage = shader->info.stage,
+      .divergent_loop_cf = false,
+      .divergent_loop_continue = false,
+      .divergent_loop_break = false,
+      .first_visit = true,
+   };
+
+   visit_cf_list(&nir_shader_get_entrypoint(shader)->body, &state);
 }
+