nir/search: Search for all combinations of commutative ops
authorJason Ekstrand <jason.ekstrand@intel.com>
Fri, 22 Mar 2019 22:45:29 +0000 (17:45 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Mon, 8 Apr 2019 21:38:48 +0000 (21:38 +0000)
Consider the following search expression and NIR sequence:

    ('iadd', ('imul', a, b), b)

    ssa_2 = imul ssa_0, ssa_1
    ssa_3 = iadd ssa_2, ssa_0

The current algorithm is greedy and, the moment the imul finds a match,
it commits those variable names and returns success.  In the above
example, it maps a -> ssa_0 and b -> ssa_1.  When we then try to match
the iadd, it sees that ssa_0 is not b and fails to match.  The iadd
match will attempt to flip itself and try again (which won't work) but
it cannot ask the imul to try a flipped match.

This commit instead counts the number of commutative ops in each
expression and assigns an index to each.  It then does a loop and loops
over the full combinatorial matrix of commutative operations.  In order
to keep things sane, we limit it to at most 4 commutative operations (16
combinations).  There is only one optimization in opt_algebraic that
goes over this limit and it's the bitfieldReverse detection for some UE4
demo.

Shader-db results on Kaby Lake:

    total instructions in shared programs: 15310125 -> 15302469 (-0.05%)
    instructions in affected programs: 1797123 -> 1789467 (-0.43%)
    helped: 6751
    HURT: 2264

    total cycles in shared programs: 357346617 -> 357202526 (-0.04%)
    cycles in affected programs: 15931005 -> 15786914 (-0.90%)
    helped: 6024
    HURT: 3436

    total loops in shared programs: 4360 -> 4360 (0.00%)
    loops in affected programs: 0 -> 0
    helped: 0
    HURT: 0

    total spills in shared programs: 23675 -> 23666 (-0.04%)
    spills in affected programs: 235 -> 226 (-3.83%)
    helped: 5
    HURT: 1

    total fills in shared programs: 32040 -> 32032 (-0.02%)
    fills in affected programs: 190 -> 182 (-4.21%)
    helped: 6
    HURT: 2

    LOST:   18
    GAINED: 5

Reviewed-by: Thomas Helland <thomashelland90@gmail.com>
src/compiler/nir/nir_algebraic.py
src/compiler/nir/nir_search.c
src/compiler/nir/nir_search.h

index fe9d1051e67ecee8a74082f29325e5b1edb4d66f..d4b3bb5957f4c42d0a0950600fdc5219a1143832 100644 (file)
@@ -114,6 +114,7 @@ static const ${val.c_type} ${val.name} = {
    ${val.cond if val.cond else 'NULL'},
 % elif isinstance(val, Expression):
    ${'true' if val.inexact else 'false'},
+   ${val.comm_expr_idx}, ${val.comm_exprs},
    ${val.c_opcode()},
    { ${', '.join(src.c_ptr for src in val.sources)} },
    ${val.cond if val.cond else 'NULL'},
@@ -307,6 +308,25 @@ class Expression(Value):
                 'Expression cannot use an unsized conversion opcode with ' \
                 'an explicit size; that\'s silly.'
 
+      self.__index_comm_exprs(0)
+
+   def __index_comm_exprs(self, base_idx):
+      """Recursively count and index commutative expressions
+      """
+      self.comm_exprs = 0
+      if self.opcode not in conv_opcode_types and \
+         "commutative" in opcodes[self.opcode].algebraic_properties:
+         self.comm_expr_idx = base_idx
+         self.comm_exprs += 1
+      else:
+         self.comm_expr_idx = -1
+
+      for s in self.sources:
+         if isinstance(s, Expression):
+            s.__index_comm_exprs(base_idx + self.comm_exprs)
+            self.comm_exprs += s.comm_exprs
+
+      return self.comm_exprs
 
    def c_opcode(self):
       if self.opcode in conv_opcode_types:
index d257b63918927b9b8a00686c5caa8c814a7789ae..df27a2473eeeefabb4631088a4e03b69a4a87cdd 100644 (file)
 #include "nir_builder.h"
 #include "util/half_float.h"
 
+#define NIR_SEARCH_MAX_COMM_OPS 4
+
 struct match_state {
    bool inexact_match;
    bool has_exact_alu;
+   uint8_t comm_op_direction;
    unsigned variables_seen;
    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
 };
@@ -349,41 +352,25 @@ match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
       }
    }
 
-   /* Stash off the current variables_seen bitmask.  This way we can
-    * restore it prior to matching in the commutative case below.
+   /* If this is a commutative expression and it's one of the first few, look
+    * up its direction for the current search operation.  We'll use that value
+    * to possibly flip the sources for the match.
     */
-   unsigned variables_seen_stash = state->variables_seen;
+   unsigned comm_op_flip =
+      (expr->comm_expr_idx >= 0 &&
+       expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
+      ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
 
    bool matched = true;
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
-      if (!match_value(expr->srcs[i], instr, i, num_components,
-                       swizzle, state)) {
+      if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
+                       num_components, swizzle, state)) {
          matched = false;
          break;
       }
    }
 
-   if (matched)
-      return true;
-
-   if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
-      assert(nir_op_infos[instr->op].num_inputs == 2);
-
-      /* Restore the variables_seen bitmask.  If we don't do this, then we
-       * could end up with an erroneous failure due to variables found in the
-       * first match attempt above not matching those in the second.
-       */
-      state->variables_seen = variables_seen_stash;
-
-      if (!match_value(expr->srcs[0], instr, 1, num_components,
-                       swizzle, state))
-         return false;
-
-      return match_value(expr->srcs[1], instr, 0, num_components,
-                         swizzle, state);
-   } else {
-      return false;
-   }
+   return matched;
 }
 
 static unsigned
@@ -513,10 +500,26 @@ nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
    struct match_state state;
    state.inexact_match = false;
    state.has_exact_alu = false;
-   state.variables_seen = 0;
 
-   if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
-                         swizzle, &state))
+   unsigned comm_expr_combinations =
+      1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
+
+   bool found = false;
+   for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
+      /* The bitfield of directions is just the current iteration.  Hooray for
+       * binary.
+       */
+      state.comm_op_direction = comb;
+      state.variables_seen = 0;
+
+      if (match_expression(search, instr,
+                           instr->dest.dest.ssa.num_components,
+                           swizzle, &state)) {
+         found = true;
+         break;
+      }
+   }
+   if (!found)
       return NULL;
 
    build->cursor = nir_before_instr(&instr->instr);
index 1c78d0a3201a0986b5533e4db632935e3719dc58..9dc09d2361c738ecb23eff0f130158d18ecb7b80 100644 (file)
@@ -132,6 +132,18 @@ typedef struct {
     */
    bool inexact;
 
+   /* Commutative expression index.  This is assigned by opt_algebraic.py when
+    * search structures are constructed and is a unique (to this structure)
+    * index within the commutative operation bitfield used for searching for
+    * all combinations of expressions containing commutative operations.
+    */
+   int8_t comm_expr_idx;
+
+   /* Number of commutative expressions in this expression including this one
+    * (if it is commutative).
+    */
+   uint8_t comm_exprs;
+
    /* One of nir_op or nir_search_op */
    uint16_t opcode;
    const nir_search_value *srcs[4];