nir/lower_atomics_to_ssbo: Also lower barriers
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index 210a37a39061158ffae3fa5a968279cb0dbff509..4ae45c6204d1716e371926ac8ccac648d6e6dc13 100644 (file)
@@ -101,14 +101,16 @@ static LLVMValueRef get_src(struct ac_nir_context *nir, nir_src src)
 }
 
 static LLVMValueRef
-get_memory_ptr(struct ac_nir_context *ctx, nir_src src)
+get_memory_ptr(struct ac_nir_context *ctx, nir_src src, unsigned bit_size)
 {
        LLVMValueRef ptr = get_src(ctx, src);
        ptr = LLVMBuildGEP(ctx->ac.builder, ctx->ac.lds, &ptr, 1, "");
        int addr_space = LLVMGetPointerAddressSpace(LLVMTypeOf(ptr));
 
+       LLVMTypeRef type = LLVMIntTypeInContext(ctx->ac.context, bit_size);
+
        return LLVMBuildBitCast(ctx->ac.builder, ptr,
-                               LLVMPointerType(ctx->ac.i32, addr_space), "");
+                               LLVMPointerType(type, addr_space), "");
 }
 
 static LLVMBasicBlockRef get_block(struct ac_nir_context *nir,
@@ -234,8 +236,19 @@ static LLVMValueRef emit_intrin_3f_param(struct ac_llvm_context *ctx,
 static LLVMValueRef emit_bcsel(struct ac_llvm_context *ctx,
                               LLVMValueRef src0, LLVMValueRef src1, LLVMValueRef src2)
 {
+       LLVMTypeRef src1_type = LLVMTypeOf(src1);
+       LLVMTypeRef src2_type = LLVMTypeOf(src2);
+
        assert(LLVMGetTypeKind(LLVMTypeOf(src0)) != LLVMVectorTypeKind);
 
+       if (LLVMGetTypeKind(src1_type) == LLVMPointerTypeKind &&
+           LLVMGetTypeKind(src2_type) != LLVMPointerTypeKind) {
+               src2 = LLVMBuildIntToPtr(ctx->builder, src2, src1_type, "");
+       } else if (LLVMGetTypeKind(src2_type) == LLVMPointerTypeKind &&
+                  LLVMGetTypeKind(src1_type) != LLVMPointerTypeKind) {
+               src1 = LLVMBuildIntToPtr(ctx->builder, src1, src2_type, "");
+       }
+
        LLVMValueRef v = LLVMBuildICmp(ctx->builder, LLVMIntNE, src0,
                                       ctx->i32_0, "");
        return LLVMBuildSelect(ctx->builder, v,
@@ -2000,7 +2013,9 @@ static LLVMValueRef load_tess_varyings(struct ac_nir_context *ctx,
 
        unsigned location = var->data.location;
        unsigned driver_location = var->data.driver_location;
-       const bool is_patch =  var->data.patch;
+       const bool is_patch = var->data.patch ||
+                             var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
+                             var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER;
        const bool is_compact = var->data.compact;
 
        get_deref_offset(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr),
@@ -2131,13 +2146,6 @@ static LLVMValueRef visit_load_var(struct ac_nir_context *ctx,
                        }
                }
                break;
-       case nir_var_mem_shared: {
-               LLVMValueRef address = get_src(ctx, instr->src[0]);
-               LLVMValueRef val = LLVMBuildLoad(ctx->ac.builder, address, "");
-               return LLVMBuildBitCast(ctx->ac.builder, val,
-                                       get_def_type(ctx, &instr->dest.ssa),
-                                       "");
-       }
        case nir_var_shader_out:
                if (ctx->stage == MESA_SHADER_TESS_CTRL) {
                        return load_tess_varyings(ctx, instr, false);
@@ -2247,7 +2255,9 @@ visit_store_var(struct ac_nir_context *ctx,
                        LLVMValueRef vertex_index = NULL;
                        LLVMValueRef indir_index = NULL;
                        unsigned const_index = 0;
-                       const bool is_patch = var->data.patch;
+                       const bool is_patch = var->data.patch ||
+                                             var->data.location == VARYING_SLOT_TESS_LEVEL_INNER ||
+                                             var->data.location == VARYING_SLOT_TESS_LEVEL_OUTER;
 
                        get_deref_offset(ctx, deref, false, NULL,
                                         is_patch ? NULL : &vertex_index,
@@ -2314,8 +2324,7 @@ visit_store_var(struct ac_nir_context *ctx,
                }
                break;
 
-       case nir_var_mem_global:
-       case nir_var_mem_shared: {
+       case nir_var_mem_global: {
                int writemask = instr->const_index[0];
                LLVMValueRef address = get_src(ctx, instr->src[0]);
                LLVMValueRef val = get_src(ctx, instr->src[1]);
@@ -2574,10 +2583,14 @@ static LLVMValueRef visit_image_load(struct ac_nir_context *ctx,
                res = ac_trim_vector(&ctx->ac, res, instr->dest.ssa.num_components);
                res = ac_to_integer(&ctx->ac, res);
        } else {
-               args.opcode = ac_image_load;
+               bool level_zero = nir_src_is_const(instr->src[3]) && nir_src_as_uint(instr->src[3]) == 0;
+
+               args.opcode = level_zero ? ac_image_load : ac_image_load_mip;
                args.resource = get_image_descriptor(ctx, instr, AC_DESC_IMAGE, false);
                get_image_coords(ctx, instr, &args, dim, is_array);
                args.dim = ac_get_image_dim(ctx->ac.chip_class, dim, is_array);
+               if (!level_zero)
+                       args.lod = get_src(ctx, instr->src[3]);
                args.dmask = 15;
                args.attributes = AC_FUNC_ATTR_READONLY;
 
@@ -2630,11 +2643,15 @@ static void visit_image_store(struct ac_nir_context *ctx,
                                             ctx->ac.i32_0, src_channels,
                                             args.cache_policy);
        } else {
-               args.opcode = ac_image_store;
+               bool level_zero = nir_src_is_const(instr->src[4]) && nir_src_as_uint(instr->src[4]) == 0;
+
+               args.opcode = level_zero ? ac_image_store : ac_image_store_mip;
                args.data[0] = ac_to_float(&ctx->ac, get_src(ctx, instr->src[3]));
                args.resource = get_image_descriptor(ctx, instr, AC_DESC_IMAGE, true);
                get_image_coords(ctx, instr, &args, dim, is_array);
                args.dim = ac_get_image_dim(ctx->ac.chip_class, dim, is_array);
+               if (!level_zero)
+                       args.lod = get_src(ctx, instr->src[4]);
                args.dmask = 15;
 
                ac_build_image_opcode(&ctx->ac, &args);
@@ -2859,7 +2876,6 @@ static void emit_membar(struct ac_llvm_context *ac,
        case nir_intrinsic_group_memory_barrier:
                wait_flags = AC_WAIT_LGKM | AC_WAIT_VLOAD | AC_WAIT_VSTORE;
                break;
-       case nir_intrinsic_memory_barrier_atomic_counter:
        case nir_intrinsic_memory_barrier_buffer:
        case nir_intrinsic_memory_barrier_image:
                wait_flags = AC_WAIT_VLOAD | AC_WAIT_VSTORE;
@@ -2968,7 +2984,8 @@ visit_load_shared(struct ac_nir_context *ctx,
 {
        LLVMValueRef values[4], derived_ptr, index, ret;
 
-       LLVMValueRef ptr = get_memory_ptr(ctx, instr->src[0]);
+       LLVMValueRef ptr = get_memory_ptr(ctx, instr->src[0],
+                                         instr->dest.ssa.bit_size);
 
        for (int chan = 0; chan < instr->num_components; chan++) {
                index = LLVMConstInt(ctx->ac.i32, chan, 0);
@@ -2987,7 +3004,8 @@ visit_store_shared(struct ac_nir_context *ctx,
        LLVMValueRef derived_ptr, data,index;
        LLVMBuilderRef builder = ctx->ac.builder;
 
-       LLVMValueRef ptr = get_memory_ptr(ctx, instr->src[1]);
+       LLVMValueRef ptr = get_memory_ptr(ctx, instr->src[1],
+                                         instr->src[0].ssa->bit_size);
        LLVMValueRef src = get_src(ctx, instr->src[0]);
 
        int writemask = nir_intrinsic_write_mask(instr);
@@ -3011,6 +3029,17 @@ static LLVMValueRef visit_var_atomic(struct ac_nir_context *ctx,
 
        const char *sync_scope = LLVM_VERSION_MAJOR >= 9 ? "workgroup-one-as" : "workgroup";
 
+       if (instr->src[0].ssa->parent_instr->type == nir_instr_type_deref) {
+               nir_deref_instr *deref = nir_instr_as_deref(instr->src[0].ssa->parent_instr);
+               if (deref->mode == nir_var_mem_global) {
+                       /* use "singlethread" sync scope to implement relaxed ordering */
+                       sync_scope = LLVM_VERSION_MAJOR >= 9 ? "singlethread-one-as" : "singlethread";
+
+                       LLVMTypeRef ptr_type = LLVMPointerType(LLVMTypeOf(src), LLVMGetPointerAddressSpace(LLVMTypeOf(ptr)));
+                       ptr = LLVMBuildBitCast(ctx->ac.builder, ptr, ptr_type , "");
+               }
+       }
+
        if (instr->intrinsic == nir_intrinsic_shared_atomic_comp_swap ||
            instr->intrinsic == nir_intrinsic_deref_atomic_comp_swap) {
                LLVMValueRef src1 = get_src(ctx, instr->src[src_idx + 1]);
@@ -3517,13 +3546,14 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                break;
        case nir_intrinsic_memory_barrier:
        case nir_intrinsic_group_memory_barrier:
-       case nir_intrinsic_memory_barrier_atomic_counter:
        case nir_intrinsic_memory_barrier_buffer:
        case nir_intrinsic_memory_barrier_image:
        case nir_intrinsic_memory_barrier_shared:
                emit_membar(&ctx->ac, instr);
                break;
-       case nir_intrinsic_barrier:
+       case nir_intrinsic_memory_barrier_tcs_patch:
+               break;
+       case nir_intrinsic_control_barrier:
                ac_emit_barrier(&ctx->ac, ctx->stage);
                break;
        case nir_intrinsic_shared_atomic_add:
@@ -3536,7 +3566,8 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_shared_atomic_xor:
        case nir_intrinsic_shared_atomic_exchange:
        case nir_intrinsic_shared_atomic_comp_swap: {
-               LLVMValueRef ptr = get_memory_ptr(ctx, instr->src[0]);
+               LLVMValueRef ptr = get_memory_ptr(ctx, instr->src[0],
+                                                 instr->src[1].ssa->bit_size);
                result = visit_var_atomic(ctx, instr, ptr, 1);
                break;
        }
@@ -3735,11 +3766,21 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                break;
        }
        case nir_intrinsic_load_constant: {
+               unsigned base = nir_intrinsic_base(instr);
+               unsigned range = nir_intrinsic_range(instr);
+
                LLVMValueRef offset = get_src(ctx, instr->src[0]);
-               LLVMValueRef base = LLVMConstInt(ctx->ac.i32,
-                                                nir_intrinsic_base(instr),
-                                                false);
-               offset = LLVMBuildAdd(ctx->ac.builder, offset, base, "");
+               offset = LLVMBuildAdd(ctx->ac.builder, offset,
+                                     LLVMConstInt(ctx->ac.i32, base, false), "");
+
+               /* Clamp the offset to avoid out-of-bound access because global
+                * instructions can't handle them.
+                */
+               LLVMValueRef size = LLVMConstInt(ctx->ac.i32, base + range, false);
+               LLVMValueRef cond = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT,
+                                                 offset, size, "");
+               offset = LLVMBuildSelect(ctx->ac.builder, cond, offset, size, "");
+
                LLVMValueRef ptr = ac_build_gep0(&ctx->ac, ctx->constant_data,
                                                 offset);
                LLVMTypeRef comp_type =
@@ -4721,14 +4762,21 @@ static void
 setup_shared(struct ac_nir_context *ctx,
             struct nir_shader *nir)
 {
-       nir_foreach_variable(variable, &nir->shared) {
-               LLVMValueRef shared =
-                       LLVMAddGlobalInAddressSpace(
-                          ctx->ac.module, glsl_to_llvm_type(&ctx->ac, variable->type),
-                          variable->name ? variable->name : "",
-                          AC_ADDR_SPACE_LDS);
-               _mesa_hash_table_insert(ctx->vars, variable, shared);
-       }
+       if (ctx->ac.lds)
+               return;
+
+       LLVMTypeRef type = LLVMArrayType(ctx->ac.i8,
+                                        nir->info.cs.shared_size);
+
+       LLVMValueRef lds =
+               LLVMAddGlobalInAddressSpace(ctx->ac.module, type,
+                                           "compute_lds",
+                                           AC_ADDR_SPACE_LDS);
+       LLVMSetAlignment(lds, 64 * 1024);
+
+       ctx->ac.lds = LLVMBuildBitCast(ctx->ac.builder, lds,
+                                      LLVMPointerType(ctx->ac.i8,
+                                                      AC_ADDR_SPACE_LDS), "");
 }
 
 void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
@@ -4869,7 +4917,7 @@ scan_tess_ctrl(nir_cf_node *cf_node, unsigned *upper_block_tf_writemask,
                                continue;
 
                        nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
-                       if (intrin->intrinsic == nir_intrinsic_barrier) {
+                       if (intrin->intrinsic == nir_intrinsic_control_barrier) {
 
                                /* If we find a barrier in nested control flow put this in the
                                 * too hard basket. In GLSL this is not possible but it is in