python: Explicitly add the 'L' suffix on Python 3
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import OrderedDict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
39
40 def type_bits(type_str):
41 m = _type_re.match(type_str)
42 assert m.group('type')
43
44 if m.group('bits') is None:
45 return 0
46 else:
47 return int(m.group('bits'))
48
49 # Represents a set of variables, each with a unique id
50 class VarSet(object):
51 def __init__(self):
52 self.names = {}
53 self.ids = itertools.count()
54 self.immutable = False;
55
56 def __getitem__(self, name):
57 if name not in self.names:
58 assert not self.immutable, "Unknown replacement variable: " + name
59 self.names[name] = next(self.ids)
60
61 return self.names[name]
62
63 def lock(self):
64 self.immutable = True
65
66 class Value(object):
67 @staticmethod
68 def create(val, name_base, varset):
69 if isinstance(val, tuple):
70 return Expression(val, name_base, varset)
71 elif isinstance(val, Expression):
72 return val
73 elif isinstance(val, (str, unicode)):
74 return Variable(val, name_base, varset)
75 elif isinstance(val, (bool, int, long, float)):
76 return Constant(val, name_base)
77
78 __template = mako.template.Template("""
79 static const ${val.c_type} ${val.name} = {
80 { ${val.type_enum}, ${val.bit_size} },
81 % if isinstance(val, Constant):
82 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
83 % elif isinstance(val, Variable):
84 ${val.index}, /* ${val.var_name} */
85 ${'true' if val.is_constant else 'false'},
86 ${val.type() or 'nir_type_invalid' },
87 ${val.cond if val.cond else 'NULL'},
88 % elif isinstance(val, Expression):
89 ${'true' if val.inexact else 'false'},
90 nir_op_${val.opcode},
91 { ${', '.join(src.c_ptr for src in val.sources)} },
92 ${val.cond if val.cond else 'NULL'},
93 % endif
94 };""")
95
96 def __init__(self, name, type_str):
97 self.name = name
98 self.type_str = type_str
99
100 @property
101 def type_enum(self):
102 return "nir_search_value_" + self.type_str
103
104 @property
105 def c_type(self):
106 return "nir_search_" + self.type_str
107
108 @property
109 def c_ptr(self):
110 return "&{0}.value".format(self.name)
111
112 def render(self):
113 return self.__template.render(val=self,
114 Constant=Constant,
115 Variable=Variable,
116 Expression=Expression)
117
118 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
119
120 class Constant(Value):
121 def __init__(self, val, name):
122 Value.__init__(self, name, "constant")
123
124 if isinstance(val, (str)):
125 m = _constant_re.match(val)
126 self.value = ast.literal_eval(m.group('value'))
127 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
128 else:
129 self.value = val
130 self.bit_size = 0
131
132 if isinstance(self.value, bool):
133 assert self.bit_size == 0 or self.bit_size == 32
134 self.bit_size = 32
135
136 def hex(self):
137 if isinstance(self.value, (bool)):
138 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
139 if isinstance(self.value, (int, long)):
140 return hex(self.value)
141 elif isinstance(self.value, float):
142 i = struct.unpack('Q', struct.pack('d', self.value))[0]
143 h = hex(i)
144
145 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
146 # Adding it explicitly makes the generated file identical, regardless
147 # of the Python version running this script.
148 if h[-1] != 'L' and i > sys.maxsize:
149 h += 'L'
150
151 return h
152 else:
153 assert False
154
155 def type(self):
156 if isinstance(self.value, (bool)):
157 return "nir_type_bool32"
158 elif isinstance(self.value, (int, long)):
159 return "nir_type_int"
160 elif isinstance(self.value, float):
161 return "nir_type_float"
162
163 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
164 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
165 r"(?P<cond>\([^\)]+\))?")
166
167 class Variable(Value):
168 def __init__(self, val, name, varset):
169 Value.__init__(self, name, "variable")
170
171 m = _var_name_re.match(val)
172 assert m and m.group('name') is not None
173
174 self.var_name = m.group('name')
175 self.is_constant = m.group('const') is not None
176 self.cond = m.group('cond')
177 self.required_type = m.group('type')
178 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
179
180 if self.required_type == 'bool':
181 assert self.bit_size == 0 or self.bit_size == 32
182 self.bit_size = 32
183
184 if self.required_type is not None:
185 assert self.required_type in ('float', 'bool', 'int', 'uint')
186
187 self.index = varset[self.var_name]
188
189 def type(self):
190 if self.required_type == 'bool':
191 return "nir_type_bool32"
192 elif self.required_type in ('int', 'uint'):
193 return "nir_type_int"
194 elif self.required_type == 'float':
195 return "nir_type_float"
196
197 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
198 r"(?P<cond>\([^\)]+\))?")
199
200 class Expression(Value):
201 def __init__(self, expr, name_base, varset):
202 Value.__init__(self, name_base, "expression")
203 assert isinstance(expr, tuple)
204
205 m = _opcode_re.match(expr[0])
206 assert m and m.group('opcode') is not None
207
208 self.opcode = m.group('opcode')
209 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
210 self.inexact = m.group('inexact') is not None
211 self.cond = m.group('cond')
212 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
213 for (i, src) in enumerate(expr[1:]) ]
214
215 def render(self):
216 srcs = "\n".join(src.render() for src in self.sources)
217 return srcs + super(Expression, self).render()
218
219 class IntEquivalenceRelation(object):
220 """A class representing an equivalence relation on integers.
221
222 Each integer has a canonical form which is the maximum integer to which it
223 is equivalent. Two integers are equivalent precisely when they have the
224 same canonical form.
225
226 The convention of maximum is explicitly chosen to make using it in
227 BitSizeValidator easier because it means that an actual bit_size (if any)
228 will always be the canonical form.
229 """
230 def __init__(self):
231 self._remap = {}
232
233 def get_canonical(self, x):
234 """Get the canonical integer corresponding to x."""
235 if x in self._remap:
236 return self.get_canonical(self._remap[x])
237 else:
238 return x
239
240 def add_equiv(self, a, b):
241 """Add an equivalence and return the canonical form."""
242 c = max(self.get_canonical(a), self.get_canonical(b))
243 if a != c:
244 assert a < c
245 self._remap[a] = c
246
247 if b != c:
248 assert b < c
249 self._remap[b] = c
250
251 return c
252
253 class BitSizeValidator(object):
254 """A class for validating bit sizes of expressions.
255
256 NIR supports multiple bit-sizes on expressions in order to handle things
257 such as fp64. The source and destination of every ALU operation is
258 assigned a type and that type may or may not specify a bit size. Sources
259 and destinations whose type does not specify a bit size are considered
260 "unsized" and automatically take on the bit size of the corresponding
261 register or SSA value. NIR has two simple rules for bit sizes that are
262 validated by nir_validator:
263
264 1) A given SSA def or register has a single bit size that is respected by
265 everything that reads from it or writes to it.
266
267 2) The bit sizes of all unsized inputs/outputs on any given ALU
268 instruction must match. They need not match the sized inputs or
269 outputs but they must match each other.
270
271 In order to keep nir_algebraic relatively simple and easy-to-use,
272 nir_search supports a type of bit-size inference based on the two rules
273 above. This is similar to type inference in many common programming
274 languages. If, for instance, you are constructing an add operation and you
275 know the second source is 16-bit, then you know that the other source and
276 the destination must also be 16-bit. There are, however, cases where this
277 inference can be ambiguous or contradictory. Consider, for instance, the
278 following transformation:
279
280 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
281
282 This transformation can potentially cause a problem because usub_borrow is
283 well-defined for any bit-size of integer. However, b2i always generates a
284 32-bit result so it could end up replacing a 64-bit expression with one
285 that takes two 64-bit values and produces a 32-bit value. As another
286 example, consider this expression:
287
288 (('bcsel', a, b, 0), ('iand', a, b))
289
290 In this case, in the search expression a must be 32-bit but b can
291 potentially have any bit size. If we had a 64-bit b value, we would end up
292 trying to and a 32-bit value with a 64-bit value which would be invalid
293
294 This class solves that problem by providing a validation layer that proves
295 that a given search-and-replace operation is 100% well-defined before we
296 generate any code. This ensures that bugs are caught at compile time
297 rather than at run time.
298
299 The basic operation of the validator is very similar to the bitsize_tree in
300 nir_search only a little more subtle. Instead of simply tracking bit
301 sizes, it tracks "bit classes" where each class is represented by an
302 integer. A value of 0 means we don't know anything yet, positive values
303 are actual bit-sizes, and negative values are used to track equivalence
304 classes of sizes that must be the same but have yet to receive an actual
305 size. The first stage uses the bitsize_tree algorithm to assign bit
306 classes to each variable. If it ever comes across an inconsistency, it
307 assert-fails. Then the second stage uses that information to prove that
308 the resulting expression can always validly be constructed.
309 """
310
311 def __init__(self, varset):
312 self._num_classes = 0
313 self._var_classes = [0] * len(varset.names)
314 self._class_relation = IntEquivalenceRelation()
315
316 def validate(self, search, replace):
317 dst_class = self._propagate_bit_size_up(search)
318 if dst_class == 0:
319 dst_class = self._new_class()
320 self._propagate_bit_class_down(search, dst_class)
321
322 validate_dst_class = self._validate_bit_class_up(replace)
323 assert validate_dst_class == 0 or validate_dst_class == dst_class
324 self._validate_bit_class_down(replace, dst_class)
325
326 def _new_class(self):
327 self._num_classes += 1
328 return -self._num_classes
329
330 def _set_var_bit_class(self, var_id, bit_class):
331 assert bit_class != 0
332 var_class = self._var_classes[var_id]
333 if var_class == 0:
334 self._var_classes[var_id] = bit_class
335 else:
336 canon_class = self._class_relation.get_canonical(var_class)
337 assert canon_class < 0 or canon_class == bit_class
338 var_class = self._class_relation.add_equiv(var_class, bit_class)
339 self._var_classes[var_id] = var_class
340
341 def _get_var_bit_class(self, var_id):
342 return self._class_relation.get_canonical(self._var_classes[var_id])
343
344 def _propagate_bit_size_up(self, val):
345 if isinstance(val, (Constant, Variable)):
346 return val.bit_size
347
348 elif isinstance(val, Expression):
349 nir_op = opcodes[val.opcode]
350 val.common_size = 0
351 for i in range(nir_op.num_inputs):
352 src_bits = self._propagate_bit_size_up(val.sources[i])
353 if src_bits == 0:
354 continue
355
356 src_type_bits = type_bits(nir_op.input_types[i])
357 if src_type_bits != 0:
358 assert src_bits == src_type_bits
359 else:
360 assert val.common_size == 0 or src_bits == val.common_size
361 val.common_size = src_bits
362
363 dst_type_bits = type_bits(nir_op.output_type)
364 if dst_type_bits != 0:
365 assert val.bit_size == 0 or val.bit_size == dst_type_bits
366 return dst_type_bits
367 else:
368 if val.common_size != 0:
369 assert val.bit_size == 0 or val.bit_size == val.common_size
370 else:
371 val.common_size = val.bit_size
372 return val.common_size
373
374 def _propagate_bit_class_down(self, val, bit_class):
375 if isinstance(val, Constant):
376 assert val.bit_size == 0 or val.bit_size == bit_class
377
378 elif isinstance(val, Variable):
379 assert val.bit_size == 0 or val.bit_size == bit_class
380 self._set_var_bit_class(val.index, bit_class)
381
382 elif isinstance(val, Expression):
383 nir_op = opcodes[val.opcode]
384 dst_type_bits = type_bits(nir_op.output_type)
385 if dst_type_bits != 0:
386 assert bit_class == 0 or bit_class == dst_type_bits
387 else:
388 assert val.common_size == 0 or val.common_size == bit_class
389 val.common_size = bit_class
390
391 if val.common_size:
392 common_class = val.common_size
393 elif nir_op.num_inputs:
394 # If we got here then we have no idea what the actual size is.
395 # Instead, we use a generic class
396 common_class = self._new_class()
397
398 for i in range(nir_op.num_inputs):
399 src_type_bits = type_bits(nir_op.input_types[i])
400 if src_type_bits != 0:
401 self._propagate_bit_class_down(val.sources[i], src_type_bits)
402 else:
403 self._propagate_bit_class_down(val.sources[i], common_class)
404
405 def _validate_bit_class_up(self, val):
406 if isinstance(val, Constant):
407 return val.bit_size
408
409 elif isinstance(val, Variable):
410 var_class = self._get_var_bit_class(val.index)
411 # By the time we get to validation, every variable should have a class
412 assert var_class != 0
413
414 # If we have an explicit size provided by the user, the variable
415 # *must* exactly match the search. It cannot be implicitly sized
416 # because otherwise we could end up with a conflict at runtime.
417 assert val.bit_size == 0 or val.bit_size == var_class
418
419 return var_class
420
421 elif isinstance(val, Expression):
422 nir_op = opcodes[val.opcode]
423 val.common_class = 0
424 for i in range(nir_op.num_inputs):
425 src_class = self._validate_bit_class_up(val.sources[i])
426 if src_class == 0:
427 continue
428
429 src_type_bits = type_bits(nir_op.input_types[i])
430 if src_type_bits != 0:
431 assert src_class == src_type_bits
432 else:
433 assert val.common_class == 0 or src_class == val.common_class
434 val.common_class = src_class
435
436 dst_type_bits = type_bits(nir_op.output_type)
437 if dst_type_bits != 0:
438 assert val.bit_size == 0 or val.bit_size == dst_type_bits
439 return dst_type_bits
440 else:
441 if val.common_class != 0:
442 assert val.bit_size == 0 or val.bit_size == val.common_class
443 else:
444 val.common_class = val.bit_size
445 return val.common_class
446
447 def _validate_bit_class_down(self, val, bit_class):
448 # At this point, everything *must* have a bit class. Otherwise, we have
449 # a value we don't know how to define.
450 assert bit_class != 0
451
452 if isinstance(val, Constant):
453 assert val.bit_size == 0 or val.bit_size == bit_class
454
455 elif isinstance(val, Variable):
456 assert val.bit_size == 0 or val.bit_size == bit_class
457
458 elif isinstance(val, Expression):
459 nir_op = opcodes[val.opcode]
460 dst_type_bits = type_bits(nir_op.output_type)
461 if dst_type_bits != 0:
462 assert bit_class == dst_type_bits
463 else:
464 assert val.common_class == 0 or val.common_class == bit_class
465 val.common_class = bit_class
466
467 for i in range(nir_op.num_inputs):
468 src_type_bits = type_bits(nir_op.input_types[i])
469 if src_type_bits != 0:
470 self._validate_bit_class_down(val.sources[i], src_type_bits)
471 else:
472 self._validate_bit_class_down(val.sources[i], val.common_class)
473
474 _optimization_ids = itertools.count()
475
476 condition_list = ['true']
477
478 class SearchAndReplace(object):
479 def __init__(self, transform):
480 self.id = next(_optimization_ids)
481
482 search = transform[0]
483 replace = transform[1]
484 if len(transform) > 2:
485 self.condition = transform[2]
486 else:
487 self.condition = 'true'
488
489 if self.condition not in condition_list:
490 condition_list.append(self.condition)
491 self.condition_index = condition_list.index(self.condition)
492
493 varset = VarSet()
494 if isinstance(search, Expression):
495 self.search = search
496 else:
497 self.search = Expression(search, "search{0}".format(self.id), varset)
498
499 varset.lock()
500
501 if isinstance(replace, Value):
502 self.replace = replace
503 else:
504 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
505
506 BitSizeValidator(varset).validate(self.search, self.replace)
507
508 _algebraic_pass_template = mako.template.Template("""
509 #include "nir.h"
510 #include "nir_search.h"
511 #include "nir_search_helpers.h"
512
513 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
514 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
515
516 struct transform {
517 const nir_search_expression *search;
518 const nir_search_value *replace;
519 unsigned condition_offset;
520 };
521
522 #endif
523
524 % for (opcode, xform_list) in xform_dict.items():
525 % for xform in xform_list:
526 ${xform.search.render()}
527 ${xform.replace.render()}
528 % endfor
529
530 static const struct transform ${pass_name}_${opcode}_xforms[] = {
531 % for xform in xform_list:
532 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
533 % endfor
534 };
535 % endfor
536
537 static bool
538 ${pass_name}_block(nir_block *block, const bool *condition_flags,
539 void *mem_ctx)
540 {
541 bool progress = false;
542
543 nir_foreach_instr_reverse_safe(instr, block) {
544 if (instr->type != nir_instr_type_alu)
545 continue;
546
547 nir_alu_instr *alu = nir_instr_as_alu(instr);
548 if (!alu->dest.dest.is_ssa)
549 continue;
550
551 switch (alu->op) {
552 % for opcode in xform_dict.keys():
553 case nir_op_${opcode}:
554 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
555 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
556 if (condition_flags[xform->condition_offset] &&
557 nir_replace_instr(alu, xform->search, xform->replace,
558 mem_ctx)) {
559 progress = true;
560 break;
561 }
562 }
563 break;
564 % endfor
565 default:
566 break;
567 }
568 }
569
570 return progress;
571 }
572
573 static bool
574 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
575 {
576 void *mem_ctx = ralloc_parent(impl);
577 bool progress = false;
578
579 nir_foreach_block_reverse(block, impl) {
580 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
581 }
582
583 if (progress)
584 nir_metadata_preserve(impl, nir_metadata_block_index |
585 nir_metadata_dominance);
586
587 return progress;
588 }
589
590
591 bool
592 ${pass_name}(nir_shader *shader)
593 {
594 bool progress = false;
595 bool condition_flags[${len(condition_list)}];
596 const nir_shader_compiler_options *options = shader->options;
597 (void) options;
598
599 % for index, condition in enumerate(condition_list):
600 condition_flags[${index}] = ${condition};
601 % endfor
602
603 nir_foreach_function(function, shader) {
604 if (function->impl)
605 progress |= ${pass_name}_impl(function->impl, condition_flags);
606 }
607
608 return progress;
609 }
610 """)
611
612 class AlgebraicPass(object):
613 def __init__(self, pass_name, transforms):
614 self.xform_dict = OrderedDict()
615 self.pass_name = pass_name
616
617 error = False
618
619 for xform in transforms:
620 if not isinstance(xform, SearchAndReplace):
621 try:
622 xform = SearchAndReplace(xform)
623 except:
624 print("Failed to parse transformation:", file=sys.stderr)
625 print(" " + str(xform), file=sys.stderr)
626 traceback.print_exc(file=sys.stderr)
627 print('', file=sys.stderr)
628 error = True
629 continue
630
631 if xform.search.opcode not in self.xform_dict:
632 self.xform_dict[xform.search.opcode] = []
633
634 self.xform_dict[xform.search.opcode].append(xform)
635
636 if error:
637 sys.exit(1)
638
639 def render(self):
640 return _algebraic_pass_template.render(pass_name=self.pass_name,
641 xform_dict=self.xform_dict,
642 condition_list=condition_list)