ac/llvm: fix amdgcn.rcp for v2f16
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index 4537ec7126774efcd192f24c7448fc64c0a058b7..337ca6605fc824ca4db697040b001686dba34ec8 100644 (file)
@@ -216,6 +216,35 @@ static LLVMValueRef emit_intrin_1f_param(struct ac_llvm_context *ctx,
        return ac_build_intrinsic(ctx, name, result_type, params, 1, AC_FUNC_ATTR_READNONE);
 }
 
+static LLVMValueRef emit_intrin_1f_param_scalar(struct ac_llvm_context *ctx,
+                                               const char *intrin,
+                                               LLVMTypeRef result_type,
+                                               LLVMValueRef src0)
+{
+       if (LLVMGetTypeKind(result_type) != LLVMVectorTypeKind)
+               return emit_intrin_1f_param(ctx, intrin, result_type, src0);
+
+       LLVMTypeRef elem_type = LLVMGetElementType(result_type);
+       LLVMValueRef ret = LLVMGetUndef(result_type);
+
+       /* Scalarize the intrinsic, because vectors are not supported. */
+       for (unsigned i = 0; i < LLVMGetVectorSize(result_type); i++) {
+               char name[64], type[64];
+               LLVMValueRef params[] = {
+                       ac_to_float(ctx, ac_llvm_extract_elem(ctx, src0, i)),
+               };
+
+               ac_build_type_name_for_intr(LLVMTypeOf(params[0]), type, sizeof(type));
+               ASSERTED const int length = snprintf(name, sizeof(name), "%s.%s", intrin, type);
+               assert(length < sizeof(name));
+               ret = LLVMBuildInsertElement(ctx->builder, ret,
+                                            ac_build_intrinsic(ctx, name, elem_type, params,
+                                                               1, AC_FUNC_ATTR_READNONE),
+                                            LLVMConstInt(ctx->i32, i, 0), "");
+       }
+       return ret;
+}
+
 static LLVMValueRef emit_intrin_2f_param(struct ac_llvm_context *ctx,
                                       const char *intrin,
                                       LLVMTypeRef result_type,
@@ -706,8 +735,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                        result = LLVMBuildFDiv(ctx->ac.builder, ctx->ac.f64_1,
                                               ac_to_float(&ctx->ac, src[0]), "");
                } else {
-                       result = emit_intrin_1f_param(&ctx->ac, "llvm.amdgcn.rcp",
-                                                     ac_to_float_type(&ctx->ac, def_type), src[0]);
+                       result = emit_intrin_1f_param_scalar(&ctx->ac, "llvm.amdgcn.rcp",
+                                                            ac_to_float_type(&ctx->ac, def_type), src[0]);
                }
                if (ctx->abi->clamp_div_by_zero)
                        result = ac_build_fmin(&ctx->ac, result,