nir/algebraic: Add a bit-size validator
[mesa.git] / src / compiler / nir / nir_algebraic.py
1 #! /usr/bin/env python
2 #
3 # Copyright (C) 2014 Intel Corporation
4 #
5 # Permission is hereby granted, free of charge, to any person obtaining a
6 # copy of this software and associated documentation files (the "Software"),
7 # to deal in the Software without restriction, including without limitation
8 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 # and/or sell copies of the Software, and to permit persons to whom the
10 # Software is furnished to do so, subject to the following conditions:
11 #
12 # The above copyright notice and this permission notice (including the next
13 # paragraph) shall be included in all copies or substantial portions of the
14 # Software.
15 #
16 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22 # IN THE SOFTWARE.
23 #
24 # Authors:
25 # Jason Ekstrand (jason@jlekstrand.net)
26
27 from __future__ import print_function
28 import ast
29 import itertools
30 import struct
31 import sys
32 import mako.template
33 import re
34 import traceback
35
36 from nir_opcodes import opcodes
37
38 _type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
39
40 def type_bits(type_str):
41 m = _type_re.match(type_str)
42 assert m.group('type')
43
44 if m.group('bits') is None:
45 return 0
46 else:
47 return int(m.group('bits'))
48
49 # Represents a set of variables, each with a unique id
50 class VarSet(object):
51 def __init__(self):
52 self.names = {}
53 self.ids = itertools.count()
54 self.immutable = False;
55
56 def __getitem__(self, name):
57 if name not in self.names:
58 assert not self.immutable, "Unknown replacement variable: " + name
59 self.names[name] = self.ids.next()
60
61 return self.names[name]
62
63 def lock(self):
64 self.immutable = True
65
66 class Value(object):
67 @staticmethod
68 def create(val, name_base, varset):
69 if isinstance(val, tuple):
70 return Expression(val, name_base, varset)
71 elif isinstance(val, Expression):
72 return val
73 elif isinstance(val, (str, unicode)):
74 return Variable(val, name_base, varset)
75 elif isinstance(val, (bool, int, long, float)):
76 return Constant(val, name_base)
77
78 __template = mako.template.Template("""
79 static const ${val.c_type} ${val.name} = {
80 { ${val.type_enum}, ${val.bit_size} },
81 % if isinstance(val, Constant):
82 ${val.type()}, { ${hex(val)} /* ${val.value} */ },
83 % elif isinstance(val, Variable):
84 ${val.index}, /* ${val.var_name} */
85 ${'true' if val.is_constant else 'false'},
86 ${val.type() or 'nir_type_invalid' },
87 % elif isinstance(val, Expression):
88 ${'true' if val.inexact else 'false'},
89 nir_op_${val.opcode},
90 { ${', '.join(src.c_ptr for src in val.sources)} },
91 % endif
92 };""")
93
94 def __init__(self, name, type_str):
95 self.name = name
96 self.type_str = type_str
97
98 @property
99 def type_enum(self):
100 return "nir_search_value_" + self.type_str
101
102 @property
103 def c_type(self):
104 return "nir_search_" + self.type_str
105
106 @property
107 def c_ptr(self):
108 return "&{0}.value".format(self.name)
109
110 def render(self):
111 return self.__template.render(val=self,
112 Constant=Constant,
113 Variable=Variable,
114 Expression=Expression)
115
116 _constant_re = re.compile(r"(?P<value>[^@]+)(?:@(?P<bits>\d+))?")
117
118 class Constant(Value):
119 def __init__(self, val, name):
120 Value.__init__(self, name, "constant")
121
122 if isinstance(val, (str)):
123 m = _constant_re.match(val)
124 self.value = ast.literal_eval(m.group('value'))
125 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
126 else:
127 self.value = val
128 self.bit_size = 0
129
130 if isinstance(self.value, bool):
131 assert self.bit_size == 0 or self.bit_size == 32
132 self.bit_size = 32
133
134 def __hex__(self):
135 if isinstance(self.value, (bool)):
136 return 'NIR_TRUE' if self.value else 'NIR_FALSE'
137 if isinstance(self.value, (int, long)):
138 return hex(self.value)
139 elif isinstance(self.value, float):
140 return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
141 else:
142 assert False
143
144 def type(self):
145 if isinstance(self.value, (bool)):
146 return "nir_type_bool32"
147 elif isinstance(self.value, (int, long)):
148 return "nir_type_int"
149 elif isinstance(self.value, float):
150 return "nir_type_float"
151
152 _var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
153 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?")
154
155 class Variable(Value):
156 def __init__(self, val, name, varset):
157 Value.__init__(self, name, "variable")
158
159 m = _var_name_re.match(val)
160 assert m and m.group('name') is not None
161
162 self.var_name = m.group('name')
163 self.is_constant = m.group('const') is not None
164 self.required_type = m.group('type')
165 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
166
167 if self.required_type == 'bool':
168 assert self.bit_size == 0 or self.bit_size == 32
169 self.bit_size = 32
170
171 if self.required_type is not None:
172 assert self.required_type in ('float', 'bool', 'int', 'uint')
173
174 self.index = varset[self.var_name]
175
176 def type(self):
177 if self.required_type == 'bool':
178 return "nir_type_bool32"
179 elif self.required_type in ('int', 'uint'):
180 return "nir_type_int"
181 elif self.required_type == 'float':
182 return "nir_type_float"
183
184 _opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?")
185
186 class Expression(Value):
187 def __init__(self, expr, name_base, varset):
188 Value.__init__(self, name_base, "expression")
189 assert isinstance(expr, tuple)
190
191 m = _opcode_re.match(expr[0])
192 assert m and m.group('opcode') is not None
193
194 self.opcode = m.group('opcode')
195 self.bit_size = int(m.group('bits')) if m.group('bits') else 0
196 self.inexact = m.group('inexact') is not None
197 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset)
198 for (i, src) in enumerate(expr[1:]) ]
199
200 def render(self):
201 srcs = "\n".join(src.render() for src in self.sources)
202 return srcs + super(Expression, self).render()
203
204 class IntEquivalenceRelation(object):
205 """A class representing an equivalence relation on integers.
206
207 Each integer has a canonical form which is the maximum integer to which it
208 is equivalent. Two integers are equivalent precisely when they have the
209 same canonical form.
210
211 The convention of maximum is explicitly chosen to make using it in
212 BitSizeValidator easier because it means that an actual bit_size (if any)
213 will always be the canonical form.
214 """
215 def __init__(self):
216 self._remap = {}
217
218 def get_canonical(self, x):
219 """Get the canonical integer corresponding to x."""
220 if x in self._remap:
221 return self.get_canonical(self._remap[x])
222 else:
223 return x
224
225 def add_equiv(self, a, b):
226 """Add an equivalence and return the canonical form."""
227 c = max(self.get_canonical(a), self.get_canonical(b))
228 if a != c:
229 assert a < c
230 self._remap[a] = c
231
232 if b != c:
233 assert b < c
234 self._remap[b] = c
235
236 return c
237
238 class BitSizeValidator(object):
239 """A class for validating bit sizes of expressions.
240
241 NIR supports multiple bit-sizes on expressions in order to handle things
242 such as fp64. The source and destination of every ALU operation is
243 assigned a type and that type may or may not specify a bit size. Sources
244 and destinations whose type does not specify a bit size are considered
245 "unsized" and automatically take on the bit size of the corresponding
246 register or SSA value. NIR has two simple rules for bit sizes that are
247 validated by nir_validator:
248
249 1) A given SSA def or register has a single bit size that is respected by
250 everything that reads from it or writes to it.
251
252 2) The bit sizes of all unsized inputs/outputs on any given ALU
253 instruction must match. They need not match the sized inputs or
254 outputs but they must match each other.
255
256 In order to keep nir_algebraic relatively simple and easy-to-use,
257 nir_search supports a type of bit-size inference based on the two rules
258 above. This is similar to type inference in many common programming
259 languages. If, for instance, you are constructing an add operation and you
260 know the second source is 16-bit, then you know that the other source and
261 the destination must also be 16-bit. There are, however, cases where this
262 inference can be ambiguous or contradictory. Consider, for instance, the
263 following transformation:
264
265 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
266
267 This transformation can potentially cause a problem because usub_borrow is
268 well-defined for any bit-size of integer. However, b2i always generates a
269 32-bit result so it could end up replacing a 64-bit expression with one
270 that takes two 64-bit values and produces a 32-bit value. As another
271 example, consider this expression:
272
273 (('bcsel', a, b, 0), ('iand', a, b))
274
275 In this case, in the search expression a must be 32-bit but b can
276 potentially have any bit size. If we had a 64-bit b value, we would end up
277 trying to and a 32-bit value with a 64-bit value which would be invalid
278
279 This class solves that problem by providing a validation layer that proves
280 that a given search-and-replace operation is 100% well-defined before we
281 generate any code. This ensures that bugs are caught at compile time
282 rather than at run time.
283
284 The basic operation of the validator is very similar to the bitsize_tree in
285 nir_search only a little more subtle. Instead of simply tracking bit
286 sizes, it tracks "bit classes" where each class is represented by an
287 integer. A value of 0 means we don't know anything yet, positive values
288 are actual bit-sizes, and negative values are used to track equivalence
289 classes of sizes that must be the same but have yet to receive an actual
290 size. The first stage uses the bitsize_tree algorithm to assign bit
291 classes to each variable. If it ever comes across an inconsistency, it
292 assert-fails. Then the second stage uses that information to prove that
293 the resulting expression can always validly be constructed.
294 """
295
296 def __init__(self, varset):
297 self._num_classes = 0
298 self._var_classes = [0] * len(varset.names)
299 self._class_relation = IntEquivalenceRelation()
300
301 def validate(self, search, replace):
302 dst_class = self._propagate_bit_size_up(search)
303 if dst_class == 0:
304 dst_class = self._new_class()
305 self._propagate_bit_class_down(search, dst_class)
306
307 validate_dst_class = self._validate_bit_class_up(replace)
308 assert validate_dst_class == 0 or validate_dst_class == dst_class
309 self._validate_bit_class_down(replace, dst_class)
310
311 def _new_class(self):
312 self._num_classes += 1
313 return -self._num_classes
314
315 def _set_var_bit_class(self, var_id, bit_class):
316 assert bit_class != 0
317 var_class = self._var_classes[var_id]
318 if var_class == 0:
319 self._var_classes[var_id] = bit_class
320 else:
321 canon_class = self._class_relation.get_canonical(var_class)
322 assert canon_class < 0 or canon_class == bit_class
323 var_class = self._class_relation.add_equiv(var_class, bit_class)
324 self._var_classes[var_id] = var_class
325
326 def _get_var_bit_class(self, var_id):
327 return self._class_relation.get_canonical(self._var_classes[var_id])
328
329 def _propagate_bit_size_up(self, val):
330 if isinstance(val, (Constant, Variable)):
331 return val.bit_size
332
333 elif isinstance(val, Expression):
334 nir_op = opcodes[val.opcode]
335 val.common_size = 0
336 for i in range(nir_op.num_inputs):
337 src_bits = self._propagate_bit_size_up(val.sources[i])
338 if src_bits == 0:
339 continue
340
341 src_type_bits = type_bits(nir_op.input_types[i])
342 if src_type_bits != 0:
343 assert src_bits == src_type_bits
344 else:
345 assert val.common_size == 0 or src_bits == val.common_size
346 val.common_size = src_bits
347
348 dst_type_bits = type_bits(nir_op.output_type)
349 if dst_type_bits != 0:
350 assert val.bit_size == 0 or val.bit_size == dst_type_bits
351 return dst_type_bits
352 else:
353 if val.common_size != 0:
354 assert val.bit_size == 0 or val.bit_size == val.common_size
355 else:
356 val.common_size = val.bit_size
357 return val.common_size
358
359 def _propagate_bit_class_down(self, val, bit_class):
360 if isinstance(val, Constant):
361 assert val.bit_size == 0 or val.bit_size == bit_class
362
363 elif isinstance(val, Variable):
364 assert val.bit_size == 0 or val.bit_size == bit_class
365 self._set_var_bit_class(val.index, bit_class)
366
367 elif isinstance(val, Expression):
368 nir_op = opcodes[val.opcode]
369 dst_type_bits = type_bits(nir_op.output_type)
370 if dst_type_bits != 0:
371 assert bit_class == 0 or bit_class == dst_type_bits
372 else:
373 assert val.common_size == 0 or val.common_size == bit_class
374 val.common_size = bit_class
375
376 if val.common_size:
377 common_class = val.common_size
378 elif nir_op.num_inputs:
379 # If we got here then we have no idea what the actual size is.
380 # Instead, we use a generic class
381 common_class = self._new_class()
382
383 for i in range(nir_op.num_inputs):
384 src_type_bits = type_bits(nir_op.input_types[i])
385 if src_type_bits != 0:
386 self._propagate_bit_class_down(val.sources[i], src_type_bits)
387 else:
388 self._propagate_bit_class_down(val.sources[i], common_class)
389
390 def _validate_bit_class_up(self, val):
391 if isinstance(val, Constant):
392 return val.bit_size
393
394 elif isinstance(val, Variable):
395 var_class = self._get_var_bit_class(val.index)
396 # By the time we get to validation, every variable should have a class
397 assert var_class != 0
398
399 # If we have an explicit size provided by the user, the variable
400 # *must* exactly match the search. It cannot be implicitly sized
401 # because otherwise we could end up with a conflict at runtime.
402 assert val.bit_size == 0 or val.bit_size == var_class
403
404 return var_class
405
406 elif isinstance(val, Expression):
407 nir_op = opcodes[val.opcode]
408 val.common_class = 0
409 for i in range(nir_op.num_inputs):
410 src_class = self._validate_bit_class_up(val.sources[i])
411 if src_class == 0:
412 continue
413
414 src_type_bits = type_bits(nir_op.input_types[i])
415 if src_type_bits != 0:
416 assert src_class == src_type_bits
417 else:
418 assert val.common_class == 0 or src_class == val.common_class
419 val.common_class = src_class
420
421 dst_type_bits = type_bits(nir_op.output_type)
422 if dst_type_bits != 0:
423 assert val.bit_size == 0 or val.bit_size == dst_type_bits
424 return dst_type_bits
425 else:
426 if val.common_class != 0:
427 assert val.bit_size == 0 or val.bit_size == val.common_class
428 else:
429 val.common_class = val.bit_size
430 return val.common_class
431
432 def _validate_bit_class_down(self, val, bit_class):
433 # At this point, everything *must* have a bit class. Otherwise, we have
434 # a value we don't know how to define.
435 assert bit_class != 0
436
437 if isinstance(val, Constant):
438 assert val.bit_size == 0 or val.bit_size == bit_class
439
440 elif isinstance(val, Variable):
441 assert val.bit_size == 0 or val.bit_size == bit_class
442
443 elif isinstance(val, Expression):
444 nir_op = opcodes[val.opcode]
445 dst_type_bits = type_bits(nir_op.output_type)
446 if dst_type_bits != 0:
447 assert bit_class == dst_type_bits
448 else:
449 assert val.common_class == 0 or val.common_class == bit_class
450 val.common_class = bit_class
451
452 for i in range(nir_op.num_inputs):
453 src_type_bits = type_bits(nir_op.input_types[i])
454 if src_type_bits != 0:
455 self._validate_bit_class_down(val.sources[i], src_type_bits)
456 else:
457 self._validate_bit_class_down(val.sources[i], val.common_class)
458
459 _optimization_ids = itertools.count()
460
461 condition_list = ['true']
462
463 class SearchAndReplace(object):
464 def __init__(self, transform):
465 self.id = _optimization_ids.next()
466
467 search = transform[0]
468 replace = transform[1]
469 if len(transform) > 2:
470 self.condition = transform[2]
471 else:
472 self.condition = 'true'
473
474 if self.condition not in condition_list:
475 condition_list.append(self.condition)
476 self.condition_index = condition_list.index(self.condition)
477
478 varset = VarSet()
479 if isinstance(search, Expression):
480 self.search = search
481 else:
482 self.search = Expression(search, "search{0}".format(self.id), varset)
483
484 varset.lock()
485
486 if isinstance(replace, Value):
487 self.replace = replace
488 else:
489 self.replace = Value.create(replace, "replace{0}".format(self.id), varset)
490
491 BitSizeValidator(varset).validate(self.search, self.replace)
492
493 _algebraic_pass_template = mako.template.Template("""
494 #include "nir.h"
495 #include "nir_search.h"
496
497 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
498 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
499
500 struct transform {
501 const nir_search_expression *search;
502 const nir_search_value *replace;
503 unsigned condition_offset;
504 };
505
506 struct opt_state {
507 void *mem_ctx;
508 bool progress;
509 const bool *condition_flags;
510 };
511
512 #endif
513
514 % for (opcode, xform_list) in xform_dict.iteritems():
515 % for xform in xform_list:
516 ${xform.search.render()}
517 ${xform.replace.render()}
518 % endfor
519
520 static const struct transform ${pass_name}_${opcode}_xforms[] = {
521 % for xform in xform_list:
522 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
523 % endfor
524 };
525 % endfor
526
527 static bool
528 ${pass_name}_block(nir_block *block, void *void_state)
529 {
530 struct opt_state *state = void_state;
531
532 nir_foreach_instr_reverse_safe(block, instr) {
533 if (instr->type != nir_instr_type_alu)
534 continue;
535
536 nir_alu_instr *alu = nir_instr_as_alu(instr);
537 if (!alu->dest.dest.is_ssa)
538 continue;
539
540 switch (alu->op) {
541 % for opcode in xform_dict.keys():
542 case nir_op_${opcode}:
543 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
544 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
545 if (state->condition_flags[xform->condition_offset] &&
546 nir_replace_instr(alu, xform->search, xform->replace,
547 state->mem_ctx)) {
548 state->progress = true;
549 break;
550 }
551 }
552 break;
553 % endfor
554 default:
555 break;
556 }
557 }
558
559 return true;
560 }
561
562 static bool
563 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
564 {
565 struct opt_state state;
566
567 state.mem_ctx = ralloc_parent(impl);
568 state.progress = false;
569 state.condition_flags = condition_flags;
570
571 nir_foreach_block_reverse_call(impl, ${pass_name}_block, &state);
572
573 if (state.progress)
574 nir_metadata_preserve(impl, nir_metadata_block_index |
575 nir_metadata_dominance);
576
577 return state.progress;
578 }
579
580
581 bool
582 ${pass_name}(nir_shader *shader)
583 {
584 bool progress = false;
585 bool condition_flags[${len(condition_list)}];
586 const nir_shader_compiler_options *options = shader->options;
587 (void) options;
588
589 % for index, condition in enumerate(condition_list):
590 condition_flags[${index}] = ${condition};
591 % endfor
592
593 nir_foreach_function(shader, function) {
594 if (function->impl)
595 progress |= ${pass_name}_impl(function->impl, condition_flags);
596 }
597
598 return progress;
599 }
600 """)
601
602 class AlgebraicPass(object):
603 def __init__(self, pass_name, transforms):
604 self.xform_dict = {}
605 self.pass_name = pass_name
606
607 error = False
608
609 for xform in transforms:
610 if not isinstance(xform, SearchAndReplace):
611 try:
612 xform = SearchAndReplace(xform)
613 except:
614 print("Failed to parse transformation:", file=sys.stderr)
615 print(" " + str(xform), file=sys.stderr)
616 traceback.print_exc(file=sys.stderr)
617 print('', file=sys.stderr)
618 error = True
619 continue
620
621 if xform.search.opcode not in self.xform_dict:
622 self.xform_dict[xform.search.opcode] = []
623
624 self.xform_dict[xform.search.opcode].append(xform)
625
626 if error:
627 sys.exit(1)
628
629 def render(self):
630 return _algebraic_pass_template.render(pass_name=self.pass_name,
631 xform_dict=self.xform_dict,
632 condition_list=condition_list)