aco: optimize 16-bit and 64-bit float comparisons
[mesa.git] / src / amd / compiler / aco_spill.cpp
index 1f3f5ea3b52508c34649ad0411530dde81c60e0f..54b84488a0a4cba02a4f403a7aac98aaf25a4b97 100644 (file)
@@ -28,6 +28,7 @@
 #include "sid.h"
 
 #include <map>
+#include <set>
 #include <stack>
 
 /*
@@ -60,13 +61,14 @@ struct spill_ctx {
    std::vector<bool> is_reloaded;
    std::map<Temp, remat_info> remat;
    std::map<Instruction *, bool> remat_used;
+   unsigned wave_size;
 
    spill_ctx(const RegisterDemand target_pressure, Program* program,
              std::vector<std::vector<RegisterDemand>> register_demand)
       : target_pressure(target_pressure), program(program),
-        register_demand(register_demand), renames(program->blocks.size()),
+        register_demand(std::move(register_demand)), renames(program->blocks.size()),
         spills_entry(program->blocks.size()), spills_exit(program->blocks.size()),
-        processed(program->blocks.size(), false) {}
+        processed(program->blocks.size(), false), wave_size(program->wave_size) {}
 
    void add_affinity(uint32_t first, uint32_t second)
    {
@@ -212,7 +214,7 @@ void next_uses_per_block(spill_ctx& ctx, unsigned block_idx, std::set<uint32_t>&
 
 }
 
-void compute_global_next_uses(spill_ctx& ctx, std::vector<std::set<Temp>>& live_out)
+void compute_global_next_uses(spill_ctx& ctx)
 {
    ctx.next_use_distances_start.resize(ctx.program->blocks.size());
    ctx.next_use_distances_end.resize(ctx.program->blocks.size());
@@ -231,11 +233,13 @@ void compute_global_next_uses(spill_ctx& ctx, std::vector<std::set<Temp>>& live_
 bool should_rematerialize(aco_ptr<Instruction>& instr)
 {
    /* TODO: rematerialization is only supported for VOP1, SOP1 and PSEUDO */
-   if (instr->format != Format::VOP1 && instr->format != Format::SOP1 && instr->format != Format::PSEUDO)
+   if (instr->format != Format::VOP1 && instr->format != Format::SOP1 && instr->format != Format::PSEUDO && instr->format != Format::SOPK)
       return false;
    /* TODO: pseudo-instruction rematerialization is only supported for p_create_vector */
    if (instr->format == Format::PSEUDO && instr->opcode != aco_opcode::p_create_vector)
       return false;
