aco: implement 16-bit nir_op_ldexp
[mesa.git] / src / amd / compiler / aco_instruction_selection.cpp
index 6a63d5d87189f9e03ccf460721ea69b7607e5333..b3f78a12f1578a62948b1ea638749c50524032d9 100644 (file)
@@ -85,6 +85,7 @@ struct if_context {
 
    unsigned BB_if_idx;
    unsigned invert_idx;
+   bool uniform_has_then_branch;
    bool then_branch_divergent;
    Block BB_invert;
    Block BB_endif;
@@ -737,12 +738,18 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
 
    if (dst.type() == RegType::vgpr) {
       aco_ptr<Instruction> bcsel;
-      if (dst.size() == 1) {
+      if (dst.regClass() == v2b) {
+         then = as_vgpr(ctx, then);
+         els = as_vgpr(ctx, els);
+
+         Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), els, then, cond);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          then = as_vgpr(ctx, then);
          els = as_vgpr(ctx, els);
 
          bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), els, then, cond);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          Temp then_lo = bld.tmp(v1), then_hi = bld.tmp(v1);
          bld.pseudo(aco_opcode::p_split_vector, Definition(then_lo), Definition(then_hi), then);
          Temp else_lo = bld.tmp(v1), else_hi = bld.tmp(v1);
@@ -1525,11 +1532,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fmul: {
-      if (dst.size() == 1) {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1]));
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.tmp(v1);
+         emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, tmp, true);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true);
-      } else if (dst.size() == 2) {
-         bld.vop3(aco_opcode::v_mul_f64, Definition(dst), get_alu_src(ctx, instr->src[0]),
-                  as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+      } else if (dst.regClass() == v2) {
+         bld.vop3(aco_opcode::v_mul_f64, Definition(dst), src0, src1);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1538,11 +1550,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fadd: {
-      if (dst.size() == 1) {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1]));
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.tmp(v1);
+         emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, tmp, true);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true);
-      } else if (dst.size() == 2) {
-         bld.vop3(aco_opcode::v_add_f64, Definition(dst), get_alu_src(ctx, instr->src[0]),
-                  as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+      } else if (dst.regClass() == v2) {
+         bld.vop3(aco_opcode::v_add_f64, Definition(dst), src0, src1);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1552,16 +1569,22 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fsub: {
       Temp src0 = get_alu_src(ctx, instr->src[0]);
-      Temp src1 = get_alu_src(ctx, instr->src[1]);
-      if (dst.size() == 1) {
+      Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1]));
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.tmp(v1);
+         if (src1.type() == RegType::vgpr || src0.type() != RegType::vgpr)
+            emit_vop2_instruction(ctx, instr, aco_opcode::v_sub_f16, tmp, false);
+         else
+            emit_vop2_instruction(ctx, instr, aco_opcode::v_subrev_f16, tmp, true);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          if (src1.type() == RegType::vgpr || src0.type() != RegType::vgpr)
             emit_vop2_instruction(ctx, instr, aco_opcode::v_sub_f32, dst, false);
          else
             emit_vop2_instruction(ctx, instr, aco_opcode::v_subrev_f32, dst, true);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          Instruction* add = bld.vop3(aco_opcode::v_add_f64, Definition(dst),
-                                     get_alu_src(ctx, instr->src[0]),
-                                     as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+                                     src0, src1);
          VOP3A_instruction* sub = static_cast<VOP3A_instruction*>(add);
          sub->neg[1] = true;
       } else {
@@ -1572,18 +1595,21 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fmax: {
-      if (dst.size() == 1) {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1]));
+      if (dst.regClass() == v2b) {
+         // TODO: check fp_mode.must_flush_denorms16_64
+         Temp tmp = bld.tmp(v1);
+         emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, tmp, true);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) {
-            Temp tmp = bld.vop3(aco_opcode::v_max_f64, bld.def(v2),
-                                get_alu_src(ctx, instr->src[0]),
-                                as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+            Temp tmp = bld.vop3(aco_opcode::v_max_f64, bld.def(v2), src0, src1);
             bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp);
          } else {
-            bld.vop3(aco_opcode::v_max_f64, Definition(dst),
-                     get_alu_src(ctx, instr->src[0]),
-                     as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+            bld.vop3(aco_opcode::v_max_f64, Definition(dst), src0, src1);
          }
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1593,18 +1619,21 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fmin: {
-      if (dst.size() == 1) {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1]));
+      if (dst.regClass() == v2b) {
+         // TODO: check fp_mode.must_flush_denorms16_64
+         Temp tmp = bld.tmp(v1);
+         emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, tmp, true);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) {
-            Temp tmp = bld.vop3(aco_opcode::v_min_f64, bld.def(v2),
-                                get_alu_src(ctx, instr->src[0]),
-                                as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+            Temp tmp = bld.vop3(aco_opcode::v_min_f64, bld.def(v2), src0, src1);
             bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp);
          } else {
-            bld.vop3(aco_opcode::v_min_f64, Definition(dst),
-                     get_alu_src(ctx, instr->src[0]),
-                     as_vgpr(ctx, get_alu_src(ctx, instr->src[1])));
+            bld.vop3(aco_opcode::v_min_f64, Definition(dst), src0, src1);
          }
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1730,9 +1759,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_frsq: {
-      if (dst.size() == 1) {
-         emit_rsq(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
-      } else if (dst.size() == 2) {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_rsq_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
+         emit_rsq(ctx, bld, Definition(dst), src);
+      } else if (dst.regClass() == v2) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1743,11 +1776,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fneg: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (dst.size() == 1) {
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop2(aco_opcode::v_xor_b32, bld.def(v1), Operand(0x8000u), as_vgpr(ctx, src));
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          if (ctx->block->fp_mode.must_flush_denorms32)
             src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src));
          bld.vop2(aco_opcode::v_xor_b32, Definition(dst), Operand(0x80000000u), as_vgpr(ctx, src));
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          if (ctx->block->fp_mode.must_flush_denorms16_64)
             src = bld.vop3(aco_opcode::v_mul_f64, bld.def(v2), Operand(0x3FF0000000000000lu), as_vgpr(ctx, src));
          Temp upper = bld.tmp(v1), lower = bld.tmp(v1);
@@ -1763,11 +1799,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fabs: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (dst.size() == 1) {
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x7FFFu), as_vgpr(ctx, src));
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          if (ctx->block->fp_mode.must_flush_denorms32)
             src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src));
          bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0x7FFFFFFFu), as_vgpr(ctx, src));
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          if (ctx->block->fp_mode.must_flush_denorms16_64)
             src = bld.vop3(aco_opcode::v_mul_f64, bld.def(v2), Operand(0x3FF0000000000000lu), as_vgpr(ctx, src));
          Temp upper = bld.tmp(v1), lower = bld.tmp(v1);
