nir: introduce lowering of bitfield_insert to bfm and a new opcode bitfield_select.
[mesa.git] / src / compiler / nir / nir_lower_alu_to_scalar.c
index 5b3281e0a13d1a26e5a696cef8da457180291685..71389c2f0c38ff8a18ee1671b718502988db5870 100644 (file)
@@ -56,6 +56,7 @@ lower_reduction(nir_alu_instr *instr, nir_op chan_op, nir_op merge_op,
          nir_alu_src_copy(&chan->src[1], &instr->src[1], chan);
          chan->src[1].swizzle[0] = chan->src[1].swizzle[i];
       }
+      chan->exact = instr->exact;
 
       nir_builder_instr_insert(builder, &chan->instr);
 
@@ -72,8 +73,8 @@ lower_reduction(nir_alu_instr *instr, nir_op chan_op, nir_op merge_op,
    nir_instr_remove(&instr->instr);
 }
 
-static void
-lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
+static bool
+lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b, BITSET_WORD *lower_set)
 {
    unsigned num_src = nir_op_infos[instr->op].num_inputs;
    unsigned i, chan;
@@ -82,36 +83,42 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
    assert(instr->dest.write_mask != 0);
 
    b->cursor = nir_before_instr(&instr->instr);
+   b->exact = instr->exact;
+
+   if (lower_set && !BITSET_TEST(lower_set, instr->op))
+      return false;
 
 #define LOWER_REDUCTION(name, chan, merge) \
    case name##2: \
    case name##3: \
    case name##4: \
       lower_reduction(instr, chan, merge, b); \
-      return;
+      return true;
 
    switch (instr->op) {
    case nir_op_vec4:
    case nir_op_vec3:
    case nir_op_vec2:
+   case nir_op_cube_face_coord:
+   case nir_op_cube_face_index:
       /* We don't need to scalarize these ops, they're the ones generated to
        * group up outputs into a value that can be SSAed.
        */
-      return;
+      return false;
 
    case nir_op_pack_half_2x16:
       if (!b->shader->options->lower_pack_half_2x16)
-         return;
+         return false;
+
+      nir_ssa_def *src_vec2 = nir_ssa_for_alu_src(b, instr, 0);
 
       nir_ssa_def *val =
-         nir_pack_half_2x16_split(b, nir_channel(b, instr->src[0].src.ssa,
-                                                 instr->src[0].swizzle[0]),
-                                     nir_channel(b, instr->src[0].src.ssa,
-                                                 instr->src[0].swizzle[1]));
+         nir_pack_half_2x16_split(b, nir_channel(b, src_vec2, 0),
+                                     nir_channel(b, src_vec2, 1));
 
       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
       nir_instr_remove(&instr->instr);
-      return;
+      return true;
 
    case nir_op_unpack_unorm_4x8:
    case nir_op_unpack_snorm_4x8:
@@ -120,28 +127,30 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
       /* There is no scalar version of these ops, unless we were to break it
        * down to bitshifts and math (which is definitely not intended).
        */
-      return;
+      return false;
 
    case nir_op_unpack_half_2x16: {
       if (!b->shader->options->lower_unpack_half_2x16)
-         return;
+         return false;
+
+      nir_ssa_def *packed = nir_ssa_for_alu_src(b, instr, 0);
 
       nir_ssa_def *comps[2];
-      comps[0] = nir_unpack_half_2x16_split_x(b, instr->src[0].src.ssa);
-      comps[1] = nir_unpack_half_2x16_split_y(b, instr->src[0].src.ssa);
+      comps[0] = nir_unpack_half_2x16_split_x(b, packed);
+      comps[1] = nir_unpack_half_2x16_split_y(b, packed);
       nir_ssa_def *vec = nir_vec(b, comps, 2);
 
       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(vec));
       nir_instr_remove(&instr->instr);
-      return;
+      return true;
    }
 
    case nir_op_pack_uvec2_to_uint: {
       assert(b->shader->options->lower_pack_snorm_2x16 ||
              b->shader->options->lower_pack_unorm_2x16);
 
-      nir_ssa_def *word =
-         nir_extract_u16(b, instr->src[0].src.ssa, nir_imm_int(b, 0));
+      nir_ssa_def *word = nir_extract_u16(b, nir_ssa_for_alu_src(b, instr, 0),
+                                             nir_imm_int(b, 0));
       nir_ssa_def *val =
          nir_ior(b, nir_ishl(b, nir_channel(b, word, 1), nir_imm_int(b, 16)),
                                 nir_channel(b, word, 0));
@@ -155,8 +164,8 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
       assert(b->shader->options->lower_pack_snorm_4x8 ||
              b->shader->options->lower_pack_unorm_4x8);
 
-      nir_ssa_def *byte =
-         nir_extract_u8(b, instr->src[0].src.ssa, nir_imm_int(b, 0));
+      nir_ssa_def *byte = nir_extract_u8(b, nir_ssa_for_alu_src(b, instr, 0),
+                                            nir_imm_int(b, 0));
       nir_ssa_def *val =
          nir_ior(b, nir_ior(b, nir_ishl(b, nir_channel(b, byte, 3), nir_imm_int(b, 24)),
                                nir_ishl(b, nir_channel(b, byte, 2), nir_imm_int(b, 16))),
@@ -169,28 +178,37 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
    }
 
    case nir_op_fdph: {
+      nir_ssa_def *src0_vec = nir_ssa_for_alu_src(b, instr, 0);
+      nir_ssa_def *src1_vec = nir_ssa_for_alu_src(b, instr, 1);
+
       nir_ssa_def *sum[4];
       for (unsigned i = 0; i < 3; i++) {
-         sum[i] = nir_fmul(b, nir_channel(b, instr->src[0].src.ssa,
-                                          instr->src[0].swizzle[i]),
-                              nir_channel(b, instr->src[1].src.ssa,
-                                          instr->src[1].swizzle[i]));
+         sum[i] = nir_fmul(b, nir_channel(b, src0_vec, i),
+                              nir_channel(b, src1_vec, i));
       }
-      sum[3] = nir_channel(b, instr->src[1].src.ssa, instr->src[1].swizzle[3]);
+      sum[3] = nir_channel(b, src1_vec, 3);
 
       nir_ssa_def *val = nir_fadd(b, nir_fadd(b, sum[0], sum[1]),
                                      nir_fadd(b, sum[2], sum[3]));
 
       nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
       nir_instr_remove(&instr->instr);
-      return;
+      return true;
    }
 
+   case nir_op_unpack_64_2x32:
+   case nir_op_unpack_32_2x16:
+      return false;
+
       LOWER_REDUCTION(nir_op_fdot, nir_op_fmul, nir_op_fadd);
       LOWER_REDUCTION(nir_op_ball_fequal, nir_op_feq, nir_op_iand);
       LOWER_REDUCTION(nir_op_ball_iequal, nir_op_ieq, nir_op_iand);
       LOWER_REDUCTION(nir_op_bany_fnequal, nir_op_fne, nir_op_ior);
       LOWER_REDUCTION(nir_op_bany_inequal, nir_op_ine, nir_op_ior);
+      LOWER_REDUCTION(nir_op_b32all_fequal, nir_op_feq32, nir_op_iand);
+      LOWER_REDUCTION(nir_op_b32all_iequal, nir_op_ieq32, nir_op_iand);
+      LOWER_REDUCTION(nir_op_b32any_fnequal, nir_op_fne32, nir_op_ior);
+      LOWER_REDUCTION(nir_op_b32any_inequal, nir_op_ine32, nir_op_ior);
       LOWER_REDUCTION(nir_op_fall_equal, nir_op_seq, nir_op_fand);
       LOWER_REDUCTION(nir_op_fany_nequal, nir_op_sne, nir_op_for);
 
@@ -199,12 +217,12 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
    }
 
    if (instr->dest.dest.ssa.num_components == 1)
-      return;
+      return false;
 
    unsigned num_components = instr->dest.dest.ssa.num_components;
-   nir_ssa_def *comps[] = { NULL, NULL, NULL, NULL };
+   nir_ssa_def *comps[NIR_MAX_VEC_COMPONENTS] = { NULL };
 
-   for (chan = 0; chan < 4; chan++) {
+   for (chan = 0; chan < NIR_MAX_VEC_COMPONENTS; chan++) {
       if (!(instr->dest.write_mask & (1 << chan)))
          continue;
 
@@ -218,13 +236,14 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
                               0 : chan);
 
          nir_alu_src_copy(&lower->src[i], &instr->src[i], lower);
-         for (int j = 0; j < 4; j++)
+         for (int j = 0; j < NIR_MAX_VEC_COMPONENTS; j++)
             lower->src[i].swizzle[j] = instr->src[i].swizzle[src_chan];
       }
 
       nir_alu_ssa_dest_init(lower, 1, instr->dest.dest.ssa.bit_size);
       lower->dest.saturate = instr->dest.saturate;
       comps[chan] = &lower->dest.dest.ssa;
+      lower->exact = instr->exact;
 
       nir_builder_instr_insert(b, &lower->instr);
    }
