nir: Add a new nir_var_mem_constant variable mode
[mesa.git] / src / compiler / nir / nir_algebraic.py
index 8c7fbc819476abfad37f04bc9fe8e7e745597768..6871af36ef8f06ace44749b0bbd63ad2d9ea67f2 100644 (file)
@@ -1,4 +1,3 @@
-#! /usr/bin/env python
 #
 # Copyright (C) 2014 Intel Corporation
 #
@@ -26,6 +25,7 @@
 
 from __future__ import print_function
 import ast
+from collections import defaultdict
 import itertools
 import struct
 import sys
@@ -33,7 +33,41 @@ import mako.template
 import re
 import traceback
 
-from nir_opcodes import opcodes
+from nir_opcodes import opcodes, type_sizes
+
+# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
+nir_search_max_comm_ops = 8
+
+# These opcodes are only employed by nir_search.  This provides a mapping from
+# opcode to destination type.
+conv_opcode_types = {
+    'i2f' : 'float',
+    'u2f' : 'float',
+    'f2f' : 'float',
+    'f2u' : 'uint',
+    'f2i' : 'int',
+    'u2u' : 'uint',
+    'i2i' : 'int',
+    'b2f' : 'float',
+    'b2i' : 'int',
+    'i2b' : 'bool',
+    'f2b' : 'bool',
+}
+
+def get_c_opcode(op):
+      if op in conv_opcode_types:
+         return 'nir_search_op_' + op
+      else:
+         return 'nir_op_' + op
+
+
+if sys.version_info < (3, 0):
+    integer_types = (int, long)
+    string_type = unicode
+
+else:
+    integer_types = (int, )
+    string_type = str
 
 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
 
@@ -56,7 +90,7 @@ class VarSet(object):
    def __getitem__(self, name):
       if name not in self.names:
          assert not self.immutable, "Unknown replacement variable: " + name
-         self.names[name] = self.ids.next()
+         self.names[name] = next(self.ids)
 
       return self.names[name]
 
@@ -66,37 +100,60 @@ class VarSet(object):
 class Value(object):
    @staticmethod
    def create(val, name_base, varset):
+      if isinstance(val, bytes):
+         val = val.decode('utf-8')
+
       if isinstance(val, tuple):
          return Expression(val, name_base, varset)
       elif isinstance(val, Expression):
          return val
-      elif isinstance(val, (str, unicode)):
+      elif isinstance(val, string_type):
          return Variable(val, name_base, varset)
-      elif isinstance(val, (bool, int, long, float)):
+      elif isinstance(val, (bool, float) + integer_types):
          return Constant(val, name_base)
 
-   __template = mako.template.Template("""
-static const ${val.c_type} ${val.name} = {
-   { ${val.type_enum}, ${val.bit_size} },
-% if isinstance(val, Constant):
-   ${val.type()}, { ${hex(val)} /* ${val.value} */ },
-% elif isinstance(val, Variable):
-   ${val.index}, /* ${val.var_name} */
-   ${'true' if val.is_constant else 'false'},
-   ${val.type() or 'nir_type_invalid' },
-   ${val.cond if val.cond else 'NULL'},
-% elif isinstance(val, Expression):
-   ${'true' if val.inexact else 'false'},
-   nir_op_${val.opcode},
-   { ${', '.join(src.c_ptr for src in val.sources)} },
-   ${val.cond if val.cond else 'NULL'},
-% endif
-};""")
-
-   def __init__(self, name, type_str):
+   def __init__(self, val, name, type_str):
+      self.in_val = str(val)
       self.name = name
       self.type_str = type_str
 
+   def __str__(self):
+      return self.in_val
+
+   def get_bit_size(self):
+      """Get the physical bit-size that has been chosen for this value, or if
+      there is none, the canonical value which currently represents this
+      bit-size class. Variables will be preferred, i.e. if there are any
+      variables in the equivalence class, the canonical value will be a
+      variable. We do this since we'll need to know which variable each value
+      is equivalent to when constructing the replacement expression. This is
+      the "find" part of the union-find algorithm.
+      """
+      bit_size = self
+
+      while isinstance(bit_size, Value):
+         if bit_size._bit_size is None:
+            break
+         bit_size = bit_size._bit_size
+
+      if bit_size is not self:
+         self._bit_size = bit_size
+      return bit_size
+
+   def set_bit_size(self, other):
+      """Make self.get_bit_size() return what other.get_bit_size() return
+      before calling this, or just "other" if it's a concrete bit-size. This is
+      the "union" part of the union-find algorithm.
+      """
+
+      self_bit_size = self.get_bit_size()
+      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
+
+      if self_bit_size == other_bit_size:
+         return
+
+      self_bit_size._bit_size = other_bit_size
+
    @property
    def type_enum(self):
       return "nir_search_value_" + self.type_str
