nir/lower_io: Add support for global scratch addressing
[mesa.git] / src / compiler / nir / nir_search.c
index b34b13fdb8904bd47fbfbf90108fe9244759b410..577f0be0b9278fbff69be9ada4b69ab75f541653 100644 (file)
 
 #include <inttypes.h>
 #include "nir_search.h"
+#include "nir_builder.h"
+#include "nir_worklist.h"
+#include "util/half_float.h"
+
+/* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
+#define NIR_SEARCH_MAX_COMM_OPS 8
 
 struct match_state {
    bool inexact_match;
    bool has_exact_alu;
+   uint8_t comm_op_direction;
    unsigned variables_seen;
+
+   /* Used for running the automaton on newly-constructed instructions. */
+   struct util_dynarray *states;
+   const struct per_op_table *pass_op_table;
+
    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
+   struct hash_table *range_ht;
 };
 
 static bool
 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state);
+static bool
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
+                        const struct per_op_table *pass_op_table);
 
-static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
+static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
+{
+    0,  1,  2,  3,
+    4,  5,  6,  7,
+    8,  9, 10, 11,
+   12, 13, 14, 15,
+};
 
 /**
  * Check if a source produces a value of the given type.
@@ -55,10 +77,6 @@ src_is_type(nir_src src, nir_alu_type type)
    if (!src.is_ssa)
       return false;
 
-   /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
-   if (nir_alu_type_get_base_type(type) == nir_type_bool)
-      type = nir_type_bool;
-
    if (src.ssa->parent_instr->type == nir_instr_type_alu) {
       nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
       nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
@@ -91,12 +109,163 @@ src_is_type(nir_src src, nir_alu_type type)
    return false;
 }
 
+static bool
+nir_op_matches_search_op(nir_op nop, uint16_t sop)
+{
+   if (sop <= nir_last_opcode)
+      return nop == sop;
+
+#define MATCH_FCONV_CASE(op) \
+   case nir_search_op_##op: \
+      return nop == nir_op_##op##16 || \
+             nop == nir_op_##op##32 || \
+             nop == nir_op_##op##64;
+
+#define MATCH_ICONV_CASE(op) \
+   case nir_search_op_##op: \
+      return nop == nir_op_##op##8 || \
+             nop == nir_op_##op##16 || \
+             nop == nir_op_##op##32 || \
+             nop == nir_op_##op##64;
+
+#define MATCH_BCONV_CASE(op) \
+   case nir_search_op_##op: \
+      return nop == nir_op_##op##1 || \
+             nop == nir_op_##op##32;
+
+   switch (sop) {
+   MATCH_FCONV_CASE(i2f)
+   MATCH_FCONV_CASE(u2f)
+   MATCH_FCONV_CASE(f2f)
+   MATCH_ICONV_CASE(f2u)
+   MATCH_ICONV_CASE(f2i)
+   MATCH_ICONV_CASE(u2u)
+   MATCH_ICONV_CASE(i2i)
+   MATCH_FCONV_CASE(b2f)
+   MATCH_ICONV_CASE(b2i)
+   MATCH_BCONV_CASE(i2b)
+   MATCH_BCONV_CASE(f2b)
+   default:
+      unreachable("Invalid nir_search_op");
+   }
+
+#undef MATCH_FCONV_CASE
+#undef MATCH_ICONV_CASE
+#undef MATCH_BCONV_CASE
+}
+
+uint16_t
+nir_search_op_for_nir_op(nir_op nop)
+{
+#define MATCH_FCONV_CASE(op) \
+   case nir_op_##op##16: \
+   case nir_op_##op##32: \
+   case nir_op_##op##64: \
+      return nir_search_op_##op;
+
+#define MATCH_ICONV_CASE(op) \
+   case nir_op_##op##8: \
+   case nir_op_##op##16: \
+   case nir_op_##op##32: \
+   case nir_op_##op##64: \
+      return nir_search_op_##op;
+
+#define MATCH_BCONV_CASE(op) \
+   case nir_op_##op##1: \
+   case nir_op_##op##32: \
+      return nir_search_op_##op;
+
+
+   switch (nop) {
+   MATCH_FCONV_CASE(i2f)
+   MATCH_FCONV_CASE(u2f)
+   MATCH_FCONV_CASE(f2f)
+   MATCH_ICONV_CASE(f2u)
+   MATCH_ICONV_CASE(f2i)
+   MATCH_ICONV_CASE(u2u)
+   MATCH_ICONV_CASE(i2i)
+   MATCH_FCONV_CASE(b2f)
+   MATCH_ICONV_CASE(b2i)
+   MATCH_BCONV_CASE(i2b)
+   MATCH_BCONV_CASE(f2b)
+   default:
+      return nop;
+   }
+
+#undef MATCH_FCONV_CASE
+#undef MATCH_ICONV_CASE
+#undef MATCH_BCONV_CASE
+}
+
+static nir_op
+nir_op_for_search_op(uint16_t sop, unsigned bit_size)
+{
+   if (sop <= nir_last_opcode)
+      return sop;
+
+#define RET_FCONV_CASE(op) \
+   case nir_search_op_##op: \
+      switch (bit_size) { \
+      case 16: return nir_op_##op##16; \
+      case 32: return nir_op_##op##32; \
+      case 64: return nir_op_##op##64; \
+      default: unreachable("Invalid bit size"); \
+      }
+
+#define RET_ICONV_CASE(op) \
+   case nir_search_op_##op: \
+      switch (bit_size) { \
+      case 8:  return nir_op_##op##8; \
+      case 16: return nir_op_##op##16; \
+      case 32: return nir_op_##op##32; \
+      case 64: return nir_op_##op##64; \
+      default: unreachable("Invalid bit size"); \
+      }
+
+#define RET_BCONV_CASE(op) \
+   case nir_search_op_##op: \
+      switch (bit_size) { \
+      case 1: return nir_op_##op##1; \
+      case 32: return nir_op_##op##32; \
+      default: unreachable("Invalid bit size"); \
+      }
+
+   switch (sop) {
+   RET_FCONV_CASE(i2f)
+   RET_FCONV_CASE(u2f)
+   RET_FCONV_CASE(f2f)
+   RET_ICONV_CASE(f2u)
+   RET_ICONV_CASE(f2i)
+   RET_ICONV_CASE(u2u)
+   RET_ICONV_CASE(i2i)
+   RET_FCONV_CASE(b2f)
+   RET_ICONV_CASE(b2i)
+   RET_BCONV_CASE(i2b)
+   RET_BCONV_CASE(f2b)
+   default:
+      unreachable("Invalid nir_search_op");
+   }
+
+#undef RET_FCONV_CASE
+#undef RET_ICONV_CASE
+#undef RET_BCONV_CASE
+}
+
 static bool
 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
             unsigned num_components, const uint8_t *swizzle,
             struct match_state *state)
 {
-   uint8_t new_swizzle[4];
+   uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
+
+   /* Searching only works on SSA values because, if it's not SSA, we can't
+    * know if the value changed between one instance of that value in the
+    * expression and another.  Also, the replace operation will place reads of
+    * that value right before the last instruction in the expression we're
+    * replacing so those reads will happen after the original reads and may
+    * not be valid if they're register reads.
+    */
+   assert(instr->src[src].src.is_ssa);
 
    /* If the source is an explicitly sized source, then we need to reset
     * both the number of components and the swizzle.
@@ -110,15 +279,12 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
 
    /* If the value has a specific bit size and it doesn't match, bail */
