nir/algebraic: Make internal classes str-able
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import OrderedDict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 if sys.version_info < (3, 0):
39 integer_types = (int, long)
40 string_type = unicode
41
42 else:
43 integer_types = (int, )
44 string_type = str
45
46 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
47
48 def type_bits(type_str):
49 m = _type_re.match(type_str)
50 assert m.group('type')
51
52 if m.group('bits') is None:
53 return 0
54 else:
55 return int(m.group('bits'))
56
57 # Represents a set of variables, each with a unique id
58 class VarSet(object):
59 def __init__(self):
60 self.names = {}
61 self.ids = itertools.count()
62 self.immutable = False;
63
64 def __getitem__(self, name):
65 if name not in self.names:
66 assert not self.immutable, "Unknown replacement variable: " + name
67 self.names[name] = next(self.ids)
68
69 return self.names[name]
70
71 def lock(self):
72 self.immutable = True
73
74 class Value(object):
75 @staticmethod
76 def create(val, name_base, varset):
77 if isinstance(val, bytes):
78 val = val.decode('utf-8')
79
80 if isinstance(val, tuple):
81 return Expression(val, name_base, varset)
82 elif isinstance(val, Expression):
83 return val
84 elif isinstance(val, string_type):
85 return Variable(val, name_base, varset)
86 elif isinstance(val, (bool, float) + integer_types):
87 return Constant(val, name_base)
88
89 __template = mako.template.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
104 % endif
105 };""")
106
107 def __init__(self, val, name, type_str):
108 self.in_val = str(val)
109 self.name = name
110 self.type_str = type_str
111
112 def __str__(self):
113 return self.in_val
114
115 @property
116 def type_enum(self):
117 return "nir_search_value_" + self.type_str
118
119 @property
120 def c_type(self):
121 return "nir_search_" + self.type_str
122
123 @property
124 def c_ptr(self):
125 return "&{0}.value".format(self.name)
126
127 def render(self):
128 return self.__template.render(val=self,
129 Constant=Constant,
130 Variable=Variable,
131 Expression=Expression)
132
133 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
134
135 class Constant(Value):
136 def __init__(self, val, name):
137 Value.__init__(self, val, name, "constant")
138
139 self.in_val = str(val)
140 if isinstance(val, (str)):
141 m = _constant_re.match(val)
142 self.value = ast.literal_eval(m.group('value'))
143 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
144 else:
145 self.value = val
146 self.bit_size = 0
147
148 if isinstance(self.value, bool):
149 assert self.bit_size == 0 or self.bit_size == 32
150 self.bit_size = 32
151
152 def hex(self):
153 if isinstance(self.value, (bool)):
154 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
155 if isinstance(self.value, integer_types):
156 return hex(self.value)
157 elif isinstance(self.value, float):
158 i = struct.unpack('Q', struct.pack('d', self.value))[0]
159 h = hex(i)
160
161 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
162 # Adding it explicitly makes the generated file identical, regardless
163 # of the Python version running this script.
164 if h[-1] != 'L' and i > sys.maxsize:
165 h += 'L'
166
167 return h
168 else:
169 assert False
170
171 def type(self):
172 if isinstance(self.value, (bool)):
173 return "nir_type_bool"
174 elif isinstance(self.value, integer_types):
175 return "nir_type_int"
176 elif isinstance(self.value, float):
177 return "nir_type_float"
178
179 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
180 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
181 r"(?P<cond>\([^\)]+\))?")
182
183 class Variable(Value):
184 def __init__(self, val, name, varset):
185 Value.__init__(self, val, name, "variable")
186
187 m = _var_name_re.match(val)
188 assert m and m.group('name') is not None
189
190 self.var_name = m.group('name')
191 self.is_constant = m.group('const') is not None
192 self.cond = m.group('cond')
193 self.required_type = m.group('type')
194 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
195
196 if self.required_type == 'bool':
197 assert self.bit_size == 0 or self.bit_size == 32
198 self.bit_size = 32
199
200 if self.required_type is not None:
201 assert self.required_type in ('float', 'bool', 'int', 'uint')
202
203 self.index = varset[self.var_name]
204
205 def __str__(self):
206 return self.in_val
207
208 def type(self):
209 if self.required_type == 'bool':
210 return "nir_type_bool"
211 elif self.required_type in ('int', 'uint'):
212 return "nir_type_int"
213 elif self.required_type == 'float':
214 return "nir_type_float"
215
216 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
217 r"(?P<cond>\([^\)]+\))?")
218
219 class Expression(Value):
220 def __init__(self, expr, name_base, varset):
221 Value.__init__(self, expr, name_base, "expression")
222 assert isinstance(expr, tuple)
223
224 m = _opcode_re.match(expr[0])
225 assert m and m.group('opcode') is not None
226
227 self.opcode = m.group('opcode')
228 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
229 self.inexact = m.group('inexact') is not None
230 self.cond = m.group('cond')
231 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
232 for (i, src) in enumerate(expr[1:]) ]
233
234 def render(self):
235 srcs = "\n".join(src.render() for src in self.sources)
236 return srcs + super(Expression, self).render()
237
238 class IntEquivalenceRelation(object):
239 """A class representing an equivalence relation on integers.
240
241 Each integer has a canonical form which is the maximum integer to which it
242 is equivalent. Two integers are equivalent precisely when they have the
243 same canonical form.
244
245 The convention of maximum is explicitly chosen to make using it in
246 BitSizeValidator easier because it means that an actual bit_size (if any)
247 will always be the canonical form.
248 """
249 def __init__(self):
250 self._remap = {}
251
252 def get_canonical(self, x):
253 """Get the canonical integer corresponding to x."""
254 if x in self._remap:
255 return self.get_canonical(self._remap[x])
256 else:
257 return x
258
259 def add_equiv(self, a, b):
260 """Add an equivalence and return the canonical form."""
261 c = max(self.get_canonical(a), self.get_canonical(b))
262 if a != c:
263 assert a < c
264 self._remap[a] = c
265
266 if b != c:
267 assert b < c
268 self._remap[b] = c
269
270 return c
271
272 class BitSizeValidator(object):
273 """A class for validating bit sizes of expressions.
274
275 NIR supports multiple bit-sizes on expressions in order to handle things
276 such as fp64. The source and destination of every ALU operation is
277 assigned a type and that type may or may not specify a bit size. Sources
278 and destinations whose type does not specify a bit size are considered
279 "unsized" and automatically take on the bit size of the corresponding
280 register or SSA value. NIR has two simple rules for bit sizes that are
281 validated by nir_validator:
282
283 1) A given SSA def or register has a single bit size that is respected by
284 everything that reads from it or writes to it.
285
286 2) The bit sizes of all unsized inputs/outputs on any given ALU
287 instruction must match. They need not match the sized inputs or
288 outputs but they must match each other.
289
290 In order to keep nir_algebraic relatively simple and easy-to-use,
291 nir_search supports a type of bit-size inference based on the two rules
292 above. This is similar to type inference in many common programming
293 languages. If, for instance, you are constructing an add operation and you
294 know the second source is 16-bit, then you know that the other source and
295 the destination must also be 16-bit. There are, however, cases where this
296 inference can be ambiguous or contradictory. Consider, for instance, the
297 following transformation:
298
299 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
300
301 This transformation can potentially cause a problem because usub_borrow is
302 well-defined for any bit-size of integer. However, b2i always generates a
303 32-bit result so it could end up replacing a 64-bit expression with one
304 that takes two 64-bit values and produces a 32-bit value. As another
305 example, consider this expression:
306
307 (('bcsel', a, b, 0), ('iand', a, b))
308
309 In this case, in the search expression a must be 32-bit but b can
310 potentially have any bit size. If we had a 64-bit b value, we would end up
311 trying to and a 32-bit value with a 64-bit value which would be invalid
312
313 This class solves that problem by providing a validation layer that proves
314 that a given search-and-replace operation is 100% well-defined before we
315 generate any code. This ensures that bugs are caught at compile time
316 rather than at run time.
317
318 The basic operation of the validator is very similar to the bitsize_tree in
319 nir_search only a little more subtle. Instead of simply tracking bit
320 sizes, it tracks "bit classes" where each class is represented by an
321 integer. A value of 0 means we don't know anything yet, positive values
322 are actual bit-sizes, and negative values are used to track equivalence
323 classes of sizes that must be the same but have yet to receive an actual
324 size. The first stage uses the bitsize_tree algorithm to assign bit
325 classes to each variable. If it ever comes across an inconsistency, it
326 assert-fails. Then the second stage uses that information to prove that
327 the resulting expression can always validly be constructed.
328 """
329
330 def __init__(self, varset):
331 self._num_classes = 0
332 self._var_classes = [0] * len(varset.names)
333 self._class_relation = IntEquivalenceRelation()
334
335 def validate(self, search, replace):
336 dst_class = self._propagate_bit_size_up(search)
337 if dst_class == 0:
338 dst_class = self._new_class()
339 self._propagate_bit_class_down(search, dst_class)
340
341 validate_dst_class = self._validate_bit_class_up(replace)
342 assert validate_dst_class == 0 or validate_dst_class == dst_class
343 self._validate_bit_class_down(replace, dst_class)
344
345 def _new_class(self):
346 self._num_classes += 1
347 return -self._num_classes
348
349 def _set_var_bit_class(self, var_id, bit_class):
350 assert bit_class != 0
351 var_class = self._var_classes[var_id]
352 if var_class == 0:
353 self._var_classes[var_id] = bit_class
354 else:
355 canon_class = self._class_relation.get_canonical(var_class)
356 assert canon_class < 0 or canon_class == bit_class
357 var_class = self._class_relation.add_equiv(var_class, bit_class)
358 self._var_classes[var_id] = var_class
359
360 def _get_var_bit_class(self, var_id):
361 return self._class_relation.get_canonical(self._var_classes[var_id])
362
363 def _propagate_bit_size_up(self, val):
364 if isinstance(val, (Constant, Variable)):
365 return val.bit_size
366
367 elif isinstance(val, Expression):
368 nir_op = opcodes[val.opcode]
369 val.common_size = 0
370 for i in range(nir_op.num_inputs):
371 src_bits = self._propagate_bit_size_up(val.sources[i])
372 if src_bits == 0:
373 continue
374
375 src_type_bits = type_bits(nir_op.input_types[i])
376 if src_type_bits != 0:
377 assert src_bits == src_type_bits
378 else:
379 assert val.common_size == 0 or src_bits == val.common_size
380 val.common_size = src_bits
381
382 dst_type_bits = type_bits(nir_op.output_type)
383 if dst_type_bits != 0:
384 assert val.bit_size == 0 or val.bit_size == dst_type_bits
385 return dst_type_bits
386 else:
387 if val.common_size != 0:
388 assert val.bit_size == 0 or val.bit_size == val.common_size
389 else:
390 val.common_size = val.bit_size
391 return val.common_size
392
393 def _propagate_bit_class_down(self, val, bit_class):
394 if isinstance(val, Constant):
395 assert val.bit_size == 0 or val.bit_size == bit_class
396
397 elif isinstance(val, Variable):
398 assert val.bit_size == 0 or val.bit_size == bit_class
399 self._set_var_bit_class(val.index, bit_class)
400
401 elif isinstance(val, Expression):
402 nir_op = opcodes[val.opcode]
403 dst_type_bits = type_bits(nir_op.output_type)
404 if dst_type_bits != 0:
405 assert bit_class == 0 or bit_class == dst_type_bits
406 else:
407 assert val.common_size == 0 or val.common_size == bit_class
408 val.common_size = bit_class
409
410 if val.common_size:
411 common_class = val.common_size
412 elif nir_op.num_inputs:
413 # If we got here then we have no idea what the actual size is.
414 # Instead, we use a generic class
415 common_class = self._new_class()
416
417 for i in range(nir_op.num_inputs):
418 src_type_bits = type_bits(nir_op.input_types[i])
419 if src_type_bits != 0:
420 self._propagate_bit_class_down(val.sources[i], src_type_bits)
421 else:
422 self._propagate_bit_class_down(val.sources[i], common_class)
423
424 def _validate_bit_class_up(self, val):
425 if isinstance(val, Constant):
426 return val.bit_size
427
428 elif isinstance(val, Variable):
429 var_class = self._get_var_bit_class(val.index)
430 # By the time we get to validation, every variable should have a class
431 assert var_class != 0
432
433 # If we have an explicit size provided by the user, the variable
434 # *must* exactly match the search. It cannot be implicitly sized
435 # because otherwise we could end up with a conflict at runtime.
436 assert val.bit_size == 0 or val.bit_size == var_class
437
438 return var_class
439
440 elif isinstance(val, Expression):
441 nir_op = opcodes[val.opcode]
442 val.common_class = 0
443 for i in range(nir_op.num_inputs):
444 src_class = self._validate_bit_class_up(val.sources[i])
445 if src_class == 0:
446 continue
447
448 src_type_bits = type_bits(nir_op.input_types[i])
449 if src_type_bits != 0:
450 assert src_class == src_type_bits
451 else:
452 assert val.common_class == 0 or src_class == val.common_class
453 val.common_class = src_class
454
455 dst_type_bits = type_bits(nir_op.output_type)
456 if dst_type_bits != 0:
457 assert val.bit_size == 0 or val.bit_size == dst_type_bits
458 return dst_type_bits
459 else:
460 if val.common_class != 0:
461 assert val.bit_size == 0 or val.bit_size == val.common_class
462 else:
463 val.common_class = val.bit_size
464 return val.common_class
465
466 def _validate_bit_class_down(self, val, bit_class):
467 # At this point, everything *must* have a bit class. Otherwise, we have
468 # a value we don't know how to define.
469 assert bit_class != 0
470
471 if isinstance(val, Constant):
472 assert val.bit_size == 0 or val.bit_size == bit_class
473
474 elif isinstance(val, Variable):
475 assert val.bit_size == 0 or val.bit_size == bit_class
476
477 elif isinstance(val, Expression):
478 nir_op = opcodes[val.opcode]
479 dst_type_bits = type_bits(nir_op.output_type)
480 if dst_type_bits != 0:
481 assert bit_class == dst_type_bits
482 else:
483 assert val.common_class == 0 or val.common_class == bit_class
484 val.common_class = bit_class
485
486 for i in range(nir_op.num_inputs):
487 src_type_bits = type_bits(nir_op.input_types[i])
488 if src_type_bits != 0:
489 self._validate_bit_class_down(val.sources[i], src_type_bits)
490 else:
491 self._validate_bit_class_down(val.sources[i], val.common_class)
492
493 _optimization_ids = itertools.count()
494
495 condition_list = ['true']
496
497 class SearchAndReplace(object):
498 def __init__(self, transform):
499 self.id = next(_optimization_ids)
500
501 search = transform[0]
502 replace = transform[1]
503 if len(transform) > 2:
504 self.condition = transform[2]
505 else:
506 self.condition = 'true'
507
508 if self.condition not in condition_list:
509 condition_list.append(self.condition)
510 self.condition_index = condition_list.index(self.condition)
511
512 varset = VarSet()
513 if isinstance(search, Expression):
514 self.search = search
515 else:
516 self.search = Expression(search, "search{0}".format(self.id), varset)
517
518 varset.lock()
519
520 if isinstance(replace, Value):
521 self.replace = replace
522 else:
523 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
524
525 BitSizeValidator(varset).validate(self.search, self.replace)
526
527 _algebraic_pass_template = mako.template.Template("""
528 #include "nir.h"
529 #include "nir_search.h"
530 #include "nir_search_helpers.h"
531
532 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
533 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
534
535 struct transform {
536 const nir_search_expression *search;
537 const nir_search_value *replace;
538 unsigned condition_offset;
539 };
540
541 #endif
542
543 % for (opcode, xform_list) in xform_dict.items():
544 % for xform in xform_list:
545 ${xform.search.render()}
546 ${xform.replace.render()}
547 % endfor
548
549 static const struct transform ${pass_name}_${opcode}_xforms[] = {
550 % for xform in xform_list:
551 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
552 % endfor
553 };
554 % endfor
555
556 static bool
557 ${pass_name}_block(nir_block *block, const bool *condition_flags,
558 void *mem_ctx)
559 {
560 bool progress = false;
561
562 nir_foreach_instr_reverse_safe(instr, block) {
563 if (instr->type != nir_instr_type_alu)
564 continue;
565
566 nir_alu_instr *alu = nir_instr_as_alu(instr);
567 if (!alu->dest.dest.is_ssa)
568 continue;
569
570 switch (alu->op) {
571 % for opcode in xform_dict.keys():
572 case nir_op_${opcode}:
573 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
574 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
575 if (condition_flags[xform->condition_offset] &&
576 nir_replace_instr(alu, xform->search, xform->replace,
577 mem_ctx)) {
578 progress = true;
579 break;
580 }
581 }
582 break;
583 % endfor
584 default:
585 break;
586 }
587 }
588
589 return progress;
590 }
591
592 static bool
593 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
594 {
595 void *mem_ctx = ralloc_parent(impl);
596 bool progress = false;
597
598 nir_foreach_block_reverse(block, impl) {
599 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
600 }
601
602 if (progress)
603 nir_metadata_preserve(impl, nir_metadata_block_index |
604 nir_metadata_dominance);
605
606 return progress;
607 }
608
609
610 bool
611 ${pass_name}(nir_shader *shader)
612 {
613 bool progress = false;
614 bool condition_flags[${len(condition_list)}];
615 const nir_shader_compiler_options *options = shader->options;
616 (void) options;
617
618 % for index, condition in enumerate(condition_list):
619 condition_flags[${index}] = ${condition};
620 % endfor
621
622 nir_foreach_function(function, shader) {
623 if (function->impl)
624 progress |= ${pass_name}_impl(function->impl, condition_flags);
625 }
626
627 return progress;
628 }
629 """)
630
631 class AlgebraicPass(object):
632 def __init__(self, pass_name, transforms):
633 self.xform_dict = OrderedDict()
634 self.pass_name = pass_name
635
636 error = False
637
638 for xform in transforms:
639 if not isinstance(xform, SearchAndReplace):
640 try:
641 xform = SearchAndReplace(xform)
642 except:
643 print("Failed to parse transformation:", file=sys.stderr)
644 print(" " + str(xform), file=sys.stderr)
645 traceback.print_exc(file=sys.stderr)
646 print('', file=sys.stderr)
647 error = True
648 continue
649
650 if xform.search.opcode not in self.xform_dict:
651 self.xform_dict[xform.search.opcode] = []
652
653 self.xform_dict[xform.search.opcode].append(xform)
654
655 if error:
656 sys.exit(1)
657
658 def render(self):
659 return _algebraic_pass_template.render(pass_name=self.pass_name,
660 xform_dict=self.xform_dict,
661 condition_list=condition_list)