aco: Implement subgroup shuffle in GFX10 wave64 mode.
authorTimur Kristóf <timur.kristof@gmail.com>
Sat, 21 Sep 2019 16:03:56 +0000 (18:03 +0200)
committerRhys Perry <pendingchaos02@gmail.com>
Mon, 28 Oct 2019 23:52:50 +0000 (23:52 +0000)
Previously subgroup shuffle was implemented using the bpermute
instruction, which only works accross half-waves, so by itself it's
not suitable for implementing subgroup shuffle when the shader is
running in wave64 mode.

This commit adds a trick using shared VGPRs that allows to implement
subgroup shuffle still relatively effectively in this mode.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
src/amd/compiler/aco_builder_h.py
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_lower_to_hw_instr.cpp
src/amd/compiler/aco_opcodes.py
src/amd/compiler/aco_reduce_assign.cpp

index f6fccfec2b2ad099571a985b1e6db2dbf6062982..8ee86716a291f86b5622875d808ba8f9606bdf8d 100644 (file)
@@ -358,7 +358,7 @@ formats = [("pseudo", [Format.PSEUDO], 'Pseudo_instruction', list(itertools.prod
            ("exp", [Format.EXP], 'Export_instruction', [(0, 4)]),
            ("branch", [Format.PSEUDO_BRANCH], 'Pseudo_branch_instruction', itertools.product([0], [0, 1])),
            ("barrier", [Format.PSEUDO_BARRIER], 'Pseudo_barrier_instruction', [(0, 0)]),
-           ("reduction", [Format.PSEUDO_REDUCTION], 'Pseudo_reduction_instruction', [(3, 2)]),
+           ("reduction", [Format.PSEUDO_REDUCTION], 'Pseudo_reduction_instruction', [(3, 2), (3, 4)]),
            ("vop1", [Format.VOP1], 'VOP1_instruction', [(1, 1), (2, 2)]),
            ("vop2", [Format.VOP2], 'VOP2_instruction', itertools.product([1, 2], [2, 3])),
            ("vopc", [Format.VOPC], 'VOPC_instruction', itertools.product([1, 2], [2])),
index fc1838724e10434698dc27bb8fe3fef7366d16e6..768860a2c9b7252be6a7dce07bbcb7f7cabbbfbb 100644 (file)
@@ -146,6 +146,33 @@ Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_ne
    return dst;
 }
 
+static Temp emit_bpermute(isel_context *ctx, Builder &bld, Temp index, Temp data)
+{
+   Temp index_x4 = bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand(2u), index);
+
+   /* Currently not implemented on GFX6-7 */
+   assert(ctx->options->chip_class >= GFX8);
+
+   if (ctx->options->chip_class <= GFX9 || ctx->options->wave_size == 32) {
+      return bld.ds(aco_opcode::ds_bpermute_b32, bld.def(v1), index_x4, data);
+   }
+
+   /* GFX10, wave64 mode:
+    * The bpermute instruction is limited to half-wave operation, which means that it can't
+    * properly support subgroup shuffle like older generations (or wave32 mode), so we
+    * emulate it here.
+    */
+
+   Temp lane_id = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u));
+   lane_id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1), lane_id);
+   Temp lane_is_hi = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x20u), lane_id);
+   Temp index_is_hi = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x20u), index);
+   Temp cmp = bld.vopc(aco_opcode::v_cmp_eq_u32, bld.def(s2, vcc), lane_is_hi, index_is_hi);
+
+   return bld.reduction(aco_opcode::p_wave64_bpermute, bld.def(v1), bld.def(s2), bld.def(s1, scc),
+                        bld.vcc(cmp), Operand(v2.as_linear()), index_x4, data, gfx10_wave64_bpermute);
+}
+
 Temp as_vgpr(isel_context *ctx, Temp val)
 {
    if (val.type() == RegType::sgpr) {
@@ -5528,15 +5555,12 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
          assert(tid.regClass() == v1);
          Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
          if (src.regClass() == v1) {
-            tid = bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand(2u), tid);
-            emit_wqm(ctx, bld.ds(aco_opcode::ds_bpermute_b32, bld.def(v1), tid, src), dst);
+            emit_wqm(ctx, emit_bpermute(ctx, bld, tid, src), dst);
          } else if (src.regClass() == v2) {
-            tid = bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand(2u), tid);
-
             Temp lo = bld.tmp(v1), hi = bld.tmp(v1);
             bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), src);
