nir/algebraic: Print out the list of transforms in the C file
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
39 nir_search_max_comm_ops = 8
40
41 # These opcodes are only employed by nir_search. This provides a mapping from
42 # opcode to destination type.
43 conv_opcode_types = {
44 'i2f' : 'float',
45 'u2f' : 'float',
46 'f2f' : 'float',
47 'f2u' : 'uint',
48 'f2i' : 'int',
49 'u2u' : 'uint',
50 'i2i' : 'int',
51 'b2f' : 'float',
52 'b2i' : 'int',
53 'i2b' : 'bool',
54 'f2b' : 'bool',
55 }
56
57 def get_c_opcode(op):
58 if op in conv_opcode_types:
59 return 'nir_search_op_' + op
60 else:
61 return 'nir_op_' + op
62
63
64 if sys.version_info < (3, 0):
65 integer_types = (int, long)
66 string_type = unicode
67
68 else:
69 integer_types = (int, )
70 string_type = str
71
72 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
73
74 def type_bits(type_str):
75 m = _type_re.match(type_str)
76 assert m.group('type')
77
78 if m.group('bits') is None:
79 return 0
80 else:
81 return int(m.group('bits'))
82
83 # Represents a set of variables, each with a unique id
84 class VarSet(object):
85 def __init__(self):
86 self.names = {}
87 self.ids = itertools.count()
88 self.immutable = False;
89
90 def __getitem__(self, name):
91 if name not in self.names:
92 assert not self.immutable, "Unknown replacement variable: " + name
93 self.names[name] = next(self.ids)
94
95 return self.names[name]
96
97 def lock(self):
98 self.immutable = True
99
100 class Value(object):
101 @staticmethod
102 def create(val, name_base, varset):
103 if isinstance(val, bytes):
104 val = val.decode('utf-8')
105
106 if isinstance(val, tuple):
107 return Expression(val, name_base, varset)
108 elif isinstance(val, Expression):
109 return val
110 elif isinstance(val, string_type):
111 return Variable(val, name_base, varset)
112 elif isinstance(val, (bool, float) + integer_types):
113 return Constant(val, name_base)
114
115 def __init__(self, val, name, type_str):
116 self.in_val = str(val)
117 self.name = name
118 self.type_str = type_str
119
120 def __str__(self):
121 return self.in_val
122
123 def get_bit_size(self):
124 """Get the physical bit-size that has been chosen for this value, or if
125 there is none, the canonical value which currently represents this
126 bit-size class. Variables will be preferred, i.e. if there are any
127 variables in the equivalence class, the canonical value will be a
128 variable. We do this since we'll need to know which variable each value
129 is equivalent to when constructing the replacement expression. This is
130 the "find" part of the union-find algorithm.
131 """
132 bit_size = self
133
134 while isinstance(bit_size, Value):
135 if bit_size._bit_size is None:
136 break
137 bit_size = bit_size._bit_size
138
139 if bit_size is not self:
140 self._bit_size = bit_size
141 return bit_size
142
143 def set_bit_size(self, other):
144 """Make self.get_bit_size() return what other.get_bit_size() return
145 before calling this, or just "other" if it's a concrete bit-size. This is
146 the "union" part of the union-find algorithm.
147 """
148
149 self_bit_size = self.get_bit_size()
150 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
151
152 if self_bit_size == other_bit_size:
153 return
154
155 self_bit_size._bit_size = other_bit_size
156
157 @property
158 def type_enum(self):
159 return "nir_search_value_" + self.type_str
160
161 @property
162 def c_type(self):
163 return "nir_search_" + self.type_str
164
165 def __c_name(self, cache):
166 if cache is not None and self.name in cache:
167 return cache[self.name]
168 else:
169 return self.name
170
171 def c_value_ptr(self, cache):
172 return "&{0}.value".format(self.__c_name(cache))
173
174 def c_ptr(self, cache):
175 return "&{0}".format(self.__c_name(cache))
176
177 @property
178 def c_bit_size(self):
179 bit_size = self.get_bit_size()
180 if isinstance(bit_size, int):
181 return bit_size
182 elif isinstance(bit_size, Variable):
183 return -bit_size.index - 1
184 else:
185 # If the bit-size class is neither a variable, nor an actual bit-size, then
186 # - If it's in the search expression, we don't need to check anything
187 # - If it's in the replace expression, either it's ambiguous (in which
188 # case we'd reject it), or it equals the bit-size of the search value
189 # We represent these cases with a 0 bit-size.
190 return 0
191
192 __template = mako.template.Template("""{
193 { ${val.type_enum}, ${val.c_bit_size} },
194 % if isinstance(val, Constant):
195 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
196 % elif isinstance(val, Variable):
197 ${val.index}, /* ${val.var_name} */
198 ${'true' if val.is_constant else 'false'},
199 ${val.type() or 'nir_type_invalid' },
200 ${val.cond if val.cond else 'NULL'},
201 % elif isinstance(val, Expression):
202 ${'true' if val.inexact else 'false'},
203 ${val.comm_expr_idx}, ${val.comm_exprs},
204 ${val.c_opcode()},
205 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
206 ${val.cond if val.cond else 'NULL'},
207 % endif
208 };""")
209
210 def render(self, cache):
211 struct_init = self.__template.render(val=self, cache=cache,
212 Constant=Constant,
213 Variable=Variable,
214 Expression=Expression)
215 if cache is not None and struct_init in cache:
216 # If it's in the cache, register a name remap in the cache and render
217 # only a comment saying it's been remapped
218 cache[self.name] = cache[struct_init]
219 return "/* {} -> {} in the cache */\n".format(self.name,
220 cache[struct_init])
221 else:
222 if cache is not None:
223 cache[struct_init] = self.name
224 return "static const {} {} = {}\n".format(self.c_type, self.name,
225 struct_init)
226
227 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
228
229 class Constant(Value):
230 def __init__(self, val, name):
231 Value.__init__(self, val, name, "constant")
232
233 if isinstance(val, (str)):
234 m = _constant_re.match(val)
235 self.value = ast.literal_eval(m.group('value'))
236 self._bit_size = int(m.group('bits')) if m.group('bits') else None
237 else:
238 self.value = val
239 self._bit_size = None
240
241 if isinstance(self.value, bool):
242 assert self._bit_size is None or self._bit_size == 1
243 self._bit_size = 1
244
245 def hex(self):
246 if isinstance(self.value, (bool)):
247 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
248 if isinstance(self.value, integer_types):
249 return hex(self.value)
250 elif isinstance(self.value, float):
251 i = struct.unpack('Q', struct.pack('d', self.value))[0]
252 h = hex(i)
253
254 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
255 # Adding it explicitly makes the generated file identical, regardless
256 # of the Python version running this script.
257 if h[-1] != 'L' and i > sys.maxsize:
258 h += 'L'
259
260 return h
261 else:
262 assert False
263
264 def type(self):
265 if isinstance(self.value, (bool)):
266 return "nir_type_bool"
267 elif isinstance(self.value, integer_types):
268 return "nir_type_int"
269 elif isinstance(self.value, float):
270 return "nir_type_float"
271
272 def equivalent(self, other):
273 """Check that two constants are equivalent.
274
275 This is check is much weaker than equality. One generally cannot be
276 used in place of the other. Using this implementation for the __eq__
277 will break BitSizeValidator.
278
279 """
280 if not isinstance(other, type(self)):
281 return False
282
283 return self.value == other.value
284
285 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
286 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
287 r"(?P<cond>\([^\)]+\))?")
288
289 class Variable(Value):
290 def __init__(self, val, name, varset):
291 Value.__init__(self, val, name, "variable")
292
293 m = _var_name_re.match(val)
294 assert m and m.group('name') is not None
295
296 self.var_name = m.group('name')
297
298 # Prevent common cases where someone puts quotes around a literal
299 # constant. If we want to support names that have numeric or
300 # punctuation characters, we can me the first assertion more flexible.
301 assert self.var_name.isalpha()
302 assert self.var_name is not 'True'
303 assert self.var_name is not 'False'
304
305 self.is_constant = m.group('const') is not None
306 self.cond = m.group('cond')
307 self.required_type = m.group('type')
308 self._bit_size = int(m.group('bits')) if m.group('bits') else None
309
310 if self.required_type == 'bool':
311 if self._bit_size is not None:
312 assert self._bit_size in type_sizes(self.required_type)
313 else:
314 self._bit_size = 1
315
316 if self.required_type is not None:
317 assert self.required_type in ('float', 'bool', 'int', 'uint')
318
319 self.index = varset[self.var_name]
320
321 def type(self):
322 if self.required_type == 'bool':
323 return "nir_type_bool"
324 elif self.required_type in ('int', 'uint'):
325 return "nir_type_int"
326 elif self.required_type == 'float':
327 return "nir_type_float"
328
329 def equivalent(self, other):
330 """Check that two variables are equivalent.
331
332 This is check is much weaker than equality. One generally cannot be
333 used in place of the other. Using this implementation for the __eq__
334 will break BitSizeValidator.
335
336 """
337 if not isinstance(other, type(self)):
338 return False
339
340 return self.index == other.index
341
342 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
343 r"(?P<cond>\([^\)]+\))?")
344
345 class Expression(Value):
346 def __init__(self, expr, name_base, varset):
347 Value.__init__(self, expr, name_base, "expression")
348 assert isinstance(expr, tuple)
349
350 m = _opcode_re.match(expr[0])
351 assert m and m.group('opcode') is not None
352
353 self.opcode = m.group('opcode')
354 self._bit_size = int(m.group('bits')) if m.group('bits') else None
355 self.inexact = m.group('inexact') is not None
356 self.cond = m.group('cond')
357
358 # "many-comm-expr" isn't really a condition. It's notification to the
359 # generator that this pattern is known to have too many commutative
360 # expressions, and an error should not be generated for this case.
361 self.many_commutative_expressions = False
362 if self.cond and self.cond.find("many-comm-expr") >= 0:
363 # Split the condition into a comma-separated list. Remove
364 # "many-comm-expr". If there is anything left, put it back together.
365 c = self.cond[1:-1].split(",")
366 c.remove("many-comm-expr")
367
368 self.cond = "({})".format(",".join(c)) if c else None
369 self.many_commutative_expressions = True
370
371 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
372 for (i, src) in enumerate(expr[1:]) ]
373
374 if self.opcode in conv_opcode_types:
375 assert self._bit_size is None, \
376 'Expression cannot use an unsized conversion opcode with ' \
377 'an explicit size; that\'s silly.'
378
379 self.__index_comm_exprs(0)
380
381 def equivalent(self, other):
382 """Check that two variables are equivalent.
383
384 This is check is much weaker than equality. One generally cannot be
385 used in place of the other. Using this implementation for the __eq__
386 will break BitSizeValidator.
387
388 This implementation does not check for equivalence due to commutativity,
389 but it could.
390
391 """
392 if not isinstance(other, type(self)):
393 return False
394
395 if len(self.sources) != len(other.sources):
396 return False
397
398 if self.opcode != other.opcode:
399 return False
400
401 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
402
403 def __index_comm_exprs(self, base_idx):
404 """Recursively count and index commutative expressions
405 """
406 self.comm_exprs = 0
407
408 # A note about the explicit "len(self.sources)" check. The list of
409 # sources comes from user input, and that input might be bad. Check
410 # that the expected second source exists before accessing it. Without
411 # this check, a unit test that does "('iadd', 'a')" will crash.
412 if self.opcode not in conv_opcode_types and \
413 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
414 len(self.sources) >= 2 and \
415 not self.sources[0].equivalent(self.sources[1]):
416 self.comm_expr_idx = base_idx
417 self.comm_exprs += 1
418 else:
419 self.comm_expr_idx = -1
420
421 for s in self.sources:
422 if isinstance(s, Expression):
423 s.__index_comm_exprs(base_idx + self.comm_exprs)
424 self.comm_exprs += s.comm_exprs
425
426 return self.comm_exprs
427
428 def c_opcode(self):
429 return get_c_opcode(self.opcode)
430
431 def render(self, cache):
432 srcs = "\n".join(src.render(cache) for src in self.sources)
433 return srcs + super(Expression, self).render(cache)
434
435 class BitSizeValidator(object):
436 """A class for validating bit sizes of expressions.
437
438 NIR supports multiple bit-sizes on expressions in order to handle things
439 such as fp64. The source and destination of every ALU operation is
440 assigned a type and that type may or may not specify a bit size. Sources
441 and destinations whose type does not specify a bit size are considered
442 "unsized" and automatically take on the bit size of the corresponding
443 register or SSA value. NIR has two simple rules for bit sizes that are
444 validated by nir_validator:
445
446 1) A given SSA def or register has a single bit size that is respected by
447 everything that reads from it or writes to it.
448
449 2) The bit sizes of all unsized inputs/outputs on any given ALU
450 instruction must match. They need not match the sized inputs or
451 outputs but they must match each other.
452
453 In order to keep nir_algebraic relatively simple and easy-to-use,
454 nir_search supports a type of bit-size inference based on the two rules
455 above. This is similar to type inference in many common programming
456 languages. If, for instance, you are constructing an add operation and you
457 know the second source is 16-bit, then you know that the other source and
458 the destination must also be 16-bit. There are, however, cases where this
459 inference can be ambiguous or contradictory. Consider, for instance, the
460 following transformation:
461
462 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
463
464 This transformation can potentially cause a problem because usub_borrow is
465 well-defined for any bit-size of integer. However, b2i always generates a
466 32-bit result so it could end up replacing a 64-bit expression with one
467 that takes two 64-bit values and produces a 32-bit value. As another
468 example, consider this expression:
469
470 (('bcsel', a, b, 0), ('iand', a, b))
471
472 In this case, in the search expression a must be 32-bit but b can
473 potentially have any bit size. If we had a 64-bit b value, we would end up
474 trying to and a 32-bit value with a 64-bit value which would be invalid
475
476 This class solves that problem by providing a validation layer that proves
477 that a given search-and-replace operation is 100% well-defined before we
478 generate any code. This ensures that bugs are caught at compile time
479 rather than at run time.
480
481 Each value maintains a "bit-size class", which is either an actual bit size
482 or an equivalence class with other values that must have the same bit size.
483 The validator works by combining bit-size classes with each other according
484 to the NIR rules outlined above, checking that there are no inconsistencies.
485 When doing this for the replacement expression, we make sure to never change
486 the equivalence class of any of the search values. We could make the example
487 transforms above work by doing some extra run-time checking of the search
488 expression, but we make the user specify those constraints themselves, to
489 avoid any surprises. Since the replacement bitsizes can only be connected to
490 the source bitsize via variables (variables must have the same bitsize in
491 the source and replacment expressions) or the roots of the expression (the
492 replacement expression must produce the same bit size as the search
493 expression), we prevent merging a variable with anything when processing the
494 replacement expression, or specializing the search bitsize
495 with anything. The former prevents
496
497 (('bcsel', a, b, 0), ('iand', a, b))
498
499 from being allowed, since we'd have to merge the bitsizes for a and b due to
500 the 'iand', while the latter prevents
501
502 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
503
504 from being allowed, since the search expression has the bit size of a and b,
505 which can't be specialized to 32 which is the bitsize of the replace
506 expression. It also prevents something like:
507
508 (('b2i', ('i2b', a)), ('ineq', a, 0))
509
510 since the bitsize of 'b2i', which can be anything, can't be specialized to
511 the bitsize of a.
512
513 After doing all this, we check that every subexpression of the replacement
514 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
515 of the search expresssion, since those are the things that are known when
516 constructing the replacement expresssion. Finally, we record the bitsize
517 needed in nir_search_value so that we know what to do when building the
518 replacement expression.
519 """
520
521 def __init__(self, varset):
522 self._var_classes = [None] * len(varset.names)
523
524 def compare_bitsizes(self, a, b):
525 """Determines which bitsize class is a specialization of the other, or
526 whether neither is. When we merge two different bitsizes, the
527 less-specialized bitsize always points to the more-specialized one, so
528 that calling get_bit_size() always gets you the most specialized bitsize.
529 The specialization partial order is given by:
530 - Physical bitsizes are always the most specialized, and a different
531 bitsize can never specialize another.
532 - In the search expression, variables can always be specialized to each
533 other and to physical bitsizes. In the replace expression, we disallow
534 this to avoid adding extra constraints to the search expression that
535 the user didn't specify.
536 - Expressions and constants without a bitsize can always be specialized to
537 each other and variables, but not the other way around.
538
539 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
540 and None if they are not comparable (neither a <= b nor b <= a).
541 """
542 if isinstance(a, int):
543 if isinstance(b, int):
544 return 0 if a == b else None
545 elif isinstance(b, Variable):
546 return -1 if self.is_search else None
547 else:
548 return -1
549 elif isinstance(a, Variable):
550 if isinstance(b, int):
551 return 1 if self.is_search else None
552 elif isinstance(b, Variable):
553 return 0 if self.is_search or a.index == b.index else None
554 else:
555 return -1
556 else:
557 if isinstance(b, int):
558 return 1
559 elif isinstance(b, Variable):
560 return 1
561 else:
562 return 0
563
564 def unify_bit_size(self, a, b, error_msg):
565 """Record that a must have the same bit-size as b. If both
566 have been assigned conflicting physical bit-sizes, call "error_msg" with
567 the bit-sizes of self and other to get a message and raise an error.
568 In the replace expression, disallow merging variables with other
569 variables and physical bit-sizes as well.
570 """
571 a_bit_size = a.get_bit_size()
572 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
573
574 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
575
576 assert cmp_result is not None, \
577 error_msg(a_bit_size, b_bit_size)
578
579 if cmp_result < 0:
580 b_bit_size.set_bit_size(a)
581 elif not isinstance(a_bit_size, int):
582 a_bit_size.set_bit_size(b)
583
584 def merge_variables(self, val):
585 """Perform the first part of type inference by merging all the different
586 uses of the same variable. We always do this as if we're in the search
587 expression, even if we're actually not, since otherwise we'd get errors
588 if the search expression specified some constraint but the replace
589 expression didn't, because we'd be merging a variable and a constant.
590 """
591 if isinstance(val, Variable):
592 if self._var_classes[val.index] is None:
593 self._var_classes[val.index] = val
594 else:
595 other = self._var_classes[val.index]
596 self.unify_bit_size(other, val,
597 lambda other_bit_size, bit_size:
598 'Variable {} has conflicting bit size requirements: ' \
599 'it must have bit size {} and {}'.format(
600 val.var_name, other_bit_size, bit_size))
601 elif isinstance(val, Expression):
602 for src in val.sources:
603 self.merge_variables(src)
604
605 def validate_value(self, val):
606 """Validate the an expression by performing classic Hindley-Milner
607 type inference on bitsizes. This will detect if there are any conflicting
608 requirements, and unify variables so that we know which variables must
609 have the same bitsize. If we're operating on the replace expression, we
610 will refuse to merge different variables together or merge a variable
611 with a constant, in order to prevent surprises due to rules unexpectedly
612 not matching at runtime.
613 """
614 if not isinstance(val, Expression):
615 return
616
617 # Generic conversion ops are special in that they have a single unsized
618 # source and an unsized destination and the two don't have to match.
619 # This means there's no validation or unioning to do here besides the
620 # len(val.sources) check.
621 if val.opcode in conv_opcode_types:
622 assert len(val.sources) == 1, \
623 "Expression {} has {} sources, expected 1".format(
624 val, len(val.sources))
625 self.validate_value(val.sources[0])
626 return
627
628 nir_op = opcodes[val.opcode]
629 assert len(val.sources) == nir_op.num_inputs, \
630 "Expression {} has {} sources, expected {}".format(
631 val, len(val.sources), nir_op.num_inputs)
632
633 for src in val.sources:
634 self.validate_value(src)
635
636 dst_type_bits = type_bits(nir_op.output_type)
637
638 # First, unify all the sources. That way, an error coming up because two
639 # sources have an incompatible bit-size won't produce an error message
640 # involving the destination.
641 first_unsized_src = None
642 for src_type, src in zip(nir_op.input_types, val.sources):
643 src_type_bits = type_bits(src_type)
644 if src_type_bits == 0:
645 if first_unsized_src is None:
646 first_unsized_src = src
647 continue
648
649 if self.is_search:
650 self.unify_bit_size(first_unsized_src, src,
651 lambda first_unsized_src_bit_size, src_bit_size:
652 'Source {} of {} must have bit size {}, while source {} ' \
653 'must have incompatible bit size {}'.format(
654 first_unsized_src, val, first_unsized_src_bit_size,
655 src, src_bit_size))
656 else:
657 self.unify_bit_size(first_unsized_src, src,
658 lambda first_unsized_src_bit_size, src_bit_size:
659 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
660 'of {} may not have the same bit size when building the ' \
661 'replacement expression.'.format(
662 first_unsized_src, first_unsized_src_bit_size, src,
663 src_bit_size, val))
664 else:
665 if self.is_search:
666 self.unify_bit_size(src, src_type_bits,
667 lambda src_bit_size, unused:
668 '{} must have {} bits, but as a source of nir_op_{} '\
669 'it must have {} bits'.format(
670 src, src_bit_size, nir_op.name, src_type_bits))
671 else:
672 self.unify_bit_size(src, src_type_bits,
673 lambda src_bit_size, unused:
674 '{} has the bit size of {}, but as a source of ' \
675 'nir_op_{} it must have {} bits, which may not be the ' \
676 'same'.format(
677 src, src_bit_size, nir_op.name, src_type_bits))
678
679 if dst_type_bits == 0:
680 if first_unsized_src is not None:
681 if self.is_search:
682 self.unify_bit_size(val, first_unsized_src,
683 lambda val_bit_size, src_bit_size:
684 '{} must have the bit size of {}, while its source {} ' \
685 'must have incompatible bit size {}'.format(
686 val, val_bit_size, first_unsized_src, src_bit_size))
687 else:
688 self.unify_bit_size(val, first_unsized_src,
689 lambda val_bit_size, src_bit_size:
690 '{} must have {} bits, but its source {} ' \
691 '(bit size of {}) may not have that bit size ' \
692 'when building the replacement.'.format(
693 val, val_bit_size, first_unsized_src, src_bit_size))
694 else:
695 self.unify_bit_size(val, dst_type_bits,
696 lambda dst_bit_size, unused:
697 '{} must have {} bits, but as a destination of nir_op_{} ' \
698 'it must have {} bits'.format(
699 val, dst_bit_size, nir_op.name, dst_type_bits))
700
701 def validate_replace(self, val, search):
702 bit_size = val.get_bit_size()
703 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
704 bit_size == search.get_bit_size(), \
705 'Ambiguous bit size for replacement value {}: ' \
706 'it cannot be deduced from a variable, a fixed bit size ' \
707 'somewhere, or the search expression.'.format(val)
708
709 if isinstance(val, Expression):
710 for src in val.sources:
711 self.validate_replace(src, search)
712
713 def validate(self, search, replace):
714 self.is_search = True
715 self.merge_variables(search)
716 self.merge_variables(replace)
717 self.validate_value(search)
718
719 self.is_search = False
720 self.validate_value(replace)
721
722 # Check that search is always more specialized than replace. Note that
723 # we're doing this in replace mode, disallowing merging variables.
724 search_bit_size = search.get_bit_size()
725 replace_bit_size = replace.get_bit_size()
726 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
727
728 assert cmp_result is not None and cmp_result <= 0, \
729 'The search expression bit size {} and replace expression ' \
730 'bit size {} may not be the same'.format(
731 search_bit_size, replace_bit_size)
732
733 replace.set_bit_size(search)
734
735 self.validate_replace(replace, search)
736
737 _optimization_ids = itertools.count()
738
739 condition_list = ['true']
740
741 class SearchAndReplace(object):
742 def __init__(self, transform):
743 self.id = next(_optimization_ids)
744
745 search = transform[0]
746 replace = transform[1]
747 if len(transform) > 2:
748 self.condition = transform[2]
749 else:
750 self.condition = 'true'
751
752 if self.condition not in condition_list:
753 condition_list.append(self.condition)
754 self.condition_index = condition_list.index(self.condition)
755
756 varset = VarSet()
757 if isinstance(search, Expression):
758 self.search = search
759 else:
760 self.search = Expression(search, "search{0}".format(self.id), varset)
761
762 varset.lock()
763
764 if isinstance(replace, Value):
765 self.replace = replace
766 else:
767 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
768
769 BitSizeValidator(varset).validate(self.search, self.replace)
770
771 class TreeAutomaton(object):
772 """This class calculates a bottom-up tree automaton to quickly search for
773 the left-hand sides of tranforms. Tree automatons are a generalization of
774 classical NFA's and DFA's, where the transition function determines the
775 state of the parent node based on the state of its children. We construct a
776 deterministic automaton to match patterns, using a similar algorithm to the
777 classical NFA to DFA construction. At the moment, it only matches opcodes
778 and constants (without checking the actual value), leaving more detailed
779 checking to the search function which actually checks the leaves. The
780 automaton acts as a quick filter for the search function, requiring only n
781 + 1 table lookups for each n-source operation. The implementation is based
782 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
783 In the language of that reference, this is a frontier-to-root deterministic
784 automaton using only symbol filtering. The filtering is crucial to reduce
785 both the time taken to generate the tables and the size of the tables.
786 """
787 def __init__(self, transforms):
788 self.patterns = [t.search for t in transforms]
789 self._compute_items()
790 self._build_table()
791 #print('num items: {}'.format(len(set(self.items.values()))))
792 #print('num states: {}'.format(len(self.states)))
793 #for state, patterns in zip(self.states, self.patterns):
794 # print('{}: num patterns: {}'.format(state, len(patterns)))
795
796 class IndexMap(object):
797 """An indexed list of objects, where one can either lookup an object by
798 index or find the index associated to an object quickly using a hash
799 table. Compared to a list, it has a constant time index(). Compared to a
800 set, it provides a stable iteration order.
801 """
802 def __init__(self, iterable=()):
803 self.objects = []
804 self.map = {}
805 for obj in iterable:
806 self.add(obj)
807
808 def __getitem__(self, i):
809 return self.objects[i]
810
811 def __contains__(self, obj):
812 return obj in self.map
813
814 def __len__(self):
815 return len(self.objects)
816
817 def __iter__(self):
818 return iter(self.objects)
819
820 def clear(self):
821 self.objects = []
822 self.map.clear()
823
824 def index(self, obj):
825 return self.map[obj]
826
827 def add(self, obj):
828 if obj in self.map:
829 return self.map[obj]
830 else:
831 index = len(self.objects)
832 self.objects.append(obj)
833 self.map[obj] = index
834 return index
835
836 def __repr__(self):
837 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
838
839 class Item(object):
840 """This represents an "item" in the language of "Tree Automatons." This
841 is just a subtree of some pattern, which represents a potential partial
842 match at runtime. We deduplicate them, so that identical subtrees of
843 different patterns share the same object, and store some extra
844 information needed for the main algorithm as well.
845 """
846 def __init__(self, opcode, children):
847 self.opcode = opcode
848 self.children = children
849 # These are the indices of patterns for which this item is the root node.
850 self.patterns = []
851 # This the set of opcodes for parents of this item. Used to speed up
852 # filtering.
853 self.parent_ops = set()
854
855 def __str__(self):
856 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
857
858 def __repr__(self):
859 return str(self)
860
861 def _compute_items(self):
862 """Build a set of all possible items, deduplicating them."""
863 # This is a map from (opcode, sources) to item.
864 self.items = {}
865
866 # The set of all opcodes used by the patterns. Used later to avoid
867 # building and emitting all the tables for opcodes that aren't used.
868 self.opcodes = self.IndexMap()
869
870 def get_item(opcode, children, pattern=None):
871 commutative = len(children) >= 2 \
872 and "2src_commutative" in opcodes[opcode].algebraic_properties
873 item = self.items.setdefault((opcode, children),
874 self.Item(opcode, children))
875 if commutative:
876 self.items[opcode, (children[1], children[0]) + children[2:]] = item
877 if pattern is not None:
878 item.patterns.append(pattern)
879 return item
880
881 self.wildcard = get_item("__wildcard", ())
882 self.const = get_item("__const", ())
883
884 def process_subpattern(src, pattern=None):
885 if isinstance(src, Constant):
886 # Note: we throw away the actual constant value!
887 return self.const
888 elif isinstance(src, Variable):
889 if src.is_constant:
890 return self.const
891 else:
892 # Note: we throw away which variable it is here! This special
893 # item is equivalent to nu in "Tree Automatons."
894 return self.wildcard
895 else:
896 assert isinstance(src, Expression)
897 opcode = src.opcode
898 stripped = opcode.rstrip('0123456789')
899 if stripped in conv_opcode_types:
900 # Matches that use conversion opcodes with a specific type,
901 # like f2b1, are tricky. Either we construct the automaton to
902 # match specific NIR opcodes like nir_op_f2b1, in which case we
903 # need to create separate items for each possible NIR opcode
904 # for patterns that have a generic opcode like f2b, or we
905 # construct it to match the search opcode, in which case we
906 # need to map f2b1 to f2b when constructing the automaton. Here
907 # we do the latter.
908 opcode = stripped
909 self.opcodes.add(opcode)
910 children = tuple(process_subpattern(c) for c in src.sources)
911 item = get_item(opcode, children, pattern)
912 for i, child in enumerate(children):
913 child.parent_ops.add(opcode)
914 return item
915
916 for i, pattern in enumerate(self.patterns):
917 process_subpattern(pattern, i)
918
919 def _build_table(self):
920 """This is the core algorithm which builds up the transition table. It
921 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
922 Comp_a and Filt_{a,i} using integers to identify match sets." It
923 simultaneously builds up a list of all possible "match sets" or
924 "states", where each match set represents the set of Item's that match a
925 given instruction, and builds up the transition table between states.
926 """
927 # Map from opcode + filtered state indices to transitioned state.
928 self.table = defaultdict(dict)
929 # Bijection from state to index. q in the original algorithm is
930 # len(self.states)
931 self.states = self.IndexMap()
932 # List of pattern matches for each state index.
933 self.state_patterns = []
934 # Map from state index to filtered state index for each opcode.
935 self.filter = defaultdict(list)
936 # Bijections from filtered state to filtered state index for each
937 # opcode, called the "representor sets" in the original algorithm.
938 # q_{a,j} in the original algorithm is len(self.rep[op]).
939 self.rep = defaultdict(self.IndexMap)
940
941 # Everything in self.states with a index at least worklist_index is part
942 # of the worklist of newly created states. There is also a worklist of
943 # newly fitered states for each opcode, for which worklist_indices
944 # serves a similar purpose. worklist_index corresponds to p in the
945 # original algorithm, while worklist_indices is p_{a,j} (although since
946 # we only filter by opcode/symbol, it's really just p_a).
947 self.worklist_index = 0
948 worklist_indices = defaultdict(lambda: 0)
949
950 # This is the set of opcodes for which the filtered worklist is non-empty.
951 # It's used to avoid scanning opcodes for which there is nothing to
952 # process when building the transition table. It corresponds to new_a in
953 # the original algorithm.
954 new_opcodes = self.IndexMap()
955
956 # Process states on the global worklist, filtering them for each opcode,
957 # updating the filter tables, and updating the filtered worklists if any
958 # new filtered states are found. Similar to ComputeRepresenterSets() in
959 # the original algorithm, although that only processes a single state.
960 def process_new_states():
961 while self.worklist_index < len(self.states):
962 state = self.states[self.worklist_index]
963
964 # Calculate pattern matches for this state. Each pattern is
965 # assigned to a unique item, so we don't have to worry about
966 # deduplicating them here. However, we do have to sort them so
967 # that they're visited at runtime in the order they're specified
968 # in the source.
969 patterns = list(sorted(p for item in state for p in item.patterns))
970 assert len(self.state_patterns) == self.worklist_index
971 self.state_patterns.append(patterns)
972
973 # calculate filter table for this state, and update filtered
974 # worklists.
975 for op in self.opcodes:
976 filt = self.filter[op]
977 rep = self.rep[op]
978 filtered = frozenset(item for item in state if \
979 op in item.parent_ops)
980 if filtered in rep:
981 rep_index = rep.index(filtered)
982 else:
983 rep_index = rep.add(filtered)
984 new_opcodes.add(op)
985 assert len(filt) == self.worklist_index
986 filt.append(rep_index)
987 self.worklist_index += 1
988
989 # There are two start states: one which can only match as a wildcard,
990 # and one which can match as a wildcard or constant. These will be the
991 # states of intrinsics/other instructions and load_const instructions,
992 # respectively. The indices of these must match the definitions of
993 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
994 # initialize things correctly.
995 self.states.add(frozenset((self.wildcard,)))
996 self.states.add(frozenset((self.const,self.wildcard)))
997 process_new_states()
998
999 while len(new_opcodes) > 0:
1000 for op in new_opcodes:
1001 rep = self.rep[op]
1002 table = self.table[op]
1003 op_worklist_index = worklist_indices[op]
1004 if op in conv_opcode_types:
1005 num_srcs = 1
1006 else:
1007 num_srcs = opcodes[op].num_inputs
1008
1009 # Iterate over all possible source combinations where at least one
1010 # is on the worklist.
1011 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1012 if all(src_idx < op_worklist_index for src_idx in src_indices):
1013 continue
1014
1015 srcs = tuple(rep[src_idx] for src_idx in src_indices)
1016
1017 # Try all possible pairings of source items and add the
1018 # corresponding parent items. This is Comp_a from the paper.
1019 parent = set(self.items[op, item_srcs] for item_srcs in
1020 itertools.product(*srcs) if (op, item_srcs) in self.items)
1021
1022 # We could always start matching something else with a
1023 # wildcard. This is Cl from the paper.
1024 parent.add(self.wildcard)
1025
1026 table[src_indices] = self.states.add(frozenset(parent))
1027 worklist_indices[op] = len(rep)
1028 new_opcodes.clear()
1029 process_new_states()
1030
1031 _algebraic_pass_template = mako.template.Template("""
1032 #include "nir.h"
1033 #include "nir_builder.h"
1034 #include "nir_search.h"
1035 #include "nir_search_helpers.h"
1036
1037 /* What follows is NIR algebraic transform code for the following ${len(xforms)}
1038 * transforms:
1039 % for xform in xforms:
1040 * ${xform.search} => ${xform.replace}
1041 % endfor
1042 */
1043
1044 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
1045 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
1046
1047 struct transform {
1048 const nir_search_expression *search;
1049 const nir_search_value *replace;
1050 unsigned condition_offset;
1051 };
1052
1053 struct per_op_table {
1054 const uint16_t *filter;
1055 unsigned num_filtered_states;
1056 const uint16_t *table;
1057 };
1058
1059 /* Note: these must match the start states created in
1060 * TreeAutomaton._build_table()
1061 */
1062
1063 /* WILDCARD_STATE = 0 is set by zeroing the state array */
1064 static const uint16_t CONST_STATE = 1;
1065
1066 #endif
1067
1068 <% cache = {} %>
1069 % for xform in xforms:
1070 ${xform.search.render(cache)}
1071 ${xform.replace.render(cache)}
1072 % endfor
1073
1074 % for state_id, state_xforms in enumerate(automaton.state_patterns):
1075 % if state_xforms: # avoid emitting a 0-length array for MSVC
1076 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1077 % for i in state_xforms:
1078 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1079 % endfor
1080 };
1081 % endif
1082 % endfor
1083
1084 static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1085 % for op in automaton.opcodes:
1086 [${get_c_opcode(op)}] = {
1087 .filter = (uint16_t []) {
1088 % for e in automaton.filter[op]:
1089 ${e},
1090 % endfor
1091 },
1092 <%
1093 num_filtered = len(automaton.rep[op])
1094 %>
1095 .num_filtered_states = ${num_filtered},
1096 .table = (uint16_t []) {
1097 <%
1098 num_srcs = len(next(iter(automaton.table[op])))
1099 %>
1100 % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1101 ${automaton.table[op][indices]},
1102 % endfor
1103 },
1104 },
1105 % endfor
1106 };
1107
1108 static void
1109 ${pass_name}_pre_block(nir_block *block, uint16_t *states)
1110 {
1111 nir_foreach_instr(instr, block) {
1112 switch (instr->type) {
1113 case nir_instr_type_alu: {
1114 nir_alu_instr *alu = nir_instr_as_alu(instr);
1115 nir_op op = alu->op;
1116 uint16_t search_op = nir_search_op_for_nir_op(op);
1117 const struct per_op_table *tbl = &${pass_name}_table[search_op];
1118 if (tbl->num_filtered_states == 0)
1119 continue;
1120
1121 /* Calculate the index into the transition table. Note the index
1122 * calculated must match the iteration order of Python's
1123 * itertools.product(), which was used to emit the transition
1124 * table.
1125 */
1126 uint16_t index = 0;
1127 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
1128 index *= tbl->num_filtered_states;
1129 index += tbl->filter[states[alu->src[i].src.ssa->index]];
1130 }
1131 states[alu->dest.dest.ssa.index] = tbl->table[index];
1132 break;
1133 }
1134
1135 case nir_instr_type_load_const: {
1136 nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
1137 states[load_const->def.index] = CONST_STATE;
1138 break;
1139 }
1140
1141 default:
1142 break;
1143 }
1144 }
1145 }
1146
1147 static bool
1148 ${pass_name}_block(nir_builder *build, nir_block *block,
1149 const uint16_t *states, const bool *condition_flags)
1150 {
1151 bool progress = false;
1152
1153 nir_foreach_instr_reverse_safe(instr, block) {
1154 if (instr->type != nir_instr_type_alu)
1155 continue;
1156
1157 nir_alu_instr *alu = nir_instr_as_alu(instr);
1158 if (!alu->dest.dest.is_ssa)
1159 continue;
1160
1161 switch (states[alu->dest.dest.ssa.index]) {
1162 % for i in range(len(automaton.state_patterns)):
1163 case ${i}:
1164 % if automaton.state_patterns[i]:
1165 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
1166 const struct transform *xform = &${pass_name}_state${i}_xforms[i];
1167 if (condition_flags[xform->condition_offset] &&
1168 nir_replace_instr(build, alu, xform->search, xform->replace)) {
1169 progress = true;
1170 break;
1171 }
1172 }
1173 % endif
1174 break;
1175 % endfor
1176 default: assert(0);
1177 }
1178 }
1179
1180 return progress;
1181 }
1182
1183 static bool
1184 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
1185 {
1186 bool progress = false;
1187
1188 nir_builder build;
1189 nir_builder_init(&build, impl);
1190
1191 /* Note: it's important here that we're allocating a zeroed array, since
1192 * state 0 is the default state, which means we don't have to visit
1193 * anything other than constants and ALU instructions.
1194 */
1195 uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
1196
1197 nir_foreach_block(block, impl) {
1198 ${pass_name}_pre_block(block, states);
1199 }
1200
1201 nir_foreach_block_reverse(block, impl) {
1202 progress |= ${pass_name}_block(&build, block, states, condition_flags);
1203 }
1204
1205 free(states);
1206
1207 if (progress) {
1208 nir_metadata_preserve(impl, nir_metadata_block_index |
1209 nir_metadata_dominance);
1210 } else {
1211 #ifndef NDEBUG
1212 impl->valid_metadata &= ~nir_metadata_not_properly_reset;
1213 #endif
1214 }
1215
1216 return progress;
1217 }
1218
1219
1220 bool
1221 ${pass_name}(nir_shader *shader)
1222 {
1223 bool progress = false;
1224 bool condition_flags[${len(condition_list)}];
1225 const nir_shader_compiler_options *options = shader->options;
1226 const shader_info *info = &shader->info;
1227 (void) options;
1228 (void) info;
1229
1230 % for index, condition in enumerate(condition_list):
1231 condition_flags[${index}] = ${condition};
1232 % endfor
1233
1234 nir_foreach_function(function, shader) {
1235 if (function->impl)
1236 progress |= ${pass_name}_impl(function->impl, condition_flags);
1237 }
1238
1239 return progress;
1240 }
1241 """)
1242
1243
1244 class AlgebraicPass(object):
1245 def __init__(self, pass_name, transforms):
1246 self.xforms = []
1247 self.opcode_xforms = defaultdict(lambda : [])
1248 self.pass_name = pass_name
1249
1250 error = False
1251
1252 for xform in transforms:
1253 if not isinstance(xform, SearchAndReplace):
1254 try:
1255 xform = SearchAndReplace(xform)
1256 except:
1257 print("Failed to parse transformation:", file=sys.stderr)
1258 print(" " + str(xform), file=sys.stderr)
1259 traceback.print_exc(file=sys.stderr)
1260 print('', file=sys.stderr)
1261 error = True
1262 continue
1263
1264 self.xforms.append(xform)
1265 if xform.search.opcode in conv_opcode_types:
1266 dst_type = conv_opcode_types[xform.search.opcode]
1267 for size in type_sizes(dst_type):
1268 sized_opcode = xform.search.opcode + str(size)
1269 self.opcode_xforms[sized_opcode].append(xform)
1270 else:
1271 self.opcode_xforms[xform.search.opcode].append(xform)
1272
1273 # Check to make sure the search pattern does not unexpectedly contain
1274 # more commutative expressions than match_expression (nir_search.c)
1275 # can handle.
1276 comm_exprs = xform.search.comm_exprs
1277
1278 if xform.search.many_commutative_expressions:
1279 if comm_exprs <= nir_search_max_comm_ops:
1280 print("Transform expected to have too many commutative " \
1281 "expression but did not " \
1282 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1283 file=sys.stderr)
1284 print(" " + str(xform), file=sys.stderr)
1285 traceback.print_exc(file=sys.stderr)
1286 print('', file=sys.stderr)
1287 error = True
1288 else:
1289 if comm_exprs > nir_search_max_comm_ops:
1290 print("Transformation with too many commutative expressions " \
1291 "({} > {}). Modify pattern or annotate with " \
1292 "\"many-comm-expr\".".format(comm_exprs,
1293 nir_search_max_comm_ops),
1294 file=sys.stderr)
1295 print(" " + str(xform.search), file=sys.stderr)
1296 print("{}".format(xform.search.cond), file=sys.stderr)
1297 error = True
1298
1299 self.automaton = TreeAutomaton(self.xforms)
1300
1301 if error:
1302 sys.exit(1)
1303
1304
1305 def render(self):
1306 return _algebraic_pass_template.render(pass_name=self.pass_name,
1307 xforms=self.xforms,
1308 opcode_xforms=self.opcode_xforms,
1309 condition_list=condition_list,
1310 automaton=self.automaton,
1311 get_c_opcode=get_c_opcode,
1312 itertools=itertools)