nir/algebraic: Refactor codegen a bit
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 if sys.version_info < (3, 0):
39 integer_types = (int, long)
40 string_type = unicode
41
42 else:
43 integer_types = (int, )
44 string_type = str
45
46 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
47
48 def type_bits(type_str):
49 m = _type_re.match(type_str)
50 assert m.group('type')
51
52 if m.group('bits') is None:
53 return 0
54 else:
55 return int(m.group('bits'))
56
57 # Represents a set of variables, each with a unique id
58 class VarSet(object):
59 def __init__(self):
60 self.names = {}
61 self.ids = itertools.count()
62 self.immutable = False;
63
64 def __getitem__(self, name):
65 if name not in self.names:
66 assert not self.immutable, "Unknown replacement variable: " + name
67 self.names[name] = next(self.ids)
68
69 return self.names[name]
70
71 def lock(self):
72 self.immutable = True
73
74 class Value(object):
75 @staticmethod
76 def create(val, name_base, varset):
77 if isinstance(val, bytes):
78 val = val.decode('utf-8')
79
80 if isinstance(val, tuple):
81 return Expression(val, name_base, varset)
82 elif isinstance(val, Expression):
83 return val
84 elif isinstance(val, string_type):
85 return Variable(val, name_base, varset)
86 elif isinstance(val, (bool, float) + integer_types):
87 return Constant(val, name_base)
88
89 __template = mako.template.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.c_bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
104 % endif
105 };""")
106
107 def __init__(self, val, name, type_str):
108 self.in_val = str(val)
109 self.name = name
110 self.type_str = type_str
111
112 def __str__(self):
113 return self.in_val
114
115 def get_bit_size(self):
116 """Get the physical bit-size that has been chosen for this value, or if
117 there is none, the canonical value which currently represents this
118 bit-size class. Variables will be preferred, i.e. if there are any
119 variables in the equivalence class, the canonical value will be a
120 variable. We do this since we'll need to know which variable each value
121 is equivalent to when constructing the replacement expression. This is
122 the "find" part of the union-find algorithm.
123 """
124 bit_size = self
125
126 while isinstance(bit_size, Value):
127 if bit_size._bit_size is None:
128 break
129 bit_size = bit_size._bit_size
130
131 if bit_size is not self:
132 self._bit_size = bit_size
133 return bit_size
134
135 def set_bit_size(self, other):
136 """Make self.get_bit_size() return what other.get_bit_size() return
137 before calling this, or just "other" if it's a concrete bit-size. This is
138 the "union" part of the union-find algorithm.
139 """
140
141 self_bit_size = self.get_bit_size()
142 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
143
144 if self_bit_size == other_bit_size:
145 return
146
147 self_bit_size._bit_size = other_bit_size
148
149 @property
150 def type_enum(self):
151 return "nir_search_value_" + self.type_str
152
153 @property
154 def c_type(self):
155 return "nir_search_" + self.type_str
156
157 @property
158 def c_ptr(self):
159 return "&{0}.value".format(self.name)
160
161 @property
162 def c_bit_size(self):
163 bit_size = self.get_bit_size()
164 if isinstance(bit_size, int):
165 return bit_size
166 elif isinstance(bit_size, Variable):
167 return -bit_size.index - 1
168 else:
169 # If the bit-size class is neither a variable, nor an actual bit-size, then
170 # - If it's in the search expression, we don't need to check anything
171 # - If it's in the replace expression, either it's ambiguous (in which
172 # case we'd reject it), or it equals the bit-size of the search value
173 # We represent these cases with a 0 bit-size.
174 return 0
175
176 def render(self):
177 return self.__template.render(val=self,
178 Constant=Constant,
179 Variable=Variable,
180 Expression=Expression)
181
182 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
183
184 class Constant(Value):
185 def __init__(self, val, name):
186 Value.__init__(self, val, name, "constant")
187
188 if isinstance(val, (str)):
189 m = _constant_re.match(val)
190 self.value = ast.literal_eval(m.group('value'))
191 self._bit_size = int(m.group('bits')) if m.group('bits') else None
192 else:
193 self.value = val
194 self._bit_size = None
195
196 if isinstance(self.value, bool):
197 assert self._bit_size is None or self._bit_size == 32
198 self._bit_size = 32
199
200 def hex(self):
201 if isinstance(self.value, (bool)):
202 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
203 if isinstance(self.value, integer_types):
204 return hex(self.value)
205 elif isinstance(self.value, float):
206 i = struct.unpack('Q', struct.pack('d', self.value))[0]
207 h = hex(i)
208
209 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
210 # Adding it explicitly makes the generated file identical, regardless
211 # of the Python version running this script.
212 if h[-1] != 'L' and i > sys.maxsize:
213 h += 'L'
214
215 return h
216 else:
217 assert False
218
219 def type(self):
220 if isinstance(self.value, (bool)):
221 return "nir_type_bool"
222 elif isinstance(self.value, integer_types):
223 return "nir_type_int"
224 elif isinstance(self.value, float):
225 return "nir_type_float"
226
227 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
228 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
229 r"(?P<cond>\([^\)]+\))?")
230
231 class Variable(Value):
232 def __init__(self, val, name, varset):
233 Value.__init__(self, val, name, "variable")
234
235 m = _var_name_re.match(val)
236 assert m and m.group('name') is not None
237
238 self.var_name = m.group('name')
239 self.is_constant = m.group('const') is not None
240 self.cond = m.group('cond')
241 self.required_type = m.group('type')
242 self._bit_size = int(m.group('bits')) if m.group('bits') else None
243
244 if self.required_type == 'bool':
245 assert self._bit_size is None or self._bit_size == 32
246 self._bit_size = 32
247
248 if self.required_type is not None:
249 assert self.required_type in ('float', 'bool', 'int', 'uint')
250
251 self.index = varset[self.var_name]
252
253 def type(self):
254 if self.required_type == 'bool':
255 return "nir_type_bool"
256 elif self.required_type in ('int', 'uint'):
257 return "nir_type_int"
258 elif self.required_type == 'float':
259 return "nir_type_float"
260
261 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
262 r"(?P<cond>\([^\)]+\))?")
263
264 class Expression(Value):
265 def __init__(self, expr, name_base, varset):
266 Value.__init__(self, expr, name_base, "expression")
267 assert isinstance(expr, tuple)
268
269 m = _opcode_re.match(expr[0])
270 assert m and m.group('opcode') is not None
271
272 self.opcode = m.group('opcode')
273 self._bit_size = int(m.group('bits')) if m.group('bits') else None
274 self.inexact = m.group('inexact') is not None
275 self.cond = m.group('cond')
276 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
277 for (i, src) in enumerate(expr[1:]) ]
278
279 def render(self):
280 srcs = "\n".join(src.render() for src in self.sources)
281 return srcs + super(Expression, self).render()
282
283 class BitSizeValidator(object):
284 """A class for validating bit sizes of expressions.
285
286 NIR supports multiple bit-sizes on expressions in order to handle things
287 such as fp64. The source and destination of every ALU operation is
288 assigned a type and that type may or may not specify a bit size. Sources
289 and destinations whose type does not specify a bit size are considered
290 "unsized" and automatically take on the bit size of the corresponding
291 register or SSA value. NIR has two simple rules for bit sizes that are
292 validated by nir_validator:
293
294 1) A given SSA def or register has a single bit size that is respected by
295 everything that reads from it or writes to it.
296
297 2) The bit sizes of all unsized inputs/outputs on any given ALU
298 instruction must match. They need not match the sized inputs or
299 outputs but they must match each other.
300
301 In order to keep nir_algebraic relatively simple and easy-to-use,
302 nir_search supports a type of bit-size inference based on the two rules
303 above. This is similar to type inference in many common programming
304 languages. If, for instance, you are constructing an add operation and you
305 know the second source is 16-bit, then you know that the other source and
306 the destination must also be 16-bit. There are, however, cases where this
307 inference can be ambiguous or contradictory. Consider, for instance, the
308 following transformation:
309
310 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
311
312 This transformation can potentially cause a problem because usub_borrow is
313 well-defined for any bit-size of integer. However, b2i always generates a
314 32-bit result so it could end up replacing a 64-bit expression with one
315 that takes two 64-bit values and produces a 32-bit value. As another
316 example, consider this expression:
317
318 (('bcsel', a, b, 0), ('iand', a, b))
319
320 In this case, in the search expression a must be 32-bit but b can
321 potentially have any bit size. If we had a 64-bit b value, we would end up
322 trying to and a 32-bit value with a 64-bit value which would be invalid
323
324 This class solves that problem by providing a validation layer that proves
325 that a given search-and-replace operation is 100% well-defined before we
326 generate any code. This ensures that bugs are caught at compile time
327 rather than at run time.
328
329 Each value maintains a "bit-size class", which is either an actual bit size
330 or an equivalence class with other values that must have the same bit size.
331 The validator works by combining bit-size classes with each other according
332 to the NIR rules outlined above, checking that there are no inconsistencies.
333 When doing this for the replacement expression, we make sure to never change
334 the equivalence class of any of the search values. We could make the example
335 transforms above work by doing some extra run-time checking of the search
336 expression, but we make the user specify those constraints themselves, to
337 avoid any surprises. Since the replacement bitsizes can only be connected to
338 the source bitsize via variables (variables must have the same bitsize in
339 the source and replacment expressions) or the roots of the expression (the
340 replacement expression must produce the same bit size as the search
341 expression), we prevent merging a variable with anything when processing the
342 replacement expression, or specializing the search bitsize
343 with anything. The former prevents
344
345 (('bcsel', a, b, 0), ('iand', a, b))
346
347 from being allowed, since we'd have to merge the bitsizes for a and b due to
348 the 'iand', while the latter prevents
349
350 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
351
352 from being allowed, since the search expression has the bit size of a and b,
353 which can't be specialized to 32 which is the bitsize of the replace
354 expression. It also prevents something like:
355
356 (('b2i', ('i2b', a)), ('ineq', a, 0))
357
358 since the bitsize of 'b2i', which can be anything, can't be specialized to
359 the bitsize of a.
360
361 After doing all this, we check that every subexpression of the replacement
362 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
363 of the search expresssion, since those are the things that are known when
364 constructing the replacement expresssion. Finally, we record the bitsize
365 needed in nir_search_value so that we know what to do when building the
366 replacement expression.
367 """
368
369 def __init__(self, varset):
370 self._var_classes = [None] * len(varset.names)
371
372 def compare_bitsizes(self, a, b):
373 """Determines which bitsize class is a specialization of the other, or
374 whether neither is. When we merge two different bitsizes, the
375 less-specialized bitsize always points to the more-specialized one, so
376 that calling get_bit_size() always gets you the most specialized bitsize.
377 The specialization partial order is given by:
378 - Physical bitsizes are always the most specialized, and a different
379 bitsize can never specialize another.
380 - In the search expression, variables can always be specialized to each
381 other and to physical bitsizes. In the replace expression, we disallow
382 this to avoid adding extra constraints to the search expression that
383 the user didn't specify.
384 - Expressions and constants without a bitsize can always be specialized to
385 each other and variables, but not the other way around.
386
387 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
388 and None if they are not comparable (neither a <= b nor b <= a).
389 """
390 if isinstance(a, int):
391 if isinstance(b, int):
392 return 0 if a == b else None
393 elif isinstance(b, Variable):
394 return -1 if self.is_search else None
395 else:
396 return -1
397 elif isinstance(a, Variable):
398 if isinstance(b, int):
399 return 1 if self.is_search else None
400 elif isinstance(b, Variable):
401 return 0 if self.is_search or a.index == b.index else None
402 else:
403 return -1
404 else:
405 if isinstance(b, int):
406 return 1
407 elif isinstance(b, Variable):
408 return 1
409 else:
410 return 0
411
412 def unify_bit_size(self, a, b, error_msg):
413 """Record that a must have the same bit-size as b. If both
414 have been assigned conflicting physical bit-sizes, call "error_msg" with
415 the bit-sizes of self and other to get a message and raise an error.
416 In the replace expression, disallow merging variables with other
417 variables and physical bit-sizes as well.
418 """
419 a_bit_size = a.get_bit_size()
420 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
421
422 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
423
424 assert cmp_result is not None, \
425 error_msg(a_bit_size, b_bit_size)
426
427 if cmp_result < 0:
428 b_bit_size.set_bit_size(a)
429 elif not isinstance(a_bit_size, int):
430 a_bit_size.set_bit_size(b)
431
432 def merge_variables(self, val):
433 """Perform the first part of type inference by merging all the different
434 uses of the same variable. We always do this as if we're in the search
435 expression, even if we're actually not, since otherwise we'd get errors
436 if the search expression specified some constraint but the replace
437 expression didn't, because we'd be merging a variable and a constant.
438 """
439 if isinstance(val, Variable):
440 if self._var_classes[val.index] is None:
441 self._var_classes[val.index] = val
442 else:
443 other = self._var_classes[val.index]
444 self.unify_bit_size(other, val,
445 lambda other_bit_size, bit_size:
446 'Variable {} has conflicting bit size requirements: ' \
447 'it must have bit size {} and {}'.format(
448 val.var_name, other_bit_size, bit_size))
449 elif isinstance(val, Expression):
450 for src in val.sources:
451 self.merge_variables(src)
452
453 def validate_value(self, val):
454 """Validate the an expression by performing classic Hindley-Milner
455 type inference on bitsizes. This will detect if there are any conflicting
456 requirements, and unify variables so that we know which variables must
457 have the same bitsize. If we're operating on the replace expression, we
458 will refuse to merge different variables together or merge a variable
459 with a constant, in order to prevent surprises due to rules unexpectedly
460 not matching at runtime.
461 """
462 if not isinstance(val, Expression):
463 return
464
465 nir_op = opcodes[val.opcode]
466 assert len(val.sources) == nir_op.num_inputs, \
467 "Expression {} has {} sources, expected {}".format(
468 val, len(val.sources), nir_op.num_inputs)
469
470 for src in val.sources:
471 self.validate_value(src)
472
473 dst_type_bits = type_bits(nir_op.output_type)
474
475 # First, unify all the sources. That way, an error coming up because two
476 # sources have an incompatible bit-size won't produce an error message
477 # involving the destination.
478 first_unsized_src = None
479 for src_type, src in zip(nir_op.input_types, val.sources):
480 src_type_bits = type_bits(src_type)
481 if src_type_bits == 0:
482 if first_unsized_src is None:
483 first_unsized_src = src
484 continue
485
486 if self.is_search:
487 self.unify_bit_size(first_unsized_src, src,
488 lambda first_unsized_src_bit_size, src_bit_size:
489 'Source {} of {} must have bit size {}, while source {} ' \
490 'must have incompatible bit size {}'.format(
491 first_unsized_src, val, first_unsized_src_bit_size,
492 src, src_bit_size))
493 else:
494 self.unify_bit_size(first_unsized_src, src,
495 lambda first_unsized_src_bit_size, src_bit_size:
496 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
497 'of {} may not have the same bit size when building the ' \
498 'replacement expression.'.format(
499 first_unsized_src, first_unsized_src_bit_size, src,
500 src_bit_size, val))
501 else:
502 if self.is_search:
503 self.unify_bit_size(src, src_type_bits,
504 lambda src_bit_size, unused:
505 '{} must have {} bits, but as a source of nir_op_{} '\
506 'it must have {} bits'.format(
507 src, src_bit_size, nir_op.name, src_type_bits))
508 else:
509 self.unify_bit_size(src, src_type_bits,
510 lambda src_bit_size, unused:
511 '{} has the bit size of {}, but as a source of ' \
512 'nir_op_{} it must have {} bits, which may not be the ' \
513 'same'.format(
514 src, src_bit_size, nir_op.name, src_type_bits))
515
516 if dst_type_bits == 0:
517 if first_unsized_src is not None:
518 if self.is_search:
519 self.unify_bit_size(val, first_unsized_src,
520 lambda val_bit_size, src_bit_size:
521 '{} must have the bit size of {}, while its source {} ' \
522 'must have incompatible bit size {}'.format(
523 val, val_bit_size, first_unsized_src, src_bit_size))
524 else:
525 self.unify_bit_size(val, first_unsized_src,
526 lambda val_bit_size, src_bit_size:
527 '{} must have {} bits, but its source {} ' \
528 '(bit size of {}) may not have that bit size ' \
529 'when building the replacement.'.format(
530 val, val_bit_size, first_unsized_src, src_bit_size))
531 else:
532 self.unify_bit_size(val, dst_type_bits,
533 lambda dst_bit_size, unused:
534 '{} must have {} bits, but as a destination of nir_op_{} ' \
535 'it must have {} bits'.format(
536 val, dst_bit_size, nir_op.name, dst_type_bits))
537
538 def validate_replace(self, val, search):
539 bit_size = val.get_bit_size()
540 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
541 bit_size == search.get_bit_size(), \
542 'Ambiguous bit size for replacement value {}: ' \
543 'it cannot be deduced from a variable, a fixed bit size ' \
544 'somewhere, or the search expression.'.format(val)
545
546 if isinstance(val, Expression):
547 for src in val.sources:
548 self.validate_replace(src, search)
549
550 def validate(self, search, replace):
551 self.is_search = True
552 self.merge_variables(search)
553 self.merge_variables(replace)
554 self.validate_value(search)
555
556 self.is_search = False
557 self.validate_value(replace)
558
559 # Check that search is always more specialized than replace. Note that
560 # we're doing this in replace mode, disallowing merging variables.
561 search_bit_size = search.get_bit_size()
562 replace_bit_size = replace.get_bit_size()
563 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
564
565 assert cmp_result is not None and cmp_result <= 0, \
566 'The search expression bit size {} and replace expression ' \
567 'bit size {} may not be the same'.format(
568 search_bit_size, replace_bit_size)
569
570 replace.set_bit_size(search)
571
572 self.validate_replace(replace, search)
573
574 _optimization_ids = itertools.count()
575
576 condition_list = ['true']
577
578 class SearchAndReplace(object):
579 def __init__(self, transform):
580 self.id = next(_optimization_ids)
581
582 search = transform[0]
583 replace = transform[1]
584 if len(transform) > 2:
585 self.condition = transform[2]
586 else:
587 self.condition = 'true'
588
589 if self.condition not in condition_list:
590 condition_list.append(self.condition)
591 self.condition_index = condition_list.index(self.condition)
592
593 varset = VarSet()
594 if isinstance(search, Expression):
595 self.search = search
596 else:
597 self.search = Expression(search, "search{0}".format(self.id), varset)
598
599 varset.lock()
600
601 if isinstance(replace, Value):
602 self.replace = replace
603 else:
604 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
605
606 BitSizeValidator(varset).validate(self.search, self.replace)
607
608 _algebraic_pass_template = mako.template.Template("""
609 #include "nir.h"
610 #include "nir_builder.h"
611 #include "nir_search.h"
612 #include "nir_search_helpers.h"
613
614 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
615 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
616
617 struct transform {
618 const nir_search_expression *search;
619 const nir_search_value *replace;
620 unsigned condition_offset;
621 };
622
623 #endif
624
625 % for xform in xforms:
626 ${xform.search.render()}
627 ${xform.replace.render()}
628 % endfor
629
630 % for (opcode, xform_list) in sorted(opcode_xforms.items()):
631 static const struct transform ${pass_name}_${opcode}_xforms[] = {
632 % for xform in xform_list:
633 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
634 % endfor
635 };
636 % endfor
637
638 static bool
639 ${pass_name}_block(nir_builder *build, nir_block *block,
640 const bool *condition_flags)
641 {
642 bool progress = false;
643
644 nir_foreach_instr_reverse_safe(instr, block) {
645 if (instr->type != nir_instr_type_alu)
646 continue;
647
648 nir_alu_instr *alu = nir_instr_as_alu(instr);
649 if (!alu->dest.dest.is_ssa)
650 continue;
651
652 switch (alu->op) {
653 % for opcode in sorted(opcode_xforms.keys()):
654 case nir_op_${opcode}:
655 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
656 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
657 if (condition_flags[xform->condition_offset] &&
658 nir_replace_instr(build, alu, xform->search, xform->replace)) {
659 progress = true;
660 break;
661 }
662 }
663 break;
664 % endfor
665 default:
666 break;
667 }
668 }
669
670 return progress;
671 }
672
673 static bool
674 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
675 {
676 bool progress = false;
677
678 nir_builder build;
679 nir_builder_init(&build, impl);
680
681 nir_foreach_block_reverse(block, impl) {
682 progress |= ${pass_name}_block(&build, block, condition_flags);
683 }
684
685 if (progress)
686 nir_metadata_preserve(impl, nir_metadata_block_index |
687 nir_metadata_dominance);
688
689 return progress;
690 }
691
692
693 bool
694 ${pass_name}(nir_shader *shader)
695 {
696 bool progress = false;
697 bool condition_flags[${len(condition_list)}];
698 const nir_shader_compiler_options *options = shader->options;
699 (void) options;
700
701 % for index, condition in enumerate(condition_list):
702 condition_flags[${index}] = ${condition};
703 % endfor
704
705 nir_foreach_function(function, shader) {
706 if (function->impl)
707 progress |= ${pass_name}_impl(function->impl, condition_flags);
708 }
709
710 return progress;
711 }
712 """)
713
714 class AlgebraicPass(object):
715 def __init__(self, pass_name, transforms):
716 self.xforms = []
717 self.opcode_xforms = defaultdict(lambda : [])
718 self.pass_name = pass_name
719
720 error = False
721
722 for xform in transforms:
723 if not isinstance(xform, SearchAndReplace):
724 try:
725 xform = SearchAndReplace(xform)
726 except:
727 print("Failed to parse transformation:", file=sys.stderr)
728 print(" " + str(xform), file=sys.stderr)
729 traceback.print_exc(file=sys.stderr)
730 print('', file=sys.stderr)
731 error = True
732 continue
733
734 self.xforms.append(xform)
735 self.opcode_xforms[xform.search.opcode].append(xform)
736
737 if error:
738 sys.exit(1)
739
740
741 def render(self):
742 return _algebraic_pass_template.render(pass_name=self.pass_name,
743 xforms=self.xforms,
744 opcode_xforms=self.opcode_xforms,
745 condition_list=condition_list)