radeonsi/nir: always lower ballot masks as 64-bit, codegen handles it
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index c0442f2688ccff4c2233cc096126c5ad9c41b197..d97387ef13d275daafdf9d87eb645e1b792a73c2 100644 (file)
@@ -42,6 +42,8 @@ struct ac_nir_context {
 
        LLVMValueRef *ssa_defs;
 
+       LLVMValueRef scratch;
+
        struct hash_table *defs;
        struct hash_table *phis;
        struct hash_table *vars;
@@ -153,12 +155,6 @@ static LLVMBasicBlockRef get_block(struct ac_nir_context *nir,
        return (LLVMBasicBlockRef)entry->data;
 }
 
-static LLVMValueRef emit_iabs(struct ac_llvm_context *ctx,
-                             LLVMValueRef src0)
-{
-       return ac_build_imax(ctx, src0, LLVMBuildNeg(ctx->builder, src0, ""));
-}
-
 static LLVMValueRef get_alu_src(struct ac_nir_context *ctx,
                                 nir_alu_src src,
                                 unsigned num_components)
@@ -193,38 +189,8 @@ static LLVMValueRef get_alu_src(struct ac_nir_context *ctx,
                                                       swizzle, "");
                }
        }
-
-       LLVMTypeRef type = LLVMTypeOf(value);
-       if (LLVMGetTypeKind(type) == LLVMVectorTypeKind)
-               type = LLVMGetElementType(type);
-
-       if (src.abs) {
-               if (LLVMGetTypeKind(type) == LLVMIntegerTypeKind) {
-                       value = emit_iabs(&ctx->ac, value);
-               } else {
-                       char name[128];
-                       unsigned fsize = type == ctx->ac.f16 ? 16 :
-                                        type == ctx->ac.f32 ? 32 : 64;
-
-                       if (LLVMGetTypeKind(LLVMTypeOf(value)) == LLVMVectorTypeKind) {
-                               snprintf(name, sizeof(name), "llvm.fabs.v%uf%u",
-                                        LLVMGetVectorSize(LLVMTypeOf(value)), fsize);
-                       } else {
-                               snprintf(name, sizeof(name), "llvm.fabs.f%u", fsize);
-                       }
-
-                       value = ac_build_intrinsic(&ctx->ac, name, LLVMTypeOf(value),
-                                                  &value, 1, AC_FUNC_ATTR_READNONE);
-               }
-       }
-
-       if (src.negate) {
-               if (LLVMGetTypeKind(type) == LLVMIntegerTypeKind)
-                       value = LLVMBuildNeg(ctx->ac.builder, value, "");
-               else
-                       value = LLVMBuildFNeg(ctx->ac.builder, value, "");
-       }
-
+       assert(!src.negate);
+       assert(!src.abs);
        return value;
 }
 
@@ -314,6 +280,12 @@ static LLVMValueRef emit_bcsel(struct ac_llvm_context *ctx,
                               ac_to_integer_or_pointer(ctx, src2), "");
 }
 
+static LLVMValueRef emit_iabs(struct ac_llvm_context *ctx,
+                             LLVMValueRef src0)
+{
+       return ac_build_imax(ctx, src0, LLVMBuildNeg(ctx->builder, src0, ""));
+}
+
 static LLVMValueRef emit_uint_carry(struct ac_llvm_context *ctx,
                                    const char *intrin,
                                    LLVMValueRef src0, LLVMValueRef src1)
@@ -1577,6 +1549,9 @@ static unsigned get_cache_policy(struct ac_nir_context *ctx,
                cache_policy |= ac_glc;
        }
 
+       if (access & ACCESS_STREAM_CACHE_POLICY)
+               cache_policy |= ac_slc;
+
        return cache_policy;
 }
 
@@ -1668,13 +1643,71 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
        }
 }
 
