st/mesa/glsl/nir/i965: make use of new gl_shader_program_data in gl_shader_program
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #! /usr/bin/env python
2 #
3 # Copyright (C) 2014 Intel Corporation
4 #
5 # Permission is hereby granted, free of charge, to any person obtaining a
6 # copy of this software and associated documentation files (the "Software"),
7 # to deal in the Software without restriction, including without limitation
8 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 # and/or sell copies of the Software, and to permit persons to whom the
10 # Software is furnished to do so, subject to the following conditions:
11 #
12 # The above copyright notice and this permission notice (including the next
13 # paragraph) shall be included in all copies or substantial portions of the
14 # Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 # IN THE SOFTWARE.
23 #
24 # Authors:
25 # Jason Ekstrand (jason@jlekstrand.net)
26
27 from __future__ import print_function
28 import ast
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
39
40 def type_bits(type_str):
41 m = _type_re.match(type_str)
42 assert m.group('type')
43
44 if m.group('bits') is None:
45 return 0
46 else:
47 return int(m.group('bits'))
48
49 # Represents a set of variables, each with a unique id
50 class VarSet(object):
51 def __init__(self):
52 self.names = {}
53 self.ids = itertools.count()
54 self.immutable = False;
55
56 def __getitem__(self, name):
57 if name not in self.names:
58 assert not self.immutable, "Unknown replacement variable: " + name
59 self.names[name] = self.ids.next()
60
61 return self.names[name]
62
63 def lock(self):
64 self.immutable = True
65
66 class Value(object):
67 @staticmethod
68 def create(val, name_base, varset):
69 if isinstance(val, tuple):
70 return Expression(val, name_base, varset)
71 elif isinstance(val, Expression):
72 return val
73 elif isinstance(val, (str, unicode)):
74 return Variable(val, name_base, varset)
75 elif isinstance(val, (bool, int, long, float)):
76 return Constant(val, name_base)
77
78 __template = mako.template.Template("""
79 #include "compiler/nir/nir_search_helpers.h"
80 static const ${val.c_type} ${val.name} = {
81 { ${val.type_enum}, ${val.bit_size} },
82 % if isinstance(val, Constant):
83 ${val.type()}, { ${hex(val)} /* ${val.value} */ },
84 % elif isinstance(val, Variable):
85 ${val.index}, /* ${val.var_name} */
86 ${'true' if val.is_constant else 'false'},
87 ${val.type() or 'nir_type_invalid' },
88 ${val.cond if val.cond else 'NULL'},
89 % elif isinstance(val, Expression):
90 ${'true' if val.inexact else 'false'},
91 nir_op_${val.opcode},
92 { ${', '.join(src.c_ptr for src in val.sources)} },
93 % endif
94 };""")
95
96 def __init__(self, name, type_str):
97 self.name = name
98 self.type_str = type_str
99
100 @property
101 def type_enum(self):
102 return "nir_search_value_" + self.type_str
103
104 @property
105 def c_type(self):
106 return "nir_search_" + self.type_str
107
108 @property
109 def c_ptr(self):
110 return "&{0}.value".format(self.name)
111
112 def render(self):
113 return self.__template.render(val=self,
114 Constant=Constant,
115 Variable=Variable,
116 Expression=Expression)
117
118 _constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
119
120 class Constant(Value):
121 def __init__(self, val, name):
122 Value.__init__(self, name, "constant")
123
124 if isinstance(val, (str)):
125 m = _constant_re.match(val)
126 self.value = ast.literal_eval(m.group('value'))
127 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
128 else:
129 self.value = val
130 self.bit_size = 0
131
132 if isinstance(self.value, bool):
133 assert self.bit_size == 0 or self.bit_size == 32
134 self.bit_size = 32
135
136 def __hex__(self):
137 if isinstance(self.value, (bool)):
138 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
139 if isinstance(self.value, (int, long)):
140 return hex(self.value)
141 elif isinstance(self.value, float):
142 return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
143 else:
144 assert False
145
146 def type(self):
147 if isinstance(self.value, (bool)):
148 return "nir_type_bool32"
149 elif isinstance(self.value, (int, long)):
150 return "nir_type_int"
151 elif isinstance(self.value, float):
152 return "nir_type_float"
153
154 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
155 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
156 r"(?P<cond>\([^\)]+\))?")
157
158 class Variable(Value):
159 def __init__(self, val, name, varset):
160 Value.__init__(self, name, "variable")
161
162 m = _var_name_re.match(val)
163 assert m and m.group('name') is not None
164
165 self.var_name = m.group('name')
166 self.is_constant = m.group('const') is not None
167 self.cond = m.group('cond')
168 self.required_type = m.group('type')
169 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
170
171 if self.required_type == 'bool':
172 assert self.bit_size == 0 or self.bit_size == 32
173 self.bit_size = 32
174
175 if self.required_type is not None:
176 assert self.required_type in ('float', 'bool', 'int', 'uint')
177
178 self.index = varset[self.var_name]
179
180 def type(self):
181 if self.required_type == 'bool':
182 return "nir_type_bool32"
183 elif self.required_type in ('int', 'uint'):
184 return "nir_type_int"
185 elif self.required_type == 'float':
186 return "nir_type_float"
187
188 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?")
189
190 class Expression(Value):
191 def __init__(self, expr, name_base, varset):
192 Value.__init__(self, name_base, "expression")
193 assert isinstance(expr, tuple)
194
195 m = _opcode_re.match(expr[0])
196 assert m and m.group('opcode') is not None
197
198 self.opcode = m.group('opcode')
199 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
200 self.inexact = m.group('inexact') is not None
201 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
202 for (i, src) in enumerate(expr[1:]) ]
203
204 def render(self):
205 srcs = "\n".join(src.render() for src in self.sources)
206 return srcs + super(Expression, self).render()
207
208 class IntEquivalenceRelation(object):
209 """A class representing an equivalence relation on integers.
210
211 Each integer has a canonical form which is the maximum integer to which it
212 is equivalent. Two integers are equivalent precisely when they have the
213 same canonical form.
214
215 The convention of maximum is explicitly chosen to make using it in
216 BitSizeValidator easier because it means that an actual bit_size (if any)
217 will always be the canonical form.
218 """
219 def __init__(self):
220 self._remap = {}
221
222 def get_canonical(self, x):
223 """Get the canonical integer corresponding to x."""
224 if x in self._remap:
225 return self.get_canonical(self._remap[x])
226 else:
227 return x
228
229 def add_equiv(self, a, b):
230 """Add an equivalence and return the canonical form."""
231 c = max(self.get_canonical(a), self.get_canonical(b))
232 if a != c:
233 assert a < c
234 self._remap[a] = c
235
236 if b != c:
237 assert b < c
238 self._remap[b] = c
239
240 return c
241
242 class BitSizeValidator(object):
243 """A class for validating bit sizes of expressions.
244
245 NIR supports multiple bit-sizes on expressions in order to handle things
246 such as fp64. The source and destination of every ALU operation is
247 assigned a type and that type may or may not specify a bit size. Sources
248 and destinations whose type does not specify a bit size are considered
249 "unsized" and automatically take on the bit size of the corresponding
250 register or SSA value. NIR has two simple rules for bit sizes that are
251 validated by nir_validator:
252
253 1) A given SSA def or register has a single bit size that is respected by
254 everything that reads from it or writes to it.
255
256 2) The bit sizes of all unsized inputs/outputs on any given ALU
257 instruction must match. They need not match the sized inputs or
258 outputs but they must match each other.
259
260 In order to keep nir_algebraic relatively simple and easy-to-use,
261 nir_search supports a type of bit-size inference based on the two rules
262 above. This is similar to type inference in many common programming
263 languages. If, for instance, you are constructing an add operation and you
264 know the second source is 16-bit, then you know that the other source and
265 the destination must also be 16-bit. There are, however, cases where this
266 inference can be ambiguous or contradictory. Consider, for instance, the
267 following transformation:
268
269 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
270
271 This transformation can potentially cause a problem because usub_borrow is
272 well-defined for any bit-size of integer. However, b2i always generates a
273 32-bit result so it could end up replacing a 64-bit expression with one
274 that takes two 64-bit values and produces a 32-bit value. As another
275 example, consider this expression:
276
277 (('bcsel', a, b, 0), ('iand', a, b))
278
279 In this case, in the search expression a must be 32-bit but b can
280 potentially have any bit size. If we had a 64-bit b value, we would end up
281 trying to and a 32-bit value with a 64-bit value which would be invalid
282
283 This class solves that problem by providing a validation layer that proves
284 that a given search-and-replace operation is 100% well-defined before we
285 generate any code. This ensures that bugs are caught at compile time
286 rather than at run time.
287
288 The basic operation of the validator is very similar to the bitsize_tree in
289 nir_search only a little more subtle. Instead of simply tracking bit
290 sizes, it tracks "bit classes" where each class is represented by an
291 integer. A value of 0 means we don't know anything yet, positive values
292 are actual bit-sizes, and negative values are used to track equivalence
293 classes of sizes that must be the same but have yet to receive an actual
294 size. The first stage uses the bitsize_tree algorithm to assign bit
295 classes to each variable. If it ever comes across an inconsistency, it
296 assert-fails. Then the second stage uses that information to prove that
297 the resulting expression can always validly be constructed.
298 """
299
300 def __init__(self, varset):
301 self._num_classes = 0
302 self._var_classes = [0] * len(varset.names)
303 self._class_relation = IntEquivalenceRelation()
304
305 def validate(self, search, replace):
306 dst_class = self._propagate_bit_size_up(search)
307 if dst_class == 0:
308 dst_class = self._new_class()
309 self._propagate_bit_class_down(search, dst_class)
310
311 validate_dst_class = self._validate_bit_class_up(replace)
312 assert validate_dst_class == 0 or validate_dst_class == dst_class
313 self._validate_bit_class_down(replace, dst_class)
314
315 def _new_class(self):
316 self._num_classes += 1
317 return -self._num_classes
318
319 def _set_var_bit_class(self, var_id, bit_class):
320 assert bit_class != 0
321 var_class = self._var_classes[var_id]
322 if var_class == 0:
323 self._var_classes[var_id] = bit_class
324 else:
325 canon_class = self._class_relation.get_canonical(var_class)
326 assert canon_class < 0 or canon_class == bit_class
327 var_class = self._class_relation.add_equiv(var_class, bit_class)
328 self._var_classes[var_id] = var_class
329
330 def _get_var_bit_class(self, var_id):
331 return self._class_relation.get_canonical(self._var_classes[var_id])
332
333 def _propagate_bit_size_up(self, val):
334 if isinstance(val, (Constant, Variable)):
335 return val.bit_size
336
337 elif isinstance(val, Expression):
338 nir_op = opcodes[val.opcode]
339 val.common_size = 0
340 for i in range(nir_op.num_inputs):
341 src_bits = self._propagate_bit_size_up(val.sources[i])
342 if src_bits == 0:
343 continue
344
345 src_type_bits = type_bits(nir_op.input_types[i])
346 if src_type_bits != 0:
347 assert src_bits == src_type_bits
348 else:
349 assert val.common_size == 0 or src_bits == val.common_size
350 val.common_size = src_bits
351
352 dst_type_bits = type_bits(nir_op.output_type)
353 if dst_type_bits != 0:
354 assert val.bit_size == 0 or val.bit_size == dst_type_bits
355 return dst_type_bits
356 else:
357 if val.common_size != 0:
358 assert val.bit_size == 0 or val.bit_size == val.common_size
359 else:
360 val.common_size = val.bit_size
361 return val.common_size
362
363 def _propagate_bit_class_down(self, val, bit_class):
364 if isinstance(val, Constant):
365 assert val.bit_size == 0 or val.bit_size == bit_class
366
367 elif isinstance(val, Variable):
368 assert val.bit_size == 0 or val.bit_size == bit_class
369 self._set_var_bit_class(val.index, bit_class)
370
371 elif isinstance(val, Expression):
372 nir_op = opcodes[val.opcode]
373 dst_type_bits = type_bits(nir_op.output_type)
374 if dst_type_bits != 0:
375 assert bit_class == 0 or bit_class == dst_type_bits
376 else:
377 assert val.common_size == 0 or val.common_size == bit_class
378 val.common_size = bit_class
379
380 if val.common_size:
381 common_class = val.common_size
382 elif nir_op.num_inputs:
383 # If we got here then we have no idea what the actual size is.
384 # Instead, we use a generic class
385 common_class = self._new_class()
386
387 for i in range(nir_op.num_inputs):
388 src_type_bits = type_bits(nir_op.input_types[i])
389 if src_type_bits != 0:
390 self._propagate_bit_class_down(val.sources[i], src_type_bits)
391 else:
392 self._propagate_bit_class_down(val.sources[i], common_class)
393
394 def _validate_bit_class_up(self, val):
395 if isinstance(val, Constant):
396 return val.bit_size
397
398 elif isinstance(val, Variable):
399 var_class = self._get_var_bit_class(val.index)
400 # By the time we get to validation, every variable should have a class
401 assert var_class != 0
402
403 # If we have an explicit size provided by the user, the variable
404 # *must* exactly match the search. It cannot be implicitly sized
405 # because otherwise we could end up with a conflict at runtime.
406 assert val.bit_size == 0 or val.bit_size == var_class
407
408 return var_class
409
410 elif isinstance(val, Expression):
411 nir_op = opcodes[val.opcode]
412 val.common_class = 0
413 for i in range(nir_op.num_inputs):
414 src_class = self._validate_bit_class_up(val.sources[i])
415 if src_class == 0:
416 continue
417
418 src_type_bits = type_bits(nir_op.input_types[i])
419 if src_type_bits != 0:
420 assert src_class == src_type_bits
421 else:
422 assert val.common_class == 0 or src_class == val.common_class
423 val.common_class = src_class
424
425 dst_type_bits = type_bits(nir_op.output_type)
426 if dst_type_bits != 0:
427 assert val.bit_size == 0 or val.bit_size == dst_type_bits
428 return dst_type_bits
429 else:
430 if val.common_class != 0:
431 assert val.bit_size == 0 or val.bit_size == val.common_class
432 else:
433 val.common_class = val.bit_size
434 return val.common_class
435
436 def _validate_bit_class_down(self, val, bit_class):
437 # At this point, everything *must* have a bit class. Otherwise, we have
438 # a value we don't know how to define.
439 assert bit_class != 0
440
441 if isinstance(val, Constant):
442 assert val.bit_size == 0 or val.bit_size == bit_class
443
444 elif isinstance(val, Variable):
445 assert val.bit_size == 0 or val.bit_size == bit_class
446
447 elif isinstance(val, Expression):
448 nir_op = opcodes[val.opcode]
449 dst_type_bits = type_bits(nir_op.output_type)
450 if dst_type_bits != 0:
451 assert bit_class == dst_type_bits
452 else:
453 assert val.common_class == 0 or val.common_class == bit_class
454 val.common_class = bit_class
455
456 for i in range(nir_op.num_inputs):
457 src_type_bits = type_bits(nir_op.input_types[i])
458 if src_type_bits != 0:
459 self._validate_bit_class_down(val.sources[i], src_type_bits)
460 else:
461 self._validate_bit_class_down(val.sources[i], val.common_class)
462
463 _optimization_ids = itertools.count()
464
465 condition_list = ['true']
466
467 class SearchAndReplace(object):
468 def __init__(self, transform):
469 self.id = _optimization_ids.next()
470
471 search = transform[0]
472 replace = transform[1]
473 if len(transform) > 2:
474 self.condition = transform[2]
475 else:
476 self.condition = 'true'
477
478 if self.condition not in condition_list:
479 condition_list.append(self.condition)
480 self.condition_index = condition_list.index(self.condition)
481
482 varset = VarSet()
483 if isinstance(search, Expression):
484 self.search = search
485 else:
486 self.search = Expression(search, "search{0}".format(self.id), varset)
487
488 varset.lock()
489
490 if isinstance(replace, Value):
491 self.replace = replace
492 else:
493 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
494
495 BitSizeValidator(varset).validate(self.search, self.replace)
496
497 _algebraic_pass_template = mako.template.Template("""
498 #include "nir.h"
499 #include "nir_search.h"
500
501 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
502 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
503
504 struct transform {
505 const nir_search_expression *search;
506 const nir_search_value *replace;
507 unsigned condition_offset;
508 };
509
510 #endif
511
512 % for (opcode, xform_list) in xform_dict.iteritems():
513 % for xform in xform_list:
514 ${xform.search.render()}
515 ${xform.replace.render()}
516 % endfor
517
518 static const struct transform ${pass_name}_${opcode}_xforms[] = {
519 % for xform in xform_list:
520 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
521 % endfor
522 };
523 % endfor
524
525 static bool
526 ${pass_name}_block(nir_block *block, const bool *condition_flags,
527 void *mem_ctx)
528 {
529 bool progress = false;
530
531 nir_foreach_instr_reverse_safe(instr, block) {
532 if (instr->type != nir_instr_type_alu)
533 continue;
534
535 nir_alu_instr *alu = nir_instr_as_alu(instr);
536 if (!alu->dest.dest.is_ssa)
537 continue;
538
539 switch (alu->op) {
540 % for opcode in xform_dict.keys():
541 case nir_op_${opcode}:
542 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
543 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
544 if (condition_flags[xform->condition_offset] &&
545 nir_replace_instr(alu, xform->search, xform->replace,
546 mem_ctx)) {
547 progress = true;
548 break;
549 }
550 }
551 break;
552 % endfor
553 default:
554 break;
555 }
556 }
557
558 return progress;
559 }
560
561 static bool
562 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
563 {
564 void *mem_ctx = ralloc_parent(impl);
565 bool progress = false;
566
567 nir_foreach_block_reverse(block, impl) {
568 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
569 }
570
571 if (progress)
572 nir_metadata_preserve(impl, nir_metadata_block_index |
573 nir_metadata_dominance);
574
575 return progress;
576 }
577
578
579 bool
580 ${pass_name}(nir_shader *shader)
581 {
582 bool progress = false;
583 bool condition_flags[${len(condition_list)}];
584 const nir_shader_compiler_options *options = shader->options;
585 (void) options;
586
587 % for index, condition in enumerate(condition_list):
588 condition_flags[${index}] = ${condition};
589 % endfor
590
591 nir_foreach_function(function, shader) {
592 if (function->impl)
593 progress |= ${pass_name}_impl(function->impl, condition_flags);
594 }
595
596 return progress;
597 }
598 """)
599
600 class AlgebraicPass(object):
601 def __init__(self, pass_name, transforms):
602 self.xform_dict = {}
603 self.pass_name = pass_name
604
605 error = False
606
607 for xform in transforms:
608 if not isinstance(xform, SearchAndReplace):
609 try:
610 xform = SearchAndReplace(xform)
611 except:
612 print("Failed to parse transformation:", file=sys.stderr)
613 print(" " + str(xform), file=sys.stderr)
614 traceback.print_exc(file=sys.stderr)
615 print('', file=sys.stderr)
616 error = True
617 continue
618
619 if xform.search.opcode not in self.xform_dict:
620 self.xform_dict[xform.search.opcode] = []
621
622 self.xform_dict[xform.search.opcode].append(xform)
623
624 if error:
625 sys.exit(1)
626
627 def render(self):
628 return _algebraic_pass_template.render(pass_name=self.pass_name,
629 xform_dict=self.xform_dict,
630 condition_list=condition_list)