nir: Report progress properly in nir_lower_bool_to_*
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
39 nir_search_max_comm_ops = 8
40
41 # These opcodes are only employed by nir_search. This provides a mapping from
42 # opcode to destination type.
43 conv_opcode_types = {
44 'i2f' : 'float',
45 'u2f' : 'float',
46 'f2f' : 'float',
47 'f2u' : 'uint',
48 'f2i' : 'int',
49 'u2u' : 'uint',
50 'i2i' : 'int',
51 'b2f' : 'float',
52 'b2i' : 'int',
53 'i2b' : 'bool',
54 'f2b' : 'bool',
55 }
56
57 def get_c_opcode(op):
58 if op in conv_opcode_types:
59 return 'nir_search_op_' + op
60 else:
61 return 'nir_op_' + op
62
63
64 if sys.version_info < (3, 0):
65 integer_types = (int, long)
66 string_type = unicode
67
68 else:
69 integer_types = (int, )
70 string_type = str
71
72 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
73
74 def type_bits(type_str):
75 m = _type_re.match(type_str)
76 assert m.group('type')
77
78 if m.group('bits') is None:
79 return 0
80 else:
81 return int(m.group('bits'))
82
83 # Represents a set of variables, each with a unique id
84 class VarSet(object):
85 def __init__(self):
86 self.names = {}
87 self.ids = itertools.count()
88 self.immutable = False;
89
90 def __getitem__(self, name):
91 if name not in self.names:
92 assert not self.immutable, "Unknown replacement variable: " + name
93 self.names[name] = next(self.ids)
94
95 return self.names[name]
96
97 def lock(self):
98 self.immutable = True
99
100 class Value(object):
101 @staticmethod
102 def create(val, name_base, varset):
103 if isinstance(val, bytes):
104 val = val.decode('utf-8')
105
106 if isinstance(val, tuple):
107 return Expression(val, name_base, varset)
108 elif isinstance(val, Expression):
109 return val
110 elif isinstance(val, string_type):
111 return Variable(val, name_base, varset)
112 elif isinstance(val, (bool, float) + integer_types):
113 return Constant(val, name_base)
114
115 def __init__(self, val, name, type_str):
116 self.in_val = str(val)
117 self.name = name
118 self.type_str = type_str
119
120 def __str__(self):
121 return self.in_val
122
123 def get_bit_size(self):
124 """Get the physical bit-size that has been chosen for this value, or if
125 there is none, the canonical value which currently represents this
126 bit-size class. Variables will be preferred, i.e. if there are any
127 variables in the equivalence class, the canonical value will be a
128 variable. We do this since we'll need to know which variable each value
129 is equivalent to when constructing the replacement expression. This is
130 the "find" part of the union-find algorithm.
131 """
132 bit_size = self
133
134 while isinstance(bit_size, Value):
135 if bit_size._bit_size is None:
136 break
137 bit_size = bit_size._bit_size
138
139 if bit_size is not self:
140 self._bit_size = bit_size
141 return bit_size
142
143 def set_bit_size(self, other):
144 """Make self.get_bit_size() return what other.get_bit_size() return
145 before calling this, or just "other" if it's a concrete bit-size. This is
146 the "union" part of the union-find algorithm.
147 """
148
149 self_bit_size = self.get_bit_size()
150 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
151
152 if self_bit_size == other_bit_size:
153 return
154
155 self_bit_size._bit_size = other_bit_size
156
157 @property
158 def type_enum(self):
159 return "nir_search_value_" + self.type_str
160
161 @property
162 def c_type(self):
163 return "nir_search_" + self.type_str
164
165 def __c_name(self, cache):
166 if cache is not None and self.name in cache:
167 return cache[self.name]
168 else:
169 return self.name
170
171 def c_value_ptr(self, cache):
172 return "&{0}.value".format(self.__c_name(cache))
173
174 def c_ptr(self, cache):
175 return "&{0}".format(self.__c_name(cache))
176
177 @property
178 def c_bit_size(self):
179 bit_size = self.get_bit_size()
180 if isinstance(bit_size, int):
181 return bit_size
182 elif isinstance(bit_size, Variable):
183 return -bit_size.index - 1
184 else:
185 # If the bit-size class is neither a variable, nor an actual bit-size, then
186 # - If it's in the search expression, we don't need to check anything
187 # - If it's in the replace expression, either it's ambiguous (in which
188 # case we'd reject it), or it equals the bit-size of the search value
189 # We represent these cases with a 0 bit-size.
190 return 0
191
192 __template = mako.template.Template("""{
193 { ${val.type_enum}, ${val.c_bit_size} },
194 % if isinstance(val, Constant):
195 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
196 % elif isinstance(val, Variable):
197 ${val.index}, /* ${val.var_name} */
198 ${'true' if val.is_constant else 'false'},
199 ${val.type() or 'nir_type_invalid' },
200 ${val.cond if val.cond else 'NULL'},
201 ${val.swizzle()},
202 % elif isinstance(val, Expression):
203 ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
204 ${val.comm_expr_idx}, ${val.comm_exprs},
205 ${val.c_opcode()},
206 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
207 ${val.cond if val.cond else 'NULL'},
208 % endif
209 };""")
210
211 def render(self, cache):
212 struct_init = self.__template.render(val=self, cache=cache,
213 Constant=Constant,
214 Variable=Variable,
215 Expression=Expression)
216 if cache is not None and struct_init in cache:
217 # If it's in the cache, register a name remap in the cache and render
218 # only a comment saying it's been remapped
219 cache[self.name] = cache[struct_init]
220 return "/* {} -> {} in the cache */\n".format(self.name,
221 cache[struct_init])
222 else:
223 if cache is not None:
224 cache[struct_init] = self.name
225 return "static const {} {} = {}\n".format(self.c_type, self.name,
226 struct_init)
227
228 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
229
230 class Constant(Value):
231 def __init__(self, val, name):
232 Value.__init__(self, val, name, "constant")
233
234 if isinstance(val, (str)):
235 m = _constant_re.match(val)
236 self.value = ast.literal_eval(m.group('value'))
237 self._bit_size = int(m.group('bits')) if m.group('bits') else None
238 else:
239 self.value = val
240 self._bit_size = None
241
242 if isinstance(self.value, bool):
243 assert self._bit_size is None or self._bit_size == 1
244 self._bit_size = 1
245
246 def hex(self):
247 if isinstance(self.value, (bool)):
248 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
249 if isinstance(self.value, integer_types):
250 return hex(self.value)
251 elif isinstance(self.value, float):
252 i = struct.unpack('Q', struct.pack('d', self.value))[0]
253 h = hex(i)
254
255 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
256 # Adding it explicitly makes the generated file identical, regardless
257 # of the Python version running this script.
258 if h[-1] != 'L' and i > sys.maxsize:
259 h += 'L'
260
261 return h
262 else:
263 assert False
264
265 def type(self):
266 if isinstance(self.value, (bool)):
267 return "nir_type_bool"
268 elif isinstance(self.value, integer_types):
269 return "nir_type_int"
270 elif isinstance(self.value, float):
271 return "nir_type_float"
272
273 def equivalent(self, other):
274 """Check that two constants are equivalent.
275
276 This is check is much weaker than equality. One generally cannot be
277 used in place of the other. Using this implementation for the __eq__
278 will break BitSizeValidator.
279
280 """
281 if not isinstance(other, type(self)):
282 return False
283
284 return self.value == other.value
285
286 # The $ at the end forces there to be an error if any part of the string
287 # doesn't match one of the field patterns.
288 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
289 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
290 r"(?P<cond>\([^\)]+\))?"
291 r"(?P<swiz>\.[xyzw]+)?"
292 r"$")
293
294 class Variable(Value):
295 def __init__(self, val, name, varset):
296 Value.__init__(self, val, name, "variable")
297
298 m = _var_name_re.match(val)
299 assert m and m.group('name') is not None, \
300 "Malformed variable name \"{}\".".format(val)
301
302 self.var_name = m.group('name')
303
304 # Prevent common cases where someone puts quotes around a literal
305 # constant. If we want to support names that have numeric or
306 # punctuation characters, we can me the first assertion more flexible.
307 assert self.var_name.isalpha()
308 assert self.var_name != 'True'
309 assert self.var_name != 'False'
310
311 self.is_constant = m.group('const') is not None
312 self.cond = m.group('cond')
313 self.required_type = m.group('type')
314 self._bit_size = int(m.group('bits')) if m.group('bits') else None
315 self.swiz = m.group('swiz')
316
317 if self.required_type == 'bool':
318 if self._bit_size is not None:
319 assert self._bit_size in type_sizes(self.required_type)
320 else:
321 self._bit_size = 1
322
323 if self.required_type is not None:
324 assert self.required_type in ('float', 'bool', 'int', 'uint')
325
326 self.index = varset[self.var_name]
327
328 def type(self):
329 if self.required_type == 'bool':
330 return "nir_type_bool"
331 elif self.required_type in ('int', 'uint'):
332 return "nir_type_int"
333 elif self.required_type == 'float':
334 return "nir_type_float"
335
336 def equivalent(self, other):
337 """Check that two variables are equivalent.
338
339 This is check is much weaker than equality. One generally cannot be
340 used in place of the other. Using this implementation for the __eq__
341 will break BitSizeValidator.
342
343 """
344 if not isinstance(other, type(self)):
345 return False
346
347 return self.index == other.index
348
349 def swizzle(self):
350 if self.swiz is not None:
351 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
352 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
353 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
354 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
355 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
356 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
357 return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
358
359 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
360 r"(?P<cond>\([^\)]+\))?")
361
362 class Expression(Value):
363 def __init__(self, expr, name_base, varset):
364 Value.__init__(self, expr, name_base, "expression")
365 assert isinstance(expr, tuple)
366
367 m = _opcode_re.match(expr[0])
368 assert m and m.group('opcode') is not None
369
370 self.opcode = m.group('opcode')
371 self._bit_size = int(m.group('bits')) if m.group('bits') else None
372 self.inexact = m.group('inexact') is not None
373 self.exact = m.group('exact') is not None
374 self.cond = m.group('cond')
375
376 assert not self.inexact or not self.exact, \
377 'Expression cannot be both exact and inexact.'
378
379 # "many-comm-expr" isn't really a condition. It's notification to the
380 # generator that this pattern is known to have too many commutative
381 # expressions, and an error should not be generated for this case.
382 self.many_commutative_expressions = False
383 if self.cond and self.cond.find("many-comm-expr") >= 0:
384 # Split the condition into a comma-separated list. Remove
385 # "many-comm-expr". If there is anything left, put it back together.
386 c = self.cond[1:-1].split(",")
387 c.remove("many-comm-expr")
388
389 self.cond = "({})".format(",".join(c)) if c else None
390 self.many_commutative_expressions = True
391
392 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
393 for (i, src) in enumerate(expr[1:]) ]
394
395 if self.opcode in conv_opcode_types:
396 assert self._bit_size is None, \
397 'Expression cannot use an unsized conversion opcode with ' \
398 'an explicit size; that\'s silly.'
399
400 self.__index_comm_exprs(0)
401
402 def equivalent(self, other):
403 """Check that two variables are equivalent.
404
405 This is check is much weaker than equality. One generally cannot be
406 used in place of the other. Using this implementation for the __eq__
407 will break BitSizeValidator.
408
409 This implementation does not check for equivalence due to commutativity,
410 but it could.
411
412 """
413 if not isinstance(other, type(self)):
414 return False
415
416 if len(self.sources) != len(other.sources):
417 return False
418
419 if self.opcode != other.opcode:
420 return False
421
422 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
423
424 def __index_comm_exprs(self, base_idx):
425 """Recursively count and index commutative expressions
426 """
427 self.comm_exprs = 0
428
429 # A note about the explicit "len(self.sources)" check. The list of
430 # sources comes from user input, and that input might be bad. Check
431 # that the expected second source exists before accessing it. Without
432 # this check, a unit test that does "('iadd', 'a')" will crash.
433 if self.opcode not in conv_opcode_types and \
434 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
435 len(self.sources) >= 2 and \
436 not self.sources[0].equivalent(self.sources[1]):
437 self.comm_expr_idx = base_idx
438 self.comm_exprs += 1
439 else:
440 self.comm_expr_idx = -1
441
442 for s in self.sources:
443 if isinstance(s, Expression):
444 s.__index_comm_exprs(base_idx + self.comm_exprs)
445 self.comm_exprs += s.comm_exprs
446
447 return self.comm_exprs
448
449 def c_opcode(self):
450 return get_c_opcode(self.opcode)
451
452 def render(self, cache):
453 srcs = "\n".join(src.render(cache) for src in self.sources)
454 return srcs + super(Expression, self).render(cache)
455
456 class BitSizeValidator(object):
457 """A class for validating bit sizes of expressions.
458
459 NIR supports multiple bit-sizes on expressions in order to handle things
460 such as fp64. The source and destination of every ALU operation is
461 assigned a type and that type may or may not specify a bit size. Sources
462 and destinations whose type does not specify a bit size are considered
463 "unsized" and automatically take on the bit size of the corresponding
464 register or SSA value. NIR has two simple rules for bit sizes that are
465 validated by nir_validator:
466
467 1) A given SSA def or register has a single bit size that is respected by
468 everything that reads from it or writes to it.
469
470 2) The bit sizes of all unsized inputs/outputs on any given ALU
471 instruction must match. They need not match the sized inputs or
472 outputs but they must match each other.
473
474 In order to keep nir_algebraic relatively simple and easy-to-use,
475 nir_search supports a type of bit-size inference based on the two rules
476 above. This is similar to type inference in many common programming
477 languages. If, for instance, you are constructing an add operation and you
478 know the second source is 16-bit, then you know that the other source and
479 the destination must also be 16-bit. There are, however, cases where this
480 inference can be ambiguous or contradictory. Consider, for instance, the
481 following transformation:
482
483 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
484
485 This transformation can potentially cause a problem because usub_borrow is
486 well-defined for any bit-size of integer. However, b2i always generates a
487 32-bit result so it could end up replacing a 64-bit expression with one
488 that takes two 64-bit values and produces a 32-bit value. As another
489 example, consider this expression:
490
491 (('bcsel', a, b, 0), ('iand', a, b))
492
493 In this case, in the search expression a must be 32-bit but b can
494 potentially have any bit size. If we had a 64-bit b value, we would end up
495 trying to and a 32-bit value with a 64-bit value which would be invalid
496
497 This class solves that problem by providing a validation layer that proves
498 that a given search-and-replace operation is 100% well-defined before we
499 generate any code. This ensures that bugs are caught at compile time
500 rather than at run time.
501
502 Each value maintains a "bit-size class", which is either an actual bit size
503 or an equivalence class with other values that must have the same bit size.
504 The validator works by combining bit-size classes with each other according
505 to the NIR rules outlined above, checking that there are no inconsistencies.
506 When doing this for the replacement expression, we make sure to never change
507 the equivalence class of any of the search values. We could make the example
508 transforms above work by doing some extra run-time checking of the search
509 expression, but we make the user specify those constraints themselves, to
510 avoid any surprises. Since the replacement bitsizes can only be connected to
511 the source bitsize via variables (variables must have the same bitsize in
512 the source and replacment expressions) or the roots of the expression (the
513 replacement expression must produce the same bit size as the search
514 expression), we prevent merging a variable with anything when processing the
515 replacement expression, or specializing the search bitsize
516 with anything. The former prevents
517
518 (('bcsel', a, b, 0), ('iand', a, b))
519
520 from being allowed, since we'd have to merge the bitsizes for a and b due to
521 the 'iand', while the latter prevents
522
523 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
524
525 from being allowed, since the search expression has the bit size of a and b,
526 which can't be specialized to 32 which is the bitsize of the replace
527 expression. It also prevents something like:
528
529 (('b2i', ('i2b', a)), ('ineq', a, 0))
530
531 since the bitsize of 'b2i', which can be anything, can't be specialized to
532 the bitsize of a.
533
534 After doing all this, we check that every subexpression of the replacement
535 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
536 of the search expresssion, since those are the things that are known when
537 constructing the replacement expresssion. Finally, we record the bitsize
538 needed in nir_search_value so that we know what to do when building the
539 replacement expression.
540 """
541
542 def __init__(self, varset):
543 self._var_classes = [None] * len(varset.names)
544
545 def compare_bitsizes(self, a, b):
546 """Determines which bitsize class is a specialization of the other, or
547 whether neither is. When we merge two different bitsizes, the
548 less-specialized bitsize always points to the more-specialized one, so
549 that calling get_bit_size() always gets you the most specialized bitsize.
550 The specialization partial order is given by:
551 - Physical bitsizes are always the most specialized, and a different
552 bitsize can never specialize another.
553 - In the search expression, variables can always be specialized to each
554 other and to physical bitsizes. In the replace expression, we disallow
555 this to avoid adding extra constraints to the search expression that
556 the user didn't specify.
557 - Expressions and constants without a bitsize can always be specialized to
558 each other and variables, but not the other way around.
559
560 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
561 and None if they are not comparable (neither a <= b nor b <= a).
562 """
563 if isinstance(a, int):
564 if isinstance(b, int):
565 return 0 if a == b else None
566 elif isinstance(b, Variable):
567 return -1 if self.is_search else None
568 else:
569 return -1
570 elif isinstance(a, Variable):
571 if isinstance(b, int):
572 return 1 if self.is_search else None
573 elif isinstance(b, Variable):
574 return 0 if self.is_search or a.index == b.index else None
575 else:
576 return -1
577 else:
578 if isinstance(b, int):
579 return 1
580 elif isinstance(b, Variable):
581 return 1
582 else:
583 return 0
584
585 def unify_bit_size(self, a, b, error_msg):
586 """Record that a must have the same bit-size as b. If both
587 have been assigned conflicting physical bit-sizes, call "error_msg" with
588 the bit-sizes of self and other to get a message and raise an error.
589 In the replace expression, disallow merging variables with other
590 variables and physical bit-sizes as well.
591 """
592 a_bit_size = a.get_bit_size()
593 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
594
595 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
596
597 assert cmp_result is not None, \
598 error_msg(a_bit_size, b_bit_size)
599
600 if cmp_result < 0:
601 b_bit_size.set_bit_size(a)
602 elif not isinstance(a_bit_size, int):
603 a_bit_size.set_bit_size(b)
604
605 def merge_variables(self, val):
606 """Perform the first part of type inference by merging all the different
607 uses of the same variable. We always do this as if we're in the search
608 expression, even if we're actually not, since otherwise we'd get errors
609 if the search expression specified some constraint but the replace
610 expression didn't, because we'd be merging a variable and a constant.
611 """
612 if isinstance(val, Variable):
613 if self._var_classes[val.index] is None:
614 self._var_classes[val.index] = val
615 else:
616 other = self._var_classes[val.index]
617 self.unify_bit_size(other, val,
618 lambda other_bit_size, bit_size:
619 'Variable {} has conflicting bit size requirements: ' \
620 'it must have bit size {} and {}'.format(
621 val.var_name, other_bit_size, bit_size))
622 elif isinstance(val, Expression):
623 for src in val.sources:
624 self.merge_variables(src)
625
626 def validate_value(self, val):
627 """Validate the an expression by performing classic Hindley-Milner
628 type inference on bitsizes. This will detect if there are any conflicting
629 requirements, and unify variables so that we know which variables must
630 have the same bitsize. If we're operating on the replace expression, we
631 will refuse to merge different variables together or merge a variable
632 with a constant, in order to prevent surprises due to rules unexpectedly
633 not matching at runtime.
634 """
635 if not isinstance(val, Expression):
636 return
637
638 # Generic conversion ops are special in that they have a single unsized
639 # source and an unsized destination and the two don't have to match.
640 # This means there's no validation or unioning to do here besides the
641 # len(val.sources) check.
642 if val.opcode in conv_opcode_types:
643 assert len(val.sources) == 1, \
644 "Expression {} has {} sources, expected 1".format(
645 val, len(val.sources))
646 self.validate_value(val.sources[0])
647 return
648
649 nir_op = opcodes[val.opcode]
650 assert len(val.sources) == nir_op.num_inputs, \
651 "Expression {} has {} sources, expected {}".format(
652 val, len(val.sources), nir_op.num_inputs)
653
654 for src in val.sources:
655 self.validate_value(src)
656
657 dst_type_bits = type_bits(nir_op.output_type)
658
659 # First, unify all the sources. That way, an error coming up because two
660 # sources have an incompatible bit-size won't produce an error message
661 # involving the destination.
662 first_unsized_src = None
663 for src_type, src in zip(nir_op.input_types, val.sources):
664 src_type_bits = type_bits(src_type)
665 if src_type_bits == 0:
666 if first_unsized_src is None:
667 first_unsized_src = src
668 continue
669
670 if self.is_search:
671 self.unify_bit_size(first_unsized_src, src,
672 lambda first_unsized_src_bit_size, src_bit_size:
673 'Source {} of {} must have bit size {}, while source {} ' \
674 'must have incompatible bit size {}'.format(
675 first_unsized_src, val, first_unsized_src_bit_size,
676 src, src_bit_size))
677 else:
678 self.unify_bit_size(first_unsized_src, src,
679 lambda first_unsized_src_bit_size, src_bit_size:
680 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
681 'of {} may not have the same bit size when building the ' \
682 'replacement expression.'.format(
683 first_unsized_src, first_unsized_src_bit_size, src,
684 src_bit_size, val))
685 else:
686 if self.is_search:
687 self.unify_bit_size(src, src_type_bits,
688 lambda src_bit_size, unused:
689 '{} must have {} bits, but as a source of nir_op_{} '\
690 'it must have {} bits'.format(
691 src, src_bit_size, nir_op.name, src_type_bits))
692 else:
693 self.unify_bit_size(src, src_type_bits,
694 lambda src_bit_size, unused:
695 '{} has the bit size of {}, but as a source of ' \
696 'nir_op_{} it must have {} bits, which may not be the ' \
697 'same'.format(
698 src, src_bit_size, nir_op.name, src_type_bits))
699
700 if dst_type_bits == 0:
701 if first_unsized_src is not None:
702 if self.is_search:
703 self.unify_bit_size(val, first_unsized_src,
704 lambda val_bit_size, src_bit_size:
705 '{} must have the bit size of {}, while its source {} ' \
706 'must have incompatible bit size {}'.format(
707 val, val_bit_size, first_unsized_src, src_bit_size))
708 else:
709 self.unify_bit_size(val, first_unsized_src,
710 lambda val_bit_size, src_bit_size:
711 '{} must have {} bits, but its source {} ' \
712 '(bit size of {}) may not have that bit size ' \
713 'when building the replacement.'.format(
714 val, val_bit_size, first_unsized_src, src_bit_size))
715 else:
716 self.unify_bit_size(val, dst_type_bits,
717 lambda dst_bit_size, unused:
718 '{} must have {} bits, but as a destination of nir_op_{} ' \
719 'it must have {} bits'.format(
720 val, dst_bit_size, nir_op.name, dst_type_bits))
721
722 def validate_replace(self, val, search):
723 bit_size = val.get_bit_size()
724 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
725 bit_size == search.get_bit_size(), \
726 'Ambiguous bit size for replacement value {}: ' \
727 'it cannot be deduced from a variable, a fixed bit size ' \
728 'somewhere, or the search expression.'.format(val)
729
730 if isinstance(val, Expression):
731 for src in val.sources:
732 self.validate_replace(src, search)
733
734 def validate(self, search, replace):
735 self.is_search = True
736 self.merge_variables(search)
737 self.merge_variables(replace)
738 self.validate_value(search)
739
740 self.is_search = False
741 self.validate_value(replace)
742
743 # Check that search is always more specialized than replace. Note that
744 # we're doing this in replace mode, disallowing merging variables.
745 search_bit_size = search.get_bit_size()
746 replace_bit_size = replace.get_bit_size()
747 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
748
749 assert cmp_result is not None and cmp_result <= 0, \
750 'The search expression bit size {} and replace expression ' \
751 'bit size {} may not be the same'.format(
752 search_bit_size, replace_bit_size)
753
754 replace.set_bit_size(search)
755
756 self.validate_replace(replace, search)
757
758 _optimization_ids = itertools.count()
759
760 condition_list = ['true']
761
762 class SearchAndReplace(object):
763 def __init__(self, transform):
764 self.id = next(_optimization_ids)
765
766 search = transform[0]
767 replace = transform[1]
768 if len(transform) > 2:
769 self.condition = transform[2]
770 else:
771 self.condition = 'true'
772
773 if self.condition not in condition_list:
774 condition_list.append(self.condition)
775 self.condition_index = condition_list.index(self.condition)
776
777 varset = VarSet()
778 if isinstance(search, Expression):
779 self.search = search
780 else:
781 self.search = Expression(search, "search{0}".format(self.id), varset)
782
783 varset.lock()
784
785 if isinstance(replace, Value):
786 self.replace = replace
787 else:
788 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
789
790 BitSizeValidator(varset).validate(self.search, self.replace)
791
792 class TreeAutomaton(object):
793 """This class calculates a bottom-up tree automaton to quickly search for
794 the left-hand sides of tranforms. Tree automatons are a generalization of
795 classical NFA's and DFA's, where the transition function determines the
796 state of the parent node based on the state of its children. We construct a
797 deterministic automaton to match patterns, using a similar algorithm to the
798 classical NFA to DFA construction. At the moment, it only matches opcodes
799 and constants (without checking the actual value), leaving more detailed
800 checking to the search function which actually checks the leaves. The
801 automaton acts as a quick filter for the search function, requiring only n
802 + 1 table lookups for each n-source operation. The implementation is based
803 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
804 In the language of that reference, this is a frontier-to-root deterministic
805 automaton using only symbol filtering. The filtering is crucial to reduce
806 both the time taken to generate the tables and the size of the tables.
807 """
808 def __init__(self, transforms):
809 self.patterns = [t.search for t in transforms]
810 self._compute_items()
811 self._build_table()
812 #print('num items: {}'.format(len(set(self.items.values()))))
813 #print('num states: {}'.format(len(self.states)))
814 #for state, patterns in zip(self.states, self.patterns):
815 # print('{}: num patterns: {}'.format(state, len(patterns)))
816
817 class IndexMap(object):
818 """An indexed list of objects, where one can either lookup an object by
819 index or find the index associated to an object quickly using a hash
820 table. Compared to a list, it has a constant time index(). Compared to a
821 set, it provides a stable iteration order.
822 """
823 def __init__(self, iterable=()):
824 self.objects = []
825 self.map = {}
826 for obj in iterable:
827 self.add(obj)
828
829 def __getitem__(self, i):
830 return self.objects[i]
831
832 def __contains__(self, obj):
833 return obj in self.map
834
835 def __len__(self):
836 return len(self.objects)
837
838 def __iter__(self):
839 return iter(self.objects)
840
841 def clear(self):
842 self.objects = []
843 self.map.clear()
844
845 def index(self, obj):
846 return self.map[obj]
847
848 def add(self, obj):
849 if obj in self.map:
850 return self.map[obj]
851 else:
852 index = len(self.objects)
853 self.objects.append(obj)
854 self.map[obj] = index
855 return index
856
857 def __repr__(self):
858 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
859
860 class Item(object):
861 """This represents an "item" in the language of "Tree Automatons." This
862 is just a subtree of some pattern, which represents a potential partial
863 match at runtime. We deduplicate them, so that identical subtrees of
864 different patterns share the same object, and store some extra
865 information needed for the main algorithm as well.
866 """
867 def __init__(self, opcode, children):
868 self.opcode = opcode
869 self.children = children
870 # These are the indices of patterns for which this item is the root node.
871 self.patterns = []
872 # This the set of opcodes for parents of this item. Used to speed up
873 # filtering.
874 self.parent_ops = set()
875
876 def __str__(self):
877 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
878
879 def __repr__(self):
880 return str(self)
881
882 def _compute_items(self):
883 """Build a set of all possible items, deduplicating them."""
884 # This is a map from (opcode, sources) to item.
885 self.items = {}
886
887 # The set of all opcodes used by the patterns. Used later to avoid
888 # building and emitting all the tables for opcodes that aren't used.
889 self.opcodes = self.IndexMap()
890
891 def get_item(opcode, children, pattern=None):
892 commutative = len(children) >= 2 \
893 and "2src_commutative" in opcodes[opcode].algebraic_properties
894 item = self.items.setdefault((opcode, children),
895 self.Item(opcode, children))
896 if commutative:
897 self.items[opcode, (children[1], children[0]) + children[2:]] = item
898 if pattern is not None:
899 item.patterns.append(pattern)
900 return item
901
902 self.wildcard = get_item("__wildcard", ())
903 self.const = get_item("__const", ())
904
905 def process_subpattern(src, pattern=None):
906 if isinstance(src, Constant):
907 # Note: we throw away the actual constant value!
908 return self.const
909 elif isinstance(src, Variable):
910 if src.is_constant:
911 return self.const
912 else:
913 # Note: we throw away which variable it is here! This special
914 # item is equivalent to nu in "Tree Automatons."
915 return self.wildcard
916 else:
917 assert isinstance(src, Expression)
918 opcode = src.opcode
919 stripped = opcode.rstrip('0123456789')
920 if stripped in conv_opcode_types:
921 # Matches that use conversion opcodes with a specific type,
922 # like f2b1, are tricky. Either we construct the automaton to
923 # match specific NIR opcodes like nir_op_f2b1, in which case we
924 # need to create separate items for each possible NIR opcode
925 # for patterns that have a generic opcode like f2b, or we
926 # construct it to match the search opcode, in which case we
927 # need to map f2b1 to f2b when constructing the automaton. Here
928 # we do the latter.
929 opcode = stripped
930 self.opcodes.add(opcode)
931 children = tuple(process_subpattern(c) for c in src.sources)
932 item = get_item(opcode, children, pattern)
933 for i, child in enumerate(children):
934 child.parent_ops.add(opcode)
935 return item
936
937 for i, pattern in enumerate(self.patterns):
938 process_subpattern(pattern, i)
939
940 def _build_table(self):
941 """This is the core algorithm which builds up the transition table. It
942 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
943 Comp_a and Filt_{a,i} using integers to identify match sets." It
944 simultaneously builds up a list of all possible "match sets" or
945 "states", where each match set represents the set of Item's that match a
946 given instruction, and builds up the transition table between states.
947 """
948 # Map from opcode + filtered state indices to transitioned state.
949 self.table = defaultdict(dict)
950 # Bijection from state to index. q in the original algorithm is
951 # len(self.states)
952 self.states = self.IndexMap()
953 # List of pattern matches for each state index.
954 self.state_patterns = []
955 # Map from state index to filtered state index for each opcode.
956 self.filter = defaultdict(list)
957 # Bijections from filtered state to filtered state index for each
958 # opcode, called the "representor sets" in the original algorithm.
959 # q_{a,j} in the original algorithm is len(self.rep[op]).
960 self.rep = defaultdict(self.IndexMap)
961
962 # Everything in self.states with a index at least worklist_index is part
963 # of the worklist of newly created states. There is also a worklist of
964 # newly fitered states for each opcode, for which worklist_indices
965 # serves a similar purpose. worklist_index corresponds to p in the
966 # original algorithm, while worklist_indices is p_{a,j} (although since
967 # we only filter by opcode/symbol, it's really just p_a).
968 self.worklist_index = 0
969 worklist_indices = defaultdict(lambda: 0)
970
971 # This is the set of opcodes for which the filtered worklist is non-empty.
972 # It's used to avoid scanning opcodes for which there is nothing to
973 # process when building the transition table. It corresponds to new_a in
974 # the original algorithm.
975 new_opcodes = self.IndexMap()
976
977 # Process states on the global worklist, filtering them for each opcode,
978 # updating the filter tables, and updating the filtered worklists if any
979 # new filtered states are found. Similar to ComputeRepresenterSets() in
980 # the original algorithm, although that only processes a single state.
981 def process_new_states():
982 while self.worklist_index < len(self.states):
983 state = self.states[self.worklist_index]
984
985 # Calculate pattern matches for this state. Each pattern is
986 # assigned to a unique item, so we don't have to worry about
987 # deduplicating them here. However, we do have to sort them so
988 # that they're visited at runtime in the order they're specified
989 # in the source.
990 patterns = list(sorted(p for item in state for p in item.patterns))
991 assert len(self.state_patterns) == self.worklist_index
992 self.state_patterns.append(patterns)
993
994 # calculate filter table for this state, and update filtered
995 # worklists.
996 for op in self.opcodes:
997 filt = self.filter[op]
998 rep = self.rep[op]
999 filtered = frozenset(item for item in state if \
1000 op in item.parent_ops)
1001 if filtered in rep:
1002 rep_index = rep.index(filtered)
1003 else:
1004 rep_index = rep.add(filtered)
1005 new_opcodes.add(op)
1006 assert len(filt) == self.worklist_index
1007 filt.append(rep_index)
1008 self.worklist_index += 1
1009
1010 # There are two start states: one which can only match as a wildcard,
1011 # and one which can match as a wildcard or constant. These will be the
1012 # states of intrinsics/other instructions and load_const instructions,
1013 # respectively. The indices of these must match the definitions of
1014 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1015 # initialize things correctly.
1016 self.states.add(frozenset((self.wildcard,)))
1017 self.states.add(frozenset((self.const,self.wildcard)))
1018 process_new_states()
1019
1020 while len(new_opcodes) > 0:
1021 for op in new_opcodes:
1022 rep = self.rep[op]
1023 table = self.table[op]
1024 op_worklist_index = worklist_indices[op]
1025 if op in conv_opcode_types:
1026 num_srcs = 1
1027 else:
1028 num_srcs = opcodes[op].num_inputs
1029
1030 # Iterate over all possible source combinations where at least one
1031 # is on the worklist.
1032 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1033 if all(src_idx < op_worklist_index for src_idx in src_indices):
1034 continue
1035
1036 srcs = tuple(rep[src_idx] for src_idx in src_indices)
1037
1038 # Try all possible pairings of source items and add the
1039 # corresponding parent items. This is Comp_a from the paper.
1040 parent = set(self.items[op, item_srcs] for item_srcs in
1041 itertools.product(*srcs) if (op, item_srcs) in self.items)
1042
1043 # We could always start matching something else with a
1044 # wildcard. This is Cl from the paper.
1045 parent.add(self.wildcard)
1046
1047 table[src_indices] = self.states.add(frozenset(parent))
1048 worklist_indices[op] = len(rep)
1049 new_opcodes.clear()
1050 process_new_states()
1051
1052 _algebraic_pass_template = mako.template.Template("""
1053 #include "nir.h"
1054 #include "nir_builder.h"
1055 #include "nir_search.h"
1056 #include "nir_search_helpers.h"
1057
1058 /* What follows is NIR algebraic transform code for the following ${len(xforms)}
1059 * transforms:
1060 % for xform in xforms:
1061 * ${xform.search} => ${xform.replace}
1062 % endfor
1063 */
1064
1065 <% cache = {} %>
1066 % for xform in xforms:
1067 ${xform.search.render(cache)}
1068 ${xform.replace.render(cache)}
1069 % endfor
1070
1071 % for state_id, state_xforms in enumerate(automaton.state_patterns):
1072 % if state_xforms: # avoid emitting a 0-length array for MSVC
1073 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1074 % for i in state_xforms:
1075 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1076 % endfor
1077 };
1078 % endif
1079 % endfor
1080
1081 static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1082 % for op in automaton.opcodes:
1083 [${get_c_opcode(op)}] = {
1084 .filter = (uint16_t []) {
1085 % for e in automaton.filter[op]:
1086 ${e},
1087 % endfor
1088 },
1089 <%
1090 num_filtered = len(automaton.rep[op])
1091 %>
1092 .num_filtered_states = ${num_filtered},
1093 .table = (uint16_t []) {
1094 <%
1095 num_srcs = len(next(iter(automaton.table[op])))
1096 %>
1097 % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1098 ${automaton.table[op][indices]},
1099 % endfor
1100 },
1101 },
1102 % endfor
1103 };
1104
1105 const struct transform *${pass_name}_transforms[] = {
1106 % for i in range(len(automaton.state_patterns)):
1107 % if automaton.state_patterns[i]:
1108 ${pass_name}_state${i}_xforms,
1109 % else:
1110 NULL,
1111 % endif
1112 % endfor
1113 };
1114
1115 const uint16_t ${pass_name}_transform_counts[] = {
1116 % for i in range(len(automaton.state_patterns)):
1117 % if automaton.state_patterns[i]:
1118 (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
1119 % else:
1120 0,
1121 % endif
1122 % endfor
1123 };
1124
1125 bool
1126 ${pass_name}(nir_shader *shader)
1127 {
1128 bool progress = false;
1129 bool condition_flags[${len(condition_list)}];
1130 const nir_shader_compiler_options *options = shader->options;
1131 const shader_info *info = &shader->info;
1132 (void) options;
1133 (void) info;
1134
1135 % for index, condition in enumerate(condition_list):
1136 condition_flags[${index}] = ${condition};
1137 % endfor
1138
1139 nir_foreach_function(function, shader) {
1140 if (function->impl) {
1141 progress |= nir_algebraic_impl(function->impl, condition_flags,
1142 ${pass_name}_transforms,
1143 ${pass_name}_transform_counts,
1144 ${pass_name}_table);
1145 }
1146 }
1147
1148 return progress;
1149 }
1150 """)
1151
1152
1153 class AlgebraicPass(object):
1154 def __init__(self, pass_name, transforms):
1155 self.xforms = []
1156 self.opcode_xforms = defaultdict(lambda : [])
1157 self.pass_name = pass_name
1158
1159 error = False
1160
1161 for xform in transforms:
1162 if not isinstance(xform, SearchAndReplace):
1163 try:
1164 xform = SearchAndReplace(xform)
1165 except:
1166 print("Failed to parse transformation:", file=sys.stderr)
1167 print(" " + str(xform), file=sys.stderr)
1168 traceback.print_exc(file=sys.stderr)
1169 print('', file=sys.stderr)
1170 error = True
1171 continue
1172
1173 self.xforms.append(xform)
1174 if xform.search.opcode in conv_opcode_types:
1175 dst_type = conv_opcode_types[xform.search.opcode]
1176 for size in type_sizes(dst_type):
1177 sized_opcode = xform.search.opcode + str(size)
1178 self.opcode_xforms[sized_opcode].append(xform)
1179 else:
1180 self.opcode_xforms[xform.search.opcode].append(xform)
1181
1182 # Check to make sure the search pattern does not unexpectedly contain
1183 # more commutative expressions than match_expression (nir_search.c)
1184 # can handle.
1185 comm_exprs = xform.search.comm_exprs
1186
1187 if xform.search.many_commutative_expressions:
1188 if comm_exprs <= nir_search_max_comm_ops:
1189 print("Transform expected to have too many commutative " \
1190 "expression but did not " \
1191 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1192 file=sys.stderr)
1193 print(" " + str(xform), file=sys.stderr)
1194 traceback.print_exc(file=sys.stderr)
1195 print('', file=sys.stderr)
1196 error = True
1197 else:
1198 if comm_exprs > nir_search_max_comm_ops:
1199 print("Transformation with too many commutative expressions " \
1200 "({} > {}). Modify pattern or annotate with " \
1201 "\"many-comm-expr\".".format(comm_exprs,
1202 nir_search_max_comm_ops),
1203 file=sys.stderr)
1204 print(" " + str(xform.search), file=sys.stderr)
1205 print("{}".format(xform.search.cond), file=sys.stderr)
1206 error = True
1207
1208 self.automaton = TreeAutomaton(self.xforms)
1209
1210 if error:
1211 sys.exit(1)
1212
1213
1214 def render(self):
1215 return _algebraic_pass_template.render(pass_name=self.pass_name,
1216 xforms=self.xforms,
1217 opcode_xforms=self.opcode_xforms,
1218 condition_list=condition_list,
1219 automaton=self.automaton,
1220 get_c_opcode=get_c_opcode,
1221 itertools=itertools)