ac/nir: Lower large indirect variables to scratch
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index 826a63773238097027f8c7edd1ba058e812a7f21..5e25e838f8f3f78c5d7a1590768f95bde4d0a828 100644 (file)
@@ -42,6 +42,8 @@ struct ac_nir_context {
 
        LLVMValueRef *ssa_defs;
 
+       LLVMValueRef scratch;
+
        struct hash_table *defs;
        struct hash_table *phis;
        struct hash_table *vars;
@@ -225,7 +227,7 @@ static LLVMValueRef emit_intrin_1f_param(struct ac_llvm_context *ctx,
                ac_to_float(ctx, src0),
        };
 
-       MAYBE_UNUSED const int length = snprintf(name, sizeof(name), "%s.f%d", intrin,
+       ASSERTED const int length = snprintf(name, sizeof(name), "%s.f%d", intrin,
                                                 ac_get_elem_bits(ctx, result_type));
        assert(length < sizeof(name));
        return ac_build_intrinsic(ctx, name, result_type, params, 1, AC_FUNC_ATTR_READNONE);
@@ -242,7 +244,7 @@ static LLVMValueRef emit_intrin_2f_param(struct ac_llvm_context *ctx,
                ac_to_float(ctx, src1),
        };
 
-       MAYBE_UNUSED const int length = snprintf(name, sizeof(name), "%s.f%d", intrin,
+       ASSERTED const int length = snprintf(name, sizeof(name), "%s.f%d", intrin,
                                                 ac_get_elem_bits(ctx, result_type));
        assert(length < sizeof(name));
        return ac_build_intrinsic(ctx, name, result_type, params, 2, AC_FUNC_ATTR_READNONE);
@@ -260,7 +262,7 @@ static LLVMValueRef emit_intrin_3f_param(struct ac_llvm_context *ctx,
                ac_to_float(ctx, src2),
        };
 
-       MAYBE_UNUSED const int length = snprintf(name, sizeof(name), "%s.f%d", intrin,
+       ASSERTED const int length = snprintf(name, sizeof(name), "%s.f%d", intrin,
                                                 ac_get_elem_bits(ctx, result_type));
        assert(length < sizeof(name));
        return ac_build_intrinsic(ctx, name, result_type, params, 3, AC_FUNC_ATTR_READNONE);
@@ -1638,13 +1640,71 @@ static void visit_store_ssbo(struct ac_nir_context *ctx,
        }
 }
 
+static LLVMValueRef emit_ssbo_comp_swap_64(struct ac_nir_context *ctx,
+                                           LLVMValueRef descriptor,
+                                          LLVMValueRef offset,
+                                          LLVMValueRef compare,
+                                          LLVMValueRef exchange)
+{
+       LLVMBasicBlockRef start_block = NULL, then_block = NULL;
+       if (ctx->abi->robust_buffer_access) {
+               LLVMValueRef size = ac_llvm_extract_elem(&ctx->ac, descriptor, 2);
+
+               LLVMValueRef cond = LLVMBuildICmp(ctx->ac.builder, LLVMIntULT, offset, size, "");
+               start_block = LLVMGetInsertBlock(ctx->ac.builder);
+
+               ac_build_ifcc(&ctx->ac, cond, -1);
+
+               then_block = LLVMGetInsertBlock(ctx->ac.builder);
+       }
+
+       LLVMValueRef ptr_parts[2] = {
+               ac_llvm_extract_elem(&ctx->ac, descriptor, 0),
+               LLVMBuildAnd(ctx->ac.builder,
+                            ac_llvm_extract_elem(&ctx->ac, descriptor, 1),
+                            LLVMConstInt(ctx->ac.i32, 65535, 0), "")
+       };
+
+       ptr_parts[1] = LLVMBuildTrunc(ctx->ac.builder, ptr_parts[1], ctx->ac.i16, "");
+       ptr_parts[1] = LLVMBuildSExt(ctx->ac.builder, ptr_parts[1], ctx->ac.i32, "");
+
+       offset = LLVMBuildZExt(ctx->ac.builder, offset, ctx->ac.i64, "");
+
+       LLVMValueRef ptr = ac_build_gather_values(&ctx->ac, ptr_parts, 2);
+       ptr = LLVMBuildBitCast(ctx->ac.builder, ptr, ctx->ac.i64, "");
+       ptr = LLVMBuildAdd(ctx->ac.builder, ptr, offset, "");
+       ptr = LLVMBuildIntToPtr(ctx->ac.builder, ptr, LLVMPointerType(ctx->ac.i64, AC_ADDR_SPACE_GLOBAL), "");
+
+       LLVMValueRef result = ac_build_atomic_cmp_xchg(&ctx->ac, ptr, compare, exchange, "singlethread-one-as");
+       result = LLVMBuildExtractValue(ctx->ac.builder, result, 0, "");
+
+       if (ctx->abi->robust_buffer_access) {
+               ac_build_endif(&ctx->ac, -1);
+
+               LLVMBasicBlockRef incoming_blocks[2] = {
+                       start_block,
+                       then_block,
+               };
+
+               LLVMValueRef incoming_values[2] = {
+                       LLVMConstInt(ctx->ac.i64, 0, 0),
+                       result,
+               };
+               LLVMValueRef ret = LLVMBuildPhi(ctx->ac.builder, ctx->ac.i64, "");
+               LLVMAddIncoming(ret, incoming_values, incoming_blocks, 2);
+               return ret;
+       } else {
+               return result;
+       }
+}
+
 static LLVMValueRef visit_atomic_ssbo(struct ac_nir_context *ctx,
                                       const nir_intrinsic_instr *instr)
 {
        LLVMTypeRef return_type = LLVMTypeOf(get_src(ctx, instr->src[2]));
        const char *op;
        char name[64], type[8];
-       LLVMValueRef params[6];
+       LLVMValueRef params[6], descriptor;
        int arg_count = 0;
 
        switch (instr->intrinsic) {
@@ -1682,13 +1742,22 @@ static LLVMValueRef visit_atomic_ssbo(struct ac_nir_context *ctx,
                abort();
        }
 
+       descriptor = ctx->abi->load_ssbo(ctx->abi,
+                                        get_src(ctx, instr->src[0]),
+                                        true);
+
+       if (instr->intrinsic == nir_intrinsic_ssbo_atomic_comp_swap &&
+           return_type == ctx->ac.i64) {
+               return emit_ssbo_comp_swap_64(ctx, descriptor,
+                                             get_src(ctx, instr->src[1]),
+                                             get_src(ctx, instr->src[2]),
+                                             get_src(ctx, instr->src[3]));
+       }
        if (instr->intrinsic == nir_intrinsic_ssbo_atomic_comp_swap) {
                params[arg_count++] = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[3]), 0);
        }
        params[arg_count++] = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[2]), 0);
-       params[arg_count++] = ctx->abi->load_ssbo(ctx->abi,
-                                                 get_src(ctx, instr->src[0]),
-                                                 true);
+       params[arg_count++] = descriptor;
 
        if (HAVE_LLVM >= 0x900) {
                /* XXX: The new raw/struct atomic intrinsics are buggy with
@@ -2344,7 +2413,7 @@ static void get_image_coords(struct ac_nir_context *ctx,
        LLVMValueRef sample_index = ac_llvm_extract_elem(&ctx->ac, get_src(ctx, instr->src[2]), 0);
 
        int count;
-       MAYBE_UNUSED bool add_frag_pos = (dim == GLSL_SAMPLER_DIM_SUBPASS ||
+       ASSERTED bool add_frag_pos = (dim == GLSL_SAMPLER_DIM_SUBPASS ||
                                          dim == GLSL_SAMPLER_DIM_SUBPASS_MS);
        bool is_ms = (dim == GLSL_SAMPLER_DIM_MS ||
                      dim == GLSL_SAMPLER_DIM_SUBPASS_MS);
@@ -2546,7 +2615,7 @@ static LLVMValueRef visit_image_atomic(struct ac_nir_context *ctx,
        const char *atomic_name;
        char intrinsic_name[64];
        enum ac_atomic_op atomic_subop;
-       MAYBE_UNUSED int length;
+       ASSERTED int length;
 
        enum glsl_sampler_dim dim;
        bool is_unsigned = false;
@@ -3397,7 +3466,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        }
        case nir_intrinsic_load_interpolated_input: {
                /* We assume any indirect loads have been lowered away */
-               MAYBE_UNUSED nir_const_value *offset = nir_src_as_const_value(instr->src[1]);
+               ASSERTED nir_const_value *offset = nir_src_as_const_value(instr->src[1]);
                assert(offset);
                assert(offset[0].i32 == 0);
 
@@ -3412,7 +3481,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        }
        case nir_intrinsic_load_input: {
                /* We only lower inputs for fragment shaders ATM */
-               MAYBE_UNUSED nir_const_value *offset = nir_src_as_const_value(instr->src[0]);
+               ASSERTED nir_const_value *offset = nir_src_as_const_value(instr->src[0]);
                assert(offset);
                assert(offset[0].i32 == 0);
 
@@ -3506,6 +3575,36 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_mbcnt_amd:
                result = ac_build_mbcnt(&ctx->ac, get_src(ctx, instr->src[0]));
                break;
+       case nir_intrinsic_load_scratch: {
+               LLVMValueRef offset = get_src(ctx, instr->src[0]);
+               LLVMValueRef ptr = ac_build_gep0(&ctx->ac, ctx->scratch,
+                                                offset);
+               LLVMTypeRef comp_type =
+                       LLVMIntTypeInContext(ctx->ac.context, instr->dest.ssa.bit_size);
+               LLVMTypeRef vec_type =
+                       instr->dest.ssa.num_components == 1 ? comp_type :
+                       LLVMVectorType(comp_type, instr->dest.ssa.num_components);
+               unsigned addr_space = LLVMGetPointerAddressSpace(LLVMTypeOf(ptr));
+               ptr = LLVMBuildBitCast(ctx->ac.builder, ptr,
+                                      LLVMPointerType(vec_type, addr_space), "");
+               result = LLVMBuildLoad(ctx->ac.builder, ptr, "");
+               break;
+       }
+       case nir_intrinsic_store_scratch: {
+               LLVMValueRef offset = get_src(ctx, instr->src[1]);
+               LLVMValueRef ptr = ac_build_gep0(&ctx->ac, ctx->scratch,
+                                                offset);
+               LLVMTypeRef comp_type =
+                       LLVMIntTypeInContext(ctx->ac.context, instr->src[0].ssa->bit_size);
+               LLVMTypeRef vec_type =
+                       instr->src[0].ssa->num_components == 1 ? comp_type :
+                       LLVMVectorType(comp_type, instr->src[0].ssa->num_components);
+               unsigned addr_space = LLVMGetPointerAddressSpace(LLVMTypeOf(ptr));
+               ptr = LLVMBuildBitCast(ctx->ac.builder, ptr,
+                                      LLVMPointerType(vec_type, addr_space), "");
+               LLVMBuildStore(ctx->ac.builder, get_src(ctx, instr->src[0]), ptr);
+               break;
+       }
        default:
                fprintf(stderr, "Unknown intrinsic: ");
                nir_print_instr(&instr->instr, stderr);
@@ -4246,7 +4345,6 @@ static void visit_cf_list(struct ac_nir_context *ctx,
 
 static void visit_block(struct ac_nir_context *ctx, nir_block *block)
 {
-       LLVMBasicBlockRef llvm_block = LLVMGetInsertBlock(ctx->ac.builder);
        nir_foreach_instr(instr, block)
        {
                switch (instr->type) {
@@ -4282,7 +4380,8 @@ static void visit_block(struct ac_nir_context *ctx, nir_block *block)
                }
        }
 
-       _mesa_hash_table_insert(ctx->defs, block, llvm_block);
+       _mesa_hash_table_insert(ctx->defs, block,
+                               LLVMGetInsertBlock(ctx->ac.builder));
 }
 
 static void visit_if(struct ac_nir_context *ctx, nir_if *if_stmt)
@@ -4406,6 +4505,18 @@ setup_locals(struct ac_nir_context *ctx,
        }
 }
 
+static void
+setup_scratch(struct ac_nir_context *ctx,
+             struct nir_shader *shader)
+{
+       if (shader->scratch_size == 0)
+               return;
+
+       ctx->scratch = ac_build_alloca_undef(&ctx->ac,
+                                            LLVMArrayType(ctx->ac.i8, shader->scratch_size),
+                                            "scratch");
+}
+
 static void
 setup_shared(struct ac_nir_context *ctx,
             struct nir_shader *nir)
@@ -4451,6 +4562,7 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
        ctx.ssa_defs = calloc(func->impl->ssa_alloc, sizeof(LLVMValueRef));
 
        setup_locals(&ctx, func);
+       setup_scratch(&ctx, nir);
 
        if (gl_shader_stage_is_compute(nir->info.stage))
                setup_shared(&ctx, nir);
@@ -4472,6 +4584,15 @@ void ac_nir_translate(struct ac_llvm_context *ac, struct ac_shader_abi *abi,
 void
 ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class chip_class)
 {
+       /* Lower large variables to scratch first so that we won't bloat the
+        * shader by generating large if ladders for them. We later lower
+        * scratch to alloca's, assuming LLVM won't generate VGPR indexing.
+        */
+       NIR_PASS_V(nir, nir_lower_vars_to_scratch,
+                  nir_var_function_temp,
+                  256,
+                  glsl_get_natural_size_align_bytes);
+
        /* While it would be nice not to have this flag, we are constrained
         * by the reality that LLVM 9.0 has buggy VGPR indexing on GFX9.
         */