nir: Add a new nir_var_mem_constant variable mode
[mesa.git] / src / compiler / nir / nir_algebraic.py
index 47f374bfabd68d12edb0b6fc4425db30ab484de7..6871af36ef8f06ace44749b0bbd63ad2d9ea67f2 100644 (file)
@@ -35,6 +35,9 @@ import traceback
 
 from nir_opcodes import opcodes, type_sizes
 
+# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
+nir_search_max_comm_ops = 8
+
 # These opcodes are only employed by nir_search.  This provides a mapping from
 # opcode to destination type.
 conv_opcode_types = {
@@ -195,8 +198,9 @@ class Value(object):
    ${'true' if val.is_constant else 'false'},
    ${val.type() or 'nir_type_invalid' },
    ${val.cond if val.cond else 'NULL'},
+   ${val.swizzle()},
 % elif isinstance(val, Expression):
-   ${'true' if val.inexact else 'false'},
+   ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
    ${val.comm_expr_idx}, ${val.comm_exprs},
    ${val.c_opcode()},
    { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
@@ -266,16 +270,34 @@ class Constant(Value):
       elif isinstance(self.value, float):
          return "nir_type_float"
 
+   def equivalent(self, other):
+      """Check that two constants are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      return self.value == other.value
+
+# The $ at the end forces there to be an error if any part of the string
+# doesn't match one of the field patterns.
 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
                           r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
-                          r"(?P<cond>\([^\)]+\))?")
+                          r"(?P<cond>\([^\)]+\))?"
+                          r"(?P<swiz>\.[xyzw]+)?"
+                          r"$")
 
 class Variable(Value):
    def __init__(self, val, name, varset):
       Value.__init__(self, val, name, "variable")
 
       m = _var_name_re.match(val)
-      assert m and m.group('name') is not None
+      assert m and m.group('name') is not None, \
+            "Malformed variable name \"{}\".".format(val)
 
       self.var_name = m.group('name')
 
@@ -283,13 +305,14 @@ class Variable(Value):
       # constant.  If we want to support names that have numeric or
       # punctuation characters, we can me the first assertion more flexible.
       assert self.var_name.isalpha()
-      assert self.var_name is not 'True'
-      assert self.var_name is not 'False'
+      assert self.var_name != 'True'
+      assert self.var_name != 'False'
 
       self.is_constant = m.group('const') is not None
       self.cond = m.group('cond')
       self.required_type = m.group('type')
       self._bit_size = int(m.group('bits')) if m.group('bits') else None
+      self.swiz = m.group('swiz')
 
       if self.required_type == 'bool':
          if self._bit_size is not None:
@@ -310,7 +333,30 @@ class Variable(Value):
       elif self.required_type == 'float':
          return "nir_type_float"
 
-_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
+   def equivalent(self, other):
+      """Check that two variables are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      return self.index == other.index
+
+   def swizzle(self):
+      if self.swiz is not None:
+         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
+                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
+                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
+                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
+                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
+         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
+      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
+
+_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
                         r"(?P<cond>\([^\)]+\))?")
 
 class Expression(Value):
@@ -324,7 +370,25 @@ class Expression(Value):
       self.opcode = m.group('opcode')
       self._bit_size = int(m.group('bits')) if m.group('bits') else None
       self.inexact = m.group('inexact') is not None
+      self.exact = m.group('exact') is not None
       self.cond = m.group('cond')
+
+      assert not self.inexact or not self.exact, \
+            'Expression cannot be both exact and inexact.'
+
+      # "many-comm-expr" isn't really a condition.  It's notification to the
+      # generator that this pattern is known to have too many commutative
+      # expressions, and an error should not be generated for this case.
+      self.many_commutative_expressions = False
+      if self.cond and self.cond.find("many-comm-expr") >= 0:
+         # Split the condition into a comma-separated list.  Remove
+         # "many-comm-expr".  If there is anything left, put it back together.
+         c = self.cond[1:-1].split(",")
+         c.remove("many-comm-expr")
+
+         self.cond = "({})".format(",".join(c)) if c else None
+         self.many_commutative_expressions = True
+
       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
                        for (i, src) in enumerate(expr[1:]) ]
 
