2112854570daa7f5d309749299c5014b45f15366
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
39 nir_search_max_comm_ops = 8
40
41 # These opcodes are only employed by nir_search. This provides a mapping from
42 # opcode to destination type.
43 conv_opcode_types = {
44 'i2f' : 'float',
45 'u2f' : 'float',
46 'f2f' : 'float',
47 'f2u' : 'uint',
48 'f2i' : 'int',
49 'u2u' : 'uint',
50 'i2i' : 'int',
51 'b2f' : 'float',
52 'b2i' : 'int',
53 'i2b' : 'bool',
54 'f2b' : 'bool',
55 }
56
57 def get_c_opcode(op):
58 if op in conv_opcode_types:
59 return 'nir_search_op_' + op
60 else:
61 return 'nir_op_' + op
62
63
64 if sys.version_info < (3, 0):
65 integer_types = (int, long)
66 string_type = unicode
67
68 else:
69 integer_types = (int, )
70 string_type = str
71
72 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
73
74 def type_bits(type_str):
75 m = _type_re.match(type_str)
76 assert m.group('type')
77
78 if m.group('bits') is None:
79 return 0
80 else:
81 return int(m.group('bits'))
82
83 # Represents a set of variables, each with a unique id
84 class VarSet(object):
85 def __init__(self):
86 self.names = {}
87 self.ids = itertools.count()
88 self.immutable = False;
89
90 def __getitem__(self, name):
91 if name not in self.names:
92 assert not self.immutable, "Unknown replacement variable: " + name
93 self.names[name] = next(self.ids)
94
95 return self.names[name]
96
97 def lock(self):
98 self.immutable = True
99
100 class Value(object):
101 @staticmethod
102 def create(val, name_base, varset):
103 if isinstance(val, bytes):
104 val = val.decode('utf-8')
105
106 if isinstance(val, tuple):
107 return Expression(val, name_base, varset)
108 elif isinstance(val, Expression):
109 return val
110 elif isinstance(val, string_type):
111 return Variable(val, name_base, varset)
112 elif isinstance(val, (bool, float) + integer_types):
113 return Constant(val, name_base)
114
115 def __init__(self, val, name, type_str):
116 self.in_val = str(val)
117 self.name = name
118 self.type_str = type_str
119
120 def __str__(self):
121 return self.in_val
122
123 def get_bit_size(self):
124 """Get the physical bit-size that has been chosen for this value, or if
125 there is none, the canonical value which currently represents this
126 bit-size class. Variables will be preferred, i.e. if there are any
127 variables in the equivalence class, the canonical value will be a
128 variable. We do this since we'll need to know which variable each value
129 is equivalent to when constructing the replacement expression. This is
130 the "find" part of the union-find algorithm.
131 """
132 bit_size = self
133
134 while isinstance(bit_size, Value):
135 if bit_size._bit_size is None:
136 break
137 bit_size = bit_size._bit_size
138
139 if bit_size is not self:
140 self._bit_size = bit_size
141 return bit_size
142
143 def set_bit_size(self, other):
144 """Make self.get_bit_size() return what other.get_bit_size() return
145 before calling this, or just "other" if it's a concrete bit-size. This is
146 the "union" part of the union-find algorithm.
147 """
148
149 self_bit_size = self.get_bit_size()
150 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
151
152 if self_bit_size == other_bit_size:
153 return
154
155 self_bit_size._bit_size = other_bit_size
156
157 @property
158 def type_enum(self):
159 return "nir_search_value_" + self.type_str
160
161 @property
162 def c_type(self):
163 return "nir_search_" + self.type_str
164
165 def __c_name(self, cache):
166 if cache is not None and self.name in cache:
167 return cache[self.name]
168 else:
169 return self.name
170
171 def c_value_ptr(self, cache):
172 return "&{0}.value".format(self.__c_name(cache))
173
174 def c_ptr(self, cache):
175 return "&{0}".format(self.__c_name(cache))
176
177 @property
178 def c_bit_size(self):
179 bit_size = self.get_bit_size()
180 if isinstance(bit_size, int):
181 return bit_size
182 elif isinstance(bit_size, Variable):
183 return -bit_size.index - 1
184 else:
185 # If the bit-size class is neither a variable, nor an actual bit-size, then
186 # - If it's in the search expression, we don't need to check anything
187 # - If it's in the replace expression, either it's ambiguous (in which
188 # case we'd reject it), or it equals the bit-size of the search value
189 # We represent these cases with a 0 bit-size.
190 return 0
191
192 __template = mako.template.Template("""{
193 { ${val.type_enum}, ${val.c_bit_size} },
194 % if isinstance(val, Constant):
195 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
196 % elif isinstance(val, Variable):
197 ${val.index}, /* ${val.var_name} */
198 ${'true' if val.is_constant else 'false'},
199 ${val.type() or 'nir_type_invalid' },
200 ${val.cond if val.cond else 'NULL'},
201 ${val.swizzle()},
202 % elif isinstance(val, Expression):
203 ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
204 ${val.comm_expr_idx}, ${val.comm_exprs},
205 ${val.c_opcode()},
206 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
207 ${val.cond if val.cond else 'NULL'},
208 % endif
209 };""")
210
211 def render(self, cache):
212 struct_init = self.__template.render(val=self, cache=cache,
213 Constant=Constant,
214 Variable=Variable,
215 Expression=Expression)
216 if cache is not None and struct_init in cache:
217 # If it's in the cache, register a name remap in the cache and render
218 # only a comment saying it's been remapped
219 cache[self.name] = cache[struct_init]
220 return "/* {} -> {} in the cache */\n".format(self.name,
221 cache[struct_init])
222 else:
223 if cache is not None:
224 cache[struct_init] = self.name
225 return "static const {} {} = {}\n".format(self.c_type, self.name,
226 struct_init)
227
228 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
229
230 class Constant(Value):
231 def __init__(self, val, name):
232 Value.__init__(self, val, name, "constant")
233
234 if isinstance(val, (str)):
235 m = _constant_re.match(val)
236 self.value = ast.literal_eval(m.group('value'))
237 self._bit_size = int(m.group('bits')) if m.group('bits') else None
238 else:
239 self.value = val
240 self._bit_size = None
241
242 if isinstance(self.value, bool):
243 assert self._bit_size is None or self._bit_size == 1
244 self._bit_size = 1
245
246 def hex(self):
247 if isinstance(self.value, (bool)):
248 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
249 if isinstance(self.value, integer_types):
250 return hex(self.value)
251 elif isinstance(self.value, float):
252 i = struct.unpack('Q', struct.pack('d', self.value))[0]
253 h = hex(i)
254
255 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
256 # Adding it explicitly makes the generated file identical, regardless
257 # of the Python version running this script.
258 if h[-1] != 'L' and i > sys.maxsize:
259 h += 'L'
260
261 return h
262 else:
263 assert False
264
265 def type(self):
266 if isinstance(self.value, (bool)):
267 return "nir_type_bool"
268 elif isinstance(self.value, integer_types):
269 return "nir_type_int"
270 elif isinstance(self.value, float):
271 return "nir_type_float"
272
273 def equivalent(self, other):
274 """Check that two constants are equivalent.
275
276 This is check is much weaker than equality. One generally cannot be
277 used in place of the other. Using this implementation for the __eq__
278 will break BitSizeValidator.
279
280 """
281 if not isinstance(other, type(self)):
282 return False
283
284 return self.value == other.value
285
286 # The $ at the end forces there to be an error if any part of the string
287 # doesn't match one of the field patterns.
288 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
289 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
290 r"(?P<cond>\([^\)]+\))?"
291 r"(?P<swiz>\.[xyzw]+)?"
292 r"$")
293
294 class Variable(Value):
295 def __init__(self, val, name, varset):
296 Value.__init__(self, val, name, "variable")
297
298 m = _var_name_re.match(val)
299 assert m and m.group('name') is not None, \
300 "Malformed variable name \"{}\".".format(val)
301
302 self.var_name = m.group('name')
303
304 # Prevent common cases where someone puts quotes around a literal
305 # constant. If we want to support names that have numeric or
306 # punctuation characters, we can me the first assertion more flexible.
307 assert self.var_name.isalpha()
308 assert self.var_name != 'True'
309 assert self.var_name != 'False'
310
311 self.is_constant = m.group('const') is not None
312 self.cond = m.group('cond')
313 self.required_type = m.group('type')
314 self._bit_size = int(m.group('bits')) if m.group('bits') else None
315 self.swiz = m.group('swiz')
316
317 if self.required_type == 'bool':
318 if self._bit_size is not None:
319 assert self._bit_size in type_sizes(self.required_type)
320 else:
321 self._bit_size = 1
322
323 if self.required_type is not None:
324 assert self.required_type in ('float', 'bool', 'int', 'uint')
325
326 self.index = varset[self.var_name]
327
328 def type(self):
329 if self.required_type == 'bool':
330 return "nir_type_bool"
331 elif self.required_type in ('int', 'uint'):
332 return "nir_type_int"
333 elif self.required_type == 'float':
334 return "nir_type_float"
335
336 def equivalent(self, other):
337 """Check that two variables are equivalent.
338
339 This is check is much weaker than equality. One generally cannot be
340 used in place of the other. Using this implementation for the __eq__
341 will break BitSizeValidator.
342
343 """
344 if not isinstance(other, type(self)):
345 return False
346
347 return self.index == other.index
348
349 def swizzle(self):
350 if self.swiz is not None:
351 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w': 3}
352 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
353 return '{0, 1, 2, 3}'
354
355 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
356 r"(?P<cond>\([^\)]+\))?")
357
358 class Expression(Value):
359 def __init__(self, expr, name_base, varset):
360 Value.__init__(self, expr, name_base, "expression")
361 assert isinstance(expr, tuple)
362
363 m = _opcode_re.match(expr[0])
364 assert m and m.group('opcode') is not None
365
366 self.opcode = m.group('opcode')
367 self._bit_size = int(m.group('bits')) if m.group('bits') else None
368 self.inexact = m.group('inexact') is not None
369 self.exact = m.group('exact') is not None
370 self.cond = m.group('cond')
371
372 assert not self.inexact or not self.exact, \
373 'Expression cannot be both exact and inexact.'
374
375 # "many-comm-expr" isn't really a condition. It's notification to the
376 # generator that this pattern is known to have too many commutative
377 # expressions, and an error should not be generated for this case.
378 self.many_commutative_expressions = False
379 if self.cond and self.cond.find("many-comm-expr") >= 0:
380 # Split the condition into a comma-separated list. Remove
381 # "many-comm-expr". If there is anything left, put it back together.
382 c = self.cond[1:-1].split(",")
383 c.remove("many-comm-expr")
384
385 self.cond = "({})".format(",".join(c)) if c else None
386 self.many_commutative_expressions = True
387
388 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
389 for (i, src) in enumerate(expr[1:]) ]
390
391 if self.opcode in conv_opcode_types:
392 assert self._bit_size is None, \
393 'Expression cannot use an unsized conversion opcode with ' \
394 'an explicit size; that\'s silly.'
395
396 self.__index_comm_exprs(0)
397
398 def equivalent(self, other):
399 """Check that two variables are equivalent.
400
401 This is check is much weaker than equality. One generally cannot be
402 used in place of the other. Using this implementation for the __eq__
403 will break BitSizeValidator.
404
405 This implementation does not check for equivalence due to commutativity,
406 but it could.
407
408 """
409 if not isinstance(other, type(self)):
410 return False
411
412 if len(self.sources) != len(other.sources):
413 return False
414
415 if self.opcode != other.opcode:
416 return False
417
418 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
419
420 def __index_comm_exprs(self, base_idx):
421 """Recursively count and index commutative expressions
422 """
423 self.comm_exprs = 0
424
425 # A note about the explicit "len(self.sources)" check. The list of
426 # sources comes from user input, and that input might be bad. Check
427 # that the expected second source exists before accessing it. Without
428 # this check, a unit test that does "('iadd', 'a')" will crash.
429 if self.opcode not in conv_opcode_types and \
430 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
431 len(self.sources) >= 2 and \
432 not self.sources[0].equivalent(self.sources[1]):
433 self.comm_expr_idx = base_idx
434 self.comm_exprs += 1
435 else:
436 self.comm_expr_idx = -1
437
438 for s in self.sources:
439 if isinstance(s, Expression):
440 s.__index_comm_exprs(base_idx + self.comm_exprs)
441 self.comm_exprs += s.comm_exprs
442
443 return self.comm_exprs
444
445 def c_opcode(self):
446 return get_c_opcode(self.opcode)
447
448 def render(self, cache):
449 srcs = "\n".join(src.render(cache) for src in self.sources)
450 return srcs + super(Expression, self).render(cache)
451
452 class BitSizeValidator(object):
453 """A class for validating bit sizes of expressions.
454
455 NIR supports multiple bit-sizes on expressions in order to handle things
456 such as fp64. The source and destination of every ALU operation is
457 assigned a type and that type may or may not specify a bit size. Sources
458 and destinations whose type does not specify a bit size are considered
459 "unsized" and automatically take on the bit size of the corresponding
460 register or SSA value. NIR has two simple rules for bit sizes that are
461 validated by nir_validator:
462
463 1) A given SSA def or register has a single bit size that is respected by
464 everything that reads from it or writes to it.
465
466 2) The bit sizes of all unsized inputs/outputs on any given ALU
467 instruction must match. They need not match the sized inputs or
468 outputs but they must match each other.
469
470 In order to keep nir_algebraic relatively simple and easy-to-use,
471 nir_search supports a type of bit-size inference based on the two rules
472 above. This is similar to type inference in many common programming
473 languages. If, for instance, you are constructing an add operation and you
474 know the second source is 16-bit, then you know that the other source and
475 the destination must also be 16-bit. There are, however, cases where this
476 inference can be ambiguous or contradictory. Consider, for instance, the
477 following transformation:
478
479 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
480
481 This transformation can potentially cause a problem because usub_borrow is
482 well-defined for any bit-size of integer. However, b2i always generates a
483 32-bit result so it could end up replacing a 64-bit expression with one
484 that takes two 64-bit values and produces a 32-bit value. As another
485 example, consider this expression:
486
487 (('bcsel', a, b, 0), ('iand', a, b))
488
489 In this case, in the search expression a must be 32-bit but b can
490 potentially have any bit size. If we had a 64-bit b value, we would end up
491 trying to and a 32-bit value with a 64-bit value which would be invalid
492
493 This class solves that problem by providing a validation layer that proves
494 that a given search-and-replace operation is 100% well-defined before we
495 generate any code. This ensures that bugs are caught at compile time
496 rather than at run time.
497
498 Each value maintains a "bit-size class", which is either an actual bit size
499 or an equivalence class with other values that must have the same bit size.
500 The validator works by combining bit-size classes with each other according
501 to the NIR rules outlined above, checking that there are no inconsistencies.
502 When doing this for the replacement expression, we make sure to never change
503 the equivalence class of any of the search values. We could make the example
504 transforms above work by doing some extra run-time checking of the search
505 expression, but we make the user specify those constraints themselves, to
506 avoid any surprises. Since the replacement bitsizes can only be connected to
507 the source bitsize via variables (variables must have the same bitsize in
508 the source and replacment expressions) or the roots of the expression (the
509 replacement expression must produce the same bit size as the search
510 expression), we prevent merging a variable with anything when processing the
511 replacement expression, or specializing the search bitsize
512 with anything. The former prevents
513
514 (('bcsel', a, b, 0), ('iand', a, b))
515
516 from being allowed, since we'd have to merge the bitsizes for a and b due to
517 the 'iand', while the latter prevents
518
519 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
520
521 from being allowed, since the search expression has the bit size of a and b,
522 which can't be specialized to 32 which is the bitsize of the replace
523 expression. It also prevents something like:
524
525 (('b2i', ('i2b', a)), ('ineq', a, 0))
526
527 since the bitsize of 'b2i', which can be anything, can't be specialized to
528 the bitsize of a.
529
530 After doing all this, we check that every subexpression of the replacement
531 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
532 of the search expresssion, since those are the things that are known when
533 constructing the replacement expresssion. Finally, we record the bitsize
534 needed in nir_search_value so that we know what to do when building the
535 replacement expression.
536 """
537
538 def __init__(self, varset):
539 self._var_classes = [None] * len(varset.names)
540
541 def compare_bitsizes(self, a, b):
542 """Determines which bitsize class is a specialization of the other, or
543 whether neither is. When we merge two different bitsizes, the
544 less-specialized bitsize always points to the more-specialized one, so
545 that calling get_bit_size() always gets you the most specialized bitsize.
546 The specialization partial order is given by:
547 - Physical bitsizes are always the most specialized, and a different
548 bitsize can never specialize another.
549 - In the search expression, variables can always be specialized to each
550 other and to physical bitsizes. In the replace expression, we disallow
551 this to avoid adding extra constraints to the search expression that
552 the user didn't specify.
553 - Expressions and constants without a bitsize can always be specialized to
554 each other and variables, but not the other way around.
555
556 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
557 and None if they are not comparable (neither a <= b nor b <= a).
558 """
559 if isinstance(a, int):
560 if isinstance(b, int):
561 return 0 if a == b else None
562 elif isinstance(b, Variable):
563 return -1 if self.is_search else None
564 else:
565 return -1
566 elif isinstance(a, Variable):
567 if isinstance(b, int):
568 return 1 if self.is_search else None
569 elif isinstance(b, Variable):
570 return 0 if self.is_search or a.index == b.index else None
571 else:
572 return -1
573 else:
574 if isinstance(b, int):
575 return 1
576 elif isinstance(b, Variable):
577 return 1
578 else:
579 return 0
580
581 def unify_bit_size(self, a, b, error_msg):
582 """Record that a must have the same bit-size as b. If both
583 have been assigned conflicting physical bit-sizes, call "error_msg" with
584 the bit-sizes of self and other to get a message and raise an error.
585 In the replace expression, disallow merging variables with other
586 variables and physical bit-sizes as well.
587 """
588 a_bit_size = a.get_bit_size()
589 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
590
591 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
592
593 assert cmp_result is not None, \
594 error_msg(a_bit_size, b_bit_size)
595
596 if cmp_result < 0:
597 b_bit_size.set_bit_size(a)
598 elif not isinstance(a_bit_size, int):
599 a_bit_size.set_bit_size(b)
600
601 def merge_variables(self, val):
602 """Perform the first part of type inference by merging all the different
603 uses of the same variable. We always do this as if we're in the search
604 expression, even if we're actually not, since otherwise we'd get errors
605 if the search expression specified some constraint but the replace
606 expression didn't, because we'd be merging a variable and a constant.
607 """
608 if isinstance(val, Variable):
609 if self._var_classes[val.index] is None:
610 self._var_classes[val.index] = val
611 else:
612 other = self._var_classes[val.index]
613 self.unify_bit_size(other, val,
614 lambda other_bit_size, bit_size:
615 'Variable {} has conflicting bit size requirements: ' \
616 'it must have bit size {} and {}'.format(
617 val.var_name, other_bit_size, bit_size))
618 elif isinstance(val, Expression):
619 for src in val.sources:
620 self.merge_variables(src)
621
622 def validate_value(self, val):
623 """Validate the an expression by performing classic Hindley-Milner
624 type inference on bitsizes. This will detect if there are any conflicting
625 requirements, and unify variables so that we know which variables must
626 have the same bitsize. If we're operating on the replace expression, we
627 will refuse to merge different variables together or merge a variable
628 with a constant, in order to prevent surprises due to rules unexpectedly
629 not matching at runtime.
630 """
631 if not isinstance(val, Expression):
632 return
633
634 # Generic conversion ops are special in that they have a single unsized
635 # source and an unsized destination and the two don't have to match.
636 # This means there's no validation or unioning to do here besides the
637 # len(val.sources) check.
638 if val.opcode in conv_opcode_types:
639 assert len(val.sources) == 1, \
640 "Expression {} has {} sources, expected 1".format(
641 val, len(val.sources))
642 self.validate_value(val.sources[0])
643 return
644
645 nir_op = opcodes[val.opcode]
646 assert len(val.sources) == nir_op.num_inputs, \
647 "Expression {} has {} sources, expected {}".format(
648 val, len(val.sources), nir_op.num_inputs)
649
650 for src in val.sources:
651 self.validate_value(src)
652
653 dst_type_bits = type_bits(nir_op.output_type)
654
655 # First, unify all the sources. That way, an error coming up because two
656 # sources have an incompatible bit-size won't produce an error message
657 # involving the destination.
658 first_unsized_src = None
659 for src_type, src in zip(nir_op.input_types, val.sources):
660 src_type_bits = type_bits(src_type)
661 if src_type_bits == 0:
662 if first_unsized_src is None:
663 first_unsized_src = src
664 continue
665
666 if self.is_search:
667 self.unify_bit_size(first_unsized_src, src,
668 lambda first_unsized_src_bit_size, src_bit_size:
669 'Source {} of {} must have bit size {}, while source {} ' \
670 'must have incompatible bit size {}'.format(
671 first_unsized_src, val, first_unsized_src_bit_size,
672 src, src_bit_size))
673 else:
674 self.unify_bit_size(first_unsized_src, src,
675 lambda first_unsized_src_bit_size, src_bit_size:
676 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
677 'of {} may not have the same bit size when building the ' \
678 'replacement expression.'.format(
679 first_unsized_src, first_unsized_src_bit_size, src,
680 src_bit_size, val))
681 else:
682 if self.is_search:
683 self.unify_bit_size(src, src_type_bits,
684 lambda src_bit_size, unused:
685 '{} must have {} bits, but as a source of nir_op_{} '\
686 'it must have {} bits'.format(
687 src, src_bit_size, nir_op.name, src_type_bits))
688 else:
689 self.unify_bit_size(src, src_type_bits,
690 lambda src_bit_size, unused:
691 '{} has the bit size of {}, but as a source of ' \
692 'nir_op_{} it must have {} bits, which may not be the ' \
693 'same'.format(
694 src, src_bit_size, nir_op.name, src_type_bits))
695
696 if dst_type_bits == 0:
697 if first_unsized_src is not None:
698 if self.is_search:
699 self.unify_bit_size(val, first_unsized_src,
700 lambda val_bit_size, src_bit_size:
701 '{} must have the bit size of {}, while its source {} ' \
702 'must have incompatible bit size {}'.format(
703 val, val_bit_size, first_unsized_src, src_bit_size))
704 else:
705 self.unify_bit_size(val, first_unsized_src,
706 lambda val_bit_size, src_bit_size:
707 '{} must have {} bits, but its source {} ' \
708 '(bit size of {}) may not have that bit size ' \
709 'when building the replacement.'.format(
710 val, val_bit_size, first_unsized_src, src_bit_size))
711 else:
712 self.unify_bit_size(val, dst_type_bits,
713 lambda dst_bit_size, unused:
714 '{} must have {} bits, but as a destination of nir_op_{} ' \
715 'it must have {} bits'.format(
716 val, dst_bit_size, nir_op.name, dst_type_bits))
717
718 def validate_replace(self, val, search):
719 bit_size = val.get_bit_size()
720 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
721 bit_size == search.get_bit_size(), \
722 'Ambiguous bit size for replacement value {}: ' \
723 'it cannot be deduced from a variable, a fixed bit size ' \
724 'somewhere, or the search expression.'.format(val)
725
726 if isinstance(val, Expression):
727 for src in val.sources:
728 self.validate_replace(src, search)
729
730 def validate(self, search, replace):
731 self.is_search = True
732 self.merge_variables(search)
733 self.merge_variables(replace)
734 self.validate_value(search)
735
736 self.is_search = False
737 self.validate_value(replace)
738
739 # Check that search is always more specialized than replace. Note that
740 # we're doing this in replace mode, disallowing merging variables.
741 search_bit_size = search.get_bit_size()
742 replace_bit_size = replace.get_bit_size()
743 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
744
745 assert cmp_result is not None and cmp_result <= 0, \
746 'The search expression bit size {} and replace expression ' \
747 'bit size {} may not be the same'.format(
748 search_bit_size, replace_bit_size)
749
750 replace.set_bit_size(search)
751
752 self.validate_replace(replace, search)
753
754 _optimization_ids = itertools.count()
755
756 condition_list = ['true']
757
758 class SearchAndReplace(object):
759 def __init__(self, transform):
760 self.id = next(_optimization_ids)
761
762 search = transform[0]
763 replace = transform[1]
764 if len(transform) > 2:
765 self.condition = transform[2]
766 else:
767 self.condition = 'true'
768
769 if self.condition not in condition_list:
770 condition_list.append(self.condition)
771 self.condition_index = condition_list.index(self.condition)
772
773 varset = VarSet()
774 if isinstance(search, Expression):
775 self.search = search
776 else:
777 self.search = Expression(search, "search{0}".format(self.id), varset)
778
779 varset.lock()
780
781 if isinstance(replace, Value):
782 self.replace = replace
783 else:
784 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
785
786 BitSizeValidator(varset).validate(self.search, self.replace)
787
788 class TreeAutomaton(object):
789 """This class calculates a bottom-up tree automaton to quickly search for
790 the left-hand sides of tranforms. Tree automatons are a generalization of
791 classical NFA's and DFA's, where the transition function determines the
792 state of the parent node based on the state of its children. We construct a
793 deterministic automaton to match patterns, using a similar algorithm to the
794 classical NFA to DFA construction. At the moment, it only matches opcodes
795 and constants (without checking the actual value), leaving more detailed
796 checking to the search function which actually checks the leaves. The
797 automaton acts as a quick filter for the search function, requiring only n
798 + 1 table lookups for each n-source operation. The implementation is based
799 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
800 In the language of that reference, this is a frontier-to-root deterministic
801 automaton using only symbol filtering. The filtering is crucial to reduce
802 both the time taken to generate the tables and the size of the tables.
803 """
804 def __init__(self, transforms):
805 self.patterns = [t.search for t in transforms]
806 self._compute_items()
807 self._build_table()
808 #print('num items: {}'.format(len(set(self.items.values()))))
809 #print('num states: {}'.format(len(self.states)))
810 #for state, patterns in zip(self.states, self.patterns):
811 # print('{}: num patterns: {}'.format(state, len(patterns)))
812
813 class IndexMap(object):
814 """An indexed list of objects, where one can either lookup an object by
815 index or find the index associated to an object quickly using a hash
816 table. Compared to a list, it has a constant time index(). Compared to a
817 set, it provides a stable iteration order.
818 """
819 def __init__(self, iterable=()):
820 self.objects = []
821 self.map = {}
822 for obj in iterable:
823 self.add(obj)
824
825 def __getitem__(self, i):
826 return self.objects[i]
827
828 def __contains__(self, obj):
829 return obj in self.map
830
831 def __len__(self):
832 return len(self.objects)
833
834 def __iter__(self):
835 return iter(self.objects)
836
837 def clear(self):
838 self.objects = []
839 self.map.clear()
840
841 def index(self, obj):
842 return self.map[obj]
843
844 def add(self, obj):
845 if obj in self.map:
846 return self.map[obj]
847 else:
848 index = len(self.objects)
849 self.objects.append(obj)
850 self.map[obj] = index
851 return index
852
853 def __repr__(self):
854 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
855
856 class Item(object):
857 """This represents an "item" in the language of "Tree Automatons." This
858 is just a subtree of some pattern, which represents a potential partial
859 match at runtime. We deduplicate them, so that identical subtrees of
860 different patterns share the same object, and store some extra
861 information needed for the main algorithm as well.
862 """
863 def __init__(self, opcode, children):
864 self.opcode = opcode
865 self.children = children
866 # These are the indices of patterns for which this item is the root node.
867 self.patterns = []
868 # This the set of opcodes for parents of this item. Used to speed up
869 # filtering.
870 self.parent_ops = set()
871
872 def __str__(self):
873 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
874
875 def __repr__(self):
876 return str(self)
877
878 def _compute_items(self):
879 """Build a set of all possible items, deduplicating them."""
880 # This is a map from (opcode, sources) to item.
881 self.items = {}
882
883 # The set of all opcodes used by the patterns. Used later to avoid
884 # building and emitting all the tables for opcodes that aren't used.
885 self.opcodes = self.IndexMap()
886
887 def get_item(opcode, children, pattern=None):
888 commutative = len(children) >= 2 \
889 and "2src_commutative" in opcodes[opcode].algebraic_properties
890 item = self.items.setdefault((opcode, children),
891 self.Item(opcode, children))
892 if commutative:
893 self.items[opcode, (children[1], children[0]) + children[2:]] = item
894 if pattern is not None:
895 item.patterns.append(pattern)
896 return item
897
898 self.wildcard = get_item("__wildcard", ())
899 self.const = get_item("__const", ())
900
901 def process_subpattern(src, pattern=None):
902 if isinstance(src, Constant):
903 # Note: we throw away the actual constant value!
904 return self.const
905 elif isinstance(src, Variable):
906 if src.is_constant:
907 return self.const
908 else:
909 # Note: we throw away which variable it is here! This special
910 # item is equivalent to nu in "Tree Automatons."
911 return self.wildcard
912 else:
913 assert isinstance(src, Expression)
914 opcode = src.opcode
915 stripped = opcode.rstrip('0123456789')
916 if stripped in conv_opcode_types:
917 # Matches that use conversion opcodes with a specific type,
918 # like f2b1, are tricky. Either we construct the automaton to
919 # match specific NIR opcodes like nir_op_f2b1, in which case we
920 # need to create separate items for each possible NIR opcode
921 # for patterns that have a generic opcode like f2b, or we
922 # construct it to match the search opcode, in which case we
923 # need to map f2b1 to f2b when constructing the automaton. Here
924 # we do the latter.
925 opcode = stripped
926 self.opcodes.add(opcode)
927 children = tuple(process_subpattern(c) for c in src.sources)
928 item = get_item(opcode, children, pattern)
929 for i, child in enumerate(children):
930 child.parent_ops.add(opcode)
931 return item
932
933 for i, pattern in enumerate(self.patterns):
934 process_subpattern(pattern, i)
935
936 def _build_table(self):
937 """This is the core algorithm which builds up the transition table. It
938 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
939 Comp_a and Filt_{a,i} using integers to identify match sets." It
940 simultaneously builds up a list of all possible "match sets" or
941 "states", where each match set represents the set of Item's that match a
942 given instruction, and builds up the transition table between states.
943 """
944 # Map from opcode + filtered state indices to transitioned state.
945 self.table = defaultdict(dict)
946 # Bijection from state to index. q in the original algorithm is
947 # len(self.states)
948 self.states = self.IndexMap()
949 # List of pattern matches for each state index.
950 self.state_patterns = []
951 # Map from state index to filtered state index for each opcode.
952 self.filter = defaultdict(list)
953 # Bijections from filtered state to filtered state index for each
954 # opcode, called the "representor sets" in the original algorithm.
955 # q_{a,j} in the original algorithm is len(self.rep[op]).
956 self.rep = defaultdict(self.IndexMap)
957
958 # Everything in self.states with a index at least worklist_index is part
959 # of the worklist of newly created states. There is also a worklist of
960 # newly fitered states for each opcode, for which worklist_indices
961 # serves a similar purpose. worklist_index corresponds to p in the
962 # original algorithm, while worklist_indices is p_{a,j} (although since
963 # we only filter by opcode/symbol, it's really just p_a).
964 self.worklist_index = 0
965 worklist_indices = defaultdict(lambda: 0)
966
967 # This is the set of opcodes for which the filtered worklist is non-empty.
968 # It's used to avoid scanning opcodes for which there is nothing to
969 # process when building the transition table. It corresponds to new_a in
970 # the original algorithm.
971 new_opcodes = self.IndexMap()
972
973 # Process states on the global worklist, filtering them for each opcode,
974 # updating the filter tables, and updating the filtered worklists if any
975 # new filtered states are found. Similar to ComputeRepresenterSets() in
976 # the original algorithm, although that only processes a single state.
977 def process_new_states():
978 while self.worklist_index < len(self.states):
979 state = self.states[self.worklist_index]
980
981 # Calculate pattern matches for this state. Each pattern is
982 # assigned to a unique item, so we don't have to worry about
983 # deduplicating them here. However, we do have to sort them so
984 # that they're visited at runtime in the order they're specified
985 # in the source.
986 patterns = list(sorted(p for item in state for p in item.patterns))
987 assert len(self.state_patterns) == self.worklist_index
988 self.state_patterns.append(patterns)
989
990 # calculate filter table for this state, and update filtered
991 # worklists.
992 for op in self.opcodes:
993 filt = self.filter[op]
994 rep = self.rep[op]
995 filtered = frozenset(item for item in state if \
996 op in item.parent_ops)
997 if filtered in rep:
998 rep_index = rep.index(filtered)
999 else:
1000 rep_index = rep.add(filtered)
1001 new_opcodes.add(op)
1002 assert len(filt) == self.worklist_index
1003 filt.append(rep_index)
1004 self.worklist_index += 1
1005
1006 # There are two start states: one which can only match as a wildcard,
1007 # and one which can match as a wildcard or constant. These will be the
1008 # states of intrinsics/other instructions and load_const instructions,
1009 # respectively. The indices of these must match the definitions of
1010 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1011 # initialize things correctly.
1012 self.states.add(frozenset((self.wildcard,)))
1013 self.states.add(frozenset((self.const,self.wildcard)))
1014 process_new_states()
1015
1016 while len(new_opcodes) > 0:
1017 for op in new_opcodes:
1018 rep = self.rep[op]
1019 table = self.table[op]
1020 op_worklist_index = worklist_indices[op]
1021 if op in conv_opcode_types:
1022 num_srcs = 1
1023 else:
1024 num_srcs = opcodes[op].num_inputs
1025
1026 # Iterate over all possible source combinations where at least one
1027 # is on the worklist.
1028 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1029 if all(src_idx < op_worklist_index for src_idx in src_indices):
1030 continue
1031
1032 srcs = tuple(rep[src_idx] for src_idx in src_indices)
1033
1034 # Try all possible pairings of source items and add the
1035 # corresponding parent items. This is Comp_a from the paper.
1036 parent = set(self.items[op, item_srcs] for item_srcs in
1037 itertools.product(*srcs) if (op, item_srcs) in self.items)
1038
1039 # We could always start matching something else with a
1040 # wildcard. This is Cl from the paper.
1041 parent.add(self.wildcard)
1042
1043 table[src_indices] = self.states.add(frozenset(parent))
1044 worklist_indices[op] = len(rep)
1045 new_opcodes.clear()
1046 process_new_states()
1047
1048 _algebraic_pass_template = mako.template.Template("""
1049 #include "nir.h"
1050 #include "nir_builder.h"
1051 #include "nir_search.h"
1052 #include "nir_search_helpers.h"
1053
1054 /* What follows is NIR algebraic transform code for the following ${len(xforms)}
1055 * transforms:
1056 % for xform in xforms:
1057 * ${xform.search} => ${xform.replace}
1058 % endfor
1059 */
1060
1061 <% cache = {} %>
1062 % for xform in xforms:
1063 ${xform.search.render(cache)}
1064 ${xform.replace.render(cache)}
1065 % endfor
1066
1067 % for state_id, state_xforms in enumerate(automaton.state_patterns):
1068 % if state_xforms: # avoid emitting a 0-length array for MSVC
1069 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1070 % for i in state_xforms:
1071 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1072 % endfor
1073 };
1074 % endif
1075 % endfor
1076
1077 static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1078 % for op in automaton.opcodes:
1079 [${get_c_opcode(op)}] = {
1080 .filter = (uint16_t []) {
1081 % for e in automaton.filter[op]:
1082 ${e},
1083 % endfor
1084 },
1085 <%
1086 num_filtered = len(automaton.rep[op])
1087 %>
1088 .num_filtered_states = ${num_filtered},
1089 .table = (uint16_t []) {
1090 <%
1091 num_srcs = len(next(iter(automaton.table[op])))
1092 %>
1093 % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1094 ${automaton.table[op][indices]},
1095 % endfor
1096 },
1097 },
1098 % endfor
1099 };
1100
1101 const struct transform *${pass_name}_transforms[] = {
1102 % for i in range(len(automaton.state_patterns)):
1103 % if automaton.state_patterns[i]:
1104 ${pass_name}_state${i}_xforms,
1105 % else:
1106 NULL,
1107 % endif
1108 % endfor
1109 };
1110
1111 const uint16_t ${pass_name}_transform_counts[] = {
1112 % for i in range(len(automaton.state_patterns)):
1113 % if automaton.state_patterns[i]:
1114 (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
1115 % else:
1116 0,
1117 % endif
1118 % endfor
1119 };
1120
1121 bool
1122 ${pass_name}(nir_shader *shader)
1123 {
1124 bool progress = false;
1125 bool condition_flags[${len(condition_list)}];
1126 const nir_shader_compiler_options *options = shader->options;
1127 const shader_info *info = &shader->info;
1128 (void) options;
1129 (void) info;
1130
1131 % for index, condition in enumerate(condition_list):
1132 condition_flags[${index}] = ${condition};
1133 % endfor
1134
1135 nir_foreach_function(function, shader) {
1136 if (function->impl) {
1137 progress |= nir_algebraic_impl(function->impl, condition_flags,
1138 ${pass_name}_transforms,
1139 ${pass_name}_transform_counts,
1140 ${pass_name}_table);
1141 }
1142 }
1143
1144 return progress;
1145 }
1146 """)
1147
1148
1149 class AlgebraicPass(object):
1150 def __init__(self, pass_name, transforms):
1151 self.xforms = []
1152 self.opcode_xforms = defaultdict(lambda : [])
1153 self.pass_name = pass_name
1154
1155 error = False
1156
1157 for xform in transforms:
1158 if not isinstance(xform, SearchAndReplace):
1159 try:
1160 xform = SearchAndReplace(xform)
1161 except:
1162 print("Failed to parse transformation:", file=sys.stderr)
1163 print(" " + str(xform), file=sys.stderr)
1164 traceback.print_exc(file=sys.stderr)
1165 print('', file=sys.stderr)
1166 error = True
1167 continue
1168
1169 self.xforms.append(xform)
1170 if xform.search.opcode in conv_opcode_types:
1171 dst_type = conv_opcode_types[xform.search.opcode]
1172 for size in type_sizes(dst_type):
1173 sized_opcode = xform.search.opcode + str(size)
1174 self.opcode_xforms[sized_opcode].append(xform)
1175 else:
1176 self.opcode_xforms[xform.search.opcode].append(xform)
1177
1178 # Check to make sure the search pattern does not unexpectedly contain
1179 # more commutative expressions than match_expression (nir_search.c)
1180 # can handle.
1181 comm_exprs = xform.search.comm_exprs
1182
1183 if xform.search.many_commutative_expressions:
1184 if comm_exprs <= nir_search_max_comm_ops:
1185 print("Transform expected to have too many commutative " \
1186 "expression but did not " \
1187 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1188 file=sys.stderr)
1189 print(" " + str(xform), file=sys.stderr)
1190 traceback.print_exc(file=sys.stderr)
1191 print('', file=sys.stderr)
1192 error = True
1193 else:
1194 if comm_exprs > nir_search_max_comm_ops:
1195 print("Transformation with too many commutative expressions " \
1196 "({} > {}). Modify pattern or annotate with " \
1197 "\"many-comm-expr\".".format(comm_exprs,
1198 nir_search_max_comm_ops),
1199 file=sys.stderr)
1200 print(" " + str(xform.search), file=sys.stderr)
1201 print("{}".format(xform.search.cond), file=sys.stderr)
1202 error = True
1203
1204 self.automaton = TreeAutomaton(self.xforms)
1205
1206 if error:
1207 sys.exit(1)
1208
1209
1210 def render(self):
1211 return _algebraic_pass_template.render(pass_name=self.pass_name,
1212 xforms=self.xforms,
1213 opcode_xforms=self.opcode_xforms,
1214 condition_list=condition_list,
1215 automaton=self.automaton,
1216 get_c_opcode=get_c_opcode,
1217 itertools=itertools)