ac/nir: fix integer comparisons with pointers
[mesa.git] / src / amd / llvm / ac_nir_to_llvm.c
index e0728822800d0fd66493792091519c9e3acc0c9f..b90a7e3dcf2b58055a2bf2fbcf1090a84a11531e 100644 (file)
@@ -170,6 +170,17 @@ static LLVMValueRef emit_int_cmp(struct ac_llvm_context *ctx,
                                  LLVMIntPredicate pred, LLVMValueRef src0,
                                  LLVMValueRef src1)
 {
+       LLVMTypeRef src0_type = LLVMTypeOf(src0);
+       LLVMTypeRef src1_type = LLVMTypeOf(src1);
+
+       if (LLVMGetTypeKind(src0_type) == LLVMPointerTypeKind &&
+           LLVMGetTypeKind(src1_type) != LLVMPointerTypeKind) {
+               src1 = LLVMBuildIntToPtr(ctx->builder, src1, src0_type, "");
+       } else if (LLVMGetTypeKind(src1_type) == LLVMPointerTypeKind &&
+                  LLVMGetTypeKind(src0_type) != LLVMPointerTypeKind) {
+               src0 = LLVMBuildIntToPtr(ctx->builder, src0, src1_type, "");
+       }
+
        LLVMValueRef result = LLVMBuildICmp(ctx->builder, pred, src0, src1, "");
        return LLVMBuildSelect(ctx->builder, result,
                               LLVMConstInt(ctx->i32, 0xFFFFFFFF, false),
@@ -943,15 +954,45 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
        case nir_op_f2f16_rtz:
+       case nir_op_f2f16:
+       case nir_op_f2fmp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               if (LLVMTypeOf(src[0]) == ctx->ac.f64)
-                       src[0] = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ctx->ac.f32, "");
-               LLVMValueRef param[2] = { src[0], ctx->ac.f32_0 };
-               result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
-               result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
+
+               /* For OpenGL, we want fast packing with v_cvt_pkrtz_f16, but if we use it,
+                * all f32->f16 conversions have to round towards zero, because both scalar
+                * and vec2 down-conversions have to round equally.
+                */
+               if (ctx->ac.float_mode == AC_FLOAT_MODE_DEFAULT_OPENGL ||
+                   instr->op == nir_op_f2f16_rtz) {
+                       src[0] = ac_to_float(&ctx->ac, src[0]);
+
+                       if (LLVMTypeOf(src[0]) == ctx->ac.f64)
+                               src[0] = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ctx->ac.f32, "");
+
+                       /* Fast path conversion. This only works if NIR is vectorized
+                        * to vec2 16.
+                        */
+                       if (LLVMTypeOf(src[0]) == ctx->ac.v2f32) {
+                               LLVMValueRef args[] = {
+                                       ac_llvm_extract_elem(&ctx->ac, src[0], 0),
+                                       ac_llvm_extract_elem(&ctx->ac, src[0], 1),
+                               };
+                               result = ac_build_cvt_pkrtz_f16(&ctx->ac, args);
+                               break;
+                       }
+
+                       assert(ac_get_llvm_num_components(src[0]) == 1);
+                       LLVMValueRef param[2] = { src[0], LLVMGetUndef(ctx->ac.f32) };
+                       result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
+                       result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
+               } else {
+                       if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
+                               result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+                       else
+                               result = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+               }
                break;
        case nir_op_f2f16_rtne:
-       case nir_op_f2f16:
        case nir_op_f2f32:
        case nir_op_f2f64:
                src[0] = ac_to_float(&ctx->ac, src[0]);
@@ -962,6 +1003,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_u2u8:
        case nir_op_u2u16:
+       case nir_op_u2ump:
        case nir_op_u2u32:
        case nir_op_u2u64:
                if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
@@ -971,6 +1013,7 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        case nir_op_i2i8:
        case nir_op_i2i16:
+       case nir_op_i2imp:
        case nir_op_i2i32:
        case nir_op_i2i64:
                if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
