nir: Add nir_[iu]shr_imm and nir_udiv_imm helpers and use them.
authorEric Anholt <eric@anholt.net>
Fri, 21 Aug 2020 18:21:33 +0000 (11:21 -0700)
committerEric Anholt <eric@anholt.net>
Mon, 24 Aug 2020 16:53:17 +0000 (09:53 -0700)
I was doing math manually in a lowering pass for converting a division to
a ushr, and this will let the pass be expressed more naturally.

Reviewed-by: Kristian H. Kristensen <hoegsberg@google.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6378>

src/compiler/nir/nir_builder.h
src/compiler/nir/nir_format_convert.h
src/compiler/nir/nir_lower_bit_size.c
src/compiler/nir/nir_lower_double_ops.c
src/compiler/nir/nir_lower_int64.c
src/compiler/nir/nir_opt_idiv_const.c

index 73c095838201cb2b8e52daf2548cd26260d09582..6da2e3c4cc18623339127c70914b7e3680a969b9 100644 (file)
@@ -776,6 +776,41 @@ nir_iand_imm(nir_builder *build, nir_ssa_def *x, uint64_t y)
    }
 }
 
+static inline nir_ssa_def *
+nir_ishr_imm(nir_builder *build, nir_ssa_def *x, uint32_t y)
+{
+   if (y == 0) {
+      return x;
+   } else {
+      return nir_ishr(build, x, nir_imm_int(build, y));
+   }
+}
+
+static inline nir_ssa_def *
+nir_ushr_imm(nir_builder *build, nir_ssa_def *x, uint32_t y)
+{
+   if (y == 0) {
+      return x;
+   } else {
+      return nir_ushr(build, x, nir_imm_int(build, y));
+   }
+}
+
+static inline nir_ssa_def *
+nir_udiv_imm(nir_builder *build, nir_ssa_def *x, uint64_t y)
+{
+   assert(x->bit_size <= 64);
+   y &= BITFIELD64_MASK(x->bit_size);
+
+   if (y == 1) {
+      return x;
+   } else if (util_is_power_of_two_nonzero(y)) {
+      return nir_ushr_imm(build, x, ffsll(y) - 1);
+   } else {
+      return nir_udiv(build, x, nir_imm_intN_t(build, y, x->bit_size));
+   }
+}
+
 static inline nir_ssa_def *
 nir_pack_bits(nir_builder *b, nir_ssa_def *src, unsigned dest_bit_size)
 {
@@ -838,7 +873,7 @@ nir_unpack_bits(nir_builder *b, nir_ssa_def *src, unsigned dest_bit_size)
    /* If we got here, we have no dedicated unpack opcode. */
    nir_ssa_def *dest_comps[NIR_MAX_VEC_COMPONENTS];
    for (unsigned i = 0; i < dest_num_components; i++) {
-      nir_ssa_def *val = nir_ushr(b, src, nir_imm_int(b, i * dest_bit_size));
+      nir_ssa_def *val = nir_ushr_imm(b, src, i * dest_bit_size);
       dest_comps[i] = nir_u2u(b, val, dest_bit_size);
    }
    return nir_vec(b, dest_comps, dest_num_components);
index a9de69e695fcb416acefc142c5175c0ccf195c38..1a86fead79af6db841b7377913cef53a410100cf 100644 (file)
@@ -192,8 +192,8 @@ nir_format_bitcast_uvec_unmasked(nir_builder *b, nir_ssa_def *src,
       unsigned src_idx = 0;
       unsigned shift = 0;
       for (unsigned i = 0; i < dst_components; i++) {
-         dst_chan[i] = nir_iand(b, nir_ushr(b, nir_channel(b, src, src_idx),
-                                               nir_imm_int(b, shift)),
+         dst_chan[i] = nir_iand(b, nir_ushr_imm(b, nir_channel(b, src, src_idx),
+                                                shift),
                                    mask);
          shift += dst_bits;
          if (shift >= src_bits) {
@@ -405,7 +405,7 @@ nir_format_pack_r9g9b9e5(nir_builder *b, nir_ssa_def *color)
     *              1 + RGB9E5_EXP_BIAS - 127;
     */
    nir_ssa_def *exp_shared =
-      nir_iadd(b, nir_umax(b, nir_ushr(b, maxu, nir_imm_int(b, 23)),
+      nir_iadd(b, nir_umax(b, nir_ushr_imm(b, maxu, 23),
                               nir_imm_int(b, -RGB9E5_EXP_BIAS - 1 + 127)),
                   nir_imm_int(b, 1 + RGB9E5_EXP_BIAS - 127));
 
@@ -432,8 +432,8 @@ nir_format_pack_r9g9b9e5(nir_builder *b, nir_ssa_def *color)
     * gm = (gm & 1) + (gm >> 1);
     * bm = (bm & 1) + (bm >> 1);
     */
-   mantissa = nir_iadd(b, nir_iand(b, mantissa, nir_imm_int(b, 1)),
-                          nir_ushr(b, mantissa, nir_imm_int(b, 1)));
+   mantissa = nir_iadd(b, nir_iand_imm(b, mantissa, 1),
+                          nir_ushr_imm(b, mantissa, 1));
 
    nir_ssa_def *packed = nir_channel(b, mantissa, 0);
    packed = nir_mask_shift_or(b, packed, nir_channel(b, mantissa, 1), ~0, 9);
index a6a2032b010b94f4c38e9208748fe792a6c9ae4c..6e7d57e3608821fdb7a125db7a12d080b21fdf79 100644 (file)
@@ -61,9 +61,9 @@ lower_instr(nir_builder *bld, nir_alu_instr *alu, unsigned bit_size)
       assert(dst_bit_size * 2 <= bit_size);
       nir_ssa_def *lowered_dst = nir_imul(bld, srcs[0], srcs[1]);
       if (nir_op_infos[op].output_type & nir_type_uint)
-         lowered_dst = nir_ushr(bld, lowered_dst, nir_imm_int(bld, dst_bit_size));
+         lowered_dst = nir_ushr_imm(bld, lowered_dst, dst_bit_size);
       else
-         lowered_dst = nir_ishr(bld, lowered_dst, nir_imm_int(bld, dst_bit_size));
+         lowered_dst = nir_ishr_imm(bld, lowered_dst, dst_bit_size);
    } else {
       lowered_dst = nir_build_alu_src_arr(bld, op, srcs);
    }
index 4910e1b8958e872fca4bc825ad672bec5659d6b6..0448e53e949db851825143dcd2a15ccacdbc9477 100644 (file)
@@ -177,8 +177,8 @@ lower_sqrt_rsq(nir_builder *b, nir_ssa_def *src, bool sqrt)
 
    nir_ssa_def *unbiased_exp = nir_isub(b, get_exponent(b, src),
                                         nir_imm_int(b, 1023));
-   nir_ssa_def *even = nir_iand(b, unbiased_exp, nir_imm_int(b, 1));
-   nir_ssa_def *half = nir_ishr(b, unbiased_exp, nir_imm_int(b, 1));
+   nir_ssa_def *even = nir_iand_imm(b, unbiased_exp, 1);
+   nir_ssa_def *half = nir_ishr_imm(b, unbiased_exp, 1);
 
    nir_ssa_def *src_norm = set_exponent(b, src,
                                         nir_iadd(b, nir_imm_int(b, 1023),
index 9d268a843633283864f2e7019915eb96e9c0c59f..0c14fe58853e9f55956b791843840ca7e837f145 100644 (file)
@@ -78,7 +78,7 @@ static nir_ssa_def *
 lower_i2i64(nir_builder *b, nir_ssa_def *x)
 {
    nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
-   return nir_pack_64_2x32_split(b, x32, nir_ishr(b, x32, nir_imm_int(b, 31)));
+   return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
 }
 
 static nir_ssa_def *
@@ -435,7 +435,7 @@ lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
    x32[0] = nir_unpack_64_2x32_split_x(b, x);
    x32[1] = nir_unpack_64_2x32_split_y(b, x);
    if (sign_extend) {
-      x32[2] = x32[3] = nir_ishr(b, x32[1], nir_imm_int(b, 31));
+      x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
    } else {
       x32[2] = x32[3] = nir_imm_int(b, 0);
    }
@@ -443,7 +443,7 @@ lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
    y32[0] = nir_unpack_64_2x32_split_x(b, y);
    y32[1] = nir_unpack_64_2x32_split_y(b, y);
    if (sign_extend) {
-      y32[2] = y32[3] = nir_ishr(b, y32[1], nir_imm_int(b, 31));
+      y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
    } else {
       y32[2] = y32[3] = nir_imm_int(b, 0);
    }
@@ -476,7 +476,7 @@ lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
          if (carry)
             tmp = nir_iadd(b, tmp, carry);
          res[i + j] = nir_u2u32(b, tmp);
-         carry = nir_ushr(b, tmp, nir_imm_int(b, 32));
+         carry = nir_ushr_imm(b, tmp, 32);
       }
       res[i + 4] = nir_u2u32(b, carry);
    }
@@ -491,7 +491,7 @@ lower_isign64(nir_builder *b, nir_ssa_def *x)
    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
 
    nir_ssa_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
-   nir_ssa_def *res_hi = nir_ishr(b, x_hi, nir_imm_int(b, 31));
+   nir_ssa_def *res_hi = nir_ishr_imm(b, x_hi, 31);
    nir_ssa_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
 
    return nir_pack_64_2x32_split(b, res_lo, res_hi);
index 2992d06031788472a26a0692e47ad1eee94e9123..49130f66dd88bb1184e92635520e8a876b18da38 100644 (file)
@@ -32,18 +32,18 @@ build_udiv(nir_builder *b, nir_ssa_def *n, uint64_t d)
    if (d == 0) {
       return nir_imm_intN_t(b, 0, n->bit_size);
    } else if (util_is_power_of_two_or_zero64(d)) {
-      return nir_ushr(b, n, nir_imm_int(b, util_logbase2_64(d)));
+      return nir_ushr_imm(b, n, util_logbase2_64(d));
    } else {
       struct util_fast_udiv_info m =
          util_compute_fast_udiv_info(d, n->bit_size, n->bit_size);
 
       if (m.pre_shift)
-         n = nir_ushr(b, n, nir_imm_int(b, m.pre_shift));
+         n = nir_ushr_imm(b, n, m.pre_shift);
       if (m.increment)
          n = nir_uadd_sat(b, n, nir_imm_intN_t(b, m.increment, n->bit_size));
       n = nir_umul_high(b, n, nir_imm_intN_t(b, m.multiplier, n->bit_size));
       if (m.post_shift)
-         n = nir_ushr(b, n, nir_imm_int(b, m.post_shift));
+         n = nir_ushr_imm(b, n, m.post_shift);
 
       return n;
    }
@@ -74,8 +74,7 @@ build_idiv(nir_builder *b, nir_ssa_def *n, int64_t d)
    } else if (d == -1) {
       return nir_ineg(b, n);
    } else if (util_is_power_of_two_or_zero64(abs_d)) {
-      nir_ssa_def *uq = nir_ushr(b, nir_iabs(b, n),
-                                    nir_imm_int(b, util_logbase2_64(abs_d)));
+      nir_ssa_def *uq = nir_ushr_imm(b, nir_iabs(b, n), util_logbase2_64(abs_d));
       nir_ssa_def *n_neg = nir_ilt(b, n, nir_imm_intN_t(b, 0, n->bit_size));
       nir_ssa_def *neg = d < 0 ? nir_inot(b, n_neg) : n_neg;
       return nir_bcsel(b, neg, nir_ineg(b, uq), uq);
@@ -90,8 +89,8 @@ build_idiv(nir_builder *b, nir_ssa_def *n, int64_t d)
       if (d < 0 && m.multiplier > 0)
          res = nir_isub(b, res, n);
       if (m.shift)
-         res = nir_ishr(b, res, nir_imm_int(b, m.shift));
-      res = nir_iadd(b, res, nir_ushr(b, res, nir_imm_int(b, n->bit_size - 1)));
+         res = nir_ishr_imm(b, res, m.shift);
+      res = nir_iadd(b, res, nir_ushr_imm(b, res, n->bit_size - 1));
 
       return res;
    }