nir: Add goto_if jump instruction
[mesa.git] / src / compiler / nir / nir_divergence_analysis.c
index b2f34b0dfb67ebefbb96a56995c3fbbcdb278c76..05892b440bfb508f0ac2a86d16c2447395f94bc6 100644 (file)
@@ -51,6 +51,9 @@ struct divergence_state {
    bool divergent_loop_continue;
    /* True if a divergent break happened since the loop header */
    bool divergent_loop_break;
+
+   /* True if we visit the block for the fist time */
+   bool first_visit;
 };
 
 static bool
@@ -142,6 +145,16 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
       else
          is_divergent = true;
       break;
+   case nir_intrinsic_load_per_vertex_input:
+      is_divergent = instr->src[0].ssa->divergent ||
+                     instr->src[1].ssa->divergent;
+      if (stage == MESA_SHADER_TESS_CTRL)
+         is_divergent |= !(options & nir_divergence_single_patch_per_tcs_subgroup);
+      if (stage == MESA_SHADER_TESS_EVAL)
+         is_divergent |= !(options & nir_divergence_single_patch_per_tes_subgroup);
+      else
+         is_divergent = true;
+      break;
    case nir_intrinsic_load_input_vertex:
       is_divergent = instr->src[1].ssa->divergent;
       assert(stage == MESA_SHADER_FRAGMENT);
@@ -155,6 +168,12 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
       else
          is_divergent = true;
       break;
+   case nir_intrinsic_load_per_vertex_output:
+      assert(stage == MESA_SHADER_TESS_CTRL);
+      is_divergent = instr->src[0].ssa->divergent ||
+                     instr->src[1].ssa->divergent ||
+                     !(options & nir_divergence_single_patch_per_tcs_subgroup);
+      break;
    case nir_intrinsic_load_layer_id:
    case nir_intrinsic_load_front_face:
       assert(stage == MESA_SHADER_FRAGMENT);
@@ -295,12 +314,11 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
    case nir_intrinsic_interp_deref_at_vertex:
    case nir_intrinsic_load_tess_coord:
    case nir_intrinsic_load_point_coord:
+   case nir_intrinsic_load_line_coord:
    case nir_intrinsic_load_frag_coord:
    case nir_intrinsic_load_sample_pos:
    case nir_intrinsic_load_vertex_id_zero_base:
    case nir_intrinsic_load_vertex_id:
-   case nir_intrinsic_load_per_vertex_input:
-   case nir_intrinsic_load_per_vertex_output:
    case nir_intrinsic_load_instance_id:
    case nir_intrinsic_load_invocation_id:
    case nir_intrinsic_load_local_invocation_id:
@@ -575,16 +593,35 @@ visit_jump(nir_jump_instr *jump, struct divergence_state *state)
       return state->divergent_loop_break;
    case nir_jump_return:
       unreachable("NIR divergence analysis: Unsupported return instruction.");
+      break;
+   case nir_jump_goto:
+   case nir_jump_goto_if:
+      unreachable("NIR divergence analysis: Unsupported goto_if instruction.");
+      break;
    }
    return false;
 }
 
+static bool
+set_ssa_def_not_divergent(nir_ssa_def *def, UNUSED void *_state)
+{
+   def->divergent = false;
+   return true;
+}
+
 static bool
 visit_block(nir_block *block, struct divergence_state *state)
 {
    bool has_changed = false;
 
    nir_foreach_instr(instr, block) {
+      /* phis are handled when processing the branches */
+      if (instr->type == nir_instr_type_phi)
+         continue;
+
+      if (state->first_visit)
+         nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL);
+
       switch (instr->type) {
       case nir_instr_type_alu:
          has_changed |= visit_alu(nir_instr_as_alu(instr));
@@ -607,9 +644,7 @@ visit_block(nir_block *block, struct divergence_state *state)
       case nir_instr_type_jump:
          has_changed |= visit_jump(nir_instr_as_jump(instr), state);
          break;
-      /* phis are handled when processing the branches */
       case nir_instr_type_phi:
-         break;
       case nir_instr_type_call:
       case nir_instr_type_parallel_copy:
          unreachable("NIR divergence analysis: Unsupported instruction type.");
@@ -738,6 +773,8 @@ visit_if(nir_if *if_stmt, struct divergence_state *state)
       if (instr->type != nir_instr_type_phi)
          break;
 
+      if (state->first_visit)
+         nir_instr_as_phi(instr)->dest.ssa.divergent = false;
       progress |= visit_if_merge_phi(nir_instr_as_phi(instr),
                                      if_stmt->condition.ssa->divergent);
    }
@@ -763,14 +800,23 @@ visit_loop(nir_loop *loop, struct divergence_state *state)
    nir_block *loop_header = nir_loop_first_block(loop);
    nir_block *loop_preheader = nir_block_cf_tree_prev(loop_header);
 
-   /* handle loop header phis first */
+   /* handle loop header phis first: we have no knowledge yet about
+    * the loop's control flow or any loop-carried sources. */
    nir_foreach_instr(instr, loop_header) {
       if (instr->type != nir_instr_type_phi)
          break;
 
-      progress |= visit_loop_header_phi(nir_instr_as_phi(instr),
-                                        loop_preheader, false);
+      nir_phi_instr *phi = nir_instr_as_phi(instr);
+      if (!state->first_visit && phi->dest.ssa.divergent)
+         continue;
 
+      nir_foreach_phi_src(src, phi) {
+         if (src->pred == loop_preheader) {
+            phi->dest.ssa.divergent = src->src.ssa->divergent;
+            break;
+         }
+      }
+      progress |= phi->dest.ssa.divergent;
    }
 
    /* setup loop state */
@@ -796,6 +842,7 @@ visit_loop(nir_loop *loop, struct divergence_state *state)
       }
 
       loop_state.divergent_loop_cf = false;
+      loop_state.first_visit = false;
    } while (repeat);
 
    /* handle phis after the loop */
@@ -803,6 +850,8 @@ visit_loop(nir_loop *loop, struct divergence_state *state)
       if (instr->type != nir_instr_type_phi)
          break;
 
+      if (state->first_visit)
+         nir_instr_as_phi(instr)->dest.ssa.divergent = false;
       progress |= visit_loop_exit_phi(nir_instr_as_phi(instr),
                                       loop_state.divergent_loop_break);
    }
@@ -834,32 +883,18 @@ visit_cf_list(struct exec_list *list, struct divergence_state *state)
    return has_changed;
 }
 
-static bool
-set_ssa_def_not_divergent(nir_ssa_def *def, UNUSED void *_state)
-{
-   def->divergent = false;
-   return true;
-}
-
 void
 nir_divergence_analysis(nir_shader *shader, nir_divergence_options options)
 {
-   nir_function_impl *impl = nir_shader_get_entrypoint(shader);
-
-   /* Set all SSA defs to non-divergent to start off */
-   nir_foreach_block(block, impl) {
-      nir_foreach_instr(instr, block)
-         nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL);
-   }
-
    struct divergence_state state = {
       .options = options,
       .stage = shader->info.stage,
       .divergent_loop_cf = false,
       .divergent_loop_continue = false,
       .divergent_loop_break = false,
+      .first_visit = true,
    };
 
-   visit_cf_list(&impl->body, &state);
+   visit_cf_list(&nir_shader_get_entrypoint(shader)->body, &state);
 }