ac/llvm: fix amdgcn.rcp for v2f16
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index 8fc9734e4a1f0b6742fc68eb85cdcf6b3a0171ce..337ca6605fc824ca4db697040b001686dba34ec8 100644 (file)
@@ -216,6 +216,35 @@ static LLVMValueRef emit_intrin_1f_param(struct ac_llvm_context *ctx,
        return ac_build_intrinsic(ctx, name, result_type, params, 1, AC_FUNC_ATTR_READNONE);
 }
 
+static LLVMValueRef emit_intrin_1f_param_scalar(struct ac_llvm_context *ctx,
+                                               const char *intrin,
+                                               LLVMTypeRef result_type,
+                                               LLVMValueRef src0)
+{
+       if (LLVMGetTypeKind(result_type) != LLVMVectorTypeKind)
+               return emit_intrin_1f_param(ctx, intrin, result_type, src0);
+
+       LLVMTypeRef elem_type = LLVMGetElementType(result_type);
+       LLVMValueRef ret = LLVMGetUndef(result_type);
+
+       /* Scalarize the intrinsic, because vectors are not supported. */
+       for (unsigned i = 0; i < LLVMGetVectorSize(result_type); i++) {
+               char name[64], type[64];
+               LLVMValueRef params[] = {
+                       ac_to_float(ctx, ac_llvm_extract_elem(ctx, src0, i)),
+               };
+
+               ac_build_type_name_for_intr(LLVMTypeOf(params[0]), type, sizeof(type));
+               ASSERTED const int length = snprintf(name, sizeof(name), "%s.%s", intrin, type);
+               assert(length < sizeof(name));
+               ret = LLVMBuildInsertElement(ctx->builder, ret,
+                                            ac_build_intrinsic(ctx, name, elem_type, params,
+                                                               1, AC_FUNC_ATTR_READNONE),
+                                            LLVMConstInt(ctx->i32, i, 0), "");
+       }
+       return ret;
+}
+
 static LLVMValueRef emit_intrin_2f_param(struct ac_llvm_context *ctx,
                                       const char *intrin,
                                       LLVMTypeRef result_type,
@@ -601,10 +630,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        unsigned num_components = instr->dest.dest.ssa.num_components;
        unsigned src_components;
        LLVMTypeRef def_type = get_def_type(ctx, &instr->dest.dest.ssa);
-       bool saved_inexact = false;
-
-       if (instr->exact)
-               saved_inexact = ac_disable_inexact_math(ctx->ac.builder);
 
        assert(nir_op_infos[instr->op].num_inputs <= ARRAY_SIZE(src));
        switch (instr->op) {
@@ -710,8 +735,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                        result = LLVMBuildFDiv(ctx->ac.builder, ctx->ac.f64_1,
                                               ac_to_float(&ctx->ac, src[0]), "");
                } else {
-                       result = emit_intrin_1f_param(&ctx->ac, "llvm.amdgcn.rcp",
-                                                     ac_to_float_type(&ctx->ac, def_type), src[0]);
+                       result = emit_intrin_1f_param_scalar(&ctx->ac, "llvm.amdgcn.rcp",
+                                                            ac_to_float_type(&ctx->ac, def_type), src[0]);
                }
                if (ctx->abi->clamp_div_by_zero)
                        result = ac_build_fmin(&ctx->ac, result,
@@ -1192,9 +1217,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = ac_to_integer_or_pointer(&ctx->ac, result);
                ctx->ssa_defs[instr->dest.dest.ssa.index] = result;
        }
-
-       if (instr->exact)
-               ac_restore_inexact_math(ctx->ac.builder, saved_inexact);
 }
 
 static void visit_load_const(struct ac_nir_context *ctx,
@@ -5167,7 +5189,7 @@ static void visit_deref(struct ac_nir_context *ctx,
                break;
        case nir_deref_type_ptr_as_array:
                if (instr->mode == nir_var_mem_global) {
-                       unsigned stride = nir_deref_instr_ptr_as_array_stride(instr);
+                       unsigned stride = nir_deref_instr_array_stride(instr);
 
                        LLVMValueRef index = get_src(ctx, instr->arr.index);
                        if (LLVMTypeOf(index) != ctx->ac.i64)
@@ -5571,7 +5593,7 @@ ac_lower_indirect_derefs(struct nir_shader *nir, enum chip_class chip_class)
         */
        indirect_mask |= nir_var_function_temp;
 
-       progress |= nir_lower_indirect_derefs(nir, indirect_mask);
+       progress |= nir_lower_indirect_derefs(nir, indirect_mask, UINT32_MAX);
        return progress;
 }