radv: Implement buffer stores with less than 4 components.
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index d79ff4399b1fe42f9071393b331f479e9aa28b45..8dea35178b36827195ba46c2e246384aae4d3982 100644 (file)
@@ -270,8 +270,9 @@ static LLVMValueRef emit_bcsel(struct ac_llvm_context *ctx,
 {
        LLVMValueRef v = LLVMBuildICmp(ctx->builder, LLVMIntNE, src0,
                                       ctx->i32_0, "");
-       return LLVMBuildSelect(ctx->builder, v, ac_to_integer(ctx, src1),
-                              ac_to_integer(ctx, src2), "");
+       return LLVMBuildSelect(ctx->builder, v,
+                              ac_to_integer_or_pointer(ctx, src1),
+                              ac_to_integer_or_pointer(ctx, src2), "");
 }
 
 static LLVMValueRef emit_minmax_int(struct ac_llvm_context *ctx,
@@ -311,9 +312,18 @@ static LLVMValueRef emit_uint_carry(struct ac_llvm_context *ctx,
 }
 
 static LLVMValueRef emit_b2f(struct ac_llvm_context *ctx,
-                            LLVMValueRef src0)
+                            LLVMValueRef src0,
+                            unsigned bitsize)
 {
-       return LLVMBuildAnd(ctx->builder, src0, LLVMBuildBitCast(ctx->builder, LLVMConstReal(ctx->f32, 1.0), ctx->i32, ""), "");
+       LLVMValueRef result = LLVMBuildAnd(ctx->builder, src0,
+                                          LLVMBuildBitCast(ctx->builder, LLVMConstReal(ctx->f32, 1.0), ctx->i32, ""),
+                                          "");
+       result = LLVMBuildBitCast(ctx->builder, result, ctx->f32, "");
+
+       if (bitsize == 32)
+               return result;
+
+       return LLVMBuildFPExt(ctx->builder, result, ctx->f64, "");
 }
 
 static LLVMValueRef emit_f2b(struct ac_llvm_context *ctx,
@@ -419,12 +429,12 @@ static LLVMValueRef emit_bitfield_extract(struct ac_llvm_context *ctx,
 {
        LLVMValueRef result;
 
-       if (HAVE_LLVM < 0x0700) {
+       if (HAVE_LLVM >= 0x0800) {
                LLVMValueRef icond = LLVMBuildICmp(ctx->builder, LLVMIntEQ, srcs[2], LLVMConstInt(ctx->i32, 32, false), "");
                result = ac_build_bfe(ctx, srcs[0], srcs[1], srcs[2], is_signed);
                result = LLVMBuildSelect(ctx->builder, icond, srcs[0], result, "");
        } else {
-               /* FIXME: LLVM 7 returns incorrect result when count is 0.
+               /* FIXME: LLVM 7+ returns incorrect result when count is 0.
                 * https://bugs.freedesktop.org/show_bug.cgi?id=107276
                 */
                LLVMValueRef zero = ctx->i32_0;
@@ -677,34 +687,34 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                                                     LLVMTypeOf(src[0]), ""),
                                       "");
                break;
-       case nir_op_ilt:
+       case nir_op_ilt32:
                result = emit_int_cmp(&ctx->ac, LLVMIntSLT, src[0], src[1]);
                break;
-       case nir_op_ine:
+       case nir_op_ine32:
                result = emit_int_cmp(&ctx->ac, LLVMIntNE, src[0], src[1]);
                break;
-       case nir_op_ieq:
+       case nir_op_ieq32:
                result = emit_int_cmp(&ctx->ac, LLVMIntEQ, src[0], src[1]);
                break;
-       case nir_op_ige:
+       case nir_op_ige32:
                result = emit_int_cmp(&ctx->ac, LLVMIntSGE, src[0], src[1]);
                break;
-       case nir_op_ult:
+       case nir_op_ult32:
                result = emit_int_cmp(&ctx->ac, LLVMIntULT, src[0], src[1]);
                break;
-       case nir_op_uge:
+       case nir_op_uge32:
                result = emit_int_cmp(&ctx->ac, LLVMIntUGE, src[0], src[1]);
                break;
-       case nir_op_feq:
+       case nir_op_feq32:
                result = emit_float_cmp(&ctx->ac, LLVMRealOEQ, src[0], src[1]);
                break;
-       case nir_op_fne:
+       case nir_op_fne32:
                result = emit_float_cmp(&ctx->ac, LLVMRealUNE, src[0], src[1]);
                break;
-       case nir_op_flt:
+       case nir_op_flt32:
                result = emit_float_cmp(&ctx->ac, LLVMRealOLT, src[0], src[1]);
                break;
-       case nir_op_fge:
+       case nir_op_fge32:
                result = emit_float_cmp(&ctx->ac, LLVMRealOGE, src[0], src[1]);
                break;
        case nir_op_fabs:
@@ -836,7 +846,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = emit_bitfield_insert(&ctx->ac, src[0], src[1], src[2], src[3]);
                break;
        case nir_op_bitfield_reverse:
-               result = ac_build_intrinsic(&ctx->ac, "llvm.bitreverse.i32", ctx->ac.i32, src, 1, AC_FUNC_ATTR_READNONE);
+               result = ac_build_bitfield_reverse(&ctx->ac, src[0]);
                break;
        case nir_op_bit_count:
                result = ac_build_bit_count(&ctx->ac, src[0]);
@@ -906,7 +916,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                else
                        result = LLVMBuildTrunc(ctx->ac.builder, src[0], def_type, "");
                break;
-       case nir_op_bcsel:
+       case nir_op_b32csel:
                result = emit_bcsel(&ctx->ac, src[0], src[1], src[2]);
                break;
        case nir_op_find_lsb:
@@ -931,16 +941,20 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                src[1] = ac_to_integer(&ctx->ac, src[1]);
                result = emit_uint_carry(&ctx->ac, "llvm.usub.with.overflow.i32", src[0], src[1]);
                break;
-       case nir_op_b2f:
-               result = emit_b2f(&ctx->ac, src[0]);
+       case nir_op_b2f16:
+       case nir_op_b2f32:
+       case nir_op_b2f64:
+               result = emit_b2f(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
-       case nir_op_f2b:
+       case nir_op_f2b32:
                result = emit_f2b(&ctx->ac, src[0]);
                break;
-       case nir_op_b2i:
+       case nir_op_b2i16:
+       case nir_op_b2i32:
+       case nir_op_b2i64:
                result = emit_b2i(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
-       case nir_op_i2b:
+       case nir_op_i2b32:
                src[0] = ac_to_integer(&ctx->ac, src[0]);
                result = emit_i2b(&ctx->ac, src[0]);
                break;
@@ -1086,7 +1100,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
 
        if (result) {
                assert(instr->dest.dest.is_ssa);
-               result = ac_to_integer(&ctx->ac, result);
+               result = ac_to_integer_or_pointer(&ctx->ac, result);
                ctx->ssa_defs[instr->dest.dest.ssa.index] = result;
        }
 }
@@ -1387,7 +1401,7 @@ static LLVMValueRef visit_load_push_constant(struct ac_nir_context *ctx,
 
        if (instr->dest.ssa.bit_size == 16) {
                unsigned load_dwords = instr->dest.ssa.num_components / 2 + 1;
-               LLVMTypeRef vec_type = LLVMVectorType(LLVMInt16Type(), 2 * load_dwords);
+               LLVMTypeRef vec_type = LLVMVectorType(LLVMInt16TypeInContext(ctx->ac.context), 2 * load_dwords);
                ptr = ac_cast_ptr(&ctx->ac, ptr, vec_type);
                LLVMValueRef res = LLVMBuildLoad(ctx->ac.builder, ptr, "");
                res = LLVMBuildBitCast(ctx->ac.builder, res, vec_type, "");
@@ -1456,6 +1470,11 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
        LLVMValueRef src_data = get_src(ctx, instr->src[0]);
        int elem_size_bytes = ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src_data)) / 8;
        unsigned writemask = nir_intrinsic_write_mask(instr);
+       enum gl_access_qualifier access = nir_intrinsic_access(instr);
+       LLVMValueRef glc = ctx->ac.i1false;
+
+       if (access & (ACCESS_VOLATILE | ACCESS_COHERENT))
+               glc = ctx->ac.i1true;
 
        LLVMValueRef rsrc = ctx->abi->load_ssbo(ctx->abi,
                                        get_src(ctx, instr->src[1]), true);
@@ -1512,7 +1531,7 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
                                ctx->ac.i32_0,
                                LLVMConstInt(ctx->ac.i32, 2, false), // dfmt (= 16bit)
                                LLVMConstInt(ctx->ac.i32, 4, false), // nfmt (= uint)
-                               ctx->ac.i1false,
+                               glc,
                                ctx->ac.i1false,
                        };
                        ac_build_intrinsic(&ctx->ac, store_name,
@@ -1540,7 +1559,7 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
                                rsrc,
                                ctx->ac.i32_0, /* vindex */
                                offset,
-                               ctx->ac.i1false,  /* glc */
+                               glc,
                                ctx->ac.i1false,  /* slc */
                        };
                        ac_build_intrinsic(&ctx->ac, store_name,
@@ -1565,7 +1584,7 @@ static LLVMValueRef visit_atomic_ssbo(struct ac_nir_context *ctx,
                                                 true);
        params[arg_count++] = ctx->ac.i32_0; /* vindex */
        params[arg_count++] = get_src(ctx, instr->src[1]);      /* voffset */
-       params[arg_count++] = LLVMConstInt(ctx->ac.i1, 0, false);  /* slc */
+       params[arg_count++] = ctx->ac.i1false;  /* slc */
 
        switch (instr->intrinsic) {
        case nir_intrinsic_ssbo_atomic_add:
@@ -1608,31 +1627,45 @@ static LLVMValueRef visit_atomic_ssbo(struct ac_nir_context *ctx,
 static LLVMValueRef visit_load_buffer(struct ac_nir_context *ctx,
                                       const nir_intrinsic_instr *instr)
 {
-       LLVMValueRef results[2];
-       int load_bytes;
        int elem_size_bytes = instr->dest.ssa.bit_size / 8;
        int num_components = instr->num_components;
-       int num_bytes = num_components * elem_size_bytes;
+       enum gl_access_qualifier access = nir_intrinsic_access(instr);
+       LLVMValueRef glc = ctx->ac.i1false;
 
-       for (int i = 0; i < num_bytes; i += load_bytes) {
-               load_bytes = MIN2(num_bytes - i, 16);
-               const char *load_name;
-               LLVMTypeRef data_type;
-               LLVMValueRef offset = get_src(ctx, instr->src[1]);
-               LLVMValueRef immoffset = LLVMConstInt(ctx->ac.i32, i, false);
-               LLVMValueRef rsrc = ctx->abi->load_ssbo(ctx->abi,
-                                                       get_src(ctx, instr->src[0]), false);
-               LLVMValueRef vindex = ctx->ac.i32_0;
+       if (access & (ACCESS_VOLATILE | ACCESS_COHERENT))
+               glc = ctx->ac.i1true;
+
+       LLVMValueRef offset = get_src(ctx, instr->src[1]);
+       LLVMValueRef rsrc = ctx->abi->load_ssbo(ctx->abi,
+                                               get_src(ctx, instr->src[0]), false);
+       LLVMValueRef vindex = ctx->ac.i32_0;
 
-               int idx = i ? 1 : 0;
+       LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.ssa);
+       LLVMTypeRef def_elem_type = num_components > 1 ? LLVMGetElementType(def_type) : def_type;
+
+       LLVMValueRef results[4];
+       for (int i = 0; i < num_components;) {
+               int num_elems = num_components - i;
+               if (elem_size_bytes < 4 && nir_intrinsic_align(instr) % 4 != 0)
+                       num_elems = 1;
+               if (num_elems * elem_size_bytes > 16)
+                       num_elems = 16 / elem_size_bytes;
+               int load_bytes = num_elems * elem_size_bytes;
+
+               LLVMValueRef immoffset = LLVMConstInt(ctx->ac.i32, i * elem_size_bytes, false);
+
+               LLVMValueRef ret;
                if (load_bytes == 2) {
-                       results[idx] = ac_build_tbuffer_load_short(&ctx->ac,
-                                                                  rsrc,
-                                                                  vindex,
-                                                                  offset,
-                                                                  ctx->ac.i32_0,
-                                                                  immoffset);
+                       ret = ac_build_tbuffer_load_short(&ctx->ac,
+                                                         rsrc,
+                                                         vindex,
+                                                         offset,
+                                                         ctx->ac.i32_0,
+                                                         immoffset,
+                                                         glc);
                } else {
+                       const char *load_name;
+                       LLVMTypeRef data_type;
                        switch (load_bytes) {
                        case 16:
                        case 12:
@@ -1655,36 +1688,26 @@ static LLVMValueRef visit_load_buffer(struct ac_nir_context *ctx,
                                rsrc,
                                vindex,
                                LLVMBuildAdd(ctx->ac.builder, offset, immoffset, ""),
-                               ctx->ac.i1false,
+                               glc,
                                ctx->ac.i1false,
                        };
-                       results[idx] = ac_build_intrinsic(&ctx->ac, load_name, data_type, params, 5, 0);
-                       unsigned num_elems = ac_get_type_size(data_type) / elem_size_bytes;
-                       LLVMTypeRef resTy = LLVMVectorType(LLVMIntType(instr->dest.ssa.bit_size), num_elems);
-                       results[idx] = LLVMBuildBitCast(ctx->ac.builder, results[idx], resTy, "");
+                       ret = ac_build_intrinsic(&ctx->ac, load_name, data_type, params, 5, 0);
                }
-       }
 
-       assume(results[0]);
-       LLVMValueRef ret = results[0];
-       if (num_bytes > 16 || num_components == 3) {
-               LLVMValueRef masks[] = {
-                       LLVMConstInt(ctx->ac.i32, 0, false), LLVMConstInt(ctx->ac.i32, 1, false),
-                       LLVMConstInt(ctx->ac.i32, 2, false), LLVMConstInt(ctx->ac.i32, 3, false),
-               };
+               LLVMTypeRef byte_vec = LLVMVectorType(ctx->ac.i8, ac_get_type_size(LLVMTypeOf(ret)));
+               ret = LLVMBuildBitCast(ctx->ac.builder, ret, byte_vec, "");
+               ret = ac_trim_vector(&ctx->ac, ret, load_bytes);
 
-               if (num_bytes > 16 && num_components == 3) {
-                       /* we end up with a v4f32 and v2f32 but shuffle fails on that */
-                       results[1] = ac_build_expand_to_vec4(&ctx->ac, results[1], 2);
-               }
+               LLVMTypeRef ret_type = LLVMVectorType(def_elem_type, num_elems);
+               ret = LLVMBuildBitCast(ctx->ac.builder, ret, ret_type, "");
 
-               LLVMValueRef swizzle = LLVMConstVector(masks, num_components);
-               ret = LLVMBuildShuffleVector(ctx->ac.builder, results[0],
-                                            results[num_bytes > 16 ? 1 : 0], swizzle, "");
+               for (unsigned j = 0; j < num_elems; j++) {
+                       results[i + j] = LLVMBuildExtractElement(ctx->ac.builder, ret, LLVMConstInt(ctx->ac.i32, j, false), "");
+               }
+               i += num_elems;
        }
 
-       return LLVMBuildBitCast(ctx->ac.builder, ret,
-                               get_def_type(ctx, &instr->dest.ssa), "");
+       return ac_build_gather_values(&ctx->ac, results, num_components);
 }
 
 static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
@@ -1709,7 +1732,8 @@ static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
                                                                 ctx->ac.i32_0,
                                                                 offset,
                                                                 ctx->ac.i32_0,
-                                                                LLVMConstInt(ctx->ac.i32, 2 * i, 0));
+                                                                LLVMConstInt(ctx->ac.i32, 2 * i, 0),
+                                                                ctx->ac.i1false);
                }
                ret = ac_build_gather_values(&ctx->ac, results, num_components);
        } else {
@@ -1839,23 +1863,32 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
        nir_variable *var = nir_deref_instr_get_variable(nir_instr_as_deref(instr->src[0].ssa->parent_instr));
 
        LLVMValueRef values[8];
-       int idx = var->data.driver_location;
+       int idx = 0;
        int ve = instr->dest.ssa.num_components;
-       unsigned comp = var->data.location_frac;
+       unsigned comp = 0;
        LLVMValueRef indir_index;
        LLVMValueRef ret;
        unsigned const_index;
-       unsigned stride = var->data.compact ? 1 : 4;
-       bool vs_in = ctx->stage == MESA_SHADER_VERTEX &&
-                    var->data.mode == nir_var_shader_in;
-
-       get_deref_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr), vs_in, NULL, NULL,
-                        &const_index, &indir_index);
+       unsigned stride = 4;
+       int mode = nir_var_shared;
+       
+       if (var) {
+               bool vs_in = ctx->stage == MESA_SHADER_VERTEX &&
+                       var->data.mode == nir_var_shader_in;
+               if (var->data.compact)
+                       stride = 1;
+               idx = var->data.driver_location;
+               comp = var->data.location_frac;
+               mode = var->data.mode;
+
+               get_deref_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr), vs_in, NULL, NULL,
+                                &const_index, &indir_index);
+       }
 
        if (instr->dest.ssa.bit_size == 64)
                ve *= 2;
 
-       switch (var->data.mode) {
+       switch (mode) {
        case nir_var_shader_in:
                if (ctx->stage == MESA_SHADER_TESS_CTRL ||
                    ctx->stage == MESA_SHADER_TESS_EVAL) {
@@ -2216,7 +2249,7 @@ static void get_image_coords(struct ac_nir_context *ctx,
        bool gfx9_1d = ctx->ac.chip_class >= GFX9 && dim == GLSL_SAMPLER_DIM_1D;
        count = image_type_to_components_count(dim, is_array);
 
-       if (is_ms) {
+       if (is_ms && instr->intrinsic == nir_intrinsic_image_deref_load) {
                LLVMValueRef fmask_load_address[3];
                int chan;
 
@@ -2359,17 +2392,33 @@ static void visit_image_store(struct ac_nir_context *ctx,
                glc = ctx->ac.i1true;
 
        if (dim == GLSL_SAMPLER_DIM_BUF) {
+               char name[48];
+               const char *types[] = { "f32", "v2f32", "v4f32" };
                LLVMValueRef rsrc = get_image_buffer_descriptor(ctx, instr, true);
+               LLVMValueRef src = ac_to_float(&ctx->ac, get_src(ctx, instr->src[3]));
+               unsigned src_channels = ac_get_llvm_num_components(src);
+
+               if (src_channels == 3)
+                       src = ac_build_expand(&ctx->ac, src, 3, 4);
 
-               params[0] = ac_to_float(&ctx->ac, get_src(ctx, instr->src[3])); /* data */
+               params[0] = src; /* data */
                params[1] = rsrc;
                params[2] = LLVMBuildExtractElement(ctx->ac.builder, get_src(ctx, instr->src[1]),
                                                    ctx->ac.i32_0, ""); /* vindex */
                params[3] = ctx->ac.i32_0; /* voffset */
-               params[4] = glc;  /* glc */
-               params[5] = ctx->ac.i1false;  /* slc */
-               ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.buffer.store.format.v4f32", ctx->ac.voidt,
-                                  params, 6, 0);
+               snprintf(name, sizeof(name), "%s.%s",
+                        HAVE_LLVM >= 0x800 ? "llvm.amdgcn.struct.buffer.store.format"
+                                           : "llvm.amdgcn.buffer.store.format",
+                        types[CLAMP(src_channels, 1, 3) - 1]);
+
+               if (HAVE_LLVM >= 0x800) {
+                       params[4] = ctx->ac.i32_0; /* soffset */
+                       params[5] = glc ? ctx->ac.i32_1 : ctx->ac.i32_0;
+               } else {
+                       params[4] = glc;  /* glc */
+                       params[5] = ctx->ac.i1false;  /* slc */
+               }
+               ac_build_intrinsic(&ctx->ac, name, ctx->ac.voidt, params, 6, 0);
        } else {
                struct ac_image_args args = {};
                args.opcode = ac_image_store;
@@ -2396,7 +2445,7 @@ static LLVMValueRef visit_image_atomic(struct ac_nir_context *ctx,
 
        bool cmpswap = instr->intrinsic == nir_intrinsic_image_deref_atomic_comp_swap;
        const char *atomic_name;
-       char intrinsic_name[41];
+       char intrinsic_name[64];
        enum ac_atomic_op atomic_subop;
        const struct glsl_type *type = glsl_without_array(var->type);
        MAYBE_UNUSED int length;
@@ -2449,10 +2498,18 @@ static LLVMValueRef visit_image_atomic(struct ac_nir_context *ctx,
                params[param_count++] = LLVMBuildExtractElement(ctx->ac.builder, get_src(ctx, instr->src[1]),
                                                                ctx->ac.i32_0, ""); /* vindex */
                params[param_count++] = ctx->ac.i32_0; /* voffset */
-               params[param_count++] = ctx->ac.i1false;  /* slc */
+               if (HAVE_LLVM >= 0x800) {
+                       params[param_count++] = ctx->ac.i32_0; /* soffset */
+                       params[param_count++] = ctx->ac.i32_0;  /* slc */
 
-               length = snprintf(intrinsic_name, sizeof(intrinsic_name),
-                                 "llvm.amdgcn.buffer.atomic.%s", atomic_name);
+                       length = snprintf(intrinsic_name, sizeof(intrinsic_name),
+                                         "llvm.amdgcn.struct.buffer.atomic.%s.i32", atomic_name);
+               } else {
+                       params[param_count++] = ctx->ac.i1false;  /* slc */
+
+                       length = snprintf(intrinsic_name, sizeof(intrinsic_name),
+                                         "llvm.amdgcn.buffer.atomic.%s", atomic_name);
+               }
 
                assert(length < sizeof(intrinsic_name));
                return ac_build_intrinsic(&ctx->ac, intrinsic_name, ctx->ac.i32,
@@ -2582,7 +2639,7 @@ static void emit_discard(struct ac_nir_context *ctx,
                                     ctx->ac.i32_0, "");
        } else {
                assert(instr->intrinsic == nir_intrinsic_discard);
-               cond = LLVMConstInt(ctx->ac.i1, false, 0);
+               cond = ctx->ac.i1false;
        }
 
        ctx->abi->emit_kill(ctx->abi, cond);
@@ -2640,7 +2697,7 @@ visit_first_invocation(struct ac_nir_context *ctx)
        LLVMValueRef active_set = ac_build_ballot(&ctx->ac, ctx->ac.i32_1);
 
        /* The second argument is whether cttz(0) should be defined, but we do not care. */
-       LLVMValueRef args[] = {active_set, LLVMConstInt(ctx->ac.i1, 0, false)};
+       LLVMValueRef args[] = {active_set, ctx->ac.i1false};
        LLVMValueRef result =  ac_build_intrinsic(&ctx->ac,
                                                  "llvm.cttz.i64",
                                                  ctx->ac.i64, args, 2,
@@ -3311,7 +3368,7 @@ static LLVMValueRef apply_round_slice(struct ac_llvm_context *ctx,
                                      LLVMValueRef coord)
 {
        coord = ac_to_float(ctx, coord);
-       coord = ac_build_intrinsic(ctx, "llvm.rint.f32", ctx->f32, &coord, 1, 0);
+       coord = ac_build_round(ctx, coord);
        coord = ac_to_integer(ctx, coord);
        return coord;
 }
@@ -3623,7 +3680,6 @@ static void visit_post_phi(struct ac_nir_context *ctx,
 
 static void phi_post_pass(struct ac_nir_context *ctx)
 {
-       struct hash_entry *entry;
        hash_table_foreach(ctx->phis, entry) {
                visit_post_phi(ctx, (nir_phi_instr*)entry->key,
                               (LLVMValueRef)entry->data);
@@ -3685,6 +3741,9 @@ static void visit_deref(struct ac_nir_context *ctx,
                result = ac_build_gep0(&ctx->ac, get_src(ctx, instr->parent),
                                       get_src(ctx, instr->arr.index));
                break;
+       case nir_deref_type_cast:
+               result = get_src(ctx, instr->parent);
+               break;
        default:
                unreachable("Unhandled deref_instr deref type");
        }
@@ -3928,7 +3987,7 @@ setup_shared(struct ac_nir_context *ctx,
                        LLVMAddGlobalInAddressSpace(
                           ctx->ac.module, glsl_to_llvm_type(&ctx->ac, variable->type),
                           variable->name ? variable->name : "",
-                          AC_LOCAL_ADDR_SPACE);
+                          AC_ADDR_SPACE_LDS);
                _mesa_hash_table_insert(ctx->vars, variable, shared);
        }
 }
@@ -4017,3 +4076,164 @@ ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class chip_class)
 
        nir_lower_indirect_derefs(nir, indirect_mask);
 }
+
+static unsigned
+get_inst_tessfactor_writemask(nir_intrinsic_instr *intrin)
+{
+       if (intrin->intrinsic != nir_intrinsic_store_deref)
+               return 0;
+
+       nir_variable *var =
+               nir_deref_instr_get_variable(nir_src_as_deref(intrin->src[0]));
+
+       if (var->data.mode != nir_var_shader_out)
+               return 0;
+
+       unsigned writemask = 0;
+       const int location = var->data.location;
+       unsigned first_component = var->data.location_frac;
+       unsigned num_comps = intrin->dest.ssa.num_components;
+
+       if (location == VARYING_SLOT_TESS_LEVEL_INNER)
+               writemask = ((1 << num_comps + 1) - 1) << first_component;
+       else if (location == VARYING_SLOT_TESS_LEVEL_OUTER)
+               writemask = (((1 << num_comps + 1) - 1) << first_component) << 4;
+
+       return writemask;
+}
+
+static void
+scan_tess_ctrl(nir_cf_node *cf_node, unsigned *upper_block_tf_writemask,
+              unsigned *cond_block_tf_writemask,
+              bool *tessfactors_are_def_in_all_invocs, bool is_nested_cf)
+{
+       switch (cf_node->type) {
+       case nir_cf_node_block: {
+               nir_block *block = nir_cf_node_as_block(cf_node);
+               nir_foreach_instr(instr, block) {
+                       if (instr->type != nir_instr_type_intrinsic)
+                               continue;
+
+                       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
+                       if (intrin->intrinsic == nir_intrinsic_barrier) {
+
+                               /* If we find a barrier in nested control flow put this in the
+                                * too hard basket. In GLSL this is not possible but it is in
+                                * SPIR-V.
+                                */
+                               if (is_nested_cf) {
+                                       *tessfactors_are_def_in_all_invocs = false;
+                                       return;
+                               }
+
+                               /* The following case must be prevented:
+                                *    gl_TessLevelInner = ...;
+                                *    barrier();
+                                *    if (gl_InvocationID == 1)
+                                *       gl_TessLevelInner = ...;
+                                *
+                                * If you consider disjoint code segments separated by barriers, each
+                                * such segment that writes tess factor channels should write the same
+                                * channels in all codepaths within that segment.
+                                */
+                               if (upper_block_tf_writemask || cond_block_tf_writemask) {
+                                       /* Accumulate the result: */
+                                       *tessfactors_are_def_in_all_invocs &=
+                                               !(*cond_block_tf_writemask & ~(*upper_block_tf_writemask));
+
+                                       /* Analyze the next code segment from scratch. */
+                                       *upper_block_tf_writemask = 0;
+                                       *cond_block_tf_writemask = 0;
+                               }
+                       } else
+                               *upper_block_tf_writemask |= get_inst_tessfactor_writemask(intrin);
+               }
+
+               break;
+       }
+       case nir_cf_node_if: {
+               unsigned then_tessfactor_writemask = 0;
+               unsigned else_tessfactor_writemask = 0;
+
+               nir_if *if_stmt = nir_cf_node_as_if(cf_node);
+               foreach_list_typed(nir_cf_node, nested_node, node, &if_stmt->then_list) {
+                       scan_tess_ctrl(nested_node, &then_tessfactor_writemask,
+                                      cond_block_tf_writemask,
+                                      tessfactors_are_def_in_all_invocs, true);
+               }
+
+               foreach_list_typed(nir_cf_node, nested_node, node, &if_stmt->else_list) {
+                       scan_tess_ctrl(nested_node, &else_tessfactor_writemask,
+                                      cond_block_tf_writemask,
+                                      tessfactors_are_def_in_all_invocs, true);
+               }
+
+               if (then_tessfactor_writemask || else_tessfactor_writemask) {
+                       /* If both statements write the same tess factor channels,
+                        * we can say that the upper block writes them too.
+                        */
+                       *upper_block_tf_writemask |= then_tessfactor_writemask &
+                               else_tessfactor_writemask;
+                       *cond_block_tf_writemask |= then_tessfactor_writemask |
+                               else_tessfactor_writemask;
+               }
+
+               break;
+       }
+       case nir_cf_node_loop: {
+               nir_loop *loop = nir_cf_node_as_loop(cf_node);
+               foreach_list_typed(nir_cf_node, nested_node, node, &loop->body) {
+                       scan_tess_ctrl(nested_node, cond_block_tf_writemask,
+                                      cond_block_tf_writemask,
+                                      tessfactors_are_def_in_all_invocs, true);
+               }
+
+               break;
+       }
+       default:
+               unreachable("unknown cf node type");
+       }
+}
+
+bool
+ac_are_tessfactors_def_in_all_invocs(const struct nir_shader *nir)
+{
+       assert(nir->info.stage == MESA_SHADER_TESS_CTRL);
+
+       /* The pass works as follows:
+        * If all codepaths write tess factors, we can say that all
+        * invocations define tess factors.
+        *
+        * Each tess factor channel is tracked separately.
+        */
+       unsigned main_block_tf_writemask = 0; /* if main block writes tess factors */
+       unsigned cond_block_tf_writemask = 0; /* if cond block writes tess factors */
+
+       /* Initial value = true. Here the pass will accumulate results from
+        * multiple segments surrounded by barriers. If tess factors aren't
+        * written at all, it's a shader bug and we don't care if this will be
+        * true.
+        */
+       bool tessfactors_are_def_in_all_invocs = true;
+
+       nir_foreach_function(function, nir) {
+               if (function->impl) {
+                       foreach_list_typed(nir_cf_node, node, node, &function->impl->body) {
+                               scan_tess_ctrl(node, &main_block_tf_writemask,
+                                              &cond_block_tf_writemask,
+                                              &tessfactors_are_def_in_all_invocs,
+                                              false);
+                       }
+               }
+       }
+
+       /* Accumulate the result for the last code segment separated by a
+        * barrier.
+        */
+       if (main_block_tf_writemask || cond_block_tf_writemask) {
+               tessfactors_are_def_in_all_invocs &=
+                       !(cond_block_tf_writemask & ~main_block_tf_writemask);
+       }
+
+       return tessfactors_are_def_in_all_invocs;
+}