python: Stabilize some script outputs
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import OrderedDict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
39
40 def type_bits(type_str):
41 m = _type_re.match(type_str)
42 assert m.group('type')
43
44 if m.group('bits') is None:
45 return 0
46 else:
47 return int(m.group('bits'))
48
49 # Represents a set of variables, each with a unique id
50 class VarSet(object):
51 def __init__(self):
52 self.names = {}
53 self.ids = itertools.count()
54 self.immutable = False;
55
56 def __getitem__(self, name):
57 if name not in self.names:
58 assert not self.immutable, "Unknown replacement variable: " + name
59 self.names[name] = self.ids.next()
60
61 return self.names[name]
62
63 def lock(self):
64 self.immutable = True
65
66 class Value(object):
67 @staticmethod
68 def create(val, name_base, varset):
69 if isinstance(val, tuple):
70 return Expression(val, name_base, varset)
71 elif isinstance(val, Expression):
72 return val
73 elif isinstance(val, (str, unicode)):
74 return Variable(val, name_base, varset)
75 elif isinstance(val, (bool, int, long, float)):
76 return Constant(val, name_base)
77
78 __template = mako.template.Template("""
79 static const ${val.c_type} ${val.name} = {
80 { ${val.type_enum}, ${val.bit_size} },
81 % if isinstance(val, Constant):
82 ${val.type()}, { ${hex(val)} /* ${val.value} */ },
83 % elif isinstance(val, Variable):
84 ${val.index}, /* ${val.var_name} */
85 ${'true' if val.is_constant else 'false'},
86 ${val.type() or 'nir_type_invalid' },
87 ${val.cond if val.cond else 'NULL'},
88 % elif isinstance(val, Expression):
89 ${'true' if val.inexact else 'false'},
90 nir_op_${val.opcode},
91 { ${', '.join(src.c_ptr for src in val.sources)} },
92 ${val.cond if val.cond else 'NULL'},
93 % endif
94 };""")
95
96 def __init__(self, name, type_str):
97 self.name = name
98 self.type_str = type_str
99
100 @property
101 def type_enum(self):
102 return "nir_search_value_" + self.type_str
103
104 @property
105 def c_type(self):
106 return "nir_search_" + self.type_str
107
108 @property
109 def c_ptr(self):
110 return "&{0}.value".format(self.name)
111
112 def render(self):
113 return self.__template.render(val=self,
114 Constant=Constant,
115 Variable=Variable,
116 Expression=Expression)
117
118 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
119
120 class Constant(Value):
121 def __init__(self, val, name):
122 Value.__init__(self, name, "constant")
123
124 if isinstance(val, (str)):
125 m = _constant_re.match(val)
126 self.value = ast.literal_eval(m.group('value'))
127 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
128 else:
129 self.value = val
130 self.bit_size = 0
131
132 if isinstance(self.value, bool):
133 assert self.bit_size == 0 or self.bit_size == 32
134 self.bit_size = 32
135
136 def __hex__(self):
137 if isinstance(self.value, (bool)):
138 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
139 if isinstance(self.value, (int, long)):
140 return hex(self.value)
141 elif isinstance(self.value, float):
142 return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
143 else:
144 assert False
145
146 def type(self):
147 if isinstance(self.value, (bool)):
148 return "nir_type_bool32"
149 elif isinstance(self.value, (int, long)):
150 return "nir_type_int"
151 elif isinstance(self.value, float):
152 return "nir_type_float"
153
154 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
155 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
156 r"(?P<cond>\([^\)]+\))?")
157
158 class Variable(Value):
159 def __init__(self, val, name, varset):
160 Value.__init__(self, name, "variable")
161
162 m = _var_name_re.match(val)
163 assert m and m.group('name') is not None
164
165 self.var_name = m.group('name')
166 self.is_constant = m.group('const') is not None
167 self.cond = m.group('cond')
168 self.required_type = m.group('type')
169 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
170
171 if self.required_type == 'bool':
172 assert self.bit_size == 0 or self.bit_size == 32
173 self.bit_size = 32
174
175 if self.required_type is not None:
176 assert self.required_type in ('float', 'bool', 'int', 'uint')
177
178 self.index = varset[self.var_name]
179
180 def type(self):
181 if self.required_type == 'bool':
182 return "nir_type_bool32"
183 elif self.required_type in ('int', 'uint'):
184 return "nir_type_int"
185 elif self.required_type == 'float':
186 return "nir_type_float"
187
188 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
189 r"(?P<cond>\([^\)]+\))?")
190
191 class Expression(Value):
192 def __init__(self, expr, name_base, varset):
193 Value.__init__(self, name_base, "expression")
194 assert isinstance(expr, tuple)
195
196 m = _opcode_re.match(expr[0])
197 assert m and m.group('opcode') is not None
198
199 self.opcode = m.group('opcode')
200 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
201 self.inexact = m.group('inexact') is not None
202 self.cond = m.group('cond')
203 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
204 for (i, src) in enumerate(expr[1:]) ]
205
206 def render(self):
207 srcs = "\n".join(src.render() for src in self.sources)
208 return srcs + super(Expression, self).render()
209
210 class IntEquivalenceRelation(object):
211 """A class representing an equivalence relation on integers.
212
213 Each integer has a canonical form which is the maximum integer to which it
214 is equivalent. Two integers are equivalent precisely when they have the
215 same canonical form.
216
217 The convention of maximum is explicitly chosen to make using it in
218 BitSizeValidator easier because it means that an actual bit_size (if any)
219 will always be the canonical form.
220 """
221 def __init__(self):
222 self._remap = {}
223
224 def get_canonical(self, x):
225 """Get the canonical integer corresponding to x."""
226 if x in self._remap:
227 return self.get_canonical(self._remap[x])
228 else:
229 return x
230
231 def add_equiv(self, a, b):
232 """Add an equivalence and return the canonical form."""
233 c = max(self.get_canonical(a), self.get_canonical(b))
234 if a != c:
235 assert a < c
236 self._remap[a] = c
237
238 if b != c:
239 assert b < c
240 self._remap[b] = c
241
242 return c
243
244 class BitSizeValidator(object):
245 """A class for validating bit sizes of expressions.
246
247 NIR supports multiple bit-sizes on expressions in order to handle things
248 such as fp64. The source and destination of every ALU operation is
249 assigned a type and that type may or may not specify a bit size. Sources
250 and destinations whose type does not specify a bit size are considered
251 "unsized" and automatically take on the bit size of the corresponding
252 register or SSA value. NIR has two simple rules for bit sizes that are
253 validated by nir_validator:
254
255 1) A given SSA def or register has a single bit size that is respected by
256 everything that reads from it or writes to it.
257
258 2) The bit sizes of all unsized inputs/outputs on any given ALU
259 instruction must match. They need not match the sized inputs or
260 outputs but they must match each other.
261
262 In order to keep nir_algebraic relatively simple and easy-to-use,
263 nir_search supports a type of bit-size inference based on the two rules
264 above. This is similar to type inference in many common programming
265 languages. If, for instance, you are constructing an add operation and you
266 know the second source is 16-bit, then you know that the other source and
267 the destination must also be 16-bit. There are, however, cases where this
268 inference can be ambiguous or contradictory. Consider, for instance, the
269 following transformation:
270
271 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
272
273 This transformation can potentially cause a problem because usub_borrow is
274 well-defined for any bit-size of integer. However, b2i always generates a
275 32-bit result so it could end up replacing a 64-bit expression with one
276 that takes two 64-bit values and produces a 32-bit value. As another
277 example, consider this expression:
278
279 (('bcsel', a, b, 0), ('iand', a, b))
280
281 In this case, in the search expression a must be 32-bit but b can
282 potentially have any bit size. If we had a 64-bit b value, we would end up
283 trying to and a 32-bit value with a 64-bit value which would be invalid
284
285 This class solves that problem by providing a validation layer that proves
286 that a given search-and-replace operation is 100% well-defined before we
287 generate any code. This ensures that bugs are caught at compile time
288 rather than at run time.
289
290 The basic operation of the validator is very similar to the bitsize_tree in
291 nir_search only a little more subtle. Instead of simply tracking bit
292 sizes, it tracks "bit classes" where each class is represented by an
293 integer. A value of 0 means we don't know anything yet, positive values
294 are actual bit-sizes, and negative values are used to track equivalence
295 classes of sizes that must be the same but have yet to receive an actual
296 size. The first stage uses the bitsize_tree algorithm to assign bit
297 classes to each variable. If it ever comes across an inconsistency, it
298 assert-fails. Then the second stage uses that information to prove that
299 the resulting expression can always validly be constructed.
300 """
301
302 def __init__(self, varset):
303 self._num_classes = 0
304 self._var_classes = [0] * len(varset.names)
305 self._class_relation = IntEquivalenceRelation()
306
307 def validate(self, search, replace):
308 dst_class = self._propagate_bit_size_up(search)
309 if dst_class == 0:
310 dst_class = self._new_class()
311 self._propagate_bit_class_down(search, dst_class)
312
313 validate_dst_class = self._validate_bit_class_up(replace)
314 assert validate_dst_class == 0 or validate_dst_class == dst_class
315 self._validate_bit_class_down(replace, dst_class)
316
317 def _new_class(self):
318 self._num_classes += 1
319 return -self._num_classes
320
321 def _set_var_bit_class(self, var_id, bit_class):
322 assert bit_class != 0
323 var_class = self._var_classes[var_id]
324 if var_class == 0:
325 self._var_classes[var_id] = bit_class
326 else:
327 canon_class = self._class_relation.get_canonical(var_class)
328 assert canon_class < 0 or canon_class == bit_class
329 var_class = self._class_relation.add_equiv(var_class, bit_class)
330 self._var_classes[var_id] = var_class
331
332 def _get_var_bit_class(self, var_id):
333 return self._class_relation.get_canonical(self._var_classes[var_id])
334
335 def _propagate_bit_size_up(self, val):
336 if isinstance(val, (Constant, Variable)):
337 return val.bit_size
338
339 elif isinstance(val, Expression):
340 nir_op = opcodes[val.opcode]
341 val.common_size = 0
342 for i in range(nir_op.num_inputs):
343 src_bits = self._propagate_bit_size_up(val.sources[i])
344 if src_bits == 0:
345 continue
346
347 src_type_bits = type_bits(nir_op.input_types[i])
348 if src_type_bits != 0:
349 assert src_bits == src_type_bits
350 else:
351 assert val.common_size == 0 or src_bits == val.common_size
352 val.common_size = src_bits
353
354 dst_type_bits = type_bits(nir_op.output_type)
355 if dst_type_bits != 0:
356 assert val.bit_size == 0 or val.bit_size == dst_type_bits
357 return dst_type_bits
358 else:
359 if val.common_size != 0:
360 assert val.bit_size == 0 or val.bit_size == val.common_size
361 else:
362 val.common_size = val.bit_size
363 return val.common_size
364
365 def _propagate_bit_class_down(self, val, bit_class):
366 if isinstance(val, Constant):
367 assert val.bit_size == 0 or val.bit_size == bit_class
368
369 elif isinstance(val, Variable):
370 assert val.bit_size == 0 or val.bit_size == bit_class
371 self._set_var_bit_class(val.index, bit_class)
372
373 elif isinstance(val, Expression):
374 nir_op = opcodes[val.opcode]
375 dst_type_bits = type_bits(nir_op.output_type)
376 if dst_type_bits != 0:
377 assert bit_class == 0 or bit_class == dst_type_bits
378 else:
379 assert val.common_size == 0 or val.common_size == bit_class
380 val.common_size = bit_class
381
382 if val.common_size:
383 common_class = val.common_size
384 elif nir_op.num_inputs:
385 # If we got here then we have no idea what the actual size is.
386 # Instead, we use a generic class
387 common_class = self._new_class()
388
389 for i in range(nir_op.num_inputs):
390 src_type_bits = type_bits(nir_op.input_types[i])
391 if src_type_bits != 0:
392 self._propagate_bit_class_down(val.sources[i], src_type_bits)
393 else:
394 self._propagate_bit_class_down(val.sources[i], common_class)
395
396 def _validate_bit_class_up(self, val):
397 if isinstance(val, Constant):
398 return val.bit_size
399
400 elif isinstance(val, Variable):
401 var_class = self._get_var_bit_class(val.index)
402 # By the time we get to validation, every variable should have a class
403 assert var_class != 0
404
405 # If we have an explicit size provided by the user, the variable
406 # *must* exactly match the search. It cannot be implicitly sized
407 # because otherwise we could end up with a conflict at runtime.
408 assert val.bit_size == 0 or val.bit_size == var_class
409
410 return var_class
411
412 elif isinstance(val, Expression):
413 nir_op = opcodes[val.opcode]
414 val.common_class = 0
415 for i in range(nir_op.num_inputs):
416 src_class = self._validate_bit_class_up(val.sources[i])
417 if src_class == 0:
418 continue
419
420 src_type_bits = type_bits(nir_op.input_types[i])
421 if src_type_bits != 0:
422 assert src_class == src_type_bits
423 else:
424 assert val.common_class == 0 or src_class == val.common_class
425 val.common_class = src_class
426
427 dst_type_bits = type_bits(nir_op.output_type)
428 if dst_type_bits != 0:
429 assert val.bit_size == 0 or val.bit_size == dst_type_bits
430 return dst_type_bits
431 else:
432 if val.common_class != 0:
433 assert val.bit_size == 0 or val.bit_size == val.common_class
434 else:
435 val.common_class = val.bit_size
436 return val.common_class
437
438 def _validate_bit_class_down(self, val, bit_class):
439 # At this point, everything *must* have a bit class. Otherwise, we have
440 # a value we don't know how to define.
441 assert bit_class != 0
442
443 if isinstance(val, Constant):
444 assert val.bit_size == 0 or val.bit_size == bit_class
445
446 elif isinstance(val, Variable):
447 assert val.bit_size == 0 or val.bit_size == bit_class
448
449 elif isinstance(val, Expression):
450 nir_op = opcodes[val.opcode]
451 dst_type_bits = type_bits(nir_op.output_type)
452 if dst_type_bits != 0:
453 assert bit_class == dst_type_bits
454 else:
455 assert val.common_class == 0 or val.common_class == bit_class
456 val.common_class = bit_class
457
458 for i in range(nir_op.num_inputs):
459 src_type_bits = type_bits(nir_op.input_types[i])
460 if src_type_bits != 0:
461 self._validate_bit_class_down(val.sources[i], src_type_bits)
462 else:
463 self._validate_bit_class_down(val.sources[i], val.common_class)
464
465 _optimization_ids = itertools.count()
466
467 condition_list = ['true']
468
469 class SearchAndReplace(object):
470 def __init__(self, transform):
471 self.id = _optimization_ids.next()
472
473 search = transform[0]
474 replace = transform[1]
475 if len(transform) > 2:
476 self.condition = transform[2]
477 else:
478 self.condition = 'true'
479
480 if self.condition not in condition_list:
481 condition_list.append(self.condition)
482 self.condition_index = condition_list.index(self.condition)
483
484 varset = VarSet()
485 if isinstance(search, Expression):
486 self.search = search
487 else:
488 self.search = Expression(search, "search{0}".format(self.id), varset)
489
490 varset.lock()
491
492 if isinstance(replace, Value):
493 self.replace = replace
494 else:
495 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
496
497 BitSizeValidator(varset).validate(self.search, self.replace)
498
499 _algebraic_pass_template = mako.template.Template("""
500 #include "nir.h"
501 #include "nir_search.h"
502 #include "nir_search_helpers.h"
503
504 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
505 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
506
507 struct transform {
508 const nir_search_expression *search;
509 const nir_search_value *replace;
510 unsigned condition_offset;
511 };
512
513 #endif
514
515 % for (opcode, xform_list) in xform_dict.iteritems():
516 % for xform in xform_list:
517 ${xform.search.render()}
518 ${xform.replace.render()}
519 % endfor
520
521 static const struct transform ${pass_name}_${opcode}_xforms[] = {
522 % for xform in xform_list:
523 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
524 % endfor
525 };
526 % endfor
527
528 static bool
529 ${pass_name}_block(nir_block *block, const bool *condition_flags,
530 void *mem_ctx)
531 {
532 bool progress = false;
533
534 nir_foreach_instr_reverse_safe(instr, block) {
535 if (instr->type != nir_instr_type_alu)
536 continue;
537
538 nir_alu_instr *alu = nir_instr_as_alu(instr);
539 if (!alu->dest.dest.is_ssa)
540 continue;
541
542 switch (alu->op) {
543 % for opcode in xform_dict.keys():
544 case nir_op_${opcode}:
545 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
546 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
547 if (condition_flags[xform->condition_offset] &&
548 nir_replace_instr(alu, xform->search, xform->replace,
549 mem_ctx)) {
550 progress = true;
551 break;
552 }
553 }
554 break;
555 % endfor
556 default:
557 break;
558 }
559 }
560
561 return progress;
562 }
563
564 static bool
565 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
566 {
567 void *mem_ctx = ralloc_parent(impl);
568 bool progress = false;
569
570 nir_foreach_block_reverse(block, impl) {
571 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
572 }
573
574 if (progress)
575 nir_metadata_preserve(impl, nir_metadata_block_index |
576 nir_metadata_dominance);
577
578 return progress;
579 }
580
581
582 bool
583 ${pass_name}(nir_shader *shader)
584 {
585 bool progress = false;
586 bool condition_flags[${len(condition_list)}];
587 const nir_shader_compiler_options *options = shader->options;
588 (void) options;
589
590 % for index, condition in enumerate(condition_list):
591 condition_flags[${index}] = ${condition};
592 % endfor
593
594 nir_foreach_function(function, shader) {
595 if (function->impl)
596 progress |= ${pass_name}_impl(function->impl, condition_flags);
597 }
598
599 return progress;
600 }
601 """)
602
603 class AlgebraicPass(object):
604 def __init__(self, pass_name, transforms):
605 self.xform_dict = OrderedDict()
606 self.pass_name = pass_name
607
608 error = False
609
610 for xform in transforms:
611 if not isinstance(xform, SearchAndReplace):
612 try:
613 xform = SearchAndReplace(xform)
614 except:
615 print("Failed to parse transformation:", file=sys.stderr)
616 print(" " + str(xform), file=sys.stderr)
617 traceback.print_exc(file=sys.stderr)
618 print('', file=sys.stderr)
619 error = True
620 continue
621
622 if xform.search.opcode not in self.xform_dict:
623 self.xform_dict[xform.search.opcode] = []
624
625 self.xform_dict[xform.search.opcode].append(xform)
626
627 if error:
628 sys.exit(1)
629
630 def render(self):
631 return _algebraic_pass_template.render(pass_name=self.pass_name,
632 xform_dict=self.xform_dict,
633 condition_list=condition_list)