+   if (instr->format == Format::SOPK && instr->opcode != aco_opcode::s_movk_i32)
+      return false;
 
    for (const Operand& op : instr->operands) {
       /* TODO: rematerialization using temporaries isn't yet supported */
@@ -255,7 +259,7 @@ aco_ptr<Instruction> do_reload(spill_ctx& ctx, Temp tmp, Temp new_name, uint32_t
    std::map<Temp, remat_info>::iterator remat = ctx.remat.find(tmp);
    if (remat != ctx.remat.end()) {
       Instruction *instr = remat->second.instr;
-      assert((instr->format == Format::VOP1 || instr->format == Format::SOP1 || instr->format == Format::PSEUDO) && "unsupported");
+      assert((instr->format == Format::VOP1 || instr->format == Format::SOP1 || instr->format == Format::PSEUDO || instr->format == Format::SOPK) && "unsupported");
       assert((instr->format != Format::PSEUDO || instr->opcode == aco_opcode::p_create_vector) && "unsupported");
       assert(instr->definitions.size() == 1 && "unsupported");
 
@@ -265,7 +269,10 @@ aco_ptr<Instruction> do_reload(spill_ctx& ctx, Temp tmp, Temp new_name, uint32_t
       } else if (instr->format == Format::SOP1) {
          res.reset(create_instruction<SOP1_instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
       } else if (instr->format == Format::PSEUDO) {
-         res.reset(create_instruction<Instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
+         res.reset(create_instruction<Pseudo_instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
+      } else if (instr->format == Format::SOPK) {
+         res.reset(create_instruction<SOPK_instruction>(instr->opcode, instr->format, instr->operands.size(), instr->definitions.size()));
+         static_cast<SOPK_instruction*>(res.get())->imm = static_cast<SOPK_instruction*>(instr)->imm;
       }
       for (unsigned i = 0; i < instr->operands.size(); i++) {
          res->operands[i] = instr->operands[i];
@@ -651,6 +658,18 @@ RegisterDemand init_live_in_vars(spill_ctx& ctx, Block* block, unsigned block_id
 }
 
 
+RegisterDemand get_demand_before(spill_ctx& ctx, unsigned block_idx, unsigned idx)
+{
+   if (idx == 0) {
+      RegisterDemand demand = ctx.register_demand[block_idx][idx];
+      aco_ptr<Instruction>& instr = ctx.program->blocks[block_idx].instructions[idx];
+      aco_ptr<Instruction> instr_before(nullptr);
+      return get_demand_before(demand, instr, instr_before);
+   } else {
+      return ctx.register_demand[block_idx][idx - 1];
+   }
+}
+
 void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
 {
    /* no coupling code necessary */
@@ -659,12 +678,13 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
 
    std::vector<aco_ptr<Instruction>> instructions;
    /* branch block: TODO take other branch into consideration */
-   if (block->linear_preds.size() == 1 && !(block->kind & block_kind_loop_exit)) {
+   if (block->linear_preds.size() == 1 && !(block->kind & (block_kind_loop_exit | block_kind_loop_header))) {
       assert(ctx.processed[block->linear_preds[0]]);
       assert(ctx.register_demand[block_idx].size() == block->instructions.size());
       std::vector<RegisterDemand> reg_demand;
       unsigned insert_idx = 0;
       unsigned pred_idx = block->linear_preds[0];
+      RegisterDemand demand_before = get_demand_before(ctx, block_idx, 0);
 
       for (std::pair<Temp, std::pair<uint32_t, uint32_t>> live : ctx.next_use_distances_start[block_idx]) {
          if (!live.first.is_linear())
@@ -685,7 +705,7 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
          Temp new_name = {ctx.program->allocateId(), live.first.regClass()};
          aco_ptr<Instruction> reload = do_reload(ctx, live.first, new_name, ctx.spills_exit[pred_idx][live.first]);
          instructions.emplace_back(std::move(reload));
-         reg_demand.push_back(RegisterDemand());
+         reg_demand.push_back(demand_before);
          ctx.renames[block_idx][live.first] = new_name;
       }
 
@@ -983,8 +1003,12 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
       idx++;
    }
 
-   ctx.register_demand[block->index].erase(ctx.register_demand[block->index].begin(), ctx.register_demand[block->index].begin() + idx);
-   ctx.register_demand[block->index].insert(ctx.register_demand[block->index].begin(), instructions.size(), RegisterDemand());
+   if (!ctx.processed[block_idx]) {
+      assert(!(block->kind & block_kind_loop_header));
+      RegisterDemand demand_before = get_demand_before(ctx, block_idx, idx);
+      ctx.register_demand[block->index].erase(ctx.register_demand[block->index].begin(), ctx.register_demand[block->index].begin() + idx);
+      ctx.register_demand[block->index].insert(ctx.register_demand[block->index].begin(), instructions.size(), demand_before);
+   }
 
    std::vector<aco_ptr<Instruction>>::iterator start = std::next(block->instructions.begin(), idx);
    instructions.insert(instructions.end(), std::move_iterator<std::vector<aco_ptr<Instruction>>::iterator>(start),
@@ -995,6 +1019,8 @@ void add_coupling_code(spill_ctx& ctx, Block* block, unsigned block_idx)
 void process_block(spill_ctx& ctx, unsigned block_idx, Block* block,
                    std::map<Temp, uint32_t> &current_spills, RegisterDemand spilled_registers)
 {
+   assert(!ctx.processed[block_idx]);
+
    std::vector<std::map<Temp, uint32_t>> local_next_use_distance;
    std::vector<aco_ptr<Instruction>> instructions;
    unsigned idx = 0;
@@ -1046,18 +1072,7 @@ void process_block(spill_ctx& ctx, unsigned block_idx, Block* block,
       if (block->register_demand.exceeds(ctx.target_pressure)) {
 
          RegisterDemand new_demand = ctx.register_demand[block_idx][idx];
-         if (idx == 0) {
-            RegisterDemand demand_before = new_demand;
-            for (const Definition& def : instr->definitions)
-               demand_before -= def.getTemp();
-            for (const Operand& op : instr->operands) {
-               if (op.isFirstKill())
-                  demand_before += op.getTemp();
-            }
-            new_demand.update(demand_before);
-         } else {
-            new_demand.update(ctx.register_demand[block_idx][idx - 1]);
-         }
+         new_demand.update(get_demand_before(ctx, block_idx, idx));
 
          assert(!local_next_use_distance.empty());
 
@@ -1139,7 +1154,6 @@ void process_block(spill_ctx& ctx, unsigned block_idx, Block* block,
 void spill_block(spill_ctx& ctx, unsigned block_idx)
 {
    Block* block = &ctx.program->blocks[block_idx];
-   ctx.processed[block_idx] = true;
 
    /* determine set of variables which are spilled at the beginning of the block */
    RegisterDemand spilled_registers = init_live_in_vars(ctx, block, block_idx);
@@ -1176,6 +1190,8 @@ void spill_block(spill_ctx& ctx, unsigned block_idx)
    else
       ctx.spills_exit[block_idx].insert(current_spills.begin(), current_spills.end());
 
+   ctx.processed[block_idx] = true;
+
    /* check if the next block leaves the current loop */
    if (block->loop_nest_depth == 0 || ctx.program->blocks[block_idx + 1].loop_nest_depth >= block->loop_nest_depth)
       return;
@@ -1285,15 +1301,15 @@ Temp load_scratch_resource(spill_ctx& ctx, Temp& scratch_offset,
 
    if (ctx.program->chip_class >= GFX10) {
       rsrc_conf |= S_008F0C_FORMAT(V_008F0C_IMG_FORMAT_32_FLOAT) |
-                   S_008F0C_OOB_SELECT(3) |
+                   S_008F0C_OOB_SELECT(V_008F0C_OOB_SELECT_RAW) |
                    S_008F0C_RESOURCE_LEVEL(1);
    } else if (ctx.program->chip_class <= GFX7) { /* dfmt modifies stride on GFX8/GFX9 when ADD_TID_EN=1 */
       rsrc_conf |= S_008F0C_NUM_FORMAT(V_008F0C_BUF_NUM_FORMAT_FLOAT) |
                    S_008F0C_DATA_FORMAT(V_008F0C_BUF_DATA_FORMAT_32);
    }
-   /* older generations need element size = 16 bytes. element size removed in GFX9 */
+   /* older generations need element size = 4 bytes. element size removed in GFX9 */
    if (ctx.program->chip_class <= GFX8)
-      rsrc_conf |= S_008F0C_ELEMENT_SIZE(3);
+      rsrc_conf |= S_008F0C_ELEMENT_SIZE(1);
 
    return bld.pseudo(aco_opcode::p_create_vector, bld.def(s4),
                      private_segment_buffer, Operand(-1u),
@@ -1346,7 +1362,7 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
          for (unsigned i = slot_idx; i < slot_idx + ctx.interferences[id].first.size(); i++) {
             if (i == spill_slot_interferences.size())
                spill_slot_interferences.emplace_back(std::set<uint32_t>());
-            if (spill_slot_interferences[i].find(id) != spill_slot_interferences[i].end() || i / 64 != slot_idx / 64) {
+            if (spill_slot_interferences[i].find(id) != spill_slot_interferences[i].end() || i / ctx.wave_size != slot_idx / ctx.wave_size) {
                interferes = true;
                break;
             }
@@ -1460,7 +1476,7 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
    }
 
    /* hope, we didn't mess up */
-   std::vector<Temp> vgpr_spill_temps((sgpr_spill_slots + 63) / 64);
+   std::vector<Temp> vgpr_spill_temps((sgpr_spill_slots + ctx.wave_size - 1) / ctx.wave_size);
    assert(vgpr_spill_temps.size() <= spills_to_vgpr);
 
    /* replace pseudo instructions with actual hardware instructions */
@@ -1505,7 +1521,7 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
             for (std::pair<Temp, uint32_t> pair : ctx.spills_exit[block.linear_preds[0]]) {
 
                if (sgpr_slot.find(pair.second) != sgpr_slot.end() &&
-                   sgpr_slot[pair.second] / 64 == i) {
+                   sgpr_slot[pair.second] / ctx.wave_size == i) {
                   can_destroy = false;
                   break;
                }
@@ -1530,12 +1546,12 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
                /* spill vgpr */
                ctx.program->config->spilled_vgprs += (*it)->operands[0].size();
                uint32_t spill_slot = vgpr_slot[spill_id];
-               bool add_offset = ctx.program->config->scratch_bytes_per_wave + vgpr_spill_slots * 4 > 4096;
-               unsigned base_offset = add_offset ? 0 : ctx.program->config->scratch_bytes_per_wave;
+               bool add_offset_to_sgpr = ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size + vgpr_spill_slots * 4 > 4096;
+               unsigned base_offset = add_offset_to_sgpr ? 0 : ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size;
 
                /* check if the scratch resource descriptor already exists */
                if (scratch_rsrc == Temp()) {
-                  unsigned offset = ctx.program->config->scratch_bytes_per_wave - base_offset;
+                  unsigned offset = add_offset_to_sgpr ? ctx.program->config->scratch_bytes_per_wave : 0;
                   scratch_rsrc = load_scratch_resource(ctx, scratch_offset,
                                                        last_top_level_block_idx == block.index ?
                                                        instructions : ctx.program->blocks[last_top_level_block_idx].instructions,
@@ -1544,46 +1560,30 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
                }
 
                unsigned offset = base_offset + spill_slot * 4;
-               aco_opcode opcode;
+               aco_opcode opcode = aco_opcode::buffer_store_dword;
                assert((*it)->operands[0].isTemp());
                Temp temp = (*it)->operands[0].getTemp();
                assert(temp.type() == RegType::vgpr && !temp.is_linear());
-               switch (temp.size()) {
-               case 1: opcode = aco_opcode::buffer_store_dword; break;
-               case 2: opcode = aco_opcode::buffer_store_dwordx2; break;
-               case 6: temp = bld.tmp(v3); /* fallthrough */
-               case 3: opcode = aco_opcode::buffer_store_dwordx3; break;
-               case 8: temp = bld.tmp(v4); /* fallthrough */
-               case 4: opcode = aco_opcode::buffer_store_dwordx4; break;
-               default: {
+               if (temp.size() > 1) {
                   Instruction* split{create_instruction<Pseudo_instruction>(aco_opcode::p_split_vector, Format::PSEUDO, 1, temp.size())};
                   split->operands[0] = Operand(temp);
                   for (unsigned i = 0; i < temp.size(); i++)
                      split->definitions[i] = bld.def(v1);
                   bld.insert(split);
-                  opcode = aco_opcode::buffer_store_dword;
                   for (unsigned i = 0; i < temp.size(); i++)
-                     bld.mubuf(opcode, Operand(), scratch_rsrc, scratch_offset, split->definitions[i].getTemp(), offset + i * 4, false);
-                  continue;
-               }
+                     bld.mubuf(opcode, scratch_rsrc, Operand(), scratch_offset, split->definitions[i].getTemp(), offset + i * 4, false);
+               } else {
+                  bld.mubuf(opcode, scratch_rsrc, Operand(), scratch_offset, temp, offset, false);
                }
-
-               if ((*it)->operands[0].size() > 4) {
-                  Temp temp2 = bld.pseudo(aco_opcode::p_split_vector, bld.def(temp.regClass()), Definition(temp), (*it)->operands[0]);
-                  bld.mubuf(opcode, Operand(), scratch_rsrc, scratch_offset, temp2, offset, false);
-                  offset += temp.size() * 4;
-               }
-               bld.mubuf(opcode, Operand(), scratch_rsrc, scratch_offset, temp, offset, false);
-
             } else if (sgpr_slot.find(spill_id) != sgpr_slot.end()) {
                ctx.program->config->spilled_sgprs += (*it)->operands[0].size();
 
                uint32_t spill_slot = sgpr_slot[spill_id];
 
                /* check if the linear vgpr already exists */
-               if (vgpr_spill_temps[spill_slot / 64] == Temp()) {
+               if (vgpr_spill_temps[spill_slot / ctx.wave_size] == Temp()) {
                   Temp linear_vgpr = {ctx.program->allocateId(), v1.as_linear()};
-                  vgpr_spill_temps[spill_slot / 64] = linear_vgpr;
+                  vgpr_spill_temps[spill_slot / ctx.wave_size] = linear_vgpr;
                   aco_ptr<Pseudo_instruction> create{create_instruction<Pseudo_instruction>(aco_opcode::p_start_linear_vgpr, Format::PSEUDO, 0, 1)};
                   create->definitions[0] = Definition(linear_vgpr);
                   /* find the right place to insert this definition */
@@ -1600,8 +1600,8 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
 
                /* spill sgpr: just add the vgpr temp to operands */
                Pseudo_instruction* spill = create_instruction<Pseudo_instruction>(aco_opcode::p_spill, Format::PSEUDO, 3, 0);
-               spill->operands[0] = Operand(vgpr_spill_temps[spill_slot / 64]);
-               spill->operands[1] = Operand(spill_slot % 64);
+               spill->operands[0] = Operand(vgpr_spill_temps[spill_slot / ctx.wave_size]);
+               spill->operands[1] = Operand(spill_slot % ctx.wave_size);
                spill->operands[2] = (*it)->operands[0];
                instructions.emplace_back(aco_ptr<Instruction>(spill));
             } else {
@@ -1615,12 +1615,12 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
             if (vgpr_slot.find(spill_id) != vgpr_slot.end()) {
                /* reload vgpr */
                uint32_t spill_slot = vgpr_slot[spill_id];
-               bool add_offset = ctx.program->config->scratch_bytes_per_wave + vgpr_spill_slots * 4 > 4096;
-               unsigned base_offset = add_offset ? 0 : ctx.program->config->scratch_bytes_per_wave;
+               bool add_offset_to_sgpr = ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size + vgpr_spill_slots * 4 > 4096;
+               unsigned base_offset = add_offset_to_sgpr ? 0 : ctx.program->config->scratch_bytes_per_wave / ctx.program->wave_size;
 
                /* check if the scratch resource descriptor already exists */
                if (scratch_rsrc == Temp()) {
-                  unsigned offset = ctx.program->config->scratch_bytes_per_wave - base_offset;
+                  unsigned offset = add_offset_to_sgpr ? ctx.program->config->scratch_bytes_per_wave : 0;
                   scratch_rsrc = load_scratch_resource(ctx, scratch_offset,
                                                        last_top_level_block_idx == block.index ?
                                                        instructions : ctx.program->blocks[last_top_level_block_idx].instructions,
@@ -1629,43 +1629,28 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
                }
 
                unsigned offset = base_offset + spill_slot * 4;
-               aco_opcode opcode;
+               aco_opcode opcode = aco_opcode::buffer_load_dword;
                Definition def = (*it)->definitions[0];
-               switch (def.size()) {
-               case 1: opcode = aco_opcode::buffer_load_dword; break;
-               case 2: opcode = aco_opcode::buffer_load_dwordx2; break;
-               case 6: def = bld.def(v3); /* fallthrough */
-               case 3: opcode = aco_opcode::buffer_load_dwordx3; break;
-               case 8: def = bld.def(v4); /* fallthrough */
-               case 4: opcode = aco_opcode::buffer_load_dwordx4; break;
-               default: {
+               if (def.size() > 1) {
                   Instruction* vec{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, def.size(), 1)};
                   vec->definitions[0] = def;
-                  opcode = aco_opcode::buffer_load_dword;
                   for (unsigned i = 0; i < def.size(); i++) {
                      Temp tmp = bld.tmp(v1);
                      vec->operands[i] = Operand(tmp);
-                     bld.mubuf(opcode, Definition(tmp), Operand(), scratch_rsrc, scratch_offset, offset + i * 4, false);
+                     bld.mubuf(opcode, Definition(tmp), scratch_rsrc, Operand(), scratch_offset, offset + i * 4, false);
                   }
                   bld.insert(vec);
-                  continue;
-               }
-               }
-
-               bld.mubuf(opcode, def, Operand(), scratch_rsrc, scratch_offset, offset, false);
-               if ((*it)->definitions[0].size() > 4) {
-                  Temp temp2 = bld.mubuf(opcode, bld.def(def.regClass()), Operand(), scratch_rsrc, scratch_offset, offset + def.size() * 4, false);
-                  bld.pseudo(aco_opcode::p_create_vector, (*it)->definitions[0], def.getTemp(), temp2);
+               } else {
+                  bld.mubuf(opcode, def, scratch_rsrc, Operand(), scratch_offset, offset, false);
                }
-
             } else if (sgpr_slot.find(spill_id) != sgpr_slot.end()) {
                uint32_t spill_slot = sgpr_slot[spill_id];
-               reload_in_loop[spill_slot / 64] = block.loop_nest_depth > 0;
+               reload_in_loop[spill_slot / ctx.wave_size] = block.loop_nest_depth > 0;
 
                /* check if the linear vgpr already exists */
-               if (vgpr_spill_temps[spill_slot / 64] == Temp()) {
+               if (vgpr_spill_temps[spill_slot / ctx.wave_size] == Temp()) {
                   Temp linear_vgpr = {ctx.program->allocateId(), v1.as_linear()};
-                  vgpr_spill_temps[spill_slot / 64] = linear_vgpr;
+                  vgpr_spill_temps[spill_slot / ctx.wave_size] = linear_vgpr;
                   aco_ptr<Pseudo_instruction> create{create_instruction<Pseudo_instruction>(aco_opcode::p_start_linear_vgpr, Format::PSEUDO, 0, 1)};
                   create->definitions[0] = Definition(linear_vgpr);
                   /* find the right place to insert this definition */
@@ -1682,8 +1667,8 @@ void assign_spill_slots(spill_ctx& ctx, unsigned spills_to_vgpr) {
 
                /* reload sgpr: just add the vgpr temp to operands */
                Pseudo_instruction* reload = create_instruction<Pseudo_instruction>(aco_opcode::p_reload, Format::PSEUDO, 2, 1);
-               reload->operands[0] = Operand(vgpr_spill_temps[spill_slot / 64]);
-               reload->operands[1] = Operand(spill_slot % 64);
+               reload->operands[0] = Operand(vgpr_spill_temps[spill_slot / ctx.wave_size]);
+               reload->operands[1] = Operand(spill_slot % ctx.wave_size);
                reload->definitions[0] = (*it)->definitions[0];
                instructions.emplace_back(aco_ptr<Instruction>(reload));
             } else {
@@ -1770,18 +1755,17 @@ void spill(Program* program, live& live_vars, const struct radv_nir_compiler_opt
    /* calculate target register demand */
    RegisterDemand register_target = program->max_reg_demand;
    if (register_target.sgpr > program->sgpr_limit)
-      register_target.vgpr += (register_target.sgpr - program->sgpr_limit + 63 + 32) / 64;
+      register_target.vgpr += (register_target.sgpr - program->sgpr_limit + program->wave_size - 1 + 32) / program->wave_size;
    register_target.sgpr = program->sgpr_limit;
 
    if (register_target.vgpr > program->vgpr_limit)
       register_target.sgpr = program->sgpr_limit - 5;
-   register_target.vgpr = program->vgpr_limit - (register_target.vgpr - program->max_reg_demand.vgpr);
-
-   int spills_to_vgpr = (program->max_reg_demand.sgpr - register_target.sgpr + 63 + 32) / 64;
+   int spills_to_vgpr = (program->max_reg_demand.sgpr - register_target.sgpr + program->wave_size - 1 + 32) / program->wave_size;
+   register_target.vgpr = program->vgpr_limit - spills_to_vgpr;
 
    /* initialize ctx */
    spill_ctx ctx(register_target, program, live_vars.register_demand);
-   compute_global_next_uses(ctx, live_vars.live_out);
+   compute_global_next_uses(ctx);
    get_rematerialize_info(ctx);
 
    /* create spills and reloads */
@@ -1794,7 +1778,7 @@ void spill(Program* program, live& live_vars, const struct radv_nir_compiler_opt
    /* update live variable information */
    live_vars = live_var_analysis(program, options);
 
-   assert(program->num_waves >= 0);
+   assert(program->num_waves > 0);
 }
 
 }