nir/lower_io: Add support for global scratch addressing
[mesa.git] / src / compiler / nir / nir_search.c
index 8578ca53c877ccc0d84d317315a807c6ae872ddb..577f0be0b9278fbff69be9ada4b69ab75f541653 100644 (file)
 #include <inttypes.h>
 #include "nir_search.h"
 #include "nir_builder.h"
+#include "nir_worklist.h"
 #include "util/half_float.h"
 
+/* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
+#define NIR_SEARCH_MAX_COMM_OPS 8
+
 struct match_state {
    bool inexact_match;
    bool has_exact_alu;
+   uint8_t comm_op_direction;
    unsigned variables_seen;
+
+   /* Used for running the automaton on newly-constructed instructions. */
+   struct util_dynarray *states;
+   const struct per_op_table *pass_op_table;
+
    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
+   struct hash_table *range_ht;
 };
 
 static bool
 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state);
+static bool
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
+                        const struct per_op_table *pass_op_table);
 
-static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
+static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
+{
+    0,  1,  2,  3,
+    4,  5,  6,  7,
+    8,  9, 10, 11,
+   12, 13, 14, 15,
+};
 
 /**
  * Check if a source produces a value of the given type.
@@ -131,6 +151,50 @@ nir_op_matches_search_op(nir_op nop, uint16_t sop)
 
 #undef MATCH_FCONV_CASE
 #undef MATCH_ICONV_CASE
+#undef MATCH_BCONV_CASE
+}
+
+uint16_t
+nir_search_op_for_nir_op(nir_op nop)
+{
+#define MATCH_FCONV_CASE(op) \
+   case nir_op_##op##16: \
+   case nir_op_##op##32: \
+   case nir_op_##op##64: \
+      return nir_search_op_##op;
+
+#define MATCH_ICONV_CASE(op) \
+   case nir_op_##op##8: \
+   case nir_op_##op##16: \
+   case nir_op_##op##32: \
+   case nir_op_##op##64: \
+      return nir_search_op_##op;
+
+#define MATCH_BCONV_CASE(op) \
+   case nir_op_##op##1: \
+   case nir_op_##op##32: \
+      return nir_search_op_##op;
+
+
+   switch (nop) {
+   MATCH_FCONV_CASE(i2f)
+   MATCH_FCONV_CASE(u2f)
+   MATCH_FCONV_CASE(f2f)
+   MATCH_ICONV_CASE(f2u)
+   MATCH_ICONV_CASE(f2i)
+   MATCH_ICONV_CASE(u2u)
+   MATCH_ICONV_CASE(i2i)
+   MATCH_FCONV_CASE(b2f)
+   MATCH_ICONV_CASE(b2i)
+   MATCH_BCONV_CASE(i2b)
+   MATCH_BCONV_CASE(f2b)
+   default:
+      return nop;
+   }
+
+#undef MATCH_FCONV_CASE
+#undef MATCH_ICONV_CASE
+#undef MATCH_BCONV_CASE
 }
 
 static nir_op
@@ -184,6 +248,7 @@ nir_op_for_search_op(uint16_t sop, unsigned bit_size)
 
 #undef RET_FCONV_CASE
 #undef RET_ICONV_CASE
+#undef RET_BCONV_CASE
 }
 
 static bool
@@ -200,8 +265,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
     * replacing so those reads will happen after the original reads and may
     * not be valid if they're register reads.
     */
-   if (!instr->src[src].src.is_ssa)
-      return false;
+   assert(instr->src[src].src.is_ssa);
 
    /* If the source is an explicitly sized source, then we need to reset
     * both the number of components and the swizzle.
@@ -249,7 +313,8 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
             return false;
 
-         if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
+         if (var->cond && !var->cond(state->range_ht, instr,
+                                     src, num_components, new_swizzle))
             return false;
 
          if (var->type != nir_type_invalid &&
@@ -279,7 +344,17 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
          return false;
 
       switch (const_val->type) {
-      case nir_type_float:
+      case nir_type_float: {
+         nir_load_const_instr *const load =
+            nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
+
+         /* There are 8-bit and 1-bit integer types, but there are no 8-bit or
+          * 1-bit float types.  This prevents potential assertion failures in
+          * nir_src_comp_as_float.
+          */
+         if (load->def.bit_size < 16)
+            return false;
+
          for (unsigned i = 0; i < num_components; ++i) {
             double val = nir_src_comp_as_float(instr->src[src].src,
                                                new_swizzle[i]);
@@ -287,6 +362,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
                return false;
          }
          return true;
