aco: implement VK_KHR_shader_float_controls
authorRhys Perry <pendingchaos02@gmail.com>
Sat, 9 Nov 2019 20:51:45 +0000 (20:51 +0000)
committerRhys Perry <pendingchaos02@gmail.com>
Fri, 15 Nov 2019 17:36:21 +0000 (17:36 +0000)
This actually supports more of the extension than the LLVM backend but we
can't enable it because ACO doesn't work with all stages yet.

With more of it enabled, some CTS tests fail because our 64-bit sqrt
is very imprecise. I can't find any precision requirements for it
anywhere, so I'm thinking it might be a CTS issue.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
src/amd/compiler/aco_assembler.cpp
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_lower_to_hw_instr.cpp
src/amd/compiler/aco_opt_value_numbering.cpp
src/amd/compiler/aco_optimizer.cpp
src/amd/vulkan/radv_device.c
src/amd/vulkan/radv_extensions.py

index ee575e882c982d500e9bd8c5cd5c4899d57361bd..9b76ba740dd9cfa749bed14a107fb90941a400f0 100644 (file)
@@ -105,7 +105,7 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
       encoding |=
          !instr->definitions.empty() && !(instr->definitions[0].physReg() == scc) ?
          instr->definitions[0].physReg() << 16 :
-         !instr->operands.empty() && !(instr->operands[0].physReg() == scc) ?
+         !instr->operands.empty() && instr->operands[0].physReg() <= 127 ?
          instr->operands[0].physReg() << 16 : 0;
       encoding |= sopk->imm;
       out.push_back(encoding);
index b094340b02fe8a46e619b86ff86a9edcd1539365..63fd36f372432b2fe114af747af7e39a884af170 100644 (file)
@@ -647,6 +647,61 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
                bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), els, cond));
 }
 
+void emit_scaled_op(isel_context *ctx, Builder& bld, Definition dst, Temp val,
+                    aco_opcode op, uint32_t undo)
+{
+   /* multiply by 16777216 to handle denormals */
+   Temp is_denormal = bld.vopc(aco_opcode::v_cmp_class_f32, bld.hint_vcc(bld.def(s2)),
+                               as_vgpr(ctx, val), bld.copy(bld.def(v1), Operand((1u << 7) | (1u << 4))));
+   Temp scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x4b800000u), val);
+   scaled = bld.vop1(op, bld.def(v1), scaled);
+   scaled = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(undo), scaled);
+
+   Temp not_scaled = bld.vop1(op, bld.def(v1), val);
+
+   bld.vop2(aco_opcode::v_cndmask_b32, dst, not_scaled, scaled, is_denormal);
+}
+
+void emit_rcp(isel_context *ctx, Builder& bld, Definition dst, Temp val)
+{
+   if (ctx->block->fp_mode.denorm32 == 0) {
+      bld.vop1(aco_opcode::v_rcp_f32, dst, val);
+      return;
+   }
+
+   emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rcp_f32, 0x4b800000u);
+}
+
+void emit_rsq(isel_context *ctx, Builder& bld, Definition dst, Temp val)
+{
+   if (ctx->block->fp_mode.denorm32 == 0) {
+      bld.vop1(aco_opcode::v_rsq_f32, dst, val);
+      return;
+   }
+
+   emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_rsq_f32, 0x45800000u);
+}
+
+void emit_sqrt(isel_context *ctx, Builder& bld, Definition dst, Temp val)
+{
+   if (ctx->block->fp_mode.denorm32 == 0) {
+      bld.vop1(aco_opcode::v_sqrt_f32, dst, val);
+      return;
+   }
+
+   emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_sqrt_f32, 0x39800000u);
+}
+
+void emit_log2(isel_context *ctx, Builder& bld, Definition dst, Temp val)
+{
+   if (ctx->block->fp_mode.denorm32 == 0) {
+      bld.vop1(aco_opcode::v_log_f32, dst, val);
+      return;
+   }
+
+   emit_scaled_op(ctx, bld, dst, val, aco_opcode::v_log_f32, 0xc1c00000u);
+}
+
 void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
 {
    if (!instr->dest.dest.is_ssa) {
@@ -1399,7 +1454,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_frsq: {
       if (dst.size() == 1) {
-         emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f32, dst);
+         emit_rsq(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
       } else if (dst.size() == 2) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f64, dst);
       } else {
@@ -1412,8 +1467,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_fneg: {
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (dst.size() == 1) {
+         if (ctx->block->fp_mode.must_flush_denorms32)
+            src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src));
          bld.vop2(aco_opcode::v_xor_b32, Definition(dst), Operand(0x80000000u), as_vgpr(ctx, src));
       } else if (dst.size() == 2) {
+         if (ctx->block->fp_mode.must_flush_denorms16_64)
+            src = bld.vop3(aco_opcode::v_mul_f64, bld.def(v2), Operand(0x3FF0000000000000lu), as_vgpr(ctx, src));
          Temp upper = bld.tmp(v1), lower = bld.tmp(v1);
          bld.pseudo(aco_opcode::p_split_vector, Definition(lower), Definition(upper), src);
          upper = bld.vop2(aco_opcode::v_xor_b32, bld.def(v1), Operand(0x80000000u), upper);
@@ -1428,8 +1487,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_fabs: {
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (dst.size() == 1) {
+         if (ctx->block->fp_mode.must_flush_denorms32)
+            src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src));
          bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0x7FFFFFFFu), as_vgpr(ctx, src));
       } else if (dst.size() == 2) {
+         if (ctx->block->fp_mode.must_flush_denorms16_64)
+            src = bld.vop3(aco_opcode::v_mul_f64, bld.def(v2), Operand(0x3FF0000000000000lu), as_vgpr(ctx, src));
          Temp upper = bld.tmp(v1), lower = bld.tmp(v1);
          bld.pseudo(aco_opcode::p_split_vector, Definition(lower), Definition(upper), src);
          upper = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x7FFFFFFFu), upper);