@@ -105,72 +162,163 @@ static const ${val.c_type} ${val.name} = {
    def c_type(self):
       return "nir_search_" + self.type_str
 
+   def __c_name(self, cache):
+      if cache is not None and self.name in cache:
+         return cache[self.name]
+      else:
+         return self.name
+
+   def c_value_ptr(self, cache):
+      return "&{0}.value".format(self.__c_name(cache))
+
+   def c_ptr(self, cache):
+      return "&{0}".format(self.__c_name(cache))
+
    @property
-   def c_ptr(self):
-      return "&{0}.value".format(self.name)
+   def c_bit_size(self):
+      bit_size = self.get_bit_size()
+      if isinstance(bit_size, int):
+         return bit_size
+      elif isinstance(bit_size, Variable):
+         return -bit_size.index - 1
+      else:
+         # If the bit-size class is neither a variable, nor an actual bit-size, then
+         # - If it's in the search expression, we don't need to check anything
+         # - If it's in the replace expression, either it's ambiguous (in which
+         # case we'd reject it), or it equals the bit-size of the search value
+         # We represent these cases with a 0 bit-size.
+         return 0
+
+   __template = mako.template.Template("""{
+   { ${val.type_enum}, ${val.c_bit_size} },
+% if isinstance(val, Constant):
+   ${val.type()}, { ${val.hex()} /* ${val.value} */ },
+% elif isinstance(val, Variable):
+   ${val.index}, /* ${val.var_name} */
+   ${'true' if val.is_constant else 'false'},
+   ${val.type() or 'nir_type_invalid' },
+   ${val.cond if val.cond else 'NULL'},
+   ${val.swizzle()},
+% elif isinstance(val, Expression):
+   ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
+   ${val.comm_expr_idx}, ${val.comm_exprs},
+   ${val.c_opcode()},
+   { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
+   ${val.cond if val.cond else 'NULL'},
+% endif
+};""")
 
-   def render(self):
-      return self.__template.render(val=self,
-                                    Constant=Constant,
-                                    Variable=Variable,
-                                    Expression=Expression)
+   def render(self, cache):
+      struct_init = self.__template.render(val=self, cache=cache,
+                                           Constant=Constant,
+                                           Variable=Variable,
+                                           Expression=Expression)
+      if cache is not None and struct_init in cache:
+         # If it's in the cache, register a name remap in the cache and render
+         # only a comment saying it's been remapped
+         cache[self.name] = cache[struct_init]
+         return "/* {} -> {} in the cache */\n".format(self.name,
+                                                       cache[struct_init])
+      else:
+         if cache is not None:
+            cache[struct_init] = self.name
+         return "static const {} {} = {}\n".format(self.c_type, self.name,
+                                                   struct_init)
 
 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
 
 class Constant(Value):
    def __init__(self, val, name):
-      Value.__init__(self, name, "constant")
+      Value.__init__(self, val, name, "constant")
 
       if isinstance(val, (str)):
          m = _constant_re.match(val)
          self.value = ast.literal_eval(m.group('value'))
-         self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+         self._bit_size = int(m.group('bits')) if m.group('bits') else None
       else:
          self.value = val
-         self.bit_size = 0
+         self._bit_size = None
 
       if isinstance(self.value, bool):
-         assert self.bit_size == 0 or self.bit_size == 32
-         self.bit_size = 32
+         assert self._bit_size is None or self._bit_size == 1
+         self._bit_size = 1
 
-   def __hex__(self):
+   def hex(self):
       if isinstance(self.value, (bool)):
          return 'NIR_TRUE' if self.value else 'NIR_FALSE'
-      if isinstance(self.value, (int, long)):
+      if isinstance(self.value, integer_types):
          return hex(self.value)
       elif isinstance(self.value, float):
-         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
+         i = struct.unpack('Q', struct.pack('d', self.value))[0]
+         h = hex(i)
+
+         # On Python 2 this 'L' suffix is automatically added, but not on Python 3
+         # Adding it explicitly makes the generated file identical, regardless
+         # of the Python version running this script.
+         if h[-1] != 'L' and i > sys.maxsize:
+            h += 'L'
+
+         return h
       else:
          assert False
 
    def type(self):
       if isinstance(self.value, (bool)):
-         return "nir_type_bool32"
-      elif isinstance(self.value, (int, long)):
+         return "nir_type_bool"
+      elif isinstance(self.value, integer_types):
          return "nir_type_int"
       elif isinstance(self.value, float):
          return "nir_type_float"
 
+   def equivalent(self, other):
+      """Check that two constants are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      return self.value == other.value
+
+# The $ at the end forces there to be an error if any part of the string
+# doesn't match one of the field patterns.
 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
                           r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
-                          r"(?P<cond>\([^\)]+\))?")
+                          r"(?P<cond>\([^\)]+\))?"
+                          r"(?P<swiz>\.[xyzw]+)?"
+                          r"$")
 
 class Variable(Value):
    def __init__(self, val, name, varset):
-      Value.__init__(self, name, "variable")
+      Value.__init__(self, val, name, "variable")
 
       m = _var_name_re.match(val)
-      assert m and m.group('name') is not None
+      assert m and m.group('name') is not None, \
+            "Malformed variable name \"{}\".".format(val)
 
       self.var_name = m.group('name')
+
+      # Prevent common cases where someone puts quotes around a literal
+      # constant.  If we want to support names that have numeric or
+      # punctuation characters, we can me the first assertion more flexible.
+      assert self.var_name.isalpha()
+      assert self.var_name != 'True'
+      assert self.var_name != 'False'
+
       self.is_constant = m.group('const') is not None
       self.cond = m.group('cond')
       self.required_type = m.group('type')
-      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+      self._bit_size = int(m.group('bits')) if m.group('bits') else None
+      self.swiz = m.group('swiz')
 
       if self.required_type == 'bool':
-         assert self.bit_size == 0 or self.bit_size == 32
-         self.bit_size = 32
+         if self._bit_size is not None:
+            assert self._bit_size in type_sizes(self.required_type)
+         else:
+            self._bit_size = 1
 
       if self.required_type is not None:
          assert self.required_type in ('float', 'bool', 'int', 'uint')
@@ -179,67 +327,131 @@ class Variable(Value):
 
    def type(self):
       if self.required_type == 'bool':
-         return "nir_type_bool32"
+         return "nir_type_bool"
       elif self.required_type in ('int', 'uint'):
          return "nir_type_int"
       elif self.required_type == 'float':
          return "nir_type_float"
 
-_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
+   def equivalent(self, other):
+      """Check that two variables are equivalent.
+
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      return self.index == other.index
+
+   def swizzle(self):
+      if self.swiz is not None:
+         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
+                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
+                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
+                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
+                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
+         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
+      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
+
+_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
                         r"(?P<cond>\([^\)]+\))?")
 
 class Expression(Value):
    def __init__(self, expr, name_base, varset):
-      Value.__init__(self, name_base, "expression")
+      Value.__init__(self, expr, name_base, "expression")
       assert isinstance(expr, tuple)
 
       m = _opcode_re.match(expr[0])
       assert m and m.group('opcode') is not None
 
       self.opcode = m.group('opcode')
-      self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+      self._bit_size = int(m.group('bits')) if m.group('bits') else None
       self.inexact = m.group('inexact') is not None
+      self.exact = m.group('exact') is not None
       self.cond = m.group('cond')
