2 # Copyright (C) 2014 Intel Corporation
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
24 # Jason Ekstrand (jason@jlekstrand.net)
26 from __future__
import print_function
28 from collections
import defaultdict
36 from nir_opcodes
import opcodes
38 if sys
.version_info
< (3, 0):
39 integer_types
= (int, long)
43 integer_types
= (int, )
46 _type_re
= re
.compile(r
"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
48 def type_bits(type_str
):
49 m
= _type_re
.match(type_str
)
50 assert m
.group('type')
52 if m
.group('bits') is None:
55 return int(m
.group('bits'))
57 # Represents a set of variables, each with a unique id
61 self
.ids
= itertools
.count()
62 self
.immutable
= False;
64 def __getitem__(self
, name
):
65 if name
not in self
.names
:
66 assert not self
.immutable
, "Unknown replacement variable: " + name
67 self
.names
[name
] = next(self
.ids
)
69 return self
.names
[name
]
76 def create(val
, name_base
, varset
):
77 if isinstance(val
, bytes
):
78 val
= val
.decode('utf-8')
80 if isinstance(val
, tuple):
81 return Expression(val
, name_base
, varset
)
82 elif isinstance(val
, Expression
):
84 elif isinstance(val
, string_type
):
85 return Variable(val
, name_base
, varset
)
86 elif isinstance(val
, (bool, float) + integer_types
):
87 return Constant(val
, name_base
)
89 __template
= mako
.template
.Template("""
90 static const ${val.c_type} ${val.name} = {
91 { ${val.type_enum}, ${val.c_bit_size} },
92 % if isinstance(val, Constant):
93 ${val.type()}, { ${val.hex()} /* ${val.value} */ },
94 % elif isinstance(val, Variable):
95 ${val.index}, /* ${val.var_name} */
96 ${'true' if val.is_constant else 'false'},
97 ${val.type() or 'nir_type_invalid' },
98 ${val.cond if val.cond else 'NULL'},
99 % elif isinstance(val, Expression):
100 ${'true' if val.inexact else 'false'},
101 nir_op_${val.opcode},
102 { ${', '.join(src.c_ptr for src in val.sources)} },
103 ${val.cond if val.cond else 'NULL'},
107 def __init__(self
, val
, name
, type_str
):
108 self
.in_val
= str(val
)
110 self
.type_str
= type_str
115 def get_bit_size(self
):
116 """Get the physical bit-size that has been chosen for this value, or if
117 there is none, the canonical value which currently represents this
118 bit-size class. Variables will be preferred, i.e. if there are any
119 variables in the equivalence class, the canonical value will be a
120 variable. We do this since we'll need to know which variable each value
121 is equivalent to when constructing the replacement expression. This is
122 the "find" part of the union-find algorithm.
126 while isinstance(bit_size
, Value
):
127 if bit_size
._bit
_size
is None:
129 bit_size
= bit_size
._bit
_size
131 if bit_size
is not self
:
132 self
._bit
_size
= bit_size
135 def set_bit_size(self
, other
):
136 """Make self.get_bit_size() return what other.get_bit_size() return
137 before calling this, or just "other" if it's a concrete bit-size. This is
138 the "union" part of the union-find algorithm.
141 self_bit_size
= self
.get_bit_size()
142 other_bit_size
= other
if isinstance(other
, int) else other
.get_bit_size()
144 if self_bit_size
== other_bit_size
:
147 self_bit_size
._bit
_size
= other_bit_size
151 return "nir_search_value_" + self
.type_str
155 return "nir_search_" + self
.type_str
159 return "&{0}.value".format(self
.name
)
162 def c_bit_size(self
):
163 bit_size
= self
.get_bit_size()
164 if isinstance(bit_size
, int):
166 elif isinstance(bit_size
, Variable
):
167 return -bit_size
.index
- 1
169 # If the bit-size class is neither a variable, nor an actual bit-size, then
170 # - If it's in the search expression, we don't need to check anything
171 # - If it's in the replace expression, either it's ambiguous (in which
172 # case we'd reject it), or it equals the bit-size of the search value
173 # We represent these cases with a 0 bit-size.
177 return self
.__template
.render(val
=self
,
180 Expression
=Expression
)
182 _constant_re
= re
.compile(r
"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
184 class Constant(Value
):
185 def __init__(self
, val
, name
):
186 Value
.__init
__(self
, val
, name
, "constant")
188 if isinstance(val
, (str)):
189 m
= _constant_re
.match(val
)
190 self
.value
= ast
.literal_eval(m
.group('value'))
191 self
._bit
_size
= int(m
.group('bits')) if m
.group('bits') else None
194 self
._bit
_size
= None
196 if isinstance(self
.value
, bool):
197 assert self
._bit
_size
is None or self
._bit
_size
== 32
201 if isinstance(self
.value
, (bool)):
202 return 'NIR_TRUE' if self
.value
else 'NIR_FALSE'
203 if isinstance(self
.value
, integer_types
):
204 return hex(self
.value
)
205 elif isinstance(self
.value
, float):
206 i
= struct
.unpack('Q', struct
.pack('d', self
.value
))[0]
209 # On Python 2 this 'L' suffix is automatically added, but not on Python 3
210 # Adding it explicitly makes the generated file identical, regardless
211 # of the Python version running this script.
212 if h
[-1] != 'L' and i
> sys
.maxsize
:
220 if isinstance(self
.value
, (bool)):
221 return "nir_type_bool"
222 elif isinstance(self
.value
, integer_types
):
223 return "nir_type_int"
224 elif isinstance(self
.value
, float):
225 return "nir_type_float"
227 _var_name_re
= re
.compile(r
"(?P<const>#)?(?P<name>\w+)"
228 r
"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
229 r
"(?P<cond>\([^\)]+\))?")
231 class Variable(Value
):
232 def __init__(self
, val
, name
, varset
):
233 Value
.__init
__(self
, val
, name
, "variable")
235 m
= _var_name_re
.match(val
)
236 assert m
and m
.group('name') is not None
238 self
.var_name
= m
.group('name')
239 self
.is_constant
= m
.group('const') is not None
240 self
.cond
= m
.group('cond')
241 self
.required_type
= m
.group('type')
242 self
._bit
_size
= int(m
.group('bits')) if m
.group('bits') else None
244 if self
.required_type
== 'bool':
245 assert self
._bit
_size
is None or self
._bit
_size
== 32
248 if self
.required_type
is not None:
249 assert self
.required_type
in ('float', 'bool', 'int', 'uint')
251 self
.index
= varset
[self
.var_name
]
254 if self
.required_type
== 'bool':
255 return "nir_type_bool"
256 elif self
.required_type
in ('int', 'uint'):
257 return "nir_type_int"
258 elif self
.required_type
== 'float':
259 return "nir_type_float"
261 _opcode_re
= re
.compile(r
"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
262 r
"(?P<cond>\([^\)]+\))?")
264 class Expression(Value
):
265 def __init__(self
, expr
, name_base
, varset
):
266 Value
.__init
__(self
, expr
, name_base
, "expression")
267 assert isinstance(expr
, tuple)
269 m
= _opcode_re
.match(expr
[0])
270 assert m
and m
.group('opcode') is not None
272 self
.opcode
= m
.group('opcode')
273 self
._bit
_size
= int(m
.group('bits')) if m
.group('bits') else None
274 self
.inexact
= m
.group('inexact') is not None
275 self
.cond
= m
.group('cond')
276 self
.sources
= [ Value
.create(src
, "{0}_{1}".format(name_base
, i
), varset
)
277 for (i
, src
) in enumerate(expr
[1:]) ]
280 srcs
= "\n".join(src
.render() for src
in self
.sources
)
281 return srcs
+ super(Expression
, self
).render()
283 class BitSizeValidator(object):
284 """A class for validating bit sizes of expressions.
286 NIR supports multiple bit-sizes on expressions in order to handle things
287 such as fp64. The source and destination of every ALU operation is
288 assigned a type and that type may or may not specify a bit size. Sources
289 and destinations whose type does not specify a bit size are considered
290 "unsized" and automatically take on the bit size of the corresponding
291 register or SSA value. NIR has two simple rules for bit sizes that are
292 validated by nir_validator:
294 1) A given SSA def or register has a single bit size that is respected by
295 everything that reads from it or writes to it.
297 2) The bit sizes of all unsized inputs/outputs on any given ALU
298 instruction must match. They need not match the sized inputs or
299 outputs but they must match each other.
301 In order to keep nir_algebraic relatively simple and easy-to-use,
302 nir_search supports a type of bit-size inference based on the two rules
303 above. This is similar to type inference in many common programming
304 languages. If, for instance, you are constructing an add operation and you
305 know the second source is 16-bit, then you know that the other source and
306 the destination must also be 16-bit. There are, however, cases where this
307 inference can be ambiguous or contradictory. Consider, for instance, the
308 following transformation:
310 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
312 This transformation can potentially cause a problem because usub_borrow is
313 well-defined for any bit-size of integer. However, b2i always generates a
314 32-bit result so it could end up replacing a 64-bit expression with one
315 that takes two 64-bit values and produces a 32-bit value. As another
316 example, consider this expression:
318 (('bcsel', a, b, 0), ('iand', a, b))
320 In this case, in the search expression a must be 32-bit but b can
321 potentially have any bit size. If we had a 64-bit b value, we would end up
322 trying to and a 32-bit value with a 64-bit value which would be invalid
324 This class solves that problem by providing a validation layer that proves
325 that a given search-and-replace operation is 100% well-defined before we
326 generate any code. This ensures that bugs are caught at compile time
327 rather than at run time.
329 Each value maintains a "bit-size class", which is either an actual bit size
330 or an equivalence class with other values that must have the same bit size.
331 The validator works by combining bit-size classes with each other according
332 to the NIR rules outlined above, checking that there are no inconsistencies.
333 When doing this for the replacement expression, we make sure to never change
334 the equivalence class of any of the search values. We could make the example
335 transforms above work by doing some extra run-time checking of the search
336 expression, but we make the user specify those constraints themselves, to
337 avoid any surprises. Since the replacement bitsizes can only be connected to
338 the source bitsize via variables (variables must have the same bitsize in
339 the source and replacment expressions) or the roots of the expression (the
340 replacement expression must produce the same bit size as the search
341 expression), we prevent merging a variable with anything when processing the
342 replacement expression, or specializing the search bitsize
343 with anything. The former prevents
345 (('bcsel', a, b, 0), ('iand', a, b))
347 from being allowed, since we'd have to merge the bitsizes for a and b due to
348 the 'iand', while the latter prevents
350 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
352 from being allowed, since the search expression has the bit size of a and b,
353 which can't be specialized to 32 which is the bitsize of the replace
354 expression. It also prevents something like:
356 (('b2i', ('i2b', a)), ('ineq', a, 0))
358 since the bitsize of 'b2i', which can be anything, can't be specialized to
361 After doing all this, we check that every subexpression of the replacement
362 was assigned a constant bitsize, the bitsize of a variable, or the bitsize
363 of the search expresssion, since those are the things that are known when
364 constructing the replacement expresssion. Finally, we record the bitsize
365 needed in nir_search_value so that we know what to do when building the
366 replacement expression.
369 def __init__(self
, varset
):
370 self
._var
_classes
= [None] * len(varset
.names
)
372 def compare_bitsizes(self
, a
, b
):
373 """Determines which bitsize class is a specialization of the other, or
374 whether neither is. When we merge two different bitsizes, the
375 less-specialized bitsize always points to the more-specialized one, so
376 that calling get_bit_size() always gets you the most specialized bitsize.
377 The specialization partial order is given by:
378 - Physical bitsizes are always the most specialized, and a different
379 bitsize can never specialize another.
380 - In the search expression, variables can always be specialized to each
381 other and to physical bitsizes. In the replace expression, we disallow
382 this to avoid adding extra constraints to the search expression that
383 the user didn't specify.
384 - Expressions and constants without a bitsize can always be specialized to
385 each other and variables, but not the other way around.
387 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
388 and None if they are not comparable (neither a <= b nor b <= a).
390 if isinstance(a
, int):
391 if isinstance(b
, int):
392 return 0 if a
== b
else None
393 elif isinstance(b
, Variable
):
394 return -1 if self
.is_search
else None
397 elif isinstance(a
, Variable
):
398 if isinstance(b
, int):
399 return 1 if self
.is_search
else None
400 elif isinstance(b
, Variable
):
401 return 0 if self
.is_search
or a
.index
== b
.index
else None
405 if isinstance(b
, int):
407 elif isinstance(b
, Variable
):
412 def unify_bit_size(self
, a
, b
, error_msg
):
413 """Record that a must have the same bit-size as b. If both
414 have been assigned conflicting physical bit-sizes, call "error_msg" with
415 the bit-sizes of self and other to get a message and raise an error.
416 In the replace expression, disallow merging variables with other
417 variables and physical bit-sizes as well.
419 a_bit_size
= a
.get_bit_size()
420 b_bit_size
= b
if isinstance(b
, int) else b
.get_bit_size()
422 cmp_result
= self
.compare_bitsizes(a_bit_size
, b_bit_size
)
424 assert cmp_result
is not None, \
425 error_msg(a_bit_size
, b_bit_size
)
428 b_bit_size
.set_bit_size(a
)
429 elif not isinstance(a_bit_size
, int):
430 a_bit_size
.set_bit_size(b
)
432 def merge_variables(self
, val
):
433 """Perform the first part of type inference by merging all the different
434 uses of the same variable. We always do this as if we're in the search
435 expression, even if we're actually not, since otherwise we'd get errors
436 if the search expression specified some constraint but the replace
437 expression didn't, because we'd be merging a variable and a constant.
439 if isinstance(val
, Variable
):
440 if self
._var
_classes
[val
.index
] is None:
441 self
._var
_classes
[val
.index
] = val
443 other
= self
._var
_classes
[val
.index
]
444 self
.unify_bit_size(other
, val
,
445 lambda other_bit_size
, bit_size
:
446 'Variable {} has conflicting bit size requirements: ' \
447 'it must have bit size {} and {}'.format(
448 val
.var_name
, other_bit_size
, bit_size
))
449 elif isinstance(val
, Expression
):
450 for src
in val
.sources
:
451 self
.merge_variables(src
)
453 def validate_value(self
, val
):
454 """Validate the an expression by performing classic Hindley-Milner
455 type inference on bitsizes. This will detect if there are any conflicting
456 requirements, and unify variables so that we know which variables must
457 have the same bitsize. If we're operating on the replace expression, we
458 will refuse to merge different variables together or merge a variable
459 with a constant, in order to prevent surprises due to rules unexpectedly
460 not matching at runtime.
462 if not isinstance(val
, Expression
):
465 nir_op
= opcodes
[val
.opcode
]
466 assert len(val
.sources
) == nir_op
.num_inputs
, \
467 "Expression {} has {} sources, expected {}".format(
468 val
, len(val
.sources
), nir_op
.num_inputs
)
470 for src
in val
.sources
:
471 self
.validate_value(src
)
473 dst_type_bits
= type_bits(nir_op
.output_type
)
475 # First, unify all the sources. That way, an error coming up because two
476 # sources have an incompatible bit-size won't produce an error message
477 # involving the destination.
478 first_unsized_src
= None
479 for src_type
, src
in zip(nir_op
.input_types
, val
.sources
):
480 src_type_bits
= type_bits(src_type
)
481 if src_type_bits
== 0:
482 if first_unsized_src
is None:
483 first_unsized_src
= src
487 self
.unify_bit_size(first_unsized_src
, src
,
488 lambda first_unsized_src_bit_size
, src_bit_size
:
489 'Source {} of {} must have bit size {}, while source {} ' \
490 'must have incompatible bit size {}'.format(
491 first_unsized_src
, val
, first_unsized_src_bit_size
,
494 self
.unify_bit_size(first_unsized_src
, src
,
495 lambda first_unsized_src_bit_size
, src_bit_size
:
496 'Sources {} (bit size of {}) and {} (bit size of {}) ' \
497 'of {} may not have the same bit size when building the ' \
498 'replacement expression.'.format(
499 first_unsized_src
, first_unsized_src_bit_size
, src
,
503 self
.unify_bit_size(src
, src_type_bits
,
504 lambda src_bit_size
, unused
:
505 '{} must have {} bits, but as a source of nir_op_{} '\
506 'it must have {} bits'.format(
507 src
, src_bit_size
, nir_op
.name
, src_type_bits
))
509 self
.unify_bit_size(src
, src_type_bits
,
510 lambda src_bit_size
, unused
:
511 '{} has the bit size of {}, but as a source of ' \
512 'nir_op_{} it must have {} bits, which may not be the ' \
514 src
, src_bit_size
, nir_op
.name
, src_type_bits
))
516 if dst_type_bits
== 0:
517 if first_unsized_src
is not None:
519 self
.unify_bit_size(val
, first_unsized_src
,
520 lambda val_bit_size
, src_bit_size
:
521 '{} must have the bit size of {}, while its source {} ' \
522 'must have incompatible bit size {}'.format(
523 val
, val_bit_size
, first_unsized_src
, src_bit_size
))
525 self
.unify_bit_size(val
, first_unsized_src
,
526 lambda val_bit_size
, src_bit_size
:
527 '{} must have {} bits, but its source {} ' \
528 '(bit size of {}) may not have that bit size ' \
529 'when building the replacement.'.format(
530 val
, val_bit_size
, first_unsized_src
, src_bit_size
))
532 self
.unify_bit_size(val
, dst_type_bits
,
533 lambda dst_bit_size
, unused
:
534 '{} must have {} bits, but as a destination of nir_op_{} ' \
535 'it must have {} bits'.format(
536 val
, dst_bit_size
, nir_op
.name
, dst_type_bits
))
538 def validate_replace(self
, val
, search
):
539 bit_size
= val
.get_bit_size()
540 assert isinstance(bit_size
, int) or isinstance(bit_size
, Variable
) or \
541 bit_size
== search
.get_bit_size(), \
542 'Ambiguous bit size for replacement value {}: ' \
543 'it cannot be deduced from a variable, a fixed bit size ' \
544 'somewhere, or the search expression.'.format(val
)
546 if isinstance(val
, Expression
):
547 for src
in val
.sources
:
548 self
.validate_replace(src
, search
)
550 def validate(self
, search
, replace
):
551 self
.is_search
= True
552 self
.merge_variables(search
)
553 self
.merge_variables(replace
)
554 self
.validate_value(search
)
556 self
.is_search
= False
557 self
.validate_value(replace
)
559 # Check that search is always more specialized than replace. Note that
560 # we're doing this in replace mode, disallowing merging variables.
561 search_bit_size
= search
.get_bit_size()
562 replace_bit_size
= replace
.get_bit_size()
563 cmp_result
= self
.compare_bitsizes(search_bit_size
, replace_bit_size
)
565 assert cmp_result
is not None and cmp_result
<= 0, \
566 'The search expression bit size {} and replace expression ' \
567 'bit size {} may not be the same'.format(
568 search_bit_size
, replace_bit_size
)
570 replace
.set_bit_size(search
)
572 self
.validate_replace(replace
, search
)
574 _optimization_ids
= itertools
.count()
576 condition_list
= ['true']
578 class SearchAndReplace(object):
579 def __init__(self
, transform
):
580 self
.id = next(_optimization_ids
)
582 search
= transform
[0]
583 replace
= transform
[1]
584 if len(transform
) > 2:
585 self
.condition
= transform
[2]
587 self
.condition
= 'true'
589 if self
.condition
not in condition_list
:
590 condition_list
.append(self
.condition
)
591 self
.condition_index
= condition_list
.index(self
.condition
)
594 if isinstance(search
, Expression
):
597 self
.search
= Expression(search
, "search{0}".format(self
.id), varset
)
601 if isinstance(replace
, Value
):
602 self
.replace
= replace
604 self
.replace
= Value
.create(replace
, "replace{0}".format(self
.id), varset
)
606 BitSizeValidator(varset
).validate(self
.search
, self
.replace
)
608 _algebraic_pass_template
= mako
.template
.Template("""
610 #include "nir_builder.h"
611 #include "nir_search.h"
612 #include "nir_search_helpers.h"
614 #ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS
615 #define NIR_OPT_ALGEBRAIC_STRUCT_DEFS
618 const nir_search_expression *search;
619 const nir_search_value *replace;
620 unsigned condition_offset;
625 % for xform in xforms:
626 ${xform.search.render()}
627 ${xform.replace.render()}
630 % for (opcode, xform_list) in sorted(opcode_xforms.items()):
631 static const struct transform ${pass_name}_${opcode}_xforms[] = {
632 % for xform in xform_list:
633 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} },
639 ${pass_name}_block(nir_builder *build, nir_block *block,
640 const bool *condition_flags)
642 bool progress = false;
644 nir_foreach_instr_reverse_safe(instr, block) {
645 if (instr->type != nir_instr_type_alu)
648 nir_alu_instr *alu = nir_instr_as_alu(instr);
649 if (!alu->dest.dest.is_ssa)
653 % for opcode in sorted(opcode_xforms.keys()):
654 case nir_op_${opcode}:
655 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) {
656 const struct transform *xform = &${pass_name}_${opcode}_xforms[i];
657 if (condition_flags[xform->condition_offset] &&
658 nir_replace_instr(build, alu, xform->search, xform->replace)) {
674 ${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags)
676 bool progress = false;
679 nir_builder_init(&build, impl);
681 nir_foreach_block_reverse(block, impl) {
682 progress |= ${pass_name}_block(&build, block, condition_flags);
686 nir_metadata_preserve(impl, nir_metadata_block_index |
687 nir_metadata_dominance);
694 ${pass_name}(nir_shader *shader)
696 bool progress = false;
697 bool condition_flags[${len(condition_list)}];
698 const nir_shader_compiler_options *options = shader->options;
701 % for index, condition in enumerate(condition_list):
702 condition_flags[${index}] = ${condition};
705 nir_foreach_function(function, shader) {
707 progress |= ${pass_name}_impl(function->impl, condition_flags);
714 class AlgebraicPass(object):
715 def __init__(self
, pass_name
, transforms
):
717 self
.opcode_xforms
= defaultdict(lambda : [])
718 self
.pass_name
= pass_name
722 for xform
in transforms
:
723 if not isinstance(xform
, SearchAndReplace
):
725 xform
= SearchAndReplace(xform
)
727 print("Failed to parse transformation:", file=sys
.stderr
)
728 print(" " + str(xform
), file=sys
.stderr
)
729 traceback
.print_exc(file=sys
.stderr
)
730 print('', file=sys
.stderr
)
734 self
.xforms
.append(xform
)
735 self
.opcode_xforms
[xform
.search
.opcode
].append(xform
)
742 return _algebraic_pass_template
.render(pass_name
=self
.pass_name
,
744 opcode_xforms
=self
.opcode_xforms
,
745 condition_list
=condition_list
)