ac/nir: fix integer comparisons with pointers
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index 24be5ca866e1e2981212f341e4f479ff351e1b78..b90a7e3dcf2b58055a2bf2fbcf1090a84a11531e 100644 (file)
@@ -170,6 +170,17 @@ static LLVMValueRef emit_int_cmp(struct ac_llvm_context *ctx,
                                  LLVMIntPredicate pred, LLVMValueRef src0,
                                  LLVMValueRef src1)
 {
+       LLVMTypeRef src0_type = LLVMTypeOf(src0);
+       LLVMTypeRef src1_type = LLVMTypeOf(src1);
+
+       if (LLVMGetTypeKind(src0_type) == LLVMPointerTypeKind &&
+           LLVMGetTypeKind(src1_type) != LLVMPointerTypeKind) {
+               src1 = LLVMBuildIntToPtr(ctx->builder, src1, src0_type, "");
+       } else if (LLVMGetTypeKind(src1_type) == LLVMPointerTypeKind &&
+                  LLVMGetTypeKind(src0_type) != LLVMPointerTypeKind) {
+               src0 = LLVMBuildIntToPtr(ctx->builder, src0, src1_type, "");
+       }
+
        LLVMValueRef result = LLVMBuildICmp(ctx->builder, pred, src0, src1, "");
        return LLVMBuildSelect(ctx->builder, result,
                               LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
@@ -943,15 +954,45 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
        case nir_op_f2f16_rtz:
+       case nir_op_f2f16:
+       case nir_op_f2fmp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               if (LLVMTypeOf(src[0]) == ctx->ac.f64)
-                       src[0] = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ctx->ac.f32, "");
-               LLVMValueRef param[2] = { src[0], ctx->ac.f32_0 };
-               result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
-               result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
+
+               /* For OpenGL, we want fast packing with v_cvt_pkrtz_f16, but if we use it,
+                * all f32->f16 conversions have to round towards zero, because both scalar
+                * and vec2 down-conversions have to round equally.
+                */
+               if (ctx->ac.float_mode == AC_FLOAT_MODE_DEFAULT_OPENGL ||
+                   instr->op == nir_op_f2f16_rtz) {
+                       src[0] = ac_to_float(&ctx->ac, src[0]);
+
+                       if (LLVMTypeOf(src[0]) == ctx->ac.f64)
+                               src[0] = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ctx->ac.f32, "");
+
+                       /* Fast path conversion. This only works if NIR is vectorized
+                        * to vec2 16.
+                        */
+                       if (LLVMTypeOf(src[0]) == ctx->ac.v2f32) {
+                               LLVMValueRef args[] = {
+                                       ac_llvm_extract_elem(&ctx->ac, src[0], 0),
+                                       ac_llvm_extract_elem(&ctx->ac, src[0], 1),
+                               };
+                               result = ac_build_cvt_pkrtz_f16(&ctx->ac, args);
+                               break;
+                       }
+
+                       assert(ac_get_llvm_num_components(src[0]) == 1);
+                       LLVMValueRef param[2] = { src[0], LLVMGetUndef(ctx->ac.f32) };
+                       result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
+                       result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
+               } else {
+                       if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
+                               result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+                       else
+                               result = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+               }
                break;
        case nir_op_f2f16_rtne:
-       case nir_op_f2f16:
        case nir_op_f2f32:
        case nir_op_f2f64:
                src[0] = ac_to_float(&ctx->ac, src[0]);
@@ -962,6 +1003,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_u2u8:
        case nir_op_u2u16:
+       case nir_op_u2ump:
        case nir_op_u2u32:
        case nir_op_u2u64:
                if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
@@ -971,6 +1013,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_i2i8:
        case nir_op_i2i16:
+       case nir_op_i2imp:
        case nir_op_i2i32:
        case nir_op_i2i64:
                if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
@@ -4011,7 +4054,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_shuffle:
                if (ctx->ac.chip_class == GFX8 ||
                    ctx->ac.chip_class == GFX9 ||
-                   (ctx->ac.chip_class == GFX10 && ctx->ac.wave_size == 32)) {
+                   (ctx->ac.chip_class >= GFX10 && ctx->ac.wave_size == 32)) {
                        result = ac_build_shuffle(&ctx->ac, get_src(ctx, instr->src[0]),
                                                  get_src(ctx, instr->src[1]));
                } else {