nir: Add floating point atomic add instrinsics
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #
2 # Copyright (C) 2014 Intel Corporation
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Jason Ekstrand (jason@jlekstrand.net)
25
26 from __future__ import print_function
27 import ast
28 from collections import OrderedDict
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 if sys.version_info < (3, 0):
39 integer_types = (int, long)
40 string_type = unicode
41
42 else:
43 integer_types = (int, )
44 string_type = str
45
46 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
47
48 def type_bits(type_str):
49 m = _type_re.match(type_str)
50 assert m.group('type')
51
52 if m.group('bits') is None:
53 return 0
54 else:
55 return int(m.group('bits'))
56
57 # Represents a set of variables, each with a unique id
58 class VarSet(object):
59 def __init__(self):
60 self.names = {}
61 self.ids = itertools.count()
62 self.immutable = False;
63
64 def __getitem__(self, name):
65 if name not in self.names:
66 assert not self.immutable, "Unknown replacement variable: " + name
67 self.names[name] = next(self.ids)
68
69 return self.names[name]
70
71 def lock(self):
72 self.immutable = True
73
74 class Value(object):
75 @staticmethod
76 def create(val, name_base, varset):
77 if isinstance(val, bytes):
78 val = val.decode('utf-8')
79
80 if isinstance(val, tuple):
81 return Expression(val, name_base, varset)
82 elif isinstance(val, Expression):
83 return val
84 elif isinstance(val, string_type):
85 return Variable(val, name_base, varset)
86 elif isinstance(val, (bool, float) + integer_types):
87 return Constant(val, name_base)
88
89 __template = mako.template.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
104 % endif
105 };""")
106
107 def __init__(self, name, type_str):
108 self.name = name
109 self.type_str = type_str
110
111 @property
112 def type_enum(self):
113 return "nir_search_value_" + self.type_str
114
115 @property
116 def c_type(self):
117 return "nir_search_" + self.type_str
118
119 @property
120 def c_ptr(self):
121 return "&{0}.value".format(self.name)
122
123 def render(self):
124 return self.__template.render(val=self,
125 Constant=Constant,
126 Variable=Variable,
127 Expression=Expression)
128
129 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
130
131 class Constant(Value):
132 def __init__(self, val, name):
133 Value.__init__(self, name, "constant")
134
135 if isinstance(val, (str)):
136 m = _constant_re.match(val)
137 self.value = ast.literal_eval(m.group('value'))
138 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
139 else:
140 self.value = val
141 self.bit_size = 0
142
143 if isinstance(self.value, bool):
144 assert self.bit_size == 0 or self.bit_size == 32
145 self.bit_size = 32
146
147 def hex(self):
148 if isinstance(self.value, (bool)):
149 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
150 if isinstance(self.value, integer_types):
151 return hex(self.value)
152 elif isinstance(self.value, float):
153 i = struct.unpack('Q', struct.pack('d', self.value))[0]
154 h = hex(i)
155
156 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
157 # Adding it explicitly makes the generated file identical, regardless
158 # of the Python version running this script.
159 if h[-1] != 'L' and i > sys.maxsize:
160 h += 'L'
161
162 return h
163 else:
164 assert False
165
166 def type(self):
167 if isinstance(self.value, (bool)):
168 return "nir_type_bool32"
169 elif isinstance(self.value, integer_types):
170 return "nir_type_int"
171 elif isinstance(self.value, float):
172 return "nir_type_float"
173
174 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
175 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
176 r"(?P<cond>\([^\)]+\))?")
177
178 class Variable(Value):
179 def __init__(self, val, name, varset):
180 Value.__init__(self, name, "variable")
181
182 m = _var_name_re.match(val)
183 assert m and m.group('name') is not None
184
185 self.var_name = m.group('name')
186 self.is_constant = m.group('const') is not None
187 self.cond = m.group('cond')
188 self.required_type = m.group('type')
189 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
190
191 if self.required_type == 'bool':
192 assert self.bit_size == 0 or self.bit_size == 32
193 self.bit_size = 32
194
195 if self.required_type is not None:
196 assert self.required_type in ('float', 'bool', 'int', 'uint')
197
198 self.index = varset[self.var_name]
199
200 def type(self):
201 if self.required_type == 'bool':
202 return "nir_type_bool32"
203 elif self.required_type in ('int', 'uint'):
204 return "nir_type_int"
205 elif self.required_type == 'float':
206 return "nir_type_float"
207
208 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
209 r"(?P<cond>\([^\)]+\))?")
210
211 class Expression(Value):
212 def __init__(self, expr, name_base, varset):
213 Value.__init__(self, name_base, "expression")
214 assert isinstance(expr, tuple)
215
216 m = _opcode_re.match(expr[0])
217 assert m and m.group('opcode') is not None
218
219 self.opcode = m.group('opcode')
220 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
221 self.inexact = m.group('inexact') is not None
222 self.cond = m.group('cond')
223 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
224 for (i, src) in enumerate(expr[1:]) ]
225
226 def render(self):
227 srcs = "\n".join(src.render() for src in self.sources)
228 return srcs + super(Expression, self).render()
229
230 class IntEquivalenceRelation(object):
231 """A class representing an equivalence relation on integers.
232
233 Each integer has a canonical form which is the maximum integer to which it
234 is equivalent. Two integers are equivalent precisely when they have the
235 same canonical form.
236
237 The convention of maximum is explicitly chosen to make using it in
238 BitSizeValidator easier because it means that an actual bit_size (if any)
239 will always be the canonical form.
240 """
241 def __init__(self):
242 self._remap = {}
243
244 def get_canonical(self, x):
245 """Get the canonical integer corresponding to x."""
246 if x in self._remap:
247 return self.get_canonical(self._remap[x])
248 else:
249 return x
250
251 def add_equiv(self, a, b):
252 """Add an equivalence and return the canonical form."""
253 c = max(self.get_canonical(a), self.get_canonical(b))
254 if a != c:
255 assert a < c
256 self._remap[a] = c
257
258 if b != c:
259 assert b < c
260 self._remap[b] = c
261
262 return c
263
264 class BitSizeValidator(object):
265 """A class for validating bit sizes of expressions.
266
267 NIR supports multiple bit-sizes on expressions in order to handle things
268 such as fp64. The source and destination of every ALU operation is
269 assigned a type and that type may or may not specify a bit size. Sources
270 and destinations whose type does not specify a bit size are considered
271 "unsized" and automatically take on the bit size of the corresponding
272 register or SSA value. NIR has two simple rules for bit sizes that are
273 validated by nir_validator:
274
275 1) A given SSA def or register has a single bit size that is respected by
276 everything that reads from it or writes to it.
277
278 2) The bit sizes of all unsized inputs/outputs on any given ALU
279 instruction must match. They need not match the sized inputs or
280 outputs but they must match each other.
281
282 In order to keep nir_algebraic relatively simple and easy-to-use,
283 nir_search supports a type of bit-size inference based on the two rules
284 above. This is similar to type inference in many common programming
285 languages. If, for instance, you are constructing an add operation and you
286 know the second source is 16-bit, then you know that the other source and
287 the destination must also be 16-bit. There are, however, cases where this
288 inference can be ambiguous or contradictory. Consider, for instance, the
289 following transformation:
290
291 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
292
293 This transformation can potentially cause a problem because usub_borrow is
294 well-defined for any bit-size of integer. However, b2i always generates a
295 32-bit result so it could end up replacing a 64-bit expression with one
296 that takes two 64-bit values and produces a 32-bit value. As another
297 example, consider this expression:
298
299 (('bcsel', a, b, 0), ('iand', a, b))
300
301 In this case, in the search expression a must be 32-bit but b can
302 potentially have any bit size. If we had a 64-bit b value, we would end up
303 trying to and a 32-bit value with a 64-bit value which would be invalid
304
305 This class solves that problem by providing a validation layer that proves
306 that a given search-and-replace operation is 100% well-defined before we
307 generate any code. This ensures that bugs are caught at compile time
308 rather than at run time.
309
310 The basic operation of the validator is very similar to the bitsize_tree in
311 nir_search only a little more subtle. Instead of simply tracking bit
312 sizes, it tracks "bit classes" where each class is represented by an
313 integer. A value of 0 means we don't know anything yet, positive values
314 are actual bit-sizes, and negative values are used to track equivalence
315 classes of sizes that must be the same but have yet to receive an actual
316 size. The first stage uses the bitsize_tree algorithm to assign bit
317 classes to each variable. If it ever comes across an inconsistency, it
318 assert-fails. Then the second stage uses that information to prove that
319 the resulting expression can always validly be constructed.
320 """
321
322 def __init__(self, varset):
323 self._num_classes = 0
324 self._var_classes = [0] * len(varset.names)
325 self._class_relation = IntEquivalenceRelation()
326
327 def validate(self, search, replace):
328 dst_class = self._propagate_bit_size_up(search)
329 if dst_class == 0:
330 dst_class = self._new_class()
331 self._propagate_bit_class_down(search, dst_class)
332
333 validate_dst_class = self._validate_bit_class_up(replace)
334 assert validate_dst_class == 0 or validate_dst_class == dst_class
335 self._validate_bit_class_down(replace, dst_class)
336
337 def _new_class(self):
338 self._num_classes += 1
339 return -self._num_classes
340
341 def _set_var_bit_class(self, var_id, bit_class):
342 assert bit_class != 0
343 var_class = self._var_classes[var_id]
344 if var_class == 0:
345 self._var_classes[var_id] = bit_class
346 else:
347 canon_class = self._class_relation.get_canonical(var_class)
348 assert canon_class < 0 or canon_class == bit_class
349 var_class = self._class_relation.add_equiv(var_class, bit_class)
350 self._var_classes[var_id] = var_class
351
352 def _get_var_bit_class(self, var_id):
353 return self._class_relation.get_canonical(self._var_classes[var_id])
354
355 def _propagate_bit_size_up(self, val):
356 if isinstance(val, (Constant, Variable)):
357 return val.bit_size
358
359 elif isinstance(val, Expression):
360 nir_op = opcodes[val.opcode]
361 val.common_size = 0
362 for i in range(nir_op.num_inputs):
363 src_bits = self._propagate_bit_size_up(val.sources[i])
364 if src_bits == 0:
365 continue
366
367 src_type_bits = type_bits(nir_op.input_types[i])
368 if src_type_bits != 0:
369 assert src_bits == src_type_bits
370 else:
371 assert val.common_size == 0 or src_bits == val.common_size
372 val.common_size = src_bits
373
374 dst_type_bits = type_bits(nir_op.output_type)
375 if dst_type_bits != 0:
376 assert val.bit_size == 0 or val.bit_size == dst_type_bits
377 return dst_type_bits
378 else:
379 if val.common_size != 0:
380 assert val.bit_size == 0 or val.bit_size == val.common_size
381 else:
382 val.common_size = val.bit_size
383 return val.common_size
384
385 def _propagate_bit_class_down(self, val, bit_class):
386 if isinstance(val, Constant):
387 assert val.bit_size == 0 or val.bit_size == bit_class
388
389 elif isinstance(val, Variable):
390 assert val.bit_size == 0 or val.bit_size == bit_class
391 self._set_var_bit_class(val.index, bit_class)
392
393 elif isinstance(val, Expression):
394 nir_op = opcodes[val.opcode]
395 dst_type_bits = type_bits(nir_op.output_type)
396 if dst_type_bits != 0:
397 assert bit_class == 0 or bit_class == dst_type_bits
398 else:
399 assert val.common_size == 0 or val.common_size == bit_class
400 val.common_size = bit_class
401
402 if val.common_size:
403 common_class = val.common_size
404 elif nir_op.num_inputs:
405 # If we got here then we have no idea what the actual size is.
406 # Instead, we use a generic class
407 common_class = self._new_class()
408
409 for i in range(nir_op.num_inputs):
410 src_type_bits = type_bits(nir_op.input_types[i])
411 if src_type_bits != 0:
412 self._propagate_bit_class_down(val.sources[i], src_type_bits)
413 else:
414 self._propagate_bit_class_down(val.sources[i], common_class)
415
416 def _validate_bit_class_up(self, val):
417 if isinstance(val, Constant):
418 return val.bit_size
419
420 elif isinstance(val, Variable):
421 var_class = self._get_var_bit_class(val.index)
422 # By the time we get to validation, every variable should have a class
423 assert var_class != 0
424
425 # If we have an explicit size provided by the user, the variable
426 # *must* exactly match the search. It cannot be implicitly sized
427 # because otherwise we could end up with a conflict at runtime.
428 assert val.bit_size == 0 or val.bit_size == var_class
429
430 return var_class
431
432 elif isinstance(val, Expression):
433 nir_op = opcodes[val.opcode]
434 val.common_class = 0
435 for i in range(nir_op.num_inputs):
436 src_class = self._validate_bit_class_up(val.sources[i])
437 if src_class == 0:
438 continue
439
440 src_type_bits = type_bits(nir_op.input_types[i])
441 if src_type_bits != 0:
442 assert src_class == src_type_bits
443 else:
444 assert val.common_class == 0 or src_class == val.common_class
445 val.common_class = src_class
446
447 dst_type_bits = type_bits(nir_op.output_type)
448 if dst_type_bits != 0:
449 assert val.bit_size == 0 or val.bit_size == dst_type_bits
450 return dst_type_bits
451 else:
452 if val.common_class != 0:
453 assert val.bit_size == 0 or val.bit_size == val.common_class
454 else:
455 val.common_class = val.bit_size
456 return val.common_class
457
458 def _validate_bit_class_down(self, val, bit_class):
459 # At this point, everything *must* have a bit class. Otherwise, we have
460 # a value we don't know how to define.
461 assert bit_class != 0
462
463 if isinstance(val, Constant):
464 assert val.bit_size == 0 or val.bit_size == bit_class
465
466 elif isinstance(val, Variable):
467 assert val.bit_size == 0 or val.bit_size == bit_class
468
469 elif isinstance(val, Expression):
470 nir_op = opcodes[val.opcode]
471 dst_type_bits = type_bits(nir_op.output_type)
472 if dst_type_bits != 0:
473 assert bit_class == dst_type_bits
474 else:
475 assert val.common_class == 0 or val.common_class == bit_class
476 val.common_class = bit_class
477
478 for i in range(nir_op.num_inputs):
479 src_type_bits = type_bits(nir_op.input_types[i])
480 if src_type_bits != 0:
481 self._validate_bit_class_down(val.sources[i], src_type_bits)
482 else:
483 self._validate_bit_class_down(val.sources[i], val.common_class)
484
485 _optimization_ids = itertools.count()
486
487 condition_list = ['true']
488
489 class SearchAndReplace(object):
490 def __init__(self, transform):
491 self.id = next(_optimization_ids)
492
493 search = transform[0]
494 replace = transform[1]
495 if len(transform) > 2:
496 self.condition = transform[2]
497 else:
498 self.condition = 'true'
499
500 if self.condition not in condition_list:
501 condition_list.append(self.condition)
502 self.condition_index = condition_list.index(self.condition)
503
504 varset = VarSet()
505 if isinstance(search, Expression):
506 self.search = search
507 else:
508 self.search = Expression(search, "search{0}".format(self.id), varset)
509
510 varset.lock()
511
512 if isinstance(replace, Value):
513 self.replace = replace
514 else:
515 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
516
517 BitSizeValidator(varset).validate(self.search, self.replace)
518
519 _algebraic_pass_template = mako.template.Template("""
520 #include "nir.h"
521 #include "nir_search.h"
522 #include "nir_search_helpers.h"
523
524 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
525 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
526
527 struct transform {
528 const nir_search_expression *search;
529 const nir_search_value *replace;
530 unsigned condition_offset;
531 };
532
533 #endif
534
535 % for (opcode, xform_list) in xform_dict.items():
536 % for xform in xform_list:
537 ${xform.search.render()}
538 ${xform.replace.render()}
539 % endfor
540
541 static const struct transform ${pass_name}_${opcode}_xforms[] = {
542 % for xform in xform_list:
543 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
544 % endfor
545 };
546 % endfor
547
548 static bool
549 ${pass_name}_block(nir_block *block, const bool *condition_flags,
550 void *mem_ctx)
551 {
552 bool progress = false;
553
554 nir_foreach_instr_reverse_safe(instr, block) {
555 if (instr->type != nir_instr_type_alu)
556 continue;
557
558 nir_alu_instr *alu = nir_instr_as_alu(instr);
559 if (!alu->dest.dest.is_ssa)
560 continue;
561
562 switch (alu->op) {
563 % for opcode in xform_dict.keys():
564 case nir_op_${opcode}:
565 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
566 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
567 if (condition_flags[xform->condition_offset] &&
568 nir_replace_instr(alu, xform->search, xform->replace,
569 mem_ctx)) {
570 progress = true;
571 break;
572 }
573 }
574 break;
575 % endfor
576 default:
577 break;
578 }
579 }
580
581 return progress;
582 }
583
584 static bool
585 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
586 {
587 void *mem_ctx = ralloc_parent(impl);
588 bool progress = false;
589
590 nir_foreach_block_reverse(block, impl) {
591 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
592 }
593
594 if (progress)
595 nir_metadata_preserve(impl, nir_metadata_block_index |
596 nir_metadata_dominance);
597
598 return progress;
599 }
600
601
602 bool
603 ${pass_name}(nir_shader *shader)
604 {
605 bool progress = false;
606 bool condition_flags[${len(condition_list)}];
607 const nir_shader_compiler_options *options = shader->options;
608 (void) options;
609
610 % for index, condition in enumerate(condition_list):
611 condition_flags[${index}] = ${condition};
612 % endfor
613
614 nir_foreach_function(function, shader) {
615 if (function->impl)
616 progress |= ${pass_name}_impl(function->impl, condition_flags);
617 }
618
619 return progress;
620 }
621 """)
622
623 class AlgebraicPass(object):
624 def __init__(self, pass_name, transforms):
625 self.xform_dict = OrderedDict()
626 self.pass_name = pass_name
627
628 error = False
629
630 for xform in transforms:
631 if not isinstance(xform, SearchAndReplace):
632 try:
633 xform = SearchAndReplace(xform)
634 except:
635 print("Failed to parse transformation:", file=sys.stderr)
636 print(" " + str(xform), file=sys.stderr)
637 traceback.print_exc(file=sys.stderr)
638 print('', file=sys.stderr)
639 error = True
640 continue
641
642 if xform.search.opcode not in self.xform_dict:
643 self.xform_dict[xform.search.opcode] = []
644
645 self.xform_dict[xform.search.opcode].append(xform)
646
647 if error:
648 sys.exit(1)
649
650 def render(self):
651 return _algebraic_pass_template.render(pass_name=self.pass_name,
652 xform_dict=self.xform_dict,
653 condition_list=condition_list)