nir/algebraic: Add support for unsized conversion opcodes
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # These opcodes are only employed by nir_search. This provides a mapping from
39 # opcode to destination type.
40 conv_opcode_types = {
41 'i2f' : 'float',
42 'u2f' : 'float',
43 'f2f' : 'float',
44 'f2u' : 'uint',
45 'f2i' : 'int',
46 'u2u' : 'uint',
47 'i2i' : 'int',
48 }
49
50 if sys.version_info < (3, 0):
51 integer_types = (int, long)
52 string_type = unicode
53
54 else:
55 integer_types = (int, )
56 string_type = str
57
58 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
59
60 def type_bits(type_str):
61 m = _type_re.match(type_str)
62 assert m.group('type')
63
64 if m.group('bits') is None:
65 return 0
66 else:
67 return int(m.group('bits'))
68
69 # Represents a set of variables, each with a unique id
70 class VarSet(object):
71 def __init__(self):
72 self.names = {}
73 self.ids = itertools.count()
74 self.immutable = False;
75
76 def __getitem__(self, name):
77 if name not in self.names:
78 assert not self.immutable, "Unknown replacement variable: " + name
79 self.names[name] = next(self.ids)
80
81 return self.names[name]
82
83 def lock(self):
84 self.immutable = True
85
86 class Value(object):
87 @staticmethod
88 def create(val, name_base, varset):
89 if isinstance(val, bytes):
90 val = val.decode('utf-8')
91
92 if isinstance(val, tuple):
93 return Expression(val, name_base, varset)
94 elif isinstance(val, Expression):
95 return val
96 elif isinstance(val, string_type):
97 return Variable(val, name_base, varset)
98 elif isinstance(val, (bool, float) + integer_types):
99 return Constant(val, name_base)
100
101 __template = mako.template.Template("""
102 static const ${val.c_type} ${val.name} = {
103 { ${val.type_enum}, ${val.c_bit_size} },
104 % if isinstance(val, Constant):
105 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
106 % elif isinstance(val, Variable):
107 ${val.index}, /* ${val.var_name} */
108 ${'true' if val.is_constant else 'false'},
109 ${val.type() or 'nir_type_invalid' },
110 ${val.cond if val.cond else 'NULL'},
111 % elif isinstance(val, Expression):
112 ${'true' if val.inexact else 'false'},
113 ${val.c_opcode()},
114 { ${', '.join(src.c_ptr for src in val.sources)} },
115 ${val.cond if val.cond else 'NULL'},
116 % endif
117 };""")
118
119 def __init__(self, val, name, type_str):
120 self.in_val = str(val)
121 self.name = name
122 self.type_str = type_str
123
124 def __str__(self):
125 return self.in_val
126
127 def get_bit_size(self):
128 """Get the physical bit-size that has been chosen for this value, or if
129 there is none, the canonical value which currently represents this
130 bit-size class. Variables will be preferred, i.e. if there are any
131 variables in the equivalence class, the canonical value will be a
132 variable. We do this since we'll need to know which variable each value
133 is equivalent to when constructing the replacement expression. This is
134 the "find" part of the union-find algorithm.
135 """
136 bit_size = self
137
138 while isinstance(bit_size, Value):
139 if bit_size._bit_size is None:
140 break
141 bit_size = bit_size._bit_size
142
143 if bit_size is not self:
144 self._bit_size = bit_size
145 return bit_size
146
147 def set_bit_size(self, other):
148 """Make self.get_bit_size() return what other.get_bit_size() return
149 before calling this, or just "other" if it's a concrete bit-size. This is
150 the "union" part of the union-find algorithm.
151 """
152
153 self_bit_size = self.get_bit_size()
154 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
155
156 if self_bit_size == other_bit_size:
157 return
158
159 self_bit_size._bit_size = other_bit_size
160
161 @property
162 def type_enum(self):
163 return "nir_search_value_" + self.type_str
164
165 @property
166 def c_type(self):
167 return "nir_search_" + self.type_str
168
169 @property
170 def c_ptr(self):
171 return "&{0}.value".format(self.name)
172
173 @property
174 def c_bit_size(self):
175 bit_size = self.get_bit_size()
176 if isinstance(bit_size, int):
177 return bit_size
178 elif isinstance(bit_size, Variable):
179 return -bit_size.index - 1
180 else:
181 # If the bit-size class is neither a variable, nor an actual bit-size, then
182 # - If it's in the search expression, we don't need to check anything
183 # - If it's in the replace expression, either it's ambiguous (in which
184 # case we'd reject it), or it equals the bit-size of the search value
185 # We represent these cases with a 0 bit-size.
186 return 0
187
188 def render(self):
189 return self.__template.render(val=self,
190 Constant=Constant,
191 Variable=Variable,
192 Expression=Expression)
193
194 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
195
196 class Constant(Value):
197 def __init__(self, val, name):
198 Value.__init__(self, val, name, "constant")
199
200 if isinstance(val, (str)):
201 m = _constant_re.match(val)
202 self.value = ast.literal_eval(m.group('value'))
203 self._bit_size = int(m.group('bits')) if m.group('bits') else None
204 else:
205 self.value = val
206 self._bit_size = None
207
208 if isinstance(self.value, bool):
209 assert self._bit_size is None or self._bit_size == 32
210 self._bit_size = 32
211
212 def hex(self):
213 if isinstance(self.value, (bool)):
214 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
215 if isinstance(self.value, integer_types):
216 return hex(self.value)
217 elif isinstance(self.value, float):
218 i = struct.unpack('Q', struct.pack('d', self.value))[0]
219 h = hex(i)
220
221 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
222 # Adding it explicitly makes the generated file identical, regardless
223 # of the Python version running this script.
224 if h[-1] != 'L' and i > sys.maxsize:
225 h += 'L'
226
227 return h
228 else:
229 assert False
230
231 def type(self):
232 if isinstance(self.value, (bool)):
233 return "nir_type_bool"
234 elif isinstance(self.value, integer_types):
235 return "nir_type_int"
236 elif isinstance(self.value, float):
237 return "nir_type_float"
238
239 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
240 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
241 r"(?P<cond>\([^\)]+\))?")
242
243 class Variable(Value):
244 def __init__(self, val, name, varset):
245 Value.__init__(self, val, name, "variable")
246
247 m = _var_name_re.match(val)
248 assert m and m.group('name') is not None
249
250 self.var_name = m.group('name')
251 self.is_constant = m.group('const') is not None
252 self.cond = m.group('cond')
253 self.required_type = m.group('type')
254 self._bit_size = int(m.group('bits')) if m.group('bits') else None
255
256 if self.required_type == 'bool':
257 assert self._bit_size is None or self._bit_size == 32
258 self._bit_size = 32
259
260 if self.required_type is not None:
261 assert self.required_type in ('float', 'bool', 'int', 'uint')
262
263 self.index = varset[self.var_name]
264
265 def type(self):
266 if self.required_type == 'bool':
267 return "nir_type_bool"
268 elif self.required_type in ('int', 'uint'):
269 return "nir_type_int"
270 elif self.required_type == 'float':
271 return "nir_type_float"
272
273 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
274 r"(?P<cond>\([^\)]+\))?")
275
276 class Expression(Value):
277 def __init__(self, expr, name_base, varset):
278 Value.__init__(self, expr, name_base, "expression")
279 assert isinstance(expr, tuple)
280
281 m = _opcode_re.match(expr[0])
282 assert m and m.group('opcode') is not None
283
284 self.opcode = m.group('opcode')
285 self._bit_size = int(m.group('bits')) if m.group('bits') else None
286 self.inexact = m.group('inexact') is not None
287 self.cond = m.group('cond')
288 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
289 for (i, src) in enumerate(expr[1:]) ]
290
291 if self.opcode in conv_opcode_types:
292 assert self._bit_size is None, \
293 'Expression cannot use an unsized conversion opcode with ' \
294 'an explicit size; that\'s silly.'
295
296
297 def c_opcode(self):
298 if self.opcode in conv_opcode_types:
299 return 'nir_search_op_' + self.opcode
300 else:
301 return 'nir_op_' + self.opcode
302
303 def render(self):
304 srcs = "\n".join(src.render() for src in self.sources)
305 return srcs + super(Expression, self).render()
306
307 class BitSizeValidator(object):
308 """A class for validating bit sizes of expressions.
309
310 NIR supports multiple bit-sizes on expressions in order to handle things
311 such as fp64. The source and destination of every ALU operation is
312 assigned a type and that type may or may not specify a bit size. Sources
313 and destinations whose type does not specify a bit size are considered
314 "unsized" and automatically take on the bit size of the corresponding
315 register or SSA value. NIR has two simple rules for bit sizes that are
316 validated by nir_validator:
317
318 1) A given SSA def or register has a single bit size that is respected by
319 everything that reads from it or writes to it.
320
321 2) The bit sizes of all unsized inputs/outputs on any given ALU
322 instruction must match. They need not match the sized inputs or
323 outputs but they must match each other.
324
325 In order to keep nir_algebraic relatively simple and easy-to-use,
326 nir_search supports a type of bit-size inference based on the two rules
327 above. This is similar to type inference in many common programming
328 languages. If, for instance, you are constructing an add operation and you
329 know the second source is 16-bit, then you know that the other source and
330 the destination must also be 16-bit. There are, however, cases where this
331 inference can be ambiguous or contradictory. Consider, for instance, the
332 following transformation:
333
334 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
335
336 This transformation can potentially cause a problem because usub_borrow is
337 well-defined for any bit-size of integer. However, b2i always generates a
338 32-bit result so it could end up replacing a 64-bit expression with one
339 that takes two 64-bit values and produces a 32-bit value. As another
340 example, consider this expression:
341
342 (('bcsel', a, b, 0), ('iand', a, b))
343
344 In this case, in the search expression a must be 32-bit but b can
345 potentially have any bit size. If we had a 64-bit b value, we would end up
346 trying to and a 32-bit value with a 64-bit value which would be invalid
347
348 This class solves that problem by providing a validation layer that proves
349 that a given search-and-replace operation is 100% well-defined before we
350 generate any code. This ensures that bugs are caught at compile time
351 rather than at run time.
352
353 Each value maintains a "bit-size class", which is either an actual bit size
354 or an equivalence class with other values that must have the same bit size.
355 The validator works by combining bit-size classes with each other according
356 to the NIR rules outlined above, checking that there are no inconsistencies.
357 When doing this for the replacement expression, we make sure to never change
358 the equivalence class of any of the search values. We could make the example
359 transforms above work by doing some extra run-time checking of the search
360 expression, but we make the user specify those constraints themselves, to
361 avoid any surprises. Since the replacement bitsizes can only be connected to
362 the source bitsize via variables (variables must have the same bitsize in
363 the source and replacment expressions) or the roots of the expression (the
364 replacement expression must produce the same bit size as the search
365 expression), we prevent merging a variable with anything when processing the
366 replacement expression, or specializing the search bitsize
367 with anything. The former prevents
368
369 (('bcsel', a, b, 0), ('iand', a, b))
370
371 from being allowed, since we'd have to merge the bitsizes for a and b due to
372 the 'iand', while the latter prevents
373
374 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
375
376 from being allowed, since the search expression has the bit size of a and b,
377 which can't be specialized to 32 which is the bitsize of the replace
378 expression. It also prevents something like:
379
380 (('b2i', ('i2b', a)), ('ineq', a, 0))
381
382 since the bitsize of 'b2i', which can be anything, can't be specialized to
383 the bitsize of a.
384
385 After doing all this, we check that every subexpression of the replacement
386 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
387 of the search expresssion, since those are the things that are known when
388 constructing the replacement expresssion. Finally, we record the bitsize
389 needed in nir_search_value so that we know what to do when building the
390 replacement expression.
391 """
392
393 def __init__(self, varset):
394 self._var_classes = [None] * len(varset.names)
395
396 def compare_bitsizes(self, a, b):
397 """Determines which bitsize class is a specialization of the other, or
398 whether neither is. When we merge two different bitsizes, the
399 less-specialized bitsize always points to the more-specialized one, so
400 that calling get_bit_size() always gets you the most specialized bitsize.
401 The specialization partial order is given by:
402 - Physical bitsizes are always the most specialized, and a different
403 bitsize can never specialize another.
404 - In the search expression, variables can always be specialized to each
405 other and to physical bitsizes. In the replace expression, we disallow
406 this to avoid adding extra constraints to the search expression that
407 the user didn't specify.
408 - Expressions and constants without a bitsize can always be specialized to
409 each other and variables, but not the other way around.
410
411 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
412 and None if they are not comparable (neither a <= b nor b <= a).
413 """
414 if isinstance(a, int):
415 if isinstance(b, int):
416 return 0 if a == b else None
417 elif isinstance(b, Variable):
418 return -1 if self.is_search else None
419 else:
420 return -1
421 elif isinstance(a, Variable):
422 if isinstance(b, int):
423 return 1 if self.is_search else None
424 elif isinstance(b, Variable):
425 return 0 if self.is_search or a.index == b.index else None
426 else:
427 return -1
428 else:
429 if isinstance(b, int):
430 return 1
431 elif isinstance(b, Variable):
432 return 1
433 else:
434 return 0
435
436 def unify_bit_size(self, a, b, error_msg):
437 """Record that a must have the same bit-size as b. If both
438 have been assigned conflicting physical bit-sizes, call "error_msg" with
439 the bit-sizes of self and other to get a message and raise an error.
440 In the replace expression, disallow merging variables with other
441 variables and physical bit-sizes as well.
442 """
443 a_bit_size = a.get_bit_size()
444 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
445
446 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
447
448 assert cmp_result is not None, \
449 error_msg(a_bit_size, b_bit_size)
450
451 if cmp_result < 0:
452 b_bit_size.set_bit_size(a)
453 elif not isinstance(a_bit_size, int):
454 a_bit_size.set_bit_size(b)
455
456 def merge_variables(self, val):
457 """Perform the first part of type inference by merging all the different
458 uses of the same variable. We always do this as if we're in the search
459 expression, even if we're actually not, since otherwise we'd get errors
460 if the search expression specified some constraint but the replace
461 expression didn't, because we'd be merging a variable and a constant.
462 """
463 if isinstance(val, Variable):
464 if self._var_classes[val.index] is None:
465 self._var_classes[val.index] = val
466 else:
467 other = self._var_classes[val.index]
468 self.unify_bit_size(other, val,
469 lambda other_bit_size, bit_size:
470 'Variable {} has conflicting bit size requirements: ' \
471 'it must have bit size {} and {}'.format(
472 val.var_name, other_bit_size, bit_size))
473 elif isinstance(val, Expression):
474 for src in val.sources:
475 self.merge_variables(src)
476
477 def validate_value(self, val):
478 """Validate the an expression by performing classic Hindley-Milner
479 type inference on bitsizes. This will detect if there are any conflicting
480 requirements, and unify variables so that we know which variables must
481 have the same bitsize. If we're operating on the replace expression, we
482 will refuse to merge different variables together or merge a variable
483 with a constant, in order to prevent surprises due to rules unexpectedly
484 not matching at runtime.
485 """
486 if not isinstance(val, Expression):
487 return
488
489 # Generic conversion ops are special in that they have a single unsized
490 # source and an unsized destination and the two don't have to match.
491 # This means there's no validation or unioning to do here besides the
492 # len(val.sources) check.
493 if val.opcode in conv_opcode_types:
494 assert len(val.sources) == 1, \
495 "Expression {} has {} sources, expected 1".format(
496 val, len(val.sources))
497 self.validate_value(val.sources[0])
498 return
499
500 nir_op = opcodes[val.opcode]
501 assert len(val.sources) == nir_op.num_inputs, \
502 "Expression {} has {} sources, expected {}".format(
503 val, len(val.sources), nir_op.num_inputs)
504
505 for src in val.sources:
506 self.validate_value(src)
507
508 dst_type_bits = type_bits(nir_op.output_type)
509
510 # First, unify all the sources. That way, an error coming up because two
511 # sources have an incompatible bit-size won't produce an error message
512 # involving the destination.
513 first_unsized_src = None
514 for src_type, src in zip(nir_op.input_types, val.sources):
515 src_type_bits = type_bits(src_type)
516 if src_type_bits == 0:
517 if first_unsized_src is None:
518 first_unsized_src = src
519 continue
520
521 if self.is_search:
522 self.unify_bit_size(first_unsized_src, src,
523 lambda first_unsized_src_bit_size, src_bit_size:
524 'Source {} of {} must have bit size {}, while source {} ' \
525 'must have incompatible bit size {}'.format(
526 first_unsized_src, val, first_unsized_src_bit_size,
527 src, src_bit_size))
528 else:
529 self.unify_bit_size(first_unsized_src, src,
530 lambda first_unsized_src_bit_size, src_bit_size:
531 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
532 'of {} may not have the same bit size when building the ' \
533 'replacement expression.'.format(
534 first_unsized_src, first_unsized_src_bit_size, src,
535 src_bit_size, val))
536 else:
537 if self.is_search:
538 self.unify_bit_size(src, src_type_bits,
539 lambda src_bit_size, unused:
540 '{} must have {} bits, but as a source of nir_op_{} '\
541 'it must have {} bits'.format(
542 src, src_bit_size, nir_op.name, src_type_bits))
543 else:
544 self.unify_bit_size(src, src_type_bits,
545 lambda src_bit_size, unused:
546 '{} has the bit size of {}, but as a source of ' \
547 'nir_op_{} it must have {} bits, which may not be the ' \
548 'same'.format(
549 src, src_bit_size, nir_op.name, src_type_bits))
550
551 if dst_type_bits == 0:
552 if first_unsized_src is not None:
553 if self.is_search:
554 self.unify_bit_size(val, first_unsized_src,
555 lambda val_bit_size, src_bit_size:
556 '{} must have the bit size of {}, while its source {} ' \
557 'must have incompatible bit size {}'.format(
558 val, val_bit_size, first_unsized_src, src_bit_size))
559 else:
560 self.unify_bit_size(val, first_unsized_src,
561 lambda val_bit_size, src_bit_size:
562 '{} must have {} bits, but its source {} ' \
563 '(bit size of {}) may not have that bit size ' \
564 'when building the replacement.'.format(
565 val, val_bit_size, first_unsized_src, src_bit_size))
566 else:
567 self.unify_bit_size(val, dst_type_bits,
568 lambda dst_bit_size, unused:
569 '{} must have {} bits, but as a destination of nir_op_{} ' \
570 'it must have {} bits'.format(
571 val, dst_bit_size, nir_op.name, dst_type_bits))
572
573 def validate_replace(self, val, search):
574 bit_size = val.get_bit_size()
575 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
576 bit_size == search.get_bit_size(), \
577 'Ambiguous bit size for replacement value {}: ' \
578 'it cannot be deduced from a variable, a fixed bit size ' \
579 'somewhere, or the search expression.'.format(val)
580
581 if isinstance(val, Expression):
582 for src in val.sources:
583 self.validate_replace(src, search)
584
585 def validate(self, search, replace):
586 self.is_search = True
587 self.merge_variables(search)
588 self.merge_variables(replace)
589 self.validate_value(search)
590
591 self.is_search = False
592 self.validate_value(replace)
593
594 # Check that search is always more specialized than replace. Note that
595 # we're doing this in replace mode, disallowing merging variables.
596 search_bit_size = search.get_bit_size()
597 replace_bit_size = replace.get_bit_size()
598 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
599
600 assert cmp_result is not None and cmp_result <= 0, \
601 'The search expression bit size {} and replace expression ' \
602 'bit size {} may not be the same'.format(
603 search_bit_size, replace_bit_size)
604
605 replace.set_bit_size(search)
606
607 self.validate_replace(replace, search)
608
609 _optimization_ids = itertools.count()
610
611 condition_list = ['true']
612
613 class SearchAndReplace(object):
614 def __init__(self, transform):
615 self.id = next(_optimization_ids)
616
617 search = transform[0]
618 replace = transform[1]
619 if len(transform) > 2:
620 self.condition = transform[2]
621 else:
622 self.condition = 'true'
623
624 if self.condition not in condition_list:
625 condition_list.append(self.condition)
626 self.condition_index = condition_list.index(self.condition)
627
628 varset = VarSet()
629 if isinstance(search, Expression):
630 self.search = search
631 else:
632 self.search = Expression(search, "search{0}".format(self.id), varset)
633
634 varset.lock()
635
636 if isinstance(replace, Value):
637 self.replace = replace
638 else:
639 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
640
641 BitSizeValidator(varset).validate(self.search, self.replace)
642
643 _algebraic_pass_template = mako.template.Template("""
644 #include "nir.h"
645 #include "nir_builder.h"
646 #include "nir_search.h"
647 #include "nir_search_helpers.h"
648
649 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
650 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
651
652 struct transform {
653 const nir_search_expression *search;
654 const nir_search_value *replace;
655 unsigned condition_offset;
656 };
657
658 #endif
659
660 % for xform in xforms:
661 ${xform.search.render()}
662 ${xform.replace.render()}
663 % endfor
664
665 % for (opcode, xform_list) in sorted(opcode_xforms.items()):
666 static const struct transform ${pass_name}_${opcode}_xforms[] = {
667 % for xform in xform_list:
668 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
669 % endfor
670 };
671 % endfor
672
673 static bool
674 ${pass_name}_block(nir_builder *build, nir_block *block,
675 const bool *condition_flags)
676 {
677 bool progress = false;
678
679 nir_foreach_instr_reverse_safe(instr, block) {
680 if (instr->type != nir_instr_type_alu)
681 continue;
682
683 nir_alu_instr *alu = nir_instr_as_alu(instr);
684 if (!alu->dest.dest.is_ssa)
685 continue;
686
687 switch (alu->op) {
688 % for opcode in sorted(opcode_xforms.keys()):
689 case nir_op_${opcode}:
690 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
691 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
692 if (condition_flags[xform->condition_offset] &&
693 nir_replace_instr(build, alu, xform->search, xform->replace)) {
694 progress = true;
695 break;
696 }
697 }
698 break;
699 % endfor
700 default:
701 break;
702 }
703 }
704
705 return progress;
706 }
707
708 static bool
709 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
710 {
711 bool progress = false;
712
713 nir_builder build;
714 nir_builder_init(&build, impl);
715
716 nir_foreach_block_reverse(block, impl) {
717 progress |= ${pass_name}_block(&build, block, condition_flags);
718 }
719
720 if (progress)
721 nir_metadata_preserve(impl, nir_metadata_block_index |
722 nir_metadata_dominance);
723
724 return progress;
725 }
726
727
728 bool
729 ${pass_name}(nir_shader *shader)
730 {
731 bool progress = false;
732 bool condition_flags[${len(condition_list)}];
733 const nir_shader_compiler_options *options = shader->options;
734 (void) options;
735
736 % for index, condition in enumerate(condition_list):
737 condition_flags[${index}] = ${condition};
738 % endfor
739
740 nir_foreach_function(function, shader) {
741 if (function->impl)
742 progress |= ${pass_name}_impl(function->impl, condition_flags);
743 }
744
745 return progress;
746 }
747 """)
748
749 class AlgebraicPass(object):
750 def __init__(self, pass_name, transforms):
751 self.xforms = []
752 self.opcode_xforms = defaultdict(lambda : [])
753 self.pass_name = pass_name
754
755 error = False
756
757 for xform in transforms:
758 if not isinstance(xform, SearchAndReplace):
759 try:
760 xform = SearchAndReplace(xform)
761 except:
762 print("Failed to parse transformation:", file=sys.stderr)
763 print(" " + str(xform), file=sys.stderr)
764 traceback.print_exc(file=sys.stderr)
765 print('', file=sys.stderr)
766 error = True
767 continue
768
769 self.xforms.append(xform)
770 if xform.search.opcode in conv_opcode_types:
771 dst_type = conv_opcode_types[xform.search.opcode]
772 for size in type_sizes(dst_type):
773 sized_opcode = xform.search.opcode + str(size)
774 self.opcode_xforms[sized_opcode].append(xform)
775 else:
776 self.opcode_xforms[xform.search.opcode].append(xform)
777
778 if error:
779 sys.exit(1)
780
781
782 def render(self):
783 return _algebraic_pass_template.render(pass_name=self.pass_name,
784 xforms=self.xforms,
785 opcode_xforms=self.opcode_xforms,
786 condition_list=condition_list)