ac/nir: fix translation of nir_op_frcp for doubles
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index bdbe6f82e22922d7a45502d29a8b3d2b3189b44b..7b348d97f0a086346e393be605c91f8fb2dd1589 100644 (file)
@@ -541,19 +541,20 @@ struct user_sgpr_info {
 };
 
 static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
+                               gl_shader_stage stage,
                                struct user_sgpr_info *user_sgpr_info)
 {
        memset(user_sgpr_info, 0, sizeof(struct user_sgpr_info));
 
        /* until we sort out scratch/global buffers always assign ring offsets for gs/vs/es */
-       if (ctx->stage == MESA_SHADER_GEOMETRY ||
-           ctx->stage == MESA_SHADER_VERTEX ||
-           ctx->stage == MESA_SHADER_TESS_CTRL ||
-           ctx->stage == MESA_SHADER_TESS_EVAL ||
+       if (stage == MESA_SHADER_GEOMETRY ||
+           stage == MESA_SHADER_VERTEX ||
+           stage == MESA_SHADER_TESS_CTRL ||
+           stage == MESA_SHADER_TESS_EVAL ||
            ctx->is_gs_copy_shader)
                user_sgpr_info->need_ring_offsets = true;
 
-       if (ctx->stage == MESA_SHADER_FRAGMENT &&
+       if (stage == MESA_SHADER_FRAGMENT &&
            ctx->shader_info->info.ps.needs_sample_positions)
                user_sgpr_info->need_ring_offsets = true;
 
@@ -562,7 +563,8 @@ static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
                user_sgpr_info->sgpr_count += 2;
        }
 
-       switch (ctx->stage) {
+       /* FIXME: fix the number of user sgprs for merged shaders on GFX9 */
+       switch (stage) {
        case MESA_SHADER_COMPUTE:
                if (ctx->shader_info->info.cs.uses_grid_size)
                        user_sgpr_info->sgpr_count += 3;
@@ -595,10 +597,12 @@ static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
                break;
        }
 
-       if (ctx->shader_info->info.needs_push_constants)
+       if (ctx->shader_info->info.loads_push_constants)
                user_sgpr_info->sgpr_count += 2;
 
-       uint32_t remaining_sgprs = 16 - user_sgpr_info->sgpr_count;
+       uint32_t available_sgprs = ctx->options->chip_class >= GFX9 ? 32 : 16;
+       uint32_t remaining_sgprs = available_sgprs - user_sgpr_info->sgpr_count;
+
        if (remaining_sgprs / 2 < util_bitcount(ctx->shader_info->info.desc_set_used_mask)) {
                user_sgpr_info->sgpr_count += 2;
                user_sgpr_info->indirect_all_descriptor_sets = true;
@@ -636,7 +640,7 @@ declare_global_input_sgprs(struct nir_to_llvm_context *ctx,
                add_array_arg(args, const_array(type, 32), desc_sets);
        }
 
-       if (ctx->shader_info->info.needs_push_constants) {
+       if (ctx->shader_info->info.loads_push_constants) {
                /* 1 for push constants and dynamic descriptors */
                add_array_arg(args, type, &ctx->push_constants);
        }
@@ -669,9 +673,14 @@ declare_vs_input_vgprs(struct nir_to_llvm_context *ctx, struct arg_info *args)
 {
        add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.vertex_id);
        if (!ctx->is_gs_copy_shader) {
-               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->rel_auto_id);
-               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->vs_prim_id);
-               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+               if (ctx->options->key.vs.as_ls) {
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->rel_auto_id);
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+               } else {
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->vs_prim_id);
+               }
+               add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* unused */
        }
 }
 
@@ -722,7 +731,7 @@ set_global_input_locs(struct nir_to_llvm_context *ctx, gl_shader_stage stage,
                ctx->shader_info->need_indirect_descriptor_sets = true;
        }
 
-       if (ctx->shader_info->info.needs_push_constants) {
+       if (ctx->shader_info->info.loads_push_constants) {
                set_loc_shader(ctx, AC_UD_PUSH_CONSTANTS, user_sgpr_idx, 2);
        }
 }