@@ -1783,11 +1822,15 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fsat: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (dst.size() == 1) {
+      if (dst.regClass() == v2b) {
+         Temp one = bld.copy(bld.def(s1), Operand(0x3c00u));
+         Temp tmp = bld.vop3(aco_opcode::v_med3_f16, bld.def(v1), Operand(0u), one, src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          bld.vop3(aco_opcode::v_med3_f32, Definition(dst), Operand(0u), Operand(0x3f800000u), src);
          /* apparently, it is not necessary to flush denorms if this instruction is used with these operands */
          // TODO: confirm that this holds under any circumstances
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          Instruction* add = bld.vop3(aco_opcode::v_add_f64, Definition(dst), src, Operand(0u));
          VOP3A_instruction* vop3 = static_cast<VOP3A_instruction*>(add);
          vop3->clamp = true;
@@ -1799,8 +1842,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_flog2: {
-      if (dst.size() == 1) {
-         emit_log2(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_log_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
+         emit_log2(ctx, bld, Definition(dst), src);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1809,9 +1856,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_frcp: {
-      if (dst.size() == 1) {
-         emit_rcp(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
-      } else if (dst.size() == 2) {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_rcp_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
+         emit_rcp(ctx, bld, Definition(dst), src);
+      } else if (dst.regClass() == v2) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1821,7 +1872,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fexp2: {
-      if (dst.size() == 1) {
+      if (dst.regClass() == v2b) {
+         Temp src = get_alu_src(ctx, instr->src[0]);
+         Temp tmp = bld.vop1(aco_opcode::v_exp_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f32, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1831,9 +1886,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fsqrt: {
-      if (dst.size() == 1) {
-         emit_sqrt(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
-      } else if (dst.size() == 2) {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_sqrt_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
+         emit_sqrt(ctx, bld, Definition(dst), src);
+      } else if (dst.regClass() == v2) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1843,9 +1902,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_ffract: {
-      if (dst.size() == 1) {
+      if (dst.regClass() == v2b) {
+         Temp src = get_alu_src(ctx, instr->src[0]);
+         Temp tmp = bld.vop1(aco_opcode::v_fract_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f32, dst);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f64, dst);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
@@ -1855,10 +1918,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_ffloor: {
-      if (dst.size() == 1) {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_floor_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_floor_f32, dst);
-      } else if (dst.size() == 2) {
-         emit_floor_f64(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
+      } else if (dst.regClass() == v2) {
+         emit_floor_f64(ctx, bld, Definition(dst), src);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1867,15 +1934,17 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fceil: {
-      if (dst.size() == 1) {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_ceil_f16, bld.def(v1), src0);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_ceil_f32, dst);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          if (ctx->options->chip_class >= GFX7) {
             emit_vop1_instruction(ctx, instr, aco_opcode::v_ceil_f64, dst);
          } else {
             /* GFX6 doesn't support V_CEIL_F64, lower it. */
-            Temp src0 = get_alu_src(ctx, instr->src[0]);
-
             /* trunc = trunc(src0)
              * if (src0 > 0.0 && src0 != trunc)
              *    trunc += 1.0
@@ -1896,10 +1965,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_ftrunc: {
-      if (dst.size() == 1) {
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_trunc_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_trunc_f32, dst);
-      } else if (dst.size() == 2) {
-         emit_trunc_f64(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0]));
+      } else if (dst.regClass() == v2) {
+         emit_trunc_f64(ctx, bld, Definition(dst), src);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1908,15 +1981,17 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_fround_even: {
-      if (dst.size() == 1) {
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_rndne_f16, bld.def(v1), src0);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_rndne_f32, dst);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          if (ctx->options->chip_class >= GFX7) {
             emit_vop1_instruction(ctx, instr, aco_opcode::v_rndne_f64, dst);
          } else {
             /* GFX6 doesn't support V_RNDNE_F64, lower it. */
-            Temp src0 = get_alu_src(ctx, instr->src[0]);
-
             Temp src0_lo = bld.tmp(v1), src0_hi = bld.tmp(v1);
             bld.pseudo(aco_opcode::p_split_vector, Definition(src0_lo), Definition(src0_hi), src0);
 
@@ -1948,11 +2023,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fsin:
    case nir_op_fcos: {
-      Temp src = get_alu_src(ctx, instr->src[0]);
+      Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0]));
       aco_ptr<Instruction> norm;
-      if (dst.size() == 1) {
-         Temp half_pi = bld.copy(bld.def(s1), Operand(0x3e22f983u));
-         Temp tmp = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), half_pi, as_vgpr(ctx, src));
+      Temp half_pi = bld.copy(bld.def(s1), Operand(0x3e22f983u));
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop2(aco_opcode::v_mul_f16, bld.def(v1), half_pi, src);
+         aco_opcode opcode = instr->op == nir_op_fsin ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16;
+         tmp = bld.vop1(opcode, bld.def(v1), tmp);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
+         Temp tmp = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), half_pi, src);
 
          /* before GFX9, v_sin_f32 and v_cos_f32 had a valid input domain of [-256, +256] */
          if (ctx->options->chip_class < GFX9)
@@ -1968,14 +2048,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_ldexp: {
-      if (dst.size() == 1) {
-         bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst),
-                  as_vgpr(ctx, get_alu_src(ctx, instr->src[0])),
-                  get_alu_src(ctx, instr->src[1]));
-      } else if (dst.size() == 2) {
-         bld.vop3(aco_opcode::v_ldexp_f64, Definition(dst),
-                  as_vgpr(ctx, get_alu_src(ctx, instr->src[0])),
-                  get_alu_src(ctx, instr->src[1]));
+      Temp src0 = get_alu_src(ctx, instr->src[0]);
+      Temp src1 = get_alu_src(ctx, instr->src[1]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.tmp(v1);
+         emit_vop2_instruction(ctx, instr, aco_opcode::v_ldexp_f16, tmp, false);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
+         bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst), as_vgpr(ctx, src0), src1);
+      } else if (dst.regClass() == v2) {
+         bld.vop3(aco_opcode::v_ldexp_f64, Definition(dst), as_vgpr(ctx, src0), src1);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1984,12 +2066,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_frexp_sig: {
-      if (dst.size() == 1) {
-         bld.vop1(aco_opcode::v_frexp_mant_f32, Definition(dst),
-                  get_alu_src(ctx, instr->src[0]));
-      } else if (dst.size() == 2) {
-         bld.vop1(aco_opcode::v_frexp_mant_f64, Definition(dst),
-                  get_alu_src(ctx, instr->src[0]));
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (dst.regClass() == v2b) {
+         Temp tmp = bld.vop1(aco_opcode::v_frexp_mant_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
+         bld.vop1(aco_opcode::v_frexp_mant_f32, Definition(dst), src);
+      } else if (dst.regClass() == v2) {
+         bld.vop1(aco_opcode::v_frexp_mant_f64, Definition(dst), src);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -1998,12 +2082,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_frexp_exp: {
-      if (instr->src[0].src.ssa->bit_size == 32) {
-         bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst),
-                  get_alu_src(ctx, instr->src[0]));
+      Temp src = get_alu_src(ctx, instr->src[0]);
+      if (instr->src[0].src.ssa->bit_size == 16) {
+         Temp tmp = bld.vop1(aco_opcode::v_frexp_exp_i16_f16, bld.def(v1), src);
+         bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), tmp, Operand(0u));
+      } else if (instr->src[0].src.ssa->bit_size == 32) {
+         bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst), src);
       } else if (instr->src[0].src.ssa->bit_size == 64) {
-         bld.vop1(aco_opcode::v_frexp_exp_i32_f64, Definition(dst),
-                  get_alu_src(ctx, instr->src[0]));
+         bld.vop1(aco_opcode::v_frexp_exp_i32_f64, Definition(dst), src);
       } else {
          fprintf(stderr, "Unimplemented NIR instr bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -2013,12 +2099,20 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_fsign: {
       Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0]));
-      if (dst.size() == 1) {
+      if (dst.regClass() == v2b) {
+         Temp one = bld.copy(bld.def(v1), Operand(0x3c00u));
+         Temp minus_one = bld.copy(bld.def(v1), Operand(0xbc00u));
+         Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f16, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src);
+         src = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), one, src, cond);
+         cond = bld.vopc(aco_opcode::v_cmp_le_f16, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src);
+         Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), minus_one, src, cond);
+         bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp);
+      } else if (dst.regClass() == v1) {
          Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f32, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src);
          src = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0x3f800000u), src, cond);
          cond = bld.vopc(aco_opcode::v_cmp_le_f32, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src);
          bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0xbf800000u), src, cond);
-      } else if (dst.size() == 2) {
+      } else if (dst.regClass() == v2) {
          Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f64, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src);
          Temp tmp = bld.vop1(aco_opcode::v_mov_b32, bld.def(v1), Operand(0x3FF00000u));
          Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), tmp, emit_extract_vector(ctx, src, 1, v1), cond);