-            lo = emit_wqm(ctx, bld.ds(aco_opcode::ds_bpermute_b32, bld.def(v1), tid, lo));
-            hi = emit_wqm(ctx, bld.ds(aco_opcode::ds_bpermute_b32, bld.def(v1), tid, hi));
+            lo = emit_wqm(ctx, emit_bpermute(ctx, bld, tid, lo));
+            hi = emit_wqm(ctx, emit_bpermute(ctx, bld, tid, hi));
             bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
             emit_split_vector(ctx, dst, 2);
          } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
index 90fc3c6fe3687c1793f5273e27c88b9a19c4b880..58d67ef293b6eab3060b9d2467d63adbd9413565 100644 (file)
@@ -831,6 +831,7 @@ enum ReduceOp {
    iand32, iand64,
    ior32, ior64,
    ixor32, ixor64,
+   gfx10_wave64_bpermute
 };
 
 /**
index 2572916380f9d142f0f53c2d4b38d6f18b4b5c3d..2fe865e2a90b86bc090351ce220a68c256172bc0 100644 (file)
@@ -31,6 +31,7 @@
 #include "aco_builder.h"
 #include "util/u_math.h"
 #include "sid.h"
+#include "vulkan/radv_shader.h"
 
 
 namespace aco {
@@ -143,8 +144,11 @@ uint32_t get_reduction_identity(ReduceOp op, unsigned idx)
       return 0xff800000u; /* negative infinity */
    case fmax64:
       return idx ? 0xfff00000u : 0u; /* negative infinity */
+   default:
+      unreachable("Invalid reduction operation");
+      break;
    }
-   unreachable("Invalid reduction operation");
+   return 0;
 }
 
 aco_opcode get_reduction_opcode(lower_context *ctx, ReduceOp op, bool *clobber_vcc, Format *format)
@@ -207,8 +211,10 @@ aco_opcode get_reduction_opcode(lower_context *ctx, ReduceOp op, bool *clobber_v
    case ixor64:
       assert(false);
       break;
+   default:
+      unreachable("Invalid reduction operation");
+      break;
    }
-   unreachable("Invalid reduction operation");
    return aco_opcode::v_min_u32;
 }
 