-   if (value->bit_size &&
+   if (value->bit_size > 0 &&
        nir_src_bit_size(instr->src[src].src) != value->bit_size)
       return false;
 
    switch (value->type) {
    case nir_search_value_expression:
-      if (!instr->src[src].src.is_ssa)
-         return false;
-
       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
          return false;
 
@@ -131,8 +297,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
       assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
 
       if (state->variables_seen & (1 << var->variable)) {
-         if (!nir_srcs_equal(state->variables[var->variable].src,
-                             instr->src[src].src))
+         if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
             return false;
 
          assert(!instr->src[src].abs && !instr->src[src].negate);
@@ -148,7 +313,8 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
             return false;
 
-         if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
+         if (var->cond && !var->cond(state->range_ht, instr,
+                                     src, num_components, new_swizzle))
             return false;
 
          if (var->type != nir_type_invalid &&
@@ -160,7 +326,7 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
          state->variables[var->variable].abs = false;
          state->variables[var->variable].negate = false;
 
-         for (unsigned i = 0; i < 4; ++i) {
+         for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
             if (i < num_components)
                state->variables[var->variable].swizzle[i] = new_swizzle[i];
             else
@@ -174,73 +340,43 @@ match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
    case nir_search_value_constant: {
       nir_search_constant *const_val = nir_search_value_as_constant(value);
 
-      if (!instr->src[src].src.is_ssa)
+      if (!nir_src_is_const(instr->src[src].src))
          return false;
 
-      if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
-         return false;
+      switch (const_val->type) {
+      case nir_type_float: {
+         nir_load_const_instr *const load =
+            nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
 
-      nir_load_const_instr *load =
-         nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
+         /* There are 8-bit and 1-bit integer types, but there are no 8-bit or
+          * 1-bit float types.  This prevents potential assertion failures in
+          * nir_src_comp_as_float.
+          */
+         if (load->def.bit_size < 16)
+            return false;
 
-      switch (const_val->type) {
-      case nir_type_float:
          for (unsigned i = 0; i < num_components; ++i) {
-            double val;
-            switch (load->def.bit_size) {
-            case 32:
-               val = load->value.f32[new_swizzle[i]];
-               break;
-            case 64:
-               val = load->value.f64[new_swizzle[i]];
-               break;
-            default:
-               unreachable("unknown bit size");
-            }
-
+            double val = nir_src_comp_as_float(instr->src[src].src,
+                                               new_swizzle[i]);
             if (val != const_val->data.d)
                return false;
          }
          return true;
+      }
 
       case nir_type_int:
-         for (unsigned i = 0; i < num_components; ++i) {
-            int64_t val;
-            switch (load->def.bit_size) {
-            case 32:
-               val = load->value.i32[new_swizzle[i]];
-               break;
-            case 64:
-               val = load->value.i64[new_swizzle[i]];
-               break;
-            default:
-               unreachable("unknown bit size");
-            }
-
-            if (val != const_val->data.i)
-               return false;
-         }
-         return true;
-
       case nir_type_uint:
-      case nir_type_bool32:
+      case nir_type_bool: {
+         unsigned bit_size = nir_src_bit_size(instr->src[src].src);
+         uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
          for (unsigned i = 0; i < num_components; ++i) {
-            uint64_t val;
-            switch (load->def.bit_size) {
-            case 32:
-               val = load->value.u32[new_swizzle[i]];
-               break;
-            case 64:
-               val = load->value.u64[new_swizzle[i]];
-               break;
-            default:
-               unreachable("unknown bit size");
-            }
-
-            if (val != const_val->data.u)
+            uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
+                                                new_swizzle[i]);
+            if ((val & mask) != (const_val->data.u & mask))
                return false;
          }
          return true;
+      }
 
       default:
          unreachable("Invalid alu source type");
@@ -257,12 +393,15 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
                  unsigned num_components, const uint8_t *swizzle,
                  struct match_state *state)
 {
-   if (instr->op != expr->opcode)
+   if (expr->cond && !expr->cond(instr))
+      return false;
+
+   if (!nir_op_matches_search_op(instr->op, expr->opcode))
       return false;
 
    assert(instr->dest.dest.is_ssa);
 
-   if (expr->value.bit_size &&
+   if (expr->value.bit_size > 0 &&
        instr->dest.dest.ssa.bit_size != expr->value.bit_size)
       return false;
 
@@ -287,177 +426,61 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
       }
    }
 
-   /* Stash off the current variables_seen bitmask.  This way we can
-    * restore it prior to matching in the commutative case below.
+   /* If this is a commutative expression and it's one of the first few, look
+    * up its direction for the current search operation.  We'll use that value
+    * to possibly flip the sources for the match.
     */
-   unsigned variables_seen_stash = state->variables_seen;
+   unsigned comm_op_flip =
+      (expr->comm_expr_idx >= 0 &&
+       expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
+      ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
 
    bool matched = true;
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
-      if (!match_value(expr->srcs[i], instr, i, num_components,
-                       swizzle, state)) {
+      /* 2src_commutative instructions that have 3 sources are only commutative
+       * in the first two sources.  Source 2 is always source 2.
+       */
+      if (!match_value(expr->srcs[i], instr,
+                       i < 2 ? i ^ comm_op_flip : i,
+                       num_components, swizzle, state)) {
          matched = false;
          break;
       }
    }
 
-   if (matched)
-      return true;
-
-   if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
-      assert(nir_op_infos[instr->op].num_inputs == 2);
-
-      /* Restore the variables_seen bitmask.  If we don't do this, then we
-       * could end up with an erroneous failure due to variables found in the
-       * first match attempt above not matching those in the second.
-       */
-      state->variables_seen = variables_seen_stash;
-
-      if (!match_value(expr->srcs[0], instr, 1, num_components,
-                       swizzle, state))
-         return false;
-
-      return match_value(expr->srcs[1], instr, 0, num_components,
-                         swizzle, state);
-   } else {
-      return false;
-   }
-}
-
-typedef struct bitsize_tree {
-   unsigned num_srcs;
-   struct bitsize_tree *srcs[4];
-
-   unsigned common_size;
-   bool is_src_sized[4];
-   bool is_dest_sized;
-
-   unsigned dest_size;
-   unsigned src_size[4];
-} bitsize_tree;
-
-static bitsize_tree *
-build_bitsize_tree(void *mem_ctx, struct match_state *state,
-                   const nir_search_value *value)
-{
-   bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
-
-   switch (value->type) {
-   case nir_search_value_expression: {
-      nir_search_expression *expr = nir_search_value_as_expression(value);
-      nir_op_info info = nir_op_infos[expr->opcode];
-      tree->num_srcs = info.num_inputs;
-      tree->common_size = 0;
-      for (unsigned i = 0; i < info.num_inputs; i++) {
-         tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
-         if (tree->is_src_sized[i])
-            tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
-         tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
-      }
-      tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
-      if (tree->is_dest_sized)
-         tree->dest_size = nir_alu_type_get_type_size(info.output_type);
-      break;
-   }
-
-   case nir_search_value_variable: {
-      nir_search_variable *var = nir_search_value_as_variable(value);
-      tree->num_srcs = 0;
-      tree->is_dest_sized = true;
-      tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
-      break;
-   }
-
-   case nir_search_value_constant: {
-      tree->num_srcs = 0;
-      tree->is_dest_sized = false;
-      tree->common_size = 0;
-      break;
-   }
-   }
-
-   if (value->bit_size) {
-      assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
-      tree->common_size = value->bit_size;
-   }
-
-   return tree;
+   return matched;
 }
 
 static unsigned
-bitsize_tree_filter_up(bitsize_tree *tree)
-{
-   for (unsigned i = 0; i < tree->num_srcs; i++) {
-      unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
-      if (src_size == 0)
-         continue;
-
-      if (tree->is_src_sized[i]) {
-         assert(src_size == tree->src_size[i]);
-      } else if (tree->common_size != 0) {
-         assert(src_size == tree->common_size);
-         tree->src_size[i] = src_size;
-      } else {
-         tree->common_size = src_size;
-         tree->src_size[i] = src_size;
-      }
-   }
-
-   if (tree->num_srcs && tree->common_size) {
-      if (tree->dest_size == 0)
-         tree->dest_size = tree->common_size;
-      else if (!tree->is_dest_sized)
-         assert(tree->dest_size == tree->common_size);
-
-      for (unsigned i = 0; i < tree->num_srcs; i++) {
-         if (!tree->src_size[i])
-            tree->src_size[i] = tree->common_size;
-      }
-   }
-
-   return tree->dest_size;
-}
-
-static void
-bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
+replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
+                struct match_state *state)
 {
-   if (tree->dest_size)
-      assert(tree->dest_size == size);
-   else
-      tree->dest_size = size;
-
-   if (!tree->is_dest_sized) {
-      if (tree->common_size)
-         assert(tree->common_size == size);
-      else
-         tree->common_size = size;
-   }
-
-   for (unsigned i = 0; i < tree->num_srcs; i++) {
-      if (!tree->src_size[i]) {
-         assert(tree->common_size);
-         tree->src_size[i] = tree->common_size;
-      }
-      bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
-   }
+   if (value->bit_size > 0)
+      return value->bit_size;
+   if (value->bit_size < 0)
+      return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
+   return search_bitsize;
 }
 
 static nir_alu_src
-construct_value(const nir_search_value *value,
-                unsigned num_components, bitsize_tree *bitsize,
+construct_value(nir_builder *build,
+                const nir_search_value *value,
+                unsigned num_components, unsigned search_bitsize,
                 struct match_state *state,
-                nir_instr *instr, void *mem_ctx)
+                nir_instr *instr)
 {
    switch (value->type) {
    case nir_search_value_expression: {
       const nir_search_expression *expr = nir_search_value_as_expression(value);
+      unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
+      nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
 
-      if (nir_op_infos[expr->opcode].output_size != 0)
-         num_components = nir_op_infos[expr->opcode].output_size;
+      if (nir_op_infos[op].output_size != 0)
+         num_components = nir_op_infos[op].output_size;
 
-      nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
+      nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
-                        bitsize->dest_size, NULL);
+                        dst_bit_size, NULL);
       alu->dest.write_mask = (1 << num_components) - 1;
       alu->dest.saturate = false;
 
@@ -466,21 +489,26 @@ construct_value(const nir_search_value *value,
        * expression we are replacing has any exact values, the entire
        * replacement should be exact.
        */
-      alu->exact = state->has_exact_alu;
+      alu->exact = state->has_exact_alu || expr->exact;
 
-      for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
+      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
          /* If the source is an explicitly sized source, then we need to reset
           * the number of components to match.
           */
          if (nir_op_infos[alu->op].input_sizes[i] != 0)
             num_components = nir_op_infos[alu->op].input_sizes[i];
 
-         alu->src[i] = construct_value(expr->srcs[i],
-                                       num_components, bitsize->srcs[i],
-                                       state, instr, mem_ctx);
+         alu->src[i] = construct_value(build, expr->srcs[i],
+                                       num_components, search_bitsize,
+                                       state, instr);
       }
 
-      nir_instr_insert_before(instr, &alu->instr);
+      nir_builder_instr_insert(build, &alu->instr);
+
+      assert(alu->dest.dest.ssa.index ==
+             util_dynarray_num_elements(state->states, uint16_t));
+      util_dynarray_append(state->states, uint16_t, 0);
+      nir_algebraic_automaton(&alu->instr, state->states, state->pass_op_table);
 
       nir_alu_src val;
       val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
@@ -496,72 +524,47 @@ construct_value(const nir_search_value *value,
       assert(state->variables_seen & (1 << var->variable));
 
       nir_alu_src val = { NIR_SRC_INIT };
-      nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
-
+      nir_alu_src_copy(&val, &state->variables[var->variable],
+                       (void *)build->shader);
       assert(!var->is_constant);
 
+      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
+         val.swizzle[i] = state->variables[var->variable].swizzle[var->swizzle[i]];
+
       return val;
    }
 
    case nir_search_value_constant: {
       const nir_search_constant *c = nir_search_value_as_constant(value);
-      nir_load_const_instr *load =
-         nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size);
+      unsigned bit_size = replace_bitsize(value, search_bitsize, state);
 
+      nir_ssa_def *cval;
       switch (c->type) {
       case nir_type_float:
-         load->def.name = ralloc_asprintf(load, "%f", c->data.d);
-         switch (bitsize->dest_size) {
-         case 32:
-            load->value.f32[0] = c->data.d;
-            break;
-         case 64:
-            load->value.f64[0] = c->data.d;
-            break;
-         default:
-            unreachable("unknown bit size");
-         }
+         cval = nir_imm_floatN_t(build, c->data.d, bit_size);
          break;
 
       case nir_type_int:
-         load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i);
-         switch (bitsize->dest_size) {
-         case 32:
-            load->value.i32[0] = c->data.i;
-            break;
-         case 64:
-            load->value.i64[0] = c->data.i;
-            break;
-         default:
-            unreachable("unknown bit size");
-         }
-         break;
-
       case nir_type_uint:
-         load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u);
-         switch (bitsize->dest_size) {
-         case 32:
-            load->value.u32[0] = c->data.u;
-            break;
-         case 64:
-            load->value.u64[0] = c->data.u;
-            break;
-         default:
-            unreachable("unknown bit size");
-         }
+         cval = nir_imm_intN_t(build, c->data.i, bit_size);
          break;
 
-      case nir_type_bool32:
-         load->value.u32[0] = c->data.u;
+      case nir_type_bool:
+         cval = nir_imm_boolN_t(build, c->data.u, bit_size);
          break;
+
       default:
          unreachable("Invalid alu source type");
       }
 
-      nir_instr_insert_before(instr, &load->instr);
+      assert(cval->index ==
+             util_dynarray_num_elements(state->states, uint16_t));
+      util_dynarray_append(state->states, uint16_t, 0);
+      nir_algebraic_automaton(cval->parent_instr, state->states,
+                              state->pass_op_table);
 
       nir_alu_src val;
-      val.src = nir_src_for_ssa(&load->def);
+      val.src = nir_src_for_ssa(cval);
       val.negate = false;
       val.abs = false,
       memset(val.swizzle, 0, sizeof val.swizzle);
@@ -574,11 +577,122 @@ construct_value(const nir_search_value *value,
    }
 }
 
-nir_alu_instr *
-nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
-                  const nir_search_value *replace, void *mem_ctx)
+UNUSED static void dump_value(const nir_search_value *val)
+{
+   switch (val->type) {
+   case nir_search_value_constant: {
+      const nir_search_constant *sconst = nir_search_value_as_constant(val);
+      switch (sconst->type) {
+      case nir_type_float:
+         fprintf(stderr, "%f", sconst->data.d);
+         break;
+      case nir_type_int:
+         fprintf(stderr, "%"PRId64, sconst->data.i);
+         break;
+      case nir_type_uint:
+         fprintf(stderr, "0x%"PRIx64, sconst->data.u);
+         break;
+      case nir_type_bool:
+         fprintf(stderr, "%s", sconst->data.u != 0 ? "True" : "False");
+         break;
+      default:
+         unreachable("bad const type");
+      }
+      break;
+   }
+
+   case nir_search_value_variable: {
+      const nir_search_variable *var = nir_search_value_as_variable(val);
+      if (var->is_constant)
+         fprintf(stderr, "#");
+      fprintf(stderr, "%c", var->variable + 'a');
+      break;
+   }
+
+   case nir_search_value_expression: {
+      const nir_search_expression *expr = nir_search_value_as_expression(val);
+      fprintf(stderr, "(");
+      if (expr->inexact)
+         fprintf(stderr, "~");
+      switch (expr->opcode) {
+#define CASE(n) \
+      case nir_search_op_##n: fprintf(stderr, #n); break;
+      CASE(f2b)
+      CASE(b2f)
+      CASE(b2i)
+      CASE(i2b)
+      CASE(i2i)
+      CASE(f2i)
+      CASE(i2f)
+#undef CASE
+      default:
+         fprintf(stderr, "%s", nir_op_infos[expr->opcode].name);
+      }
+
+      unsigned num_srcs = 1;
+      if (expr->opcode <= nir_last_opcode)
+         num_srcs = nir_op_infos[expr->opcode].num_inputs;
+
+      for (unsigned i = 0; i < num_srcs; i++) {
+         fprintf(stderr, " ");
+         dump_value(expr->srcs[i]);
+      }
+
+      fprintf(stderr, ")");
+      break;
+   }
+   }
+
+   if (val->bit_size > 0)
+      fprintf(stderr, "@%d", val->bit_size);
+}
+
+static void
+add_uses_to_worklist(nir_instr *instr, nir_instr_worklist *worklist)
+{
+   nir_ssa_def *def = nir_instr_ssa_def(instr);
+
+   nir_foreach_use_safe(use_src, def) {
+      nir_instr_worklist_push_tail(worklist, use_src->parent_instr);
+   }
+}
+
+static void
+nir_algebraic_update_automaton(nir_instr *new_instr,
+                               nir_instr_worklist *algebraic_worklist,
+                               struct util_dynarray *states,
+                               const struct per_op_table *pass_op_table)
+{
+
+   nir_instr_worklist *automaton_worklist = nir_instr_worklist_create();
+
+   /* Walk through the tree of uses of our new instruction's SSA value,
+    * recursively updating the automaton state until it stabilizes.
+    */
+   add_uses_to_worklist(new_instr, automaton_worklist);
+
+   nir_instr *instr;
+   while ((instr = nir_instr_worklist_pop_head(automaton_worklist))) {
+      if (nir_algebraic_automaton(instr, states, pass_op_table)) {
+         nir_instr_worklist_push_tail(algebraic_worklist, instr);
+
+         add_uses_to_worklist(instr, automaton_worklist);
+      }
+   }
+
+   nir_instr_worklist_destroy(automaton_worklist);
+}
+
+nir_ssa_def *
+nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
+                  struct hash_table *range_ht,
+                  struct util_dynarray *states,
+                  const struct per_op_table *pass_op_table,
+                  const nir_search_expression *search,
+                  const nir_search_value *replace,
+                  nir_instr_worklist *algebraic_worklist)
 {
-   uint8_t swizzle[4] = { 0, 0, 0, 0 };
+   uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
 
    for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
       swizzle[i] = i;
@@ -588,42 +702,269 @@ nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
    struct match_state state;
    state.inexact_match = false;
    state.has_exact_alu = false;
-   state.variables_seen = 0;
+   state.range_ht = range_ht;
+   state.pass_op_table = pass_op_table;
+
+   STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
+
+   unsigned comm_expr_combinations =
+      1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
 
-   if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
-                         swizzle, &state))
+   bool found = false;
+   for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
+      /* The bitfield of directions is just the current iteration.  Hooray for
+       * binary.
+       */
+      state.comm_op_direction = comb;
+      state.variables_seen = 0;
+
+      if (match_expression(search, instr,
+                           instr->dest.dest.ssa.num_components,
+                           swizzle, &state)) {
+         found = true;
+         break;
+      }
+   }
+   if (!found)
       return NULL;
 
-   void *bitsize_ctx = ralloc_context(NULL);
-   bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
-   bitsize_tree_filter_up(tree);
-   bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
+#if 0
+   fprintf(stderr, "matched: ");
+   dump_value(&search->value);
+   fprintf(stderr, " -> ");
+   dump_value(replace);
+   fprintf(stderr, " ssa_%d\n", instr->dest.dest.ssa.index);
+#endif
+
+   /* If the instruction at the root of the expression tree being replaced is
+    * a unary operation, insert the replacement instructions at the location
+    * of the source of the unary operation.  Otherwise, insert the replacement
+    * instructions at the location of the expression tree root.
+    *
+    * For the unary operation case, this is done to prevent some spurious code
+    * motion that can dramatically extend live ranges.  Imagine an expression
+    * like -(A+B) where the addtion and the negation are separated by flow
+    * control and thousands of instructions.  If this expression is replaced
+    * with -A+-B, inserting the new instructions at the site of the negation
+    * could extend the live range of A and B dramtically.  This could increase
+    * register pressure and cause spilling.
+    *
+    * It may well be that moving instructions around is a good thing, but
+    * keeping algebraic optimizations and code motion optimizations separate
+    * seems safest.
+    */
+   nir_alu_instr *const src_instr = nir_src_as_alu_instr(instr->src[0].src);
+   if (src_instr != NULL &&
+       (instr->op == nir_op_fneg || instr->op == nir_op_fabs ||
+        instr->op == nir_op_ineg || instr->op == nir_op_iabs ||
+        instr->op == nir_op_inot)) {
+      /* Insert new instructions *after*.  Otherwise a hypothetical
+       * replacement fneg(X) -> fabs(X) would insert the fabs() instruction
+       * before X!  This can also occur for things like fneg(X.wzyx) -> X.wzyx
+       * in vector mode.  A move instruction to handle the swizzle will get
+       * inserted before X.
+       *
+       * This manifested in a single OpenGL ES 2.0 CTS vertex shader test on
+       * older Intel GPU that use vector-mode vertex processing.
+       */
+      build->cursor = nir_after_instr(&src_instr->instr);
+   } else {
+      build->cursor = nir_before_instr(&instr->instr);
+   }
+
+   state.states = states;
 
-   /* Inserting a mov may be unnecessary.  However, it's much easier to
-    * simply let copy propagation clean this up than to try to go through
-    * and rewrite swizzles ourselves.
+   nir_alu_src val = construct_value(build, replace,
+                                     instr->dest.dest.ssa.num_components,
+                                     instr->dest.dest.ssa.bit_size,
+                                     &state, &instr->instr);
+
+   /* Note that NIR builder will elide the MOV if it's a no-op, which may
+    * allow more work to be done in a single pass through algebraic.
     */
-   nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
-   mov->dest.write_mask = instr->dest.write_mask;
-   nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
-                     instr->dest.dest.ssa.num_components,
-                     instr->dest.dest.ssa.bit_size, NULL);
-
-   mov->src[0] = construct_value(replace,
-                                 instr->dest.dest.ssa.num_components, tree,
-                                 &state, &instr->instr, mem_ctx);
-   nir_instr_insert_before(&instr->instr, &mov->instr);
-
-   nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
-                            nir_src_for_ssa(&mov->dest.dest.ssa));
-
-   /* We know this one has no more uses because we just rewrote them all,
-    * so we can remove it.  The rest of the matched expression, however, we
-    * don't know so much about.  We'll just let dead code clean them up.
+   nir_ssa_def *ssa_val =
+      nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
+   if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) {
+      util_dynarray_append(states, uint16_t, 0);
+      nir_algebraic_automaton(ssa_val->parent_instr, states, pass_op_table);
+   }
+
+   /* Rewrite the uses of the old SSA value to the new one, and recurse
+    * through the uses updating the automaton's state.
+    */
+   nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
+   nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist,
+                                  states, pass_op_table);
+
+   /* Nothing uses the instr any more, so drop it out of the program.  Note
+    * that the instr may be in the worklist still, so we can't free it
+    * directly.
     */
    nir_instr_remove(&instr->instr);
 
