from nir_opcodes import opcodes, type_sizes
+# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
+nir_search_max_comm_ops = 8
+
# These opcodes are only employed by nir_search. This provides a mapping from
# opcode to destination type.
conv_opcode_types = {
'f2b' : 'bool',
}
+def get_c_opcode(op):
+ if op in conv_opcode_types:
+ return 'nir_search_op_' + op
+ else:
+ return 'nir_op_' + op
+
+
if sys.version_info < (3, 0):
integer_types = (int, long)
string_type = unicode
elif isinstance(val, (bool, float) + integer_types):
return Constant(val, name_base)
- __template = mako.template.Template("""
-static const ${val.c_type} ${val.name} = {
- { ${val.type_enum}, ${val.c_bit_size} },
-% if isinstance(val, Constant):
- ${val.type()}, { ${val.hex()} /* ${val.value} */ },
-% elif isinstance(val, Variable):
- ${val.index}, /* ${val.var_name} */
- ${'true' if val.is_constant else 'false'},
- ${val.type() or 'nir_type_invalid' },
- ${val.cond if val.cond else 'NULL'},
-% elif isinstance(val, Expression):
- ${'true' if val.inexact else 'false'},
- ${val.comm_expr_idx}, ${val.comm_exprs},
- ${val.c_opcode()},
- { ${', '.join(src.c_ptr for src in val.sources)} },
- ${val.cond if val.cond else 'NULL'},
-% endif
-};""")
-
def __init__(self, val, name, type_str):
self.in_val = str(val)
self.name = name
def c_type(self):
return "nir_search_" + self.type_str
- @property
- def c_ptr(self):
- return "&{0}.value".format(self.name)
+ def __c_name(self, cache):
+ if cache is not None and self.name in cache:
+ return cache[self.name]
+ else:
+ return self.name
+
+ def c_value_ptr(self, cache):
+ return "&{0}.value".format(self.__c_name(cache))
+
+ def c_ptr(self, cache):
+ return "&{0}".format(self.__c_name(cache))
@property
def c_bit_size(self):
# We represent these cases with a 0 bit-size.
return 0
- def render(self):
- return self.__template.render(val=self,
- Constant=Constant,
- Variable=Variable,
- Expression=Expression)
+ __template = mako.template.Template("""{
+ { ${val.type_enum}, ${val.c_bit_size} },
+% if isinstance(val, Constant):
+ ${val.type()}, { ${val.hex()} /* ${val.value} */ },
+% elif isinstance(val, Variable):
+ ${val.index}, /* ${val.var_name} */
+ ${'true' if val.is_constant else 'false'},
+ ${val.type() or 'nir_type_invalid' },
+ ${val.cond if val.cond else 'NULL'},
+ ${val.swizzle()},
+% elif isinstance(val, Expression):
+ ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
+ ${val.comm_expr_idx}, ${val.comm_exprs},
+ ${val.c_opcode()},
+ { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
+ ${val.cond if val.cond else 'NULL'},
+% endif
+};""")
+
+ def render(self, cache):
+ struct_init = self.__template.render(val=self, cache=cache,
+ Constant=Constant,
+ Variable=Variable,
+ Expression=Expression)
+ if cache is not None and struct_init in cache:
+ # If it's in the cache, register a name remap in the cache and render
+ # only a comment saying it's been remapped
+ cache[self.name] = cache[struct_init]
+ return "/* {} -> {} in the cache */\n".format(self.name,
+ cache[struct_init])
+ else:
+ if cache is not None:
+ cache[struct_init] = self.name
+ return "static const {} {} = {}\n".format(self.c_type, self.name,
+ struct_init)
_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
elif isinstance(self.value, float):
return "nir_type_float"
+ def equivalent(self, other):
+ """Check that two constants are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.value == other.value
+
_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
- r"(?P<cond>\([^\)]+\))?")
+ r"(?P<cond>\([^\)]+\))?"
+ r"(?P<swiz>\.[xyzw]+)?")
class Variable(Value):
def __init__(self, val, name, varset):
# constant. If we want to support names that have numeric or
# punctuation characters, we can me the first assertion more flexible.
assert self.var_name.isalpha()
- assert self.var_name is not 'True'
- assert self.var_name is not 'False'
+ assert self.var_name != 'True'
+ assert self.var_name != 'False'
self.is_constant = m.group('const') is not None
self.cond = m.group('cond')
self.required_type = m.group('type')
self._bit_size = int(m.group('bits')) if m.group('bits') else None
+ self.swiz = m.group('swiz')
if self.required_type == 'bool':
if self._bit_size is not None:
elif self.required_type == 'float':
return "nir_type_float"
-_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.index == other.index
+
+ def swizzle(self):
+ if self.swiz is not None:
+ swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w': 3}
+ return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
+ return '{0, 1, 2, 3}'
+
+_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
r"(?P<cond>\([^\)]+\))?")
class Expression(Value):
self.opcode = m.group('opcode')
self._bit_size = int(m.group('bits')) if m.group('bits') else None
self.inexact = m.group('inexact') is not None
+ self.exact = m.group('exact') is not None
self.cond = m.group('cond')
+
+ assert not self.inexact or not self.exact, \
+ 'Expression cannot be both exact and inexact.'
+
+ # "many-comm-expr" isn't really a condition. It's notification to the
+ # generator that this pattern is known to have too many commutative
+ # expressions, and an error should not be generated for this case.
+ self.many_commutative_expressions = False
+ if self.cond and self.cond.find("many-comm-expr") >= 0:
+ # Split the condition into a comma-separated list. Remove
+ # "many-comm-expr". If there is anything left, put it back together.
+ c = self.cond[1:-1].split(",")
+ c.remove("many-comm-expr")
+
+ self.cond = "({})".format(",".join(c)) if c else None
+ self.many_commutative_expressions = True
+
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
for (i, src) in enumerate(expr[1:]) ]
self.__index_comm_exprs(0)
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ This implementation does not check for equivalence due to commutativity,
+ but it could.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.sources) != len(other.sources):
+ return False
+
+ if self.opcode != other.opcode:
+ return False
+
+ return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
def __index_comm_exprs(self, base_idx):
"""Recursively count and index commutative expressions
"""
self.comm_exprs = 0
+
+ # A note about the explicit "len(self.sources)" check. The list of
+ # sources comes from user input, and that input might be bad. Check
+ # that the expected second source exists before accessing it. Without
+ # this check, a unit test that does "('iadd', 'a')" will crash.
if self.opcode not in conv_opcode_types and \
- "commutative" in opcodes[self.opcode].algebraic_properties:
+ "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+ len(self.sources) >= 2 and \
+ not self.sources[0].equivalent(self.sources[1]):
self.comm_expr_idx = base_idx
self.comm_exprs += 1
else:
return self.comm_exprs
def c_opcode(self):
- if self.opcode in conv_opcode_types:
- return 'nir_search_op_' + self.opcode
- else:
- return 'nir_op_' + self.opcode
+ return get_c_opcode(self.opcode)
- def render(self):
- srcs = "\n".join(src.render() for src in self.sources)
- return srcs + super(Expression, self).render()
+ def render(self, cache):
+ srcs = "\n".join(src.render(cache) for src in self.sources)
+ return srcs + super(Expression, self).render(cache)
class BitSizeValidator(object):
"""A class for validating bit sizes of expressions.
BitSizeValidator(varset).validate(self.search, self.replace)
+class TreeAutomaton(object):
+ """This class calculates a bottom-up tree automaton to quickly search for
+ the left-hand sides of tranforms. Tree automatons are a generalization of
+ classical NFA's and DFA's, where the transition function determines the
+ state of the parent node based on the state of its children. We construct a
+ deterministic automaton to match patterns, using a similar algorithm to the
+ classical NFA to DFA construction. At the moment, it only matches opcodes
+ and constants (without checking the actual value), leaving more detailed
+ checking to the search function which actually checks the leaves. The
+ automaton acts as a quick filter for the search function, requiring only n
+ + 1 table lookups for each n-source operation. The implementation is based
+ on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
+ In the language of that reference, this is a frontier-to-root deterministic
+ automaton using only symbol filtering. The filtering is crucial to reduce
+ both the time taken to generate the tables and the size of the tables.
+ """
+ def __init__(self, transforms):
+ self.patterns = [t.search for t in transforms]
+ self._compute_items()
+ self._build_table()
+ #print('num items: {}'.format(len(set(self.items.values()))))
+ #print('num states: {}'.format(len(self.states)))
+ #for state, patterns in zip(self.states, self.patterns):
+ # print('{}: num patterns: {}'.format(state, len(patterns)))
+
+ class IndexMap(object):
+ """An indexed list of objects, where one can either lookup an object by
+ index or find the index associated to an object quickly using a hash
+ table. Compared to a list, it has a constant time index(). Compared to a
+ set, it provides a stable iteration order.
+ """
+ def __init__(self, iterable=()):
+ self.objects = []
+ self.map = {}
+ for obj in iterable:
+ self.add(obj)
+
+ def __getitem__(self, i):
+ return self.objects[i]
+
+ def __contains__(self, obj):
+ return obj in self.map
+
+ def __len__(self):
+ return len(self.objects)
+
+ def __iter__(self):
+ return iter(self.objects)
+
+ def clear(self):
+ self.objects = []
+ self.map.clear()
+
+ def index(self, obj):
+ return self.map[obj]
+
+ def add(self, obj):
+ if obj in self.map:
+ return self.map[obj]
+ else:
+ index = len(self.objects)
+ self.objects.append(obj)
+ self.map[obj] = index
+ return index
+
+ def __repr__(self):
+ return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
+
+ class Item(object):
+ """This represents an "item" in the language of "Tree Automatons." This
+ is just a subtree of some pattern, which represents a potential partial
+ match at runtime. We deduplicate them, so that identical subtrees of
+ different patterns share the same object, and store some extra
+ information needed for the main algorithm as well.
+ """
+ def __init__(self, opcode, children):
+ self.opcode = opcode
+ self.children = children
+ # These are the indices of patterns for which this item is the root node.
+ self.patterns = []
+ # This the set of opcodes for parents of this item. Used to speed up
+ # filtering.
+ self.parent_ops = set()
+
+ def __str__(self):
+ return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
+
+ def __repr__(self):
+ return str(self)
+
+ def _compute_items(self):
+ """Build a set of all possible items, deduplicating them."""
+ # This is a map from (opcode, sources) to item.
+ self.items = {}
+
+ # The set of all opcodes used by the patterns. Used later to avoid
+ # building and emitting all the tables for opcodes that aren't used.
+ self.opcodes = self.IndexMap()
+
+ def get_item(opcode, children, pattern=None):
+ commutative = len(children) >= 2 \
+ and "2src_commutative" in opcodes[opcode].algebraic_properties
+ item = self.items.setdefault((opcode, children),
+ self.Item(opcode, children))
+ if commutative:
+ self.items[opcode, (children[1], children[0]) + children[2:]] = item
+ if pattern is not None:
+ item.patterns.append(pattern)
+ return item
+
+ self.wildcard = get_item("__wildcard", ())
+ self.const = get_item("__const", ())
+
+ def process_subpattern(src, pattern=None):
+ if isinstance(src, Constant):
+ # Note: we throw away the actual constant value!
+ return self.const
+ elif isinstance(src, Variable):
+ if src.is_constant:
+ return self.const
+ else:
+ # Note: we throw away which variable it is here! This special
+ # item is equivalent to nu in "Tree Automatons."
+ return self.wildcard
+ else:
+ assert isinstance(src, Expression)
+ opcode = src.opcode
+ stripped = opcode.rstrip('0123456789')
+ if stripped in conv_opcode_types:
+ # Matches that use conversion opcodes with a specific type,
+ # like f2b1, are tricky. Either we construct the automaton to
+ # match specific NIR opcodes like nir_op_f2b1, in which case we
+ # need to create separate items for each possible NIR opcode
+ # for patterns that have a generic opcode like f2b, or we
+ # construct it to match the search opcode, in which case we
+ # need to map f2b1 to f2b when constructing the automaton. Here
+ # we do the latter.
+ opcode = stripped
+ self.opcodes.add(opcode)
+ children = tuple(process_subpattern(c) for c in src.sources)
+ item = get_item(opcode, children, pattern)
+ for i, child in enumerate(children):
+ child.parent_ops.add(opcode)
+ return item
+
+ for i, pattern in enumerate(self.patterns):
+ process_subpattern(pattern, i)
+
+ def _build_table(self):
+ """This is the core algorithm which builds up the transition table. It
+ is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
+ Comp_a and Filt_{a,i} using integers to identify match sets." It
+ simultaneously builds up a list of all possible "match sets" or
+ "states", where each match set represents the set of Item's that match a
+ given instruction, and builds up the transition table between states.
+ """
+ # Map from opcode + filtered state indices to transitioned state.
+ self.table = defaultdict(dict)
+ # Bijection from state to index. q in the original algorithm is
+ # len(self.states)
+ self.states = self.IndexMap()
+ # List of pattern matches for each state index.
+ self.state_patterns = []
+ # Map from state index to filtered state index for each opcode.
+ self.filter = defaultdict(list)
+ # Bijections from filtered state to filtered state index for each
+ # opcode, called the "representor sets" in the original algorithm.
+ # q_{a,j} in the original algorithm is len(self.rep[op]).
+ self.rep = defaultdict(self.IndexMap)
+
+ # Everything in self.states with a index at least worklist_index is part
+ # of the worklist of newly created states. There is also a worklist of
+ # newly fitered states for each opcode, for which worklist_indices
+ # serves a similar purpose. worklist_index corresponds to p in the
+ # original algorithm, while worklist_indices is p_{a,j} (although since
+ # we only filter by opcode/symbol, it's really just p_a).
+ self.worklist_index = 0
+ worklist_indices = defaultdict(lambda: 0)
+
+ # This is the set of opcodes for which the filtered worklist is non-empty.
+ # It's used to avoid scanning opcodes for which there is nothing to
+ # process when building the transition table. It corresponds to new_a in
+ # the original algorithm.
+ new_opcodes = self.IndexMap()
+
+ # Process states on the global worklist, filtering them for each opcode,
+ # updating the filter tables, and updating the filtered worklists if any
+ # new filtered states are found. Similar to ComputeRepresenterSets() in
+ # the original algorithm, although that only processes a single state.
+ def process_new_states():
+ while self.worklist_index < len(self.states):
+ state = self.states[self.worklist_index]
+
+ # Calculate pattern matches for this state. Each pattern is
+ # assigned to a unique item, so we don't have to worry about
+ # deduplicating them here. However, we do have to sort them so
+ # that they're visited at runtime in the order they're specified
+ # in the source.
+ patterns = list(sorted(p for item in state for p in item.patterns))
+ assert len(self.state_patterns) == self.worklist_index
+ self.state_patterns.append(patterns)
+
+ # calculate filter table for this state, and update filtered
+ # worklists.
+ for op in self.opcodes:
+ filt = self.filter[op]
+ rep = self.rep[op]
+ filtered = frozenset(item for item in state if \
+ op in item.parent_ops)
+ if filtered in rep:
+ rep_index = rep.index(filtered)
+ else:
+ rep_index = rep.add(filtered)
+ new_opcodes.add(op)
+ assert len(filt) == self.worklist_index
+ filt.append(rep_index)
+ self.worklist_index += 1
+
+ # There are two start states: one which can only match as a wildcard,
+ # and one which can match as a wildcard or constant. These will be the
+ # states of intrinsics/other instructions and load_const instructions,
+ # respectively. The indices of these must match the definitions of
+ # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
+ # initialize things correctly.
+ self.states.add(frozenset((self.wildcard,)))
+ self.states.add(frozenset((self.const,self.wildcard)))
+ process_new_states()
+
+ while len(new_opcodes) > 0:
+ for op in new_opcodes:
+ rep = self.rep[op]
+ table = self.table[op]
+ op_worklist_index = worklist_indices[op]
+ if op in conv_opcode_types:
+ num_srcs = 1
+ else:
+ num_srcs = opcodes[op].num_inputs
+
+ # Iterate over all possible source combinations where at least one
+ # is on the worklist.
+ for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
+ if all(src_idx < op_worklist_index for src_idx in src_indices):
+ continue
+
+ srcs = tuple(rep[src_idx] for src_idx in src_indices)
+
+ # Try all possible pairings of source items and add the
+ # corresponding parent items. This is Comp_a from the paper.
+ parent = set(self.items[op, item_srcs] for item_srcs in
+ itertools.product(*srcs) if (op, item_srcs) in self.items)
+
+ # We could always start matching something else with a
+ # wildcard. This is Cl from the paper.
+ parent.add(self.wildcard)
+
+ table[src_indices] = self.states.add(frozenset(parent))
+ worklist_indices[op] = len(rep)
+ new_opcodes.clear()
+ process_new_states()
+
_algebraic_pass_template = mako.template.Template("""
#include "nir.h"
#include "nir_builder.h"
#include "nir_search.h"
#include "nir_search_helpers.h"
-#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-
-struct transform {
- const nir_search_expression *search;
- const nir_search_value *replace;
- unsigned condition_offset;
-};
-
-#endif
+/* What follows is NIR algebraic transform code for the following ${len(xforms)}
+ * transforms:
+% for xform in xforms:
+ * ${xform.search} => ${xform.replace}
+% endfor
+ */
+<% cache = {} %>
% for xform in xforms:
- ${xform.search.render()}
- ${xform.replace.render()}
+ ${xform.search.render(cache)}
+ ${xform.replace.render(cache)}
% endfor
-% for (opcode, xform_list) in sorted(opcode_xforms.items()):
-static const struct transform ${pass_name}_${opcode}_xforms[] = {
-% for xform in xform_list:
- { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
+% for state_id, state_xforms in enumerate(automaton.state_patterns):
+% if state_xforms: # avoid emitting a 0-length array for MSVC
+static const struct transform ${pass_name}_state${state_id}_xforms[] = {
+% for i in state_xforms:
+ { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
% endfor
};
+% endif
% endfor
-static bool
-${pass_name}_block(nir_builder *build, nir_block *block,
- const bool *condition_flags)
-{
- bool progress = false;
-
- nir_foreach_instr_reverse_safe(instr, block) {
- if (instr->type != nir_instr_type_alu)
- continue;
-
- nir_alu_instr *alu = nir_instr_as_alu(instr);
- if (!alu->dest.dest.is_ssa)
- continue;
-
- switch (alu->op) {
- % for opcode in sorted(opcode_xforms.keys()):
- case nir_op_${opcode}:
- for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
- const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
- if (condition_flags[xform->condition_offset] &&
- nir_replace_instr(build, alu, xform->search, xform->replace)) {
- progress = true;
- break;
- }
- }
- break;
+static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
+% for op in automaton.opcodes:
+ [${get_c_opcode(op)}] = {
+ .filter = (uint16_t []) {
+ % for e in automaton.filter[op]:
+ ${e},
% endfor
- default:
- break;
- }
- }
-
- return progress;
-}
-
-static bool
-${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
-{
- bool progress = false;
-
- nir_builder build;
- nir_builder_init(&build, impl);
-
- nir_foreach_block_reverse(block, impl) {
- progress |= ${pass_name}_block(&build, block, condition_flags);
- }
-
- if (progress) {
- nir_metadata_preserve(impl, nir_metadata_block_index |
- nir_metadata_dominance);
- } else {
-#ifndef NDEBUG
- impl->valid_metadata &= ~nir_metadata_not_properly_reset;
-#endif
- }
+ },
+ <%
+ num_filtered = len(automaton.rep[op])
+ %>
+ .num_filtered_states = ${num_filtered},
+ .table = (uint16_t []) {
+ <%
+ num_srcs = len(next(iter(automaton.table[op])))
+ %>
+ % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
+ ${automaton.table[op][indices]},
+ % endfor
+ },
+ },
+% endfor
+};
- return progress;
-}
+const struct transform *${pass_name}_transforms[] = {
+% for i in range(len(automaton.state_patterns)):
+ % if automaton.state_patterns[i]:
+ ${pass_name}_state${i}_xforms,
+ % else:
+ NULL,
+ % endif
+% endfor
+};
+const uint16_t ${pass_name}_transform_counts[] = {
+% for i in range(len(automaton.state_patterns)):
+ % if automaton.state_patterns[i]:
+ (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
+ % else:
+ 0,
+ % endif
+% endfor
+};
bool
${pass_name}(nir_shader *shader)
bool progress = false;
bool condition_flags[${len(condition_list)}];
const nir_shader_compiler_options *options = shader->options;
+ const shader_info *info = &shader->info;
(void) options;
+ (void) info;
% for index, condition in enumerate(condition_list):
condition_flags[${index}] = ${condition};
% endfor
nir_foreach_function(function, shader) {
- if (function->impl)
- progress |= ${pass_name}_impl(function->impl, condition_flags);
+ if (function->impl) {
+ progress |= nir_algebraic_impl(function->impl, condition_flags,
+ ${pass_name}_transforms,
+ ${pass_name}_transform_counts,
+ ${pass_name}_table);
+ }
}
return progress;
}
""")
+
class AlgebraicPass(object):
def __init__(self, pass_name, transforms):
self.xforms = []
else:
self.opcode_xforms[xform.search.opcode].append(xform)
+ # Check to make sure the search pattern does not unexpectedly contain
+ # more commutative expressions than match_expression (nir_search.c)
+ # can handle.
+ comm_exprs = xform.search.comm_exprs
+
+ if xform.search.many_commutative_expressions:
+ if comm_exprs <= nir_search_max_comm_ops:
+ print("Transform expected to have too many commutative " \
+ "expression but did not " \
+ "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
+ file=sys.stderr)
+ print(" " + str(xform), file=sys.stderr)
+ traceback.print_exc(file=sys.stderr)
+ print('', file=sys.stderr)
+ error = True
+ else:
+ if comm_exprs > nir_search_max_comm_ops:
+ print("Transformation with too many commutative expressions " \
+ "({} > {}). Modify pattern or annotate with " \
+ "\"many-comm-expr\".".format(comm_exprs,
+ nir_search_max_comm_ops),
+ file=sys.stderr)
+ print(" " + str(xform.search), file=sys.stderr)
+ print("{}".format(xform.search.cond), file=sys.stderr)
+ error = True
+
+ self.automaton = TreeAutomaton(self.xforms)
+
if error:
sys.exit(1)
return _algebraic_pass_template.render(pass_name=self.pass_name,
xforms=self.xforms,
opcode_xforms=self.opcode_xforms,
- condition_list=condition_list)
+ condition_list=condition_list,
+ automaton=self.automaton,
+ get_c_opcode=get_c_opcode,
+ itertools=itertools)