aco: add missing conversion operations for small bitsizes
authorDaniel Schürmann <daniel@schuermann.dev>
Fri, 28 Feb 2020 19:17:44 +0000 (20:17 +0100)
committerDaniel Schürmann <daniel@schuermann.dev>
Fri, 3 Apr 2020 22:13:15 +0000 (23:13 +0100)
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-By: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4002>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp

index 8b8f3fa8ce0ccc6357d887181ceb46a607395987..4589b405f6e658098619ffe9c81e94831acd796c 100644 (file)
@@ -1900,8 +1900,27 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       }
       break;
    }
+   case nir_op_f2f16:
+   case nir_op_f2f16_rtne: {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 64)
+         src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src);
+      src = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src);
+      bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
+      break;
+   }
+   case nir_op_f2f16_rtz: {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 64)
+         src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src);
+      src = bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, bld.def(v1), src, Operand(0u));
+      bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
+      break;
+   }
    case nir_op_f2f32: {
-      if (instr->src[0].src.ssa->bit_size == 64) {
+      if (instr->src[0].src.ssa->bit_size == 16) {
+         emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f16, dst);
+      } else if (instr->src[0].src.ssa->bit_size == 64) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1911,13 +1930,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_f2f64: {
-      if (instr->src[0].src.ssa->bit_size == 32) {
-         emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_f32, dst);
-      } else {
-         fprintf(stderr, "Unimplemented NIR instr bit size: ");
-         nir_print_instr(&instr->instr, stderr);
-         fprintf(stderr, "\n");
-      }
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 16)
+         src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src);
+      bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src);
       break;
    }
    case nir_op_i2f32: {
@@ -1969,6 +1985,36 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       }
       break;
    }
+   case nir_op_f2i16: {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 16)
+         src = bld.vop1(aco_opcode::v_cvt_i16_f16, bld.def(v1), src);
+      else if (instr->src[0].src.ssa->bit_size == 32)
+         src = bld.vop1(aco_opcode::v_cvt_i32_f32, bld.def(v1), src);
+      else
+         src = bld.vop1(aco_opcode::v_cvt_i32_f64, bld.def(v1), src);
+
+      if (dst.type() == RegType::vgpr)
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
+      else
+         bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src);
+      break;
+   }
+   case nir_op_f2u16: {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 16)
+         src = bld.vop1(aco_opcode::v_cvt_u16_f16, bld.def(v1), src);
+      else if (instr->src[0].src.ssa->bit_size == 32)
+         src = bld.vop1(aco_opcode::v_cvt_u32_f32, bld.def(v1), src);
+      else
+         src = bld.vop1(aco_opcode::v_cvt_u32_f64, bld.def(v1), src);
+
+      if (dst.type() == RegType::vgpr)
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src);
+      else
+         bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src);
+      break;
+   }
    case nir_op_f2i32: {
       Temp src = get_alu_src(ctx, instr->src[0]);
       if (instr->src[0].src.ssa->bit_size == 32) {
@@ -2190,9 +2236,91 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       }
       break;
    }
