nir: Maintain the algebraic automaton's state as we work.
authorConnor Abbott <cwabbott0@gmail.com>
Fri, 10 May 2019 14:57:45 +0000 (16:57 +0200)
committerEric Anholt <eric@anholt.net>
Tue, 26 Nov 2019 18:13:19 +0000 (10:13 -0800)
In order to have nir_opt_algebraic be able to do further algebraic
work on the output of a replacement, we need to maintain the
automaton's state.

Reviewed-by: Eric Anholt <eric@anholt.net>
src/compiler/nir/nir_search.c
src/compiler/nir/nir_search.h

index b78d3046a7b11b6b7c70ab53384d2ee799fbdf13..e6f36493fe2d304a5702dc4cda25a548c2d15d35 100644 (file)
@@ -38,6 +38,11 @@ struct match_state {
    bool has_exact_alu;
    uint8_t comm_op_direction;
    unsigned variables_seen;
+
+   /* Used for running the automaton on newly-constructed instructions. */
+   struct util_dynarray *states;
+   const struct per_op_table *pass_op_table;
+
    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
    struct hash_table *range_ht;
 };
@@ -46,6 +51,9 @@ static bool
 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state);
+static void
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
+                        const struct per_op_table *pass_op_table);
 
 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
 
@@ -490,6 +498,11 @@ construct_value(nir_builder *build,
 
       nir_builder_instr_insert(build, &alu->instr);
 
+      assert(alu->dest.dest.ssa.index ==
+             util_dynarray_num_elements(state->states, uint16_t));
+      util_dynarray_append(state->states, uint16_t, 0);
+      nir_algebraic_automaton(&alu->instr, state->states, state->pass_op_table);
+
       nir_alu_src val;
       val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
       val.negate = false;
@@ -537,6 +550,12 @@ construct_value(nir_builder *build,
          unreachable("Invalid alu source type");
       }
 
+      assert(cval->index ==
+             util_dynarray_num_elements(state->states, uint16_t));
+      util_dynarray_append(state->states, uint16_t, 0);
+      nir_algebraic_automaton(cval->parent_instr, state->states,
+                              state->pass_op_table);
+
       nir_alu_src val;
       val.src = nir_src_for_ssa(cval);
       val.negate = false;
@@ -624,6 +643,8 @@ UNUSED static void dump_value(const nir_search_value *val)
 nir_ssa_def *
 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
                   struct hash_table *range_ht,
+                  struct util_dynarray *states,
+                  const struct per_op_table *pass_op_table,
                   const nir_search_expression *search,
                   const nir_search_value *replace)
 {
@@ -638,6 +659,7 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
    state.inexact_match = false;
    state.has_exact_alu = false;
    state.range_ht = range_ht;
+   state.pass_op_table = pass_op_table;
 
    STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
 
@@ -672,6 +694,8 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
 
    build->cursor = nir_before_instr(&instr->instr);
 
+   state.states = states;
+
    nir_alu_src val = construct_value(build, replace,
                                      instr->dest.dest.ssa.num_components,
                                      instr->dest.dest.ssa.bit_size,
@@ -682,6 +706,11 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
     */
    nir_ssa_def *ssa_val =
       nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
+   if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) {
+      util_dynarray_append(states, uint16_t, 0);
+      nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
+   }
+
    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
 
    /* We know this one has no more uses because we just rewrote them all,
@@ -694,42 +723,43 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
 }
 
 static void
-nir_algebraic_automaton(nir_block *block, uint16_t *states,
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
                         const struct per_op_table *pass_op_table)
 {
-   nir_foreach_instr(instr, block) {
-      switch (instr->type) {
-      case nir_instr_type_alu: {
-         nir_alu_instr *alu = nir_instr_as_alu(instr);
-         nir_op op = alu->op;
-         uint16_t search_op = nir_search_op_for_nir_op(op);
-         const struct per_op_table *tbl = &pass_op_table[search_op];
-         if (tbl->num_filtered_states == 0)
-            continue;
-
-         /* Calculate the index into the transition table. Note the index
-          * calculated must match the iteration order of Python's
-          * itertools.product(), which was used to emit the transition
-          * table.
-          */
-         uint16_t index = 0;
-         for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
-            index *= tbl->num_filtered_states;
-            index += tbl->filter[states[alu->src[i].src.ssa->index]];
-         }
-         states[alu->dest.dest.ssa.index] = tbl->table[index];
-         break;
+   switch (instr->type) {
+   case nir_instr_type_alu: {
+      nir_alu_instr *alu = nir_instr_as_alu(instr);
+      nir_op op = alu->op;
+      uint16_t search_op = nir_search_op_for_nir_op(op);
+      const struct per_op_table *tbl = &pass_op_table[search_op];
+      if (tbl->num_filtered_states == 0)
+         return;
+
+      /* Calculate the index into the transition table. Note the index
+       * calculated must match the iteration order of Python's
+       * itertools.product(), which was used to emit the transition
+       * table.
+       */
+      uint16_t index = 0;
+      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
+         index *= tbl->num_filtered_states;
+         index += tbl->filter[*util_dynarray_element(states, uint16_t,
+                                                     alu->src[i].src.ssa->index)];
       }
+      *util_dynarray_element(states, uint16_t, alu->dest.dest.ssa.index) =
+         tbl->table[index];
+      break;
+   }
 