+      }
 
       case nir_type_int:
       case nir_type_uint:
@@ -350,41 +426,29 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
       }
    }
 
-   /* Stash off the current variables_seen bitmask.  This way we can
-    * restore it prior to matching in the commutative case below.
+   /* If this is a commutative expression and it's one of the first few, look
+    * up its direction for the current search operation.  We'll use that value
+    * to possibly flip the sources for the match.
     */
-   unsigned variables_seen_stash = state->variables_seen;
+   unsigned comm_op_flip =
+      (expr->comm_expr_idx >= 0 &&
+       expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
+      ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
 
    bool matched = true;
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
-      if (!match_value(expr->srcs[i], instr, i, num_components,
-                       swizzle, state)) {
+      /* 2src_commutative instructions that have 3 sources are only commutative
+       * in the first two sources.  Source 2 is always source 2.
+       */
+      if (!match_value(expr->srcs[i], instr,
+                       i < 2 ? i ^ comm_op_flip : i,
+                       num_components, swizzle, state)) {
          matched = false;
          break;
       }
    }
 
-   if (matched)
-      return true;
-
-   if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
-      assert(nir_op_infos[instr->op].num_inputs == 2);
-
-      /* Restore the variables_seen bitmask.  If we don't do this, then we
-       * could end up with an erroneous failure due to variables found in the
-       * first match attempt above not matching those in the second.
-       */
-      state->variables_seen = variables_seen_stash;
-
-      if (!match_value(expr->srcs[0], instr, 1, num_components,
-                       swizzle, state))
-         return false;
-
-      return match_value(expr->srcs[1], instr, 0, num_components,
-                         swizzle, state);
-   } else {
-      return false;
-   }
+   return matched;
 }
 
 static unsigned