@@ -804,12 +810,74 @@ void lower_to_hw_instr(Program* program)
 
          } else if (instr->format == Format::PSEUDO_REDUCTION) {
             Pseudo_reduction_instruction* reduce = static_cast<Pseudo_reduction_instruction*>(instr.get());
-            emit_reduction(&ctx, reduce->opcode, reduce->reduce_op, reduce->cluster_size,
-                           reduce->operands[1].physReg(), // tmp
-                           reduce->definitions[1].physReg(), // stmp
-                           reduce->operands[2].physReg(), // vtmp
-                           reduce->definitions[2].physReg(), // sitmp
-                           reduce->operands[0], reduce->definitions[0]);
+            if (reduce->reduce_op == gfx10_wave64_bpermute) {
+               /* Only makes sense on GFX10 wave64 */
+               assert(program->chip_class >= GFX10);
+               assert(program->info->wave_size == 64);
+               assert(instr->definitions[0].regClass() == v1); /* Destination */
+               assert(instr->definitions[1].regClass() == s2); /* Temp EXEC */
+               assert(instr->definitions[1].physReg() != vcc);
+               assert(instr->definitions[2].physReg() == scc); /* SCC clobber */
+               assert(instr->operands[0].physReg() == vcc); /* Compare */
+               assert(instr->operands[1].regClass() == v2.as_linear()); /* Temp VGPR pair */
+               assert(instr->operands[2].regClass() == v1); /* Indices x4 */
+               assert(instr->operands[3].regClass() == v1); /* Input data */
+
+               /* Shared VGPRs are allocated in groups of 8 */
+               program->config->num_shared_vgprs = 8;
+
+               PhysReg shared_vgpr_reg_lo = PhysReg(align(program->config->num_vgprs, 4) + 256);
+               PhysReg shared_vgpr_reg_hi = PhysReg(shared_vgpr_reg_lo + 1);
+               Operand compare = instr->operands[0];
+               Operand tmp1(instr->operands[1].physReg(), v1);
+               Operand tmp2(PhysReg(instr->operands[1].physReg() + 1), v1);
+               Operand index_x4 = instr->operands[2];
+               Operand input_data = instr->operands[3];
+               Definition shared_vgpr_lo(shared_vgpr_reg_lo, v1);
+               Definition shared_vgpr_hi(shared_vgpr_reg_hi, v1);
+               Definition def_temp1(tmp1.physReg(), v1);
+               Definition def_temp2(tmp2.physReg(), v1);
+
+               /* Save EXEC and clear it */
+               bld.sop1(aco_opcode::s_and_saveexec_b64, instr->definitions[1], instr->definitions[2],
+                        Definition(exec, s2), Operand(0u), Operand(exec, s2));
+
+               /* Set EXEC to enable HI lanes only */
+               bld.sop1(aco_opcode::s_mov_b32, Definition(exec_hi, s1), Operand((uint32_t)-1));
+               /* HI: Copy data from high lanes 32-63 to shared vgpr */
+               bld.vop1(aco_opcode::v_mov_b32, shared_vgpr_hi, input_data);
+
+               /* Invert EXEC to enable LO lanes only */
+               bld.sop1(aco_opcode::s_not_b64, Definition(exec, s2), Operand(exec, s2));
+               /* LO: Copy data from low lanes 0-31 to shared vgpr */
+               bld.vop1(aco_opcode::v_mov_b32, shared_vgpr_lo, input_data);
+               /* LO: Copy shared vgpr (high lanes' data) to output vgpr */
+               bld.vop1(aco_opcode::v_mov_b32, def_temp1, Operand(shared_vgpr_reg_hi, v1));
+
+               /* Invert EXEC to enable HI lanes only */
+               bld.sop1(aco_opcode::s_not_b64, Definition(exec, s2), Operand(exec, s2));
+               /* HI: Copy shared vgpr (low lanes' data) to output vgpr */
+               bld.vop1(aco_opcode::v_mov_b32, def_temp1, Operand(shared_vgpr_reg_lo, v1));
+
+               /* Enable exec mask for all lanes */
+               bld.sop1(aco_opcode::s_mov_b64, Definition(exec, s2), Operand((uint32_t)-1));
+               /* Permute the original input */
+               bld.ds(aco_opcode::ds_bpermute_b32, def_temp2, index_x4, input_data);
+               /* Permute the swapped input */
+               bld.ds(aco_opcode::ds_bpermute_b32, def_temp1, index_x4, tmp1);
+
+               /* Restore saved EXEC */
+               bld.sop1(aco_opcode::s_mov_b64, Definition(exec, s2), Operand(instr->definitions[1].physReg(), s2));
+               /* Choose whether to use the original or swapped */
+               bld.vop2(aco_opcode::v_cndmask_b32, instr->definitions[0], tmp1, tmp2, compare);
+            } else {
+               emit_reduction(&ctx, reduce->opcode, reduce->reduce_op, reduce->cluster_size,
+                              reduce->operands[1].physReg(), // tmp
+                              reduce->definitions[1].physReg(), // stmp
+                              reduce->operands[2].physReg(), // vtmp
+                              reduce->definitions[2].physReg(), // sitmp
+                              reduce->operands[0], reduce->definitions[0]);
+            }
          } else {
             ctx.instructions.emplace_back(std::move(instr));
          }
index a358527e60b33550493df9fcd4455fecceb73b7d..08337a18d222781be3e24a10dca9a70738f44819 100644 (file)
@@ -212,6 +212,8 @@ opcode("p_reduce", format=Format.PSEUDO_REDUCTION)
 opcode("p_inclusive_scan", format=Format.PSEUDO_REDUCTION)
 # e.g. subgroupExclusiveMin()
 opcode("p_exclusive_scan", format=Format.PSEUDO_REDUCTION)
+# simulates proper bpermute behavior on GFX10 wave64
+opcode("p_wave64_bpermute", format=Format.PSEUDO_REDUCTION)
 
 opcode("p_branch", format=Format.PSEUDO_BRANCH)
 opcode("p_cbranch", format=Format.PSEUDO_BRANCH)
index 66a3ec64c044e7583ac32e68dbb98e593bb37400..d9c762a65dbbea4030a6d871d57af727d6ad4265 100644 (file)
@@ -118,10 +118,12 @@ void setup_reduce_temp(Program* program)
          unsigned cluster_size = static_cast<Pseudo_reduction_instruction *>(instr)->cluster_size;
          bool need_vtmp = op == imul32 || op == fadd64 || op == fmul64 ||
                           op == fmin64 || op == fmax64;
-         if (program->chip_class >= GFX10 && cluster_size == 64)
+
+         if (program->chip_class >= GFX10 && cluster_size == 64 && op != gfx10_wave64_bpermute)
             need_vtmp = true;
 
          need_vtmp |= cluster_size == 32;
+
          vtmp_in_loop |= need_vtmp && block.loop_nest_depth > 0;
          if (need_vtmp && (int)last_top_level_block_idx != vtmp_inserted_at) {
             vtmp = {program->allocateId(), vtmp.regClass()};