aco: Treat all booleans as per-lane.
authorTimur Kristóf <timur.kristof@gmail.com>
Mon, 4 Nov 2019 18:28:08 +0000 (19:28 +0100)
committerTimur Kristóf <timur.kristof@gmail.com>
Thu, 14 Nov 2019 16:27:11 +0000 (17:27 +0100)
Previously, instruction selection had two kinds of booleans:
1. divergent which was per-lane and stored in s2 (VCC size)
2. uniform which was stored in s1
Additionally, uniform booleans were made per-lane when they resulted
from operations which were supported only by the VALU.

To decide which type was used, we relied on the destination size,
which was not reliable due to the per-lane uniform bools, but it
mostly works on wave64.
However, in wave32 mode (where VCC is also s1) this approach
makes it impossible keep track of which boolean is uniform and
which is divergent.

This commit makes all booleans per-lane.
The resulting excess code size will be taken care of by the optimizer.

v2 (by Daniel Schürmann):
- Better names for some functions
- Use s_andn2_b64 with exec for nir_op_inot
- Simplify code due to using s_and_b64 in bool_to_scalar_condition

v3 (by Timur Kristóf):
- Fix several subgroups regressions

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/compiler/aco_lower_bool_phis.cpp

index ab34a06867163701ef14e6a07879a6fd11bf8f5a..a7c3c7034038f2bd4f5bbdcac7d429c968ab43fb 100644 (file)
@@ -130,6 +130,8 @@ Temp emit_wqm(isel_context *ctx, Temp src, Temp dst=Temp(0, s1), bool program_ne
    if (!dst.id())
       dst = bld.tmp(src.regClass());
 
+   assert(src.size() == dst.size());
+
    if (ctx->stage != fragment_fs) {
       if (!dst.id())
          return src;
@@ -331,33 +333,31 @@ void expand_vector(isel_context* ctx, Temp vec_src, Temp dst, unsigned num_compo
    ctx->allocated_vec.emplace(dst.id(), elems);
 }
 
-Temp as_divergent_bool(isel_context *ctx, Temp val, bool vcc_hint)
+Temp bool_to_vector_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s2))
 {
-   if (val.regClass() == s2) {
-      return val;
-   } else {
-      assert(val.regClass() == s1);
-      Builder bld(ctx->program, ctx->block);
-      Definition& def = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2),
-                                 Operand((uint32_t) -1), Operand(0u), bld.scc(val)).def(0);
-      if (vcc_hint)
-         def.setHint(vcc);
-      return def.getTemp();
-   }
+   Builder bld(ctx->program, ctx->block);
+   if (!dst.id())
+      dst = bld.tmp(s2);
+
+   assert(val.regClass() == s1);
+   assert(dst.regClass() == s2);
+
+   return bld.sop2(aco_opcode::s_cselect_b64, bld.hint_vcc(Definition(dst)), Operand((uint32_t) -1), Operand(0u), bld.scc(val));
 }
 
-Temp as_uniform_bool(isel_context *ctx, Temp val)
+Temp bool_to_scalar_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s1))
 {
-   if (val.regClass() == s1) {
-      return val;
-   } else {
-      assert(val.regClass() == s2);
-      Builder bld(ctx->program, ctx->block);
-      /* if we're currently in WQM mode, ensure that the source is also computed in WQM */
-      Temp tmp = bld.tmp(s1);
-      bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2)).def(1).getTemp();
-      return emit_wqm(ctx, tmp);
-   }
+   Builder bld(ctx->program, ctx->block);
+   if (!dst.id())
+      dst = bld.tmp(s1);
+
+   assert(val.regClass() == s2);
+   assert(dst.regClass() == s1);
+
+   /* if we're currently in WQM mode, ensure that the source is also computed in WQM */
+   Temp tmp = bld.tmp(s1);
+   bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(Definition(tmp)), val, Operand(exec, s2));
+   return emit_wqm(ctx, tmp, dst);
 }
 
 Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1)
@@ -526,27 +526,44 @@ void emit_vopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o
          src1 = as_vgpr(ctx, src1);
       }
    }
+
    Builder bld(ctx->program, ctx->block);
-   bld.vopc(op, Definition(dst), src0, src1).def(0).setHint(vcc);
+   bld.vopc(op, bld.hint_vcc(Definition(dst)), src0, src1);
 }
 
