nir: Make algebraic backtrack and reprocess after a replacement.
authorEric Anholt <eric@anholt.net>
Wed, 2 Oct 2019 17:59:13 +0000 (10:59 -0700)
committerEric Anholt <eric@anholt.net>
Tue, 26 Nov 2019 18:13:46 +0000 (10:13 -0800)
The algebraic pass was exhibiting O(n^2) behavior in
dEQP-GLES2.functional.uniform_api.random.3 and
dEQP-GLES31.functional.ubo.random.all_per_block_buffers.13 (along with
other code-generated tests, and likely real-world loop-unroll cases).
In the process of using fmul(b2f(x), b2f(x)) -> b2f(iand(x, y)) to
transform:

result = b2f(a == b);
result *= b2f(c == d);
...
result *= b2f(z == w);

->

temp = (a == b)
temp = temp && (c == d)
...
temp = temp && (z == w)
result = b2f(temp);

nir_opt_algebraic, proceeding bottom-to-top, would match and convert
the top-most fmul(b2f(), b2f()) case each time, leaving the new b2f to
be matched by the next fmul down on the next time algebraic got run by
the optimization loop.

Back in 2016 in 7be8d0773229 ("nir: Do opt_algebraic in reverse
order."), Matt changed algebraic to go bottom-to-top so that we would
match the biggest patterns first.  This helped his cases, but I
believe introduced this failure mode.  Instead of reverting that, now
that we've got the automaton, we can update the automaton's state
recursively and just re-process any instructions whose state has
changed (indicating that they might match new things).  There's a
small chance that the state will hash to the same value and miss out
on this round of algebraic, but this seems to be good enough to fix
dEQP.

Effects with NIR_VALIDATE=0 (improvement is better with validation enabled):

Intel shader-db runtime -0.954712% +/- 0.333844% (n=44/46, obvious throttling
  outliers removed)
dEQP-GLES2.functional.uniform_api.random.3 runtime
  -65.3512% +/- 4.22369% (n=21, was 1.4s)
dEQP-GLES31.functional.ubo.random.all_per_block_buffers.13 runtime
  -68.8066% +/- 6.49523% (was 4.8s)

v2: Use two worklists, suggested by @cwabbott, to cut out a bunch of
    tricky code.  Runtime of uniform_api.random.3 down -0.790299% +/-
    0.244213% compred to v1.
v3: Re-add the nir_instr_remove() that I accidentally dropped in v2,
    fixing infinite loops.

Reviewed-by: Connor Abbott <cwabbott0@gmail.com>
src/compiler/nir/nir_search.c
src/compiler/nir/nir_search.h

index 6bb2f35aae83a528680cd2297a1de7a0bb4c16b5..c1b179525abaf9b94337e76698ccb17bccd76d70 100644 (file)
@@ -28,6 +28,7 @@
 #include <inttypes.h>
 #include "nir_search.h"
 #include "nir_builder.h"
+#include "nir_worklist.h"
 #include "util/half_float.h"
 
 /* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
@@ -51,7 +52,7 @@ static bool
 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state);
-static void
+static bool
 nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
                         const struct per_op_table *pass_op_table);
 
@@ -640,13 +641,50 @@ UNUSED static void dump_value(const nir_search_value *val)
       fprintf(stderr, "@%d", val->bit_size);
 }
 
+static void
+add_uses_to_worklist(nir_instr *instr, nir_instr_worklist *worklist)
+{
+   nir_ssa_def *def = nir_instr_ssa_def(instr);
+
+   nir_foreach_use_safe(use_src, def) {
+      nir_instr_worklist_push_tail(worklist, use_src->parent_instr);
+   }
+}
+
+static void
+nir_algebraic_update_automaton(nir_instr *new_instr,
+                               nir_instr_worklist *algebraic_worklist,
+                               struct util_dynarray *states,
+                               const struct per_op_table *pass_op_table)
+{
+
+   nir_instr_worklist *automaton_worklist = nir_instr_worklist_create();
+
+   /* Walk through the tree of uses of our new instruction's SSA value,
+    * recursively updating the automaton state until it stabilizes.
+    */
+   add_uses_to_worklist(new_instr, automaton_worklist);
+
+   nir_instr *instr;
+   while ((instr = nir_instr_worklist_pop_head(automaton_worklist))) {
+      if (nir_algebraic_automaton(instr, states, pass_op_table)) {
+         nir_instr_worklist_push_tail(algebraic_worklist, instr);
+
+         add_uses_to_worklist(instr, automaton_worklist);
+      }
+   }
+
+   nir_instr_worklist_destroy(automaton_worklist);
+}
+
 nir_ssa_def *
 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
                   struct hash_table *range_ht,
                   struct util_dynarray *states,
                   const struct per_op_table *pass_op_table,
                   const nir_search_expression *search,
-                  const nir_search_value *replace)
+                  const nir_search_value *replace,
+                  nir_instr_worklist *algebraic_worklist)
 {
    uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
 
@@ -711,18 +749,23 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
       nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
    }
 
+   /* Rewrite the uses of the old SSA value to the new one, and recurse
+    * through the uses updating the automaton's state.
+    */
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
+   nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist,
+                                  states, pass_op_table);
 