+
+      assert not self.inexact or not self.exact, \
+            'Expression cannot be both exact and inexact.'
+
+      # "many-comm-expr" isn't really a condition.  It's notification to the
+      # generator that this pattern is known to have too many commutative
+      # expressions, and an error should not be generated for this case.
+      self.many_commutative_expressions = False
+      if self.cond and self.cond.find("many-comm-expr") >= 0:
+         # Split the condition into a comma-separated list.  Remove
+         # "many-comm-expr".  If there is anything left, put it back together.
+         c = self.cond[1:-1].split(",")
+         c.remove("many-comm-expr")
+
+         self.cond = "({})".format(",".join(c)) if c else None
+         self.many_commutative_expressions = True
+
       self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
                        for (i, src) in enumerate(expr[1:]) ]
 
-   def render(self):
-      srcs = "\n".join(src.render() for src in self.sources)
-      return srcs + super(Expression, self).render()
+      if self.opcode in conv_opcode_types:
+         assert self._bit_size is None, \
+                'Expression cannot use an unsized conversion opcode with ' \
+                'an explicit size; that\'s silly.'
 
-class IntEquivalenceRelation(object):
-   """A class representing an equivalence relation on integers.
+      self.__index_comm_exprs(0)
 
-   Each integer has a canonical form which is the maximum integer to which it
-   is equivalent.  Two integers are equivalent precisely when they have the
-   same canonical form.
+   def equivalent(self, other):
+      """Check that two variables are equivalent.
 
-   The convention of maximum is explicitly chosen to make using it in
-   BitSizeValidator easier because it means that an actual bit_size (if any)
-   will always be the canonical form.
-   """
-   def __init__(self):
-      self._remap = {}
+      This is check is much weaker than equality.  One generally cannot be
+      used in place of the other.  Using this implementation for the __eq__
+      will break BitSizeValidator.
+
+      This implementation does not check for equivalence due to commutativity,
+      but it could.
+
+      """
+      if not isinstance(other, type(self)):
+         return False
+
+      if len(self.sources) != len(other.sources):
+         return False
+
+      if self.opcode != other.opcode:
+         return False
 
-   def get_canonical(self, x):
-      """Get the canonical integer corresponding to x."""
-      if x in self._remap:
-         return self.get_canonical(self._remap[x])
+      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
+   def __index_comm_exprs(self, base_idx):
+      """Recursively count and index commutative expressions
+      """
+      self.comm_exprs = 0
+
+      # A note about the explicit "len(self.sources)" check. The list of
+      # sources comes from user input, and that input might be bad.  Check
+      # that the expected second source exists before accessing it. Without
+      # this check, a unit test that does "('iadd', 'a')" will crash.
+      if self.opcode not in conv_opcode_types and \
+         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+         len(self.sources) >= 2 and \
+         not self.sources[0].equivalent(self.sources[1]):
+         self.comm_expr_idx = base_idx
+         self.comm_exprs += 1
       else:
-         return x
+         self.comm_expr_idx = -1
+
+      for s in self.sources:
+         if isinstance(s, Expression):
+            s.__index_comm_exprs(base_idx + self.comm_exprs)
+            self.comm_exprs += s.comm_exprs
 
-   def add_equiv(self, a, b):
-      """Add an equivalence and return the canonical form."""
-      c = max(self.get_canonical(a), self.get_canonical(b))
-      if a != c:
-         assert a < c
-         self._remap[a] = c
+      return self.comm_exprs
 
-      if b != c:
-         assert b < c
-         self._remap[b] = c
+   def c_opcode(self):
+      return get_c_opcode(self.opcode)
 
-      return c
+   def render(self, cache):
+      srcs = "\n".join(src.render(cache) for src in self.sources)
+      return srcs + super(Expression, self).render(cache)
 
 class BitSizeValidator(object):
    """A class for validating bit sizes of expressions.
@@ -268,7 +480,7 @@ class BitSizeValidator(object):
    inference can be ambiguous or contradictory.  Consider, for instance, the
    following transformation:
 
-   (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
+   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
 
    This transformation can potentially cause a problem because usub_borrow is
    well-defined for any bit-size of integer.  However, b2i always generates a
@@ -287,180 +499,261 @@ class BitSizeValidator(object):
    generate any code.  This ensures that bugs are caught at compile time
    rather than at run time.
 
-   The basic operation of the validator is very similar to the bitsize_tree in
-   nir_search only a little more subtle.  Instead of simply tracking bit
-   sizes, it tracks "bit classes" where each class is represented by an
-   integer.  A value of 0 means we don't know anything yet, positive values
-   are actual bit-sizes, and negative values are used to track equivalence
-   classes of sizes that must be the same but have yet to receive an actual
-   size.  The first stage uses the bitsize_tree algorithm to assign bit
-   classes to each variable.  If it ever comes across an inconsistency, it
-   assert-fails.  Then the second stage uses that information to prove that
-   the resulting expression can always validly be constructed.
-   """
-
-   def __init__(self, varset):
-      self._num_classes = 0
-      self._var_classes = [0] * len(varset.names)
-      self._class_relation = IntEquivalenceRelation()
-
-   def validate(self, search, replace):
-      dst_class = self._propagate_bit_size_up(search)
-      if dst_class == 0:
-         dst_class = self._new_class()
-      self._propagate_bit_class_down(search, dst_class)
-
-      validate_dst_class = self._validate_bit_class_up(replace)
-      assert validate_dst_class == 0 or validate_dst_class == dst_class
-      self._validate_bit_class_down(replace, dst_class)
-
-   def _new_class(self):
-      self._num_classes += 1
-      return -self._num_classes
-
-   def _set_var_bit_class(self, var_id, bit_class):
-      assert bit_class != 0
-      var_class = self._var_classes[var_id]
-      if var_class == 0:
-         self._var_classes[var_id] = bit_class
-      else:
-         canon_class = self._class_relation.get_canonical(var_class)
-         assert canon_class < 0 or canon_class == bit_class
-         var_class = self._class_relation.add_equiv(var_class, bit_class)
-         self._var_classes[var_id] = var_class
+   Each value maintains a "bit-size class", which is either an actual bit size
+   or an equivalence class with other values that must have the same bit size.
+   The validator works by combining bit-size classes with each other according
+   to the NIR rules outlined above, checking that there are no inconsistencies.
+   When doing this for the replacement expression, we make sure to never change
+   the equivalence class of any of the search values. We could make the example
+   transforms above work by doing some extra run-time checking of the search
+   expression, but we make the user specify those constraints themselves, to
+   avoid any surprises. Since the replacement bitsizes can only be connected to
+   the source bitsize via variables (variables must have the same bitsize in
+   the source and replacment expressions) or the roots of the expression (the
+   replacement expression must produce the same bit size as the search
+   expression), we prevent merging a variable with anything when processing the
+   replacement expression, or specializing the search bitsize
+   with anything. The former prevents
 