-void emit_comparison(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst)
+void emit_sopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode op, Temp dst)
 {
-   if (dst.regClass() == s2) {
-      emit_vopc_instruction(ctx, instr, op, dst);
-      if (!ctx->divergent_vals[instr->dest.dest.ssa.index])
-         emit_split_vector(ctx, dst, 2);
-   } else if (dst.regClass() == s1) {
-      Temp src0 = get_alu_src(ctx, instr->src[0]);
-      Temp src1 = get_alu_src(ctx, instr->src[1]);
-      assert(src0.type() == RegType::sgpr && src1.type() == RegType::sgpr);
+   Temp src0 = get_alu_src(ctx, instr->src[0]);
+   Temp src1 = get_alu_src(ctx, instr->src[1]);
 
-      Builder bld(ctx->program, ctx->block);
-      bld.sopc(op, bld.scc(Definition(dst)), src0, src1);
+   assert(dst.regClass() == s2);
+   assert(src0.type() == RegType::sgpr);
+   assert(src1.type() == RegType::sgpr);
 
-   } else {
-      assert(false);
-   }
+   Builder bld(ctx->program, ctx->block);
+   /* Emit the SALU comparison instruction */
+   Temp cmp = bld.sopc(op, bld.scc(bld.def(s1)), src0, src1);
+   /* Turn the result into a per-lane bool */
+   bool_to_vector_condition(ctx, cmp, dst);
+}
+
+void emit_comparison(isel_context *ctx, nir_alu_instr *instr, Temp dst,
+                     aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::last_opcode, aco_opcode s64_op = aco_opcode::last_opcode)
+{
+   aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : s32_op;
+   aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : v32_op;
+   bool divergent_vals = ctx->divergent_vals[instr->dest.dest.ssa.index];
+   bool use_valu = s_op == aco_opcode::last_opcode ||
+                   divergent_vals ||
+                   ctx->allocated[instr->src[0].src.ssa->index].type() == RegType::vgpr ||
+                   ctx->allocated[instr->src[1].src.ssa->index].type() == RegType::vgpr;
+   aco_opcode op = use_valu ? v_op : s_op;
+   assert(op != aco_opcode::last_opcode);
+
+   if (use_valu)
+      emit_vopc_instruction(ctx, instr, op, dst);
+   else
+      emit_sopc_instruction(ctx, instr, op, dst);
 }
 
 void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32, aco_opcode op64, Temp dst)
@@ -554,16 +571,13 @@ void emit_boolean_logic(isel_context *ctx, nir_alu_instr *instr, aco_opcode op32
    Builder bld(ctx->program, ctx->block);
    Temp src0 = get_alu_src(ctx, instr->src[0]);
    Temp src1 = get_alu_src(ctx, instr->src[1]);
-   if (dst.regClass() == s2) {
-      bld.sop2(op64, Definition(dst), bld.def(s1, scc),
-               as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false));
-   } else {
-      assert(dst.regClass() == s1);
-      bld.sop2(op32, bld.def(s1), bld.scc(Definition(dst)),
-               as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1));
-   }
-}
 
