aco: improve 8/16-bit constants
authorRhys Perry <pendingchaos02@gmail.com>
Fri, 15 May 2020 15:28:03 +0000 (16:28 +0100)
committerMarge Bot <eric+marge@anholt.net>
Mon, 15 Jun 2020 18:24:22 +0000 (18:24 +0000)
fossil-db (Navi, fp16 enabled):
Totals from 1 (0.00% of 127638) affected shaders:
CodeSize: 4540 -> 4388 (-3.35%)
Instrs: 861 -> 830 (-3.60%)
Cycles: 3444 -> 3320 (-3.60%)
VMEM: 489 -> 465 (-4.91%)
SMEM: 107 -> 110 (+2.80%)
SClause: 31 -> 30 (-3.23%)
Copies: 58 -> 54 (-6.90%)

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5245>

src/amd/compiler/aco_builder_h.py
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_ir.h
src/amd/compiler/aco_lower_to_hw_instr.cpp
src/amd/compiler/aco_optimizer.cpp
src/amd/compiler/aco_print_ir.cpp

index edd5f3fda645f8ef5372925570de799a5d83559c..0296653efdaab664fac5e222ac0d68e0dd899796 100644 (file)
@@ -78,6 +78,8 @@ ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask)
 
 aco_ptr<Instruction> create_s_mov(Definition dst, Operand src);
 
 
 aco_ptr<Instruction> create_s_mov(Definition dst, Operand src);
 
