nir/algebraic: Fix a typo in the bit size validation code
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import OrderedDict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 if sys.version_info < (3, 0):
39 integer_types = (int, long)
40 string_type = unicode
41
42 else:
43 integer_types = (int, )
44 string_type = str
45
46 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
47
48 def type_bits(type_str):
49 m = _type_re.match(type_str)
50 assert m.group('type')
51
52 if m.group('bits') is None:
53 return 0
54 else:
55 return int(m.group('bits'))
56
57 # Represents a set of variables, each with a unique id
58 class VarSet(object):
59 def __init__(self):
60 self.names = {}
61 self.ids = itertools.count()
62 self.immutable = False;
63
64 def __getitem__(self, name):
65 if name not in self.names:
66 assert not self.immutable, "Unknown replacement variable: " + name
67 self.names[name] = next(self.ids)
68
69 return self.names[name]
70
71 def lock(self):
72 self.immutable = True
73
74 class Value(object):
75 @staticmethod
76 def create(val, name_base, varset):
77 if isinstance(val, bytes):
78 val = val.decode('utf-8')
79
80 if isinstance(val, tuple):
81 return Expression(val, name_base, varset)
82 elif isinstance(val, Expression):
83 return val
84 elif isinstance(val, string_type):
85 return Variable(val, name_base, varset)
86 elif isinstance(val, (bool, float) + integer_types):
87 return Constant(val, name_base)
88
89 __template = mako.template.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
104 % endif
105 };""")
106
107 def __init__(self, val, name, type_str):
108 self.in_val = str(val)
109 self.name = name
110 self.type_str = type_str
111
112 def __str__(self):
113 return self.in_val
114
115 @property
116 def type_enum(self):
117 return "nir_search_value_" + self.type_str
118
119 @property
120 def c_type(self):
121 return "nir_search_" + self.type_str
122
123 @property
124 def c_ptr(self):
125 return "&{0}.value".format(self.name)
126
127 def render(self):
128 return self.__template.render(val=self,
129 Constant=Constant,
130 Variable=Variable,
131 Expression=Expression)
132
133 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
134
135 class Constant(Value):
136 def __init__(self, val, name):
137 Value.__init__(self, val, name, "constant")
138
139 self.in_val = str(val)
140 if isinstance(val, (str)):
141 m = _constant_re.match(val)
142 self.value = ast.literal_eval(m.group('value'))
143 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
144 else:
145 self.value = val
146 self.bit_size = 0
147
148 if isinstance(self.value, bool):
149 assert self.bit_size == 0 or self.bit_size == 32
150 self.bit_size = 32
151
152 def hex(self):
153 if isinstance(self.value, (bool)):
154 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
155 if isinstance(self.value, integer_types):
156 return hex(self.value)
157 elif isinstance(self.value, float):
158 i = struct.unpack('Q', struct.pack('d', self.value))[0]
159 h = hex(i)
160
161 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
162 # Adding it explicitly makes the generated file identical, regardless
163 # of the Python version running this script.
164 if h[-1] != 'L' and i > sys.maxsize:
165 h += 'L'
166
167 return h
168 else:
169 assert False
170
171 def type(self):
172 if isinstance(self.value, (bool)):
173 return "nir_type_bool"
174 elif isinstance(self.value, integer_types):
175 return "nir_type_int"
176 elif isinstance(self.value, float):
177 return "nir_type_float"
178
179 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
180 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
181 r"(?P<cond>\([^\)]+\))?")
182
183 class Variable(Value):
184 def __init__(self, val, name, varset):
185 Value.__init__(self, val, name, "variable")
186
187 m = _var_name_re.match(val)
188 assert m and m.group('name') is not None
189
190 self.var_name = m.group('name')
191 self.is_constant = m.group('const') is not None
192 self.cond = m.group('cond')
193 self.required_type = m.group('type')
194 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
195
196 if self.required_type == 'bool':
197 assert self.bit_size == 0 or self.bit_size == 32
198 self.bit_size = 32
199
200 if self.required_type is not None:
201 assert self.required_type in ('float', 'bool', 'int', 'uint')
202
203 self.index = varset[self.var_name]
204
205 def __str__(self):
206 return self.in_val
207
208 def type(self):
209 if self.required_type == 'bool':
210 return "nir_type_bool"
211 elif self.required_type in ('int', 'uint'):
212 return "nir_type_int"
213 elif self.required_type == 'float':
214 return "nir_type_float"
215
216 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
217 r"(?P<cond>\([^\)]+\))?")
218
219 class Expression(Value):
220 def __init__(self, expr, name_base, varset):
221 Value.__init__(self, expr, name_base, "expression")
222 assert isinstance(expr, tuple)
223
224 m = _opcode_re.match(expr[0])
225 assert m and m.group('opcode') is not None
226
227 self.opcode = m.group('opcode')
228 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
229 self.inexact = m.group('inexact') is not None
230 self.cond = m.group('cond')
231 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
232 for (i, src) in enumerate(expr[1:]) ]
233
234 def render(self):
235 srcs = "\n".join(src.render() for src in self.sources)
236 return srcs + super(Expression, self).render()
237
238 class IntEquivalenceRelation(object):
239 """A class representing an equivalence relation on integers.
240
241 Each integer has a canonical form which is the maximum integer to which it
242 is equivalent. Two integers are equivalent precisely when they have the
243 same canonical form.
244
245 The convention of maximum is explicitly chosen to make using it in
246 BitSizeValidator easier because it means that an actual bit_size (if any)
247 will always be the canonical form.
248 """
249 def __init__(self):
250 self._remap = {}
251
252 def get_canonical(self, x):
253 """Get the canonical integer corresponding to x."""
254 if x in self._remap:
255 return self.get_canonical(self._remap[x])
256 else:
257 return x
258
259 def add_equiv(self, a, b):
260 """Add an equivalence and return the canonical form."""
261 c = max(self.get_canonical(a), self.get_canonical(b))
262 if a != c:
263 assert a < c
264 self._remap[a] = c
265
266 if b != c:
267 assert b < c
268 self._remap[b] = c
269
270 return c
271
272 class BitSizeValidator(object):
273 """A class for validating bit sizes of expressions.
274
275 NIR supports multiple bit-sizes on expressions in order to handle things
276 such as fp64. The source and destination of every ALU operation is
277 assigned a type and that type may or may not specify a bit size. Sources
278 and destinations whose type does not specify a bit size are considered
279 "unsized" and automatically take on the bit size of the corresponding
280 register or SSA value. NIR has two simple rules for bit sizes that are
281 validated by nir_validator:
282
283 1) A given SSA def or register has a single bit size that is respected by
284 everything that reads from it or writes to it.
285
286 2) The bit sizes of all unsized inputs/outputs on any given ALU
287 instruction must match. They need not match the sized inputs or
288 outputs but they must match each other.
289
290 In order to keep nir_algebraic relatively simple and easy-to-use,
291 nir_search supports a type of bit-size inference based on the two rules
292 above. This is similar to type inference in many common programming
293 languages. If, for instance, you are constructing an add operation and you
294 know the second source is 16-bit, then you know that the other source and
295 the destination must also be 16-bit. There are, however, cases where this
296 inference can be ambiguous or contradictory. Consider, for instance, the
297 following transformation:
298
299 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
300
301 This transformation can potentially cause a problem because usub_borrow is
302 well-defined for any bit-size of integer. However, b2i always generates a
303 32-bit result so it could end up replacing a 64-bit expression with one
304 that takes two 64-bit values and produces a 32-bit value. As another
305 example, consider this expression:
306
307 (('bcsel', a, b, 0), ('iand', a, b))
308
309 In this case, in the search expression a must be 32-bit but b can
310 potentially have any bit size. If we had a 64-bit b value, we would end up
311 trying to and a 32-bit value with a 64-bit value which would be invalid
312
313 This class solves that problem by providing a validation layer that proves
314 that a given search-and-replace operation is 100% well-defined before we
315 generate any code. This ensures that bugs are caught at compile time
316 rather than at run time.
317
318 The basic operation of the validator is very similar to the bitsize_tree in
319 nir_search only a little more subtle. Instead of simply tracking bit
320 sizes, it tracks "bit classes" where each class is represented by an
321 integer. A value of 0 means we don't know anything yet, positive values
322 are actual bit-sizes, and negative values are used to track equivalence
323 classes of sizes that must be the same but have yet to receive an actual
324 size. The first stage uses the bitsize_tree algorithm to assign bit
325 classes to each variable. If it ever comes across an inconsistency, it
326 assert-fails. Then the second stage uses that information to prove that
327 the resulting expression can always validly be constructed.
328 """
329
330 def __init__(self, varset):
331 self._num_classes = 0
332 self._var_classes = [0] * len(varset.names)
333 self._class_relation = IntEquivalenceRelation()
334
335 def validate(self, search, replace):
336 search_dst_class = self._propagate_bit_size_up(search)
337 if search_dst_class == 0:
338 search_dst_class = self._new_class()
339 self._propagate_bit_class_down(search, search_dst_class)
340
341 replace_dst_class = self._validate_bit_class_up(replace)
342 if replace_dst_class != 0:
343 assert search_dst_class != 0, \
344 'Search expression matches any bit size but replace ' \
345 'expression can only generate {0}-bit values' \
346 .format(replace_dst_class)
347
348 assert search_dst_class == replace_dst_class, \
349 'Search expression matches any {0}-bit values but replace ' \
350 'expression can only generates {1}-bit values' \
351 .format(search_dst_class, replace_dst_class)
352
353 self._validate_bit_class_down(replace, search_dst_class)
354
355 def _new_class(self):
356 self._num_classes += 1
357 return -self._num_classes
358
359 def _set_var_bit_class(self, var, bit_class):
360 assert bit_class != 0
361 var_class = self._var_classes[var.index]
362 if var_class == 0:
363 self._var_classes[var.index] = bit_class
364 else:
365 canon_var_class = self._class_relation.get_canonical(var_class)
366 canon_bit_class = self._class_relation.get_canonical(bit_class)
367 assert canon_var_class < 0 or canon_bit_class < 0 or \
368 canon_var_class == canon_bit_class, \
369 'Variable {0} cannot be both {1}-bit and {2}-bit' \
370 .format(str(var), bit_class, var_class)
371 var_class = self._class_relation.add_equiv(var_class, bit_class)
372 self._var_classes[var.index] = var_class
373
374 def _get_var_bit_class(self, var):
375 return self._class_relation.get_canonical(self._var_classes[var.index])
376
377 def _propagate_bit_size_up(self, val):
378 if isinstance(val, (Constant, Variable)):
379 return val.bit_size
380
381 elif isinstance(val, Expression):
382 nir_op = opcodes[val.opcode]
383 val.common_size = 0
384 for i in range(nir_op.num_inputs):
385 src_bits = self._propagate_bit_size_up(val.sources[i])
386 if src_bits == 0:
387 continue
388
389 src_type_bits = type_bits(nir_op.input_types[i])
390 if src_type_bits != 0:
391 assert src_bits == src_type_bits, \
392 'Source {0} of nir_op_{1} must be a {2}-bit value but ' \
393 'the only possible matched values are {3}-bit: {4}' \
394 .format(i, val.opcode, src_type_bits, src_bits, str(val))
395 else:
396 assert val.common_size == 0 or src_bits == val.common_size, \
397 'Expression cannot have both {0}-bit and {1}-bit ' \
398 'variable-width sources: {2}' \
399 .format(src_bits, val.common_size, str(val))
400 val.common_size = src_bits
401
402 dst_type_bits = type_bits(nir_op.output_type)
403 if dst_type_bits != 0:
404 assert val.bit_size == 0 or val.bit_size == dst_type_bits, \
405 'nir_op_{0} produces a {1}-bit result but a {2}-bit ' \
406 'result was requested' \
407 .format(val.opcode, dst_type_bits, val.bit_size)
408 return dst_type_bits
409 else:
410 if val.common_size != 0:
411 assert val.bit_size == 0 or val.bit_size == val.common_size, \
412 'Variable width expression musr be {0}-bit based on ' \
413 'the sources but a {1}-bit result was requested: {2}' \
414 .format(val.common_size, val.bit_size, str(val))
415 else:
416 val.common_size = val.bit_size
417 return val.common_size
418
419 def _propagate_bit_class_down(self, val, bit_class):
420 if isinstance(val, Constant):
421 assert val.bit_size == 0 or val.bit_size == bit_class, \
422 'Constant is {0}-bit but a {1}-bit value is required: {2}' \
423 .format(val.bit_size, bit_class, str(val))
424
425 elif isinstance(val, Variable):
426 assert val.bit_size == 0 or val.bit_size == bit_class, \
427 'Variable is {0}-bit but a {1}-bit value is required: {2}' \
428 .format(val.bit_size, bit_class, str(val))
429 self._set_var_bit_class(val, bit_class)
430
431 elif isinstance(val, Expression):
432 nir_op = opcodes[val.opcode]
433 dst_type_bits = type_bits(nir_op.output_type)
434 if dst_type_bits != 0:
435 assert bit_class == 0 or bit_class == dst_type_bits, \
436 'nir_op_{0} produces a {1}-bit result but the parent ' \
437 'expression wants a {2}-bit value' \
438 .format(val.opcode, dst_type_bits, bit_class)
439 else:
440 assert val.common_size == 0 or val.common_size == bit_class, \
441 'Variable-width expression produces a {0}-bit result ' \
442 'based on the source widths but the parent expression ' \
443 'wants a {1}-bit value: {2}' \
444 .format(val.common_size, bit_class, str(val))
445 val.common_size = bit_class
446
447 if val.common_size:
448 common_class = val.common_size
449 elif nir_op.num_inputs:
450 # If we got here then we have no idea what the actual size is.
451 # Instead, we use a generic class
452 common_class = self._new_class()
453
454 for i in range(nir_op.num_inputs):
455 src_type_bits = type_bits(nir_op.input_types[i])
456 if src_type_bits != 0:
457 self._propagate_bit_class_down(val.sources[i], src_type_bits)
458 else:
459 self._propagate_bit_class_down(val.sources[i], common_class)
460
461 def _validate_bit_class_up(self, val):
462 if isinstance(val, Constant):
463 return val.bit_size
464
465 elif isinstance(val, Variable):
466 var_class = self._get_var_bit_class(val)
467 # By the time we get to validation, every variable should have a class
468 assert var_class != 0
469
470 # If we have an explicit size provided by the user, the variable
471 # *must* exactly match the search. It cannot be implicitly sized
472 # because otherwise we could end up with a conflict at runtime.
473 assert val.bit_size == 0 or val.bit_size == var_class
474
475 return var_class
476
477 elif isinstance(val, Expression):
478 nir_op = opcodes[val.opcode]
479 val.common_class = 0
480 for i in range(nir_op.num_inputs):
481 src_class = self._validate_bit_class_up(val.sources[i])
482 if src_class == 0:
483 continue
484
485 src_type_bits = type_bits(nir_op.input_types[i])
486 if src_type_bits != 0:
487 assert src_class == src_type_bits
488 else:
489 assert val.common_class == 0 or src_class == val.common_class
490 val.common_class = src_class
491
492 dst_type_bits = type_bits(nir_op.output_type)
493 if dst_type_bits != 0:
494 assert val.bit_size == 0 or val.bit_size == dst_type_bits
495 return dst_type_bits
496 else:
497 if val.common_class != 0:
498 assert val.bit_size == 0 or val.bit_size == val.common_class
499 else:
500 val.common_class = val.bit_size
501 return val.common_class
502
503 def _validate_bit_class_down(self, val, bit_class):
504 # At this point, everything *must* have a bit class. Otherwise, we have
505 # a value we don't know how to define.
506 assert bit_class != 0
507
508 if isinstance(val, Constant):
509 assert val.bit_size == 0 or val.bit_size == bit_class
510
511 elif isinstance(val, Variable):
512 assert val.bit_size == 0 or val.bit_size == bit_class
513
514 elif isinstance(val, Expression):
515 nir_op = opcodes[val.opcode]
516 dst_type_bits = type_bits(nir_op.output_type)
517 if dst_type_bits != 0:
518 assert bit_class == dst_type_bits
519 else:
520 assert val.common_class == 0 or val.common_class == bit_class
521 val.common_class = bit_class
522
523 for i in range(nir_op.num_inputs):
524 src_type_bits = type_bits(nir_op.input_types[i])
525 if src_type_bits != 0:
526 self._validate_bit_class_down(val.sources[i], src_type_bits)
527 else:
528 self._validate_bit_class_down(val.sources[i], val.common_class)
529
530 _optimization_ids = itertools.count()
531
532 condition_list = ['true']
533
534 class SearchAndReplace(object):
535 def __init__(self, transform):
536 self.id = next(_optimization_ids)
537
538 search = transform[0]
539 replace = transform[1]
540 if len(transform) > 2:
541 self.condition = transform[2]
542 else:
543 self.condition = 'true'
544
545 if self.condition not in condition_list:
546 condition_list.append(self.condition)
547 self.condition_index = condition_list.index(self.condition)
548
549 varset = VarSet()
550 if isinstance(search, Expression):
551 self.search = search
552 else:
553 self.search = Expression(search, "search{0}".format(self.id), varset)
554
555 varset.lock()
556
557 if isinstance(replace, Value):
558 self.replace = replace
559 else:
560 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
561
562 BitSizeValidator(varset).validate(self.search, self.replace)
563
564 _algebraic_pass_template = mako.template.Template("""
565 #include "nir.h"
566 #include "nir_search.h"
567 #include "nir_search_helpers.h"
568
569 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
570 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
571
572 struct transform {
573 const nir_search_expression *search;
574 const nir_search_value *replace;
575 unsigned condition_offset;
576 };
577
578 #endif
579
580 % for (opcode, xform_list) in xform_dict.items():
581 % for xform in xform_list:
582 ${xform.search.render()}
583 ${xform.replace.render()}
584 % endfor
585
586 static const struct transform ${pass_name}_${opcode}_xforms[] = {
587 % for xform in xform_list:
588 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
589 % endfor
590 };
591 % endfor
592
593 static bool
594 ${pass_name}_block(nir_block *block, const bool *condition_flags,
595 void *mem_ctx)
596 {
597 bool progress = false;
598
599 nir_foreach_instr_reverse_safe(instr, block) {
600 if (instr->type != nir_instr_type_alu)
601 continue;
602
603 nir_alu_instr *alu = nir_instr_as_alu(instr);
604 if (!alu->dest.dest.is_ssa)
605 continue;
606
607 switch (alu->op) {
608 % for opcode in xform_dict.keys():
609 case nir_op_${opcode}:
610 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
611 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
612 if (condition_flags[xform->condition_offset] &&
613 nir_replace_instr(alu, xform->search, xform->replace,
614 mem_ctx)) {
615 progress = true;
616 break;
617 }
618 }
619 break;
620 % endfor
621 default:
622 break;
623 }
624 }
625
626 return progress;
627 }
628
629 static bool
630 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
631 {
632 void *mem_ctx = ralloc_parent(impl);
633 bool progress = false;
634
635 nir_foreach_block_reverse(block, impl) {
636 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
637 }
638
639 if (progress)
640 nir_metadata_preserve(impl, nir_metadata_block_index |
641 nir_metadata_dominance);
642
643 return progress;
644 }
645
646
647 bool
648 ${pass_name}(nir_shader *shader)
649 {
650 bool progress = false;
651 bool condition_flags[${len(condition_list)}];
652 const nir_shader_compiler_options *options = shader->options;
653 (void) options;
654
655 % for index, condition in enumerate(condition_list):
656 condition_flags[${index}] = ${condition};
657 % endfor
658
659 nir_foreach_function(function, shader) {
660 if (function->impl)
661 progress |= ${pass_name}_impl(function->impl, condition_flags);
662 }
663
664 return progress;
665 }
666 """)
667
668 class AlgebraicPass(object):
669 def __init__(self, pass_name, transforms):
670 self.xform_dict = OrderedDict()
671 self.pass_name = pass_name
672
673 error = False
674
675 for xform in transforms:
676 if not isinstance(xform, SearchAndReplace):
677 try:
678 xform = SearchAndReplace(xform)
679 except:
680 print("Failed to parse transformation:", file=sys.stderr)
681 print(" " + str(xform), file=sys.stderr)
682 traceback.print_exc(file=sys.stderr)
683 print('', file=sys.stderr)
684 error = True
685 continue
686
687 if xform.search.opcode not in self.xform_dict:
688 self.xform_dict[xform.search.opcode] = []
689
690 self.xform_dict[xform.search.opcode].append(xform)
691
692 if error:
693 sys.exit(1)
694
695 def render(self):
696 return _algebraic_pass_template.render(pass_name=self.pass_name,
697 xform_dict=self.xform_dict,
698 condition_list=condition_list)