-   def _get_var_bit_class(self, var_id):
-      return self._class_relation.get_canonical(self._var_classes[var_id])
+   (('bcsel', a, b, 0), ('iand', a, b))
 
-   def _propagate_bit_size_up(self, val):
-      if isinstance(val, (Constant, Variable)):
-         return val.bit_size
+   from being allowed, since we'd have to merge the bitsizes for a and b due to
+   the 'iand', while the latter prevents
 
-      elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         val.common_size = 0
-         for i in range(nir_op.num_inputs):
-            src_bits = self._propagate_bit_size_up(val.sources[i])
-            if src_bits == 0:
-               continue
+   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
 
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               assert src_bits == src_type_bits
-            else:
-               assert val.common_size == 0 or src_bits == val.common_size
-               val.common_size = src_bits
+   from being allowed, since the search expression has the bit size of a and b,
+   which can't be specialized to 32 which is the bitsize of the replace
+   expression. It also prevents something like:
 
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert val.bit_size == 0 or val.bit_size == dst_type_bits
-            return dst_type_bits
-         else:
-            if val.common_size != 0:
-               assert val.bit_size == 0 or val.bit_size == val.common_size
-            else:
-               val.common_size = val.bit_size
-            return val.common_size
+   (('b2i', ('i2b', a)), ('ineq', a, 0))
 
-   def _propagate_bit_class_down(self, val, bit_class):
-      if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class
+   since the bitsize of 'b2i', which can be anything, can't be specialized to
+   the bitsize of a.
 
