from nir_opcodes import opcodes, type_sizes
+# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
+nir_search_max_comm_ops = 8
+
# These opcodes are only employed by nir_search. This provides a mapping from
# opcode to destination type.
conv_opcode_types = {
${'true' if val.is_constant else 'false'},
${val.type() or 'nir_type_invalid' },
${val.cond if val.cond else 'NULL'},
+ ${val.swizzle()},
% elif isinstance(val, Expression):
- ${'true' if val.inexact else 'false'},
+ ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
${val.comm_expr_idx}, ${val.comm_exprs},
${val.c_opcode()},
{ ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
elif isinstance(self.value, float):
return "nir_type_float"
+ def equivalent(self, other):
+ """Check that two constants are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.value == other.value
+
_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
- r"(?P<cond>\([^\)]+\))?")
+ r"(?P<cond>\([^\)]+\))?"
+ r"(?P<swiz>\.[xyzw]+)?")
class Variable(Value):
def __init__(self, val, name, varset):
# constant. If we want to support names that have numeric or
# punctuation characters, we can me the first assertion more flexible.
assert self.var_name.isalpha()
- assert self.var_name is not 'True'
- assert self.var_name is not 'False'
+ assert self.var_name != 'True'
+ assert self.var_name != 'False'
self.is_constant = m.group('const') is not None
self.cond = m.group('cond')
self.required_type = m.group('type')
self._bit_size = int(m.group('bits')) if m.group('bits') else None
+ self.swiz = m.group('swiz')
if self.required_type == 'bool':
if self._bit_size is not None:
elif self.required_type == 'float':
return "nir_type_float"
-_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ return self.index == other.index
+
+ def swizzle(self):
+ if self.swiz is not None:
+ swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w': 3}
+ return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
+ return '{0, 1, 2, 3}'
+
+_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
r"(?P<cond>\([^\)]+\))?")
class Expression(Value):
self.opcode = m.group('opcode')
self._bit_size = int(m.group('bits')) if m.group('bits') else None
self.inexact = m.group('inexact') is not None
+ self.exact = m.group('exact') is not None
self.cond = m.group('cond')
+
+ assert not self.inexact or not self.exact, \
+ 'Expression cannot be both exact and inexact.'
+
+ # "many-comm-expr" isn't really a condition. It's notification to the
+ # generator that this pattern is known to have too many commutative
+ # expressions, and an error should not be generated for this case.
+ self.many_commutative_expressions = False
+ if self.cond and self.cond.find("many-comm-expr") >= 0:
+ # Split the condition into a comma-separated list. Remove
+ # "many-comm-expr". If there is anything left, put it back together.
+ c = self.cond[1:-1].split(",")
+ c.remove("many-comm-expr")
+
+ self.cond = "({})".format(",".join(c)) if c else None
+ self.many_commutative_expressions = True
+
self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
for (i, src) in enumerate(expr[1:]) ]
self.__index_comm_exprs(0)
+ def equivalent(self, other):
+ """Check that two variables are equivalent.
+
+ This is check is much weaker than equality. One generally cannot be
+ used in place of the other. Using this implementation for the __eq__
+ will break BitSizeValidator.
+
+ This implementation does not check for equivalence due to commutativity,
+ but it could.
+
+ """
+ if not isinstance(other, type(self)):
+ return False
+
+ if len(self.sources) != len(other.sources):
+ return False
+
+ if self.opcode != other.opcode:
+ return False
+
+ return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
+
def __index_comm_exprs(self, base_idx):
"""Recursively count and index commutative expressions
"""
self.comm_exprs = 0
+
+ # A note about the explicit "len(self.sources)" check. The list of
+ # sources comes from user input, and that input might be bad. Check
+ # that the expected second source exists before accessing it. Without
+ # this check, a unit test that does "('iadd', 'a')" will crash.
if self.opcode not in conv_opcode_types and \
- "2src_commutative" in opcodes[self.opcode].algebraic_properties:
+ "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
+ len(self.sources) >= 2 and \
+ not self.sources[0].equivalent(self.sources[1]):
self.comm_expr_idx = base_idx
self.comm_exprs += 1
else:
#include "nir_search.h"
#include "nir_search_helpers.h"
-#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
-
-struct transform {
- const nir_search_expression *search;
- const nir_search_value *replace;
- unsigned condition_offset;
-};
-
-struct per_op_table {
- const uint16_t *filter;
- unsigned num_filtered_states;
- const uint16_t *table;
-};
-
-/* Note: these must match the start states created in
- * TreeAutomaton._build_table()
+/* What follows is NIR algebraic transform code for the following ${len(xforms)}
+ * transforms:
+% for xform in xforms:
+ * ${xform.search} => ${xform.replace}
+% endfor
*/
-/* WILDCARD_STATE = 0 is set by zeroing the state array */
-static const uint16_t CONST_STATE = 1;
-
-#endif
-
<% cache = {} %>
% for xform in xforms:
${xform.search.render(cache)}
% endfor
};
-static void
-${pass_name}_pre_block(nir_block *block, uint16_t *states)
-{
- nir_foreach_instr(instr, block) {
- switch (instr->type) {
- case nir_instr_type_alu: {
- nir_alu_instr *alu = nir_instr_as_alu(instr);
- nir_op op = alu->op;
- uint16_t search_op = nir_search_op_for_nir_op(op);
- const struct per_op_table *tbl = &${pass_name}_table[search_op];
- if (tbl->num_filtered_states == 0)
- continue;
-
- /* Calculate the index into the transition table. Note the index
- * calculated must match the iteration order of Python's
- * itertools.product(), which was used to emit the transition
- * table.
- */
- uint16_t index = 0;
- for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
- index *= tbl->num_filtered_states;
- index += tbl->filter[states[alu->src[i].src.ssa->index]];
- }
- states[alu->dest.dest.ssa.index] = tbl->table[index];
- break;
- }
-
- case nir_instr_type_load_const: {
- nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
- states[load_const->def.index] = CONST_STATE;
- break;
- }
-
- default:
- break;
- }
- }
-}
-
-static bool
-${pass_name}_block(nir_builder *build, nir_block *block,
- const uint16_t *states, const bool *condition_flags)
-{
- bool progress = false;
-
- nir_foreach_instr_reverse_safe(instr, block) {
- if (instr->type != nir_instr_type_alu)
- continue;
-
- nir_alu_instr *alu = nir_instr_as_alu(instr);
- if (!alu->dest.dest.is_ssa)
- continue;
-
- switch (states[alu->dest.dest.ssa.index]) {
+const struct transform *${pass_name}_transforms[] = {
% for i in range(len(automaton.state_patterns)):
- case ${i}:
- % if automaton.state_patterns[i]:
- for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
- const struct transform *xform = &${pass_name}_state${i}_xforms[i];
- if (condition_flags[xform->condition_offset] &&
- nir_replace_instr(build, alu, xform->search, xform->replace)) {
- progress = true;
- break;
- }
- }
- % endif
- break;
+ % if automaton.state_patterns[i]:
+ ${pass_name}_state${i}_xforms,
+ % else:
+ NULL,
+ % endif
% endfor
- default: assert(0);
- }
- }
-
- return progress;
-}
-
-static bool
-${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
-{
- bool progress = false;
-
- nir_builder build;
- nir_builder_init(&build, impl);
-
- /* Note: it's important here that we're allocating a zeroed array, since
- * state 0 is the default state, which means we don't have to visit
- * anything other than constants and ALU instructions.
- */
- uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
-
- nir_foreach_block(block, impl) {
- ${pass_name}_pre_block(block, states);
- }
-
- nir_foreach_block_reverse(block, impl) {
- progress |= ${pass_name}_block(&build, block, states, condition_flags);
- }
-
- free(states);
-
- if (progress) {
- nir_metadata_preserve(impl, nir_metadata_block_index |
- nir_metadata_dominance);
- } else {
-#ifndef NDEBUG
- impl->valid_metadata &= ~nir_metadata_not_properly_reset;
-#endif
- }
-
- return progress;
-}
+};
+const uint16_t ${pass_name}_transform_counts[] = {
+% for i in range(len(automaton.state_patterns)):
+ % if automaton.state_patterns[i]:
+ (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
+ % else:
+ 0,
+ % endif
+% endfor
+};
bool
${pass_name}(nir_shader *shader)
% endfor
nir_foreach_function(function, shader) {
- if (function->impl)
- progress |= ${pass_name}_impl(function->impl, condition_flags);
+ if (function->impl) {
+ progress |= nir_algebraic_impl(function->impl, condition_flags,
+ ${pass_name}_transforms,
+ ${pass_name}_transform_counts,
+ ${pass_name}_table);
+ }
}
return progress;
}
""")
-
class AlgebraicPass(object):
def __init__(self, pass_name, transforms):
else:
self.opcode_xforms[xform.search.opcode].append(xform)
+ # Check to make sure the search pattern does not unexpectedly contain
+ # more commutative expressions than match_expression (nir_search.c)
+ # can handle.
+ comm_exprs = xform.search.comm_exprs
+
+ if xform.search.many_commutative_expressions:
+ if comm_exprs <= nir_search_max_comm_ops:
+ print("Transform expected to have too many commutative " \
+ "expression but did not " \
+ "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
+ file=sys.stderr)
+ print(" " + str(xform), file=sys.stderr)
+ traceback.print_exc(file=sys.stderr)
+ print('', file=sys.stderr)
+ error = True
+ else:
+ if comm_exprs > nir_search_max_comm_ops:
+ print("Transformation with too many commutative expressions " \
+ "({} > {}). Modify pattern or annotate with " \
+ "\"many-comm-expr\".".format(comm_exprs,
+ nir_search_max_comm_ops),
+ file=sys.stderr)
+ print(" " + str(xform.search), file=sys.stderr)
+ print("{}".format(xform.search.cond), file=sys.stderr)
+ error = True
+
self.automaton = TreeAutomaton(self.xforms)
if error: