#include <inttypes.h>
#include "nir_search.h"
#include "nir_builder.h"
+#include "nir_worklist.h"
#include "util/half_float.h"
/* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
-#define NIR_SEARCH_MAX_COMM_OPS 4
+#define NIR_SEARCH_MAX_COMM_OPS 8
struct match_state {
bool inexact_match;
bool has_exact_alu;
uint8_t comm_op_direction;
unsigned variables_seen;
+
+ /* Used for running the automaton on newly-constructed instructions. */
+ struct util_dynarray *states;
+ const struct per_op_table *pass_op_table;
+
nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
+ struct hash_table *range_ht;
};
static bool
match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
unsigned num_components, const uint8_t *swizzle,
struct match_state *state);
+static bool
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
+ const struct per_op_table *pass_op_table);
-static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
+static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
+{
+ 0, 1, 2, 3,
+ 4, 5, 6, 7,
+ 8, 9, 10, 11,
+ 12, 13, 14, 15,
+};
/**
* Check if a source produces a value of the given type.
instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
return false;
- if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
+ if (var->cond && !var->cond(state->range_ht, instr,
+ src, num_components, new_swizzle))
return false;
if (var->type != nir_type_invalid &&
* expression we are replacing has any exact values, the entire
* replacement should be exact.
*/
- alu->exact = state->has_exact_alu;
+ alu->exact = state->has_exact_alu || expr->exact;
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
/* If the source is an explicitly sized source, then we need to reset
nir_builder_instr_insert(build, &alu->instr);
+ assert(alu->dest.dest.ssa.index ==
+ util_dynarray_num_elements(state->states, uint16_t));
+ util_dynarray_append(state->states, uint16_t, 0);
+ nir_algebraic_automaton(&alu->instr, state->states, state->pass_op_table);
+
nir_alu_src val;
val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
val.negate = false;
(void *)build->shader);
assert(!var->is_constant);
+ for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+ val.swizzle[i] = state->variables[var->variable].swizzle[var->swizzle[i]];
+
return val;
}
unreachable("Invalid alu source type");
}
+ assert(cval->index ==
+ util_dynarray_num_elements(state->states, uint16_t));
+ util_dynarray_append(state->states, uint16_t, 0);
+ nir_algebraic_automaton(cval->parent_instr, state->states,
+ state->pass_op_table);
+
nir_alu_src val;
val.src = nir_src_for_ssa(cval);
val.negate = false;
}
}
-MAYBE_UNUSED static void dump_value(const nir_search_value *val)
+UNUSED static void dump_value(const nir_search_value *val)
{
switch (val->type) {
case nir_search_value_constant: {
const nir_search_constant *sconst = nir_search_value_as_constant(val);
switch (sconst->type) {
case nir_type_float:
- printf("%f", sconst->data.d);
+ fprintf(stderr, "%f", sconst->data.d);
break;
case nir_type_int:
- printf("%"PRId64, sconst->data.i);
+ fprintf(stderr, "%"PRId64, sconst->data.i);
break;
case nir_type_uint:
- printf("0x%"PRIx64, sconst->data.u);
+ fprintf(stderr, "0x%"PRIx64, sconst->data.u);
+ break;
+ case nir_type_bool:
+ fprintf(stderr, "%s", sconst->data.u != 0 ? "True" : "False");
break;
default:
unreachable("bad const type");
case nir_search_value_variable: {
const nir_search_variable *var = nir_search_value_as_variable(val);
if (var->is_constant)
- printf("#");
- printf("%c", var->variable + 'a');
+ fprintf(stderr, "#");
+ fprintf(stderr, "%c", var->variable + 'a');
break;
}
case nir_search_value_expression: {
const nir_search_expression *expr = nir_search_value_as_expression(val);
- printf("(");
+ fprintf(stderr, "(");
if (expr->inexact)
- printf("~");
+ fprintf(stderr, "~");
switch (expr->opcode) {
#define CASE(n) \
- case nir_search_op_##n: printf(#n); break;
+ case nir_search_op_##n: fprintf(stderr, #n); break;
CASE(f2b)
CASE(b2f)
CASE(b2i)
CASE(i2f)
#undef CASE
default:
- printf("%s", nir_op_infos[expr->opcode].name);
+ fprintf(stderr, "%s", nir_op_infos[expr->opcode].name);
}
unsigned num_srcs = 1;
num_srcs = nir_op_infos[expr->opcode].num_inputs;
for (unsigned i = 0; i < num_srcs; i++) {
- printf(" ");
+ fprintf(stderr, " ");
dump_value(expr->srcs[i]);
}
- printf(")");
+ fprintf(stderr, ")");
break;
}
}
if (val->bit_size > 0)
- printf("@%d", val->bit_size);
+ fprintf(stderr, "@%d", val->bit_size);
+}
+
+static void
+add_uses_to_worklist(nir_instr *instr, nir_instr_worklist *worklist)
+{
+ nir_ssa_def *def = nir_instr_ssa_def(instr);
+
+ nir_foreach_use_safe(use_src, def) {
+ nir_instr_worklist_push_tail(worklist, use_src->parent_instr);
+ }
+}
+
+static void
+nir_algebraic_update_automaton(nir_instr *new_instr,
+ nir_instr_worklist *algebraic_worklist,
+ struct util_dynarray *states,
+ const struct per_op_table *pass_op_table)
+{
+
+ nir_instr_worklist *automaton_worklist = nir_instr_worklist_create();
+
+ /* Walk through the tree of uses of our new instruction's SSA value,
+ * recursively updating the automaton state until it stabilizes.
+ */
+ add_uses_to_worklist(new_instr, automaton_worklist);
+
+ nir_instr *instr;
+ while ((instr = nir_instr_worklist_pop_head(automaton_worklist))) {
+ if (nir_algebraic_automaton(instr, states, pass_op_table)) {
+ nir_instr_worklist_push_tail(algebraic_worklist, instr);
+
+ add_uses_to_worklist(instr, automaton_worklist);
+ }
+ }
+
+ nir_instr_worklist_destroy(automaton_worklist);
}
nir_ssa_def *
nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
+ struct hash_table *range_ht,
+ struct util_dynarray *states,
+ const struct per_op_table *pass_op_table,
const nir_search_expression *search,
- const nir_search_value *replace)
+ const nir_search_value *replace,
+ nir_instr_worklist *algebraic_worklist)
{
uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
struct match_state state;
state.inexact_match = false;
state.has_exact_alu = false;
+ state.range_ht = range_ht;
+ state.pass_op_table = pass_op_table;
+
+ STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
unsigned comm_expr_combinations =
1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
return NULL;
#if 0
- printf("matched: ");
+ fprintf(stderr, "matched: ");
dump_value(&search->value);
- printf(" -> ");
+ fprintf(stderr, " -> ");
dump_value(replace);
- printf(" ssa_%d\n", instr->dest.dest.ssa.index);
+ fprintf(stderr, " ssa_%d\n", instr->dest.dest.ssa.index);
#endif
- build->cursor = nir_before_instr(&instr->instr);
+ /* If the instruction at the root of the expression tree being replaced is
+ * a unary operation, insert the replacement instructions at the location
+ * of the source of the unary operation. Otherwise, insert the replacement
+ * instructions at the location of the expression tree root.
+ *
+ * For the unary operation case, this is done to prevent some spurious code
+ * motion that can dramatically extend live ranges. Imagine an expression
+ * like -(A+B) where the addtion and the negation are separated by flow
+ * control and thousands of instructions. If this expression is replaced
+ * with -A+-B, inserting the new instructions at the site of the negation
+ * could extend the live range of A and B dramtically. This could increase
+ * register pressure and cause spilling.
+ *
+ * It may well be that moving instructions around is a good thing, but
+ * keeping algebraic optimizations and code motion optimizations separate
+ * seems safest.
+ */
+ nir_alu_instr *const src_instr = nir_src_as_alu_instr(instr->src[0].src);
+ if (src_instr != NULL &&
+ (instr->op == nir_op_fneg || instr->op == nir_op_fabs ||
+ instr->op == nir_op_ineg || instr->op == nir_op_iabs ||
+ instr->op == nir_op_inot)) {
+ /* Insert new instructions *after*. Otherwise a hypothetical
+ * replacement fneg(X) -> fabs(X) would insert the fabs() instruction
+ * before X! This can also occur for things like fneg(X.wzyx) -> X.wzyx
+ * in vector mode. A move instruction to handle the swizzle will get
+ * inserted before X.
+ *
+ * This manifested in a single OpenGL ES 2.0 CTS vertex shader test on
+ * older Intel GPU that use vector-mode vertex processing.
+ */
+ build->cursor = nir_after_instr(&src_instr->instr);
+ } else {
+ build->cursor = nir_before_instr(&instr->instr);
+ }
+
+ state.states = states;
nir_alu_src val = construct_value(build, replace,
instr->dest.dest.ssa.num_components,
instr->dest.dest.ssa.bit_size,
&state, &instr->instr);
- /* Inserting a mov may be unnecessary. However, it's much easier to
- * simply let copy propagation clean this up than to try to go through
- * and rewrite swizzles ourselves.
+ /* Note that NIR builder will elide the MOV if it's a no-op, which may
+ * allow more work to be done in a single pass through algebraic.
*/
nir_ssa_def *ssa_val =
nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
+ if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) {
+ util_dynarray_append(states, uint16_t, 0);
+ nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
+ }
+
+ /* Rewrite the uses of the old SSA value to the new one, and recurse
+ * through the uses updating the automaton's state.
+ */
nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
+ nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist,
+ states, pass_op_table);
- /* We know this one has no more uses because we just rewrote them all,
- * so we can remove it. The rest of the matched expression, however, we
- * don't know so much about. We'll just let dead code clean them up.
+ /* Nothing uses the instr any more, so drop it out of the program. Note
+ * that the instr may be in the worklist still, so we can't free it
+ * directly.
*/
nir_instr_remove(&instr->instr);
return ssa_val;
}
+
+static bool
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
+ const struct per_op_table *pass_op_table)
+{
+ switch (instr->type) {
+ case nir_instr_type_alu: {
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+ nir_op op = alu->op;
+ uint16_t search_op = nir_search_op_for_nir_op(op);
+ const struct per_op_table *tbl = &pass_op_table[search_op];
+ if (tbl->num_filtered_states == 0)
+ return false;
+
+ /* Calculate the index into the transition table. Note the index
+ * calculated must match the iteration order of Python's
+ * itertools.product(), which was used to emit the transition
+ * table.
+ */
+ unsigned index = 0;
+ for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
+ index *= tbl->num_filtered_states;
+ index += tbl->filter[*util_dynarray_element(states, uint16_t,
+ alu->src[i].src.ssa->index)];
+ }
+
+ uint16_t *state = util_dynarray_element(states, uint16_t,
+ alu->dest.dest.ssa.index);
+ if (*state != tbl->table[index]) {
+ *state = tbl->table[index];
+ return true;
+ }
+ return false;
+ }
+
+ case nir_instr_type_load_const: {
+ nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
+ uint16_t *state = util_dynarray_element(states, uint16_t,
+ load_const->def.index);
+ if (*state != CONST_STATE) {
+ *state = CONST_STATE;
+ return true;
+ }
+ return false;
+ }
+
+ default:
+ return false;
+ }
+}
+
+static bool
+nir_algebraic_instr(nir_builder *build, nir_instr *instr,
+ struct hash_table *range_ht,
+ const bool *condition_flags,
+ const struct transform **transforms,
+ const uint16_t *transform_counts,
+ struct util_dynarray *states,
+ const struct per_op_table *pass_op_table,
+ nir_instr_worklist *worklist)
+{
+
+ if (instr->type != nir_instr_type_alu)
+ return false;
+
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+ if (!alu->dest.dest.is_ssa)
+ return false;
+
+ unsigned bit_size = alu->dest.dest.ssa.bit_size;
+ const unsigned execution_mode =
+ build->shader->info.float_controls_execution_mode;
+ const bool ignore_inexact =
+ nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
+ nir_is_denorm_flush_to_zero(execution_mode, bit_size);
+
+ int xform_idx = *util_dynarray_element(states, uint16_t,
+ alu->dest.dest.ssa.index);
+ for (uint16_t i = 0; i < transform_counts[xform_idx]; i++) {
+ const struct transform *xform = &transforms[xform_idx][i];
+ if (condition_flags[xform->condition_offset] &&
+ !(xform->search->inexact && ignore_inexact) &&
+ nir_replace_instr(build, alu, range_ht, states, pass_op_table,
+ xform->search, xform->replace, worklist)) {
+ _mesa_hash_table_clear(range_ht, NULL);
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool
+nir_algebraic_impl(nir_function_impl *impl,
+ const bool *condition_flags,
+ const struct transform **transforms,
+ const uint16_t *transform_counts,
+ const struct per_op_table *pass_op_table)
+{
+ bool progress = false;
+
+ nir_builder build;
+ nir_builder_init(&build, impl);
+
+ /* Note: it's important here that we're allocating a zeroed array, since
+ * state 0 is the default state, which means we don't have to visit
+ * anything other than constants and ALU instructions.
+ */
+ struct util_dynarray states = {0};
+ if (!util_dynarray_resize(&states, uint16_t, impl->ssa_alloc)) {
+ nir_metadata_preserve(impl, nir_metadata_all);
+ return false;
+ }
+ memset(states.data, 0, states.size);
+
+ struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
+
+ nir_instr_worklist *worklist = nir_instr_worklist_create();
+
+ /* Walk top-to-bottom setting up the automaton state. */
+ nir_foreach_block(block, impl) {
+ nir_foreach_instr(instr, block) {
+ nir_algebraic_automaton(instr, &states, pass_op_table);
+ }
+ }
+
+ /* Put our instrs in the worklist such that we're popping the last instr
+ * first. This will encourage us to match the biggest source patterns when
+ * possible.
+ */
+ nir_foreach_block_reverse(block, impl) {
+ nir_foreach_instr_reverse(instr, block) {
+ nir_instr_worklist_push_tail(worklist, instr);
+ }
+ }
+
+ nir_instr *instr;
+ while ((instr = nir_instr_worklist_pop_head(worklist))) {
+ /* The worklist can have an instr pushed to it multiple times if it was
+ * the src of multiple instrs that also got optimized, so make sure that
+ * we don't try to re-optimize an instr we already handled.
+ */
+ if (exec_node_is_tail_sentinel(&instr->node))
+ continue;
+
+ progress |= nir_algebraic_instr(&build, instr,
+ range_ht, condition_flags,
+ transforms, transform_counts, &states,
+ pass_op_table, worklist);
+ }
+
+ nir_instr_worklist_destroy(worklist);
+ ralloc_free(range_ht);
+ util_dynarray_fini(&states);
+
+ if (progress) {
+ nir_metadata_preserve(impl, nir_metadata_block_index |
+ nir_metadata_dominance);
+ } else {
+ nir_metadata_preserve(impl, nir_metadata_all);
+ }
+
+ return progress;
+}