+   assert(dst.regClass() == s2);
+   assert(src0.regClass() == s2);
+   assert(src1.regClass() == s2);
+
+   bld.sop2(op64, Definition(dst), bld.def(s1, scc), src0, src1);
+}
 
 void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
 {
@@ -572,9 +586,9 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
    Temp then = get_alu_src(ctx, instr->src[1]);
    Temp els = get_alu_src(ctx, instr->src[2]);
 
-   if (dst.type() == RegType::vgpr) {
-      cond = as_divergent_bool(ctx, cond, true);
+   assert(cond.regClass() == s2);
 
+   if (dst.type() == RegType::vgpr) {
       aco_ptr<Instruction> bcsel;
       if (dst.size() == 1) {
          then = as_vgpr(ctx, then);
@@ -599,11 +613,17 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
       return;
    }
 
-   if (instr->dest.dest.ssa.bit_size != 1) { /* uniform condition and values in sgpr */
+   if (instr->dest.dest.ssa.bit_size == 1) {
+      assert(dst.regClass() == s2);
+      assert(then.regClass() == s2);
+      assert(els.regClass() == s2);
+   }
+
+   if (!ctx->divergent_vals[instr->src[0].src.ssa->index]) { /* uniform condition and values in sgpr */
       if (dst.regClass() == s1 || dst.regClass() == s2) {
          assert((then.regClass() == s1 || then.regClass() == s2) && els.regClass() == then.regClass());
          aco_opcode op = dst.regClass() == s1 ? aco_opcode::s_cselect_b32 : aco_opcode::s_cselect_b64;
-         bld.sop2(op, Definition(dst), then, els, bld.scc(as_uniform_bool(ctx, cond)));
+         bld.sop2(op, Definition(dst), then, els, bld.scc(bool_to_scalar_condition(ctx, cond)));
       } else {
          fprintf(stderr, "Unimplemented uniform bcsel bit size: ");
          nir_print_instr(&instr->instr, stderr);
@@ -612,34 +632,10 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)
       return;
    }
 
-   /* boolean bcsel */
-   assert(instr->dest.dest.ssa.bit_size == 1);
-
-   if (dst.regClass() == s1)
-      cond = as_uniform_bool(ctx, cond);
-
-   if (cond.regClass() == s1) { /* uniform selection */
-      aco_opcode op;
-      if (dst.regClass() == s2) {
-         op = aco_opcode::s_cselect_b64;
-         then = as_divergent_bool(ctx, then, false);
-         els = as_divergent_bool(ctx, els, false);
-      } else {
-         assert(dst.regClass() == s1);
-         op = aco_opcode::s_cselect_b32;
-         then = as_uniform_bool(ctx, then);
-         els = as_uniform_bool(ctx, els);
-      }
-      bld.sop2(op, Definition(dst), then, els, bld.scc(cond));
-      return;
-   }
-
    /* divergent boolean bcsel
     * this implements bcsel on bools: dst = s0 ? s1 : s2
     * are going to be: dst = (s0 & s1) | (~s0 & s2) */
-   assert (dst.regClass() == s2);
-   then = as_divergent_bool(ctx, then, false);
-   els = as_divergent_bool(ctx, els, false);
+   assert(instr->dest.dest.ssa.bit_size == 1);
 
    if (cond.id() != then.id())
       then = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), cond, then);
@@ -700,16 +696,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_inot: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      /* uniform booleans */
-      if (instr->dest.dest.ssa.bit_size == 1 && dst.regClass() == s1) {
-         if (src.regClass() == s1) {
-            /* in this case, src is either 1 or 0 */
-            bld.sop2(aco_opcode::s_xor_b32, bld.def(s1), bld.scc(Definition(dst)), Operand(1u), src);
-         } else {
-            /* src is either exec_mask or 0 */
-            assert(src.regClass() == s2);
-            bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(Definition(dst)), Operand(0u), src);
-         }
+      if (instr->dest.dest.ssa.bit_size == 1) {
+         assert(src.regClass() == s2);
+         assert(dst.regClass() == s2);
+         bld.sop2(aco_opcode::s_andn2_b64, Definition(dst), bld.def(s1, scc), Operand(exec, s2), src);
       } else if (dst.regClass() == v1) {
          emit_vop1_instruction(ctx, instr, aco_opcode::v_not_b32, dst);
       } else if (dst.type() == RegType::sgpr) {
@@ -1919,12 +1909,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_b2f32: {
       Temp src = get_alu_src(ctx, instr->src[0]);
+      assert(src.regClass() == s2);
+
       if (dst.regClass() == s1) {
-         src = as_uniform_bool(ctx, src);
+         src = bool_to_scalar_condition(ctx, src);
          bld.sop2(aco_opcode::s_mul_i32, Definition(dst), Operand(0x3f800000u), src);
       } else if (dst.regClass() == v1) {
-         bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u),
-                      as_divergent_bool(ctx, src, true));
+         bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(0x3f800000u), src);
       } else {
          unreachable("Wrong destination register class for nir_op_b2f32.");
       }
@@ -1932,13 +1923,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_b2f64: {
       Temp src = get_alu_src(ctx, instr->src[0]);
+      assert(src.regClass() == s2);
+
       if (dst.regClass() == s2) {
-         src = as_uniform_bool(ctx, src);
+         src = bool_to_scalar_condition(ctx, src);
          bld.sop2(aco_opcode::s_cselect_b64, Definition(dst), Operand(0x3f800000u), Operand(0u), bld.scc(src));
       } else if (dst.regClass() == v2) {
          Temp one = bld.vop1(aco_opcode::v_mov_b32, bld.def(v2), Operand(0x3FF00000u));
-         Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one,
-                      as_divergent_bool(ctx, src, true));
+         Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, src);
          bld.pseudo(aco_opcode::p_create_vector, Definition(dst), Operand(0u), upper);
       } else {
          unreachable("Wrong destination register class for nir_op_b2f64.");
@@ -2000,29 +1992,31 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
    }
    case nir_op_b2i32: {
       Temp src = get_alu_src(ctx, instr->src[0]);
+      assert(src.regClass() == s2);
+
       if (dst.regClass() == s1) {
-         if (src.regClass() == s1) {
-            bld.copy(Definition(dst), src);
-         } else {
-            // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ
-            assert(src.regClass() == s2);
-            bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(Definition(dst)), Operand(0u), src);
-         }
-      } else {
-         assert(dst.regClass() == v1 && src.regClass() == s2);
+         // TODO: in a post-RA optimization, we can check if src is in VCC, and directly use VCCNZ
+         bool_to_scalar_condition(ctx, src, dst);
+      } else if (dst.regClass() == v1) {
          bld.vop2_e64(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0u), Operand(1u), src);
