nir/lower_goto_if: Rework some set union logic
[mesa.git] / src / compiler / nir / nir_lower_goto_ifs.c
index 956371f5a1ebd583e25f30bd42d6edbf9cd59069..ac202950fb63067a7f21cf6a19cc851035352e47 100644 (file)
 #include "nir_builder.h"
 
 struct path {
+   /** Set of blocks which this path represents
+    *
+    * It's "reachable" not in the sense that these are all the nodes reachable
+    * through this path but in the sense that, when you see one of these
+    * blocks, you know you've reached this path.
+    */
    struct set *reachable;
+
+   /** Fork in the path, if reachable->entries > 1 */
    struct path_fork *fork;
 };
 
@@ -46,12 +54,24 @@ struct routes {
 };
 
 struct strct_lvl {
-   struct exec_node node;
+   struct list_head link;
+
+   /** Set of blocks at the current level */
    struct set *blocks;
+
+   /** Path for the next level */
    struct path out_path;
+
+   /** Reach set from inside_outside if irreducable */
    struct set *reach;
+
+   /** True if a skip region starts with this level */
    bool skip_start;
+
+   /** True if a skip region ends with this level */
    bool skip_end;
+
+   /** True if this level is irreducable */
    bool irreducible;
 };
 
@@ -212,9 +232,10 @@ fork_reachable(struct path_fork *fork)
  */
 static void
 loop_routing_start(struct routes *routing, nir_builder *b,
-                   struct path loop_path, struct set *reach)
+                   struct path loop_path, struct set *reach,
+                   void *mem_ctx)
 {
-   struct routes *routing_backup = ralloc(routing, struct routes);
+   struct routes *routing_backup = ralloc(mem_ctx, struct routes);
    *routing_backup = *routing;
    bool break_needed = false;
    bool continue_needed = false;
@@ -238,7 +259,7 @@ loop_routing_start(struct routes *routing, nir_builder *b,
    routing->loop_backup = routing_backup;
 
    if (break_needed) {
-      struct path_fork *fork = ralloc(routing_backup, struct path_fork);
+      struct path_fork *fork = ralloc(mem_ctx, struct path_fork);
       fork->is_var = true;
       fork->path_var = nir_local_variable_create(b->impl, glsl_bool_type(),
                                                  "path_break");
@@ -248,7 +269,7 @@ loop_routing_start(struct routes *routing, nir_builder *b,
       routing->brk.reachable = fork_reachable(fork);
    }
    if (continue_needed) {
-      struct path_fork *fork = ralloc(routing_backup, struct path_fork);
+      struct path_fork *fork = ralloc(mem_ctx, struct path_fork);
       fork->is_var = true;
       fork->path_var = nir_local_variable_create(b->impl, glsl_bool_type(),
                                                  "path_continue");
@@ -334,10 +355,10 @@ loop_routing_end(struct routes *routing, nir_builder *b)
  */
 static void
 inside_outside(nir_block *block, struct set *loop_heads, struct set *outside,
-               struct set *reach, struct set *brk_reachable)
+               struct set *reach, struct set *brk_reachable, void *mem_ctx)
 {
    assert(_mesa_set_search(loop_heads, block));
-   struct set *remaining = _mesa_pointer_set_create(NULL);
+   struct set *remaining = _mesa_pointer_set_create(mem_ctx);
    for (int i = 0; i < block->num_dom_children; i++) {
       if (!_mesa_set_search(brk_reachable, block->dom_children[i]))
          _mesa_set_add(remaining, block->dom_children[i]);
@@ -379,12 +400,9 @@ inside_outside(nir_block *block, struct set *loop_heads, struct set *outside,
    /* Recurse for each remaining */
    set_foreach(remaining, entry) {
       inside_outside((nir_block *) entry->key, loop_heads, outside, reach,
-                     brk_reachable);
+                     brk_reachable, mem_ctx);
    }
 
-   _mesa_set_destroy(remaining, NULL);
-   remaining = NULL;
-
    for (int i = 0; i < 2; i++) {
       if (block->successors[i] && block->successors[i]->successors[0] &&
           !_mesa_set_search(loop_heads, block->successors[i])) {
@@ -401,11 +419,12 @@ inside_outside(nir_block *block, struct set *loop_heads, struct set *outside,
  * then the function calls itself recursively
  */
 static struct path_fork *
-select_fork(struct set *reachable, nir_function_impl *impl, bool need_var)
+select_fork(struct set *reachable, nir_function_impl *impl, bool need_var,
+            void *mem_ctx)
 {
    struct path_fork *fork = NULL;
    if (reachable->entries > 1) {
-      fork = ralloc(reachable, struct path_fork);
+      fork = ralloc(mem_ctx, struct path_fork);
       fork->is_var = need_var;
       if (need_var)
          fork->path_var = nir_local_variable_create(impl, glsl_bool_type(),
@@ -418,14 +437,14 @@ select_fork(struct set *reachable, nir_function_impl *impl, bool need_var)
                                   entry->hash, entry->key);
       }
       fork->paths[0].fork = select_fork(fork->paths[0].reachable, impl,
-                                         need_var);
+                                        need_var, mem_ctx);
       fork->paths[1].reachable = _mesa_pointer_set_create(fork);
       while ((entry = _mesa_set_next_entry(reachable, entry))) {
          _mesa_set_add_pre_hashed(fork->paths[1].reachable,
                                   entry->hash, entry->key);
       }
       fork->paths[1].fork = select_fork(fork->paths[1].reachable, impl,
-                                         need_var);
+                                        need_var, mem_ctx);
    }
    return fork;
 }
@@ -454,10 +473,11 @@ select_fork(struct set *reachable, nir_function_impl *impl, bool need_var)
  */
 static void
 handle_irreducible(struct set *remaining, struct strct_lvl *curr_level,
-                   struct set *brk_reachable) {
+                   struct set *brk_reachable, void *mem_ctx)
+{
    nir_block *candidate = (nir_block *)
       _mesa_set_next_entry(remaining, NULL)->key;
-   struct set *old_candidates = _mesa_pointer_set_create(curr_level);
+   struct set *old_candidates = _mesa_pointer_set_create(mem_ctx);
    while (candidate) {
       _mesa_set_add(old_candidates, candidate);
       nir_block *to_be_added = candidate;
@@ -490,7 +510,7 @@ handle_irreducible(struct set *remaining, struct strct_lvl *curr_level,
    set_foreach(curr_level->blocks, entry) {
       _mesa_set_remove_key(remaining, entry->key);
       inside_outside((nir_block *) entry->key, loop_heads, remaining,
-                     curr_level->reach, brk_reachable);
+                     curr_level->reach, brk_reachable, mem_ctx);
    }
    _mesa_set_destroy(loop_heads, NULL);
 }
@@ -531,19 +551,17 @@ handle_irreducible(struct set *remaining, struct strct_lvl *curr_level,
  *                       zeroth level
  */
 static void
-organize_levels(struct exec_list *levels, struct set *remaining,
+organize_levels(struct list_head *levels, struct set *remaining,
                 struct set *reach, struct routes *routing,
-                nir_function_impl *impl, bool is_domminated)
+                nir_function_impl *impl, bool is_domminated, void *mem_ctx)
 {
-   void *mem_ctx = ralloc_parent(remaining);
-
    /* blocks that can be reached by the remaining blocks */
    struct set *remaining_frontier = _mesa_pointer_set_create(mem_ctx);
 
    /* targets of active skip path */
    struct set *skip_targets = _mesa_pointer_set_create(mem_ctx);
 
-   exec_list_make_empty(levels);
+   list_inithead(levels);
    while (remaining->entries) {
       _mesa_set_clear(remaining_frontier, NULL);
       set_foreach(remaining, entry) {
@@ -556,7 +574,7 @@ organize_levels(struct exec_list *levels, struct set *remaining,
          }
       }
 
-      struct strct_lvl *curr_level = ralloc(mem_ctx, struct strct_lvl);
+      struct strct_lvl *curr_level = rzalloc(mem_ctx, struct strct_lvl);
       curr_level->blocks = _mesa_pointer_set_create(curr_level);
       set_foreach(remaining, entry) {
          nir_block *candidate = (nir_block *) entry->key;
@@ -567,26 +585,25 @@ organize_levels(struct exec_list *levels, struct set *remaining,
       }
 
       curr_level->irreducible = !curr_level->blocks->entries;
-      if (curr_level->irreducible)
-         handle_irreducible(remaining, curr_level, routing->brk.reachable);
+      if (curr_level->irreducible) {
+         handle_irreducible(remaining, curr_level,
+                            routing->brk.reachable, mem_ctx);
+      }
       assert(curr_level->blocks->entries);
-      curr_level->skip_start = 0;
 
       struct strct_lvl *prev_level = NULL;
-      struct exec_node *tail;
-      if ((tail = exec_list_get_tail(levels)))
-         prev_level = exec_node_data(struct strct_lvl, tail, node);
-
-      if (skip_targets->entries) {
-         set_foreach(skip_targets, entry) {
-            if (_mesa_set_search_pre_hashed(curr_level->blocks,
-                                            entry->hash, entry->key)) {
-               _mesa_set_remove(skip_targets, entry);
-               prev_level->skip_end = 1;
-               curr_level->skip_start = !!skip_targets->entries;
-            }
+      if (!list_is_empty(levels))
+         prev_level = list_last_entry(levels, struct strct_lvl, link);
+
+      set_foreach(skip_targets, entry) {
+         if (_mesa_set_search_pre_hashed(curr_level->blocks,
+                                         entry->hash, entry->key)) {
+            _mesa_set_remove(skip_targets, entry);
+            prev_level->skip_end = 1;
          }
       }
+      curr_level->skip_start = skip_targets->entries != 0;
+
       struct set *prev_frontier = NULL;
       if (!prev_level) {
          prev_frontier = reach;
@@ -595,9 +612,16 @@ organize_levels(struct exec_list *levels, struct set *remaining,
       } else {
          set_foreach(curr_level->blocks, blocks_entry) {
             nir_block *level_block = (nir_block *) blocks_entry->key;
-            if (!prev_frontier) {
-               prev_frontier = curr_level->blocks->entries == 1 ?
-                  level_block->dom_frontier :
+            if (curr_level->blocks->entries == 1) {
+               /* If we only have one block, there's no union operation and we
+                * can just use the one from the one block.
+                */
+               prev_frontier = level_block->dom_frontier;
+               break;
+            }
+
+            if (prev_frontier == NULL) {
+               prev_frontier =
                   _mesa_set_clone(level_block->dom_frontier, prev_level);
             } else {
                set_foreach(level_block->dom_frontier, entry)
@@ -621,32 +645,26 @@ organize_levels(struct exec_list *levels, struct set *remaining,
       }
 
       curr_level->skip_end = 0;
-      exec_list_push_tail(levels, &curr_level->node);
+      list_addtail(&curr_level->link, levels);
    }
 
    if (skip_targets->entries)
-      exec_node_data(struct strct_lvl, exec_list_get_tail(levels), node)
-      ->skip_end = 1;
-   _mesa_set_destroy(remaining_frontier, NULL);
-   remaining_frontier = NULL;
-   _mesa_set_destroy(skip_targets, NULL);
-   skip_targets = NULL;
+      list_last_entry(levels, struct strct_lvl, link)->skip_end = 1;
 
    /* Iterate throught all levels reverse and create all the paths and forks */
    struct path path_after_skip;
 
-   foreach_list_typed_reverse(struct strct_lvl, level, node, levels) {
-      bool need_var = !(is_domminated && exec_node_get_prev(&level->node)
-                                         == &levels->head_sentinel);
+   list_for_each_entry_rev(struct strct_lvl, level, levels, link) {
+      bool need_var = !(is_domminated && level->link.prev == levels);
       level->out_path = routing->regular;
       if (level->skip_end) {
          path_after_skip = routing->regular;
       }
       routing->regular.reachable = level->blocks;
       routing->regular.fork = select_fork(routing->regular.reachable, impl,
-                                          need_var);
+                                          need_var, mem_ctx);
       if (level->skip_start) {
-         struct path_fork *fork = ralloc(level, struct path_fork);
+         struct path_fork *fork = ralloc(mem_ctx, struct path_fork);
          fork->is_var = need_var;
          if (need_var)
             fork->path_var = nir_local_variable_create(impl, glsl_bool_type(),
@@ -660,24 +678,28 @@ organize_levels(struct exec_list *levels, struct set *remaining,
 }
 
 static void
-nir_structurize(struct routes *routing, nir_builder *b, nir_block *block);
+nir_structurize(struct routes *routing, nir_builder *b,
+                nir_block *block, void *mem_ctx);
 
 /**
  * Places all the if else statements to select between all blocks in a select
  * path
  */
 static void
-select_blocks(struct routes *routing, nir_builder *b, struct path in_path) {
+select_blocks(struct routes *routing, nir_builder *b,
+              struct path in_path, void *mem_ctx)
+{
    if (!in_path.fork) {
       nir_structurize(routing, b, (nir_block *)
-                      _mesa_set_next_entry(in_path.reachable, NULL)->key);
+                      _mesa_set_next_entry(in_path.reachable, NULL)->key,
+                      mem_ctx);
    } else {
       assert(!(in_path.fork->is_var &&
                strcmp(in_path.fork->path_var->name, "path_select")));
       nir_push_if_src(b, nir_src_for_ssa(fork_condition(b, in_path.fork)));
-      select_blocks(routing, b, in_path.fork->paths[1]);
+      select_blocks(routing, b, in_path.fork->paths[1], mem_ctx);
       nir_push_else(b, NULL);
-      select_blocks(routing, b, in_path.fork->paths[0]);
+      select_blocks(routing, b, in_path.fork->paths[0], mem_ctx);
       nir_pop_if(b, NULL);
    }
 }
@@ -686,15 +708,12 @@ select_blocks(struct routes *routing, nir_builder *b, struct path in_path) {
  * Builds the structurized nir code by the final level list.
  */
 static void
-plant_levels(struct exec_list *levels, struct routes *routing,
-             nir_builder *b)
+plant_levels(struct list_head *levels, struct routes *routing,
+             nir_builder *b, void *mem_ctx)
 {
    /* Place all dominated blocks and build the path forks */
-   struct exec_node *list_node;
-   while ((list_node = exec_list_pop_head(levels))) {
-      struct strct_lvl *curr_level =
-         exec_node_data(struct strct_lvl, list_node, node);
-      if (curr_level->skip_start) {
+   list_for_each_entry(struct strct_lvl, level, levels, link) {
+      if (level->skip_start) {
          assert(routing->regular.fork);
          assert(!(routing->regular.fork->is_var && strcmp(
              routing->regular.fork->path_var->name, "path_conditional")));
@@ -703,15 +722,14 @@ plant_levels(struct exec_list *levels, struct routes *routing,
          routing->regular = routing->regular.fork->paths[1];
       }
       struct path in_path = routing->regular;
-      routing->regular = curr_level->out_path;
-      if (curr_level->irreducible)
-         loop_routing_start(routing, b, in_path, curr_level->reach);
-      select_blocks(routing, b, in_path);
-      if (curr_level->irreducible)
+      routing->regular = level->out_path;
+      if (level->irreducible)
+         loop_routing_start(routing, b, in_path, level->reach, mem_ctx);
+      select_blocks(routing, b, in_path, mem_ctx);
+      if (level->irreducible)
          loop_routing_end(routing, b);
-      if (curr_level->skip_end)
+      if (level->skip_end)
          nir_pop_if(b, NULL);
-      ralloc_free(curr_level);
    }
 }
 
@@ -720,11 +738,9 @@ plant_levels(struct exec_list *levels, struct routes *routing,
  * \param  routing  the routing after the block and all dominated blocks
  */
 static void
-nir_structurize(struct routes *routing, nir_builder *b, nir_block *block)
+nir_structurize(struct routes *routing, nir_builder *b, nir_block *block,
+                void *mem_ctx)
 {
-   /* Mem context for this function; freed at the end */
-   void *mem_ctx = ralloc_context(routing);
-
    struct set *remaining = _mesa_pointer_set_create(mem_ctx);
    for (int i = 0; i < block->num_dom_children; i++) {
       if (!_mesa_set_search(routing->brk.reachable, block->dom_children[i]))
@@ -733,7 +749,7 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block)
 
    /* If the block can reach back to itself, it is a loop head */
    int is_looped = _mesa_set_search(block->dom_frontier, block) != NULL;
-   struct exec_list outside_levels;
+   struct list_head outside_levels;
    if (is_looped) {
       struct set *loop_heads = _mesa_pointer_set_create(mem_ctx);
       _mesa_set_add(loop_heads, block);
@@ -741,16 +757,13 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block)
       struct set *outside = _mesa_pointer_set_create(mem_ctx);
       struct set *reach = _mesa_pointer_set_create(mem_ctx);
       inside_outside(block, loop_heads, outside, reach,
-                     routing->brk.reachable);
-
-      _mesa_set_destroy(loop_heads, NULL);
-      loop_heads = NULL;
+                     routing->brk.reachable, mem_ctx);
 
       set_foreach(outside, entry)
          _mesa_set_remove_key(remaining, entry->key);
 
       organize_levels(&outside_levels, outside, reach, routing, b->impl,
-                      false);
+                      false, mem_ctx);
 
       struct path loop_path = {
          .reachable = _mesa_pointer_set_create(mem_ctx),
@@ -758,9 +771,7 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block)
       };
       _mesa_set_add(loop_path.reachable, block);
 
-      loop_routing_start(routing, b, loop_path, reach);
-      _mesa_set_destroy(reach, NULL);
-      reach = NULL;
+      loop_routing_start(routing, b, loop_path, reach, mem_ctx);
    }
 
    struct set *reach = _mesa_pointer_set_create(mem_ctx);
@@ -769,12 +780,8 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block)
    if (block->successors[1] && block->successors[1]->successors[0])
       _mesa_set_add(reach, block->successors[1]);
 
-   struct exec_list levels;
-   organize_levels(&levels, remaining, reach, routing, b->impl, true);
-   _mesa_set_destroy(remaining, NULL);
-   remaining = NULL;
-   _mesa_set_destroy(reach, NULL);
-   reach = NULL;
+   struct list_head levels;
+   organize_levels(&levels, remaining, reach, routing, b->impl, true, mem_ctx);
 
    /* Push all instructions of this block, without the jump instr */
    nir_jump_instr *jump_instr = NULL;
@@ -795,13 +802,11 @@ nir_structurize(struct routes *routing, nir_builder *b, nir_block *block)
       route_to(b, routing, block->successors[0]);
    }
 
-   plant_levels(&levels, routing, b);
+   plant_levels(&levels, routing, b, mem_ctx);
    if (is_looped) {
       loop_routing_end(routing, b);
-      plant_levels(&outside_levels, routing, b);
+      plant_levels(&outside_levels, routing, b, mem_ctx);
    }
-
-   ralloc_free(mem_ctx);
 }
 
 static bool
@@ -825,28 +830,30 @@ nir_lower_goto_ifs_impl(nir_function_impl *impl)
    nir_builder_init(&b, impl);
    b.cursor = nir_before_block(nir_start_block(impl));
 
-   struct routes *routing = ralloc(b.shader, struct routes);
-   routing->regular.reachable = _mesa_pointer_set_create(routing);
-   _mesa_set_add(routing->regular.reachable, impl->end_block);
-   struct set *empty = _mesa_pointer_set_create(routing);
-   routing->regular.fork = NULL;
-   routing->brk.reachable = empty;
-   routing->brk.fork = NULL;
-   routing->cont.reachable = empty;
-   routing->cont.fork = NULL;
-   nir_structurize(routing, &b,
-                   nir_cf_node_as_block(exec_node_data
-                                        (nir_cf_node,
-                                         exec_list_get_head(&cf_list.list),
-                                         node)));
+   void *mem_ctx = ralloc_context(b.shader);
+
+   struct set *end_set = _mesa_pointer_set_create(mem_ctx);
+   _mesa_set_add(end_set, impl->end_block);
+   struct set *empty_set = _mesa_pointer_set_create(mem_ctx);
+
+   nir_cf_node *start_node =
+      exec_node_data(nir_cf_node, exec_list_get_head(&cf_list.list), node);
+   nir_block *start_block = nir_cf_node_as_block(start_node);
+
+   struct routes *routing = ralloc(mem_ctx, struct routes);
+   *routing = (struct routes) {
+      .regular.reachable = end_set,
+      .brk.reachable = empty_set,
+      .cont.reachable = empty_set,
+   };
+   nir_structurize(routing, &b, start_block, mem_ctx);
    assert(routing->regular.fork == NULL);
    assert(routing->brk.fork == NULL);
    assert(routing->cont.fork == NULL);
-   assert(routing->brk.reachable == empty);
-   assert(routing->cont.reachable == empty);
-   _mesa_set_destroy(routing->regular.reachable, NULL);
-   _mesa_set_destroy(empty, NULL);
-   ralloc_free(routing);
+   assert(routing->brk.reachable == empty_set);
+   assert(routing->cont.reachable == empty_set);
+
+   ralloc_free(mem_ctx);
    nir_cf_delete(&cf_list);
 
    nir_metadata_preserve(impl, nir_metadata_none);