ac/nir: select v_cvt_pkrtz for all conversions from f32 to f16 for radeonsi
authorMarek Olšák <marek.olsak@amd.com>
Sun, 10 May 2020 03:00:44 +0000 (23:00 -0400)
committerMarek Olšák <marek.olsak@amd.com>
Tue, 2 Jun 2020 20:29:25 +0000 (16:29 -0400)
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5003>

src/amd/llvm/ac_nir_to_llvm.c

index d47624a07e1fd96f5977fba858c0e9040bbaf510..8a707c93666becfeb71505de5be66faf1e60e9ce 100644 (file)
@@ -943,18 +943,47 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
        case nir_op_f2f16_rtz:
                result = LLVMBuildUIToFP(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                break;
        case nir_op_f2f16_rtz:
+       case nir_op_f2f16:
+       case nir_op_f2fmp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
                src[0] = ac_to_float(&ctx->ac, src[0]);
-               if (LLVMTypeOf(src[0]) == ctx->ac.f64)
-                       src[0] = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ctx->ac.f32, "");
-               LLVMValueRef param[2] = { src[0], ctx->ac.f32_0 };
-               result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
-               result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
+
+               /* For OpenGL, we want fast packing with v_cvt_pkrtz_f16, but if we use it,
+                * all f32->f16 conversions have to round towards zero, because both scalar
+                * and vec2 down-conversions have to round equally.
+                */
+               if (ctx->ac.float_mode == AC_FLOAT_MODE_DEFAULT_OPENGL ||
+                   instr->op == nir_op_f2f16_rtz) {
+                       src[0] = ac_to_float(&ctx->ac, src[0]);
+
+                       if (LLVMTypeOf(src[0]) == ctx->ac.f64)
+                               src[0] = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ctx->ac.f32, "");
+
+                       /* Fast path conversion. This only works if NIR is vectorized
+                        * to vec2 16.
+                        */
+                       if (LLVMTypeOf(src[0]) == ctx->ac.v2f32) {
+                               LLVMValueRef args[] = {
+                                       ac_llvm_extract_elem(&ctx->ac, src[0], 0),
+                                       ac_llvm_extract_elem(&ctx->ac, src[0], 1),
+                               };
+                               result = ac_build_cvt_pkrtz_f16(&ctx->ac, args);
+                               break;
+                       }
+
+                       assert(ac_get_llvm_num_components(src[0]) == 1);
+                       LLVMValueRef param[2] = { src[0], ctx->ac.f32_0 };
+                       result = ac_build_cvt_pkrtz_f16(&ctx->ac, param);
+                       result = LLVMBuildExtractElement(ctx->ac.builder, result, ctx->ac.i32_0, "");
+               } else {
+                       if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
+                               result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+                       else
+                               result = LLVMBuildFPTrunc(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
+               }
                break;
        case nir_op_f2f16_rtne:
                break;
        case nir_op_f2f16_rtne:
-       case nir_op_f2f16:
        case nir_op_f2f32:
        case nir_op_f2f64:
        case nir_op_f2f32:
        case nir_op_f2f64:
-       case nir_op_f2fmp:
                src[0] = ac_to_float(&ctx->ac, src[0]);
                if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
                        result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");
                src[0] = ac_to_float(&ctx->ac, src[0]);
                if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(src[0])) < ac_get_elem_bits(&ctx->ac, def_type))
                        result = LLVMBuildFPExt(ctx->ac.builder, src[0], ac_to_float_type(&ctx->ac, def_type), "");