nir: copy intrinsic type when lowering load input/uniform and store output
[mesa.git] / src / compiler / nir / nir_lower_double_ops.c
index 054fce9c168219508ff283226b515304a3379fbb..4ddd91f9054e9f06319950b007ce1be97202da16 100644 (file)
@@ -142,8 +142,8 @@ lower_rcp(nir_builder *b, nir_ssa_def *src)
     * See https://en.wikipedia.org/wiki/Division_algorithm for more details.
     */
 
-   ra = nir_ffma(b, ra, nir_ffma(b, ra, src, nir_imm_double(b, -1)), ra);
-   ra = nir_ffma(b, ra, nir_ffma(b, ra, src, nir_imm_double(b, -1)), ra);
+   ra = nir_ffma(b, nir_fneg(b, ra), nir_ffma(b, ra, src, nir_imm_double(b, -1)), ra);
+   ra = nir_ffma(b, nir_fneg(b, ra), nir_ffma(b, ra, src, nir_imm_double(b, -1)), ra);
 
    return fix_inv_result(b, ra, src, new_exp);
 }
@@ -426,6 +426,7 @@ lower_mod(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
 
 static bool
 lower_doubles_instr_to_soft(nir_builder *b, nir_alu_instr *instr,
+                            const nir_shader *softfp64,
                             nir_lower_doubles_options options)
 {
    if (!(options & nir_lower_fp64_full_software))
@@ -549,39 +550,41 @@ lower_doubles_instr_to_soft(nir_builder *b, nir_alu_instr *instr,
    case nir_op_ffma:
       name = "__ffma64";
       break;
+   case nir_op_fsat:
+      name = "__fsat64";
+      break;
    default:
       return false;
    }
 
-   nir_shader *shader = b->shader;
    nir_function *func = NULL;
-
-   nir_foreach_function(function, shader) {
+   nir_foreach_function(function, softfp64) {
       if (strcmp(function->name, name) == 0) {
          func = function;
          break;
       }
    }
-   if (!func) {
+   if (!func || !func->impl) {
       fprintf(stderr, "Cannot find function \"%s\"\n", name);
       assert(func);
    }
 
    b->cursor = nir_before_instr(&instr->instr);
 
-   nir_call_instr *call = nir_call_instr_create(shader, func);
+   nir_ssa_def *params[4] = { NULL, };
 
    nir_variable *ret_tmp =
       nir_local_variable_create(b->impl, return_type, "return_tmp");
    nir_deref_instr *ret_deref = nir_build_deref_var(b, ret_tmp);
-   call->params[0] = nir_src_for_ssa(&ret_deref->dest.ssa);
+   params[0] = &ret_deref->dest.ssa;
 
+   assert(nir_op_infos[instr->op].num_inputs + 1 == func->num_params);
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
-      nir_src arg = nir_src_for_ssa(nir_imov_alu(b, instr->src[i], 1));
-      nir_src_copy(&call->params[i + 1], &arg, call);
+      assert(i + 1 < ARRAY_SIZE(params));
+      params[i + 1] = nir_mov_alu(b, instr->src[i], 1);
    }
 
-   nir_builder_instr_insert(b, &call->instr);
+   nir_inline_function_impl(b, func->impl, params);
 
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
                             nir_src_for_ssa(nir_load_deref(b, ret_deref)));
@@ -589,8 +592,26 @@ lower_doubles_instr_to_soft(nir_builder *b, nir_alu_instr *instr,
    return true;
 }
 
+nir_lower_doubles_options
+nir_lower_doubles_op_to_options_mask(nir_op opcode)
+{
+   switch (opcode) {
+   case nir_op_frcp:          return nir_lower_drcp;
+   case nir_op_fsqrt:         return nir_lower_dsqrt;
+   case nir_op_frsq:          return nir_lower_drsq;
+   case nir_op_ftrunc:        return nir_lower_dtrunc;
+   case nir_op_ffloor:        return nir_lower_dfloor;
+   case nir_op_fceil:         return nir_lower_dceil;
+   case nir_op_ffract:        return nir_lower_dfract;
+   case nir_op_fround_even:   return nir_lower_dround_even;
+   case nir_op_fmod:          return nir_lower_dmod;
+   default:                   return 0;
+   }
+}
+
 static bool
 lower_doubles_instr(nir_builder *b, nir_alu_instr *instr,
+                    const nir_shader *softfp64,
                     nir_lower_doubles_options options)
 {
    assert(instr->dest.dest.is_ssa);
@@ -604,63 +625,16 @@ lower_doubles_instr(nir_builder *b, nir_alu_instr *instr,
    if (!is_64)
       return false;
 
-   if (lower_doubles_instr_to_soft(b, instr, options))
+   if (lower_doubles_instr_to_soft(b, instr, softfp64, options))
       return true;
 
-   switch (instr->op) {
-   case nir_op_frcp:
-      if (!(options & nir_lower_drcp))
-         return false;
-      break;
-
-   case nir_op_fsqrt:
-      if (!(options & nir_lower_dsqrt))
-         return false;
-      break;
-
-   case nir_op_frsq:
-      if (!(options & nir_lower_drsq))
-         return false;
-      break;
-
-   case nir_op_ftrunc:
-      if (!(options & nir_lower_dtrunc))
-         return false;
-      break;
-
-   case nir_op_ffloor:
-      if (!(options & nir_lower_dfloor))
-         return false;
-      break;
-
-   case nir_op_fceil:
-      if (!(options & nir_lower_dceil))
-         return false;
-      break;
-
-   case nir_op_ffract:
-      if (!(options & nir_lower_dfract))
-         return false;
-      break;
-
-   case nir_op_fround_even:
-      if (!(options & nir_lower_dround_even))
-         return false;
-      break;
-
-   case nir_op_fmod:
-      if (!(options & nir_lower_dmod))
-         return false;
-      break;
-
-   default:
+   if (!(options & nir_lower_doubles_op_to_options_mask(instr->op)))
       return false;
-   }
 
    b->cursor = nir_before_instr(&instr->instr);
 
-   nir_ssa_def *src = nir_fmov_alu(b, instr->src[0],
-                                   instr->dest.dest.ssa.num_components);
+   nir_ssa_def *src = nir_mov_alu(b, instr->src[0],
+                                  instr->dest.dest.ssa.num_components);
 
    nir_ssa_def *result;
 
@@ -691,8 +665,8 @@ lower_doubles_instr(nir_builder *b, nir_alu_instr *instr,
       break;
 
    case nir_op_fmod: {
-      nir_ssa_def *src1 = nir_fmov_alu(b, instr->src[1],
-                                       instr->dest.dest.ssa.num_components);
+      nir_ssa_def *src1 = nir_mov_alu(b, instr->src[1],
+                                      instr->dest.dest.ssa.num_components);
       result = lower_mod(b, src, src1);
    }
       break;
@@ -707,6 +681,7 @@ lower_doubles_instr(nir_builder *b, nir_alu_instr *instr,
 
 static bool
 nir_lower_doubles_impl(nir_function_impl *impl,
+                       const nir_shader *softfp64,
                        nir_lower_doubles_options options)
 {
    bool progress = false;
@@ -714,17 +689,30 @@ nir_lower_doubles_impl(nir_function_impl *impl,
    nir_builder b;
    nir_builder_init(&b, impl);
 
-   nir_foreach_block(block, impl) {
+   nir_foreach_block_safe(block, impl) {
       nir_foreach_instr_safe(instr, block) {
          if (instr->type == nir_instr_type_alu)
             progress |= lower_doubles_instr(&b, nir_instr_as_alu(instr),
-                                            options);
+                                            softfp64, options);
       }
    }
 
    if (progress) {
-      nir_metadata_preserve(impl, nir_metadata_block_index |
-                                  nir_metadata_dominance);
+      if (options & nir_lower_fp64_full_software) {
+         /* SSA and register indices are completely messed up now */
+         nir_index_ssa_defs(impl);
+         nir_index_local_regs(impl);
+
+         nir_metadata_preserve(impl, nir_metadata_none);
+
+         /* And we have deref casts we need to clean up thanks to function
+          * inlining.
+          */
+         nir_opt_deref_impl(impl);
+      } else {
+         nir_metadata_preserve(impl, nir_metadata_block_index |
+                                     nir_metadata_dominance);
+      }
     } else {
 #ifndef NDEBUG
       impl->valid_metadata &= ~nir_metadata_not_properly_reset;
@@ -735,13 +723,15 @@ nir_lower_doubles_impl(nir_function_impl *impl,
 }
 
 bool
-nir_lower_doubles(nir_shader *shader, nir_lower_doubles_options options)
+nir_lower_doubles(nir_shader *shader,
+                  const nir_shader *softfp64,
+                  nir_lower_doubles_options options)
 {
    bool progress = false;
 
    nir_foreach_function(function, shader) {
       if (function->impl) {
-         progress |= nir_lower_doubles_impl(function->impl, options);
+         progress |= nir_lower_doubles_impl(function->impl, softfp64, options);
       }
    }