aco: fix printing ASM on GFX6-7 if clrxdisasm is not found
[mesa.git] / src / amd / compiler / aco_insert_waitcnt.cpp
index 254eb97d1510c2038e9155115c6557b61cf61ccb..196b45fbb75c071c43826e6cd7f9b30ccbee5438 100644 (file)
@@ -25,6 +25,7 @@
 #include <algorithm>
 #include <map>
 #include <stack>
+#include <math.h>
 
 #include "aco_ir.h"
 #include "vulkan/radv_shader.h"
@@ -66,6 +67,7 @@ enum wait_event : uint16_t {
    event_gds_gpr_lock = 1 << 9,
    event_vmem_gpr_lock = 1 << 10,
    event_sendmsg = 1 << 11,
+   num_events = 12,
 };
 
 enum counter_type : uint8_t {
@@ -73,6 +75,7 @@ enum counter_type : uint8_t {
    counter_lgkm = 1 << 1,
    counter_vm = 1 << 2,
    counter_vs = 1 << 3,
+   num_counters = 4,
 };
 
 static const uint16_t exp_events = event_exp_pos | event_exp_param | event_exp_mrt_null | event_gds_gpr_lock | event_vmem_gpr_lock;
@@ -105,6 +108,21 @@ uint8_t get_counters_for_event(wait_event ev)
    }
 }
 