@@ -1436,12 +1479,14 @@ static LLVMValueRef build_tex_intrinsic(struct ac_nir_context *ctx,
        if (instr->sampler_dim == GLSL_SAMPLER_DIM_BUF) {
                unsigned mask = nir_ssa_def_components_read(&instr->dest.ssa);
 
+               assert(instr->dest.is_ssa);
                return ac_build_buffer_load_format(&ctx->ac,
                                                   args->resource,
                                                   args->coords[0],
                                                   ctx->ac.i32_0,
                                                   util_last_bit(mask),
-                                                  0, true);
+                                                  0, true,
+                                                  instr->dest.ssa.bit_size == 16);
        }
 
        args->opcode = ac_image_sample;
@@ -2782,11 +2827,13 @@ static LLVMValueRef visit_image_load(struct ac_nir_context *ctx,
                vindex = LLVMBuildExtractElement(ctx->ac.builder, get_src(ctx, instr->src[1]),
                                                 ctx->ac.i32_0, "");
 
+               assert(instr->dest.is_ssa);
                bool can_speculate = access & ACCESS_CAN_REORDER;
                res = ac_build_buffer_load_format(&ctx->ac, rsrc, vindex,
                                                  ctx->ac.i32_0, num_channels,
                                                  args.cache_policy,
-                                                 can_speculate);
+                                                 can_speculate,
+                                                 instr->dest.ssa.bit_size == 16);
                res = ac_build_expand_to_vec4(&ctx->ac, res, num_channels);
 
                res = ac_trim_vector(&ctx->ac, res, instr->dest.ssa.num_components);
@@ -2803,6 +2850,9 @@ static LLVMValueRef visit_image_load(struct ac_nir_context *ctx,
                args.dmask = 15;
                args.attributes = AC_FUNC_ATTR_READONLY;
 
+               assert(instr->dest.is_ssa);
+               args.d16 = instr->dest.ssa.bit_size == 16;
+
                res = ac_build_image_opcode(&ctx->ac, &args);
        }
        return exit_waterfall(ctx, &wctx, res);
@@ -2857,8 +2907,7 @@ static void visit_image_store(struct ac_nir_context *ctx,
                                                 ctx->ac.i32_0, "");
 
                ac_build_buffer_store_format(&ctx->ac, rsrc, src, vindex,
-                                            ctx->ac.i32_0, src_channels,
-                                            args.cache_policy);
+                                            ctx->ac.i32_0, args.cache_policy);
        } else {
                bool level_zero = nir_src_is_const(instr->src[4]) && nir_src_as_uint(instr->src[4]) == 0;
 
@@ -2870,6 +2919,7 @@ static void visit_image_store(struct ac_nir_context *ctx,
                if (!level_zero)
                        args.lod = get_src(ctx, instr->src[4]);
                args.dmask = 15;
+               args.d16 = ac_get_elem_bits(&ctx->ac, LLVMTypeOf(args.data[0])) == 16;
 
                ac_build_image_opcode(&ctx->ac, &args);
        }
@@ -4004,7 +4054,7 @@ static void visit_intrinsic(struct ac_nir_context *ctx,
        case nir_intrinsic_shuffle:
                if (ctx->ac.chip_class == GFX8 ||
                    ctx->ac.chip_class == GFX9 ||
-                   (ctx->ac.chip_class == GFX10 && ctx->ac.wave_size == 32)) {
+                   (ctx->ac.chip_class >= GFX10 && ctx->ac.wave_size == 32)) {
                        result = ac_build_shuffle(&ctx->ac, get_src(ctx, instr->src[0]),
                                                  get_src(ctx, instr->src[1]));
                } else {
@@ -4733,6 +4783,9 @@ static void visit_tex(struct ac_nir_context *ctx, nir_tex_instr *instr)
                }
        }
 
+       assert(instr->dest.is_ssa);
+       args.d16 = instr->dest.ssa.bit_size == 16;
+
        result = build_tex_intrinsic(ctx, instr, &args);
 
        if (instr->op == nir_texop_query_levels)