+static LLVMValueRef emit_ssbo_comp_swap_64(struct ac_nir_context *ctx,
+                                           LLVMValueRef descriptor,
+                                          LLVMValueRef offset,
+                                          LLVMValueRef compare,
+                                          LLVMValueRef exchange)
+{
+       LLVMBasicBlockRef start_block = NULL, then_block = NULL;
+       if (ctx->abi->robust_buffer_access) {
+               LLVMValueRef size = ac_llvm_extract_elem(&ctx->ac, descriptor, 2);
+
+               LLVMValueRef cond = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, offset, size, "");
+               start_block = LLVMGetInsertBlock(ctx->ac.builder);
+
+               ac_build_ifcc(&ctx->ac, cond, -1);
+
+               then_block = LLVMGetInsertBlock(ctx->ac.builder);
+       }
+
+       LLVMValueRef ptr_parts[2] = {
+               ac_llvm_extract_elem(&ctx->ac, descriptor, 0),
+               LLVMBuildAnd(ctx->ac.builder,
+                            ac_llvm_extract_elem(&ctx->ac, descriptor, 1),
+                            LLVMConstInt(ctx->ac.i32, 65535, 0), "")
+       };
+
+       ptr_parts[1] = LLVMBuildTrunc(ctx->ac.builder, ptr_parts[1], ctx->ac.i16, "");
+       ptr_parts[1] = LLVMBuildSExt(ctx->ac.builder, ptr_parts[1], ctx->ac.i32, "");
+
+       offset = LLVMBuildZExt(ctx->ac.builder, offset, ctx->ac.i64, "");
+
+       LLVMValueRef ptr = ac_build_gather_values(&ctx->ac, ptr_parts, 2);
+       ptr = LLVMBuildBitCast(ctx->ac.builder, ptr, ctx->ac.i64, "");
+       ptr = LLVMBuildAdd(ctx->ac.builder, ptr, offset, "");
+       ptr = LLVMBuildIntToPtr(ctx->ac.builder, ptr, LLVMPointerType(ctx->ac.i64, AC_ADDR_SPACE_GLOBAL), "");
+
+       LLVMValueRef result = ac_build_atomic_cmp_xchg(&ctx->ac, ptr, compare, exchange, "singlethread-one-as");
+       result = LLVMBuildExtractValue(ctx->ac.builder, result, 0, "");
+
+       if (ctx->abi->robust_buffer_access) {
+               ac_build_endif(&ctx->ac, -1);
+
+               LLVMBasicBlockRef incoming_blocks[2] = {
+                       start_block,
+                       then_block,
+               };
+
+               LLVMValueRef incoming_values[2] = {
+                       LLVMConstInt(ctx->ac.i64, 0, 0),
+                       result,
+               };
+               LLVMValueRef ret = LLVMBuildPhi(ctx->ac.builder, ctx->ac.i64, "");
+               LLVMAddIncoming(ret, incoming_values, incoming_blocks, 2);
+               return ret;
+       } else {
+               return result;
+       }
+}
+
 static LLVMValueRef visit_atomic_ssbo(struct ac_nir_context *ctx,
                                       const nir_intrinsic_instr *instr)
 {
        LLVMTypeRef return_type = LLVMTypeOf(get_src(ctx, instr->src[2]));
        const char *op;
        char name[64], type[8];
-       LLVMValueRef params[6];
+       LLVMValueRef params[6], descriptor;
        int arg_count = 0;
 
        switch (instr->intrinsic) {
@@ -1712,13 +1745,22 @@ static LLVMValueRef visit_atomic_ssbo(struct ac_nir_context *ctx,
                abort();
        }
 
+       descriptor = ctx->abi->load_ssbo(ctx->abi,
+                                        get_src(ctx, instr->src[0]),
+                                        true);
+
+       if (instr->intrinsic == nir_intrinsic_ssbo_atomic_comp_swap &&
+           return_type == ctx->ac.i64) {
+               return emit_ssbo_comp_swap_64(ctx, descriptor,
+                                             get_src(ctx, instr->src[1]),
+                                             get_src(ctx, instr->src[2]),
+                                             get_src(ctx, instr->src[3]));
+       }
        if (instr->intrinsic == nir_intrinsic_ssbo_atomic_comp_swap) {
                params[arg_count++] = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[3]), 0);
        }
        params[arg_count++] = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[2]), 0);