-      case nir_instr_type_load_const: {
-         nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
-         states[load_const->def.index] = CONST_STATE;
-         break;
-      }
+   case nir_instr_type_load_const: {
+      nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
+      *util_dynarray_element(states, uint16_t, load_const->def.index) =
+         CONST_STATE;
+      break;
+   }
 
-      default:
-         break;
-      }
+   default:
+      break;
    }
 }
 
@@ -739,7 +769,8 @@ nir_algebraic_block(nir_builder *build, nir_block *block,
                     const bool *condition_flags,
                     const struct transform **transforms,
                     const uint16_t *transform_counts,
-                    const uint16_t *states)
+                    struct util_dynarray *states,
+                    const struct per_op_table *pass_op_table)
 {
    bool progress = false;
    const unsigned execution_mode = build->shader->info.float_controls_execution_mode;
@@ -757,12 +788,13 @@ nir_algebraic_block(nir_builder *build, nir_block *block,
          nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
          nir_is_denorm_flush_to_zero(execution_mode, bit_size);
 
-      int xform_idx = states[alu->dest.dest.ssa.index];
+      int xform_idx = *util_dynarray_element(states, uint16_t,
+                                             alu->dest.dest.ssa.index);
       for (uint16_t i = 0; i < transform_counts[xform_idx]; i++) {
          const struct transform *xform = &transforms[xform_idx][i];
          if (condition_flags[xform->condition_offset] &&
              !(xform->search->inexact && ignore_inexact) &&
-             nir_replace_instr(build, alu, range_ht,
+             nir_replace_instr(build, alu, range_ht, states, pass_op_table,
                                xform->search, xform->replace)) {
             _mesa_hash_table_clear(range_ht, NULL);
             progress = true;
@@ -790,22 +822,27 @@ nir_algebraic_impl(nir_function_impl *impl,
     * state 0 is the default state, which means we don't have to visit
     * anything other than constants and ALU instructions.
     */
-   uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
+   struct util_dynarray states = {0};
+   if (!util_dynarray_resize(&states, uint16_t, impl->ssa_alloc))
+      return false;
+   memset(states.data, 0, states.size);
 
    struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
 
    nir_foreach_block(block, impl) {
-      nir_algebraic_automaton(block, states, pass_op_table);
+      nir_foreach_instr(instr, block) {
+         nir_algebraic_automaton(instr, &states, pass_op_table);
+      }
    }
 
    nir_foreach_block_reverse(block, impl) {
       progress |= nir_algebraic_block(&build, block, range_ht, condition_flags,
                                       transforms, transform_counts,
-                                      states);
+                                      &states, pass_op_table);
    }
 
    ralloc_free(range_ht);
-   free(states);
+   util_dynarray_fini(&states);
 
    if (progress) {
       nir_metadata_preserve(impl, nir_metadata_block_index |
index 80d153916c809b0bfa87bce12bf4919f476e0b45..9d567f88165e8b717127a9fb816add4136f43b56 100644 (file)
@@ -29,6 +29,7 @@
 #define _NIR_SEARCH_
 
 #include "nir.h"
+#include "util/u_dynarray.h"
 
 #define NIR_SEARCH_MAX_VARIABLES 16
 
@@ -198,6 +199,8 @@ NIR_DEFINE_CAST(nir_search_value_as_expression, nir_search_value,
 nir_ssa_def *
 nir_replace_instr(struct nir_builder *b, nir_alu_instr *instr,
                   struct hash_table *range_ht,
+                  struct util_dynarray *states,
+                  const struct per_op_table *pass_op_table,
                   const nir_search_expression *search,
                   const nir_search_value *replace);
 bool