@@ -1458,7 +1521,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_flog2: {
       if (dst.size() == 1) {
-         emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f32, dst);
+         emit_log2(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1468,7 +1531,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_frcp: {
       if (dst.size() == 1) {
-         emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f32, dst);
+         emit_rcp(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
       } else if (dst.size() == 2) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f64, dst);
       } else {
@@ -1490,7 +1553,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fsqrt: {
       if (dst.size() == 1) {
-         emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f32, dst);
+         emit_sqrt(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
       } else if (dst.size() == 2) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f64, dst);
       } else {
@@ -2040,8 +2103,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          Temp src0 = bld.tmp(v1);
          Temp src1 = bld.tmp(v1);
          bld.pseudo(aco_opcode::p_split_vector, Definition(src0), Definition(src1), src);
-         bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1);
-
+         if (!ctx->block->fp_mode.care_about_round32 || ctx->block->fp_mode.round32 == fp_round_tz)
+            bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1);
+         else
+            bld.vop3(aco_opcode::v_cvt_pk_u16_u32, Definition(dst),
+                     bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src0),
+                     bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src1));
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -2074,7 +2141,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fquantize2f16: {
-      Temp f16 = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), get_alu_src(ctx, instr->src[0]));
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      Temp f16 = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src);
 
       Temp mask = bld.copy(bld.def(s1), Operand(0x36Fu)); /* value is NOT negative/positive denormal value */
 
@@ -2083,7 +2151,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
 
       Temp f32 = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), f16);
 
-      bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), f32, cmp_res);
+      if (ctx->block->fp_mode.preserve_signed_zero_inf_nan32) {
+         Temp copysign_0 = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0u), as_vgpr(ctx, src));
+         bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), copysign_0, f32, cmp_res);
+      } else {
+         bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), f32, cmp_res);
+      }
       break;
    }
    case nir_op_bfm: {
@@ -7593,6 +7666,56 @@ void handle_bc_optimize(isel_context *ctx)
    }
 }
 
