X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fcompiler%2Fnir%2Fnir_algebraic.py;h=6871af36ef8f06ace44749b0bbd63ad2d9ea67f2;hb=2fcfcca842a6c3ca77f38791da88b185839f064a;hp=fe9d1051e67ecee8a74082f29325e5b1edb4d66f;hpb=2623653126985be5aca1a29e24bdecb4bb42c8b4;p=mesa.git diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index fe9d1051e67..6871af36ef8 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -35,6 +35,9 @@ import traceback from nir_opcodes import opcodes, type_sizes +# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c +nir_search_max_comm_ops = 8 + # These opcodes are only employed by nir_search. This provides a mapping from # opcode to destination type. conv_opcode_types = { @@ -51,6 +54,13 @@ conv_opcode_types = { 'f2b' : 'bool', } +def get_c_opcode(op): + if op in conv_opcode_types: + return 'nir_search_op_' + op + else: + return 'nir_op_' + op + + if sys.version_info < (3, 0): integer_types = (int, long) string_type = unicode @@ -102,24 +112,6 @@ class Value(object): elif isinstance(val, (bool, float) + integer_types): return Constant(val, name_base) - __template = mako.template.Template(""" -static const ${val.c_type} ${val.name} = { - { ${val.type_enum}, ${val.c_bit_size} }, -% if isinstance(val, Constant): - ${val.type()}, { ${val.hex()} /* ${val.value} */ }, -% elif isinstance(val, Variable): - ${val.index}, /* ${val.var_name} */ - ${'true' if val.is_constant else 'false'}, - ${val.type() or 'nir_type_invalid' }, - ${val.cond if val.cond else 'NULL'}, -% elif isinstance(val, Expression): - ${'true' if val.inexact else 'false'}, - ${val.c_opcode()}, - { ${', '.join(src.c_ptr for src in val.sources)} }, - ${val.cond if val.cond else 'NULL'}, -% endif -};""") - def __init__(self, val, name, type_str): self.in_val = str(val) self.name = name @@ -170,9 +162,17 @@ static const ${val.c_type} ${val.name} = { def c_type(self): return "nir_search_" + self.type_str - @property - def c_ptr(self): - return "&{0}.value".format(self.name) + def __c_name(self, cache): + if cache is not None and self.name in cache: + return cache[self.name] + else: + return self.name + + def c_value_ptr(self, cache): + return "&{0}.value".format(self.__c_name(cache)) + + def c_ptr(self, cache): + return "&{0}".format(self.__c_name(cache)) @property def c_bit_size(self): @@ -189,11 +189,41 @@ static const ${val.c_type} ${val.name} = { # We represent these cases with a 0 bit-size. return 0 - def render(self): - return self.__template.render(val=self, - Constant=Constant, - Variable=Variable, - Expression=Expression) + __template = mako.template.Template("""{ + { ${val.type_enum}, ${val.c_bit_size} }, +% if isinstance(val, Constant): + ${val.type()}, { ${val.hex()} /* ${val.value} */ }, +% elif isinstance(val, Variable): + ${val.index}, /* ${val.var_name} */ + ${'true' if val.is_constant else 'false'}, + ${val.type() or 'nir_type_invalid' }, + ${val.cond if val.cond else 'NULL'}, + ${val.swizzle()}, +% elif isinstance(val, Expression): + ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'}, + ${val.comm_expr_idx}, ${val.comm_exprs}, + ${val.c_opcode()}, + { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} }, + ${val.cond if val.cond else 'NULL'}, +% endif +};""") + + def render(self, cache): + struct_init = self.__template.render(val=self, cache=cache, + Constant=Constant, + Variable=Variable, + Expression=Expression) + if cache is not None and struct_init in cache: + # If it's in the cache, register a name remap in the cache and render + # only a comment saying it's been remapped + cache[self.name] = cache[struct_init] + return "/* {} -> {} in the cache */\n".format(self.name, + cache[struct_init]) + else: + if cache is not None: + cache[struct_init] = self.name + return "static const {} {} = {}\n".format(self.c_type, self.name, + struct_init) _constant_re = re.compile(r"(?P[^@\(]+)(?:@(?P\d+))?") @@ -240,16 +270,34 @@ class Constant(Value): elif isinstance(self.value, float): return "nir_type_float" + def equivalent(self, other): + """Check that two constants are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + """ + if not isinstance(other, type(self)): + return False + + return self.value == other.value + +# The $ at the end forces there to be an error if any part of the string +# doesn't match one of the field patterns. _var_name_re = re.compile(r"(?P#)?(?P\w+)" r"(?:@(?Pint|uint|bool|float)?(?P\d+)?)?" - r"(?P\([^\)]+\))?") + r"(?P\([^\)]+\))?" + r"(?P\.[xyzw]+)?" + r"$") class Variable(Value): def __init__(self, val, name, varset): Value.__init__(self, val, name, "variable") m = _var_name_re.match(val) - assert m and m.group('name') is not None + assert m and m.group('name') is not None, \ + "Malformed variable name \"{}\".".format(val) self.var_name = m.group('name') @@ -257,13 +305,14 @@ class Variable(Value): # constant. If we want to support names that have numeric or # punctuation characters, we can me the first assertion more flexible. assert self.var_name.isalpha() - assert self.var_name is not 'True' - assert self.var_name is not 'False' + assert self.var_name != 'True' + assert self.var_name != 'False' self.is_constant = m.group('const') is not None self.cond = m.group('cond') self.required_type = m.group('type') self._bit_size = int(m.group('bits')) if m.group('bits') else None + self.swiz = m.group('swiz') if self.required_type == 'bool': if self._bit_size is not None: @@ -284,7 +333,30 @@ class Variable(Value): elif self.required_type == 'float': return "nir_type_float" -_opcode_re = re.compile(r"(?P~)?(?P\w+)(?:@(?P\d+))?" + def equivalent(self, other): + """Check that two variables are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + """ + if not isinstance(other, type(self)): + return False + + return self.index == other.index + + def swizzle(self): + if self.swiz is not None: + swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, + 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, + 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, + 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, + 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } + return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' + return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' + +_opcode_re = re.compile(r"(?P~)?(?P!)?(?P\w+)(?:@(?P\d+))?" r"(?P\([^\)]+\))?") class Expression(Value): @@ -298,7 +370,25 @@ class Expression(Value): self.opcode = m.group('opcode') self._bit_size = int(m.group('bits')) if m.group('bits') else None self.inexact = m.group('inexact') is not None + self.exact = m.group('exact') is not None self.cond = m.group('cond') + + assert not self.inexact or not self.exact, \ + 'Expression cannot be both exact and inexact.' + + # "many-comm-expr" isn't really a condition. It's notification to the + # generator that this pattern is known to have too many commutative + # expressions, and an error should not be generated for this case. + self.many_commutative_expressions = False + if self.cond and self.cond.find("many-comm-expr") >= 0: + # Split the condition into a comma-separated list. Remove + # "many-comm-expr". If there is anything left, put it back together. + c = self.cond[1:-1].split(",") + c.remove("many-comm-expr") + + self.cond = "({})".format(",".join(c)) if c else None + self.many_commutative_expressions = True + self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) for (i, src) in enumerate(expr[1:]) ] @@ -307,16 +397,61 @@ class Expression(Value): 'Expression cannot use an unsized conversion opcode with ' \ 'an explicit size; that\'s silly.' + self.__index_comm_exprs(0) - def c_opcode(self): - if self.opcode in conv_opcode_types: - return 'nir_search_op_' + self.opcode + def equivalent(self, other): + """Check that two variables are equivalent. + + This is check is much weaker than equality. One generally cannot be + used in place of the other. Using this implementation for the __eq__ + will break BitSizeValidator. + + This implementation does not check for equivalence due to commutativity, + but it could. + + """ + if not isinstance(other, type(self)): + return False + + if len(self.sources) != len(other.sources): + return False + + if self.opcode != other.opcode: + return False + + return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) + + def __index_comm_exprs(self, base_idx): + """Recursively count and index commutative expressions + """ + self.comm_exprs = 0 + + # A note about the explicit "len(self.sources)" check. The list of + # sources comes from user input, and that input might be bad. Check + # that the expected second source exists before accessing it. Without + # this check, a unit test that does "('iadd', 'a')" will crash. + if self.opcode not in conv_opcode_types and \ + "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ + len(self.sources) >= 2 and \ + not self.sources[0].equivalent(self.sources[1]): + self.comm_expr_idx = base_idx + self.comm_exprs += 1 else: - return 'nir_op_' + self.opcode + self.comm_expr_idx = -1 - def render(self): - srcs = "\n".join(src.render() for src in self.sources) - return srcs + super(Expression, self).render() + for s in self.sources: + if isinstance(s, Expression): + s.__index_comm_exprs(base_idx + self.comm_exprs) + self.comm_exprs += s.comm_exprs + + return self.comm_exprs + + def c_opcode(self): + return get_c_opcode(self.opcode) + + def render(self, cache): + srcs = "\n".join(src.render(cache) for src in self.sources) + return srcs + super(Expression, self).render(cache) class BitSizeValidator(object): """A class for validating bit sizes of expressions. @@ -654,95 +789,338 @@ class SearchAndReplace(object): BitSizeValidator(varset).validate(self.search, self.replace) +class TreeAutomaton(object): + """This class calculates a bottom-up tree automaton to quickly search for + the left-hand sides of tranforms. Tree automatons are a generalization of + classical NFA's and DFA's, where the transition function determines the + state of the parent node based on the state of its children. We construct a + deterministic automaton to match patterns, using a similar algorithm to the + classical NFA to DFA construction. At the moment, it only matches opcodes + and constants (without checking the actual value), leaving more detailed + checking to the search function which actually checks the leaves. The + automaton acts as a quick filter for the search function, requiring only n + + 1 table lookups for each n-source operation. The implementation is based + on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." + In the language of that reference, this is a frontier-to-root deterministic + automaton using only symbol filtering. The filtering is crucial to reduce + both the time taken to generate the tables and the size of the tables. + """ + def __init__(self, transforms): + self.patterns = [t.search for t in transforms] + self._compute_items() + self._build_table() + #print('num items: {}'.format(len(set(self.items.values())))) + #print('num states: {}'.format(len(self.states))) + #for state, patterns in zip(self.states, self.patterns): + # print('{}: num patterns: {}'.format(state, len(patterns))) + + class IndexMap(object): + """An indexed list of objects, where one can either lookup an object by + index or find the index associated to an object quickly using a hash + table. Compared to a list, it has a constant time index(). Compared to a + set, it provides a stable iteration order. + """ + def __init__(self, iterable=()): + self.objects = [] + self.map = {} + for obj in iterable: + self.add(obj) + + def __getitem__(self, i): + return self.objects[i] + + def __contains__(self, obj): + return obj in self.map + + def __len__(self): + return len(self.objects) + + def __iter__(self): + return iter(self.objects) + + def clear(self): + self.objects = [] + self.map.clear() + + def index(self, obj): + return self.map[obj] + + def add(self, obj): + if obj in self.map: + return self.map[obj] + else: + index = len(self.objects) + self.objects.append(obj) + self.map[obj] = index + return index + + def __repr__(self): + return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' + + class Item(object): + """This represents an "item" in the language of "Tree Automatons." This + is just a subtree of some pattern, which represents a potential partial + match at runtime. We deduplicate them, so that identical subtrees of + different patterns share the same object, and store some extra + information needed for the main algorithm as well. + """ + def __init__(self, opcode, children): + self.opcode = opcode + self.children = children + # These are the indices of patterns for which this item is the root node. + self.patterns = [] + # This the set of opcodes for parents of this item. Used to speed up + # filtering. + self.parent_ops = set() + + def __str__(self): + return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' + + def __repr__(self): + return str(self) + + def _compute_items(self): + """Build a set of all possible items, deduplicating them.""" + # This is a map from (opcode, sources) to item. + self.items = {} + + # The set of all opcodes used by the patterns. Used later to avoid + # building and emitting all the tables for opcodes that aren't used. + self.opcodes = self.IndexMap() + + def get_item(opcode, children, pattern=None): + commutative = len(children) >= 2 \ + and "2src_commutative" in opcodes[opcode].algebraic_properties + item = self.items.setdefault((opcode, children), + self.Item(opcode, children)) + if commutative: + self.items[opcode, (children[1], children[0]) + children[2:]] = item + if pattern is not None: + item.patterns.append(pattern) + return item + + self.wildcard = get_item("__wildcard", ()) + self.const = get_item("__const", ()) + + def process_subpattern(src, pattern=None): + if isinstance(src, Constant): + # Note: we throw away the actual constant value! + return self.const + elif isinstance(src, Variable): + if src.is_constant: + return self.const + else: + # Note: we throw away which variable it is here! This special + # item is equivalent to nu in "Tree Automatons." + return self.wildcard + else: + assert isinstance(src, Expression) + opcode = src.opcode + stripped = opcode.rstrip('0123456789') + if stripped in conv_opcode_types: + # Matches that use conversion opcodes with a specific type, + # like f2b1, are tricky. Either we construct the automaton to + # match specific NIR opcodes like nir_op_f2b1, in which case we + # need to create separate items for each possible NIR opcode + # for patterns that have a generic opcode like f2b, or we + # construct it to match the search opcode, in which case we + # need to map f2b1 to f2b when constructing the automaton. Here + # we do the latter. + opcode = stripped + self.opcodes.add(opcode) + children = tuple(process_subpattern(c) for c in src.sources) + item = get_item(opcode, children, pattern) + for i, child in enumerate(children): + child.parent_ops.add(opcode) + return item + + for i, pattern in enumerate(self.patterns): + process_subpattern(pattern, i) + + def _build_table(self): + """This is the core algorithm which builds up the transition table. It + is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . + Comp_a and Filt_{a,i} using integers to identify match sets." It + simultaneously builds up a list of all possible "match sets" or + "states", where each match set represents the set of Item's that match a + given instruction, and builds up the transition table between states. + """ + # Map from opcode + filtered state indices to transitioned state. + self.table = defaultdict(dict) + # Bijection from state to index. q in the original algorithm is + # len(self.states) + self.states = self.IndexMap() + # List of pattern matches for each state index. + self.state_patterns = [] + # Map from state index to filtered state index for each opcode. + self.filter = defaultdict(list) + # Bijections from filtered state to filtered state index for each + # opcode, called the "representor sets" in the original algorithm. + # q_{a,j} in the original algorithm is len(self.rep[op]). + self.rep = defaultdict(self.IndexMap) + + # Everything in self.states with a index at least worklist_index is part + # of the worklist of newly created states. There is also a worklist of + # newly fitered states for each opcode, for which worklist_indices + # serves a similar purpose. worklist_index corresponds to p in the + # original algorithm, while worklist_indices is p_{a,j} (although since + # we only filter by opcode/symbol, it's really just p_a). + self.worklist_index = 0 + worklist_indices = defaultdict(lambda: 0) + + # This is the set of opcodes for which the filtered worklist is non-empty. + # It's used to avoid scanning opcodes for which there is nothing to + # process when building the transition table. It corresponds to new_a in + # the original algorithm. + new_opcodes = self.IndexMap() + + # Process states on the global worklist, filtering them for each opcode, + # updating the filter tables, and updating the filtered worklists if any + # new filtered states are found. Similar to ComputeRepresenterSets() in + # the original algorithm, although that only processes a single state. + def process_new_states(): + while self.worklist_index < len(self.states): + state = self.states[self.worklist_index] + + # Calculate pattern matches for this state. Each pattern is + # assigned to a unique item, so we don't have to worry about + # deduplicating them here. However, we do have to sort them so + # that they're visited at runtime in the order they're specified + # in the source. + patterns = list(sorted(p for item in state for p in item.patterns)) + assert len(self.state_patterns) == self.worklist_index + self.state_patterns.append(patterns) + + # calculate filter table for this state, and update filtered + # worklists. + for op in self.opcodes: + filt = self.filter[op] + rep = self.rep[op] + filtered = frozenset(item for item in state if \ + op in item.parent_ops) + if filtered in rep: + rep_index = rep.index(filtered) + else: + rep_index = rep.add(filtered) + new_opcodes.add(op) + assert len(filt) == self.worklist_index + filt.append(rep_index) + self.worklist_index += 1 + + # There are two start states: one which can only match as a wildcard, + # and one which can match as a wildcard or constant. These will be the + # states of intrinsics/other instructions and load_const instructions, + # respectively. The indices of these must match the definitions of + # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can + # initialize things correctly. + self.states.add(frozenset((self.wildcard,))) + self.states.add(frozenset((self.const,self.wildcard))) + process_new_states() + + while len(new_opcodes) > 0: + for op in new_opcodes: + rep = self.rep[op] + table = self.table[op] + op_worklist_index = worklist_indices[op] + if op in conv_opcode_types: + num_srcs = 1 + else: + num_srcs = opcodes[op].num_inputs + + # Iterate over all possible source combinations where at least one + # is on the worklist. + for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): + if all(src_idx < op_worklist_index for src_idx in src_indices): + continue + + srcs = tuple(rep[src_idx] for src_idx in src_indices) + + # Try all possible pairings of source items and add the + # corresponding parent items. This is Comp_a from the paper. + parent = set(self.items[op, item_srcs] for item_srcs in + itertools.product(*srcs) if (op, item_srcs) in self.items) + + # We could always start matching something else with a + # wildcard. This is Cl from the paper. + parent.add(self.wildcard) + + table[src_indices] = self.states.add(frozenset(parent)) + worklist_indices[op] = len(rep) + new_opcodes.clear() + process_new_states() + _algebraic_pass_template = mako.template.Template(""" #include "nir.h" #include "nir_builder.h" #include "nir_search.h" #include "nir_search_helpers.h" -#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS -#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS - -struct transform { - const nir_search_expression *search; - const nir_search_value *replace; - unsigned condition_offset; -}; - -#endif +/* What follows is NIR algebraic transform code for the following ${len(xforms)} + * transforms: +% for xform in xforms: + * ${xform.search} => ${xform.replace} +% endfor + */ +<% cache = {} %> % for xform in xforms: - ${xform.search.render()} - ${xform.replace.render()} + ${xform.search.render(cache)} + ${xform.replace.render(cache)} % endfor -% for (opcode, xform_list) in sorted(opcode_xforms.items()): -static const struct transform ${pass_name}_${opcode}_xforms[] = { -% for xform in xform_list: - { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} }, +% for state_id, state_xforms in enumerate(automaton.state_patterns): +% if state_xforms: # avoid emitting a 0-length array for MSVC +static const struct transform ${pass_name}_state${state_id}_xforms[] = { +% for i in state_xforms: + { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} }, % endfor }; +% endif % endfor -static bool -${pass_name}_block(nir_builder *build, nir_block *block, - const bool *condition_flags) -{ - bool progress = false; - - nir_foreach_instr_reverse_safe(instr, block) { - if (instr->type != nir_instr_type_alu) - continue; - - nir_alu_instr *alu = nir_instr_as_alu(instr); - if (!alu->dest.dest.is_ssa) - continue; - - switch (alu->op) { - % for opcode in sorted(opcode_xforms.keys()): - case nir_op_${opcode}: - for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) { - const struct transform *xform = &${pass_name}_${opcode}_xforms[i]; - if (condition_flags[xform->condition_offset] && - nir_replace_instr(build, alu, xform->search, xform->replace)) { - progress = true; - break; - } - } - break; +static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = { +% for op in automaton.opcodes: + [${get_c_opcode(op)}] = { + .filter = (uint16_t []) { + % for e in automaton.filter[op]: + ${e}, % endfor - default: - break; - } - } - - return progress; -} - -static bool -${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags) -{ - bool progress = false; - - nir_builder build; - nir_builder_init(&build, impl); - - nir_foreach_block_reverse(block, impl) { - progress |= ${pass_name}_block(&build, block, condition_flags); - } - - if (progress) { - nir_metadata_preserve(impl, nir_metadata_block_index | - nir_metadata_dominance); - } else { -#ifndef NDEBUG - impl->valid_metadata &= ~nir_metadata_not_properly_reset; -#endif - } + }, + <% + num_filtered = len(automaton.rep[op]) + %> + .num_filtered_states = ${num_filtered}, + .table = (uint16_t []) { + <% + num_srcs = len(next(iter(automaton.table[op]))) + %> + % for indices in itertools.product(range(num_filtered), repeat=num_srcs): + ${automaton.table[op][indices]}, + % endfor + }, + }, +% endfor +}; - return progress; -} +const struct transform *${pass_name}_transforms[] = { +% for i in range(len(automaton.state_patterns)): + % if automaton.state_patterns[i]: + ${pass_name}_state${i}_xforms, + % else: + NULL, + % endif +% endfor +}; +const uint16_t ${pass_name}_transform_counts[] = { +% for i in range(len(automaton.state_patterns)): + % if automaton.state_patterns[i]: + (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms), + % else: + 0, + % endif +% endfor +}; bool ${pass_name}(nir_shader *shader) @@ -750,21 +1128,28 @@ ${pass_name}(nir_shader *shader) bool progress = false; bool condition_flags[${len(condition_list)}]; const nir_shader_compiler_options *options = shader->options; + const shader_info *info = &shader->info; (void) options; + (void) info; % for index, condition in enumerate(condition_list): condition_flags[${index}] = ${condition}; % endfor nir_foreach_function(function, shader) { - if (function->impl) - progress |= ${pass_name}_impl(function->impl, condition_flags); + if (function->impl) { + progress |= nir_algebraic_impl(function->impl, condition_flags, + ${pass_name}_transforms, + ${pass_name}_transform_counts, + ${pass_name}_table); + } } return progress; } """) + class AlgebraicPass(object): def __init__(self, pass_name, transforms): self.xforms = [] @@ -794,6 +1179,34 @@ class AlgebraicPass(object): else: self.opcode_xforms[xform.search.opcode].append(xform) + # Check to make sure the search pattern does not unexpectedly contain + # more commutative expressions than match_expression (nir_search.c) + # can handle. + comm_exprs = xform.search.comm_exprs + + if xform.search.many_commutative_expressions: + if comm_exprs <= nir_search_max_comm_ops: + print("Transform expected to have too many commutative " \ + "expression but did not " \ + "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), + file=sys.stderr) + print(" " + str(xform), file=sys.stderr) + traceback.print_exc(file=sys.stderr) + print('', file=sys.stderr) + error = True + else: + if comm_exprs > nir_search_max_comm_ops: + print("Transformation with too many commutative expressions " \ + "({} > {}). Modify pattern or annotate with " \ + "\"many-comm-expr\".".format(comm_exprs, + nir_search_max_comm_ops), + file=sys.stderr) + print(" " + str(xform.search), file=sys.stderr) + print("{}".format(xform.search.cond), file=sys.stderr) + error = True + + self.automaton = TreeAutomaton(self.xforms) + if error: sys.exit(1) @@ -802,4 +1215,7 @@ class AlgebraicPass(object): return _algebraic_pass_template.render(pass_name=self.pass_name, xforms=self.xforms, opcode_xforms=self.opcode_xforms, - condition_list=condition_list) + condition_list=condition_list, + automaton=self.automaton, + get_c_opcode=get_c_opcode, + itertools=itertools)