nir/lower_doubles: Inline functions directly in lower_doubles
[mesa.git] / src / compiler / nir / nir_lower_double_ops.c
index 69f4b3a78db3db68bbe5c9ac1dc30610e71b4b75..863046e65c7e9b5ef9800437b494b35e696e4eb1 100644 (file)
@@ -426,6 +426,7 @@ lower_mod(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
 
 static bool
 lower_doubles_instr_to_soft(nir_builder *b, nir_alu_instr *instr,
+                            const nir_shader *softfp64,
                             nir_lower_doubles_options options)
 {
    if (!(options & nir_lower_fp64_full_software))
@@ -553,35 +554,34 @@ lower_doubles_instr_to_soft(nir_builder *b, nir_alu_instr *instr,
       return false;
    }
 
-   nir_shader *shader = b->shader;
    nir_function *func = NULL;
-
-   nir_foreach_function(function, shader) {
+   nir_foreach_function(function, softfp64) {
       if (strcmp(function->name, name) == 0) {
          func = function;
          break;
       }
    }
-   if (!func) {
+   if (!func || !func->impl) {
       fprintf(stderr, "Cannot find function \"%s\"\n", name);
       assert(func);
    }
 
    b->cursor = nir_before_instr(&instr->instr);
 
-   nir_call_instr *call = nir_call_instr_create(shader, func);
+   nir_ssa_def *params[4] = { NULL, };
 
    nir_variable *ret_tmp =
       nir_local_variable_create(b->impl, return_type, "return_tmp");
    nir_deref_instr *ret_deref = nir_build_deref_var(b, ret_tmp);
-   call->params[0] = nir_src_for_ssa(&ret_deref->dest.ssa);
+   params[0] = &ret_deref->dest.ssa;
 
+   assert(nir_op_infos[instr->op].num_inputs + 1 == func->num_params);
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
-      nir_src arg = nir_src_for_ssa(nir_imov_alu(b, instr->src[i], 1));
-      nir_src_copy(&call->params[i + 1], &arg, call);
+      assert(i + 1 < ARRAY_SIZE(params));
+      params[i + 1] = nir_imov_alu(b, instr->src[i], 1);
    }
 
-   nir_builder_instr_insert(b, &call->instr);
+   nir_inline_function_impl(b, func->impl, params);
 
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
                             nir_src_for_ssa(nir_load_deref(b, ret_deref)));
@@ -608,6 +608,7 @@ nir_lower_doubles_op_to_options_mask(nir_op opcode)
 
 static bool
 lower_doubles_instr(nir_builder *b, nir_alu_instr *instr,
+                    const nir_shader *softfp64,
                     nir_lower_doubles_options options)
 {
    assert(instr->dest.dest.is_ssa);
@@ -621,7 +622,7 @@ lower_doubles_instr(nir_builder *b, nir_alu_instr *instr,
    if (!is_64)
       return false;
 
-   if (lower_doubles_instr_to_soft(b, instr, options))
+   if (lower_doubles_instr_to_soft(b, instr, softfp64, options))
       return true;
 
    if (!(options & nir_lower_doubles_op_to_options_mask(instr->op)))
@@ -677,6 +678,7 @@ lower_doubles_instr(nir_builder *b, nir_alu_instr *instr,
 
 static bool
 nir_lower_doubles_impl(nir_function_impl *impl,
+                       const nir_shader *softfp64,
                        nir_lower_doubles_options options)
 {
    bool progress = false;
@@ -684,17 +686,30 @@ nir_lower_doubles_impl(nir_function_impl *impl,
    nir_builder b;
    nir_builder_init(&b, impl);
 
-   nir_foreach_block(block, impl) {
+   nir_foreach_block_safe(block, impl) {
       nir_foreach_instr_safe(instr, block) {
          if (instr->type == nir_instr_type_alu)
             progress |= lower_doubles_instr(&b, nir_instr_as_alu(instr),
-                                            options);
+                                            softfp64, options);
       }
    }
 
    if (progress) {
-      nir_metadata_preserve(impl, nir_metadata_block_index |
-                                  nir_metadata_dominance);
+      if (options & nir_lower_fp64_full_software) {
+         /* SSA and register indices are completely messed up now */
+         nir_index_ssa_defs(impl);
+         nir_index_local_regs(impl);
+
+         nir_metadata_preserve(impl, nir_metadata_none);
+
+         /* And we have deref casts we need to clean up thanks to function
+          * inlining.
+          */
+         nir_opt_deref_impl(impl);
+      } else {
+         nir_metadata_preserve(impl, nir_metadata_block_index |
+                                     nir_metadata_dominance);
+      }
     } else {
 #ifndef NDEBUG
       impl->valid_metadata &= ~nir_metadata_not_properly_reset;
@@ -705,13 +720,15 @@ nir_lower_doubles_impl(nir_function_impl *impl,
 }
 
 bool
-nir_lower_doubles(nir_shader *shader, nir_lower_doubles_options options)
+nir_lower_doubles(nir_shader *shader,
+                  const nir_shader *softfp64,
+                  nir_lower_doubles_options options)
 {
    bool progress = false;
 
    nir_foreach_function(function, shader) {
       if (function->impl) {
-         progress |= nir_lower_doubles_impl(function->impl, options);
+         progress |= nir_lower_doubles_impl(function->impl, softfp64, options);
       }
    }