+      } else {
+         unreachable("Invalid register class for b2i32");
       }
       break;
    }
    case nir_op_i2b1: {
       Temp src = get_alu_src(ctx, instr->src[0]);
-      if (dst.regClass() == s2) {
+      assert(dst.regClass() == s2);
+
+      if (src.type() == RegType::vgpr) {
          assert(src.regClass() == v1 || src.regClass() == v2);
          bld.vopc(src.size() == 2 ? aco_opcode::v_cmp_lg_u64 : aco_opcode::v_cmp_lg_u32,
                   Definition(dst), Operand(0u), src).def(0).setHint(vcc);
       } else {
-         assert(src.regClass() == s1 && dst.regClass() == s1);
-         bld.sopc(aco_opcode::s_cmp_lg_u32, bld.scc(Definition(dst)), Operand(0u), src);
+         assert(src.regClass() == s1 || src.regClass() == s2);
+         Temp tmp = bld.sopc(src.size() == 2 ? aco_opcode::s_cmp_lg_u64 : aco_opcode::s_cmp_lg_u32,
+                             bld.scc(bld.def(s1)), Operand(0u), src);
+         bool_to_vector_condition(ctx, tmp, dst);
       }
       break;
    }
@@ -2228,119 +2222,49 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       break;
    }
    case nir_op_flt: {
-      if (instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f32, dst);
-      else if (instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_f64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64);
       break;
    }
    case nir_op_fge: {
-      if (instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f32, dst);
-      else if (instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_f64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64);
       break;
    }
    case nir_op_feq: {
-      if (instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f32, dst);
-      else if (instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_f64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64);
       break;
    }
    case nir_op_fne: {
-      if (instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f32, dst);
-      else if (instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_neq_f64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64);
       break;
    }
    case nir_op_ilt: {
-      if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i32, dst);
-      else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_i32, dst);
-      else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_i64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32);
       break;
    }
    case nir_op_ige: {
-      if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i32, dst);
-      else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_i32, dst);
-      else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_i64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32);
       break;
    }
    case nir_op_ieq: {
-      if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) {
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i32, dst);
-      } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) {
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_i32, dst);
-      } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) {
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_eq_i64, dst);
-      } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) {
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_eq_u64, dst);
-      } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) {
-         Temp src0 = get_alu_src(ctx, instr->src[0]);
-         Temp src1 = get_alu_src(ctx, instr->src[1]);
-         bld.sopc(aco_opcode::s_cmp_eq_i32, bld.scc(Definition(dst)),
-                  as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1));
-      } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) {
-         Temp src0 = get_alu_src(ctx, instr->src[0]);
-         Temp src1 = get_alu_src(ctx, instr->src[1]);
-         bld.sop2(aco_opcode::s_xnor_b64, Definition(dst), bld.def(s1, scc),
-                  as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false));
-      } else {
-         fprintf(stderr, "Unimplemented NIR instr bit size: ");
-         nir_print_instr(&instr->instr, stderr);
-         fprintf(stderr, "\n");
-      }
+      if (instr->src[0].src.ssa->bit_size == 1)
+         emit_boolean_logic(ctx, instr, aco_opcode::s_xnor_b32, aco_opcode::s_xnor_b64, dst);
+      else
+         emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, aco_opcode::s_cmp_eq_u64);
       break;
    }
    case nir_op_ine: {
-      if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32) {
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i32, dst);
-      } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64) {
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lg_i64, dst);
-      } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32) {
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_i32, dst);
-      } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 64) {
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_lg_u64, dst);
-      } else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 1) {
-         Temp src0 = get_alu_src(ctx, instr->src[0]);
-         Temp src1 = get_alu_src(ctx, instr->src[1]);
-         bld.sopc(aco_opcode::s_cmp_lg_i32, bld.scc(Definition(dst)),
-                  as_uniform_bool(ctx, src0), as_uniform_bool(ctx, src1));
-      } else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 1) {
-         Temp src0 = get_alu_src(ctx, instr->src[0]);
-         Temp src1 = get_alu_src(ctx, instr->src[1]);
-         bld.sop2(aco_opcode::s_xor_b64, Definition(dst), bld.def(s1, scc),
-                  as_divergent_bool(ctx, src0, false), as_divergent_bool(ctx, src1, false));
-      } else {
-         fprintf(stderr, "Unimplemented NIR instr bit size: ");
-         nir_print_instr(&instr->instr, stderr);
-         fprintf(stderr, "\n");
-      }
+      if (instr->src[0].src.ssa->bit_size == 1)
+         emit_boolean_logic(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::s_xor_b64, dst);
+      else
+         emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, aco_opcode::s_cmp_lg_u64);
       break;
    }
    case nir_op_ult: {
-      if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u32, dst);
-      else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_lt_u32, dst);
-      else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_lt_u64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32);
       break;
    }
    case nir_op_uge: {
-      if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u32, dst);
-      else if (dst.regClass() == s1 && instr->src[0].src.ssa->bit_size == 32)
-         emit_comparison(ctx, instr, aco_opcode::s_cmp_ge_u32, dst);
-      else if (dst.regClass() == s2 && instr->src[0].src.ssa->bit_size == 64)
-         emit_comparison(ctx, instr, aco_opcode::v_cmp_ge_u64, dst);
+      emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32);
       break;
    }
    case nir_op_fddx:
@@ -2387,9 +2311,13 @@ void visit_load_const(isel_context *ctx, nir_load_const_instr *instr)
    assert(instr->def.num_components == 1 && "Vector load_const should be lowered to scalar.");
    assert(dst.type() == RegType::sgpr);
 
-   if (dst.size() == 1)
-   {
-      Builder(ctx->program, ctx->block).copy(Definition(dst), Operand(instr->value[0].u32));
+   Builder bld(ctx->program, ctx->block);
+
+   if (instr->def.bit_size == 1) {
+      assert(dst.regClass() == s2);
+      bld.sop1(aco_opcode::s_mov_b64, Definition(dst), Operand((uint64_t)(instr->value[0].b ? -1 : 0)));
+   } else if (dst.size() == 1) {
+      bld.copy(Definition(dst), Operand(instr->value[0].u32));
    } else {
       assert(dst.size() != 1);
       aco_ptr<Pseudo_instruction> vec{create_instruction<Pseudo_instruction>(aco_opcode::p_create_vector, Format::PSEUDO, dst.size(), 1)};
@@ -3577,7 +3505,8 @@ void visit_discard_if(isel_context *ctx, nir_intrinsic_instr *instr)
 
    // TODO: optimize uniform conditions
    Builder bld(ctx->program, ctx->block);
-   Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false);
+   Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
+   assert(src.regClass() == s2);
    src = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2));
    bld.pseudo(aco_opcode::p_discard_if, src);
    ctx->block->kind |= block_kind_uses_discard_if;
@@ -5114,15 +5043,17 @@ Temp emit_boolean_reduce(isel_context *ctx, nir_op op, unsigned cluster_size, Te
    } else if (op == nir_op_iand && cluster_size == 64) {
       //subgroupAnd(val) -> (exec & ~val) == 0
       Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp();
-      return bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), tmp, Operand(0u));
+      return bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp));
    } else if (op == nir_op_ior && cluster_size == 64) {
       //subgroupOr(val) -> (val & exec) != 0
-      return bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp();
+      Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)).def(1).getTemp();
+      return bool_to_vector_condition(ctx, tmp);
    } else if (op == nir_op_ixor && cluster_size == 64) {
       //subgroupXor(val) -> s_bcnt1_i32_b64(val & exec) & 1
       Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2));
       tmp = bld.sop1(aco_opcode::s_bcnt1_i32_b64, bld.def(s2), bld.def(s1, scc), tmp);
-      return bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp();
+      tmp = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), tmp, Operand(1u)).def(1).getTemp();
+      return bool_to_vector_condition(ctx, tmp);
    } else {
       //subgroupClustered{And,Or,Xor}(val, n) ->
       //lane_id = v_mbcnt_hi_u32_b32(-1, v_mbcnt_lo_u32_b32(-1, 0))
@@ -5221,8 +5152,6 @@ void emit_uniform_subgroup(isel_context *ctx, nir_intrinsic_instr *instr, Temp s
    Definition dst(get_ssa_temp(ctx, &instr->dest.ssa));
    if (src.regClass().type() == RegType::vgpr) {
       bld.pseudo(aco_opcode::p_as_uniform, dst, src);
-   } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
-      bld.sopc(aco_opcode::s_cmp_lg_u64, bld.scc(dst), Operand(0u), Operand(src));
    } else if (src.regClass() == s1) {
       bld.sop1(aco_opcode::s_mov_b32, dst, src);
    } else if (src.regClass() == s2) {
@@ -5541,10 +5470,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
    case nir_intrinsic_ballot: {
       Definition tmp = bld.def(s2);
       Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
-      if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s2) {
+      if (instr->src[0].ssa->bit_size == 1) {
+         assert(src.regClass() == s2);
          bld.sop2(aco_opcode::s_and_b64, tmp, bld.def(s1, scc), Operand(exec, s2), src);
-      } else if (instr->src[0].ssa->bit_size == 1 && src.regClass() == s1) {
-         bld.sop2(aco_opcode::s_cselect_b64, tmp, Operand(exec, s2), Operand(0u), bld.scc(src));
       } else if (instr->src[0].ssa->bit_size == 32 && src.regClass() == v1) {
          bld.vopc(aco_opcode::v_cmp_lg_u32, tmp, Operand(0u), src);
       } else if (instr->src[0].ssa->bit_size == 64 && src.regClass() == v2) {
@@ -5576,9 +5504,12 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
             hi = emit_wqm(ctx, emit_bpermute(ctx, bld, tid, hi));
             bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
             emit_split_vector(ctx, dst, 2);
-         } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2 && tid.regClass() == s1) {
-            emit_wqm(ctx, bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid), dst);
-         } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
+         } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == s1) {
+            assert(src.regClass() == s2);
+            Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src, tid);
+            bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst);
+         } else if (instr->dest.ssa.bit_size == 1 && tid.regClass() == v1) {
+            assert(src.regClass() == s2);
             Temp tmp = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), tid, src);
             tmp = emit_extract_vector(ctx, tmp, 0, v1);
             tmp = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(1u), tmp);
