From a99ae1943d880702c8472ea9be11e4f92b6a440f Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 12 Aug 2020 15:58:32 +0100 Subject: [PATCH] aco: remove omod_success/clamp_success MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit This simplifies the optimizer and should make SDWA optimizations easier. No fossil-db changes on Navi. Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_optimizer.cpp | 247 +++++++++-------------------- 1 file changed, 77 insertions(+), 170 deletions(-) diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index a7cc0cdf6c8..a385b342345 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -95,9 +95,7 @@ enum Label { label_omod2 = 1 << 8, label_omod4 = 1 << 9, label_omod5 = 1 << 10, - label_omod_success = 1 << 11, label_clamp = 1 << 12, - label_clamp_success = 1 << 13, label_undefined = 1 << 14, label_vcc = 1 << 15, label_b2f = 1 << 16, @@ -115,10 +113,13 @@ enum Label { label_constant_16bit = 1 << 29, }; -static constexpr uint64_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success | - label_add_sub | label_bitwise | label_uniform_bitwise | label_minmax | label_vopc; +static constexpr uint64_t instr_usedef_labels = label_vec | label_mul | label_mad | label_add_sub | + label_bitwise | label_uniform_bitwise | label_minmax | label_vopc; +static constexpr uint64_t instr_mod_labels = label_omod2 | label_omod4 | label_omod5 | label_clamp; + +static constexpr uint64_t instr_labels = instr_usedef_labels | instr_mod_labels; static constexpr uint64_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool | - label_omod2 | label_omod4 | label_omod5 | label_clamp | label_scc_invert | label_b2i; + label_scc_invert | label_b2i; static constexpr uint32_t val_labels = label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal; struct ssa_info { @@ -133,11 +134,16 @@ struct ssa_info { void add_label(Label new_label) { - /* Since all labels which use "instr" use it for the same thing - * (indicating the defining instruction), there is no need to clear - * any other instr labels. */ - if (new_label & instr_labels) + /* Since all the instr_usedef_labels use instr for the same thing + * (indicating the defining instruction), there is usually no need to + * clear any other instr labels. */ + if (new_label & instr_usedef_labels) + label &= ~(instr_mod_labels | temp_labels | val_labels); /* instr, temp and val alias */ + + if (new_label & instr_mod_labels) { + label &= ~instr_labels; label &= ~(temp_labels | val_labels); /* instr, temp and val alias */ + } if (new_label & temp_labels) { label &= ~temp_labels; @@ -310,10 +316,10 @@ struct ssa_info { return label & label_mad; } - void set_omod2(Temp def) + void set_omod2(Instruction* mul) { add_label(label_omod2); - temp = def; + instr = mul; } bool is_omod2() @@ -321,10 +327,10 @@ struct ssa_info { return label & label_omod2; } - void set_omod4(Temp def) + void set_omod4(Instruction* mul) { add_label(label_omod4); - temp = def; + instr = mul; } bool is_omod4() @@ -332,10 +338,10 @@ struct ssa_info { return label & label_omod4; } - void set_omod5(Temp def) + void set_omod5(Instruction* mul) { add_label(label_omod5); - temp = def; + instr = mul; } bool is_omod5() @@ -343,21 +349,10 @@ struct ssa_info { return label & label_omod5; } - void set_omod_success(Instruction* omod_instr) - { - add_label(label_omod_success); - instr = omod_instr; - } - - bool is_omod_success() - { - return label & label_omod_success; - } - - void set_clamp(Temp def) + void set_clamp(Instruction *med3) { add_label(label_clamp); - temp = def; + instr = med3; } bool is_clamp() @@ -365,17 +360,6 @@ struct ssa_info { return label & label_clamp; } - void set_clamp_success(Instruction* clamp_instr) - { - add_label(label_clamp_success); - instr = clamp_instr; - } - - bool is_clamp_success() - { - return label & label_clamp_success; - } - void set_undefined() { add_label(label_undefined); @@ -650,10 +634,12 @@ void to_VOP3(opt_ctx& ctx, aco_ptr& instr) instr->definitions[i] = tmp->definitions[i]; if (instr->definitions[i].isTemp()) { ssa_info& info = ctx.info[instr->definitions[i].tempId()]; - if (info.label & instr_labels && info.instr == tmp.get()) + if (info.label & instr_usedef_labels && info.instr == tmp.get()) info.instr = instr.get(); } } + /* we don't need to update any instr_mod_labels because they either haven't + * been applied yet or this instruction isn't dead and so they've been ignored */ } /* only covers special cases */ @@ -1240,6 +1226,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr) } case aco_opcode::v_mul_f16: case aco_opcode::v_mul_f32: { /* omod */ + ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); + /* TODO: try to move the negate/abs modifier to the consumer instead */ if (instr->usesModifiers()) break; @@ -1249,11 +1237,11 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr) for (unsigned i = 0; i < 2; i++) { if (instr->operands[!i].isConstant() && instr->operands[i].isTemp()) { if (instr->operands[!i].constantValue() == (fp16 ? 0x4000 : 0x40000000)) { /* 2.0 */ - ctx.info[instr->operands[i].tempId()].set_omod2(instr->definitions[0].getTemp()); + ctx.info[instr->operands[i].tempId()].set_omod2(instr.get()); } else if (instr->operands[!i].constantValue() == (fp16 ? 0x4400 : 0x40800000)) { /* 4.0 */ - ctx.info[instr->operands[i].tempId()].set_omod4(instr->definitions[0].getTemp()); + ctx.info[instr->operands[i].tempId()].set_omod4(instr.get()); } else if (instr->operands[!i].constantValue() == (fp16 ? 0xb800 : 0x3f000000)) { /* 0.5 */ - ctx.info[instr->operands[i].tempId()].set_omod5(instr->definitions[0].getTemp()); + ctx.info[instr->operands[i].tempId()].set_omod5(instr.get()); } else if (instr->operands[!i].constantValue() == (fp16 ? 0x3c00 : 0x3f800000) && !(fp16 ? block.fp_mode.must_flush_denorms16_64 : block.fp_mode.must_flush_denorms32)) { /* 1.0 */ ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[i].getTemp()); @@ -1315,9 +1303,8 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr) else idx = i; } - if (found_zero && found_one && instr->operands[idx].isTemp()) { - ctx.info[instr->operands[idx].tempId()].set_clamp(instr->definitions[0].getTemp()); - } + if (found_zero && found_one && instr->operands[idx].isTemp()) + ctx.info[instr->operands[idx].tempId()].set_clamp(instr.get()); break; } case aco_opcode::v_cndmask_b32: @@ -1564,7 +1551,7 @@ void decrease_uses(opt_ctx &ctx, Instruction* instr) Instruction *follow_operand(opt_ctx &ctx, Operand op, bool ignore_uses=false) { - if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_labels)) + if (!op.isTemp() || !(ctx.info[op.tempId()].label & instr_usedef_labels)) return nullptr; if (!ignore_uses && ctx.uses[op.tempId()] > 1) return nullptr; @@ -1991,9 +1978,6 @@ void create_vop3_for_op3(opt_ctx& ctx, aco_opcode opcode, aco_ptr& bool combine_three_valu_op(opt_ctx& ctx, aco_ptr& instr, aco_opcode op2, aco_opcode new_op, const char *shuffle, uint8_t ops) { - uint64_t omod_clamp = ctx.info[instr->definitions[0].tempId()].label & - (label_omod_success | label_clamp_success); - for (unsigned swap = 0; swap < 2; swap++) { if (!((1 << swap) & ops)) continue; @@ -2007,10 +1991,6 @@ bool combine_three_valu_op(opt_ctx& ctx, aco_ptr& instr, aco_opcode &clamp, &omod, NULL, NULL, NULL)) { ctx.uses[instr->operands[swap].tempId()]--; create_vop3_for_op3(ctx, new_op, instr, operands, neg, abs, opsel, clamp, omod); - if (omod_clamp & label_omod_success) - ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get()); - if (omod_clamp & label_clamp_success) - ctx.info[instr->definitions[0].tempId()].set_clamp_success(instr.get()); return true; } } @@ -2022,9 +2002,6 @@ bool combine_minmax(opt_ctx& ctx, aco_ptr& instr, aco_opcode opposi if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2)) return true; - uint64_t omod_clamp = ctx.info[instr->definitions[0].tempId()].label & - (label_omod_success | label_clamp_success); - /* min(-max(a, b), c) -> min3(-a, -b, c) * * max(-min(a, b), c) -> max3(-a, -b, c) */ for (unsigned swap = 0; swap < 2; swap++) { @@ -2041,10 +2018,6 @@ bool combine_minmax(opt_ctx& ctx, aco_ptr& instr, aco_opcode opposi neg[1] = true; neg[2] = true; create_vop3_for_op3(ctx, minmax3, instr, operands, neg, abs, opsel, clamp, omod); - if (omod_clamp & label_omod_success) - ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get()); - if (omod_clamp & label_clamp_success) - ctx.info[instr->definitions[0].tempId()].set_clamp_success(instr.get()); return true; } } @@ -2276,9 +2249,6 @@ bool combine_clamp(opt_ctx& ctx, aco_ptr& instr, else return false; - uint64_t omod_clamp = ctx.info[instr->definitions[0].tempId()].label & - (label_omod_success | label_clamp_success); - for (unsigned swap = 0; swap < 2; swap++) { Operand operands[3]; bool neg[3], abs[3], clamp; @@ -2367,10 +2337,6 @@ bool combine_clamp(opt_ctx& ctx, aco_ptr& instr, ctx.uses[instr->operands[swap].tempId()]--; create_vop3_for_op3(ctx, med, instr, operands, neg, abs, opsel, clamp, omod); - if (omod_clamp & label_omod_success) - ctx.info[instr->definitions[0].tempId()].set_omod_success(instr.get()); - if (omod_clamp & label_clamp_success) - ctx.info[instr->definitions[0].tempId()].set_clamp_success(instr.get()); return true; } @@ -2460,117 +2426,59 @@ void apply_sgprs(opt_ctx &ctx, aco_ptr& instr) } } -bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr& instr) +bool apply_omod_clamp_helper(opt_ctx &ctx, aco_ptr& instr, ssa_info& def_info) { - /* check if we could apply omod on predecessor */ - if (instr->opcode == aco_opcode::v_mul_f32 || instr->opcode == aco_opcode::v_mul_f16) { - bool op0 = instr->operands[0].isTemp() && ctx.info[instr->operands[0].tempId()].is_omod_success(); - bool op1 = instr->operands[1].isTemp() && ctx.info[instr->operands[1].tempId()].is_omod_success(); - if (op0 || op1) { - unsigned idx = op0 ? 0 : 1; - Instruction* omod_instr = ctx.info[instr->operands[idx].tempId()].instr; - - /* omod was successfully applied */ - - /* check if we have an additional clamp modifier */ - if (ctx.info[instr->definitions[0].tempId()].is_clamp() && ctx.uses[instr->definitions[0].tempId()] == 1 && - ctx.uses[ctx.info[instr->definitions[0].tempId()].temp.id()]) { - /* if the omod instruction is v_mad, we also have to change the original add */ - if (ctx.info[instr->operands[idx].tempId()].is_mad()) { - uint32_t mad_info_idx = ctx.info[instr->operands[idx].tempId()].instr->pass_flags; - Instruction* add_instr = ctx.mad_infos[mad_info_idx].add_instr.get(); - static_cast(add_instr)->clamp = true; - } + to_VOP3(ctx, instr); - static_cast(omod_instr)->clamp = true; - ctx.info[instr->definitions[0].tempId()].set_clamp_success(omod_instr); - } + if (!def_info.is_clamp() && (static_cast(instr.get())->clamp || + static_cast(instr.get())->omod)) + return false; - /* if the omod instruction is v_mad, we also have to change the original add */ - if (ctx.info[instr->operands[idx].tempId()].is_mad()) { - uint32_t mad_info_idx = ctx.info[instr->operands[idx].tempId()].instr->pass_flags; - Instruction* add_instr = ctx.mad_infos[mad_info_idx].add_instr.get(); - add_instr->definitions[0] = instr->definitions[0]; - ctx.info[instr->definitions[0].tempId()].set_mad(omod_instr, mad_info_idx); - } + if (def_info.is_omod2()) + static_cast(instr.get())->omod = 1; + else if (def_info.is_omod4()) + static_cast(instr.get())->omod = 2; + else if (def_info.is_omod5()) + static_cast(instr.get())->omod = 3; + else if (def_info.is_clamp()) + static_cast(instr.get())->clamp = true; - /* change definition ssa-id of modified instruction */ - omod_instr->definitions[0] = instr->definitions[0]; + return true; +} - /* change the definition of instr to something unused, e.g. the original omod def */ - instr->definitions[0] = Definition(instr->operands[idx].getTemp()); - ctx.uses[instr->definitions[0].tempId()] = 0; - return true; - } - if (!ctx.info[instr->definitions[0].tempId()].label) { - /* in all other cases, label this instruction as option for multiply-add */ - ctx.info[instr->definitions[0].tempId()].set_mul(instr.get()); - } - } +/* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */ +bool apply_omod_clamp(opt_ctx &ctx, Block& block, aco_ptr& instr) +{ + if (instr->definitions.empty() || ctx.uses[instr->definitions[0].tempId()] != 1 || + !instr_info.can_use_output_modifiers[(int)instr->opcode]) + return false; - /* check if we could apply clamp on predecessor */ - if (instr->opcode == aco_opcode::v_med3_f32 || instr->opcode == aco_opcode::v_med3_f16) { - bool is_fp16 = instr->opcode == aco_opcode::v_med3_f16; - unsigned idx = 0; - bool found_zero = false, found_one = false; - for (unsigned i = 0; i < 3; i++) - { - if (instr->operands[i].constantEquals(0)) - found_zero = true; - else if (instr->operands[i].constantEquals(is_fp16 ? 0x3c00 : 0x3f800000)) /* 1.0 */ - found_one = true; - else - idx = i; - } - if (found_zero && found_one && instr->operands[idx].isTemp() && - ctx.info[instr->operands[idx].tempId()].is_clamp_success()) { - Instruction* clamp_instr = ctx.info[instr->operands[idx].tempId()].instr; + if (!can_use_VOP3(ctx, instr)) + return false; - /* clamp was successfully applied */ - /* if the clamp instruction is v_mad, we also have to change the original add */ - if (ctx.info[instr->operands[idx].tempId()].is_mad()) { - uint32_t mad_info_idx = ctx.info[instr->operands[idx].tempId()].instr->pass_flags; - Instruction* add_instr = ctx.mad_infos[mad_info_idx].add_instr.get(); - add_instr->definitions[0] = instr->definitions[0]; + /* omod has no effect if denormals are enabled */ + bool can_use_omod = (instr->definitions[0].bytes() == 4 ? block.fp_mode.denorm32 : block.fp_mode.denorm16_64) == 0; + ssa_info& def_info = ctx.info[instr->definitions[0].tempId()]; - ctx.info[instr->definitions[0].tempId()].set_mad(clamp_instr, mad_info_idx); - } - /* change definition ssa-id of modified instruction */ - clamp_instr->definitions[0] = instr->definitions[0]; + uint64_t omod_labels = label_omod2 | label_omod4 | label_omod5; + if (!def_info.is_clamp() && !(can_use_omod && (def_info.label & omod_labels))) + return false; + /* if the omod/clamp instruction is dead, then the single user of this + * instruction is a different instruction */ + if (!ctx.uses[def_info.instr->definitions[0].tempId()]) + return false; - /* change the definition of instr to something unused, e.g. the original omod def */ - instr->definitions[0] = Definition(instr->operands[idx].getTemp()); - ctx.uses[instr->definitions[0].tempId()] = 0; - return true; - } - } + /* MADs/FMAs are created later, so we don't have to update the original add */ + assert(!ctx.info[instr->definitions[0].tempId()].is_mad()); - /* omod has no effect if denormals are enabled */ - /* apply omod / clamp modifiers if the def is used only once and the instruction can have modifiers */ - if (!instr->definitions.empty() && ctx.uses[instr->definitions[0].tempId()] == 1 && - can_use_VOP3(ctx, instr) && instr_info.can_use_output_modifiers[(int)instr->opcode]) { - bool can_use_omod = (instr->definitions[0].bytes() == 4 ? block.fp_mode.denorm32 : block.fp_mode.denorm16_64) == 0; - ssa_info& def_info = ctx.info[instr->definitions[0].tempId()]; - if (can_use_omod && def_info.is_omod2() && ctx.uses[def_info.temp.id()]) { - to_VOP3(ctx, instr); - static_cast(instr.get())->omod = 1; - def_info.set_omod_success(instr.get()); - } else if (can_use_omod && def_info.is_omod4() && ctx.uses[def_info.temp.id()]) { - to_VOP3(ctx, instr); - static_cast(instr.get())->omod = 2; - def_info.set_omod_success(instr.get()); - } else if (can_use_omod && def_info.is_omod5() && ctx.uses[def_info.temp.id()]) { - to_VOP3(ctx, instr); - static_cast(instr.get())->omod = 3; - def_info.set_omod_success(instr.get()); - } else if (def_info.is_clamp() && ctx.uses[def_info.temp.id()]) { - to_VOP3(ctx, instr); - static_cast(instr.get())->clamp = true; - def_info.set_clamp_success(instr.get()); - } - } + if (!apply_omod_clamp_helper(ctx, instr, def_info)) + return false; - return false; + std::swap(instr->definitions[0], def_info.instr->definitions[0]); + ctx.info[instr->definitions[0].tempId()].label &= label_clamp; + ctx.uses[def_info.instr->definitions[0].tempId()]--; + + return true; } // TODO: we could possibly move the whole label_instruction pass to combine_instruction: @@ -2584,8 +2492,7 @@ void combine_instruction(opt_ctx &ctx, Block& block, aco_ptr& instr if (instr->isVALU()) { if (can_apply_sgprs(instr)) apply_sgprs(ctx, instr); - if (apply_omod_clamp(ctx, block, instr)) - return; + while (apply_omod_clamp(ctx, block, instr)) ; } if (ctx.info[instr->definitions[0].tempId()].is_vcc_hint()) { -- 2.30.2