aco: fix SMEM offsets for SI/CI
[mesa.git] / src / amd / compiler / aco_optimizer.cpp
index d3b35761704874253796fe4d8b9445bac53d4e15..e92531030c82919df75d9a06476cba7ba11547d4 100644 (file)
@@ -481,6 +481,12 @@ bool can_accept_constant(aco_ptr<Instruction>& instr, unsigned operand)
 
 bool valu_can_accept_literal(opt_ctx& ctx, aco_ptr<Instruction>& instr, unsigned operand)
 {
+   /* instructions like v_cndmask_b32 can't take a literal because they always
+    * read SGPRs */
+   if (instr->operands.size() >= 3 &&
+       instr->operands[2].isTemp() && instr->operands[2].regClass().type() == RegType::sgpr)
+      return false;
+
    // TODO: VOP3 can take a literal on GFX10
    return !instr->isSDWA() && !instr->isDPP() && !instr->isVOP3() &&
           operand == 0 && can_accept_constant(instr, operand);
@@ -542,7 +548,7 @@ bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp
    return false;
 }
 
-void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    if (instr->isSALU() || instr->isVALU() || instr->format == Format::PSEUDO) {
       ASSERTED bool all_const = false;
@@ -703,7 +709,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
          SMEM_instruction *smem = static_cast<SMEM_instruction *>(instr.get());
          Temp base;
          uint32_t offset;
-         if (i == 1 && info.is_constant_or_literal() && info.val <= 0xFFFFF) {
+         if (i == 1 && info.is_constant_or_literal() &&
+             (ctx.program->chip_class < GFX8 || info.val <= 0xFFFFF)) {
             instr->operands[i] = Operand(info.val);
             continue;
          } else if (i == 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->chip_class >= GFX9) {
@@ -882,7 +889,8 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
                ctx.info[instr->operands[i].tempId()].set_omod4();
             } else if (instr->operands[!i].constantValue() == 0x3f000000) { /* 0.5 */
                ctx.info[instr->operands[i].tempId()].set_omod5();
-            } else if (instr->operands[!i].constantValue() == 0x3f800000) { /* 1.0 */
+            } else if (instr->operands[!i].constantValue() == 0x3f800000 &&
+                       !block.fp_mode.must_flush_denorms32) { /* 1.0 */
                ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[i].getTemp());
             } else {
                continue;
@@ -986,13 +994,13 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    case aco_opcode::s_add_u32:
       ctx.info[instr->definitions[0].tempId()].set_add_sub(instr.get());
       break;
+   case aco_opcode::s_and_b32:
    case aco_opcode::s_and_b64:
       if (instr->operands[1].isFixed() && instr->operands[1].physReg() == exec &&
           instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_uniform_bool()) {
          ctx.info[instr->definitions[1].tempId()].set_temp(ctx.info[instr->operands[0].tempId()].temp);
       }
       /* fallthrough */
-   case aco_opcode::s_and_b32:
    case aco_opcode::s_not_b32:
    case aco_opcode::s_not_b64:
    case aco_opcode::s_or_b32:
@@ -1035,6 +1043,7 @@ void label_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       ctx.info[instr->definitions[0].tempId()].set_fcmp(instr.get());
       break;
    case aco_opcode::s_cselect_b64:
+   case aco_opcode::s_cselect_b32:
       if (instr->operands[0].constantEquals((unsigned) -1) &&
           instr->operands[1].constantEquals(0)) {
          /* Found a cselect that operates on a uniform bool that comes from eg. s_cmp */
@@ -1886,7 +1895,7 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    }
 }
 
-bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    /* check if we could apply omod on predecessor */
    if (instr->opcode == aco_opcode::v_mul_f32) {
@@ -1953,18 +1962,21 @@ bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       }
    }
 
+   /* omod has no effect if denormals are enabled */
+   bool can_use_omod = block.fp_mode.denorm32 == 0;
+
    /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */
    if (!instr->definitions.empty() && ctx.uses[instr->definitions[0].tempId()] == 1 &&
        can_use_VOP3(instr) && instr_info.can_use_output_modifiers[(int)instr->opcode]) {
-      if(ctx.info[instr->definitions[0].tempId()].is_omod2()) {
+      if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod2()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 1;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (ctx.info[instr->definitions[0].tempId()].is_omod4()) {
+      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod4()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 2;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
-      } else if (ctx.info[instr->definitions[0].tempId()].is_omod5()) {
+      } else if (can_use_omod && ctx.info[instr->definitions[0].tempId()].is_omod5()) {
          to_VOP3(ctx, instr);
          static_cast<VOP3A_instruction*>(instr.get())->omod = 3;
          ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get());
@@ -1981,7 +1993,7 @@ bool apply_omod_clamp(opt_ctx &ctx, aco_ptr<Instruction>& instr)
 // TODO: we could possibly move the whole label_instruction pass to combine_instruction:
 // this would mean that we'd have to fix the instruction uses while value propagation
 
-void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
+void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 {
    if (instr->definitions.empty() || !ctx.uses[instr->definitions[0].tempId()])
       return;
@@ -1989,7 +2001,7 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (instr->isVALU()) {
       if (can_apply_sgprs(instr))
          apply_sgprs(ctx, instr);
-      if (apply_omod_clamp(ctx, instr))
+      if (apply_omod_clamp(ctx, block, instr))
          return;
    }
 
@@ -2042,9 +2054,11 @@ void combine_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       return;
    }
    /* combine mul+add -> mad */
-   else if (instr->opcode == aco_opcode::v_add_f32 ||
-       instr->opcode == aco_opcode::v_sub_f32 ||
-       instr->opcode == aco_opcode::v_subrev_f32) {
+   else if ((instr->opcode == aco_opcode::v_add_f32 ||
+             instr->opcode == aco_opcode::v_sub_f32 ||
+             instr->opcode == aco_opcode::v_subrev_f32) &&
+            block.fp_mode.denorm32 == 0 && !block.fp_mode.preserve_signed_zero_inf_nan32) {
+      //TODO: we could use fma instead when denormals are enabled if the NIR isn't marked as precise
 
       uint32_t uses_src0 = UINT32_MAX;
       uint32_t uses_src1 = UINT32_MAX;
@@ -2388,7 +2402,7 @@ void optimize(Program* program)
    /* 1. Bottom-Up DAG pass (forward) to label all ssa-defs */
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& instr : block.instructions)
-         label_instruction(ctx, instr);
+         label_instruction(ctx, block, instr);
    }
 
    ctx.uses = std::move(dead_code_analysis(program));
@@ -2396,7 +2410,7 @@ void optimize(Program* program)
    /* 2. Combine v_mad, omod, clamp and propagate sgpr on VALU instructions */
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& instr : block.instructions)
-         combine_instruction(ctx, instr);
+         combine_instruction(ctx, block, instr);
    }
 
    /* 3. Top-Down DAG pass (backward) to select instructions (includes DCE) */