-       params[arg_count++] = ctx->abi->load_ssbo(ctx->abi,
-                                                 get_src(ctx, instr->src[0]),
-                                                 true);
+       params[arg_count++] = descriptor;
 
        if (HAVE_LLVM >= 0x900) {
                /* XXX: The new raw/struct atomic intrinsics are buggy with
@@ -2399,7 +2441,7 @@ static void get_image_coords(struct ac_nir_context *ctx,
                                                               fmask_load_address[2],
                                                               sample_index,
                                                               get_sampler_desc(ctx, nir_instr_as_deref(instr->src[0].ssa->parent_instr),
-                                                                               AC_DESC_FMASK, &instr->instr, false, false));
+                                                                               AC_DESC_FMASK, &instr->instr, true, false));
        }
        if (count == 1 && !gfx9_1d) {
                if (instr->src[1].ssa->num_components)
@@ -2638,6 +2680,27 @@ static LLVMValueRef visit_image_atomic(struct ac_nir_context *ctx,
                atomic_name = "cmpswap";
                atomic_subop = 0; /* not used */
                break;
+       case nir_intrinsic_bindless_image_atomic_inc_wrap:
+       case nir_intrinsic_image_deref_atomic_inc_wrap: {
+               atomic_name = "inc";
+               atomic_subop = ac_atomic_inc_wrap;
+               /* ATOMIC_INC instruction does:
+                *      value = (value + 1) % (data + 1)
+                * but we want:
+                *      value = (value + 1) % data
+                * So replace 'data' by 'data - 1'.
+                */
+               ctx->ssa_defs[instr->src[3].ssa->index] =
+                       LLVMBuildSub(ctx->ac.builder,
+                                    ctx->ssa_defs[instr->src[3].ssa->index],
+                                    ctx->ac.i32_1, "");
+               break;
+       }
+       case nir_intrinsic_bindless_image_atomic_dec_wrap:
+       case nir_intrinsic_image_deref_atomic_dec_wrap:
+               atomic_name = "dec";
+               atomic_subop = ac_atomic_dec_wrap;
+               break;
        default:
                abort();
        }
@@ -3041,6 +3104,9 @@ static LLVMValueRef barycentric_at_sample(struct ac_nir_context *ctx,
                                          unsigned mode,
                                          LLVMValueRef sample_id)
 {
+       if (ctx->abi->interp_at_sample_force_center)
+               return barycentric_center(ctx, mode);
+
        LLVMValueRef halfval = LLVMConstReal(ctx->ac.f32, 0.5f);
 
        /* fetch sample ID */
@@ -3139,6 +3205,8 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        switch (instr->intrinsic) {
        case nir_intrinsic_ballot:
                result = ac_build_ballot(&ctx->ac, get_src(ctx, instr->src[0]));
+               if (ctx->ac.ballot_mask_bits > ctx->ac.wave_size)
+                       result = LLVMBuildZExt(ctx->ac.builder, result, ctx->ac.iN_ballotmask, "");
                break;
        case nir_intrinsic_read_invocation:
                result = ac_build_readlane(&ctx->ac, get_src(ctx, instr->src[0]),
@@ -3247,6 +3315,10 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_load_color1:
                result = ctx->abi->color1;
                break;
+       case nir_intrinsic_load_user_data_amd:
+               assert(LLVMTypeOf(ctx->abi->user_data) == ctx->ac.v4i32);
+               result = ctx->abi->user_data;
+               break;
        case nir_intrinsic_load_instance_id:
                result = ctx->abi->instance_id;
                break;
@@ -3342,6 +3414,8 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_bindless_image_atomic_xor:
        case nir_intrinsic_bindless_image_atomic_exchange:
        case nir_intrinsic_bindless_image_atomic_comp_swap:
+       case nir_intrinsic_bindless_image_atomic_inc_wrap:
+       case nir_intrinsic_bindless_image_atomic_dec_wrap:
                result = visit_image_atomic(ctx, instr, true);
                break;
        case nir_intrinsic_image_deref_atomic_add:
@@ -3352,6 +3426,8 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_image_deref_atomic_xor:
        case nir_intrinsic_image_deref_atomic_exchange:
        case nir_intrinsic_image_deref_atomic_comp_swap:
+       case nir_intrinsic_image_deref_atomic_inc_wrap:
+       case nir_intrinsic_image_deref_atomic_dec_wrap:
                result = visit_image_atomic(ctx, instr, false);
                break;
        case nir_intrinsic_bindless_image_size:
@@ -3463,10 +3539,16 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->load_tess_coord(ctx->abi);
                break;
        case nir_intrinsic_load_tess_level_outer:
-               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_OUTER);
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_OUTER, false);
                break;
        case nir_intrinsic_load_tess_level_inner:
-               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_INNER);
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_INNER, false);
+               break;
+       case nir_intrinsic_load_tess_level_outer_default:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_OUTER, true);
+               break;
+       case nir_intrinsic_load_tess_level_inner_default:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_INNER, true);
                break;
        case nir_intrinsic_load_patch_vertices_in:
                result = ctx->abi->load_patch_vertices_in(ctx->abi);