@@ -2152,7 +2246,15 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_f2i32: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (instr->src[0].src.ssa->bit_size == 32) {
+      if (instr->src[0].src.ssa->bit_size == 16) {
+         Temp tmp = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src);
+         if (dst.type() == RegType::vgpr) {
+            bld.vop1(aco_opcode::v_cvt_i32_f32, Definition(dst), tmp);
+         } else {
+            bld.pseudo(aco_opcode::p_as_uniform, Definition(dst),
+                       bld.vop1(aco_opcode::v_cvt_i32_f32, bld.def(v1), tmp));
+         }
+      } else if (instr->src[0].src.ssa->bit_size == 32) {
          if (dst.type() == RegType::vgpr)
             bld.vop1(aco_opcode::v_cvt_i32_f32, Definition(dst), src);
          else
@@ -2175,7 +2277,15 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_f2u32: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (instr->src[0].src.ssa->bit_size == 32) {
+      if (instr->src[0].src.ssa->bit_size == 16) {
+         Temp tmp = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src);
+         if (dst.type() == RegType::vgpr) {
+            bld.vop1(aco_opcode::v_cvt_u32_f32, Definition(dst), tmp);
+         } else {
+            bld.pseudo(aco_opcode::p_as_uniform, Definition(dst),
+                       bld.vop1(aco_opcode::v_cvt_u32_f32, bld.def(v1), tmp));
+         }
+      } else if (instr->src[0].src.ssa->bit_size == 32) {
          if (dst.type() == RegType::vgpr)
             bld.vop1(aco_opcode::v_cvt_u32_f32, Definition(dst), src);
          else
@@ -3791,6 +3901,8 @@ void visit_store_output(isel_context *ctx, nir_intrinsic_instr *instr)
    if (ctx->stage == vertex_vs ||
        ctx->stage == tess_eval_vs ||
        ctx->stage == fragment_fs ||
+       ctx->stage == ngg_vertex_gs ||
+       ctx->stage == ngg_tess_eval_gs ||
        ctx->shader->info.stage == MESA_SHADER_GEOMETRY) {
       bool stored_to_temps = store_output_to_temps(ctx, instr);
       if (!stored_to_temps) {
@@ -9305,11 +9417,98 @@ static void end_divergent_if(isel_context *ctx, if_context *ic)
    }
 }
 
+static void begin_uniform_if_then(isel_context *ctx, if_context *ic, Temp cond)
+{
+   assert(cond.regClass() == s1);
+
+   append_logical_end(ctx->block);
+   ctx->block->kind |= block_kind_uniform;
+
+   aco_ptr<Pseudo_branch_instruction> branch;
+   aco_opcode branch_opcode = aco_opcode::p_cbranch_z;
+   branch.reset(create_instruction<Pseudo_branch_instruction>(branch_opcode, Format::PSEUDO_BRANCH, 1, 0));
+   branch->operands[0] = Operand(cond);
+   branch->operands[0].setFixed(scc);
+   ctx->block->instructions.emplace_back(std::move(branch));
+
+   ic->BB_if_idx = ctx->block->index;
+   ic->BB_endif = Block();
+   ic->BB_endif.loop_nest_depth = ctx->cf_info.loop_nest_depth;
+   ic->BB_endif.kind |= ctx->block->kind & block_kind_top_level;
+
+   ctx->cf_info.has_branch = false;
+   ctx->cf_info.parent_loop.has_divergent_branch = false;
+
+   /** emit then block */
+   Block* BB_then = ctx->program->create_and_insert_block();
+   BB_then->loop_nest_depth = ctx->cf_info.loop_nest_depth;
+   add_edge(ic->BB_if_idx, BB_then);
+   append_logical_start(BB_then);
+   ctx->block = BB_then;
+}
+
+static void begin_uniform_if_else(isel_context *ctx, if_context *ic)
+{
+   Block *BB_then = ctx->block;
+
+   ic->uniform_has_then_branch = ctx->cf_info.has_branch;
+   ic->then_branch_divergent = ctx->cf_info.parent_loop.has_divergent_branch;
+
+   if (!ic->uniform_has_then_branch) {
+      append_logical_end(BB_then);
+      /* branch from then block to endif block */
+      aco_ptr<Pseudo_branch_instruction> branch;
+      branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0));
+      BB_then->instructions.emplace_back(std::move(branch));
+      add_linear_edge(BB_then->index, &ic->BB_endif);
+      if (!ic->then_branch_divergent)
+         add_logical_edge(BB_then->index, &ic->BB_endif);
+      BB_then->kind |= block_kind_uniform;
+   }
+
+   ctx->cf_info.has_branch = false;
+   ctx->cf_info.parent_loop.has_divergent_branch = false;
+
+   /** emit else block */
+   Block* BB_else = ctx->program->create_and_insert_block();
+   BB_else->loop_nest_depth = ctx->cf_info.loop_nest_depth;
+   add_edge(ic->BB_if_idx, BB_else);
+   append_logical_start(BB_else);
+   ctx->block = BB_else;
+}
+
+static void end_uniform_if(isel_context *ctx, if_context *ic)
+{
+   Block *BB_else = ctx->block;
+
+   if (!ctx->cf_info.has_branch) {
+      append_logical_end(BB_else);
+      /* branch from then block to endif block */
+      aco_ptr<Pseudo_branch_instruction> branch;
+      branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0));
+      BB_else->instructions.emplace_back(std::move(branch));
+      add_linear_edge(BB_else->index, &ic->BB_endif);
+      if (!ctx->cf_info.parent_loop.has_divergent_branch)
+         add_logical_edge(BB_else->index, &ic->BB_endif);
+      BB_else->kind |= block_kind_uniform;
+   }
+
+   ctx->cf_info.has_branch &= ic->uniform_has_then_branch;
+   ctx->cf_info.parent_loop.has_divergent_branch &= ic->then_branch_divergent;
+
+   /** emit endif merge block */
+   if (!ctx->cf_info.has_branch) {
+      ctx->block = ctx->program->insert_block(std::move(ic->BB_endif));
+      append_logical_start(ctx->block);
+   }
+}
+
 static bool visit_if(isel_context *ctx, nir_if *if_stmt)
 {
    Temp cond = get_ssa_temp(ctx, if_stmt->condition.ssa);
    Builder bld(ctx->program, ctx->block);
    aco_ptr<Pseudo_branch_instruction> branch;
+   if_context ic;
 
    if (!ctx->divergent_vals[if_stmt->condition.ssa->index]) { /* uniform condition */
       /**
@@ -9327,77 +9526,19 @@ static bool visit_if(isel_context *ctx, nir_if *if_stmt)
        *    to the loop exit/entry block. Otherwise, it branches to the next
        *    merge block.
        **/
-      append_logical_end(ctx->block);
-      ctx->block->kind |= block_kind_uniform;
 
-      /* emit branch */
-      assert(cond.regClass() == bld.lm);
       // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction
+      assert(cond.regClass() == ctx->program->lane_mask);
       cond = bool_to_scalar_condition(ctx, cond);
 
-      branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_cbranch_z, Format::PSEUDO_BRANCH, 1, 0));
-      branch->operands[0] = Operand(cond);
-      branch->operands[0].setFixed(scc);
-      ctx->block->instructions.emplace_back(std::move(branch));
-
-      unsigned BB_if_idx = ctx->block->index;
-      Block BB_endif = Block();
-      BB_endif.loop_nest_depth = ctx->cf_info.loop_nest_depth;
-      BB_endif.kind |= ctx->block->kind & block_kind_top_level;
-
-      /** emit then block */
-      Block* BB_then = ctx->program->create_and_insert_block();
-      BB_then->loop_nest_depth = ctx->cf_info.loop_nest_depth;
-      add_edge(BB_if_idx, BB_then);
-      append_logical_start(BB_then);
-      ctx->block = BB_then;
+      begin_uniform_if_then(ctx, &ic, cond);
       visit_cf_list(ctx, &if_stmt->then_list);