-   ralloc_free(bitsize_ctx);
+   return ssa_val;
+}
+
+static bool
+nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
+                        const struct per_op_table *pass_op_table)
+{
+   switch (instr->type) {
+   case nir_instr_type_alu: {
+      nir_alu_instr *alu = nir_instr_as_alu(instr);
+      nir_op op = alu->op;
+      uint16_t search_op = nir_search_op_for_nir_op(op);
+      const struct per_op_table *tbl = &pass_op_table[search_op];
+      if (tbl->num_filtered_states == 0)
+         return false;
+
+      /* Calculate the index into the transition table. Note the index
+       * calculated must match the iteration order of Python's
+       * itertools.product(), which was used to emit the transition
+       * table.
+       */
+      unsigned index = 0;
+      for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
+         index *= tbl->num_filtered_states;
+         index += tbl->filter[*util_dynarray_element(states, uint16_t,
+                                                     alu->src[i].src.ssa->index)];
+      }
+
+      uint16_t *state = util_dynarray_element(states, uint16_t,
+                                              alu->dest.dest.ssa.index);
+      if (*state != tbl->table[index]) {
+         *state = tbl->table[index];
+         return true;
+      }
+      return false;
+   }
+
+   case nir_instr_type_load_const: {
+      nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
+      uint16_t *state = util_dynarray_element(states, uint16_t,
+                                              load_const->def.index);
+      if (*state != CONST_STATE) {
+         *state = CONST_STATE;
+         return true;
+      }
+      return false;
+   }
+
+   default:
+      return false;
+   }
+}
+
+static bool
+nir_algebraic_instr(nir_builder *build, nir_instr *instr,
+                    struct hash_table *range_ht,
+                    const bool *condition_flags,
+                    const struct transform **transforms,
+                    const uint16_t *transform_counts,
+                    struct util_dynarray *states,
+                    const struct per_op_table *pass_op_table,
+                    nir_instr_worklist *worklist)
+{
+
+   if (instr->type != nir_instr_type_alu)
+      return false;
+
+   nir_alu_instr *alu = nir_instr_as_alu(instr);
+   if (!alu->dest.dest.is_ssa)
+      return false;
+
+   unsigned bit_size = alu->dest.dest.ssa.bit_size;
+   const unsigned execution_mode =
+      build->shader->info.float_controls_execution_mode;
+   const bool ignore_inexact =
+      nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
+      nir_is_denorm_flush_to_zero(execution_mode, bit_size);
+
+   int xform_idx = *util_dynarray_element(states, uint16_t,
+                                          alu->dest.dest.ssa.index);
+   for (uint16_t i = 0; i < transform_counts[xform_idx]; i++) {
+      const struct transform *xform = &transforms[xform_idx][i];
+      if (condition_flags[xform->condition_offset] &&
+          !(xform->search->inexact && ignore_inexact) &&
+          nir_replace_instr(build, alu, range_ht, states, pass_op_table,
+                            xform->search, xform->replace, worklist)) {
+         _mesa_hash_table_clear(range_ht, NULL);
+         return true;
+      }
+   }
+
+   return false;
+}
+
+bool
+nir_algebraic_impl(nir_function_impl *impl,
+                   const bool *condition_flags,
+                   const struct transform **transforms,
+                   const uint16_t *transform_counts,
+                   const struct per_op_table *pass_op_table)
+{
+   bool progress = false;
+
+   nir_builder build;
+   nir_builder_init(&build, impl);
+
+   /* Note: it's important here that we're allocating a zeroed array, since
+    * state 0 is the default state, which means we don't have to visit
+    * anything other than constants and ALU instructions.
+    */
+   struct util_dynarray states = {0};
+   if (!util_dynarray_resize(&states, uint16_t, impl->ssa_alloc)) {
+      nir_metadata_preserve(impl, nir_metadata_all);
+      return false;
+   }
+   memset(states.data, 0, states.size);
+
+   struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
+
+   nir_instr_worklist *worklist = nir_instr_worklist_create();
+
+   /* Walk top-to-bottom setting up the automaton state. */
+   nir_foreach_block(block, impl) {
+      nir_foreach_instr(instr, block) {
+         nir_algebraic_automaton(instr, &states, pass_op_table);
+      }
+   }
+
+   /* Put our instrs in the worklist such that we're popping the last instr
+    * first.  This will encourage us to match the biggest source patterns when
+    * possible.
+    */
+   nir_foreach_block_reverse(block, impl) {
+      nir_foreach_instr_reverse(instr, block) {
+         nir_instr_worklist_push_tail(worklist, instr);
+      }
+   }
+
+   nir_instr *instr;
+   while ((instr = nir_instr_worklist_pop_head(worklist))) {
+      /* The worklist can have an instr pushed to it multiple times if it was
+       * the src of multiple instrs that also got optimized, so make sure that
+       * we don't try to re-optimize an instr we already handled.
+       */
+      if (exec_node_is_tail_sentinel(&instr->node))
+         continue;
+
+      progress |= nir_algebraic_instr(&build, instr,
+                                      range_ht, condition_flags,
+                                      transforms, transform_counts, &states,
+                                      pass_op_table, worklist);
+   }
+
+   nir_instr_worklist_destroy(worklist);
+   ralloc_free(range_ht);
+   util_dynarray_fini(&states);
+
+   if (progress) {
+      nir_metadata_preserve(impl, nir_metadata_block_index |
+                                  nir_metadata_dominance);
+   } else {
+      nir_metadata_preserve(impl, nir_metadata_all);
+   }
 
-   return mov;
+   return progress;
 }