nir: Don't reassociate add/mul chains containing only constants
[mesa.git] / src / compiler / nir / nir_algebraic.py
index 39b85089b1e726c4d8d7df0d444932bddca18ea3..fe9d1051e67ecee8a74082f29325e5b1edb4d66f 100644 (file)
@@ -25,7 +25,7 @@
 
 from __future__ import print_function
 import ast
-from collections import OrderedDict
+from collections import defaultdict
 import itertools
 import struct
 import sys
@@ -33,7 +33,23 @@ import mako.template
 import re
 import traceback
 
-from nir_opcodes import opcodes
+from nir_opcodes import opcodes, type_sizes
+
+# These opcodes are only employed by nir_search.  This provides a mapping from
+# opcode to destination type.
+conv_opcode_types = {
+    'i2f' : 'float',
+    'u2f' : 'float',
+    'f2f' : 'float',
+    'f2u' : 'uint',
+    'f2i' : 'int',
+    'u2u' : 'uint',
+    'i2i' : 'int',
+    'b2f' : 'float',
+    'b2i' : 'int',
+    'i2b' : 'bool',
+    'f2b' : 'bool',
+}
 
 if sys.version_info < (3, 0):
     integer_types = (int, long)
@@ -88,7 +104,7 @@ class Value(object):
 
    __template = mako.template.Template("""
 static const ${val.c_type} ${val.name} = {
-   { ${val.type_enum}, ${val.bit_size} },
+   { ${val.type_enum}, ${val.c_bit_size} },
 % if isinstance(val, Constant):
    ${val.type()}, { ${val.hex()} /* ${val.value} */ },
 % elif isinstance(val, Variable):
@@ -98,7 +114,7 @@ static const ${val.c_type} ${val.name} = {
    ${val.cond if val.cond else 'NULL'},
 % elif isinstance(val, Expression):
    ${'true' if val.inexact else 'false'},
-   nir_op_${val.opcode},
+   ${val.c_opcode()},
    { ${', '.join(src.c_ptr for src in val.sources)} },
    ${val.cond if val.cond else 'NULL'},
 % endif
@@ -112,6 +128,40 @@ static const ${val.c_type} ${val.name} = {
    def __str__(self):
       return self.in_val
 
+   def get_bit_size(self):
+      """Get the physical bit-size that has been chosen for this value, or if
+      there is none, the canonical value which currently represents this
+      bit-size class. Variables will be preferred, i.e. if there are any
+      variables in the equivalence class, the canonical value will be a
+      variable. We do this since we'll need to know which variable each value
+      is equivalent to when constructing the replacement expression. This is
+      the "find" part of the union-find algorithm.
+      """
+      bit_size = self
+
+      while isinstance(bit_size, Value):
+         if bit_size._bit_size is None:
+            break
+         bit_size = bit_size._bit_size
+
+      if bit_size is not self:
+         self._bit_size = bit_size
+      return bit_size
+
+   def set_bit_size(self, other):
+      """Make self.get_bit_size() return what other.get_bit_size() return
+      before calling this, or just "other" if it's a concrete bit-size. This is
+      the "union" part of the union-find algorithm.
+      """
+
+      self_bit_size = self.get_bit_size()
+      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
+
+      if self_bit_size == other_bit_size:
+         return
+
+      self_bit_size._bit_size = other_bit_size
+
    @property
    def type_enum(self):
       return "nir_search_value_" + self.type_str
@@ -124,6 +174,21 @@ static const ${val.c_type} ${val.name} = {
    def c_ptr(self):
       return "&{0}.value".format(self.name)
 
+   @property
+   def c_bit_size(self):
+      bit_size = self.get_bit_size()
+      if isinstance(bit_size, int):
+         return bit_size
+      elif isinstance(bit_size, Variable):
+         return -bit_size.index - 1
+      else:
+         # If the bit-size class is neither a variable, nor an actual bit-size, then
+         # - If it's in the search expression, we don't need to check anything
+         # - If it's in the replace expression, either it's ambiguous (in which
+         # case we'd reject it), or it equals the bit-size of the search value
+         # We represent these cases with a 0 bit-size.
+         return 0
+
    def render(self):
       return self.__template.render(val=self,
                                     Constant=Constant,
@@ -136,18 +201,17 @@ class Constant(Value):
    def __init__(self, val, name):
       Value.__init__(self, val, name, "constant")
 
-      self.in_val = str(val)
       if isinstance(val, (str)):
          m = _constant_re.match(val)
          self.value = ast.literal_eval(m.group('value'))
-         self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+         self._bit_size = int(m.group('bits')) if m.group('bits') else None
       else:
          self.value = val
-         self.bit_size = 0
+         self._bit_size = None
 
       if isinstance(self.value, bool):
-         assert self.bit_size == 0 or self.bit_size == 32
-         self.bit_size = 32
+         assert self._bit_size is None or self._bit_size == 1
+         self._bit_size = 1
 
    def hex(self):
       if isinstance(self.value, (bool)):
@@ -188,23 +252,30 @@ class Variable(Value):
       assert m and m.group('name') is not None
 
       self.var_name = m.group('name')
+
+      # Prevent common cases where someone puts quotes around a literal
+      # constant.  If we want to support names that have numeric or
+      # punctuation characters, we can me the first assertion more flexible.
+      assert self.var_name.isalpha()
+      assert self.var_name is not 'True'
+      assert self.var_name is not 'False'
+
       self.is_constant = m.group('const') is not None
       self.cond = m.group('cond')
       self.required_type = m.group('type')
-      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+      self._bit_size = int(m.group('bits')) if m.group('bits') else None
 
       if self.required_type == 'bool':
-         assert self.bit_size == 0 or self.bit_size == 32
-         self.bit_size = 32
+         if self._bit_size is not None:
+            assert self._bit_size in type_sizes(self.required_type)
+         else:
+            self._bit_size = 1
 
       if self.required_type is not None:
          assert self.required_type in ('float', 'bool', 'int', 'uint')
 
       self.index = varset[self.var_name]
 
-   def __str__(self):
-      return self.in_val
-
    def type(self):
       if self.required_type == 'bool':
          return "nir_type_bool"
@@ -225,49 +296,27 @@ class Expression(Value):
       assert m and m.group('opcode') is not None
 
       self.opcode = m.group('opcode')
-      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+      self._bit_size = int(m.group('bits')) if m.group('bits') else None
       self.inexact = m.group('inexact') is not None
       self.cond = m.group('cond')
       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
                        for (i, src) in enumerate(expr[1:]) ]
 
-   def render(self):
-      srcs = "\n".join(src.render() for src in self.sources)
-      return srcs + super(Expression, self).render()
+      if self.opcode in conv_opcode_types:
+         assert self._bit_size is None, \
+                'Expression cannot use an unsized conversion opcode with ' \
+                'an explicit size; that\'s silly.'
 
-class IntEquivalenceRelation(object):
-   """A class representing an equivalence relation on integers.
 
-   Each integer has a canonical form which is the maximum integer to which it
-   is equivalent.  Two integers are equivalent precisely when they have the
-   same canonical form.
-
-   The convention of maximum is explicitly chosen to make using it in
-   BitSizeValidator easier because it means that an actual bit_size (if any)
-   will always be the canonical form.
-   """
-   def __init__(self):
-      self._remap = {}
-
-   def get_canonical(self, x):
-      """Get the canonical integer corresponding to x."""
-      if x in self._remap:
-         return self.get_canonical(self._remap[x])
+   def c_opcode(self):
+      if self.opcode in conv_opcode_types:
+         return 'nir_search_op_' + self.opcode
       else:
-         return x
-
-   def add_equiv(self, a, b):
-      """Add an equivalence and return the canonical form."""
-      c = max(self.get_canonical(a), self.get_canonical(b))
-      if a != c:
-         assert a < c
-         self._remap[a] = c
-
-      if b != c:
-         assert b < c
-         self._remap[b] = c
+         return 'nir_op_' + self.opcode
 
-      return c
+   def render(self):
+      srcs = "\n".join(src.render() for src in self.sources)
+      return srcs + super(Expression, self).render()
 
 class BitSizeValidator(object):
    """A class for validating bit sizes of expressions.
@@ -296,7 +345,7 @@ class BitSizeValidator(object):
    inference can be ambiguous or contradictory.  Consider, for instance, the
    following transformation:
 
-   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
+   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
 
    This transformation can potentially cause a problem because usub_borrow is
    well-defined for any bit-size of integer.  However, b2i always generates a
@@ -315,180 +364,261 @@ class BitSizeValidator(object):
    generate any code.  This ensures that bugs are caught at compile time
    rather than at run time.
 
-   The basic operation of the validator is very similar to the bitsize_tree in
-   nir_search only a little more subtle.  Instead of simply tracking bit
-   sizes, it tracks "bit classes" where each class is represented by an
-   integer.  A value of 0 means we don't know anything yet, positive values
-   are actual bit-sizes, and negative values are used to track equivalence
-   classes of sizes that must be the same but have yet to receive an actual
-   size.  The first stage uses the bitsize_tree algorithm to assign bit
-   classes to each variable.  If it ever comes across an inconsistency, it
-   assert-fails.  Then the second stage uses that information to prove that
-   the resulting expression can always validly be constructed.
-   """
-
-   def __init__(self, varset):
-      self._num_classes = 0
-      self._var_classes = [0] * len(varset.names)
-      self._class_relation = IntEquivalenceRelation()
+   Each value maintains a "bit-size class", which is either an actual bit size
+   or an equivalence class with other values that must have the same bit size.
+   The validator works by combining bit-size classes with each other according
+   to the NIR rules outlined above, checking that there are no inconsistencies.
+   When doing this for the replacement expression, we make sure to never change
+   the equivalence class of any of the search values. We could make the example
+   transforms above work by doing some extra run-time checking of the search
+   expression, but we make the user specify those constraints themselves, to
+   avoid any surprises. Since the replacement bitsizes can only be connected to
+   the source bitsize via variables (variables must have the same bitsize in
+   the source and replacment expressions) or the roots of the expression (the
+   replacement expression must produce the same bit size as the search
+   expression), we prevent merging a variable with anything when processing the
+   replacement expression, or specializing the search bitsize
+   with anything. The former prevents
 
-   def validate(self, search, replace):
-      dst_class = self._propagate_bit_size_up(search)
-      if dst_class == 0:
-         dst_class = self._new_class()
-      self._propagate_bit_class_down(search, dst_class)
-
-      validate_dst_class = self._validate_bit_class_up(replace)
-      assert validate_dst_class == 0 or validate_dst_class == dst_class
-      self._validate_bit_class_down(replace, dst_class)
-
-   def _new_class(self):
-      self._num_classes += 1
-      return -self._num_classes
-
-   def _set_var_bit_class(self, var_id, bit_class):
-      assert bit_class != 0
-      var_class = self._var_classes[var_id]
-      if var_class == 0:
-         self._var_classes[var_id] = bit_class
-      else:
-         canon_class = self._class_relation.get_canonical(var_class)
-         assert canon_class < 0 or canon_class == bit_class
-         var_class = self._class_relation.add_equiv(var_class, bit_class)
-         self._var_classes[var_id] = var_class
+   (('bcsel', a, b, 0), ('iand', a, b))
 
-   def _get_var_bit_class(self, var_id):
-      return self._class_relation.get_canonical(self._var_classes[var_id])
+   from being allowed, since we'd have to merge the bitsizes for a and b due to
+   the 'iand', while the latter prevents
 
-   def _propagate_bit_size_up(self, val):
-      if isinstance(val, (Constant, Variable)):
-         return val.bit_size
+   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
 
-      elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         val.common_size = 0
-         for i in range(nir_op.num_inputs):
-            src_bits = self._propagate_bit_size_up(val.sources[i])
-            if src_bits == 0:
-               continue
+   from being allowed, since the search expression has the bit size of a and b,
+   which can't be specialized to 32 which is the bitsize of the replace
+   expression. It also prevents something like:
 
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               assert src_bits == src_type_bits
-            else:
-               assert val.common_size == 0 or src_bits == val.common_size
-               val.common_size = src_bits
+   (('b2i', ('i2b', a)), ('ineq', a, 0))
 
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert val.bit_size == 0 or val.bit_size == dst_type_bits
-            return dst_type_bits
-         else:
-            if val.common_size != 0:
-               assert val.bit_size == 0 or val.bit_size == val.common_size
-            else:
-               val.common_size = val.bit_size
-            return val.common_size
-
-   def _propagate_bit_class_down(self, val, bit_class):
-      if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class
+   since the bitsize of 'b2i', which can be anything, can't be specialized to
+   the bitsize of a.
 
-      elif isinstance(val, Variable):
-         assert val.bit_size == 0 or val.bit_size == bit_class
-         self._set_var_bit_class(val.index, bit_class)
+   After doing all this, we check that every subexpression of the replacement
+   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
+   of the search expresssion, since those are the things that are known when
+   constructing the replacement expresssion. Finally, we record the bitsize
+   needed in nir_search_value so that we know what to do when building the
+   replacement expression.
+   """
 
-      elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert bit_class == 0 or bit_class == dst_type_bits
+   def __init__(self, varset):
+      self._var_classes = [None] * len(varset.names)
+
+   def compare_bitsizes(self, a, b):
+      """Determines which bitsize class is a specialization of the other, or
+      whether neither is. When we merge two different bitsizes, the
+      less-specialized bitsize always points to the more-specialized one, so
+      that calling get_bit_size() always gets you the most specialized bitsize.
+      The specialization partial order is given by:
+      - Physical bitsizes are always the most specialized, and a different
+        bitsize can never specialize another.
+      - In the search expression, variables can always be specialized to each
+        other and to physical bitsizes. In the replace expression, we disallow
+        this to avoid adding extra constraints to the search expression that
+        the user didn't specify.
+      - Expressions and constants without a bitsize can always be specialized to
+        each other and variables, but not the other way around.
+
+        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
+        and None if they are not comparable (neither a <= b nor b <= a).
+      """
+      if isinstance(a, int):
+         if isinstance(b, int):
+            return 0 if a == b else None
+         elif isinstance(b, Variable):
+            return -1 if self.is_search else None
          else:
-            assert val.common_size == 0 or val.common_size == bit_class
-            val.common_size = bit_class
-
-         if val.common_size:
-            common_class = val.common_size
-         elif nir_op.num_inputs:
-            # If we got here then we have no idea what the actual size is.
-            # Instead, we use a generic class
-            common_class = self._new_class()
-
-         for i in range(nir_op.num_inputs):
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               self._propagate_bit_class_down(val.sources[i], src_type_bits)
-            else:
-               self._propagate_bit_class_down(val.sources[i], common_class)
-
-   def _validate_bit_class_up(self, val):
-      if isinstance(val, Constant):
-         return val.bit_size
-
-      elif isinstance(val, Variable):
-         var_class = self._get_var_bit_class(val.index)
-         # By the time we get to validation, every variable should have a class
-         assert var_class != 0
-
-         # If we have an explicit size provided by the user, the variable
-         # *must* exactly match the search.  It cannot be implicitly sized
-         # because otherwise we could end up with a conflict at runtime.
-         assert val.bit_size == 0 or val.bit_size == var_class
-
-         return var_class
-
+            return -1
+      elif isinstance(a, Variable):
+         if isinstance(b, int):
+            return 1 if self.is_search else None
+         elif isinstance(b, Variable):
+            return 0 if self.is_search or a.index == b.index else None
+         else:
+            return -1
+      else:
+         if isinstance(b, int):
+            return 1
+         elif isinstance(b, Variable):
+            return 1
+         else:
+            return 0
+
+   def unify_bit_size(self, a, b, error_msg):
+      """Record that a must have the same bit-size as b. If both
+      have been assigned conflicting physical bit-sizes, call "error_msg" with
+      the bit-sizes of self and other to get a message and raise an error.
+      In the replace expression, disallow merging variables with other
+      variables and physical bit-sizes as well.
+      """
+      a_bit_size = a.get_bit_size()
+      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
+
+      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
+
+      assert cmp_result is not None, \
+         error_msg(a_bit_size, b_bit_size)
+
+      if cmp_result < 0:
+         b_bit_size.set_bit_size(a)
+      elif not isinstance(a_bit_size, int):
+         a_bit_size.set_bit_size(b)
+
+   def merge_variables(self, val):
+      """Perform the first part of type inference by merging all the different
+      uses of the same variable. We always do this as if we're in the search
+      expression, even if we're actually not, since otherwise we'd get errors
+      if the search expression specified some constraint but the replace
+      expression didn't, because we'd be merging a variable and a constant.
+      """
+      if isinstance(val, Variable):
+         if self._var_classes[val.index] is None:
+            self._var_classes[val.index] = val
+         else:
+            other = self._var_classes[val.index]
+            self.unify_bit_size(other, val,
+                  lambda other_bit_size, bit_size:
+                     'Variable {} has conflicting bit size requirements: ' \
+                     'it must have bit size {} and {}'.format(
+                        val.var_name, other_bit_size, bit_size))
       elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         val.common_class = 0
-         for i in range(nir_op.num_inputs):
-            src_class = self._validate_bit_class_up(val.sources[i])
-            if src_class == 0:
+         for src in val.sources:
+            self.merge_variables(src)
+
+   def validate_value(self, val):
+      """Validate the an expression by performing classic Hindley-Milner
+      type inference on bitsizes. This will detect if there are any conflicting
+      requirements, and unify variables so that we know which variables must
+      have the same bitsize. If we're operating on the replace expression, we
+      will refuse to merge different variables together or merge a variable
+      with a constant, in order to prevent surprises due to rules unexpectedly
+      not matching at runtime.
+      """
+      if not isinstance(val, Expression):
+         return
+
+      # Generic conversion ops are special in that they have a single unsized
+      # source and an unsized destination and the two don't have to match.
+      # This means there's no validation or unioning to do here besides the
+      # len(val.sources) check.
+      if val.opcode in conv_opcode_types:
+         assert len(val.sources) == 1, \
+            "Expression {} has {} sources, expected 1".format(
+               val, len(val.sources))
+         self.validate_value(val.sources[0])
+         return
+
+      nir_op = opcodes[val.opcode]
+      assert len(val.sources) == nir_op.num_inputs, \
+         "Expression {} has {} sources, expected {}".format(
+            val, len(val.sources), nir_op.num_inputs)
+
+      for src in val.sources:
+         self.validate_value(src)
+
+      dst_type_bits = type_bits(nir_op.output_type)
+
+      # First, unify all the sources. That way, an error coming up because two
+      # sources have an incompatible bit-size won't produce an error message
+      # involving the destination.
+      first_unsized_src = None
+      for src_type, src in zip(nir_op.input_types, val.sources):
+         src_type_bits = type_bits(src_type)
+         if src_type_bits == 0:
+            if first_unsized_src is None:
+               first_unsized_src = src
                continue
 
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               assert src_class == src_type_bits
+            if self.is_search:
+               self.unify_bit_size(first_unsized_src, src,
+                  lambda first_unsized_src_bit_size, src_bit_size:
+                     'Source {} of {} must have bit size {}, while source {} ' \
+                     'must have incompatible bit size {}'.format(
+                        first_unsized_src, val, first_unsized_src_bit_size,
+                        src, src_bit_size))
             else:
-               assert val.common_class == 0 or src_class == val.common_class
-               val.common_class = src_class
-
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert val.bit_size == 0 or val.bit_size == dst_type_bits
-            return dst_type_bits
+               self.unify_bit_size(first_unsized_src, src,
+                  lambda first_unsized_src_bit_size, src_bit_size:
+                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
+                     'of {} may not have the same bit size when building the ' \
+                     'replacement expression.'.format(
+                        first_unsized_src, first_unsized_src_bit_size, src,
+                        src_bit_size, val))
          else:
-            if val.common_class != 0:
-               assert val.bit_size == 0 or val.bit_size == val.common_class
+            if self.is_search:
+               self.unify_bit_size(src, src_type_bits,
+                  lambda src_bit_size, unused:
+                     '{} must have {} bits, but as a source of nir_op_{} '\
+                     'it must have {} bits'.format(
+                        src, src_bit_size, nir_op.name, src_type_bits))
+            else:
+               self.unify_bit_size(src, src_type_bits,
+                  lambda src_bit_size, unused:
+                     '{} has the bit size of {}, but as a source of ' \
+                     'nir_op_{} it must have {} bits, which may not be the ' \
+                     'same'.format(
+                        src, src_bit_size, nir_op.name, src_type_bits))
+
+      if dst_type_bits == 0:
+         if first_unsized_src is not None:
+            if self.is_search:
+               self.unify_bit_size(val, first_unsized_src,
+                  lambda val_bit_size, src_bit_size:
+                     '{} must have the bit size of {}, while its source {} ' \
+                     'must have incompatible bit size {}'.format(
+                        val, val_bit_size, first_unsized_src, src_bit_size))
             else:
-               val.common_class = val.bit_size
-            return val.common_class
+               self.unify_bit_size(val, first_unsized_src,
+                  lambda val_bit_size, src_bit_size:
+                     '{} must have {} bits, but its source {} ' \
+                     '(bit size of {}) may not have that bit size ' \
+                     'when building the replacement.'.format(
+                        val, val_bit_size, first_unsized_src, src_bit_size))
+      else:
+         self.unify_bit_size(val, dst_type_bits,
+            lambda dst_bit_size, unused:
+               '{} must have {} bits, but as a destination of nir_op_{} ' \
+               'it must have {} bits'.format(
+                  val, dst_bit_size, nir_op.name, dst_type_bits))
+
+   def validate_replace(self, val, search):
+      bit_size = val.get_bit_size()
+      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
+            bit_size == search.get_bit_size(), \
+            'Ambiguous bit size for replacement value {}: ' \
+            'it cannot be deduced from a variable, a fixed bit size ' \
+            'somewhere, or the search expression.'.format(val)
+
+      if isinstance(val, Expression):
+         for src in val.sources:
+            self.validate_replace(src, search)
 
-   def _validate_bit_class_down(self, val, bit_class):
-      # At this point, everything *must* have a bit class.  Otherwise, we have
-      # a value we don't know how to define.
-      assert bit_class != 0
+   def validate(self, search, replace):
+      self.is_search = True
+      self.merge_variables(search)
+      self.merge_variables(replace)
+      self.validate_value(search)
 
-      if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class
+      self.is_search = False
+      self.validate_value(replace)
 
-      elif isinstance(val, Variable):
-         assert val.bit_size == 0 or val.bit_size == bit_class
+      # Check that search is always more specialized than replace. Note that
+      # we're doing this in replace mode, disallowing merging variables.
+      search_bit_size = search.get_bit_size()
+      replace_bit_size = replace.get_bit_size()
+      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
 
-      elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert bit_class == dst_type_bits
-         else:
-            assert val.common_class == 0 or val.common_class == bit_class
-            val.common_class = bit_class
+      assert cmp_result is not None and cmp_result <= 0, \
+         'The search expression bit size {} and replace expression ' \
+         'bit size {} may not be the same'.format(
+               search_bit_size, replace_bit_size)
 
-         for i in range(nir_op.num_inputs):
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               self._validate_bit_class_down(val.sources[i], src_type_bits)
-            else:
-               self._validate_bit_class_down(val.sources[i], val.common_class)
+      replace.set_bit_size(search)
+
+      self.validate_replace(replace, search)
 
 _optimization_ids = itertools.count()
 
@@ -526,6 +656,7 @@ class SearchAndReplace(object):
 
 _algebraic_pass_template = mako.template.Template("""
 #include "nir.h"
+#include "nir_builder.h"
 #include "nir_search.h"
 #include "nir_search_helpers.h"
 
@@ -540,12 +671,12 @@ struct transform {
 
 #endif
 
-% for (opcode, xform_list) in xform_dict.items():
-% for xform in xform_list:
+% for xform in xforms:
    ${xform.search.render()}
    ${xform.replace.render()}
 % endfor
 
+% for (opcode, xform_list) in sorted(opcode_xforms.items()):
 static const struct transform ${pass_name}_${opcode}_xforms[] = {
 % for xform in xform_list:
    { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
@@ -554,8 +685,8 @@ static const struct transform ${pass_name}_${opcode}_xforms[] = {
 % endfor
 
 static bool
-${pass_name}_block(nir_block *block, const bool *condition_flags,
-                   void *mem_ctx)
+${pass_name}_block(nir_builder *build, nir_block *block,
+                   const bool *condition_flags)
 {
    bool progress = false;
 
@@ -568,13 +699,12 @@ ${pass_name}_block(nir_block *block, const bool *condition_flags,
          continue;
 
       switch (alu->op) {
-      % for opcode in xform_dict.keys():
+      % for opcode in sorted(opcode_xforms.keys()):
       case nir_op_${opcode}:
          for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
             const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
             if (condition_flags[xform->condition_offset] &&
-                nir_replace_instr(alu, xform->search, xform->replace,
-                                  mem_ctx)) {
+                nir_replace_instr(build, alu, xform->search, xform->replace)) {
                progress = true;
                break;
             }
@@ -592,16 +722,23 @@ ${pass_name}_block(nir_block *block, const bool *condition_flags,
 static bool
 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
 {
-   void *mem_ctx = ralloc_parent(impl);
    bool progress = false;
 
+   nir_builder build;
+   nir_builder_init(&build, impl);
+
    nir_foreach_block_reverse(block, impl) {
-      progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
+      progress |= ${pass_name}_block(&build, block, condition_flags);
    }
 
-   if (progress)
+   if (progress) {
       nir_metadata_preserve(impl, nir_metadata_block_index |
                                   nir_metadata_dominance);
+    } else {
+#ifndef NDEBUG
+      impl->valid_metadata &= ~nir_metadata_not_properly_reset;
+#endif
+    }
 
    return progress;
 }
@@ -630,7 +767,8 @@ ${pass_name}(nir_shader *shader)
 
 class AlgebraicPass(object):
    def __init__(self, pass_name, transforms):
-      self.xform_dict = OrderedDict()
+      self.xforms = []
+      self.opcode_xforms = defaultdict(lambda : [])
       self.pass_name = pass_name
 
       error = False
@@ -647,15 +785,21 @@ class AlgebraicPass(object):
                error = True
                continue
 
-         if xform.search.opcode not in self.xform_dict:
-            self.xform_dict[xform.search.opcode] = []
-
-         self.xform_dict[xform.search.opcode].append(xform)
+         self.xforms.append(xform)
+         if xform.search.opcode in conv_opcode_types:
+            dst_type = conv_opcode_types[xform.search.opcode]
+            for size in type_sizes(dst_type):
+               sized_opcode = xform.search.opcode + str(size)
+               self.opcode_xforms[sized_opcode].append(xform)
+         else:
+            self.opcode_xforms[xform.search.opcode].append(xform)
 
       if error:
          sys.exit(1)
 
+
    def render(self):
       return _algebraic_pass_template.render(pass_name=self.pass_name,
-                                             xform_dict=self.xform_dict,
+                                             xforms=self.xforms,
+                                             opcode_xforms=self.opcode_xforms,
                                              condition_list=condition_list)