nir: Properly preserve metadata in more cases
[mesa.git] / src / compiler / nir / nir_opt_loop_unroll.c
index 9ab0a924c82812c56a630dbc625ab51c651f817e..c5e7b8c9b5169e5c2962371a14166477bff661bd 100644 (file)
@@ -42,7 +42,7 @@
  * to keep track of and update phis along the way which gets tricky and
  * doesn't add much value over converting to regs.
  *
- * The loop may have a continue instruction at the end of the loop which does
+ * The loop may have a jump instruction at the end of the loop which does
  * nothing.  Once we're out of SSA, we can safely delete it so we don't have
  * to deal with it later.
  */
@@ -67,7 +67,7 @@ loop_prepare_for_unroll(nir_loop *loop)
 
    nir_lower_phis_to_regs_block(block_after_loop);
 
-   /* Remove continue if its the last instruction in the loop */
+   /* Remove jump if it's the last instruction in the loop */
    nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
    if (last_instr && last_instr->type == nir_instr_type_jump) {
       nir_instr_remove(last_instr);
@@ -491,7 +491,7 @@ complex_unroll_single_terminator(nir_loop *loop)
    unsigned num_times_to_clone = loop->info->max_trip_count + 1;
 
    nir_cf_list lp_body;
-   nir_cf_node *unroll_loc =
+   UNUSED nir_cf_node *unroll_loc =
       complex_unroll_loop_body(loop, terminator, &lp_header, &lp_body,
                                remap_table, num_times_to_clone);
 
@@ -514,7 +514,7 @@ complex_unroll_single_terminator(nir_loop *loop)
 static bool
 wrapper_unroll(nir_loop *loop)
 {
-   if (!list_empty(&loop->info->loop_terminator_list)) {
+   if (!list_is_empty(&loop->info->loop_terminator_list)) {
 
       /* Unrolling a loop with a large number of exits can result in a
        * large inrease in register pressure. For now we just skip
@@ -560,31 +560,7 @@ wrapper_unroll(nir_loop *loop)
            nir_after_block(nir_if_last_else_block(terminator->nif));
       }
    } else {
-      nir_block *blk_after_loop =
-         nir_cursor_current_block(nir_after_cf_node(&loop->cf_node));
-
-      /* There may still be some single src phis following the loop that
-       * have not yet been cleaned up by another pass. Tidy those up
-       * before unrolling the loop.
-       */
-      nir_foreach_instr_safe(instr, blk_after_loop) {
-         if (instr->type != nir_instr_type_phi)
-            break;
-
-         nir_phi_instr *phi = nir_instr_as_phi(instr);
-         assert(exec_list_length(&phi->srcs) == 1);
-
-         nir_phi_src *phi_src =
-            exec_node_data(nir_phi_src, exec_list_get_head(&phi->srcs), node);
-
-         nir_ssa_def_rewrite_uses(&phi->dest.ssa, phi_src->src);
-         nir_instr_remove(instr);
-      }
-
-      /* Remove break at end of the loop */
-      nir_block *last_loop_blk = nir_loop_last_block(loop);
-      nir_instr *break_instr = nir_block_last_instr(last_loop_blk);
-      nir_instr_remove(break_instr);
+      loop_prepare_for_unroll(loop);
    }
 
    /* Pluck out the loop body. */
@@ -670,11 +646,9 @@ remove_out_of_bounds_induction_use(nir_shader *shader, nir_loop *loop,
             if (is_access_out_of_bounds(term, nir_src_as_deref(intrin->src[0]),
                                         trip_count)) {
                if (intrin->intrinsic == nir_intrinsic_load_deref) {
-                  assert(intrin->src[0].is_ssa);
-                  nir_ssa_def *a_ssa = intrin->src[0].ssa;
                   nir_ssa_def *undef =
-                     nir_ssa_undef(&b, intrin->num_components,
-                                   a_ssa->bit_size);
+                     nir_ssa_undef(&b, intrin->dest.ssa.num_components,
+                                   intrin->dest.ssa.bit_size);
                   nir_ssa_def_rewrite_uses(&intrin->dest.ssa,
                                            nir_src_for_ssa(undef));
                } else {
@@ -686,14 +660,6 @@ remove_out_of_bounds_induction_use(nir_shader *shader, nir_loop *loop,
             if (intrin->intrinsic == nir_intrinsic_copy_deref &&
                 is_access_out_of_bounds(term, nir_src_as_deref(intrin->src[1]),
                                         trip_count)) {
-               assert(intrin->src[1].is_ssa);
-               nir_ssa_def *a_ssa = intrin->src[1].ssa;
-               nir_ssa_def *undef =
-                  nir_ssa_undef(&b, intrin->num_components, a_ssa->bit_size);
-
-               /* Replace the copy with a store of the undefined value */
-               b.cursor = nir_before_instr(instr);
-               nir_store_deref(&b, nir_src_as_deref(intrin->src[0]), undef, ~0);
                nir_instr_remove(instr);
             }
          }
@@ -780,11 +746,20 @@ partial_unroll(nir_shader *shader, nir_loop *loop, unsigned trip_count)
    _mesa_hash_table_destroy(remap_table, NULL);
 }
 
+/*
+ * Returns true if we should unroll the loop, otherwise false.
+ */
 static bool
-is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
+check_unrolling_restrictions(nir_shader *shader, nir_loop *loop)
 {
-   unsigned max_iter = shader->options->max_unroll_iterations;
+   if (loop->control == nir_loop_control_unroll)
+      return true;
+
+   if (loop->control == nir_loop_control_dont_unroll)
+      return false;
 
+   nir_loop_info *li = loop->info;
+   unsigned max_iter = shader->options->max_unroll_iterations;
    unsigned trip_count =
       li->max_trip_count ? li->max_trip_count : li->guessed_trip_count;
 
@@ -801,7 +776,63 @@ is_loop_small_enough_to_unroll(nir_shader *shader, nir_loop_info *li)
 }
 
 static bool
-process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out)
+process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out,
+              bool *unrolled_this_block);
+
+static bool
+process_loops_in_block(nir_shader *sh, struct exec_list *block,
+                       bool *has_nested_loop_out)
+{
+   /* We try to unroll as many loops in one pass as possible.
+    * E.g. we can safely unroll both loops in this block:
+    *
+    *    if (...) {
+    *       loop {...}
+    *    }
+    *
+    *    if (...) {
+    *       loop {...}
+    *    }
+    *
+    * Unrolling one loop doesn't affect the other one.
+    *
+    * On the other hand for block with:
+    *
+    *    loop {...}
+    *    ...
+    *    loop {...}
+    *
+    * It is unsafe to unroll both loops in one pass without taking
+    * complicating precautions, since the structure of the block would
+    * change after unrolling the first loop. So in such a case we leave
+    * the second loop for the next iteration of unrolling to handle.
+    */
+
+   bool progress = false;
+   bool unrolled_this_block = false;
+
+   foreach_list_typed(nir_cf_node, nested_node, node, block) {
+      if (process_loops(sh, nested_node,
+                        has_nested_loop_out, &unrolled_this_block)) {
+         progress = true;
+
+         /* If current node is unrolled we could not safely continue
+          * our iteration since we don't know the next node
+          * and it's hard to guarantee that we won't end up unrolling
+          * inner loop of the currently unrolled one, if such exists.
+          */
+         if (unrolled_this_block) {
+            break;
+         }
+      }
+   }
+
+   return progress;
+}
+
+static bool
+process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out,
+              bool *unrolled_this_block)
 {
    bool progress = false;
    bool has_nested_loop = false;
@@ -812,16 +843,15 @@ process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out)
       return progress;
    case nir_cf_node_if: {
       nir_if *if_stmt = nir_cf_node_as_if(cf_node);
-      foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->then_list)
-         progress |= process_loops(sh, nested_node, has_nested_loop_out);
-      foreach_list_typed_safe(nir_cf_node, nested_node, node, &if_stmt->else_list)
-         progress |= process_loops(sh, nested_node, has_nested_loop_out);
+      progress |= process_loops_in_block(sh, &if_stmt->then_list,
+                                         has_nested_loop_out);
+      progress |= process_loops_in_block(sh, &if_stmt->else_list,
+                                         has_nested_loop_out);
       return progress;
    }
    case nir_cf_node_loop: {
       loop = nir_cf_node_as_loop(cf_node);
-      foreach_list_typed_safe(nir_cf_node, nested_node, node, &loop->body)
-         progress |= process_loops(sh, nested_node, &has_nested_loop);
+      progress |= process_loops_in_block(sh, &loop->body, &has_nested_loop);
 
       break;
    }
@@ -829,10 +859,12 @@ process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out)
       unreachable("unknown cf node type");
    }
 