-      elif isinstance(val, Variable):
-         assert val.bit_size == 0 or val.bit_size == bit_class
-         self._set_var_bit_class(val.index, bit_class)
+   After doing all this, we check that every subexpression of the replacement
+   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
+   of the search expresssion, since those are the things that are known when
+   constructing the replacement expresssion. Finally, we record the bitsize
+   needed in nir_search_value so that we know what to do when building the
+   replacement expression.
+   """
 
-      elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert bit_class == 0 or bit_class == dst_type_bits
+   def __init__(self, varset):
+      self._var_classes = [None] * len(varset.names)
+
+   def compare_bitsizes(self, a, b):
+      """Determines which bitsize class is a specialization of the other, or
+      whether neither is. When we merge two different bitsizes, the
+      less-specialized bitsize always points to the more-specialized one, so
+      that calling get_bit_size() always gets you the most specialized bitsize.
+      The specialization partial order is given by:
+      - Physical bitsizes are always the most specialized, and a different
+        bitsize can never specialize another.
+      - In the search expression, variables can always be specialized to each
+        other and to physical bitsizes. In the replace expression, we disallow
+        this to avoid adding extra constraints to the search expression that
+        the user didn't specify.
+      - Expressions and constants without a bitsize can always be specialized to
+        each other and variables, but not the other way around.
+
+        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
+        and None if they are not comparable (neither a <= b nor b <= a).
+      """
+      if isinstance(a, int):
+         if isinstance(b, int):
+            return 0 if a == b else None
+         elif isinstance(b, Variable):
+            return -1 if self.is_search else None
          else:
-            assert val.common_size == 0 or val.common_size == bit_class
-            val.common_size = bit_class
-
-         if val.common_size:
-            common_class = val.common_size
-         elif nir_op.num_inputs:
-            # If we got here then we have no idea what the actual size is.
-            # Instead, we use a generic class
-            common_class = self._new_class()
-
-         for i in range(nir_op.num_inputs):
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               self._propagate_bit_class_down(val.sources[i], src_type_bits)
-            else:
-               self._propagate_bit_class_down(val.sources[i], common_class)
-
-   def _validate_bit_class_up(self, val):
-      if isinstance(val, Constant):
-         return val.bit_size
-
-      elif isinstance(val, Variable):
-         var_class = self._get_var_bit_class(val.index)
-         # By the time we get to validation, every variable should have a class
-         assert var_class != 0
-
-         # If we have an explicit size provided by the user, the variable
-         # *must* exactly match the search.  It cannot be implicitly sized
-         # because otherwise we could end up with a conflict at runtime.
-         assert val.bit_size == 0 or val.bit_size == var_class
-
-         return var_class
-
+            return -1
+      elif isinstance(a, Variable):
+         if isinstance(b, int):
+            return 1 if self.is_search else None
+         elif isinstance(b, Variable):
+            return 0 if self.is_search or a.index == b.index else None
+         else:
+            return -1
+      else:
+         if isinstance(b, int):
+            return 1
+         elif isinstance(b, Variable):
+            return 1
+         else:
+            return 0
+
+   def unify_bit_size(self, a, b, error_msg):
+      """Record that a must have the same bit-size as b. If both
+      have been assigned conflicting physical bit-sizes, call "error_msg" with
+      the bit-sizes of self and other to get a message and raise an error.
+      In the replace expression, disallow merging variables with other
+      variables and physical bit-sizes as well.
+      """
+      a_bit_size = a.get_bit_size()
+      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
+
+      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
+
+      assert cmp_result is not None, \
+         error_msg(a_bit_size, b_bit_size)
+
+      if cmp_result < 0:
+         b_bit_size.set_bit_size(a)
+      elif not isinstance(a_bit_size, int):
+         a_bit_size.set_bit_size(b)
+
+   def merge_variables(self, val):
+      """Perform the first part of type inference by merging all the different
+      uses of the same variable. We always do this as if we're in the search
+      expression, even if we're actually not, since otherwise we'd get errors
+      if the search expression specified some constraint but the replace
+      expression didn't, because we'd be merging a variable and a constant.
+      """
+      if isinstance(val, Variable):
+         if self._var_classes[val.index] is None:
+            self._var_classes[val.index] = val
+         else:
+            other = self._var_classes[val.index]
+            self.unify_bit_size(other, val,
+                  lambda other_bit_size, bit_size:
+                     'Variable {} has conflicting bit size requirements: ' \
+                     'it must have bit size {} and {}'.format(
+                        val.var_name, other_bit_size, bit_size))
       elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         val.common_class = 0
-         for i in range(nir_op.num_inputs):
-            src_class = self._validate_bit_class_up(val.sources[i])
-            if src_class == 0:
+         for src in val.sources:
+            self.merge_variables(src)
+
+   def validate_value(self, val):
+      """Validate the an expression by performing classic Hindley-Milner
+      type inference on bitsizes. This will detect if there are any conflicting
+      requirements, and unify variables so that we know which variables must
+      have the same bitsize. If we're operating on the replace expression, we
+      will refuse to merge different variables together or merge a variable
+      with a constant, in order to prevent surprises due to rules unexpectedly
+      not matching at runtime.
+      """
+      if not isinstance(val, Expression):
+         return
+
+      # Generic conversion ops are special in that they have a single unsized
+      # source and an unsized destination and the two don't have to match.
+      # This means there's no validation or unioning to do here besides the
+      # len(val.sources) check.
+      if val.opcode in conv_opcode_types:
+         assert len(val.sources) == 1, \
+            "Expression {} has {} sources, expected 1".format(
+               val, len(val.sources))
+         self.validate_value(val.sources[0])
+         return
+
+      nir_op = opcodes[val.opcode]
+      assert len(val.sources) == nir_op.num_inputs, \
+         "Expression {} has {} sources, expected {}".format(
+            val, len(val.sources), nir_op.num_inputs)
+
+      for src in val.sources:
+         self.validate_value(src)
+
+      dst_type_bits = type_bits(nir_op.output_type)
+
+      # First, unify all the sources. That way, an error coming up because two
+      # sources have an incompatible bit-size won't produce an error message
+      # involving the destination.
+      first_unsized_src = None
+      for src_type, src in zip(nir_op.input_types, val.sources):
+         src_type_bits = type_bits(src_type)
+         if src_type_bits == 0:
+            if first_unsized_src is None:
+               first_unsized_src = src
                continue
 
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               assert src_class == src_type_bits
+            if self.is_search:
+               self.unify_bit_size(first_unsized_src, src,
+                  lambda first_unsized_src_bit_size, src_bit_size:
+                     'Source {} of {} must have bit size {}, while source {} ' \
+                     'must have incompatible bit size {}'.format(
+                        first_unsized_src, val, first_unsized_src_bit_size,
+                        src, src_bit_size))
             else:
-               assert val.common_class == 0 or src_class == val.common_class
-               val.common_class = src_class
-
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert val.bit_size == 0 or val.bit_size == dst_type_bits
-            return dst_type_bits
+               self.unify_bit_size(first_unsized_src, src,
+                  lambda first_unsized_src_bit_size, src_bit_size:
+                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
+                     'of {} may not have the same bit size when building the ' \
+                     'replacement expression.'.format(
+                        first_unsized_src, first_unsized_src_bit_size, src,
+                        src_bit_size, val))
          else:
-            if val.common_class != 0:
-               assert val.bit_size == 0 or val.bit_size == val.common_class
+            if self.is_search:
+               self.unify_bit_size(src, src_type_bits,
+                  lambda src_bit_size, unused:
+                     '{} must have {} bits, but as a source of nir_op_{} '\
+                     'it must have {} bits'.format(
+                        src, src_bit_size, nir_op.name, src_type_bits))
+            else:
+               self.unify_bit_size(src, src_type_bits,
+                  lambda src_bit_size, unused:
+                     '{} has the bit size of {}, but as a source of ' \
+                     'nir_op_{} it must have {} bits, which may not be the ' \
+                     'same'.format(
+                        src, src_bit_size, nir_op.name, src_type_bits))
+
+      if dst_type_bits == 0:
+         if first_unsized_src is not None:
+            if self.is_search:
+               self.unify_bit_size(val, first_unsized_src,
+                  lambda val_bit_size, src_bit_size:
+                     '{} must have the bit size of {}, while its source {} ' \
+                     'must have incompatible bit size {}'.format(
+                        val, val_bit_size, first_unsized_src, src_bit_size))
             else:
-               val.common_class = val.bit_size
-            return val.common_class
+               self.unify_bit_size(val, first_unsized_src,
+                  lambda val_bit_size, src_bit_size:
+                     '{} must have {} bits, but its source {} ' \
+                     '(bit size of {}) may not have that bit size ' \
+                     'when building the replacement.'.format(
+                        val, val_bit_size, first_unsized_src, src_bit_size))
+      else:
+         self.unify_bit_size(val, dst_type_bits,
+            lambda dst_bit_size, unused:
+               '{} must have {} bits, but as a destination of nir_op_{} ' \
+               'it must have {} bits'.format(
+                  val, dst_bit_size, nir_op.name, dst_type_bits))
+
+   def validate_replace(self, val, search):
+      bit_size = val.get_bit_size()
+      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
+            bit_size == search.get_bit_size(), \
+            'Ambiguous bit size for replacement value {}: ' \
+            'it cannot be deduced from a variable, a fixed bit size ' \
+            'somewhere, or the search expression.'.format(val)
+
+      if isinstance(val, Expression):
+         for src in val.sources:
+            self.validate_replace(src, search)
 
-   def _validate_bit_class_down(self, val, bit_class):
-      # At this point, everything *must* have a bit class.  Otherwise, we have
-      # a value we don't know how to define.
-      assert bit_class != 0
+   def validate(self, search, replace):
+      self.is_search = True
+      self.merge_variables(search)
+      self.merge_variables(replace)
+      self.validate_value(search)
 
-      if isinstance(val, Constant):
-         assert val.bit_size == 0 or val.bit_size == bit_class
+      self.is_search = False
+      self.validate_value(replace)
 
-      elif isinstance(val, Variable):
-         assert val.bit_size == 0 or val.bit_size == bit_class
+      # Check that search is always more specialized than replace. Note that
+      # we're doing this in replace mode, disallowing merging variables.
+      search_bit_size = search.get_bit_size()
+      replace_bit_size = replace.get_bit_size()
+      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
 
-      elif isinstance(val, Expression):
-         nir_op = opcodes[val.opcode]
-         dst_type_bits = type_bits(nir_op.output_type)
-         if dst_type_bits != 0:
-            assert bit_class == dst_type_bits
-         else:
-            assert val.common_class == 0 or val.common_class == bit_class
-            val.common_class = bit_class
+      assert cmp_result is not None and cmp_result <= 0, \
+         'The search expression bit size {} and replace expression ' \
+         'bit size {} may not be the same'.format(
+               search_bit_size, replace_bit_size)
 
-         for i in range(nir_op.num_inputs):
-            src_type_bits = type_bits(nir_op.input_types[i])
-            if src_type_bits != 0:
-               self._validate_bit_class_down(val.sources[i], src_type_bits)
-            else:
-               self._validate_bit_class_down(val.sources[i], val.common_class)
+      replace.set_bit_size(search)
+
+      self.validate_replace(replace, search)
 
 _optimization_ids = itertools.count()
 
@@ -468,7 +761,7 @@ condition_list = ['true']
 
 class SearchAndReplace(object):
    def __init__(self, transform):
-      self.id = _optimization_ids.next()
+      self.id = next(_optimization_ids)
 
       search = transform[0]
       replace = transform[1]
@@ -496,88 +789,338 @@ class SearchAndReplace(object):
 
       BitSizeValidator(varset).validate(self.search, self.replace)
 
+class TreeAutomaton(object):
+   """This class calculates a bottom-up tree automaton to quickly search for
+   the left-hand sides of tranforms. Tree automatons are a generalization of
+   classical NFA's and DFA's, where the transition function determines the
+   state of the parent node based on the state of its children. We construct a
+   deterministic automaton to match patterns, using a similar algorithm to the
+   classical NFA to DFA construction. At the moment, it only matches opcodes
+   and constants (without checking the actual value), leaving more detailed
+   checking to the search function which actually checks the leaves. The
+   automaton acts as a quick filter for the search function, requiring only n
+   + 1 table lookups for each n-source operation. The implementation is based
+   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
+   In the language of that reference, this is a frontier-to-root deterministic
+   automaton using only symbol filtering. The filtering is crucial to reduce
+   both the time taken to generate the tables and the size of the tables.
+   """
+   def __init__(self, transforms):
+      self.patterns = [t.search for t in transforms]
+      self._compute_items()
+      self._build_table()
+      #print('num items: {}'.format(len(set(self.items.values()))))
+      #print('num states: {}'.format(len(self.states)))
+      #for state, patterns in zip(self.states, self.patterns):
+      #   print('{}: num patterns: {}'.format(state, len(patterns)))
+
+   class IndexMap(object):
+      """An indexed list of objects, where one can either lookup an object by
+      index or find the index associated to an object quickly using a hash
+      table. Compared to a list, it has a constant time index(). Compared to a
+      set, it provides a stable iteration order.
+      """
+      def __init__(self, iterable=()):
+         self.objects = []
+         self.map = {}
+         for obj in iterable:
+            self.add(obj)
+
+      def __getitem__(self, i):
+         return self.objects[i]
+
+      def __contains__(self, obj):
+         return obj in self.map
+
+      def __len__(self):
+         return len(self.objects)
+
+      def __iter__(self):
+         return iter(self.objects)
+
+      def clear(self):
+         self.objects = []
+         self.map.clear()
+
+      def index(self, obj):
+         return self.map[obj]
+
+      def add(self, obj):
+         if obj in self.map:
+            return self.map[obj]
+         else:
+            index = len(self.objects)
+            self.objects.append(obj)
+            self.map[obj] = index
+            return index
+
+      def __repr__(self):
+         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
+
+   class Item(object):
+      """This represents an "item" in the language of "Tree Automatons." This
+      is just a subtree of some pattern, which represents a potential partial
+      match at runtime. We deduplicate them, so that identical subtrees of
+      different patterns share the same object, and store some extra
+      information needed for the main algorithm as well.
+      """
+      def __init__(self, opcode, children):
+         self.opcode = opcode
+         self.children = children
+         # These are the indices of patterns for which this item is the root node.
+         self.patterns = []
+         # This the set of opcodes for parents of this item. Used to speed up
+         # filtering.
+         self.parent_ops = set()
+
+      def __str__(self):
+         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
+
+      def __repr__(self):
+         return str(self)
+
+   def _compute_items(self):
+      """Build a set of all possible items, deduplicating them."""
+      # This is a map from (opcode, sources) to item.
+      self.items = {}
+
+      # The set of all opcodes used by the patterns. Used later to avoid
+      # building and emitting all the tables for opcodes that aren't used.
+      self.opcodes = self.IndexMap()
+
+      def get_item(opcode, children, pattern=None):
+         commutative = len(children) >= 2 \
+               and "2src_commutative" in opcodes[opcode].algebraic_properties
+         item = self.items.setdefault((opcode, children),
+                                      self.Item(opcode, children))
+         if commutative:
+            self.items[opcode, (children[1], children[0]) + children[2:]] = item
+         if pattern is not None:
+            item.patterns.append(pattern)
+         return item
+
+      self.wildcard = get_item("__wildcard", ())
+      self.const = get_item("__const", ())
+
+      def process_subpattern(src, pattern=None):
+         if isinstance(src, Constant):
+            # Note: we throw away the actual constant value!
+            return self.const
+         elif isinstance(src, Variable):
+            if src.is_constant:
+               return self.const
+            else:
+               # Note: we throw away which variable it is here! This special
+               # item is equivalent to nu in "Tree Automatons."
+               return self.wildcard
+         else:
+            assert isinstance(src, Expression)
+            opcode = src.opcode
+            stripped = opcode.rstrip('0123456789')
+            if stripped in conv_opcode_types:
+               # Matches that use conversion opcodes with a specific type,
+               # like f2b1, are tricky.  Either we construct the automaton to
+               # match specific NIR opcodes like nir_op_f2b1, in which case we
+               # need to create separate items for each possible NIR opcode
+               # for patterns that have a generic opcode like f2b, or we
+               # construct it to match the search opcode, in which case we
+               # need to map f2b1 to f2b when constructing the automaton. Here
+               # we do the latter.
+               opcode = stripped
+            self.opcodes.add(opcode)
+            children = tuple(process_subpattern(c) for c in src.sources)
+            item = get_item(opcode, children, pattern)
+            for i, child in enumerate(children):
+               child.parent_ops.add(opcode)
+            return item
+
+      for i, pattern in enumerate(self.patterns):
+         process_subpattern(pattern, i)
+
+   def _build_table(self):
+      """This is the core algorithm which builds up the transition table. It
+      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
+      Comp_a and Filt_{a,i} using integers to identify match sets." It
+      simultaneously builds up a list of all possible "match sets" or
+      "states", where each match set represents the set of Item's that match a
+      given instruction, and builds up the transition table between states.
+      """
+      # Map from opcode + filtered state indices to transitioned state.
+      self.table = defaultdict(dict)
+      # Bijection from state to index. q in the original algorithm is
+      # len(self.states)
+      self.states = self.IndexMap()
+      # List of pattern matches for each state index.
+      self.state_patterns = []
+      # Map from state index to filtered state index for each opcode.
+      self.filter = defaultdict(list)
+      # Bijections from filtered state to filtered state index for each
+      # opcode, called the "representor sets" in the original algorithm.
+      # q_{a,j} in the original algorithm is len(self.rep[op]).
+      self.rep = defaultdict(self.IndexMap)
+
+      # Everything in self.states with a index at least worklist_index is part
+      # of the worklist of newly created states. There is also a worklist of
+      # newly fitered states for each opcode, for which worklist_indices
+      # serves a similar purpose. worklist_index corresponds to p in the
+      # original algorithm, while worklist_indices is p_{a,j} (although since
+      # we only filter by opcode/symbol, it's really just p_a).
+      self.worklist_index = 0
+      worklist_indices = defaultdict(lambda: 0)
+
+      # This is the set of opcodes for which the filtered worklist is non-empty.
+      # It's used to avoid scanning opcodes for which there is nothing to
+      # process when building the transition table. It corresponds to new_a in
+      # the original algorithm.
+      new_opcodes = self.IndexMap()
+
+      # Process states on the global worklist, filtering them for each opcode,
+      # updating the filter tables, and updating the filtered worklists if any
+      # new filtered states are found. Similar to ComputeRepresenterSets() in
+      # the original algorithm, although that only processes a single state.
+      def process_new_states():
+         while self.worklist_index < len(self.states):
+            state = self.states[self.worklist_index]
+
+            # Calculate pattern matches for this state. Each pattern is
+            # assigned to a unique item, so we don't have to worry about
+            # deduplicating them here. However, we do have to sort them so
+            # that they're visited at runtime in the order they're specified
+            # in the source.
+            patterns = list(sorted(p for item in state for p in item.patterns))
+            assert len(self.state_patterns) == self.worklist_index
+            self.state_patterns.append(patterns)
+
+            # calculate filter table for this state, and update filtered
+            # worklists.
+            for op in self.opcodes:
+               filt = self.filter[op]
+               rep = self.rep[op]
+               filtered = frozenset(item for item in state if \
+                  op in item.parent_ops)
+               if filtered in rep:
+                  rep_index = rep.index(filtered)
+               else:
+                  rep_index = rep.add(filtered)
+                  new_opcodes.add(op)
+               assert len(filt) == self.worklist_index
+               filt.append(rep_index)
+            self.worklist_index += 1
+
+      # There are two start states: one which can only match as a wildcard,
+      # and one which can match as a wildcard or constant. These will be the
+      # states of intrinsics/other instructions and load_const instructions,
+      # respectively. The indices of these must match the definitions of
+      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
+      # initialize things correctly.
+      self.states.add(frozenset((self.wildcard,)))
+      self.states.add(frozenset((self.const,self.wildcard)))
+      process_new_states()
+
+      while len(new_opcodes) > 0:
+         for op in new_opcodes:
+            rep = self.rep[op]
+            table = self.table[op]
+            op_worklist_index = worklist_indices[op]
+            if op in conv_opcode_types:
+               num_srcs = 1
+            else:
+               num_srcs = opcodes[op].num_inputs
+
+            # Iterate over all possible source combinations where at least one
+            # is on the worklist.
+            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
+               if all(src_idx < op_worklist_index for src_idx in src_indices):
+                  continue
+
+               srcs = tuple(rep[src_idx] for src_idx in src_indices)
+
+               # Try all possible pairings of source items and add the
+               # corresponding parent items. This is Comp_a from the paper.
+               parent = set(self.items[op, item_srcs] for item_srcs in
+                  itertools.product(*srcs) if (op, item_srcs) in self.items)
+
+               # We could always start matching something else with a
+               # wildcard. This is Cl from the paper.
+               parent.add(self.wildcard)
+
+               table[src_indices] = self.states.add(frozenset(parent))
+            worklist_indices[op] = len(rep)
+         new_opcodes.clear()
+         process_new_states()
+
 _algebraic_pass_template = mako.template.Template("""
 #include "nir.h"
