nir: Get rid of nir_shader::stage
[mesa.git] / src / compiler / nir / nir_lower_double_ops.c
index 75032a6f766da0442163b4de5fe4fa09bd57b6d1..b3543bc6963ee4c2161a72aa72db1af2e4bfcd9e 100644 (file)
@@ -41,22 +41,22 @@ static nir_ssa_def *
 set_exponent(nir_builder *b, nir_ssa_def *src, nir_ssa_def *exp)
 {
    /* Split into bits 0-31 and 32-63 */
-   nir_ssa_def *lo = nir_unpack_double_2x32_split_x(b, src);
-   nir_ssa_def *hi = nir_unpack_double_2x32_split_y(b, src);
+   nir_ssa_def *lo = nir_unpack_64_2x32_split_x(b, src);
+   nir_ssa_def *hi = nir_unpack_64_2x32_split_y(b, src);
 
    /* The exponent is bits 52-62, or 20-30 of the high word, so set the exponent
     * to 1023
     */
    nir_ssa_def *new_hi = nir_bfi(b, nir_imm_int(b, 0x7ff00000), exp, hi);
    /* recombine */
-   return nir_pack_double_2x32_split(b, lo, new_hi);
+   return nir_pack_64_2x32_split(b, lo, new_hi);
 }
 
 static nir_ssa_def *
 get_exponent(nir_builder *b, nir_ssa_def *src)
 {
    /* get bits 32-63 */
-   nir_ssa_def *hi = nir_unpack_double_2x32_split_y(b, src);
+   nir_ssa_def *hi = nir_unpack_64_2x32_split_y(b, src);
 
    /* extract bits 20-30 of the high word */
    return nir_ubitfield_extract(b, hi, nir_imm_int(b, 20), nir_imm_int(b, 11));
@@ -67,7 +67,7 @@ get_exponent(nir_builder *b, nir_ssa_def *src)
 static nir_ssa_def *
 get_signed_inf(nir_builder *b, nir_ssa_def *zero)
 {
-   nir_ssa_def *zero_hi = nir_unpack_double_2x32_split_y(b, zero);
+   nir_ssa_def *zero_hi = nir_unpack_64_2x32_split_y(b, zero);
 
    /* The bit pattern for infinity is 0x7ff0000000000000, where the sign bit
     * is the highest bit. Only the sign bit can be non-zero in the passed in
@@ -76,7 +76,7 @@ get_signed_inf(nir_builder *b, nir_ssa_def *zero)
     * bits and then pack it together with zero low 32 bits.
     */
    nir_ssa_def *inf_hi = nir_ior(b, nir_imm_int(b, 0x7ff00000), zero_hi);
-   return nir_pack_double_2x32_split(b, nir_imm_int(b, 0), inf_hi);
+   return nir_pack_64_2x32_split(b, nir_imm_int(b, 0), inf_hi);
 }
 
 /*
@@ -116,7 +116,7 @@ lower_rcp(nir_builder *b, nir_ssa_def *src)
    /* cast to float, do an rcp, and then cast back to get an approximate
     * result
     */
-   nir_ssa_def *ra = nir_f2d(b, nir_frcp(b, nir_d2f(b, src_norm)));
+   nir_ssa_def *ra = nir_f2f64(b, nir_frcp(b, nir_f2f32(b, src_norm)));
 
    /* Fixup the exponent of the result - note that we check if this is too
     * small below.
@@ -180,7 +180,7 @@ lower_sqrt_rsq(nir_builder *b, nir_ssa_def *src, bool sqrt)
                                         nir_iadd(b, nir_imm_int(b, 1023),
                                                  even));
 
-   nir_ssa_def *ra = nir_f2d(b, nir_frsq(b, nir_d2f(b, src_norm)));
+   nir_ssa_def *ra = nir_f2f64(b, nir_frsq(b, nir_f2f32(b, src_norm)));
    nir_ssa_def *new_exp = nir_isub(b, get_exponent(b, ra), half);
    ra = set_exponent(b, ra, new_exp);
 
@@ -267,36 +267,36 @@ lower_sqrt_rsq(nir_builder *b, nir_ssa_def *src, bool sqrt)
     * (https://en.wikipedia.org/wiki/Methods_of_computing_square_roots).
     */
 
-    nir_ssa_def *one_half = nir_imm_double(b, 0.5);
-    nir_ssa_def *h_0 = nir_fmul(b, one_half, ra);
-    nir_ssa_def *g_0 = nir_fmul(b, src, ra);
-    nir_ssa_def *r_0 = nir_ffma(b, nir_fneg(b, h_0), g_0, one_half);
-    nir_ssa_def *h_1 = nir_ffma(b, h_0, r_0, h_0);
-    nir_ssa_def *res;
-    if (sqrt) {
-       nir_ssa_def *g_1 = nir_ffma(b, g_0, r_0, g_0);
-       nir_ssa_def *r_1 = nir_ffma(b, nir_fneg(b, g_1), g_1, src);
-       res = nir_ffma(b, h_1, r_1, g_1);
-    } else {
-       nir_ssa_def *y_1 = nir_fmul(b, nir_imm_double(b, 2.0), h_1);
-       nir_ssa_def *r_1 = nir_ffma(b, nir_fneg(b, y_1), nir_fmul(b, h_1, src),
-                                   one_half);
-       res = nir_ffma(b, y_1, r_1, y_1);
-    }
-
-    if (sqrt) {
-       /* Here, the special cases we need to handle are
-        * 0 -> 0 and
-        * +inf -> +inf
-        */
-       res = nir_bcsel(b, nir_ior(b, nir_feq(b, src, nir_imm_double(b, 0.0)),
-                                  nir_feq(b, src, nir_imm_double(b, INFINITY))),
-                       src, res);
-    } else {
-       res = fix_inv_result(b, res, src, new_exp);
-    }
-
-    return res;
+   nir_ssa_def *one_half = nir_imm_double(b, 0.5);
+   nir_ssa_def *h_0 = nir_fmul(b, one_half, ra);
+   nir_ssa_def *g_0 = nir_fmul(b, src, ra);
+   nir_ssa_def *r_0 = nir_ffma(b, nir_fneg(b, h_0), g_0, one_half);
+   nir_ssa_def *h_1 = nir_ffma(b, h_0, r_0, h_0);
+   nir_ssa_def *res;
+   if (sqrt) {
+      nir_ssa_def *g_1 = nir_ffma(b, g_0, r_0, g_0);
+      nir_ssa_def *r_1 = nir_ffma(b, nir_fneg(b, g_1), g_1, src);
+      res = nir_ffma(b, h_1, r_1, g_1);
+   } else {
+      nir_ssa_def *y_1 = nir_fmul(b, nir_imm_double(b, 2.0), h_1);
+      nir_ssa_def *r_1 = nir_ffma(b, nir_fneg(b, y_1), nir_fmul(b, h_1, src),
+                                  one_half);
+      res = nir_ffma(b, y_1, r_1, y_1);
+   }
+
+   if (sqrt) {
+      /* Here, the special cases we need to handle are
+       * 0 -> 0 and
+       * +inf -> +inf
+       */
+      res = nir_bcsel(b, nir_ior(b, nir_feq(b, src, nir_imm_double(b, 0.0)),
+                                 nir_feq(b, src, nir_imm_double(b, INFINITY))),
+                                 src, res);
+   } else {
+      res = fix_inv_result(b, res, src, new_exp);
+   }
+
+   return res;
 }
 
 static nir_ssa_def *
@@ -337,8 +337,8 @@ lower_trunc(nir_builder *b, nir_ssa_def *src)
                          nir_imm_int(b, ~0),
                          nir_isub(b, frac_bits, nir_imm_int(b, 32))));
 
-   nir_ssa_def *src_lo = nir_unpack_double_2x32_split_x(b, src);
-   nir_ssa_def *src_hi = nir_unpack_double_2x32_split_y(b, src);
+   nir_ssa_def *src_lo = nir_unpack_64_2x32_split_x(b, src);
+   nir_ssa_def *src_hi = nir_unpack_64_2x32_split_y(b, src);
 
    return
       nir_bcsel(b,
@@ -346,9 +346,9 @@ lower_trunc(nir_builder *b, nir_ssa_def *src)
                 nir_imm_double(b, 0.0),
                 nir_bcsel(b, nir_ige(b, unbiased_exp, nir_imm_int(b, 53)),
                           src,
-                          nir_pack_double_2x32_split(b,
-                                                     nir_iand(b, mask_lo, src_lo),
-                                                     nir_iand(b, mask_hi, src_hi))));
+                          nir_pack_64_2x32_split(b,
+                                                 nir_iand(b, mask_lo, src_lo),
+                                                 nir_iand(b, mask_hi, src_hi))));
 }
 
 static nir_ssa_def *
