nir/algebraic: refactor inexact opcode restrictions
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
39 nir_search_max_comm_ops = 8
40
41 # These opcodes are only employed by nir_search. This provides a mapping from
42 # opcode to destination type.
43 conv_opcode_types = {
44 'i2f' : 'float',
45 'u2f' : 'float',
46 'f2f' : 'float',
47 'f2u' : 'uint',
48 'f2i' : 'int',
49 'u2u' : 'uint',
50 'i2i' : 'int',
51 'b2f' : 'float',
52 'b2i' : 'int',
53 'i2b' : 'bool',
54 'f2b' : 'bool',
55 }
56
57 def get_c_opcode(op):
58 if op in conv_opcode_types:
59 return 'nir_search_op_' + op
60 else:
61 return 'nir_op_' + op
62
63
64 if sys.version_info < (3, 0):
65 integer_types = (int, long)
66 string_type = unicode
67
68 else:
69 integer_types = (int, )
70 string_type = str
71
72 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
73
74 def type_bits(type_str):
75 m = _type_re.match(type_str)
76 assert m.group('type')
77
78 if m.group('bits') is None:
79 return 0
80 else:
81 return int(m.group('bits'))
82
83 # Represents a set of variables, each with a unique id
84 class VarSet(object):
85 def __init__(self):
86 self.names = {}
87 self.ids = itertools.count()
88 self.immutable = False;
89
90 def __getitem__(self, name):
91 if name not in self.names:
92 assert not self.immutable, "Unknown replacement variable: " + name
93 self.names[name] = next(self.ids)
94
95 return self.names[name]
96
97 def lock(self):
98 self.immutable = True
99
100 class Value(object):
101 @staticmethod
102 def create(val, name_base, varset):
103 if isinstance(val, bytes):
104 val = val.decode('utf-8')
105
106 if isinstance(val, tuple):
107 return Expression(val, name_base, varset)
108 elif isinstance(val, Expression):
109 return val
110 elif isinstance(val, string_type):
111 return Variable(val, name_base, varset)
112 elif isinstance(val, (bool, float) + integer_types):
113 return Constant(val, name_base)
114
115 def __init__(self, val, name, type_str):
116 self.in_val = str(val)
117 self.name = name
118 self.type_str = type_str
119
120 def __str__(self):
121 return self.in_val
122
123 def get_bit_size(self):
124 """Get the physical bit-size that has been chosen for this value, or if
125 there is none, the canonical value which currently represents this
126 bit-size class. Variables will be preferred, i.e. if there are any
127 variables in the equivalence class, the canonical value will be a
128 variable. We do this since we'll need to know which variable each value
129 is equivalent to when constructing the replacement expression. This is
130 the "find" part of the union-find algorithm.
131 """
132 bit_size = self
133
134 while isinstance(bit_size, Value):
135 if bit_size._bit_size is None:
136 break
137 bit_size = bit_size._bit_size
138
139 if bit_size is not self:
140 self._bit_size = bit_size
141 return bit_size
142
143 def set_bit_size(self, other):
144 """Make self.get_bit_size() return what other.get_bit_size() return
145 before calling this, or just "other" if it's a concrete bit-size. This is
146 the "union" part of the union-find algorithm.
147 """
148
149 self_bit_size = self.get_bit_size()
150 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
151
152 if self_bit_size == other_bit_size:
153 return
154
155 self_bit_size._bit_size = other_bit_size
156
157 @property
158 def type_enum(self):
159 return "nir_search_value_" + self.type_str
160
161 @property
162 def c_type(self):
163 return "nir_search_" + self.type_str
164
165 def __c_name(self, cache):
166 if cache is not None and self.name in cache:
167 return cache[self.name]
168 else:
169 return self.name
170
171 def c_value_ptr(self, cache):
172 return "&{0}.value".format(self.__c_name(cache))
173
174 def c_ptr(self, cache):
175 return "&{0}".format(self.__c_name(cache))
176
177 @property
178 def c_bit_size(self):
179 bit_size = self.get_bit_size()
180 if isinstance(bit_size, int):
181 return bit_size
182 elif isinstance(bit_size, Variable):
183 return -bit_size.index - 1
184 else:
185 # If the bit-size class is neither a variable, nor an actual bit-size, then
186 # - If it's in the search expression, we don't need to check anything
187 # - If it's in the replace expression, either it's ambiguous (in which
188 # case we'd reject it), or it equals the bit-size of the search value
189 # We represent these cases with a 0 bit-size.
190 return 0
191
192 __template = mako.template.Template("""{
193 { ${val.type_enum}, ${val.c_bit_size} },
194 % if isinstance(val, Constant):
195 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
196 % elif isinstance(val, Variable):
197 ${val.index}, /* ${val.var_name} */
198 ${'true' if val.is_constant else 'false'},
199 ${val.type() or 'nir_type_invalid' },
200 ${val.cond if val.cond else 'NULL'},
201 ${val.swizzle()},
202 % elif isinstance(val, Expression):
203 ${'true' if val.inexact else 'false'},
204 ${val.comm_expr_idx}, ${val.comm_exprs},
205 ${val.c_opcode()},
206 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
207 ${val.cond if val.cond else 'NULL'},
208 % endif
209 };""")
210
211 def render(self, cache):
212 struct_init = self.__template.render(val=self, cache=cache,
213 Constant=Constant,
214 Variable=Variable,
215 Expression=Expression)
216 if cache is not None and struct_init in cache:
217 # If it's in the cache, register a name remap in the cache and render
218 # only a comment saying it's been remapped
219 cache[self.name] = cache[struct_init]
220 return "/* {} -> {} in the cache */\n".format(self.name,
221 cache[struct_init])
222 else:
223 if cache is not None:
224 cache[struct_init] = self.name
225 return "static const {} {} = {}\n".format(self.c_type, self.name,
226 struct_init)
227
228 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
229
230 class Constant(Value):
231 def __init__(self, val, name):
232 Value.__init__(self, val, name, "constant")
233
234 if isinstance(val, (str)):
235 m = _constant_re.match(val)
236 self.value = ast.literal_eval(m.group('value'))
237 self._bit_size = int(m.group('bits')) if m.group('bits') else None
238 else:
239 self.value = val
240 self._bit_size = None
241
242 if isinstance(self.value, bool):
243 assert self._bit_size is None or self._bit_size == 1
244 self._bit_size = 1
245
246 def hex(self):
247 if isinstance(self.value, (bool)):
248 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
249 if isinstance(self.value, integer_types):
250 return hex(self.value)
251 elif isinstance(self.value, float):
252 i = struct.unpack('Q', struct.pack('d', self.value))[0]
253 h = hex(i)
254
255 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
256 # Adding it explicitly makes the generated file identical, regardless
257 # of the Python version running this script.
258 if h[-1] != 'L' and i > sys.maxsize:
259 h += 'L'
260
261 return h
262 else:
263 assert False
264
265 def type(self):
266 if isinstance(self.value, (bool)):
267 return "nir_type_bool"
268 elif isinstance(self.value, integer_types):
269 return "nir_type_int"
270 elif isinstance(self.value, float):
271 return "nir_type_float"
272
273 def equivalent(self, other):
274 """Check that two constants are equivalent.
275
276 This is check is much weaker than equality. One generally cannot be
277 used in place of the other. Using this implementation for the __eq__
278 will break BitSizeValidator.
279
280 """
281 if not isinstance(other, type(self)):
282 return False
283
284 return self.value == other.value
285
286 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
287 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
288 r"(?P<cond>\([^\)]+\))?"
289 r"(?P<swiz>\.[xyzw]+)?")
290
291 class Variable(Value):
292 def __init__(self, val, name, varset):
293 Value.__init__(self, val, name, "variable")
294
295 m = _var_name_re.match(val)
296 assert m and m.group('name') is not None
297
298 self.var_name = m.group('name')
299
300 # Prevent common cases where someone puts quotes around a literal
301 # constant. If we want to support names that have numeric or
302 # punctuation characters, we can me the first assertion more flexible.
303 assert self.var_name.isalpha()
304 assert self.var_name is not 'True'
305 assert self.var_name is not 'False'
306
307 self.is_constant = m.group('const') is not None
308 self.cond = m.group('cond')
309 self.required_type = m.group('type')
310 self._bit_size = int(m.group('bits')) if m.group('bits') else None
311 self.swiz = m.group('swiz')
312
313 if self.required_type == 'bool':
314 if self._bit_size is not None:
315 assert self._bit_size in type_sizes(self.required_type)
316 else:
317 self._bit_size = 1
318
319 if self.required_type is not None:
320 assert self.required_type in ('float', 'bool', 'int', 'uint')
321
322 self.index = varset[self.var_name]
323
324 def type(self):
325 if self.required_type == 'bool':
326 return "nir_type_bool"
327 elif self.required_type in ('int', 'uint'):
328 return "nir_type_int"
329 elif self.required_type == 'float':
330 return "nir_type_float"
331
332 def equivalent(self, other):
333 """Check that two variables are equivalent.
334
335 This is check is much weaker than equality. One generally cannot be
336 used in place of the other. Using this implementation for the __eq__
337 will break BitSizeValidator.
338
339 """
340 if not isinstance(other, type(self)):
341 return False
342
343 return self.index == other.index
344
345 def swizzle(self):
346 if self.swiz is not None:
347 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w': 3}
348 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
349 return '{0, 1, 2, 3}'
350
351 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
352 r"(?P<cond>\([^\)]+\))?")
353
354 class Expression(Value):
355 def __init__(self, expr, name_base, varset):
356 Value.__init__(self, expr, name_base, "expression")
357 assert isinstance(expr, tuple)
358
359 m = _opcode_re.match(expr[0])
360 assert m and m.group('opcode') is not None
361
362 self.opcode = m.group('opcode')
363 self._bit_size = int(m.group('bits')) if m.group('bits') else None
364 self.inexact = m.group('inexact') is not None
365 self.cond = m.group('cond')
366
367 # "many-comm-expr" isn't really a condition. It's notification to the
368 # generator that this pattern is known to have too many commutative
369 # expressions, and an error should not be generated for this case.
370 self.many_commutative_expressions = False
371 if self.cond and self.cond.find("many-comm-expr") >= 0:
372 # Split the condition into a comma-separated list. Remove
373 # "many-comm-expr". If there is anything left, put it back together.
374 c = self.cond[1:-1].split(",")
375 c.remove("many-comm-expr")
376
377 self.cond = "({})".format(",".join(c)) if c else None
378 self.many_commutative_expressions = True
379
380 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
381 for (i, src) in enumerate(expr[1:]) ]
382
383 if self.opcode in conv_opcode_types:
384 assert self._bit_size is None, \
385 'Expression cannot use an unsized conversion opcode with ' \
386 'an explicit size; that\'s silly.'
387
388 self.__index_comm_exprs(0)
389
390 def equivalent(self, other):
391 """Check that two variables are equivalent.
392
393 This is check is much weaker than equality. One generally cannot be
394 used in place of the other. Using this implementation for the __eq__
395 will break BitSizeValidator.
396
397 This implementation does not check for equivalence due to commutativity,
398 but it could.
399
400 """
401 if not isinstance(other, type(self)):
402 return False
403
404 if len(self.sources) != len(other.sources):
405 return False
406
407 if self.opcode != other.opcode:
408 return False
409
410 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
411
412 def __index_comm_exprs(self, base_idx):
413 """Recursively count and index commutative expressions
414 """
415 self.comm_exprs = 0
416
417 # A note about the explicit "len(self.sources)" check. The list of
418 # sources comes from user input, and that input might be bad. Check
419 # that the expected second source exists before accessing it. Without
420 # this check, a unit test that does "('iadd', 'a')" will crash.
421 if self.opcode not in conv_opcode_types and \
422 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
423 len(self.sources) >= 2 and \
424 not self.sources[0].equivalent(self.sources[1]):
425 self.comm_expr_idx = base_idx
426 self.comm_exprs += 1
427 else:
428 self.comm_expr_idx = -1
429
430 for s in self.sources:
431 if isinstance(s, Expression):
432 s.__index_comm_exprs(base_idx + self.comm_exprs)
433 self.comm_exprs += s.comm_exprs
434
435 return self.comm_exprs
436
437 def c_opcode(self):
438 return get_c_opcode(self.opcode)
439
440 def render(self, cache):
441 srcs = "\n".join(src.render(cache) for src in self.sources)
442 return srcs + super(Expression, self).render(cache)
443
444 class BitSizeValidator(object):
445 """A class for validating bit sizes of expressions.
446
447 NIR supports multiple bit-sizes on expressions in order to handle things
448 such as fp64. The source and destination of every ALU operation is
449 assigned a type and that type may or may not specify a bit size. Sources
450 and destinations whose type does not specify a bit size are considered
451 "unsized" and automatically take on the bit size of the corresponding
452 register or SSA value. NIR has two simple rules for bit sizes that are
453 validated by nir_validator:
454
455 1) A given SSA def or register has a single bit size that is respected by
456 everything that reads from it or writes to it.
457
458 2) The bit sizes of all unsized inputs/outputs on any given ALU
459 instruction must match. They need not match the sized inputs or
460 outputs but they must match each other.
461
462 In order to keep nir_algebraic relatively simple and easy-to-use,
463 nir_search supports a type of bit-size inference based on the two rules
464 above. This is similar to type inference in many common programming
465 languages. If, for instance, you are constructing an add operation and you
466 know the second source is 16-bit, then you know that the other source and
467 the destination must also be 16-bit. There are, however, cases where this
468 inference can be ambiguous or contradictory. Consider, for instance, the
469 following transformation:
470
471 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
472
473 This transformation can potentially cause a problem because usub_borrow is
474 well-defined for any bit-size of integer. However, b2i always generates a
475 32-bit result so it could end up replacing a 64-bit expression with one
476 that takes two 64-bit values and produces a 32-bit value. As another
477 example, consider this expression:
478
479 (('bcsel', a, b, 0), ('iand', a, b))
480
481 In this case, in the search expression a must be 32-bit but b can
482 potentially have any bit size. If we had a 64-bit b value, we would end up
483 trying to and a 32-bit value with a 64-bit value which would be invalid
484
485 This class solves that problem by providing a validation layer that proves
486 that a given search-and-replace operation is 100% well-defined before we
487 generate any code. This ensures that bugs are caught at compile time
488 rather than at run time.
489
490 Each value maintains a "bit-size class", which is either an actual bit size
491 or an equivalence class with other values that must have the same bit size.
492 The validator works by combining bit-size classes with each other according
493 to the NIR rules outlined above, checking that there are no inconsistencies.
494 When doing this for the replacement expression, we make sure to never change
495 the equivalence class of any of the search values. We could make the example
496 transforms above work by doing some extra run-time checking of the search
497 expression, but we make the user specify those constraints themselves, to
498 avoid any surprises. Since the replacement bitsizes can only be connected to
499 the source bitsize via variables (variables must have the same bitsize in
500 the source and replacment expressions) or the roots of the expression (the
501 replacement expression must produce the same bit size as the search
502 expression), we prevent merging a variable with anything when processing the
503 replacement expression, or specializing the search bitsize
504 with anything. The former prevents
505
506 (('bcsel', a, b, 0), ('iand', a, b))
507
508 from being allowed, since we'd have to merge the bitsizes for a and b due to
509 the 'iand', while the latter prevents
510
511 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
512
513 from being allowed, since the search expression has the bit size of a and b,
514 which can't be specialized to 32 which is the bitsize of the replace
515 expression. It also prevents something like:
516
517 (('b2i', ('i2b', a)), ('ineq', a, 0))
518
519 since the bitsize of 'b2i', which can be anything, can't be specialized to
520 the bitsize of a.
521
522 After doing all this, we check that every subexpression of the replacement
523 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
524 of the search expresssion, since those are the things that are known when
525 constructing the replacement expresssion. Finally, we record the bitsize
526 needed in nir_search_value so that we know what to do when building the
527 replacement expression.
528 """
529
530 def __init__(self, varset):
531 self._var_classes = [None] * len(varset.names)
532
533 def compare_bitsizes(self, a, b):
534 """Determines which bitsize class is a specialization of the other, or
535 whether neither is. When we merge two different bitsizes, the
536 less-specialized bitsize always points to the more-specialized one, so
537 that calling get_bit_size() always gets you the most specialized bitsize.
538 The specialization partial order is given by:
539 - Physical bitsizes are always the most specialized, and a different
540 bitsize can never specialize another.
541 - In the search expression, variables can always be specialized to each
542 other and to physical bitsizes. In the replace expression, we disallow
543 this to avoid adding extra constraints to the search expression that
544 the user didn't specify.
545 - Expressions and constants without a bitsize can always be specialized to
546 each other and variables, but not the other way around.
547
548 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
549 and None if they are not comparable (neither a <= b nor b <= a).
550 """
551 if isinstance(a, int):
552 if isinstance(b, int):
553 return 0 if a == b else None
554 elif isinstance(b, Variable):
555 return -1 if self.is_search else None
556 else:
557 return -1
558 elif isinstance(a, Variable):
559 if isinstance(b, int):
560 return 1 if self.is_search else None
561 elif isinstance(b, Variable):
562 return 0 if self.is_search or a.index == b.index else None
563 else:
564 return -1
565 else:
566 if isinstance(b, int):
567 return 1
568 elif isinstance(b, Variable):
569 return 1
570 else:
571 return 0
572
573 def unify_bit_size(self, a, b, error_msg):
574 """Record that a must have the same bit-size as b. If both
575 have been assigned conflicting physical bit-sizes, call "error_msg" with
576 the bit-sizes of self and other to get a message and raise an error.
577 In the replace expression, disallow merging variables with other
578 variables and physical bit-sizes as well.
579 """
580 a_bit_size = a.get_bit_size()
581 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
582
583 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
584
585 assert cmp_result is not None, \
586 error_msg(a_bit_size, b_bit_size)
587
588 if cmp_result < 0:
589 b_bit_size.set_bit_size(a)
590 elif not isinstance(a_bit_size, int):
591 a_bit_size.set_bit_size(b)
592
593 def merge_variables(self, val):
594 """Perform the first part of type inference by merging all the different
595 uses of the same variable. We always do this as if we're in the search
596 expression, even if we're actually not, since otherwise we'd get errors
597 if the search expression specified some constraint but the replace
598 expression didn't, because we'd be merging a variable and a constant.
599 """
600 if isinstance(val, Variable):
601 if self._var_classes[val.index] is None:
602 self._var_classes[val.index] = val
603 else:
604 other = self._var_classes[val.index]
605 self.unify_bit_size(other, val,
606 lambda other_bit_size, bit_size:
607 'Variable {} has conflicting bit size requirements: ' \
608 'it must have bit size {} and {}'.format(
609 val.var_name, other_bit_size, bit_size))
610 elif isinstance(val, Expression):
611 for src in val.sources:
612 self.merge_variables(src)
613
614 def validate_value(self, val):
615 """Validate the an expression by performing classic Hindley-Milner
616 type inference on bitsizes. This will detect if there are any conflicting
617 requirements, and unify variables so that we know which variables must
618 have the same bitsize. If we're operating on the replace expression, we
619 will refuse to merge different variables together or merge a variable
620 with a constant, in order to prevent surprises due to rules unexpectedly
621 not matching at runtime.
622 """
623 if not isinstance(val, Expression):
624 return
625
626 # Generic conversion ops are special in that they have a single unsized
627 # source and an unsized destination and the two don't have to match.
628 # This means there's no validation or unioning to do here besides the
629 # len(val.sources) check.
630 if val.opcode in conv_opcode_types:
631 assert len(val.sources) == 1, \
632 "Expression {} has {} sources, expected 1".format(
633 val, len(val.sources))
634 self.validate_value(val.sources[0])
635 return
636
637 nir_op = opcodes[val.opcode]
638 assert len(val.sources) == nir_op.num_inputs, \
639 "Expression {} has {} sources, expected {}".format(
640 val, len(val.sources), nir_op.num_inputs)
641
642 for src in val.sources:
643 self.validate_value(src)
644
645 dst_type_bits = type_bits(nir_op.output_type)
646
647 # First, unify all the sources. That way, an error coming up because two
648 # sources have an incompatible bit-size won't produce an error message
649 # involving the destination.
650 first_unsized_src = None
651 for src_type, src in zip(nir_op.input_types, val.sources):
652 src_type_bits = type_bits(src_type)
653 if src_type_bits == 0:
654 if first_unsized_src is None:
655 first_unsized_src = src
656 continue
657
658 if self.is_search:
659 self.unify_bit_size(first_unsized_src, src,
660 lambda first_unsized_src_bit_size, src_bit_size:
661 'Source {} of {} must have bit size {}, while source {} ' \
662 'must have incompatible bit size {}'.format(
663 first_unsized_src, val, first_unsized_src_bit_size,
664 src, src_bit_size))
665 else:
666 self.unify_bit_size(first_unsized_src, src,
667 lambda first_unsized_src_bit_size, src_bit_size:
668 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
669 'of {} may not have the same bit size when building the ' \
670 'replacement expression.'.format(
671 first_unsized_src, first_unsized_src_bit_size, src,
672 src_bit_size, val))
673 else:
674 if self.is_search:
675 self.unify_bit_size(src, src_type_bits,
676 lambda src_bit_size, unused:
677 '{} must have {} bits, but as a source of nir_op_{} '\
678 'it must have {} bits'.format(
679 src, src_bit_size, nir_op.name, src_type_bits))
680 else:
681 self.unify_bit_size(src, src_type_bits,
682 lambda src_bit_size, unused:
683 '{} has the bit size of {}, but as a source of ' \
684 'nir_op_{} it must have {} bits, which may not be the ' \
685 'same'.format(
686 src, src_bit_size, nir_op.name, src_type_bits))
687
688 if dst_type_bits == 0:
689 if first_unsized_src is not None:
690 if self.is_search:
691 self.unify_bit_size(val, first_unsized_src,
692 lambda val_bit_size, src_bit_size:
693 '{} must have the bit size of {}, while its source {} ' \
694 'must have incompatible bit size {}'.format(
695 val, val_bit_size, first_unsized_src, src_bit_size))
696 else:
697 self.unify_bit_size(val, first_unsized_src,
698 lambda val_bit_size, src_bit_size:
699 '{} must have {} bits, but its source {} ' \
700 '(bit size of {}) may not have that bit size ' \
701 'when building the replacement.'.format(
702 val, val_bit_size, first_unsized_src, src_bit_size))
703 else:
704 self.unify_bit_size(val, dst_type_bits,
705 lambda dst_bit_size, unused:
706 '{} must have {} bits, but as a destination of nir_op_{} ' \
707 'it must have {} bits'.format(
708 val, dst_bit_size, nir_op.name, dst_type_bits))
709
710 def validate_replace(self, val, search):
711 bit_size = val.get_bit_size()
712 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
713 bit_size == search.get_bit_size(), \
714 'Ambiguous bit size for replacement value {}: ' \
715 'it cannot be deduced from a variable, a fixed bit size ' \
716 'somewhere, or the search expression.'.format(val)
717
718 if isinstance(val, Expression):
719 for src in val.sources:
720 self.validate_replace(src, search)
721
722 def validate(self, search, replace):
723 self.is_search = True
724 self.merge_variables(search)
725 self.merge_variables(replace)
726 self.validate_value(search)
727
728 self.is_search = False
729 self.validate_value(replace)
730
731 # Check that search is always more specialized than replace. Note that
732 # we're doing this in replace mode, disallowing merging variables.
733 search_bit_size = search.get_bit_size()
734 replace_bit_size = replace.get_bit_size()
735 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
736
737 assert cmp_result is not None and cmp_result <= 0, \
738 'The search expression bit size {} and replace expression ' \
739 'bit size {} may not be the same'.format(
740 search_bit_size, replace_bit_size)
741
742 replace.set_bit_size(search)
743
744 self.validate_replace(replace, search)
745
746 _optimization_ids = itertools.count()
747
748 condition_list = ['true']
749
750 class SearchAndReplace(object):
751 def __init__(self, transform):
752 self.id = next(_optimization_ids)
753
754 search = transform[0]
755 replace = transform[1]
756 if len(transform) > 2:
757 self.condition = transform[2]
758 else:
759 self.condition = 'true'
760
761 if self.condition not in condition_list:
762 condition_list.append(self.condition)
763 self.condition_index = condition_list.index(self.condition)
764
765 varset = VarSet()
766 if isinstance(search, Expression):
767 self.search = search
768 else:
769 self.search = Expression(search, "search{0}".format(self.id), varset)
770
771 varset.lock()
772
773 if isinstance(replace, Value):
774 self.replace = replace
775 else:
776 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
777
778 BitSizeValidator(varset).validate(self.search, self.replace)
779
780 class TreeAutomaton(object):
781 """This class calculates a bottom-up tree automaton to quickly search for
782 the left-hand sides of tranforms. Tree automatons are a generalization of
783 classical NFA's and DFA's, where the transition function determines the
784 state of the parent node based on the state of its children. We construct a
785 deterministic automaton to match patterns, using a similar algorithm to the
786 classical NFA to DFA construction. At the moment, it only matches opcodes
787 and constants (without checking the actual value), leaving more detailed
788 checking to the search function which actually checks the leaves. The
789 automaton acts as a quick filter for the search function, requiring only n
790 + 1 table lookups for each n-source operation. The implementation is based
791 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
792 In the language of that reference, this is a frontier-to-root deterministic
793 automaton using only symbol filtering. The filtering is crucial to reduce
794 both the time taken to generate the tables and the size of the tables.
795 """
796 def __init__(self, transforms):
797 self.patterns = [t.search for t in transforms]
798 self._compute_items()
799 self._build_table()
800 #print('num items: {}'.format(len(set(self.items.values()))))
801 #print('num states: {}'.format(len(self.states)))
802 #for state, patterns in zip(self.states, self.patterns):
803 # print('{}: num patterns: {}'.format(state, len(patterns)))
804
805 class IndexMap(object):
806 """An indexed list of objects, where one can either lookup an object by
807 index or find the index associated to an object quickly using a hash
808 table. Compared to a list, it has a constant time index(). Compared to a
809 set, it provides a stable iteration order.
810 """
811 def __init__(self, iterable=()):
812 self.objects = []
813 self.map = {}
814 for obj in iterable:
815 self.add(obj)
816
817 def __getitem__(self, i):
818 return self.objects[i]
819
820 def __contains__(self, obj):
821 return obj in self.map
822
823 def __len__(self):
824 return len(self.objects)
825
826 def __iter__(self):
827 return iter(self.objects)
828
829 def clear(self):
830 self.objects = []
831 self.map.clear()
832
833 def index(self, obj):
834 return self.map[obj]
835
836 def add(self, obj):
837 if obj in self.map:
838 return self.map[obj]
839 else:
840 index = len(self.objects)
841 self.objects.append(obj)
842 self.map[obj] = index
843 return index
844
845 def __repr__(self):
846 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
847
848 class Item(object):
849 """This represents an "item" in the language of "Tree Automatons." This
850 is just a subtree of some pattern, which represents a potential partial
851 match at runtime. We deduplicate them, so that identical subtrees of
852 different patterns share the same object, and store some extra
853 information needed for the main algorithm as well.
854 """
855 def __init__(self, opcode, children):
856 self.opcode = opcode
857 self.children = children
858 # These are the indices of patterns for which this item is the root node.
859 self.patterns = []
860 # This the set of opcodes for parents of this item. Used to speed up
861 # filtering.
862 self.parent_ops = set()
863
864 def __str__(self):
865 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
866
867 def __repr__(self):
868 return str(self)
869
870 def _compute_items(self):
871 """Build a set of all possible items, deduplicating them."""
872 # This is a map from (opcode, sources) to item.
873 self.items = {}
874
875 # The set of all opcodes used by the patterns. Used later to avoid
876 # building and emitting all the tables for opcodes that aren't used.
877 self.opcodes = self.IndexMap()
878
879 def get_item(opcode, children, pattern=None):
880 commutative = len(children) >= 2 \
881 and "2src_commutative" in opcodes[opcode].algebraic_properties
882 item = self.items.setdefault((opcode, children),
883 self.Item(opcode, children))
884 if commutative:
885 self.items[opcode, (children[1], children[0]) + children[2:]] = item
886 if pattern is not None:
887 item.patterns.append(pattern)
888 return item
889
890 self.wildcard = get_item("__wildcard", ())
891 self.const = get_item("__const", ())
892
893 def process_subpattern(src, pattern=None):
894 if isinstance(src, Constant):
895 # Note: we throw away the actual constant value!
896 return self.const
897 elif isinstance(src, Variable):
898 if src.is_constant:
899 return self.const
900 else:
901 # Note: we throw away which variable it is here! This special
902 # item is equivalent to nu in "Tree Automatons."
903 return self.wildcard
904 else:
905 assert isinstance(src, Expression)
906 opcode = src.opcode
907 stripped = opcode.rstrip('0123456789')
908 if stripped in conv_opcode_types:
909 # Matches that use conversion opcodes with a specific type,
910 # like f2b1, are tricky. Either we construct the automaton to
911 # match specific NIR opcodes like nir_op_f2b1, in which case we
912 # need to create separate items for each possible NIR opcode
913 # for patterns that have a generic opcode like f2b, or we
914 # construct it to match the search opcode, in which case we
915 # need to map f2b1 to f2b when constructing the automaton. Here
916 # we do the latter.
917 opcode = stripped
918 self.opcodes.add(opcode)
919 children = tuple(process_subpattern(c) for c in src.sources)
920 item = get_item(opcode, children, pattern)
921 for i, child in enumerate(children):
922 child.parent_ops.add(opcode)
923 return item
924
925 for i, pattern in enumerate(self.patterns):
926 process_subpattern(pattern, i)
927
928 def _build_table(self):
929 """This is the core algorithm which builds up the transition table. It
930 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
931 Comp_a and Filt_{a,i} using integers to identify match sets." It
932 simultaneously builds up a list of all possible "match sets" or
933 "states", where each match set represents the set of Item's that match a
934 given instruction, and builds up the transition table between states.
935 """
936 # Map from opcode + filtered state indices to transitioned state.
937 self.table = defaultdict(dict)
938 # Bijection from state to index. q in the original algorithm is
939 # len(self.states)
940 self.states = self.IndexMap()
941 # List of pattern matches for each state index.
942 self.state_patterns = []
943 # Map from state index to filtered state index for each opcode.
944 self.filter = defaultdict(list)
945 # Bijections from filtered state to filtered state index for each
946 # opcode, called the "representor sets" in the original algorithm.
947 # q_{a,j} in the original algorithm is len(self.rep[op]).
948 self.rep = defaultdict(self.IndexMap)
949
950 # Everything in self.states with a index at least worklist_index is part
951 # of the worklist of newly created states. There is also a worklist of
952 # newly fitered states for each opcode, for which worklist_indices
953 # serves a similar purpose. worklist_index corresponds to p in the
954 # original algorithm, while worklist_indices is p_{a,j} (although since
955 # we only filter by opcode/symbol, it's really just p_a).
956 self.worklist_index = 0
957 worklist_indices = defaultdict(lambda: 0)
958
959 # This is the set of opcodes for which the filtered worklist is non-empty.
960 # It's used to avoid scanning opcodes for which there is nothing to
961 # process when building the transition table. It corresponds to new_a in
962 # the original algorithm.
963 new_opcodes = self.IndexMap()
964
965 # Process states on the global worklist, filtering them for each opcode,
966 # updating the filter tables, and updating the filtered worklists if any
967 # new filtered states are found. Similar to ComputeRepresenterSets() in
968 # the original algorithm, although that only processes a single state.
969 def process_new_states():
970 while self.worklist_index < len(self.states):
971 state = self.states[self.worklist_index]
972
973 # Calculate pattern matches for this state. Each pattern is
974 # assigned to a unique item, so we don't have to worry about
975 # deduplicating them here. However, we do have to sort them so
976 # that they're visited at runtime in the order they're specified
977 # in the source.
978 patterns = list(sorted(p for item in state for p in item.patterns))
979 assert len(self.state_patterns) == self.worklist_index
980 self.state_patterns.append(patterns)
981
982 # calculate filter table for this state, and update filtered
983 # worklists.
984 for op in self.opcodes:
985 filt = self.filter[op]
986 rep = self.rep[op]
987 filtered = frozenset(item for item in state if \
988 op in item.parent_ops)
989 if filtered in rep:
990 rep_index = rep.index(filtered)
991 else:
992 rep_index = rep.add(filtered)
993 new_opcodes.add(op)
994 assert len(filt) == self.worklist_index
995 filt.append(rep_index)
996 self.worklist_index += 1
997
998 # There are two start states: one which can only match as a wildcard,
999 # and one which can match as a wildcard or constant. These will be the
1000 # states of intrinsics/other instructions and load_const instructions,
1001 # respectively. The indices of these must match the definitions of
1002 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1003 # initialize things correctly.
1004 self.states.add(frozenset((self.wildcard,)))
1005 self.states.add(frozenset((self.const,self.wildcard)))
1006 process_new_states()
1007
1008 while len(new_opcodes) > 0:
1009 for op in new_opcodes:
1010 rep = self.rep[op]
1011 table = self.table[op]
1012 op_worklist_index = worklist_indices[op]
1013 if op in conv_opcode_types:
1014 num_srcs = 1
1015 else:
1016 num_srcs = opcodes[op].num_inputs
1017
1018 # Iterate over all possible source combinations where at least one
1019 # is on the worklist.
1020 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1021 if all(src_idx < op_worklist_index for src_idx in src_indices):
1022 continue
1023
1024 srcs = tuple(rep[src_idx] for src_idx in src_indices)
1025
1026 # Try all possible pairings of source items and add the
1027 # corresponding parent items. This is Comp_a from the paper.
1028 parent = set(self.items[op, item_srcs] for item_srcs in
1029 itertools.product(*srcs) if (op, item_srcs) in self.items)
1030
1031 # We could always start matching something else with a
1032 # wildcard. This is Cl from the paper.
1033 parent.add(self.wildcard)
1034
1035 table[src_indices] = self.states.add(frozenset(parent))
1036 worklist_indices[op] = len(rep)
1037 new_opcodes.clear()
1038 process_new_states()
1039
1040 _algebraic_pass_template = mako.template.Template("""
1041 #include "nir.h"
1042 #include "nir_builder.h"
1043 #include "nir_search.h"
1044 #include "nir_search_helpers.h"
1045
1046 /* What follows is NIR algebraic transform code for the following ${len(xforms)}
1047 * transforms:
1048 % for xform in xforms:
1049 * ${xform.search} => ${xform.replace}
1050 % endfor
1051 */
1052
1053 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
1054 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
1055
1056 struct transform {
1057 const nir_search_expression *search;
1058 const nir_search_value *replace;
1059 unsigned condition_offset;
1060 };
1061
1062 struct per_op_table {
1063 const uint16_t *filter;
1064 unsigned num_filtered_states;
1065 const uint16_t *table;
1066 };
1067
1068 /* Note: these must match the start states created in
1069 * TreeAutomaton._build_table()
1070 */
1071
1072 /* WILDCARD_STATE = 0 is set by zeroing the state array */
1073 static const uint16_t CONST_STATE = 1;
1074
1075 #endif
1076
1077 <% cache = {} %>
1078 % for xform in xforms:
1079 ${xform.search.render(cache)}
1080 ${xform.replace.render(cache)}
1081 % endfor
1082
1083 % for state_id, state_xforms in enumerate(automaton.state_patterns):
1084 % if state_xforms: # avoid emitting a 0-length array for MSVC
1085 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1086 % for i in state_xforms:
1087 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1088 % endfor
1089 };
1090 % endif
1091 % endfor
1092
1093 static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1094 % for op in automaton.opcodes:
1095 [${get_c_opcode(op)}] = {
1096 .filter = (uint16_t []) {
1097 % for e in automaton.filter[op]:
1098 ${e},
1099 % endfor
1100 },
1101 <%
1102 num_filtered = len(automaton.rep[op])
1103 %>
1104 .num_filtered_states = ${num_filtered},
1105 .table = (uint16_t []) {
1106 <%
1107 num_srcs = len(next(iter(automaton.table[op])))
1108 %>
1109 % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1110 ${automaton.table[op][indices]},
1111 % endfor
1112 },
1113 },
1114 % endfor
1115 };
1116
1117 static void
1118 ${pass_name}_pre_block(nir_block *block, uint16_t *states)
1119 {
1120 nir_foreach_instr(instr, block) {
1121 switch (instr->type) {
1122 case nir_instr_type_alu: {
1123 nir_alu_instr *alu = nir_instr_as_alu(instr);
1124 nir_op op = alu->op;
1125 uint16_t search_op = nir_search_op_for_nir_op(op);
1126 const struct per_op_table *tbl = &${pass_name}_table[search_op];
1127 if (tbl->num_filtered_states == 0)
1128 continue;
1129
1130 /* Calculate the index into the transition table. Note the index
1131 * calculated must match the iteration order of Python's
1132 * itertools.product(), which was used to emit the transition
1133 * table.
1134 */
1135 uint16_t index = 0;
1136 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
1137 index *= tbl->num_filtered_states;
1138 index += tbl->filter[states[alu->src[i].src.ssa->index]];
1139 }
1140 states[alu->dest.dest.ssa.index] = tbl->table[index];
1141 break;
1142 }
1143
1144 case nir_instr_type_load_const: {
1145 nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
1146 states[load_const->def.index] = CONST_STATE;
1147 break;
1148 }
1149
1150 default:
1151 break;
1152 }
1153 }
1154 }
1155
1156 static bool
1157 ${pass_name}_block(nir_builder *build, nir_block *block,
1158 const uint16_t *states, const bool *condition_flags)
1159 {
1160 bool progress = false;
1161 const unsigned execution_mode = build->shader->info.float_controls_execution_mode;
1162
1163 nir_foreach_instr_reverse_safe(instr, block) {
1164 if (instr->type != nir_instr_type_alu)
1165 continue;
1166
1167 nir_alu_instr *alu = nir_instr_as_alu(instr);
1168 if (!alu->dest.dest.is_ssa)
1169 continue;
1170
1171 unsigned bit_size = alu->dest.dest.ssa.bit_size;
1172 const bool ignore_inexact =
1173 nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
1174 nir_is_denorm_flush_to_zero(execution_mode, bit_size);
1175
1176 switch (states[alu->dest.dest.ssa.index]) {
1177 % for i in range(len(automaton.state_patterns)):
1178 case ${i}:
1179 % if automaton.state_patterns[i]:
1180 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
1181 const struct transform *xform = &${pass_name}_state${i}_xforms[i];
1182 if (condition_flags[xform->condition_offset] &&
1183 !(xform->search->inexact && ignore_inexact) &&
1184 nir_replace_instr(build, alu, xform->search, xform->replace)) {
1185 progress = true;
1186 break;
1187 }
1188 }
1189 % endif
1190 break;
1191 % endfor
1192 default: assert(0);
1193 }
1194 }
1195
1196 return progress;
1197 }
1198
1199 static bool
1200 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
1201 {
1202 bool progress = false;
1203
1204 nir_builder build;
1205 nir_builder_init(&build, impl);
1206
1207 /* Note: it's important here that we're allocating a zeroed array, since
1208 * state 0 is the default state, which means we don't have to visit
1209 * anything other than constants and ALU instructions.
1210 */
1211 uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
1212
1213 nir_foreach_block(block, impl) {
1214 ${pass_name}_pre_block(block, states);
1215 }
1216
1217 nir_foreach_block_reverse(block, impl) {
1218 progress |= ${pass_name}_block(&build, block, states, condition_flags);
1219 }
1220
1221 free(states);
1222
1223 if (progress) {
1224 nir_metadata_preserve(impl, nir_metadata_block_index |
1225 nir_metadata_dominance);
1226 } else {
1227 #ifndef NDEBUG
1228 impl->valid_metadata &= ~nir_metadata_not_properly_reset;
1229 #endif
1230 }
1231
1232 return progress;
1233 }
1234
1235
1236 bool
1237 ${pass_name}(nir_shader *shader)
1238 {
1239 bool progress = false;
1240 bool condition_flags[${len(condition_list)}];
1241 const nir_shader_compiler_options *options = shader->options;
1242 const shader_info *info = &shader->info;
1243 (void) options;
1244 (void) info;
1245
1246 % for index, condition in enumerate(condition_list):
1247 condition_flags[${index}] = ${condition};
1248 % endfor
1249
1250 nir_foreach_function(function, shader) {
1251 if (function->impl)
1252 progress |= ${pass_name}_impl(function->impl, condition_flags);
1253 }
1254
1255 return progress;
1256 }
1257 """)
1258
1259
1260 class AlgebraicPass(object):
1261 def __init__(self, pass_name, transforms):
1262 self.xforms = []
1263 self.opcode_xforms = defaultdict(lambda : [])
1264 self.pass_name = pass_name
1265
1266 error = False
1267
1268 for xform in transforms:
1269 if not isinstance(xform, SearchAndReplace):
1270 try:
1271 xform = SearchAndReplace(xform)
1272 except:
1273 print("Failed to parse transformation:", file=sys.stderr)
1274 print(" " + str(xform), file=sys.stderr)
1275 traceback.print_exc(file=sys.stderr)
1276 print('', file=sys.stderr)
1277 error = True
1278 continue
1279
1280 self.xforms.append(xform)
1281 if xform.search.opcode in conv_opcode_types:
1282 dst_type = conv_opcode_types[xform.search.opcode]
1283 for size in type_sizes(dst_type):
1284 sized_opcode = xform.search.opcode + str(size)
1285 self.opcode_xforms[sized_opcode].append(xform)
1286 else:
1287 self.opcode_xforms[xform.search.opcode].append(xform)
1288
1289 # Check to make sure the search pattern does not unexpectedly contain
1290 # more commutative expressions than match_expression (nir_search.c)
1291 # can handle.
1292 comm_exprs = xform.search.comm_exprs
1293
1294 if xform.search.many_commutative_expressions:
1295 if comm_exprs <= nir_search_max_comm_ops:
1296 print("Transform expected to have too many commutative " \
1297 "expression but did not " \
1298 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1299 file=sys.stderr)
1300 print(" " + str(xform), file=sys.stderr)
1301 traceback.print_exc(file=sys.stderr)
1302 print('', file=sys.stderr)
1303 error = True
1304 else:
1305 if comm_exprs > nir_search_max_comm_ops:
1306 print("Transformation with too many commutative expressions " \
1307 "({} > {}). Modify pattern or annotate with " \
1308 "\"many-comm-expr\".".format(comm_exprs,
1309 nir_search_max_comm_ops),
1310 file=sys.stderr)
1311 print(" " + str(xform.search), file=sys.stderr)
1312 print("{}".format(xform.search.cond), file=sys.stderr)
1313 error = True
1314
1315 self.automaton = TreeAutomaton(self.xforms)
1316
1317 if error:
1318 sys.exit(1)
1319
1320
1321 def render(self):
1322 return _algebraic_pass_template.render(pass_name=self.pass_name,
1323 xforms=self.xforms,
1324 opcode_xforms=self.opcode_xforms,
1325 condition_list=condition_list,
1326 automaton=self.automaton,
1327 get_c_opcode=get_c_opcode,
1328 itertools=itertools)