aco: handle when ACO adds new continue edges
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 31 Jan 2020 13:56:26 +0000 (13:56 +0000)
committerMarge Bot <eric+marge@anholt.net>
Mon, 23 Mar 2020 15:55:12 +0000 (15:55 +0000)
Usually a loop ends with a uniform continue. If it doesn't and we end up
adding our own continue edges (because of continue_or_break or divergent
breaks at the end), we have to add extra operands to the loop header phis.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3658>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_ir.h

index 3ffa6153d4629d34e312912020b0be2086a74024..52db7f4a6b3dc3c271345c094962f61695568878 100644 (file)
@@ -8235,7 +8235,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
 
    std::vector<unsigned>& preds = logical ? ctx->block->logical_preds : ctx->block->linear_preds;
    unsigned num_operands = 0;
-   Operand operands[std::max(exec_list_length(&instr->srcs), (unsigned)preds.size())];
+   Operand operands[std::max(exec_list_length(&instr->srcs), (unsigned)preds.size()) + 1];
    unsigned num_defined = 0;
    unsigned cur_pred_idx = 0;
    for (std::pair<unsigned, nir_ssa_def *> src : phi_src) {
@@ -8266,6 +8266,17 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
    while (cur_pred_idx++ < preds.size())
       operands[num_operands++] = Operand(dst.regClass());
 
+   /* If the loop ends with a break, still add a linear continue edge in case
+    * that break is divergent or continue_or_break is used. We'll either remove
+    * this operand later in visit_loop() if it's not necessary or replace the
+    * undef with something correct. */
+   if (!logical && ctx->block->kind & block_kind_loop_header) {
+      nir_loop *loop = nir_cf_node_as_loop(instr->instr.block->cf_node.parent);
+      nir_block *last = nir_loop_last_block(loop);
+      if (last->successors[0] != instr->instr.block)
+         operands[num_operands++] = Operand(RegClass());
+   }
+
    if (num_defined == 0) {
       Builder bld(ctx->program, ctx->block);
       if (dst.regClass() == s1) {
@@ -8487,6 +8498,51 @@ void visit_block(isel_context *ctx, nir_block *block)
 
 
 
+static Operand create_continue_phis(isel_context *ctx, unsigned first, unsigned last,
+                                    aco_ptr<Instruction>& header_phi, Operand *vals)
+{
+   vals[0] = Operand(header_phi->definitions[0].getTemp());
+   RegClass rc = vals[0].regClass();
+
+   unsigned loop_nest_depth = ctx->program->blocks[first].loop_nest_depth;
+
+   unsigned next_pred = 1;
+
+   for (unsigned idx = first + 1; idx <= last; idx++) {
+      Block& block = ctx->program->blocks[idx];
+      if (block.loop_nest_depth != loop_nest_depth) {
+         vals[idx - first] = vals[idx - 1 - first];
+         continue;
+      }
+
+      if (block.kind & block_kind_continue) {
+         vals[idx - first] = header_phi->operands[next_pred];
+         next_pred++;
+         continue;
+      }
+
+      bool all_same = true;
+      for (unsigned i = 1; all_same && (i < block.linear_preds.size()); i++)
+         all_same = vals[block.linear_preds[i] - first] == vals[block.linear_preds[0] - first];
+
+      Operand val;
+      if (all_same) {
+         val = vals[block.linear_preds[0] - first];
+      } else {
+         aco_ptr<Instruction> phi(create_instruction<Pseudo_instruction>(
+            aco_opcode::p_linear_phi, Format::PSEUDO, block.linear_preds.size(), 1));
+         for (unsigned i = 0; i < block.linear_preds.size(); i++)
+            phi->operands[i] = vals[block.linear_preds[i] - first];
+         val = Operand(Temp(ctx->program->allocateId(), rc));
+         phi->definitions[0] = Definition(val.getTemp());
+         block.instructions.emplace(block.instructions.begin(), std::move(phi));
+      }
+      vals[idx - first] = val;
+   }
+
+   return vals[last - first];
+}
+
 static void visit_loop(isel_context *ctx, nir_loop *loop)
 {
    //TODO: we might want to wrap the loop around a branch if exec_potentially_empty=true
@@ -8570,6 +8626,24 @@ static void visit_loop(isel_context *ctx, nir_loop *loop)
       }
    }
 
+   /* Fixup linear phis in loop header from expecting a continue. Both this fixup
+    * and the previous one shouldn't both happen at once because a break in the
+    * merge block would get CSE'd */
+   if (nir_loop_last_block(loop)->successors[0] != nir_loop_first_block(loop)) {
+      unsigned num_vals = ctx->cf_info.has_branch ? 1 : (ctx->block->index - loop_header_idx + 1);
+      Operand vals[num_vals];
+      for (aco_ptr<Instruction>& instr : ctx->program->blocks[loop_header_idx].instructions) {
+         if (instr->opcode == aco_opcode::p_linear_phi) {
+            if (ctx->cf_info.has_branch)
+               instr->operands.pop_back();
+            else
+               instr->operands.back() = create_continue_phis(ctx, loop_header_idx, ctx->block->index, instr, vals);
+         } else if (!is_phi(instr)) {
+            break;
+         }
+      }
+   }
+
    ctx->cf_info.has_branch = false;
 
    // TODO: if the loop has not a single exit, we must add one °°
index 5bbe337fe1704527c43d4aaefd5e9021cb98e25c..1eae6c5d0ccb8f2f137ae13a2905de67894f50e6 100644 (file)
@@ -521,6 +521,23 @@ public:
       return isFirstKill() && !isLateKill();
    }
 
+   constexpr bool operator == (Operand other) const noexcept
+   {
+      if (other.size() != size())
+         return false;
+      if (isFixed() != other.isFixed() || isKillBeforeDef() != other.isKillBeforeDef())
+         return false;
+      if (isFixed() && other.isFixed() && physReg() != other.physReg())
+         return false;
+      if (isLiteral())
+         return other.isLiteral() && other.constantValue() == constantValue();
+      else if (isConstant())
+         return other.isConstant() && other.physReg() == physReg();
+      else if (isUndefined())
+         return other.isUndefined() && other.regClass() == regClass();
+      else
+         return other.isTemp() && other.getTemp() == getTemp();
+   }
 private:
    union {
       uint32_t i;