@@ -438,56 +438,79 @@ lower_round_even(nir_builder *b, nir_ssa_def *src)
                                         nir_fsub(b, src, nir_imm_double(b, 0.5)))));
 }
 
-static void
+static nir_ssa_def *
+lower_mod(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
+{
+   /* mod(x,y) = x - y * floor(x/y)
+    *
+    * If the division is lowered, it could add some rounding errors that make
+    * floor() to return the quotient minus one when x = N * y. If this is the
+    * case, we return zero because mod(x, y) output value is [0, y).
+    */
+   nir_ssa_def *floor = nir_ffloor(b, nir_fdiv(b, src0, src1));
+   nir_ssa_def *mod = nir_fsub(b, src0, nir_fmul(b, src1, floor));
+
+   return nir_bcsel(b,
+                    nir_fne(b, mod, src1),
+                    mod,
+                    nir_imm_double(b, 0.0));
+}
+
+static bool
 lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
 {
    assert(instr->dest.dest.is_ssa);
    if (instr->dest.dest.ssa.bit_size != 64)
-      return;
+      return false;
 
    switch (instr->op) {
    case nir_op_frcp:
       if (!(options & nir_lower_drcp))
-         return;
+         return false;
       break;
 
    case nir_op_fsqrt:
       if (!(options & nir_lower_dsqrt))
-         return;
+         return false;
       break;
 
    case nir_op_frsq:
       if (!(options & nir_lower_drsq))
-         return;
+         return false;
       break;
 
    case nir_op_ftrunc:
       if (!(options & nir_lower_dtrunc))
-         return;
+         return false;
       break;
 
    case nir_op_ffloor:
       if (!(options & nir_lower_dfloor))
-         return;
+         return false;
       break;
 
    case nir_op_fceil:
       if (!(options & nir_lower_dceil))
-         return;
+         return false;
       break;
 
    case nir_op_ffract:
       if (!(options & nir_lower_dfract))
-         return;
+         return false;
       break;
 
    case nir_op_fround_even:
       if (!(options & nir_lower_dround_even))
-         return;
+         return false;
+      break;
+
+   case nir_op_fmod:
+      if (!(options & nir_lower_dmod))
+         return false;
       break;
 
    default:
-      return;
+      return false;
    }
 
    nir_builder bld;
@@ -525,40 +548,52 @@ lower_doubles_instr(nir_alu_instr *instr, nir_lower_doubles_options options)
       result = lower_round_even(&bld, src);
       break;
 
+   case nir_op_fmod: {
+      nir_ssa_def *src1 = nir_fmov_alu(&bld, instr->src[1],
+                                      instr->dest.dest.ssa.num_components);
+      result = lower_mod(&bld, src, src1);
+   }
+      break;
    default:
       unreachable("unhandled opcode");
    }
 
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(result));
    nir_instr_remove(&instr->instr);
+   return true;
 }
 
 static bool
-lower_doubles_block(nir_block *block, void *ctx)
+nir_lower_doubles_impl(nir_function_impl *impl,
+                       nir_lower_doubles_options options)
 {
-   nir_lower_doubles_options options = *((nir_lower_doubles_options *) ctx);
-
-   nir_foreach_instr_safe(instr, block) {
-      if (instr->type != nir_instr_type_alu)
-         continue;
-
-      lower_doubles_instr(nir_instr_as_alu(instr), options);
+   bool progress = false;
+
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr_safe(instr, block) {
+         if (instr->type == nir_instr_type_alu)
+            progress |= lower_doubles_instr(nir_instr_as_alu(instr),
+                                            options);
+      }
    }
 
-   return true;
-}
+   if (progress)
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
 
-static void
-lower_doubles_impl(nir_function_impl *impl, nir_lower_doubles_options options)
-{
-   nir_foreach_block_call(impl, lower_doubles_block, &options);
+   return progress;
 }
 
-void
+bool
 nir_lower_doubles(nir_shader *shader, nir_lower_doubles_options options)
 {
-   nir_foreach_function(shader, function) {
-      if (function->impl)
-         lower_doubles_impl(function->impl, options);
+   bool progress = false;
+
+   nir_foreach_function(function, shader) {
+      if (function->impl) {
+         progress |= nir_lower_doubles_impl(function->impl, options);
+      }
    }
+
+   return progress;
 }