aco: add ACO_DEBUG=force-waitcnt to emit wait-states
[mesa.git] / src / amd / compiler / aco_insert_waitcnt.cpp
index c0a93e3a9291ac4605e3fb49a1caf4e803d32d6a..751892e44368a476d1f1ff44261826d536037691 100644 (file)
@@ -155,6 +155,7 @@ struct wait_imm {
       assert(exp == unset_counter || exp <= 0x7);
       switch (chip) {
       case GFX10:
+      case GFX10_3:
          assert(lgkm == unset_counter || lgkm <= 0x3f);
          assert(vm == unset_counter || vm <= 0x3f);
          imm = ((vm & 0x30) << 10) | ((lgkm & 0x3f) << 8) | ((exp & 0x7) << 4) | (vm & 0xf);
@@ -200,20 +201,27 @@ struct wait_entry {
    uint8_t counters; /* use counter_type notion */
    bool wait_on_read:1;
    bool logical:1;
+   bool has_vmem_nosampler:1;
+   bool has_vmem_sampler:1;
 
    wait_entry(wait_event event, wait_imm imm, bool logical, bool wait_on_read)
            : imm(imm), events(event), counters(get_counters_for_event(event)),
-             wait_on_read(wait_on_read), logical(logical) {}
+             wait_on_read(wait_on_read), logical(logical),
+             has_vmem_nosampler(false), has_vmem_sampler(false) {}
 
    bool join(const wait_entry& other)
    {
       bool changed = (other.events & ~events) ||
                      (other.counters & ~counters) ||
-                     (other.wait_on_read && !wait_on_read);
+                     (other.wait_on_read && !wait_on_read) ||
+                     (other.has_vmem_nosampler && !has_vmem_nosampler) ||
+                     (other.has_vmem_sampler && !has_vmem_sampler);
       events |= other.events;
       counters |= other.counters;
       changed |= imm.combine(other.imm);
-      wait_on_read = wait_on_read || other.wait_on_read;
+      wait_on_read |= other.wait_on_read;
+      has_vmem_nosampler |= other.has_vmem_nosampler;
+      has_vmem_sampler |= other.has_vmem_sampler;
       assert(logical == other.logical);
       return changed;
    }
@@ -230,6 +238,8 @@ struct wait_entry {
       if (counter == counter_vm) {
          imm.vm = wait_imm::unset_counter;
          events &= ~event_vmem;
+         has_vmem_nosampler = false;
+         has_vmem_sampler = false;
       }
 
       if (counter == counter_exp) {
@@ -264,8 +274,8 @@ struct wait_ctx {
    bool pending_flat_vm = false;
    bool pending_s_buffer_store = false; /* GFX10 workaround */
 
-   wait_imm barrier_imm[barrier_count];
-   uint16_t barrier_events[barrier_count] = {}; /* use wait_event notion */
+   wait_imm barrier_imm[storage_count];
+   uint16_t barrier_events[storage_count] = {}; /* use wait_event notion */
 
    std::map<PhysReg,wait_entry> gpr_map;
 
@@ -284,7 +294,8 @@ struct wait_ctx {
              max_exp_cnt(6),
              max_lgkm_cnt(program_->chip_class >= GFX10 ? 62 : 14),
              max_vs_cnt(program_->chip_class >= GFX10 ? 62 : 0),
-             unordered_events(event_smem | (program_->chip_class < GFX10 ? event_flat : 0)) {}
+             unordered_events(event_smem | (program_->chip_class < GFX10 ? event_flat : 0)),
+             collect_statistics(program_->collect_statistics) {}
 
    bool join(const wait_ctx* other, bool logical)
    {
@@ -317,7 +328,7 @@ struct wait_ctx {
          }
       }
 
-      for (unsigned i = 0; i < barrier_count; i++) {
+      for (unsigned i = 0; i < storage_count; i++) {
          changed |= barrier_imm[i].combine(other->barrier_imm[i]);
          changed |= other->barrier_events[i] & ~barrier_events[i];
          barrier_events[i] |= other->barrier_events[i];
@@ -402,22 +413,16 @@ wait_imm check_instr(Instruction* instr, wait_ctx& ctx)
             continue;
 
          /* Vector Memory reads and writes return in the order they were issued */
-         if (instr->isVMEM() && ((it->second.events & vm_events) == event_vmem)) {
-            it->second.remove_counter(counter_vm);
-            if (!it->second.counters)
-               it = ctx.gpr_map.erase(it);
+         bool has_sampler = instr->format == Format::MIMG && !instr->operands[1].isUndefined() && instr->operands[1].regClass() == s4;
+         if (instr->isVMEM() && ((it->second.events & vm_events) == event_vmem) &&
+             it->second.has_vmem_nosampler == !has_sampler && it->second.has_vmem_sampler == has_sampler)
             continue;
-         }
 
          /* LDS reads and writes return in the order they were issued. same for GDS */
          if (instr->format == Format::DS) {
             bool gds = static_cast<DS_instruction*>(instr)->gds;
-            if ((it->second.events & lgkm_events) == (gds ? event_gds : event_lds)) {
-               it->second.remove_counter(counter_lgkm);
-               if (!it->second.counters)
-                  it = ctx.gpr_map.erase(it);
+            if ((it->second.events & lgkm_events) == (gds ? event_gds : event_lds))
                continue;
-            }
          }
 
          wait.combine(it->second.imm);
@@ -440,9 +445,60 @@ wait_imm parse_wait_instr(wait_ctx& ctx, Instruction *instr)
    return wait_imm();
 }
 
-wait_imm kill(Instruction* instr, wait_ctx& ctx)
+wait_imm perform_barrier(wait_ctx& ctx, memory_sync_info sync, unsigned semantics)
 {
    wait_imm imm;
+   sync_scope subgroup_scope = ctx.program->workgroup_size <= ctx.program->wave_size ? scope_workgroup : scope_subgroup;
+   if ((sync.semantics & semantics) && sync.scope > subgroup_scope) {
+      unsigned storage = sync.storage;
+      while (storage) {
+         unsigned idx = u_bit_scan(&storage);
+
+         /* LDS is private to the workgroup */
+         sync_scope bar_scope_lds = MIN2(sync.scope, scope_workgroup);
+
+         uint16_t events = ctx.barrier_events[idx];
+         if (bar_scope_lds <= subgroup_scope)
+            events &= ~event_lds;
+
+         /* in non-WGP, the L1/L0 cache keeps all memory operations in-order for the same workgroup */
+         if (ctx.chip_class < GFX10 && sync.scope <= scope_workgroup)
+            events &= ~(event_vmem | event_vmem_store | event_smem);
+
+         if (events)
+            imm.combine(ctx.barrier_imm[idx]);
+      }
+   }
+
+   return imm;
+}
+
+void force_waitcnt(wait_ctx& ctx, wait_imm& imm)
+{
+   if (ctx.vm_cnt)
+      imm.vm = 0;
+   if (ctx.exp_cnt)
+      imm.exp = 0;
+   if (ctx.lgkm_cnt)
+      imm.lgkm = 0;
+
+   if (ctx.chip_class >= GFX10) {
+      if (ctx.vs_cnt)
+         imm.vs = 0;
+   }
+}
+
+wait_imm kill(Instruction* instr, wait_ctx& ctx, memory_sync_info sync_info)
+{
+   wait_imm imm;
+
+   if (debug_flags & DEBUG_FORCE_WAITCNT) {
+      /* Force emitting waitcnt states right after the instruction if there is
+       * something to wait for.
+       */
+      force_waitcnt(ctx, imm);
+   }
+
    if (ctx.exp_cnt || ctx.vm_cnt || ctx.lgkm_cnt)
       imm.combine(check_instr(instr, ctx));
 
@@ -457,7 +513,7 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
       imm.lgkm = 0;
    }
 
-   if (ctx.chip_class >= GFX10) {
+   if (ctx.chip_class >= GFX10 && instr->format == Format::SMEM) {
       /* GFX10: A store followed by a load at the same address causes a problem because
        * the load doesn't load the correct values unless we wait for the store first.
        * This is NOT mitigated by an s_nop.
@@ -467,44 +523,15 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
       SMEM_instruction *smem = static_cast<SMEM_instruction *>(instr);
       if (ctx.pending_s_buffer_store &&
           !smem->definitions.empty() &&
-          !smem->can_reorder && smem->barrier == barrier_buffer) {
+          !smem->sync.can_reorder()) {
          imm.lgkm = 0;
       }
    }
 
-   if (instr->format == Format::PSEUDO_BARRIER) {
-      switch (instr->opcode) {
-      case aco_opcode::p_memory_barrier_common:
-         imm.combine(ctx.barrier_imm[ffs(barrier_atomic) - 1]);
-         imm.combine(ctx.barrier_imm[ffs(barrier_buffer) - 1]);
-         imm.combine(ctx.barrier_imm[ffs(barrier_image) - 1]);
-         if (ctx.program->workgroup_size > ctx.program->wave_size)
-            imm.combine(ctx.barrier_imm[ffs(barrier_shared) - 1]);
-         break;
-      case aco_opcode::p_memory_barrier_atomic:
-         imm.combine(ctx.barrier_imm[ffs(barrier_atomic) - 1]);
-         break;
-      /* see comment in aco_scheduler.cpp's can_move_instr() on why these barriers are merged */
-      case aco_opcode::p_memory_barrier_buffer:
-      case aco_opcode::p_memory_barrier_image:
-         imm.combine(ctx.barrier_imm[ffs(barrier_buffer) - 1]);
-         imm.combine(ctx.barrier_imm[ffs(barrier_image) - 1]);
-         break;
-      case aco_opcode::p_memory_barrier_shared:
-         if (ctx.program->workgroup_size > ctx.program->wave_size)
-            imm.combine(ctx.barrier_imm[ffs(barrier_shared) - 1]);
-         break;
-      case aco_opcode::p_memory_barrier_gs_data:
-         imm.combine(ctx.barrier_imm[ffs(barrier_gs_data) - 1]);
-         break;
-      case aco_opcode::p_memory_barrier_gs_sendmsg:
-         imm.combine(ctx.barrier_imm[ffs(barrier_gs_sendmsg) - 1]);
-         break;
-      default:
-         assert(false);
-         break;
-      }
-   }
+   if (instr->opcode == aco_opcode::p_barrier)
+      imm.combine(perform_barrier(ctx, static_cast<Pseudo_barrier_instruction *>(instr)->sync, semantic_acqrel));
+   else
+      imm.combine(perform_barrier(ctx, sync_info, semantic_release));
 
    if (!imm.empty()) {
       if (ctx.pending_flat_vm && imm.vm != wait_imm::unset_counter)
@@ -519,7 +546,7 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
       ctx.vs_cnt = std::min(ctx.vs_cnt, imm.vs);
 
       /* update barrier wait imms */
-      for (unsigned i = 0; i < barrier_count; i++) {
+      for (unsigned i = 0; i < storage_count; i++) {
          wait_imm& bar = ctx.barrier_imm[i];
          uint16_t& bar_ev = ctx.barrier_events[i];
          if (bar.exp != wait_imm::unset_counter && imm.exp <= bar.exp) {
@@ -552,7 +579,7 @@ wait_imm kill(Instruction* instr, wait_ctx& ctx)
             ctx.wait_and_remove_from_entry(it->first, it->second, counter_vm);
          if (imm.lgkm != wait_imm::unset_counter && imm.lgkm <= it->second.imm.lgkm)
             ctx.wait_and_remove_from_entry(it->first, it->second, counter_lgkm);
-         if (imm.lgkm != wait_imm::unset_counter && imm.vs <= it->second.imm.vs)
+         if (imm.vs != wait_imm::unset_counter && imm.vs <= it->second.imm.vs)
             ctx.wait_and_remove_from_entry(it->first, it->second, counter_vs);
          if (!it->second.counters)
             it = ctx.gpr_map.erase(it);
@@ -577,12 +604,12 @@ void update_barrier_counter(uint8_t *ctr, unsigned max)
       (*ctr)++;
 }
 
-void update_barrier_imm(wait_ctx& ctx, uint8_t counters, wait_event event, barrier_interaction barrier)
+void update_barrier_imm(wait_ctx& ctx, uint8_t counters, wait_event event, memory_sync_info sync)
 {
-   for (unsigned i = 0; i < barrier_count; i++) {
+   for (unsigned i = 0; i < storage_count; i++) {
       wait_imm& bar = ctx.barrier_imm[i];
       uint16_t& bar_ev = ctx.barrier_events[i];
-      if (barrier & (1 << i)) {
+      if (sync.storage & (1 << i) && !(sync.semantics & semantic_private)) {
          bar_ev |= event;
          if (counters & counter_lgkm)
             bar.lgkm = 0;
@@ -605,7 +632,7 @@ void update_barrier_imm(wait_ctx& ctx, uint8_t counters, wait_event event, barri
    }
 }
 
-void update_counters(wait_ctx& ctx, wait_event event, barrier_interaction barrier=barrier_none)
+void update_counters(wait_ctx& ctx, wait_event event, memory_sync_info sync=memory_sync_info())
 {
    uint8_t counters = get_counters_for_event(event);
 
@@ -618,7 +645,7 @@ void update_counters(wait_ctx& ctx, wait_event event, barrier_interaction barrie
    if (counters & counter_vs && ctx.vs_cnt <= ctx.max_vs_cnt)
       ctx.vs_cnt++;
 
-   update_barrier_imm(ctx, counters, event, barrier);
+   update_barrier_imm(ctx, counters, event, sync);
 
    if (ctx.unordered_events & event)
       return;
@@ -647,7 +674,7 @@ void update_counters(wait_ctx& ctx, wait_event event, barrier_interaction barrie
    }
 }
 
-void update_counters_for_flat_load(wait_ctx& ctx, barrier_interaction barrier=barrier_none)
+void update_counters_for_flat_load(wait_ctx& ctx, memory_sync_info sync=memory_sync_info())
 {
    assert(ctx.chip_class < GFX10);
 
@@ -656,7 +683,7 @@ void update_counters_for_flat_load(wait_ctx& ctx, barrier_interaction barrier=ba
    if (ctx.vm_cnt <= ctx.max_vm_cnt)
       ctx.vm_cnt++;
 
-   update_barrier_imm(ctx, counter_vm | counter_lgkm, event_flat, barrier);
+   update_barrier_imm(ctx, counter_vm | counter_lgkm, event_flat, sync);
 
    for (std::pair<PhysReg,wait_entry> e : ctx.gpr_map)
    {
@@ -669,7 +696,8 @@ void update_counters_for_flat_load(wait_ctx& ctx, barrier_interaction barrier=ba
    ctx.pending_flat_vm = true;
 }
 
-void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, bool wait_on_read)
+void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, bool wait_on_read,
+                       bool has_sampler=false)
 {
    uint16_t counters = get_counters_for_event(event);
    wait_imm imm;
@@ -683,9 +711,11 @@ void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event
       imm.vs = 0;
 
    wait_entry new_entry(event, imm, !rc.is_linear(), wait_on_read);
+   new_entry.has_vmem_nosampler = (event & event_vmem) && !has_sampler;
+   new_entry.has_vmem_sampler = (event & event_vmem) && has_sampler;
 
    for (unsigned i = 0; i < rc.size(); i++) {
-      auto it = ctx.gpr_map.emplace(PhysReg{reg.reg+i}, new_entry);
+      auto it = ctx.gpr_map.emplace(PhysReg{reg.reg()+i}, new_entry);
       if (!it.second)
          it.first->second.join(new_entry);
    }
@@ -696,20 +726,20 @@ void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event
          unsigned i = u_bit_scan(&counters_todo);
          ctx.unwaited_instrs[i].insert(std::make_pair(ctx.gen_instr, 0u));
          for (unsigned j = 0; j < rc.size(); j++)
-            ctx.reg_instrs[i][PhysReg{reg.reg+j}].insert(ctx.gen_instr);
+            ctx.reg_instrs[i][PhysReg{reg.reg()+j}].insert(ctx.gen_instr);
       }
    }
 }
 
-void insert_wait_entry(wait_ctx& ctx, Operand op, wait_event event)
+void insert_wait_entry(wait_ctx& ctx, Operand op, wait_event event, bool has_sampler=false)
 {
    if (!op.isConstant() && !op.isUndefined())
-      insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false);
+      insert_wait_entry(ctx, op.physReg(), op.regClass(), event, false, has_sampler);
 }
 
-void insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event)
+void insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event, bool has_sampler=false)
 {
-   insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true);
+   insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, has_sampler);
 }
 
 void gen(Instruction* instr, wait_ctx& ctx)
@@ -741,10 +771,11 @@ void gen(Instruction* instr, wait_ctx& ctx)
       break;
    }
    case Format::FLAT: {
+      FLAT_instruction *flat = static_cast<FLAT_instruction*>(instr);
       if (ctx.chip_class < GFX10 && !instr->definitions.empty())
-         update_counters_for_flat_load(ctx, barrier_buffer);
+         update_counters_for_flat_load(ctx, flat->sync);
       else
-         update_counters(ctx, event_flat, barrier_buffer);
+         update_counters(ctx, event_flat, flat->sync);
 
       if (!instr->definitions.empty())
          insert_wait_entry(ctx, instr->definitions[0], event_flat);
@@ -752,27 +783,26 @@ void gen(Instruction* instr, wait_ctx& ctx)
    }
    case Format::SMEM: {
       SMEM_instruction *smem = static_cast<SMEM_instruction*>(instr);
-      update_counters(ctx, event_smem, static_cast<SMEM_instruction*>(instr)->barrier);
+      update_counters(ctx, event_smem, smem->sync);
 
       if (!instr->definitions.empty())
          insert_wait_entry(ctx, instr->definitions[0], event_smem);
       else if (ctx.chip_class >= GFX10 &&
-               !smem->can_reorder &&
-               smem->barrier == barrier_buffer)
+               !smem->sync.can_reorder())
          ctx.pending_s_buffer_store = true;
 
       break;
    }
    case Format::DS: {
-      bool gds = static_cast<DS_instruction*>(instr)->gds;
-      update_counters(ctx, gds ? event_gds : event_lds, gds ? barrier_none : barrier_shared);
-      if (gds)
+      DS_instruction *ds = static_cast<DS_instruction*>(instr);
+      update_counters(ctx, ds->gds ? event_gds : event_lds, ds->sync);
+      if (ds->gds)
          update_counters(ctx, event_gds_gpr_lock);
 
       if (!instr->definitions.empty())
-         insert_wait_entry(ctx, instr->definitions[0], gds ? event_gds : event_lds);
+         insert_wait_entry(ctx, instr->definitions[0], ds->gds ? event_gds : event_lds);
 
-      if (gds) {
+      if (ds->gds) {
          for (const Operand& op : instr->operands)
             insert_wait_entry(ctx, op, event_gds_gpr_lock);
          insert_wait_entry(ctx, exec, s2, event_gds_gpr_lock, false);
@@ -784,10 +814,12 @@ void gen(Instruction* instr, wait_ctx& ctx)
    case Format::MIMG:
    case Format::GLOBAL: {
       wait_event ev = !instr->definitions.empty() || ctx.chip_class < GFX10 ? event_vmem : event_vmem_store;
-      update_counters(ctx, ev, get_barrier_interaction(instr));
+      update_counters(ctx, ev, get_sync_info(instr));
+
+      bool has_sampler = instr->format == Format::MIMG && !instr->operands[1].isUndefined() && instr->operands[1].regClass() == s4;
 
       if (!instr->definitions.empty())
-         insert_wait_entry(ctx, instr->definitions[0], ev);
+         insert_wait_entry(ctx, instr->definitions[0], ev, has_sampler);
 
       if (ctx.chip_class == GFX6 &&
           instr->format != Format::MIMG &&
@@ -808,7 +840,7 @@ void gen(Instruction* instr, wait_ctx& ctx)
    case Format::SOPP: {
       if (instr->opcode == aco_opcode::s_sendmsg ||
           instr->opcode == aco_opcode::s_sendmsghalt)
-         update_counters(ctx, event_sendmsg, get_barrier_interaction(instr));
+         update_counters(ctx, event_sendmsg);
    }
    default:
       break;
@@ -839,12 +871,11 @@ void handle_block(Program *program, Block& block, wait_ctx& ctx)
 
    wait_imm queued_imm;
 
-   ctx.collect_statistics = program->collect_statistics;
-
    for (aco_ptr<Instruction>& instr : block.instructions) {
       bool is_wait = !parse_wait_instr(ctx, instr.get()).empty();
 
-      queued_imm.combine(kill(instr.get(), ctx));
+      memory_sync_info sync_info = get_sync_info(instr.get());
+      queued_imm.combine(kill(instr.get(), ctx, sync_info));
 
       ctx.gen_instr = instr.get();
       gen(instr.get(), ctx);
@@ -856,6 +887,8 @@ void handle_block(Program *program, Block& block, wait_ctx& ctx)
          }
          new_instructions.emplace_back(std::move(instr));
 
+         queued_imm.combine(perform_barrier(ctx, sync_info, semantic_acquire));
+
          if (ctx.collect_statistics)
             ctx.advance_unwaited_instrs();
       }
@@ -869,14 +902,14 @@ void handle_block(Program *program, Block& block, wait_ctx& ctx)
 
 } /* end namespace */
 
-static uint32_t calculate_score(unsigned num_ctx, wait_ctx *ctx, uint32_t event_mask)
+static uint32_t calculate_score(std::vector<wait_ctx> &ctx_vec, uint32_t event_mask)
 {
    double result = 0.0;
    unsigned num_waits = 0;
    while (event_mask) {
       unsigned event_index = u_bit_scan(&event_mask);
-      for (unsigned i = 0; i < num_ctx; i++) {
-         for (unsigned dist : ctx[i].wait_distances[event_index]) {
+      for (const wait_ctx &ctx : ctx_vec) {
+         for (unsigned dist : ctx.wait_distances[event_index]) {
             double score = dist;
             /* for many events, excessive distances provide little benefit, so
              * decrease the score in that case. */
@@ -918,11 +951,9 @@ void insert_wait_states(Program* program)
 {
    /* per BB ctx */
    std::vector<bool> done(program->blocks.size());
-   wait_ctx in_ctx[program->blocks.size()];
-   wait_ctx out_ctx[program->blocks.size()];
+   std::vector<wait_ctx> in_ctx(program->blocks.size(), wait_ctx(program));
+   std::vector<wait_ctx> out_ctx(program->blocks.size(), wait_ctx(program));
 
-   for (unsigned i = 0; i < program->blocks.size(); i++)
-      in_ctx[i] = wait_ctx(program);
    std::stack<unsigned> loop_header_indices;
    unsigned loop_progress = 0;
 
@@ -972,9 +1003,9 @@ void insert_wait_states(Program* program)
 
    if (program->collect_statistics) {
       program->statistics[statistic_vmem_score] =
-         calculate_score(program->blocks.size(), out_ctx, event_vmem | event_flat | event_vmem_store);
+         calculate_score(out_ctx, event_vmem | event_flat | event_vmem_store);
       program->statistics[statistic_smem_score] =
-         calculate_score(program->blocks.size(), out_ctx, event_smem);
+         calculate_score(out_ctx, event_smem);
    }
 }