From: Ian Romanick Date: Mon, 24 Jun 2019 22:12:56 +0000 (-0700) Subject: nir/algebraic: Fail build when too many commutative expressions are used X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=8d6b35fffd70a967b143dbe670c65d25348b8f1b;p=mesa.git nir/algebraic: Fail build when too many commutative expressions are used Search patterns that are expected to have too many (e.g., the giant bitfield_reverse pattern) can be added to a white list. This would have saved me a few hours debugging. :( v2: Implement the expected-failure annotation as a property of the search-replace pattern instead of as a property of the whole list of patterns. Suggested by Connor. Reviewed-by: Connor Abbott Reviewed-by: Dylan Baker --- diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index 4600b812a0c..d15d4ba3d67 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -35,6 +35,9 @@ import traceback from nir_opcodes import opcodes, type_sizes +# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c +nir_search_max_comm_ops = 4 + # These opcodes are only employed by nir_search. This provides a mapping from # opcode to destination type. conv_opcode_types = { @@ -325,6 +328,20 @@ class Expression(Value): self._bit_size = int(m.group('bits')) if m.group('bits') else None self.inexact = m.group('inexact') is not None self.cond = m.group('cond') + + # "many-comm-expr" isn't really a condition. It's notification to the + # generator that this pattern is known to have too many commutative + # expressions, and an error should not be generated for this case. + self.many_commutative_expressions = False + if self.cond and self.cond.find("many-comm-expr") >= 0: + # Split the condition into a comma-separated list. Remove + # "many-comm-expr". If there is anything left, put it back together. + c = self.cond[1:-1].split(",") + c.remove("many-comm-expr") + + self.cond = "({})".format(",".join(c)) if c else None + self.many_commutative_expressions = True + self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) for (i, src) in enumerate(expr[1:]) ] @@ -1191,6 +1208,32 @@ class AlgebraicPass(object): else: self.opcode_xforms[xform.search.opcode].append(xform) + # Check to make sure the search pattern does not unexpectedly contain + # more commutative expressions than match_expression (nir_search.c) + # can handle. + comm_exprs = xform.search.comm_exprs + + if xform.search.many_commutative_expressions: + if comm_exprs <= nir_search_max_comm_ops: + print("Transform expected to have too many commutative " \ + "expression but did not " \ + "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), + file=sys.stderr) + print(" " + str(xform), file=sys.stderr) + traceback.print_exc(file=sys.stderr) + print('', file=sys.stderr) + error = True + else: + if comm_exprs > nir_search_max_comm_ops: + print("Transformation with too many commutative expressions " \ + "({} > {}). Modify pattern or annotate with " \ + "\"many-comm-expr\".".format(comm_exprs, + nir_search_max_comm_ops), + file=sys.stderr) + print(" " + str(xform.search), file=sys.stderr) + print("{}".format(xform.search.cond), file=sys.stderr) + error = True + self.automaton = TreeAutomaton(self.xforms) if error: diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 4fd9a93a22a..a8f0a83741b 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -66,6 +66,12 @@ e = 'e' # should only match that particular bit-size. In the replace half of the # expression this indicates that the constructed value should have that # bit-size. +# +# A special condition "many-comm-expr" can be used with expressions to note +# that the expression and its subexpressions have more commutative expressions +# than nir_replace_instr can handle. If this special condition is needed with +# another condition, the two can be separated by a comma (e.g., +# "(many-comm-expr,is_used_once)"). optimizations = [ @@ -1056,7 +1062,7 @@ def bitfield_reverse(u): step2 = ('ior', ('ishl', ('iand', step1, 0x00ff00ff), 8), ('ushr', ('iand', step1, 0xff00ff00), 8)) step3 = ('ior', ('ishl', ('iand', step2, 0x0f0f0f0f), 4), ('ushr', ('iand', step2, 0xf0f0f0f0), 4)) step4 = ('ior', ('ishl', ('iand', step3, 0x33333333), 2), ('ushr', ('iand', step3, 0xcccccccc), 2)) - step5 = ('ior', ('ishl', ('iand', step4, 0x55555555), 1), ('ushr', ('iand', step4, 0xaaaaaaaa), 1)) + step5 = ('ior(many-comm-expr)', ('ishl', ('iand', step4, 0x55555555), 1), ('ushr', ('iand', step4, 0xaaaaaaaa), 1)) return step5 diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index b8bedaa2013..2179ca0a311 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -30,6 +30,7 @@ #include "nir_builder.h" #include "util/half_float.h" +/* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */ #define NIR_SEARCH_MAX_COMM_OPS 4 struct match_state {