@@ -234,33 +253,42 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(vec));
 
    nir_instr_remove(&instr->instr);
-}
-
-static bool
-lower_alu_to_scalar_block(nir_block *block, void *builder)
-{
-   nir_foreach_instr_safe(block, instr) {
-      if (instr->type == nir_instr_type_alu)
-         lower_alu_instr_scalar(nir_instr_as_alu(instr), builder);
-   }
-
    return true;
 }
 
-static void
-nir_lower_alu_to_scalar_impl(nir_function_impl *impl)
+static bool
+nir_lower_alu_to_scalar_impl(nir_function_impl *impl, BITSET_WORD *lower_set)
 {
    nir_builder builder;
    nir_builder_init(&builder, impl);
+   bool progress = false;
+
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type == nir_instr_type_alu) {
+            progress = lower_alu_instr_scalar(nir_instr_as_alu(instr),
+                                              &builder,
+                                              lower_set) || progress;
+         }
+      }
+   }
+
+   nir_metadata_preserve(impl, nir_metadata_block_index |
+                               nir_metadata_dominance);
 
-   nir_foreach_block(impl, lower_alu_to_scalar_block, &builder);
+   return progress;
 }
 
-void
-nir_lower_alu_to_scalar(nir_shader *shader)
+bool
+nir_lower_alu_to_scalar(nir_shader *shader, BITSET_WORD *lower_set)
 {
-   nir_foreach_function(shader, function) {
+   bool progress = false;
+
+   nir_foreach_function(function, shader) {
       if (function->impl)
-         nir_lower_alu_to_scalar_impl(function->impl);
+         progress = nir_lower_alu_to_scalar_impl(function->impl,
+                                                 lower_set) || progress;
    }
+
+   return progress;
 }