-      BB_then = ctx->block;
-      bool then_branch = ctx->cf_info.has_branch;
-      bool then_branch_divergent = ctx->cf_info.parent_loop.has_divergent_branch;
 
-      if (!then_branch) {
-         append_logical_end(BB_then);
-         /* branch from then block to endif block */
-         branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0));
-         BB_then->instructions.emplace_back(std::move(branch));
-         add_linear_edge(BB_then->index, &BB_endif);
-         if (!then_branch_divergent)
-            add_logical_edge(BB_then->index, &BB_endif);
-         BB_then->kind |= block_kind_uniform;
-      }
-
-      ctx->cf_info.has_branch = false;
-      ctx->cf_info.parent_loop.has_divergent_branch = false;
-
-      /** emit else block */
-      Block* BB_else = ctx->program->create_and_insert_block();
-      BB_else->loop_nest_depth = ctx->cf_info.loop_nest_depth;
-      add_edge(BB_if_idx, BB_else);
-      append_logical_start(BB_else);
-      ctx->block = BB_else;
+      begin_uniform_if_else(ctx, &ic);
       visit_cf_list(ctx, &if_stmt->else_list);
-      BB_else = ctx->block;
-
-      if (!ctx->cf_info.has_branch) {
-         append_logical_end(BB_else);
-         /* branch from then block to endif block */
-         branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0));
-         BB_else->instructions.emplace_back(std::move(branch));
-         add_linear_edge(BB_else->index, &BB_endif);
-         if (!ctx->cf_info.parent_loop.has_divergent_branch)
-            add_logical_edge(BB_else->index, &BB_endif);
-         BB_else->kind |= block_kind_uniform;
-      }
 
