ac/llvm: add better code for fsign
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index 337ca6605fc824ca4db697040b001686dba34ec8..4b696f28f124eb774a15483da9c1d4eb84dad17e 100644 (file)
@@ -286,8 +286,6 @@ static LLVMValueRef emit_bcsel(struct ac_llvm_context *ctx,
        LLVMTypeRef src1_type = LLVMTypeOf(src1);
        LLVMTypeRef src2_type = LLVMTypeOf(src2);
 
-       assert(LLVMGetTypeKind(LLVMTypeOf(src0)) != LLVMVectorTypeKind);
-
        if (LLVMGetTypeKind(src1_type) == LLVMPointerTypeKind &&
            LLVMGetTypeKind(src2_type) != LLVMPointerTypeKind) {
                src2 = LLVMBuildIntToPtr(ctx->builder, src2, src1_type, "");
@@ -297,7 +295,7 @@ static LLVMValueRef emit_bcsel(struct ac_llvm_context *ctx,
        }
 
        LLVMValueRef v = LLVMBuildICmp(ctx->builder, LLVMIntNE, src0,
-                                      ctx->i32_0, "");
+                                      LLVMConstNull(LLVMTypeOf(src0)), "");
        return LLVMBuildSelect(ctx->builder, v,
                               ac_to_integer_or_pointer(ctx, src1),
                               ac_to_integer_or_pointer(ctx, src2), "");
@@ -703,17 +701,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        case nir_op_umod:
                result = LLVMBuildURem(ctx->ac.builder, src[0], src[1], "");
                break;
-       case nir_op_fmod:
-               /* lower_fmod only lower 16-bit and 32-bit fmod */
-               assert(instr->dest.dest.ssa.bit_size == 64);
-               src[0] = ac_to_float(&ctx->ac, src[0]);
-               src[1] = ac_to_float(&ctx->ac, src[1]);
-               result = ac_build_fdiv(&ctx->ac, src[0], src[1]);
-               result = emit_intrin_1f_param(&ctx->ac, "llvm.floor",
-                                             ac_to_float_type(&ctx->ac, def_type), result);
-               result = LLVMBuildFMul(ctx->ac.builder, src[1] , result, "");
-               result = LLVMBuildFSub(ctx->ac.builder, src[0], result, "");
-               break;
        case nir_op_irem:
                result = LLVMBuildSRem(ctx->ac.builder, src[0], src[1], "");
                break;
@@ -835,13 +822,11 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = ac_build_umin(&ctx->ac, src[0], src[1]);
                break;
        case nir_op_isign:
-               result = ac_build_isign(&ctx->ac, src[0],
-                                       instr->dest.dest.ssa.bit_size);
+               result = ac_build_isign(&ctx->ac, src[0]);
                break;
        case nir_op_fsign:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = ac_build_fsign(&ctx->ac, src[0],
-                                       instr->dest.dest.ssa.bit_size);
+               result = ac_build_fsign(&ctx->ac, src[0]);
                break;
        case nir_op_ffloor:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.floor",
@@ -860,9 +845,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                                              ac_to_float_type(&ctx->ac, def_type),src[0]);
                break;
        case nir_op_ffract:
-               src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = ac_build_fract(&ctx->ac, src[0],
-                                       instr->dest.dest.ssa.bit_size);
+               result = emit_intrin_1f_param_scalar(&ctx->ac, "llvm.amdgcn.fract",
+                                                    ac_to_float_type(&ctx->ac, def_type), src[0]);
                break;
        case nir_op_fsin:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.sin",
@@ -885,8 +869,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                                              ac_to_float_type(&ctx->ac, def_type), src[0]);
                break;
        case nir_op_frsq:
-               result = emit_intrin_1f_param(&ctx->ac, "llvm.amdgcn.rsq",
-                                             ac_to_float_type(&ctx->ac, def_type), src[0]);
+               result = emit_intrin_1f_param_scalar(&ctx->ac, "llvm.amdgcn.rsq",
+                                                    ac_to_float_type(&ctx->ac, def_type), src[0]);
                if (ctx->abi->clamp_div_by_zero)
                        result = ac_build_fmin(&ctx->ac, result,
                                               LLVMConstReal(ac_to_float_type(&ctx->ac, def_type), FLT_MAX));