nir: Add bit_count to lower_int64 pass
[mesa.git] / src / compiler / nir / nir_lower_int64.c
index 9d268a843633283864f2e7019915eb96e9c0c59f..07b307ea461f91d6636a5f0be5754f0521f7c581 100644 (file)
@@ -78,7 +78,7 @@ static nir_ssa_def *
 lower_i2i64(nir_builder *b, nir_ssa_def *x)
 {
    nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
-   return nir_pack_64_2x32_split(b, x32, nir_ishr(b, x32, nir_imm_int(b, 31)));
+   return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
 }
 
 static nir_ssa_def *
@@ -435,7 +435,7 @@ lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
    x32[0] = nir_unpack_64_2x32_split_x(b, x);
    x32[1] = nir_unpack_64_2x32_split_y(b, x);
    if (sign_extend) {
-      x32[2] = x32[3] = nir_ishr(b, x32[1], nir_imm_int(b, 31));
+      x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
    } else {
       x32[2] = x32[3] = nir_imm_int(b, 0);
    }
@@ -443,7 +443,7 @@ lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
    y32[0] = nir_unpack_64_2x32_split_x(b, y);
    y32[1] = nir_unpack_64_2x32_split_y(b, y);
    if (sign_extend) {
-      y32[2] = y32[3] = nir_ishr(b, y32[1], nir_imm_int(b, 31));
+      y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
    } else {
       y32[2] = y32[3] = nir_imm_int(b, 0);
    }
@@ -476,7 +476,7 @@ lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
          if (carry)
             tmp = nir_iadd(b, tmp, carry);
          res[i + j] = nir_u2u32(b, tmp);
-         carry = nir_ushr(b, tmp, nir_imm_int(b, 32));
+         carry = nir_ushr_imm(b, tmp, 32);
       }
       res[i + 4] = nir_u2u32(b, carry);
    }
@@ -491,7 +491,7 @@ lower_isign64(nir_builder *b, nir_ssa_def *x)
    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
 
    nir_ssa_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
-   nir_ssa_def *res_hi = nir_ishr(b, x_hi, nir_imm_int(b, 31));
+   nir_ssa_def *res_hi = nir_ishr_imm(b, x_hi, 31);
    nir_ssa_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
 
    return nir_pack_64_2x32_split(b, res_lo, res_hi);
@@ -785,6 +785,16 @@ lower_f2(nir_builder *b, nir_ssa_def *x, bool dst_is_signed)
    return res;
 }
 
+static nir_ssa_def *
+lower_bit_count64(nir_builder *b, nir_ssa_def *x)
+{
+   nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
+   nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
+   nir_ssa_def *lo_count = nir_bit_count(b, x_lo);
+   nir_ssa_def *hi_count = nir_bit_count(b, x_hi);
+   return nir_iadd(b, lo_count, hi_count);
+}
+
 nir_lower_int64_options
 nir_lower_int64_op_to_options_mask(nir_op opcode)
 {
@@ -838,12 +848,6 @@ nir_lower_int64_op_to_options_mask(nir_op opcode)
    case nir_op_imax:
    case nir_op_umin:
    case nir_op_umax:
-   case nir_op_imin3:
-   case nir_op_imax3:
-   case nir_op_umin3:
-   case nir_op_umax3:
-   case nir_op_imed3:
-   case nir_op_umed3:
       return nir_lower_minmax64;
    case nir_op_iabs:
       return nir_lower_iabs64;
@@ -865,6 +869,8 @@ nir_lower_int64_op_to_options_mask(nir_op opcode)
       return nir_lower_extract64;
    case nir_op_ufind_msb:
       return nir_lower_ufind_msb64;
+   case nir_op_bit_count:
+      return nir_lower_bit_count64;
    default:
       return 0;
    }
@@ -944,18 +950,6 @@ lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state)
       return lower_umin64(b, src[0], src[1]);
    case nir_op_umax:
       return lower_umax64(b, src[0], src[1]);
-   case nir_op_imin3:
-      return lower_imin64(b, src[0], lower_imin64(b, src[1], src[2]));
-   case nir_op_imax3:
-      return lower_imax64(b, src[0], lower_imax64(b, src[1], src[2]));
-   case nir_op_umin3:
-      return lower_umin64(b, src[0], lower_umin64(b, src[1], src[2]));
-   case nir_op_umax3:
-      return lower_umax64(b, src[0], lower_umax64(b, src[1], src[2]));
-   case nir_op_imed3:
-      return lower_imax64(b, lower_imin64(b, lower_imax64(b, src[0], src[1]), src[2]), lower_imin64(b, src[0], src[1]));
-   case nir_op_umed3:
-      return lower_umax64(b, lower_umin64(b, lower_umax64(b, src[0], src[1]), src[2]), lower_umin64(b, src[0], src[1]));
    case nir_op_iabs:
       return lower_iabs64(b, src[0]);
    case nir_op_ineg:
@@ -981,6 +975,8 @@ lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state)
       return lower_extract(b, alu->op, src[0], src[1]);
    case nir_op_ufind_msb:
       return lower_ufind_msb64(b, src[0]);
+   case nir_op_bit_count:
+      return lower_bit_count64(b, src[0]);
    case nir_op_i2f64:
    case nir_op_i2f32:
    case nir_op_i2f16:
@@ -1046,6 +1042,7 @@ should_lower_int64_alu_instr(const nir_instr *instr, const void *_data)
          return false;
       break;
    case nir_op_ufind_msb:
+   case nir_op_bit_count:
       assert(alu->src[0].src.is_ssa);
       if (alu->src[0].src.ssa->bit_size != 64)
          return false;