+extern uint8_t int8_mul_table[512];
+
 enum sendmsg {
    sendmsg_none = 0,
    _sendmsg_gs = 2,
 enum sendmsg {
    sendmsg_none = 0,
    _sendmsg_gs = 2,
@@ -388,6 +390,36 @@ public:
         return vop1(aco_opcode::v_mov_b32, dst, op);
       } else if (op.bytes() > 2) {
          return pseudo(aco_opcode::p_create_vector, dst, op);
         return vop1(aco_opcode::v_mov_b32, dst, op);
       } else if (op.bytes() > 2) {
          return pseudo(aco_opcode::p_create_vector, dst, op);
+      } else if (op.bytes() == 1 && op.isConstant()) {
+        uint8_t val = op.constantValue();
+        Operand op32((uint32_t)val | (val & 0x80u ? 0xffffff00u : 0u));
+        aco_ptr<SDWA_instruction> sdwa;
+        if (op32.isLiteral()) {
+            sdwa.reset(create_instruction<SDWA_instruction>(aco_opcode::v_mul_u32_u24, asSDWA(Format::VOP2), 2, 1));
+            uint32_t a = (uint32_t)int8_mul_table[val * 2];
+            uint32_t b = (uint32_t)int8_mul_table[val * 2 + 1];
+            sdwa->operands[0] = Operand(a | (a & 0x80u ? 0xffffff00u : 0x0u));
+            sdwa->operands[1] = Operand(b | (b & 0x80u ? 0xffffff00u : 0x0u));
+        } else {
+            sdwa.reset(create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1));
+            sdwa->operands[0] = op32;
+        }
+        sdwa->definitions[0] = dst;
+        sdwa->sel[0] = sdwa_udword;
+        sdwa->sel[1] = sdwa_udword;
+        sdwa->dst_sel = sdwa_ubyte;
+        sdwa->dst_preserve = true;
+        return insert(std::move(sdwa));
+      } else if (op.bytes() == 2 && op.isConstant() && !op.isLiteral()) {
+        aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_add_f16, asSDWA(Format::VOP2), 2, 1)};
+        sdwa->operands[0] = op;
+        sdwa->operands[1] = Operand(0u);
+        sdwa->definitions[0] = dst;
+        sdwa->sel[0] = sdwa_uword;
+        sdwa->sel[1] = sdwa_udword;
+        sdwa->dst_sel = dst.bytes() == 1 ? sdwa_ubyte : sdwa_uword;
+        sdwa->dst_preserve = true;
+        return insert(std::move(sdwa));
       } else if (dst.regClass().is_subdword()) {
         if (program->chip_class >= GFX8) {
             aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
       } else if (dst.regClass().is_subdword()) {
         if (program->chip_class >= GFX8) {
             aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
index f256ec6eb3a9aa9999b0966551fef0fb89771b72..c0cc445ffa38bd5fbd1863f644ee3c961304384c 100644 (file)
@@ -1925,7 +1925,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_fsat: {
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (dst.regClass() == v2b) {
    case nir_op_fsat: {
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (dst.regClass() == v2b) {
-         bld.vop3(aco_opcode::v_med3_f16, Definition(dst), Operand(0u), Operand(0x3f800000u), src);
+         bld.vop3(aco_opcode::v_med3_f16, Definition(dst), Operand((uint16_t)0u), Operand((uint16_t)0x3c00), src);
       } else if (dst.regClass() == v1) {
          bld.vop3(aco_opcode::v_med3_f32, Definition(dst), Operand(0u), Operand(0x3f800000u), src);
          /* apparently, it is not necessary to flush denorms if this instruction is used with these operands */
       } else if (dst.regClass() == v1) {
          bld.vop3(aco_opcode::v_med3_f32, Definition(dst), Operand(0u), Operand(0x3f800000u), src);
          /* apparently, it is not necessary to flush denorms if this instruction is used with these operands */
index 68d0b9bf4ceed9fce26c16e512fb6edff352818a..3db6b4b6d4385922858ceffc3c9ea1495ef1cbca 100644 (file)
@@ -337,7 +337,7 @@ class Operand final
 public:
    constexpr Operand()
       : reg_(PhysReg{128}), isTemp_(false), isFixed_(true), isConstant_(false),
 public:
    constexpr Operand()
       : reg_(PhysReg{128}), isTemp_(false), isFixed_(true), isConstant_(false),
-        isKill_(false), isUndef_(true), isFirstKill_(false), is64BitConst_(false),
+        isKill_(false), isUndef_(true), isFirstKill_(false), constSize(0),
         isLateKill_(false) {}
 
    explicit Operand(Temp r) noexcept
         isLateKill_(false) {}
 
    explicit Operand(Temp r) noexcept
@@ -350,11 +350,51 @@ public:
          setFixed(PhysReg{128});
       }
    };
          setFixed(PhysReg{128});
       }
    };
+   explicit Operand(uint8_t v) noexcept
+   {
+      /* 8-bit constants are only used for copies and copies from any 8-bit
+       * constant can be implemented with a SDWA v_mul_u32_u24. So consider all
+       * to be inline constants. */
+      data_.i = v;
+      isConstant_ = true;
+      constSize = 0;
+      setFixed(PhysReg{0u});
+   };
+   explicit Operand(uint16_t v) noexcept
+   {
+      data_.i = v;
+      isConstant_ = true;
+      constSize = 1;
+      if (v <= 64)
+         setFixed(PhysReg{128u + v});
+      else if (v >= 0xFFF0) /* [-16 .. -1] */
+         setFixed(PhysReg{192u + (0xFFFF - v)});
+      else if (v == 0x3800) /* 0.5 */
+         setFixed(PhysReg{240});
+      else if (v == 0xB800) /* -0.5 */
+         setFixed(PhysReg{241});
+      else if (v == 0x3C00) /* 1.0 */
+         setFixed(PhysReg{242});
+      else if (v == 0xBC00) /* -1.0 */
+         setFixed(PhysReg{243});
+      else if (v == 0x4000) /* 2.0 */
+         setFixed(PhysReg{244});
+      else if (v == 0xC000) /* -2.0 */
+         setFixed(PhysReg{245});
+      else if (v == 0x4400) /* 4.0 */
+         setFixed(PhysReg{246});
+      else if (v == 0xC400) /* -4.0 */
+         setFixed(PhysReg{247});
+      else if (v == 0x3118) /* 1/2 PI */
+         setFixed(PhysReg{248});
+      else /* Literal Constant */
+         setFixed(PhysReg{255});
+   };
    explicit Operand(uint32_t v, bool is64bit = false) noexcept
    {
       data_.i = v;
       isConstant_ = true;
    explicit Operand(uint32_t v, bool is64bit = false) noexcept
    {
       data_.i = v;
       isConstant_ = true;
-      is64BitConst_ = is64bit;
+      constSize = is64bit ? 3 : 2;
       if (v <= 64)
          setFixed(PhysReg{128 + v});
       else if (v >= 0xFFFFFFF0) /* [-16 .. -1] */
       if (v <= 64)
          setFixed(PhysReg{128 + v});
       else if (v >= 0xFFFFFFF0) /* [-16 .. -1] */
@@ -383,7 +423,7 @@ public:
    explicit Operand(uint64_t v) noexcept
    {
       isConstant_ = true;
    explicit Operand(uint64_t v) noexcept
    {
       isConstant_ = true;
-      is64BitConst_ = true;
+      constSize = 3;
       if (v <= 64) {
          data_.i = (uint32_t) v;
          setFixed(PhysReg{128 + (uint32_t) v});
       if (v <= 64) {
          data_.i = (uint32_t) v;
          setFixed(PhysReg{128 + (uint32_t) v});
@@ -465,7 +505,7 @@ public:
    constexpr unsigned bytes() const noexcept
    {
       if (isConstant())
    constexpr unsigned bytes() const noexcept
    {
       if (isConstant())
-         return is64BitConst_ ? 8 : 4; //TODO: sub-dword constants
+         return 1 << constSize;
       else
          return data_.temp.bytes();
    }
       else
          return data_.temp.bytes();
    }
@@ -473,7 +513,7 @@ public:
    constexpr unsigned size() const noexcept
    {
       if (isConstant())
    constexpr unsigned size() const noexcept
    {
       if (isConstant())
-         return is64BitConst_ ? 2 : 1;
+         return constSize > 2 ? 2 : 1;
       else
          return data_.temp.size();
    }
       else
          return data_.temp.size();
    }
@@ -521,7 +561,7 @@ public:
 
    constexpr uint64_t constantValue64(bool signext=false) const noexcept
    {
 
    constexpr uint64_t constantValue64(bool signext=false) const noexcept
    {
-      if (is64BitConst_) {
+      if (constSize == 3) {
          if (reg_ <= 192)
             return reg_ - 128;
          else if (reg_ <= 208)
          if (reg_ <= 192)
             return reg_ - 128;
          else if (reg_ <= 208)
@@ -545,6 +585,10 @@ public:
          case 247:
             return 0xC010000000000000;
          }
          case 247:
             return 0xC010000000000000;
          }
+      } else if (constSize == 1) {
+         return (signext && (data_.i & 0x8000u) ? 0xffffffffffff0000ull : 0ull) | data_.i;
+      } else if (constSize == 0) {
+         return (signext && (data_.i & 0x80u) ? 0xffffffffffffff00ull : 0ull) | data_.i;
       }
       return (signext && (data_.i & 0x80000000u) ? 0xffffffff00000000ull : 0ull) | data_.i;
    }
       }
       return (signext && (data_.i & 0x80000000u) ? 0xffffffff00000000ull : 0ull) | data_.i;
    }
@@ -635,11 +679,11 @@ private:
          uint8_t isKill_:1;
          uint8_t isUndef_:1;
          uint8_t isFirstKill_:1;
          uint8_t isKill_:1;
          uint8_t isUndef_:1;
          uint8_t isFirstKill_:1;
-         uint8_t is64BitConst_:1;
+         uint8_t constSize:2;
          uint8_t isLateKill_:1;
       };
       /* can't initialize bit-fields in c++11, so work around using a union */
          uint8_t isLateKill_:1;
       };
       /* can't initialize bit-fields in c++11, so work around using a union */
-      uint8_t control_ = 0;
+      uint16_t control_ = 0;
    };
 };
 
    };
 };
 
index bb63aea95d44a51fe5d55d3513960e73c7d29c4f..5e93dc603e6b65a49d33c7f5e46eb989e80adefa 100644 (file)
@@ -41,6 +41,37 @@ struct lower_context {
    std::vector<aco_ptr<Instruction>> instructions;
 };
 
    std::vector<aco_ptr<Instruction>> instructions;
 };
 
