aco/wave32: Introduce emit_mbcnt which takes wave size into account.
authorTimur Kristóf <timur.kristof@gmail.com>
Thu, 31 Oct 2019 10:26:14 +0000 (11:26 +0100)
committerDaniel Schürmann <daniel@schuermann.dev>
Wed, 4 Dec 2019 10:36:01 +0000 (10:36 +0000)
This is relevant because in wave32 mode the v_mbcnt_hi_u32_b32
instruction is superfluous.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
src/amd/compiler/aco_instruction_selection.cpp

index f262d0f7e5a5e6e618b1b137f7ae6e5a845670df..a2b2c21170cf703ae170d61a1e1358312e4c8521 100644 (file)
@@ -123,6 +123,21 @@ Temp get_ssa_temp(struct isel_context *ctx, nir_ssa_def *def)
    return ctx->allocated[def->index];
 }
 
+Temp emit_mbcnt(isel_context *ctx, Definition dst,
+                Operand mask_lo = Operand((uint32_t) -1), Operand mask_hi = Operand((uint32_t) -1))
+{
+   Builder bld(ctx->program, ctx->block);
+   Definition lo_def = ctx->program->wave_size == 32 ? dst : bld.def(v1);
+   Temp thread_id_lo = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, lo_def, mask_lo, Operand(0u));
+
+   if (ctx->program->wave_size == 32) {
+      return thread_id_lo;
+   } else {
+      Temp thread_id_hi = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, dst, mask_hi, thread_id_lo);
+      return thread_id_hi;
+   }
+}
+
 Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_needs_wqm = false)
 {
    Builder bld(ctx->program, ctx->block);
@@ -170,8 +185,7 @@ static Temp emit_bpermute(isel_context *ctx, Builder &bld, Temp index, Temp data
       ctx->program->vgpr_limit -= 4; /* We allocate 8 shared VGPRs, so we'll have 4 fewer normal VGPRs */
    }
 
-   Temp lane_id = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u));
-   lane_id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1), lane_id);
+   Temp lane_id = emit_mbcnt(ctx, bld.def(v1));
    Temp lane_is_hi = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x20u), lane_id);
    Temp index_is_hi = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x20u), index);
    Temp cmp = bld.vopc(aco_opcode::v_cmp_eq_u32, bld.def(s2, vcc), lane_is_hi, index_is_hi);
@@ -5228,7 +5242,7 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te
       return bool_to_vector_condition(ctx, tmp);
    } else {
       //subgroupClustered{And,Or,Xor}(val, n) ->
-      //lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0))
+      //lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0)) ;  just v_mbcnt_lo_u32_b32 on wave32
       //cluster_offset = ~(n - 1) & lane_id
       //cluster_mask = ((1 << n) - 1)
       //subgroupClusteredAnd():
@@ -5237,8 +5251,7 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te
       //   return ((val & exec) >> cluster_offset) & cluster_mask != 0
       //subgroupClusteredXor():
       //   return v_bnt_u32_b32(((val & exec) >> cluster_offset) & cluster_mask, 0) & 1 != 0
-      Temp lane_id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1),
-                              bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u)));
+      Temp lane_id = emit_mbcnt(ctx, bld.def(v1));
       Temp cluster_offset = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(~uint32_t(cluster_size - 1)), lane_id);
 
       Temp tmp;
@@ -5284,8 +5297,7 @@ Temp emit_boolean_exclusive_scan(isel_context *ctx, nir_op op, Temp src)
    Builder::Result lohi = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), bld.def(s1), tmp);
    Temp lo = lohi.def(0).getTemp();
    Temp hi = lohi.def(1).getTemp();
-   Temp mbcnt = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), hi,
-                         bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), lo, Operand(0u)));
+   Temp mbcnt = emit_mbcnt(ctx, bld.def(v1), Operand(lo), Operand(hi));
 
    Definition cmp_def = Definition();
    if (op == nir_op_iand)
@@ -5645,8 +5657,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       break;
    }
    case nir_intrinsic_load_local_invocation_index: {
-      Temp id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1),
-                         bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u)));
+      Temp id = emit_mbcnt(ctx, bld.def(v1));
       Temp tg_num = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), Operand(0xfc0u),
                              get_arg(ctx, ctx->args->ac.tg_size));
       bld.vop2(aco_opcode::v_or_b32, Definition(get_ssa_temp(ctx, &instr->dest.ssa)), tg_num, id);
@@ -5662,8 +5673,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       break;
    }
    case nir_intrinsic_load_subgroup_invocation: {
-      bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, Definition(get_ssa_temp(ctx, &instr->dest.ssa)), Operand((uint32_t) -1),
-               bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u)));
+      emit_mbcnt(ctx, Definition(get_ssa_temp(ctx, &instr->dest.ssa)));
       break;
    }
    case nir_intrinsic_load_num_subgroups: {
@@ -6024,9 +6034,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       RegClass rc = RegClass(src.type(), 1);
       Temp mask_lo = bld.tmp(rc), mask_hi = bld.tmp(rc);
       bld.pseudo(aco_opcode::p_split_vector, Definition(mask_lo), Definition(mask_hi), src);
-      Temp tmp = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), mask_lo, Operand(0u));
       Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
-      Temp wqm_tmp = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), mask_hi, tmp);
+      Temp wqm_tmp = emit_mbcnt(ctx, bld.def(v1), Operand(mask_lo), Operand(mask_hi));
       emit_wqm(ctx, wqm_tmp, dst);
       break;
    }
@@ -7745,8 +7754,7 @@ static void emit_streamout(isel_context *ctx, unsigned stream)
    Temp so_vtx_count = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
                                 get_arg(ctx, ctx->args->streamout_config), Operand(0x70010u));
 
-   Temp tid = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1),
-                       bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u)));
+   Temp tid = emit_mbcnt(ctx, bld.def(v1));
 
    Temp can_emit = bld.vopc(aco_opcode::v_cmp_gt_i32, bld.def(s2), so_vtx_count, tid);
 
@@ -7925,8 +7933,7 @@ void select_program(Program *program,
       if (shader_count >= 2) {
          Builder bld(ctx.program, ctx.block);
          Temp count = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), ctx.merged_wave_info, Operand((8u << 16) | (i * 8u)));
-         Temp thread_id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1),
-                                   bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u)));
+         Temp thread_id = emit_mbcnt(&ctx, bld.def(v1));
          Temp cond = bld.vopc(aco_opcode::v_cmp_gt_u32, bld.hint_vcc(bld.def(s2)), count, thread_id);
 
          begin_divergent_if_then(&ctx, &ic, cond);