ac: add doubles support to isign
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index 48e2920a15855c7b447234480bba5648c86718ac..6467ed66ae5454a817d51009fb6603290b538a8d 100644 (file)
@@ -597,10 +597,12 @@ static void allocate_user_sgprs(struct nir_to_llvm_context *ctx,
                break;
        }
 
-       if (ctx->shader_info->info.needs_push_constants)
+       if (ctx->shader_info->info.loads_push_constants)
                user_sgpr_info->sgpr_count += 2;
 
-       uint32_t remaining_sgprs = 16 - user_sgpr_info->sgpr_count;
+       uint32_t available_sgprs = ctx->options->chip_class >= GFX9 ? 32 : 16;
+       uint32_t remaining_sgprs = available_sgprs - user_sgpr_info->sgpr_count;
+
        if (remaining_sgprs / 2 < util_bitcount(ctx->shader_info->info.desc_set_used_mask)) {
                user_sgpr_info->sgpr_count += 2;
                user_sgpr_info->indirect_all_descriptor_sets = true;
@@ -638,7 +640,7 @@ declare_global_input_sgprs(struct nir_to_llvm_context *ctx,
                add_array_arg(args, const_array(type, 32), desc_sets);
        }
 
-       if (ctx->shader_info->info.needs_push_constants) {
+       if (ctx->shader_info->info.loads_push_constants) {
                /* 1 for push constants and dynamic descriptors */
                add_array_arg(args, type, &ctx->push_constants);
        }
@@ -671,9 +673,14 @@ declare_vs_input_vgprs(struct nir_to_llvm_context *ctx, struct arg_info *args)
 {
        add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.vertex_id);
        if (!ctx->is_gs_copy_shader) {
-               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->rel_auto_id);
-               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->vs_prim_id);
-               add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+               if (ctx->options->key.vs.as_ls) {
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->rel_auto_id);
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+               } else {
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->abi.instance_id);
+                       add_arg(args, ARG_VGPR, ctx->ac.i32, &ctx->vs_prim_id);
+               }
+               add_arg(args, ARG_VGPR, ctx->ac.i32, NULL); /* unused */
        }
 }
 
@@ -724,7 +731,7 @@ set_global_input_locs(struct nir_to_llvm_context *ctx, gl_shader_stage stage,
                ctx->shader_info->need_indirect_descriptor_sets = true;
        }
 
-       if (ctx->shader_info->info.needs_push_constants) {
+       if (ctx->shader_info->info.loads_push_constants) {
                set_loc_shader(ctx, AC_UD_PUSH_CONSTANTS, user_sgpr_idx, 2);
        }
 }
@@ -1331,26 +1338,49 @@ static LLVMValueRef emit_iabs(struct ac_llvm_context *ctx,
 }
 
 static LLVMValueRef emit_fsign(struct ac_llvm_context *ctx,
-                              LLVMValueRef src0)
+                              LLVMValueRef src0,
+                              unsigned bitsize)
 {
-       LLVMValueRef cmp, val;
+       LLVMValueRef cmp, val, zero, one;
+       LLVMTypeRef type;
 
-       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGT, src0, ctx->f32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, ctx->f32_1, src0, "");
-       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGE, val, ctx->f32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstReal(ctx->f32, -1.0), "");
+       if (bitsize == 32) {
+               type = ctx->f32;
+               zero = ctx->f32_0;
+               one = ctx->f32_1;
+       } else {
+               type = ctx->f64;
+               zero = ctx->f64_0;
+               one = ctx->f64_1;
+       }
+
+       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGT, src0, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, one, src0, "");
+       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGE, val, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstReal(type, -1.0), "");
        return val;
 }
 
 static LLVMValueRef emit_isign(struct ac_llvm_context *ctx,
-                              LLVMValueRef src0)
+                              LLVMValueRef src0, unsigned bitsize)
 {
-       LLVMValueRef cmp, val;
+       LLVMValueRef cmp, val, zero, one;
+       LLVMTypeRef type;
 
-       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGT, src0, ctx->i32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, ctx->i32_1, src0, "");
-       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGE, val, ctx->i32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstInt(ctx->i32, -1, true), "");
+       if (bitsize == 32) {
+               type = ctx->i32;
+               zero = ctx->i32_0;
+               one = ctx->i32_1;
+       } else {
+               type = ctx->i64;
+               zero = ctx->i64_0;
+               one = ctx->i64_1;
+       }
+
+       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGT, src0, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, one, src0, "");
+       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGE, val, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstInt(type, -1, true), "");
        return val;
 }
 