+/* used by handle_operands() indirectly through Builder::copy */
+uint8_t int8_mul_table[512] = {
+    0, 20, 1, 1, 1, 2, 1, 3, 1, 4, 1, 5, 1, 6, 1, 7, 1, 8, 1, 9, 1, 10, 1, 11,
+    1, 12, 1, 13, 1, 14, 1, 15, 1, 16, 1, 17, 1, 18, 1, 19, 1, 20, 1, 21,
+    1, 22, 1, 23, 1, 24, 1, 25, 1, 26, 1, 27, 1, 28, 1, 29, 1, 30, 1, 31,
+    1, 32, 1, 33, 1, 34, 1, 35, 1, 36, 1, 37, 1, 38, 1, 39, 1, 40, 1, 41,
+    1, 42, 1, 43, 1, 44, 1, 45, 1, 46, 1, 47, 1, 48, 1, 49, 1, 50, 1, 51,
+    1, 52, 1, 53, 1, 54, 1, 55, 1, 56, 1, 57, 1, 58, 1, 59, 1, 60, 1, 61,
+    1, 62, 1, 63, 1, 64, 5, 13, 2, 33, 17, 19, 2, 34, 3, 23, 2, 35, 11, 53,
+    2, 36, 7, 47, 2, 37, 3, 25, 2, 38, 7, 11, 2, 39, 53, 243, 2, 40, 3, 27,
+    2, 41, 17, 35, 2, 42, 5, 17, 2, 43, 3, 29, 2, 44, 15, 23, 2, 45, 7, 13,
+    2, 46, 3, 31, 2, 47, 5, 19, 2, 48, 19, 59, 2, 49, 3, 33, 2, 50, 7, 51,
+    2, 51, 15, 41, 2, 52, 3, 35, 2, 53, 11, 33, 2, 54, 23, 27, 2, 55, 3, 37,
+    2, 56, 9, 41, 2, 57, 5, 23, 2, 58, 3, 39, 2, 59, 7, 17, 2, 60, 9, 241,
+    2, 61, 3, 41, 2, 62, 5, 25, 2, 63, 35, 245, 2, 64, 3, 43, 5, 26, 9, 43,
+    3, 44, 7, 19, 10, 39, 3, 45, 4, 34, 11, 59, 3, 46, 9, 243, 4, 35, 3, 47,
+    22, 53, 7, 57, 3, 48, 5, 29, 10, 245, 3, 49, 4, 37, 9, 45, 3, 50, 7, 241,
+    4, 38, 3, 51, 7, 22, 5, 31, 3, 52, 7, 59, 7, 242, 3, 53, 4, 40, 7, 23,
+    3, 54, 15, 45, 4, 41, 3, 55, 6, 241, 9, 47, 3, 56, 13, 13, 5, 34, 3, 57,
+    4, 43, 11, 39, 3, 58, 5, 35, 4, 44, 3, 59, 6, 243, 7, 245, 3, 60, 5, 241,
+    7, 26, 3, 61, 4, 46, 5, 37, 3, 62, 11, 17, 4, 47, 3, 63, 5, 38, 5, 243,
+    3, 64, 7, 247, 9, 50, 5, 39, 4, 241, 33, 37, 6, 33, 13, 35, 4, 242, 5, 245,
+    6, 247, 7, 29, 4, 51, 5, 41, 5, 246, 7, 249, 3, 240, 11, 19, 5, 42, 3, 241,
+    4, 245, 25, 29, 3, 242, 5, 43, 4, 246, 3, 243, 17, 58, 17, 43, 3, 244,
+    5, 249, 6, 37, 3, 245, 2, 240, 5, 45, 2, 241, 21, 23, 2, 242, 3, 247,
+    2, 243, 5, 251, 2, 244, 29, 61, 2, 245, 3, 249, 2, 246, 17, 29, 2, 247,
+    9, 55, 1, 240, 1, 241, 1, 242, 1, 243, 1, 244, 1, 245, 1, 246, 1, 247,
+    1, 248, 1, 249, 1, 250, 1, 251, 1, 252, 1, 253, 1, 254, 1, 255
+};
+
+
 aco_opcode get_reduce_opcode(chip_class chip, ReduceOp op) {
    /* Because some 16-bit instructions are already VOP3 on GFX10, we use the
     * 32-bit opcodes (VOP2) which allows to remove the tempory VGPR and to use
 aco_opcode get_reduce_opcode(chip_class chip, ReduceOp op) {
    /* Because some 16-bit instructions are already VOP3 on GFX10, we use the
     * 32-bit opcodes (VOP2) which allows to remove the tempory VGPR and to use
@@ -999,11 +1030,15 @@ void split_copy(unsigned offset, Definition *def, Operand *op, const copy_operat
                       RegClass(src.def.regClass().type(), bytes).as_subdword();
    *def = Definition(src.def.tempId(), def_reg, def_cls);
    if (src.op.isConstant()) {
                       RegClass(src.def.regClass().type(), bytes).as_subdword();
    *def = Definition(src.def.tempId(), def_reg, def_cls);
    if (src.op.isConstant()) {
-      assert(offset == 0 || (offset == 4 && src.op.bytes() == 8));
-      if (src.op.bytes() == 8 && bytes == 4)
+      assert(bytes >= 1 && bytes <= 8);
+      if (bytes == 8)
+         *op = Operand(src.op.constantValue64() >> (offset * 8u));
+      else if (bytes == 4)
          *op = Operand(uint32_t(src.op.constantValue64() >> (offset * 8u)));
          *op = Operand(uint32_t(src.op.constantValue64() >> (offset * 8u)));
-      else
-         *op  = src.op;
+      else if (bytes == 2)
+         *op = Operand(uint16_t(src.op.constantValue64() >> (offset * 8u)));
+      else if (bytes == 1)
+         *op = Operand(uint8_t(src.op.constantValue64() >> (offset * 8u)));
    } else {
       RegClass op_cls = bytes % 4 == 0 ? RegClass(src.op.regClass().type(), bytes / 4u) :
                         RegClass(src.op.regClass().type(), bytes).as_subdword();
    } else {
       RegClass op_cls = bytes % 4 == 0 ? RegClass(src.op.regClass().type(), bytes / 4u) :
                         RegClass(src.op.regClass().type(), bytes).as_subdword();
index 37564b7e993fe60fe5814b25e198ab503c173e91..58d22910150a2ee1456796c19e326994b3c63382 100644 (file)
@@ -61,7 +61,7 @@ struct mad_info {
 
 enum Label {
    label_vec = 1 << 0,
 
 enum Label {
    label_vec = 1 << 0,
-   label_constant = 1 << 1,
+   label_constant_32bit = 1 << 1,
    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
     * 32-bit operations but this shouldn't cause any issues because we don't
     * look through any conversions */
    /* label_{abs,neg,mul,omod2,omod4,omod5,clamp} are used for both 16 and
     * 32-bit operations but this shouldn't cause any issues because we don't
     * look through any conversions */
@@ -91,13 +91,14 @@ enum Label {
    label_vcc_hint = 1 << 25,
    label_scc_needed = 1 << 26,
    label_b2i = 1 << 27,
    label_vcc_hint = 1 << 25,
    label_scc_needed = 1 << 26,
    label_b2i = 1 << 27,
+   label_constant_16bit = 1 << 29,
 };
 
 static constexpr uint32_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success |
                                          label_add_sub | label_bitwise | label_uniform_bitwise | label_minmax | label_fcmp;
 static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool |
                                         label_omod2 | label_omod4 | label_omod5 | label_clamp | label_scc_invert | label_b2i;
 };
 
 static constexpr uint32_t instr_labels = label_vec | label_mul | label_mad | label_omod_success | label_clamp_success |
                                          label_add_sub | label_bitwise | label_uniform_bitwise | label_minmax | label_fcmp;
 static constexpr uint32_t temp_labels = label_abs | label_neg | label_temp | label_vcc | label_b2f | label_uniform_bool |
                                         label_omod2 | label_omod4 | label_omod5 | label_clamp | label_scc_invert | label_b2i;
