nir,intel: Add support for lowering 64-bit nir_opt_extract_*
[mesa.git] / src / compiler / nir / nir_lower_int64.c
index b3b78c6649a14de1c5a35cbdafd7d99c0cd9f303..84ec2a77f1e36aec680bbedbeeb70b80751b380f 100644 (file)
@@ -629,6 +629,34 @@ lower_irem64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
    return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
 }
 
+static nir_ssa_def *
+lower_extract(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *c)
+{
+   assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
+          op == nir_op_extract_u16 || op == nir_op_extract_i16);
+
+   const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
+   const int chunk_bits =
+      (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
+   const int num_chunks_in_32 = 32 / chunk_bits;
+
+   nir_ssa_def *extract32;
+   if (chunk < num_chunks_in_32) {
+      extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
+                                   nir_imm_int(b, chunk),
+                                   NULL, NULL);
+   } else {
+      extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
+                                   nir_imm_int(b, chunk - num_chunks_in_32),
+                                   NULL, NULL);
+   }
+
+   if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
+      return lower_i2i64(b, extract32);
+   else
+      return lower_u2u64(b, extract32);
+}
+
 nir_lower_int64_options
 nir_lower_int64_op_to_options_mask(nir_op opcode)
 {
@@ -685,6 +713,11 @@ nir_lower_int64_op_to_options_mask(nir_op opcode)
    case nir_op_ishr:
    case nir_op_ushr:
       return nir_lower_shift64;
+   case nir_op_extract_u8:
+   case nir_op_extract_i8:
+   case nir_op_extract_u16:
+   case nir_op_extract_i16:
+      return nir_lower_extract64;
    default:
       return 0;
    }
@@ -779,6 +812,11 @@ lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
       return lower_ishr64(b, src[0], src[1]);
    case nir_op_ushr:
       return lower_ushr64(b, src[0], src[1]);
+   case nir_op_extract_u8:
+   case nir_op_extract_i8:
+   case nir_op_extract_u16:
+   case nir_op_extract_i16:
+      return lower_extract(b, alu->op, src[0], src[1]);
    default:
       unreachable("Invalid ALU opcode to lower");
    }