aco/wave32: Fix reductions.
authorTimur Kristóf <timur.kristof@gmail.com>
Wed, 27 Nov 2019 15:59:11 +0000 (16:59 +0100)
committerDaniel Schürmann <daniel@schuermann.dev>
Wed, 4 Dec 2019 10:36:01 +0000 (10:36 +0000)
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/compiler/aco_lower_to_hw_instr.cpp

index 84c88e4eaa502e2c21675c844fef41ebdc985c0d..0f89cb1aee56ee247787881c28ee4629db67e18b 100644 (file)
@@ -5232,15 +5232,15 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te
       //subgroupClusteredOr(val, 4) -> wqm(val & exec)
       return bld.sop1(Builder::s_wqm, bld.def(bld.lm), bld.def(s1, scc),
                       bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm)));
-   } else if (op == nir_op_iand && cluster_size == 64) {
+   } else if (op == nir_op_iand && cluster_size == ctx->program->wave_size) {
       //subgroupAnd(val) -> (exec & ~val) == 0
       Temp tmp = bld.sop2(Builder::s_andn2, bld.def(bld.lm), bld.def(s1, scc), Operand(exec, bld.lm), src).def(1).getTemp();
       return bld.sop2(Builder::s_cselect, bld.def(bld.lm), Operand(0u), Operand(-1u), bld.scc(tmp));
-   } else if (op == nir_op_ior && cluster_size == 64) {
+   } else if (op == nir_op_ior && cluster_size == ctx->program->wave_size) {
       //subgroupOr(val) -> (val & exec) != 0
       Temp tmp = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm)).def(1).getTemp();
       return bool_to_vector_condition(ctx, tmp);
-   } else if (op == nir_op_ixor && cluster_size == 64) {
+   } else if (op == nir_op_ixor && cluster_size == ctx->program->wave_size) {
       //subgroupXor(val) -> s_bcnt1_i32_b64(val & exec) & 1
       Temp tmp = bld.sop2(Builder::s_and, bld.def(bld.lm), bld.def(s1, scc), src, Operand(exec, bld.lm));
       tmp = bld.sop1(Builder::s_bcnt1_i32, bld.def(s1), bld.def(s1, scc), tmp);
@@ -5839,7 +5839,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       nir_op op = (nir_op) nir_intrinsic_reduction_op(instr);
       unsigned cluster_size = instr->intrinsic == nir_intrinsic_reduce ?
          nir_intrinsic_cluster_size(instr) : 0;
-      cluster_size = util_next_power_of_two(MIN2(cluster_size ? cluster_size : 64, 64));
+      cluster_size = util_next_power_of_two(MIN2(cluster_size ? cluster_size : ctx->program->wave_size, ctx->program->wave_size));
 
       if (!ctx->divergent_vals[instr->src[0].ssa->index] && (op == nir_op_ior || op == nir_op_iand)) {
          emit_uniform_subgroup(ctx, instr, src);
index 47f5778822f8e5f1266a3f0a401c433668ee4a97..469aebbb8d92af6e5f197dd89d35216d16383903 100644 (file)
@@ -390,8 +390,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
                      if (intrinsic->dest.ssa.bit_size == 1) {
                         size = lane_mask_size;
                         type = RegType::sgpr;
-                     } else if (nir_intrinsic_cluster_size(intrinsic) == 0 ||
-                         !ctx->divergent_vals[intrinsic->dest.ssa.index]) {
+                     } else if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) {
                         type = RegType::sgpr;
                      } else {
                         type = RegType::vgpr;
index e9c2d66d8233818922f820610563f09b7ff1b5ff..19e0f598074f0bb06a89147f75fc526e8edddb45 100644 (file)
@@ -412,7 +412,8 @@ uint32_t get_reduction_identity(ReduceOp op, unsigned idx)
 void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsigned cluster_size, PhysReg tmp,
                     PhysReg stmp, PhysReg vtmp, PhysReg sitmp, Operand src, Definition dst)
 {
-   assert(cluster_size == 64 || op == aco_opcode::p_reduce);
+   assert(cluster_size == ctx->program->wave_size || op == aco_opcode::p_reduce);
+   assert(cluster_size <= ctx->program->wave_size);
 
    Builder bld(ctx->program, &ctx->instructions);
 
@@ -462,23 +463,34 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig
       emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(),
                   dpp_row_mirror, 0xf, 0xf, false);
       if (cluster_size == 16) break;
-      if (cluster_size == 32) {
+
+      if (ctx->program->chip_class >= GFX10) {
+         /* GFX10+ doesn't support row_bcast15 and row_bcast31 */
+
+         for (unsigned i = 0; i < src.size(); i++)
+            bld.vop3(aco_opcode::v_permlanex16_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{tmp+i}, v1), Operand(0u), Operand(0u));
+
+         if (cluster_size == 32 && dst.regClass().type() == RegType::vgpr) {
+            bld.sop1(Builder::s_mov, Definition(exec, bld.lm), Operand(stmp, bld.lm));
+            exec_restored = true;
+            emit_op(ctx, dst.physReg(), tmp, vtmp, PhysReg{0}, reduce_op, src.size());
+            dst_written = true;
+         } else {
+            emit_op(ctx, tmp, tmp, vtmp, PhysReg{0}, reduce_op, src.size());
+         }
+
+         if (cluster_size == 64) {
+            for (unsigned i = 0; i < src.size(); i++)
+               bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u));
+            emit_op(ctx, tmp, sitmp, tmp, vtmp, reduce_op, src.size());
+         }
+      } else if (cluster_size == 32) {
          for (unsigned i = 0; i < src.size(); i++)
             bld.ds(aco_opcode::ds_swizzle_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{tmp+i}, s1), ds_pattern_bitmode(0x1f, 0, 0x10));
          bld.sop1(Builder::s_mov, Definition(exec, bld.lm), Operand(stmp, bld.lm));
          exec_restored = true;
          emit_op(ctx, dst.physReg(), vtmp, tmp, PhysReg{0}, reduce_op, src.size());
          dst_written = true;
