Rename nir_lower_constant_initializers to nir_lower_variable_initalizers
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import defaultdict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes, type_sizes
37
38 # This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
39 nir_search_max_comm_ops = 8
40
41 # These opcodes are only employed by nir_search. This provides a mapping from
42 # opcode to destination type.
43 conv_opcode_types = {
44 'i2f' : 'float',
45 'u2f' : 'float',
46 'f2f' : 'float',
47 'f2u' : 'uint',
48 'f2i' : 'int',
49 'u2u' : 'uint',
50 'i2i' : 'int',
51 'b2f' : 'float',
52 'b2i' : 'int',
53 'i2b' : 'bool',
54 'f2b' : 'bool',
55 }
56
57 def get_c_opcode(op):
58 if op in conv_opcode_types:
59 return 'nir_search_op_' + op
60 else:
61 return 'nir_op_' + op
62
63
64 if sys.version_info < (3, 0):
65 integer_types = (int, long)
66 string_type = unicode
67
68 else:
69 integer_types = (int, )
70 string_type = str
71
72 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
73
74 def type_bits(type_str):
75 m = _type_re.match(type_str)
76 assert m.group('type')
77
78 if m.group('bits') is None:
79 return 0
80 else:
81 return int(m.group('bits'))
82
83 # Represents a set of variables, each with a unique id
84 class VarSet(object):
85 def __init__(self):
86 self.names = {}
87 self.ids = itertools.count()
88 self.immutable = False;
89
90 def __getitem__(self, name):
91 if name not in self.names:
92 assert not self.immutable, "Unknown replacement variable: " + name
93 self.names[name] = next(self.ids)
94
95 return self.names[name]
96
97 def lock(self):
98 self.immutable = True
99
100 class Value(object):
101 @staticmethod
102 def create(val, name_base, varset):
103 if isinstance(val, bytes):
104 val = val.decode('utf-8')
105
106 if isinstance(val, tuple):
107 return Expression(val, name_base, varset)
108 elif isinstance(val, Expression):
109 return val
110 elif isinstance(val, string_type):
111 return Variable(val, name_base, varset)
112 elif isinstance(val, (bool, float) + integer_types):
113 return Constant(val, name_base)
114
115 def __init__(self, val, name, type_str):
116 self.in_val = str(val)
117 self.name = name
118 self.type_str = type_str
119
120 def __str__(self):
121 return self.in_val
122
123 def get_bit_size(self):
124 """Get the physical bit-size that has been chosen for this value, or if
125 there is none, the canonical value which currently represents this
126 bit-size class. Variables will be preferred, i.e. if there are any
127 variables in the equivalence class, the canonical value will be a
128 variable. We do this since we'll need to know which variable each value
129 is equivalent to when constructing the replacement expression. This is
130 the "find" part of the union-find algorithm.
131 """
132 bit_size = self
133
134 while isinstance(bit_size, Value):
135 if bit_size._bit_size is None:
136 break
137 bit_size = bit_size._bit_size
138
139 if bit_size is not self:
140 self._bit_size = bit_size
141 return bit_size
142
143 def set_bit_size(self, other):
144 """Make self.get_bit_size() return what other.get_bit_size() return
145 before calling this, or just "other" if it's a concrete bit-size. This is
146 the "union" part of the union-find algorithm.
147 """
148
149 self_bit_size = self.get_bit_size()
150 other_bit_size = other if isinstance(other, int) else other.get_bit_size()
151
152 if self_bit_size == other_bit_size:
153 return
154
155 self_bit_size._bit_size = other_bit_size
156
157 @property
158 def type_enum(self):
159 return "nir_search_value_" + self.type_str
160
161 @property
162 def c_type(self):
163 return "nir_search_" + self.type_str
164
165 def __c_name(self, cache):
166 if cache is not None and self.name in cache:
167 return cache[self.name]
168 else:
169 return self.name
170
171 def c_value_ptr(self, cache):
172 return "&{0}.value".format(self.__c_name(cache))
173
174 def c_ptr(self, cache):
175 return "&{0}".format(self.__c_name(cache))
176
177 @property
178 def c_bit_size(self):
179 bit_size = self.get_bit_size()
180 if isinstance(bit_size, int):
181 return bit_size
182 elif isinstance(bit_size, Variable):
183 return -bit_size.index - 1
184 else:
185 # If the bit-size class is neither a variable, nor an actual bit-size, then
186 # - If it's in the search expression, we don't need to check anything
187 # - If it's in the replace expression, either it's ambiguous (in which
188 # case we'd reject it), or it equals the bit-size of the search value
189 # We represent these cases with a 0 bit-size.
190 return 0
191
192 __template = mako.template.Template("""{
193 { ${val.type_enum}, ${val.c_bit_size} },
194 % if isinstance(val, Constant):
195 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
196 % elif isinstance(val, Variable):
197 ${val.index}, /* ${val.var_name} */
198 ${'true' if val.is_constant else 'false'},
199 ${val.type() or 'nir_type_invalid' },
200 ${val.cond if val.cond else 'NULL'},
201 ${val.swizzle()},
202 % elif isinstance(val, Expression):
203 ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'},
204 ${val.comm_expr_idx}, ${val.comm_exprs},
205 ${val.c_opcode()},
206 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} },
207 ${val.cond if val.cond else 'NULL'},
208 % endif
209 };""")
210
211 def render(self, cache):
212 struct_init = self.__template.render(val=self, cache=cache,
213 Constant=Constant,
214 Variable=Variable,
215 Expression=Expression)
216 if cache is not None and struct_init in cache:
217 # If it's in the cache, register a name remap in the cache and render
218 # only a comment saying it's been remapped
219 cache[self.name] = cache[struct_init]
220 return "/* {} -> {} in the cache */\n".format(self.name,
221 cache[struct_init])
222 else:
223 if cache is not None:
224 cache[struct_init] = self.name
225 return "static const {} {} = {}\n".format(self.c_type, self.name,
226 struct_init)
227
228 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
229
230 class Constant(Value):
231 def __init__(self, val, name):
232 Value.__init__(self, val, name, "constant")
233
234 if isinstance(val, (str)):
235 m = _constant_re.match(val)
236 self.value = ast.literal_eval(m.group('value'))
237 self._bit_size = int(m.group('bits')) if m.group('bits') else None
238 else:
239 self.value = val
240 self._bit_size = None
241
242 if isinstance(self.value, bool):
243 assert self._bit_size is None or self._bit_size == 1
244 self._bit_size = 1
245
246 def hex(self):
247 if isinstance(self.value, (bool)):
248 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
249 if isinstance(self.value, integer_types):
250 return hex(self.value)
251 elif isinstance(self.value, float):
252 i = struct.unpack('Q', struct.pack('d', self.value))[0]
253 h = hex(i)
254
255 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
256 # Adding it explicitly makes the generated file identical, regardless
257 # of the Python version running this script.
258 if h[-1] != 'L' and i > sys.maxsize:
259 h += 'L'
260
261 return h
262 else:
263 assert False
264
265 def type(self):
266 if isinstance(self.value, (bool)):
267 return "nir_type_bool"
268 elif isinstance(self.value, integer_types):
269 return "nir_type_int"
270 elif isinstance(self.value, float):
271 return "nir_type_float"
272
273 def equivalent(self, other):
274 """Check that two constants are equivalent.
275
276 This is check is much weaker than equality. One generally cannot be
277 used in place of the other. Using this implementation for the __eq__
278 will break BitSizeValidator.
279
280 """
281 if not isinstance(other, type(self)):
282 return False
283
284 return self.value == other.value
285
286 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
287 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
288 r"(?P<cond>\([^\)]+\))?"
289 r"(?P<swiz>\.[xyzw]+)?")
290
291 class Variable(Value):
292 def __init__(self, val, name, varset):
293 Value.__init__(self, val, name, "variable")
294
295 m = _var_name_re.match(val)
296 assert m and m.group('name') is not None
297
298 self.var_name = m.group('name')
299
300 # Prevent common cases where someone puts quotes around a literal
301 # constant. If we want to support names that have numeric or
302 # punctuation characters, we can me the first assertion more flexible.
303 assert self.var_name.isalpha()
304 assert self.var_name != 'True'
305 assert self.var_name != 'False'
306
307 self.is_constant = m.group('const') is not None
308 self.cond = m.group('cond')
309 self.required_type = m.group('type')
310 self._bit_size = int(m.group('bits')) if m.group('bits') else None
311 self.swiz = m.group('swiz')
312
313 if self.required_type == 'bool':
314 if self._bit_size is not None:
315 assert self._bit_size in type_sizes(self.required_type)
316 else:
317 self._bit_size = 1
318
319 if self.required_type is not None:
320 assert self.required_type in ('float', 'bool', 'int', 'uint')
321
322 self.index = varset[self.var_name]
323
324 def type(self):
325 if self.required_type == 'bool':
326 return "nir_type_bool"
327 elif self.required_type in ('int', 'uint'):
328 return "nir_type_int"
329 elif self.required_type == 'float':
330 return "nir_type_float"
331
332 def equivalent(self, other):
333 """Check that two variables are equivalent.
334
335 This is check is much weaker than equality. One generally cannot be
336 used in place of the other. Using this implementation for the __eq__
337 will break BitSizeValidator.
338
339 """
340 if not isinstance(other, type(self)):
341 return False
342
343 return self.index == other.index
344
345 def swizzle(self):
346 if self.swiz is not None:
347 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w': 3}
348 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
349 return '{0, 1, 2, 3}'
350
351 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
352 r"(?P<cond>\([^\)]+\))?")
353
354 class Expression(Value):
355 def __init__(self, expr, name_base, varset):
356 Value.__init__(self, expr, name_base, "expression")
357 assert isinstance(expr, tuple)
358
359 m = _opcode_re.match(expr[0])
360 assert m and m.group('opcode') is not None
361
362 self.opcode = m.group('opcode')
363 self._bit_size = int(m.group('bits')) if m.group('bits') else None
364 self.inexact = m.group('inexact') is not None
365 self.exact = m.group('exact') is not None
366 self.cond = m.group('cond')
367
368 assert not self.inexact or not self.exact, \
369 'Expression cannot be both exact and inexact.'
370
371 # "many-comm-expr" isn't really a condition. It's notification to the
372 # generator that this pattern is known to have too many commutative
373 # expressions, and an error should not be generated for this case.
374 self.many_commutative_expressions = False
375 if self.cond and self.cond.find("many-comm-expr") >= 0:
376 # Split the condition into a comma-separated list. Remove
377 # "many-comm-expr". If there is anything left, put it back together.
378 c = self.cond[1:-1].split(",")
379 c.remove("many-comm-expr")
380
381 self.cond = "({})".format(",".join(c)) if c else None
382 self.many_commutative_expressions = True
383
384 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
385 for (i, src) in enumerate(expr[1:]) ]
386
387 if self.opcode in conv_opcode_types:
388 assert self._bit_size is None, \
389 'Expression cannot use an unsized conversion opcode with ' \
390 'an explicit size; that\'s silly.'
391
392 self.__index_comm_exprs(0)
393
394 def equivalent(self, other):
395 """Check that two variables are equivalent.
396
397 This is check is much weaker than equality. One generally cannot be
398 used in place of the other. Using this implementation for the __eq__
399 will break BitSizeValidator.
400
401 This implementation does not check for equivalence due to commutativity,
402 but it could.
403
404 """
405 if not isinstance(other, type(self)):
406 return False
407
408 if len(self.sources) != len(other.sources):
409 return False
410
411 if self.opcode != other.opcode:
412 return False
413
414 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
415
416 def __index_comm_exprs(self, base_idx):
417 """Recursively count and index commutative expressions
418 """
419 self.comm_exprs = 0
420
421 # A note about the explicit "len(self.sources)" check. The list of
422 # sources comes from user input, and that input might be bad. Check
423 # that the expected second source exists before accessing it. Without
424 # this check, a unit test that does "('iadd', 'a')" will crash.
425 if self.opcode not in conv_opcode_types and \
426 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
427 len(self.sources) >= 2 and \
428 not self.sources[0].equivalent(self.sources[1]):
429 self.comm_expr_idx = base_idx
430 self.comm_exprs += 1
431 else:
432 self.comm_expr_idx = -1
433
434 for s in self.sources:
435 if isinstance(s, Expression):
436 s.__index_comm_exprs(base_idx + self.comm_exprs)
437 self.comm_exprs += s.comm_exprs
438
439 return self.comm_exprs
440
441 def c_opcode(self):
442 return get_c_opcode(self.opcode)
443
444 def render(self, cache):
445 srcs = "\n".join(src.render(cache) for src in self.sources)
446 return srcs + super(Expression, self).render(cache)
447
448 class BitSizeValidator(object):
449 """A class for validating bit sizes of expressions.
450
451 NIR supports multiple bit-sizes on expressions in order to handle things
452 such as fp64. The source and destination of every ALU operation is
453 assigned a type and that type may or may not specify a bit size. Sources
454 and destinations whose type does not specify a bit size are considered
455 "unsized" and automatically take on the bit size of the corresponding
456 register or SSA value. NIR has two simple rules for bit sizes that are
457 validated by nir_validator:
458
459 1) A given SSA def or register has a single bit size that is respected by
460 everything that reads from it or writes to it.
461
462 2) The bit sizes of all unsized inputs/outputs on any given ALU
463 instruction must match. They need not match the sized inputs or
464 outputs but they must match each other.
465
466 In order to keep nir_algebraic relatively simple and easy-to-use,
467 nir_search supports a type of bit-size inference based on the two rules
468 above. This is similar to type inference in many common programming
469 languages. If, for instance, you are constructing an add operation and you
470 know the second source is 16-bit, then you know that the other source and
471 the destination must also be 16-bit. There are, however, cases where this
472 inference can be ambiguous or contradictory. Consider, for instance, the
473 following transformation:
474
475 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
476
477 This transformation can potentially cause a problem because usub_borrow is
478 well-defined for any bit-size of integer. However, b2i always generates a
479 32-bit result so it could end up replacing a 64-bit expression with one
480 that takes two 64-bit values and produces a 32-bit value. As another
481 example, consider this expression:
482
483 (('bcsel', a, b, 0), ('iand', a, b))
484
485 In this case, in the search expression a must be 32-bit but b can
486 potentially have any bit size. If we had a 64-bit b value, we would end up
487 trying to and a 32-bit value with a 64-bit value which would be invalid
488
489 This class solves that problem by providing a validation layer that proves
490 that a given search-and-replace operation is 100% well-defined before we
491 generate any code. This ensures that bugs are caught at compile time
492 rather than at run time.
493
494 Each value maintains a "bit-size class", which is either an actual bit size
495 or an equivalence class with other values that must have the same bit size.
496 The validator works by combining bit-size classes with each other according
497 to the NIR rules outlined above, checking that there are no inconsistencies.
498 When doing this for the replacement expression, we make sure to never change
499 the equivalence class of any of the search values. We could make the example
500 transforms above work by doing some extra run-time checking of the search
501 expression, but we make the user specify those constraints themselves, to
502 avoid any surprises. Since the replacement bitsizes can only be connected to
503 the source bitsize via variables (variables must have the same bitsize in
504 the source and replacment expressions) or the roots of the expression (the
505 replacement expression must produce the same bit size as the search
506 expression), we prevent merging a variable with anything when processing the
507 replacement expression, or specializing the search bitsize
508 with anything. The former prevents
509
510 (('bcsel', a, b, 0), ('iand', a, b))
511
512 from being allowed, since we'd have to merge the bitsizes for a and b due to
513 the 'iand', while the latter prevents
514
515 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
516
517 from being allowed, since the search expression has the bit size of a and b,
518 which can't be specialized to 32 which is the bitsize of the replace
519 expression. It also prevents something like:
520
521 (('b2i', ('i2b', a)), ('ineq', a, 0))
522
523 since the bitsize of 'b2i', which can be anything, can't be specialized to
524 the bitsize of a.
525
526 After doing all this, we check that every subexpression of the replacement
527 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
528 of the search expresssion, since those are the things that are known when
529 constructing the replacement expresssion. Finally, we record the bitsize
530 needed in nir_search_value so that we know what to do when building the
531 replacement expression.
532 """
533
534 def __init__(self, varset):
535 self._var_classes = [None] * len(varset.names)
536
537 def compare_bitsizes(self, a, b):
538 """Determines which bitsize class is a specialization of the other, or
539 whether neither is. When we merge two different bitsizes, the
540 less-specialized bitsize always points to the more-specialized one, so
541 that calling get_bit_size() always gets you the most specialized bitsize.
542 The specialization partial order is given by:
543 - Physical bitsizes are always the most specialized, and a different
544 bitsize can never specialize another.
545 - In the search expression, variables can always be specialized to each
546 other and to physical bitsizes. In the replace expression, we disallow
547 this to avoid adding extra constraints to the search expression that
548 the user didn't specify.
549 - Expressions and constants without a bitsize can always be specialized to
550 each other and variables, but not the other way around.
551
552 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
553 and None if they are not comparable (neither a <= b nor b <= a).
554 """
555 if isinstance(a, int):
556 if isinstance(b, int):
557 return 0 if a == b else None
558 elif isinstance(b, Variable):
559 return -1 if self.is_search else None
560 else:
561 return -1
562 elif isinstance(a, Variable):
563 if isinstance(b, int):
564 return 1 if self.is_search else None
565 elif isinstance(b, Variable):
566 return 0 if self.is_search or a.index == b.index else None
567 else:
568 return -1
569 else:
570 if isinstance(b, int):
571 return 1
572 elif isinstance(b, Variable):
573 return 1
574 else:
575 return 0
576
577 def unify_bit_size(self, a, b, error_msg):
578 """Record that a must have the same bit-size as b. If both
579 have been assigned conflicting physical bit-sizes, call "error_msg" with
580 the bit-sizes of self and other to get a message and raise an error.
581 In the replace expression, disallow merging variables with other
582 variables and physical bit-sizes as well.
583 """
584 a_bit_size = a.get_bit_size()
585 b_bit_size = b if isinstance(b, int) else b.get_bit_size()
586
587 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
588
589 assert cmp_result is not None, \
590 error_msg(a_bit_size, b_bit_size)
591
592 if cmp_result < 0:
593 b_bit_size.set_bit_size(a)
594 elif not isinstance(a_bit_size, int):
595 a_bit_size.set_bit_size(b)
596
597 def merge_variables(self, val):
598 """Perform the first part of type inference by merging all the different
599 uses of the same variable. We always do this as if we're in the search
600 expression, even if we're actually not, since otherwise we'd get errors
601 if the search expression specified some constraint but the replace
602 expression didn't, because we'd be merging a variable and a constant.
603 """
604 if isinstance(val, Variable):
605 if self._var_classes[val.index] is None:
606 self._var_classes[val.index] = val
607 else:
608 other = self._var_classes[val.index]
609 self.unify_bit_size(other, val,
610 lambda other_bit_size, bit_size:
611 'Variable {} has conflicting bit size requirements: ' \
612 'it must have bit size {} and {}'.format(
613 val.var_name, other_bit_size, bit_size))
614 elif isinstance(val, Expression):
615 for src in val.sources:
616 self.merge_variables(src)
617
618 def validate_value(self, val):
619 """Validate the an expression by performing classic Hindley-Milner
620 type inference on bitsizes. This will detect if there are any conflicting
621 requirements, and unify variables so that we know which variables must
622 have the same bitsize. If we're operating on the replace expression, we
623 will refuse to merge different variables together or merge a variable
624 with a constant, in order to prevent surprises due to rules unexpectedly
625 not matching at runtime.
626 """
627 if not isinstance(val, Expression):
628 return
629
630 # Generic conversion ops are special in that they have a single unsized
631 # source and an unsized destination and the two don't have to match.
632 # This means there's no validation or unioning to do here besides the
633 # len(val.sources) check.
634 if val.opcode in conv_opcode_types:
635 assert len(val.sources) == 1, \
636 "Expression {} has {} sources, expected 1".format(
637 val, len(val.sources))
638 self.validate_value(val.sources[0])
639 return
640
641 nir_op = opcodes[val.opcode]
642 assert len(val.sources) == nir_op.num_inputs, \
643 "Expression {} has {} sources, expected {}".format(
644 val, len(val.sources), nir_op.num_inputs)
645
646 for src in val.sources:
647 self.validate_value(src)
648
649 dst_type_bits = type_bits(nir_op.output_type)
650
651 # First, unify all the sources. That way, an error coming up because two
652 # sources have an incompatible bit-size won't produce an error message
653 # involving the destination.
654 first_unsized_src = None
655 for src_type, src in zip(nir_op.input_types, val.sources):
656 src_type_bits = type_bits(src_type)
657 if src_type_bits == 0:
658 if first_unsized_src is None:
659 first_unsized_src = src
660 continue
661
662 if self.is_search:
663 self.unify_bit_size(first_unsized_src, src,
664 lambda first_unsized_src_bit_size, src_bit_size:
665 'Source {} of {} must have bit size {}, while source {} ' \
666 'must have incompatible bit size {}'.format(
667 first_unsized_src, val, first_unsized_src_bit_size,
668 src, src_bit_size))
669 else:
670 self.unify_bit_size(first_unsized_src, src,
671 lambda first_unsized_src_bit_size, src_bit_size:
672 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
673 'of {} may not have the same bit size when building the ' \
674 'replacement expression.'.format(
675 first_unsized_src, first_unsized_src_bit_size, src,
676 src_bit_size, val))
677 else:
678 if self.is_search:
679 self.unify_bit_size(src, src_type_bits,
680 lambda src_bit_size, unused:
681 '{} must have {} bits, but as a source of nir_op_{} '\
682 'it must have {} bits'.format(
683 src, src_bit_size, nir_op.name, src_type_bits))
684 else:
685 self.unify_bit_size(src, src_type_bits,
686 lambda src_bit_size, unused:
687 '{} has the bit size of {}, but as a source of ' \
688 'nir_op_{} it must have {} bits, which may not be the ' \
689 'same'.format(
690 src, src_bit_size, nir_op.name, src_type_bits))
691
692 if dst_type_bits == 0:
693 if first_unsized_src is not None:
694 if self.is_search:
695 self.unify_bit_size(val, first_unsized_src,
696 lambda val_bit_size, src_bit_size:
697 '{} must have the bit size of {}, while its source {} ' \
698 'must have incompatible bit size {}'.format(
699 val, val_bit_size, first_unsized_src, src_bit_size))
700 else:
701 self.unify_bit_size(val, first_unsized_src,
702 lambda val_bit_size, src_bit_size:
703 '{} must have {} bits, but its source {} ' \
704 '(bit size of {}) may not have that bit size ' \
705 'when building the replacement.'.format(
706 val, val_bit_size, first_unsized_src, src_bit_size))
707 else:
708 self.unify_bit_size(val, dst_type_bits,
709 lambda dst_bit_size, unused:
710 '{} must have {} bits, but as a destination of nir_op_{} ' \
711 'it must have {} bits'.format(
712 val, dst_bit_size, nir_op.name, dst_type_bits))
713
714 def validate_replace(self, val, search):
715 bit_size = val.get_bit_size()
716 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
717 bit_size == search.get_bit_size(), \
718 'Ambiguous bit size for replacement value {}: ' \
719 'it cannot be deduced from a variable, a fixed bit size ' \
720 'somewhere, or the search expression.'.format(val)
721
722 if isinstance(val, Expression):
723 for src in val.sources:
724 self.validate_replace(src, search)
725
726 def validate(self, search, replace):
727 self.is_search = True
728 self.merge_variables(search)
729 self.merge_variables(replace)
730 self.validate_value(search)
731
732 self.is_search = False
733 self.validate_value(replace)
734
735 # Check that search is always more specialized than replace. Note that
736 # we're doing this in replace mode, disallowing merging variables.
737 search_bit_size = search.get_bit_size()
738 replace_bit_size = replace.get_bit_size()
739 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
740
741 assert cmp_result is not None and cmp_result <= 0, \
742 'The search expression bit size {} and replace expression ' \
743 'bit size {} may not be the same'.format(
744 search_bit_size, replace_bit_size)
745
746 replace.set_bit_size(search)
747
748 self.validate_replace(replace, search)
749
750 _optimization_ids = itertools.count()
751
752 condition_list = ['true']
753
754 class SearchAndReplace(object):
755 def __init__(self, transform):
756 self.id = next(_optimization_ids)
757
758 search = transform[0]
759 replace = transform[1]
760 if len(transform) > 2:
761 self.condition = transform[2]
762 else:
763 self.condition = 'true'
764
765 if self.condition not in condition_list:
766 condition_list.append(self.condition)
767 self.condition_index = condition_list.index(self.condition)
768
769 varset = VarSet()
770 if isinstance(search, Expression):
771 self.search = search
772 else:
773 self.search = Expression(search, "search{0}".format(self.id), varset)
774
775 varset.lock()
776
777 if isinstance(replace, Value):
778 self.replace = replace
779 else:
780 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
781
782 BitSizeValidator(varset).validate(self.search, self.replace)
783
784 class TreeAutomaton(object):
785 """This class calculates a bottom-up tree automaton to quickly search for
786 the left-hand sides of tranforms. Tree automatons are a generalization of
787 classical NFA's and DFA's, where the transition function determines the
788 state of the parent node based on the state of its children. We construct a
789 deterministic automaton to match patterns, using a similar algorithm to the
790 classical NFA to DFA construction. At the moment, it only matches opcodes
791 and constants (without checking the actual value), leaving more detailed
792 checking to the search function which actually checks the leaves. The
793 automaton acts as a quick filter for the search function, requiring only n
794 + 1 table lookups for each n-source operation. The implementation is based
795 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
796 In the language of that reference, this is a frontier-to-root deterministic
797 automaton using only symbol filtering. The filtering is crucial to reduce
798 both the time taken to generate the tables and the size of the tables.
799 """
800 def __init__(self, transforms):
801 self.patterns = [t.search for t in transforms]
802 self._compute_items()
803 self._build_table()
804 #print('num items: {}'.format(len(set(self.items.values()))))
805 #print('num states: {}'.format(len(self.states)))
806 #for state, patterns in zip(self.states, self.patterns):
807 # print('{}: num patterns: {}'.format(state, len(patterns)))
808
809 class IndexMap(object):
810 """An indexed list of objects, where one can either lookup an object by
811 index or find the index associated to an object quickly using a hash
812 table. Compared to a list, it has a constant time index(). Compared to a
813 set, it provides a stable iteration order.
814 """
815 def __init__(self, iterable=()):
816 self.objects = []
817 self.map = {}
818 for obj in iterable:
819 self.add(obj)
820
821 def __getitem__(self, i):
822 return self.objects[i]
823
824 def __contains__(self, obj):
825 return obj in self.map
826
827 def __len__(self):
828 return len(self.objects)
829
830 def __iter__(self):
831 return iter(self.objects)
832
833 def clear(self):
834 self.objects = []
835 self.map.clear()
836
837 def index(self, obj):
838 return self.map[obj]
839
840 def add(self, obj):
841 if obj in self.map:
842 return self.map[obj]
843 else:
844 index = len(self.objects)
845 self.objects.append(obj)
846 self.map[obj] = index
847 return index
848
849 def __repr__(self):
850 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
851
852 class Item(object):
853 """This represents an "item" in the language of "Tree Automatons." This
854 is just a subtree of some pattern, which represents a potential partial
855 match at runtime. We deduplicate them, so that identical subtrees of
856 different patterns share the same object, and store some extra
857 information needed for the main algorithm as well.
858 """
859 def __init__(self, opcode, children):
860 self.opcode = opcode
861 self.children = children
862 # These are the indices of patterns for which this item is the root node.
863 self.patterns = []
864 # This the set of opcodes for parents of this item. Used to speed up
865 # filtering.
866 self.parent_ops = set()
867
868 def __str__(self):
869 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
870
871 def __repr__(self):
872 return str(self)
873
874 def _compute_items(self):
875 """Build a set of all possible items, deduplicating them."""
876 # This is a map from (opcode, sources) to item.
877 self.items = {}
878
879 # The set of all opcodes used by the patterns. Used later to avoid
880 # building and emitting all the tables for opcodes that aren't used.
881 self.opcodes = self.IndexMap()
882
883 def get_item(opcode, children, pattern=None):
884 commutative = len(children) >= 2 \
885 and "2src_commutative" in opcodes[opcode].algebraic_properties
886 item = self.items.setdefault((opcode, children),
887 self.Item(opcode, children))
888 if commutative:
889 self.items[opcode, (children[1], children[0]) + children[2:]] = item
890 if pattern is not None:
891 item.patterns.append(pattern)
892 return item
893
894 self.wildcard = get_item("__wildcard", ())
895 self.const = get_item("__const", ())
896
897 def process_subpattern(src, pattern=None):
898 if isinstance(src, Constant):
899 # Note: we throw away the actual constant value!
900 return self.const
901 elif isinstance(src, Variable):
902 if src.is_constant:
903 return self.const
904 else:
905 # Note: we throw away which variable it is here! This special
906 # item is equivalent to nu in "Tree Automatons."
907 return self.wildcard
908 else:
909 assert isinstance(src, Expression)
910 opcode = src.opcode
911 stripped = opcode.rstrip('0123456789')
912 if stripped in conv_opcode_types:
913 # Matches that use conversion opcodes with a specific type,
914 # like f2b1, are tricky. Either we construct the automaton to
915 # match specific NIR opcodes like nir_op_f2b1, in which case we
916 # need to create separate items for each possible NIR opcode
917 # for patterns that have a generic opcode like f2b, or we
918 # construct it to match the search opcode, in which case we
919 # need to map f2b1 to f2b when constructing the automaton. Here
920 # we do the latter.
921 opcode = stripped
922 self.opcodes.add(opcode)
923 children = tuple(process_subpattern(c) for c in src.sources)
924 item = get_item(opcode, children, pattern)
925 for i, child in enumerate(children):
926 child.parent_ops.add(opcode)
927 return item
928
929 for i, pattern in enumerate(self.patterns):
930 process_subpattern(pattern, i)
931
932 def _build_table(self):
933 """This is the core algorithm which builds up the transition table. It
934 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
935 Comp_a and Filt_{a,i} using integers to identify match sets." It
936 simultaneously builds up a list of all possible "match sets" or
937 "states", where each match set represents the set of Item's that match a
938 given instruction, and builds up the transition table between states.
939 """
940 # Map from opcode + filtered state indices to transitioned state.
941 self.table = defaultdict(dict)
942 # Bijection from state to index. q in the original algorithm is
943 # len(self.states)
944 self.states = self.IndexMap()
945 # List of pattern matches for each state index.
946 self.state_patterns = []
947 # Map from state index to filtered state index for each opcode.
948 self.filter = defaultdict(list)
949 # Bijections from filtered state to filtered state index for each
950 # opcode, called the "representor sets" in the original algorithm.
951 # q_{a,j} in the original algorithm is len(self.rep[op]).
952 self.rep = defaultdict(self.IndexMap)
953
954 # Everything in self.states with a index at least worklist_index is part
955 # of the worklist of newly created states. There is also a worklist of
956 # newly fitered states for each opcode, for which worklist_indices
957 # serves a similar purpose. worklist_index corresponds to p in the
958 # original algorithm, while worklist_indices is p_{a,j} (although since
959 # we only filter by opcode/symbol, it's really just p_a).
960 self.worklist_index = 0
961 worklist_indices = defaultdict(lambda: 0)
962
963 # This is the set of opcodes for which the filtered worklist is non-empty.
964 # It's used to avoid scanning opcodes for which there is nothing to
965 # process when building the transition table. It corresponds to new_a in
966 # the original algorithm.
967 new_opcodes = self.IndexMap()
968
969 # Process states on the global worklist, filtering them for each opcode,
970 # updating the filter tables, and updating the filtered worklists if any
971 # new filtered states are found. Similar to ComputeRepresenterSets() in
972 # the original algorithm, although that only processes a single state.
973 def process_new_states():
974 while self.worklist_index < len(self.states):
975 state = self.states[self.worklist_index]
976
977 # Calculate pattern matches for this state. Each pattern is
978 # assigned to a unique item, so we don't have to worry about
979 # deduplicating them here. However, we do have to sort them so
980 # that they're visited at runtime in the order they're specified
981 # in the source.
982 patterns = list(sorted(p for item in state for p in item.patterns))
983 assert len(self.state_patterns) == self.worklist_index
984 self.state_patterns.append(patterns)
985
986 # calculate filter table for this state, and update filtered
987 # worklists.
988 for op in self.opcodes:
989 filt = self.filter[op]
990 rep = self.rep[op]
991 filtered = frozenset(item for item in state if \
992 op in item.parent_ops)
993 if filtered in rep:
994 rep_index = rep.index(filtered)
995 else:
996 rep_index = rep.add(filtered)
997 new_opcodes.add(op)
998 assert len(filt) == self.worklist_index
999 filt.append(rep_index)
1000 self.worklist_index += 1
1001
1002 # There are two start states: one which can only match as a wildcard,
1003 # and one which can match as a wildcard or constant. These will be the
1004 # states of intrinsics/other instructions and load_const instructions,
1005 # respectively. The indices of these must match the definitions of
1006 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1007 # initialize things correctly.
1008 self.states.add(frozenset((self.wildcard,)))
1009 self.states.add(frozenset((self.const,self.wildcard)))
1010 process_new_states()
1011
1012 while len(new_opcodes) > 0:
1013 for op in new_opcodes:
1014 rep = self.rep[op]
1015 table = self.table[op]
1016 op_worklist_index = worklist_indices[op]
1017 if op in conv_opcode_types:
1018 num_srcs = 1
1019 else:
1020 num_srcs = opcodes[op].num_inputs
1021
1022 # Iterate over all possible source combinations where at least one
1023 # is on the worklist.
1024 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1025 if all(src_idx < op_worklist_index for src_idx in src_indices):
1026 continue
1027
1028 srcs = tuple(rep[src_idx] for src_idx in src_indices)
1029
1030 # Try all possible pairings of source items and add the
1031 # corresponding parent items. This is Comp_a from the paper.
1032 parent = set(self.items[op, item_srcs] for item_srcs in
1033 itertools.product(*srcs) if (op, item_srcs) in self.items)
1034
1035 # We could always start matching something else with a
1036 # wildcard. This is Cl from the paper.
1037 parent.add(self.wildcard)
1038
1039 table[src_indices] = self.states.add(frozenset(parent))
1040 worklist_indices[op] = len(rep)
1041 new_opcodes.clear()
1042 process_new_states()
1043
1044 _algebraic_pass_template = mako.template.Template("""
1045 #include "nir.h"
1046 #include "nir_builder.h"
1047 #include "nir_search.h"
1048 #include "nir_search_helpers.h"
1049
1050 /* What follows is NIR algebraic transform code for the following ${len(xforms)}
1051 * transforms:
1052 % for xform in xforms:
1053 * ${xform.search} => ${xform.replace}
1054 % endfor
1055 */
1056
1057 <% cache = {} %>
1058 % for xform in xforms:
1059 ${xform.search.render(cache)}
1060 ${xform.replace.render(cache)}
1061 % endfor
1062
1063 % for state_id, state_xforms in enumerate(automaton.state_patterns):
1064 % if state_xforms: # avoid emitting a 0-length array for MSVC
1065 static const struct transform ${pass_name}_state${state_id}_xforms[] = {
1066 % for i in state_xforms:
1067 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} },
1068 % endfor
1069 };
1070 % endif
1071 % endfor
1072
1073 static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = {
1074 % for op in automaton.opcodes:
1075 [${get_c_opcode(op)}] = {
1076 .filter = (uint16_t []) {
1077 % for e in automaton.filter[op]:
1078 ${e},
1079 % endfor
1080 },
1081 <%
1082 num_filtered = len(automaton.rep[op])
1083 %>
1084 .num_filtered_states = ${num_filtered},
1085 .table = (uint16_t []) {
1086 <%
1087 num_srcs = len(next(iter(automaton.table[op])))
1088 %>
1089 % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1090 ${automaton.table[op][indices]},
1091 % endfor
1092 },
1093 },
1094 % endfor
1095 };
1096
1097 const struct transform *${pass_name}_transforms[] = {
1098 % for i in range(len(automaton.state_patterns)):
1099 % if automaton.state_patterns[i]:
1100 ${pass_name}_state${i}_xforms,
1101 % else:
1102 NULL,
1103 % endif
1104 % endfor
1105 };
1106
1107 const uint16_t ${pass_name}_transform_counts[] = {
1108 % for i in range(len(automaton.state_patterns)):
1109 % if automaton.state_patterns[i]:
1110 (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms),
1111 % else:
1112 0,
1113 % endif
1114 % endfor
1115 };
1116
1117 bool
1118 ${pass_name}(nir_shader *shader)
1119 {
1120 bool progress = false;
1121 bool condition_flags[${len(condition_list)}];
1122 const nir_shader_compiler_options *options = shader->options;
1123 const shader_info *info = &shader->info;
1124 (void) options;
1125 (void) info;
1126
1127 % for index, condition in enumerate(condition_list):
1128 condition_flags[${index}] = ${condition};
1129 % endfor
1130
1131 nir_foreach_function(function, shader) {
1132 if (function->impl) {
1133 progress |= nir_algebraic_impl(function->impl, condition_flags,
1134 ${pass_name}_transforms,
1135 ${pass_name}_transform_counts,
1136 ${pass_name}_table);
1137 }
1138 }
1139
1140 return progress;
1141 }
1142 """)
1143
1144
1145 class AlgebraicPass(object):
1146 def __init__(self, pass_name, transforms):
1147 self.xforms = []
1148 self.opcode_xforms = defaultdict(lambda : [])
1149 self.pass_name = pass_name
1150
1151 error = False
1152
1153 for xform in transforms:
1154 if not isinstance(xform, SearchAndReplace):
1155 try:
1156 xform = SearchAndReplace(xform)
1157 except:
1158 print("Failed to parse transformation:", file=sys.stderr)
1159 print(" " + str(xform), file=sys.stderr)
1160 traceback.print_exc(file=sys.stderr)
1161 print('', file=sys.stderr)
1162 error = True
1163 continue
1164
1165 self.xforms.append(xform)
1166 if xform.search.opcode in conv_opcode_types:
1167 dst_type = conv_opcode_types[xform.search.opcode]
1168 for size in type_sizes(dst_type):
1169 sized_opcode = xform.search.opcode + str(size)
1170 self.opcode_xforms[sized_opcode].append(xform)
1171 else:
1172 self.opcode_xforms[xform.search.opcode].append(xform)
1173
1174 # Check to make sure the search pattern does not unexpectedly contain
1175 # more commutative expressions than match_expression (nir_search.c)
1176 # can handle.
1177 comm_exprs = xform.search.comm_exprs
1178
1179 if xform.search.many_commutative_expressions:
1180 if comm_exprs <= nir_search_max_comm_ops:
1181 print("Transform expected to have too many commutative " \
1182 "expression but did not " \
1183 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1184 file=sys.stderr)
1185 print(" " + str(xform), file=sys.stderr)
1186 traceback.print_exc(file=sys.stderr)
1187 print('', file=sys.stderr)
1188 error = True
1189 else:
1190 if comm_exprs > nir_search_max_comm_ops:
1191 print("Transformation with too many commutative expressions " \
1192 "({} > {}). Modify pattern or annotate with " \
1193 "\"many-comm-expr\".".format(comm_exprs,
1194 nir_search_max_comm_ops),
1195 file=sys.stderr)
1196 print(" " + str(xform.search), file=sys.stderr)
1197 print("{}".format(xform.search.cond), file=sys.stderr)
1198 error = True
1199
1200 self.automaton = TreeAutomaton(self.xforms)
1201
1202 if error:
1203 sys.exit(1)
1204
1205
1206 def render(self):
1207 return _algebraic_pass_template.render(pass_name=self.pass_name,
1208 xforms=self.xforms,
1209 opcode_xforms=self.opcode_xforms,
1210 condition_list=condition_list,
1211 automaton=self.automaton,
1212 get_c_opcode=get_c_opcode,
1213 itertools=itertools)