@@ -425,7 +489,7 @@ construct_value(nir_builder *build,
        * expression we are replacing has any exact values, the entire
        * replacement should be exact.
        */
-      alu->exact = state->has_exact_alu;
+      alu->exact = state->has_exact_alu || expr->exact;
 
       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
          /* If the source is an explicitly sized source, then we need to reset
@@ -441,6 +505,11 @@ construct_value(nir_builder *build,
 
       nir_builder_instr_insert(build, &alu->instr);
 
+      assert(alu->dest.dest.ssa.index ==
+             util_dynarray_num_elements(state->states, uint16_t));
+      util_dynarray_append(state->states, uint16_t, 0);
+      nir_algebraic_automaton(&alu->instr, state->states, state->pass_op_table);
+
       nir_alu_src val;
       val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
       val.negate = false;
@@ -459,6 +528,9 @@ construct_value(nir_builder *build,
                        (void *)build->shader);
       assert(!var->is_constant);
 
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         val.swizzle[i] = state->variables[var->variable].swizzle[var->swizzle[i]];
+
       return val;
    }
 
@@ -485,6 +557,12 @@ construct_value(nir_builder *build,
          unreachable("Invalid alu source type");
       }
 
+      assert(cval->index ==
+             util_dynarray_num_elements(state->states, uint16_t));
+      util_dynarray_append(state->states, uint16_t, 0);
+      nir_algebraic_automaton(cval->parent_instr, state->states,
+                              state->pass_op_table);
+
       nir_alu_src val;
       val.src = nir_src_for_ssa(cval);
       val.negate = false;
@@ -499,10 +577,120 @@ construct_value(nir_builder *build,
    }
 }
 
+UNUSED static void dump_value(const nir_search_value *val)
+{
+   switch (val->type) {
+   case nir_search_value_constant: {
+      const nir_search_constant *sconst = nir_search_value_as_constant(val);
+      switch (sconst->type) {
+      case nir_type_float:
+         fprintf(stderr, "%f", sconst->data.d);
+         break;
+      case nir_type_int:
+         fprintf(stderr, "%"PRId64, sconst->data.i);
+         break;
+      case nir_type_uint:
+         fprintf(stderr, "0x%"PRIx64, sconst->data.u);
+         break;
+      case nir_type_bool:
+         fprintf(stderr, "%s", sconst->data.u != 0 ? "True" : "False");
+         break;
+      default:
+         unreachable("bad const type");
+      }
+      break;
+   }
+
+   case nir_search_value_variable: {
+      const nir_search_variable *var = nir_search_value_as_variable(val);
+      if (var->is_constant)
+         fprintf(stderr, "#");
+      fprintf(stderr, "%c", var->variable + 'a');
+      break;
+   }
+
+   case nir_search_value_expression: {
+      const nir_search_expression *expr = nir_search_value_as_expression(val);
+      fprintf(stderr, "(");
+      if (expr->inexact)
+         fprintf(stderr, "~");
+      switch (expr->opcode) {
+#define CASE(n) \
+      case nir_search_op_##n: fprintf(stderr, #n); break;
+      CASE(f2b)
+      CASE(b2f)
+      CASE(b2i)
+      CASE(i2b)
+      CASE(i2i)
+      CASE(f2i)
+      CASE(i2f)
+#undef CASE
+      default:
+         fprintf(stderr, "%s", nir_op_infos[expr->opcode].name);
+      }
+
+      unsigned num_srcs = 1;
+      if (expr->opcode <= nir_last_opcode)
+         num_srcs = nir_op_infos[expr->opcode].num_inputs;
+
+      for (unsigned i = 0; i < num_srcs; i++) {
+         fprintf(stderr, " ");
+         dump_value(expr->srcs[i]);
+      }
+
+      fprintf(stderr, ")");
+      break;
+   }
+   }
+
+   if (val->bit_size > 0)
+      fprintf(stderr, "@%d", val->bit_size);
+}
+
+static void
+add_uses_to_worklist(nir_instr *instr, nir_instr_worklist *worklist)
+{
+   nir_ssa_def *def = nir_instr_ssa_def(instr);
+
+   nir_foreach_use_safe(use_src, def) {
+      nir_instr_worklist_push_tail(worklist, use_src->parent_instr);
+   }
+}
+
+static void
+nir_algebraic_update_automaton(nir_instr *new_instr,
+                               nir_instr_worklist *algebraic_worklist,
+                               struct util_dynarray *states,
+                               const struct per_op_table *pass_op_table)
+{
+
+   nir_instr_worklist *automaton_worklist = nir_instr_worklist_create();
+
+   /* Walk through the tree of uses of our new instruction's SSA value,
+    * recursively updating the automaton state until it stabilizes.
+    */
+   add_uses_to_worklist(new_instr, automaton_worklist);
+
+   nir_instr *instr;
+   while ((instr = nir_instr_worklist_pop_head(automaton_worklist))) {
+      if (nir_algebraic_automaton(instr, states, pass_op_table)) {
+         nir_instr_worklist_push_tail(algebraic_worklist, instr);
+
+         add_uses_to_worklist(instr, automaton_worklist);
+      }
+   }
+
+   nir_instr_worklist_destroy(automaton_worklist);
+}
+
 nir_ssa_def *
 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
+                  struct hash_table *range_ht,
+                  struct util_dynarray *states,
+                  const struct per_op_table *pass_op_table,
                   const nir_search_expression *search,
-                  const nir_search_value *replace)
+                  const nir_search_value *replace,
+                  nir_instr_worklist *algebraic_worklist)
 {
    uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
 
@@ -514,32 +702,269 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
    struct match_state state;
    state.inexact_match = false;
    state.has_exact_alu = false;
-   state.variables_seen = 0;
+   state.range_ht = range_ht;
+   state.pass_op_table = pass_op_table;
+
+   STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
+
+   unsigned comm_expr_combinations =
+      1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
 
-   if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
-                         swizzle, &state))
+   bool found = false;
+   for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
+      /* The bitfield of directions is just the current iteration.  Hooray for
+       * binary.
+       */
+      state.comm_op_direction = comb;
+      state.variables_seen = 0;
+
+      if (match_expression(search, instr,
+                           instr->dest.dest.ssa.num_components,
+                           swizzle, &state)) {
+         found = true;
+         break;
+      }
+   }
+   if (!found)
       return NULL;
 