@@ -5614,11 +5545,11 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
          hi = emit_wqm(ctx, bld.vop1(aco_opcode::v_readfirstlane_b32, bld.def(s1), hi));
          bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, hi);
          emit_split_vector(ctx, dst, 2);
-      } else if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
-         emit_wqm(ctx,
-                  bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src,
-                           bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2))),
-                  dst);
+      } else if (instr->dest.ssa.bit_size == 1) {
+         assert(src.regClass() == s2);
+         Temp tmp = bld.sopc(aco_opcode::s_bitcmp1_b64, bld.def(s1, scc), src,
+                             bld.sop1(aco_opcode::s_ff1_i32_b64, bld.def(s1), Operand(exec, s2)));
+         bool_to_vector_condition(ctx, emit_wqm(ctx, tmp), dst);
       } else if (src.regClass() == s1) {
          bld.sop1(aco_opcode::s_mov_b32, Definition(dst), src);
       } else if (src.regClass() == s2) {
@@ -5631,27 +5562,25 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       break;
    }
    case nir_intrinsic_vote_all: {
-      Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false);
+      Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
       Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
       assert(src.regClass() == s2);
-      assert(dst.regClass() == s1);
+      assert(dst.regClass() == s2);
 
-      Definition tmp = bld.def(s1);
-      bld.sopc(aco_opcode::s_cmp_eq_u64, bld.scc(tmp),
-               bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2)),
-               Operand(exec, s2));
-      emit_wqm(ctx, tmp.getTemp(), dst);
+      Temp tmp = bld.sop2(aco_opcode::s_andn2_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp();
+      Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(0u), Operand(-1u), bld.scc(tmp));
+      emit_wqm(ctx, val, dst);
       break;
    }
    case nir_intrinsic_vote_any: {
-      Temp src = as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false);
+      Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
       Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
       assert(src.regClass() == s2);
-      assert(dst.regClass() == s1);
+      assert(dst.regClass() == s2);
 
-      Definition tmp = bld.def(s1);
-      bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.scc(tmp), src, Operand(exec, s2));
-      emit_wqm(ctx, tmp.getTemp(), dst);
+      Temp tmp = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), Operand(exec, s2), src).def(1).getTemp();
+      Temp val = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), Operand(-1u), Operand(0u), bld.scc(tmp));
+      emit_wqm(ctx, val, dst);
       break;
    }
    case nir_intrinsic_reduce:
@@ -5752,7 +5681,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       } else {
          Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
          unsigned lane = nir_src_as_const_value(instr->src[1])->u32;
-         if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
+         if (instr->dest.ssa.bit_size == 1) {
+            assert(src.regClass() == s2);
             uint32_t half_mask = 0x11111111u << lane;
             Temp mask_tmp = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), Operand(half_mask), Operand(half_mask));
             Temp tmp = bld.tmp(s2);
@@ -5809,7 +5739,8 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       }
 
       Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
-      if (instr->dest.ssa.bit_size == 1 && src.regClass() == s2) {
+      if (instr->dest.ssa.bit_size == 1) {
+         assert(src.regClass() == s2);
          src = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), Operand((uint32_t)-1), src);
          src = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl);
          Temp tmp = bld.vopc(aco_opcode::v_cmp_lg_u32, bld.def(s2), Operand(0u), src);
@@ -5912,9 +5843,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
       ctx->program->needs_exact = true;
       break;
    case nir_intrinsic_demote_if: {
-      Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc),
-                           as_divergent_bool(ctx, get_ssa_temp(ctx, instr->src[0].ssa), false),
-                           Operand(exec, s2));
+      Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
+      assert(src.regClass() == s2);
+      Temp cond = bld.sop2(aco_opcode::s_and_b64, bld.def(s2), bld.def(s1, scc), src, Operand(exec, s2));
       bld.pseudo(aco_opcode::p_demote_to_helper, cond);
       ctx->block->kind |= block_kind_uses_demote;
       ctx->program->needs_exact = true;
