aco: fix image_atomic_cmp_swap
[mesa.git] / src / amd / compiler / aco_validate.cpp
index 0988d66df3ab53401119ab75fec2ca3c761044e0..a479083a54c58f56c6410b41be3becc59d972b39 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "aco_ir.h"
 
+#include <array>
 #include <map>
 
 namespace aco {
@@ -59,6 +60,12 @@ void validate(Program* program, FILE * output)
          is_valid = false;
       }
    };
+   auto check_block = [&output, &is_valid](bool check, const char * msg, aco::Block * block) -> void {
+      if (!check) {
+         fprintf(output, "%s: BB%u\n", msg, block->index);
+         is_valid = false;
+      }
+   };
 
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& instr : block.instructions) {
@@ -92,55 +99,77 @@ void validate(Program* program, FILE * output)
                bool flat = instr->format == Format::FLAT || instr->format == Format::SCRATCH || instr->format == Format::GLOBAL;
                bool can_be_undef = is_phi(instr) || instr->format == Format::EXP ||
                                    instr->format == Format::PSEUDO_REDUCTION ||
-                                   (flat && i == 1) || (instr->format == Format::MIMG && i == 2) ||
-                                   ((instr->format == Format::MUBUF || instr->format == Format::MTBUF) && i == 0);
+                                   (flat && i == 1) || (instr->format == Format::MIMG && i == 1) ||
+                                   ((instr->format == Format::MUBUF || instr->format == Format::MTBUF) && i == 1);
                check(can_be_undef, "Undefs can only be used in certain operands", instr.get());
             }
          }
 
