nir: add d2i, d2u, d2b opcodes
[mesa.git] / src / compiler / nir / nir_lower_alu_to_scalar.c
index 0a27e66cf0f596b24714172b60799eaf6e59895e..e8ba640fe0baabc54cbd70d346a1230b8b15a0d4 100644 (file)
  */
 
 static void
-nir_alu_ssa_dest_init(nir_alu_instr *instr, unsigned num_components)
+nir_alu_ssa_dest_init(nir_alu_instr *instr, unsigned num_components,
+                      unsigned bit_size)
 {
-   nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components, NULL);
+   nir_ssa_dest_init(&instr->instr, &instr->dest.dest, num_components,
+                     bit_size, NULL);
    instr->dest.write_mask = (1 << num_components) - 1;
 }
 
@@ -46,7 +48,7 @@ lower_reduction(nir_alu_instr *instr, nir_op chan_op, nir_op merge_op,
    nir_ssa_def *last = NULL;
    for (unsigned i = 0; i < num_components; i++) {
       nir_alu_instr *chan = nir_alu_instr_create(builder->shader, chan_op);
-      nir_alu_ssa_dest_init(chan, 1);
+      nir_alu_ssa_dest_init(chan, 1, instr->dest.dest.ssa.bit_size);
       nir_alu_src_copy(&chan->src[0], &instr->src[0], chan);
       chan->src[0].swizzle[0] = chan->src[0].swizzle[i];
       if (nir_op_infos[chan_op].num_inputs > 1) {
@@ -80,6 +82,7 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
    assert(instr->dest.write_mask != 0);
 
    b->cursor = nir_before_instr(&instr->instr);
+   b->exact = instr->exact;
 
 #define LOWER_REDUCTION(name, chan, merge) \
    case name##2: \
@@ -97,6 +100,20 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
        */
       return;
 
+   case nir_op_pack_half_2x16:
+      if (!b->shader->options->lower_pack_half_2x16)
+         return;
+
+      nir_ssa_def *val =
+         nir_pack_half_2x16_split(b, nir_channel(b, instr->src[0].src.ssa,
+                                                 instr->src[0].swizzle[0]),
+                                     nir_channel(b, instr->src[0].src.ssa,
+                                                 instr->src[0].swizzle[1]));
+
+      nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
+      nir_instr_remove(&instr->instr);
+      return;
+
    case nir_op_unpack_unorm_4x8:
    case nir_op_unpack_snorm_4x8:
    case nir_op_unpack_unorm_2x16:
@@ -106,11 +123,51 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
        */
       return;
 
-   case nir_op_unpack_half_2x16:
-      /* We could split this into unpack_half_2x16_split_[xy], but should
-       * we?
-       */
+   case nir_op_unpack_half_2x16: {
+      if (!b->shader->options->lower_unpack_half_2x16)
+         return;
+
+      nir_ssa_def *comps[2];
+      comps[0] = nir_unpack_half_2x16_split_x(b, instr->src[0].src.ssa);
+      comps[1] = nir_unpack_half_2x16_split_y(b, instr->src[0].src.ssa);
+      nir_ssa_def *vec = nir_vec(b, comps, 2);
+
+      nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(vec));
+      nir_instr_remove(&instr->instr);
       return;
+   }
+
+   case nir_op_pack_uvec2_to_uint: {
+      assert(b->shader->options->lower_pack_snorm_2x16 ||
+             b->shader->options->lower_pack_unorm_2x16);
+
+      nir_ssa_def *word =
+         nir_extract_u16(b, instr->src[0].src.ssa, nir_imm_int(b, 0));
+      nir_ssa_def *val =
+         nir_ior(b, nir_ishl(b, nir_channel(b, word, 1), nir_imm_int(b, 16)),
+                                nir_channel(b, word, 0));
+
+      nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
+      nir_instr_remove(&instr->instr);
+      break;
+   }
+
+   case nir_op_pack_uvec4_to_uint: {
+      assert(b->shader->options->lower_pack_snorm_4x8 ||
+             b->shader->options->lower_pack_unorm_4x8);
+
+      nir_ssa_def *byte =
+         nir_extract_u8(b, instr->src[0].src.ssa, nir_imm_int(b, 0));
+      nir_ssa_def *val =
+         nir_ior(b, nir_ior(b, nir_ishl(b, nir_channel(b, byte, 3), nir_imm_int(b, 24)),
+                               nir_ishl(b, nir_channel(b, byte, 2), nir_imm_int(b, 16))),
+                    nir_ior(b, nir_ishl(b, nir_channel(b, byte, 1), nir_imm_int(b, 8)),
+                               nir_channel(b, byte, 0)));
+
+      nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(val));
+      nir_instr_remove(&instr->instr);
+      break;
+   }
 
    case nir_op_fdph: {
       nir_ssa_def *sum[4];
@@ -166,7 +223,7 @@ lower_alu_instr_scalar(nir_alu_instr *instr, nir_builder *b)
             lower->src[i].swizzle[j] = instr->src[i].swizzle[src_chan];
       }
 
-      nir_alu_ssa_dest_init(lower, 1);
+      nir_alu_ssa_dest_init(lower, 1, instr->dest.dest.ssa.bit_size);
       lower->dest.saturate = instr->dest.saturate;
       comps[chan] = &lower->dest.dest.ssa;