aco: add a late kill flag
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 21 Feb 2020 15:46:39 +0000 (15:46 +0000)
committerMarge Bot <eric+marge@anholt.net>
Mon, 16 Mar 2020 16:09:02 +0000 (16:09 +0000)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/3914>

src/amd/compiler/aco_ir.h
src/amd/compiler/aco_live_var_analysis.cpp
src/amd/compiler/aco_print_ir.cpp
src/amd/compiler/aco_register_allocation.cpp
src/amd/compiler/aco_validate.cpp

index 3a2a2186afe405e4944cb2bd82a313b57a63bd8e..92511975a6965adac8d94370b12bee703afdac81 100644 (file)
@@ -298,7 +298,8 @@ class Operand final
 public:
    constexpr Operand()
       : reg_(PhysReg{128}), isTemp_(false), isFixed_(true), isConstant_(false),
-        isKill_(false), isUndef_(true), isFirstKill_(false), is64BitConst_(false) {}
+        isKill_(false), isUndef_(true), isFirstKill_(false), is64BitConst_(false),
+        isLateKill_(false) {}
 
    explicit Operand(Temp r) noexcept
    {
@@ -471,6 +472,19 @@ public:
       return isConstant() && constantValue() == cmp;
    }
 
+   /* Indicates that the killed operand's live range intersects with the
+    * instruction's definitions. Unlike isKill() and isFirstKill(), this is
+    * not set by liveness analysis. */
+   constexpr void setLateKill(bool flag) noexcept
+   {
+      isLateKill_ = flag;
+   }
+
+   constexpr bool isLateKill() const noexcept
+   {
+      return isLateKill_;
+   }
+
    constexpr void setKill(bool flag) noexcept
    {
       isKill_ = flag;
@@ -497,6 +511,16 @@ public:
       return isFirstKill_;
    }
 
+   constexpr bool isKillBeforeDef() const noexcept
+   {
+      return isKill() && !isLateKill();
+   }
+
+   constexpr bool isFirstKillBeforeDef() const noexcept
+   {
+      return isFirstKill() && !isLateKill();
+   }
+
 private:
    union {
       uint32_t i;
@@ -513,6 +537,7 @@ private:
          uint8_t isUndef_:1;
          uint8_t isFirstKill_:1;
          uint8_t is64BitConst_:1;
+         uint8_t isLateKill_:1;
       };
       /* can't initialize bit-fields in c++11, so work around using a union */
       uint8_t control_ = 0;
index 3c8e4472db12a49b0842ddbc167e363cd3a107a8..d4383cf588745abb4fdca4a2eb2461417879fd8d 100644 (file)
@@ -57,12 +57,19 @@ RegisterDemand get_live_changes(aco_ptr<Instruction>& instr)
 RegisterDemand get_temp_registers(aco_ptr<Instruction>& instr)
 {
    RegisterDemand temp_registers;
+
    for (Definition def : instr->definitions) {
       if (!def.isTemp())
          continue;
       if (def.isKill())
          temp_registers += def.getTemp();
    }
+
+   for (Operand op : instr->operands) {
+      if (op.isTemp() && op.isLateKill() && op.isFirstKill())
+         temp_registers += op.getTemp();
+   }
+
    return temp_registers;
 }
 
@@ -139,6 +146,7 @@ void process_live_temps_per_block(Program *program, live& lives, Block* block,
             new_demand -= temp;
             definition.setKill(false);
          } else {
+            register_demand[idx] += temp;
             definition.setKill(true);
          }
 
@@ -175,6 +183,8 @@ void process_live_temps_per_block(Program *program, live& lives, Block* block,
                      insn->operands[j].setKill(true);
                   }
                }
+               if (operand.isLateKill())
+                  register_demand[idx] += temp;
                new_demand += temp;
             }
 
@@ -183,8 +193,6 @@ void process_live_temps_per_block(Program *program, live& lives, Block* block,
          }
       }
 
-      register_demand[idx] += get_temp_registers(block->instructions[idx]);
-
       block->register_demand.update(register_demand[idx]);
    }
 