-   build->cursor = nir_before_instr(&instr->instr);
+#if 0
+   fprintf(stderr, "matched: ");
+   dump_value(&search->value);
+   fprintf(stderr, " -> ");
+   dump_value(replace);
+   fprintf(stderr, " ssa_%d\n", instr->dest.dest.ssa.index);
+#endif
+
+   /* If the instruction at the root of the expression tree being replaced is
+    * a unary operation, insert the replacement instructions at the location
+    * of the source of the unary operation.  Otherwise, insert the replacement
+    * instructions at the location of the expression tree root.
+    *
+    * For the unary operation case, this is done to prevent some spurious code
+    * motion that can dramatically extend live ranges.  Imagine an expression
+    * like -(A+B) where the addtion and the negation are separated by flow
+    * control and thousands of instructions.  If this expression is replaced
+    * with -A+-B, inserting the new instructions at the site of the negation
+    * could extend the live range of A and B dramtically.  This could increase
+    * register pressure and cause spilling.
+    *
+    * It may well be that moving instructions around is a good thing, but
+    * keeping algebraic optimizations and code motion optimizations separate
+    * seems safest.
+    */
+   nir_alu_instr *const src_instr = nir_src_as_alu_instr(instr->src[0].src);
+   if (src_instr != NULL &&
+       (instr->op == nir_op_fneg || instr->op == nir_op_fabs ||
+        instr->op == nir_op_ineg || instr->op == nir_op_iabs ||
+        instr->op == nir_op_inot)) {
+      /* Insert new instructions *after*.  Otherwise a hypothetical
+       * replacement fneg(X) -> fabs(X) would insert the fabs() instruction
+       * before X!  This can also occur for things like fneg(X.wzyx) -> X.wzyx
+       * in vector mode.  A move instruction to handle the swizzle will get
+       * inserted before X.
+       *
+       * This manifested in a single OpenGL ES 2.0 CTS vertex shader test on
+       * older Intel GPU that use vector-mode vertex processing.
+       */
+      build->cursor = nir_after_instr(&src_instr->instr);
+   } else {
+      build->cursor = nir_before_instr(&instr->instr);
+   }
+
+   state.states = states;
 
    nir_alu_src val = construct_value(build, replace,
                                      instr->dest.dest.ssa.num_components,
                                      instr->dest.dest.ssa.bit_size,
                                      &state, &instr->instr);
 
-   /* Inserting a mov may be unnecessary.  However, it's much easier to
-    * simply let copy propagation clean this up than to try to go through
-    * and rewrite swizzles ourselves.
+   /* Note that NIR builder will elide the MOV if it's a no-op, which may
+    * allow more work to be done in a single pass through algebraic.
     */
    nir_ssa_def *ssa_val =
-      nir_imov_alu(build, val, instr->dest.dest.ssa.num_components);
+      nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
+   if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) {
+      util_dynarray_append(states, uint16_t, 0);
+      nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
+   }
+
+   /* Rewrite the uses of the old SSA value to the new one, and recurse
+    * through the uses updating the automaton's state.
+    */
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
+   nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist,
+                                  states, pass_op_table);
 
-   /* We know this one has no more uses because we just rewrote them all,
-    * so we can remove it.  The rest of the matched expression, however, we
-    * don't know so much about.  We'll just let dead code clean them up.
+   /* Nothing uses the instr any more, so drop it out of the program.  Note
+    * that the instr may be in the worklist still, so we can't free it
+    * directly.
     */
    nir_instr_remove(&instr->instr);
 
    return ssa_val;
 }
