aco: implement 8-bit/16-bit conversions on GFX6-GFX7
authorSamuel Pitoiset <samuel.pitoiset@gmail.com>
Thu, 7 May 2020 08:55:28 +0000 (10:55 +0200)
committerMarge Bot <eric+marge@anholt.net>
Tue, 9 Jun 2020 21:25:38 +0000 (21:25 +0000)
Use v_bfe to implement small bitsize conversions because the
compiler probably optimizes this better.

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5226>

src/amd/compiler/aco_instruction_selection.cpp

index ccfb17eb6708555d3eadf983da670930cdc0b5d6..66f6c8857e222e59642770d5dddcb84286ae3a8e 100644 (file)
@@ -944,7 +944,7 @@ Temp emit_floor_f64(isel_context *ctx, Builder& bld, Definition dst, Temp val)
    return add->definitions[0].getTemp();
 }
 
-Temp convert_int(Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, bool is_signed, Temp dst=Temp()) {
+Temp convert_int(isel_context *ctx, Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, bool is_signed, Temp dst=Temp()) {
    if (!dst.id()) {
       if (dst_bits % 32 == 0 || src.type() == RegType::sgpr)
          dst = bld.tmp(src.type(), DIV_ROUND_UP(dst_bits, 32u));
@@ -967,7 +967,7 @@ Temp convert_int(Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, b
          bld.sop1(src_bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, Definition(tmp), src);
       else
          bld.sop2(aco_opcode::s_and_b32, Definition(tmp), bld.def(s1, scc), Operand(src_bits == 8 ? 0xFFu : 0xFFFFu), src);
-   } else {
+   } else if (ctx->options->chip_class >= GFX8) {
       assert(src_bits != 8 || src.regClass() == v1b);
       assert(src_bits != 16 || src.regClass() == v2b);
       aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
@@ -979,6 +979,10 @@ Temp convert_int(Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, b
          sdwa->sel[0] = src_bits == 8 ? sdwa_ubyte : sdwa_uword;
       sdwa->dst_sel = tmp.bytes() == 2 ? sdwa_uword : sdwa_udword;
       bld.insert(std::move(sdwa));
+   } else {
+      assert(ctx->options->chip_class == GFX6 || ctx->options->chip_class == GFX7);
+      aco_opcode opcode = is_signed ? aco_opcode::v_bfe_i32 : aco_opcode::v_bfe_u32;
+      bld.vop3(opcode, Definition(tmp), src, Operand(0u), Operand(src_bits == 8 ? 8u : 16u));
    }
 
    if (dst_bits == 64) {
@@ -2130,7 +2134,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       if (instr->src[0].src.ssa->bit_size == 16) {
          Temp tmp = bld.vop1(aco_opcode::v_frexp_exp_i16_f16, bld.def(v1), src);
          tmp = bld.pseudo(aco_opcode::p_extract_vector, bld.def(v1b), tmp, Operand(0u));
-         convert_int(bld, tmp, 8, 32, true, dst);
+         convert_int(ctx, bld, tmp, 8, 32, true, dst);
       } else if (instr->src[0].src.ssa->bit_size == 32) {
          bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst), src);
       } else if (instr->src[0].src.ssa->bit_size == 64) {
@@ -2211,7 +2215,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       assert(dst.regClass() == v2b);
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (instr->src[0].src.ssa->bit_size == 8)
-         src = convert_int(bld, src, 8, 16, true);
+         src = convert_int(ctx, bld, src, 8, 16, true);
       bld.vop1(aco_opcode::v_cvt_f16_i16, Definition(dst), src);
       break;
    }
@@ -2219,7 +2223,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       assert(dst.size() == 1);
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (instr->src[0].src.ssa->bit_size <= 16)
-         src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true);
+         src = convert_int(ctx, bld, src, instr->src[0].src.ssa->bit_size, 32, true);
       bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src);
       break;
    }
@@ -2227,7 +2231,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       if (instr->src[0].src.ssa->bit_size <= 32) {
          Temp src = get_alu_src(ctx, instr->src[0]);
          if (instr->src[0].src.ssa->bit_size <= 16)
-            src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true);
+            src = convert_int(ctx, bld, src, instr->src[0].src.ssa->bit_size, 32, true);
          bld.vop1(aco_opcode::v_cvt_f64_i32, Definition(dst), src);
       } else if (instr->src[0].src.ssa->bit_size == 64) {
          Temp src = get_alu_src(ctx, instr->src[0]);
@@ -2250,7 +2254,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       assert(dst.regClass() == v2b);
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (instr->src[0].src.ssa->bit_size == 8)
-         src = convert_int(bld, src, 8, 16, false);
+         src = convert_int(ctx, bld, src, 8, 16, false);
       bld.vop1(aco_opcode::v_cvt_f16_u16, Definition(dst), src);
       break;
    }
@@ -2262,7 +2266,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
          bld.vop1(aco_opcode::v_cvt_f32_ubyte0, Definition(dst), src);
       } else {
          if (instr->src[0].src.ssa->bit_size == 16)
-            src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true);
+            src = convert_int(ctx, bld, src, instr->src[0].src.ssa->bit_size, 32, true);
          bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src);
       }
       break;
@@ -2271,7 +2275,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       if (instr->src[0].src.ssa->bit_size <= 32) {
          Temp src = get_alu_src(ctx, instr->src[0]);
          if (instr->src[0].src.ssa->bit_size <= 16)
-            src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, false);
+            src = convert_int(ctx, bld, src, instr->src[0].src.ssa->bit_size, 32, false);
          bld.vop1(aco_opcode::v_cvt_f64_u32, Definition(dst), src);
       } else if (instr->src[0].src.ssa->bit_size == 64) {
          Temp src = get_alu_src(ctx, instr->src[0]);
@@ -2583,7 +2587,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_i2i16:
    case nir_op_i2i32:
    case nir_op_i2i64: {
-      convert_int(bld, get_alu_src(ctx, instr->src[0]),
+      convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]),
                   instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, true, dst);
       break;
    }
@@ -2591,7 +2595,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_u2u16:
    case nir_op_u2u32:
    case nir_op_u2u64: {
-      convert_int(bld, get_alu_src(ctx, instr->src[0]),
+      convert_int(ctx, bld, get_alu_src(ctx, instr->src[0]),
                   instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, false, dst);
       break;
    }
@@ -10053,7 +10057,7 @@ static bool export_fs_mrt_color(isel_context *ctx, int slot)
       } else if (is_16bit) {
          for (unsigned i = 0; i < 4; i++) {
             if ((write_mask >> i) & 1) {
-               Temp tmp = convert_int(bld, values[i].getTemp(), 16, 32, false);
+               Temp tmp = convert_int(ctx, bld, values[i].getTemp(), 16, 32, false);
                values[i] = Operand(tmp);
             }
          }
@@ -10084,7 +10088,7 @@ static bool export_fs_mrt_color(isel_context *ctx, int slot)
       } else if (is_16bit) {
          for (unsigned i = 0; i < 4; i++) {
             if ((write_mask >> i) & 1) {
-               Temp tmp = convert_int(bld, values[i].getTemp(), 16, 32, true);
+               Temp tmp = convert_int(ctx, bld, values[i].getTemp(), 16, 32, true);
                values[i] = Operand(tmp);
             }
          }