-static constexpr uint32_t val_labels = label_constant | label_constant_64bit | label_literal | label_mad;
+static constexpr uint32_t val_labels = label_constant_32bit | label_constant_64bit | label_constant_16bit | label_literal | label_mad;
 
 struct ssa_info {
    uint32_t val;
 
 struct ssa_info {
    uint32_t val;
@@ -122,7 +123,10 @@ struct ssa_info {
          label &= ~instr_labels; /* instr and temp alias */
       }
 
          label &= ~instr_labels; /* instr and temp alias */
       }
 
-      if (new_label & val_labels)
+      uint32_t const_labels = label_literal | label_constant_32bit | label_constant_64bit | label_constant_16bit;
+      if (new_label & const_labels)
+         label &= ~val_labels | const_labels;
+      else if (new_label & val_labels)
          label &= ~val_labels;
 
       label |= new_label;
          label &= ~val_labels;
 
       label |= new_label;
@@ -139,26 +143,85 @@ struct ssa_info {
       return label & label_vec;
    }
 
       return label & label_vec;
    }
 
-   void set_constant(uint32_t constant)
+   void set_constant(chip_class chip, uint64_t constant)
    {
    {
-      add_label(label_constant);
+      Operand op16((uint16_t)constant);
+      Operand op32((uint32_t)constant);
+      add_label(label_literal);
       val = constant;
       val = constant;
-   }
 
 
-   bool is_constant()
+      if (chip >= GFX8 && !op16.isLiteral())
+         add_label(label_constant_16bit);
+
+      if (!op32.isLiteral() || ((uint32_t)constant == 0x3e22f983 && chip >= GFX8))
+         add_label(label_constant_32bit);
+
+      if (constant <= 64) {
+         add_label(label_constant_64bit);
+      } else if (constant >= 0xFFFFFFFFFFFFFFF0) { /* [-16 .. -1] */
+         add_label(label_constant_64bit);
+      } else if (constant == 0x3FE0000000000000) { /* 0.5 */
+         add_label(label_constant_64bit);
+      } else if (constant == 0xBFE0000000000000) { /* -0.5 */
+         add_label(label_constant_64bit);
+      } else if (constant == 0x3FF0000000000000) { /* 1.0 */
+         add_label(label_constant_64bit);
+      } else if (constant == 0xBFF0000000000000) { /* -1.0 */
+         add_label(label_constant_64bit);
+      } else if (constant == 0x4000000000000000) { /* 2.0 */
+         add_label(label_constant_64bit);
+      } else if (constant == 0xC000000000000000) { /* -2.0 */
+         add_label(label_constant_64bit);
+      } else if (constant == 0x4010000000000000) { /* 4.0 */
+         add_label(label_constant_64bit);
+      } else if (constant == 0xC010000000000000) { /* -4.0 */
+         add_label(label_constant_64bit);
+      }
+
+      if (label & label_constant_64bit) {
+         val = Operand(constant).constantValue();
+         if (val != constant)
+            label &= ~(label_literal | label_constant_16bit | label_constant_32bit);
+      }
+   }
+
+   bool is_constant(unsigned bits)
    {
    {
-      return label & label_constant;
+      switch (bits) {
+      case 8:
+         return label & label_literal;
+      case 16:
+         return label & label_constant_16bit;
+      case 32:
+         return label & label_constant_32bit;
+      case 64:
+         return label & label_constant_64bit;
+      }
+      return false;
    }
 
    }
 
