aco: increase accuracy of SGPR limits
[mesa.git] / src / amd / compiler / aco_live_var_analysis.cpp
index f99e57c8b3a5ff0fc32765f809003edf6d60c8ba..3fe413256e75bbbdfbf8aa176ac1b490935be38e 100644 (file)
@@ -28,6 +28,7 @@
  */
 
 #include "aco_ir.h"
+#include "util/u_math.h"
 
 #include <set>
 #include <vector>
@@ -190,25 +191,62 @@ void process_live_temps_per_block(Program *program, live& lives, Block* block,
 }
 } /* end namespace */
 
+uint16_t get_extra_sgprs(Program *program)
+{
+   if (program->chip_class >= GFX10) {
+      assert(!program->needs_flat_scr);
+      assert(!program->needs_xnack_mask);
+      return 2;
+   } else if (program->chip_class >= GFX8) {
+      if (program->needs_flat_scr)
+         return 6;
+      else if (program->needs_xnack_mask)
+         return 4;
+      else if (program->needs_vcc)
+         return 2;
+      else
+         return 0;
+   } else {
+      assert(!program->needs_xnack_mask);
+      if (program->needs_flat_scr)
+         return 4;
+      else if (program->needs_vcc)
+         return 2;
+      else
+         return 0;
+   }
+}
+
+uint16_t get_sgpr_alloc(Program *program, uint16_t addressable_sgprs)
+{
+   assert(addressable_sgprs <= program->sgpr_limit);
+   uint16_t sgprs = addressable_sgprs + get_extra_sgprs(program);
+   uint16_t granule = program->sgpr_alloc_granule + 1;
+   return align(std::max(sgprs, granule), granule);
+}
+
+uint16_t get_addr_sgpr_from_waves(Program *program, uint16_t max_waves)
+{
+    uint16_t sgprs = program->physical_sgprs / max_waves & ~program->sgpr_alloc_granule;
+    sgprs -= get_extra_sgprs(program);
+    return std::min(sgprs, program->sgpr_limit);
+}
+
 void update_vgpr_sgpr_demand(Program* program, const RegisterDemand new_demand)
 {
    // TODO: also take shared mem into account
-   const int16_t total_sgpr_regs = program->chip_class >= GFX8 ? 800 : 512;
-   const int16_t max_addressible_sgpr = program->sgpr_limit;
-   /* VGPRs are allocated in chunks of 4 */
-   const int16_t rounded_vgpr_demand = std::max<int16_t>(4, (new_demand.vgpr + 3) & ~3);
-   /* SGPRs are allocated in chunks of 16 between 8 and 104. VCC occupies the last 2 registers */
-   const int16_t rounded_sgpr_demand = std::min(std::max<int16_t>(8, (new_demand.sgpr + 2 + 7) & ~7), max_addressible_sgpr);
+   const int16_t vgpr_alloc = std::max<int16_t>(4, (new_demand.vgpr + 3) & ~3);
    /* this won't compile, register pressure reduction necessary */
-   if (new_demand.vgpr > 256 || new_demand.sgpr > max_addressible_sgpr) {
+   if (new_demand.vgpr > 256 || new_demand.sgpr > program->sgpr_limit) {
       program->num_waves = 0;
       program->max_reg_demand = new_demand;
    } else {
-      program->num_waves = std::min<uint16_t>(10,
-                                              std::min<uint16_t>(256 / rounded_vgpr_demand,
-                                                                 total_sgpr_regs / rounded_sgpr_demand));
+      program->num_waves = program->physical_sgprs / get_sgpr_alloc(program, new_demand.sgpr);
+      program->num_waves = std::min<uint16_t>(program->num_waves, 256 / vgpr_alloc);
+      program->num_waves = std::min<uint16_t>(program->num_waves, 10);
 
-      program->max_reg_demand = {  int16_t((256 / program->num_waves) & ~3), std::min<int16_t>(((total_sgpr_regs / program->num_waves) & ~7) - 2, max_addressible_sgpr)};
+      program->max_reg_demand.vgpr = int16_t((256 / program->num_waves) & ~3);
+      program->max_reg_demand.sgpr = get_addr_sgpr_from_waves(program, program->num_waves);
    }
 }