+void setup_fp_mode(isel_context *ctx, nir_shader *shader)
+{
+   Program *program = ctx->program;
+
+   unsigned float_controls = shader->info.float_controls_execution_mode;
+
+   program->next_fp_mode.preserve_signed_zero_inf_nan32 =
+      float_controls & FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32;
+   program->next_fp_mode.preserve_signed_zero_inf_nan16_64 =
+      float_controls & (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 |
+                        FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64);
+
+   program->next_fp_mode.must_flush_denorms32 =
+      float_controls & FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP32;
+   program->next_fp_mode.must_flush_denorms16_64 =
+      float_controls & (FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP16 |
+                        FLOAT_CONTROLS_DENORM_FLUSH_TO_ZERO_FP64);
+
+   program->next_fp_mode.care_about_round32 =
+      float_controls & (FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP32 | FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP32);
+
+   program->next_fp_mode.care_about_round16_64 =
+      float_controls & (FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16 | FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP64 |
+                        FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16 | FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP64);
+
+   /* default to preserving fp16 and fp64 denorms, since it's free */
+   if (program->next_fp_mode.must_flush_denorms16_64)
+      program->next_fp_mode.denorm16_64 = 0;
+   else
+      program->next_fp_mode.denorm16_64 = fp_denorm_keep;
+
+   /* preserving fp32 denorms is expensive, so only do it if asked */
+   if (float_controls & FLOAT_CONTROLS_DENORM_PRESERVE_FP32)
+      program->next_fp_mode.denorm32 = fp_denorm_keep;
+   else
+      program->next_fp_mode.denorm32 = 0;
+
+   if (float_controls & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP32)
+      program->next_fp_mode.round32 = fp_round_tz;
+   else
+      program->next_fp_mode.round32 = fp_round_ne;
+
+   if (float_controls & (FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16 | FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP64))
+      program->next_fp_mode.round16_64 = fp_round_tz;
+   else
+      program->next_fp_mode.round16_64 = fp_round_ne;
+
+   ctx->block->fp_mode = program->next_fp_mode;
+}
+
 void select_program(Program *program,
                     unsigned shader_count,
                     struct nir_shader *const *shaders,
@@ -7606,6 +7729,8 @@ void select_program(Program *program,
       nir_shader *nir = shaders[i];
       init_context(&ctx, nir);
 
+      setup_fp_mode(&ctx, nir);
+
       if (!i) {
          add_startpgm(&ctx); /* needs to be after init_context() for FS */
          append_logical_start(ctx.block);
@@ -7648,6 +7773,8 @@ void select_program(Program *program,
       ralloc_free(ctx.divergent_vals);
    }
 
+   program->config->float_mode = program->blocks[0].fp_mode.val;
+
    append_logical_end(ctx.block);
    ctx.block->kind |= block_kind_uniform;
    Builder bld(ctx.program, ctx.block);
index cdc8103497bbb655ef3ad710154080dfc50e7bd9..807ce74686839cddcb5d5dd1894036207589a8da 100644 (file)
@@ -1360,7 +1360,6 @@ setup_isel_context(Program* program,
       scratch_size = std::max(scratch_size, shaders[i]->scratch_size);
    ctx.scratch_enabled = scratch_size > 0;
    ctx.program->config->scratch_bytes_per_wave = align(scratch_size * ctx.program->wave_size, 1024);
-   ctx.program->config->float_mode = V_00B028_FP_64_DENORMS;
 
    ctx.block = ctx.program->create_and_insert_block();
    ctx.block->loop_nest_depth = 0;
index a6fe846c74d031ed663448a2eb1525bd70b99e40..59e77feffe52fe338a52c44f38576998c9272d81 100644 (file)
@@ -110,6 +110,53 @@ enum barrier_interaction {
    barrier_count = 4,
 };
 
+enum fp_round {
+   fp_round_ne = 0,
+   fp_round_pi = 1,
+   fp_round_ni = 2,
+   fp_round_tz = 3,
+};
+
+enum fp_denorm {
+   /* Note that v_rcp_f32, v_exp_f32, v_log_f32, v_sqrt_f32, v_rsq_f32 and
+    * v_mad_f32/v_madak_f32/v_madmk_f32/v_mac_f32 always flush denormals. */
+   fp_denorm_flush = 0x0,
+   fp_denorm_keep = 0x3,
+};
+
+struct float_mode {
+   /* matches encoding of the MODE register */
+   union {
+      struct {
+          fp_round round32:2;
+          fp_round round16_64:2;
+          unsigned denorm32:2;
+          unsigned denorm16_64:2;
+      };
+      uint8_t val = 0;
+   };
+   /* if false, optimizations which may remove infs/nan/-0.0 can be done */
+   bool preserve_signed_zero_inf_nan32:1;
+   bool preserve_signed_zero_inf_nan16_64:1;
+   /* if false, optimizations which may remove denormal flushing can be done */
+   bool must_flush_denorms32:1;
+   bool must_flush_denorms16_64:1;
+   bool care_about_round32:1;
+   bool care_about_round16_64:1;
+
+   /* Returns true if instructions using the mode "other" can safely use the
+    * current one instead. */
+   bool canReplace(float_mode other) const noexcept {
+      return val == other.val &&
+             (preserve_signed_zero_inf_nan32 || !other.preserve_signed_zero_inf_nan32) &&
+             (preserve_signed_zero_inf_nan16_64 || !other.preserve_signed_zero_inf_nan16_64) &&
+             (must_flush_denorms32  || !other.must_flush_denorms32) &&
+             (must_flush_denorms16_64 || !other.must_flush_denorms16_64) &&
+             (care_about_round32 || !other.care_about_round32) &&
+             (care_about_round16_64 || !other.care_about_round16_64);
+   }
+};
+
 constexpr Format asVOP3(Format format) {
    return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format);
 };
@@ -1019,6 +1066,7 @@ struct RegisterDemand {
 
 /* CFG */
 struct Block {
+   float_mode fp_mode;
    unsigned index;
    unsigned offset = 0;
    std::vector<aco_ptr<Instruction>> instructions;
@@ -1086,6 +1134,7 @@ static constexpr Stage geometry_gs = sw_gs | hw_gs;
 
 class Program final {
 public:
+   float_mode next_fp_mode;
    std::vector<Block> blocks;
    RegisterDemand max_reg_demand = RegisterDemand();
    uint16_t num_waves = 0;
@@ -1133,11 +1182,13 @@ public:
 
    Block* create_and_insert_block() {
       blocks.emplace_back(blocks.size());
+      blocks.back().fp_mode = next_fp_mode;
       return &blocks.back();
    }
 
    Block* insert_block(Block&& block) {
       block.index = blocks.size();
+      block.fp_mode = next_fp_mode;
       blocks.emplace_back(std::move(block));
       return &blocks.back();
    }
index 1502619b9db36c3cf96b0448fe5a72ffdfa27701..8db54064202a7425bfef21bd5f2446b7bebf6632 100644 (file)
@@ -592,6 +592,22 @@ void lower_to_hw_instr(Program* program)
       ctx.program = program;
       Builder bld(program, &ctx.instructions);
 
+      bool set_mode = i == 0 && block->fp_mode.val != program->config->float_mode;
+      for (unsigned pred : block->linear_preds) {
+         if (program->blocks[pred].fp_mode.val != block->fp_mode.val) {
+            set_mode = true;
+            break;
+         }
+      }
+      if (set_mode) {
+         /* only allow changing modes at top-level blocks so this doesn't break
+          * the "jump over empty blocks" optimization */
+         assert(block->kind & block_kind_top_level);
+         uint32_t mode = block->fp_mode.val;
+         /* "((size - 1) << 11) | register" (MODE is encoded as register 1) */
+         bld.sopk(aco_opcode::s_setreg_imm32_b32, Operand(mode), (7 << 11) | 1);
+      }
+
       for (size_t j = 0; j < block->instructions.size(); j++) {
          aco_ptr<Instruction>& instr = block->instructions[j];
          aco_ptr<Instruction> mov;
index 803249637d5bf3596f890d1dafff9357c0b4f6e8..40823da3c3633bee2cba48d2e33759e7f6924783 100644 (file)
@@ -303,7 +303,8 @@ void process_block(vn_ctx& ctx, Block& block)
          Instruction* orig_instr = res.first->first;
          assert(instr->definitions.size() == orig_instr->definitions.size());
          /* check if the original instruction dominates the current one */
-         if (dominates(ctx, res.first->second, block.index)) {
+         if (dominates(ctx, res.first->second, block.index) &&
+             ctx.program->blocks[res.first->second].fp_mode.canReplace(block.fp_mode)) {
             for (unsigned i = 0; i < instr->definitions.size(); i++) {
                assert(instr->definitions[i].regClass() == orig_instr->definitions[i].regClass());
                ctx.renames[instr->definitions[i].tempId()] = orig_instr->definitions[i].getTemp();
index 5b4fcf751262c6366747ac83cb3d5ff3e2fd77f2..7b66aa1eeb3de933522ddede65e20b702a2059fa 100644 (file)
@@ -548,7 +548,7 @@ bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp
    return false;
 }
 
-void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    if (instr->isSALU() || instr->isVALU() || instr->format == Format::PSEUDO) {
       ASSERTED bool all_const = false;
@@ -888,7 +888,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
                ctx.info[instr->operands[i].tempId()].set_omod4();
             } else if (instr->operands[!i].constantValue() == 0x3f000000) { /* 0.5 */
                ctx.info[instr->operands[i].tempId()].set_omod5();
-            } else if (instr->operands[!i].constantValue() == 0x3f800000) { /* 1.0 */
+            } else if (instr->operands[!i].constantValue() == 0x3f800000 &&
+                       !block.fp_mode.must_flush_denorms32) { /* 1.0 */
                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[i].getTemp());
             } else {
                continue;
@@ -1892,7 +1893,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    }
 }
 
-bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    /* check if we could apply omod on predecessor */
    if (instr->opcode == aco_opcode::v_mul_f32) {
@@ -1959,18 +1960,21 @@ bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       }
    }
 
+   /* omod has no effect if denormals are enabled */
+   bool can_use_omod = block.fp_mode.denorm32 == 0;
+
    /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
    if (!instr->definitions.empty() && ctx.uses[instr->definitions[0].tempId()] == 1 &&
        can_use_VOP3(instr) && instr_info.can_use_output_modifiers[(int)instr->opcode]) {
-      if(ctx.info[instr->definitions[0].tempId()].is_omod2()) {
+      if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod2()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 1;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (ctx.info[instr->definitions[0].tempId()].is_omod4()) {
+      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod4()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 2;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (ctx.info[instr->definitions[0].tempId()].is_omod5()) {
+      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod5()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 3;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
@@ -1987,7 +1991,7 @@ bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
 // this would mean that we'd have to fix the instruction uses while value propagation
 
-void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    if (instr->definitions.empty() || !ctx.uses[instr->definitions[0].tempId()])
       return;
@@ -1995,7 +1999,7 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (instr->isVALU()) {
       if (can_apply_sgprs(instr))
          apply_sgprs(ctx, instr);
-      if (apply_omod_clamp(ctx, instr))
+      if (apply_omod_clamp(ctx, block, instr))
          return;
    }
 
@@ -2048,9 +2052,11 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       return;
    }
    /* combine mul+add -> mad */
-   else if (instr->opcode == aco_opcode::v_add_f32 ||
-       instr->opcode == aco_opcode::v_sub_f32 ||
-       instr->opcode == aco_opcode::v_subrev_f32) {
+   else if ((instr->opcode == aco_opcode::v_add_f32 ||
+             instr->opcode == aco_opcode::v_sub_f32 ||
+             instr->opcode == aco_opcode::v_subrev_f32) &&
+            block.fp_mode.denorm32 == 0 && !block.fp_mode.preserve_signed_zero_inf_nan32) {
+      //TODO: we could use fma instead when denormals are enabled if the NIR isn't marked as precise
 
       uint32_t uses_src0 = UINT32_MAX;
       uint32_t uses_src1 = UINT32_MAX;
@@ -2394,7 +2400,7 @@ void optimize(Program* program)
    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& instr : block.instructions)
-         label_instruction(ctx, instr);
+         label_instruction(ctx, block, instr);
    }
 
    ctx.uses = std::move(dead_code_analysis(program));
@@ -2402,7 +2408,7 @@ void optimize(Program* program)
    /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& instr : block.instructions)
-         combine_instruction(ctx, instr);
+         combine_instruction(ctx, block, instr);
    }
 
    /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */
index 4775609629f96ebe99c943d251d8dc3c8089fea3..b561980c123ce4914d6a6b836601f7ec4cbd76fd 100644 (file)
@@ -1557,6 +1557,8 @@ void radv_GetPhysicalDeviceProperties2(
                         * support for changing the register. The same logic
                         * applies for the rounding modes because they are
                         * configured with the same config register.
+                        * TODO: we can enable a lot of these for ACO when it
+                        * supports all stages
                         */
                        properties->shaderDenormFlushToZeroFloat32 = true;
                        properties->shaderDenormPreserveFloat32 = false;
index 587e9820844dd8b388c43f02b487f3d395606e04..a4983ba0f6156d88616c924aae65a53c8e14d24b 100644 (file)
@@ -89,7 +89,7 @@ EXTENSIONS = [
     Extension('VK_KHR_shader_atomic_int64',               1, 'LLVM_VERSION_MAJOR >= 9'),
     Extension('VK_KHR_shader_clock',                      1, True),
     Extension('VK_KHR_shader_draw_parameters',            1, True),
-    Extension('VK_KHR_shader_float_controls',             1, '!device->use_aco'),
+    Extension('VK_KHR_shader_float_controls',             1, True),
     Extension('VK_KHR_shader_float16_int8',               1, '!device->use_aco'),
     Extension('VK_KHR_spirv_1_4',                         1, True),
     Extension('VK_KHR_storage_buffer_storage_class',      1, True),