@@ -6520,7 +6451,9 @@ void visit_tex(isel_context *ctx, nir_tex_instr *instr)
                             Operand((uint32_t)V_008F14_IMG_NUM_FORMAT_SINT),
                             bld.scc(compare_cube_wa));
          }
-         tg4_compare_cube_wa64 = as_divergent_bool(ctx, compare_cube_wa, true);
+         tg4_compare_cube_wa64 = bld.tmp(s2);
+         bool_to_vector_condition(ctx, compare_cube_wa, tg4_compare_cube_wa64);
+
          nfmt = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), nfmt, Operand(26u));
 
          desc[1] = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), desc[1],
@@ -6770,6 +6703,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
    aco_ptr<Pseudo_instruction> phi;
    unsigned num_src = exec_list_length(&instr->srcs);
    Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
+   assert(instr->dest.ssa.bit_size != 1 || dst.regClass() == s2);
 
    aco_opcode opcode = !dst.is_linear() || ctx->divergent_vals[instr->dest.ssa.index] ? aco_opcode::p_phi : aco_opcode::p_linear_phi;
 
@@ -6797,7 +6731,7 @@ void visit_phi(isel_context *ctx, nir_phi_instr *instr)
    }
 
    /* try to scalarize vector phis */
-   if (dst.size() > 1) {
+   if (instr->dest.ssa.bit_size != 1 && dst.size() > 1) {
       // TODO: scalarize linear phis on divergent ifs
       bool can_scalarize = (opcode == aco_opcode::p_phi || !(ctx->block->kind & block_kind_merge));
       std::array<Temp, 4> new_vec;
@@ -7265,10 +7199,10 @@ static void visit_if(isel_context *ctx, nir_if *if_stmt)
       ctx->block->kind |= block_kind_uniform;
 
       /* emit branch */
-      if (cond.regClass() == s2) {
-         // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction
-         cond = as_uniform_bool(ctx, cond);
-      }
+      assert(cond.regClass() == s2);
+      // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction
+      cond = bool_to_scalar_condition(ctx, cond);
+
       branch.reset(create_instruction<Pseudo_branch_instruction>(aco_opcode::p_cbranch_z, Format::PSEUDO_BRANCH, 1, 0));
       branch->operands[0] = Operand(cond);
       branch->operands[0].setFixed(scc);
index 2c349635799dc8e91bd658cc3cd7361a1f07af74..cdc8103497bbb655ef3ad710154080dfc50e7bd9 100644 (file)
@@ -244,25 +244,14 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_op_fge:
                   case nir_op_feq:
                   case nir_op_fne:
-                     size = 2;
-                     break;
                   case nir_op_ilt:
                   case nir_op_ige:
                   case nir_op_ult:
                   case nir_op_uge:
-                     size = alu_instr->src[0].src.ssa->bit_size == 64 ? 2 : 1;
-                     /* fallthrough */
                   case nir_op_ieq:
                   case nir_op_ine:
                   case nir_op_i2b1:
-                     if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) {
-                        size = 2;
-                     } else {
-                        for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
-                           if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr)
-                              size = 2;
-                        }
-                     }
+                     size = 2;
                      break;
                   case nir_op_f2i64:
                   case nir_op_f2u64:
