From c44af6cbc7731f8f482da38298887198d975e245 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Timur=20Krist=C3=B3f?= Date: Thu, 31 Oct 2019 11:26:14 +0100 Subject: [PATCH] aco/wave32: Introduce emit_mbcnt which takes wave size into account. MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit This is relevant because in wave32 mode the v_mbcnt_hi_u32_b32 instruction is superfluous. Signed-off-by: Timur Kristóf Reviewed-by: Daniel Schürmann --- .../compiler/aco_instruction_selection.cpp | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index f262d0f7e5a..a2b2c21170c 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -123,6 +123,21 @@ Temp get_ssa_temp(struct isel_context *ctx, nir_ssa_def *def) return ctx->allocated[def->index]; } +Temp emit_mbcnt(isel_context *ctx, Definition dst, + Operand mask_lo = Operand((uint32_t) -1), Operand mask_hi = Operand((uint32_t) -1)) +{ + Builder bld(ctx->program, ctx->block); + Definition lo_def = ctx->program->wave_size == 32 ? dst : bld.def(v1); + Temp thread_id_lo = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, lo_def, mask_lo, Operand(0u)); + + if (ctx->program->wave_size == 32) { + return thread_id_lo; + } else { + Temp thread_id_hi = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, dst, mask_hi, thread_id_lo); + return thread_id_hi; + } +} + Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_needs_wqm = false) { Builder bld(ctx->program, ctx->block); @@ -170,8 +185,7 @@ static Temp emit_bpermute(isel_context *ctx, Builder &bld, Temp index, Temp data ctx->program->vgpr_limit -= 4; /* We allocate 8 shared VGPRs, so we'll have 4 fewer normal VGPRs */ } - Temp lane_id = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u)); - lane_id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1), lane_id); + Temp lane_id = emit_mbcnt(ctx, bld.def(v1)); Temp lane_is_hi = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x20u), lane_id); Temp index_is_hi = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x20u), index); Temp cmp = bld.vopc(aco_opcode::v_cmp_eq_u32, bld.def(s2, vcc), lane_is_hi, index_is_hi); @@ -5228,7 +5242,7 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te return bool_to_vector_condition(ctx, tmp); } else { //subgroupClustered{And,Or,Xor}(val, n) -> - //lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0)) + //lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0)) ; just v_mbcnt_lo_u32_b32 on wave32 //cluster_offset = ~(n - 1) & lane_id //cluster_mask = ((1 << n) - 1) //subgroupClusteredAnd(): @@ -5237,8 +5251,7 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te // return ((val & exec) >> cluster_offset) & cluster_mask != 0 //subgroupClusteredXor(): // return v_bnt_u32_b32(((val & exec) >> cluster_offset) & cluster_mask, 0) & 1 != 0 - Temp lane_id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1), - bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u))); + Temp lane_id = emit_mbcnt(ctx, bld.def(v1)); Temp cluster_offset = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(~uint32_t(cluster_size - 1)), lane_id); Temp tmp; @@ -5284,8 +5297,7 @@ Temp emit_boolean_exclusive_scan(isel_context *ctx, nir_op op, Temp src) Builder::Result lohi = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), bld.def(s1), tmp); Temp lo = lohi.def(0).getTemp(); Temp hi = lohi.def(1).getTemp(); - Temp mbcnt = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), hi, - bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), lo, Operand(0u))); + Temp mbcnt = emit_mbcnt(ctx, bld.def(v1), Operand(lo), Operand(hi)); Definition cmp_def = Definition(); if (op == nir_op_iand) @@ -5645,8 +5657,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) break; } case nir_intrinsic_load_local_invocation_index: { - Temp id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1), - bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u))); + Temp id = emit_mbcnt(ctx, bld.def(v1)); Temp tg_num = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), Operand(0xfc0u), get_arg(ctx, ctx->args->ac.tg_size)); bld.vop2(aco_opcode::v_or_b32, Definition(get_ssa_temp(ctx, &instr->dest.ssa)), tg_num, id); @@ -5662,8 +5673,7 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) break; } case nir_intrinsic_load_subgroup_invocation: { - bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, Definition(get_ssa_temp(ctx, &instr->dest.ssa)), Operand((uint32_t) -1), - bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u))); + emit_mbcnt(ctx, Definition(get_ssa_temp(ctx, &instr->dest.ssa))); break; } case nir_intrinsic_load_num_subgroups: { @@ -6024,9 +6034,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) RegClass rc = RegClass(src.type(), 1); Temp mask_lo = bld.tmp(rc), mask_hi = bld.tmp(rc); bld.pseudo(aco_opcode::p_split_vector, Definition(mask_lo), Definition(mask_hi), src); - Temp tmp = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), mask_lo, Operand(0u)); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); - Temp wqm_tmp = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), mask_hi, tmp); + Temp wqm_tmp = emit_mbcnt(ctx, bld.def(v1), Operand(mask_lo), Operand(mask_hi)); emit_wqm(ctx, wqm_tmp, dst); break; } @@ -7745,8 +7754,7 @@ static void emit_streamout(isel_context *ctx, unsigned stream) Temp so_vtx_count = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), get_arg(ctx, ctx->args->streamout_config), Operand(0x70010u)); - Temp tid = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1), - bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u))); + Temp tid = emit_mbcnt(ctx, bld.def(v1)); Temp can_emit = bld.vopc(aco_opcode::v_cmp_gt_i32, bld.def(s2), so_vtx_count, tid); @@ -7925,8 +7933,7 @@ void select_program(Program *program, if (shader_count >= 2) { Builder bld(ctx.program, ctx.block); Temp count = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), ctx.merged_wave_info, Operand((8u << 16) | (i * 8u))); - Temp thread_id = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), Operand((uint32_t) -1), - bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), Operand((uint32_t) -1), Operand(0u))); + Temp thread_id = emit_mbcnt(&ctx, bld.def(v1)); Temp cond = bld.vopc(aco_opcode::v_cmp_gt_u32, bld.hint_vcc(bld.def(s2)), count, thread_id); begin_divergent_if_then(&ctx, &ic, cond); -- 2.30.2