+uint16_t get_events_for_counter(counter_type ctr)
+{
+   switch (ctr) {
+   case counter_exp:
+      return exp_events;
+   case counter_lgkm:
+      return lgkm_events;
+   case counter_vm:
+      return vm_events;
+   case counter_vs:
+      return vs_events;
+   }
+   return 0;
+}
+
 struct wait_imm {
    static const uint8_t unset_counter = 0xff;
 
@@ -182,20 +200,27 @@ struct wait_entry {
    uint8_t counters; /* use counter_type notion */
    bool wait_on_read:1;
    bool logical:1;
+   bool has_vmem_nosampler:1;
+   bool has_vmem_sampler:1;
 
    wait_entry(wait_event event, wait_imm imm, bool logical, bool wait_on_read)
            : imm(imm), events(event), counters(get_counters_for_event(event)),
-             wait_on_read(wait_on_read), logical(logical) {}
+             wait_on_read(wait_on_read), logical(logical),
+             has_vmem_nosampler(false), has_vmem_sampler(false) {}
 
    bool join(const wait_entry& other)
    {
       bool changed = (other.events & ~events) ||
                      (other.counters & ~counters) ||
-                     (other.wait_on_read && !wait_on_read);
+                     (other.wait_on_read && !wait_on_read) ||
+                     (other.has_vmem_nosampler && !has_vmem_nosampler) ||
+                     (other.has_vmem_sampler && !has_vmem_sampler);
       events |= other.events;
       counters |= other.counters;
       changed |= imm.combine(other.imm);
-      wait_on_read = wait_on_read || other.wait_on_read;
+      wait_on_read |= other.wait_on_read;
+      has_vmem_nosampler |= other.has_vmem_nosampler;
+      has_vmem_sampler |= other.has_vmem_sampler;
       assert(logical == other.logical);
       return changed;
    }
@@ -212,6 +237,8 @@ struct wait_entry {
       if (counter == counter_vm) {
          imm.vm = wait_imm::unset_counter;
          events &= ~event_vmem;
+         has_vmem_nosampler = false;
+         has_vmem_sampler = false;
       }
 
       if (counter == counter_exp) {
@@ -251,6 +278,13 @@ struct wait_ctx {
 
    std::map<PhysReg,wait_entry> gpr_map;
 
+   /* used for vmem/smem scores */
+   bool collect_statistics;
+   Instruction *gen_instr;
+   std::map<Instruction *, unsigned> unwaited_instrs[num_counters];
+   std::map<PhysReg,std::set<Instruction *>> reg_instrs[num_counters];
+   std::vector<unsigned> wait_distances[num_events];
+
    wait_ctx() {}
    wait_ctx(Program *program_)
            : program(program_),
@@ -298,8 +332,53 @@ struct wait_ctx {
          barrier_events[i] |= other->barrier_events[i];
       }
 
+      /* these are used for statistics, so don't update "changed" */
+      for (unsigned i = 0; i < num_counters; i++) {
+         for (std::pair<Instruction *, unsigned> instr : other->unwaited_instrs[i]) {
+            auto pos = unwaited_instrs[i].find(instr.first);
+            if (pos == unwaited_instrs[i].end())
+               unwaited_instrs[i].insert(instr);
+            else
+               pos->second = std::min(pos->second, instr.second);
+         }
+         /* don't use a foreach loop to avoid copies */
+         for (auto it = other->reg_instrs[i].begin(); it != other->reg_instrs[i].end(); ++it)
+            reg_instrs[i][it->first].insert(it->second.begin(), it->second.end());
+      }
+
       return changed;
    }
+
+   void wait_and_remove_from_entry(PhysReg reg, wait_entry& entry, counter_type counter) {
+      if (collect_statistics && (entry.counters & counter)) {
+         unsigned counter_idx = ffs(counter) - 1;
+         for (Instruction *instr : reg_instrs[counter_idx][reg]) {
+            auto pos = unwaited_instrs[counter_idx].find(instr);
+            if (pos == unwaited_instrs[counter_idx].end())
+               continue;
+
+            unsigned distance = pos->second;
+            unsigned events = entry.events & get_events_for_counter(counter);
+            while (events) {
+               unsigned event_idx = u_bit_scan(&events);
+               wait_distances[event_idx].push_back(distance);
+            }
+
+            unwaited_instrs[counter_idx].erase(instr);
+         }
+         reg_instrs[counter_idx][reg].clear();
+      }
+
+      entry.remove_counter(counter);
+   }
+
+   void advance_unwaited_instrs()
+   {
+      for (unsigned i = 0; i < num_counters; i++) {
+         for (auto it = unwaited_instrs[i].begin(); it != unwaited_instrs[i].end(); ++it)
+            it->second++;
+      }
+   }
 };
 
 wait_imm check_instr(Instruction* instr, wait_ctx& ctx)
@@ -332,22 +411,16 @@ wait_imm check_instr(Instruction* instr, wait_ctx& ctx)
             continue;
 
          /* Vector Memory reads and writes return in the order they were issued */
-         if (instr->isVMEM() && ((it->second.events & vm_events) == event_vmem)) {
-            it->second.remove_counter(counter_vm);
-            if (!it->second.counters)
-               it = ctx.gpr_map.erase(it);
+         bool has_sampler = instr->format == Format::MIMG && !instr->operands[1].isUndefined() && instr->operands[1].regClass() == s4;
+         if (instr->isVMEM() && ((it->second.events & vm_events) == event_vmem) &&
+             it->second.has_vmem_nosampler == !has_sampler && it->second.has_vmem_sampler == has_sampler)
             continue;
-         }
 
          /* LDS reads and writes return in the order they were issued. same for GDS */
          if (instr->format == Format::DS) {
             bool gds = static_cast<DS_instruction*>(instr)->gds;
-            if ((it->second.events & lgkm_events) == (gds ? event_gds : event_lds)) {
-               it->second.remove_counter(counter_lgkm);
-               if (!it->second.counters)
-                  it = ctx.gpr_map.erase(it);
+            if ((it->second.events & lgkm_events) == (gds ? event_gds : event_lds))
                continue;
-            }
          }
 
          wait.combine(it->second.imm);
@@ -387,7 +460,7 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
       imm.lgkm = 0;
    }
 
-   if (ctx.chip_class >= GFX10) {
+   if (ctx.chip_class >= GFX10 && instr->format == Format::SMEM) {
       /* GFX10: A store followed by a load at the same address causes a problem because
        * the load doesn't load the correct values unless we wait for the store first.
        * This is NOT mitigated by an s_nop.
@@ -403,17 +476,12 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
    }
 
    if (instr->format == Format::PSEUDO_BARRIER) {
-      uint32_t workgroup_size = UINT32_MAX;
-      if (ctx.program->stage & sw_cs) {
-         unsigned* bsize = ctx.program->info->cs.block_size;
-         workgroup_size = bsize[0] * bsize[1] * bsize[2];
-      }
       switch (instr->opcode) {
       case aco_opcode::p_memory_barrier_common:
          imm.combine(ctx.barrier_imm[ffs(barrier_atomic) - 1]);
          imm.combine(ctx.barrier_imm[ffs(barrier_buffer) - 1]);
          imm.combine(ctx.barrier_imm[ffs(barrier_image) - 1]);
-         if (workgroup_size > ctx.program->wave_size)
+         if (ctx.program->workgroup_size > ctx.program->wave_size)
             imm.combine(ctx.barrier_imm[ffs(barrier_shared) - 1]);
          break;
       case aco_opcode::p_memory_barrier_atomic:
@@ -426,7 +494,7 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
          imm.combine(ctx.barrier_imm[ffs(barrier_image) - 1]);
          break;
       case aco_opcode::p_memory_barrier_shared:
-         if (workgroup_size > ctx.program->wave_size)
+         if (ctx.program->workgroup_size > ctx.program->wave_size)
             imm.combine(ctx.barrier_imm[ffs(barrier_shared) - 1]);
          break;
       case aco_opcode::p_memory_barrier_gs_data:
@@ -482,13 +550,13 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
       while (it != ctx.gpr_map.end())
       {
          if (imm.exp != wait_imm::unset_counter && imm.exp <= it->second.imm.exp)
-            it->second.remove_counter(counter_exp);
+            ctx.wait_and_remove_from_entry(it->first, it->second, counter_exp);
          if (imm.vm != wait_imm::unset_counter && imm.vm <= it->second.imm.vm)
-            it->second.remove_counter(counter_vm);
+            ctx.wait_and_remove_from_entry(it->first, it->second, counter_vm);
          if (imm.lgkm != wait_imm::unset_counter && imm.lgkm <= it->second.imm.lgkm)
-            it->second.remove_counter(counter_lgkm);
-         if (imm.lgkm != wait_imm::unset_counter && imm.vs <= it->second.imm.vs)
-            it->second.remove_counter(counter_vs);
+            ctx.wait_and_remove_from_entry(it->first, it->second, counter_lgkm);
+         if (imm.vs != wait_imm::unset_counter && imm.vs <= it->second.imm.vs)
+            ctx.wait_and_remove_from_entry(it->first, it->second, counter_vs);
          if (!it->second.counters)
             it = ctx.gpr_map.erase(it);
          else
@@ -604,7 +672,8 @@ void update_counters_for_flat_load(wait_ctx& ctx, barrier_interaction barrier=ba
    ctx.pending_flat_vm = true;
 }
 
-void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, bool wait_on_read)
+void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, bool wait_on_read,
+                       bool has_sampler=false)
 {
    uint16_t counters = get_counters_for_event(event);
    wait_imm imm;
@@ -618,23 +687,35 @@ void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event
       imm.vs = 0;
 
    wait_entry new_entry(event, imm, !rc.is_linear(), wait_on_read);
+   new_entry.has_vmem_nosampler = (event & event_vmem) && !has_sampler;
+   new_entry.has_vmem_sampler = (event & event_vmem) && has_sampler;
 
    for (unsigned i = 0; i < rc.size(); i++) {
-      auto it = ctx.gpr_map.emplace(PhysReg{reg.reg+i}, new_entry);
+      auto it = ctx.gpr_map.emplace(PhysReg{reg.reg()+i}, new_entry);
       if (!it.second)
          it.first->second.join(new_entry);
    }
+
+   if (ctx.collect_statistics) {
+      unsigned counters_todo = counters;
+      while (counters_todo) {
+         unsigned i = u_bit_scan(&counters_todo);
+         ctx.unwaited_instrs[i].insert(std::make_pair(ctx.gen_instr, 0u));
+         for (unsigned j = 0; j < rc.size(); j++)
+            ctx.reg_instrs[i][PhysReg{reg.reg()+j}].insert(ctx.gen_instr);
+      }
+   }
 }
 
-void insert_wait_entry(wait_ctx& ctx, Operand op, wait_event event)
+void insert_wait_entry(wait_ctx& ctx, Operand op, wait_event event, bool has_sampler=false)
 {
    if (!op.isConstant() && !op.isUndefined())
-      insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false);
+      insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false, has_sampler);
 }
 
-void insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event)
+void insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event, bool has_sampler=false)
 {
-   insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true);
+   insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, has_sampler);
 }
 
 void gen(Instruction* instr, wait_ctx& ctx)
@@ -711,8 +792,10 @@ void gen(Instruction* instr, wait_ctx& ctx)
       wait_event ev = !instr->definitions.empty() || ctx.chip_class < GFX10 ? event_vmem : event_vmem_store;
       update_counters(ctx, ev, get_barrier_interaction(instr));
 
+      bool has_sampler = instr->format == Format::MIMG && !instr->operands[1].isUndefined() && instr->operands[1].regClass() == s4;
+
       if (!instr->definitions.empty())
-         insert_wait_entry(ctx, instr->definitions[0], ev);
+         insert_wait_entry(ctx, instr->definitions[0], ev, has_sampler);
 
       if (ctx.chip_class == GFX6 &&
           instr->format != Format::MIMG &&
@@ -763,11 +846,15 @@ void handle_block(Program *program, Block& block, wait_ctx& ctx)
    std::vector<aco_ptr<Instruction>> new_instructions;
 
    wait_imm queued_imm;
+
+   ctx.collect_statistics = program->collect_statistics;
+
    for (aco_ptr<Instruction>& instr : block.instructions) {
       bool is_wait = !parse_wait_instr(ctx, instr.get()).empty();
 
       queued_imm.combine(kill(instr.get(), ctx));
 
+      ctx.gen_instr = instr.get();
       gen(instr.get(), ctx);
 
       if (instr->format != Format::PSEUDO_BARRIER && !is_wait) {
@@ -776,6 +863,9 @@ void handle_block(Program *program, Block& block, wait_ctx& ctx)
             queued_imm = wait_imm();
          }
          new_instructions.emplace_back(std::move(instr));
+
+         if (ctx.collect_statistics)
+            ctx.advance_unwaited_instrs();
       }
    }
 
@@ -787,14 +877,58 @@ void handle_block(Program *program, Block& block, wait_ctx& ctx)
 
 } /* end namespace */
 
+static uint32_t calculate_score(std::vector<wait_ctx> &ctx_vec, uint32_t event_mask)
+{
+   double result = 0.0;
+   unsigned num_waits = 0;
+   while (event_mask) {
+      unsigned event_index = u_bit_scan(&event_mask);
+      for (const wait_ctx &ctx : ctx_vec) {
+         for (unsigned dist : ctx.wait_distances[event_index]) {
+            double score = dist;
+            /* for many events, excessive distances provide little benefit, so
+             * decrease the score in that case. */
+            double threshold = INFINITY;
+            double inv_strength = 0.000001;
+            switch (1 << event_index) {
+            case event_smem:
+               threshold = 70.0;
+               inv_strength = 75.0;
+               break;
+            case event_vmem:
+            case event_vmem_store:
+            case event_flat:
+               threshold = 230.0;
+               inv_strength = 150.0;
+               break;
+            case event_lds:
+               threshold = 16.0;
+               break;
+            default:
+               break;
+            }
+            if (score > threshold) {
+               score -= threshold;
+               score = threshold + score / (1.0 + score / inv_strength);
+            }
+
+            /* we don't want increases in high scores to hide decreases in low scores,
+             * so raise to the power of 0.1 before averaging. */
+            result += pow(score, 0.1);
+            num_waits++;
+         }
+      }
+   }
+   return round(pow(result / num_waits, 10.0) * 10.0);
+}
+
 void insert_wait_states(Program* program)
 {
    /* per BB ctx */
    std::vector<bool> done(program->blocks.size());
-   wait_ctx in_ctx[program->blocks.size()];
-   wait_ctx out_ctx[program->blocks.size()];
-   for (unsigned i = 0; i < program->blocks.size(); i++)
-      in_ctx[i] = wait_ctx(program);
+   std::vector<wait_ctx> in_ctx(program->blocks.size(), wait_ctx(program));
+   std::vector<wait_ctx> out_ctx(program->blocks.size(), wait_ctx(program));
+
    std::stack<unsigned> loop_header_indices;
    unsigned loop_progress = 0;
 
@@ -822,13 +956,15 @@ void insert_wait_states(Program* program)
       for (unsigned b : current.logical_preds)
          changed |= ctx.join(&out_ctx[b], true);
 
-      in_ctx[current.index] = ctx;
-
-      if (done[current.index] && !changed)
+      if (done[current.index] && !changed) {
+         in_ctx[current.index] = std::move(ctx);
          continue;
+      } else {
+         in_ctx[current.index] = ctx;
+      }
 
       if (current.instructions.empty()) {
-         out_ctx[current.index] = ctx;
+         out_ctx[current.index] = std::move(ctx);
          continue;
       }
 
@@ -837,7 +973,14 @@ void insert_wait_states(Program* program)
 
       handle_block(program, current, ctx);
 
-      out_ctx[current.index] = ctx;
+      out_ctx[current.index] = std::move(ctx);
+   }
+
+   if (program->collect_statistics) {
+      program->statistics[statistic_vmem_score] =
+         calculate_score(out_ctx, event_vmem | event_flat | event_vmem_store);
+      program->statistics[statistic_smem_score] =
+         calculate_score(out_ctx, event_smem);
    }
 }