aco: add SDWA_instruction
authorRhys Perry <pendingchaos02@gmail.com>
Wed, 4 Dec 2019 20:18:05 +0000 (20:18 +0000)
committerDaniel Schürmann <daniel@schuermann.dev>
Fri, 3 Apr 2020 22:13:15 +0000 (23:13 +0100)
Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-By: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4002>

src/amd/compiler/aco_assembler.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_print_asm.cpp
src/amd/compiler/aco_print_ir.cpp
src/amd/compiler/aco_validate.cpp

index c46208b13b402166409b5eef6268e6be2a7b4ff8..33bc612bac48efe6e55afcabce95b2c9858df899 100644 (file)
@@ -547,7 +547,7 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
          /* first emit the instruction without the DPP operand */
          Operand dpp_op = instr->operands[0];
          instr->operands[0] = Operand(PhysReg{250}, v1);
-         instr->format = (Format) ((uint32_t) instr->format & ~(1 << 14));
+         instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::DPP);
          emit_instruction(ctx, out, instr);
          DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
          uint32_t encoding = (0xF & dpp->row_mask) << 28;
@@ -561,6 +561,47 @@ void emit_instruction(asm_context& ctx, std::vector<uint32_t>& out, Instruction*
          encoding |= (0xFF) & dpp_op.physReg();
          out.push_back(encoding);
          return;
+      } else if (instr->isSDWA()) {
+         /* first emit the instruction without the SDWA operand */
+         Operand sdwa_op = instr->operands[0];
+         instr->operands[0] = Operand(PhysReg{249}, v1);
+         instr->format = (Format) ((uint16_t) instr->format & ~(uint16_t)Format::SDWA);
+         emit_instruction(ctx, out, instr);
+
+         SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
+         uint32_t encoding = 0;
+
+         if ((uint16_t)instr->format & (uint16_t)Format::VOPC) {
+            if (instr->definitions[0].physReg() != vcc) {
+               encoding |= instr->definitions[0].physReg() << 8;
+               encoding |= 1 << 15;
+            }
+            encoding |= (sdwa->clamp ? 1 : 0) << 13;
+         } else {
+            encoding |= (uint32_t)(sdwa->dst_sel & sdwa_asuint) << 8;
+            uint32_t dst_u = sdwa->dst_sel & sdwa_sext ? 1 : 0;
+            encoding |= dst_u << 11;
+            encoding |= (sdwa->clamp ? 1 : 0) << 13;
+            encoding |= sdwa->omod << 14;
+         }
+
+         encoding |= (uint32_t)(sdwa->sel[0] & sdwa_asuint) << 16;
+         encoding |= sdwa->sel[0] & sdwa_sext ? 1 << 19 : 0;
+         encoding |= sdwa->abs[0] << 21;
+         encoding |= sdwa->neg[0] << 20;
+
+         if (instr->operands.size() >= 2) {
+            encoding |= (uint32_t)(sdwa->sel[1] & sdwa_asuint) << 24;
+            encoding |= sdwa->sel[1] & sdwa_sext ? 1 << 27 : 0;
+            encoding |= sdwa->abs[1] << 29;
+            encoding |= sdwa->neg[1] << 28;
+         }
+
+         encoding |= 0xFF & sdwa_op.physReg();
+         encoding |= (sdwa_op.physReg() < 256) << 23;
+         if (instr->operands.size() >= 2)
+            encoding |= (instr->operands[1].physReg() < 256) << 31;
+         out.push_back(encoding);
       } else {
          unreachable("unimplemented instruction format");
       }
index 05a9754c6b19270c0e02a71002ea958fc32f11af..c8b5c00e1f2175e9ba66c710008cf6267c6f24e0 100644 (file)
@@ -169,6 +169,11 @@ constexpr Format asVOP3(Format format) {
    return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format);
 };
 
+constexpr Format asSDWA(Format format) {
+   assert(format == Format::VOP1 || format == Format::VOP2 || format == Format::VOPC);
+   return (Format) ((uint32_t) Format::SDWA | (uint32_t) format);
+}
+
 enum class RegType {
    none = 0,
    sgpr,
@@ -841,6 +846,55 @@ struct DPP_instruction : public Instruction {
    bool bound_ctrl : 1;
 };
 
+enum sdwa_sel : uint8_t {
+    /* masks */
+    sdwa_wordnum = 0x1,
+    sdwa_bytenum = 0x3,
+    sdwa_asuint = 0x7,
+
+    /* flags */
+    sdwa_isword = 0x4,
+    sdwa_sext = 0x8,
+
+    /* specific values */
+    sdwa_ubyte0 = 0,
+    sdwa_ubyte1 = 1,
+    sdwa_ubyte2 = 2,
+    sdwa_ubyte3 = 3,
+    sdwa_uword0 = sdwa_isword | 0,
+    sdwa_uword1 = sdwa_isword | 1,
+    sdwa_udword = 6,
+
+    sdwa_sbyte0 = sdwa_ubyte0 | sdwa_sext,
+    sdwa_sbyte1 = sdwa_ubyte1 | sdwa_sext,
+    sdwa_sbyte2 = sdwa_ubyte2 | sdwa_sext,
+    sdwa_sbyte3 = sdwa_ubyte3 | sdwa_sext,
+    sdwa_sword0 = sdwa_uword0 | sdwa_sext,
+    sdwa_sword1 = sdwa_uword1 | sdwa_sext,
+    sdwa_sdword = sdwa_udword | sdwa_sext,
+};
+
+/**
+ * Sub-Dword Addressing Format:
+ * This format can be used for VOP1, VOP2 or VOPC instructions.
+ *
+ * omod and SGPR/constant operands are only available on GFX9+. For VOPC,
+ * the definition doesn't have to be VCC on GFX9+.
+ *
+ */
+struct SDWA_instruction : public Instruction {
+   /* these destination modifiers aren't available with VOPC except for
+    * clamp on GFX8 */
+   unsigned dst_sel:4;
+   bool dst_preserve:1;
+   bool clamp:1;
+   unsigned omod:2; /* GFX9+ */
+
+   unsigned sel[2];
+   bool neg[2];
+   bool abs[2];
+};
+
 struct Interp_instruction : public Instruction {
    uint8_t attribute;
    uint8_t component;
index fead382c7cf1a6d6e895f534b2d8e2b29901aca8..e2dbc5bd8b6119331e42c2f8c4df268e26c6a5e4 100644 (file)
@@ -140,6 +140,9 @@ void print_asm(Program *program, std::vector<uint32_t>& binary,
       if (!l && program->chip_class == GFX9 && ((binary[pos] & 0xffff8000) == 0xd1348000)) { /* not actually an invalid instruction */
          out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_add_u32_e64 + clamp";
          new_pos = pos + 2;
+      } else if (program->chip_class == GFX10 && l == 4 && ((binary[pos] & 0xfe0001ff) == 0x020000f9)) {
+         out << std::left << std::setw(align_width) << std::setfill(' ') << "\tv_cndmask_b32 + sdwa";
+         new_pos = pos + 2;
       } else if (!l) {
          out << std::left << std::setw(align_width) << std::setfill(' ') << "(invalid instruction)";
          new_pos = pos + 1;
index 7564b52c17ce8e66931f4f78d8b3998b37a1b86f..43afe0a77c0867d76dce979bdd50fef798958b2a 100644 (file)
@@ -528,7 +528,38 @@ static void print_instr_format_specific(struct Instruction *instr, FILE *output)
       if (dpp->bound_ctrl)
          fprintf(output, " bound_ctrl:1");
    } else if ((int)instr->format & (int)Format::SDWA) {
-      fprintf(output, " (printing unimplemented)");
+      SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
+      switch (sdwa->omod) {
+      case 1:
+         fprintf(output, " *2");
+         break;
+      case 2:
+         fprintf(output, " *4");
+         break;
+      case 3:
+         fprintf(output, " *0.5");
+         break;
+      }
+      if (sdwa->clamp)
+         fprintf(output, " clamp");
+      switch (sdwa->dst_sel & sdwa_asuint) {
+      case sdwa_udword:
+         break;
+      case sdwa_ubyte0:
+      case sdwa_ubyte1:
+      case sdwa_ubyte2:
+      case sdwa_ubyte3:
+         fprintf(output, " dst_sel:%sbyte%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
+                 sdwa->dst_sel & sdwa_bytenum);
+         break;
+      case sdwa_uword0:
+      case sdwa_uword1:
+         fprintf(output, " dst_sel:%sword%u", sdwa->dst_sel & sdwa_sext ? "s" : "u",
+                 sdwa->dst_sel & sdwa_wordnum);
+         break;
+      }
+      if (sdwa->dst_preserve)
+         fprintf(output, " dst_preserve");
    }
 }
 
@@ -546,23 +577,33 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
    if (instr->operands.size()) {
       bool abs[instr->operands.size()];
       bool neg[instr->operands.size()];
+      uint8_t sel[instr->operands.size()];
       if ((int)instr->format & (int)Format::VOP3A) {
          VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(instr);
          for (unsigned i = 0; i < instr->operands.size(); ++i) {
             abs[i] = vop3->abs[i];
             neg[i] = vop3->neg[i];
+            sel[i] = sdwa_udword;
          }
       } else if (instr->isDPP()) {
          DPP_instruction* dpp = static_cast<DPP_instruction*>(instr);
-         assert(instr->operands.size() <= 2);
          for (unsigned i = 0; i < instr->operands.size(); ++i) {
-            abs[i] = dpp->abs[i];
-            neg[i] = dpp->neg[i];
+            abs[i] = i < 2 ? dpp->abs[i] : false;
+            neg[i] = i < 2 ? dpp->neg[i] : false;
+            sel[i] = sdwa_udword;
+         }
+      } else if (instr->isSDWA()) {
+         SDWA_instruction* sdwa = static_cast<SDWA_instruction*>(instr);
+         for (unsigned i = 0; i < instr->operands.size(); ++i) {
+            abs[i] = i < 2 ? sdwa->abs[i] : false;
+            neg[i] = i < 2 ? sdwa->neg[i] : false;
+            sel[i] = i < 2 ? sdwa->sel[i] : sdwa_udword;
          }
       } else {
          for (unsigned i = 0; i < instr->operands.size(); ++i) {
             abs[i] = false;
             neg[i] = false;
+            sel[i] = sdwa_udword;
          }
       }
       for (unsigned i = 0; i < instr->operands.size(); ++i) {
@@ -575,7 +616,20 @@ void aco_print_instr(struct Instruction *instr, FILE *output)
             fprintf(output, "-");
          if (abs[i])
             fprintf(output, "|");
+         if (sel[i] & sdwa_sext)
+            fprintf(output, "sext(");
          print_operand(&instr->operands[i], output);
+         if (sel[i] & sdwa_sext)
+            fprintf(output, ")");
+         if ((sel[i] & sdwa_asuint) == sdwa_udword) {
+            /* print nothing */
+         } else if (sel[i] & sdwa_isword) {
+            unsigned index = sel[i] & sdwa_wordnum;
+            fprintf(output, "[%u:%u]", index * 16, index * 16 + 15);
+         } else {
+            unsigned index = sel[i] & sdwa_bytenum;
+            fprintf(output, "[%u:%u]", index * 8, index * 8 + 7);
+         }
          if (abs[i])
             fprintf(output, "|");
        }
index e967f0ca9e7d41c907b1c02c80d19d025c0cb5ae..4bbce14a86a2cae2aa97501890a8222549a01f3b 100644 (file)
@@ -93,6 +93,50 @@ void validate(Program* program, FILE * output)
                   "Format cannot have VOP3A/VOP3B applied", instr.get());
          }
 
+         /* check SDWA */
+         if (instr->isSDWA()) {
+            check(base_format == Format::VOP2 ||
+                  base_format == Format::VOP1 ||
+                  base_format == Format::VOPC,
+                  "Format cannot have SDWA applied", instr.get());
+
+            check(program->chip_class >= GFX8, "SDWA is GFX8+ only", instr.get());
+
+            SDWA_instruction *sdwa = static_cast<SDWA_instruction*>(instr.get());
+            check(sdwa->omod == 0 || program->chip_class >= GFX9, "SDWA omod only supported on GFX9+", instr.get());
+            if (base_format == Format::VOPC) {
+               check(sdwa->clamp == false || program->chip_class == GFX8, "SDWA VOPC clamp only supported on GFX8", instr.get());
+               check((instr->definitions[0].isFixed() && instr->definitions[0].physReg() == vcc) ||
+                     program->chip_class >= GFX9,
+                     "SDWA+VOPC definition must be fixed to vcc on GFX8", instr.get());
+            }
+
+            if (instr->operands.size() >= 3) {
+               check(instr->operands[2].isFixed() && instr->operands[2].physReg() == vcc,
+                     "3rd operand must be fixed to vcc with SDWA", instr.get());
+            }
+            if (instr->definitions.size() >= 2) {
+               check(instr->definitions[1].isFixed() && instr->definitions[1].physReg() == vcc,
+                     "2nd definition must be fixed to vcc with SDWA", instr.get());
+            }
+
+            check(instr->opcode != aco_opcode::v_madmk_f32 &&
+                  instr->opcode != aco_opcode::v_madak_f32 &&
+                  instr->opcode != aco_opcode::v_madmk_f16 &&
+                  instr->opcode != aco_opcode::v_madak_f16 &&
+                  instr->opcode != aco_opcode::v_readfirstlane_b32 &&
+                  instr->opcode != aco_opcode::v_clrexcp &&
+                  instr->opcode != aco_opcode::v_swap_b32,
+                  "SDWA can't be used with this opcode", instr.get());
+            if (program->chip_class != GFX8) {
+               check(instr->opcode != aco_opcode::v_mac_f32 &&
+                     instr->opcode != aco_opcode::v_mac_f16 &&
+                     instr->opcode != aco_opcode::v_fmac_f32 &&
+                     instr->opcode != aco_opcode::v_fmac_f16,
+                     "SDWA can't be used with this opcode", instr.get());
+            }
+         }
+
          /* check for undefs */
          for (unsigned i = 0; i < instr->operands.size(); i++) {
             if (instr->operands[i].isUndefined()) {
@@ -137,6 +181,10 @@ void validate(Program* program, FILE * output)
                if (program->chip_class >= GFX10 && !is_shift64)
                   const_bus_limit = 2;
 
+               uint32_t scalar_mask = instr->isVOP3() ? 0x7 : 0x5;
+               if (instr->isSDWA())
+                  scalar_mask = program->chip_class >= GFX9 ? 0x7 : 0x4;
+
                check(instr->definitions[0].getTemp().type() == RegType::vgpr ||
                      (int) instr->format & (int) Format::VOPC ||
                      instr->opcode == aco_opcode::v_readfirstlane_b32 ||
@@ -158,7 +206,7 @@ void validate(Program* program, FILE * output)
                      continue;
                   }
                   if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
-                     check(i != 1 || instr->isVOP3(), "Wrong source position for SGPR argument", instr.get());
+                     check(scalar_mask & (1 << i), "Wrong source position for SGPR argument", instr.get());
 
                      if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
                         if (num_sgprs < 2)
@@ -167,7 +215,7 @@ void validate(Program* program, FILE * output)
                   }
 
                   if (op.isConstant() && !op.isLiteral())
-                     check(i == 0 || instr->isVOP3(), "Wrong source position for constant argument", instr.get());
+                     check(scalar_mask & (1 << i), "Wrong source position for constant argument", instr.get());
                }
                check(num_sgprs + (literal.isUndefined() ? 0 : 1) <= const_bus_limit, "Too many SGPRs/literals", instr.get());
             }