From b84d59af50a53959fcde232ee2682e77569a7da2 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 4 Dec 2019 20:18:05 +0000 Subject: [PATCH] aco: add SDWA_instruction MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Reviewed-By: Timur Kristóf Part-of: --- src/amd/compiler/aco_assembler.cpp | 43 ++++++++++++++++++++- src/amd/compiler/aco_ir.h | 54 ++++++++++++++++++++++++++ src/amd/compiler/aco_print_asm.cpp | 3 ++ src/amd/compiler/aco_print_ir.cpp | 62 ++++++++++++++++++++++++++++-- src/amd/compiler/aco_validate.cpp | 52 ++++++++++++++++++++++++- 5 files changed, 207 insertions(+), 7 deletions(-) diff --git a/src/amd/compiler/aco_assembler.cpp b/src/amd/compiler/aco_assembler.cpp index c46208b13b4..33bc612bac4 100644 --- a/src/amd/compiler/aco_assembler.cpp +++ b/src/amd/compiler/aco_assembler.cpp @@ -547,7 +547,7 @@ void emit_instruction(asm_context& ctx, std::vector& out, Instruction* /* first emit the instruction without the DPP operand */ Operand dpp_op = instr->operands[0]; instr->operands[0] = Operand(PhysReg{250}, v1); - instr->format = (Format) ((uint32_t) instr->format & ~(1 << 14)); + instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::DPP); emit_instruction(ctx, out, instr); DPP_instruction* dpp = static_cast(instr); uint32_t encoding = (0xF & dpp->row_mask) << 28; @@ -561,6 +561,47 @@ void emit_instruction(asm_context& ctx, std::vector& out, Instruction* encoding |= (0xFF) & dpp_op.physReg(); out.push_back(encoding); return; + } else if (instr->isSDWA()) { + /* first emit the instruction without the SDWA operand */ + Operand sdwa_op = instr->operands[0]; + instr->operands[0] = Operand(PhysReg{249}, v1); + instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::SDWA); + emit_instruction(ctx, out, instr); + + SDWA_instruction* sdwa = static_cast(instr); + uint32_t encoding = 0; + + if ((uint16_t)instr->format & (uint16_t)Format::VOPC) { + if (instr->definitions[0].physReg() != vcc) { + encoding |= instr->definitions[0].physReg() << 8; + encoding |= 1 << 15; + } + encoding |= (sdwa->clamp ? 1 : 0) << 13; + } else { + encoding |= (uint32_t)(sdwa->dst_sel & sdwa_asuint) << 8; + uint32_t dst_u = sdwa->dst_sel & sdwa_sext ? 1 : 0; + encoding |= dst_u << 11; + encoding |= (sdwa->clamp ? 1 : 0) << 13; + encoding |= sdwa->omod << 14; + } + + encoding |= (uint32_t)(sdwa->sel[0] & sdwa_asuint) << 16; + encoding |= sdwa->sel[0] & sdwa_sext ? 1 << 19 : 0; + encoding |= sdwa->abs[0] << 21; + encoding |= sdwa->neg[0] << 20; + + if (instr->operands.size() >= 2) { + encoding |= (uint32_t)(sdwa->sel[1] & sdwa_asuint) << 24; + encoding |= sdwa->sel[1] & sdwa_sext ? 1 << 27 : 0; + encoding |= sdwa->abs[1] << 29; + encoding |= sdwa->neg[1] << 28; + } + + encoding |= 0xFF & sdwa_op.physReg(); + encoding |= (sdwa_op.physReg() < 256) << 23; + if (instr->operands.size() >= 2) + encoding |= (instr->operands[1].physReg() < 256) << 31; + out.push_back(encoding); } else { unreachable("unimplemented instruction format"); } diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h index 05a9754c6b1..c8b5c00e1f2 100644 --- a/src/amd/compiler/aco_ir.h +++ b/src/amd/compiler/aco_ir.h @@ -169,6 +169,11 @@ constexpr Format asVOP3(Format format) { return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format); }; +constexpr Format asSDWA(Format format) { + assert(format == Format::VOP1 || format == Format::VOP2 || format == Format::VOPC); + return (Format) ((uint32_t) Format::SDWA | (uint32_t) format); +} + enum class RegType { none = 0, sgpr, @@ -841,6 +846,55 @@ struct DPP_instruction : public Instruction { bool bound_ctrl : 1; }; +enum sdwa_sel : uint8_t { + /* masks */ + sdwa_wordnum = 0x1, + sdwa_bytenum = 0x3, + sdwa_asuint = 0x7, + + /* flags */ + sdwa_isword = 0x4, + sdwa_sext = 0x8, + + /* specific values */ + sdwa_ubyte0 = 0, + sdwa_ubyte1 = 1, + sdwa_ubyte2 = 2, + sdwa_ubyte3 = 3, + sdwa_uword0 = sdwa_isword | 0, + sdwa_uword1 = sdwa_isword | 1, + sdwa_udword = 6, + + sdwa_sbyte0 = sdwa_ubyte0 | sdwa_sext, + sdwa_sbyte1 = sdwa_ubyte1 | sdwa_sext, + sdwa_sbyte2 = sdwa_ubyte2 | sdwa_sext, + sdwa_sbyte3 = sdwa_ubyte3 | sdwa_sext, + sdwa_sword0 = sdwa_uword0 | sdwa_sext, + sdwa_sword1 = sdwa_uword1 | sdwa_sext, + sdwa_sdword = sdwa_udword | sdwa_sext, +}; + +/** + * Sub-Dword Addressing Format: + * This format can be used for VOP1, VOP2 or VOPC instructions. + * + * omod and SGPR/constant operands are only available on GFX9+. For VOPC, + * the definition doesn't have to be VCC on GFX9+. + * + */ +struct SDWA_instruction : public Instruction { + /* these destination modifiers aren't available with VOPC except for + * clamp on GFX8 */ + unsigned dst_sel:4; + bool dst_preserve:1; + bool clamp:1; + unsigned omod:2; /* GFX9+ */ + + unsigned sel[2]; + bool neg[2]; + bool abs[2]; +}; + struct Interp_instruction : public Instruction { uint8_t attribute; uint8_t component; diff --git a/src/amd/compiler/aco_print_asm.cpp b/src/amd/compiler/aco_print_asm.cpp index fead382c7cf..e2dbc5bd8b6 100644 --- a/src/amd/compiler/aco_print_asm.cpp +++ b/src/amd/compiler/aco_print_asm.cpp @@ -140,6 +140,9 @@ void print_asm(Program *program, std::vector& binary, if (!l && program->chip_class == GFX9 && ((binary[pos] & 0xffff8000) == 0xd1348000)) { /* not actually an invalid instruction */ out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_add_u32_e64 + clamp"; new_pos = pos + 2; + } else if (program->chip_class == GFX10 && l == 4 && ((binary[pos] & 0xfe0001ff) == 0x020000f9)) { + out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_cndmask_b32 + sdwa"; + new_pos = pos + 2; } else if (!l) { out << std::left << std::setw(align_width) << std::setfill(' ') << "(invalid instruction)"; new_pos = pos + 1; diff --git a/src/amd/compiler/aco_print_ir.cpp b/src/amd/compiler/aco_print_ir.cpp index 7564b52c17c..43afe0a77c0 100644 --- a/src/amd/compiler/aco_print_ir.cpp +++ b/src/amd/compiler/aco_print_ir.cpp @@ -528,7 +528,38 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output) if (dpp->bound_ctrl) fprintf(output, " bound_ctrl:1"); } else if ((int)instr->format & (int)Format::SDWA) { - fprintf(output, " (printing unimplemented)"); + SDWA_instruction* sdwa = static_cast(instr); + switch (sdwa->omod) { + case 1: + fprintf(output, " *2"); + break; + case 2: + fprintf(output, " *4"); + break; + case 3: + fprintf(output, " *0.5"); + break; + } + if (sdwa->clamp) + fprintf(output, " clamp"); + switch (sdwa->dst_sel & sdwa_asuint) { + case sdwa_udword: + break; + case sdwa_ubyte0: + case sdwa_ubyte1: + case sdwa_ubyte2: + case sdwa_ubyte3: + fprintf(output, " dst_sel:%sbyte%u", sdwa->dst_sel & sdwa_sext ? "s" : "u", + sdwa->dst_sel & sdwa_bytenum); + break; + case sdwa_uword0: + case sdwa_uword1: + fprintf(output, " dst_sel:%sword%u", sdwa->dst_sel & sdwa_sext ? "s" : "u", + sdwa->dst_sel & sdwa_wordnum); + break; + } + if (sdwa->dst_preserve) + fprintf(output, " dst_preserve"); } } @@ -546,23 +577,33 @@ void aco_print_instr(struct Instruction *instr, FILE *output) if (instr->operands.size()) { bool abs[instr->operands.size()]; bool neg[instr->operands.size()]; + uint8_t sel[instr->operands.size()]; if ((int)instr->format & (int)Format::VOP3A) { VOP3A_instruction* vop3 = static_cast(instr); for (unsigned i = 0; i < instr->operands.size(); ++i) { abs[i] = vop3->abs[i]; neg[i] = vop3->neg[i]; + sel[i] = sdwa_udword; } } else if (instr->isDPP()) { DPP_instruction* dpp = static_cast(instr); - assert(instr->operands.size() <= 2); for (unsigned i = 0; i < instr->operands.size(); ++i) { - abs[i] = dpp->abs[i]; - neg[i] = dpp->neg[i]; + abs[i] = i < 2 ? dpp->abs[i] : false; + neg[i] = i < 2 ? dpp->neg[i] : false; + sel[i] = sdwa_udword; + } + } else if (instr->isSDWA()) { + SDWA_instruction* sdwa = static_cast(instr); + for (unsigned i = 0; i < instr->operands.size(); ++i) { + abs[i] = i < 2 ? sdwa->abs[i] : false; + neg[i] = i < 2 ? sdwa->neg[i] : false; + sel[i] = i < 2 ? sdwa->sel[i] : sdwa_udword; } } else { for (unsigned i = 0; i < instr->operands.size(); ++i) { abs[i] = false; neg[i] = false; + sel[i] = sdwa_udword; } } for (unsigned i = 0; i < instr->operands.size(); ++i) { @@ -575,7 +616,20 @@ void aco_print_instr(struct Instruction *instr, FILE *output) fprintf(output, "-"); if (abs[i]) fprintf(output, "|"); + if (sel[i] & sdwa_sext) + fprintf(output, "sext("); print_operand(&instr->operands[i], output); + if (sel[i] & sdwa_sext) + fprintf(output, ")"); + if ((sel[i] & sdwa_asuint) == sdwa_udword) { + /* print nothing */ + } else if (sel[i] & sdwa_isword) { + unsigned index = sel[i] & sdwa_wordnum; + fprintf(output, "[%u:%u]", index * 16, index * 16 + 15); + } else { + unsigned index = sel[i] & sdwa_bytenum; + fprintf(output, "[%u:%u]", index * 8, index * 8 + 7); + } if (abs[i]) fprintf(output, "|"); } diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index e967f0ca9e7..4bbce14a86a 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -93,6 +93,50 @@ void validate(Program* program, FILE * output) "Format cannot have VOP3A/VOP3B applied", instr.get()); } + /* check SDWA */ + if (instr->isSDWA()) { + check(base_format == Format::VOP2 || + base_format == Format::VOP1 || + base_format == Format::VOPC, + "Format cannot have SDWA applied", instr.get()); + + check(program->chip_class >= GFX8, "SDWA is GFX8+ only", instr.get()); + + SDWA_instruction *sdwa = static_cast(instr.get()); + check(sdwa->omod == 0 || program->chip_class >= GFX9, "SDWA omod only supported on GFX9+", instr.get()); + if (base_format == Format::VOPC) { + check(sdwa->clamp == false || program->chip_class == GFX8, "SDWA VOPC clamp only supported on GFX8", instr.get()); + check((instr->definitions[0].isFixed() && instr->definitions[0].physReg() == vcc) || + program->chip_class >= GFX9, + "SDWA+VOPC definition must be fixed to vcc on GFX8", instr.get()); + } + + if (instr->operands.size() >= 3) { + check(instr->operands[2].isFixed() && instr->operands[2].physReg() == vcc, + "3rd operand must be fixed to vcc with SDWA", instr.get()); + } + if (instr->definitions.size() >= 2) { + check(instr->definitions[1].isFixed() && instr->definitions[1].physReg() == vcc, + "2nd definition must be fixed to vcc with SDWA", instr.get()); + } + + check(instr->opcode != aco_opcode::v_madmk_f32 && + instr->opcode != aco_opcode::v_madak_f32 && + instr->opcode != aco_opcode::v_madmk_f16 && + instr->opcode != aco_opcode::v_madak_f16 && + instr->opcode != aco_opcode::v_readfirstlane_b32 && + instr->opcode != aco_opcode::v_clrexcp && + instr->opcode != aco_opcode::v_swap_b32, + "SDWA can't be used with this opcode", instr.get()); + if (program->chip_class != GFX8) { + check(instr->opcode != aco_opcode::v_mac_f32 && + instr->opcode != aco_opcode::v_mac_f16 && + instr->opcode != aco_opcode::v_fmac_f32 && + instr->opcode != aco_opcode::v_fmac_f16, + "SDWA can't be used with this opcode", instr.get()); + } + } + /* check for undefs */ for (unsigned i = 0; i < instr->operands.size(); i++) { if (instr->operands[i].isUndefined()) { @@ -137,6 +181,10 @@ void validate(Program* program, FILE * output) if (program->chip_class >= GFX10 && !is_shift64) const_bus_limit = 2; + uint32_t scalar_mask = instr->isVOP3() ? 0x7 : 0x5; + if (instr->isSDWA()) + scalar_mask = program->chip_class >= GFX9 ? 0x7 : 0x4; + check(instr->definitions[0].getTemp().type() == RegType::vgpr || (int) instr->format & (int) Format::VOPC || instr->opcode == aco_opcode::v_readfirstlane_b32 || @@ -158,7 +206,7 @@ void validate(Program* program, FILE * output) continue; } if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) { - check(i != 1 || instr->isVOP3(), "Wrong source position for SGPR argument", instr.get()); + check(scalar_mask & (1 << i), "Wrong source position for SGPR argument", instr.get()); if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) { if (num_sgprs < 2) @@ -167,7 +215,7 @@ void validate(Program* program, FILE * output) } if (op.isConstant() && !op.isLiteral()) - check(i == 0 || instr->isVOP3(), "Wrong source position for constant argument", instr.get()); + check(scalar_mask & (1 << i), "Wrong source position for constant argument", instr.get()); } check(num_sgprs + (literal.isUndefined() ? 0 : 1) <= const_bus_limit, "Too many SGPRs/literals", instr.get()); } -- 2.30.2