index b8dc42009e2e554e28d7b006c4ea9063dcf4883c..8f89236ff90de7f7aeebaae17ff6c741d9e475c6 100644 (file)
@@ -131,6 +131,9 @@ static void print_operand(const Operand *operand, FILE *output)
       print_reg_class(operand->regClass(), output);
       fprintf(output, "undef");
    } else {
+      if (operand->isLateKill())
+         fprintf(output, "(latekill)");
+
       fprintf(output, "%%%d", operand->tempId());
 
       if (operand->isFixed())
index e8b1069a50997cb480eddb63be2609892581df18..c3726acf1b2846eb5b2bbf02aa06e48a7f880756 100644 (file)
@@ -219,7 +219,7 @@ void update_renames(ra_ctx& ctx, RegisterFile& reg_file,
          if (!op.isTemp())
             continue;
          if (op.tempId() == copy.first.tempId()) {
-            bool omit_renaming = instr->opcode == aco_opcode::p_create_vector && !op.isKill();
+            bool omit_renaming = instr->opcode == aco_opcode::p_create_vector && !op.isKillBeforeDef();
             for (std::pair<Operand, Definition>& pc : parallelcopies) {
                PhysReg def_reg = pc.second.physReg();
                omit_renaming &= def_reg > copy.first.physReg() ?
@@ -336,8 +336,8 @@ bool get_regs_for_copies(ra_ctx& ctx,
 
       /* check if this is a dead operand, then we can re-use the space from the definition */
       bool is_dead_operand = false;
-      for (unsigned i = 0; !is_phi(instr) && !is_dead_operand && i < instr->operands.size(); i++) {
-         if (instr->operands[i].isTemp() && instr->operands[i].isKill() && instr->operands[i].tempId() == id)
+      for (unsigned i = 0; !is_phi(instr) && !is_dead_operand && (i < instr->operands.size()); i++) {
+         if (instr->operands[i].isTemp() && instr->operands[i].isKillBeforeDef() && instr->operands[i].tempId() == id)
             is_dead_operand = true;
       }
 
@@ -409,7 +409,7 @@ bool get_regs_for_copies(ra_ctx& ctx,
             }
             bool is_kill = false;
             for (const Operand& op : instr->operands) {
-               if (op.isTemp() && op.isKill() && op.tempId() == reg_file[j]) {
+               if (op.isTemp() && op.isKillBeforeDef() && op.tempId() == reg_file[j]) {
                   is_kill = true;
                   break;
                }
@@ -488,7 +488,7 @@ std::pair<PhysReg, bool> get_reg_impl(ra_ctx& ctx,
    unsigned killed_ops = 0;
    for (unsigned j = 0; !is_phi(instr) && j < instr->operands.size(); j++) {
       if (instr->operands[j].isTemp() &&
-          instr->operands[j].isFirstKill() &&
+          instr->operands[j].isFirstKillBeforeDef() &&
           instr->operands[j].physReg() >= lb &&
           instr->operands[j].physReg() < ub) {
          assert(instr->operands[j].isFixed());
@@ -573,7 +573,7 @@ std::pair<PhysReg, bool> get_reg_impl(ra_ctx& ctx,
    if (num_moves == 0xFF) {
       /* remove killed operands from reg_file once again */
       for (unsigned i = 0; !is_phi(instr) && i < instr->operands.size(); i++) {
-         if (instr->operands[i].isTemp() && instr->operands[i].isFirstKill())
+         if (instr->operands[i].isTemp() && instr->operands[i].isFirstKillBeforeDef())
             reg_file.clear(instr->operands[i]);
       }
       for (unsigned i = 0; i < instr->definitions.size(); i++) {
@@ -597,7 +597,7 @@ std::pair<PhysReg, bool> get_reg_impl(ra_ctx& ctx,
    if (instr->opcode == aco_opcode::p_create_vector) {
       /* move killed operands which aren't yet at the correct position */
       for (unsigned i = 0, offset = 0; i < instr->operands.size(); offset += instr->operands[i].size(), i++) {
-         if (instr->operands[i].isTemp() && instr->operands[i].isFirstKill() &&
+         if (instr->operands[i].isTemp() && instr->operands[i].isFirstKillBeforeDef() &&
              instr->operands[i].getTemp().type() == rc.type()) {
 
             if (instr->operands[i].physReg() != best_pos + offset) {
@@ -622,7 +622,7 @@ std::pair<PhysReg, bool> get_reg_impl(ra_ctx& ctx,
       /* remove killed operands from reg_file once again */
       if (!is_phi(instr)) {
          for (const Operand& op : instr->operands) {
-            if (op.isTemp() && op.isFirstKill())
+            if (op.isTemp() && op.isFirstKillBeforeDef())
                reg_file.clear(op);
          }
       }
@@ -646,7 +646,7 @@ std::pair<PhysReg, bool> get_reg_impl(ra_ctx& ctx,
       if (!instr->operands[i].isTemp() || !instr->operands[i].isFixed())
          continue;
       assert(!instr->operands[i].isUndefined());
-      if (instr->operands[i].isFirstKill())
+      if (instr->operands[i].isFirstKillBeforeDef())
          reg_file.clear(instr->operands[i]);
    }
    for (unsigned i = 0; i < instr->definitions.size(); i++) {
@@ -695,11 +695,12 @@ PhysReg get_reg(ra_ctx& ctx,
    if (res.second)
       return res.first;
 
+   /* try using more registers */
+
    /* We should only fail here because keeping under the limit would require
     * too many moves. */
    assert(reg_file.count_zero(PhysReg{lb}, ub-lb) >= size);
 
-   /* try using more registers */
    uint16_t max_addressible_sgpr = ctx.program->sgpr_limit;
    uint16_t max_addressible_vgpr = ctx.program->vgpr_limit;
    if (rc.type() == RegType::vgpr && ctx.program->max_reg_demand.vgpr < max_addressible_vgpr) {
@@ -767,7 +768,7 @@ PhysReg get_reg_create_vector(ra_ctx& ctx,
    /* test for each operand which definition placement causes the least shuffle instructions */
    for (unsigned i = 0, offset = 0; i < instr->operands.size(); offset += instr->operands[i].size(), i++) {
       // TODO: think about, if we can alias live operands on the same register
-      if (!instr->operands[i].isTemp() || !instr->operands[i].isKill() || instr->operands[i].getTemp().type() != rc.type())
+      if (!instr->operands[i].isTemp() || !instr->operands[i].isKillBeforeDef() || instr->operands[i].getTemp().type() != rc.type())
          continue;
 
       if (offset > instr->operands[i].physReg())
@@ -836,7 +837,7 @@ PhysReg get_reg_create_vector(ra_ctx& ctx,
 
    /* move killed operands which aren't yet at the correct position */
    for (unsigned i = 0, offset = 0; i < instr->operands.size(); offset += instr->operands[i].size(), i++) {
-      if (instr->operands[i].isTemp() && instr->operands[i].isFirstKill() && instr->operands[i].getTemp().type() == rc.type()) {
+      if (instr->operands[i].isTemp() && instr->operands[i].isFirstKillBeforeDef() && instr->operands[i].getTemp().type() == rc.type()) {
          if (instr->operands[i].physReg() != best_pos + offset)
             vars.emplace(instr->operands[i].size(), instr->operands[i].tempId());
          else
@@ -1187,7 +1188,7 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
                /* try to coalesce phi affinities with parallelcopies */
                if (!def.isFixed() && instr->opcode == aco_opcode::p_parallelcopy) {
                   Operand op = instr->operands[i];
-                  if (op.isTemp() && op.isFirstKill() && def.regClass() == op.regClass()) {
+                  if (op.isTemp() && op.isFirstKillBeforeDef() && def.regClass() == op.regClass()) {
                      phi_ressources[it->second].emplace_back(op.getTemp());
                      temp_to_phi_ressources[op.tempId()] = it->second;
                   }
@@ -1409,7 +1410,7 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
                if (phi->opcode == aco_opcode::p_phi) {
                   if (phi->operands[idx].isTemp() &&
                       phi->operands[idx].getTemp().type() == RegType::sgpr &&
-                      phi->operands[idx].isFirstKill()) {
+                      phi->operands[idx].isFirstKillBeforeDef()) {
                      Temp phi_op = read_variable(phi->operands[idx].getTemp(), block.index);
                      PhysReg reg = ctx.assignments[phi_op.id()].first;
                      assert(register_file[reg] == phi_op.id());
@@ -1512,14 +1513,14 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
          }
          /* remove dead vars from register file */
          for (const Operand& op : instr->operands) {
-            if (op.isTemp() && op.isFirstKill())
+            if (op.isTemp() && op.isFirstKillBeforeDef())
                register_file.clear(op);
          }
 
          /* try to optimize v_mad_f32 -> v_mac_f32 */
          if (instr->opcode == aco_opcode::v_mad_f32 &&
              instr->operands[2].isTemp() &&
-             instr->operands[2].isKill() &&
+             instr->operands[2].isKillBeforeDef() &&
              instr->operands[2].getTemp().type() == RegType::vgpr &&
              instr->operands[1].isTemp() &&
              instr->operands[1].getTemp().type() == RegType::vgpr) { /* TODO: swap src0 and src1 in this case */
@@ -1573,7 +1574,7 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
 
                /* re-enable the killed operands, so that we don't move the blocking var there */
                for (const Operand& op : instr->operands) {
-                  if (op.isTemp() && op.isFirstKill())
+                  if (op.isTemp() && op.isFirstKillBeforeDef())
                      register_file.fill(op.physReg(), op.size(), 0xFFFF);
                }
 
@@ -1581,7 +1582,7 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
                PhysReg reg = get_reg(ctx, register_file, rc, parallelcopy, instr);
                /* once again, disable killed operands */
                for (const Operand& op : instr->operands) {
-                  if (op.isTemp() && op.isFirstKill())
+                  if (op.isTemp() && op.isFirstKillBeforeDef())
                      register_file.clear(op);
                }
                for (unsigned k = 0; k < i; k++) {
@@ -1629,7 +1630,7 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
                definition.setFixed(reg);
             } else if (instr->opcode == aco_opcode::p_wqm) {
                PhysReg reg;
-               if (instr->operands[0].isKill() && instr->operands[0].getTemp().type() == definition.getTemp().type()) {
+               if (instr->operands[0].isKillBeforeDef() && instr->operands[0].getTemp().type() == definition.getTemp().type()) {
                   reg = instr->operands[0].physReg();
                   assert(register_file[reg.reg] == 0);
                } else {
@@ -1638,7 +1639,7 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
                definition.setFixed(reg);
             } else if (instr->opcode == aco_opcode::p_extract_vector) {
                PhysReg reg;
-               if (instr->operands[0].isKill() &&
+               if (instr->operands[0].isKillBeforeDef() &&
                    instr->operands[0].getTemp().type() == definition.getTemp().type()) {
                   reg = instr->operands[0].physReg();
                   reg.reg += definition.size() * instr->operands[1].constantValue();
@@ -1711,11 +1712,15 @@ void register_allocation(Program *program, std::vector<std::set<Temp>> live_out_
 
          handle_pseudo(ctx, register_file, instr.get());
 
-         /* kill definitions */
+         /* kill definitions and late-kill operands */
          for (const Definition& def : instr->definitions) {
              if (def.isTemp() && def.isKill())
                 register_file.clear(def);
          }
+         for (const Operand& op : instr->operands) {
+            if (op.isTemp() && op.isFirstKill() && op.isLateKill())
+               register_file.clear(op);
+         }
 
          /* emit parallelcopy */
          if (!parallelcopy.empty()) {
index 0e9b6c20ab09fe7f624b991025fb6fe848d76dfb..e967f0ca9e7d41c907b1c02c80d19d025c0cb5ae 100644 (file)
@@ -497,7 +497,7 @@ bool validate_ra(Program *program, const struct radv_nir_compiler_options *optio
             for (const Operand& op : instr->operands) {
                if (!op.isTemp())
                   continue;
-               if (op.isFirstKill()) {
+               if (op.isFirstKillBeforeDef()) {
                   for (unsigned j = 0; j < op.getTemp().size(); j++)
                      regs[op.physReg() + j] = 0;
                }
@@ -525,6 +525,17 @@ bool validate_ra(Program *program, const struct radv_nir_compiler_options *optio
                   regs[def.physReg() + j] = 0;
             }
          }
+
+         if (instr->opcode != aco_opcode::p_phi && instr->opcode != aco_opcode::p_linear_phi) {
+            for (const Operand& op : instr->operands) {
+               if (!op.isTemp())
+                  continue;
+               if (op.isLateKill() && op.isFirstKill()) {
+                  for (unsigned j = 0; j < op.getTemp().size(); j++)
+                     regs[op.physReg() + j] = 0;
+               }
+            }
+         }
       }
    }