aco: flush denorms after fmin/fmax on pre-GFX9
authorDaniel Schürmann <daniel@schuermann.dev>
Mon, 18 Nov 2019 17:44:51 +0000 (18:44 +0100)
committerDaniel Schürmann <daniel@schuermann.dev>
Sat, 7 Dec 2019 10:23:11 +0000 (11:23 +0100)
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
src/amd/compiler/aco_instruction_selection.cpp

index 0e62d22aa681ae7fb31d71458d1f423a9b3e25b4..2c386b50a147db86634a921de691895eb933030d 100644 (file)
@@ -435,7 +435,8 @@ void emit_sop2_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o
    ctx->block->instructions.emplace_back(std::move(sop2));
 }
 
-void emit_vop2_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst, bool commutative, bool swap_srcs=false)
+void emit_vop2_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst,
+                           bool commutative, bool swap_srcs=false, bool flush_denorms = false)
 {
    Builder bld(ctx->program, ctx->block);
    Temp src0 = get_alu_src(ctx, instr->src[swap_srcs ? 1 : 0]);
@@ -457,10 +458,18 @@ void emit_vop2_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o
          src1 = bld.copy(bld.def(RegType::vgpr, src1.size()), src1); //TODO: as_vgpr
       }
    }
-   bld.vop2(op, Definition(dst), src0, src1);
+
+   if (flush_denorms && ctx->program->chip_class < GFX9) {
+      assert(dst.size() == 1);
+      Temp tmp = bld.vop2(op, bld.def(v1), src0, src1);
+      bld.vop2(aco_opcode::v_mul_f32, Definition(dst), Operand(0x3f800000u), tmp);
+   } else {
+      bld.vop2(op, Definition(dst), src0, src1);
+   }
 }
 
-void emit_vop3a_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst)
+void emit_vop3a_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst,
+                            bool flush_denorms = false)
 {
    Temp src0 = get_alu_src(ctx, instr->src[0]);
    Temp src1 = get_alu_src(ctx, instr->src[1]);
@@ -476,7 +485,13 @@ void emit_vop3a_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode
       src2 = as_vgpr(ctx, src2);
 
    Builder bld(ctx->program, ctx->block);
-   bld.vop3(op, Definition(dst), src0, src1, src2);
+   if (flush_denorms && ctx->program->chip_class < GFX9) {
+      assert(dst.size() == 1);
+      Temp tmp = bld.vop3(op, Definition(dst), src0, src1, src2);
+      bld.vop2(aco_opcode::v_mul_f32, Definition(dst), Operand(0x3f800000u), tmp);
+   } else {
+      bld.vop3(op, Definition(dst), src0, src1, src2);
+   }
 }
 
 void emit_vop1_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst)
@@ -1342,11 +1357,18 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fmax: {
       if (dst.size() == 1) {
-         emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true);
+         emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32);
       } else if (dst.size() == 2) {
-         bld.vop3(aco_opcode::v_max_f64, Definition(dst),
-                  get_alu_src(ctx, instr->src[0]),
-                  as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+         if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) {
+            Temp tmp = bld.vop3(aco_opcode::v_max_f64, bld.def(v2),
+                                get_alu_src(ctx, instr->src[0]),
+                                as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+            bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp);
+         } else {
+            bld.vop3(aco_opcode::v_max_f64, Definition(dst),
+                     get_alu_src(ctx, instr->src[0]),
+                     as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+         }
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1356,11 +1378,18 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fmin: {
       if (dst.size() == 1) {
-         emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true);
+         emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32);
       } else if (dst.size() == 2) {
-         bld.vop3(aco_opcode::v_min_f64, Definition(dst),
-                  get_alu_src(ctx, instr->src[0]),
-                  as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+         if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) {
+            Temp tmp = bld.vop3(aco_opcode::v_min_f64, bld.def(v2),
+                                get_alu_src(ctx, instr->src[0]),
+                                as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+            bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp);
+         } else {
+            bld.vop3(aco_opcode::v_min_f64, Definition(dst),
+                     get_alu_src(ctx, instr->src[0]),
+                     as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+         }
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1370,7 +1399,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fmax3: {
       if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst);
+         emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1380,7 +1409,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fmin3: {
       if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst);
+         emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1390,7 +1419,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fmed3: {
       if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst);
+         emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1540,6 +1569,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (dst.size() == 1) {
          bld.vop3(aco_opcode::v_med3_f32, Definition(dst), Operand(0u), Operand(0x3f800000u), src);
+         /* apparently, it is not necessary to flush denorms if this instruction is used with these operands */
+         // TODO: confirm that this holds under any circumstances
       } else if (dst.size() == 2) {
          Instruction* add = bld.vop3(aco_opcode::v_add_f64, Definition(dst), src, Operand(0u));
          VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(add);