nir: fix crash in loop unroll corner case
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 import itertools
29 import struct
30 import sys
31 import mako.template
32 import re
33 import traceback
34
35 from nir_opcodes import opcodes
36
37 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
38
39 def type_bits(type_str):
40 m = _type_re.match(type_str)
41 assert m.group('type')
42
43 if m.group('bits') is None:
44 return 0
45 else:
46 return int(m.group('bits'))
47
48 # Represents a set of variables, each with a unique id
49 class VarSet(object):
50 def __init__(self):
51 self.names = {}
52 self.ids = itertools.count()
53 self.immutable = False;
54
55 def __getitem__(self, name):
56 if name not in self.names:
57 assert not self.immutable, "Unknown replacement variable: " + name
58 self.names[name] = self.ids.next()
59
60 return self.names[name]
61
62 def lock(self):
63 self.immutable = True
64
65 class Value(object):
66 @staticmethod
67 def create(val, name_base, varset):
68 if isinstance(val, tuple):
69 return Expression(val, name_base, varset)
70 elif isinstance(val, Expression):
71 return val
72 elif isinstance(val, (str, unicode)):
73 return Variable(val, name_base, varset)
74 elif isinstance(val, (bool, int, long, float)):
75 return Constant(val, name_base)
76
77 __template = mako.template.Template("""
78 static const ${val.c_type} ${val.name} = {
79 { ${val.type_enum}, ${val.bit_size} },
80 % if isinstance(val, Constant):
81 ${val.type()}, { ${hex(val)} /* ${val.value} */ },
82 % elif isinstance(val, Variable):
83 ${val.index}, /* ${val.var_name} */
84 ${'true' if val.is_constant else 'false'},
85 ${val.type() or 'nir_type_invalid' },
86 ${val.cond if val.cond else 'NULL'},
87 % elif isinstance(val, Expression):
88 ${'true' if val.inexact else 'false'},
89 nir_op_${val.opcode},
90 { ${', '.join(src.c_ptr for src in val.sources)} },
91 ${val.cond if val.cond else 'NULL'},
92 % endif
93 };""")
94
95 def __init__(self, name, type_str):
96 self.name = name
97 self.type_str = type_str
98
99 @property
100 def type_enum(self):
101 return "nir_search_value_" + self.type_str
102
103 @property
104 def c_type(self):
105 return "nir_search_" + self.type_str
106
107 @property
108 def c_ptr(self):
109 return "&{0}.value".format(self.name)
110
111 def render(self):
112 return self.__template.render(val=self,
113 Constant=Constant,
114 Variable=Variable,
115 Expression=Expression)
116
117 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
118
119 class Constant(Value):
120 def __init__(self, val, name):
121 Value.__init__(self, name, "constant")
122
123 if isinstance(val, (str)):
124 m = _constant_re.match(val)
125 self.value = ast.literal_eval(m.group('value'))
126 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
127 else:
128 self.value = val
129 self.bit_size = 0
130
131 if isinstance(self.value, bool):
132 assert self.bit_size == 0 or self.bit_size == 32
133 self.bit_size = 32
134
135 def __hex__(self):
136 if isinstance(self.value, (bool)):
137 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
138 if isinstance(self.value, (int, long)):
139 return hex(self.value)
140 elif isinstance(self.value, float):
141 return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
142 else:
143 assert False
144
145 def type(self):
146 if isinstance(self.value, (bool)):
147 return "nir_type_bool32"
148 elif isinstance(self.value, (int, long)):
149 return "nir_type_int"
150 elif isinstance(self.value, float):
151 return "nir_type_float"
152
153 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
154 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
155 r"(?P<cond>\([^\)]+\))?")
156
157 class Variable(Value):
158 def __init__(self, val, name, varset):
159 Value.__init__(self, name, "variable")
160
161 m = _var_name_re.match(val)
162 assert m and m.group('name') is not None
163
164 self.var_name = m.group('name')
165 self.is_constant = m.group('const') is not None
166 self.cond = m.group('cond')
167 self.required_type = m.group('type')
168 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
169
170 if self.required_type == 'bool':
171 assert self.bit_size == 0 or self.bit_size == 32
172 self.bit_size = 32
173
174 if self.required_type is not None:
175 assert self.required_type in ('float', 'bool', 'int', 'uint')
176
177 self.index = varset[self.var_name]
178
179 def type(self):
180 if self.required_type == 'bool':
181 return "nir_type_bool32"
182 elif self.required_type in ('int', 'uint'):
183 return "nir_type_int"
184 elif self.required_type == 'float':
185 return "nir_type_float"
186
187 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
188 r"(?P<cond>\([^\)]+\))?")
189
190 class Expression(Value):
191 def __init__(self, expr, name_base, varset):
192 Value.__init__(self, name_base, "expression")
193 assert isinstance(expr, tuple)
194
195 m = _opcode_re.match(expr[0])
196 assert m and m.group('opcode') is not None
197
198 self.opcode = m.group('opcode')
199 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
200 self.inexact = m.group('inexact') is not None
201 self.cond = m.group('cond')
202 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
203 for (i, src) in enumerate(expr[1:]) ]
204
205 def render(self):
206 srcs = "\n".join(src.render() for src in self.sources)
207 return srcs + super(Expression, self).render()
208
209 class IntEquivalenceRelation(object):
210 """A class representing an equivalence relation on integers.
211
212 Each integer has a canonical form which is the maximum integer to which it
213 is equivalent. Two integers are equivalent precisely when they have the
214 same canonical form.
215
216 The convention of maximum is explicitly chosen to make using it in
217 BitSizeValidator easier because it means that an actual bit_size (if any)
218 will always be the canonical form.
219 """
220 def __init__(self):
221 self._remap = {}
222
223 def get_canonical(self, x):
224 """Get the canonical integer corresponding to x."""
225 if x in self._remap:
226 return self.get_canonical(self._remap[x])
227 else:
228 return x
229
230 def add_equiv(self, a, b):
231 """Add an equivalence and return the canonical form."""
232 c = max(self.get_canonical(a), self.get_canonical(b))
233 if a != c:
234 assert a < c
235 self._remap[a] = c
236
237 if b != c:
238 assert b < c
239 self._remap[b] = c
240
241 return c
242
243 class BitSizeValidator(object):
244 """A class for validating bit sizes of expressions.
245
246 NIR supports multiple bit-sizes on expressions in order to handle things
247 such as fp64. The source and destination of every ALU operation is
248 assigned a type and that type may or may not specify a bit size. Sources
249 and destinations whose type does not specify a bit size are considered
250 "unsized" and automatically take on the bit size of the corresponding
251 register or SSA value. NIR has two simple rules for bit sizes that are
252 validated by nir_validator:
253
254 1) A given SSA def or register has a single bit size that is respected by
255 everything that reads from it or writes to it.
256
257 2) The bit sizes of all unsized inputs/outputs on any given ALU
258 instruction must match. They need not match the sized inputs or
259 outputs but they must match each other.
260
261 In order to keep nir_algebraic relatively simple and easy-to-use,
262 nir_search supports a type of bit-size inference based on the two rules
263 above. This is similar to type inference in many common programming
264 languages. If, for instance, you are constructing an add operation and you
265 know the second source is 16-bit, then you know that the other source and
266 the destination must also be 16-bit. There are, however, cases where this
267 inference can be ambiguous or contradictory. Consider, for instance, the
268 following transformation:
269
270 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
271
272 This transformation can potentially cause a problem because usub_borrow is
273 well-defined for any bit-size of integer. However, b2i always generates a
274 32-bit result so it could end up replacing a 64-bit expression with one
275 that takes two 64-bit values and produces a 32-bit value. As another
276 example, consider this expression:
277
278 (('bcsel', a, b, 0), ('iand', a, b))
279
280 In this case, in the search expression a must be 32-bit but b can
281 potentially have any bit size. If we had a 64-bit b value, we would end up
282 trying to and a 32-bit value with a 64-bit value which would be invalid
283
284 This class solves that problem by providing a validation layer that proves
285 that a given search-and-replace operation is 100% well-defined before we
286 generate any code. This ensures that bugs are caught at compile time
287 rather than at run time.
288
289 The basic operation of the validator is very similar to the bitsize_tree in
290 nir_search only a little more subtle. Instead of simply tracking bit
291 sizes, it tracks "bit classes" where each class is represented by an
292 integer. A value of 0 means we don't know anything yet, positive values
293 are actual bit-sizes, and negative values are used to track equivalence
294 classes of sizes that must be the same but have yet to receive an actual
295 size. The first stage uses the bitsize_tree algorithm to assign bit
296 classes to each variable. If it ever comes across an inconsistency, it
297 assert-fails. Then the second stage uses that information to prove that
298 the resulting expression can always validly be constructed.
299 """
300
301 def __init__(self, varset):
302 self._num_classes = 0
303 self._var_classes = [0] * len(varset.names)
304 self._class_relation = IntEquivalenceRelation()
305
306 def validate(self, search, replace):
307 dst_class = self._propagate_bit_size_up(search)
308 if dst_class == 0:
309 dst_class = self._new_class()
310 self._propagate_bit_class_down(search, dst_class)
311
312 validate_dst_class = self._validate_bit_class_up(replace)
313 assert validate_dst_class == 0 or validate_dst_class == dst_class
314 self._validate_bit_class_down(replace, dst_class)
315
316 def _new_class(self):
317 self._num_classes += 1
318 return -self._num_classes
319
320 def _set_var_bit_class(self, var_id, bit_class):
321 assert bit_class != 0
322 var_class = self._var_classes[var_id]
323 if var_class == 0:
324 self._var_classes[var_id] = bit_class
325 else:
326 canon_class = self._class_relation.get_canonical(var_class)
327 assert canon_class < 0 or canon_class == bit_class
328 var_class = self._class_relation.add_equiv(var_class, bit_class)
329 self._var_classes[var_id] = var_class
330
331 def _get_var_bit_class(self, var_id):
332 return self._class_relation.get_canonical(self._var_classes[var_id])
333
334 def _propagate_bit_size_up(self, val):
335 if isinstance(val, (Constant, Variable)):
336 return val.bit_size
337
338 elif isinstance(val, Expression):
339 nir_op = opcodes[val.opcode]
340 val.common_size = 0
341 for i in range(nir_op.num_inputs):
342 src_bits = self._propagate_bit_size_up(val.sources[i])
343 if src_bits == 0:
344 continue
345
346 src_type_bits = type_bits(nir_op.input_types[i])
347 if src_type_bits != 0:
348 assert src_bits == src_type_bits
349 else:
350 assert val.common_size == 0 or src_bits == val.common_size
351 val.common_size = src_bits
352
353 dst_type_bits = type_bits(nir_op.output_type)
354 if dst_type_bits != 0:
355 assert val.bit_size == 0 or val.bit_size == dst_type_bits
356 return dst_type_bits
357 else:
358 if val.common_size != 0:
359 assert val.bit_size == 0 or val.bit_size == val.common_size
360 else:
361 val.common_size = val.bit_size
362 return val.common_size
363
364 def _propagate_bit_class_down(self, val, bit_class):
365 if isinstance(val, Constant):
366 assert val.bit_size == 0 or val.bit_size == bit_class
367
368 elif isinstance(val, Variable):
369 assert val.bit_size == 0 or val.bit_size == bit_class
370 self._set_var_bit_class(val.index, bit_class)
371
372 elif isinstance(val, Expression):
373 nir_op = opcodes[val.opcode]
374 dst_type_bits = type_bits(nir_op.output_type)
375 if dst_type_bits != 0:
376 assert bit_class == 0 or bit_class == dst_type_bits
377 else:
378 assert val.common_size == 0 or val.common_size == bit_class
379 val.common_size = bit_class
380
381 if val.common_size:
382 common_class = val.common_size
383 elif nir_op.num_inputs:
384 # If we got here then we have no idea what the actual size is.
385 # Instead, we use a generic class
386 common_class = self._new_class()
387
388 for i in range(nir_op.num_inputs):
389 src_type_bits = type_bits(nir_op.input_types[i])
390 if src_type_bits != 0:
391 self._propagate_bit_class_down(val.sources[i], src_type_bits)
392 else:
393 self._propagate_bit_class_down(val.sources[i], common_class)
394
395 def _validate_bit_class_up(self, val):
396 if isinstance(val, Constant):
397 return val.bit_size
398
399 elif isinstance(val, Variable):
400 var_class = self._get_var_bit_class(val.index)
401 # By the time we get to validation, every variable should have a class
402 assert var_class != 0
403
404 # If we have an explicit size provided by the user, the variable
405 # *must* exactly match the search. It cannot be implicitly sized
406 # because otherwise we could end up with a conflict at runtime.
407 assert val.bit_size == 0 or val.bit_size == var_class
408
409 return var_class
410
411 elif isinstance(val, Expression):
412 nir_op = opcodes[val.opcode]
413 val.common_class = 0
414 for i in range(nir_op.num_inputs):
415 src_class = self._validate_bit_class_up(val.sources[i])
416 if src_class == 0:
417 continue
418
419 src_type_bits = type_bits(nir_op.input_types[i])
420 if src_type_bits != 0:
421 assert src_class == src_type_bits
422 else:
423 assert val.common_class == 0 or src_class == val.common_class
424 val.common_class = src_class
425
426 dst_type_bits = type_bits(nir_op.output_type)
427 if dst_type_bits != 0:
428 assert val.bit_size == 0 or val.bit_size == dst_type_bits
429 return dst_type_bits
430 else:
431 if val.common_class != 0:
432 assert val.bit_size == 0 or val.bit_size == val.common_class
433 else:
434 val.common_class = val.bit_size
435 return val.common_class
436
437 def _validate_bit_class_down(self, val, bit_class):
438 # At this point, everything *must* have a bit class. Otherwise, we have
439 # a value we don't know how to define.
440 assert bit_class != 0
441
442 if isinstance(val, Constant):
443 assert val.bit_size == 0 or val.bit_size == bit_class
444
445 elif isinstance(val, Variable):
446 assert val.bit_size == 0 or val.bit_size == bit_class
447
448 elif isinstance(val, Expression):
449 nir_op = opcodes[val.opcode]
450 dst_type_bits = type_bits(nir_op.output_type)
451 if dst_type_bits != 0:
452 assert bit_class == dst_type_bits
453 else:
454 assert val.common_class == 0 or val.common_class == bit_class
455 val.common_class = bit_class
456
457 for i in range(nir_op.num_inputs):
458 src_type_bits = type_bits(nir_op.input_types[i])
459 if src_type_bits != 0:
460 self._validate_bit_class_down(val.sources[i], src_type_bits)
461 else:
462 self._validate_bit_class_down(val.sources[i], val.common_class)
463
464 _optimization_ids = itertools.count()
465
466 condition_list = ['true']
467
468 class SearchAndReplace(object):
469 def __init__(self, transform):
470 self.id = _optimization_ids.next()
471
472 search = transform[0]
473 replace = transform[1]
474 if len(transform) > 2:
475 self.condition = transform[2]
476 else:
477 self.condition = 'true'
478
479 if self.condition not in condition_list:
480 condition_list.append(self.condition)
481 self.condition_index = condition_list.index(self.condition)
482
483 varset = VarSet()
484 if isinstance(search, Expression):
485 self.search = search
486 else:
487 self.search = Expression(search, "search{0}".format(self.id), varset)
488
489 varset.lock()
490
491 if isinstance(replace, Value):
492 self.replace = replace
493 else:
494 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
495
496 BitSizeValidator(varset).validate(self.search, self.replace)
497
498 _algebraic_pass_template = mako.template.Template("""
499 #include "nir.h"
500 #include "nir_search.h"
501 #include "nir_search_helpers.h"
502
503 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
504 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
505
506 struct transform {
507 const nir_search_expression *search;
508 const nir_search_value *replace;
509 unsigned condition_offset;
510 };
511
512 #endif
513
514 % for (opcode, xform_list) in xform_dict.iteritems():
515 % for xform in xform_list:
516 ${xform.search.render()}
517 ${xform.replace.render()}
518 % endfor
519
520 static const struct transform ${pass_name}_${opcode}_xforms[] = {
521 % for xform in xform_list:
522 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
523 % endfor
524 };
525 % endfor
526
527 static bool
528 ${pass_name}_block(nir_block *block, const bool *condition_flags,
529 void *mem_ctx)
530 {
531 bool progress = false;
532
533 nir_foreach_instr_reverse_safe(instr, block) {
534 if (instr->type != nir_instr_type_alu)
535 continue;
536
537 nir_alu_instr *alu = nir_instr_as_alu(instr);
538 if (!alu->dest.dest.is_ssa)
539 continue;
540
541 switch (alu->op) {
542 % for opcode in xform_dict.keys():
543 case nir_op_${opcode}:
544 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
545 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
546 if (condition_flags[xform->condition_offset] &&
547 nir_replace_instr(alu, xform->search, xform->replace,
548 mem_ctx)) {
549 progress = true;
550 break;
551 }
552 }
553 break;
554 % endfor
555 default:
556 break;
557 }
558 }
559
560 return progress;
561 }
562
563 static bool
564 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
565 {
566 void *mem_ctx = ralloc_parent(impl);
567 bool progress = false;
568
569 nir_foreach_block_reverse(block, impl) {
570 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
571 }
572
573 if (progress)
574 nir_metadata_preserve(impl, nir_metadata_block_index |
575 nir_metadata_dominance);
576
577 return progress;
578 }
579
580
581 bool
582 ${pass_name}(nir_shader *shader)
583 {
584 bool progress = false;
585 bool condition_flags[${len(condition_list)}];
586 const nir_shader_compiler_options *options = shader->options;
587 (void) options;
588
589 % for index, condition in enumerate(condition_list):
590 condition_flags[${index}] = ${condition};
591 % endfor
592
593 nir_foreach_function(function, shader) {
594 if (function->impl)
595 progress |= ${pass_name}_impl(function->impl, condition_flags);
596 }
597
598 return progress;
599 }
600 """)
601
602 class AlgebraicPass(object):
603 def __init__(self, pass_name, transforms):
604 self.xform_dict = {}
605 self.pass_name = pass_name
606
607 error = False
608
609 for xform in transforms:
610 if not isinstance(xform, SearchAndReplace):
611 try:
612 xform = SearchAndReplace(xform)
613 except:
614 print("Failed to parse transformation:", file=sys.stderr)
615 print(" " + str(xform), file=sys.stderr)
616 traceback.print_exc(file=sys.stderr)
617 print('', file=sys.stderr)
618 error = True
619 continue
620
621 if xform.search.opcode not in self.xform_dict:
622 self.xform_dict[xform.search.opcode] = []
623
624 self.xform_dict[xform.search.opcode].append(xform)
625
626 if error:
627 sys.exit(1)
628
629 def render(self):
630 return _algebraic_pass_template.render(pass_name=self.pass_name,
631 xform_dict=self.xform_dict,
632 condition_list=condition_list)