-      } else if (ctx->program->chip_class >= GFX10) {
-         assert(cluster_size == 64);
-         /* GFX10+ doesn't support row_bcast15 and row_bcast31 */
-         for (unsigned i = 0; i < src.size(); i++)
-            bld.vop3(aco_opcode::v_permlanex16_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{tmp+i}, v1), Operand(0u), Operand(0u));
-         emit_op(ctx, tmp, tmp, vtmp, PhysReg{0}, reduce_op, src.size());
-
-         for (unsigned i = 0; i < src.size(); i++)
-            bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u));
-         emit_op(ctx, tmp, sitmp, tmp, vtmp, reduce_op, src.size());
       } else {
          assert(cluster_size == 64);
          emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(),
@@ -504,10 +516,12 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig
          }
          bld.sop1(Builder::s_mov, Definition(exec, bld.lm), Operand(UINT64_MAX));
 
-         /* fill in the gap in row 2 */
-         for (unsigned i = 0; i < src.size(); i++) {
-            bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u));
-            bld.vop3(aco_opcode::v_writelane_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{sitmp+i}, s1), Operand(32u));
+         if (ctx->program->wave_size == 64) {
+            /* fill in the gap in row 2 */
+            for (unsigned i = 0; i < src.size(); i++) {
+               bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u));
+               bld.vop3(aco_opcode::v_writelane_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{sitmp+i}, s1), Operand(32u));
+            }
          }
          std::swap(tmp, vtmp);
       } else {
@@ -523,7 +537,7 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig
       }
       /* fall through */
    case aco_opcode::p_inclusive_scan:
-      assert(cluster_size == 64);
+      assert(cluster_size == ctx->program->wave_size);
       emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(),
                   dpp_row_sr(1), 0xf, 0xf, false, identity);
       emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(),
@@ -544,11 +558,13 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig
          }
          emit_op(ctx, tmp, tmp, vtmp, PhysReg{0}, reduce_op, src.size());
 
-         bld.sop1(aco_opcode::s_mov_b32, Definition(exec_lo, s1), Operand(0u));
-         bld.sop1(aco_opcode::s_mov_b32, Definition(exec_hi, s1), Operand(0xffffffffu));
-         for (unsigned i = 0; i < src.size(); i++)
-            bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u));
-         emit_op(ctx, tmp, sitmp, tmp, vtmp, reduce_op, src.size());
+         if (ctx->program->wave_size == 64) {
+            bld.sop1(aco_opcode::s_mov_b32, Definition(exec_lo, s1), Operand(0u));
+            bld.sop1(aco_opcode::s_mov_b32, Definition(exec_hi, s1), Operand(0xffffffffu));
+            for (unsigned i = 0; i < src.size(); i++)
+               bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u));
+            emit_op(ctx, tmp, sitmp, tmp, vtmp, reduce_op, src.size());
+         }
       } else {
          emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(),
                      dpp_row_bcast15, 0xa, 0xf, false, identity);
@@ -563,10 +579,10 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig
    if (!exec_restored)
       bld.sop1(Builder::s_mov, Definition(exec, bld.lm), Operand(stmp, bld.lm));
 
-   if (op == aco_opcode::p_reduce && cluster_size == 64) {
+   if (op == aco_opcode::p_reduce && dst.regClass().type() == RegType::sgpr) {
       for (unsigned k = 0; k < src.size(); k++) {
          bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{dst.physReg() + k}, s1),
-                  Operand(PhysReg{tmp + k}, v1), Operand(63u));
+                  Operand(PhysReg{tmp + k}, v1), Operand(ctx->program->wave_size - 1));
       }
    } else if (!(dst.physReg() == tmp) && !dst_written) {
       for (unsigned k = 0; k < src.size(); k++) {