-   void set_constant_64bit(uint32_t constant)
+   bool is_literal(unsigned bits)
    {
    {
-      add_label(label_constant_64bit);
-      val = constant;
+      bool is_lit = label & label_literal;
+      switch (bits) {
+      case 8:
+         return false;
+      case 16:
+         return is_lit && ~(label & label_constant_16bit);
+      case 32:
+         return is_lit && ~(label & label_constant_32bit);
+      case 64:
+         return false;
+      }
+      return false;
    }
 
    }
 
-   bool is_constant_64bit()
+   bool is_constant_or_literal(unsigned bits)
    {
    {
-      return label & label_constant_64bit;
+      if (bits == 64)
+         return label & label_constant_64bit;
+      else
+         return label & label_literal;
    }
 
    void set_abs(Temp abs_temp)
    }
 
    void set_abs(Temp abs_temp)
@@ -211,17 +274,6 @@ struct ssa_info {
       return label & label_temp;
    }
 
       return label & label_temp;
    }
 
-   void set_literal(uint32_t lit)
-   {
-      add_label(label_literal);
-      val = lit;
-   }
-
-   bool is_literal()
-   {
-      return label & label_literal;
-   }
-
    void set_mad(Instruction* mad, uint32_t mad_info_idx)
    {
       add_label(label_mad);
    void set_mad(Instruction* mad, uint32_t mad_info_idx)
    {
       add_label(label_mad);
@@ -321,11 +373,6 @@ struct ssa_info {
       return label & label_vcc;
    }
 
       return label & label_vcc;
    }
 
-   bool is_constant_or_literal()
-   {
-      return is_constant() || is_literal();
-   }
-
    void set_b2f(Temp val)
    {
       add_label(label_b2f);
    void set_b2f(Temp val)
    {
       add_label(label_b2f);
@@ -655,7 +702,7 @@ bool parse_base_offset(opt_ctx &ctx, Instruction* instr, unsigned op_index, Temp
       if (add_instr->operands[i].isConstant()) {
          *offset = add_instr->operands[i].constantValue();
       } else if (add_instr->operands[i].isTemp() &&
       if (add_instr->operands[i].isConstant()) {
          *offset = add_instr->operands[i].constantValue();
       } else if (add_instr->operands[i].isTemp() &&
-                 ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal()) {
+                 ctx.info[add_instr->operands[i].tempId()].is_constant_or_literal(32)) {
          *offset = ctx.info[add_instr->operands[i].tempId()].val;
       } else {
          continue;
          *offset = ctx.info[add_instr->operands[i].tempId()].val;
       } else {
          continue;
@@ -687,11 +734,15 @@ unsigned get_operand_size(aco_ptr<Instruction>& instr, unsigned index)
       return 0;
 }
 
       return 0;
 }
 
-Operand get_constant_op(opt_ctx &ctx, uint32_t val, bool is64bit = false)
+Operand get_constant_op(opt_ctx &ctx, ssa_info info, uint32_t bits)
 {
 {
+   if (bits == 8)
+      return Operand((uint8_t)info.val);
+   if (bits == 16)
+      return Operand((uint16_t)info.val);
    // TODO: this functions shouldn't be needed if we store Operand instead of value.
    // TODO: this functions shouldn't be needed if we store Operand instead of value.
-   Operand op(val, is64bit);
-   if (val == 0x3e22f983 && ctx.program->chip_class >= GFX8)
+   Operand op(info.val, bits == 64);
+   if (info.is_literal(32) && info.val == 0x3e22f983 && ctx.program->chip_class >= GFX8)
       op.setFixed(PhysReg{248}); /* 1/2 PI can be an inline constant on GFX8+ */
    return op;
 }
       op.setFixed(PhysReg{248}); /* 1/2 PI can be an inline constant on GFX8+ */
    return op;
 }
@@ -706,7 +757,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
    if (instr->isSALU() || instr->isVALU() || instr->format == Format::PSEUDO) {
       ASSERTED bool all_const = false;
       for (Operand& op : instr->operands)
    if (instr->isSALU() || instr->isVALU() || instr->format == Format::PSEUDO) {
       ASSERTED bool all_const = false;
       for (Operand& op : instr->operands)
-         all_const = all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal());
+         all_const = all_const && (!op.isTemp() || ctx.info[op.tempId()].is_constant_or_literal(32));
       perfwarn(all_const, "All instruction operands are constant", instr.get());
    }
 
       perfwarn(all_const, "All instruction operands are constant", instr.get());
    }
 
@@ -728,13 +779,13 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
       /* SALU / PSEUDO: propagate inline constants */
       if (instr->isSALU() || instr->format == Format::PSEUDO) {
          bool is_subdword = false;
       /* SALU / PSEUDO: propagate inline constants */
       if (instr->isSALU() || instr->format == Format::PSEUDO) {
          bool is_subdword = false;
-         // TODO: optimize SGPR and constant propagation for subdword pseudo instructions on gfx9+
+         // TODO: optimize SGPR propagation for subdword pseudo instructions on gfx9+
          if (instr->format == Format::PSEUDO) {
             is_subdword = std::any_of(instr->definitions.begin(), instr->definitions.end(),
                                       [] (const Definition& def) { return def.regClass().is_subdword();});
             is_subdword = is_subdword || std::any_of(instr->operands.begin(), instr->operands.end(),
                                                      [] (const Operand& op) { return op.hasRegClass() && op.regClass().is_subdword();});
          if (instr->format == Format::PSEUDO) {
             is_subdword = std::any_of(instr->definitions.begin(), instr->definitions.end(),
                                       [] (const Definition& def) { return def.regClass().is_subdword();});
             is_subdword = is_subdword || std::any_of(instr->operands.begin(), instr->operands.end(),
                                                      [] (const Operand& op) { return op.hasRegClass() && op.regClass().is_subdword();});
-            if (is_subdword)
+            if (is_subdword && ctx.program->chip_class < GFX9)
                continue;
          }
 
                continue;
          }
 
@@ -760,9 +811,10 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
                break;
             }
          }
                break;
             }
          }
-         if ((info.is_constant() || info.is_constant_64bit() || (info.is_literal() && instr->format == Format::PSEUDO)) &&
+         unsigned bits = get_operand_size(instr, i);
+         if ((info.is_constant(bits) || (!is_subdword && info.is_literal(bits) && instr->format == Format::PSEUDO)) &&
              !instr->operands[i].isFixed() && alu_can_accept_constant(instr->opcode, i)) {
              !instr->operands[i].isFixed() && alu_can_accept_constant(instr->opcode, i)) {
-            instr->operands[i] = get_constant_op(ctx, info.val, info.is_constant_64bit());
+            instr->operands[i] = get_constant_op(ctx, info, bits);
             continue;
          }
       }
             continue;
          }
       }
