nir: add loop unroll support for complex wrapper loops
[mesa.git] / src / compiler / nir / nir_opt_loop_unroll.c
index 9c33267cb7290d4e97a4dac720bd3df1da6cce22..0ba150f126453296850035552a6a20761f14161e 100644 (file)
@@ -67,7 +67,6 @@ loop_prepare_for_unroll(nir_loop *loop)
    /* Remove continue if its the last instruction in the loop */
    nir_instr *last_instr = nir_block_last_instr(nir_loop_last_block(loop));
    if (last_instr && last_instr->type == nir_instr_type_jump) {
-      assert(nir_instr_as_jump(last_instr)->type == nir_jump_continue);
       nir_instr_remove(last_instr);
    }
 }
@@ -474,54 +473,91 @@ complex_unroll(nir_loop *loop, nir_loop_terminator *unlimit_term,
 static bool
 wrapper_unroll(nir_loop *loop)
 {
-   bool progress = false;
-
-   nir_block *blk_after_loop =
-      nir_cursor_current_block(nir_after_cf_node(&loop->cf_node));
-
-   /* There may still be some single src phis following the loop that
-    * have not yet been cleaned up by another pass. Tidy those up before
-    * unrolling the loop.
-    */
-   nir_foreach_instr_safe(instr, blk_after_loop) {
-      if (instr->type != nir_instr_type_phi)
-         break;
+   if (!list_empty(&loop->info->loop_terminator_list)) {
 
-      nir_phi_instr *phi = nir_instr_as_phi(instr);
-      assert(exec_list_length(&phi->srcs) == 1);
+      /* Unrolling a loop with a large number of exits can result in a
+       * large inrease in register pressure. For now we just skip
+       * unrolling if we have more than 3 exits (not including the break
+       * at the end of the loop).
+       *
+       * TODO: Most loops that fit this pattern are simply switch
+       * statements that are converted to a loop to take advantage of
+       * exiting jump instruction handling. In this case we could make
+       * use of a binary seach pattern like we do in
+       * nir_lower_indirect_derefs(), this should allow us to unroll the
+       * loops in an optimal way and should also avoid some of the
+       * register pressure that comes from simply nesting the
+       * terminators one after the other.
+       */
+      if (list_length(&loop->info->loop_terminator_list) > 3)
+         return false;
+
+      loop_prepare_for_unroll(loop);
+
+      nir_cursor loop_end = nir_after_block(nir_loop_last_block(loop));
+      list_for_each_entry(nir_loop_terminator, terminator,
+                          &loop->info->loop_terminator_list,
+                          loop_terminator_link) {
+
+         /* Remove break from the terminator */
+         nir_instr *break_instr =
+            nir_block_last_instr(terminator->break_block);
+         nir_instr_remove(break_instr);
+
+         /* Pluck out the loop body. */
+         nir_cf_list loop_body;
+         nir_cf_extract(&loop_body,
+                        nir_after_cf_node(&terminator->nif->cf_node),
+                        loop_end);
+
+         /* Reinsert loop body into continue from block */
+         nir_cf_reinsert(&loop_body,
+                         nir_after_block(terminator->continue_from_block));
+
+         loop_end = terminator->continue_from_then ?
+           nir_after_block(nir_if_last_then_block(terminator->nif)) :
+           nir_after_block(nir_if_last_else_block(terminator->nif));
+      }
+   } else {
+      nir_block *blk_after_loop =
+         nir_cursor_current_block(nir_after_cf_node(&loop->cf_node));
 
-      nir_phi_src *phi_src = exec_node_data(nir_phi_src,
-                                            exec_list_get_head(&phi->srcs),
-                                            node);
+      /* There may still be some single src phis following the loop that
+       * have not yet been cleaned up by another pass. Tidy those up
+       * before unrolling the loop.
+       */
+      nir_foreach_instr_safe(instr, blk_after_loop) {
+         if (instr->type != nir_instr_type_phi)
+            break;
 
-      nir_ssa_def_rewrite_uses(&phi->dest.ssa, phi_src->src);
-      nir_instr_remove(instr);
+         nir_phi_instr *phi = nir_instr_as_phi(instr);
+         assert(exec_list_length(&phi->srcs) == 1);
 
-      progress = true;
-   }
+         nir_phi_src *phi_src =
+            exec_node_data(nir_phi_src, exec_list_get_head(&phi->srcs), node);
 
-   nir_block *last_loop_blk = nir_loop_last_block(loop);
-   if (nir_block_ends_in_break(last_loop_blk)) {
+         nir_ssa_def_rewrite_uses(&phi->dest.ssa, phi_src->src);
+         nir_instr_remove(instr);
+      }
 
       /* Remove break at end of the loop */
+      nir_block *last_loop_blk = nir_loop_last_block(loop);
       nir_instr *break_instr = nir_block_last_instr(last_loop_blk);
       nir_instr_remove(break_instr);
+   }
 
-      /* Pluck out the loop body. */
-      nir_cf_list loop_body;
-      nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
-                     nir_after_block(nir_loop_last_block(loop)));
-
-      /* Reinsert loop body after the loop */
-      nir_cf_reinsert(&loop_body, nir_after_cf_node(&loop->cf_node));
+   /* Pluck out the loop body. */
+   nir_cf_list loop_body;
+   nir_cf_extract(&loop_body, nir_before_block(nir_loop_first_block(loop)),
+                  nir_after_block(nir_loop_last_block(loop)));
 
-      /* The loop has been unrolled so remove it. */
-      nir_cf_node_remove(&loop->cf_node);
+   /* Reinsert loop body after the loop */
+   nir_cf_reinsert(&loop_body, nir_after_cf_node(&loop->cf_node));
 
-      progress = true;
-   }
+   /* The loop has been unrolled so remove it. */
+   nir_cf_node_remove(&loop->cf_node);
 
-   return progress;
+   return true;
 }
 
 static bool
@@ -585,9 +621,12 @@ process_loops(nir_shader *sh, nir_cf_node *cf_node, bool *has_nested_loop_out)
        * statements in a loop like this.
        */
       if (loop->info->limiting_terminator == NULL &&
-          list_empty(&loop->info->loop_terminator_list) &&
           !loop->info->complex_loop) {
 
+         nir_block *last_loop_blk = nir_loop_last_block(loop);
+         if (!nir_block_ends_in_break(last_loop_blk))
+            goto exit;
+
          progress = wrapper_unroll(loop);
 
          goto exit;