nir: Rename Boolean-related opcodes to include 32 in the name
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # These opcodes are only employed by nir_search. This provides a mapping from
39 # opcode to destination type.
40 conv_opcode_types = {
41 'i2f' : 'float',
42 'u2f' : 'float',
43 'f2f' : 'float',
44 'f2u' : 'uint',
45 'f2i' : 'int',
46 'u2u' : 'uint',
47 'i2i' : 'int',
48 'b2f' : 'float',
49 'b2i' : 'int',
50 'i2b' : 'bool',
51 'f2b' : 'bool',
52 }
53
54 if sys.version_info < (3, 0):
55 integer_types = (int, long)
56 string_type = unicode
57
58 else:
59 integer_types = (int, )
60 string_type = str
61
62 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
63
64 def type_bits(type_str):
65 m = _type_re.match(type_str)
66 assert m.group('type')
67
68 if m.group('bits') is None:
69 return 0
70 else:
71 return int(m.group('bits'))
72
73 # Represents a set of variables, each with a unique id
74 class VarSet(object):
75 def __init__(self):
76 self.names = {}
77 self.ids = itertools.count()
78 self.immutable = False;
79
80 def __getitem__(self, name):
81 if name not in self.names:
82 assert not self.immutable, "Unknown replacement variable: " + name
83 self.names[name] = next(self.ids)
84
85 return self.names[name]
86
87 def lock(self):
88 self.immutable = True
89
90 class Value(object):
91 @staticmethod
92 def create(val, name_base, varset):
93 if isinstance(val, bytes):
94 val = val.decode('utf-8')
95
96 if isinstance(val, tuple):
97 return Expression(val, name_base, varset)
98 elif isinstance(val, Expression):
99 return val
100 elif isinstance(val, string_type):
101 return Variable(val, name_base, varset)
102 elif isinstance(val, (bool, float) + integer_types):
103 return Constant(val, name_base)
104
105 __template = mako.template.Template("""
106 static const ${val.c_type} ${val.name} = {
107 { ${val.type_enum}, ${val.c_bit_size} },
108 % if isinstance(val, Constant):
109 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
110 % elif isinstance(val, Variable):
111 ${val.index}, /* ${val.var_name} */
112 ${'true' if val.is_constant else 'false'},
113 ${val.type() or 'nir_type_invalid' },
114 ${val.cond if val.cond else 'NULL'},
115 % elif isinstance(val, Expression):
116 ${'true' if val.inexact else 'false'},
117 ${val.c_opcode()},
118 { ${', '.join(src.c_ptr for src in val.sources)} },
119 ${val.cond if val.cond else 'NULL'},
120 % endif
121 };""")
122
123 def __init__(self, val, name, type_str):
124 self.in_val = str(val)
125 self.name = name
126 self.type_str = type_str
127
128 def __str__(self):
129 return self.in_val
130
131 def get_bit_size(self):
132 """Get the physical bit-size that has been chosen for this value, or if
133 there is none, the canonical value which currently represents this
134 bit-size class. Variables will be preferred, i.e. if there are any
135 variables in the equivalence class, the canonical value will be a
136 variable. We do this since we'll need to know which variable each value
137 is equivalent to when constructing the replacement expression. This is
138 the "find" part of the union-find algorithm.
139 """
140 bit_size = self
141
142 while isinstance(bit_size, Value):
143 if bit_size._bit_size is None:
144 break
145 bit_size = bit_size._bit_size
146
147 if bit_size is not self:
148 self._bit_size = bit_size
149 return bit_size
150
151 def set_bit_size(self, other):
152 """Make self.get_bit_size() return what other.get_bit_size() return
153 before calling this, or just "other" if it's a concrete bit-size. This is
154 the "union" part of the union-find algorithm.
155 """
156
157 self_bit_size = self.get_bit_size()
158 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
159
160 if self_bit_size == other_bit_size:
161 return
162
163 self_bit_size._bit_size = other_bit_size
164
165 @property
166 def type_enum(self):
167 return "nir_search_value_" + self.type_str
168
169 @property
170 def c_type(self):
171 return "nir_search_" + self.type_str
172
173 @property
174 def c_ptr(self):
175 return "&{0}.value".format(self.name)
176
177 @property
178 def c_bit_size(self):
179 bit_size = self.get_bit_size()
180 if isinstance(bit_size, int):
181 return bit_size
182 elif isinstance(bit_size, Variable):
183 return -bit_size.index - 1
184 else:
185 # If the bit-size class is neither a variable, nor an actual bit-size, then
186 # - If it's in the search expression, we don't need to check anything
187 # - If it's in the replace expression, either it's ambiguous (in which
188 # case we'd reject it), or it equals the bit-size of the search value
189 # We represent these cases with a 0 bit-size.
190 return 0
191
192 def render(self):
193 return self.__template.render(val=self,
194 Constant=Constant,
195 Variable=Variable,
196 Expression=Expression)
197
198 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
199
200 class Constant(Value):
201 def __init__(self, val, name):
202 Value.__init__(self, val, name, "constant")
203
204 if isinstance(val, (str)):
205 m = _constant_re.match(val)
206 self.value = ast.literal_eval(m.group('value'))
207 self._bit_size = int(m.group('bits')) if m.group('bits') else None
208 else:
209 self.value = val
210 self._bit_size = None
211
212 if isinstance(self.value, bool):
213 assert self._bit_size is None or self._bit_size == 32
214 self._bit_size = 32
215
216 def hex(self):
217 if isinstance(self.value, (bool)):
218 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
219 if isinstance(self.value, integer_types):
220 return hex(self.value)
221 elif isinstance(self.value, float):
222 i = struct.unpack('Q', struct.pack('d', self.value))[0]
223 h = hex(i)
224
225 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
226 # Adding it explicitly makes the generated file identical, regardless
227 # of the Python version running this script.
228 if h[-1] != 'L' and i > sys.maxsize:
229 h += 'L'
230
231 return h
232 else:
233 assert False
234
235 def type(self):
236 if isinstance(self.value, (bool)):
237 return "nir_type_bool"
238 elif isinstance(self.value, integer_types):
239 return "nir_type_int"
240 elif isinstance(self.value, float):
241 return "nir_type_float"
242
243 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
244 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
245 r"(?P<cond>\([^\)]+\))?")
246
247 class Variable(Value):
248 def __init__(self, val, name, varset):
249 Value.__init__(self, val, name, "variable")
250
251 m = _var_name_re.match(val)
252 assert m and m.group('name') is not None
253
254 self.var_name = m.group('name')
255 self.is_constant = m.group('const') is not None
256 self.cond = m.group('cond')
257 self.required_type = m.group('type')
258 self._bit_size = int(m.group('bits')) if m.group('bits') else None
259
260 if self.required_type == 'bool':
261 assert self._bit_size is None or self._bit_size == 32
262 self._bit_size = 32
263
264 if self.required_type is not None:
265 assert self.required_type in ('float', 'bool', 'int', 'uint')
266
267 self.index = varset[self.var_name]
268
269 def type(self):
270 if self.required_type == 'bool':
271 return "nir_type_bool"
272 elif self.required_type in ('int', 'uint'):
273 return "nir_type_int"
274 elif self.required_type == 'float':
275 return "nir_type_float"
276
277 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
278 r"(?P<cond>\([^\)]+\))?")
279
280 opcode_remap = {
281 'flt' : 'flt32',
282 'fge' : 'fge32',
283 'feq' : 'feq32',
284 'fne' : 'fne32',
285 'ilt' : 'ilt32',
286 'ige' : 'ige32',
287 'ieq' : 'ieq32',
288 'ine' : 'ine32',
289 'ult' : 'ult32',
290 'uge' : 'uge32',
291
292 'ball_iequal2' : 'b32all_iequal2',
293 'ball_iequal3' : 'b32all_iequal3',
294 'ball_iequal4' : 'b32all_iequal4',
295 'bany_inequal2' : 'b32any_inequal2',
296 'bany_inequal3' : 'b32any_inequal3',
297 'bany_inequal4' : 'b32any_inequal4',
298 'ball_fequal2' : 'b32all_fequal2',
299 'ball_fequal3' : 'b32all_fequal3',
300 'ball_fequal4' : 'b32all_fequal4',
301 'bany_fnequal2' : 'b32any_fnequal2',
302 'bany_fnequal3' : 'b32any_fnequal3',
303 'bany_fnequal4' : 'b32any_fnequal4',
304
305 'bcsel' : 'b32csel',
306 }
307
308 class Expression(Value):
309 def __init__(self, expr, name_base, varset):
310 Value.__init__(self, expr, name_base, "expression")
311 assert isinstance(expr, tuple)
312
313 m = _opcode_re.match(expr[0])
314 assert m and m.group('opcode') is not None
315
316 self.opcode = m.group('opcode')
317 if self.opcode in opcode_remap:
318 self.opcode = opcode_remap[self.opcode]
319 self._bit_size = int(m.group('bits')) if m.group('bits') else None
320 self.inexact = m.group('inexact') is not None
321 self.cond = m.group('cond')
322 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
323 for (i, src) in enumerate(expr[1:]) ]
324
325 if self.opcode in conv_opcode_types:
326 assert self._bit_size is None, \
327 'Expression cannot use an unsized conversion opcode with ' \
328 'an explicit size; that\'s silly.'
329
330
331 def c_opcode(self):
332 if self.opcode in conv_opcode_types:
333 return 'nir_search_op_' + self.opcode
334 else:
335 return 'nir_op_' + self.opcode
336
337 def render(self):
338 srcs = "\n".join(src.render() for src in self.sources)
339 return srcs + super(Expression, self).render()
340
341 class BitSizeValidator(object):
342 """A class for validating bit sizes of expressions.
343
344 NIR supports multiple bit-sizes on expressions in order to handle things
345 such as fp64. The source and destination of every ALU operation is
346 assigned a type and that type may or may not specify a bit size. Sources
347 and destinations whose type does not specify a bit size are considered
348 "unsized" and automatically take on the bit size of the corresponding
349 register or SSA value. NIR has two simple rules for bit sizes that are
350 validated by nir_validator:
351
352 1) A given SSA def or register has a single bit size that is respected by
353 everything that reads from it or writes to it.
354
355 2) The bit sizes of all unsized inputs/outputs on any given ALU
356 instruction must match. They need not match the sized inputs or
357 outputs but they must match each other.
358
359 In order to keep nir_algebraic relatively simple and easy-to-use,
360 nir_search supports a type of bit-size inference based on the two rules
361 above. This is similar to type inference in many common programming
362 languages. If, for instance, you are constructing an add operation and you
363 know the second source is 16-bit, then you know that the other source and
364 the destination must also be 16-bit. There are, however, cases where this
365 inference can be ambiguous or contradictory. Consider, for instance, the
366 following transformation:
367
368 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
369
370 This transformation can potentially cause a problem because usub_borrow is
371 well-defined for any bit-size of integer. However, b2i always generates a
372 32-bit result so it could end up replacing a 64-bit expression with one
373 that takes two 64-bit values and produces a 32-bit value. As another
374 example, consider this expression:
375
376 (('bcsel', a, b, 0), ('iand', a, b))
377
378 In this case, in the search expression a must be 32-bit but b can
379 potentially have any bit size. If we had a 64-bit b value, we would end up
380 trying to and a 32-bit value with a 64-bit value which would be invalid
381
382 This class solves that problem by providing a validation layer that proves
383 that a given search-and-replace operation is 100% well-defined before we
384 generate any code. This ensures that bugs are caught at compile time
385 rather than at run time.
386
387 Each value maintains a "bit-size class", which is either an actual bit size
388 or an equivalence class with other values that must have the same bit size.
389 The validator works by combining bit-size classes with each other according
390 to the NIR rules outlined above, checking that there are no inconsistencies.
391 When doing this for the replacement expression, we make sure to never change
392 the equivalence class of any of the search values. We could make the example
393 transforms above work by doing some extra run-time checking of the search
394 expression, but we make the user specify those constraints themselves, to
395 avoid any surprises. Since the replacement bitsizes can only be connected to
396 the source bitsize via variables (variables must have the same bitsize in
397 the source and replacment expressions) or the roots of the expression (the
398 replacement expression must produce the same bit size as the search
399 expression), we prevent merging a variable with anything when processing the
400 replacement expression, or specializing the search bitsize
401 with anything. The former prevents
402
403 (('bcsel', a, b, 0), ('iand', a, b))
404
405 from being allowed, since we'd have to merge the bitsizes for a and b due to
406 the 'iand', while the latter prevents
407
408 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
409
410 from being allowed, since the search expression has the bit size of a and b,
411 which can't be specialized to 32 which is the bitsize of the replace
412 expression. It also prevents something like:
413
414 (('b2i', ('i2b', a)), ('ineq', a, 0))
415
416 since the bitsize of 'b2i', which can be anything, can't be specialized to
417 the bitsize of a.
418
419 After doing all this, we check that every subexpression of the replacement
420 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
421 of the search expresssion, since those are the things that are known when
422 constructing the replacement expresssion. Finally, we record the bitsize
423 needed in nir_search_value so that we know what to do when building the
424 replacement expression.
425 """
426
427 def __init__(self, varset):
428 self._var_classes = [None] * len(varset.names)
429
430 def compare_bitsizes(self, a, b):
431 """Determines which bitsize class is a specialization of the other, or
432 whether neither is. When we merge two different bitsizes, the
433 less-specialized bitsize always points to the more-specialized one, so
434 that calling get_bit_size() always gets you the most specialized bitsize.
435 The specialization partial order is given by:
436 - Physical bitsizes are always the most specialized, and a different
437 bitsize can never specialize another.
438 - In the search expression, variables can always be specialized to each
439 other and to physical bitsizes. In the replace expression, we disallow
440 this to avoid adding extra constraints to the search expression that
441 the user didn't specify.
442 - Expressions and constants without a bitsize can always be specialized to
443 each other and variables, but not the other way around.
444
445 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
446 and None if they are not comparable (neither a <= b nor b <= a).
447 """
448 if isinstance(a, int):
449 if isinstance(b, int):
450 return 0 if a == b else None
451 elif isinstance(b, Variable):
452 return -1 if self.is_search else None
453 else:
454 return -1
455 elif isinstance(a, Variable):
456 if isinstance(b, int):
457 return 1 if self.is_search else None
458 elif isinstance(b, Variable):
459 return 0 if self.is_search or a.index == b.index else None
460 else:
461 return -1
462 else:
463 if isinstance(b, int):
464 return 1
465 elif isinstance(b, Variable):
466 return 1
467 else:
468 return 0
469
470 def unify_bit_size(self, a, b, error_msg):
471 """Record that a must have the same bit-size as b. If both
472 have been assigned conflicting physical bit-sizes, call "error_msg" with
473 the bit-sizes of self and other to get a message and raise an error.
474 In the replace expression, disallow merging variables with other
475 variables and physical bit-sizes as well.
476 """
477 a_bit_size = a.get_bit_size()
478 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
479
480 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
481
482 assert cmp_result is not None, \
483 error_msg(a_bit_size, b_bit_size)
484
485 if cmp_result < 0:
486 b_bit_size.set_bit_size(a)
487 elif not isinstance(a_bit_size, int):
488 a_bit_size.set_bit_size(b)
489
490 def merge_variables(self, val):
491 """Perform the first part of type inference by merging all the different
492 uses of the same variable. We always do this as if we're in the search
493 expression, even if we're actually not, since otherwise we'd get errors
494 if the search expression specified some constraint but the replace
495 expression didn't, because we'd be merging a variable and a constant.
496 """
497 if isinstance(val, Variable):
498 if self._var_classes[val.index] is None:
499 self._var_classes[val.index] = val
500 else:
501 other = self._var_classes[val.index]
502 self.unify_bit_size(other, val,
503 lambda other_bit_size, bit_size:
504 'Variable {} has conflicting bit size requirements: ' \
505 'it must have bit size {} and {}'.format(
506 val.var_name, other_bit_size, bit_size))
507 elif isinstance(val, Expression):
508 for src in val.sources:
509 self.merge_variables(src)
510
511 def validate_value(self, val):
512 """Validate the an expression by performing classic Hindley-Milner
513 type inference on bitsizes. This will detect if there are any conflicting
514 requirements, and unify variables so that we know which variables must
515 have the same bitsize. If we're operating on the replace expression, we
516 will refuse to merge different variables together or merge a variable
517 with a constant, in order to prevent surprises due to rules unexpectedly
518 not matching at runtime.
519 """
520 if not isinstance(val, Expression):
521 return
522
523 # Generic conversion ops are special in that they have a single unsized
524 # source and an unsized destination and the two don't have to match.
525 # This means there's no validation or unioning to do here besides the
526 # len(val.sources) check.
527 if val.opcode in conv_opcode_types:
528 assert len(val.sources) == 1, \
529 "Expression {} has {} sources, expected 1".format(
530 val, len(val.sources))
531 self.validate_value(val.sources[0])
532 return
533
534 nir_op = opcodes[val.opcode]
535 assert len(val.sources) == nir_op.num_inputs, \
536 "Expression {} has {} sources, expected {}".format(
537 val, len(val.sources), nir_op.num_inputs)
538
539 for src in val.sources:
540 self.validate_value(src)
541
542 dst_type_bits = type_bits(nir_op.output_type)
543
544 # First, unify all the sources. That way, an error coming up because two
545 # sources have an incompatible bit-size won't produce an error message
546 # involving the destination.
547 first_unsized_src = None
548 for src_type, src in zip(nir_op.input_types, val.sources):
549 src_type_bits = type_bits(src_type)
550 if src_type_bits == 0:
551 if first_unsized_src is None:
552 first_unsized_src = src
553 continue
554
555 if self.is_search:
556 self.unify_bit_size(first_unsized_src, src,
557 lambda first_unsized_src_bit_size, src_bit_size:
558 'Source {} of {} must have bit size {}, while source {} ' \
559 'must have incompatible bit size {}'.format(
560 first_unsized_src, val, first_unsized_src_bit_size,
561 src, src_bit_size))
562 else:
563 self.unify_bit_size(first_unsized_src, src,
564 lambda first_unsized_src_bit_size, src_bit_size:
565 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
566 'of {} may not have the same bit size when building the ' \
567 'replacement expression.'.format(
568 first_unsized_src, first_unsized_src_bit_size, src,
569 src_bit_size, val))
570 else:
571 if self.is_search:
572 self.unify_bit_size(src, src_type_bits,
573 lambda src_bit_size, unused:
574 '{} must have {} bits, but as a source of nir_op_{} '\
575 'it must have {} bits'.format(
576 src, src_bit_size, nir_op.name, src_type_bits))
577 else:
578 self.unify_bit_size(src, src_type_bits,
579 lambda src_bit_size, unused:
580 '{} has the bit size of {}, but as a source of ' \
581 'nir_op_{} it must have {} bits, which may not be the ' \
582 'same'.format(
583 src, src_bit_size, nir_op.name, src_type_bits))
584
585 if dst_type_bits == 0:
586 if first_unsized_src is not None:
587 if self.is_search:
588 self.unify_bit_size(val, first_unsized_src,
589 lambda val_bit_size, src_bit_size:
590 '{} must have the bit size of {}, while its source {} ' \
591 'must have incompatible bit size {}'.format(
592 val, val_bit_size, first_unsized_src, src_bit_size))
593 else:
594 self.unify_bit_size(val, first_unsized_src,
595 lambda val_bit_size, src_bit_size:
596 '{} must have {} bits, but its source {} ' \
597 '(bit size of {}) may not have that bit size ' \
598 'when building the replacement.'.format(
599 val, val_bit_size, first_unsized_src, src_bit_size))
600 else:
601 self.unify_bit_size(val, dst_type_bits,
602 lambda dst_bit_size, unused:
603 '{} must have {} bits, but as a destination of nir_op_{} ' \
604 'it must have {} bits'.format(
605 val, dst_bit_size, nir_op.name, dst_type_bits))
606
607 def validate_replace(self, val, search):
608 bit_size = val.get_bit_size()
609 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
610 bit_size == search.get_bit_size(), \
611 'Ambiguous bit size for replacement value {}: ' \
612 'it cannot be deduced from a variable, a fixed bit size ' \
613 'somewhere, or the search expression.'.format(val)
614
615 if isinstance(val, Expression):
616 for src in val.sources:
617 self.validate_replace(src, search)
618
619 def validate(self, search, replace):
620 self.is_search = True
621 self.merge_variables(search)
622 self.merge_variables(replace)
623 self.validate_value(search)
624
625 self.is_search = False
626 self.validate_value(replace)
627
628 # Check that search is always more specialized than replace. Note that
629 # we're doing this in replace mode, disallowing merging variables.
630 search_bit_size = search.get_bit_size()
631 replace_bit_size = replace.get_bit_size()
632 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
633
634 assert cmp_result is not None and cmp_result <= 0, \
635 'The search expression bit size {} and replace expression ' \
636 'bit size {} may not be the same'.format(
637 search_bit_size, replace_bit_size)
638
639 replace.set_bit_size(search)
640
641 self.validate_replace(replace, search)
642
643 _optimization_ids = itertools.count()
644
645 condition_list = ['true']
646
647 class SearchAndReplace(object):
648 def __init__(self, transform):
649 self.id = next(_optimization_ids)
650
651 search = transform[0]
652 replace = transform[1]
653 if len(transform) > 2:
654 self.condition = transform[2]
655 else:
656 self.condition = 'true'
657
658 if self.condition not in condition_list:
659 condition_list.append(self.condition)
660 self.condition_index = condition_list.index(self.condition)
661
662 varset = VarSet()
663 if isinstance(search, Expression):
664 self.search = search
665 else:
666 self.search = Expression(search, "search{0}".format(self.id), varset)
667
668 varset.lock()
669
670 if isinstance(replace, Value):
671 self.replace = replace
672 else:
673 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
674
675 BitSizeValidator(varset).validate(self.search, self.replace)
676
677 _algebraic_pass_template = mako.template.Template("""
678 #include "nir.h"
679 #include "nir_builder.h"
680 #include "nir_search.h"
681 #include "nir_search_helpers.h"
682
683 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
684 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
685
686 struct transform {
687 const nir_search_expression *search;
688 const nir_search_value *replace;
689 unsigned condition_offset;
690 };
691
692 #endif
693
694 % for xform in xforms:
695 ${xform.search.render()}
696 ${xform.replace.render()}
697 % endfor
698
699 % for (opcode, xform_list) in sorted(opcode_xforms.items()):
700 static const struct transform ${pass_name}_${opcode}_xforms[] = {
701 % for xform in xform_list:
702 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
703 % endfor
704 };
705 % endfor
706
707 static bool
708 ${pass_name}_block(nir_builder *build, nir_block *block,
709 const bool *condition_flags)
710 {
711 bool progress = false;
712
713 nir_foreach_instr_reverse_safe(instr, block) {
714 if (instr->type != nir_instr_type_alu)
715 continue;
716
717 nir_alu_instr *alu = nir_instr_as_alu(instr);
718 if (!alu->dest.dest.is_ssa)
719 continue;
720
721 switch (alu->op) {
722 % for opcode in sorted(opcode_xforms.keys()):
723 case nir_op_${opcode}:
724 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
725 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
726 if (condition_flags[xform->condition_offset] &&
727 nir_replace_instr(build, alu, xform->search, xform->replace)) {
728 progress = true;
729 break;
730 }
731 }
732 break;
733 % endfor
734 default:
735 break;
736 }
737 }
738
739 return progress;
740 }
741
742 static bool
743 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
744 {
745 bool progress = false;
746
747 nir_builder build;
748 nir_builder_init(&build, impl);
749
750 nir_foreach_block_reverse(block, impl) {
751 progress |= ${pass_name}_block(&build, block, condition_flags);
752 }
753
754 if (progress)
755 nir_metadata_preserve(impl, nir_metadata_block_index |
756 nir_metadata_dominance);
757
758 return progress;
759 }
760
761
762 bool
763 ${pass_name}(nir_shader *shader)
764 {
765 bool progress = false;
766 bool condition_flags[${len(condition_list)}];
767 const nir_shader_compiler_options *options = shader->options;
768 (void) options;
769
770 % for index, condition in enumerate(condition_list):
771 condition_flags[${index}] = ${condition};
772 % endfor
773
774 nir_foreach_function(function, shader) {
775 if (function->impl)
776 progress |= ${pass_name}_impl(function->impl, condition_flags);
777 }
778
779 return progress;
780 }
781 """)
782
783 class AlgebraicPass(object):
784 def __init__(self, pass_name, transforms):
785 self.xforms = []
786 self.opcode_xforms = defaultdict(lambda : [])
787 self.pass_name = pass_name
788
789 error = False
790
791 for xform in transforms:
792 if not isinstance(xform, SearchAndReplace):
793 try:
794 xform = SearchAndReplace(xform)
795 except:
796 print("Failed to parse transformation:", file=sys.stderr)
797 print(" " + str(xform), file=sys.stderr)
798 traceback.print_exc(file=sys.stderr)
799 print('', file=sys.stderr)
800 error = True
801 continue
802
803 self.xforms.append(xform)
804 if xform.search.opcode in conv_opcode_types:
805 dst_type = conv_opcode_types[xform.search.opcode]
806 for size in type_sizes(dst_type):
807 sized_opcode = xform.search.opcode + str(size)
808 self.opcode_xforms[sized_opcode].append(xform)
809 else:
810 self.opcode_xforms[xform.search.opcode].append(xform)
811
812 if error:
813 sys.exit(1)
814
815
816 def render(self):
817 return _algebraic_pass_template.render(pass_name=self.pass_name,
818 xforms=self.xforms,
819 opcode_xforms=self.opcode_xforms,
820 condition_list=condition_list)