ac/llvm: fix nir_texop_texture_samples with NULL descriptors
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index d1c333ac73d252004c6a54a003cf9ce3394f77e5..2a495eb76e4e21a8d95ebcc676812bdc95961493 100644 (file)
@@ -245,7 +245,7 @@ static LLVMValueRef emit_bcsel(struct ac_llvm_context *ctx,
        LLVMTypeRef src1_type = LLVMTypeOf(src1);
        LLVMTypeRef src2_type = LLVMTypeOf(src2);
 
-       assert(LLVMGetTypeKind(LLVMTypeOf(src0)) != LLVMVectorTypeKind);
+       assert(LLVMGetTypeKind(LLVMTypeOf(src0)) != LLVMFixedVectorTypeKind);
 
        if (LLVMGetTypeKind(src1_type) == LLVMPointerTypeKind &&
            LLVMGetTypeKind(src2_type) != LLVMPointerTypeKind) {
@@ -589,6 +589,10 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        unsigned num_components = instr->dest.dest.ssa.num_components;
        unsigned src_components;
        LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.dest.ssa);
+       bool saved_inexact = false;
+
+       if (instr->exact)
+               saved_inexact = ac_disable_inexact_math(ctx->ac.builder);
 
        assert(nir_op_infos[instr->op].num_inputs <= ARRAY_SIZE(src));
        switch (instr->op) {
@@ -1182,6 +1186,9 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = ac_to_integer_or_pointer(&ctx->ac, result);
                ctx->ssa_defs[instr->dest.dest.ssa.index] = result;
        }
+
+       if (instr->exact)
+               ac_restore_inexact_math(ctx->ac.builder, saved_inexact);
 }
 
 static void visit_load_const(struct ac_nir_context *ctx,
@@ -1740,6 +1747,16 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
                        count = 1;
                        num_bytes = 2;
                }