@@ -805,8 +857,9 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
                static_cast<VOP3A_instruction*>(instr.get())->neg[i] = true;
             continue;
          }
                static_cast<VOP3A_instruction*>(instr.get())->neg[i] = true;
             continue;
          }
-         if ((info.is_constant() || info.is_constant_64bit()) && alu_can_accept_constant(instr->opcode, i)) {
-            Operand op = get_constant_op(ctx, info.val, info.is_constant_64bit());
+         unsigned bits = get_operand_size(instr, i);
+         if (info.is_constant(bits) && alu_can_accept_constant(instr->opcode, i)) {
+            Operand op = get_constant_op(ctx, info, bits);
             perfwarn(instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
             if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) {
                instr->operands[i] = op;
             perfwarn(instr->opcode == aco_opcode::v_cndmask_b32 && i == 2, "v_cndmask_b32 with a constant selector", instr.get());
             if (i == 0 || instr->opcode == aco_opcode::v_readlane_b32 || instr->opcode == aco_opcode::v_writelane_b32) {
                instr->operands[i] = op;
@@ -831,13 +884,13 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          while (info.is_temp())
             info = ctx.info[info.temp.id()];
 
          while (info.is_temp())
             info = ctx.info[info.temp.id()];
 
-         if (mubuf->offen && i == 1 && info.is_constant_or_literal() && mubuf->offset + info.val < 4096) {
+         if (mubuf->offen && i == 1 && info.is_constant_or_literal(32) && mubuf->offset + info.val < 4096) {
             assert(!mubuf->idxen);
             instr->operands[1] = Operand(v1);
             mubuf->offset += info.val;
             mubuf->offen = false;
             continue;
             assert(!mubuf->idxen);
             instr->operands[1] = Operand(v1);
             mubuf->offset += info.val;
             mubuf->offen = false;
             continue;
-         } else if (i == 2 && info.is_constant_or_literal() && mubuf->offset + info.val < 4096) {
+         } else if (i == 2 && info.is_constant_or_literal(32) && mubuf->offset + info.val < 4096) {
             instr->operands[2] = Operand((uint32_t) 0);
             mubuf->offset += info.val;
             continue;
             instr->operands[2] = Operand((uint32_t) 0);
             mubuf->offset += info.val;
             continue;
@@ -891,7 +944,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          SMEM_instruction *smem = static_cast<SMEM_instruction *>(instr.get());
          Temp base;
          uint32_t offset;
          SMEM_instruction *smem = static_cast<SMEM_instruction *>(instr.get());
          Temp base;
          uint32_t offset;
-         if (i == 1 && info.is_constant_or_literal() &&
+         if (i == 1 && info.is_constant_or_literal(32) &&
              ((ctx.program->chip_class == GFX6 && info.val <= 0x3FF) ||
               (ctx.program->chip_class == GFX7 && info.val <= 0xFFFFFFFF) ||
               (ctx.program->chip_class >= GFX8 && info.val <= 0xFFFFF))) {
              ((ctx.program->chip_class == GFX6 && info.val <= 0x3FF) ||
               (ctx.program->chip_class == GFX7 && info.val <= 0xFFFFFFFF) ||
               (ctx.program->chip_class >= GFX8 && info.val <= 0xFFFFF))) {
@@ -900,7 +953,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          } else if (i == 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->chip_class >= GFX9) {
             bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
             if (soe &&
          } else if (i == 1 && parse_base_offset(ctx, instr.get(), i, &base, &offset) && base.regClass() == s1 && offset <= 0xFFFFF && ctx.program->chip_class >= GFX9) {
             bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4);
             if (soe &&
-                (!ctx.info[smem->operands.back().tempId()].is_constant_or_literal() ||
+                (!ctx.info[smem->operands.back().tempId()].is_constant_or_literal(32) ||
                  ctx.info[smem->operands.back().tempId()].val != 0)) {
                continue;
             }
                  ctx.info[smem->operands.back().tempId()].val != 0)) {
                continue;
             }
@@ -996,12 +1049,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
 
          Operand vec_op = vec->operands[vec_index];
          if (vec_op.isConstant()) {
 
          Operand vec_op = vec->operands[vec_index];
          if (vec_op.isConstant()) {
-            if (vec_op.isLiteral())
-               ctx.info[instr->definitions[i].tempId()].set_literal(vec_op.constantValue());
-            else if (vec_op.size() == 1)
-               ctx.info[instr->definitions[i].tempId()].set_constant(vec_op.constantValue());
-            else if (vec_op.size() == 2)
-               ctx.info[instr->definitions[i].tempId()].set_constant_64bit(vec_op.constantValue());
+            ctx.info[instr->definitions[i].tempId()].set_constant(ctx.program->chip_class, vec_op.constantValue64());
          } else if (vec_op.isUndefined()) {
             ctx.info[instr->definitions[i].tempId()].set_undefined();
          } else {
          } else if (vec_op.isUndefined()) {
             ctx.info[instr->definitions[i].tempId()].set_undefined();
          } else {
@@ -1035,12 +1083,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          instr->operands[0] = op;
 
          if (op.isConstant()) {
          instr->operands[0] = op;
 
          if (op.isConstant()) {
-            if (op.isLiteral())
-               ctx.info[instr->definitions[0].tempId()].set_literal(op.constantValue());
-            else if (op.size() == 1)
-               ctx.info[instr->definitions[0].tempId()].set_constant(op.constantValue());
-            else if (op.size() == 2)
-               ctx.info[instr->definitions[0].tempId()].set_constant_64bit(op.constantValue());
+            ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, op.constantValue64());
          } else if (op.isUndefined()) {
             ctx.info[instr->definitions[0].tempId()].set_undefined();
          } else {
          } else if (op.isUndefined()) {
             ctx.info[instr->definitions[0].tempId()].set_undefined();
          } else {
@@ -1060,12 +1103,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
       } else if (instr->usesModifiers()) {
          // TODO
       } else if (instr->operands[0].isConstant()) {
       } else if (instr->usesModifiers()) {
          // TODO
       } else if (instr->operands[0].isConstant()) {
-         if (instr->operands[0].isLiteral())
-            ctx.info[instr->definitions[0].tempId()].set_literal(instr->operands[0].constantValue());
-         else if (instr->operands[0].size() == 1)
-            ctx.info[instr->definitions[0].tempId()].set_constant(instr->operands[0].constantValue());
-         else if (instr->operands[0].size() == 2)
-            ctx.info[instr->definitions[0].tempId()].set_constant_64bit(instr->operands[0].constantValue());
+         ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, instr->operands[0].constantValue64());
       } else if (instr->operands[0].isTemp()) {
          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
       } else {
       } else if (instr->operands[0].isTemp()) {
          ctx.info[instr->definitions[0].tempId()].set_temp(instr->operands[0].getTemp());
       } else {
@@ -1074,25 +1112,19 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
       break;
    case aco_opcode::p_is_helper:
       if (!ctx.program->needs_wqm)
       break;
    case aco_opcode::p_is_helper:
       if (!ctx.program->needs_wqm)
-         ctx.info[instr->definitions[0].tempId()].set_constant(0u);
+         ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, 0u);
       break;
    case aco_opcode::s_movk_i32: {
       uint32_t v = static_cast<SOPK_instruction*>(instr.get())->imm;
       v = v & 0x8000 ? (v | 0xffff0000) : v;
       break;
    case aco_opcode::s_movk_i32: {
       uint32_t v = static_cast<SOPK_instruction*>(instr.get())->imm;
       v = v & 0x8000 ? (v | 0xffff0000) : v;
-      if (v <= 64 || v >= 0xfffffff0)
-         ctx.info[instr->definitions[0].tempId()].set_constant(v);
-      else
-         ctx.info[instr->definitions[0].tempId()].set_literal(v);
+      ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, v);
       break;
    }
    case aco_opcode::v_bfrev_b32:
    case aco_opcode::s_brev_b32: {
       if (instr->operands[0].isConstant()) {
          uint32_t v = util_bitreverse(instr->operands[0].constantValue());
       break;
    }
    case aco_opcode::v_bfrev_b32:
    case aco_opcode::s_brev_b32: {
       if (instr->operands[0].isConstant()) {
          uint32_t v = util_bitreverse(instr->operands[0].constantValue());
-         if (v <= 64 || v >= 0xfffffff0)
-            ctx.info[instr->definitions[0].tempId()].set_constant(v);
-         else
-            ctx.info[instr->definitions[0].tempId()].set_literal(v);
+         ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, v);
       }
       break;
    }
       }
       break;
    }
@@ -1101,10 +1133,7 @@ void label_instruction(opt_ctx &ctx, Block& block, aco_ptr<Instruction>& instr)
          unsigned size = instr->operands[0].constantValue() & 0x1f;
          unsigned start = instr->operands[1].constantValue() & 0x1f;
          uint32_t v = ((1u << size) - 1u) << start;
          unsigned size = instr->operands[0].constantValue() & 0x1f;
          unsigned start = instr->operands[1].constantValue() & 0x1f;
          uint32_t v = ((1u << size) - 1u) << start;
-         if (v <= 64 || v >= 0xfffffff0)
-            ctx.info[instr->definitions[0].tempId()].set_constant(v);
-         else
-            ctx.info[instr->definitions[0].tempId()].set_literal(v);
+         ctx.info[instr->definitions[0].tempId()].set_constant(ctx.program->chip_class, v);
       }
       break;
    }
       }
       break;
    }
@@ -1629,7 +1658,7 @@ bool combine_constant_comparison_ordering(opt_ctx &ctx, aco_ptr<Instruction>& in
    } else if (cmp->operands[constant_operand].isTemp()) {
       Temp tmp = cmp->operands[constant_operand].getTemp();
       unsigned id = original_temp_id(ctx, tmp);
    } else if (cmp->operands[constant_operand].isTemp()) {
       Temp tmp = cmp->operands[constant_operand].getTemp();
       unsigned id = original_temp_id(ctx, tmp);
-      if (!ctx.info[id].is_constant() && !ctx.info[id].is_literal())
+      if (!ctx.info[id].is_constant_or_literal(32))
          return false;
       constant = ctx.info[id].val;
    } else {
          return false;
       constant = ctx.info[id].val;
    } else {
@@ -2115,7 +2144,7 @@ bool combine_clamp(opt_ctx& ctx, aco_ptr<Instruction>& instr,
             uint32_t val;
             if (operands[i].isConstant()) {
                val = operands[i].constantValue();
             uint32_t val;
             if (operands[i].isConstant()) {
                val = operands[i].constantValue();
-            } else if (operands[i].isTemp() && ctx.info[operands[i].tempId()].is_constant_or_literal()) {
+            } else if (operands[i].isTemp() && ctx.info[operands[i].tempId()].is_constant_or_literal(32)) {
                val = ctx.info[operands[i].tempId()].val;
             } else {
                continue;
                val = ctx.info[operands[i].tempId()].val;
             } else {
                continue;
@@ -2791,9 +2820,10 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
             }
             if (!instr->operands[i].isTemp())
                continue;
             }
             if (!instr->operands[i].isTemp())
                continue;
+            unsigned bits = get_operand_size(instr, i);
             /* if one of the operands is sgpr, we cannot add a literal somewhere else on pre-GFX10 or operands other than the 1st */
             if (instr->operands[i].getTemp().type() == RegType::sgpr && (i > 0 || ctx.program->chip_class < GFX10)) {
             /* if one of the operands is sgpr, we cannot add a literal somewhere else on pre-GFX10 or operands other than the 1st */
             if (instr->operands[i].getTemp().type() == RegType::sgpr && (i > 0 || ctx.program->chip_class < GFX10)) {
-               if (!sgpr_used && ctx.info[instr->operands[i].tempId()].is_literal()) {
+               if (!sgpr_used && ctx.info[instr->operands[i].tempId()].is_literal(bits)) {
                   literal_uses = ctx.uses[instr->operands[i].tempId()];
                   literal_idx = i;
                } else {
                   literal_uses = ctx.uses[instr->operands[i].tempId()];
                   literal_idx = i;
                } else {
@@ -2802,7 +2832,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
                sgpr_used = true;
                /* don't break because we still need to check constants */
             } else if (!sgpr_used &&
                sgpr_used = true;
                /* don't break because we still need to check constants */
             } else if (!sgpr_used &&
-                       ctx.info[instr->operands[i].tempId()].is_literal() &&
+                       ctx.info[instr->operands[i].tempId()].is_literal(bits) &&
                        ctx.uses[instr->operands[i].tempId()] < literal_uses) {
                literal_uses = ctx.uses[instr->operands[i].tempId()];
                literal_idx = i;
                        ctx.uses[instr->operands[i].tempId()] < literal_uses) {
                literal_uses = ctx.uses[instr->operands[i].tempId()];
                literal_idx = i;
@@ -2881,6 +2911,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    /* choose a literal to apply */
    for (unsigned i = 0; i < num_operands; i++) {
       Operand op = instr->operands[i];
    /* choose a literal to apply */
    for (unsigned i = 0; i < num_operands; i++) {
       Operand op = instr->operands[i];
+      unsigned bits = get_operand_size(instr, i);
 
       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
           op.tempId() != sgpr_ids[0])
 
       if (instr->isVALU() && op.isTemp() && op.getTemp().type() == RegType::sgpr &&
           op.tempId() != sgpr_ids[0])
@@ -2889,7 +2920,7 @@ void select_instruction(opt_ctx &ctx, aco_ptr<Instruction>& instr)
       if (op.isLiteral()) {
          current_literal = op;
          continue;
       if (op.isLiteral()) {
          current_literal = op;
          continue;
-      } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal()) {
+      } else if (!op.isTemp() || !ctx.info[op.tempId()].is_literal(bits)) {
          continue;
       }
 
          continue;
       }
 
@@ -2974,7 +3005,8 @@ void apply_literals(opt_ctx &ctx, aco_ptr<Instruction>& instr)
    if (instr->isSALU() || instr->isVALU()) {
       for (unsigned i = 0; i < instr->operands.size(); i++) {
          Operand op = instr->operands[i];
    if (instr->isSALU() || instr->isVALU()) {
       for (unsigned i = 0; i < instr->operands.size(); i++) {
          Operand op = instr->operands[i];
-         if (op.isTemp() && ctx.info[op.tempId()].is_literal() && ctx.uses[op.tempId()] == 0) {
+         unsigned bits = get_operand_size(instr, i);
+         if (op.isTemp() && ctx.info[op.tempId()].is_literal(bits) && ctx.uses[op.tempId()] == 0) {
             Operand literal(ctx.info[op.tempId()].val);
             if (instr->isVALU() && i > 0)
                to_VOP3(ctx, instr);
             Operand literal(ctx.info[op.tempId()].val);
             if (instr->isVALU() && i > 0)
                to_VOP3(ctx, instr);
index 0fb0ceb186d5d5bda1f5a91575b06e6e5cdce351..3daa60b71c14574b6e3130cb4961aa4b8161e17a 100644 (file)
@@ -153,8 +153,13 @@ static void print_constant(uint8_t reg, FILE *output)
 
 static void print_operand(const Operand *operand, FILE *output)
 {
 
 static void print_operand(const Operand *operand, FILE *output)
 {
-   if (operand->isLiteral()) {
-      fprintf(output, "0x%x", operand->constantValue());
+   if (operand->isLiteral() || (operand->isConstant() && operand->bytes() == 1)) {
+      if (operand->bytes() == 1)
+         fprintf(output, "0x%.2x", operand->constantValue());
+      else if (operand->bytes() == 2)
+         fprintf(output, "0x%.4x", operand->constantValue());
+      else
+         fprintf(output, "0x%x", operand->constantValue());
    } else if (operand->isConstant()) {
       print_constant(operand->physReg().reg(), output);
    } else if (operand->isUndefined()) {
    } else if (operand->isConstant()) {
       print_constant(operand->physReg().reg(), output);
    } else if (operand->isUndefined()) {