+#include "nir_builder.h"
 #include "nir_search.h"
 #include "nir_search_helpers.h"
 
-#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-
-struct transform {
-   const nir_search_expression *search;
-   const nir_search_value *replace;
-   unsigned condition_offset;
-};
-
-#endif
+/* What follows is NIR algebraic transform code for the following ${len(xforms)}
+ * transforms:
+% for xform in xforms:
+ *    ${xform.search} => ${xform.replace}
+% endfor
+ */
 
-% for (opcode, xform_list) in xform_dict.iteritems():
-% for xform in xform_list:
-   ${xform.search.render()}
-   ${xform.replace.render()}
+<% cache = {} %>
+% for xform in xforms:
+   ${xform.search.render(cache)}
+   ${xform.replace.render(cache)}
 % endfor
 
-static const struct transform ${pass_name}_${opcode}_xforms[] = {
-% for xform in xform_list:
-   { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
+% for state_id, state_xforms in enumerate(automaton.state_patterns):
+% if state_xforms: # avoid emitting a 0-length array for MSVC
+static const struct transform ${pass_name}_state${state_id}_xforms[] = {
+% for i in state_xforms:
+  { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
 % endfor
 };
+% endif
 % endfor
 
-static bool
-${pass_name}_block(nir_block *block, const bool *condition_flags,
-                   void *mem_ctx)
-{
-   bool progress = false;
-
-   nir_foreach_instr_reverse_safe(instr, block) {
-      if (instr->type != nir_instr_type_alu)
-         continue;
-
-      nir_alu_instr *alu = nir_instr_as_alu(instr);
-      if (!alu->dest.dest.is_ssa)
-         continue;
-
-      switch (alu->op) {
-      % for opcode in xform_dict.keys():
-      case nir_op_${opcode}:
-         for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
-            const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
-            if (condition_flags[xform->condition_offset] &&
-                nir_replace_instr(alu, xform->search, xform->replace,
-                                  mem_ctx)) {
-               progress = true;
-               break;
-            }
-         }
-         break;
+static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
+% for op in automaton.opcodes:
+   [${get_c_opcode(op)}] = {
+      .filter = (uint16_t []) {
+      % for e in automaton.filter[op]:
+         ${e},
       % endfor
-      default:
-         break;
-      }
-   }
-
-   return progress;
-}
-
-static bool
-${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
-{
-   void *mem_ctx = ralloc_parent(impl);
-   bool progress = false;
-
-   nir_foreach_block_reverse(block, impl) {
-      progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
-   }
-
-   if (progress)
-      nir_metadata_preserve(impl, nir_metadata_block_index |
-                                  nir_metadata_dominance);
+      },
+      <%
+        num_filtered = len(automaton.rep[op])
+      %>
+      .num_filtered_states = ${num_filtered},
+      .table = (uint16_t []) {
+      <%
+        num_srcs = len(next(iter(automaton.table[op])))
+      %>
+      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
+         ${automaton.table[op][indices]},
+      % endfor
+      },
+   },
+% endfor
+};
 
-   return progress;
-}
+const struct transform *${pass_name}_transforms[] = {
+% for i in range(len(automaton.state_patterns)):
+   % if automaton.state_patterns[i]:
+   ${pass_name}_state${i}_xforms,
+   % else:
+   NULL,
+   % endif
+% endfor
+};
 
+const uint16_t ${pass_name}_transform_counts[] = {
+% for i in range(len(automaton.state_patterns)):
+   % if automaton.state_patterns[i]:
+   (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
+   % else:
+   0,
+   % endif
+% endfor
+};
 
 bool
 ${pass_name}(nir_shader *shader)
@@ -585,24 +1128,32 @@ ${pass_name}(nir_shader *shader)
    bool progress = false;
    bool condition_flags[${len(condition_list)}];
    const nir_shader_compiler_options *options = shader->options;
+   const shader_info *info = &shader->info;
    (void) options;
+   (void) info;
 
    % for index, condition in enumerate(condition_list):
    condition_flags[${index}] = ${condition};
    % endfor
 
    nir_foreach_function(function, shader) {
-      if (function->impl)
-         progress |= ${pass_name}_impl(function->impl, condition_flags);
+      if (function->impl) {
+         progress |= nir_algebraic_impl(function->impl, condition_flags,
+                                        ${pass_name}_transforms,
+                                        ${pass_name}_transform_counts,
+                                        ${pass_name}_table);
+      }
    }
 
    return progress;
 }
 """)
 
+
 class AlgebraicPass(object):
    def __init__(self, pass_name, transforms):
-      self.xform_dict = {}
+      self.xforms = []
+      self.opcode_xforms = defaultdict(lambda : [])
       self.pass_name = pass_name
 
       error = False
@@ -619,15 +1170,52 @@ class AlgebraicPass(object):
                error = True
                continue
 
-         if xform.search.opcode not in self.xform_dict:
-            self.xform_dict[xform.search.opcode] = []
+         self.xforms.append(xform)
+         if xform.search.opcode in conv_opcode_types:
+            dst_type = conv_opcode_types[xform.search.opcode]
+            for size in type_sizes(dst_type):
+               sized_opcode = xform.search.opcode + str(size)
+               self.opcode_xforms[sized_opcode].append(xform)
+         else:
+            self.opcode_xforms[xform.search.opcode].append(xform)
+
+         # Check to make sure the search pattern does not unexpectedly contain
+         # more commutative expressions than match_expression (nir_search.c)
+         # can handle.
+         comm_exprs = xform.search.comm_exprs
+
+         if xform.search.many_commutative_expressions:
+            if comm_exprs <= nir_search_max_comm_ops:
+               print("Transform expected to have too many commutative " \
+                     "expression but did not " \
+                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
+                     file=sys.stderr)
+               print("  " + str(xform), file=sys.stderr)
+               traceback.print_exc(file=sys.stderr)
+               print('', file=sys.stderr)
+               error = True
+         else:
+            if comm_exprs > nir_search_max_comm_ops:
+               print("Transformation with too many commutative expressions " \
+                     "({} > {}).  Modify pattern or annotate with " \
+                     "\"many-comm-expr\".".format(comm_exprs,
+                                                  nir_search_max_comm_ops),
+                     file=sys.stderr)
+               print("  " + str(xform.search), file=sys.stderr)
+               print("{}".format(xform.search.cond), file=sys.stderr)
+               error = True
 
-         self.xform_dict[xform.search.opcode].append(xform)
+      self.automaton = TreeAutomaton(self.xforms)
 
       if error:
          sys.exit(1)
 
+
    def render(self):
       return _algebraic_pass_template.render(pass_name=self.pass_name,
-                                             xform_dict=self.xform_dict,
-                                             condition_list=condition_list)
+                                             xforms=self.xforms,
+                                             opcode_xforms=self.opcode_xforms,
+                                             condition_list=condition_list,
+                                             automaton=self.automaton,
+                                             get_c_opcode=get_c_opcode,
+                                             itertools=itertools)