+   const bool unrolled_child_block = progress;
+
    /* Don't attempt to unroll a second inner loop in this pass, wait until the
     * next pass as we have altered the cf.
     */
-   if (!progress) {
+   if (!progress && loop->control != nir_loop_control_dont_unroll) {
 
       /* Check for the classic
        *
@@ -858,7 +890,7 @@ process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out)
          unsigned num_lt = list_length(&loop->info->loop_terminator_list);
          if (!has_nested_loop && num_lt == 1 && !loop->partially_unrolled &&
              loop->info->guessed_trip_count &&
-             is_loop_small_enough_to_unroll(sh, loop->info)) {
+             check_unrolling_restrictions(sh, loop)) {
             partial_unroll(sh, loop, loop->info->guessed_trip_count);
             progress = true;
          }
@@ -867,7 +899,7 @@ process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out)
       if (has_nested_loop || !loop->info->limiting_terminator)
          goto exit;
 
-      if (!is_loop_small_enough_to_unroll(sh, loop->info))
+      if (!check_unrolling_restrictions(sh, loop))
          goto exit;
 
       if (loop->info->exact_trip_count_known) {
@@ -913,6 +945,9 @@ process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out)
 
 exit:
    *has_nested_loop_out = true;
+   if (progress && !unrolled_child_block)
+      *unrolled_this_block = true;
+
    return progress;
 }
 
@@ -924,14 +959,16 @@ nir_opt_loop_unroll_impl(nir_function_impl *impl,
    nir_metadata_require(impl, nir_metadata_loop_analysis, indirect_mask);
    nir_metadata_require(impl, nir_metadata_block_index);
 
-   foreach_list_typed_safe(nir_cf_node, node, node, &impl->body) {
-      bool has_nested_loop = false;
-      progress |= process_loops(impl->function->shader, node,
-                                &has_nested_loop);
-   }
+   bool has_nested_loop = false;
+   progress |= process_loops_in_block(impl->function->shader, &impl->body,
+                                      &has_nested_loop);
 
-   if (progress)
+   if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_none);
       nir_lower_regs_to_ssa_impl(impl);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
+   }
 
    return progress;
 }