@@ -335,12 +399,41 @@ class Expression(Value):
 
       self.__index_comm_exprs(0)
 
+   def equivalent(self, other):
+      """Check that two variables are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      This implementation does not check for equivalence due to commutativity,
+      but it could.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      if len(self.sources) != len(other.sources):
+         return False
+
+      if self.opcode != other.opcode:
+         return False
+
+      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
    def __index_comm_exprs(self, base_idx):
       """Recursively count and index commutative expressions
       """
       self.comm_exprs = 0
+
+      # A note about the explicit "len(self.sources)" check. The list of
+      # sources comes from user input, and that input might be bad.  Check
+      # that the expected second source exists before accessing it. Without
+      # this check, a unit test that does "('iadd', 'a')" will crash.
       if self.opcode not in conv_opcode_types and \
-         "commutative" in opcodes[self.opcode].algebraic_properties:
+         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+         len(self.sources) >= 2 and \
+         not self.sources[0].equivalent(self.sources[1]):
          self.comm_expr_idx = base_idx
          self.comm_exprs += 1
       else:
@@ -796,12 +889,12 @@ class TreeAutomaton(object):
       self.opcodes = self.IndexMap()
 
       def get_item(opcode, children, pattern=None):
-         commutative = len(children) == 2 \
-               and "commutative" in opcodes[opcode].algebraic_properties
+         commutative = len(children) >= 2 \
+               and "2src_commutative" in opcodes[opcode].algebraic_properties
          item = self.items.setdefault((opcode, children),
                                       self.Item(opcode, children))
          if commutative:
-            self.items[opcode, (children[1], children[0])] = item
+            self.items[opcode, (children[1], children[0]) + children[2:]] = item
          if pattern is not None:
             item.patterns.append(pattern)
          return item
@@ -962,30 +1055,13 @@ _algebraic_pass_template = mako.template.Template("""
 #include "nir_search.h"
 #include "nir_search_helpers.h"
 
-#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-
-struct transform {
-   const nir_search_expression *search;
-   const nir_search_value *replace;
-   unsigned condition_offset;
-};
-
-struct per_op_table {
-   const uint16_t *filter;
-   unsigned num_filtered_states;
-   const uint16_t *table;
-};
-
-/* Note: these must match the start states created in
- * TreeAutomaton._build_table()
+/* What follows is NIR algebraic transform code for the following ${len(xforms)}
+ * transforms:
+% for xform in xforms:
+ *    ${xform.search} => ${xform.replace}
+% endfor
  */
 
-/* WILDCARD_STATE = 0 is set by zeroing the state array */
-static const uint16_t CONST_STATE = 1;
-
-#endif
-
 <% cache = {} %>
 % for xform in xforms:
    ${xform.search.render(cache)}
@@ -1026,117 +1102,25 @@ static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
 % endfor
 };
 
-static void
-${pass_name}_pre_block(nir_block *block, uint16_t *states)
-{
-   nir_foreach_instr(instr, block) {
-      switch (instr->type) {
-      case nir_instr_type_alu: {
-         nir_alu_instr *alu = nir_instr_as_alu(instr);
-         nir_op op = alu->op;
-         uint16_t search_op = nir_search_op_for_nir_op(op);
-         const struct per_op_table *tbl = &${pass_name}_table[search_op];
-         if (tbl->num_filtered_states == 0)
-            continue;
-
-         /* Calculate the index into the transition table. Note the index
-          * calculated must match the iteration order of Python's
-          * itertools.product(), which was used to emit the transition
-          * table.
-          */
-         uint16_t index = 0;
-         for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
-            index *= tbl->num_filtered_states;
-            index += tbl->filter[states[alu->src[i].src.ssa->index]];
-         }
-         states[alu->dest.dest.ssa.index] = tbl->table[index];
-         break;
-      }
-
-      case nir_instr_type_load_const: {
-         nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
-         states[load_const->def.index] = CONST_STATE;
-         break;
-      }
-
-      default:
-         break;
-      }
-   }
-}
-
-static bool
-${pass_name}_block(nir_builder *build, nir_block *block,
-                   const uint16_t *states, const bool *condition_flags)
-{
-   bool progress = false;
-
-   nir_foreach_instr_reverse_safe(instr, block) {
-      if (instr->type != nir_instr_type_alu)
-         continue;
-
-      nir_alu_instr *alu = nir_instr_as_alu(instr);
-      if (!alu->dest.dest.is_ssa)
-         continue;
-
-      switch (states[alu->dest.dest.ssa.index]) {
+const struct transform *${pass_name}_transforms[] = {
 % for i in range(len(automaton.state_patterns)):
-      case ${i}:
-         % if automaton.state_patterns[i]:
-         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
-            const struct transform *xform = &${pass_name}_state${i}_xforms[i];
-            if (condition_flags[xform->condition_offset] &&
-                nir_replace_instr(build, alu, xform->search, xform->replace)) {
-               progress = true;
-               break;
-            }
-         }
-         % endif
-         break;
+   % if automaton.state_patterns[i]:
+   ${pass_name}_state${i}_xforms,
+   % else:
+   NULL,
+   % endif
 % endfor
-      default: assert(0);
-      }
-   }
-
-   return progress;
-}
-
-static bool
-${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
-{
-   bool progress = false;
-
-   nir_builder build;
-   nir_builder_init(&build, impl);
-
-   /* Note: it's important here that we're allocating a zeroed array, since
-    * state 0 is the default state, which means we don't have to visit
-    * anything other than constants and ALU instructions.
-    */
-   uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
-
-   nir_foreach_block(block, impl) {
-      ${pass_name}_pre_block(block, states);
-   }
-
-   nir_foreach_block_reverse(block, impl) {
-      progress |= ${pass_name}_block(&build, block, states, condition_flags);
-   }
-
-   free(states);
-
-   if (progress) {
-      nir_metadata_preserve(impl, nir_metadata_block_index |
-                                  nir_metadata_dominance);
-    } else {
-#ifndef NDEBUG
-      impl->valid_metadata &= ~nir_metadata_not_properly_reset;
-#endif
-    }
-
-   return progress;
-}
+};
 
+const uint16_t ${pass_name}_transform_counts[] = {
+% for i in range(len(automaton.state_patterns)):
+   % if automaton.state_patterns[i]:
+   (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
+   % else:
+   0,
+   % endif
+% endfor
+};
 
 bool
 ${pass_name}(nir_shader *shader)
@@ -1153,15 +1137,18 @@ ${pass_name}(nir_shader *shader)
    % endfor
 
    nir_foreach_function(function, shader) {
-      if (function->impl)
-         progress |= ${pass_name}_impl(function->impl, condition_flags);
+      if (function->impl) {
+         progress |= nir_algebraic_impl(function->impl, condition_flags,
+                                        ${pass_name}_transforms,
+                                        ${pass_name}_transform_counts,
+                                        ${pass_name}_table);
+      }
    }
 
    return progress;
 }
 """)
 
-      
 
 class AlgebraicPass(object):
    def __init__(self, pass_name, transforms):
@@ -1192,6 +1179,32 @@ class AlgebraicPass(object):
          else:
             self.opcode_xforms[xform.search.opcode].append(xform)
 
+         # Check to make sure the search pattern does not unexpectedly contain
+         # more commutative expressions than match_expression (nir_search.c)
+         # can handle.
+         comm_exprs = xform.search.comm_exprs
+
+         if xform.search.many_commutative_expressions:
+            if comm_exprs <= nir_search_max_comm_ops:
+               print("Transform expected to have too many commutative " \
+                     "expression but did not " \
+                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
+                     file=sys.stderr)
+               print("  " + str(xform), file=sys.stderr)
+               traceback.print_exc(file=sys.stderr)
+               print('', file=sys.stderr)
+               error = True
+         else:
+            if comm_exprs > nir_search_max_comm_ops:
+               print("Transformation with too many commutative expressions " \
+                     "({} > {}).  Modify pattern or annotate with " \
+                     "\"many-comm-expr\".".format(comm_exprs,
+                                                  nir_search_max_comm_ops),
+                     file=sys.stderr)
+               print("  " + str(xform.search), file=sys.stderr)
+               print("{}".format(xform.search.cond), file=sys.stderr)
+               error = True
+
       self.automaton = TreeAutomaton(self.xforms)
 
       if error: