2 # Copyright (C) 2014 Intel Corporation
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 # Jason Ekstrand (jason@jlekstrand.net)
26 from __future__
import print_function
28 from collections
import OrderedDict
36 from nir_opcodes
import opcodes
38 if sys
.version_info
< (3, 0):
39 integer_types
= (int, long)
43 integer_types
= (int, )
46 _type_re
= re
.compile(r
"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
48 def type_bits(type_str
):
49 m
= _type_re
.match(type_str
)
50 assert m
.group('type')
52 if m
.group('bits') is None:
55 return int(m
.group('bits'))
57 # Represents a set of variables, each with a unique id
61 self
.ids
= itertools
.count()
62 self
.immutable
= False;
64 def __getitem__(self
, name
):
65 if name
not in self
.names
:
66 assert not self
.immutable
, "Unknown replacement variable: " + name
67 self
.names
[name
] = next(self
.ids
)
69 return self
.names
[name
]
76 def create(val
, name_base
, varset
):
77 if isinstance(val
, bytes
):
78 val
= val
.decode('utf-8')
80 if isinstance(val
, tuple):
81 return Expression(val
, name_base
, varset
)
82 elif isinstance(val
, Expression
):
84 elif isinstance(val
, string_type
):
85 return Variable(val
, name_base
, varset
)
86 elif isinstance(val
, (bool, float) + integer_types
):
87 return Constant(val
, name_base
)
89 __template
= mako
.template
.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
107 def __init__(self
, val
, name
, type_str
):
108 self
.in_val
= str(val
)
110 self
.type_str
= type_str
117 return "nir_search_value_" + self
.type_str
121 return "nir_search_" + self
.type_str
125 return "&{0}.value".format(self
.name
)
128 return self
.__template
.render(val
=self
,
131 Expression
=Expression
)
133 _constant_re
= re
.compile(r
"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
135 class Constant(Value
):
136 def __init__(self
, val
, name
):
137 Value
.__init
__(self
, val
, name
, "constant")
139 self
.in_val
= str(val
)
140 if isinstance(val
, (str)):
141 m
= _constant_re
.match(val
)
142 self
.value
= ast
.literal_eval(m
.group('value'))
143 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
148 if isinstance(self
.value
, bool):
149 assert self
.bit_size
== 0 or self
.bit_size
== 32
153 if isinstance(self
.value
, (bool)):
154 return 'NIR_TRUE' if self
.value
else 'NIR_FALSE'
155 if isinstance(self
.value
, integer_types
):
156 return hex(self
.value
)
157 elif isinstance(self
.value
, float):
158 i
= struct
.unpack('Q', struct
.pack('d', self
.value
))[0]
161 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
162 # Adding it explicitly makes the generated file identical, regardless
163 # of the Python version running this script.
164 if h
[-1] != 'L' and i
> sys
.maxsize
:
172 if isinstance(self
.value
, (bool)):
173 return "nir_type_bool"
174 elif isinstance(self
.value
, integer_types
):
175 return "nir_type_int"
176 elif isinstance(self
.value
, float):
177 return "nir_type_float"
179 _var_name_re
= re
.compile(r
"(?P<const>#)?(?P<name>\w+)"
180 r
"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
181 r
"(?P<cond>\([^\)]+\))?")
183 class Variable(Value
):
184 def __init__(self
, val
, name
, varset
):
185 Value
.__init
__(self
, val
, name
, "variable")
187 m
= _var_name_re
.match(val
)
188 assert m
and m
.group('name') is not None
190 self
.var_name
= m
.group('name')
191 self
.is_constant
= m
.group('const') is not None
192 self
.cond
= m
.group('cond')
193 self
.required_type
= m
.group('type')
194 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
196 if self
.required_type
== 'bool':
197 assert self
.bit_size
== 0 or self
.bit_size
== 32
200 if self
.required_type
is not None:
201 assert self
.required_type
in ('float', 'bool', 'int', 'uint')
203 self
.index
= varset
[self
.var_name
]
209 if self
.required_type
== 'bool':
210 return "nir_type_bool"
211 elif self
.required_type
in ('int', 'uint'):
212 return "nir_type_int"
213 elif self
.required_type
== 'float':
214 return "nir_type_float"
216 _opcode_re
= re
.compile(r
"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
217 r
"(?P<cond>\([^\)]+\))?")
219 class Expression(Value
):
220 def __init__(self
, expr
, name_base
, varset
):
221 Value
.__init
__(self
, expr
, name_base
, "expression")
222 assert isinstance(expr
, tuple)
224 m
= _opcode_re
.match(expr
[0])
225 assert m
and m
.group('opcode') is not None
227 self
.opcode
= m
.group('opcode')
228 self
.bit_size
= int(m
.group('bits')) if m
.group('bits') else 0
229 self
.inexact
= m
.group('inexact') is not None
230 self
.cond
= m
.group('cond')
231 self
.sources
= [ Value
.create(src
, "{0}_{1}".format(name_base
, i
), varset
)
232 for (i
, src
) in enumerate(expr
[1:]) ]
235 srcs
= "\n".join(src
.render() for src
in self
.sources
)
236 return srcs
+ super(Expression
, self
).render()
238 class IntEquivalenceRelation(object):
239 """A class representing an equivalence relation on integers.
241 Each integer has a canonical form which is the maximum integer to which it
242 is equivalent. Two integers are equivalent precisely when they have the
245 The convention of maximum is explicitly chosen to make using it in
246 BitSizeValidator easier because it means that an actual bit_size (if any)
247 will always be the canonical form.
252 def get_canonical(self
, x
):
253 """Get the canonical integer corresponding to x."""
255 return self
.get_canonical(self
._remap
[x
])
259 def add_equiv(self
, a
, b
):
260 """Add an equivalence and return the canonical form."""
261 c
= max(self
.get_canonical(a
), self
.get_canonical(b
))
272 class BitSizeValidator(object):
273 """A class for validating bit sizes of expressions.
275 NIR supports multiple bit-sizes on expressions in order to handle things
276 such as fp64. The source and destination of every ALU operation is
277 assigned a type and that type may or may not specify a bit size. Sources
278 and destinations whose type does not specify a bit size are considered
279 "unsized" and automatically take on the bit size of the corresponding
280 register or SSA value. NIR has two simple rules for bit sizes that are
281 validated by nir_validator:
283 1) A given SSA def or register has a single bit size that is respected by
284 everything that reads from it or writes to it.
286 2) The bit sizes of all unsized inputs/outputs on any given ALU
287 instruction must match. They need not match the sized inputs or
288 outputs but they must match each other.
290 In order to keep nir_algebraic relatively simple and easy-to-use,
291 nir_search supports a type of bit-size inference based on the two rules
292 above. This is similar to type inference in many common programming
293 languages. If, for instance, you are constructing an add operation and you
294 know the second source is 16-bit, then you know that the other source and
295 the destination must also be 16-bit. There are, however, cases where this
296 inference can be ambiguous or contradictory. Consider, for instance, the
297 following transformation:
299 (('usub_borrow', a, b), ('b2i', ('ult', a, b)))
301 This transformation can potentially cause a problem because usub_borrow is
302 well-defined for any bit-size of integer. However, b2i always generates a
303 32-bit result so it could end up replacing a 64-bit expression with one
304 that takes two 64-bit values and produces a 32-bit value. As another
305 example, consider this expression:
307 (('bcsel', a, b, 0), ('iand', a, b))
309 In this case, in the search expression a must be 32-bit but b can
310 potentially have any bit size. If we had a 64-bit b value, we would end up
311 trying to and a 32-bit value with a 64-bit value which would be invalid
313 This class solves that problem by providing a validation layer that proves
314 that a given search-and-replace operation is 100% well-defined before we
315 generate any code. This ensures that bugs are caught at compile time
316 rather than at run time.
318 The basic operation of the validator is very similar to the bitsize_tree in
319 nir_search only a little more subtle. Instead of simply tracking bit
320 sizes, it tracks "bit classes" where each class is represented by an
321 integer. A value of 0 means we don't know anything yet, positive values
322 are actual bit-sizes, and negative values are used to track equivalence
323 classes of sizes that must be the same but have yet to receive an actual
324 size. The first stage uses the bitsize_tree algorithm to assign bit
325 classes to each variable. If it ever comes across an inconsistency, it
326 assert-fails. Then the second stage uses that information to prove that
327 the resulting expression can always validly be constructed.
330 def __init__(self
, varset
):
331 self
._num
_classes
= 0
332 self
._var
_classes
= [0] * len(varset
.names
)
333 self
._class
_relation
= IntEquivalenceRelation()
335 def validate(self
, search
, replace
):
336 search_dst_class
= self
._propagate
_bit
_size
_up
(search
)
337 if search_dst_class
== 0:
338 search_dst_class
= self
._new
_class
()
339 self
._propagate
_bit
_class
_down
(search
, search_dst_class
)
341 replace_dst_class
= self
._validate
_bit
_class
_up
(replace
)
342 if replace_dst_class
!= 0:
343 assert search_dst_class
!= 0, \
344 'Search expression matches any bit size but replace ' \
345 'expression can only generate {0}-bit values' \
346 .format(replace_dst_class
)
348 assert search_dst_class
== replace_dst_class
, \
349 'Search expression matches any {0}-bit values but replace ' \
350 'expression can only generates {1}-bit values' \
351 .format(search_dst_class
, replace_dst_class
)
353 self
._validate
_bit
_class
_down
(replace
, search_dst_class
)
355 def _new_class(self
):
356 self
._num
_classes
+= 1
357 return -self
._num
_classes
359 def _set_var_bit_class(self
, var
, bit_class
):
360 assert bit_class
!= 0
361 var_class
= self
._var
_classes
[var
.index
]
363 self
._var
_classes
[var
.index
] = bit_class
365 canon_var_class
= self
._class
_relation
.get_canonical(var_class
)
366 canon_bit_class
= self
._class
_relation
.get_canonical(bit_class
)
367 assert canon_var_class
< 0 or canon_bit_class
< 0 or \
368 canon_var_class
== canon_bit_class
, \
369 'Variable {0} cannot be both {1}-bit and {2}-bit' \
370 .format(str(var
), bit_class
, var_class
)
371 var_class
= self
._class
_relation
.add_equiv(var_class
, bit_class
)
372 self
._var
_classes
[var
.index
] = var_class
374 def _get_var_bit_class(self
, var
):
375 return self
._class
_relation
.get_canonical(self
._var
_classes
[var
.index
])
377 def _propagate_bit_size_up(self
, val
):
378 if isinstance(val
, (Constant
, Variable
)):
381 elif isinstance(val
, Expression
):
382 nir_op
= opcodes
[val
.opcode
]
384 for i
in range(nir_op
.num_inputs
):
385 src_bits
= self
._propagate
_bit
_size
_up
(val
.sources
[i
])
389 src_type_bits
= type_bits(nir_op
.input_types
[i
])
390 if src_type_bits
!= 0:
391 assert src_bits
== src_type_bits
, \
392 'Source {0} of nir_op_{1} must be a {2}-bit value but ' \
393 'the only possible matched values are {3}-bit: {4}' \
394 .format(i
, val
.opcode
, src_type_bits
, src_bits
, str(val
))
396 assert val
.common_size
== 0 or src_bits
== val
.common_size
, \
397 'Expression cannot have both {0}-bit and {1}-bit ' \
398 'variable-width sources: {2}' \
399 .format(src_bits
, val
.common_size
, str(val
))
400 val
.common_size
= src_bits
402 dst_type_bits
= type_bits(nir_op
.output_type
)
403 if dst_type_bits
!= 0:
404 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
, \
405 'nir_op_{0} produces a {1}-bit result but a {2}-bit ' \
406 'result was requested' \
407 .format(val
.opcode
, dst_type_bits
, val
.bit_size
)
410 if val
.common_size
!= 0:
411 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_size
, \
412 'Variable width expression musr be {0}-bit based on ' \
413 'the sources but a {1}-bit result was requested: {2}' \
414 .format(val
.common_size
, val
.bit_size
, str(val
))
416 val
.common_size
= val
.bit_size
417 return val
.common_size
419 def _propagate_bit_class_down(self
, val
, bit_class
):
420 if isinstance(val
, Constant
):
421 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
, \
422 'Constant is {0}-bit but a {1}-bit value is required: {2}' \
423 .format(val
.bit_size
, bit_class
, str(val
))
425 elif isinstance(val
, Variable
):
426 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
, \
427 'Variable is {0}-bit but a {1}-bit value is required: {2}' \
428 .format(val
.bit_size
, bit_class
, str(val
))
429 self
._set
_var
_bit
_class
(val
, bit_class
)
431 elif isinstance(val
, Expression
):
432 nir_op
= opcodes
[val
.opcode
]
433 dst_type_bits
= type_bits(nir_op
.output_type
)
434 if dst_type_bits
!= 0:
435 assert bit_class
== 0 or bit_class
== dst_type_bits
, \
436 'nir_op_{0} produces a {1}-bit result but the parent ' \
437 'expression wants a {2}-bit value' \
438 .format(val
.opcode
, dst_type_bits
, bit_class
)
440 assert val
.common_size
== 0 or val
.common_size
== bit_class
, \
441 'Variable-width expression produces a {0}-bit result ' \
442 'based on the source widths but the parent expression ' \
443 'wants a {1}-bit value: {2}' \
444 .format(val
.common_size
, bit_class
, str(val
))
445 val
.common_size
= bit_class
448 common_class
= val
.common_size
449 elif nir_op
.num_inputs
:
450 # If we got here then we have no idea what the actual size is.
451 # Instead, we use a generic class
452 common_class
= self
._new
_class
()
454 for i
in range(nir_op
.num_inputs
):
455 src_type_bits
= type_bits(nir_op
.input_types
[i
])
456 if src_type_bits
!= 0:
457 self
._propagate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
459 self
._propagate
_bit
_class
_down
(val
.sources
[i
], common_class
)
461 def _validate_bit_class_up(self
, val
):
462 if isinstance(val
, Constant
):
465 elif isinstance(val
, Variable
):
466 var_class
= self
._get
_var
_bit
_class
(val
)
467 # By the time we get to validation, every variable should have a class
468 assert var_class
!= 0
470 # If we have an explicit size provided by the user, the variable
471 # *must* exactly match the search. It cannot be implicitly sized
472 # because otherwise we could end up with a conflict at runtime.
473 assert val
.bit_size
== 0 or val
.bit_size
== var_class
477 elif isinstance(val
, Expression
):
478 nir_op
= opcodes
[val
.opcode
]
480 for i
in range(nir_op
.num_inputs
):
481 src_class
= self
._validate
_bit
_class
_up
(val
.sources
[i
])
485 src_type_bits
= type_bits(nir_op
.input_types
[i
])
486 if src_type_bits
!= 0:
487 assert src_class
== src_type_bits
489 assert val
.common_class
== 0 or src_class
== val
.common_class
490 val
.common_class
= src_class
492 dst_type_bits
= type_bits(nir_op
.output_type
)
493 if dst_type_bits
!= 0:
494 assert val
.bit_size
== 0 or val
.bit_size
== dst_type_bits
497 if val
.common_class
!= 0:
498 assert val
.bit_size
== 0 or val
.bit_size
== val
.common_class
500 val
.common_class
= val
.bit_size
501 return val
.common_class
503 def _validate_bit_class_down(self
, val
, bit_class
):
504 # At this point, everything *must* have a bit class. Otherwise, we have
505 # a value we don't know how to define.
506 assert bit_class
!= 0
508 if isinstance(val
, Constant
):
509 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
511 elif isinstance(val
, Variable
):
512 assert val
.bit_size
== 0 or val
.bit_size
== bit_class
514 elif isinstance(val
, Expression
):
515 nir_op
= opcodes
[val
.opcode
]
516 dst_type_bits
= type_bits(nir_op
.output_type
)
517 if dst_type_bits
!= 0:
518 assert bit_class
== dst_type_bits
520 assert val
.common_class
== 0 or val
.common_class
== bit_class
521 val
.common_class
= bit_class
523 for i
in range(nir_op
.num_inputs
):
524 src_type_bits
= type_bits(nir_op
.input_types
[i
])
525 if src_type_bits
!= 0:
526 self
._validate
_bit
_class
_down
(val
.sources
[i
], src_type_bits
)
528 self
._validate
_bit
_class
_down
(val
.sources
[i
], val
.common_class
)
530 _optimization_ids
= itertools
.count()
532 condition_list
= ['true']
534 class SearchAndReplace(object):
535 def __init__(self
, transform
):
536 self
.id = next(_optimization_ids
)
538 search
= transform
[0]
539 replace
= transform
[1]
540 if len(transform
) > 2:
541 self
.condition
= transform
[2]
543 self
.condition
= 'true'
545 if self
.condition
not in condition_list
:
546 condition_list
.append(self
.condition
)
547 self
.condition_index
= condition_list
.index(self
.condition
)
550 if isinstance(search
, Expression
):
553 self
.search
= Expression(search
, "search{0}".format(self
.id), varset
)
557 if isinstance(replace
, Value
):
558 self
.replace
= replace
560 self
.replace
= Value
.create(replace
, "replace{0}".format(self
.id), varset
)
562 BitSizeValidator(varset
).validate(self
.search
, self
.replace
)
564 _algebraic_pass_template
= mako
.template
.Template("""
566 #include "nir_search.h"
567 #include "nir_search_helpers.h"
569 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
570 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
573 const nir_search_expression *search;
574 const nir_search_value *replace;
575 unsigned condition_offset;
580 % for (opcode, xform_list) in xform_dict.items():
581 % for xform in xform_list:
582 ${xform.search.render()}
583 ${xform.replace.render()}
586 static const struct transform ${pass_name}_${opcode}_xforms[] = {
587 % for xform in xform_list:
588 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
594 ${pass_name}_block(nir_block *block, const bool *condition_flags,
597 bool progress = false;
599 nir_foreach_instr_reverse_safe(instr, block) {
600 if (instr->type != nir_instr_type_alu)
603 nir_alu_instr *alu = nir_instr_as_alu(instr);
604 if (!alu->dest.dest.is_ssa)
608 % for opcode in xform_dict.keys():
609 case nir_op_${opcode}:
610 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
611 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
612 if (condition_flags[xform->condition_offset] &&
613 nir_replace_instr(alu, xform->search, xform->replace,
630 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
632 void *mem_ctx = ralloc_parent(impl);
633 bool progress = false;
635 nir_foreach_block_reverse(block, impl) {
636 progress |= ${pass_name}_block(block, condition_flags, mem_ctx);
640 nir_metadata_preserve(impl, nir_metadata_block_index |
641 nir_metadata_dominance);
648 ${pass_name}(nir_shader *shader)
650 bool progress = false;
651 bool condition_flags[${len(condition_list)}];
652 const nir_shader_compiler_options *options = shader->options;
655 % for index, condition in enumerate(condition_list):
656 condition_flags[${index}] = ${condition};
659 nir_foreach_function(function, shader) {
661 progress |= ${pass_name}_impl(function->impl, condition_flags);
668 class AlgebraicPass(object):
669 def __init__(self
, pass_name
, transforms
):
670 self
.xform_dict
= OrderedDict()
671 self
.pass_name
= pass_name
675 for xform
in transforms
:
676 if not isinstance(xform
, SearchAndReplace
):
678 xform
= SearchAndReplace(xform
)
680 print("Failed to parse transformation:", file=sys
.stderr
)
681 print(" " + str(xform
), file=sys
.stderr
)
682 traceback
.print_exc(file=sys
.stderr
)
683 print('', file=sys
.stderr
)
687 if xform
.search
.opcode
not in self
.xform_dict
:
688 self
.xform_dict
[xform
.search
.opcode
] = []
690 self
.xform_dict
[xform
.search
.opcode
].append(xform
)
696 return _algebraic_pass_template
.render(pass_name
=self
.pass_name
,
697 xform_dict
=self
.xform_dict
,
698 condition_list
=condition_list
)