ac: add doubles support to isign
[mesa.git] / src / amd / common / ac_nir_to_llvm.c
index 7153c9708d418c723ffa3c5c352b6ffdcb4e2ed0..6467ed66ae5454a817d51009fb6603290b538a8d 100644 (file)
@@ -1338,26 +1338,49 @@ static LLVMValueRef emit_iabs(struct ac_llvm_context *ctx,
 }
 
 static LLVMValueRef emit_fsign(struct ac_llvm_context *ctx,
-                              LLVMValueRef src0)
+                              LLVMValueRef src0,
+                              unsigned bitsize)
 {
-       LLVMValueRef cmp, val;
+       LLVMValueRef cmp, val, zero, one;
+       LLVMTypeRef type;
 
-       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGT, src0, ctx->f32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, ctx->f32_1, src0, "");
-       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGE, val, ctx->f32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstReal(ctx->f32, -1.0), "");
+       if (bitsize == 32) {
+               type = ctx->f32;
+               zero = ctx->f32_0;
+               one = ctx->f32_1;
+       } else {
+               type = ctx->f64;
+               zero = ctx->f64_0;
+               one = ctx->f64_1;
+       }
+
+       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGT, src0, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, one, src0, "");
+       cmp = LLVMBuildFCmp(ctx->builder, LLVMRealOGE, val, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstReal(type, -1.0), "");
        return val;
 }
 
 static LLVMValueRef emit_isign(struct ac_llvm_context *ctx,
-                              LLVMValueRef src0)
+                              LLVMValueRef src0, unsigned bitsize)
 {
-       LLVMValueRef cmp, val;
+       LLVMValueRef cmp, val, zero, one;
+       LLVMTypeRef type;
 
-       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGT, src0, ctx->i32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, ctx->i32_1, src0, "");
-       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGE, val, ctx->i32_0, "");
-       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstInt(ctx->i32, -1, true), "");
+       if (bitsize == 32) {
+               type = ctx->i32;
+               zero = ctx->i32_0;
+               one = ctx->i32_1;
+       } else {
+               type = ctx->i64;
+               zero = ctx->i64_0;
+               one = ctx->i64_1;
+       }
+
+       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGT, src0, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, one, src0, "");
+       cmp = LLVMBuildICmp(ctx->builder, LLVMIntSGE, val, zero, "");
+       val = LLVMBuildSelect(ctx->builder, cmp, val, LLVMConstInt(type, -1, true), "");
        return val;
 }
 
@@ -1410,9 +1433,15 @@ static LLVMValueRef emit_f2b(struct ac_llvm_context *ctx,
 }
 
 static LLVMValueRef emit_b2i(struct ac_llvm_context *ctx,
-                            LLVMValueRef src0)
+                            LLVMValueRef src0,
+                            unsigned bitsize)
 {
-       return LLVMBuildAnd(ctx->builder, src0, ctx->i32_1, "");
+       LLVMValueRef result = LLVMBuildAnd(ctx->builder, src0, ctx->i32_1, "");
+
+       if (bitsize == 32)
+               return result;
+
+       return LLVMBuildZExt(ctx->builder, result, ctx->i64, "");
 }
 
 static LLVMValueRef emit_i2b(struct ac_llvm_context *ctx,
@@ -1715,7 +1744,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_frcp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, src[0]);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      src[0]);
                break;
        case nir_op_iand:
                result = LLVMBuildAnd(ctx->ac.builder, src[0], src[1], "");
@@ -1794,11 +1824,11 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = emit_minmax_int(&ctx->ac, LLVMIntULT, src[0], src[1]);
                break;
        case nir_op_isign:
-               result = emit_isign(&ctx->ac, src[0]);
+               result = emit_isign(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
        case nir_op_fsign:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               result = emit_fsign(&ctx->ac, src[0]);
+               result = emit_fsign(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
        case nir_op_ffloor:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.floor",
@@ -1842,7 +1872,8 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
        case nir_op_frsq:
                result = emit_intrin_1f_param(&ctx->ac, "llvm.sqrt",
                                              ac_to_float_type(&ctx->ac, def_type), src[0]);
-               result = ac_build_fdiv(&ctx->ac, ctx->ac.f32_1, result);
+               result = ac_build_fdiv(&ctx->ac, instr->dest.dest.ssa.bit_size == 32 ? ctx->ac.f32_1 : ctx->ac.f64_1,
+                                      result);
                break;
        case nir_op_fpow:
                result = emit_intrin_2f_param(&ctx->ac, "llvm.pow",
@@ -1965,7 +1996,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = emit_f2b(&ctx->ac, src[0]);
                break;
        case nir_op_b2i:
-               result = emit_b2i(&ctx->ac, src[0]);
+               result = emit_b2i(&ctx->ac, src[0], instr->dest.dest.ssa.bit_size);
                break;
        case nir_op_i2b:
                src[0] = ac_to_integer(&ctx->ac, src[0]);
@@ -2593,7 +2624,7 @@ static LLVMValueRef visit_load_ubo_buffer(struct ac_nir_context *ctx,
 
        ret = ac_build_buffer_load(&ctx->ac, rsrc, num_components, NULL, offset,
                                   NULL, 0, false, false, true, true);
-
+       ret = trim_vector(&ctx->ac, ret, num_components);
        return LLVMBuildBitCast(ctx->ac.builder, ret,
                                get_def_type(ctx, &instr->dest.ssa), "");
 }