@@ -1403,9 +1433,15 @@ static LLVMValueRef emit_f2b(struct ac_llvm_context *ctx,
 }
 
 static LLVMValueRef emit_b2i(struct ac_llvm_context *ctx,
-                            LLVMValueRef src0)
+                            LLVMValueRef src0,
+                            unsigned bitsize)
 {
-       return LLVMBuildAnd(ctx->builder, src0, ctx->i32_1, "");
+       LLVMValueRef result = LLVMBuildAnd(ctx->builder, src0, ctx->i32_1, "");
+
+       if (bitsize == 32)
+               return result;
+
+       return LLVMBuildZExt(ctx->builder, result, ctx->i64, "");
 }
 
 static LLVMValueRef emit_i2b(struct ac_llvm_context *ctx,
@@ -1708,7 +1744,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_frcp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, src[0]);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      src[0]);
                break;
        case nir_op_iand:
                result = LLVMBuildAnd(ctx->ac.builder, src[0], src[1], "");
@@ -1787,11 +1824,11 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = emit_minmax_int(&ctx->ac, LLVMIntULT, src[0], src[1]);
                break;
        case nir_op_isign:
-               result = emit_isign(&ctx->ac, src[0]);
+               result = emit_isign(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
        case nir_op_fsign:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = emit_fsign(&ctx->ac, src[0]);
+               result = emit_fsign(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
        case nir_op_ffloor:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.floor",
@@ -1835,7 +1872,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        case nir_op_frsq:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.sqrt",
                                              ac_to_float_type(&ctx->ac, def_type), src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, result);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      result);
                break;
        case nir_op_fpow:
                result = emit_intrin_2f_param(&ctx->ac, "llvm.pow",
@@ -1958,7 +1996,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = emit_f2b(&ctx->ac, src[0]);
                break;
        case nir_op_b2i:
-               result = emit_b2i(&ctx->ac, src[0]);
+               result = emit_b2i(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
        case nir_op_i2b:
                src[0] = ac_to_integer(&ctx->ac, src[0]);
@@ -2573,7 +2611,7 @@ static LLVMValueRef visit_load_buffer(struct ac_nir_context *ctx,
 static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
                                           const nir_intrinsic_instr *instr)
 {
-       LLVMValueRef results[8], ret;
+       LLVMValueRef ret;
        LLVMValueRef rsrc = get_src(ctx, instr->src[0]);
        LLVMValueRef offset = get_src(ctx, instr->src[1]);
        int num_components = instr->num_components;
@@ -2584,20 +2622,9 @@ static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
        if (instr->dest.ssa.bit_size == 64)
                num_components *= 2;
 
-       for (unsigned i = 0; i < num_components; ++i) {
-               LLVMValueRef params[] = {
-                       rsrc,
-                       LLVMBuildAdd(ctx->ac.builder, LLVMConstInt(ctx->ac.i32, 4 * i, 0),
-                                    offset, "")
-               };
-               results[i] = ac_build_intrinsic(&ctx->ac, "llvm.SI.load.const.v4i32", ctx->ac.f32,
-                                               params, 2,
-                                               AC_FUNC_ATTR_READNONE |
-                                               AC_FUNC_ATTR_LEGACY);
-       }
-
-
-       ret = ac_build_gather_values(&ctx->ac, results, num_components);
+       ret = ac_build_buffer_load(&ctx->ac, rsrc, num_components, NULL, offset,
+                                  NULL, 0, false, false, true, true);
+       ret = trim_vector(&ctx->ac, ret, num_components);
        return LLVMBuildBitCast(ctx->ac.builder, ret,
                                get_def_type(ctx, &instr->dest.ssa), "");
 }
@@ -3826,19 +3853,18 @@ static void emit_membar(struct nir_to_llvm_context *ctx,
                ac_build_waitcnt(&ctx->ac, waitcnt);
 }
 
-static void emit_barrier(struct nir_to_llvm_context *ctx)
+static void emit_barrier(struct ac_llvm_context *ac, gl_shader_stage stage)
 {
        /* SI only (thanks to a hw bug workaround):
         * The real barrier instruction isn’t needed, because an entire patch
         * always fits into a single wave.
         */
-       if (ctx->options->chip_class == SI &&
-           ctx->stage == MESA_SHADER_TESS_CTRL) {
-               ac_build_waitcnt(&ctx->ac, LGKM_CNT & VM_CNT);
+       if (ac->chip_class == SI && stage == MESA_SHADER_TESS_CTRL) {
+               ac_build_waitcnt(ac, LGKM_CNT & VM_CNT);
                return;
        }
-       ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.s.barrier",
-                          ctx->ac.voidt, NULL, 0, AC_FUNC_ATTR_CONVERGENT);
+       ac_build_intrinsic(ac, "llvm.amdgcn.s.barrier",
+                          ac->voidt, NULL, 0, AC_FUNC_ATTR_CONVERGENT);
 }
 
 static void emit_discard_if(struct ac_nir_context *ctx,
@@ -4171,6 +4197,13 @@ load_tess_coord(struct ac_shader_abi *abi, LLVMTypeRef type,
        return LLVMBuildBitCast(ctx->builder, result, type, "");
 }
 
+static LLVMValueRef
+load_patch_vertices_in(struct ac_shader_abi *abi)
+{
+       struct nir_to_llvm_context *ctx = nir_to_llvm_context_from_abi(abi);
+       return LLVMConstInt(ctx->ac.i32, ctx->options->key.tcs.input_vertices, false);
+}
+
 static void visit_intrinsic(struct ac_nir_context *ctx,
                             nir_intrinsic_instr *instr)
 {
@@ -4331,7 +4364,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                emit_membar(ctx->nctx, instr);
                break;
        case nir_intrinsic_barrier:
-               emit_barrier(ctx->nctx);
+               emit_barrier(&ctx->ac, ctx->stage);
                break;
        case nir_intrinsic_var_atomic_add:
        case nir_intrinsic_var_atomic_imin:
@@ -4364,8 +4397,14 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
                result = ctx->abi->load_tess_coord(ctx->abi, type, instr->num_components);
                break;
        }
+       case nir_intrinsic_load_tess_level_outer:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_OUTER);
+               break;
+       case nir_intrinsic_load_tess_level_inner:
+               result = ctx->abi->load_tess_level(ctx->abi, VARYING_SLOT_TESS_LEVEL_INNER);
+               break;
        case nir_intrinsic_load_patch_vertices_in:
-               result = LLVMConstInt(ctx->ac.i32, ctx->nctx->options->key.tcs.input_vertices, false);
+               result = ctx->abi->load_patch_vertices_in(ctx->abi);
                break;
        default:
                fprintf(stderr, "Unknown intrinsic: ");
@@ -5173,8 +5212,13 @@ handle_vs_input_decl(struct nir_to_llvm_context *ctx,
        if (ctx->options->key.vs.instance_rate_inputs & (1u << index)) {
                buffer_index = LLVMBuildAdd(ctx->builder, ctx->abi.instance_id,
                                            ctx->abi.start_instance, "");
-               ctx->shader_info->vs.vgpr_comp_cnt = MAX2(3,
-                                           ctx->shader_info->vs.vgpr_comp_cnt);
+               if (ctx->options->key.vs.as_ls) {
+                       ctx->shader_info->vs.vgpr_comp_cnt =
+                               MAX2(2, ctx->shader_info->vs.vgpr_comp_cnt);
+               } else {
+                       ctx->shader_info->vs.vgpr_comp_cnt =
+                               MAX2(1, ctx->shader_info->vs.vgpr_comp_cnt);
+               }
        } else
                buffer_index = LLVMBuildAdd(ctx->builder, ctx->abi.vertex_id,
                                            ctx->abi.base_vertex, "");
@@ -5550,6 +5594,7 @@ setup_locals(struct ac_nir_context *ctx,
        nir_foreach_variable(variable, &func->impl->locals) {
                unsigned attrib_count = glsl_count_attribute_slots(variable->type, false);
                variable->data.driver_location = ctx->num_locals * 4;
+               variable->data.location_frac = 0;
                ctx->num_locals += attrib_count;
        }
        ctx->locals = malloc(4 * ctx->num_locals * sizeof(LLVMValueRef));
@@ -6163,7 +6208,7 @@ write_tess_factors(struct nir_to_llvm_context *ctx)
        LLVMValueRef lds_base, lds_inner, lds_outer, byteoffset, buffer;
        LLVMValueRef out[6], vec0, vec1, tf_base, inner[4], outer[4];
        int i;
-       emit_barrier(ctx);
+       emit_barrier(&ctx->ac, ctx->stage);
 
        switch (ctx->options->key.tcs.primitive_mode) {
        case GL_ISOLINES:
@@ -6691,22 +6736,30 @@ LLVMModuleRef ac_translate_nir_to_llvm(LLVMTargetMachineRef tm,
                        ctx.tcs_outputs_read = shaders[i]->info.outputs_read;
                        ctx.tcs_patch_outputs_read = shaders[i]->info.patch_outputs_read;
                        ctx.abi.load_tess_inputs = load_tcs_input;
+                       ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
                        ctx.abi.store_tcs_outputs = store_tcs_output;
                } else if (shaders[i]->info.stage == MESA_SHADER_TESS_EVAL) {
                        ctx.tes_primitive_mode = shaders[i]->info.tess.primitive_mode;
                        ctx.abi.load_tess_inputs = load_tes_input;
                        ctx.abi.load_tess_coord = load_tess_coord;
+                       ctx.abi.load_patch_vertices_in = load_patch_vertices_in;
                } else if (shaders[i]->info.stage == MESA_SHADER_VERTEX) {
                        if (shader_info->info.vs.needs_instance_id) {
-                               ctx.shader_info->vs.vgpr_comp_cnt =
-                                       MAX2(3, ctx.shader_info->vs.vgpr_comp_cnt);
+                               if (ctx.ac.chip_class == GFX9 &&
+                                   shaders[shader_count - 1]->info.stage == MESA_SHADER_TESS_CTRL) {
+                                       ctx.shader_info->vs.vgpr_comp_cnt =
+                                               MAX2(2, ctx.shader_info->vs.vgpr_comp_cnt);
+                               } else {
+                                       ctx.shader_info->vs.vgpr_comp_cnt =
+                                               MAX2(1, ctx.shader_info->vs.vgpr_comp_cnt);
+                               }
                        }
                } else if (shaders[i]->info.stage == MESA_SHADER_FRAGMENT) {
                        shader_info->fs.can_discard = shaders[i]->info.fs.uses_discard;
                }
 
                if (i)
-                       emit_barrier(&ctx);
+                       emit_barrier(&ctx.ac, ctx.stage);
 
                ac_setup_rings(&ctx);
 
@@ -6938,7 +6991,7 @@ ac_fill_shader_info(struct ac_shader_variant_info *shader_info, struct nir_shade
         case MESA_SHADER_VERTEX:
                 shader_info->vs.as_es = options->key.vs.as_es;
                 shader_info->vs.as_ls = options->key.vs.as_ls;
-                /* in LS mode we need at least 1, invocation id needs 3, handled elsewhere */
+                /* in LS mode we need at least 1, invocation id needs 2, handled elsewhere */
                 if (options->key.vs.as_ls)
                         shader_info->vs.vgpr_comp_cnt = MAX2(1, shader_info->vs.vgpr_comp_cnt);
                 break;
@@ -6963,6 +7016,14 @@ void ac_compile_nir_shader(LLVMTargetMachineRef tm,
        ac_compile_llvm_module(tm, llvm_module, binary, config, shader_info, nir[0]->info.stage, dump_shader, options->supports_spill);
        for (int i = 0; i < nir_count; ++i)
                ac_fill_shader_info(shader_info, nir[i], options);
+
+       /* Determine the ES type (VS or TES) for the GS on GFX9. */
+       if (options->chip_class == GFX9) {
+               if (nir_count == 2 &&
+                   nir[1]->info.stage == MESA_SHADER_GEOMETRY) {
+                       shader_info->gs.es_type = nir[0]->info.stage;
+               }
+       }
 }
 
 static void