aco: add SDWA_instruction
[mesa.git] / src / amd / compiler / aco_ir.h
index 92511975a6965adac8d94370b12bee703afdac81..c8b5c00e1f2175e9ba66c710008cf6267c6f24e0 100644 (file)
@@ -169,6 +169,11 @@ constexpr Format asVOP3(Format format) {
    return (Format) ((uint32_t) Format::VOP3 | (uint32_t) format);
 };
 
+constexpr Format asSDWA(Format format) {
+   assert(format == Format::VOP1 || format == Format::VOP2 || format == Format::VOPC);
+   return (Format) ((uint32_t) Format::SDWA | (uint32_t) format);
+}
+
 enum class RegType {
    none = 0,
    sgpr,
@@ -267,10 +272,15 @@ private:
  */
 struct PhysReg {
    constexpr PhysReg() = default;
-   explicit constexpr PhysReg(unsigned r) : reg(r) {}
-   constexpr operator unsigned() const { return reg; }
-
-   uint16_t reg = 0;
+   explicit constexpr PhysReg(unsigned r) : reg_b(r << 2) {}
+   constexpr unsigned reg() const { return reg_b >> 2; }
+   constexpr unsigned byte() const { return reg_b & 0x3; }
+   constexpr operator unsigned() const { return reg(); }
+   constexpr bool operator==(PhysReg other) const { return reg_b == other.reg_b; }
+   constexpr bool operator!=(PhysReg other) const { return reg_b != other.reg_b; }
+   constexpr bool operator <(PhysReg other) const { return reg_b < other.reg_b; }
+
+   uint16_t reg_b = 0;
 };
 
 /* helper expressions for special registers */
@@ -472,6 +482,36 @@ public:
       return isConstant() && constantValue() == cmp;
    }
 
+   constexpr uint64_t constantValue64(bool signext=false) const noexcept
+   {
+      if (is64BitConst_) {
+         if (reg_ <= 192)
+            return reg_ - 128;
+         else if (reg_ <= 208)
+            return 0xFFFFFFFFFFFFFFFF - (reg_ - 193);
+
+         switch (reg_) {
+         case 240:
+            return 0x3FE0000000000000;
+         case 241:
+            return 0xBFE0000000000000;
+         case 242:
+            return 0x3FF0000000000000;
+         case 243:
+            return 0xBFF0000000000000;
+         case 244:
+            return 0x4000000000000000;
+         case 245:
+            return 0xC000000000000000;
+         case 246:
+            return 0x4010000000000000;
+         case 247:
+            return 0xC010000000000000;
+         }
+      }
+      return (signext && (data_.i & 0x80000000u) ? 0xffffffff00000000ull : 0ull) | data_.i;
+   }
+
    /* Indicates that the killed operand's live range intersects with the
     * instruction's definitions. Unlike isKill() and isFirstKill(), this is
     * not set by liveness analysis. */
@@ -521,6 +561,23 @@ public:
       return isFirstKill() && !isLateKill();
    }
 
