d945c1a80756989a34630bd7cf6f6f4ce4f4dedf
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # These opcodes are only employed by nir_search. This provides a mapping from
39 # opcode to destination type.
40 conv_opcode_types = {
41 'i2f' : 'float',
42 'u2f' : 'float',
43 'f2f' : 'float',
44 'f2u' : 'uint',
45 'f2i' : 'int',
46 'u2u' : 'uint',
47 'i2i' : 'int',
48 'b2f' : 'float',
49 'b2i' : 'int',
50 'i2b' : 'bool',
51 'f2b' : 'bool',
52 }
53
54 def get_c_opcode(op):
55 if op in conv_opcode_types:
56 return 'nir_search_op_' + op
57 else:
58 return 'nir_op_' + op
59
60
61 if sys.version_info < (3, 0):
62 integer_types = (int, long)
63 string_type = unicode
64
65 else:
66 integer_types = (int, )
67 string_type = str
68
69 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
70
71 def type_bits(type_str):
72 m = _type_re.match(type_str)
73 assert m.group('type')
74
75 if m.group('bits') is None:
76 return 0
77 else:
78 return int(m.group('bits'))
79
80 # Represents a set of variables, each with a unique id
81 class VarSet(object):
82 def __init__(self):
83 self.names = {}
84 self.ids = itertools.count()
85 self.immutable = False;
86
87 def __getitem__(self, name):
88 if name not in self.names:
89 assert not self.immutable, "Unknown replacement variable: " + name
90 self.names[name] = next(self.ids)
91
92 return self.names[name]
93
94 def lock(self):
95 self.immutable = True
96
97 class Value(object):
98 @staticmethod
99 def create(val, name_base, varset):
100 if isinstance(val, bytes):
101 val = val.decode('utf-8')
102
103 if isinstance(val, tuple):
104 return Expression(val, name_base, varset)
105 elif isinstance(val, Expression):
106 return val
107 elif isinstance(val, string_type):
108 return Variable(val, name_base, varset)
109 elif isinstance(val, (bool, float) + integer_types):
110 return Constant(val, name_base)
111
112 def __init__(self, val, name, type_str):
113 self.in_val = str(val)
114 self.name = name
115 self.type_str = type_str
116
117 def __str__(self):
118 return self.in_val
119
120 def get_bit_size(self):
121 """Get the physical bit-size that has been chosen for this value, or if
122 there is none, the canonical value which currently represents this
123 bit-size class. Variables will be preferred, i.e. if there are any
124 variables in the equivalence class, the canonical value will be a
125 variable. We do this since we'll need to know which variable each value
126 is equivalent to when constructing the replacement expression. This is
127 the "find" part of the union-find algorithm.
128 """
129 bit_size = self
130
131 while isinstance(bit_size, Value):
132 if bit_size._bit_size is None:
133 break
134 bit_size = bit_size._bit_size
135
136 if bit_size is not self:
137 self._bit_size = bit_size
138 return bit_size
139
140 def set_bit_size(self, other):
141 """Make self.get_bit_size() return what other.get_bit_size() return
142 before calling this, or just "other" if it's a concrete bit-size. This is
143 the "union" part of the union-find algorithm.
144 """
145
146 self_bit_size = self.get_bit_size()
147 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
148
149 if self_bit_size == other_bit_size:
150 return
151
152 self_bit_size._bit_size = other_bit_size
153
154 @property
155 def type_enum(self):
156 return "nir_search_value_" + self.type_str
157
158 @property
159 def c_type(self):
160 return "nir_search_" + self.type_str
161
162 def __c_name(self, cache):
163 if cache is not None and self.name in cache:
164 return cache[self.name]
165 else:
166 return self.name
167
168 def c_value_ptr(self, cache):
169 return "&{0}.value".format(self.__c_name(cache))
170
171 def c_ptr(self, cache):
172 return "&{0}".format(self.__c_name(cache))
173
174 @property
175 def c_bit_size(self):
176 bit_size = self.get_bit_size()
177 if isinstance(bit_size, int):
178 return bit_size
179 elif isinstance(bit_size, Variable):
180 return -bit_size.index - 1
181 else:
182 # If the bit-size class is neither a variable, nor an actual bit-size, then
183 # - If it's in the search expression, we don't need to check anything
184 # - If it's in the replace expression, either it's ambiguous (in which
185 # case we'd reject it), or it equals the bit-size of the search value
186 # We represent these cases with a 0 bit-size.
187 return 0
188
189 __template = mako.template.Template("""{
190 { ${val.type_enum}, ${val.c_bit_size} },
191 % if isinstance(val, Constant):
192 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
193 % elif isinstance(val, Variable):
194 ${val.index}, /* ${val.var_name} */
195 ${'true' if val.is_constant else 'false'},
196 ${val.type() or 'nir_type_invalid' },
197 ${val.cond if val.cond else 'NULL'},
198 % elif isinstance(val, Expression):
199 ${'true' if val.inexact else 'false'},
200 ${val.comm_expr_idx}, ${val.comm_exprs},
201 ${val.c_opcode()},
202 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
203 ${val.cond if val.cond else 'NULL'},
204 % endif
205 };""")
206
207 def render(self, cache):
208 struct_init = self.__template.render(val=self, cache=cache,
209 Constant=Constant,
210 Variable=Variable,
211 Expression=Expression)
212 if cache is not None and struct_init in cache:
213 # If it's in the cache, register a name remap in the cache and render
214 # only a comment saying it's been remapped
215 cache[self.name] = cache[struct_init]
216 return "/* {} -> {} in the cache */\n".format(self.name,
217 cache[struct_init])
218 else:
219 if cache is not None:
220 cache[struct_init] = self.name
221 return "static const {} {} = {}\n".format(self.c_type, self.name,
222 struct_init)
223
224 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
225
226 class Constant(Value):
227 def __init__(self, val, name):
228 Value.__init__(self, val, name, "constant")
229
230 if isinstance(val, (str)):
231 m = _constant_re.match(val)
232 self.value = ast.literal_eval(m.group('value'))
233 self._bit_size = int(m.group('bits')) if m.group('bits') else None
234 else:
235 self.value = val
236 self._bit_size = None
237
238 if isinstance(self.value, bool):
239 assert self._bit_size is None or self._bit_size == 1
240 self._bit_size = 1
241
242 def hex(self):
243 if isinstance(self.value, (bool)):
244 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
245 if isinstance(self.value, integer_types):
246 return hex(self.value)
247 elif isinstance(self.value, float):
248 i = struct.unpack('Q', struct.pack('d', self.value))[0]
249 h = hex(i)
250
251 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
252 # Adding it explicitly makes the generated file identical, regardless
253 # of the Python version running this script.
254 if h[-1] != 'L' and i > sys.maxsize:
255 h += 'L'
256
257 return h
258 else:
259 assert False
260
261 def type(self):
262 if isinstance(self.value, (bool)):
263 return "nir_type_bool"
264 elif isinstance(self.value, integer_types):
265 return "nir_type_int"
266 elif isinstance(self.value, float):
267 return "nir_type_float"
268
269 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
270 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
271 r"(?P<cond>\([^\)]+\))?")
272
273 class Variable(Value):
274 def __init__(self, val, name, varset):
275 Value.__init__(self, val, name, "variable")
276
277 m = _var_name_re.match(val)
278 assert m and m.group('name') is not None
279
280 self.var_name = m.group('name')
281
282 # Prevent common cases where someone puts quotes around a literal
283 # constant. If we want to support names that have numeric or
284 # punctuation characters, we can me the first assertion more flexible.
285 assert self.var_name.isalpha()
286 assert self.var_name is not 'True'
287 assert self.var_name is not 'False'
288
289 self.is_constant = m.group('const') is not None
290 self.cond = m.group('cond')
291 self.required_type = m.group('type')
292 self._bit_size = int(m.group('bits')) if m.group('bits') else None
293
294 if self.required_type == 'bool':
295 if self._bit_size is not None:
296 assert self._bit_size in type_sizes(self.required_type)
297 else:
298 self._bit_size = 1
299
300 if self.required_type is not None:
301 assert self.required_type in ('float', 'bool', 'int', 'uint')
302
303 self.index = varset[self.var_name]
304
305 def type(self):
306 if self.required_type == 'bool':
307 return "nir_type_bool"
308 elif self.required_type in ('int', 'uint'):
309 return "nir_type_int"
310 elif self.required_type == 'float':
311 return "nir_type_float"
312
313 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
314 r"(?P<cond>\([^\)]+\))?")
315
316 class Expression(Value):
317 def __init__(self, expr, name_base, varset):
318 Value.__init__(self, expr, name_base, "expression")
319 assert isinstance(expr, tuple)
320
321 m = _opcode_re.match(expr[0])
322 assert m and m.group('opcode') is not None
323
324 self.opcode = m.group('opcode')
325 self._bit_size = int(m.group('bits')) if m.group('bits') else None
326 self.inexact = m.group('inexact') is not None
327 self.cond = m.group('cond')
328 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
329 for (i, src) in enumerate(expr[1:]) ]
330
331 if self.opcode in conv_opcode_types:
332 assert self._bit_size is None, \
333 'Expression cannot use an unsized conversion opcode with ' \
334 'an explicit size; that\'s silly.'
335
336 self.__index_comm_exprs(0)
337
338 def __index_comm_exprs(self, base_idx):
339 """Recursively count and index commutative expressions
340 """
341 self.comm_exprs = 0
342 if self.opcode not in conv_opcode_types and \
343 "2src_commutative" in opcodes[self.opcode].algebraic_properties:
344 self.comm_expr_idx = base_idx
345 self.comm_exprs += 1
346 else:
347 self.comm_expr_idx = -1
348
349 for s in self.sources:
350 if isinstance(s, Expression):
351 s.__index_comm_exprs(base_idx + self.comm_exprs)
352 self.comm_exprs += s.comm_exprs
353
354 return self.comm_exprs
355
356 def c_opcode(self):
357 return get_c_opcode(self.opcode)
358
359 def render(self, cache):
360 srcs = "\n".join(src.render(cache) for src in self.sources)
361 return srcs + super(Expression, self).render(cache)
362
363 class BitSizeValidator(object):
364 """A class for validating bit sizes of expressions.
365
366 NIR supports multiple bit-sizes on expressions in order to handle things
367 such as fp64. The source and destination of every ALU operation is
368 assigned a type and that type may or may not specify a bit size. Sources
369 and destinations whose type does not specify a bit size are considered
370 "unsized" and automatically take on the bit size of the corresponding
371 register or SSA value. NIR has two simple rules for bit sizes that are
372 validated by nir_validator:
373
374 1) A given SSA def or register has a single bit size that is respected by
375 everything that reads from it or writes to it.
376
377 2) The bit sizes of all unsized inputs/outputs on any given ALU
378 instruction must match. They need not match the sized inputs or
379 outputs but they must match each other.
380
381 In order to keep nir_algebraic relatively simple and easy-to-use,
382 nir_search supports a type of bit-size inference based on the two rules
383 above. This is similar to type inference in many common programming
384 languages. If, for instance, you are constructing an add operation and you
385 know the second source is 16-bit, then you know that the other source and
386 the destination must also be 16-bit. There are, however, cases where this
387 inference can be ambiguous or contradictory. Consider, for instance, the
388 following transformation:
389
390 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
391
392 This transformation can potentially cause a problem because usub_borrow is
393 well-defined for any bit-size of integer. However, b2i always generates a
394 32-bit result so it could end up replacing a 64-bit expression with one
395 that takes two 64-bit values and produces a 32-bit value. As another
396 example, consider this expression:
397
398 (('bcsel', a, b, 0), ('iand', a, b))
399
400 In this case, in the search expression a must be 32-bit but b can
401 potentially have any bit size. If we had a 64-bit b value, we would end up
402 trying to and a 32-bit value with a 64-bit value which would be invalid
403
404 This class solves that problem by providing a validation layer that proves
405 that a given search-and-replace operation is 100% well-defined before we
406 generate any code. This ensures that bugs are caught at compile time
407 rather than at run time.
408
409 Each value maintains a "bit-size class", which is either an actual bit size
410 or an equivalence class with other values that must have the same bit size.
411 The validator works by combining bit-size classes with each other according
412 to the NIR rules outlined above, checking that there are no inconsistencies.
413 When doing this for the replacement expression, we make sure to never change
414 the equivalence class of any of the search values. We could make the example
415 transforms above work by doing some extra run-time checking of the search
416 expression, but we make the user specify those constraints themselves, to
417 avoid any surprises. Since the replacement bitsizes can only be connected to
418 the source bitsize via variables (variables must have the same bitsize in
419 the source and replacment expressions) or the roots of the expression (the
420 replacement expression must produce the same bit size as the search
421 expression), we prevent merging a variable with anything when processing the
422 replacement expression, or specializing the search bitsize
423 with anything. The former prevents
424
425 (('bcsel', a, b, 0), ('iand', a, b))
426
427 from being allowed, since we'd have to merge the bitsizes for a and b due to
428 the 'iand', while the latter prevents
429
430 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
431
432 from being allowed, since the search expression has the bit size of a and b,
433 which can't be specialized to 32 which is the bitsize of the replace
434 expression. It also prevents something like:
435
436 (('b2i', ('i2b', a)), ('ineq', a, 0))
437
438 since the bitsize of 'b2i', which can be anything, can't be specialized to
439 the bitsize of a.
440
441 After doing all this, we check that every subexpression of the replacement
442 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
443 of the search expresssion, since those are the things that are known when
444 constructing the replacement expresssion. Finally, we record the bitsize
445 needed in nir_search_value so that we know what to do when building the
446 replacement expression.
447 """
448
449 def __init__(self, varset):
450 self._var_classes = [None] * len(varset.names)
451
452 def compare_bitsizes(self, a, b):
453 """Determines which bitsize class is a specialization of the other, or
454 whether neither is. When we merge two different bitsizes, the
455 less-specialized bitsize always points to the more-specialized one, so
456 that calling get_bit_size() always gets you the most specialized bitsize.
457 The specialization partial order is given by:
458 - Physical bitsizes are always the most specialized, and a different
459 bitsize can never specialize another.
460 - In the search expression, variables can always be specialized to each
461 other and to physical bitsizes. In the replace expression, we disallow
462 this to avoid adding extra constraints to the search expression that
463 the user didn't specify.
464 - Expressions and constants without a bitsize can always be specialized to
465 each other and variables, but not the other way around.
466
467 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
468 and None if they are not comparable (neither a <= b nor b <= a).
469 """
470 if isinstance(a, int):
471 if isinstance(b, int):
472 return 0 if a == b else None
473 elif isinstance(b, Variable):
474 return -1 if self.is_search else None
475 else:
476 return -1
477 elif isinstance(a, Variable):
478 if isinstance(b, int):
479 return 1 if self.is_search else None
480 elif isinstance(b, Variable):
481 return 0 if self.is_search or a.index == b.index else None
482 else:
483 return -1
484 else:
485 if isinstance(b, int):
486 return 1
487 elif isinstance(b, Variable):
488 return 1
489 else:
490 return 0
491
492 def unify_bit_size(self, a, b, error_msg):
493 """Record that a must have the same bit-size as b. If both
494 have been assigned conflicting physical bit-sizes, call "error_msg" with
495 the bit-sizes of self and other to get a message and raise an error.
496 In the replace expression, disallow merging variables with other
497 variables and physical bit-sizes as well.
498 """
499 a_bit_size = a.get_bit_size()
500 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
501
502 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
503
504 assert cmp_result is not None, \
505 error_msg(a_bit_size, b_bit_size)
506
507 if cmp_result < 0:
508 b_bit_size.set_bit_size(a)
509 elif not isinstance(a_bit_size, int):
510 a_bit_size.set_bit_size(b)
511
512 def merge_variables(self, val):
513 """Perform the first part of type inference by merging all the different
514 uses of the same variable. We always do this as if we're in the search
515 expression, even if we're actually not, since otherwise we'd get errors
516 if the search expression specified some constraint but the replace
517 expression didn't, because we'd be merging a variable and a constant.
518 """
519 if isinstance(val, Variable):
520 if self._var_classes[val.index] is None:
521 self._var_classes[val.index] = val
522 else:
523 other = self._var_classes[val.index]
524 self.unify_bit_size(other, val,
525 lambda other_bit_size, bit_size:
526 'Variable {} has conflicting bit size requirements: ' \
527 'it must have bit size {} and {}'.format(
528 val.var_name, other_bit_size, bit_size))
529 elif isinstance(val, Expression):
530 for src in val.sources:
531 self.merge_variables(src)
532
533 def validate_value(self, val):
534 """Validate the an expression by performing classic Hindley-Milner
535 type inference on bitsizes. This will detect if there are any conflicting
536 requirements, and unify variables so that we know which variables must
537 have the same bitsize. If we're operating on the replace expression, we
538 will refuse to merge different variables together or merge a variable
539 with a constant, in order to prevent surprises due to rules unexpectedly
540 not matching at runtime.
541 """
542 if not isinstance(val, Expression):
543 return
544
545 # Generic conversion ops are special in that they have a single unsized
546 # source and an unsized destination and the two don't have to match.
547 # This means there's no validation or unioning to do here besides the
548 # len(val.sources) check.
549 if val.opcode in conv_opcode_types:
550 assert len(val.sources) == 1, \
551 "Expression {} has {} sources, expected 1".format(
552 val, len(val.sources))
553 self.validate_value(val.sources[0])
554 return
555
556 nir_op = opcodes[val.opcode]
557 assert len(val.sources) == nir_op.num_inputs, \
558 "Expression {} has {} sources, expected {}".format(
559 val, len(val.sources), nir_op.num_inputs)
560
561 for src in val.sources:
562 self.validate_value(src)
563
564 dst_type_bits = type_bits(nir_op.output_type)
565
566 # First, unify all the sources. That way, an error coming up because two
567 # sources have an incompatible bit-size won't produce an error message
568 # involving the destination.
569 first_unsized_src = None
570 for src_type, src in zip(nir_op.input_types, val.sources):
571 src_type_bits = type_bits(src_type)
572 if src_type_bits == 0:
573 if first_unsized_src is None:
574 first_unsized_src = src
575 continue
576
577 if self.is_search:
578 self.unify_bit_size(first_unsized_src, src,
579 lambda first_unsized_src_bit_size, src_bit_size:
580 'Source {} of {} must have bit size {}, while source {} ' \
581 'must have incompatible bit size {}'.format(
582 first_unsized_src, val, first_unsized_src_bit_size,
583 src, src_bit_size))
584 else:
585 self.unify_bit_size(first_unsized_src, src,
586 lambda first_unsized_src_bit_size, src_bit_size:
587 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
588 'of {} may not have the same bit size when building the ' \
589 'replacement expression.'.format(
590 first_unsized_src, first_unsized_src_bit_size, src,
591 src_bit_size, val))
592 else:
593 if self.is_search:
594 self.unify_bit_size(src, src_type_bits,
595 lambda src_bit_size, unused:
596 '{} must have {} bits, but as a source of nir_op_{} '\
597 'it must have {} bits'.format(
598 src, src_bit_size, nir_op.name, src_type_bits))
599 else:
600 self.unify_bit_size(src, src_type_bits,
601 lambda src_bit_size, unused:
602 '{} has the bit size of {}, but as a source of ' \
603 'nir_op_{} it must have {} bits, which may not be the ' \
604 'same'.format(
605 src, src_bit_size, nir_op.name, src_type_bits))
606
607 if dst_type_bits == 0:
608 if first_unsized_src is not None:
609 if self.is_search:
610 self.unify_bit_size(val, first_unsized_src,
611 lambda val_bit_size, src_bit_size:
612 '{} must have the bit size of {}, while its source {} ' \
613 'must have incompatible bit size {}'.format(
614 val, val_bit_size, first_unsized_src, src_bit_size))
615 else:
616 self.unify_bit_size(val, first_unsized_src,
617 lambda val_bit_size, src_bit_size:
618 '{} must have {} bits, but its source {} ' \
619 '(bit size of {}) may not have that bit size ' \
620 'when building the replacement.'.format(
621 val, val_bit_size, first_unsized_src, src_bit_size))
622 else:
623 self.unify_bit_size(val, dst_type_bits,
624 lambda dst_bit_size, unused:
625 '{} must have {} bits, but as a destination of nir_op_{} ' \
626 'it must have {} bits'.format(
627 val, dst_bit_size, nir_op.name, dst_type_bits))
628
629 def validate_replace(self, val, search):
630 bit_size = val.get_bit_size()
631 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
632 bit_size == search.get_bit_size(), \
633 'Ambiguous bit size for replacement value {}: ' \
634 'it cannot be deduced from a variable, a fixed bit size ' \
635 'somewhere, or the search expression.'.format(val)
636
637 if isinstance(val, Expression):
638 for src in val.sources:
639 self.validate_replace(src, search)
640
641 def validate(self, search, replace):
642 self.is_search = True
643 self.merge_variables(search)
644 self.merge_variables(replace)
645 self.validate_value(search)
646
647 self.is_search = False
648 self.validate_value(replace)
649
650 # Check that search is always more specialized than replace. Note that
651 # we're doing this in replace mode, disallowing merging variables.
652 search_bit_size = search.get_bit_size()
653 replace_bit_size = replace.get_bit_size()
654 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
655
656 assert cmp_result is not None and cmp_result <= 0, \
657 'The search expression bit size {} and replace expression ' \
658 'bit size {} may not be the same'.format(
659 search_bit_size, replace_bit_size)
660
661 replace.set_bit_size(search)
662
663 self.validate_replace(replace, search)
664
665 _optimization_ids = itertools.count()
666
667 condition_list = ['true']
668
669 class SearchAndReplace(object):
670 def __init__(self, transform):
671 self.id = next(_optimization_ids)
672
673 search = transform[0]
674 replace = transform[1]
675 if len(transform) > 2:
676 self.condition = transform[2]
677 else:
678 self.condition = 'true'
679
680 if self.condition not in condition_list:
681 condition_list.append(self.condition)
682 self.condition_index = condition_list.index(self.condition)
683
684 varset = VarSet()
685 if isinstance(search, Expression):
686 self.search = search
687 else:
688 self.search = Expression(search, "search{0}".format(self.id), varset)
689
690 varset.lock()
691
692 if isinstance(replace, Value):
693 self.replace = replace
694 else:
695 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
696
697 BitSizeValidator(varset).validate(self.search, self.replace)
698
699 class TreeAutomaton(object):
700 """This class calculates a bottom-up tree automaton to quickly search for
701 the left-hand sides of tranforms. Tree automatons are a generalization of
702 classical NFA's and DFA's, where the transition function determines the
703 state of the parent node based on the state of its children. We construct a
704 deterministic automaton to match patterns, using a similar algorithm to the
705 classical NFA to DFA construction. At the moment, it only matches opcodes
706 and constants (without checking the actual value), leaving more detailed
707 checking to the search function which actually checks the leaves. The
708 automaton acts as a quick filter for the search function, requiring only n
709 + 1 table lookups for each n-source operation. The implementation is based
710 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
711 In the language of that reference, this is a frontier-to-root deterministic
712 automaton using only symbol filtering. The filtering is crucial to reduce
713 both the time taken to generate the tables and the size of the tables.
714 """
715 def __init__(self, transforms):
716 self.patterns = [t.search for t in transforms]
717 self._compute_items()
718 self._build_table()
719 #print('num items: {}'.format(len(set(self.items.values()))))
720 #print('num states: {}'.format(len(self.states)))
721 #for state, patterns in zip(self.states, self.patterns):
722 # print('{}: num patterns: {}'.format(state, len(patterns)))
723
724 class IndexMap(object):
725 """An indexed list of objects, where one can either lookup an object by
726 index or find the index associated to an object quickly using a hash
727 table. Compared to a list, it has a constant time index(). Compared to a
728 set, it provides a stable iteration order.
729 """
730 def __init__(self, iterable=()):
731 self.objects = []
732 self.map = {}
733 for obj in iterable:
734 self.add(obj)
735
736 def __getitem__(self, i):
737 return self.objects[i]
738
739 def __contains__(self, obj):
740 return obj in self.map
741
742 def __len__(self):
743 return len(self.objects)
744
745 def __iter__(self):
746 return iter(self.objects)
747
748 def clear(self):
749 self.objects = []
750 self.map.clear()
751
752 def index(self, obj):
753 return self.map[obj]
754
755 def add(self, obj):
756 if obj in self.map:
757 return self.map[obj]
758 else:
759 index = len(self.objects)
760 self.objects.append(obj)
761 self.map[obj] = index
762 return index
763
764 def __repr__(self):
765 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
766
767 class Item(object):
768 """This represents an "item" in the language of "Tree Automatons." This
769 is just a subtree of some pattern, which represents a potential partial
770 match at runtime. We deduplicate them, so that identical subtrees of
771 different patterns share the same object, and store some extra
772 information needed for the main algorithm as well.
773 """
774 def __init__(self, opcode, children):
775 self.opcode = opcode
776 self.children = children
777 # These are the indices of patterns for which this item is the root node.
778 self.patterns = []
779 # This the set of opcodes for parents of this item. Used to speed up
780 # filtering.
781 self.parent_ops = set()
782
783 def __str__(self):
784 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
785
786 def __repr__(self):
787 return str(self)
788
789 def _compute_items(self):
790 """Build a set of all possible items, deduplicating them."""
791 # This is a map from (opcode, sources) to item.
792 self.items = {}
793
794 # The set of all opcodes used by the patterns. Used later to avoid
795 # building and emitting all the tables for opcodes that aren't used.
796 self.opcodes = self.IndexMap()
797
798 def get_item(opcode, children, pattern=None):
799 commutative = len(children) == 2 \
800 and "2src_commutative" in opcodes[opcode].algebraic_properties
801 item = self.items.setdefault((opcode, children),
802 self.Item(opcode, children))
803 if commutative:
804 self.items[opcode, (children[1], children[0])] = item
805 if pattern is not None:
806 item.patterns.append(pattern)
807 return item
808
809 self.wildcard = get_item("__wildcard", ())
810 self.const = get_item("__const", ())
811
812 def process_subpattern(src, pattern=None):
813 if isinstance(src, Constant):
814 # Note: we throw away the actual constant value!
815 return self.const
816 elif isinstance(src, Variable):
817 if src.is_constant:
818 return self.const
819 else:
820 # Note: we throw away which variable it is here! This special
821 # item is equivalent to nu in "Tree Automatons."
822 return self.wildcard
823 else:
824 assert isinstance(src, Expression)
825 opcode = src.opcode
826 stripped = opcode.rstrip('0123456789')
827 if stripped in conv_opcode_types:
828 # Matches that use conversion opcodes with a specific type,
829 # like f2b1, are tricky. Either we construct the automaton to
830 # match specific NIR opcodes like nir_op_f2b1, in which case we
831 # need to create separate items for each possible NIR opcode
832 # for patterns that have a generic opcode like f2b, or we
833 # construct it to match the search opcode, in which case we
834 # need to map f2b1 to f2b when constructing the automaton. Here
835 # we do the latter.
836 opcode = stripped
837 self.opcodes.add(opcode)
838 children = tuple(process_subpattern(c) for c in src.sources)
839 item = get_item(opcode, children, pattern)
840 for i, child in enumerate(children):
841 child.parent_ops.add(opcode)
842 return item
843
844 for i, pattern in enumerate(self.patterns):
845 process_subpattern(pattern, i)
846
847 def _build_table(self):
848 """This is the core algorithm which builds up the transition table. It
849 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
850 Comp_a and Filt_{a,i} using integers to identify match sets." It
851 simultaneously builds up a list of all possible "match sets" or
852 "states", where each match set represents the set of Item's that match a
853 given instruction, and builds up the transition table between states.
854 """
855 # Map from opcode + filtered state indices to transitioned state.
856 self.table = defaultdict(dict)
857 # Bijection from state to index. q in the original algorithm is
858 # len(self.states)
859 self.states = self.IndexMap()
860 # List of pattern matches for each state index.
861 self.state_patterns = []
862 # Map from state index to filtered state index for each opcode.
863 self.filter = defaultdict(list)
864 # Bijections from filtered state to filtered state index for each
865 # opcode, called the "representor sets" in the original algorithm.
866 # q_{a,j} in the original algorithm is len(self.rep[op]).
867 self.rep = defaultdict(self.IndexMap)
868
869 # Everything in self.states with a index at least worklist_index is part
870 # of the worklist of newly created states. There is also a worklist of
871 # newly fitered states for each opcode, for which worklist_indices
872 # serves a similar purpose. worklist_index corresponds to p in the
873 # original algorithm, while worklist_indices is p_{a,j} (although since
874 # we only filter by opcode/symbol, it's really just p_a).
875 self.worklist_index = 0
876 worklist_indices = defaultdict(lambda: 0)
877
878 # This is the set of opcodes for which the filtered worklist is non-empty.
879 # It's used to avoid scanning opcodes for which there is nothing to
880 # process when building the transition table. It corresponds to new_a in
881 # the original algorithm.
882 new_opcodes = self.IndexMap()
883
884 # Process states on the global worklist, filtering them for each opcode,
885 # updating the filter tables, and updating the filtered worklists if any
886 # new filtered states are found. Similar to ComputeRepresenterSets() in
887 # the original algorithm, although that only processes a single state.
888 def process_new_states():
889 while self.worklist_index < len(self.states):
890 state = self.states[self.worklist_index]
891
892 # Calculate pattern matches for this state. Each pattern is
893 # assigned to a unique item, so we don't have to worry about
894 # deduplicating them here. However, we do have to sort them so
895 # that they're visited at runtime in the order they're specified
896 # in the source.
897 patterns = list(sorted(p for item in state for p in item.patterns))
898 assert len(self.state_patterns) == self.worklist_index
899 self.state_patterns.append(patterns)
900
901 # calculate filter table for this state, and update filtered
902 # worklists.
903 for op in self.opcodes:
904 filt = self.filter[op]
905 rep = self.rep[op]
906 filtered = frozenset(item for item in state if \
907 op in item.parent_ops)
908 if filtered in rep:
909 rep_index = rep.index(filtered)
910 else:
911 rep_index = rep.add(filtered)
912 new_opcodes.add(op)
913 assert len(filt) == self.worklist_index
914 filt.append(rep_index)
915 self.worklist_index += 1
916
917 # There are two start states: one which can only match as a wildcard,
918 # and one which can match as a wildcard or constant. These will be the
919 # states of intrinsics/other instructions and load_const instructions,
920 # respectively. The indices of these must match the definitions of
921 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
922 # initialize things correctly.
923 self.states.add(frozenset((self.wildcard,)))
924 self.states.add(frozenset((self.const,self.wildcard)))
925 process_new_states()
926
927 while len(new_opcodes) > 0:
928 for op in new_opcodes:
929 rep = self.rep[op]
930 table = self.table[op]
931 op_worklist_index = worklist_indices[op]
932 if op in conv_opcode_types:
933 num_srcs = 1
934 else:
935 num_srcs = opcodes[op].num_inputs
936
937 # Iterate over all possible source combinations where at least one
938 # is on the worklist.
939 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
940 if all(src_idx < op_worklist_index for src_idx in src_indices):
941 continue
942
943 srcs = tuple(rep[src_idx] for src_idx in src_indices)
944
945 # Try all possible pairings of source items and add the
946 # corresponding parent items. This is Comp_a from the paper.
947 parent = set(self.items[op, item_srcs] for item_srcs in
948 itertools.product(*srcs) if (op, item_srcs) in self.items)
949
950 # We could always start matching something else with a
951 # wildcard. This is Cl from the paper.
952 parent.add(self.wildcard)
953
954 table[src_indices] = self.states.add(frozenset(parent))
955 worklist_indices[op] = len(rep)
956 new_opcodes.clear()
957 process_new_states()
958
959 _algebraic_pass_template = mako.template.Template("""
960 #include "nir.h"
961 #include "nir_builder.h"
962 #include "nir_search.h"
963 #include "nir_search_helpers.h"
964
965 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
966 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
967
968 struct transform {
969 const nir_search_expression *search;
970 const nir_search_value *replace;
971 unsigned condition_offset;
972 };
973
974 struct per_op_table {
975 const uint16_t *filter;
976 unsigned num_filtered_states;
977 const uint16_t *table;
978 };
979
980 /* Note: these must match the start states created in
981 * TreeAutomaton._build_table()
982 */
983
984 /* WILDCARD_STATE = 0 is set by zeroing the state array */
985 static const uint16_t CONST_STATE = 1;
986
987 #endif
988
989 <% cache = {} %>
990 % for xform in xforms:
991 ${xform.search.render(cache)}
992 ${xform.replace.render(cache)}
993 % endfor
994
995 % for state_id, state_xforms in enumerate(automaton.state_patterns):
996 % if state_xforms: # avoid emitting a 0-length array for MSVC
997 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
998 % for i in state_xforms:
999 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1000 % endfor
1001 };
1002 % endif
1003 % endfor
1004
1005 static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1006 % for op in automaton.opcodes:
1007 [${get_c_opcode(op)}] = {
1008 .filter = (uint16_t []) {
1009 % for e in automaton.filter[op]:
1010 ${e},
1011 % endfor
1012 },
1013 <%
1014 num_filtered = len(automaton.rep[op])
1015 %>
1016 .num_filtered_states = ${num_filtered},
1017 .table = (uint16_t []) {
1018 <%
1019 num_srcs = len(next(iter(automaton.table[op])))
1020 %>
1021 % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1022 ${automaton.table[op][indices]},
1023 % endfor
1024 },
1025 },
1026 % endfor
1027 };
1028
1029 static void
1030 ${pass_name}_pre_block(nir_block *block, uint16_t *states)
1031 {
1032 nir_foreach_instr(instr, block) {
1033 switch (instr->type) {
1034 case nir_instr_type_alu: {
1035 nir_alu_instr *alu = nir_instr_as_alu(instr);
1036 nir_op op = alu->op;
1037 uint16_t search_op = nir_search_op_for_nir_op(op);
1038 const struct per_op_table *tbl = &${pass_name}_table[search_op];
1039 if (tbl->num_filtered_states == 0)
1040 continue;
1041
1042 /* Calculate the index into the transition table. Note the index
1043 * calculated must match the iteration order of Python's
1044 * itertools.product(), which was used to emit the transition
1045 * table.
1046 */
1047 uint16_t index = 0;
1048 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
1049 index *= tbl->num_filtered_states;
1050 index += tbl->filter[states[alu->src[i].src.ssa->index]];
1051 }
1052 states[alu->dest.dest.ssa.index] = tbl->table[index];
1053 break;
1054 }
1055
1056 case nir_instr_type_load_const: {
1057 nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
1058 states[load_const->def.index] = CONST_STATE;
1059 break;
1060 }
1061
1062 default:
1063 break;
1064 }
1065 }
1066 }
1067
1068 static bool
1069 ${pass_name}_block(nir_builder *build, nir_block *block,
1070 const uint16_t *states, const bool *condition_flags)
1071 {
1072 bool progress = false;
1073
1074 nir_foreach_instr_reverse_safe(instr, block) {
1075 if (instr->type != nir_instr_type_alu)
1076 continue;
1077
1078 nir_alu_instr *alu = nir_instr_as_alu(instr);
1079 if (!alu->dest.dest.is_ssa)
1080 continue;
1081
1082 switch (states[alu->dest.dest.ssa.index]) {
1083 % for i in range(len(automaton.state_patterns)):
1084 case ${i}:
1085 % if automaton.state_patterns[i]:
1086 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_state${i}_xforms); i++) {
1087 const struct transform *xform = &${pass_name}_state${i}_xforms[i];
1088 if (condition_flags[xform->condition_offset] &&
1089 nir_replace_instr(build, alu, xform->search, xform->replace)) {
1090 progress = true;
1091 break;
1092 }
1093 }
1094 % endif
1095 break;
1096 % endfor
1097 default: assert(0);
1098 }
1099 }
1100
1101 return progress;
1102 }
1103
1104 static bool
1105 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
1106 {
1107 bool progress = false;
1108
1109 nir_builder build;
1110 nir_builder_init(&build, impl);
1111
1112 /* Note: it's important here that we're allocating a zeroed array, since
1113 * state 0 is the default state, which means we don't have to visit
1114 * anything other than constants and ALU instructions.
1115 */
1116 uint16_t *states = calloc(impl->ssa_alloc, sizeof(*states));
1117
1118 nir_foreach_block(block, impl) {
1119 ${pass_name}_pre_block(block, states);
1120 }
1121
1122 nir_foreach_block_reverse(block, impl) {
1123 progress |= ${pass_name}_block(&build, block, states, condition_flags);
1124 }
1125
1126 free(states);
1127
1128 if (progress) {
1129 nir_metadata_preserve(impl, nir_metadata_block_index |
1130 nir_metadata_dominance);
1131 } else {
1132 #ifndef NDEBUG
1133 impl->valid_metadata &= ~nir_metadata_not_properly_reset;
1134 #endif
1135 }
1136
1137 return progress;
1138 }
1139
1140
1141 bool
1142 ${pass_name}(nir_shader *shader)
1143 {
1144 bool progress = false;
1145 bool condition_flags[${len(condition_list)}];
1146 const nir_shader_compiler_options *options = shader->options;
1147 const shader_info *info = &shader->info;
1148 (void) options;
1149 (void) info;
1150
1151 % for index, condition in enumerate(condition_list):
1152 condition_flags[${index}] = ${condition};
1153 % endfor
1154
1155 nir_foreach_function(function, shader) {
1156 if (function->impl)
1157 progress |= ${pass_name}_impl(function->impl, condition_flags);
1158 }
1159
1160 return progress;
1161 }
1162 """)
1163
1164
1165
1166 class AlgebraicPass(object):
1167 def __init__(self, pass_name, transforms):
1168 self.xforms = []
1169 self.opcode_xforms = defaultdict(lambda : [])
1170 self.pass_name = pass_name
1171
1172 error = False
1173
1174 for xform in transforms:
1175 if not isinstance(xform, SearchAndReplace):
1176 try:
1177 xform = SearchAndReplace(xform)
1178 except:
1179 print("Failed to parse transformation:", file=sys.stderr)
1180 print(" " + str(xform), file=sys.stderr)
1181 traceback.print_exc(file=sys.stderr)
1182 print('', file=sys.stderr)
1183 error = True
1184 continue
1185
1186 self.xforms.append(xform)
1187 if xform.search.opcode in conv_opcode_types:
1188 dst_type = conv_opcode_types[xform.search.opcode]
1189 for size in type_sizes(dst_type):
1190 sized_opcode = xform.search.opcode + str(size)
1191 self.opcode_xforms[sized_opcode].append(xform)
1192 else:
1193 self.opcode_xforms[xform.search.opcode].append(xform)
1194
1195 self.automaton = TreeAutomaton(self.xforms)
1196
1197 if error:
1198 sys.exit(1)
1199
1200
1201 def render(self):
1202 return _algebraic_pass_template.render(pass_name=self.pass_name,
1203 xforms=self.xforms,
1204 opcode_xforms=self.opcode_xforms,
1205 condition_list=condition_list,
1206 automaton=self.automaton,
1207 get_c_opcode=get_c_opcode,
1208 itertools=itertools)