@@ -760,7 +769,7 @@ static void create_function(struct nir_to_llvm_context *ctx,
        struct arg_info args = {};
        LLVMValueRef desc_sets;
 
-       allocate_user_sgprs(ctx, &user_sgpr_info);
+       allocate_user_sgprs(ctx, stage, &user_sgpr_info);
 
        if (user_sgpr_info.need_ring_offsets && !ctx->options->supports_spill) {
                add_arg(&args, ARG_SGPR, const_array(ctx->ac.v4i32, 16),
@@ -1706,7 +1715,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_frcp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, src[0]);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      src[0]);
                break;
        case nir_op_iand:
                result = LLVMBuildAnd(ctx->ac.builder, src[0], src[1], "");
@@ -1833,7 +1843,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        case nir_op_frsq:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.sqrt",
                                              ac_to_float_type(&ctx->ac, def_type), src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, result);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      result);
                break;
        case nir_op_fpow:
                result = emit_intrin_2f_param(&ctx->ac, "llvm.pow",
@@ -2252,7 +2263,9 @@ static LLVMValueRef build_tex_intrinsic(struct ac_nir_context *ctx,
        case nir_texop_txf:
        case nir_texop_txf_ms:
        case nir_texop_samples_identical:
-               args->opcode = instr->sampler_dim == GLSL_SAMPLER_DIM_MS ? ac_image_load : ac_image_load_mip;
+               args->opcode = lod_is_zero ||
+                              instr->sampler_dim == GLSL_SAMPLER_DIM_MS ?
+                                       ac_image_load : ac_image_load_mip;
                args->compare = false;
                args->offset = false;
                break;
@@ -2569,7 +2582,7 @@ static LLVMValueRef visit_load_buffer(struct ac_nir_context *ctx,
 static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
                                           const nir_intrinsic_instr *instr)
 {
-       LLVMValueRef results[8], ret;
+       LLVMValueRef ret;
        LLVMValueRef rsrc = get_src(ctx, instr->src[0]);
        LLVMValueRef offset = get_src(ctx, instr->src[1]);
        int num_components = instr->num_components;
@@ -2580,20 +2593,9 @@ static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
        if (instr->dest.ssa.bit_size == 64)
                num_components *= 2;
 
-       for (unsigned i = 0; i < num_components; ++i) {
-               LLVMValueRef params[] = {
-                       rsrc,
-                       LLVMBuildAdd(ctx->ac.builder, LLVMConstInt(ctx->ac.i32, 4 * i, 0),
-                                    offset, "")
-               };
-               results[i] = ac_build_intrinsic(&ctx->ac, "llvm.SI.load.const.v4i32", ctx->ac.f32,
-                                               params, 2,
-                                               AC_FUNC_ATTR_READNONE |
-                                               AC_FUNC_ATTR_LEGACY);
-       }
+       ret = ac_build_buffer_load(&ctx->ac, rsrc, num_components, NULL, offset,
+                                  NULL, 0, false, false, true, true);
 
-
-       ret = ac_build_gather_values(&ctx->ac, results, num_components);
        return LLVMBuildBitCast(ctx->ac.builder, ret,
                                get_def_type(ctx, &instr->dest.ssa), "");
 }
@@ -3822,19 +3824,18 @@ static void emit_membar(struct nir_to_llvm_context *ctx,
                ac_build_waitcnt(&ctx->ac, waitcnt);
 }
 
-static void emit_barrier(struct nir_to_llvm_context *ctx)
+static void emit_barrier(struct ac_llvm_context *ac, gl_shader_stage stage)
 {
        /* SI only (thanks to a hw bug workaround):
         * The real barrier instruction isn’t needed, because an entire patch
         * always fits into a single wave.
         */
-       if (ctx->options->chip_class == SI &&
-           ctx->stage == MESA_SHADER_TESS_CTRL) {
-               ac_build_waitcnt(&ctx->ac, LGKM_CNT & VM_CNT);
+       if (ac->chip_class == SI && stage == MESA_SHADER_TESS_CTRL) {
+               ac_build_waitcnt(ac, LGKM_CNT & VM_CNT);
                return;
        }
-       ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.s.barrier",
-                          ctx->ac.voidt, NULL, 0, AC_FUNC_ATTR_CONVERGENT);
+       ac_build_intrinsic(ac, "llvm.amdgcn.s.barrier",
+                          ac->voidt, NULL, 0, AC_FUNC_ATTR_CONVERGENT);
 }
 
 static void emit_discard_if(struct ac_nir_context *ctx,
@@ -4147,9 +4148,11 @@ visit_end_primitive(struct nir_to_llvm_context *ctx,
 }
 
 static LLVMValueRef
-visit_load_tess_coord(struct nir_to_llvm_context *ctx,
-                     const nir_intrinsic_instr *instr)
+load_tess_coord(struct ac_shader_abi *abi, LLVMTypeRef type,
+               unsigned num_components)
 {
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
+
        LLVMValueRef coord[4] = {
                ctx->tes_u,
                ctx->tes_v,
@@ -4161,9 +4164,15 @@ visit_load_tess_coord(struct nir_to_llvm_context *ctx,
                coord[2] = LLVMBuildFSub(ctx->builder, ctx->ac.f32_1,
                                        LLVMBuildFAdd(ctx->builder, coord[0], coord[1], ""), "");
 
-       LLVMValueRef result = ac_build_gather_values(&ctx->ac, coord, instr->num_components);
-       return LLVMBuildBitCast(ctx->builder, result,
-                               get_def_type(ctx->nir, &instr->dest.ssa), "");
+       LLVMValueRef result = ac_build_gather_values(&ctx->ac, coord, num_components);
+       return LLVMBuildBitCast(ctx->builder, result, type, "");
+}
+
+static LLVMValueRef
+load_patch_vertices_in(struct ac_shader_abi *abi)
+{
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
+       return LLVMConstInt(ctx->ac.i32, ctx->options->key.tcs.input_vertices, false);
 }
 
 static void visit_intrinsic(struct ac_nir_context *ctx,
@@ -4326,7 +4335,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                emit_membar(ctx->nctx, instr);
                break;
        case nir_intrinsic_barrier:
-               emit_barrier(ctx->nctx);
+               emit_barrier(&ctx->ac, ctx->stage);
                break;
        case nir_intrinsic_var_atomic_add:
        case nir_intrinsic_var_atomic_imin:
@@ -4352,11 +4361,21 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_end_primitive:
                visit_end_primitive(ctx->nctx, instr);
                break;
-       case nir_intrinsic_load_tess_coord:
-               result = visit_load_tess_coord(ctx->nctx, instr);
+       case nir_intrinsic_load_tess_coord: {
+               LLVMTypeRef type = ctx->nctx ?
+                       get_def_type(ctx->nctx->nir, &instr->dest.ssa) :
+                       NULL;
+               result = ctx->abi->load_tess_coord(ctx->abi, type, instr->num_components);
+               break;
+       }
+       case nir_intrinsic_load_tess_level_outer:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_OUTER);
+               break;
+       case nir_intrinsic_load_tess_level_inner:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_INNER);
                break;
        case nir_intrinsic_load_patch_vertices_in:
-               result = LLVMConstInt(ctx->ac.i32, ctx->nctx->options->key.tcs.input_vertices, false);
+               result = ctx->abi->load_patch_vertices_in(ctx->abi);
                break;
        default:
                fprintf(stderr, "Unknown intrinsic: ");
@@ -5164,8 +5183,13 @@ handle_vs_input_decl(struct nir_to_llvm_context *ctx,
        if (ctx->options->key.vs.instance_rate_inputs & (1u << index)) {
                buffer_index = LLVMBuildAdd(ctx->builder, ctx->abi.instance_id,
                                            ctx->abi.start_instance, "");
-               ctx->shader_info->vs.vgpr_comp_cnt = MAX2(3,
-                                           ctx->shader_info->vs.vgpr_comp_cnt);
+               if (ctx->options->key.vs.as_ls) {
+                       ctx->shader_info->vs.vgpr_comp_cnt =
+                               MAX2(2, ctx->shader_info->vs.vgpr_comp_cnt);
+               } else {
+                       ctx->shader_info->vs.vgpr_comp_cnt =
+                               MAX2(1, ctx->shader_info->vs.vgpr_comp_cnt);
+               }
        } else
                buffer_index = LLVMBuildAdd(ctx->builder, ctx->abi.vertex_id,
                                            ctx->abi.base_vertex, "");
@@ -5541,6 +5565,7 @@ setup_locals(struct ac_nir_context *ctx,
        nir_foreach_variable(variable, &func->impl->locals) {
                unsigned attrib_count = glsl_count_attribute_slots(variable->type, false);
                variable->data.driver_location = ctx->num_locals * 4;
+               variable->data.location_frac = 0;
                ctx->num_locals += attrib_count;
        }
        ctx->locals = malloc(4 * ctx->num_locals * sizeof(LLVMValueRef));
@@ -6154,7 +6179,7 @@ write_tess_factors(struct nir_to_llvm_context *ctx)
        LLVMValueRef lds_base, lds_inner, lds_outer, byteoffset, buffer;
        LLVMValueRef out[6], vec0, vec1, tf_base, inner[4], outer[4];
        int i;
-       emit_barrier(ctx);
+       emit_barrier(&ctx->ac, ctx->stage);
 
        switch (ctx->options->key.tcs.primitive_mode) {
        case GL_ISOLINES:
@@ -6682,21 +6707,30 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
                        ctx.tcs_outputs_read = shaders[i]->info.outputs_read;
                        ctx.tcs_patch_outputs_read = shaders[i]->info.patch_outputs_read;
                        ctx.abi.load_tess_inputs = load_tcs_input;
+                       ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
                        ctx.abi.store_tcs_outputs = store_tcs_output;
                } else if (shaders[i]->info.stage == MESA_SHADER_TESS_EVAL) {
                        ctx.tes_primitive_mode = shaders[i]->info.tess.primitive_mode;
                        ctx.abi.load_tess_inputs = load_tes_input;
+                       ctx.abi.load_tess_coord = load_tess_coord;
+                       ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
                } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX) {
                        if (shader_info->info.vs.needs_instance_id) {
-                               ctx.shader_info->vs.vgpr_comp_cnt =
-                                       MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+                               if (ctx.ac.chip_class == GFX9 &&
+                                   shaders[shader_count - 1]->info.stage == MESA_SHADER_TESS_CTRL) {
+                                       ctx.shader_info->vs.vgpr_comp_cnt =
+                                               MAX2(2, ctx.shader_info->vs.vgpr_comp_cnt);
+                               } else {
+                                       ctx.shader_info->vs.vgpr_comp_cnt =
+                                               MAX2(1, ctx.shader_info->vs.vgpr_comp_cnt);
+                               }
                        }
                } else if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT) {
                        shader_info->fs.can_discard = shaders[i]->info.fs.uses_discard;
                }
 
                if (i)
-                       emit_barrier(&ctx);
+                       emit_barrier(&ctx.ac, ctx.stage);
 
                ac_setup_rings(&ctx);
 
@@ -6882,6 +6916,20 @@ static void ac_compile_llvm_module(LLVMTargetMachineRef tm,
        /* +3 for scratch wave offset and VCC */
        config->num_sgprs = MAX2(config->num_sgprs,
                                 shader_info->num_input_sgprs + 3);
+
+       /* Enable 64-bit and 16-bit denormals, because there is no performance
+        * cost.
+        *
+        * If denormals are enabled, all floating-point output modifiers are
+        * ignored.
+        *
+        * Don't enable denormals for 32-bit floats, because:
+        * - Floating-point output modifiers would be ignored by the hw.
+        * - Some opcodes don't support denormals, such as v_mad_f32. We would
+        *   have to stop using those.
+        * - SI & CI would be very slow.
+        */
+       config->float_mode |= V_00B028_FP_64_DENORMS;
 }
 
 static void
@@ -6914,7 +6962,7 @@ ac_fill_shader_info(struct ac_shader_variant_info *shader_info, struct nir_shade
         case MESA_SHADER_VERTEX:
                 shader_info->vs.as_es = options->key.vs.as_es;
                 shader_info->vs.as_ls = options->key.vs.as_ls;
-                /* in LS mode we need at least 1, invocation id needs 3, handled elsewhere */
+                /* in LS mode we need at least 1, invocation id needs 2, handled elsewhere */
                 if (options->key.vs.as_ls)
                         shader_info->vs.vgpr_comp_cnt = MAX2(1, shader_info->vs.vgpr_comp_cnt);
                 break;
@@ -6939,6 +6987,14 @@ void ac_compile_nir_shader(LLVMTargetMachineRef tm,
        ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, nir[0]->info.stage, dump_shader, options->supports_spill);
        for (int i = 0; i < nir_count; ++i)
                ac_fill_shader_info(shader_info, nir[i], options);
+
+       /* Determine the ES type (VS or TES) for the GS on GFX9. */
+       if (options->chip_class == GFX9) {
+               if (nir_count == 2 &&
+                   nir[1]->info.stage == MESA_SHADER_GEOMETRY) {
+                       shader_info->gs.es_type = nir[0]->info.stage;
+               }
+       }
 }
 
 static void