nir: Make nir_search's dumping go to stderr.
[mesa.git] / src / compiler / nir / nir_algebraic.py
index d15d4ba3d67997394f6f934c6c9abaa53ae3f0af..fe66952ba16fce8623f563b27517100139023eb5 100644 (file)
@@ -36,7 +36,7 @@ import traceback
 from nir_opcodes import opcodes, type_sizes
 
 # This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
-nir_search_max_comm_ops = 4
+nir_search_max_comm_ops = 8
 
 # These opcodes are only employed by nir_search.  This provides a mapping from
 # opcode to destination type.
@@ -198,6 +198,7 @@ class Value(object):
    ${'true' if val.is_constant else 'false'},
    ${val.type() or 'nir_type_invalid' },
    ${val.cond if val.cond else 'NULL'},
+   ${val.swizzle()},
 % elif isinstance(val, Expression):
    ${'true' if val.inexact else 'false'},
    ${val.comm_expr_idx}, ${val.comm_exprs},
@@ -269,9 +270,23 @@ class Constant(Value):
       elif isinstance(self.value, float):
          return "nir_type_float"
 
+   def equivalent(self, other):
+      """Check that two constants are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      return self.value == other.value
+
 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
                           r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
-                          r"(?P<cond>\([^\)]+\))?")
+                          r"(?P<cond>\([^\)]+\))?"
+                          r"(?P<swiz>\.[xyzw]+)?")
 
 class Variable(Value):
    def __init__(self, val, name, varset):
@@ -293,6 +308,7 @@ class Variable(Value):
       self.cond = m.group('cond')
       self.required_type = m.group('type')
       self._bit_size = int(m.group('bits')) if m.group('bits') else None
+      self.swiz = m.group('swiz')
 
       if self.required_type == 'bool':
          if self._bit_size is not None:
@@ -313,6 +329,25 @@ class Variable(Value):
       elif self.required_type == 'float':
          return "nir_type_float"
 
+   def equivalent(self, other):
+      """Check that two variables are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      return self.index == other.index
+
+   def swizzle(self):
+      if self.swiz is not None:
+         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w': 3}
+         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
+      return '{0, 1, 2, 3}'
+
 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
                         r"(?P<cond>\([^\)]+\))?")
 
@@ -352,12 +387,41 @@ class Expression(Value):
 
       self.__index_comm_exprs(0)
 
+   def equivalent(self, other):
+      """Check that two variables are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      This implementation does not check for equivalence due to commutativity,
+      but it could.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      if len(self.sources) != len(other.sources):
+         return False
+
+      if self.opcode != other.opcode:
+         return False
+
+      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
    def __index_comm_exprs(self, base_idx):
       """Recursively count and index commutative expressions
       """
       self.comm_exprs = 0
+
+      # A note about the explicit "len(self.sources)" check. The list of
+      # sources comes from user input, and that input might be bad.  Check
+      # that the expected second source exists before accessing it. Without
+      # this check, a unit test that does "('iadd', 'a')" will crash.
       if self.opcode not in conv_opcode_types and \
-         "2src_commutative" in opcodes[self.opcode].algebraic_properties:
+         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+         len(self.sources) >= 2 and \
+         not self.sources[0].equivalent(self.sources[1]):
          self.comm_expr_idx = base_idx
          self.comm_exprs += 1
       else:
@@ -979,6 +1043,13 @@ _algebraic_pass_template = mako.template.Template("""
 #include "nir_search.h"
 #include "nir_search_helpers.h"
 
+/* What follows is NIR algebraic transform code for the following ${len(xforms)}
+ * transforms:
+% for xform in xforms:
+ *    ${xform.search} => ${xform.replace}
+% endfor
+ */
+
 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
 
@@ -1087,6 +1158,7 @@ ${pass_name}_block(nir_builder *build, nir_block *block,
                    const uint16_t *states, const bool *condition_flags)
 {
    bool progress = false;
+   const unsigned execution_mode = build->shader->info.float_controls_execution_mode;
 
    nir_foreach_instr_reverse_safe(instr, block) {
       if (instr->type != nir_instr_type_alu)
@@ -1096,6 +1168,11 @@ ${pass_name}_block(nir_builder *build, nir_block *block,
       if (!alu->dest.dest.is_ssa)
          continue;
 
+      unsigned bit_size = alu->dest.dest.ssa.bit_size;
+      const bool ignore_inexact =
+         nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
+         nir_is_denorm_flush_to_zero(execution_mode, bit_size);
+
       switch (states[alu->dest.dest.ssa.index]) {
 % for i in range(len(automaton.state_patterns)):
       case ${i}:
@@ -1103,6 +1180,7 @@ ${pass_name}_block(nir_builder *build, nir_block *block,
          for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
             const struct transform *xform = &${pass_name}_state${i}_xforms[i];
             if (condition_flags[xform->condition_offset] &&
+                !(xform->search->inexact && ignore_inexact) &&
                 nir_replace_instr(build, alu, xform->search, xform->replace)) {
                progress = true;
                break;