aco: keep loop live-through variables spilled
[mesa.git] / src / amd / compiler / aco_spill.cpp
index b684d560abf81da2732947e57d9180a50eadb5b0..7d3055e33e097bc5917342af3161c4498a1bf0ae 100644 (file)
  */
 
 #include "aco_ir.h"
+#include "aco_builder.h"
+#include "sid.h"
+
 #include <map>
+#include <set>
 #include <stack>
-#include "vulkan/radv_shader.h"
-
 
 /*
  * Implements the spilling algorithm on SSA-form from
@@ -54,18 +56,19 @@ struct spill_ctx {
    std::stack<Block*> loop_header;
    std::vector<std::map<Temp, std::pair<uint32_t, uint32_t>>> next_use_distances_start;
    std::vector<std::map<Temp, std::pair<uint32_t, uint32_t>>> next_use_distances_end;
-   std::vector<std::pair<RegClass, std::set<uint32_t>>> interferences;
+   std::vector<std::pair<RegClass, std::unordered_set<uint32_t>>> interferences;
    std::vector<std::vector<uint32_t>> affinities;
    std::vector<bool> is_reloaded;
    std::map<Temp, remat_info> remat;
    std::map<Instruction *, bool> remat_used;
+   unsigned wave_size;
 
    spill_ctx(const RegisterDemand target_pressure, Program* program,
              std::vector<std::vector<RegisterDemand>> register_demand)
       : target_pressure(target_pressure), program(program),
-        register_demand(register_demand), renames(program->blocks.size()),
+        register_demand(std::move(register_demand)), renames(program->blocks.size()),
         spills_entry(program->blocks.size()), spills_exit(program->blocks.size()),
-        processed(program->blocks.size(), false) {}
+        processed(program->blocks.size(), false), wave_size(program->wave_size) {}
 
    void add_affinity(uint32_t first, uint32_t second)
    {
@@ -95,9 +98,19 @@ struct spill_ctx {
       }
    }
 
+   void add_interference(uint32_t first, uint32_t second)
+   {
+      if (interferences[first].first.type() != interferences[second].first.type())
+         return;
+
+      bool inserted = interferences[first].second.insert(second).second;
+      if (inserted)
+         interferences[second].second.insert(first);
+   }
+
    uint32_t allocate_spill_id(RegClass rc)
    {
-      interferences.emplace_back(rc, std::set<uint32_t>());
+      interferences.emplace_back(rc, std::unordered_set<uint32_t>());
       is_reloaded.push_back(false);
       return next_spill_id++;
    }
@@ -157,6 +170,8 @@ void next_uses_per_block(spill_ctx& ctx, unsigned block_idx, std::set<uint32_t>&
          /* omit exec mask */
          if (op.isFixed() && op.physReg() == exec)
             continue;
+         if (op.regClass().type() == RegType::vgpr && op.regClass().is_linear())
+            continue;
          if (op.isTemp())
             next_uses[op.getTemp()] = {block_idx, idx};
       }
@@ -209,7 +224,7 @@ void next_uses_per_block(spill_ctx& ctx, unsigned block_idx, std::set<uint32_t>&
 
 }
 