+
+static bool
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
+                        const struct per_op_table *pass_op_table)
+{
+   switch (instr->type) {
+   case nir_instr_type_alu: {
+      nir_alu_instr *alu = nir_instr_as_alu(instr);
+      nir_op op = alu->op;
+      uint16_t search_op = nir_search_op_for_nir_op(op);
+      const struct per_op_table *tbl = &pass_op_table[search_op];
+      if (tbl->num_filtered_states == 0)
+         return false;
+
+      /* Calculate the index into the transition table. Note the index
+       * calculated must match the iteration order of Python's
+       * itertools.product(), which was used to emit the transition
+       * table.
+       */
+      unsigned index = 0;
+      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
+         index *= tbl->num_filtered_states;
+         index += tbl->filter[*util_dynarray_element(states, uint16_t,
+                                                     alu->src[i].src.ssa->index)];
+      }
+
+      uint16_t *state = util_dynarray_element(states, uint16_t,
+                                              alu->dest.dest.ssa.index);
+      if (*state != tbl->table[index]) {
+         *state = tbl->table[index];
+         return true;
+      }
+      return false;
+   }
+
+   case nir_instr_type_load_const: {
+      nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
+      uint16_t *state = util_dynarray_element(states, uint16_t,
+                                              load_const->def.index);
+      if (*state != CONST_STATE) {
+         *state = CONST_STATE;
+         return true;
+      }
+      return false;
+   }
+
+   default:
+      return false;
+   }
+}
+
+static bool
+nir_algebraic_instr(nir_builder *build, nir_instr *instr,
+                    struct hash_table *range_ht,
+                    const bool *condition_flags,
+                    const struct transform **transforms,
+                    const uint16_t *transform_counts,
+                    struct util_dynarray *states,
+                    const struct per_op_table *pass_op_table,
+                    nir_instr_worklist *worklist)
+{
+
+   if (instr->type != nir_instr_type_alu)
+      return false;
+
+   nir_alu_instr *alu = nir_instr_as_alu(instr);
+   if (!alu->dest.dest.is_ssa)
+      return false;
+
+   unsigned bit_size = alu->dest.dest.ssa.bit_size;
+   const unsigned execution_mode =
+      build->shader->info.float_controls_execution_mode;
+   const bool ignore_inexact =
+      nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
+      nir_is_denorm_flush_to_zero(execution_mode, bit_size);
+
+   int xform_idx = *util_dynarray_element(states, uint16_t,
+                                          alu->dest.dest.ssa.index);
+   for (uint16_t i = 0; i < transform_counts[xform_idx]; i++) {
+      const struct transform *xform = &transforms[xform_idx][i];
+      if (condition_flags[xform->condition_offset] &&
+          !(xform->search->inexact && ignore_inexact) &&
+          nir_replace_instr(build, alu, range_ht, states, pass_op_table,
+                            xform->search, xform->replace, worklist)) {
+         _mesa_hash_table_clear(range_ht, NULL);
+         return true;
+      }
+   }
+
+   return false;
+}
+
+bool
+nir_algebraic_impl(nir_function_impl *impl,
+                   const bool *condition_flags,
+                   const struct transform **transforms,
+                   const uint16_t *transform_counts,
+                   const struct per_op_table *pass_op_table)
+{
+   bool progress = false;
+
+   nir_builder build;
+   nir_builder_init(&build, impl);
+
+   /* Note: it's important here that we're allocating a zeroed array, since
+    * state 0 is the default state, which means we don't have to visit
+    * anything other than constants and ALU instructions.
+    */
+   struct util_dynarray states = {0};
+   if (!util_dynarray_resize(&states, uint16_t, impl->ssa_alloc)) {
+      nir_metadata_preserve(impl, nir_metadata_all);
+      return false;
+   }
+   memset(states.data, 0, states.size);
+
+   struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
+
+   nir_instr_worklist *worklist = nir_instr_worklist_create();
+
+   /* Walk top-to-bottom setting up the automaton state. */
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr(instr, block) {
+         nir_algebraic_automaton(instr, &states, pass_op_table);
+      }
+   }
+
+   /* Put our instrs in the worklist such that we're popping the last instr
+    * first.  This will encourage us to match the biggest source patterns when
+    * possible.
+    */
+   nir_foreach_block_reverse(block, impl) {
+      nir_foreach_instr_reverse(instr, block) {
+         nir_instr_worklist_push_tail(worklist, instr);
+      }
+   }
+
+   nir_instr *instr;
+   while ((instr = nir_instr_worklist_pop_head(worklist))) {
+      /* The worklist can have an instr pushed to it multiple times if it was
+       * the src of multiple instrs that also got optimized, so make sure that
+       * we don't try to re-optimize an instr we already handled.
+       */
+      if (exec_node_is_tail_sentinel(&instr->node))
+         continue;
+
+      progress |= nir_algebraic_instr(&build, instr,
+                                      range_ht, condition_flags,
+                                      transforms, transform_counts, &states,
+                                      pass_op_table, worklist);
+   }
+
+   nir_instr_worklist_destroy(worklist);
+   ralloc_free(range_ht);
+   util_dynarray_fini(&states);
+
+   if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
+   }
+
+   return progress;
+}