nir: Make "divergent" a property of an SSA value
authorJason Ekstrand <jason@jlekstrand.net>
Tue, 15 Oct 2019 19:48:10 +0000 (14:48 -0500)
committerMarge Bot <eric+marge@anholt.net>
Wed, 13 May 2020 18:49:22 +0000 (18:49 +0000)
v2: fix usage in ACO (by Daniel Schürmann)

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4062>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/compiler/nir/nir.c
src/compiler/nir/nir.h
src/compiler/nir/nir_divergence_analysis.c

index 65cf38f57e9851e0783db430f324906a65e2484d..fd018622a4b171a4110205de698f6160396129fd 100644 (file)
@@ -711,9 +711,8 @@ void emit_comparison(isel_context *ctx, nir_alu_instr *instr, Temp dst,
 {
    aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : instr->src[0].src.ssa->bit_size == 32 ? s32_op : aco_opcode::num_opcodes;
    aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : instr->src[0].src.ssa->bit_size == 32 ? v32_op : v16_op;
-   bool divergent_vals = ctx->divergent_vals[instr->dest.dest.ssa.index];
    bool use_valu = s_op == aco_opcode::num_opcodes ||
-                   divergent_vals ||
+                   nir_dest_is_divergent(instr->dest.dest) ||
                    ctx->allocated[instr->src[0].src.ssa->index].type() == RegType::vgpr ||
                    ctx->allocated[instr->src[1].src.ssa->index].type() == RegType::vgpr;
    aco_opcode op = use_valu ? v_op : s_op;
@@ -779,7 +778,7 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
       assert(els.regClass() == bld.lm);
    }
 
-   if (!ctx->divergent_vals[instr->src[0].src.ssa->index]) { /* uniform condition and values in sgpr */
+   if (!nir_src_is_divergent(instr->src[0].src)) { /* uniform condition and values in sgpr */
       if (dst.regClass() == s1 || dst.regClass() == s2) {
          assert((then.regClass() == s1 || then.regClass() == s2) && els.regClass() == then.regClass());
          assert(dst.size() == then.size());
@@ -5010,7 +5009,7 @@ void visit_load_resource(isel_context *ctx, nir_intrinsic_instr *instr)
 {
    Builder bld(ctx->program, ctx->block);
    Temp index = get_ssa_temp(ctx, instr->src[0].ssa);
-   if (!ctx->divergent_vals[instr->dest.ssa.index])
+   if (!nir_dest_is_divergent(instr->dest))
       index = bld.as_uniform(index);
    unsigned desc_set = nir_intrinsic_desc_set(instr);
    unsigned binding = nir_intrinsic_binding(instr);
@@ -6086,7 +6085,7 @@ void visit_store_ssbo(isel_context *ctx, nir_intrinsic_instr *instr)
    Temp rsrc = convert_pointer_to_64_bit(ctx, get_ssa_temp(ctx, instr->src[1].ssa));
    rsrc = bld.smem(aco_opcode::s_load_dwordx4, bld.def(s4), rsrc, Operand(0u));
 
-   bool smem = !ctx->divergent_vals[instr->src[2].ssa->index] &&
+   bool smem = !nir_src_is_divergent(instr->src[2]) &&
                ctx->options->chip_class >= GFX8 &&
                elem_size_bytes >= 4;
    if (smem)
@@ -7477,11 +7476,11 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
    case nir_intrinsic_shuffle:
    case nir_intrinsic_read_invocation: {
       Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
-      if (!ctx->divergent_vals[instr->src[0].ssa->index]) {
+      if (!nir_src_is_divergent(instr->src[0])) {
          emit_uniform_subgroup(ctx, instr, src);
       } else {
          Temp tid = get_ssa_temp(ctx, instr->src[1].ssa);
-         if (instr->intrinsic == nir_intrinsic_read_invocation || !ctx->divergent_vals[instr->src[1].ssa->index])
+         if (instr->intrinsic == nir_intrinsic_read_invocation || !nir_src_is_divergent(instr->src[1]))
             tid = bld.as_uniform(tid);
          Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
          if (src.regClass() == v1) {
@@ -7587,7 +7586,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
          nir_intrinsic_cluster_size(instr) : 0;
       cluster_size = util_next_power_of_two(MIN2(cluster_size ? cluster_size : ctx->program->wave_size, ctx->program->wave_size));
 
-      if (!ctx->divergent_vals[instr->src[0].ssa->index] && (op == nir_op_ior || op == nir_op_iand)) {
+      if (!nir_src_is_divergent(instr->src[0]) && (op == nir_op_ior || op == nir_op_iand)) {
          emit_uniform_subgroup(ctx, instr, src);
       } else if (instr->dest.ssa.bit_size == 1) {
          if (op == nir_op_imul || op == nir_op_umin || op == nir_op_imin)
@@ -7670,7 +7669,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
    }
    case nir_intrinsic_quad_broadcast: {
       Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
-      if (!ctx->divergent_vals[instr->dest.ssa.index]) {
+      if (!nir_dest_is_divergent(instr->dest)) {
          emit_uniform_subgroup(ctx, instr, src);
       } else {
          Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
@@ -7717,7 +7716,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
    case nir_intrinsic_quad_swap_diagonal:
    case nir_intrinsic_quad_swizzle_amd: {
       Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
-      if (!ctx->divergent_vals[instr->dest.ssa.index]) {
+      if (!nir_dest_is_divergent(instr->dest)) {
          emit_uniform_subgroup(ctx, instr, src);
          break;
       }
@@ -7779,7 +7778,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
    }
    case nir_intrinsic_masked_swizzle_amd: {
       Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
-      if (!ctx->divergent_vals[instr->dest.ssa.index]) {
+      if (!nir_dest_is_divergent(instr->dest)) {
          emit_uniform_subgroup(ctx, instr, src);
          break;
       }
@@ -8774,7 +8773,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
    Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
    assert(instr->dest.ssa.bit_size != 1 || dst.regClass() == ctx->program->lane_mask);
 
-   bool logical = !dst.is_linear() || ctx->divergent_vals[instr->dest.ssa.index];
+   bool logical = !dst.is_linear() || nir_dest_is_divergent(instr->dest);
    logical |= ctx->block->kind & block_kind_merge;
    aco_opcode opcode = logical ? aco_opcode::p_phi : aco_opcode::p_linear_phi;
 
@@ -9468,7 +9467,7 @@ static bool visit_if(isel_context *ctx, nir_if *if_stmt)
    aco_ptr<Pseudo_branch_instruction> branch;
    if_context ic;
 
-   if (!ctx->divergent_vals[if_stmt->condition.ssa->index]) { /* uniform condition */
+   if (!nir_src_is_divergent(if_stmt->condition)) { /* uniform condition */
       /**
        * Uniform conditionals are represented in the following way*) :
        *
@@ -10683,8 +10682,6 @@ void select_program(Program *program,
       if (ngg_no_gs && !ngg_early_prim_export(&ctx))
          ngg_emit_nogs_output(&ctx);
 
-      ralloc_free(ctx.divergent_vals);
-
       if (i == 0 && ctx.stage == vertex_tess_control_hs && ctx.tcs_in_out_eq) {
          /* Outputs of the previous stage are inputs to the next stage */
          ctx.inputs = ctx.outputs;
index 44659b462294cf40f293a580d5f58100d8851180..90a92232343f706c3c8f3742ac65b13d96af4fd7 100644 (file)
@@ -57,7 +57,6 @@ struct isel_context {
    nir_shader *shader;
    uint32_t constant_data_offset;
    Block *block;
-   bool *divergent_vals;
    std::unique_ptr<Temp[]> allocated;
    std::unordered_map<unsigned, std::array<Temp,NIR_MAX_VEC_COMPONENTS>> allocated_vec;
    Stage stage; /* Stage */
@@ -152,7 +151,7 @@ unsigned get_interp_input(nir_intrinsic_op intrin, enum glsl_interp_mode interp)
  * block instead. This is so that we can use any SGPR live-out of the side
  * without the branch without creating a linear phi in the invert or merge block. */
 bool
-sanitize_if(nir_function_impl *impl, bool *divergent, nir_if *nif)
+sanitize_if(nir_function_impl *impl, nir_if *nif)
 {
    //TODO: skip this if the condition is uniform and there are no divergent breaks/continues?
 
@@ -197,7 +196,7 @@ sanitize_if(nir_function_impl *impl, bool *divergent, nir_if *nif)
 }
 
 bool
-sanitize_cf_list(nir_function_impl *impl, bool *divergent, struct exec_list *cf_list)
+sanitize_cf_list(nir_function_impl *impl, struct exec_list *cf_list)
 {
    bool progress = false;
    foreach_list_typed(nir_cf_node, cf_node, node, cf_list) {
@@ -206,14 +205,14 @@ sanitize_cf_list(nir_function_impl *impl, bool *divergent, struct exec_list *cf_
          break;
       case nir_cf_node_if: {
          nir_if *nif = nir_cf_node_as_if(cf_node);
-         progress |= sanitize_cf_list(impl, divergent, &nif->then_list);
-         progress |= sanitize_cf_list(impl, divergent, &nif->else_list);
-         progress |= sanitize_if(impl, divergent, nif);
+         progress |= sanitize_cf_list(impl, &nif->then_list);
+         progress |= sanitize_cf_list(impl, &nif->else_list);
+         progress |= sanitize_if(impl, nif);
          break;
       }
       case nir_cf_node_loop: {
          nir_loop *loop = nir_cf_node_as_loop(cf_node);
-         progress |= sanitize_cf_list(impl, divergent, &loop->body);
+         progress |= sanitize_cf_list(impl, &loop->body);
          break;
       }
       case nir_cf_node_function:
@@ -238,11 +237,11 @@ void init_context(isel_context *ctx, nir_shader *shader)
    unsigned lane_mask_size = ctx->program->lane_mask.size();
 
    ctx->shader = shader;
-   ctx->divergent_vals = nir_divergence_analysis(shader, nir_divergence_view_index_uniform);
+   nir_divergence_analysis(shader, nir_divergence_view_index_uniform);
 
    /* sanitize control flow */
    nir_metadata_require(impl, nir_metadata_dominance);
-   sanitize_cf_list(impl, ctx->divergent_vals, &impl->body);
+   sanitize_cf_list(impl, &impl->body);
    nir_metadata_preserve(impl, (nir_metadata)~nir_metadata_block_index);
 
    /* we'll need this for isel */
@@ -332,10 +331,10 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_op_b2f16:
                   case nir_op_b2f32:
                   case nir_op_mov:
-                     type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr;
+                     type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr;
                      break;
                   case nir_op_bcsel:
-                     type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr;
+                     type = nir_dest_is_divergent(alu_instr->dest.dest) ? RegType::vgpr : RegType::sgpr;
                      /* fallthrough */
                   default:
                      for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
@@ -465,7 +464,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_intrinsic_load_global:
                   case nir_intrinsic_vulkan_resource_index:
                   case nir_intrinsic_load_shared:
-                     type = ctx->divergent_vals[intrinsic->dest.ssa.index] ? RegType::vgpr : RegType::sgpr;
+                     type = nir_dest_is_divergent(intrinsic->dest) ? RegType::vgpr : RegType::sgpr;
                      break;
                   case nir_intrinsic_load_view_index:
                      type = ctx->stage == fragment_fs ? RegType::vgpr : RegType::sgpr;
@@ -524,9 +523,10 @@ void init_context(isel_context *ctx, nir_shader *shader)
 
                if (tex->dest.ssa.bit_size == 64)
                   size *= 2;
-               if (tex->op == nir_texop_texture_samples)
-                  assert(!ctx->divergent_vals[tex->dest.ssa.index]);
-               if (ctx->divergent_vals[tex->dest.ssa.index])
+               if (tex->op == nir_texop_texture_samples) {
+                  assert(!tex->dest.ssa.divergent);
+               }
+               if (nir_dest_is_divergent(tex->dest))
                   allocated[tex->dest.ssa.index] = Temp(0, RegClass(RegType::vgpr, size));
                else
                   allocated[tex->dest.ssa.index] = Temp(0, RegClass(RegType::sgpr, size));
@@ -558,7 +558,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   break;
                }
 
-               if (ctx->divergent_vals[phi->dest.ssa.index]) {
+               if (nir_dest_is_divergent(phi->dest)) {
                   type = RegType::vgpr;
                } else {
                   type = RegType::sgpr;
index e3569d1f80b2fed50ff095f78f2d4788c1197e9a..0f64c458535f0ffd6595db05ab50e9fa200040c7 100644 (file)
@@ -1451,6 +1451,7 @@ nir_ssa_def_init(nir_instr *instr, nir_ssa_def *def,
    list_inithead(&def->if_uses);
    def->num_components = num_components;
    def->bit_size = bit_size;
+   def->divergent = true; /* This is the safer default */
 
    if (instr->block) {
       nir_function_impl *impl =
index 25caa370f13277b32d06f237cbb2133a5f3db69d..281422fc6cf2ecf4361be9d3e115d59dce0e043f 100644 (file)
@@ -740,6 +740,12 @@ typedef struct nir_ssa_def {
 
    /* The bit-size of each channel; must be one of 8, 16, 32, or 64 */
    uint8_t bit_size;
+
+   /**
+    * True if this SSA value may have different values in different SIMD
+    * invocations of the shader.  This is set by nir_divergence_analysis.
+    */
+   bool divergent;
 } nir_ssa_def;
 
 struct nir_src;
@@ -880,6 +886,13 @@ nir_src_is_const(nir_src src)
           src.ssa->parent_instr->type == nir_instr_type_load_const;
 }
 
+static inline bool
+nir_src_is_divergent(nir_src src)
+{
+   assert(src.is_ssa);
+   return src.ssa->divergent;
+}
+
 static inline unsigned
 nir_dest_bit_size(nir_dest dest)
 {
@@ -892,6 +905,13 @@ nir_dest_num_components(nir_dest dest)
    return dest.is_ssa ? dest.ssa.num_components : dest.reg.reg->num_components;
 }
 
+static inline bool
+nir_dest_is_divergent(nir_dest dest)
+{
+   assert(dest.is_ssa);
+   return dest.ssa.divergent;
+}
+
 /* Are all components the same, ie. .xxxx */
 static inline bool
 nir_is_same_comp_swizzle(uint8_t *swiz, unsigned nr_comp)
@@ -4321,7 +4341,7 @@ bool nir_repair_ssa(nir_shader *shader);
 
 void nir_convert_loop_to_lcssa(nir_loop *loop);
 bool nir_convert_to_lcssa(nir_shader *shader, bool skip_invariants, bool skip_bool_invariants);
-bool* nir_divergence_analysis(nir_shader *shader, nir_divergence_options options);
+void nir_divergence_analysis(nir_shader *shader, nir_divergence_options options);
 
 /* If phi_webs_only is true, only convert SSA values involved in phi nodes to
  * registers.  If false, convert all values (even those not involved in a phi
index bfa12d1a57d1fb940fd684f7d094439f305f0f38..03201e0faf8ec21eb04010e779d164b37fb4e644 100644 (file)
  */
 
 static bool
-visit_cf_list(bool *divergent, struct exec_list *list,
+visit_cf_list(struct exec_list *list,
               nir_divergence_options options, gl_shader_stage stage);
 
 static bool
-visit_alu(bool *divergent, nir_alu_instr *instr)
+visit_alu(nir_alu_instr *instr)
 {
-   if (divergent[instr->dest.dest.ssa.index])
+   if (instr->dest.dest.ssa.divergent)
       return false;
 
    unsigned num_src = nir_op_infos[instr->op].num_inputs;
 
    for (unsigned i = 0; i < num_src; i++) {
-      if (divergent[instr->src[i].src.ssa->index]) {
-         divergent[instr->dest.dest.ssa.index] = true;
+      if (instr->src[i].src.ssa->divergent) {
+         instr->dest.dest.ssa.divergent = true;
          return true;
       }
    }
@@ -59,13 +59,13 @@ visit_alu(bool *divergent, nir_alu_instr *instr)
 }
 
 static bool
-visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
+visit_intrinsic(nir_intrinsic_instr *instr,
                 nir_divergence_options options, gl_shader_stage stage)
 {
    if (!nir_intrinsic_infos[instr->intrinsic].has_dest)
       return false;
 
-   if (divergent[instr->dest.ssa.index])
+   if (instr->dest.ssa.divergent)
       return false;
 
    bool is_divergent = false;
@@ -117,7 +117,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
 
    /* Intrinsics with divergence depending on shader stage and hardware */
    case nir_intrinsic_load_input:
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       if (stage == MESA_SHADER_FRAGMENT)
          is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
       else if (stage == MESA_SHADER_TESS_EVAL)
@@ -126,13 +126,13 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
          is_divergent = true;
       break;
    case nir_intrinsic_load_input_vertex:
-      is_divergent = divergent[instr->src[1].ssa->index];
+      is_divergent = instr->src[1].ssa->divergent;
       assert(stage == MESA_SHADER_FRAGMENT);
       is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
       break;
    case nir_intrinsic_load_output:
       assert(stage == MESA_SHADER_TESS_CTRL || stage == MESA_SHADER_FRAGMENT);
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       if (stage == MESA_SHADER_TESS_CTRL)
          is_divergent |= !(options & nir_divergence_single_patch_per_tcs_subgroup);
       else
@@ -152,7 +152,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
       break;
    case nir_intrinsic_load_fs_input_interp_deltas:
       assert(stage == MESA_SHADER_FRAGMENT);
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
       break;
    case nir_intrinsic_load_primitive_id:
@@ -194,7 +194,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
       /* fallthrough */
    case nir_intrinsic_inclusive_scan: {
       nir_op op = nir_intrinsic_reduction_op(instr);
-      is_divergent = divergent[instr->src[0].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent;
       if (op != nir_op_umin && op != nir_op_imin && op != nir_op_fmin &&
           op != nir_op_umax && op != nir_op_imax && op != nir_op_fmax &&
           op != nir_op_iand && op != nir_op_ior)
@@ -245,7 +245,7 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
    case nir_intrinsic_masked_swizzle_amd: {
       unsigned num_srcs = nir_intrinsic_infos[instr->intrinsic].num_srcs;
       for (unsigned i = 0; i < num_srcs; i++) {
-         if (divergent[instr->src[i].ssa->index]) {
+         if (instr->src[i].ssa->divergent) {
             is_divergent = true;
             break;
          }
@@ -254,8 +254,8 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
    }
 
    case nir_intrinsic_shuffle:
-      is_divergent = divergent[instr->src[0].ssa->index] &&
-                     divergent[instr->src[1].ssa->index];
+      is_divergent = instr->src[0].ssa->divergent &&
+                     instr->src[1].ssa->divergent;
       break;
 
    /* Intrinsics which are always divergent */
@@ -423,14 +423,14 @@ visit_intrinsic(bool *divergent, nir_intrinsic_instr *instr,
 #endif
    }
 
-   divergent[instr->dest.ssa.index] = is_divergent;
+   instr->dest.ssa.divergent = is_divergent;
    return is_divergent;
 }
 
 static bool
-visit_tex(bool *divergent, nir_tex_instr *instr)
+visit_tex(nir_tex_instr *instr)
 {
-   if (divergent[instr->dest.ssa.index])
+   if (instr->dest.ssa.divergent)
       return false;
 
    bool is_divergent = false;
@@ -440,27 +440,27 @@ visit_tex(bool *divergent, nir_tex_instr *instr)
       case nir_tex_src_sampler_deref:
       case nir_tex_src_sampler_handle:
       case nir_tex_src_sampler_offset:
-         is_divergent |= divergent[instr->src[i].src.ssa->index] &&
+         is_divergent |= instr->src[i].src.ssa->divergent &&
                          instr->sampler_non_uniform;
          break;
       case nir_tex_src_texture_deref:
       case nir_tex_src_texture_handle:
       case nir_tex_src_texture_offset:
-         is_divergent |= divergent[instr->src[i].src.ssa->index] &&
+         is_divergent |= instr->src[i].src.ssa->divergent &&
                          instr->texture_non_uniform;
          break;
       default:
-         is_divergent |= divergent[instr->src[i].src.ssa->index];
+         is_divergent |= instr->src[i].src.ssa->divergent;
          break;
       }
    }
 
-   divergent[instr->dest.ssa.index] = is_divergent;
+   instr->dest.ssa.divergent = is_divergent;
    return is_divergent;
 }
 
 static bool
-visit_phi(bool *divergent, nir_phi_instr *instr)
+visit_phi(nir_phi_instr *instr)
 {
    /* There are 3 types of phi instructions:
     * (1) gamma: represent the joining point of different paths
@@ -481,13 +481,13 @@ visit_phi(bool *divergent, nir_phi_instr *instr)
     *     (note: there should be no phi for loop-invariant variables.)
     */
 
-   if (divergent[instr->dest.ssa.index])
+   if (instr->dest.ssa.divergent)
       return false;
 
    nir_foreach_phi_src(src, instr) {
       /* if any source value is divergent, the resulting value is divergent */
-      if (divergent[src->src.ssa->index]) {
-         divergent[instr->dest.ssa.index] = true;
+      if (src->src.ssa->divergent) {
+         instr->dest.ssa.divergent = true;
          return true;
       }
    }
@@ -537,8 +537,8 @@ visit_phi(bool *divergent, nir_phi_instr *instr)
          while (current->type != nir_cf_node_loop) {
             assert (current->type == nir_cf_node_if);
             nir_if *if_node = nir_cf_node_as_if(current);
-            if (divergent[if_node->condition.ssa->index]) {
-               divergent[instr->dest.ssa.index] = true;
+            if (if_node->condition.ssa->divergent) {
+               instr->dest.ssa.divergent = true;
                return true;
             }
             current = current->parent;
@@ -558,8 +558,8 @@ visit_phi(bool *divergent, nir_phi_instr *instr)
 
       /* gamma: check if the condition is divergent */
       nir_if *if_node = nir_cf_node_as_if(prev);
-      if (divergent[if_node->condition.ssa->index]) {
-         divergent[instr->dest.ssa.index] = true;
+      if (if_node->condition.ssa->divergent) {
+         instr->dest.ssa.divergent = true;
          return true;
       }
 
@@ -578,8 +578,8 @@ visit_phi(bool *divergent, nir_phi_instr *instr)
          while (current->type != nir_cf_node_loop) {
             assert(current->type == nir_cf_node_if);
             nir_if *if_node = nir_cf_node_as_if(current);
-            if (divergent[if_node->condition.ssa->index]) {
-               divergent[instr->dest.ssa.index] = true;
+            if (if_node->condition.ssa->divergent) {
+               instr->dest.ssa.divergent = true;
                return true;
             }
             current = current->parent;
@@ -607,12 +607,12 @@ visit_phi(bool *divergent, nir_phi_instr *instr)
                }
                assert(current->type == nir_cf_node_if);
                nir_if *if_node = nir_cf_node_as_if(current);
-               is_divergent |= divergent[if_node->condition.ssa->index];
+               is_divergent |= if_node->condition.ssa->divergent;
                current = current->parent;
             }
 
             if (is_divergent) {
-               divergent[instr->dest.ssa.index] = true;
+               instr->dest.ssa.divergent = true;
                return true;
             }
          }
@@ -623,13 +623,13 @@ visit_phi(bool *divergent, nir_phi_instr *instr)
 }
 
 static bool
-visit_load_const(bool *divergent, nir_load_const_instr *instr)
+visit_load_const(nir_load_const_instr *instr)
 {
    return false;
 }
 
 static bool
-visit_ssa_undef(bool *divergent, nir_ssa_undef_instr *instr)
+visit_ssa_undef(nir_ssa_undef_instr *instr)
 {
    return false;
 }
@@ -675,10 +675,10 @@ nir_variable_is_uniform(nir_variable *var, nir_divergence_options options,
 }
 
 static bool
-visit_deref(bool *divergent, nir_deref_instr *deref,
+visit_deref(nir_deref_instr *deref,
             nir_divergence_options options, gl_shader_stage stage)
 {
-   if (divergent[deref->dest.ssa.index])
+   if (deref->dest.ssa.divergent)
       return false;
 
    bool is_divergent = false;
@@ -688,24 +688,24 @@ visit_deref(bool *divergent, nir_deref_instr *deref,
       break;
    case nir_deref_type_array:
    case nir_deref_type_ptr_as_array:
-      is_divergent = divergent[deref->arr.index.ssa->index];
+      is_divergent = deref->arr.index.ssa->divergent;
       /* fallthrough */
    case nir_deref_type_struct:
    case nir_deref_type_array_wildcard:
-      is_divergent |= divergent[deref->parent.ssa->index];
+      is_divergent |= deref->parent.ssa->divergent;
       break;
    case nir_deref_type_cast:
       is_divergent = !nir_variable_mode_is_uniform(deref->var->data.mode) ||
-                     divergent[deref->parent.ssa->index];
+                     deref->parent.ssa->divergent;
       break;
    }
 
-   divergent[deref->dest.ssa.index] = is_divergent;
+   deref->dest.ssa.divergent = is_divergent;
    return is_divergent;
 }
 
 static bool
-visit_block(bool *divergent, nir_block *block, nir_divergence_options options,
+visit_block(nir_block *block, nir_divergence_options options,
             gl_shader_stage stage)
 {
    bool has_changed = false;
@@ -713,26 +713,26 @@ visit_block(bool *divergent, nir_block *block, nir_divergence_options options,
    nir_foreach_instr(instr, block) {
       switch (instr->type) {
       case nir_instr_type_alu:
-         has_changed |= visit_alu(divergent, nir_instr_as_alu(instr));
+         has_changed |= visit_alu(nir_instr_as_alu(instr));
          break;
       case nir_instr_type_intrinsic:
-         has_changed |= visit_intrinsic(divergent, nir_instr_as_intrinsic(instr),
+         has_changed |= visit_intrinsic(nir_instr_as_intrinsic(instr),
                                         options, stage);
          break;
       case nir_instr_type_tex:
-         has_changed |= visit_tex(divergent, nir_instr_as_tex(instr));
+         has_changed |= visit_tex(nir_instr_as_tex(instr));
          break;
       case nir_instr_type_phi:
-         has_changed |= visit_phi(divergent, nir_instr_as_phi(instr));
+         has_changed |= visit_phi(nir_instr_as_phi(instr));
          break;
       case nir_instr_type_load_const:
-         has_changed |= visit_load_const(divergent, nir_instr_as_load_const(instr));
+         has_changed |= visit_load_const(nir_instr_as_load_const(instr));
          break;
       case nir_instr_type_ssa_undef:
-         has_changed |= visit_ssa_undef(divergent, nir_instr_as_ssa_undef(instr));
+         has_changed |= visit_ssa_undef(nir_instr_as_ssa_undef(instr));
          break;
       case nir_instr_type_deref:
-         has_changed |= visit_deref(divergent, nir_instr_as_deref(instr),
+         has_changed |= visit_deref(nir_instr_as_deref(instr),
                                     options, stage);
          break;
       case nir_instr_type_jump:
@@ -747,21 +747,21 @@ visit_block(bool *divergent, nir_block *block, nir_divergence_options options,
 }
 
 static bool
-visit_if(bool *divergent, nir_if *if_stmt, nir_divergence_options options, gl_shader_stage stage)
+visit_if(nir_if *if_stmt, nir_divergence_options options, gl_shader_stage stage)
 {
-   return visit_cf_list(divergent, &if_stmt->then_list, options, stage) |
-          visit_cf_list(divergent, &if_stmt->else_list, options, stage);
+   return visit_cf_list(&if_stmt->then_list, options, stage) |
+          visit_cf_list(&if_stmt->else_list, options, stage);
 }
 
 static bool
-visit_loop(bool *divergent, nir_loop *loop, nir_divergence_options options, gl_shader_stage stage)
+visit_loop(nir_loop *loop, nir_divergence_options options, gl_shader_stage stage)
 {
    bool has_changed = false;
    bool repeat = true;
 
    /* TODO: restructure this and the phi handling more efficiently */
    while (repeat) {
-      repeat = visit_cf_list(divergent, &loop->body, options, stage);
+      repeat = visit_cf_list(&loop->body, options, stage);
       has_changed |= repeat;
    }
 
@@ -769,7 +769,7 @@ visit_loop(bool *divergent, nir_loop *loop, nir_divergence_options options, gl_s
 }
 
 static bool
-visit_cf_list(bool *divergent, struct exec_list *list,
+visit_cf_list(struct exec_list *list,
               nir_divergence_options options, gl_shader_stage stage)
 {
    bool has_changed = false;
@@ -777,15 +777,15 @@ visit_cf_list(bool *divergent, struct exec_list *list,
    foreach_list_typed(nir_cf_node, node, node, list) {
       switch (node->type) {
       case nir_cf_node_block:
-         has_changed |= visit_block(divergent, nir_cf_node_as_block(node),
+         has_changed |= visit_block(nir_cf_node_as_block(node),
                                     options, stage);
          break;
       case nir_cf_node_if:
-         has_changed |= visit_if(divergent, nir_cf_node_as_if(node),
+         has_changed |= visit_if(nir_cf_node_as_if(node),
                                  options, stage);
          break;
       case nir_cf_node_loop:
-         has_changed |= visit_loop(divergent, nir_cf_node_as_loop(node),
+         has_changed |= visit_loop(nir_cf_node_as_loop(node),
                                    options, stage);
          break;
       case nir_cf_node_function:
@@ -796,14 +796,23 @@ visit_cf_list(bool *divergent, struct exec_list *list,
    return has_changed;
 }
 
+static bool
+set_ssa_def_not_divergent(nir_ssa_def *def, UNUSED void *_state)
+{
+   def->divergent = false;
+   return true;
+}
 
-bool*
+void
 nir_divergence_analysis(nir_shader *shader, nir_divergence_options options)
 {
    nir_function_impl *impl = nir_shader_get_entrypoint(shader);
-   bool *t = rzalloc_array(shader, bool, impl->ssa_alloc);
 
-   visit_cf_list(t, &impl->body, options, shader->info.stage);
+   /* Set all SSA defs to non-divergent to start off */
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr(instr, block)
+         nir_foreach_ssa_def(instr, set_ssa_def_not_divergent, NULL);
+   }
 
-   return t;
+   visit_cf_list(&impl->body, options, shader->info.stage);
 }