+   constexpr bool operator == (Operand other) const noexcept
+   {
+      if (other.size() != size())
+         return false;
+      if (isFixed() != other.isFixed() || isKillBeforeDef() != other.isKillBeforeDef())
+         return false;
+      if (isFixed() && other.isFixed() && physReg() != other.physReg())
+         return false;
+      if (isLiteral())
+         return other.isLiteral() && other.constantValue() == constantValue();
+      else if (isConstant())
+         return other.isConstant() && other.physReg() == physReg();
+      else if (isUndefined())
+         return other.isUndefined() && other.regClass() == regClass();
+      else
+         return other.isTemp() && other.getTemp() == getTemp();
+   }
 private:
    union {
       uint32_t i;
@@ -789,6 +846,55 @@ struct DPP_instruction : public Instruction {
    bool bound_ctrl : 1;
 };
 
+enum sdwa_sel : uint8_t {
+    /* masks */
+    sdwa_wordnum = 0x1,
+    sdwa_bytenum = 0x3,
+    sdwa_asuint = 0x7,
+
+    /* flags */
+    sdwa_isword = 0x4,
+    sdwa_sext = 0x8,
+
+    /* specific values */
+    sdwa_ubyte0 = 0,
+    sdwa_ubyte1 = 1,
+    sdwa_ubyte2 = 2,
+    sdwa_ubyte3 = 3,
+    sdwa_uword0 = sdwa_isword | 0,
+    sdwa_uword1 = sdwa_isword | 1,
+    sdwa_udword = 6,
+
+    sdwa_sbyte0 = sdwa_ubyte0 | sdwa_sext,
+    sdwa_sbyte1 = sdwa_ubyte1 | sdwa_sext,
+    sdwa_sbyte2 = sdwa_ubyte2 | sdwa_sext,
+    sdwa_sbyte3 = sdwa_ubyte3 | sdwa_sext,
+    sdwa_sword0 = sdwa_uword0 | sdwa_sext,
+    sdwa_sword1 = sdwa_uword1 | sdwa_sext,
+    sdwa_sdword = sdwa_udword | sdwa_sext,
+};
+
+/**
+ * Sub-Dword Addressing Format:
+ * This format can be used for VOP1, VOP2 or VOPC instructions.
+ *
+ * omod and SGPR/constant operands are only available on GFX9+. For VOPC,
+ * the definition doesn't have to be VCC on GFX9+.
+ *
+ */
+struct SDWA_instruction : public Instruction {
+   /* these destination modifiers aren't available with VOPC except for
+    * clamp on GFX8 */
+   unsigned dst_sel:4;
+   bool dst_preserve:1;
+   bool clamp:1;
+   unsigned omod:2; /* GFX9+ */
+
+   unsigned sel[2];
+   bool neg[2];
+   bool abs[2];
+};
+
 struct Interp_instruction : public Instruction {
    uint8_t attribute;
    uint8_t component;
@@ -1172,6 +1278,21 @@ static constexpr Stage tess_control_hs = sw_tcs | hw_hs;
 static constexpr Stage tess_eval_es = sw_tes | hw_es; /* tesselation evaluation before geometry */
 static constexpr Stage geometry_gs = sw_gs | hw_gs;
 
+enum statistic {
+   statistic_hash,
+   statistic_instructions,
+   statistic_copies,
+   statistic_branches,
+   statistic_cycles,
+   statistic_vmem_clauses,
+   statistic_smem_clauses,
+   statistic_vmem_score,
+   statistic_smem_score,
+   statistic_sgpr_presched,
+   statistic_vgpr_presched,
+   num_statistics
+};
+
 class Program final {
 public:
    float_mode next_fp_mode;
@@ -1197,16 +1318,22 @@ public:
    uint16_t min_waves = 0;
    uint16_t lds_alloc_granule;
    uint32_t lds_limit; /* in bytes */
+   bool has_16bank_lds;
    uint16_t vgpr_limit;
    uint16_t sgpr_limit;
    uint16_t physical_sgprs;
    uint16_t sgpr_alloc_granule; /* minus one. must be power of two */
    uint16_t vgpr_alloc_granule; /* minus one. must be power of two */
+   unsigned workgroup_size; /* if known; otherwise UINT_MAX */
+
+   bool xnack_enabled = false;
 
    bool needs_vcc = false;
-   bool needs_xnack_mask = false;
    bool needs_flat_scr = false;
 
+   bool collect_statistics = false;
+   uint32_t statistics[num_statistics];
+
    uint32_t allocateId()
    {
       assert(allocationID <= 16777215);
@@ -1287,6 +1414,10 @@ void perfwarn(bool cond, const char *msg, Instruction *instr=NULL);
 #define perfwarn(program, cond, msg, ...) do {} while(0)
 #endif
 
+void collect_presched_stats(Program *program);
+void collect_preasm_stats(Program *program);
+void collect_postasm_stats(Program *program, const std::vector<uint32_t>& code);
+
 void aco_print_instr(Instruction *instr, FILE *output);
 void aco_print_program(Program *program, FILE *output);