-         /* check num literals */
          if (instr->isSALU() || instr->isVALU()) {
-            unsigned num_literals = 0;
+            /* check literals */
+            Operand literal(s1);
             for (unsigned i = 0; i < instr->operands.size(); i++)
             {
-               if (instr->operands[i].isLiteral()) {
-                  check(instr->format == Format::SOP1 ||
-                        instr->format == Format::SOP2 ||
-                        instr->format == Format::SOPC ||
-                        instr->format == Format::VOP1 ||
-                        instr->format == Format::VOP2 ||
-                        instr->format == Format::VOPC,
-                        "Literal applied on wrong instruction format", instr.get());
-
-                  num_literals++;
-                  check(!instr->isVALU() || i == 0 || i == 2, "Wrong source position for Literal argument", instr.get());
-               }
+               Operand op = instr->operands[i];
+               if (!op.isLiteral())
+                  continue;
+
+               check(instr->format == Format::SOP1 ||
+                     instr->format == Format::SOP2 ||
+                     instr->format == Format::SOPC ||
+                     instr->format == Format::VOP1 ||
+                     instr->format == Format::VOP2 ||
+                     instr->format == Format::VOPC ||
+                     (instr->isVOP3() && program->chip_class >= GFX10),
+                     "Literal applied on wrong instruction format", instr.get());
+
+               check(literal.isUndefined() || (literal.size() == op.size() && literal.constantValue() == op.constantValue()), "Only 1 Literal allowed", instr.get());
+               literal = op;
+               check(!instr->isVALU() || instr->isVOP3() || i == 0 || i == 2, "Wrong source position for Literal argument", instr.get());
             }
-            check(num_literals <= 1, "Only 1 Literal allowed", instr.get());
 
             /* check num sgprs for VALU */
             if (instr->isVALU()) {
+               bool is_shift64 = instr->opcode == aco_opcode::v_lshlrev_b64 ||
+                                 instr->opcode == aco_opcode::v_lshrrev_b64 ||
+                                 instr->opcode == aco_opcode::v_ashrrev_i64;
+               unsigned const_bus_limit = 1;
+               if (program->chip_class >= GFX10 && !is_shift64)
+                  const_bus_limit = 2;
+
                check(instr->definitions[0].getTemp().type() == RegType::vgpr ||
                      (int) instr->format & (int) Format::VOPC ||
                      instr->opcode == aco_opcode::v_readfirstlane_b32 ||
-                     instr->opcode == aco_opcode::v_readlane_b32,
+                     instr->opcode == aco_opcode::v_readlane_b32 ||
+                     instr->opcode == aco_opcode::v_readlane_b32_e64,
                      "Wrong Definition type for VALU instruction", instr.get());
-               unsigned num_sgpr = 0;
-               unsigned sgpr_idx = instr->operands.size();
+               unsigned num_sgprs = 0;
+               unsigned sgpr[] = {0, 0};
                for (unsigned i = 0; i < instr->operands.size(); i++)
                {
-                  if (instr->operands[i].isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
-                     check(i != 1 || (int) instr->format & (int) Format::VOP3A, "Wrong source position for SGPR argument", instr.get());
+                  Operand op = instr->operands[i];
+                  if (instr->opcode == aco_opcode::v_readfirstlane_b32 ||
+                      instr->opcode == aco_opcode::v_readlane_b32 ||
+                      instr->opcode == aco_opcode::v_readlane_b32_e64 ||
+                      instr->opcode == aco_opcode::v_writelane_b32 ||
+                      instr->opcode == aco_opcode::v_writelane_b32_e64) {
+                     check(!op.isLiteral(), "No literal allowed on VALU instruction", instr.get());
+                     check(i == 1 || (op.isTemp() && op.regClass() == v1), "Wrong Operand type for VALU instruction", instr.get());
+                     continue;
+                  }
+                  if (op.isTemp() && instr->operands[i].regClass().type() == RegType::sgpr) {
+                     check(i != 1 || instr->isVOP3(), "Wrong source position for SGPR argument", instr.get());
 
-                     if (sgpr_idx == instr->operands.size() || instr->operands[sgpr_idx].tempId() != instr->operands[i].tempId())
-                        num_sgpr++;
-                     sgpr_idx = i;
+                     if (op.tempId() != sgpr[0] && op.tempId() != sgpr[1]) {
+                        if (num_sgprs < 2)
+                           sgpr[num_sgprs++] = op.tempId();
+                     }
                   }
 
-                  if (instr->operands[i].isConstant() && !instr->operands[i].isLiteral())
-                     check(i == 0 || (int) instr->format & (int) Format::VOP3A, "Wrong source position for constant argument", instr.get());
+                  if (op.isConstant() && !op.isLiteral())
+                     check(i == 0 || instr->isVOP3(), "Wrong source position for constant argument", instr.get());
                }
-               check(num_sgpr + num_literals <= 1, "Only 1 Literal OR 1 SGPR allowed", instr.get());
+               check(num_sgprs + (literal.isUndefined() ? 0 : 1) <= const_bus_limit, "Too many SGPRs/literals", instr.get());
             }
 
             if (instr->format == Format::SOP1 || instr->format == Format::SOP2) {
@@ -181,7 +210,7 @@ void validate(Program* program, FILE * output)
                }
             } else if (instr->opcode == aco_opcode::p_phi) {
                check(instr->operands.size() == block.logical_preds.size(), "Number of Operands does not match number of predecessors", instr.get());
-               check(instr->definitions[0].getTemp().type() == RegType::vgpr || instr->definitions[0].getTemp().regClass() == s2, "Logical Phi Definition must be vgpr or divergent boolean", instr.get());
+               check(instr->definitions[0].getTemp().type() == RegType::vgpr || instr->definitions[0].getTemp().regClass() == program->lane_mask, "Logical Phi Definition must be vgpr or divergent boolean", instr.get());
             } else if (instr->opcode == aco_opcode::p_linear_phi) {
                for (const Operand& op : instr->operands)
                   check(!op.isTemp() || op.getTemp().is_linear(), "Wrong Operand type", instr.get());
@@ -200,15 +229,30 @@ void validate(Program* program, FILE * output)
             break;
          }
          case Format::MTBUF:
-         case Format::MUBUF:
-         case Format::MIMG: {
+         case Format::MUBUF: {
             check(instr->operands.size() > 1, "VMEM instructions must have at least one operand", instr.get());
-            check(instr->operands[0].hasRegClass() && instr->operands[0].regClass().type() == RegType::vgpr,
+            check(instr->operands[1].hasRegClass() && instr->operands[1].regClass().type() == RegType::vgpr,
                   "VADDR must be in vgpr for VMEM instructions", instr.get());
-            check(instr->operands[1].isTemp() && instr->operands[1].regClass().type() == RegType::sgpr, "VMEM resource constant must be sgpr", instr.get());
+            check(instr->operands[0].isTemp() && instr->operands[0].regClass().type() == RegType::sgpr, "VMEM resource constant must be sgpr", instr.get());
             check(instr->operands.size() < 4 || (instr->operands[3].isTemp() && instr->operands[3].regClass().type() == RegType::vgpr), "VMEM write data must be vgpr", instr.get());
             break;
          }
+         case Format::MIMG: {
+            check(instr->operands.size() == 3, "MIMG instructions must have exactly 3 operands", instr.get());
+            check(instr->operands[0].hasRegClass() && (instr->operands[0].regClass() == s4 || instr->operands[0].regClass() == s8),
+                  "MIMG operands[0] (resource constant) must be in 4 or 8 SGPRs", instr.get());
+            if (instr->operands[1].hasRegClass() && instr->operands[1].regClass().type() == RegType::sgpr)
+               check(instr->operands[1].regClass() == s4, "MIMG operands[1] (sampler constant) must be 4 SGPRs", instr.get());
+            else if (instr->operands[1].hasRegClass() && instr->operands[1].regClass().type() == RegType::vgpr)
+               check((instr->definitions.empty() || instr->definitions[0].regClass() == instr->operands[1].regClass() ||
+                     instr->opcode == aco_opcode::image_atomic_cmpswap || instr->opcode == aco_opcode::image_atomic_fcmpswap),
+                     "MIMG operands[1] (VDATA) must be the same as definitions[0] for atomics", instr.get());
+            check(instr->operands[2].hasRegClass() && instr->operands[2].regClass().type() == RegType::vgpr,
+                  "MIMG operands[2] (VADDR) must be VGPR", instr.get());
+            check(instr->definitions.empty() || (instr->definitions[0].isTemp() && instr->definitions[0].regClass().type() == RegType::vgpr),
+                  "MIMG definitions[0] (VDATA) must be VGPR", instr.get());
+            break;
+         }
          case Format::DS: {
             for (const Operand& op : instr->operands) {
                check((op.isTemp() && op.regClass().type() == RegType::vgpr) || op.physReg() == m0,
@@ -243,6 +287,31 @@ void validate(Program* program, FILE * output)
          }
       }
    }
+
+   /* validate CFG */
+   for (unsigned i = 0; i < program->blocks.size(); i++) {
+      Block& block = program->blocks[i];
+      check_block(block.index == i, "block.index must match actual index", &block);
+
+      /* predecessors/successors should be sorted */
+      for (unsigned j = 0; j + 1 < block.linear_preds.size(); j++)
+         check_block(block.linear_preds[j] < block.linear_preds[j + 1], "linear predecessors must be sorted", &block);
+      for (unsigned j = 0; j + 1 < block.logical_preds.size(); j++)
+         check_block(block.logical_preds[j] < block.logical_preds[j + 1], "logical predecessors must be sorted", &block);
+      for (unsigned j = 0; j + 1 < block.linear_succs.size(); j++)
+         check_block(block.linear_succs[j] < block.linear_succs[j + 1], "linear successors must be sorted", &block);
+      for (unsigned j = 0; j + 1 < block.logical_succs.size(); j++)
+         check_block(block.logical_succs[j] < block.logical_succs[j + 1], "logical successors must be sorted", &block);
+
+      /* critical edges are not allowed */
+      if (block.linear_preds.size() > 1) {
+         for (unsigned pred : block.linear_preds)
+            check_block(program->blocks[pred].linear_succs.size() == 1, "linear critical edges are not allowed", &program->blocks[pred]);
+         for (unsigned pred : block.logical_preds)
+            check_block(program->blocks[pred].logical_succs.size() == 1, "logical critical edges are not allowed", &program->blocks[pred]);
+      }
+   }
+
    assert(is_valid);
 }