-void compute_global_next_uses(spill_ctx& ctx, std::vector<std::set<Temp>>& live_out)
+void compute_global_next_uses(spill_ctx& ctx)
 {
    ctx.next_use_distances_start.resize(ctx.program->blocks.size());
    ctx.next_use_distances_end.resize(ctx.program->blocks.size());
@@ -228,11 +243,13 @@ void compute_global_next_uses(spill_ctx& ctx, std::vector<std::set<Temp>>& live_
 bool should_rematerialize(aco_ptr<Instruction>& instr)
 {
    /* TODO: rematerialization is only supported for VOP1, SOP1 and PSEUDO */
-   if (instr->format != Format::VOP1 && instr->format != Format::SOP1 && instr->format != Format::PSEUDO)
+   if (instr->format != Format::VOP1 && instr->format != Format::SOP1 && instr->format != Format::PSEUDO && instr->format != Format::SOPK)
       return false;
    /* TODO: pseudo-instruction rematerialization is only supported for p_create_vector */
    if (instr->format == Format::PSEUDO && instr->opcode != aco_opcode::p_create_vector)
       return false;
+   if (instr->format == Format::SOPK && instr->opcode != aco_opcode::s_movk_i32)
+      return false;
 
    for (const Operand& op : instr->operands) {
       /* TODO: rematerialization using temporaries isn't yet supported */
@@ -252,7 +269,7 @@ aco_ptr<Instruction> do_reload(spill_ctx& ctx, Temp tmp, Temp new_name, uint32_t
    std::map<Temp, remat_info>::iterator remat = ctx.remat.find(tmp);
    if (remat != ctx.remat.end()) {
       Instruction *instr = remat->second.instr;
-      assert((instr->format == Format::VOP1 || instr->format == Format::SOP1 || instr->format == Format::PSEUDO) && "unsupported");
+      assert((instr->format == Format::VOP1 || instr->format == Format::SOP1 || instr->format == Format::PSEUDO || instr->format == Format::SOPK) && "unsupported");
       assert((instr->format != Format::PSEUDO || instr->opcode == aco_opcode::p_create_vector) && "unsupported");
       assert(instr->definitions.size() == 1 && "unsupported");
 
@@ -262,7 +279,10 @@ aco_ptr<Instruction> do_reload(spill_ctx& ctx, Temp tmp, Temp new_name, uint32_t
       } else if (instr->format == Format::SOP1) {
          res.reset(create_instruction<SOP1_instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
       } else if (instr->format == Format::PSEUDO) {
-         res.reset(create_instruction<Instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
+         res.reset(create_instruction<Pseudo_instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
+      } else if (instr->format == Format::SOPK) {
+         res.reset(create_instruction<SOPK_instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
+         static_cast<SOPK_instruction*>(res.get())->imm = static_cast<SOPK_instruction*>(instr)->imm;
       }
       for (unsigned i = 0; i < instr->operands.size(); i++) {
          res->operands[i] = instr->operands[i];
@@ -322,6 +342,8 @@ std::vector<std::map<Temp, uint32_t>> local_next_uses(spill_ctx& ctx, Block* blo
       for (const Operand& op : instr->operands) {
          if (op.isFixed() && op.physReg() == exec)
             continue;
+         if (op.regClass().type() == RegType::vgpr && op.regClass().is_linear())
+            continue;
          if (op.isTemp())
             next_uses[op.getTemp()] = idx;
       }
@@ -361,6 +383,20 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
       }
       unsigned loop_end = i;
 
+      /* keep live-through spilled */
+      for (std::pair<Temp, std::pair<uint32_t, uint32_t>> pair : ctx.next_use_distances_end[block_idx - 1]) {
+         if (pair.second.first < loop_end)
+            continue;
+
+         Temp to_spill = pair.first;
+         auto it = ctx.spills_exit[block_idx - 1].find(to_spill);
+         if (it == ctx.spills_exit[block_idx - 1].end())
+            continue;
+
+         ctx.spills_entry[block_idx][to_spill] = it->second;
+         spilled_registers += to_spill;
+      }
+
       /* select live-through vgpr variables */
       while (new_demand.vgpr - spilled_registers.vgpr > ctx.target_pressure.vgpr) {
          unsigned distance = 0;
@@ -429,6 +465,13 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
       assert(idx != 0 && "loop without phis: TODO");
       idx--;
       RegisterDemand reg_pressure = ctx.register_demand[block_idx][idx] - spilled_registers;
+      /* Consider register pressure from linear predecessors. This can affect
+       * reg_pressure if the branch instructions define sgprs. */
+      for (unsigned pred : block->linear_preds) {
+         reg_pressure.sgpr = std::max<int16_t>(
+            reg_pressure.sgpr, ctx.register_demand[pred].back().sgpr - spilled_registers.sgpr);
+      }
+
       while (reg_pressure.sgpr > ctx.target_pressure.sgpr) {
          unsigned distance = 0;
          Temp to_spill;
@@ -473,7 +516,7 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
       for (std::pair<Temp, uint32_t> pair : ctx.spills_exit[pred_idx]) {
          if (pair.first.type() == RegType::sgpr &&
              ctx.next_use_distances_start[block_idx].find(pair.first) != ctx.next_use_distances_start[block_idx].end() &&
-             ctx.next_use_distances_start[block_idx][pair.first].second > block_idx) {
+             ctx.next_use_distances_start[block_idx][pair.first].first != block_idx) {
             ctx.spills_entry[block_idx].insert(pair);
             spilled_registers.sgpr += pair.first.size();
          }
@@ -483,7 +526,7 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
          for (std::pair<Temp, uint32_t> pair : ctx.spills_exit[pred_idx]) {
             if (pair.first.type() == RegType::vgpr &&
                 ctx.next_use_distances_start[block_idx].find(pair.first) != ctx.next_use_distances_start[block_idx].end() &&
-                ctx.next_use_distances_end[pred_idx][pair.first].second > block_idx) {
+                ctx.next_use_distances_start[block_idx][pair.first].first != block_idx) {
                ctx.spills_entry[block_idx].insert(pair);
                spilled_registers.vgpr += pair.first.size();
             }
@@ -520,7 +563,7 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
 
    /* keep variables spilled on all incoming paths */
    for (std::pair<Temp, std::pair<uint32_t, uint32_t>> pair : ctx.next_use_distances_start[block_idx]) {
-      std::vector<unsigned>& preds = pair.first.type() == RegType::vgpr ? block->logical_preds : block->linear_preds;
+      std::vector<unsigned>& preds = pair.first.is_linear() ? block->linear_preds : block->logical_preds;
       /* If it can be rematerialized, keep the variable spilled if all predecessors do not reload it.
        * Otherwise, if any predecessor reloads it, ensure it's reloaded on all other predecessors.
        * The idea is that it's better in practice to rematerialize redundantly than to create lots of phis. */
@@ -592,16 +635,34 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
          }
       }
    } else {
+      for (unsigned i = 0; i < idx; i++) {
+         aco_ptr<Instruction>& instr = block->instructions[i];
+         assert(is_phi(instr));
+         /* Killed phi definitions increase pressure in the predecessor but not
+          * the block they're in. Since the loops below are both to control
+          * pressure of the start of this block and the ends of it's
+          * predecessors, we need to count killed unspilled phi definitions here. */
+         if (instr->definitions[0].isKill() &&
+             !ctx.spills_entry[block_idx].count(instr->definitions[0].getTemp()))
+            reg_pressure += instr->definitions[0].getTemp();
+      }
       idx--;
    }
    reg_pressure += ctx.register_demand[block_idx][idx] - spilled_registers;
 
+   /* Consider register pressure from linear predecessors. This can affect
+    * reg_pressure if the branch instructions define sgprs. */
+   for (unsigned pred : block->linear_preds) {
+      reg_pressure.sgpr = std::max<int16_t>(
+         reg_pressure.sgpr, ctx.register_demand[pred].back().sgpr - spilled_registers.sgpr);
+   }
+
    while (reg_pressure.sgpr > ctx.target_pressure.sgpr) {
       assert(!partial_spills.empty());
 
       std::set<Temp>::iterator it = partial_spills.begin();
-      Temp to_spill = *it;
-      unsigned distance = ctx.next_use_distances_start[block_idx][*it].second;
+      Temp to_spill = Temp();
+      unsigned distance = 0;
       while (it != partial_spills.end()) {
          assert(ctx.spills_entry[block_idx].find(*it) == ctx.spills_entry[block_idx].end());
 
@@ -623,8 +684,8 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
       assert(!partial_spills.empty());
 
       std::set<Temp>::iterator it = partial_spills.begin();
-      Temp to_spill = *it;
-      unsigned distance = ctx.next_use_distances_start[block_idx][*it].second;
+      Temp to_spill = Temp();
+      unsigned distance = 0;
       while (it != partial_spills.end()) {
          assert(ctx.spills_entry[block_idx].find(*it) == ctx.spills_entry[block_idx].end());
 
@@ -646,6 +707,18 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
 }
 
 
+RegisterDemand get_demand_before(spill_ctx& ctx, unsigned block_idx, unsigned idx)
+{
+   if (idx == 0) {
+      RegisterDemand demand = ctx.register_demand[block_idx][idx];
+      aco_ptr<Instruction>& instr = ctx.program->blocks[block_idx].instructions[idx];
+      aco_ptr<Instruction> instr_before(nullptr);
+      return get_demand_before(demand, instr, instr_before);
+   } else {
+      return ctx.register_demand[block_idx][idx - 1];
+   }
+}
+
 void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
 {
    /* no coupling code necessary */
@@ -654,13 +727,48 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
 
    std::vector<aco_ptr<Instruction>> instructions;
    /* branch block: TODO take other branch into consideration */
-   if (block->linear_preds.size() == 1 && !(block->kind & block_kind_loop_exit)) {
+   if (block->linear_preds.size() == 1 && !(block->kind & (block_kind_loop_exit | block_kind_loop_header))) {
       assert(ctx.processed[block->linear_preds[0]]);
+      assert(ctx.register_demand[block_idx].size() == block->instructions.size());
+      std::vector<RegisterDemand> reg_demand;
+      unsigned insert_idx = 0;
+      unsigned pred_idx = block->linear_preds[0];
+      RegisterDemand demand_before = get_demand_before(ctx, block_idx, 0);
+
+      for (std::pair<Temp, std::pair<uint32_t, uint32_t>> live : ctx.next_use_distances_start[block_idx]) {
+         if (!live.first.is_linear())
+            continue;
+         /* still spilled */
+         if (ctx.spills_entry[block_idx].find(live.first) != ctx.spills_entry[block_idx].end())
+            continue;
+
+         /* in register at end of predecessor */
+         if (ctx.spills_exit[pred_idx].find(live.first) == ctx.spills_exit[pred_idx].end()) {
+            std::map<Temp, Temp>::iterator it = ctx.renames[pred_idx].find(live.first);
+            if (it != ctx.renames[pred_idx].end())
+               ctx.renames[block_idx].insert(*it);
+            continue;
+         }
+
+         /* variable is spilled at predecessor and live at current block: create reload instruction */
+         Temp new_name = {ctx.program->allocateId(), live.first.regClass()};
+         aco_ptr<Instruction> reload = do_reload(ctx, live.first, new_name, ctx.spills_exit[pred_idx][live.first]);
+         instructions.emplace_back(std::move(reload));
+         reg_demand.push_back(demand_before);
+         ctx.renames[block_idx][live.first] = new_name;
+      }
 
       if (block->logical_preds.size() == 1) {
+         do {
+            assert(insert_idx < block->instructions.size());
+            instructions.emplace_back(std::move(block->instructions[insert_idx]));
+            reg_demand.push_back(ctx.register_demand[block_idx][insert_idx]);
+            insert_idx++;
+         } while (instructions.back()->opcode != aco_opcode::p_logical_start);
+
          unsigned pred_idx = block->logical_preds[0];
          for (std::pair<Temp, std::pair<uint32_t, uint32_t>> live : ctx.next_use_distances_start[block_idx]) {
-            if (live.first.type() == RegType::sgpr)
+            if (live.first.is_linear())
                continue;
             /* still spilled */
             if (ctx.spills_entry[block_idx].find(live.first) != ctx.spills_entry[block_idx].end())
@@ -678,45 +786,20 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
             Temp new_name = {ctx.program->allocateId(), live.first.regClass()};
             aco_ptr<Instruction> reload = do_reload(ctx, live.first, new_name, ctx.spills_exit[pred_idx][live.first]);
             instructions.emplace_back(std::move(reload));
+            reg_demand.emplace_back(reg_demand.back());
             ctx.renames[block_idx][live.first] = new_name;
          }
       }
 
-      unsigned pred_idx = block->linear_preds[0];
-      for (std::pair<Temp, std::pair<uint32_t, uint32_t>> live : ctx.next_use_distances_start[block_idx]) {
-         if (live.first.type() == RegType::vgpr)
-            continue;
-         /* still spilled */
-         if (ctx.spills_entry[block_idx].find(live.first) != ctx.spills_entry[block_idx].end())
-            continue;
-
-         /* in register at end of predecessor */
-         if (ctx.spills_exit[pred_idx].find(live.first) == ctx.spills_exit[pred_idx].end()) {
-            std::map<Temp, Temp>::iterator it = ctx.renames[pred_idx].find(live.first);
-            if (it != ctx.renames[pred_idx].end())
-               ctx.renames[block_idx].insert(*it);
-            continue;
-         }
-
-         /* variable is spilled at predecessor and live at current block: create reload instruction */
-         Temp new_name = {ctx.program->allocateId(), live.first.regClass()};
-         aco_ptr<Instruction> reload = do_reload(ctx, live.first, new_name, ctx.spills_exit[pred_idx][live.first]);
-         instructions.emplace_back(std::move(reload));
-         ctx.renames[block_idx][live.first] = new_name;
-      }
-
       /* combine new reload instructions with original block */
       if (!instructions.empty()) {
-         unsigned insert_idx = 0;
-         while (block->instructions[insert_idx]->opcode == aco_opcode::p_phi ||
-                block->instructions[insert_idx]->opcode == aco_opcode::p_linear_phi) {
-            insert_idx++;
-         }
-         ctx.register_demand[block->index].insert(std::next(ctx.register_demand[block->index].begin(), insert_idx),
-                                                  instructions.size(), RegisterDemand());
-         block->instructions.insert(std::next(block->instructions.begin(), insert_idx),
-                                    std::move_iterator<std::vector<aco_ptr<Instruction>>::iterator>(instructions.begin()),
-                                    std::move_iterator<std::vector<aco_ptr<Instruction>>::iterator>(instructions.end()));
+         reg_demand.insert(reg_demand.end(), std::next(ctx.register_demand[block->index].begin(), insert_idx),
+                           ctx.register_demand[block->index].end());
+         ctx.register_demand[block_idx] = std::move(reg_demand);
+         instructions.insert(instructions.end(),
+                             std::move_iterator<std::vector<aco_ptr<Instruction>>::iterator>(std::next(block->instructions.begin(), insert_idx)),
+                             std::move_iterator<std::vector<aco_ptr<Instruction>>::iterator>(block->instructions.end()));
+         block->instructions = std::move(instructions);
       }
       return;
    }
@@ -752,8 +835,7 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
          for (std::pair<Temp, uint32_t> pair : ctx.spills_exit[pred_idx]) {
             if (var == pair.first)
                continue;
-            ctx.interferences[def_spill_id].second.emplace(pair.second);
-            ctx.interferences[pair.second].second.emplace(def_spill_id);
+            ctx.add_interference(def_spill_id, pair.second);
          }
 
          /* check if variable is already spilled at predecessor */
@@ -794,7 +876,7 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
    /* iterate all (other) spilled variables for which to spill at the predecessor */
    // TODO: would be better to have them sorted: first vgprs and first with longest distance
    for (std::pair<Temp, uint32_t> pair : ctx.spills_entry[block_idx]) {
-      std::vector<unsigned> preds = pair.first.type() == RegType::vgpr ? block->logical_preds : block->linear_preds;
+      std::vector<unsigned> preds = pair.first.is_linear() ? block->linear_preds : block->logical_preds;
 
       for (unsigned pred_idx : preds) {
          /* variable is already spilled at predecessor */
@@ -813,8 +895,7 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
          for (std::pair<Temp, uint32_t> exit_spill : ctx.spills_exit[pred_idx]) {
             if (exit_spill.first == pair.first)
                continue;
-            ctx.interferences[exit_spill.second].second.emplace(pair.second);
-            ctx.interferences[pair.second].second.emplace(exit_spill.second);
+            ctx.add_interference(exit_spill.second, pair.second);
          }
 
          /* variable is in register at predecessor and has to be spilled */
@@ -887,7 +968,7 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
       /* skip spilled variables */
       if (ctx.spills_entry[block_idx].find(pair.first) != ctx.spills_entry[block_idx].end())
          continue;
-      std::vector<unsigned> preds = pair.first.type() == RegType::vgpr ? block->logical_preds : block->linear_preds;
+      std::vector<unsigned> preds = pair.first.is_linear() ? block->linear_preds : block->logical_preds;
 
       /* variable is dead at predecessor, it must be from a phi */
       bool is_dead = false;
@@ -941,7 +1022,7 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
 
       if (!is_same) {
          /* the variable was renamed differently in the predecessors: we have to create a phi */
-         aco_opcode opcode = pair.first.type() == RegType::vgpr ? aco_opcode::p_phi : aco_opcode::p_linear_phi;
+         aco_opcode opcode = pair.first.is_linear() ? aco_opcode::p_linear_phi : aco_opcode::p_phi;
          aco_ptr<Pseudo_instruction> phi{create_instruction<Pseudo_instruction>(opcode, Format::PSEUDO, preds.size(), 1)};
          rename = {ctx.program->allocateId(), pair.first.regClass()};
          for (unsigned i = 0; i < phi->operands.size(); i++) {
@@ -969,8 +1050,12 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
       idx++;
    }
 
-   ctx.register_demand[block->index].erase(ctx.register_demand[block->index].begin(), ctx.register_demand[block->index].begin() + idx);
-   ctx.register_demand[block->index].insert(ctx.register_demand[block->index].begin(), instructions.size(), RegisterDemand());
+   if (!ctx.processed[block_idx]) {
+      assert(!(block->kind & block_kind_loop_header));
+      RegisterDemand demand_before = get_demand_before(ctx, block_idx, idx);
+      ctx.register_demand[block->index].erase(ctx.register_demand[block->index].begin(), ctx.register_demand[block->index].begin() + idx);
+      ctx.register_demand[block->index].insert(ctx.register_demand[block->index].begin(), instructions.size(), demand_before);
+   }
 
    std::vector<aco_ptr<Instruction>>::iterator start = std::next(block->instructions.begin(), idx);
    instructions.insert(instructions.end(), std::move_iterator<std::vector<aco_ptr<Instruction>>::iterator>(start),
@@ -981,6 +1066,8 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
 void process_block(spill_ctx& ctx, unsigned block_idx, Block* block,
                    std::map<Temp, uint32_t> &current_spills, RegisterDemand spilled_registers)
 {
+   assert(!ctx.processed[block_idx]);
+
    std::vector<std::map<Temp, uint32_t>> local_next_use_distance;
    std::vector<aco_ptr<Instruction>> instructions;
    unsigned idx = 0;
@@ -1032,18 +1119,7 @@ void process_block(spill_ctx& ctx, unsigned block_idx, Block* block,
       if (block->register_demand.exceeds(ctx.target_pressure)) {
 
          RegisterDemand new_demand = ctx.register_demand[block_idx][idx];
-         if (idx == 0) {
-            RegisterDemand demand_before = new_demand;
-            for (const Definition& def : instr->definitions)
-               demand_before -= def.getTemp();
-            for (const Operand& op : instr->operands) {
-               if (op.isFirstKill())
-                  demand_before += op.getTemp();
-            }
-            new_demand.update(demand_before);
-         } else {
-            new_demand.update(ctx.register_demand[block_idx][idx - 1]);
-         }
+         new_demand.update(get_demand_before(ctx, block_idx, idx));
 
          assert(!local_next_use_distance.empty());
 
@@ -1084,14 +1160,10 @@ void process_block(spill_ctx& ctx, unsigned block_idx, Block* block,
             uint32_t spill_id = ctx.allocate_spill_id(to_spill.regClass());
 
             /* add interferences with currently spilled variables */
-            for (std::pair<Temp, uint32_t> pair : current_spills) {
-               ctx.interferences[spill_id].second.emplace(pair.second);
-               ctx.interferences[pair.second].second.emplace(spill_id);
-            }
-            for (std::pair<Temp, std::pair<Temp, uint32_t>> pair : reloads) {
-               ctx.interferences[spill_id].second.emplace(pair.second.second);
-               ctx.interferences[pair.second.second].second.emplace(spill_id);
-            }
+            for (std::pair<Temp, uint32_t> pair : current_spills)
+               ctx.add_interference(spill_id, pair.second);
+            for (std::pair<Temp, std::pair<Temp, uint32_t>> pair : reloads)
+               ctx.add_interference(spill_id, pair.second.second);
 
             current_spills[to_spill] = spill_id;
             spilled_registers += to_spill;
@@ -1125,16 +1197,14 @@ void process_block(spill_ctx& ctx, unsigned block_idx, Block* block,
 void spill_block(spill_ctx& ctx, unsigned block_idx)
 {
    Block* block = &ctx.program->blocks[block_idx];
-   ctx.processed[block_idx] = true;
 
    /* determine set of variables which are spilled at the beginning of the block */
    RegisterDemand spilled_registers = init_live_in_vars(ctx, block, block_idx);
 
    /* add interferences for spilled variables */
-   for (std::pair<Temp, uint32_t> x : ctx.spills_entry[block_idx]) {
-      for (std::pair<Temp, uint32_t> y : ctx.spills_entry[block_idx])
-         if (x.second != y.second)
-            ctx.interferences[x.second].second.emplace(y.second);
+   for (auto it = ctx.spills_entry[block_idx].begin(); it != ctx.spills_entry[block_idx].end(); ++it) {
+      for (auto it2 = std::next(it); it2 != ctx.spills_entry[block_idx].end(); ++it2)
+         ctx.add_interference(it->second, it2->second);
    }
 
    bool is_loop_header = block->loop_nest_depth && ctx.loop_header.top()->index == block_idx;
@@ -1162,6 +1232,8 @@ void spill_block(spill_ctx& ctx, unsigned block_idx)
    else
       ctx.spills_exit[block_idx].insert(current_spills.begin(), current_spills.end());
 
+   ctx.processed[block_idx] = true;
+
    /* check if the next block leaves the current loop */
    if (block->loop_nest_depth == 0 || ctx.program->blocks[block_idx + 1].loop_nest_depth >= block->loop_nest_depth)
       return;
@@ -1244,9 +1316,148 @@ void spill_block(spill_ctx& ctx, unsigned block_idx)
    ctx.loop_header.pop();
 }
 
+Temp load_scratch_resource(spill_ctx& ctx, Temp& scratch_offset,
+                           std::vector<aco_ptr<Instruction>>& instructions,
+                           unsigned offset, bool is_top_level)
+{
+   Builder bld(ctx.program);
+   if (is_top_level) {
+      bld.reset(&instructions);
+   } else {
+      /* find p_logical_end */
+      unsigned idx = instructions.size() - 1;
+      while (instructions[idx]->opcode != aco_opcode::p_logical_end)
+         idx--;
+      bld.reset(&instructions, std::next(instructions.begin(), idx));
+   }
+
+   Temp private_segment_buffer = ctx.program->private_segment_buffer;
+   if (ctx.program->stage != compute_cs)
+      private_segment_buffer = bld.smem(aco_opcode::s_load_dwordx2, bld.def(s2), private_segment_buffer, Operand(0u));
+
+   if (offset)
+      scratch_offset = bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.def(s1, scc), scratch_offset, Operand(offset));
+
+   uint32_t rsrc_conf = S_008F0C_ADD_TID_ENABLE(1) |
+                        S_008F0C_INDEX_STRIDE(ctx.program->wave_size == 64 ? 3 : 2);
+
+   if (ctx.program->chip_class >= GFX10) {
+      rsrc_conf |= S_008F0C_FORMAT(V_008F0C_IMG_FORMAT_32_FLOAT) |
+                   S_008F0C_OOB_SELECT(V_008F0C_OOB_SELECT_RAW) |
+                   S_008F0C_RESOURCE_LEVEL(1);
+   } else if (ctx.program->chip_class <= GFX7) { /* dfmt modifies stride on GFX8/GFX9 when ADD_TID_EN=1 */
+      rsrc_conf |= S_008F0C_NUM_FORMAT(V_008F0C_BUF_NUM_FORMAT_FLOAT) |
+                   S_008F0C_DATA_FORMAT(V_008F0C_BUF_DATA_FORMAT_32);
+   }
+   /* older generations need element size = 4 bytes. element size removed in GFX9 */
+   if (ctx.program->chip_class <= GFX8)
+      rsrc_conf |= S_008F0C_ELEMENT_SIZE(1);
+
+   return bld.pseudo(aco_opcode::p_create_vector, bld.def(s4),
+                     private_segment_buffer, Operand(-1u),
+                     Operand(rsrc_conf));
+}
+
+void add_interferences(spill_ctx& ctx, std::vector<bool>& is_assigned,
+                       std::vector<uint32_t>& slots, std::vector<bool>& slots_used,
+                       unsigned id)
+{
+   for (unsigned other : ctx.interferences[id].second) {
+      if (!is_assigned[other])
+         continue;
+
+      RegClass other_rc = ctx.interferences[other].first;
+      unsigned slot = slots[other];
+      std::fill(slots_used.begin() + slot, slots_used.begin() + slot + other_rc.size(), true);
+   }
+}
+
+unsigned find_available_slot(std::vector<bool>& used, unsigned wave_size,
+                             unsigned size, bool is_sgpr, unsigned *num_slots)
+{
+   unsigned wave_size_minus_one = wave_size - 1;
+   unsigned slot = 0;
+
+   while (true) {
+      bool available = true;
+      for (unsigned i = 0; i < size; i++) {
+         if (slot + i < used.size() && used[slot + i]) {
+            available = false;
+            break;
+         }
+      }
+      if (!available) {
+         slot++;
+         continue;
+      }
+
+      if (is_sgpr && ((slot & wave_size_minus_one) > wave_size - size)) {
+         slot = align(slot, wave_size);
+         continue;
+      }
+
+      std::fill(used.begin(), used.end(), false);
+
+      if (slot + size > used.size())
+         used.resize(slot + size);
+
+      return slot;
+   }
+}
+
+void assign_spill_slots_helper(spill_ctx& ctx, RegType type,
+                               std::vector<bool>& is_assigned,
+                               std::vector<uint32_t>& slots,
+                               unsigned *num_slots)
+{
+   std::vector<bool> slots_used(*num_slots);
+
+   /* assign slots for ids with affinities first */
+   for (std::vector<uint32_t>& vec : ctx.affinities) {
+      if (ctx.interferences[vec[0]].first.type() != type)
+         continue;
+
+      for (unsigned id : vec) {
+         if (!ctx.is_reloaded[id])
+            continue;
+
+         add_interferences(ctx, is_assigned, slots, slots_used, id);
+      }
+
+      unsigned slot = find_available_slot(slots_used, ctx.wave_size,
+                                          ctx.interferences[vec[0]].first.size(),
+                                          type == RegType::sgpr, num_slots);
+
+      for (unsigned id : vec) {
+         assert(!is_assigned[id]);
+
+         if (ctx.is_reloaded[id]) {
+            slots[id] = slot;
+            is_assigned[id] = true;
+         }
+      }
+   }
+
+   /* assign slots for ids without affinities */
+   for (unsigned id = 0; id < ctx.interferences.size(); id++) {
+      if (is_assigned[id] || !ctx.is_reloaded[id] || ctx.interferences[id].first.type() != type)
+         continue;
+
+      add_interferences(ctx, is_assigned, slots, slots_used, id);
+
+      unsigned slot = find_available_slot(slots_used, ctx.wave_size,
+                                          ctx.interferences[id].first.size(),
+                                          type == RegType::sgpr, num_slots);
+
+      slots[id] = slot;
+      is_assigned[id] = true;
+   }
+
+   *num_slots = slots_used.size();
+}
+
 void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
-   std::map<uint32_t, uint32_t> sgpr_slot;
-   std::map<uint32_t, uint32_t> vgpr_slot;
+   std::vector<uint32_t> slots(ctx.interferences.size());
    std::vector<bool> is_assigned(ctx.interferences.size());
 
    /* first, handle affinities: just merge all interferences into both spill ids */
@@ -1254,13 +1465,6 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
       for (unsigned i = 0; i < vec.size(); i++) {
          for (unsigned j = i + 1; j < vec.size(); j++) {
             assert(vec[i] != vec[j]);
-            for (uint32_t id : ctx.interferences[vec[i]].second)
-               ctx.interferences[id].second.insert(vec[j]);
-            for (uint32_t id : ctx.interferences[vec[j]].second)
-               ctx.interferences[id].second.insert(vec[i]);
-            ctx.interferences[vec[i]].second.insert(ctx.interferences[vec[j]].second.begin(), ctx.interferences[vec[j]].second.end());
-            ctx.interferences[vec[j]].second.insert(ctx.interferences[vec[i]].second.begin(), ctx.interferences[vec[i]].second.end());
-
             bool reloaded = ctx.is_reloaded[vec[i]] || ctx.is_reloaded[vec[j]];
             ctx.is_reloaded[vec[i]] = reloaded;
             ctx.is_reloaded[vec[j]] = reloaded;
@@ -1272,96 +1476,9 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
          assert(i != id);
 
    /* for each spill slot, assign as many spill ids as possible */
-   std::vector<std::set<uint32_t>> spill_slot_interferences;
-   unsigned slot_idx = 0;
-   bool done = false;
-
-   /* assign sgpr spill slots */
-   while (!done) {
-      done = true;
-      for (unsigned id = 0; id < ctx.interferences.size(); id++) {
-         if (is_assigned[id] || !ctx.is_reloaded[id])
-            continue;
-         if (ctx.interferences[id].first.type() != RegType::sgpr)
-            continue;
-
-         /* check interferences */
-         bool interferes = false;
-         for (unsigned i = slot_idx; i < slot_idx + ctx.interferences[id].first.size(); i++) {
-            if (i == spill_slot_interferences.size())
-               spill_slot_interferences.emplace_back(std::set<uint32_t>());
-            if (spill_slot_interferences[i].find(id) != spill_slot_interferences[i].end() || i / 64 != slot_idx / 64) {
-               interferes = true;
-               break;
-            }
-         }
-         if (interferes) {
-            done = false;
-            continue;
-         }
-
-         /* we found a spill id which can be assigned to current spill slot */
-         sgpr_slot[id] = slot_idx;
-         is_assigned[id] = true;
-         for (unsigned i = slot_idx; i < slot_idx + ctx.interferences[id].first.size(); i++)
-            spill_slot_interferences[i].insert(ctx.interferences[id].second.begin(), ctx.interferences[id].second.end());
-
-         /* add all affinities: there are no additional interferences */
-         for (std::vector<uint32_t>& vec : ctx.affinities) {
-            bool found_affinity = false;
-            for (uint32_t entry : vec) {
-               if (entry == id) {
-                  found_affinity = true;
-                  break;
-               }
-            }
-            if (!found_affinity)
-               continue;
-            for (uint32_t entry : vec) {
-               sgpr_slot[entry] = slot_idx;
-               is_assigned[entry] = true;
-            }
-         }
-      }
-      slot_idx++;
-   }
-
-   slot_idx = 0;
-   done = false;
-
-   /* assign vgpr spill slots */
-   while (!done) {
-      done = true;
-      for (unsigned id = 0; id < ctx.interferences.size(); id++) {
-         if (is_assigned[id] || !ctx.is_reloaded[id])
-            continue;
-         if (ctx.interferences[id].first.type() != RegType::vgpr)
-            continue;
-
-         /* check interferences */
-         bool interferes = false;
-         for (unsigned i = slot_idx; i < slot_idx + ctx.interferences[id].first.size(); i++) {
-            if (i == spill_slot_interferences.size())
-               spill_slot_interferences.emplace_back(std::set<uint32_t>());
-            /* check for interference and ensure that vector regs are stored next to each other */
-            if (spill_slot_interferences[i].find(id) != spill_slot_interferences[i].end() || i / 64 != slot_idx / 64) {
-               interferes = true;
-               break;
-            }
-         }
-         if (interferes) {
-            done = false;
-            continue;
-         }
-
-         /* we found a spill id which can be assigned to current spill slot */
-         vgpr_slot[id] = slot_idx;
-         is_assigned[id] = true;
-         for (unsigned i = slot_idx; i < slot_idx + ctx.interferences[id].first.size(); i++)
-            spill_slot_interferences[i].insert(ctx.interferences[id].second.begin(), ctx.interferences[id].second.end());
-      }
-      slot_idx++;
-   }
+   unsigned sgpr_spill_slots = 0, vgpr_spill_slots = 0;
+   assign_spill_slots_helper(ctx, RegType::sgpr, is_assigned, slots, &sgpr_spill_slots);
+   assign_spill_slots_helper(ctx, RegType::vgpr, is_assigned, slots, &vgpr_spill_slots);
 
    for (unsigned id = 0; id < is_assigned.size(); id++)
       assert(is_assigned[id] || !ctx.is_reloaded[id]);
@@ -1374,19 +1491,17 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
                continue;
             assert(ctx.is_reloaded[vec[i]] == ctx.is_reloaded[vec[j]]);
             assert(ctx.interferences[vec[i]].first.type() == ctx.interferences[vec[j]].first.type());
-            if (ctx.interferences[vec[i]].first.type() == RegType::sgpr)
-               assert(sgpr_slot[vec[i]] == sgpr_slot[vec[j]]);
-            else
-               assert(vgpr_slot[vec[i]] == vgpr_slot[vec[j]]);
+            assert(slots[vec[i]] == slots[vec[j]]);
          }
       }
    }
 
    /* hope, we didn't mess up */
-   std::vector<Temp> vgpr_spill_temps((spill_slot_interferences.size() + 63) / 64);
+   std::vector<Temp> vgpr_spill_temps((sgpr_spill_slots + ctx.wave_size - 1) / ctx.wave_size);
    assert(vgpr_spill_temps.size() <= spills_to_vgpr);
 
    /* replace pseudo instructions with actual hardware instructions */
+   Temp scratch_offset = ctx.program->scratch_offset, scratch_rsrc = Temp();
    unsigned last_top_level_block_idx = 0;
    std::vector<bool> reload_in_loop(vgpr_spill_temps.size());
    for (Block& block : ctx.program->blocks) {
@@ -1426,8 +1541,8 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
             bool can_destroy = true;
             for (std::pair<Temp, uint32_t> pair : ctx.spills_exit[block.linear_preds[0]]) {
 
-               if (sgpr_slot.find(pair.second) != sgpr_slot.end() &&
-                   sgpr_slot[pair.second] / 64 == i) {
+               if (ctx.interferences[pair.second].first.type() == RegType::sgpr &&
+                   slots[pair.second] / ctx.wave_size == i) {
                   can_destroy = false;
                   break;
                }
@@ -1440,6 +1555,7 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
       std::vector<aco_ptr<Instruction>>::iterator it;
       std::vector<aco_ptr<Instruction>> instructions;
       instructions.reserve(block.instructions.size());
+      Builder bld(ctx.program, &instructions);
       for (it = block.instructions.begin(); it != block.instructions.end(); ++it) {
 
          if ((*it)->opcode == aco_opcode::p_spill) {
@@ -1447,20 +1563,53 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
 
             if (!ctx.is_reloaded[spill_id]) {
                /* never reloaded, so don't spill */
-            } else if (vgpr_slot.find(spill_id) != vgpr_slot.end()) {
+            } else if (!is_assigned[spill_id]) {
+               unreachable("No spill slot assigned for spill id");
+            } else if (ctx.interferences[spill_id].first.type() == RegType::vgpr) {
                /* spill vgpr */
                ctx.program->config->spilled_vgprs += (*it)->operands[0].size();
+               uint32_t spill_slot = slots[spill_id];
+               bool add_offset_to_sgpr = ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size + vgpr_spill_slots * 4 > 4096;
+               unsigned base_offset = add_offset_to_sgpr ? 0 : ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size;
+
+               /* check if the scratch resource descriptor already exists */
+               if (scratch_rsrc == Temp()) {
+                  unsigned offset = add_offset_to_sgpr ? ctx.program->config->scratch_bytes_per_wave : 0;
+                  scratch_rsrc = load_scratch_resource(ctx, scratch_offset,
+                                                       last_top_level_block_idx == block.index ?
+                                                       instructions : ctx.program->blocks[last_top_level_block_idx].instructions,
+                                                       offset,
+                                                       last_top_level_block_idx == block.index);
+               }
 
-               assert(false && "vgpr spilling not yet implemented.");
-            } else if (sgpr_slot.find(spill_id) != sgpr_slot.end()) {
+               unsigned offset = base_offset + spill_slot * 4;
+               aco_opcode opcode = aco_opcode::buffer_store_dword;
+               assert((*it)->operands[0].isTemp());
+               Temp temp = (*it)->operands[0].getTemp();
+               assert(temp.type() == RegType::vgpr && !temp.is_linear());
+               if (temp.size() > 1) {
+                  Instruction* split{create_instruction<Pseudo_instruction>(aco_opcode::p_split_vector, Format::PSEUDO, 1, temp.size())};
+                  split->operands[0] = Operand(temp);
+                  for (unsigned i = 0; i < temp.size(); i++)
+                     split->definitions[i] = bld.def(v1);
+                  bld.insert(split);
+                  for (unsigned i = 0; i < temp.size(); i++) {
+                     Instruction *instr = bld.mubuf(opcode, scratch_rsrc, Operand(v1), scratch_offset, split->definitions[i].getTemp(), offset + i * 4, false, true);
+                     static_cast<MUBUF_instruction *>(instr)->sync = memory_sync_info(storage_vgpr_spill, semantic_private);
+                  }
+               } else {
+                  Instruction *instr = bld.mubuf(opcode, scratch_rsrc, Operand(v1), scratch_offset, temp, offset, false, true);
+                  static_cast<MUBUF_instruction *>(instr)->sync = memory_sync_info(storage_vgpr_spill, semantic_private);
+               }
+            } else {
                ctx.program->config->spilled_sgprs += (*it)->operands[0].size();
 
-               uint32_t spill_slot = sgpr_slot[spill_id];
+               uint32_t spill_slot = slots[spill_id];
 
                /* check if the linear vgpr already exists */
-               if (vgpr_spill_temps[spill_slot / 64] == Temp()) {
+               if (vgpr_spill_temps[spill_slot / ctx.wave_size] == Temp()) {
                   Temp linear_vgpr = {ctx.program->allocateId(), v1.as_linear()};
-                  vgpr_spill_temps[spill_slot / 64] = linear_vgpr;
+                  vgpr_spill_temps[spill_slot / ctx.wave_size] = linear_vgpr;
                   aco_ptr<Pseudo_instruction> create{create_instruction<Pseudo_instruction>(aco_opcode::p_start_linear_vgpr, Format::PSEUDO, 0, 1)};
                   create->definitions[0] = Definition(linear_vgpr);
                   /* find the right place to insert this definition */
@@ -1477,30 +1626,59 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
 
                /* spill sgpr: just add the vgpr temp to operands */
                Pseudo_instruction* spill = create_instruction<Pseudo_instruction>(aco_opcode::p_spill, Format::PSEUDO, 3, 0);
-               spill->operands[0] = Operand(vgpr_spill_temps[spill_slot / 64]);
-               spill->operands[1] = Operand(spill_slot % 64);
+               spill->operands[0] = Operand(vgpr_spill_temps[spill_slot / ctx.wave_size]);
+               spill->operands[1] = Operand(spill_slot % ctx.wave_size);
                spill->operands[2] = (*it)->operands[0];
                instructions.emplace_back(aco_ptr<Instruction>(spill));
-            } else {
-               unreachable("No spill slot assigned for spill id");
             }
 
          } else if ((*it)->opcode == aco_opcode::p_reload) {
             uint32_t spill_id = (*it)->operands[0].constantValue();
             assert(ctx.is_reloaded[spill_id]);
 
-            if (vgpr_slot.find(spill_id) != vgpr_slot.end()) {
+            if (!is_assigned[spill_id]) {
+               unreachable("No spill slot assigned for spill id");
+            } else if (ctx.interferences[spill_id].first.type() == RegType::vgpr) {
                /* reload vgpr */
-               assert(false && "vgpr spilling not yet implemented.");
+               uint32_t spill_slot = slots[spill_id];
+               bool add_offset_to_sgpr = ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size + vgpr_spill_slots * 4 > 4096;
+               unsigned base_offset = add_offset_to_sgpr ? 0 : ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size;
+
+               /* check if the scratch resource descriptor already exists */
+               if (scratch_rsrc == Temp()) {
+                  unsigned offset = add_offset_to_sgpr ? ctx.program->config->scratch_bytes_per_wave : 0;
+                  scratch_rsrc = load_scratch_resource(ctx, scratch_offset,
+                                                       last_top_level_block_idx == block.index ?
+                                                       instructions : ctx.program->blocks[last_top_level_block_idx].instructions,
+                                                       offset,
+                                                       last_top_level_block_idx == block.index);
+               }
 
-            } else if (sgpr_slot.find(spill_id) != sgpr_slot.end()) {
-               uint32_t spill_slot = sgpr_slot[spill_id];
-               reload_in_loop[spill_slot / 64] = block.loop_nest_depth > 0;
+               unsigned offset = base_offset + spill_slot * 4;
+               aco_opcode opcode = aco_opcode::buffer_load_dword;
+               Definition def = (*it)->definitions[0];
+               if (def.size() > 1) {
+                  Instruction* vec{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, def.size(), 1)};
+                  vec->definitions[0] = def;
+                  for (unsigned i = 0; i < def.size(); i++) {
+                     Temp tmp = bld.tmp(v1);
+                     vec->operands[i] = Operand(tmp);
+                     Instruction *instr = bld.mubuf(opcode, Definition(tmp), scratch_rsrc, Operand(v1), scratch_offset, offset + i * 4, false, true);
+                     static_cast<MUBUF_instruction *>(instr)->sync = memory_sync_info(storage_vgpr_spill, semantic_private);
+                  }
+                  bld.insert(vec);
+               } else {
+                  Instruction *instr = bld.mubuf(opcode, def, scratch_rsrc, Operand(v1), scratch_offset, offset, false, true);
+                  static_cast<MUBUF_instruction *>(instr)->sync = memory_sync_info(storage_vgpr_spill, semantic_private);
+               }
+            } else {
+               uint32_t spill_slot = slots[spill_id];
+               reload_in_loop[spill_slot / ctx.wave_size] = block.loop_nest_depth > 0;
 
                /* check if the linear vgpr already exists */
-               if (vgpr_spill_temps[spill_slot / 64] == Temp()) {
+               if (vgpr_spill_temps[spill_slot / ctx.wave_size] == Temp()) {
                   Temp linear_vgpr = {ctx.program->allocateId(), v1.as_linear()};
-                  vgpr_spill_temps[spill_slot / 64] = linear_vgpr;
+                  vgpr_spill_temps[spill_slot / ctx.wave_size] = linear_vgpr;
                   aco_ptr<Pseudo_instruction> create{create_instruction<Pseudo_instruction>(aco_opcode::p_start_linear_vgpr, Format::PSEUDO, 0, 1)};
                   create->definitions[0] = Definition(linear_vgpr);
                   /* find the right place to insert this definition */
@@ -1517,12 +1695,10 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
 
                /* reload sgpr: just add the vgpr temp to operands */
                Pseudo_instruction* reload = create_instruction<Pseudo_instruction>(aco_opcode::p_reload, Format::PSEUDO, 2, 1);
-               reload->operands[0] = Operand(vgpr_spill_temps[spill_slot / 64]);
-               reload->operands[1] = Operand(spill_slot % 64);
+               reload->operands[0] = Operand(vgpr_spill_temps[spill_slot / ctx.wave_size]);
+               reload->operands[1] = Operand(spill_slot % ctx.wave_size);
                reload->definitions[0] = (*it)->definitions[0];
                instructions.emplace_back(aco_ptr<Instruction>(reload));
-            } else {
-               unreachable("No spill slot assigned for spill id");
             }
          } else if (!ctx.remat_used.count(it->get()) || ctx.remat_used[it->get()]) {
             instructions.emplace_back(std::move(*it));
@@ -1532,6 +1708,9 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
       block.instructions = std::move(instructions);
    }
 
+   /* update required scratch memory */
+   ctx.program->config->scratch_bytes_per_wave += align(vgpr_spill_slots * 4 * ctx.program->wave_size, 1024);
+
    /* SSA elimination inserts copies for logical phis right before p_logical_end
     * So if a linear vgpr is used between that p_logical_end and the branch,
     * we need to ensure logical phis don't choose a definition which aliases
@@ -1602,18 +1781,17 @@ void spill(Program* program, live& live_vars, const struct radv_nir_compiler_opt
    /* calculate target register demand */
    RegisterDemand register_target = program->max_reg_demand;
    if (register_target.sgpr > program->sgpr_limit)
-      register_target.vgpr += (register_target.sgpr - program->sgpr_limit + 63 + 32) / 64;
+      register_target.vgpr += (register_target.sgpr - program->sgpr_limit + program->wave_size - 1 + 32) / program->wave_size;
    register_target.sgpr = program->sgpr_limit;
 
    if (register_target.vgpr > program->vgpr_limit)
       register_target.sgpr = program->sgpr_limit - 5;
-   register_target.vgpr = program->vgpr_limit - (register_target.vgpr - program->max_reg_demand.vgpr);
-
-   int spills_to_vgpr = (program->max_reg_demand.sgpr - register_target.sgpr + 63 + 32) / 64;
+   int spills_to_vgpr = (program->max_reg_demand.sgpr - register_target.sgpr + program->wave_size - 1 + 32) / program->wave_size;
+   register_target.vgpr = program->vgpr_limit - spills_to_vgpr;
 
    /* initialize ctx */
    spill_ctx ctx(register_target, program, live_vars.register_demand);
-   compute_global_next_uses(ctx, live_vars.live_out);
+   compute_global_next_uses(ctx);
    get_rematerialize_info(ctx);
 
    /* create spills and reloads */
@@ -1626,7 +1804,7 @@ void spill(Program* program, live& live_vars, const struct radv_nir_compiler_opt
    /* update live variable information */
    live_vars = live_var_analysis(program, options);
 
-   assert(program->num_waves >= 0);
+   assert(program->num_waves > 0);
 }
 
 }