From: Eric Anholt Date: Wed, 2 Oct 2019 17:59:13 +0000 (-0700) Subject: nir: Make algebraic backtrack and reprocess after a replacement. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=d845dca0f5451331abca250275c3d119f5d98d0b;p=mesa.git nir: Make algebraic backtrack and reprocess after a replacement. The algebraic pass was exhibiting O(n^2) behavior in dEQP-GLES2.functional.uniform_api.random.3 and dEQP-GLES31.functional.ubo.random.all_per_block_buffers.13 (along with other code-generated tests, and likely real-world loop-unroll cases). In the process of using fmul(b2f(x), b2f(x)) -> b2f(iand(x, y)) to transform: result = b2f(a == b); result *= b2f(c == d); ... result *= b2f(z == w); -> temp = (a == b) temp = temp && (c == d) ... temp = temp && (z == w) result = b2f(temp); nir_opt_algebraic, proceeding bottom-to-top, would match and convert the top-most fmul(b2f(), b2f()) case each time, leaving the new b2f to be matched by the next fmul down on the next time algebraic got run by the optimization loop. Back in 2016 in 7be8d0773229 ("nir: Do opt_algebraic in reverse order."), Matt changed algebraic to go bottom-to-top so that we would match the biggest patterns first. This helped his cases, but I believe introduced this failure mode. Instead of reverting that, now that we've got the automaton, we can update the automaton's state recursively and just re-process any instructions whose state has changed (indicating that they might match new things). There's a small chance that the state will hash to the same value and miss out on this round of algebraic, but this seems to be good enough to fix dEQP. Effects with NIR_VALIDATE=0 (improvement is better with validation enabled): Intel shader-db runtime -0.954712% +/- 0.333844% (n=44/46, obvious throttling outliers removed) dEQP-GLES2.functional.uniform_api.random.3 runtime -65.3512% +/- 4.22369% (n=21, was 1.4s) dEQP-GLES31.functional.ubo.random.all_per_block_buffers.13 runtime -68.8066% +/- 6.49523% (was 4.8s) v2: Use two worklists, suggested by @cwabbott, to cut out a bunch of tricky code. Runtime of uniform_api.random.3 down -0.790299% +/- 0.244213% compred to v1. v3: Re-add the nir_instr_remove() that I accidentally dropped in v2, fixing infinite loops. Reviewed-by: Connor Abbott --- diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index 6bb2f35aae8..c1b179525ab 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -28,6 +28,7 @@ #include #include "nir_search.h" #include "nir_builder.h" +#include "nir_worklist.h" #include "util/half_float.h" /* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */ @@ -51,7 +52,7 @@ static bool match_expression(const nir_search_expression *expr, nir_alu_instr *instr, unsigned num_components, const uint8_t *swizzle, struct match_state *state); -static void +static bool nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states, const struct per_op_table *pass_op_table); @@ -640,13 +641,50 @@ UNUSED static void dump_value(const nir_search_value *val) fprintf(stderr, "@%d", val->bit_size); } +static void +add_uses_to_worklist(nir_instr *instr, nir_instr_worklist *worklist) +{ + nir_ssa_def *def = nir_instr_ssa_def(instr); + + nir_foreach_use_safe(use_src, def) { + nir_instr_worklist_push_tail(worklist, use_src->parent_instr); + } +} + +static void +nir_algebraic_update_automaton(nir_instr *new_instr, + nir_instr_worklist *algebraic_worklist, + struct util_dynarray *states, + const struct per_op_table *pass_op_table) +{ + + nir_instr_worklist *automaton_worklist = nir_instr_worklist_create(); + + /* Walk through the tree of uses of our new instruction's SSA value, + * recursively updating the automaton state until it stabilizes. + */ + add_uses_to_worklist(new_instr, automaton_worklist); + + nir_instr *instr; + while ((instr = nir_instr_worklist_pop_head(automaton_worklist))) { + if (nir_algebraic_automaton(instr, states, pass_op_table)) { + nir_instr_worklist_push_tail(algebraic_worklist, instr); + + add_uses_to_worklist(instr, automaton_worklist); + } + } + + nir_instr_worklist_destroy(automaton_worklist); +} + nir_ssa_def * nir_replace_instr(nir_builder *build, nir_alu_instr *instr, struct hash_table *range_ht, struct util_dynarray *states, const struct per_op_table *pass_op_table, const nir_search_expression *search, - const nir_search_value *replace) + const nir_search_value *replace, + nir_instr_worklist *algebraic_worklist) { uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 }; @@ -711,18 +749,23 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr, nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table); } + /* Rewrite the uses of the old SSA value to the new one, and recurse + * through the uses updating the automaton's state. + */ nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val)); + nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist, + states, pass_op_table); - /* We know this one has no more uses because we just rewrote them all, - * so we can remove it. The rest of the matched expression, however, we - * don't know so much about. We'll just let dead code clean them up. + /* Nothing uses the instr any more, so drop it out of the program. Note + * that the instr may be in the worklist still, so we can't free it + * directly. */ nir_instr_remove(&instr->instr); return ssa_val; } -static void +static bool nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states, const struct per_op_table *pass_op_table) { @@ -733,7 +776,7 @@ nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states, uint16_t search_op = nir_search_op_for_nir_op(op); const struct per_op_table *tbl = &pass_op_table[search_op]; if (tbl->num_filtered_states == 0) - return; + return false; /* Calculate the index into the transition table. Note the index * calculated must match the iteration order of Python's @@ -746,20 +789,29 @@ nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states, index += tbl->filter[*util_dynarray_element(states, uint16_t, alu->src[i].src.ssa->index)]; } - *util_dynarray_element(states, uint16_t, alu->dest.dest.ssa.index) = - tbl->table[index]; - break; + + uint16_t *state = util_dynarray_element(states, uint16_t, + alu->dest.dest.ssa.index); + if (*state != tbl->table[index]) { + *state = tbl->table[index]; + return true; + } + return false; } case nir_instr_type_load_const: { nir_load_const_instr *load_const = nir_instr_as_load_const(instr); - *util_dynarray_element(states, uint16_t, load_const->def.index) = - CONST_STATE; - break; + uint16_t *state = util_dynarray_element(states, uint16_t, + load_const->def.index); + if (*state != CONST_STATE) { + *state = CONST_STATE; + return true; + } + return false; } default: - break; + return false; } } @@ -770,7 +822,8 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr, const struct transform **transforms, const uint16_t *transform_counts, struct util_dynarray *states, - const struct per_op_table *pass_op_table) + const struct per_op_table *pass_op_table, + nir_instr_worklist *worklist) { if (instr->type != nir_instr_type_alu) @@ -794,7 +847,7 @@ nir_algebraic_instr(nir_builder *build, nir_instr *instr, if (condition_flags[xform->condition_offset] && !(xform->search->inexact && ignore_inexact) && nir_replace_instr(build, alu, range_ht, states, pass_op_table, - xform->search, xform->replace)) { + xform->search, xform->replace, worklist)) { _mesa_hash_table_clear(range_ht, NULL); return true; } @@ -826,21 +879,41 @@ nir_algebraic_impl(nir_function_impl *impl, struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL); + nir_instr_worklist *worklist = nir_instr_worklist_create(); + + /* Walk top-to-bottom setting up the automaton state. */ nir_foreach_block(block, impl) { nir_foreach_instr(instr, block) { nir_algebraic_automaton(instr, &states, pass_op_table); } } + /* Put our instrs in the worklist such that we're popping the last instr + * first. This will encourage us to match the biggest source patterns when + * possible. + */ nir_foreach_block_reverse(block, impl) { - nir_foreach_instr_reverse_safe(instr, block) { - progress |= nir_algebraic_instr(&build, instr, - range_ht, condition_flags, - transforms, transform_counts, &states, - pass_op_table); + nir_foreach_instr_reverse(instr, block) { + nir_instr_worklist_push_tail(worklist, instr); } } + nir_instr *instr; + while ((instr = nir_instr_worklist_pop_head(worklist))) { + /* The worklist can have an instr pushed to it multiple times if it was + * the src of multiple instrs that also got optimized, so make sure that + * we don't try to re-optimize an instr we already handled. + */ + if (exec_node_is_tail_sentinel(&instr->node)) + continue; + + progress |= nir_algebraic_instr(&build, instr, + range_ht, condition_flags, + transforms, transform_counts, &states, + pass_op_table, worklist); + } + + nir_instr_worklist_destroy(worklist); ralloc_free(range_ht); util_dynarray_fini(&states); diff --git a/src/compiler/nir/nir_search.h b/src/compiler/nir/nir_search.h index 9d567f88165..30e8b6ac7f8 100644 --- a/src/compiler/nir/nir_search.h +++ b/src/compiler/nir/nir_search.h @@ -29,6 +29,7 @@ #define _NIR_SEARCH_ #include "nir.h" +#include "nir_worklist.h" #include "util/u_dynarray.h" #define NIR_SEARCH_MAX_VARIABLES 16 @@ -202,7 +203,8 @@ nir_replace_instr(struct nir_builder *b, nir_alu_instr *instr, struct util_dynarray *states, const struct per_op_table *pass_op_table, const nir_search_expression *search, - const nir_search_value *replace); + const nir_search_value *replace, + nir_instr_worklist *algebraic_worklist); bool nir_algebraic_impl(nir_function_impl *impl, const bool *condition_flags,