nir: Optimize mask+downcast to just downcast
[mesa.git] / src / compiler / nir / nir_lower_double_ops.c
index 45ac155de51c622e7fd50d688acda1688c4cc94e..a6a9b9584261419f797a16e2aad392ce19cb5992 100644 (file)
@@ -26,6 +26,8 @@
 #include "nir_builder.h"
 #include "c99_math.h"
 
+#include <float.h>
+
 /*
  * Lowers some unsupported double operations, using only:
  *
@@ -47,7 +49,9 @@ set_exponent(nir_builder *b, nir_ssa_def *src, nir_ssa_def *exp)
    /* The exponent is bits 52-62, or 20-30 of the high word, so set the exponent
     * to 1023
     */
-   nir_ssa_def *new_hi = nir_bfi(b, nir_imm_int(b, 0x7ff00000), exp, hi);
+   nir_ssa_def *new_hi = nir_bitfield_insert(b, hi, exp,
+                                             nir_imm_int(b, 20),
+                                             nir_imm_int(b, 11));
    /* recombine */
    return nir_pack_64_2x32_split(b, lo, new_hi);
 }
@@ -289,9 +293,20 @@ lower_sqrt_rsq(nir_builder *b, nir_ssa_def *src, bool sqrt)
        * 0 -> 0 and
        * +inf -> +inf
        */
-      res = nir_bcsel(b, nir_ior(b, nir_feq(b, src, nir_imm_double(b, 0.0)),
+      const bool preserve_denorms =
+         b->shader->info.float_controls_execution_mode &
+         FLOAT_CONTROLS_DENORM_PRESERVE_FP64;
+      nir_ssa_def *src_flushed = src;
+      if (!preserve_denorms) {
+         src_flushed = nir_bcsel(b,
+                                 nir_flt(b, nir_fabs(b, src),
+                                         nir_imm_double(b, DBL_MIN)),
+                                 nir_imm_double(b, 0.0),
+                                 src);
+      }
+      res = nir_bcsel(b, nir_ior(b, nir_feq(b, src_flushed, nir_imm_double(b, 0.0)),
                                  nir_feq(b, src, nir_imm_double(b, INFINITY))),
-                                 src, res);
+                                 src_flushed, res);
    } else {
       res = fix_inv_result(b, res, src, new_exp);
    }
@@ -413,15 +428,32 @@ lower_mod(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
     *
     * If the division is lowered, it could add some rounding errors that make
     * floor() to return the quotient minus one when x = N * y. If this is the
-    * case, we return zero because mod(x, y) output value is [0, y).
+    * case, we should return zero because mod(x, y) output value is [0, y).
+    * But fortunately Vulkan spec allows this kind of errors; from Vulkan
+    * spec, appendix A (Precision and Operation of SPIR-V instructions:
+    *
+    *   "The OpFRem and OpFMod instructions use cheap approximations of
+    *   remainder, and the error can be large due to the discontinuity in
+    *   trunc() and floor(). This can produce mathematically unexpected
+    *   results in some cases, such as FMod(x,x) computing x rather than 0,
+    *   and can also cause the result to have a different sign than the
+    *   infinitely precise result."
+    *
+    * In practice this means the output value is actually in the interval
+    * [0, y].
+    *
+    * While Vulkan states this behaviour explicitly, OpenGL does not, and thus
+    * we need to assume that value should be in range [0, y); but on the other
+    * hand, mod(a,b) is defined as "a - b * floor(a/b)" and OpenGL allows for
+    * some error in division, so a/a could actually end up being 1.0 - 1ULP;
+    * so in this case floor(a/a) would end up as 0, and hence mod(a,a) == a.
+    *
+    * In summary, in the practice mod(a,a) can be "a" both for OpenGL and
+    * Vulkan.
     */
    nir_ssa_def *floor = nir_ffloor(b, nir_fdiv(b, src0, src1));
-   nir_ssa_def *mod = nir_fsub(b, src0, nir_fmul(b, src1, floor));
 
-   return nir_bcsel(b,
-                    nir_fne(b, mod, src1),
-                    mod,
-                    nir_imm_double(b, 0.0));
+   return nir_fsub(b, src0, nir_fmul(b, src1, floor));
 }
 
 static nir_ssa_def *
@@ -439,17 +471,15 @@ lower_doubles_instr_to_soft(nir_builder *b, nir_alu_instr *instr,
 
    switch (instr->op) {
    case nir_op_f2i64:
-      if (instr->src[0].src.ssa->bit_size == 64)
-         name = "__fp64_to_int64";
-      else
-         name = "__fp32_to_int64";
+      if (instr->src[0].src.ssa->bit_size != 64)
+         return false;
+      name = "__fp64_to_int64";
       return_type = glsl_int64_t_type();
       break;
    case nir_op_f2u64:
-      if (instr->src[0].src.ssa->bit_size == 64)
-         name = "__fp64_to_uint64";
-      else
-         name = "__fp32_to_uint64";
+      if (instr->src[0].src.ssa->bit_size != 64)
+         return false;
+      name = "__fp64_to_uint64";
       break;
    case nir_op_f2f64:
       name = "__fp32_to_fp64";
@@ -474,18 +504,6 @@ lower_doubles_instr_to_soft(nir_builder *b, nir_alu_instr *instr,
    case nir_op_b2f64:
       name = "__bool_to_fp64";
       break;
-   case nir_op_i2f32:
-      if (instr->src[0].src.ssa->bit_size != 64)
-         return false;
-      name = "__int64_to_fp32";
-      return_type = glsl_float_type();
-      break;
-   case nir_op_u2f32:
-      if (instr->src[0].src.ssa->bit_size != 64)
-         return false;
-      name = "__uint64_to_fp32";
-      return_type = glsl_float_type();
-      break;
    case nir_op_i2f64:
       if (instr->src[0].src.ssa->bit_size == 64)
          name = "__int64_to_fp64";
@@ -723,6 +741,11 @@ nir_lower_doubles_impl(nir_function_impl *impl,
        * inlining.
        */
       nir_opt_deref_impl(impl);
+   } else if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
    }
 
    return progress;