@@ -3536,6 +3618,50 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_mbcnt_amd:
                result = ac_build_mbcnt(&ctx->ac, get_src(ctx, instr->src[0]));
                break;
+       case nir_intrinsic_load_scratch: {
+               LLVMValueRef offset = get_src(ctx, instr->src[0]);
+               LLVMValueRef ptr = ac_build_gep0(&ctx->ac, ctx->scratch,
+                                                offset);
+               LLVMTypeRef comp_type =
+                       LLVMIntTypeInContext(ctx->ac.context, instr->dest.ssa.bit_size);
+               LLVMTypeRef vec_type =
+                       instr->dest.ssa.num_components == 1 ? comp_type :
+                       LLVMVectorType(comp_type, instr->dest.ssa.num_components);
+               unsigned addr_space = LLVMGetPointerAddressSpace(LLVMTypeOf(ptr));
+               ptr = LLVMBuildBitCast(ctx->ac.builder, ptr,
+                                      LLVMPointerType(vec_type, addr_space), "");
+               result = LLVMBuildLoad(ctx->ac.builder, ptr, "");
+               break;
+       }
+       case nir_intrinsic_store_scratch: {
+               LLVMValueRef offset = get_src(ctx, instr->src[1]);
+               LLVMValueRef ptr = ac_build_gep0(&ctx->ac, ctx->scratch,
+                                                offset);
+               LLVMTypeRef comp_type =
+                       LLVMIntTypeInContext(ctx->ac.context, instr->src[0].ssa->bit_size);
+               unsigned addr_space = LLVMGetPointerAddressSpace(LLVMTypeOf(ptr));
+               ptr = LLVMBuildBitCast(ctx->ac.builder, ptr,
+                                      LLVMPointerType(comp_type, addr_space), "");
+               LLVMValueRef src = get_src(ctx, instr->src[0]);
+               unsigned wrmask = nir_intrinsic_write_mask(instr);
+               while (wrmask) {
+                       int start, count;
+                       u_bit_scan_consecutive_range(&wrmask, &start, &count);
+                       
+                       LLVMValueRef offset = LLVMConstInt(ctx->ac.i32, start, false);
+                       LLVMValueRef offset_ptr = LLVMBuildGEP(ctx->ac.builder, ptr, &offset, 1, "");
+                       LLVMTypeRef vec_type =
+                               count == 1 ? comp_type : LLVMVectorType(comp_type, count);
+                       offset_ptr = LLVMBuildBitCast(ctx->ac.builder,
+                                                     offset_ptr,
+                                                     LLVMPointerType(vec_type, addr_space),
+                                                     "");
+                       LLVMValueRef offset_src =
+                               ac_extract_components(&ctx->ac, src, start, count);
+                       LLVMBuildStore(ctx->ac.builder, offset_src, offset_ptr);
+               }
+               break;
+       }
        default:
                fprintf(stderr, "Unknown intrinsic: ");
                nir_print_instr(&instr->instr, stderr);
@@ -4436,6 +4562,18 @@ setup_locals(struct ac_nir_context *ctx,
        }
 }
 
+static void
+setup_scratch(struct ac_nir_context *ctx,
+             struct nir_shader *shader)
+{
+       if (shader->scratch_size == 0)
+               return;
+
+       ctx->scratch = ac_build_alloca_undef(&ctx->ac,
+                                            LLVMArrayType(ctx->ac.i8, shader->scratch_size),
+                                            "scratch");
+}
+
 static void
 setup_shared(struct ac_nir_context *ctx,
             struct nir_shader *nir)
@@ -4481,6 +4619,7 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
        ctx.ssa_defs = calloc(func->impl->ssa_alloc, sizeof(LLVMValueRef));
 
        setup_locals(&ctx, func);
+       setup_scratch(&ctx, nir);
 
        if (gl_shader_stage_is_compute(nir->info.stage))
                setup_shared(&ctx, nir);
@@ -4502,6 +4641,15 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
 void
 ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class chip_class)
 {
+       /* Lower large variables to scratch first so that we won't bloat the
+        * shader by generating large if ladders for them. We later lower
+        * scratch to alloca's, assuming LLVM won't generate VGPR indexing.
+        */
+       NIR_PASS_V(nir, nir_lower_vars_to_scratch,
+                  nir_var_function_temp,
+                  256,
+                  glsl_get_natural_size_align_bytes);
+
        /* While it would be nice not to have this flag, we are constrained
         * by the reality that LLVM 9.0 has buggy VGPR indexing on GFX9.
         */