nir: Switch the arguments to nir_foreach_function
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #! /usr/bin/env python
2 #
3 # Copyright (C) 2014 Intel Corporation
4 #
5 # Permission is hereby granted, free of charge, to any person obtaining a
6 # copy of this software and associated documentation files (the "Software"),
7 # to deal in the Software without restriction, including without limitation
8 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 # and/or sell copies of the Software, and to permit persons to whom the
10 # Software is furnished to do so, subject to the following conditions:
11 #
12 # The above copyright notice and this permission notice (including the next
13 # paragraph) shall be included in all copies or substantial portions of the
14 # Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 # IN THE SOFTWARE.
23 #
24 # Authors:
25 # Jason Ekstrand (jason@jlekstrand.net)
26
27 from __future__ import print_function
28 import ast
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
39
40 def type_bits(type_str):
41 m = _type_re.match(type_str)
42 assert m.group('type')
43
44 if m.group('bits') is None:
45 return 0
46 else:
47 return int(m.group('bits'))
48
49 # Represents a set of variables, each with a unique id
50 class VarSet(object):
51 def __init__(self):
52 self.names = {}
53 self.ids = itertools.count()
54 self.immutable = False;
55
56 def __getitem__(self, name):
57 if name not in self.names:
58 assert not self.immutable, "Unknown replacement variable: " + name
59 self.names[name] = self.ids.next()
60
61 return self.names[name]
62
63 def lock(self):
64 self.immutable = True
65
66 class Value(object):
67 @staticmethod
68 def create(val, name_base, varset):
69 if isinstance(val, tuple):
70 return Expression(val, name_base, varset)
71 elif isinstance(val, Expression):
72 return val
73 elif isinstance(val, (str, unicode)):
74 return Variable(val, name_base, varset)
75 elif isinstance(val, (bool, int, long, float)):
76 return Constant(val, name_base)
77
78 __template = mako.template.Template("""
79 static const ${val.c_type} ${val.name} = {
80 { ${val.type_enum}, ${val.bit_size} },
81 % if isinstance(val, Constant):
82 ${val.type()}, { ${hex(val)} /* ${val.value} */ },
83 % elif isinstance(val, Variable):
84 ${val.index}, /* ${val.var_name} */
85 ${'true' if val.is_constant else 'false'},
86 ${val.type() or 'nir_type_invalid' },
87 % elif isinstance(val, Expression):
88 ${'true' if val.inexact else 'false'},
89 nir_op_${val.opcode},
90 { ${', '.join(src.c_ptr for src in val.sources)} },
91 % endif
92 };""")
93
94 def __init__(self, name, type_str):
95 self.name = name
96 self.type_str = type_str
97
98 @property
99 def type_enum(self):
100 return "nir_search_value_" + self.type_str
101
102 @property
103 def c_type(self):
104 return "nir_search_" + self.type_str
105
106 @property
107 def c_ptr(self):
108 return "&{0}.value".format(self.name)
109
110 def render(self):
111 return self.__template.render(val=self,
112 Constant=Constant,
113 Variable=Variable,
114 Expression=Expression)
115
116 _constant_re = re.compile(r"(?P<value>[^@]+)(?:@(?P<bits>\d+))?")
117
118 class Constant(Value):
119 def __init__(self, val, name):
120 Value.__init__(self, name, "constant")
121
122 if isinstance(val, (str)):
123 m = _constant_re.match(val)
124 self.value = ast.literal_eval(m.group('value'))
125 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
126 else:
127 self.value = val
128 self.bit_size = 0
129
130 if isinstance(self.value, bool):
131 assert self.bit_size == 0 or self.bit_size == 32
132 self.bit_size = 32
133
134 def __hex__(self):
135 if isinstance(self.value, (bool)):
136 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
137 if isinstance(self.value, (int, long)):
138 return hex(self.value)
139 elif isinstance(self.value, float):
140 return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
141 else:
142 assert False
143
144 def type(self):
145 if isinstance(self.value, (bool)):
146 return "nir_type_bool32"
147 elif isinstance(self.value, (int, long)):
148 return "nir_type_int"
149 elif isinstance(self.value, float):
150 return "nir_type_float"
151
152 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
153 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?")
154
155 class Variable(Value):
156 def __init__(self, val, name, varset):
157 Value.__init__(self, name, "variable")
158
159 m = _var_name_re.match(val)
160 assert m and m.group('name') is not None
161
162 self.var_name = m.group('name')
163 self.is_constant = m.group('const') is not None
164 self.required_type = m.group('type')
165 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
166
167 if self.required_type == 'bool':
168 assert self.bit_size == 0 or self.bit_size == 32
169 self.bit_size = 32
170
171 if self.required_type is not None:
172 assert self.required_type in ('float', 'bool', 'int', 'uint')
173
174 self.index = varset[self.var_name]
175
176 def type(self):
177 if self.required_type == 'bool':
178 return "nir_type_bool32"
179 elif self.required_type in ('int', 'uint'):
180 return "nir_type_int"
181 elif self.required_type == 'float':
182 return "nir_type_float"
183
184 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?")
185
186 class Expression(Value):
187 def __init__(self, expr, name_base, varset):
188 Value.__init__(self, name_base, "expression")
189 assert isinstance(expr, tuple)
190
191 m = _opcode_re.match(expr[0])
192 assert m and m.group('opcode') is not None
193
194 self.opcode = m.group('opcode')
195 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
196 self.inexact = m.group('inexact') is not None
197 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
198 for (i, src) in enumerate(expr[1:]) ]
199
200 def render(self):
201 srcs = "\n".join(src.render() for src in self.sources)
202 return srcs + super(Expression, self).render()
203
204 class IntEquivalenceRelation(object):
205 """A class representing an equivalence relation on integers.
206
207 Each integer has a canonical form which is the maximum integer to which it
208 is equivalent. Two integers are equivalent precisely when they have the
209 same canonical form.
210
211 The convention of maximum is explicitly chosen to make using it in
212 BitSizeValidator easier because it means that an actual bit_size (if any)
213 will always be the canonical form.
214 """
215 def __init__(self):
216 self._remap = {}
217
218 def get_canonical(self, x):
219 """Get the canonical integer corresponding to x."""
220 if x in self._remap:
221 return self.get_canonical(self._remap[x])
222 else:
223 return x
224
225 def add_equiv(self, a, b):
226 """Add an equivalence and return the canonical form."""
227 c = max(self.get_canonical(a), self.get_canonical(b))
228 if a != c:
229 assert a < c
230 self._remap[a] = c
231
232 if b != c:
233 assert b < c
234 self._remap[b] = c
235
236 return c
237
238 class BitSizeValidator(object):
239 """A class for validating bit sizes of expressions.
240
241 NIR supports multiple bit-sizes on expressions in order to handle things
242 such as fp64. The source and destination of every ALU operation is
243 assigned a type and that type may or may not specify a bit size. Sources
244 and destinations whose type does not specify a bit size are considered
245 "unsized" and automatically take on the bit size of the corresponding
246 register or SSA value. NIR has two simple rules for bit sizes that are
247 validated by nir_validator:
248
249 1) A given SSA def or register has a single bit size that is respected by
250 everything that reads from it or writes to it.
251
252 2) The bit sizes of all unsized inputs/outputs on any given ALU
253 instruction must match. They need not match the sized inputs or
254 outputs but they must match each other.
255
256 In order to keep nir_algebraic relatively simple and easy-to-use,
257 nir_search supports a type of bit-size inference based on the two rules
258 above. This is similar to type inference in many common programming
259 languages. If, for instance, you are constructing an add operation and you
260 know the second source is 16-bit, then you know that the other source and
261 the destination must also be 16-bit. There are, however, cases where this
262 inference can be ambiguous or contradictory. Consider, for instance, the
263 following transformation:
264
265 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
266
267 This transformation can potentially cause a problem because usub_borrow is
268 well-defined for any bit-size of integer. However, b2i always generates a
269 32-bit result so it could end up replacing a 64-bit expression with one
270 that takes two 64-bit values and produces a 32-bit value. As another
271 example, consider this expression:
272
273 (('bcsel', a, b, 0), ('iand', a, b))
274
275 In this case, in the search expression a must be 32-bit but b can
276 potentially have any bit size. If we had a 64-bit b value, we would end up
277 trying to and a 32-bit value with a 64-bit value which would be invalid
278
279 This class solves that problem by providing a validation layer that proves
280 that a given search-and-replace operation is 100% well-defined before we
281 generate any code. This ensures that bugs are caught at compile time
282 rather than at run time.
283
284 The basic operation of the validator is very similar to the bitsize_tree in
285 nir_search only a little more subtle. Instead of simply tracking bit
286 sizes, it tracks "bit classes" where each class is represented by an
287 integer. A value of 0 means we don't know anything yet, positive values
288 are actual bit-sizes, and negative values are used to track equivalence
289 classes of sizes that must be the same but have yet to receive an actual
290 size. The first stage uses the bitsize_tree algorithm to assign bit
291 classes to each variable. If it ever comes across an inconsistency, it
292 assert-fails. Then the second stage uses that information to prove that
293 the resulting expression can always validly be constructed.
294 """
295
296 def __init__(self, varset):
297 self._num_classes = 0
298 self._var_classes = [0] * len(varset.names)
299 self._class_relation = IntEquivalenceRelation()
300
301 def validate(self, search, replace):
302 dst_class = self._propagate_bit_size_up(search)
303 if dst_class == 0:
304 dst_class = self._new_class()
305 self._propagate_bit_class_down(search, dst_class)
306
307 validate_dst_class = self._validate_bit_class_up(replace)
308 assert validate_dst_class == 0 or validate_dst_class == dst_class
309 self._validate_bit_class_down(replace, dst_class)
310
311 def _new_class(self):
312 self._num_classes += 1
313 return -self._num_classes
314
315 def _set_var_bit_class(self, var_id, bit_class):
316 assert bit_class != 0
317 var_class = self._var_classes[var_id]
318 if var_class == 0:
319 self._var_classes[var_id] = bit_class
320 else:
321 canon_class = self._class_relation.get_canonical(var_class)
322 assert canon_class < 0 or canon_class == bit_class
323 var_class = self._class_relation.add_equiv(var_class, bit_class)
324 self._var_classes[var_id] = var_class
325
326 def _get_var_bit_class(self, var_id):
327 return self._class_relation.get_canonical(self._var_classes[var_id])
328
329 def _propagate_bit_size_up(self, val):
330 if isinstance(val, (Constant, Variable)):
331 return val.bit_size
332
333 elif isinstance(val, Expression):
334 nir_op = opcodes[val.opcode]
335 val.common_size = 0
336 for i in range(nir_op.num_inputs):
337 src_bits = self._propagate_bit_size_up(val.sources[i])
338 if src_bits == 0:
339 continue
340
341 src_type_bits = type_bits(nir_op.input_types[i])
342 if src_type_bits != 0:
343 assert src_bits == src_type_bits
344 else:
345 assert val.common_size == 0 or src_bits == val.common_size
346 val.common_size = src_bits
347
348 dst_type_bits = type_bits(nir_op.output_type)
349 if dst_type_bits != 0:
350 assert val.bit_size == 0 or val.bit_size == dst_type_bits
351 return dst_type_bits
352 else:
353 if val.common_size != 0:
354 assert val.bit_size == 0 or val.bit_size == val.common_size
355 else:
356 val.common_size = val.bit_size
357 return val.common_size
358
359 def _propagate_bit_class_down(self, val, bit_class):
360 if isinstance(val, Constant):
361 assert val.bit_size == 0 or val.bit_size == bit_class
362
363 elif isinstance(val, Variable):
364 assert val.bit_size == 0 or val.bit_size == bit_class
365 self._set_var_bit_class(val.index, bit_class)
366
367 elif isinstance(val, Expression):
368 nir_op = opcodes[val.opcode]
369 dst_type_bits = type_bits(nir_op.output_type)
370 if dst_type_bits != 0:
371 assert bit_class == 0 or bit_class == dst_type_bits
372 else:
373 assert val.common_size == 0 or val.common_size == bit_class
374 val.common_size = bit_class
375
376 if val.common_size:
377 common_class = val.common_size
378 elif nir_op.num_inputs:
379 # If we got here then we have no idea what the actual size is.
380 # Instead, we use a generic class
381 common_class = self._new_class()
382
383 for i in range(nir_op.num_inputs):
384 src_type_bits = type_bits(nir_op.input_types[i])
385 if src_type_bits != 0:
386 self._propagate_bit_class_down(val.sources[i], src_type_bits)
387 else:
388 self._propagate_bit_class_down(val.sources[i], common_class)
389
390 def _validate_bit_class_up(self, val):
391 if isinstance(val, Constant):
392 return val.bit_size
393
394 elif isinstance(val, Variable):
395 var_class = self._get_var_bit_class(val.index)
396 # By the time we get to validation, every variable should have a class
397 assert var_class != 0
398
399 # If we have an explicit size provided by the user, the variable
400 # *must* exactly match the search. It cannot be implicitly sized
401 # because otherwise we could end up with a conflict at runtime.
402 assert val.bit_size == 0 or val.bit_size == var_class
403
404 return var_class
405
406 elif isinstance(val, Expression):
407 nir_op = opcodes[val.opcode]
408 val.common_class = 0
409 for i in range(nir_op.num_inputs):
410 src_class = self._validate_bit_class_up(val.sources[i])
411 if src_class == 0:
412 continue
413
414 src_type_bits = type_bits(nir_op.input_types[i])
415 if src_type_bits != 0:
416 assert src_class == src_type_bits
417 else:
418 assert val.common_class == 0 or src_class == val.common_class
419 val.common_class = src_class
420
421 dst_type_bits = type_bits(nir_op.output_type)
422 if dst_type_bits != 0:
423 assert val.bit_size == 0 or val.bit_size == dst_type_bits
424 return dst_type_bits
425 else:
426 if val.common_class != 0:
427 assert val.bit_size == 0 or val.bit_size == val.common_class
428 else:
429 val.common_class = val.bit_size
430 return val.common_class
431
432 def _validate_bit_class_down(self, val, bit_class):
433 # At this point, everything *must* have a bit class. Otherwise, we have
434 # a value we don't know how to define.
435 assert bit_class != 0
436
437 if isinstance(val, Constant):
438 assert val.bit_size == 0 or val.bit_size == bit_class
439
440 elif isinstance(val, Variable):
441 assert val.bit_size == 0 or val.bit_size == bit_class
442
443 elif isinstance(val, Expression):
444 nir_op = opcodes[val.opcode]
445 dst_type_bits = type_bits(nir_op.output_type)
446 if dst_type_bits != 0:
447 assert bit_class == dst_type_bits
448 else:
449 assert val.common_class == 0 or val.common_class == bit_class
450 val.common_class = bit_class
451
452 for i in range(nir_op.num_inputs):
453 src_type_bits = type_bits(nir_op.input_types[i])
454 if src_type_bits != 0:
455 self._validate_bit_class_down(val.sources[i], src_type_bits)
456 else:
457 self._validate_bit_class_down(val.sources[i], val.common_class)
458
459 _optimization_ids = itertools.count()
460
461 condition_list = ['true']
462
463 class SearchAndReplace(object):
464 def __init__(self, transform):
465 self.id = _optimization_ids.next()
466
467 search = transform[0]
468 replace = transform[1]
469 if len(transform) > 2:
470 self.condition = transform[2]
471 else:
472 self.condition = 'true'
473
474 if self.condition not in condition_list:
475 condition_list.append(self.condition)
476 self.condition_index = condition_list.index(self.condition)
477
478 varset = VarSet()
479 if isinstance(search, Expression):
480 self.search = search
481 else:
482 self.search = Expression(search, "search{0}".format(self.id), varset)
483
484 varset.lock()
485
486 if isinstance(replace, Value):
487 self.replace = replace
488 else:
489 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
490
491 BitSizeValidator(varset).validate(self.search, self.replace)
492
493 _algebraic_pass_template = mako.template.Template("""
494 #include "nir.h"
495 #include "nir_search.h"
496
497 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
498 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
499
500 struct transform {
501 const nir_search_expression *search;
502 const nir_search_value *replace;
503 unsigned condition_offset;
504 };
505
506 #endif
507
508 % for (opcode, xform_list) in xform_dict.iteritems():
509 % for xform in xform_list:
510 ${xform.search.render()}
511 ${xform.replace.render()}
512 % endfor
513
514 static const struct transform ${pass_name}_${opcode}_xforms[] = {
515 % for xform in xform_list:
516 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
517 % endfor
518 };
519 % endfor
520
521 static bool
522 ${pass_name}_block(nir_block *block, const bool *condition_flags,
523 void *mem_ctx)
524 {
525 bool progress = false;
526
527 nir_foreach_instr_reverse_safe(instr, block) {
528 if (instr->type != nir_instr_type_alu)
529 continue;
530
531 nir_alu_instr *alu = nir_instr_as_alu(instr);
532 if (!alu->dest.dest.is_ssa)
533 continue;
534
535 switch (alu->op) {
536 % for opcode in xform_dict.keys():
537 case nir_op_${opcode}:
538 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
539 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
540 if (condition_flags[xform->condition_offset] &&
541 nir_replace_instr(alu, xform->search, xform->replace,
542 mem_ctx)) {
543 progress = true;
544 break;
545 }
546 }
547 break;
548 % endfor
549 default:
550 break;
551 }
552 }
553
554 return progress;
555 }
556
557 static bool
558 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
559 {
560 void *mem_ctx = ralloc_parent(impl);
561 bool progress = false;
562
563 nir_foreach_block_reverse(block, impl) {
564 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
565 }
566
567 if (progress)
568 nir_metadata_preserve(impl, nir_metadata_block_index |
569 nir_metadata_dominance);
570
571 return progress;
572 }
573
574
575 bool
576 ${pass_name}(nir_shader *shader)
577 {
578 bool progress = false;
579 bool condition_flags[${len(condition_list)}];
580 const nir_shader_compiler_options *options = shader->options;
581 (void) options;
582
583 % for index, condition in enumerate(condition_list):
584 condition_flags[${index}] = ${condition};
585 % endfor
586
587 nir_foreach_function(function, shader) {
588 if (function->impl)
589 progress |= ${pass_name}_impl(function->impl, condition_flags);
590 }
591
592 return progress;
593 }
594 """)
595
596 class AlgebraicPass(object):
597 def __init__(self, pass_name, transforms):
598 self.xform_dict = {}
599 self.pass_name = pass_name
600
601 error = False
602
603 for xform in transforms:
604 if not isinstance(xform, SearchAndReplace):
605 try:
606 xform = SearchAndReplace(xform)
607 except:
608 print("Failed to parse transformation:", file=sys.stderr)
609 print(" " + str(xform), file=sys.stderr)
610 traceback.print_exc(file=sys.stderr)
611 print('', file=sys.stderr)
612 error = True
613 continue
614
615 if xform.search.opcode not in self.xform_dict:
616 self.xform_dict[xform.search.opcode] = []
617
618 self.xform_dict[xform.search.opcode].append(xform)
619
620 if error:
621 sys.exit(1)
622
623 def render(self):
624 return _algebraic_pass_template.render(pass_name=self.pass_name,
625 xform_dict=self.xform_dict,
626 condition_list=condition_list)