-      ctx->cf_info.has_branch &= then_branch;
-      ctx->cf_info.parent_loop.has_divergent_branch &= then_branch_divergent;
+      end_uniform_if(ctx, &ic);
 
-      /** emit endif merge block */
-      if (!ctx->cf_info.has_branch) {
-         ctx->block = ctx->program->insert_block(std::move(BB_endif));
-         append_logical_start(ctx->block);
-      }
       return !ctx->cf_info.has_branch;
    } else { /* non-uniform condition */
       /**
@@ -9425,8 +9566,6 @@ static bool visit_if(isel_context *ctx, nir_if *if_stmt)
        * *) Exceptions may be due to break and continue statements within loops
        **/
 
-      if_context ic;
-
       begin_divergent_if_then(ctx, &ic, cond);
       visit_cf_list(ctx, &if_stmt->then_list);
 
@@ -9478,9 +9617,11 @@ static bool export_vs_varying(isel_context *ctx, int slot, bool is_pos, int *nex
 {
    assert(ctx->stage == vertex_vs ||
           ctx->stage == tess_eval_vs ||
-          ctx->stage == gs_copy_vs);
+          ctx->stage == gs_copy_vs ||
+          ctx->stage == ngg_vertex_gs ||
+          ctx->stage == ngg_tess_eval_gs);
 
-   int offset = ctx->stage == tess_eval_vs
+   int offset = (ctx->stage & sw_tes)
                 ? ctx->program->info->tes.outinfo.vs_output_param_offset[slot]
                 : ctx->program->info->vs.outinfo.vs_output_param_offset[slot];
    uint64_t mask = ctx->outputs.mask[slot];
@@ -9548,17 +9689,46 @@ static void export_vs_psiz_layer_viewport(isel_context *ctx, int *next_pos)
    ctx->block->instructions.emplace_back(std::move(exp));
 }
 
+static void create_export_phis(isel_context *ctx)
+{
+   /* Used when exports are needed, but the output temps are defined in a preceding block.
+    * This function will set up phis in order to access the outputs in the next block.
+    */
+
+   assert(ctx->block->instructions.back()->opcode == aco_opcode::p_logical_start);
+   aco_ptr<Instruction> logical_start = aco_ptr<Instruction>(ctx->block->instructions.back().release());
+   ctx->block->instructions.pop_back();
+
+   Builder bld(ctx->program, ctx->block);
+
+   for (unsigned slot = 0; slot <= VARYING_SLOT_VAR31; ++slot) {
+      uint64_t mask = ctx->outputs.mask[slot];
+      for (unsigned i = 0; i < 4; ++i) {
+         if (!(mask & (1 << i)))
+            continue;
+
+         Temp old = ctx->outputs.temps[slot * 4 + i];
+         Temp phi = bld.pseudo(aco_opcode::p_phi, bld.def(v1), old, Operand(v1));
+         ctx->outputs.temps[slot * 4 + i] = phi;
+      }
+   }
+
+   bld.insert(std::move(logical_start));
+}
+
 static void create_vs_exports(isel_context *ctx)
 {
    assert(ctx->stage == vertex_vs ||
           ctx->stage == tess_eval_vs ||
-          ctx->stage == gs_copy_vs);
+          ctx->stage == gs_copy_vs ||
+          ctx->stage == ngg_vertex_gs ||
+          ctx->stage == ngg_tess_eval_gs);
 
-   radv_vs_output_info *outinfo = ctx->stage == tess_eval_vs
+   radv_vs_output_info *outinfo = (ctx->stage & sw_tes)
                                   ? &ctx->program->info->tes.outinfo
                                   : &ctx->program->info->vs.outinfo;
 
-   if (outinfo->export_prim_id) {
+   if (outinfo->export_prim_id && !(ctx->stage & hw_ngg_gs)) {
       ctx->outputs.mask[VARYING_SLOT_PRIMITIVE_ID] |= 0x1;
       ctx->outputs.temps[VARYING_SLOT_PRIMITIVE_ID * 4u] = get_arg(ctx, ctx->args->vs_prim_id);
    }
@@ -9588,7 +9758,8 @@ static void create_vs_exports(isel_context *ctx)
    }
 
    for (unsigned i = 0; i <= VARYING_SLOT_VAR31; ++i) {
-      if (i < VARYING_SLOT_VAR0 && i != VARYING_SLOT_LAYER &&
+      if (i < VARYING_SLOT_VAR0 &&
+          i != VARYING_SLOT_LAYER &&
           i != VARYING_SLOT_PRIMITIVE_ID)
          continue;
 
@@ -10226,6 +10397,239 @@ void cleanup_cfg(Program *program)
    }
 }
 
+Temp merged_wave_info_to_mask(isel_context *ctx, unsigned i)
+{
+   Builder bld(ctx->program, ctx->block);
+
+   /* The s_bfm only cares about s0.u[5:0] so we don't need either s_bfe nor s_and here */
+   Temp count = i == 0
+                ? get_arg(ctx, ctx->args->merged_wave_info)
+                : bld.sop2(aco_opcode::s_lshr_b32, bld.def(s1), bld.def(s1, scc),
+                           get_arg(ctx, ctx->args->merged_wave_info), Operand(i * 8u));
+
+   Temp mask = bld.sop2(aco_opcode::s_bfm_b64, bld.def(s2), count, Operand(0u));
+   Temp cond;
+
+   if (ctx->program->wave_size == 64) {
+      /* Special case for 64 active invocations, because 64 doesn't work with s_bfm */
+      Temp active_64 = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), count, Operand(6u /* log2(64) */));
+      cond = bld.sop2(Builder::s_cselect, bld.def(bld.lm), Operand(-1u), mask, bld.scc(active_64));
+   } else {
+      /* We use s_bfm_b64 (not _b32) which works with 32, but we need to extract the lower half of the register */
+      cond = emit_extract_vector(ctx, mask, 0, bld.lm);
+   }
+
+   return cond;
+}
+
+bool ngg_early_prim_export(isel_context *ctx)
+{
+   /* TODO: Check edge flags, and if they are written, return false. (Needed for OpenGL, not for Vulkan.) */
+   return true;
+}
+
+void ngg_emit_sendmsg_gs_alloc_req(isel_context *ctx)
+{
+   Builder bld(ctx->program, ctx->block);
+
+   /* It is recommended to do the GS_ALLOC_REQ as soon and as quickly as possible, so we set the maximum priority (3). */
+   bld.sopp(aco_opcode::s_setprio, -1u, 0x3u);
+
+   /* Get the id of the current wave within the threadgroup (workgroup) */
+   Builder::Result wave_id_in_tg = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
+                                            get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16)));
+
+   /* Execute the following code only on the first wave (wave id 0),
+    * use the SCC def to tell if the wave id is zero or not.
+    */
+   Temp cond = wave_id_in_tg.def(1).getTemp();
+   if_context ic;
+   begin_uniform_if_then(ctx, &ic, cond);
+   begin_uniform_if_else(ctx, &ic);
+   bld.reset(ctx->block);
+
+   /* Number of vertices output by VS/TES */
+   Temp vtx_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
+                           get_arg(ctx, ctx->args->gs_tg_info), Operand(12u | (9u << 16u)));
+   /* Number of primitives output by VS/TES */
+   Temp prm_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
+                           get_arg(ctx, ctx->args->gs_tg_info), Operand(22u | (9u << 16u)));
+
+   /* Put the number of vertices and primitives into m0 for the GS_ALLOC_REQ */
+   Temp tmp = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), prm_cnt, Operand(12u));
+   tmp = bld.sop2(aco_opcode::s_or_b32, bld.m0(bld.def(s1)), bld.def(s1, scc), tmp, vtx_cnt);
+
+   /* Request the SPI to allocate space for the primitives and vertices that will be exported by the threadgroup. */
+   bld.sopp(aco_opcode::s_sendmsg, bld.m0(tmp), -1, sendmsg_gs_alloc_req);
+
+   /* After the GS_ALLOC_REQ is done, reset priority to default (0). */
+   bld.sopp(aco_opcode::s_setprio, -1u, 0x0u);
+
+   end_uniform_if(ctx, &ic);
+}
+
+Temp ngg_get_prim_exp_arg(isel_context *ctx, unsigned num_vertices, const Temp vtxindex[])
+{
+   Builder bld(ctx->program, ctx->block);
+
+   if (ctx->args->options->key.vs_common_out.as_ngg_passthrough) {
+      return get_arg(ctx, ctx->args->gs_vtx_offset[0]);
+   }
+
+   Temp gs_invocation_id = get_arg(ctx, ctx->args->ac.gs_invocation_id);
+   Temp tmp;
+
+   for (unsigned i = 0; i < num_vertices; ++i) {
+      assert(vtxindex[i].id());
+
+      if (i)
+         tmp = bld.vop3(aco_opcode::v_lshl_add_u32, bld.def(v1), vtxindex[i], Operand(10u * i), tmp);
+      else
+         tmp = vtxindex[i];
+
+      /* The initial edge flag is always false in tess eval shaders. */
+      if (ctx->stage == ngg_vertex_gs) {
+         Temp edgeflag = bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1), gs_invocation_id, Operand(8 + i), Operand(1u));
+         tmp = bld.vop3(aco_opcode::v_lshl_add_u32, bld.def(v1), edgeflag, Operand(10u * i + 9u), tmp);
+      }
+   }
+
+   /* TODO: Set isnull field in case of merged NGG VS+GS. */
+
+   return tmp;
+}
+
+void ngg_emit_prim_export(isel_context *ctx, unsigned num_vertices_per_primitive, const Temp vtxindex[])
+{
+   Builder bld(ctx->program, ctx->block);
+   Temp prim_exp_arg = ngg_get_prim_exp_arg(ctx, num_vertices_per_primitive, vtxindex);
+
+   bld.exp(aco_opcode::exp, prim_exp_arg, Operand(v1), Operand(v1), Operand(v1),
+        1 /* enabled mask */, V_008DFC_SQ_EXP_PRIM /* dest */,
+        false /* compressed */, true/* done */, false /* valid mask */);
+}
+
+void ngg_emit_nogs_gsthreads(isel_context *ctx)
+{
+   /* Emit the things that NGG GS threads need to do, for shaders that don't have SW GS.
+    * These must always come before VS exports.
+    *
+    * It is recommended to do these as early as possible. They can be at the beginning when
+    * there is no SW GS and the shader doesn't write edge flags.
+    */
+
+   if_context ic;
+   Temp is_gs_thread = merged_wave_info_to_mask(ctx, 1);
+   begin_divergent_if_then(ctx, &ic, is_gs_thread);
+
+   Builder bld(ctx->program, ctx->block);
+   constexpr unsigned max_vertices_per_primitive = 3;
+   unsigned num_vertices_per_primitive = max_vertices_per_primitive;
+
+   if (ctx->stage == ngg_vertex_gs) {
+      /* TODO: optimize for points & lines */
+   } else if (ctx->stage == ngg_tess_eval_gs) {
+      if (ctx->shader->info.tess.point_mode)
+         num_vertices_per_primitive = 1;
+      else if (ctx->shader->info.tess.primitive_mode == GL_ISOLINES)
+         num_vertices_per_primitive = 2;
+   } else {
+      unreachable("Unsupported NGG shader stage");
+   }
+
+   Temp vtxindex[max_vertices_per_primitive];
+   vtxindex[0] = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0xffffu),
+                          get_arg(ctx, ctx->args->gs_vtx_offset[0]));
+   vtxindex[1] = num_vertices_per_primitive < 2 ? Temp(0, v1) :
+                 bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1),
+                          get_arg(ctx, ctx->args->gs_vtx_offset[0]), Operand(16u), Operand(16u));
+   vtxindex[2] = num_vertices_per_primitive < 3 ? Temp(0, v1) :
+                 bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0xffffu),
+                          get_arg(ctx, ctx->args->gs_vtx_offset[2]));
+
+   /* Export primitive data to the index buffer. */
+   ngg_emit_prim_export(ctx, num_vertices_per_primitive, vtxindex);
+
+   /* Export primitive ID. */
+   if (ctx->stage == ngg_vertex_gs && ctx->args->options->key.vs_common_out.export_prim_id) {
+      /* Copy Primitive IDs from GS threads to the LDS address corresponding to the ES thread of the provoking vertex. */
+      Temp prim_id = get_arg(ctx, ctx->args->ac.gs_prim_id);
+      Temp provoking_vtx_index = vtxindex[0];
+      Temp addr = bld.v_mul_imm(bld.def(v1), provoking_vtx_index, 4u);
+
+      store_lds(ctx, 4, prim_id, 0x1u, addr, 0u, 4u);
+   }
+
+   begin_divergent_if_else(ctx, &ic);
+   end_divergent_if(ctx, &ic);
+}
+
+void ngg_emit_nogs_output(isel_context *ctx)
+{
+   /* Emits NGG GS output, for stages that don't have SW GS. */
+
+   if_context ic;
+   Builder bld(ctx->program, ctx->block);
+   bool late_prim_export = !ngg_early_prim_export(ctx);
+
+   /* NGG streamout is currently disabled by default. */
+   assert(!ctx->args->shader_info->so.num_outputs);
+
+   if (late_prim_export) {
+      /* VS exports are output to registers in a predecessor block. Emit phis to get them into this block. */
+      create_export_phis(ctx);
+      /* Do what we need to do in the GS threads. */
+      ngg_emit_nogs_gsthreads(ctx);
+
+      /* What comes next should be executed on ES threads. */
+      Temp is_es_thread = merged_wave_info_to_mask(ctx, 0);
+      begin_divergent_if_then(ctx, &ic, is_es_thread);
+      bld.reset(ctx->block);
+   }
+
+   /* Export VS outputs */
+   ctx->block->kind |= block_kind_export_end;
+   create_vs_exports(ctx);
+
+   /* Export primitive ID */
+   if (ctx->args->options->key.vs_common_out.export_prim_id) {
+      Temp prim_id;
+
+      if (ctx->stage == ngg_vertex_gs) {
+         /* Wait for GS threads to store primitive ID in LDS. */
+         bld.barrier(aco_opcode::p_memory_barrier_shared);
+         bld.sopp(aco_opcode::s_barrier);
+
+         /* Calculate LDS address where the GS threads stored the primitive ID. */
+         Temp wave_id_in_tg = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc),
+                                       get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16)));
+         Temp thread_id_in_wave = emit_mbcnt(ctx, bld.def(v1));
+         Temp wave_id_mul = bld.v_mul_imm(bld.def(v1), as_vgpr(ctx, wave_id_in_tg), ctx->program->wave_size);
+         Temp thread_id_in_tg = bld.vadd32(bld.def(v1), Operand(wave_id_mul), Operand(thread_id_in_wave));
+         Temp addr = bld.v_mul_imm(bld.def(v1), thread_id_in_tg, 4u);
+
+         /* Load primitive ID from LDS. */
+         prim_id = load_lds(ctx, 4, bld.tmp(v1), addr, 0u, 4u);
+      } else if (ctx->stage == ngg_tess_eval_gs) {
+         /* TES: Just use the patch ID as the primitive ID. */
+         prim_id = get_arg(ctx, ctx->args->ac.tes_patch_id);
+      } else {
+         unreachable("unsupported NGG shader stage.");
+      }
+
+      ctx->outputs.mask[VARYING_SLOT_PRIMITIVE_ID] |= 0x1;
+      ctx->outputs.temps[VARYING_SLOT_PRIMITIVE_ID * 4u] = prim_id;
+
+      export_vs_varying(ctx, VARYING_SLOT_PRIMITIVE_ID, false, nullptr);
+   }
+
+   if (late_prim_export) {
+      begin_divergent_if_else(ctx, &ic);
+      end_divergent_if(ctx, &ic);
+      bld.reset(ctx->block);
+   }
+}
+
 void select_program(Program *program,
                     unsigned shader_count,
                     struct nir_shader *const *shaders,
@@ -10234,6 +10638,7 @@ void select_program(Program *program,
 {
    isel_context ctx = setup_isel_context(program, shader_count, shaders, config, args, false);
    if_context ic_merged_wave_info;
+   bool ngg_no_gs = ctx.stage == ngg_vertex_gs || ctx.stage == ngg_tess_eval_gs;
 
    for (unsigned i = 0; i < shader_count; i++) {
       nir_shader *nir = shaders[i];
@@ -10252,6 +10657,13 @@ void select_program(Program *program,
          split_arguments(&ctx, startpgm);
       }
 
+      if (ngg_no_gs) {
+         ngg_emit_sendmsg_gs_alloc_req(&ctx);
+
+         if (ngg_early_prim_export(&ctx))
+            ngg_emit_nogs_gsthreads(&ctx);
+      }
+
       /* In a merged VS+TCS HS, the VS implementation can be completely empty. */
       nir_function_impl *func = nir_shader_get_entrypoint(nir);
       bool empty_shader = nir_cf_list_is_empty_block(&func->body) &&
@@ -10260,28 +10672,10 @@ void select_program(Program *program,
                            (nir->info.stage == MESA_SHADER_TESS_EVAL &&
                             ctx.stage == tess_eval_geometry_gs));
 
-      bool check_merged_wave_info = ctx.tcs_in_out_eq ? i == 0 : (shader_count >= 2 && !empty_shader);
+      bool check_merged_wave_info = ctx.tcs_in_out_eq ? i == 0 : ((shader_count >= 2 && !empty_shader) || ngg_no_gs);
       bool endif_merged_wave_info = ctx.tcs_in_out_eq ? i == 1 : check_merged_wave_info;
       if (check_merged_wave_info) {
-         Builder bld(ctx.program, ctx.block);
-
-         /* The s_bfm only cares about s0.u[5:0] so we don't need either s_bfe nor s_and here */
-         Temp count = i == 0 ? get_arg(&ctx, args->merged_wave_info)
-                             : bld.sop2(aco_opcode::s_lshr_b32, bld.def(s1), bld.def(s1, scc),
-                                        get_arg(&ctx, args->merged_wave_info), Operand(i * 8u));
-
-         Temp mask = bld.sop2(aco_opcode::s_bfm_b64, bld.def(s2), count, Operand(0u));
-         Temp cond;
-
-         if (ctx.program->wave_size == 64) {
-            /* Special case for 64 active invocations, because 64 doesn't work with s_bfm */
-            Temp active_64 = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), count, Operand(6u /* log2(64) */));
-            cond = bld.sop2(Builder::s_cselect, bld.def(bld.lm), Operand(-1u), mask, bld.scc(active_64));
-         } else {
-            /* We use s_bfm_b64 (not _b32) which works with 32, but we need to extract the lower half of the register */
-            cond = emit_extract_vector(&ctx, mask, 0, bld.lm);
-         }
-
+         Temp cond = merged_wave_info_to_mask(&ctx, i);
          begin_divergent_if_then(&ctx, &ic_merged_wave_info, cond);
       }
 
@@ -10302,11 +10696,14 @@ void select_program(Program *program,
 
       visit_cf_list(&ctx, &func->body);
 
-      if (ctx.program->info->so.num_outputs && (ctx.stage == vertex_vs || ctx.stage == tess_eval_vs))
+      if (ctx.program->info->so.num_outputs && (ctx.stage & hw_vs))
          emit_streamout(&ctx, 0);
 
-      if (ctx.stage == vertex_vs || ctx.stage == tess_eval_vs) {
+      if (ctx.stage & hw_vs) {
          create_vs_exports(&ctx);
+         ctx.block->kind |= block_kind_export_end;
+      } else if (ngg_no_gs && ngg_early_prim_export(&ctx)) {
+         ngg_emit_nogs_output(&ctx);
       } else if (nir->info.stage == MESA_SHADER_GEOMETRY) {
          Builder bld(ctx.program, ctx.block);
          bld.barrier(aco_opcode::p_memory_barrier_gs_data);
@@ -10315,14 +10712,19 @@ void select_program(Program *program,
          write_tcs_tess_factors(&ctx);
       }
 
-      if (ctx.stage == fragment_fs)
+      if (ctx.stage == fragment_fs) {
          create_fs_exports(&ctx);
+         ctx.block->kind |= block_kind_export_end;
+      }
 
       if (endif_merged_wave_info) {
          begin_divergent_if_else(&ctx, &ic_merged_wave_info);
          end_divergent_if(&ctx, &ic_merged_wave_info);
       }
 
+      if (ngg_no_gs && !ngg_early_prim_export(&ctx))
+         ngg_emit_nogs_output(&ctx);
+
       ralloc_free(ctx.divergent_vals);
 
       if (i == 0 && ctx.stage == vertex_tess_control_hs && ctx.tcs_in_out_eq) {
@@ -10335,7 +10737,7 @@ void select_program(Program *program,
    program->config->float_mode = program->blocks[0].fp_mode.val;
 
    append_logical_end(ctx.block);
-   ctx.block->kind |= block_kind_uniform | block_kind_export_end;
+   ctx.block->kind |= block_kind_uniform;
    Builder bld(ctx.program, ctx.block);
    if (ctx.program->wb_smem_l1_on_end)
       bld.smem(aco_opcode::s_dcache_wb, false);