from __future__ import print_function
import ast
-from collections import OrderedDict
+from collections import defaultdict
import itertools
import struct
import sys
import re
import traceback
-from nir_opcodes import opcodes
+from nir_opcodes import opcodes, type_sizes
+
+# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
+nir_search_max_comm_ops = 8
+
+# These opcodes are only employed by nir_search. This provides a mapping from
+# opcode to destination type.
+conv_opcode_types = {
+ 'i2f' : 'float',
+ 'u2f' : 'float',
+ 'f2f' : 'float',
+ 'f2u' : 'uint',
+ 'f2i' : 'int',
+ 'u2u' : 'uint',
+ 'i2i' : 'int',
+ 'b2f' : 'float',
+ 'b2i' : 'int',
+ 'i2b' : 'bool',
+ 'f2b' : 'bool',
+}
+
+def get_c_opcode(op):
+ if op in conv_opcode_types:
+ return 'nir_search_op_' + op
+ else:
+ return 'nir_op_' + op
+
if sys.version_info < (3, 0):
integer_types = (int, long)
elif isinstance(val, (bool, float) + integer_types):
return Constant(val, name_base)
- __template = mako.template.Template("""
-static const ${val.c_type} ${val.name} = {
- { ${val.type_enum}, ${val.bit_size} },
-% if isinstance(val, Constant):
- ${val.type()}, { ${val.hex()} /* ${val.value} */ },
-% elif isinstance(val, Variable):
- ${val.index}, /* ${val.var_name} */
- ${'true' if val.is_constant else 'false'},
- ${val.type() or 'nir_type_invalid' },
- ${val.cond if val.cond else 'NULL'},
-% elif isinstance(val, Expression):
- ${'true' if val.inexact else 'false'},
- nir_op_${val.opcode},
- { ${', '.join(src.c_ptr for src in val.sources)} },
- ${val.cond if val.cond else 'NULL'},
-% endif
-};""")
-
def __init__(self, val, name, type_str):
self.in_val = str(val)
self.name = name
def __str__(self):
return self.in_val
+ def get_bit_size(self):
+ """Get the physical bit-size that has been chosen for this value, or if
+ there is none, the canonical value which currently represents this
+ bit-size class. Variables will be preferred, i.e. if there are any
+ variables in the equivalence class, the canonical value will be a
+ variable. We do this since we'll need to know which variable each value
+ is equivalent to when constructing the replacement expression. This is
+ the "find" part of the union-find algorithm.
+ """
+ bit_size = self
+
+ while isinstance(bit_size, Value):
+ if bit_size._bit_size is None:
+ break
+ bit_size = bit_size._bit_size
+
+ if bit_size is not self:
+ self._bit_size = bit_size
+ return bit_size
+
+ def set_bit_size(self, other):
+ """Make self.get_bit_size() return what other.get_bit_size() return
+ before calling this, or just "other" if it's a concrete bit-size. This is
+ the "union" part of the union-find algorithm.
+ """
+
+ self_bit_size = self.get_bit_size()
+ other_bit_size = other if isinstance(other, int) else other.get_bit_size()
+
+ if self_bit_size == other_bit_size:
+ return
+
+ self_bit_size._bit_size = other_bit_size
+
@property
def type_enum(self):
return "nir_search_value_" + self.type_str
def c_type(self):
return "nir_search_" + self.type_str
+ def __c_name(self, cache):
+ if cache is not None and self.name in cache:
+ return cache[self.name]
+ else:
+ return self.name
+
+ def c_value_ptr(self, cache):
+ return "&{0}.value".format(self.__c_name(cache))
+
+ def c_ptr(self, cache):
+ return "&{0}".format(self.__c_name(cache))
+
@property
- def c_ptr(self):
- return "&{0}.value".format(self.name)
+ def c_bit_size(self):
+ bit_size = self.get_bit_size()
+ if isinstance(bit_size, int):
+ return bit_size
+ elif isinstance(bit_size, Variable):
+ return -bit_size.index - 1
+ else:
+ # If the bit-size class is neither a variable, nor an actual bit-size, then
+ # - If it's in the search expression, we don't need to check anything
+ # - If it's in the replace expression, either it's ambiguous (in which
+ # case we'd reject it), or it equals the bit-size of the search value
+ # We represent these cases with a 0 bit-size.
+ return 0
+
+ __template = mako.template.Template("""{
+ { ${val.type_enum}, ${val.c_bit_size} },
+% if isinstance(val, Constant):
+ ${val.type()}, { ${val.hex()} /* ${val.value} */ },
+% elif isinstance(val, Variable):
+ ${val.index}, /* ${val.var_name} */
+ ${'true' if val.is_constant else 'false'},
+ ${val.type() or 'nir_type_invalid' },
+ ${val.cond if val.cond else 'NULL'},
+ ${val.swizzle()},
+% elif isinstance(val, Expression):
+ ${'true' if val.inexact else 'false'},
+ ${val.comm_expr_idx}, ${val.comm_exprs},
+ ${val.c_opcode()},
+ { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
+ ${val.cond if val.cond else 'NULL'},
+% endif
+};""")
- def render(self):
- return self.__template.render(val=self,
- Constant=Constant,
- Variable=Variable,
- Expression=Expression)
+ def render(self, cache):
+ struct_init = self.__template.render(val=self, cache=cache,
+ Constant=Constant,
+ Variable=Variable,
+ Expression=Expression)
+ if cache is not None and struct_init in cache:
+ # If it's in the cache, register a name remap in the cache and render
+ # only a comment saying it's been remapped
+ cache[self.name] = cache[struct_init]
+ return "/* {} -> {} in the cache */\n".format(self.name,
+ cache[struct_init])
+ else:
+ if cache is not None:
+ cache[struct_init] = self.name
+ return "static const {} {} = {}\n".format(self.c_type, self.name,
+ struct_init)
_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
def __init__(self, val, name):
Value.__init__(self, val, name, "constant")
- self.in_val = str(val)
if isinstance(val, (str)):
m = _constant_re.match(val)
self.value = ast.literal_eval(m.group('value'))
- self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+ self._bit_size = int(m.group('bits')) if m.group('bits') else None
else:
self.value = val
- self.bit_size = 0
+ self._bit_size = None
if isinstance(self.value, bool):
- assert self.bit_size == 0 or self.bit_size == 32
- self.bit_size = 32
+ assert self._bit_size is None or self._bit_size == 1
+ self._bit_size = 1
def hex(self):
if isinstance(self.value, (bool)):
elif isinstance(self.value, float):
return "nir_type_float"
+ def equivalent(self, other):
+ """Check that two constants are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.value == other.value
+
_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
- r"(?P<cond>\([^\)]+\))?")
+ r"(?P<cond>\([^\)]+\))?"
+ r"(?P<swiz>\.[xyzw]+)?")
class Variable(Value):
def __init__(self, val, name, varset):
assert m and m.group('name') is not None
self.var_name = m.group('name')
+
+ # Prevent common cases where someone puts quotes around a literal
+ # constant. If we want to support names that have numeric or
+ # punctuation characters, we can me the first assertion more flexible.
+ assert self.var_name.isalpha()
+ assert self.var_name is not 'True'
+ assert self.var_name is not 'False'
+
self.is_constant = m.group('const') is not None
self.cond = m.group('cond')
self.required_type = m.group('type')
- self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+ self._bit_size = int(m.group('bits')) if m.group('bits') else None
+ self.swiz = m.group('swiz')
if self.required_type == 'bool':
- assert self.bit_size == 0 or self.bit_size == 32
- self.bit_size = 32
+ if self._bit_size is not None:
+ assert self._bit_size in type_sizes(self.required_type)
+ else:
+ self._bit_size = 1
if self.required_type is not None:
assert self.required_type in ('float', 'bool', 'int', 'uint')
self.index = varset[self.var_name]
- def __str__(self):
- return self.in_val
-
def type(self):
if self.required_type == 'bool':
return "nir_type_bool"
elif self.required_type == 'float':
return "nir_type_float"
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.index == other.index
+
+ def swizzle(self):
+ if self.swiz is not None:
+ swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w': 3}
+ return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
+ return '{0, 1, 2, 3}'
+
_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
r"(?P<cond>\([^\)]+\))?")
assert m and m.group('opcode') is not None
self.opcode = m.group('opcode')
- self.bit_size = int(m.group('bits')) if m.group('bits') else 0
+ self._bit_size = int(m.group('bits')) if m.group('bits') else None
self.inexact = m.group('inexact') is not None
self.cond = m.group('cond')
+
+ # "many-comm-expr" isn't really a condition. It's notification to the
+ # generator that this pattern is known to have too many commutative
+ # expressions, and an error should not be generated for this case.
+ self.many_commutative_expressions = False
+ if self.cond and self.cond.find("many-comm-expr") >= 0:
+ # Split the condition into a comma-separated list. Remove
+ # "many-comm-expr". If there is anything left, put it back together.
+ c = self.cond[1:-1].split(",")
+ c.remove("many-comm-expr")
+
+ self.cond = "({})".format(",".join(c)) if c else None
+ self.many_commutative_expressions = True
+
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
for (i, src) in enumerate(expr[1:]) ]
- def render(self):
- srcs = "\n".join(src.render() for src in self.sources)
- return srcs + super(Expression, self).render()
+ if self.opcode in conv_opcode_types:
+ assert self._bit_size is None, \
+ 'Expression cannot use an unsized conversion opcode with ' \
+ 'an explicit size; that\'s silly.'
-class IntEquivalenceRelation(object):
- """A class representing an equivalence relation on integers.
+ self.__index_comm_exprs(0)
- Each integer has a canonical form which is the maximum integer to which it
- is equivalent. Two integers are equivalent precisely when they have the
- same canonical form.
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
- The convention of maximum is explicitly chosen to make using it in
- BitSizeValidator easier because it means that an actual bit_size (if any)
- will always be the canonical form.
- """
- def __init__(self):
- self._remap = {}
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ This implementation does not check for equivalence due to commutativity,
+ but it could.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.sources) != len(other.sources):
+ return False
+
+ if self.opcode != other.opcode:
+ return False
+
+ return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
+ def __index_comm_exprs(self, base_idx):
+ """Recursively count and index commutative expressions
+ """
+ self.comm_exprs = 0
- def get_canonical(self, x):
- """Get the canonical integer corresponding to x."""
- if x in self._remap:
- return self.get_canonical(self._remap[x])
+ # A note about the explicit "len(self.sources)" check. The list of
+ # sources comes from user input, and that input might be bad. Check
+ # that the expected second source exists before accessing it. Without
+ # this check, a unit test that does "('iadd', 'a')" will crash.
+ if self.opcode not in conv_opcode_types and \
+ "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+ len(self.sources) >= 2 and \
+ not self.sources[0].equivalent(self.sources[1]):
+ self.comm_expr_idx = base_idx
+ self.comm_exprs += 1
else:
- return x
+ self.comm_expr_idx = -1
- def add_equiv(self, a, b):
- """Add an equivalence and return the canonical form."""
- c = max(self.get_canonical(a), self.get_canonical(b))
- if a != c:
- assert a < c
- self._remap[a] = c
+ for s in self.sources:
+ if isinstance(s, Expression):
+ s.__index_comm_exprs(base_idx + self.comm_exprs)
+ self.comm_exprs += s.comm_exprs
- if b != c:
- assert b < c
- self._remap[b] = c
+ return self.comm_exprs
- return c
+ def c_opcode(self):
+ return get_c_opcode(self.opcode)
+
+ def render(self, cache):
+ srcs = "\n".join(src.render(cache) for src in self.sources)
+ return srcs + super(Expression, self).render(cache)
class BitSizeValidator(object):
"""A class for validating bit sizes of expressions.
inference can be ambiguous or contradictory. Consider, for instance, the
following transformation:
- (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
+ (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
This transformation can potentially cause a problem because usub_borrow is
well-defined for any bit-size of integer. However, b2i always generates a
generate any code. This ensures that bugs are caught at compile time
rather than at run time.
- The basic operation of the validator is very similar to the bitsize_tree in
- nir_search only a little more subtle. Instead of simply tracking bit
- sizes, it tracks "bit classes" where each class is represented by an
- integer. A value of 0 means we don't know anything yet, positive values
- are actual bit-sizes, and negative values are used to track equivalence
- classes of sizes that must be the same but have yet to receive an actual
- size. The first stage uses the bitsize_tree algorithm to assign bit
- classes to each variable. If it ever comes across an inconsistency, it
- assert-fails. Then the second stage uses that information to prove that
- the resulting expression can always validly be constructed.
- """
+ Each value maintains a "bit-size class", which is either an actual bit size
+ or an equivalence class with other values that must have the same bit size.
+ The validator works by combining bit-size classes with each other according
+ to the NIR rules outlined above, checking that there are no inconsistencies.
+ When doing this for the replacement expression, we make sure to never change
+ the equivalence class of any of the search values. We could make the example
+ transforms above work by doing some extra run-time checking of the search
+ expression, but we make the user specify those constraints themselves, to
+ avoid any surprises. Since the replacement bitsizes can only be connected to
+ the source bitsize via variables (variables must have the same bitsize in
+ the source and replacment expressions) or the roots of the expression (the
+ replacement expression must produce the same bit size as the search
+ expression), we prevent merging a variable with anything when processing the
+ replacement expression, or specializing the search bitsize
+ with anything. The former prevents
- def __init__(self, varset):
- self._num_classes = 0
- self._var_classes = [0] * len(varset.names)
- self._class_relation = IntEquivalenceRelation()
+ (('bcsel', a, b, 0), ('iand', a, b))
- def validate(self, search, replace):
- search_dst_class = self._propagate_bit_size_up(search)
- if search_dst_class == 0:
- search_dst_class = self._new_class()
- self._propagate_bit_class_down(search, search_dst_class)
-
- replace_dst_class = self._validate_bit_class_up(replace)
- assert replace_dst_class == 0 or replace_dst_class == search_dst_class
- self._validate_bit_class_down(replace, search_dst_class)
-
- def _new_class(self):
- self._num_classes += 1
- return -self._num_classes
-
- def _set_var_bit_class(self, var, bit_class):
- assert bit_class != 0
- var_class = self._var_classes[var.index]
- if var_class == 0:
- self._var_classes[var.index] = bit_class
- else:
- canon_class = self._class_relation.get_canonical(var_class)
- assert canon_class < 0 or canon_class == bit_class
- var_class = self._class_relation.add_equiv(var_class, bit_class)
- self._var_classes[var.index] = var_class
+ from being allowed, since we'd have to merge the bitsizes for a and b due to
+ the 'iand', while the latter prevents
- def _get_var_bit_class(self, var):
- return self._class_relation.get_canonical(self._var_classes[var.index])
+ (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
- def _propagate_bit_size_up(self, val):
- if isinstance(val, (Constant, Variable)):
- return val.bit_size
+ from being allowed, since the search expression has the bit size of a and b,
+ which can't be specialized to 32 which is the bitsize of the replace
+ expression. It also prevents something like:
- elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- val.common_size = 0
- for i in range(nir_op.num_inputs):
- src_bits = self._propagate_bit_size_up(val.sources[i])
- if src_bits == 0:
- continue
+ (('b2i', ('i2b', a)), ('ineq', a, 0))
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- assert src_bits == src_type_bits
- else:
- assert val.common_size == 0 or src_bits == val.common_size
- val.common_size = src_bits
+ since the bitsize of 'b2i', which can be anything, can't be specialized to
+ the bitsize of a.
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert val.bit_size == 0 or val.bit_size == dst_type_bits
- return dst_type_bits
- else:
- if val.common_size != 0:
- assert val.bit_size == 0 or val.bit_size == val.common_size
- else:
- val.common_size = val.bit_size
- return val.common_size
-
- def _propagate_bit_class_down(self, val, bit_class):
- if isinstance(val, Constant):
- assert val.bit_size == 0 or val.bit_size == bit_class
-
- elif isinstance(val, Variable):
- assert val.bit_size == 0 or val.bit_size == bit_class
- self._set_var_bit_class(val, bit_class)
+ After doing all this, we check that every subexpression of the replacement
+ was assigned a constant bitsize, the bitsize of a variable, or the bitsize
+ of the search expresssion, since those are the things that are known when
+ constructing the replacement expresssion. Finally, we record the bitsize
+ needed in nir_search_value so that we know what to do when building the
+ replacement expression.
+ """
- elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert bit_class == 0 or bit_class == dst_type_bits
+ def __init__(self, varset):
+ self._var_classes = [None] * len(varset.names)
+
+ def compare_bitsizes(self, a, b):
+ """Determines which bitsize class is a specialization of the other, or
+ whether neither is. When we merge two different bitsizes, the
+ less-specialized bitsize always points to the more-specialized one, so
+ that calling get_bit_size() always gets you the most specialized bitsize.
+ The specialization partial order is given by:
+ - Physical bitsizes are always the most specialized, and a different
+ bitsize can never specialize another.
+ - In the search expression, variables can always be specialized to each
+ other and to physical bitsizes. In the replace expression, we disallow
+ this to avoid adding extra constraints to the search expression that
+ the user didn't specify.
+ - Expressions and constants without a bitsize can always be specialized to
+ each other and variables, but not the other way around.
+
+ We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
+ and None if they are not comparable (neither a <= b nor b <= a).
+ """
+ if isinstance(a, int):
+ if isinstance(b, int):
+ return 0 if a == b else None
+ elif isinstance(b, Variable):
+ return -1 if self.is_search else None
else:
- assert val.common_size == 0 or val.common_size == bit_class
- val.common_size = bit_class
-
- if val.common_size:
- common_class = val.common_size
- elif nir_op.num_inputs:
- # If we got here then we have no idea what the actual size is.
- # Instead, we use a generic class
- common_class = self._new_class()
-
- for i in range(nir_op.num_inputs):
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- self._propagate_bit_class_down(val.sources[i], src_type_bits)
- else:
- self._propagate_bit_class_down(val.sources[i], common_class)
-
- def _validate_bit_class_up(self, val):
- if isinstance(val, Constant):
- return val.bit_size
-
- elif isinstance(val, Variable):
- var_class = self._get_var_bit_class(val)
- # By the time we get to validation, every variable should have a class
- assert var_class != 0
-
- # If we have an explicit size provided by the user, the variable
- # *must* exactly match the search. It cannot be implicitly sized
- # because otherwise we could end up with a conflict at runtime.
- assert val.bit_size == 0 or val.bit_size == var_class
-
- return var_class
-
+ return -1
+ elif isinstance(a, Variable):
+ if isinstance(b, int):
+ return 1 if self.is_search else None
+ elif isinstance(b, Variable):
+ return 0 if self.is_search or a.index == b.index else None
+ else:
+ return -1
+ else:
+ if isinstance(b, int):
+ return 1
+ elif isinstance(b, Variable):
+ return 1
+ else:
+ return 0
+
+ def unify_bit_size(self, a, b, error_msg):
+ """Record that a must have the same bit-size as b. If both
+ have been assigned conflicting physical bit-sizes, call "error_msg" with
+ the bit-sizes of self and other to get a message and raise an error.
+ In the replace expression, disallow merging variables with other
+ variables and physical bit-sizes as well.
+ """
+ a_bit_size = a.get_bit_size()
+ b_bit_size = b if isinstance(b, int) else b.get_bit_size()
+
+ cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
+
+ assert cmp_result is not None, \
+ error_msg(a_bit_size, b_bit_size)
+
+ if cmp_result < 0:
+ b_bit_size.set_bit_size(a)
+ elif not isinstance(a_bit_size, int):
+ a_bit_size.set_bit_size(b)
+
+ def merge_variables(self, val):
+ """Perform the first part of type inference by merging all the different
+ uses of the same variable. We always do this as if we're in the search
+ expression, even if we're actually not, since otherwise we'd get errors
+ if the search expression specified some constraint but the replace
+ expression didn't, because we'd be merging a variable and a constant.
+ """
+ if isinstance(val, Variable):
+ if self._var_classes[val.index] is None:
+ self._var_classes[val.index] = val
+ else:
+ other = self._var_classes[val.index]
+ self.unify_bit_size(other, val,
+ lambda other_bit_size, bit_size:
+ 'Variable {} has conflicting bit size requirements: ' \
+ 'it must have bit size {} and {}'.format(
+ val.var_name, other_bit_size, bit_size))
elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- val.common_class = 0
- for i in range(nir_op.num_inputs):
- src_class = self._validate_bit_class_up(val.sources[i])
- if src_class == 0:
+ for src in val.sources:
+ self.merge_variables(src)
+
+ def validate_value(self, val):
+ """Validate the an expression by performing classic Hindley-Milner
+ type inference on bitsizes. This will detect if there are any conflicting
+ requirements, and unify variables so that we know which variables must
+ have the same bitsize. If we're operating on the replace expression, we
+ will refuse to merge different variables together or merge a variable
+ with a constant, in order to prevent surprises due to rules unexpectedly
+ not matching at runtime.
+ """
+ if not isinstance(val, Expression):
+ return
+
+ # Generic conversion ops are special in that they have a single unsized
+ # source and an unsized destination and the two don't have to match.
+ # This means there's no validation or unioning to do here besides the
+ # len(val.sources) check.
+ if val.opcode in conv_opcode_types:
+ assert len(val.sources) == 1, \
+ "Expression {} has {} sources, expected 1".format(
+ val, len(val.sources))
+ self.validate_value(val.sources[0])
+ return
+
+ nir_op = opcodes[val.opcode]
+ assert len(val.sources) == nir_op.num_inputs, \
+ "Expression {} has {} sources, expected {}".format(
+ val, len(val.sources), nir_op.num_inputs)
+
+ for src in val.sources:
+ self.validate_value(src)
+
+ dst_type_bits = type_bits(nir_op.output_type)
+
+ # First, unify all the sources. That way, an error coming up because two
+ # sources have an incompatible bit-size won't produce an error message
+ # involving the destination.
+ first_unsized_src = None
+ for src_type, src in zip(nir_op.input_types, val.sources):
+ src_type_bits = type_bits(src_type)
+ if src_type_bits == 0:
+ if first_unsized_src is None:
+ first_unsized_src = src
continue
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- assert src_class == src_type_bits
+ if self.is_search:
+ self.unify_bit_size(first_unsized_src, src,
+ lambda first_unsized_src_bit_size, src_bit_size:
+ 'Source {} of {} must have bit size {}, while source {} ' \
+ 'must have incompatible bit size {}'.format(
+ first_unsized_src, val, first_unsized_src_bit_size,
+ src, src_bit_size))
else:
- assert val.common_class == 0 or src_class == val.common_class
- val.common_class = src_class
-
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert val.bit_size == 0 or val.bit_size == dst_type_bits
- return dst_type_bits
+ self.unify_bit_size(first_unsized_src, src,
+ lambda first_unsized_src_bit_size, src_bit_size:
+ 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
+ 'of {} may not have the same bit size when building the ' \
+ 'replacement expression.'.format(
+ first_unsized_src, first_unsized_src_bit_size, src,
+ src_bit_size, val))
else:
- if val.common_class != 0:
- assert val.bit_size == 0 or val.bit_size == val.common_class
+ if self.is_search:
+ self.unify_bit_size(src, src_type_bits,
+ lambda src_bit_size, unused:
+ '{} must have {} bits, but as a source of nir_op_{} '\
+ 'it must have {} bits'.format(
+ src, src_bit_size, nir_op.name, src_type_bits))
+ else:
+ self.unify_bit_size(src, src_type_bits,
+ lambda src_bit_size, unused:
+ '{} has the bit size of {}, but as a source of ' \
+ 'nir_op_{} it must have {} bits, which may not be the ' \
+ 'same'.format(
+ src, src_bit_size, nir_op.name, src_type_bits))
+
+ if dst_type_bits == 0:
+ if first_unsized_src is not None:
+ if self.is_search:
+ self.unify_bit_size(val, first_unsized_src,
+ lambda val_bit_size, src_bit_size:
+ '{} must have the bit size of {}, while its source {} ' \
+ 'must have incompatible bit size {}'.format(
+ val, val_bit_size, first_unsized_src, src_bit_size))
else:
- val.common_class = val.bit_size
- return val.common_class
+ self.unify_bit_size(val, first_unsized_src,
+ lambda val_bit_size, src_bit_size:
+ '{} must have {} bits, but its source {} ' \
+ '(bit size of {}) may not have that bit size ' \
+ 'when building the replacement.'.format(
+ val, val_bit_size, first_unsized_src, src_bit_size))
+ else:
+ self.unify_bit_size(val, dst_type_bits,
+ lambda dst_bit_size, unused:
+ '{} must have {} bits, but as a destination of nir_op_{} ' \
+ 'it must have {} bits'.format(
+ val, dst_bit_size, nir_op.name, dst_type_bits))
+
+ def validate_replace(self, val, search):
+ bit_size = val.get_bit_size()
+ assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
+ bit_size == search.get_bit_size(), \
+ 'Ambiguous bit size for replacement value {}: ' \
+ 'it cannot be deduced from a variable, a fixed bit size ' \
+ 'somewhere, or the search expression.'.format(val)
+
+ if isinstance(val, Expression):
+ for src in val.sources:
+ self.validate_replace(src, search)
+
+ def validate(self, search, replace):
+ self.is_search = True
+ self.merge_variables(search)
+ self.merge_variables(replace)
+ self.validate_value(search)
- def _validate_bit_class_down(self, val, bit_class):
- # At this point, everything *must* have a bit class. Otherwise, we have
- # a value we don't know how to define.
- assert bit_class != 0
+ self.is_search = False
+ self.validate_value(replace)
- if isinstance(val, Constant):
- assert val.bit_size == 0 or val.bit_size == bit_class
+ # Check that search is always more specialized than replace. Note that
+ # we're doing this in replace mode, disallowing merging variables.
+ search_bit_size = search.get_bit_size()
+ replace_bit_size = replace.get_bit_size()
+ cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
- elif isinstance(val, Variable):
- assert val.bit_size == 0 or val.bit_size == bit_class
+ assert cmp_result is not None and cmp_result <= 0, \
+ 'The search expression bit size {} and replace expression ' \
+ 'bit size {} may not be the same'.format(
+ search_bit_size, replace_bit_size)
- elif isinstance(val, Expression):
- nir_op = opcodes[val.opcode]
- dst_type_bits = type_bits(nir_op.output_type)
- if dst_type_bits != 0:
- assert bit_class == dst_type_bits
- else:
- assert val.common_class == 0 or val.common_class == bit_class
- val.common_class = bit_class
+ replace.set_bit_size(search)
- for i in range(nir_op.num_inputs):
- src_type_bits = type_bits(nir_op.input_types[i])
- if src_type_bits != 0:
- self._validate_bit_class_down(val.sources[i], src_type_bits)
- else:
- self._validate_bit_class_down(val.sources[i], val.common_class)
+ self.validate_replace(replace, search)
_optimization_ids = itertools.count()
BitSizeValidator(varset).validate(self.search, self.replace)
+class TreeAutomaton(object):
+ """This class calculates a bottom-up tree automaton to quickly search for
+ the left-hand sides of tranforms. Tree automatons are a generalization of
+ classical NFA's and DFA's, where the transition function determines the
+ state of the parent node based on the state of its children. We construct a
+ deterministic automaton to match patterns, using a similar algorithm to the
+ classical NFA to DFA construction. At the moment, it only matches opcodes
+ and constants (without checking the actual value), leaving more detailed
+ checking to the search function which actually checks the leaves. The
+ automaton acts as a quick filter for the search function, requiring only n
+ + 1 table lookups for each n-source operation. The implementation is based
+ on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
+ In the language of that reference, this is a frontier-to-root deterministic
+ automaton using only symbol filtering. The filtering is crucial to reduce
+ both the time taken to generate the tables and the size of the tables.
+ """
+ def __init__(self, transforms):
+ self.patterns = [t.search for t in transforms]
+ self._compute_items()
+ self._build_table()
+ #print('num items: {}'.format(len(set(self.items.values()))))
+ #print('num states: {}'.format(len(self.states)))
+ #for state, patterns in zip(self.states, self.patterns):
+ # print('{}: num patterns: {}'.format(state, len(patterns)))
+
+ class IndexMap(object):
+ """An indexed list of objects, where one can either lookup an object by
+ index or find the index associated to an object quickly using a hash
+ table. Compared to a list, it has a constant time index(). Compared to a
+ set, it provides a stable iteration order.
+ """
+ def __init__(self, iterable=()):
+ self.objects = []
+ self.map = {}
+ for obj in iterable:
+ self.add(obj)
+
+ def __getitem__(self, i):
+ return self.objects[i]
+
+ def __contains__(self, obj):
+ return obj in self.map
+
+ def __len__(self):
+ return len(self.objects)
+
+ def __iter__(self):
+ return iter(self.objects)
+
+ def clear(self):
+ self.objects = []
+ self.map.clear()
+
+ def index(self, obj):
+ return self.map[obj]
+
+ def add(self, obj):
+ if obj in self.map:
+ return self.map[obj]
+ else:
+ index = len(self.objects)
+ self.objects.append(obj)
+ self.map[obj] = index
+ return index
+
+ def __repr__(self):
+ return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
+
+ class Item(object):
+ """This represents an "item" in the language of "Tree Automatons." This
+ is just a subtree of some pattern, which represents a potential partial
+ match at runtime. We deduplicate them, so that identical subtrees of
+ different patterns share the same object, and store some extra
+ information needed for the main algorithm as well.
+ """
+ def __init__(self, opcode, children):
+ self.opcode = opcode
+ self.children = children
+ # These are the indices of patterns for which this item is the root node.
+ self.patterns = []
+ # This the set of opcodes for parents of this item. Used to speed up
+ # filtering.
+ self.parent_ops = set()
+
+ def __str__(self):
+ return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
+
+ def __repr__(self):
+ return str(self)
+
+ def _compute_items(self):
+ """Build a set of all possible items, deduplicating them."""
+ # This is a map from (opcode, sources) to item.
+ self.items = {}
+
+ # The set of all opcodes used by the patterns. Used later to avoid
+ # building and emitting all the tables for opcodes that aren't used.
+ self.opcodes = self.IndexMap()
+
+ def get_item(opcode, children, pattern=None):
+ commutative = len(children) >= 2 \
+ and "2src_commutative" in opcodes[opcode].algebraic_properties
+ item = self.items.setdefault((opcode, children),
+ self.Item(opcode, children))
+ if commutative:
+ self.items[opcode, (children[1], children[0]) + children[2:]] = item
+ if pattern is not None:
+ item.patterns.append(pattern)
+ return item
+
+ self.wildcard = get_item("__wildcard", ())
+ self.const = get_item("__const", ())
+
+ def process_subpattern(src, pattern=None):
+ if isinstance(src, Constant):
+ # Note: we throw away the actual constant value!
+ return self.const
+ elif isinstance(src, Variable):
+ if src.is_constant:
+ return self.const
+ else:
+ # Note: we throw away which variable it is here! This special
+ # item is equivalent to nu in "Tree Automatons."
+ return self.wildcard
+ else:
+ assert isinstance(src, Expression)
+ opcode = src.opcode
+ stripped = opcode.rstrip('0123456789')
+ if stripped in conv_opcode_types:
+ # Matches that use conversion opcodes with a specific type,
+ # like f2b1, are tricky. Either we construct the automaton to
+ # match specific NIR opcodes like nir_op_f2b1, in which case we
+ # need to create separate items for each possible NIR opcode
+ # for patterns that have a generic opcode like f2b, or we
+ # construct it to match the search opcode, in which case we
+ # need to map f2b1 to f2b when constructing the automaton. Here
+ # we do the latter.
+ opcode = stripped
+ self.opcodes.add(opcode)
+ children = tuple(process_subpattern(c) for c in src.sources)
+ item = get_item(opcode, children, pattern)
+ for i, child in enumerate(children):
+ child.parent_ops.add(opcode)
+ return item
+
+ for i, pattern in enumerate(self.patterns):
+ process_subpattern(pattern, i)
+
+ def _build_table(self):
+ """This is the core algorithm which builds up the transition table. It
+ is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
+ Comp_a and Filt_{a,i} using integers to identify match sets." It
+ simultaneously builds up a list of all possible "match sets" or
+ "states", where each match set represents the set of Item's that match a
+ given instruction, and builds up the transition table between states.
+ """
+ # Map from opcode + filtered state indices to transitioned state.
+ self.table = defaultdict(dict)
+ # Bijection from state to index. q in the original algorithm is
+ # len(self.states)
+ self.states = self.IndexMap()
+ # List of pattern matches for each state index.
+ self.state_patterns = []
+ # Map from state index to filtered state index for each opcode.
+ self.filter = defaultdict(list)
+ # Bijections from filtered state to filtered state index for each
+ # opcode, called the "representor sets" in the original algorithm.
+ # q_{a,j} in the original algorithm is len(self.rep[op]).
+ self.rep = defaultdict(self.IndexMap)
+
+ # Everything in self.states with a index at least worklist_index is part
+ # of the worklist of newly created states. There is also a worklist of
+ # newly fitered states for each opcode, for which worklist_indices
+ # serves a similar purpose. worklist_index corresponds to p in the
+ # original algorithm, while worklist_indices is p_{a,j} (although since
+ # we only filter by opcode/symbol, it's really just p_a).
+ self.worklist_index = 0
+ worklist_indices = defaultdict(lambda: 0)
+
+ # This is the set of opcodes for which the filtered worklist is non-empty.
+ # It's used to avoid scanning opcodes for which there is nothing to
+ # process when building the transition table. It corresponds to new_a in
+ # the original algorithm.
+ new_opcodes = self.IndexMap()
+
+ # Process states on the global worklist, filtering them for each opcode,
+ # updating the filter tables, and updating the filtered worklists if any
+ # new filtered states are found. Similar to ComputeRepresenterSets() in
+ # the original algorithm, although that only processes a single state.
+ def process_new_states():
+ while self.worklist_index < len(self.states):
+ state = self.states[self.worklist_index]
+
+ # Calculate pattern matches for this state. Each pattern is
+ # assigned to a unique item, so we don't have to worry about
+ # deduplicating them here. However, we do have to sort them so
+ # that they're visited at runtime in the order they're specified
+ # in the source.
+ patterns = list(sorted(p for item in state for p in item.patterns))
+ assert len(self.state_patterns) == self.worklist_index
+ self.state_patterns.append(patterns)
+
+ # calculate filter table for this state, and update filtered
+ # worklists.
+ for op in self.opcodes:
+ filt = self.filter[op]
+ rep = self.rep[op]
+ filtered = frozenset(item for item in state if \
+ op in item.parent_ops)
+ if filtered in rep:
+ rep_index = rep.index(filtered)
+ else:
+ rep_index = rep.add(filtered)
+ new_opcodes.add(op)
+ assert len(filt) == self.worklist_index
+ filt.append(rep_index)
+ self.worklist_index += 1
+
+ # There are two start states: one which can only match as a wildcard,
+ # and one which can match as a wildcard or constant. These will be the
+ # states of intrinsics/other instructions and load_const instructions,
+ # respectively. The indices of these must match the definitions of
+ # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
+ # initialize things correctly.
+ self.states.add(frozenset((self.wildcard,)))
+ self.states.add(frozenset((self.const,self.wildcard)))
+ process_new_states()
+
+ while len(new_opcodes) > 0:
+ for op in new_opcodes:
+ rep = self.rep[op]
+ table = self.table[op]
+ op_worklist_index = worklist_indices[op]
+ if op in conv_opcode_types:
+ num_srcs = 1
+ else:
+ num_srcs = opcodes[op].num_inputs
+
+ # Iterate over all possible source combinations where at least one
+ # is on the worklist.
+ for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
+ if all(src_idx < op_worklist_index for src_idx in src_indices):
+ continue
+
+ srcs = tuple(rep[src_idx] for src_idx in src_indices)
+
+ # Try all possible pairings of source items and add the
+ # corresponding parent items. This is Comp_a from the paper.
+ parent = set(self.items[op, item_srcs] for item_srcs in
+ itertools.product(*srcs) if (op, item_srcs) in self.items)
+
+ # We could always start matching something else with a
+ # wildcard. This is Cl from the paper.
+ parent.add(self.wildcard)
+
+ table[src_indices] = self.states.add(frozenset(parent))
+ worklist_indices[op] = len(rep)
+ new_opcodes.clear()
+ process_new_states()
+
_algebraic_pass_template = mako.template.Template("""
#include "nir.h"
+#include "nir_builder.h"
#include "nir_search.h"
#include "nir_search_helpers.h"
+/* What follows is NIR algebraic transform code for the following ${len(xforms)}
+ * transforms:
+% for xform in xforms:
+ * ${xform.search} => ${xform.replace}
+% endfor
+ */
+
#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
unsigned condition_offset;
};
+struct per_op_table {
+ const uint16_t *filter;
+ unsigned num_filtered_states;
+ const uint16_t *table;
+};
+
+/* Note: these must match the start states created in
+ * TreeAutomaton._build_table()
+ */
+
+/* WILDCARD_STATE = 0 is set by zeroing the state array */
+static const uint16_t CONST_STATE = 1;
+
#endif
-% for (opcode, xform_list) in xform_dict.items():
-% for xform in xform_list:
- ${xform.search.render()}
- ${xform.replace.render()}
+<% cache = {} %>
+% for xform in xforms:
+ ${xform.search.render(cache)}
+ ${xform.replace.render(cache)}
% endfor
-static const struct transform ${pass_name}_${opcode}_xforms[] = {
-% for xform in xform_list:
- { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
+% for state_id, state_xforms in enumerate(automaton.state_patterns):
+% if state_xforms: # avoid emitting a 0-length array for MSVC
+static const struct transform ${pass_name}_state${state_id}_xforms[] = {
+% for i in state_xforms:
+ { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
% endfor
};
+% endif
+% endfor
+
+static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
+% for op in automaton.opcodes:
+ [${get_c_opcode(op)}] = {
+ .filter = (uint16_t []) {
+ % for e in automaton.filter[op]:
+ ${e},
+ % endfor
+ },
+ <%
+ num_filtered = len(automaton.rep[op])
+ %>
+ .num_filtered_states = ${num_filtered},
+ .table = (uint16_t []) {
+ <%
+ num_srcs = len(next(iter(automaton.table[op])))
+ %>
+ % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
+ ${automaton.table[op][indices]},
+ % endfor
+ },
+ },
% endfor
+};
+
+static void
+${pass_name}_pre_block(nir_block *block, uint16_t *states)
+{
+ nir_foreach_instr(instr, block) {
+ switch (instr->type) {
+ case nir_instr_type_alu: {
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+ nir_op op = alu->op;
+ uint16_t search_op = nir_search_op_for_nir_op(op);
+ const struct per_op_table *tbl = &${pass_name}_table[search_op];
+ if (tbl->num_filtered_states == 0)
+ continue;
+
+ /* Calculate the index into the transition table. Note the index
+ * calculated must match the iteration order of Python's
+ * itertools.product(), which was used to emit the transition
+ * table.
+ */
+ uint16_t index = 0;
+ for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
+ index *= tbl->num_filtered_states;
+ index += tbl->filter[states[alu->src[i].src.ssa->index]];
+ }
+ states[alu->dest.dest.ssa.index] = tbl->table[index];
+ break;
+ }
+
+ case nir_instr_type_load_const: {
+ nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
+ states[load_const->def.index] = CONST_STATE;
+ break;
+ }
+
+ default:
+ break;
+ }
+ }
+}
static bool
-${pass_name}_block(nir_block *block, const bool *condition_flags,
- void *mem_ctx)
+${pass_name}_block(nir_builder *build, nir_block *block,
+ const uint16_t *states, const bool *condition_flags)
{
bool progress = false;
+ const unsigned execution_mode = build->shader->info.float_controls_execution_mode;
nir_foreach_instr_reverse_safe(instr, block) {
if (instr->type != nir_instr_type_alu)
if (!alu->dest.dest.is_ssa)
continue;
- switch (alu->op) {
- % for opcode in xform_dict.keys():
- case nir_op_${opcode}:
- for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
- const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
+ unsigned bit_size = alu->dest.dest.ssa.bit_size;
+ const bool ignore_inexact =
+ nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
+ nir_is_denorm_flush_to_zero(execution_mode, bit_size);
+
+ switch (states[alu->dest.dest.ssa.index]) {
+% for i in range(len(automaton.state_patterns)):
+ case ${i}:
+ % if automaton.state_patterns[i]:
+ for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
+ const struct transform *xform = &${pass_name}_state${i}_xforms[i];
if (condition_flags[xform->condition_offset] &&
- nir_replace_instr(alu, xform->search, xform->replace,
- mem_ctx)) {
+ !(xform->search->inexact && ignore_inexact) &&
+ nir_replace_instr(build, alu, xform->search, xform->replace)) {
progress = true;
break;
}
}
+ % endif
break;
- % endfor
- default:
- break;
+% endfor
+ default: assert(0);
}
}
static bool
${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
{
- void *mem_ctx = ralloc_parent(impl);
bool progress = false;
+ nir_builder build;
+ nir_builder_init(&build, impl);
+
+ /* Note: it's important here that we're allocating a zeroed array, since
+ * state 0 is the default state, which means we don't have to visit
+ * anything other than constants and ALU instructions.
+ */
+ uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
+
+ nir_foreach_block(block, impl) {
+ ${pass_name}_pre_block(block, states);
+ }
+
nir_foreach_block_reverse(block, impl) {
- progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
+ progress |= ${pass_name}_block(&build, block, states, condition_flags);
}
- if (progress)
+ free(states);
+
+ if (progress) {
nir_metadata_preserve(impl, nir_metadata_block_index |
nir_metadata_dominance);
+ } else {
+#ifndef NDEBUG
+ impl->valid_metadata &= ~nir_metadata_not_properly_reset;
+#endif
+ }
return progress;
}
bool progress = false;
bool condition_flags[${len(condition_list)}];
const nir_shader_compiler_options *options = shader->options;
+ const shader_info *info = &shader->info;
(void) options;
+ (void) info;
% for index, condition in enumerate(condition_list):
condition_flags[${index}] = ${condition};
}
""")
+
class AlgebraicPass(object):
def __init__(self, pass_name, transforms):
- self.xform_dict = OrderedDict()
+ self.xforms = []
+ self.opcode_xforms = defaultdict(lambda : [])
self.pass_name = pass_name
error = False
error = True
continue
- if xform.search.opcode not in self.xform_dict:
- self.xform_dict[xform.search.opcode] = []
+ self.xforms.append(xform)
+ if xform.search.opcode in conv_opcode_types:
+ dst_type = conv_opcode_types[xform.search.opcode]
+ for size in type_sizes(dst_type):
+ sized_opcode = xform.search.opcode + str(size)
+ self.opcode_xforms[sized_opcode].append(xform)
+ else:
+ self.opcode_xforms[xform.search.opcode].append(xform)
+
+ # Check to make sure the search pattern does not unexpectedly contain
+ # more commutative expressions than match_expression (nir_search.c)
+ # can handle.
+ comm_exprs = xform.search.comm_exprs
+
+ if xform.search.many_commutative_expressions:
+ if comm_exprs <= nir_search_max_comm_ops:
+ print("Transform expected to have too many commutative " \
+ "expression but did not " \
+ "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
+ file=sys.stderr)
+ print(" " + str(xform), file=sys.stderr)
+ traceback.print_exc(file=sys.stderr)
+ print('', file=sys.stderr)
+ error = True
+ else:
+ if comm_exprs > nir_search_max_comm_ops:
+ print("Transformation with too many commutative expressions " \
+ "({} > {}). Modify pattern or annotate with " \
+ "\"many-comm-expr\".".format(comm_exprs,
+ nir_search_max_comm_ops),
+ file=sys.stderr)
+ print(" " + str(xform.search), file=sys.stderr)
+ print("{}".format(xform.search.cond), file=sys.stderr)
+ error = True
- self.xform_dict[xform.search.opcode].append(xform)
+ self.automaton = TreeAutomaton(self.xforms)
if error:
sys.exit(1)
+
def render(self):
return _algebraic_pass_template.render(pass_name=self.pass_name,
- xform_dict=self.xform_dict,
- condition_list=condition_list)
+ xforms=self.xforms,
+ opcode_xforms=self.opcode_xforms,
+ condition_list=condition_list,
+ automaton=self.automaton,
+ get_c_opcode=get_c_opcode,
+ itertools=itertools)