From: Timur Kristóf Date: Mon, 4 Nov 2019 18:28:08 +0000 (+0100) Subject: aco: Treat all booleans as per-lane. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=8995c0b30a696c709fac9e5f761c101913dc92ec;p=mesa.git aco: Treat all booleans as per-lane. Previously, instruction selection had two kinds of booleans: 1. divergent which was per-lane and stored in s2 (VCC size) 2. uniform which was stored in s1 Additionally, uniform booleans were made per-lane when they resulted from operations which were supported only by the VALU. To decide which type was used, we relied on the destination size, which was not reliable due to the per-lane uniform bools, but it mostly works on wave64. However, in wave32 mode (where VCC is also s1) this approach makes it impossible keep track of which boolean is uniform and which is divergent. This commit makes all booleans per-lane. The resulting excess code size will be taken care of by the optimizer. v2 (by Daniel Schürmann): - Better names for some functions - Use s_andn2_b64 with exec for nir_op_inot - Simplify code due to using s_and_b64 in bool_to_scalar_condition v3 (by Timur Kristóf): - Fix several subgroups regressions Signed-off-by: Timur Kristóf Reviewed-by: Rhys Perry Reviewed-by: Daniel Schürmann --- diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index ab34a068671..a7c3c703403 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -130,6 +130,8 @@ Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_ne if (!dst.id()) dst = bld.tmp(src.regClass()); + assert(src.size() == dst.size()); + if (ctx->stage != fragment_fs) { if (!dst.id()) return src; @@ -331,33 +333,31 @@ void expand_vector(isel_context* ctx, Temp vec_src, Temp dst, unsigned num_compo ctx->allocated_vec.emplace(dst.id(), elems); } -Temp as_divergent_bool(isel_context *ctx, Temp val, bool vcc_hint) +Temp bool_to_vector_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s2)) { - if (val.regClass() == s2) { - return val; - } else { - assert(val.regClass() == s1); - Builder bld(ctx->program, ctx->block); - Definition& def = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), - Operand((uint32_t) -1), Operand(0u), bld.scc(val)).def(0); - if (vcc_hint) - def.setHint(vcc); - return def.getTemp(); - } + Builder bld(ctx->program, ctx->block); + if (!dst.id()) + dst = bld.tmp(s2); + + assert(val.regClass() == s1); + assert(dst.regClass() == s2); + + return bld.sop2(aco_opcode::s_cselect_b64, bld.hint_vcc(Definition(dst)), Operand((uint32_t) -1), Operand(0u), bld.scc(val)); } -Temp as_uniform_bool(isel_context *ctx, Temp val) +Temp bool_to_scalar_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s1)) { - if (val.regClass() == s1) { - return val; - } else { - assert(val.regClass() == s2); - Builder bld(ctx->program, ctx->block); - /* if we're currently in WQM mode, ensure that the source is also computed in WQM */ - Temp tmp = bld.tmp(s1); - bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2)).def(1).getTemp(); - return emit_wqm(ctx, tmp); - } + Builder bld(ctx->program, ctx->block); + if (!dst.id()) + dst = bld.tmp(s1); + + assert(val.regClass() == s2); + assert(dst.regClass() == s1); + + /* if we're currently in WQM mode, ensure that the source is also computed in WQM */ + Temp tmp = bld.tmp(s1); + bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2)); + return emit_wqm(ctx, tmp, dst); } Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1) @@ -526,27 +526,44 @@ void emit_vopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o src1 = as_vgpr(ctx, src1); } } + Builder bld(ctx->program, ctx->block); - bld.vopc(op, Definition(dst), src0, src1).def(0).setHint(vcc); + bld.vopc(op, bld.hint_vcc(Definition(dst)), src0, src1); } -void emit_comparison(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst) +void emit_sopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst) { - if (dst.regClass() == s2) { - emit_vopc_instruction(ctx, instr, op, dst); - if (!ctx->divergent_vals[instr->dest.dest.ssa.index]) - emit_split_vector(ctx, dst, 2); - } else if (dst.regClass() == s1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - assert(src0.type() == RegType::sgpr && src1.type() == RegType::sgpr); + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); - Builder bld(ctx->program, ctx->block); - bld.sopc(op, bld.scc(Definition(dst)), src0, src1); + assert(dst.regClass() == s2); + assert(src0.type() == RegType::sgpr); + assert(src1.type() == RegType::sgpr); - } else { - assert(false); - } + Builder bld(ctx->program, ctx->block); + /* Emit the SALU comparison instruction */ + Temp cmp = bld.sopc(op, bld.scc(bld.def(s1)), src0, src1); + /* Turn the result into a per-lane bool */ + bool_to_vector_condition(ctx, cmp, dst); +} + +void emit_comparison(isel_context *ctx, nir_alu_instr *instr, Temp dst, + aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::last_opcode, aco_opcode s64_op = aco_opcode::last_opcode) +{ + aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : s32_op; + aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : v32_op; + bool divergent_vals = ctx->divergent_vals[instr->dest.dest.ssa.index]; + bool use_valu = s_op == aco_opcode::last_opcode || + divergent_vals || + ctx->allocated[instr->src[0].src.ssa->index].type() == RegType::vgpr || + ctx->allocated[instr->src[1].src.ssa->index].type() == RegType::vgpr; + aco_opcode op = use_valu ? v_op : s_op; + assert(op != aco_opcode::last_opcode); + + if (use_valu) + emit_vopc_instruction(ctx, instr, op, dst); + else + emit_sopc_instruction(ctx, instr, op, dst); } void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32, aco_opcode op64, Temp dst) @@ -554,16 +571,13 @@ void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32 Builder bld(ctx->program, ctx->block); Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); - if (dst.regClass() == s2) { - bld.sop2(op64, Definition(dst), bld.def(s1, scc), - as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false)); - } else { - assert(dst.regClass() == s1); - bld.sop2(op32, bld.def(s1), bld.scc(Definition(dst)), - as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1)); - } -} + assert(dst.regClass() == s2); + assert(src0.regClass() == s2); + assert(src1.regClass() == s2); + + bld.sop2(op64, Definition(dst), bld.def(s1, scc), src0, src1); +} void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) { @@ -572,9 +586,9 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) Temp then = get_alu_src(ctx, instr->src[1]); Temp els = get_alu_src(ctx, instr->src[2]); - if (dst.type() == RegType::vgpr) { - cond = as_divergent_bool(ctx, cond, true); + assert(cond.regClass() == s2); + if (dst.type() == RegType::vgpr) { aco_ptr bcsel; if (dst.size() == 1) { then = as_vgpr(ctx, then); @@ -599,11 +613,17 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) return; } - if (instr->dest.dest.ssa.bit_size != 1) { /* uniform condition and values in sgpr */ + if (instr->dest.dest.ssa.bit_size == 1) { + assert(dst.regClass() == s2); + assert(then.regClass() == s2); + assert(els.regClass() == s2); + } + + if (!ctx->divergent_vals[instr->src[0].src.ssa->index]) { /* uniform condition and values in sgpr */ if (dst.regClass() == s1 || dst.regClass() == s2) { assert((then.regClass() == s1 || then.regClass() == s2) && els.regClass() == then.regClass()); aco_opcode op = dst.regClass() == s1 ? aco_opcode::s_cselect_b32 : aco_opcode::s_cselect_b64; - bld.sop2(op, Definition(dst), then, els, bld.scc(as_uniform_bool(ctx, cond))); + bld.sop2(op, Definition(dst), then, els, bld.scc(bool_to_scalar_condition(ctx, cond))); } else { fprintf(stderr, "Unimplemented uniform bcsel bit size: "); nir_print_instr(&instr->instr, stderr); @@ -612,34 +632,10 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) return; } - /* boolean bcsel */ - assert(instr->dest.dest.ssa.bit_size == 1); - - if (dst.regClass() == s1) - cond = as_uniform_bool(ctx, cond); - - if (cond.regClass() == s1) { /* uniform selection */ - aco_opcode op; - if (dst.regClass() == s2) { - op = aco_opcode::s_cselect_b64; - then = as_divergent_bool(ctx, then, false); - els = as_divergent_bool(ctx, els, false); - } else { - assert(dst.regClass() == s1); - op = aco_opcode::s_cselect_b32; - then = as_uniform_bool(ctx, then); - els = as_uniform_bool(ctx, els); - } - bld.sop2(op, Definition(dst), then, els, bld.scc(cond)); - return; - } - /* divergent boolean bcsel * this implements bcsel on bools: dst = s0 ? s1 : s2 * are going to be: dst = (s0 & s1) | (~s0 & s2) */ - assert (dst.regClass() == s2); - then = as_divergent_bool(ctx, then, false); - els = as_divergent_bool(ctx, els, false); + assert(instr->dest.dest.ssa.bit_size == 1); if (cond.id() != then.id()) then = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), cond, then); @@ -700,16 +696,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_inot: { Temp src = get_alu_src(ctx, instr->src[0]); - /* uniform booleans */ - if (instr->dest.dest.ssa.bit_size == 1 && dst.regClass() == s1) { - if (src.regClass() == s1) { - /* in this case, src is either 1 or 0 */ - bld.sop2(aco_opcode::s_xor_b32, bld.def(s1), bld.scc(Definition(dst)), Operand(1u), src); - } else { - /* src is either exec_mask or 0 */ - assert(src.regClass() == s2); - bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(Definition(dst)), Operand(0u), src); - } + if (instr->dest.dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); + assert(dst.regClass() == s2); + bld.sop2(aco_opcode::s_andn2_b64, Definition(dst), bld.def(s1, scc), Operand(exec, s2), src); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_not_b32, dst); } else if (dst.type() == RegType::sgpr) { @@ -1919,12 +1909,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_b2f32: { Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == s2); + if (dst.regClass() == s1) { - src = as_uniform_bool(ctx, src); + src = bool_to_scalar_condition(ctx, src); bld.sop2(aco_opcode::s_mul_i32, Definition(dst), Operand(0x3f800000u), src); } else if (dst.regClass() == v1) { - bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u), - as_divergent_bool(ctx, src, true)); + bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u), src); } else { unreachable("Wrong destination register class for nir_op_b2f32."); } @@ -1932,13 +1923,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_b2f64: { Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == s2); + if (dst.regClass() == s2) { - src = as_uniform_bool(ctx, src); + src = bool_to_scalar_condition(ctx, src); bld.sop2(aco_opcode::s_cselect_b64, Definition(dst), Operand(0x3f800000u), Operand(0u), bld.scc(src)); } else if (dst.regClass() == v2) { Temp one = bld.vop1(aco_opcode::v_mov_b32, bld.def(v2), Operand(0x3FF00000u)); - Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, - as_divergent_bool(ctx, src, true)); + Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, src); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), Operand(0u), upper); } else { unreachable("Wrong destination register class for nir_op_b2f64."); @@ -2000,29 +1992,31 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_b2i32: { Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == s2); + if (dst.regClass() == s1) { - if (src.regClass() == s1) { - bld.copy(Definition(dst), src); - } else { - // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ - assert(src.regClass() == s2); - bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(Definition(dst)), Operand(0u), src); - } - } else { - assert(dst.regClass() == v1 && src.regClass() == s2); + // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ + bool_to_scalar_condition(ctx, src, dst); + } else if (dst.regClass() == v1) { bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(1u), src); + } else { + unreachable("Invalid register class for b2i32"); } break; } case nir_op_i2b1: { Temp src = get_alu_src(ctx, instr->src[0]); - if (dst.regClass() == s2) { + assert(dst.regClass() == s2); + + if (src.type() == RegType::vgpr) { assert(src.regClass() == v1 || src.regClass() == v2); bld.vopc(src.size() == 2 ? aco_opcode::v_cmp_lg_u64 : aco_opcode::v_cmp_lg_u32, Definition(dst), Operand(0u), src).def(0).setHint(vcc); } else { - assert(src.regClass() == s1 && dst.regClass() == s1); - bld.sopc(aco_opcode::s_cmp_lg_u32, bld.scc(Definition(dst)), Operand(0u), src); + assert(src.regClass() == s1 || src.regClass() == s2); + Temp tmp = bld.sopc(src.size() == 2 ? aco_opcode::s_cmp_lg_u64 : aco_opcode::s_cmp_lg_u32, + bld.scc(bld.def(s1)), Operand(0u), src); + bool_to_vector_condition(ctx, tmp, dst); } break; } @@ -2228,119 +2222,49 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_flt: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64); break; } case nir_op_fge: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64); break; } case nir_op_feq: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64); break; } case nir_op_fne: { - if (instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f32, dst); - else if (instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64); break; } case nir_op_ilt: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_i32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32); break; } case nir_op_ige: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_i32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32); break; } case nir_op_ieq: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i32, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_i32, dst); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_u64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sopc(aco_opcode::s_cmp_eq_i32, bld.scc(Definition(dst)), - as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1)); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sop2(aco_opcode::s_xnor_b64, Definition(dst), bld.def(s1, scc), - as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false)); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + if (instr->src[0].src.ssa->bit_size == 1) + emit_boolean_logic(ctx, instr, aco_opcode::s_xnor_b32, aco_opcode::s_xnor_b64, dst); + else + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, aco_opcode::s_cmp_eq_u64); break; } case nir_op_ine: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i32, dst); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_i32, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) { - emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_u64, dst); - } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sopc(aco_opcode::s_cmp_lg_i32, bld.scc(Definition(dst)), - as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1)); - } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = get_alu_src(ctx, instr->src[1]); - bld.sop2(aco_opcode::s_xor_b64, Definition(dst), bld.def(s1, scc), - as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false)); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + if (instr->src[0].src.ssa->bit_size == 1) + emit_boolean_logic(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::s_xor_b64, dst); + else + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, aco_opcode::s_cmp_lg_u64); break; } case nir_op_ult: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_u32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32); break; } case nir_op_uge: { - if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u32, dst); - else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) - emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_u32, dst); - else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) - emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u64, dst); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32); break; } case nir_op_fddx: @@ -2387,9 +2311,13 @@ void visit_load_const(isel_context *ctx, nir_load_const_instr *instr) assert(instr->def.num_components == 1 && "Vector load_const should be lowered to scalar."); assert(dst.type() == RegType::sgpr); - if (dst.size() == 1) - { - Builder(ctx->program, ctx->block).copy(Definition(dst), Operand(instr->value[0].u32)); + Builder bld(ctx->program, ctx->block); + + if (instr->def.bit_size == 1) { + assert(dst.regClass() == s2); + bld.sop1(aco_opcode::s_mov_b64, Definition(dst), Operand((uint64_t)(instr->value[0].b ? -1 : 0))); + } else if (dst.size() == 1) { + bld.copy(Definition(dst), Operand(instr->value[0].u32)); } else { assert(dst.size() != 1); aco_ptr vec{create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, dst.size(), 1)}; @@ -3577,7 +3505,8 @@ void visit_discard_if(isel_context *ctx, nir_intrinsic_instr *instr) // TODO: optimize uniform conditions Builder bld(ctx->program, ctx->block); - Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); + assert(src.regClass() == s2); src = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)); bld.pseudo(aco_opcode::p_discard_if, src); ctx->block->kind |= block_kind_uses_discard_if; @@ -5114,15 +5043,17 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te } else if (op == nir_op_iand && cluster_size == 64) { //subgroupAnd(val) -> (exec & ~val) == 0 Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp(); - return bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), tmp, Operand(0u)); + return bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp)); } else if (op == nir_op_ior && cluster_size == 64) { //subgroupOr(val) -> (val & exec) != 0 - return bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp(); + Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp(); + return bool_to_vector_condition(ctx, tmp); } else if (op == nir_op_ixor && cluster_size == 64) { //subgroupXor(val) -> s_bcnt1_i32_b64(val & exec) & 1 Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)); tmp = bld.sop1(aco_opcode::s_bcnt1_i32_b64, bld.def(s2), bld.def(s1, scc), tmp); - return bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp(); + tmp = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp(); + return bool_to_vector_condition(ctx, tmp); } else { //subgroupClustered{And,Or,Xor}(val, n) -> //lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0)) @@ -5221,8 +5152,6 @@ void emit_uniform_subgroup(isel_context *ctx, nir_intrinsic_instr *instr, Temp s Definition dst(get_ssa_temp(ctx, &instr->dest.ssa)); if (src.regClass().type() == RegType::vgpr) { bld.pseudo(aco_opcode::p_as_uniform, dst, src); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { - bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(dst), Operand(0u), Operand(src)); } else if (src.regClass() == s1) { bld.sop1(aco_opcode::s_mov_b32, dst, src); } else if (src.regClass() == s2) { @@ -5541,10 +5470,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) case nir_intrinsic_ballot: { Definition tmp = bld.def(s2); Temp src = get_ssa_temp(ctx, instr->src[0].ssa); - if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s2) { + if (instr->src[0].ssa->bit_size == 1) { + assert(src.regClass() == s2); bld.sop2(aco_opcode::s_and_b64, tmp, bld.def(s1, scc), Operand(exec, s2), src); - } else if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s1) { - bld.sop2(aco_opcode::s_cselect_b64, tmp, Operand(exec, s2), Operand(0u), bld.scc(src)); } else if (instr->src[0].ssa->bit_size == 32 && src.regClass() == v1) { bld.vopc(aco_opcode::v_cmp_lg_u32, tmp, Operand(0u), src); } else if (instr->src[0].ssa->bit_size == 64 && src.regClass() == v2) { @@ -5576,9 +5504,12 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) hi = emit_wqm(ctx, emit_bpermute(ctx, bld, tid, hi)); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi); emit_split_vector(ctx, dst, 2); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2 && tid.regClass() == s1) { - emit_wqm(ctx, bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid), dst); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { + } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == s1) { + assert(src.regClass() == s2); + Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid); + bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst); + } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == v1) { + assert(src.regClass() == s2); Temp tmp = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), tid, src); tmp = emit_extract_vector(ctx, tmp, 0, v1); tmp = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(1u), tmp); @@ -5614,11 +5545,11 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) hi = emit_wqm(ctx, bld.vop1(aco_opcode::v_readfirstlane_b32, bld.def(s1), hi)); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi); emit_split_vector(ctx, dst, 2); - } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { - emit_wqm(ctx, - bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, - bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2))), - dst); + } else if (instr->dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); + Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, + bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2))); + bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst); } else if (src.regClass() == s1) { bld.sop1(aco_opcode::s_mov_b32, Definition(dst), src); } else if (src.regClass() == s2) { @@ -5631,27 +5562,25 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) break; } case nir_intrinsic_vote_all: { - Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); assert(src.regClass() == s2); - assert(dst.regClass() == s1); + assert(dst.regClass() == s2); - Definition tmp = bld.def(s1); - bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(tmp), - bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)), - Operand(exec, s2)); - emit_wqm(ctx, tmp.getTemp(), dst); + Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp(); + Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp)); + emit_wqm(ctx, val, dst); break; } case nir_intrinsic_vote_any: { - Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); assert(src.regClass() == s2); - assert(dst.regClass() == s1); + assert(dst.regClass() == s2); - Definition tmp = bld.def(s1); - bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(tmp), src, Operand(exec, s2)); - emit_wqm(ctx, tmp.getTemp(), dst); + Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp(); + Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(-1u), Operand(0u), bld.scc(tmp)); + emit_wqm(ctx, val, dst); break; } case nir_intrinsic_reduce: @@ -5752,7 +5681,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) } else { Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); unsigned lane = nir_src_as_const_value(instr->src[1])->u32; - if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { + if (instr->dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); uint32_t half_mask = 0x11111111u << lane; Temp mask_tmp = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), Operand(half_mask), Operand(half_mask)); Temp tmp = bld.tmp(s2); @@ -5809,7 +5739,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) } Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); - if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) { + if (instr->dest.ssa.bit_size == 1) { + assert(src.regClass() == s2); src = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), Operand((uint32_t)-1), src); src = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl); Temp tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(s2), Operand(0u), src); @@ -5912,9 +5843,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr) ctx->program->needs_exact = true; break; case nir_intrinsic_demote_if: { - Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), - as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false), - Operand(exec, s2)); + Temp src = get_ssa_temp(ctx, instr->src[0].ssa); + assert(src.regClass() == s2); + Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)); bld.pseudo(aco_opcode::p_demote_to_helper, cond); ctx->block->kind |= block_kind_uses_demote; ctx->program->needs_exact = true; @@ -6520,7 +6451,9 @@ void visit_tex(isel_context *ctx, nir_tex_instr *instr) Operand((uint32_t)V_008F14_IMG_NUM_FORMAT_SINT), bld.scc(compare_cube_wa)); } - tg4_compare_cube_wa64 = as_divergent_bool(ctx, compare_cube_wa, true); + tg4_compare_cube_wa64 = bld.tmp(s2); + bool_to_vector_condition(ctx, compare_cube_wa, tg4_compare_cube_wa64); + nfmt = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), nfmt, Operand(26u)); desc[1] = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), desc[1], @@ -6770,6 +6703,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr) aco_ptr phi; unsigned num_src = exec_list_length(&instr->srcs); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); + assert(instr->dest.ssa.bit_size != 1 || dst.regClass() == s2); aco_opcode opcode = !dst.is_linear() || ctx->divergent_vals[instr->dest.ssa.index] ? aco_opcode::p_phi : aco_opcode::p_linear_phi; @@ -6797,7 +6731,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr) } /* try to scalarize vector phis */ - if (dst.size() > 1) { + if (instr->dest.ssa.bit_size != 1 && dst.size() > 1) { // TODO: scalarize linear phis on divergent ifs bool can_scalarize = (opcode == aco_opcode::p_phi || !(ctx->block->kind & block_kind_merge)); std::array new_vec; @@ -7265,10 +7199,10 @@ static void visit_if(isel_context *ctx, nir_if *if_stmt) ctx->block->kind |= block_kind_uniform; /* emit branch */ - if (cond.regClass() == s2) { - // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction - cond = as_uniform_bool(ctx, cond); - } + assert(cond.regClass() == s2); + // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction + cond = bool_to_scalar_condition(ctx, cond); + branch.reset(create_instruction(aco_opcode::p_cbranch_z, Format::PSEUDO_BRANCH, 1, 0)); branch->operands[0] = Operand(cond); branch->operands[0].setFixed(scc); diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index 2c349635799..cdc8103497b 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -244,25 +244,14 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_op_fge: case nir_op_feq: case nir_op_fne: - size = 2; - break; case nir_op_ilt: case nir_op_ige: case nir_op_ult: case nir_op_uge: - size = alu_instr->src[0].src.ssa->bit_size == 64 ? 2 : 1; - /* fallthrough */ case nir_op_ieq: case nir_op_ine: case nir_op_i2b1: - if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) { - size = 2; - } else { - for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) { - if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr) - size = 2; - } - } + size = 2; break; case nir_op_f2i64: case nir_op_f2u64: @@ -274,13 +263,7 @@ void init_context(isel_context *ctx, nir_shader *shader) break; case nir_op_bcsel: if (alu_instr->dest.dest.ssa.bit_size == 1) { - if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) - size = 2; - else if (allocated[alu_instr->src[1].src.ssa->index].regClass() == s2 && - allocated[alu_instr->src[2].src.ssa->index].regClass() == s2) - size = 2; - else - size = 1; + size = 2; } else { if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) { type = RegType::vgpr; @@ -298,32 +281,14 @@ void init_context(isel_context *ctx, nir_shader *shader) break; case nir_op_mov: if (alu_instr->dest.dest.ssa.bit_size == 1) { - size = allocated[alu_instr->src[0].src.ssa->index].size(); + size = 2; } else { type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr; } break; - case nir_op_inot: - case nir_op_ixor: - if (alu_instr->dest.dest.ssa.bit_size == 1) { - size = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? 2 : 1; - break; - } else { - /* fallthrough */ - } default: if (alu_instr->dest.dest.ssa.bit_size == 1) { - if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) { - size = 2; - } else { - size = 2; - for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) { - if (allocated[alu_instr->src[i].src.ssa->index].regClass() == s1) { - size = 1; - break; - } - } - } + size = 2; } else { for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) { if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr) @@ -339,6 +304,8 @@ void init_context(isel_context *ctx, nir_shader *shader) unsigned size = nir_instr_as_load_const(instr)->def.num_components; if (nir_instr_as_load_const(instr)->def.bit_size == 64) size *= 2; + else if (nir_instr_as_load_const(instr)->def.bit_size == 1) + size *= 2; allocated[nir_instr_as_load_const(instr)->def.index] = Temp(0, RegClass(RegType::sgpr, size)); break; } @@ -365,6 +332,8 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_intrinsic_read_invocation: case nir_intrinsic_first_invocation: type = RegType::sgpr; + if (intrinsic->dest.ssa.bit_size == 1) + size = 2; break; case nir_intrinsic_ballot: type = RegType::sgpr; @@ -433,11 +402,11 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_intrinsic_masked_swizzle_amd: case nir_intrinsic_inclusive_scan: case nir_intrinsic_exclusive_scan: - if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) { + if (intrinsic->dest.ssa.bit_size == 1) { + size = 2; type = RegType::sgpr; - } else if (intrinsic->src[0].ssa->bit_size == 1) { + } else if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) { type = RegType::sgpr; - size = 2; } else { type = RegType::vgpr; } @@ -452,12 +421,12 @@ void init_context(isel_context *ctx, nir_shader *shader) size = 2; break; case nir_intrinsic_reduce: - if (nir_intrinsic_cluster_size(intrinsic) == 0 || - !ctx->divergent_vals[intrinsic->dest.ssa.index]) { + if (intrinsic->dest.ssa.bit_size == 1) { + size = 2; type = RegType::sgpr; - } else if (intrinsic->src[0].ssa->bit_size == 1) { + } else if (nir_intrinsic_cluster_size(intrinsic) == 0 || + !ctx->divergent_vals[intrinsic->dest.ssa.index]) { type = RegType::sgpr; - size = 2; } else { type = RegType::vgpr; } @@ -554,7 +523,7 @@ void init_context(isel_context *ctx, nir_shader *shader) if (phi->dest.ssa.bit_size == 1) { assert(size == 1 && "multiple components not yet supported on boolean phis."); type = RegType::sgpr; - size *= ctx->divergent_vals[phi->dest.ssa.index] ? 2 : 1; + size *= 2; allocated[phi->dest.ssa.index] = Temp(0, RegClass(type, size)); break; } diff --git a/src/amd/compiler/aco_lower_bool_phis.cpp b/src/amd/compiler/aco_lower_bool_phis.cpp index ac4663a2ce1..9e5374fe6a0 100644 --- a/src/amd/compiler/aco_lower_bool_phis.cpp +++ b/src/amd/compiler/aco_lower_bool_phis.cpp @@ -150,13 +150,6 @@ void lower_divergent_bool_phi(Program *program, Block *block, aco_ptroperands[i].isTemp()); Temp phi_src = phi->operands[i].getTemp(); - if (phi_src.regClass() == s1) { - Temp new_phi_src = bld.tmp(s2); - insert_before_logical_end(pred, - bld.sop2(aco_opcode::s_cselect_b64, Definition(new_phi_src), - Operand((uint32_t)-1), Operand(0u), bld.scc(phi_src)).get_ptr()); - phi_src = new_phi_src; - } assert(phi_src.regClass() == s2); Operand cur = get_ssa(program, pred->index, &state); @@ -218,6 +211,7 @@ void lower_bool_phis(Program* program) for (Block& block : program->blocks) { for (aco_ptr& phi : block.instructions) { if (phi->opcode == aco_opcode::p_phi) { + assert(phi->definitions[0].regClass() != s1); if (phi->definitions[0].regClass() == s2) lower_divergent_bool_phi(program, &block, phi); } else if (phi->opcode == aco_opcode::p_linear_phi) {