nir: fix const-cast warning on MSVC
[mesa.git] / src / compiler / nir / nir_lower_goto_ifs.c
index ff7443ccfb076a13d2ccdac05f45fc2e929fdb50..44eaf729ec522a4e0ecb78a413cc5705035c0312 100644 (file)
@@ -23,6 +23,9 @@
 
 #include "nir.h"
 #include "nir_builder.h"
+#include "nir_vla.h"
+
+#define NIR_LOWER_GOTO_IFS_DEBUG 0
 
 struct path {
    /** Set of blocks which this path represents
@@ -47,6 +50,7 @@ struct path_fork {
 };
 
 struct routes {
+   struct set *outside;
    struct path regular;
    struct path brk;
    struct path cont;
@@ -54,7 +58,7 @@ struct routes {
 };
 
 struct strct_lvl {
-   struct exec_node node;
+   struct list_head link;
 
    /** Set of blocks at the current level */
    struct set *blocks;
@@ -65,6 +69,9 @@ struct strct_lvl {
    /** Reach set from inside_outside if irreducable */
    struct set *reach;
 
+   /** Outside set from inside_outside if irreducable */
+   struct set *outside;
+
    /** True if a skip region starts with this level */
    bool skip_start;
 
@@ -75,6 +82,57 @@ struct strct_lvl {
    bool irreducible;
 };
 
+static int
+nir_block_ptr_cmp(const void *_a, const void *_b)
+{
+   const nir_block *const *a = _a;
+   const nir_block *const *b = _b;
+   return (int)(*a)->index - (int)(*b)->index;
+}
+
+static void
+print_block_set(const struct set *set)
+{
+   printf("{ ");
+   if (set != NULL) {
+      unsigned count = 0;
+      set_foreach(set, entry) {
+         if (count++)
+            printf(", ");
+         printf("%u", ((nir_block *)entry->key)->index);
+      }
+   }
+   printf(" }\n");
+}
+
+/** Return a sorted array of blocks for a set
+ *
+ * Hash set ordering is non-deterministic.  We hash based on pointers and so,
+ * if any pointer ever changes from one run to another, the order of the set
+ * may change.  Any time we're going to make decisions which may affect the
+ * final structure which may depend on ordering, we should first sort the
+ * blocks.
+ */
+static nir_block **
+sorted_block_arr_for_set(const struct set *block_set, void *mem_ctx)
+{
+   const unsigned num_blocks = block_set->entries;
+   nir_block **block_arr = ralloc_array(mem_ctx, nir_block *, num_blocks);
+   unsigned i = 0;
+   set_foreach(block_set, entry)
+      block_arr[i++] = (nir_block *)entry->key;
+   assert(i == num_blocks);
+   qsort(block_arr, num_blocks, sizeof(*block_arr), nir_block_ptr_cmp);
+   return block_arr;
+}
+
+static nir_block *
+block_for_singular_set(const struct set *block_set)
+{
+   assert(block_set->entries == 1);
+   return (nir_block *)_mesa_set_next_entry(block_set, NULL)->key;
+}
+
 /**
  * Sets all path variables to reach the target block via a fork
  */
@@ -233,8 +291,27 @@ fork_reachable(struct path_fork *fork)
 static void
 loop_routing_start(struct routes *routing, nir_builder *b,
                    struct path loop_path, struct set *reach,
-                   void *mem_ctx)
+                   struct set *outside, void *mem_ctx)
 {
+   if (NIR_LOWER_GOTO_IFS_DEBUG) {
+      printf("loop_routing_start:\n");
+      printf("    reach =                       ");
+      print_block_set(reach);
+      printf("    outside =                     ");
+      print_block_set(outside);
+      printf("    loop_path.reachable =         ");
+      print_block_set(loop_path.reachable);
+      printf("    routing->outside =            ");
+      print_block_set(routing->outside);
+      printf("    routing->regular.reachable =  ");
+      print_block_set(routing->regular.reachable);
+      printf("    routing->brk.reachable =      ");
+      print_block_set(routing->brk.reachable);
+      printf("    routing->cont.reachable =     ");
+      print_block_set(routing->cont.reachable);
+      printf("\n");
+   }
+
    struct routes *routing_backup = ralloc(mem_ctx, struct routes);
    *routing_backup = *routing;
    bool break_needed = false;
@@ -253,6 +330,12 @@ loop_routing_start(struct routes *routing, nir_builder *b,
       continue_needed = true;
    }
 
+   if (outside && outside->entries) {
+      routing->outside = _mesa_set_clone(routing->outside, routing);
+      set_foreach(outside, entry)
+         _mesa_set_add_pre_hashed(routing->outside, entry->hash, entry->key);
+   }
+
    routing->brk = routing_backup->regular;
    routing->cont = loop_path;
    routing->regular = loop_path;
@@ -364,6 +447,20 @@ inside_outside(nir_block *block, struct set *loop_heads, struct set *outside,
          _mesa_set_add(remaining, block->dom_children[i]);
    }
 
+
+   if (NIR_LOWER_GOTO_IFS_DEBUG) {
+      printf("inside_outside(%u):\n", block->index);
+      printf("    loop_heads = ");
+      print_block_set(loop_heads);
+      printf("    reach =      ");
+      print_block_set(reach);
+      printf("    brk_reach =  ");
+      print_block_set(brk_reachable);
+      printf("    remaining =  ");
+      print_block_set(remaining);
+      printf("\n");
+   }
+
    bool progress = true;
    while (remaining->entries && progress) {
       progress = false;
@@ -409,6 +506,43 @@ inside_outside(nir_block *block, struct set *loop_heads, struct set *outside,
          _mesa_set_add(reach, block->successors[i]);
       }
    }
+
+   if (NIR_LOWER_GOTO_IFS_DEBUG) {
+      printf("outside(%u) = ", block->index);
+      print_block_set(outside);
+      printf("reach(%u) =   ", block->index);
+      print_block_set(reach);
+   }
+}
+
+static struct path_fork *
+select_fork_recur(struct nir_block **blocks, unsigned start, unsigned end,
+                  nir_function_impl *impl, bool need_var, void *mem_ctx)
+{
+   if (start == end - 1)
+      return NULL;
+
+   struct path_fork *fork = ralloc(mem_ctx, struct path_fork);
+   fork->is_var = need_var;
+   if (need_var)
+      fork->path_var = nir_local_variable_create(impl, glsl_bool_type(),
+                                                 "path_select");
+
+   unsigned mid = start + (end - start) / 2;
+
+   fork->paths[0].reachable = _mesa_pointer_set_create(fork);
+   for (unsigned i = start; i < mid; i++)
+      _mesa_set_add(fork->paths[0].reachable, blocks[i]);
+   fork->paths[0].fork =
+      select_fork_recur(blocks, start, mid, impl, need_var, mem_ctx);
+
+   fork->paths[1].reachable = _mesa_pointer_set_create(fork);
+   for (unsigned i = mid; i < end; i++)
+      _mesa_set_add(fork->paths[1].reachable, blocks[i]);
+   fork->paths[1].fork =
+      select_fork_recur(blocks, mid, end, impl, need_var, mem_ctx);
+
+   return fork;
 }
 
 /**
@@ -422,31 +556,15 @@ static struct path_fork *
 select_fork(struct set *reachable, nir_function_impl *impl, bool need_var,
             void *mem_ctx)
 {
-   struct path_fork *fork = NULL;
-   if (reachable->entries > 1) {
-      fork = ralloc(mem_ctx, struct path_fork);
-      fork->is_var = need_var;
-      if (need_var)
-         fork->path_var = nir_local_variable_create(impl, glsl_bool_type(),
-                                                    "path_select");
-      fork->paths[0].reachable = _mesa_pointer_set_create(fork);
-      struct set_entry *entry = NULL;
-      while (fork->paths[0].reachable->entries < reachable->entries / 2 &&
-             (entry = _mesa_set_next_entry(reachable, entry))) {
-         _mesa_set_add_pre_hashed(fork->paths[0].reachable,
-                                  entry->hash, entry->key);
-      }
-      fork->paths[0].fork = select_fork(fork->paths[0].reachable, impl,
-                                        need_var, mem_ctx);
-      fork->paths[1].reachable = _mesa_pointer_set_create(fork);
-      while ((entry = _mesa_set_next_entry(reachable, entry))) {
-         _mesa_set_add_pre_hashed(fork->paths[1].reachable,
-                                  entry->hash, entry->key);
-      }
-      fork->paths[1].fork = select_fork(fork->paths[1].reachable, impl,
-                                        need_var, mem_ctx);
-   }
-   return fork;
+   assert(reachable->entries > 0);
+   if (reachable->entries <= 1)
+      return NULL;
+
+   /* Hash set ordering is non-deterministic.  We're about to turn a set into
+    * a tree so we really want things to be in a deterministic ordering.
+    */
+   return select_fork_recur(sorted_block_arr_for_set(reachable, mem_ctx),
+                            0, reachable->entries, impl, need_var, mem_ctx);
 }
 
 /**
@@ -480,23 +598,21 @@ handle_irreducible(struct set *remaining, struct strct_lvl *curr_level,
    struct set *old_candidates = _mesa_pointer_set_create(mem_ctx);
    while (candidate) {
       _mesa_set_add(old_candidates, candidate);
-      nir_block *to_be_added = candidate;
-      candidate = NULL;
 
+      /* Start with just the candidate block */
       _mesa_set_clear(curr_level->blocks, NULL);
-      while (to_be_added) {
-         _mesa_set_add(curr_level->blocks, to_be_added);
-         to_be_added = NULL;
-
-         set_foreach(remaining, entry) {
-            nir_block *remaining_block = (nir_block *) entry->key;
-            if (!_mesa_set_search(curr_level->blocks, remaining_block)
-                && _mesa_set_intersects(remaining_block->dom_frontier,
-                                        curr_level->blocks)) {
-               if (_mesa_set_search(old_candidates, remaining_block))
-                  to_be_added = remaining_block;
-               else
-                  candidate = remaining_block;
+      _mesa_set_add(curr_level->blocks, candidate);
+
+      candidate = NULL;
+      set_foreach(remaining, entry) {
+         nir_block *remaining_block = (nir_block *) entry->key;
+         if (!_mesa_set_search(curr_level->blocks, remaining_block) &&
+             _mesa_set_intersects(remaining_block->dom_frontier,
+                                  curr_level->blocks)) {
+            if (_mesa_set_search(old_candidates, remaining_block)) {
+               _mesa_set_add(curr_level->blocks, remaining_block);
+            } else {
+               candidate = remaining_block;
                break;
             }
          }
@@ -512,6 +628,7 @@ handle_irreducible(struct set *remaining, struct strct_lvl *curr_level,
       inside_outside((nir_block *) entry->key, loop_heads, remaining,
                      curr_level->reach, brk_reachable, mem_ctx);
    }
+   curr_level->outside = remaining;
    _mesa_set_destroy(loop_heads, NULL);
 }
 
@@ -551,17 +668,28 @@ handle_irreducible(struct set *remaining, struct strct_lvl *curr_level,
  *                       zeroth level
  */
 static void
-organize_levels(struct exec_list *levels, struct set *remaining,
+organize_levels(struct list_head *levels, struct set *children,
                 struct set *reach, struct routes *routing,
                 nir_function_impl *impl, bool is_domminated, void *mem_ctx)
 {
+   if (NIR_LOWER_GOTO_IFS_DEBUG) {
+      printf("organize_levels:\n");
+      printf("    children = ");
+      print_block_set(children);
+      printf("    reach =     ");
+      print_block_set(reach);
+   }
+
+   /* Duplicate remaining because we're going to destroy it */
+   struct set *remaining = _mesa_set_clone(children, mem_ctx);
+
    /* blocks that can be reached by the remaining blocks */
    struct set *remaining_frontier = _mesa_pointer_set_create(mem_ctx);
 
    /* targets of active skip path */
    struct set *skip_targets = _mesa_pointer_set_create(mem_ctx);
 
-   exec_list_make_empty(levels);
+   list_inithead(levels);
    while (remaining->entries) {
       _mesa_set_clear(remaining_frontier, NULL);
       set_foreach(remaining, entry) {
@@ -590,23 +718,20 @@ organize_levels(struct exec_list *levels, struct set *remaining,
                             routing->brk.reachable, mem_ctx);
       }
       assert(curr_level->blocks->entries);
-      curr_level->skip_start = 0;
 
       struct strct_lvl *prev_level = NULL;
-      struct exec_node *tail;
-      if ((tail = exec_list_get_tail(levels)))
-         prev_level = exec_node_data(struct strct_lvl, tail, node);
-
-      if (skip_targets->entries) {
-         set_foreach(skip_targets, entry) {
-            if (_mesa_set_search_pre_hashed(curr_level->blocks,
-                                            entry->hash, entry->key)) {
-               _mesa_set_remove(skip_targets, entry);
-               prev_level->skip_end = 1;
-               curr_level->skip_start = !!skip_targets->entries;
-            }
+      if (!list_is_empty(levels))
+         prev_level = list_last_entry(levels, struct strct_lvl, link);
+
+      set_foreach(skip_targets, entry) {
+         if (_mesa_set_search_pre_hashed(curr_level->blocks,
+                                         entry->hash, entry->key)) {
+            _mesa_set_remove(skip_targets, entry);
+            prev_level->skip_end = 1;
          }
       }
+      curr_level->skip_start = skip_targets->entries != 0;
+
       struct set *prev_frontier = NULL;
       if (!prev_level) {
          prev_frontier = reach;
@@ -615,9 +740,16 @@ organize_levels(struct exec_list *levels, struct set *remaining,
       } else {
          set_foreach(curr_level->blocks, blocks_entry) {
             nir_block *level_block = (nir_block *) blocks_entry->key;
-            if (!prev_frontier) {
-               prev_frontier = curr_level->blocks->entries == 1 ?
-                  level_block->dom_frontier :
+            if (curr_level->blocks->entries == 1) {
+               /* If we only have one block, there's no union operation and we
+                * can just use the one from the one block.
+                */
+               prev_frontier = level_block->dom_frontier;
+               break;
+            }
+
+            if (prev_frontier == NULL) {
+               prev_frontier =
                   _mesa_set_clone(level_block->dom_frontier, prev_level);
             } else {
                set_foreach(level_block->dom_frontier, entry)
@@ -641,19 +773,26 @@ organize_levels(struct exec_list *levels, struct set *remaining,
       }
 
       curr_level->skip_end = 0;
-      exec_list_push_tail(levels, &curr_level->node);
+      list_addtail(&curr_level->link, levels);
+   }
+
+   if (NIR_LOWER_GOTO_IFS_DEBUG) {
+      printf("    levels:\n");
+      list_for_each_entry(struct strct_lvl, level, levels, link) {
+         printf("        ");
+         print_block_set(level->blocks);
+      }
+      printf("\n");
    }
 
    if (skip_targets->entries)
-      exec_node_data(struct strct_lvl, exec_list_get_tail(levels), node)
-      ->skip_end = 1;
+      list_last_entry(levels, struct strct_lvl, link)->skip_end = 1;
 
    /* Iterate throught all levels reverse and create all the paths and forks */
    struct path path_after_skip;
 
-   foreach_list_typed_reverse(struct strct_lvl, level, node, levels) {
-      bool need_var = !(is_domminated && exec_node_get_prev(&level->node)
-                                         == &levels->head_sentinel);
+   list_for_each_entry_rev(struct strct_lvl, level, levels, link) {
+      bool need_var = !(is_domminated && level->link.prev == levels);
       level->out_path = routing->regular;
       if (level->skip_end) {
          path_after_skip = routing->regular;
@@ -688,9 +827,8 @@ select_blocks(struct routes *routing, nir_builder *b,
               struct path in_path, void *mem_ctx)
 {
    if (!in_path.fork) {
-      nir_structurize(routing, b, (nir_block *)
-                      _mesa_set_next_entry(in_path.reachable, NULL)->key,
-                      mem_ctx);
+      nir_block *block = block_for_singular_set(in_path.reachable);
+      nir_structurize(routing, b, block, mem_ctx);
    } else {
       assert(!(in_path.fork->is_var &&
                strcmp(in_path.fork->path_var->name, "path_select")));
@@ -706,15 +844,12 @@ select_blocks(struct routes *routing, nir_builder *b,
  * Builds the structurized nir code by the final level list.
  */
 static void
-plant_levels(struct exec_list *levels, struct routes *routing,
+plant_levels(struct list_head *levels, struct routes *routing,
              nir_builder *b, void *mem_ctx)
 {
    /* Place all dominated blocks and build the path forks */
-   struct exec_node *list_node;
-   while ((list_node = exec_list_pop_head(levels))) {
-      struct strct_lvl *curr_level =
-         exec_node_data(struct strct_lvl, list_node, node);
-      if (curr_level->skip_start) {
+   list_for_each_entry(struct strct_lvl, level, levels, link) {
+      if (level->skip_start) {
          assert(routing->regular.fork);
          assert(!(routing->regular.fork->is_var && strcmp(
              routing->regular.fork->path_var->name, "path_conditional")));
@@ -723,13 +858,15 @@ plant_levels(struct exec_list *levels, struct routes *routing,
          routing->regular = routing->regular.fork->paths[1];
       }
       struct path in_path = routing->regular;
-      routing->regular = curr_level->out_path;
-      if (curr_level->irreducible)
-         loop_routing_start(routing, b, in_path, curr_level->reach, mem_ctx);
+      routing->regular = level->out_path;
+      if (level->irreducible) {
+         loop_routing_start(routing, b, in_path, level->reach,
+                            level->outside, mem_ctx);
+      }
       select_blocks(routing, b, in_path, mem_ctx);
-      if (curr_level->irreducible)
+      if (level->irreducible)
          loop_routing_end(routing, b);
-      if (curr_level->skip_end)
+      if (level->skip_end)
          nir_pop_if(b, NULL);
    }
 }
@@ -744,13 +881,13 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block,
 {
    struct set *remaining = _mesa_pointer_set_create(mem_ctx);
    for (int i = 0; i < block->num_dom_children; i++) {
-      if (!_mesa_set_search(routing->brk.reachable, block->dom_children[i]))
+      if (!_mesa_set_search(routing->outside, block->dom_children[i]))
          _mesa_set_add(remaining, block->dom_children[i]);
    }
 
    /* If the block can reach back to itself, it is a loop head */
    int is_looped = _mesa_set_search(block->dom_frontier, block) != NULL;
-   struct exec_list outside_levels;
+   struct list_head outside_levels;
    if (is_looped) {
       struct set *loop_heads = _mesa_pointer_set_create(mem_ctx);
       _mesa_set_add(loop_heads, block);
@@ -772,7 +909,7 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block,
       };
       _mesa_set_add(loop_path.reachable, block);
 
-      loop_routing_start(routing, b, loop_path, reach, mem_ctx);
+      loop_routing_start(routing, b, loop_path, reach, outside, mem_ctx);
    }
 
    struct set *reach = _mesa_pointer_set_create(mem_ctx);
@@ -781,7 +918,7 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block,
    if (block->successors[1] && block->successors[1]->successors[0])
       _mesa_set_add(reach, block->successors[1]);
 
-   struct exec_list levels;
+   struct list_head levels;
    organize_levels(&levels, remaining, reach, routing, b->impl, true, mem_ctx);
 
    /* Push all instructions of this block, without the jump instr */
@@ -843,6 +980,7 @@ nir_lower_goto_ifs_impl(nir_function_impl *impl)
 
    struct routes *routing = ralloc(mem_ctx, struct routes);
    *routing = (struct routes) {
+      .outside = empty_set,
       .regular.reachable = end_set,
       .brk.reachable = empty_set,
       .cont.reachable = empty_set,