aco: reserve 2 sgprs for each branch
[mesa.git] / src / amd / compiler / aco_insert_NOPs.cpp
index 7c6e100faf1da684d1efc287bc6e47831d440087..a609c18d5dc304bdc3e00e6af240d1f58e113791 100644 (file)
@@ -179,7 +179,12 @@ struct NOP_ctx_gfx10 {
 
 int get_wait_states(aco_ptr<Instruction>& instr)
 {
-   return 1;
+   if (instr->opcode == aco_opcode::s_nop)
+      return static_cast<SOPP_instruction*>(instr.get())->imm + 1;
+   else if (instr->opcode == aco_opcode::p_constaddr)
+      return 3; /* lowered to 3 instructions in the assembler */
+   else
+      return 1;
 }
 
 bool regs_intersect(PhysReg a_reg, unsigned a_size, PhysReg b_reg, unsigned b_size)
@@ -213,13 +218,23 @@ int handle_raw_hazard_internal(Program *program, Block *block,
       if (is_hazard)
          return nops_needed;
 
+      mask &= ~writemask;
       nops_needed -= get_wait_states(pred);
 
-      if (nops_needed <= 0)
+      if (nops_needed <= 0 || mask == 0)
          return 0;
    }
 
-   return 0;
+   int res = 0;
+
+   /* Loops require branch instructions, which count towards the wait
+    * states. So even with loops this should finish unless nops_needed is some
+    * huge value. */
+   for (unsigned lin_pred : block->linear_preds) {
+      res = std::max(res, handle_raw_hazard_internal<Valu, Vintrp, Salu>(
+         program, &program->blocks[lin_pred], nops_needed, reg, mask));
+   }
+   return res;
 }
 
 template <bool Valu, bool Vintrp, bool Salu>
@@ -259,6 +274,41 @@ bool test_bitset_range(BITSET_WORD *words, unsigned start, unsigned size) {
    }
 }
 
+/* A SMEM clause is any group of consecutive SMEM instructions. The
+ * instructions in this group may return out of order and/or may be replayed.
+ *
+ * To fix this potential hazard correctly, we have to make sure that when a
+ * clause has more than one instruction, no instruction in the clause writes
+ * to a register that is read by another instruction in the clause (including
+ * itself). In this case, we have to break the SMEM clause by inserting non
+ * SMEM instructions.
+ *
+ * SMEM clauses are only present on GFX8+, and only matter when XNACK is set.
+ */
+void handle_smem_clause_hazards(Program *program, NOP_ctx_gfx6 &ctx,
+                                aco_ptr<Instruction>& instr, int *NOPs)
+{
+   /* break off from previous SMEM clause if needed */
+   if (!*NOPs & (ctx.smem_clause || ctx.smem_write)) {
+      /* Don't allow clauses with store instructions since the clause's
+       * instructions may use the same address. */
+      if (ctx.smem_write || instr->definitions.empty() || instr_info.is_atomic[(unsigned)instr->opcode]) {
+         *NOPs = 1;
+      } else if (program->xnack_enabled) {
+         for (Operand op : instr->operands) {
+            if (!op.isConstant() && test_bitset_range(ctx.smem_clause_write, op.physReg(), op.size())) {
+               *NOPs = 1;
+               break;
+            }
+         }
+
+         Definition def = instr->definitions[0];
+         if (!*NOPs && test_bitset_range(ctx.smem_clause_read_write, def.physReg(), def.size()))
+            *NOPs = 1;
+      }
+   }
+}
+
 /* TODO: we don't handle accessing VCC using the actual SGPR instead of using the alias */
 void handle_instruction_gfx6(Program *program, Block *cur_block, NOP_ctx_gfx6 &ctx,
                              aco_ptr<Instruction>& instr, std::vector<aco_ptr<Instruction>>& new_instructions)
@@ -285,24 +335,7 @@ void handle_instruction_gfx6(Program *program, Block *cur_block, NOP_ctx_gfx6 &c
          }
       }
 
