nir/algebraic: Rewrite bit-size inference
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import OrderedDict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 if sys.version_info < (3, 0):
39 integer_types = (int, long)
40 string_type = unicode
41
42 else:
43 integer_types = (int, )
44 string_type = str
45
46 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
47
48 def type_bits(type_str):
49 m = _type_re.match(type_str)
50 assert m.group('type')
51
52 if m.group('bits') is None:
53 return 0
54 else:
55 return int(m.group('bits'))
56
57 # Represents a set of variables, each with a unique id
58 class VarSet(object):
59 def __init__(self):
60 self.names = {}
61 self.ids = itertools.count()
62 self.immutable = False;
63
64 def __getitem__(self, name):
65 if name not in self.names:
66 assert not self.immutable, "Unknown replacement variable: " + name
67 self.names[name] = next(self.ids)
68
69 return self.names[name]
70
71 def lock(self):
72 self.immutable = True
73
74 class Value(object):
75 @staticmethod
76 def create(val, name_base, varset):
77 if isinstance(val, bytes):
78 val = val.decode('utf-8')
79
80 if isinstance(val, tuple):
81 return Expression(val, name_base, varset)
82 elif isinstance(val, Expression):
83 return val
84 elif isinstance(val, string_type):
85 return Variable(val, name_base, varset)
86 elif isinstance(val, (bool, float) + integer_types):
87 return Constant(val, name_base)
88
89 __template = mako.template.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.c_bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
104 % endif
105 };""")
106
107 def __init__(self, val, name, type_str):
108 self.in_val = str(val)
109 self.name = name
110 self.type_str = type_str
111
112 def __str__(self):
113 return self.in_val
114
115 def get_bit_size(self):
116 """Get the physical bit-size that has been chosen for this value, or if
117 there is none, the canonical value which currently represents this
118 bit-size class. Variables will be preferred, i.e. if there are any
119 variables in the equivalence class, the canonical value will be a
120 variable. We do this since we'll need to know which variable each value
121 is equivalent to when constructing the replacement expression. This is
122 the "find" part of the union-find algorithm.
123 """
124 bit_size = self
125
126 while isinstance(bit_size, Value):
127 if bit_size._bit_size is None:
128 break
129 bit_size = bit_size._bit_size
130
131 if bit_size is not self:
132 self._bit_size = bit_size
133 return bit_size
134
135 def set_bit_size(self, other):
136 """Make self.get_bit_size() return what other.get_bit_size() return
137 before calling this, or just "other" if it's a concrete bit-size. This is
138 the "union" part of the union-find algorithm.
139 """
140
141 self_bit_size = self.get_bit_size()
142 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
143
144 if self_bit_size == other_bit_size:
145 return
146
147 self_bit_size._bit_size = other_bit_size
148
149 @property
150 def type_enum(self):
151 return "nir_search_value_" + self.type_str
152
153 @property
154 def c_type(self):
155 return "nir_search_" + self.type_str
156
157 @property
158 def c_ptr(self):
159 return "&{0}.value".format(self.name)
160
161 @property
162 def c_bit_size(self):
163 bit_size = self.get_bit_size()
164 if isinstance(bit_size, int):
165 return bit_size
166 elif isinstance(bit_size, Variable):
167 return -bit_size.index - 1
168 else:
169 # If the bit-size class is neither a variable, nor an actual bit-size, then
170 # - If it's in the search expression, we don't need to check anything
171 # - If it's in the replace expression, either it's ambiguous (in which
172 # case we'd reject it), or it equals the bit-size of the search value
173 # We represent these cases with a 0 bit-size.
174 return 0
175
176 def render(self):
177 return self.__template.render(val=self,
178 Constant=Constant,
179 Variable=Variable,
180 Expression=Expression)
181
182 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
183
184 class Constant(Value):
185 def __init__(self, val, name):
186 Value.__init__(self, val, name, "constant")
187
188 self.in_val = str(val)
189 if isinstance(val, (str)):
190 m = _constant_re.match(val)
191 self.value = ast.literal_eval(m.group('value'))
192 self._bit_size = int(m.group('bits')) if m.group('bits') else None
193 else:
194 self.value = val
195 self._bit_size = None
196
197 if isinstance(self.value, bool):
198 assert self._bit_size is None or self._bit_size == 32
199 self._bit_size = 32
200
201 def hex(self):
202 if isinstance(self.value, (bool)):
203 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
204 if isinstance(self.value, integer_types):
205 return hex(self.value)
206 elif isinstance(self.value, float):
207 i = struct.unpack('Q', struct.pack('d', self.value))[0]
208 h = hex(i)
209
210 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
211 # Adding it explicitly makes the generated file identical, regardless
212 # of the Python version running this script.
213 if h[-1] != 'L' and i > sys.maxsize:
214 h += 'L'
215
216 return h
217 else:
218 assert False
219
220 def type(self):
221 if isinstance(self.value, (bool)):
222 return "nir_type_bool"
223 elif isinstance(self.value, integer_types):
224 return "nir_type_int"
225 elif isinstance(self.value, float):
226 return "nir_type_float"
227
228 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
229 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
230 r"(?P<cond>\([^\)]+\))?")
231
232 class Variable(Value):
233 def __init__(self, val, name, varset):
234 Value.__init__(self, val, name, "variable")
235
236 m = _var_name_re.match(val)
237 assert m and m.group('name') is not None
238
239 self.var_name = m.group('name')
240 self.is_constant = m.group('const') is not None
241 self.cond = m.group('cond')
242 self.required_type = m.group('type')
243 self._bit_size = int(m.group('bits')) if m.group('bits') else None
244
245 if self.required_type == 'bool':
246 assert self._bit_size is None or self._bit_size == 32
247 self._bit_size = 32
248
249 if self.required_type is not None:
250 assert self.required_type in ('float', 'bool', 'int', 'uint')
251
252 self.index = varset[self.var_name]
253
254 def __str__(self):
255 return self.in_val
256
257 def type(self):
258 if self.required_type == 'bool':
259 return "nir_type_bool"
260 elif self.required_type in ('int', 'uint'):
261 return "nir_type_int"
262 elif self.required_type == 'float':
263 return "nir_type_float"
264
265 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
266 r"(?P<cond>\([^\)]+\))?")
267
268 class Expression(Value):
269 def __init__(self, expr, name_base, varset):
270 Value.__init__(self, expr, name_base, "expression")
271 assert isinstance(expr, tuple)
272
273 m = _opcode_re.match(expr[0])
274 assert m and m.group('opcode') is not None
275
276 self.opcode = m.group('opcode')
277 self._bit_size = int(m.group('bits')) if m.group('bits') else None
278 self.inexact = m.group('inexact') is not None
279 self.cond = m.group('cond')
280 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
281 for (i, src) in enumerate(expr[1:]) ]
282
283 def render(self):
284 srcs = "\n".join(src.render() for src in self.sources)
285 return srcs + super(Expression, self).render()
286
287 class BitSizeValidator(object):
288 """A class for validating bit sizes of expressions.
289
290 NIR supports multiple bit-sizes on expressions in order to handle things
291 such as fp64. The source and destination of every ALU operation is
292 assigned a type and that type may or may not specify a bit size. Sources
293 and destinations whose type does not specify a bit size are considered
294 "unsized" and automatically take on the bit size of the corresponding
295 register or SSA value. NIR has two simple rules for bit sizes that are
296 validated by nir_validator:
297
298 1) A given SSA def or register has a single bit size that is respected by
299 everything that reads from it or writes to it.
300
301 2) The bit sizes of all unsized inputs/outputs on any given ALU
302 instruction must match. They need not match the sized inputs or
303 outputs but they must match each other.
304
305 In order to keep nir_algebraic relatively simple and easy-to-use,
306 nir_search supports a type of bit-size inference based on the two rules
307 above. This is similar to type inference in many common programming
308 languages. If, for instance, you are constructing an add operation and you
309 know the second source is 16-bit, then you know that the other source and
310 the destination must also be 16-bit. There are, however, cases where this
311 inference can be ambiguous or contradictory. Consider, for instance, the
312 following transformation:
313
314 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
315
316 This transformation can potentially cause a problem because usub_borrow is
317 well-defined for any bit-size of integer. However, b2i always generates a
318 32-bit result so it could end up replacing a 64-bit expression with one
319 that takes two 64-bit values and produces a 32-bit value. As another
320 example, consider this expression:
321
322 (('bcsel', a, b, 0), ('iand', a, b))
323
324 In this case, in the search expression a must be 32-bit but b can
325 potentially have any bit size. If we had a 64-bit b value, we would end up
326 trying to and a 32-bit value with a 64-bit value which would be invalid
327
328 This class solves that problem by providing a validation layer that proves
329 that a given search-and-replace operation is 100% well-defined before we
330 generate any code. This ensures that bugs are caught at compile time
331 rather than at run time.
332
333 Each value maintains a "bit-size class", which is either an actual bit size
334 or an equivalence class with other values that must have the same bit size.
335 The validator works by combining bit-size classes with each other according
336 to the NIR rules outlined above, checking that there are no inconsistencies.
337 When doing this for the replacement expression, we make sure to never change
338 the equivalence class of any of the search values. We could make the example
339 transforms above work by doing some extra run-time checking of the search
340 expression, but we make the user specify those constraints themselves, to
341 avoid any surprises. Since the replacement bitsizes can only be connected to
342 the source bitsize via variables (variables must have the same bitsize in
343 the source and replacment expressions) or the roots of the expression (the
344 replacement expression must produce the same bit size as the search
345 expression), we prevent merging a variable with anything when processing the
346 replacement expression, or specializing the search bitsize
347 with anything. The former prevents
348
349 (('bcsel', a, b, 0), ('iand', a, b))
350
351 from being allowed, since we'd have to merge the bitsizes for a and b due to
352 the 'iand', while the latter prevents
353
354 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
355
356 from being allowed, since the search expression has the bit size of a and b,
357 which can't be specialized to 32 which is the bitsize of the replace
358 expression. It also prevents something like:
359
360 (('b2i', ('i2b', a)), ('ineq', a, 0))
361
362 since the bitsize of 'b2i', which can be anything, can't be specialized to
363 the bitsize of a.
364
365 After doing all this, we check that every subexpression of the replacement
366 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
367 of the search expresssion, since those are the things that are known when
368 constructing the replacement expresssion. Finally, we record the bitsize
369 needed in nir_search_value so that we know what to do when building the
370 replacement expression.
371 """
372
373 def __init__(self, varset):
374 self._var_classes = [None] * len(varset.names)
375
376 def compare_bitsizes(self, a, b):
377 """Determines which bitsize class is a specialization of the other, or
378 whether neither is. When we merge two different bitsizes, the
379 less-specialized bitsize always points to the more-specialized one, so
380 that calling get_bit_size() always gets you the most specialized bitsize.
381 The specialization partial order is given by:
382 - Physical bitsizes are always the most specialized, and a different
383 bitsize can never specialize another.
384 - In the search expression, variables can always be specialized to each
385 other and to physical bitsizes. In the replace expression, we disallow
386 this to avoid adding extra constraints to the search expression that
387 the user didn't specify.
388 - Expressions and constants without a bitsize can always be specialized to
389 each other and variables, but not the other way around.
390
391 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
392 and None if they are not comparable (neither a <= b nor b <= a).
393 """
394 if isinstance(a, int):
395 if isinstance(b, int):
396 return 0 if a == b else None
397 elif isinstance(b, Variable):
398 return -1 if self.is_search else None
399 else:
400 return -1
401 elif isinstance(a, Variable):
402 if isinstance(b, int):
403 return 1 if self.is_search else None
404 elif isinstance(b, Variable):
405 return 0 if self.is_search or a.index == b.index else None
406 else:
407 return -1
408 else:
409 if isinstance(b, int):
410 return 1
411 elif isinstance(b, Variable):
412 return 1
413 else:
414 return 0
415
416 def unify_bit_size(self, a, b, error_msg):
417 """Record that a must have the same bit-size as b. If both
418 have been assigned conflicting physical bit-sizes, call "error_msg" with
419 the bit-sizes of self and other to get a message and raise an error.
420 In the replace expression, disallow merging variables with other
421 variables and physical bit-sizes as well.
422 """
423 a_bit_size = a.get_bit_size()
424 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
425
426 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
427
428 assert cmp_result is not None, \
429 error_msg(a_bit_size, b_bit_size)
430
431 if cmp_result < 0:
432 b_bit_size.set_bit_size(a)
433 elif not isinstance(a_bit_size, int):
434 a_bit_size.set_bit_size(b)
435
436 def merge_variables(self, val):
437 """Perform the first part of type inference by merging all the different
438 uses of the same variable. We always do this as if we're in the search
439 expression, even if we're actually not, since otherwise we'd get errors
440 if the search expression specified some constraint but the replace
441 expression didn't, because we'd be merging a variable and a constant.
442 """
443 if isinstance(val, Variable):
444 if self._var_classes[val.index] is None:
445 self._var_classes[val.index] = val
446 else:
447 other = self._var_classes[val.index]
448 self.unify_bit_size(other, val,
449 lambda other_bit_size, bit_size:
450 'Variable {} has conflicting bit size requirements: ' \
451 'it must have bit size {} and {}'.format(
452 val.var_name, other_bit_size, bit_size))
453 elif isinstance(val, Expression):
454 for src in val.sources:
455 self.merge_variables(src)
456
457 def validate_value(self, val):
458 """Validate the an expression by performing classic Hindley-Milner
459 type inference on bitsizes. This will detect if there are any conflicting
460 requirements, and unify variables so that we know which variables must
461 have the same bitsize. If we're operating on the replace expression, we
462 will refuse to merge different variables together or merge a variable
463 with a constant, in order to prevent surprises due to rules unexpectedly
464 not matching at runtime.
465 """
466 if not isinstance(val, Expression):
467 return
468
469 nir_op = opcodes[val.opcode]
470 assert len(val.sources) == nir_op.num_inputs, \
471 "Expression {} has {} sources, expected {}".format(
472 val, len(val.sources), nir_op.num_inputs)
473
474 for src in val.sources:
475 self.validate_value(src)
476
477 dst_type_bits = type_bits(nir_op.output_type)
478
479 # First, unify all the sources. That way, an error coming up because two
480 # sources have an incompatible bit-size won't produce an error message
481 # involving the destination.
482 first_unsized_src = None
483 for src_type, src in zip(nir_op.input_types, val.sources):
484 src_type_bits = type_bits(src_type)
485 if src_type_bits == 0:
486 if first_unsized_src is None:
487 first_unsized_src = src
488 continue
489
490 if self.is_search:
491 self.unify_bit_size(first_unsized_src, src,
492 lambda first_unsized_src_bit_size, src_bit_size:
493 'Source {} of {} must have bit size {}, while source {} ' \
494 'must have incompatible bit size {}'.format(
495 first_unsized_src, val, first_unsized_src_bit_size,
496 src, src_bit_size))
497 else:
498 self.unify_bit_size(first_unsized_src, src,
499 lambda first_unsized_src_bit_size, src_bit_size:
500 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
501 'of {} may not have the same bit size when building the ' \
502 'replacement expression.'.format(
503 first_unsized_src, first_unsized_src_bit_size, src,
504 src_bit_size, val))
505 else:
506 if self.is_search:
507 self.unify_bit_size(src, src_type_bits,
508 lambda src_bit_size, unused:
509 '{} must have {} bits, but as a source of nir_op_{} '\
510 'it must have {} bits'.format(
511 src, src_bit_size, nir_op.name, src_type_bits))
512 else:
513 self.unify_bit_size(src, src_type_bits,
514 lambda src_bit_size, unused:
515 '{} has the bit size of {}, but as a source of ' \
516 'nir_op_{} it must have {} bits, which may not be the ' \
517 'same'.format(
518 src, src_bit_size, nir_op.name, src_type_bits))
519
520 if dst_type_bits == 0:
521 if first_unsized_src is not None:
522 if self.is_search:
523 self.unify_bit_size(val, first_unsized_src,
524 lambda val_bit_size, src_bit_size:
525 '{} must have the bit size of {}, while its source {} ' \
526 'must have incompatible bit size {}'.format(
527 val, val_bit_size, first_unsized_src, src_bit_size))
528 else:
529 self.unify_bit_size(val, first_unsized_src,
530 lambda val_bit_size, src_bit_size:
531 '{} must have {} bits, but its source {} ' \
532 '(bit size of {}) may not have that bit size ' \
533 'when building the replacement.'.format(
534 val, val_bit_size, first_unsized_src, src_bit_size))
535 else:
536 self.unify_bit_size(val, dst_type_bits,
537 lambda dst_bit_size, unused:
538 '{} must have {} bits, but as a destination of nir_op_{} ' \
539 'it must have {} bits'.format(
540 val, dst_bit_size, nir_op.name, dst_type_bits))
541
542 def validate_replace(self, val, search):
543 bit_size = val.get_bit_size()
544 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
545 bit_size == search.get_bit_size(), \
546 'Ambiguous bit size for replacement value {}: ' \
547 'it cannot be deduced from a variable, a fixed bit size ' \
548 'somewhere, or the search expression.'.format(val)
549
550 if isinstance(val, Expression):
551 for src in val.sources:
552 self.validate_replace(src, search)
553
554 def validate(self, search, replace):
555 self.is_search = True
556 self.merge_variables(search)
557 self.merge_variables(replace)
558 self.validate_value(search)
559
560 self.is_search = False
561 self.validate_value(replace)
562
563 # Check that search is always more specialized than replace. Note that
564 # we're doing this in replace mode, disallowing merging variables.
565 search_bit_size = search.get_bit_size()
566 replace_bit_size = replace.get_bit_size()
567 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
568
569 assert cmp_result is not None and cmp_result <= 0, \
570 'The search expression bit size {} and replace expression ' \
571 'bit size {} may not be the same'.format(
572 search_bit_size, replace_bit_size)
573
574 replace.set_bit_size(search)
575
576 self.validate_replace(replace, search)
577
578 _optimization_ids = itertools.count()
579
580 condition_list = ['true']
581
582 class SearchAndReplace(object):
583 def __init__(self, transform):
584 self.id = next(_optimization_ids)
585
586 search = transform[0]
587 replace = transform[1]
588 if len(transform) > 2:
589 self.condition = transform[2]
590 else:
591 self.condition = 'true'
592
593 if self.condition not in condition_list:
594 condition_list.append(self.condition)
595 self.condition_index = condition_list.index(self.condition)
596
597 varset = VarSet()
598 if isinstance(search, Expression):
599 self.search = search
600 else:
601 self.search = Expression(search, "search{0}".format(self.id), varset)
602
603 varset.lock()
604
605 if isinstance(replace, Value):
606 self.replace = replace
607 else:
608 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
609
610 BitSizeValidator(varset).validate(self.search, self.replace)
611
612 _algebraic_pass_template = mako.template.Template("""
613 #include "nir.h"
614 #include "nir_builder.h"
615 #include "nir_search.h"
616 #include "nir_search_helpers.h"
617
618 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
619 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
620
621 struct transform {
622 const nir_search_expression *search;
623 const nir_search_value *replace;
624 unsigned condition_offset;
625 };
626
627 #endif
628
629 % for (opcode, xform_list) in xform_dict.items():
630 % for xform in xform_list:
631 ${xform.search.render()}
632 ${xform.replace.render()}
633 % endfor
634
635 static const struct transform ${pass_name}_${opcode}_xforms[] = {
636 % for xform in xform_list:
637 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
638 % endfor
639 };
640 % endfor
641
642 static bool
643 ${pass_name}_block(nir_builder *build, nir_block *block,
644 const bool *condition_flags)
645 {
646 bool progress = false;
647
648 nir_foreach_instr_reverse_safe(instr, block) {
649 if (instr->type != nir_instr_type_alu)
650 continue;
651
652 nir_alu_instr *alu = nir_instr_as_alu(instr);
653 if (!alu->dest.dest.is_ssa)
654 continue;
655
656 switch (alu->op) {
657 % for opcode in xform_dict.keys():
658 case nir_op_${opcode}:
659 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
660 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
661 if (condition_flags[xform->condition_offset] &&
662 nir_replace_instr(build, alu, xform->search, xform->replace)) {
663 progress = true;
664 break;
665 }
666 }
667 break;
668 % endfor
669 default:
670 break;
671 }
672 }
673
674 return progress;
675 }
676
677 static bool
678 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
679 {
680 bool progress = false;
681
682 nir_builder build;
683 nir_builder_init(&build, impl);
684
685 nir_foreach_block_reverse(block, impl) {
686 progress |= ${pass_name}_block(&build, block, condition_flags);
687 }
688
689 if (progress)
690 nir_metadata_preserve(impl, nir_metadata_block_index |
691 nir_metadata_dominance);
692
693 return progress;
694 }
695
696
697 bool
698 ${pass_name}(nir_shader *shader)
699 {
700 bool progress = false;
701 bool condition_flags[${len(condition_list)}];
702 const nir_shader_compiler_options *options = shader->options;
703 (void) options;
704
705 % for index, condition in enumerate(condition_list):
706 condition_flags[${index}] = ${condition};
707 % endfor
708
709 nir_foreach_function(function, shader) {
710 if (function->impl)
711 progress |= ${pass_name}_impl(function->impl, condition_flags);
712 }
713
714 return progress;
715 }
716 """)
717
718 class AlgebraicPass(object):
719 def __init__(self, pass_name, transforms):
720 self.xform_dict = OrderedDict()
721 self.pass_name = pass_name
722
723 error = False
724
725 for xform in transforms:
726 if not isinstance(xform, SearchAndReplace):
727 try:
728 xform = SearchAndReplace(xform)
729 except:
730 print("Failed to parse transformation:", file=sys.stderr)
731 print(" " + str(xform), file=sys.stderr)
732 traceback.print_exc(file=sys.stderr)
733 print('', file=sys.stderr)
734 error = True
735 continue
736
737 if xform.search.opcode not in self.xform_dict:
738 self.xform_dict[xform.search.opcode] = []
739
740 self.xform_dict[xform.search.opcode].append(xform)
741
742 if error:
743 sys.exit(1)
744
745 def render(self):
746 return _algebraic_pass_template.render(pass_name=self.pass_name,
747 xform_dict=self.xform_dict,
748 condition_list=condition_list)