aco: refactor get_reg() to take Temp instead of RegClass
authorDaniel Schürmann <daniel@schuermann.dev>
Fri, 10 Apr 2020 17:55:18 +0000 (18:55 +0100)
committerMarge Bot <eric+marge@anholt.net>
Wed, 22 Apr 2020 18:23:22 +0000 (18:23 +0000)
This patch also moves get_reg_specified() and
get_reg_vec() before get_reg() to make use of it later.

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4573>

src/amd/compiler/aco_register_allocation.cpp

index 1be80a6e8826183348eb206f5227e9b76bf05d2d..7dff6b98ff23c60c1242604d883bc9708d5149e0 100644 (file)
@@ -796,12 +796,75 @@ std::pair<PhysReg, bool> get_reg_impl(ra_ctx& ctx,
    return {PhysReg{best_pos}, true};
 }
 
+bool get_reg_specified(ra_ctx& ctx,
+                       RegisterFile& reg_file,
+                       RegClass rc,
+                       std::vector<std::pair<Operand, Definition>>& parallelcopies,
+                       aco_ptr<Instruction>& instr,
+                       PhysReg reg)
+{
+   uint32_t size = rc.size();
+   uint32_t stride = 1;
+   uint32_t lb, ub;
+
+   if (rc.type() == RegType::vgpr) {
+      lb = 256;
+      ub = 256 + ctx.program->max_reg_demand.vgpr;
+   } else {
+      if (size == 2)
+         stride = 2;
+      else if (size >= 4)
+         stride = 4;
+      if (reg % stride != 0)
+         return false;
+      lb = 0;
+      ub = ctx.program->max_reg_demand.sgpr;
+   }
+
+   if (rc.is_subdword() && reg.byte() && !instr_can_access_subdword(instr))
+      return false;
+
+   uint32_t reg_lo = reg.reg();
+   uint32_t reg_hi = reg + (size - 1);
+
+   if (reg_lo < lb || reg_hi >= ub || reg_lo > reg_hi)
+      return false;
+
+   if (reg_file.test(reg, rc.bytes()))
+      return false;
+
+   adjust_max_used_regs(ctx, rc, reg_lo);
+   return true;
+}
+
+std::pair<PhysReg, bool> get_reg_vec(ra_ctx& ctx,
+                                     RegisterFile& reg_file,
+                                     RegClass rc)
+{
+   uint32_t size = rc.size();
+   uint32_t stride = 1;
+   uint32_t lb, ub;
+   if (rc.type() == RegType::vgpr) {
+      lb = 256;
+      ub = 256 + ctx.program->max_reg_demand.vgpr;
+   } else {
+      lb = 0;
+      ub = ctx.program->max_reg_demand.sgpr;
+      if (size == 2)
+         stride = 2;
+      else if (size >= 4)
+         stride = 4;
+   }
+   return get_reg_simple(ctx, reg_file, lb, ub, size, stride, rc);
+}
+
 PhysReg get_reg(ra_ctx& ctx,
                 RegisterFile& reg_file,
-                RegClass rc,
+                Temp temp,
                 std::vector<std::pair<Operand, Definition>>& parallelcopies,
                 aco_ptr<Instruction>& instr)
 {
+   RegClass rc = temp.regClass();
    uint32_t size = rc.size();
    uint32_t stride = 1;
    uint32_t lb, ub;
@@ -852,10 +915,10 @@ PhysReg get_reg(ra_ctx& ctx,
    uint16_t max_addressible_vgpr = ctx.program->vgpr_limit;
    if (rc.type() == RegType::vgpr && ctx.program->max_reg_demand.vgpr < max_addressible_vgpr) {
       update_vgpr_sgpr_demand(ctx.program, RegisterDemand(ctx.program->max_reg_demand.vgpr + 1, ctx.program->max_reg_demand.sgpr));
-      return get_reg(ctx, reg_file, rc, parallelcopies, instr);
+      return get_reg(ctx, reg_file, temp, parallelcopies, instr);
    } else if (rc.type() == RegType::sgpr && ctx.program->max_reg_demand.sgpr < max_addressible_sgpr) {
       update_vgpr_sgpr_demand(ctx.program,  RegisterDemand(ctx.program->max_reg_demand.vgpr, ctx.program->max_reg_demand.sgpr + 1));
-      return get_reg(ctx, reg_file, rc, parallelcopies, instr);
+      return get_reg(ctx, reg_file, temp, parallelcopies, instr);
    }
 
    //FIXME: if nothing helps, shift-rotate the registers to make space
@@ -863,35 +926,13 @@ PhysReg get_reg(ra_ctx& ctx,
    unreachable("did not find a register");
 }
 
-
-std::pair<PhysReg, bool> get_reg_vec(ra_ctx& ctx,
-                                     RegisterFile& reg_file,
-                                     RegClass rc)
-{
-   uint32_t size = rc.size();
-   uint32_t stride = 1;
-   uint32_t lb, ub;
-   if (rc.type() == RegType::vgpr) {
-      lb = 256;
-      ub = 256 + ctx.program->max_reg_demand.vgpr;
-   } else {
-      lb = 0;
-      ub = ctx.program->max_reg_demand.sgpr;
-      if (size == 2)
-         stride = 2;
-      else if (size >= 4)
-         stride = 4;
-   }
-   return get_reg_simple(ctx, reg_file, lb, ub, size, stride, rc);
-}
-
-
 PhysReg get_reg_create_vector(ra_ctx& ctx,
                               RegisterFile& reg_file,
-                              RegClass rc,
+                              Temp temp,
                               std::vector<std::pair<Operand, Definition>>& parallelcopies,
                               aco_ptr<Instruction>& instr)
 {
+   RegClass rc = temp.regClass();
    /* create_vector instructions have different costs w.r.t. register coalescing */
    uint32_t size = rc.size();
    uint32_t bytes = rc.bytes();
@@ -986,7 +1027,7 @@ PhysReg get_reg_create_vector(ra_ctx& ctx,
    }
 
    if (num_moves >= bytes)
-      return get_reg(ctx, reg_file, rc, parallelcopies, instr);
+      return get_reg(ctx, reg_file, temp, parallelcopies, instr);
 
    /* collect variables to be moved */
    std::set<std::pair<unsigned, unsigned>> vars = collect_vars(ctx, reg_file, PhysReg{best_pos}, size);
@@ -1019,47 +1060,6 @@ PhysReg get_reg_create_vector(ra_ctx& ctx,
    return PhysReg{best_pos};
 }
 
-bool get_reg_specified(ra_ctx& ctx,
-                       RegisterFile& reg_file,
-                       RegClass rc,
-                       std::vector<std::pair<Operand, Definition>>& parallelcopies,
-                       aco_ptr<Instruction>& instr,
-                       PhysReg reg)
-{
-   uint32_t size = rc.size();
-   uint32_t stride = 1;
-   uint32_t lb, ub;
-
-   if (rc.type() == RegType::vgpr) {
-      lb = 256;
-      ub = 256 + ctx.program->max_reg_demand.vgpr;
-   } else {
-      if (size == 2)
-         stride = 2;
-      else if (size >= 4)
-         stride = 4;
-      if (reg % stride != 0)
-         return false;
-      lb = 0;
-      ub = ctx.program->max_reg_demand.sgpr;
-   }
-
-   if (rc.is_subdword() && reg.byte() && !instr_can_access_subdword(instr))
-      return false;
-
-   uint32_t reg_lo = reg.reg();
-   uint32_t reg_hi = reg + (size - 1);
-
-   if (reg_lo < lb || reg_hi >= ub || reg_lo > reg_hi)
-      return false;
-
-   if (reg_file.test(reg, rc.bytes()))
-      return false;
-
-   adjust_max_used_regs(ctx, rc, reg_lo);
-   return true;
-}
-
 void handle_pseudo(ra_ctx& ctx,
                    const RegisterFile& reg_file,
                    Instruction* instr)
@@ -1157,7 +1157,7 @@ void get_reg_for_operand(ra_ctx& ctx, RegisterFile& register_file,
          pc_op.setFixed(operand.physReg());
 
          /* find free reg */
-         PhysReg reg = get_reg(ctx, register_file, pc_op.regClass(), parallelcopy, instr);
+         PhysReg reg = get_reg(ctx, register_file, pc_op.getTemp(), parallelcopy, instr);
          Definition pc_def = Definition(PhysReg{reg}, pc_op.regClass());
          register_file.clear(pc_op);
          parallelcopy.emplace_back(pc_op, pc_def);
@@ -1165,7 +1165,7 @@ void get_reg_for_operand(ra_ctx& ctx, RegisterFile& register_file,
       dst = operand.physReg();
 
    } else {
-      dst = get_reg(ctx, register_file, operand.regClass(), parallelcopy, instr);
+      dst = get_reg(ctx, register_file, operand.getTemp(), parallelcopy, instr);
    }
 
    Operand pc_op = operand;
@@ -1518,7 +1518,7 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                }
             }
             if (!definition.isFixed())
-               definition.setFixed(get_reg(ctx, register_file, definition.regClass(), parallelcopy, phi));
+               definition.setFixed(get_reg(ctx, register_file, definition.getTemp(), parallelcopy, phi));
 
             /* process parallelcopy */
             for (std::pair<Operand, Definition> pc : parallelcopy) {
@@ -1721,7 +1721,7 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                }
 
                /* find a new register for the blocking variable */
-               PhysReg reg = get_reg(ctx, register_file, rc, parallelcopy, instr);
+               PhysReg reg = get_reg(ctx, register_file, pc_op.getTemp(), parallelcopy, instr);
                /* once again, disable killed operands */
                for (const Operand& op : instr->operands) {
                   if (op.isTemp() && op.isFirstKillBeforeDef())
@@ -1769,7 +1769,7 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                PhysReg reg = instr->operands[0].physReg();
                reg.reg_b += i * definition.bytes();
                if (!get_reg_specified(ctx, register_file, definition.regClass(), parallelcopy, instr, reg))
-                  reg = get_reg(ctx, register_file, definition.regClass(), parallelcopy, instr);
+                  reg = get_reg(ctx, register_file, definition.getTemp(), parallelcopy, instr);
                definition.setFixed(reg);
             } else if (instr->opcode == aco_opcode::p_wqm) {
                PhysReg reg;
@@ -1777,7 +1777,7 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                   reg = instr->operands[0].physReg();
                   assert(register_file[reg.reg()] == 0);
                } else {
-                  reg = get_reg(ctx, register_file, definition.regClass(), parallelcopy, instr);
+                  reg = get_reg(ctx, register_file, definition.getTemp(), parallelcopy, instr);
                }
                definition.setFixed(reg);
             } else if (instr->opcode == aco_opcode::p_extract_vector) {
@@ -1788,11 +1788,11 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                   reg.reg_b += definition.bytes() * instr->operands[1].constantValue();
                   assert(!register_file.test(reg, definition.bytes()));
                } else {
-                  reg = get_reg(ctx, register_file, definition.regClass(), parallelcopy, instr);
+                  reg = get_reg(ctx, register_file, definition.getTemp(), parallelcopy, instr);
                }
                definition.setFixed(reg);
             } else if (instr->opcode == aco_opcode::p_create_vector) {
-               PhysReg reg = get_reg_create_vector(ctx, register_file, definition.regClass(),
+               PhysReg reg = get_reg_create_vector(ctx, register_file, definition.getTemp(),
                                                    parallelcopy, instr);
                definition.setFixed(reg);
             } else if (ctx.affinities.find(definition.tempId()) != ctx.affinities.end() &&
@@ -1801,7 +1801,7 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                if (get_reg_specified(ctx, register_file, definition.regClass(), parallelcopy, instr, reg))
                   definition.setFixed(reg);
                else
-                  definition.setFixed(get_reg(ctx, register_file, definition.regClass(), parallelcopy, instr));
+                  definition.setFixed(get_reg(ctx, register_file, definition.getTemp(), parallelcopy, instr));
 
             } else if (vectors.find(definition.tempId()) != vectors.end()) {
                Instruction* vec = vectors[definition.tempId()];
@@ -1834,14 +1834,14 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                      reg.reg_b += byte_offset;
                      /* make sure to only use byte offset if the instruction supports it */
                      if (vec->definitions[0].regClass().is_subdword() && reg.byte() && !instr_can_access_subdword(instr))
-                        reg = get_reg(ctx, register_file, definition.regClass(), parallelcopy, instr);
+                        reg = get_reg(ctx, register_file, definition.getTemp(), parallelcopy, instr);
                   } else {
-                     reg = get_reg(ctx, register_file, definition.regClass(), parallelcopy, instr);
+                     reg = get_reg(ctx, register_file, definition.getTemp(), parallelcopy, instr);
                   }
                   definition.setFixed(reg);
                }
             } else
-               definition.setFixed(get_reg(ctx, register_file, definition.regClass(), parallelcopy, instr));
+               definition.setFixed(get_reg(ctx, register_file, definition.getTemp(), parallelcopy, instr));
 
             assert(definition.isFixed() && ((definition.getTemp().type() == RegType::vgpr && definition.physReg() >= 256) ||
                                             (definition.getTemp().type() != RegType::vgpr && definition.physReg() < 256)));
@@ -1966,10 +1966,9 @@ void register_allocation(Program *program, std::vector<TempSet>& live_out_per_bl
                   if (op.isTemp() && op.isFirstKill())
                      register_file.block(op.physReg(), op.bytes());
                }
-               RegClass rc = can_sgpr ? s1 : v1;
-               PhysReg reg = get_reg(ctx, register_file, rc, parallelcopy, instr);
-               Temp tmp = {program->allocateId(), rc};
-               ctx.assignments.emplace_back(reg, rc);
+               Temp tmp = {program->allocateId(), can_sgpr ? s1 : v1};
+               ctx.assignments.emplace_back();
+               PhysReg reg = get_reg(ctx, register_file, tmp, parallelcopy, instr);
 
                aco_ptr<Instruction> mov;
                if (can_sgpr)