+   case nir_op_i2i8:
+   case nir_op_u2u8: {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      /* we can actually just say dst = src */
+      if (src.regClass() == s1)
+         bld.copy(Definition(dst), src);
+      else
+         emit_extract_vector(ctx, src, 0, dst);
+      break;
+   }
+   case nir_op_i2i16: {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 8) {
+         if (dst.regClass() == s1) {
+            bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src));
+         } else {
+            assert(src.regClass() == v1b);
+            aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
+            sdwa->operands[0] = Operand(src);
+            sdwa->definitions[0] = Definition(dst);
+            sdwa->sel[0] = sdwa_sbyte;
+            sdwa->dst_sel = sdwa_sword;
+            ctx->block->instructions.emplace_back(std::move(sdwa));
+         }
+      } else {
+         Temp src = get_alu_src(ctx, instr->src[0]);
+         /* we can actually just say dst = src */
+         if (src.regClass() == s1)
+            bld.copy(Definition(dst), src);
+         else
+            emit_extract_vector(ctx, src, 0, dst);
+      }
+      break;
+   }
+   case nir_op_u2u16: {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 8) {
+         if (dst.regClass() == s1)
+            bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src);
+         else {
+            assert(src.regClass() == v1b);
+            aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
+            sdwa->operands[0] = Operand(src);
+            sdwa->definitions[0] = Definition(dst);
+            sdwa->sel[0] = sdwa_ubyte;
+            sdwa->dst_sel = sdwa_uword;
+            ctx->block->instructions.emplace_back(std::move(sdwa));
+         }
+      } else {
+         Temp src = get_alu_src(ctx, instr->src[0]);
+         /* we can actually just say dst = src */
+         if (src.regClass() == s1)
+            bld.copy(Definition(dst), src);
+         else
+            emit_extract_vector(ctx, src, 0, dst);
+      }
+      break;
+   }
    case nir_op_i2i32: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (instr->src[0].src.ssa->bit_size == 64) {
+      if (instr->src[0].src.ssa->bit_size == 8) {
+         if (dst.regClass() == s1) {
+            bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src));
+         } else {
+            assert(src.regClass() == v1b);
+            aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
+            sdwa->operands[0] = Operand(src);
+            sdwa->definitions[0] = Definition(dst);
+            sdwa->sel[0] = sdwa_sbyte;
+            sdwa->dst_sel = sdwa_sdword;
+            ctx->block->instructions.emplace_back(std::move(sdwa));
+         }
+      } else if (instr->src[0].src.ssa->bit_size == 16) {
+         if (dst.regClass() == s1) {
+            bld.sop1(aco_opcode::s_sext_i32_i16, Definition(dst), Operand(src));
+         } else {
+            assert(src.regClass() == v2b);
+            aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
+            sdwa->operands[0] = Operand(src);
+            sdwa->definitions[0] = Definition(dst);
+            sdwa->sel[0] = sdwa_sword;
+            sdwa->dst_sel = sdwa_udword;
+            ctx->block->instructions.emplace_back(std::move(sdwa));
+         }
+      } else if (instr->src[0].src.ssa->bit_size == 64) {
          /* we can actually just say dst = src, as it would map the lower register */
          emit_extract_vector(ctx, src, 0, dst);
       } else {
@@ -2204,12 +2332,29 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_u2u32: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (instr->src[0].src.ssa->bit_size == 16) {
+      if (instr->src[0].src.ssa->bit_size == 8) {
+         if (dst.regClass() == s1)
+            bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src);
+         else {
+            assert(src.regClass() == v1b);
+            aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
+            sdwa->operands[0] = Operand(src);
+            sdwa->definitions[0] = Definition(dst);
+            sdwa->sel[0] = sdwa_ubyte;
+            sdwa->dst_sel = sdwa_udword;
+            ctx->block->instructions.emplace_back(std::move(sdwa));
+         }
+      } else if (instr->src[0].src.ssa->bit_size == 16) {
          if (dst.regClass() == s1) {
             bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFFFu), src);
          } else {
-            // TODO: do better with SDWA
-            bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0xFFFFu), src);
+            assert(src.regClass() == v2b);
+            aco_ptr<SDWA_instruction> sdwa{create_instruction<SDWA_instruction>(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)};
+            sdwa->operands[0] = Operand(src);
+            sdwa->definitions[0] = Definition(dst);
+            sdwa->sel[0] = sdwa_uword;
+            sdwa->dst_sel = sdwa_udword;
+            ctx->block->instructions.emplace_back(std::move(sdwa));
          }
       } else if (instr->src[0].src.ssa->bit_size == 64) {
          /* we can actually just say dst = src, as it would map the lower register */
@@ -2298,6 +2443,32 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    case nir_op_unpack_64_2x32_split_y:
       bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0]));
       break;
+   case nir_op_unpack_32_2x16_split_x:
+      if (dst.type() == RegType::vgpr) {
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(dst.regClass()), get_alu_src(ctx, instr->src[0]));
+      } else {
+         bld.copy(Definition(dst), get_alu_src(ctx, instr->src[0]));
+      }
+      break;
+   case nir_op_unpack_32_2x16_split_y:
+      if (dst.type() == RegType::vgpr) {
+         bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0]));
+      } else {
+         bld.sop2(aco_opcode::s_bfe_u32, Definition(dst), get_alu_src(ctx, instr->src[0]), Operand(uint32_t(16 << 16 | 16)));
+      }
+      break;
+   case nir_op_pack_32_2x16_split: {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      Temp src1 = get_alu_src(ctx, instr->src[1]);
+      if (dst.regClass() == v1) {
+         bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src0, src1);
+      } else {
+         src0 = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), src0, Operand(0xFFFFu));
+         src1 = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), src1, Operand(16u));
+         bld.sop2(aco_opcode::s_or_b32, Definition(dst), bld.def(s1, scc), src0, src1);
+      }
+      break;
+   }
    case nir_op_pack_half_2x16: {
       Temp src = get_alu_src(ctx, instr->src[0], 2);
 
index d7d294e7044dbd44bcf8e6ada3a3a097a0402da6..ac7df9201007225ce6ef0af588831ca8b2008bf8 100644 (file)
@@ -305,6 +305,9 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_op_fround_even:
                   case nir_op_fsin:
                   case nir_op_fcos:
+                  case nir_op_f2f16:
+                  case nir_op_f2f16_rtz:
+                  case nir_op_f2f16_rtne:
                   case nir_op_f2f32:
                   case nir_op_f2f64:
                   case nir_op_u2f32:
@@ -328,13 +331,15 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_op_cube_face_coord:
                      type = RegType::vgpr;
                      break;
+                  case nir_op_f2i16:
+                  case nir_op_f2u16:
+                  case nir_op_f2i32:
+                  case nir_op_f2u32:
                   case nir_op_f2i64:
                   case nir_op_f2u64:
                   case nir_op_b2i32:
                   case nir_op_b2b32:
                   case nir_op_b2f32:
-                  case nir_op_f2i32:
-                  case nir_op_f2u32:
                   case nir_op_mov:
                      type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr;
                      break;