+
+               /* Due to alignment issues, split stores of 8-bit/16-bit
+                * vectors.
+                */
+               if (ctx->ac.chip_class == GFX6 && count > 1 && elem_size_bytes < 4) {
+                       writemask |= ((1u << (count - 1)) - 1u) << (start + 1);
+                       count = 1;
+                       num_bytes = elem_size_bytes;
+               }
+
                data = extract_vector_range(&ctx->ac, base_data, start, count);
 
                offset = LLVMBuildAdd(ctx->ac.builder, base_offset,
@@ -2176,7 +2193,7 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx,
        LLVMTypeRef dest_type = get_def_type(ctx, &instr->dest.ssa);
 
        LLVMTypeRef src_component_type;
-       if (LLVMGetTypeKind(dest_type) == LLVMVectorTypeKind)
+       if (LLVMGetTypeKind(dest_type) == LLVMFixedVectorTypeKind)
                src_component_type = LLVMGetElementType(dest_type);
        else
                src_component_type = dest_type;
@@ -2328,14 +2345,19 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                break;
        case nir_var_mem_global:  {
                LLVMValueRef address = get_src(ctx, instr->src[0]);
+               LLVMTypeRef result_type = get_def_type(ctx, &instr->dest.ssa);
                unsigned explicit_stride = glsl_get_explicit_stride(deref->type);
                unsigned natural_stride = type_scalar_size_bytes(deref->type);
                unsigned stride = explicit_stride ? explicit_stride : natural_stride;
+               int elem_size_bytes = ac_get_elem_bits(&ctx->ac, result_type) / 8;
+               bool split_loads = ctx->ac.chip_class == GFX6 && elem_size_bytes < 4;
 
-               LLVMTypeRef result_type = get_def_type(ctx, &instr->dest.ssa);
-               if (stride != natural_stride) {
-                       LLVMTypeRef ptr_type =  LLVMPointerType(LLVMGetElementType(result_type),
-                                                               LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+               if (stride != natural_stride || split_loads) {
+                       if (LLVMGetTypeKind(result_type) == LLVMFixedVectorTypeKind)
+                               result_type = LLVMGetElementType(result_type);
+
+                       LLVMTypeRef ptr_type = LLVMPointerType(result_type,
+                                                              LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
                        address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
 
                        for (unsigned i = 0; i < instr->dest.ssa.num_components; ++i) {
@@ -2489,23 +2511,29 @@ visit_store_var(struct ac_nir_context *ctx,
                unsigned explicit_stride = glsl_get_explicit_stride(deref->type);
                unsigned natural_stride = type_scalar_size_bytes(deref->type);
                unsigned stride = explicit_stride ? explicit_stride : natural_stride;
+               int elem_size_bytes = ac_get_elem_bits(&ctx->ac, LLVMTypeOf(val)) / 8;
+               bool split_stores = ctx->ac.chip_class == GFX6 && elem_size_bytes < 4;
 
                LLVMTypeRef ptr_type =  LLVMPointerType(LLVMTypeOf(val),
                                                        LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
                address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
 
                if (writemask == (1u << ac_get_llvm_num_components(val)) - 1 &&
-                   stride == natural_stride) {
-                       LLVMTypeRef ptr_type =  LLVMPointerType(LLVMTypeOf(val),
-                                                               LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+                   stride == natural_stride && !split_stores) {
+                       LLVMTypeRef ptr_type = LLVMPointerType(LLVMTypeOf(val),
+                                                              LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
                        address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
 
                        val = LLVMBuildBitCast(ctx->ac.builder, val,
                                               LLVMGetElementType(LLVMTypeOf(address)), "");
                        LLVMBuildStore(ctx->ac.builder, val, address);
                } else {
-                       LLVMTypeRef ptr_type =  LLVMPointerType(LLVMGetElementType(LLVMTypeOf(val)),
-                                                               LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
+                       LLVMTypeRef val_type = LLVMTypeOf(val);
+                       if (LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMFixedVectorTypeKind)
+                               val_type = LLVMGetElementType(val_type);
+
+                       LLVMTypeRef ptr_type = LLVMPointerType(val_type,
+                                                              LLVMGetPointerAddressSpace(LLVMTypeOf(address)));
                        address = LLVMBuildBitCast(ctx->ac.builder, address, ptr_type , "");
                        for (unsigned chan = 0; chan < 4; chan++) {
                                if (!(writemask & (1 << chan)))
@@ -3917,7 +3945,16 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_emit_vertex:
                ctx->abi->emit_vertex(ctx->abi, nir_intrinsic_stream_id(instr), ctx->abi->outputs);
                break;
+       case nir_intrinsic_emit_vertex_with_counter: {
+               unsigned stream = nir_intrinsic_stream_id(instr);
+               LLVMValueRef next_vertex = get_src(ctx, instr->src[0]);
+               ctx->abi->emit_vertex_with_counter(ctx->abi, stream,
+                                                  next_vertex,
+                                                  ctx->abi->outputs);
+               break;
+       }
        case nir_intrinsic_end_primitive:
+       case nir_intrinsic_end_primitive_with_counter:
                ctx->abi->emit_primitive(ctx->abi, nir_intrinsic_stream_id(instr));
                break;
        case nir_intrinsic_load_tess_coord:
@@ -4432,6 +4469,8 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
 
        if (instr->op == nir_texop_texture_samples) {
                LLVMValueRef res, samples, is_msaa;
+               LLVMValueRef default_sample;
+
                res = LLVMBuildBitCast(ctx->ac.builder, args.resource, ctx->ac.v8i32, "");
                samples = LLVMBuildExtractElement(ctx->ac.builder, res,
                                                  LLVMConstInt(ctx->ac.i32, 3, false), "");
@@ -4448,8 +4487,27 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
                                       LLVMConstInt(ctx->ac.i32, 0xf, false), "");
                samples = LLVMBuildShl(ctx->ac.builder, ctx->ac.i32_1,
                                       samples, "");
+
+               if (ctx->abi->robust_buffer_access) {
+                       LLVMValueRef dword1, is_null_descriptor;
+
+                       /* Extract the second dword of the descriptor, if it's
+                        * all zero, then it's a null descriptor.
+                        */
+                       dword1 = LLVMBuildExtractElement(ctx->ac.builder, res,
+                                                        LLVMConstInt(ctx->ac.i32, 1, false), "");
+                       is_null_descriptor =
+                               LLVMBuildICmp(ctx->ac.builder, LLVMIntEQ, dword1,
+                                             LLVMConstInt(ctx->ac.i32, 0, false), "");
+                       default_sample =
+                               LLVMBuildSelect(ctx->ac.builder, is_null_descriptor,
+                                               ctx->ac.i32_0, ctx->ac.i32_1, "");
+               } else {
+                       default_sample = ctx->ac.i32_1;
+               }
+
                samples = LLVMBuildSelect(ctx->ac.builder, is_msaa, samples,
-                                         ctx->ac.i32_1, "");
+                                         default_sample, "");
                result = samples;
                goto write_result;
        }
@@ -4920,7 +4978,7 @@ static void visit_deref(struct ac_nir_context *ctx,
                LLVMTypeRef type = LLVMPointerType(pointee_type, address_space);
 
                if (LLVMTypeOf(result) != type) {
-                       if (LLVMGetTypeKind(LLVMTypeOf(result)) == LLVMVectorTypeKind) {
+                       if (LLVMGetTypeKind(LLVMTypeOf(result)) == LLVMFixedVectorTypeKind) {
                                result = LLVMBuildBitCast(ctx->ac.builder, result,
                                                          type, "");
                        } else {