ac/llvm: fix amdgcn.rcp for v2f16
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index ed0cb8008f1a940f2c18a0e838384326787d33b5..337ca6605fc824ca4db697040b001686dba34ec8 100644 (file)
@@ -216,6 +216,35 @@ static LLVMValueRef emit_intrin_1f_param(struct ac_llvm_context *ctx,
        return ac_build_intrinsic(ctx, name, result_type, params, 1, AC_FUNC_ATTR_READNONE);
 }
 
+static LLVMValueRef emit_intrin_1f_param_scalar(struct ac_llvm_context *ctx,
+                                               const char *intrin,
+                                               LLVMTypeRef result_type,
+                                               LLVMValueRef src0)
+{
+       if (LLVMGetTypeKind(result_type) != LLVMVectorTypeKind)
+               return emit_intrin_1f_param(ctx, intrin, result_type, src0);
+
+       LLVMTypeRef elem_type = LLVMGetElementType(result_type);
+       LLVMValueRef ret = LLVMGetUndef(result_type);
+
+       /* Scalarize the intrinsic, because vectors are not supported. */
+       for (unsigned i = 0; i < LLVMGetVectorSize(result_type); i++) {
+               char name[64], type[64];
+               LLVMValueRef params[] = {
+                       ac_to_float(ctx, ac_llvm_extract_elem(ctx, src0, i)),
+               };
+
+               ac_build_type_name_for_intr(LLVMTypeOf(params[0]), type, sizeof(type));
+               ASSERTED const int length = snprintf(name, sizeof(name), "%s.%s", intrin, type);
+               assert(length < sizeof(name));
+               ret = LLVMBuildInsertElement(ctx->builder, ret,
+                                            ac_build_intrinsic(ctx, name, elem_type, params,
+                                                               1, AC_FUNC_ATTR_READNONE),
+                                            LLVMConstInt(ctx->i32, i, 0), "");
+       }
+       return ret;
+}
+
 static LLVMValueRef emit_intrin_2f_param(struct ac_llvm_context *ctx,
                                       const char *intrin,
                                       LLVMTypeRef result_type,
@@ -601,10 +630,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        unsigned num_components = instr->dest.dest.ssa.num_components;
        unsigned src_components;
        LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.dest.ssa);
-       bool saved_inexact = false;
-
-       if (instr->exact)
-               saved_inexact = ac_disable_inexact_math(ctx->ac.builder);
 
        assert(nir_op_infos[instr->op].num_inputs <= ARRAY_SIZE(src));
        switch (instr->op) {
@@ -710,8 +735,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                        result = LLVMBuildFDiv(ctx->ac.builder, ctx->ac.f64_1,
                                               ac_to_float(&ctx->ac, src[0]), "");
                } else {
-                       result = emit_intrin_1f_param(&ctx->ac, "llvm.amdgcn.rcp",
-                                                     ac_to_float_type(&ctx->ac, def_type), src[0]);
+                       result = emit_intrin_1f_param_scalar(&ctx->ac, "llvm.amdgcn.rcp",
+                                                            ac_to_float_type(&ctx->ac, def_type), src[0]);
                }
                if (ctx->abi->clamp_div_by_zero)
                        result = ac_build_fmin(&ctx->ac, result,
@@ -1192,9 +1217,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = ac_to_integer_or_pointer(&ctx->ac, result);
                ctx->ssa_defs[instr->dest.dest.ssa.index] = result;
        }
-
-       if (instr->exact)
-               ac_restore_inexact_math(ctx->ac.builder, saved_inexact);
 }
 
 static void visit_load_const(struct ac_nir_context *ctx,
@@ -5167,7 +5189,7 @@ static void visit_deref(struct ac_nir_context *ctx,
                break;
        case nir_deref_type_ptr_as_array:
                if (instr->mode == nir_var_mem_global) {
-                       unsigned stride = nir_deref_instr_ptr_as_array_stride(instr);
+                       unsigned stride = nir_deref_instr_array_stride(instr);
 
                        LLVMValueRef index = get_src(ctx, instr->arr.index);
                        if (LLVMTypeOf(index) != ctx->ac.i64)
@@ -5571,33 +5593,26 @@ ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class chip_class)
         */
        indirect_mask |= nir_var_function_temp;
 
-       progress |= nir_lower_indirect_derefs(nir, indirect_mask);
+       progress |= nir_lower_indirect_derefs(nir, indirect_mask, UINT32_MAX);
        return progress;
 }
 
 static unsigned
 get_inst_tessfactor_writemask(nir_intrinsic_instr *intrin)
 {
-       if (intrin->intrinsic != nir_intrinsic_store_deref)
+       if (intrin->intrinsic != nir_intrinsic_store_output)
                return 0;
 
-       nir_variable *var =
-               nir_deref_instr_get_variable(nir_src_as_deref(intrin->src[0]));
+       unsigned writemask = nir_intrinsic_write_mask(intrin) <<
+                            nir_intrinsic_component(intrin);
+       unsigned location = nir_intrinsic_io_semantics(intrin).location;
 
-       if (var->data.mode != nir_var_shader_out)
-               return 0;
+       if (location == VARYING_SLOT_TESS_LEVEL_OUTER)
+               return writemask << 4;
+       else if (location == VARYING_SLOT_TESS_LEVEL_INNER)
+               return writemask;
 
-       unsigned writemask = 0;
-       const int location = var->data.location;
-       unsigned first_component = var->data.location_frac;
-       unsigned num_comps = intrin->dest.ssa.num_components;
-
-       if (location == VARYING_SLOT_TESS_LEVEL_INNER)
-               writemask = ((1 << (num_comps + 1)) - 1) << first_component;
-       else if (location == VARYING_SLOT_TESS_LEVEL_OUTER)
-               writemask = (((1 << (num_comps + 1)) - 1) << first_component) << 4;
-
-       return writemask;
+       return 0;
 }
 
 static void