-      /* break off from prevous SMEM clause if needed */
-      if (!NOPs & (ctx.smem_clause || ctx.smem_write)) {
-         /* Don't allow clauses with store instructions since the clause's
-          * instructions may use the same address. */
-         if (ctx.smem_write || instr->definitions.empty() || instr_info.is_atomic[(unsigned)instr->opcode]) {
-            NOPs = 1;
-         } else {
-            for (Operand op : instr->operands) {
-               if (!op.isConstant() && test_bitset_range(ctx.smem_clause_write, op.physReg(), op.size())) {
-                  NOPs = 1;
-                  break;
-               }
-            }
-            Definition def = instr->definitions[0];
-            if (!NOPs && test_bitset_range(ctx.smem_clause_read_write, def.physReg(), def.size()))
-               NOPs = 1;
-         }
-      }
+      handle_smem_clause_hazards(program, ctx, instr, &NOPs);
    } else if (instr->isSALU()) {
       if (instr->opcode == aco_opcode::s_setreg_b32 || instr->opcode == aco_opcode::s_setreg_imm32_b32 ||
           instr->opcode == aco_opcode::s_getreg_b32) {
@@ -399,8 +432,11 @@ void handle_instruction_gfx6(Program *program, Block *cur_block, NOP_ctx_gfx6 &c
    if ((ctx.smem_clause || ctx.smem_write) && (NOPs || instr->format != Format::SMEM)) {
       ctx.smem_clause = false;
       ctx.smem_write = false;
-      BITSET_ZERO(ctx.smem_clause_read_write);
-      BITSET_ZERO(ctx.smem_clause_write);
+
+      if (program->xnack_enabled) {
+         BITSET_ZERO(ctx.smem_clause_read_write);
+         BITSET_ZERO(ctx.smem_clause_write);
+      }
    }
 
    if (instr->format == Format::SMEM) {
@@ -409,15 +445,17 @@ void handle_instruction_gfx6(Program *program, Block *cur_block, NOP_ctx_gfx6 &c
       } else {
          ctx.smem_clause = true;
 
-         for (Operand op : instr->operands) {
-            if (!op.isConstant()) {
-               set_bitset_range(ctx.smem_clause_read_write, op.physReg(), op.size());
+         if (program->xnack_enabled) {
+            for (Operand op : instr->operands) {
+               if (!op.isConstant()) {
+                  set_bitset_range(ctx.smem_clause_read_write, op.physReg(), op.size());
+               }
             }
-         }
 
-         Definition def = instr->definitions[0];
-         set_bitset_range(ctx.smem_clause_read_write, def.physReg(), def.size());
-         set_bitset_range(ctx.smem_clause_write, def.physReg(), def.size());
+            Definition def = instr->definitions[0];
+            set_bitset_range(ctx.smem_clause_read_write, def.physReg(), def.size());
+            set_bitset_range(ctx.smem_clause_write, def.physReg(), def.size());
+         }
       }
    } else if (instr->isVALU()) {
       for (Definition def : instr->definitions) {
@@ -564,20 +602,29 @@ void handle_instruction_gfx10(Program *program, Block *cur_block, NOP_ctx_gfx10
       if (program->wave_size == 64)
          ctx.sgprs_read_by_VMEM.set(exec_hi);
    } else if (instr->isSALU() || instr->format == Format::SMEM) {
+      if (instr->opcode == aco_opcode::s_waitcnt) {
+         /* Hazard is mitigated by "s_waitcnt vmcnt(0)" */
+         uint16_t imm = static_cast<SOPP_instruction*>(instr.get())->imm;
+         unsigned vmcnt = (imm & 0xF) | ((imm & (0x3 << 14)) >> 10);
+         if (vmcnt == 0)
+            ctx.sgprs_read_by_VMEM.reset();
+      } else if (instr->opcode == aco_opcode::s_waitcnt_depctr) {
+         /* Hazard is mitigated by a s_waitcnt_depctr with a magic imm */
+         const SOPP_instruction *sopp = static_cast<const SOPP_instruction *>(instr.get());
+         if (sopp->imm == 0xffe3)
+            ctx.sgprs_read_by_VMEM.reset();
+      }
+
       /* Check if SALU writes an SGPR that was previously read by the VALU */
       if (check_written_regs(instr, ctx.sgprs_read_by_VMEM)) {
          ctx.sgprs_read_by_VMEM.reset();
 
-         /* Insert v_nop to mitigate the problem */
-         aco_ptr<VOP1_instruction> nop{create_instruction<VOP1_instruction>(aco_opcode::v_nop, Format::VOP1, 0, 0)};
-         new_instructions.emplace_back(std::move(nop));
+         /* Insert s_waitcnt_depctr instruction with magic imm to mitigate the problem */
+         aco_ptr<SOPP_instruction> depctr{create_instruction<SOPP_instruction>(aco_opcode::s_waitcnt_depctr, Format::SOPP, 0, 0)};
+         depctr->imm = 0xffe3;
+         depctr->block = -1;
+         new_instructions.emplace_back(std::move(depctr));
       }
-   } else if (instr->opcode == aco_opcode::s_waitcnt) {
-      /* Hazard is mitigated by "s_waitcnt vmcnt(0)" */
-      uint16_t imm = static_cast<SOPP_instruction*>(instr.get())->imm;
-      unsigned vmcnt = (imm & 0xF) | ((imm & (0x3 << 14)) >> 10);
-      if (vmcnt == 0)
-         ctx.sgprs_read_by_VMEM.reset();
    } else if (instr->isVALU()) {
       /* Hazard is mitigated by any VALU instruction */
       ctx.sgprs_read_by_VMEM.reset();
@@ -757,14 +804,12 @@ void mitigate_hazards(Program *program)
 
 void insert_NOPs(Program* program)
 {
-   if (program->chip_class >= GFX10) {
+   if (program->chip_class >= GFX10_3)
+      ; /* no hazards/bugs to mitigate */
+   else if (program->chip_class >= GFX10)
       mitigate_hazards<NOP_ctx_gfx10, handle_instruction_gfx10>(program);
-   } else {
-      for (Block& block : program->blocks) {
-         NOP_ctx_gfx6 ctx;
-         handle_block<NOP_ctx_gfx6, handle_instruction_gfx6>(program, ctx, block);
-      }
-   }
+   else
+      mitigate_hazards<NOP_ctx_gfx6, handle_instruction_gfx6>(program);
 }
 
 }