@@ -274,13 +263,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
                      break;
                   case nir_op_bcsel:
                      if (alu_instr->dest.dest.ssa.bit_size == 1) {
-                        if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index])
-                           size = 2;
-                        else if (allocated[alu_instr->src[1].src.ssa->index].regClass() == s2 &&
-                                 allocated[alu_instr->src[2].src.ssa->index].regClass() == s2)
-                           size = 2;
-                        else
-                           size = 1;
+                        size = 2;
                      } else {
                         if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) {
                            type = RegType::vgpr;
@@ -298,32 +281,14 @@ void init_context(isel_context *ctx, nir_shader *shader)
                      break;
                   case nir_op_mov:
                      if (alu_instr->dest.dest.ssa.bit_size == 1) {
-                        size = allocated[alu_instr->src[0].src.ssa->index].size();
+                        size = 2;
                      } else {
                         type = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? RegType::vgpr : RegType::sgpr;
                      }
                      break;
-                  case nir_op_inot:
-                  case nir_op_ixor:
-                     if (alu_instr->dest.dest.ssa.bit_size == 1) {
-                        size = ctx->divergent_vals[alu_instr->dest.dest.ssa.index] ? 2 : 1;
-                        break;
-                     } else {
-                        /* fallthrough */
-                     }
                   default:
                      if (alu_instr->dest.dest.ssa.bit_size == 1) {
-                        if (ctx->divergent_vals[alu_instr->dest.dest.ssa.index]) {
-                           size = 2;
-                        } else {
-                           size = 2;
-                           for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
-                              if (allocated[alu_instr->src[i].src.ssa->index].regClass() == s1) {
-                                 size = 1;
-                                 break;
-                              }
-                           }
-                        }
+                        size = 2;
                      } else {
                         for (unsigned i = 0; i < nir_op_infos[alu_instr->op].num_inputs; i++) {
                            if (allocated[alu_instr->src[i].src.ssa->index].type() == RegType::vgpr)
@@ -339,6 +304,8 @@ void init_context(isel_context *ctx, nir_shader *shader)
                unsigned size = nir_instr_as_load_const(instr)->def.num_components;
                if (nir_instr_as_load_const(instr)->def.bit_size == 64)
                   size *= 2;
+               else if (nir_instr_as_load_const(instr)->def.bit_size == 1)
+                  size *= 2;
                allocated[nir_instr_as_load_const(instr)->def.index] = Temp(0, RegClass(RegType::sgpr, size));
                break;
             }
@@ -365,6 +332,8 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_intrinsic_read_invocation:
                   case nir_intrinsic_first_invocation:
                      type = RegType::sgpr;
+                     if (intrinsic->dest.ssa.bit_size == 1)
+                        size = 2;
                      break;
                   case nir_intrinsic_ballot:
                      type = RegType::sgpr;
@@ -433,11 +402,11 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_intrinsic_masked_swizzle_amd:
                   case nir_intrinsic_inclusive_scan:
                   case nir_intrinsic_exclusive_scan:
-                     if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) {
+                     if (intrinsic->dest.ssa.bit_size == 1) {
+                        size = 2;
                         type = RegType::sgpr;
-                     } else if (intrinsic->src[0].ssa->bit_size == 1) {
+                     } else if (!ctx->divergent_vals[intrinsic->dest.ssa.index]) {
                         type = RegType::sgpr;
-                        size = 2;
                      } else {
                         type = RegType::vgpr;
                      }
@@ -452,12 +421,12 @@ void init_context(isel_context *ctx, nir_shader *shader)
                      size = 2;
                      break;
                   case nir_intrinsic_reduce:
-                     if (nir_intrinsic_cluster_size(intrinsic) == 0 ||
-                         !ctx->divergent_vals[intrinsic->dest.ssa.index]) {
+                     if (intrinsic->dest.ssa.bit_size == 1) {
+                        size = 2;
                         type = RegType::sgpr;
-                     } else if (intrinsic->src[0].ssa->bit_size == 1) {
+                     } else if (nir_intrinsic_cluster_size(intrinsic) == 0 ||
+                         !ctx->divergent_vals[intrinsic->dest.ssa.index]) {
                         type = RegType::sgpr;
-                        size = 2;
                      } else {
                         type = RegType::vgpr;
                      }
@@ -554,7 +523,7 @@ void init_context(isel_context *ctx, nir_shader *shader)
                if (phi->dest.ssa.bit_size == 1) {
                   assert(size == 1 && "multiple components not yet supported on boolean phis.");
                   type = RegType::sgpr;
-                  size *= ctx->divergent_vals[phi->dest.ssa.index] ? 2 : 1;
+                  size *= 2;
                   allocated[phi->dest.ssa.index] = Temp(0, RegClass(type, size));
                   break;
                }
index ac4663a2ce195f9e161e22428f05826247b255c8..9e5374fe6a025a93cf587b00f6903e9d622166f4 100644 (file)
@@ -150,13 +150,6 @@ void lower_divergent_bool_phi(Program *program, Block *block, aco_ptr<Instructio
 
       assert(phi->operands[i].isTemp());
       Temp phi_src = phi->operands[i].getTemp();
-      if (phi_src.regClass() == s1) {
-         Temp new_phi_src = bld.tmp(s2);
-         insert_before_logical_end(pred,
-            bld.sop2(aco_opcode::s_cselect_b64, Definition(new_phi_src),
-                     Operand((uint32_t)-1), Operand(0u), bld.scc(phi_src)).get_ptr());
-         phi_src = new_phi_src;
-      }
       assert(phi_src.regClass() == s2);
 
       Operand cur = get_ssa(program, pred->index, &state);
@@ -218,6 +211,7 @@ void lower_bool_phis(Program* program)
    for (Block& block : program->blocks) {
       for (aco_ptr<Instruction>& phi : block.instructions) {
          if (phi->opcode == aco_opcode::p_phi) {
+            assert(phi->definitions[0].regClass() != s1);
             if (phi->definitions[0].regClass() == s2)
                lower_divergent_bool_phi(program, &block, phi);
          } else if (phi->opcode == aco_opcode::p_linear_phi) {