-   /* We know this one has no more uses because we just rewrote them all,
-    * so we can remove it.  The rest of the matched expression, however, we
-    * don't know so much about.  We'll just let dead code clean them up.
+   /* Nothing uses the instr any more, so drop it out of the program.  Note
+    * that the instr may be in the worklist still, so we can't free it
+    * directly.
     */
    nir_instr_remove(&instr->instr);
 
    return ssa_val;
 }
 
-static void
+static bool
 nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
                         const struct per_op_table *pass_op_table)
 {
@@ -733,7 +776,7 @@ nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
       uint16_t search_op = nir_search_op_for_nir_op(op);
       const struct per_op_table *tbl = &pass_op_table[search_op];
       if (tbl->num_filtered_states == 0)
-         return;
+         return false;
 
       /* Calculate the index into the transition table. Note the index
        * calculated must match the iteration order of Python's
@@ -746,20 +789,29 @@ nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
          index += tbl->filter[*util_dynarray_element(states, uint16_t,
                                                      alu->src[i].src.ssa->index)];
       }
-      *util_dynarray_element(states, uint16_t, alu->dest.dest.ssa.index) =
-         tbl->table[index];
-      break;
+
+      uint16_t *state = util_dynarray_element(states, uint16_t,
+                                              alu->dest.dest.ssa.index);
+      if (*state != tbl->table[index]) {
+         *state = tbl->table[index];
+         return true;
+      }
+      return false;
    }
 
    case nir_instr_type_load_const: {
       nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
-      *util_dynarray_element(states, uint16_t, load_const->def.index) =
-         CONST_STATE;
-      break;
+      uint16_t *state = util_dynarray_element(states, uint16_t,
+                                              load_const->def.index);
+      if (*state != CONST_STATE) {
+         *state = CONST_STATE;
+         return true;
+      }
+      return false;
    }
 
    default:
-      break;
+      return false;
    }
 }
 
@@ -770,7 +822,8 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr,
                     const struct transform **transforms,
                     const uint16_t *transform_counts,
                     struct util_dynarray *states,
-                    const struct per_op_table *pass_op_table)
+                    const struct per_op_table *pass_op_table,
+                    nir_instr_worklist *worklist)
 {
 
    if (instr->type != nir_instr_type_alu)
@@ -794,7 +847,7 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr,
       if (condition_flags[xform->condition_offset] &&
           !(xform->search->inexact && ignore_inexact) &&
           nir_replace_instr(build, alu, range_ht, states, pass_op_table,
-                            xform->search, xform->replace)) {
+                            xform->search, xform->replace, worklist)) {
          _mesa_hash_table_clear(range_ht, NULL);
          return true;
       }
@@ -826,21 +879,41 @@ nir_algebraic_impl(nir_function_impl *impl,
 
    struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
 
+   nir_instr_worklist *worklist = nir_instr_worklist_create();
+
+   /* Walk top-to-bottom setting up the automaton state. */
    nir_foreach_block(block, impl) {
       nir_foreach_instr(instr, block) {
          nir_algebraic_automaton(instr, &states, pass_op_table);
       }
    }
 
+   /* Put our instrs in the worklist such that we're popping the last instr
+    * first.  This will encourage us to match the biggest source patterns when
+    * possible.
+    */
    nir_foreach_block_reverse(block, impl) {
-      nir_foreach_instr_reverse_safe(instr, block) {
-         progress |= nir_algebraic_instr(&build, instr,
-                                         range_ht, condition_flags,
-                                         transforms, transform_counts, &states,
-                                         pass_op_table);
+      nir_foreach_instr_reverse(instr, block) {
+         nir_instr_worklist_push_tail(worklist, instr);
       }
    }
 
+   nir_instr *instr;
+   while ((instr = nir_instr_worklist_pop_head(worklist))) {
+      /* The worklist can have an instr pushed to it multiple times if it was
+       * the src of multiple instrs that also got optimized, so make sure that
+       * we don't try to re-optimize an instr we already handled.
+       */
+      if (exec_node_is_tail_sentinel(&instr->node))
+         continue;
+
+      progress |= nir_algebraic_instr(&build, instr,
+                                      range_ht, condition_flags,
+                                      transforms, transform_counts, &states,
+                                      pass_op_table, worklist);
+   }
+
+   nir_instr_worklist_destroy(worklist);
    ralloc_free(range_ht);
    util_dynarray_fini(&states);
 
index 9d567f88165e8b717127a9fb816add4136f43b56..30e8b6ac7f88dbf93c674acb1f356e6a473342af 100644 (file)
@@ -29,6 +29,7 @@
 #define _NIR_SEARCH_
 
 #include "nir.h"
+#include "nir_worklist.h"
 #include "util/u_dynarray.h"
 
 #define NIR_SEARCH_MAX_VARIABLES 16
@@ -202,7 +203,8 @@ nir_replace_instr(struct nir_builder *b, nir_alu_instr *instr,
                   struct util_dynarray *states,
                   const struct per_op_table *pass_op_table,
                   const nir_search_expression *search,
-                  const nir_search_value *replace);
+                  const nir_search_value *replace,
+                  nir_instr_worklist *algebraic_worklist);
 bool
 nir_algebraic_impl(nir_function_impl *impl,
                    const bool *condition_flags,