nir: Extend nir_lower_int64() to support i2f/f2i lowering
authorBoris Brezillon <boris.brezillon@collabora.com>
Fri, 19 Jun 2020 15:28:09 +0000 (17:28 +0200)
committerMarge Bot <eric+marge@anholt.net>
Thu, 30 Jul 2020 16:54:24 +0000 (16:54 +0000)
That's an attempt at replacing the complex __int64_to_float() and
__float_to_int64() implementations found in float64.glsl by a simpler
native NIR equivalent.

Thanks to that, we can have lower those conversion without having to
compile a GLSL shader, which would be quite annoying for OpenCL
kernels.

Signed-off-by: Boris Brezillon <boris.brezillon@collabora.com>
Reviewed-by: Matt Turner <mattst88@gmail.com>
Acked-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/5588>

src/compiler/nir/nir_lower_int64.c

index 2c188d5eaf63ad38e64789ced8ef03c18a6c31fe..03d7c1568071ea9053a55691ea1ba05199689878 100644 (file)
 #include "nir.h"
 #include "nir_builder.h"
 
+#define COND_LOWER_OP(b, name, ...)                                   \
+        (b->shader->options->lower_int64_options &                    \
+         nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
+        lower_##name##64(b, __VA_ARGS__) : nir_##name(b, __VA_ARGS__)
+
+#define COND_LOWER_CMP(b, name, ...)                                  \
+        (b->shader->options->lower_int64_options &                    \
+         nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
+        lower_int64_compare(b, nir_op_##name, __VA_ARGS__) :          \
+        nir_##name(b, __VA_ARGS__)
+
+#define COND_LOWER_CAST(b, name, ...)                                 \
+        (b->shader->options->lower_int64_options &                    \
+         nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
+        lower_##name(b, __VA_ARGS__) :                                \
+        nir_##name(b, __VA_ARGS__)
+
 static nir_ssa_def *
 lower_b2i64(nir_builder *b, nir_ssa_def *x)
 {
@@ -670,6 +687,109 @@ lower_ufind_msb64(nir_builder *b, nir_ssa_def *x)
    return nir_bcsel(b, valid_hi_bits, hi_res, lo_count);
 }
 
+static nir_ssa_def *
+lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size,
+         bool src_is_signed)
+{
+   nir_ssa_def *x_sign = NULL;
+
+   if (src_is_signed) {
+      x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
+                         nir_imm_floatN_t(b, -1, dest_bit_size),
+                         nir_imm_floatN_t(b, 1, dest_bit_size));
+      x = COND_LOWER_OP(b, iabs, x);
+   }
+
+   nir_ssa_def *exp = COND_LOWER_OP(b, ufind_msb, x);
+   unsigned significand_bits;
+
+   switch (dest_bit_size) {
+   case 32:
+      significand_bits = 23;
+      break;
+   case 16:
+      significand_bits = 10;
+      break;
+   default:
+      unreachable("Invalid dest_bit_size");
+   }
+
+   /* We keep one more bit than can fit in the significand field to let the
+    * u2f32 conversion do the rounding for us.
+    */
+   nir_ssa_def *discard =
+      nir_imax(b, nir_isub(b, exp, nir_imm_int(b, significand_bits + 1)),
+                  nir_imm_int(b, 0));
+
+   /* Part of the "round to nearest" has to be taken care of before we discard
+    * the LSB, and that's what this extra iadd is for.
+    * "Round to nearest even" is handled by u2f. That works because the
+    * shifted value either fits in the significand field (which means no
+    * rounding is required) or contains one extra bit that forces the
+    * conversion op to round things properly.
+    */
+   nir_ssa_def *add = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
+   add = COND_LOWER_OP(b, isub, add, nir_imm_int64(b, 1));
+   nir_ssa_def *rounded_x = COND_LOWER_OP(b, iadd, x, add);
+
+   /* Signed Values can't overflow because we've saved the sign and promoted
+    * them to unsigned values.
+    */
+   if (!src_is_signed) {
+      nir_ssa_def *overflow = COND_LOWER_CMP(b, ult, rounded_x, x);
+      rounded_x = COND_LOWER_OP(b, bcsel, overflow,
+                                nir_imm_int64(b, UINT64_MAX), rounded_x);
+   }
+
+   nir_ssa_def *significand = COND_LOWER_OP(b, ushr, rounded_x, discard);
+   significand = COND_LOWER_CAST(b, u2u32, significand);
+
+   nir_ssa_def *res;
+
+   if (dest_bit_size == 32)
+      res = nir_fmul(b, nir_u2f32(b, significand),
+                     nir_fexp2(b, nir_u2f32(b, discard)));
+   else
+      res = nir_fmul(b, nir_u2f16(b, significand),
+                     nir_fexp2(b, nir_u2f16(b, discard)));
+
+   if (src_is_signed)
+      res = nir_fmul(b, res, x_sign);
+
+   return res;
+}
+
+static nir_ssa_def *
+lower_f2(nir_builder *b, nir_ssa_def *x, bool dst_is_signed)
+{
+   assert(x->bit_size == 16 || x->bit_size == 32);
+   nir_ssa_def *x_sign = NULL;
+
+   if (dst_is_signed)
+      x_sign = nir_fsign(b, x);
+   else
+      x = nir_fmin(b, x, nir_imm_floatN_t(b, UINT64_MAX, x->bit_size));
+
+   x = nir_ftrunc(b, x);
+
+   if (dst_is_signed) {
+      x = nir_fmin(b, x, nir_imm_floatN_t(b, INT64_MAX, x->bit_size));
+      x = nir_fmax(b, x, nir_imm_floatN_t(b, INT64_MIN, x->bit_size));
+      x = nir_fabs(b, x);
+   }
+
+   nir_ssa_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
+   nir_ssa_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
+   nir_ssa_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
+   nir_ssa_def *res = nir_pack_64_2x32_split(b, res_lo, res_hi);
+
+   if (dst_is_signed)
+      res = nir_bcsel(b, nir_flt(b, x_sign, nir_imm_float(b, 0)),
+                      nir_ineg(b, res), res);
+
+   return res;
+}
+
 nir_lower_int64_options
 nir_lower_int64_op_to_options_mask(nir_op opcode)
 {
@@ -701,6 +821,12 @@ nir_lower_int64_op_to_options_mask(nir_op opcode)
    case nir_op_u2u16:
    case nir_op_u2u32:
    case nir_op_u2u64:
+   case nir_op_i2f32:
+   case nir_op_u2f32:
+   case nir_op_i2f16:
+   case nir_op_u2f16:
+   case nir_op_f2i64:
+   case nir_op_f2u64:
    case nir_op_bcsel:
       return nir_lower_mov64;
    case nir_op_ieq:
@@ -860,7 +986,21 @@ lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state)
       return lower_extract(b, alu->op, src[0], src[1]);
    case nir_op_ufind_msb:
       return lower_ufind_msb64(b, src[0]);
-      break;
+   case nir_op_i2f64:
+   case nir_op_i2f32:
+   case nir_op_i2f16:
+      return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), true);
+   case nir_op_u2f64:
+   case nir_op_u2f32:
+   case nir_op_u2f16:
+      return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), false);
+   case nir_op_f2i64:
+   case nir_op_f2u64:
+      /* We don't support f64toi64 (yet?). */
+      if (src[0]->bit_size > 32)
+         return false;
+
+      return lower_f2(b, src[0], alu->op == nir_op_f2i64);
    default:
       unreachable("Invalid ALU opcode to lower");
    }
@@ -922,6 +1062,19 @@ should_lower_int64_alu_instr(const nir_instr *instr, const void *_data)
       if (alu->dest.dest.ssa.bit_size != 64)
          return false;
       break;
+   case nir_op_i2f64:
+   case nir_op_u2f64:
+   case nir_op_i2f32:
+   case nir_op_u2f32:
+   case nir_op_i2f16:
+   case nir_op_u2f16:
+      assert(alu->src[0].src.is_ssa);
+      if (alu->src[0].src.ssa->bit_size != 64)
+         return false;
+      break;
+   case nir_op_f2u64:
+   case nir_op_f2i64:
+      /* fall-through */
    default:
       assert(alu->dest.dest.is_ssa);
       if (alu->dest.dest.ssa.bit_size != 64)